gr-libs 0.1.8__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. gr_libs/__init__.py +3 -1
  2. gr_libs/_evaluation/__init__.py +1 -0
  3. evaluation/analyze_results_cross_alg_cross_domain.py → gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +81 -88
  4. evaluation/generate_experiments_results.py → gr_libs/_evaluation/_generate_experiments_results.py +6 -6
  5. evaluation/generate_task_specific_statistics_plots.py → gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +11 -14
  6. evaluation/get_plans_images.py → gr_libs/_evaluation/_get_plans_images.py +3 -4
  7. evaluation/increasing_and_decreasing_.py → gr_libs/_evaluation/_increasing_and_decreasing_.py +3 -1
  8. gr_libs/_version.py +2 -2
  9. gr_libs/all_experiments.py +294 -0
  10. gr_libs/environment/__init__.py +14 -1
  11. gr_libs/environment/{utils → _utils}/utils.py +1 -1
  12. gr_libs/environment/environment.py +257 -22
  13. gr_libs/metrics/__init__.py +2 -0
  14. gr_libs/metrics/metrics.py +166 -31
  15. gr_libs/ml/__init__.py +1 -6
  16. gr_libs/ml/base/__init__.py +3 -1
  17. gr_libs/ml/base/rl_agent.py +68 -3
  18. gr_libs/ml/neural/__init__.py +1 -3
  19. gr_libs/ml/neural/deep_rl_learner.py +227 -67
  20. gr_libs/ml/neural/utils/__init__.py +1 -2
  21. gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +1 -1
  22. gr_libs/ml/planner/mcts/mcts_model.py +71 -34
  23. gr_libs/ml/sequential/__init__.py +0 -1
  24. gr_libs/ml/sequential/{lstm_model.py → _lstm_model.py} +11 -14
  25. gr_libs/ml/tabular/__init__.py +1 -3
  26. gr_libs/ml/tabular/tabular_q_learner.py +27 -9
  27. gr_libs/ml/tabular/tabular_rl_agent.py +22 -9
  28. gr_libs/ml/utils/__init__.py +2 -9
  29. gr_libs/ml/utils/format.py +13 -90
  30. gr_libs/ml/utils/math.py +3 -2
  31. gr_libs/ml/utils/other.py +2 -2
  32. gr_libs/ml/utils/storage.py +41 -94
  33. gr_libs/odgr_executor.py +268 -0
  34. gr_libs/problems/consts.py +2 -2
  35. gr_libs/recognizer/_utils/__init__.py +0 -0
  36. gr_libs/recognizer/{utils → _utils}/format.py +2 -2
  37. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +116 -36
  38. gr_libs/recognizer/graml/{gr_dataset.py → _gr_dataset.py} +11 -11
  39. gr_libs/recognizer/graml/graml_recognizer.py +172 -29
  40. gr_libs/recognizer/recognizer.py +59 -10
  41. gr_libs/tutorials/draco_panda_tutorial.py +58 -0
  42. gr_libs/tutorials/draco_parking_tutorial.py +56 -0
  43. {tutorials → gr_libs/tutorials}/gcdraco_panda_tutorial.py +5 -9
  44. {tutorials → gr_libs/tutorials}/gcdraco_parking_tutorial.py +3 -7
  45. {tutorials → gr_libs/tutorials}/graml_minigrid_tutorial.py +2 -2
  46. {tutorials → gr_libs/tutorials}/graml_panda_tutorial.py +5 -10
  47. {tutorials → gr_libs/tutorials}/graml_parking_tutorial.py +5 -9
  48. {tutorials → gr_libs/tutorials}/graml_point_maze_tutorial.py +2 -1
  49. {tutorials → gr_libs/tutorials}/graql_minigrid_tutorial.py +2 -2
  50. {gr_libs-0.1.8.dist-info → gr_libs-0.2.2.dist-info}/METADATA +84 -29
  51. gr_libs-0.2.2.dist-info/RECORD +71 -0
  52. {gr_libs-0.1.8.dist-info → gr_libs-0.2.2.dist-info}/WHEEL +1 -1
  53. gr_libs-0.2.2.dist-info/top_level.txt +2 -0
  54. tests/test_draco.py +14 -0
  55. tests/test_gcdraco.py +2 -2
  56. tests/test_graml.py +4 -4
  57. tests/test_graql.py +1 -1
  58. evaluation/create_minigrid_map_image.py +0 -38
  59. evaluation/file_system.py +0 -53
  60. evaluation/generate_experiments_results_new_ver1.py +0 -238
  61. evaluation/generate_experiments_results_new_ver2.py +0 -331
  62. gr_libs/ml/neural/utils/penv.py +0 -60
  63. gr_libs/recognizer/utils/__init__.py +0 -1
  64. gr_libs-0.1.8.dist-info/RECORD +0 -70
  65. gr_libs-0.1.8.dist-info/top_level.txt +0 -4
  66. /gr_libs/environment/{utils → _utils}/__init__.py +0 -0
  67. /gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +0 -0
  68. /gr_libs/ml/planner/mcts/{utils → _utils}/node.py +0 -0
@@ -1,24 +1,26 @@
1
- from collections import OrderedDict
2
1
  import gc
2
+ from collections import OrderedDict
3
3
  from types import MethodType
4
- from typing import List, Tuple
5
- import numpy as np
4
+
6
5
  import cv2
6
+ import numpy as np
7
7
 
8
8
  from gr_libs.environment.environment import EnvProperty
9
9
 
10
10
  if __name__ != "__main__":
11
11
  from gr_libs.ml.utils.storage import get_agent_model_dir
12
12
  from gr_libs.ml.utils.format import random_subset_with_order
13
- from stable_baselines3 import SAC, PPO, TD3
14
- from stable_baselines3.common.base_class import BaseAlgorithm
15
- from gr_libs.ml.utils import device
16
- import gymnasium as gym
13
+
14
+ import os
17
15
 
18
16
  # built-in python modules
19
17
  import random
20
- import os
21
- import sys
18
+
19
+ import gymnasium as gym
20
+ from stable_baselines3 import PPO, SAC, TD3
21
+ from stable_baselines3.common.base_class import BaseAlgorithm
22
+
23
+ from gr_libs.ml.utils import device
22
24
 
23
25
  # TODO do we need this?
24
26
  NETWORK_SETUP = {
@@ -42,7 +44,6 @@ NETWORK_SETUP = {
42
44
  ("normalize_kwargs", {"norm_obs": False, "norm_reward": False}),
43
45
  ]
44
46
  ),
45
- # "tqc": OrderedDict([('batch_size', 256), ('buffer_size', 1000000), ('ent_coef', 'auto'), ('env_wrapper', ['sb3_contrib.common.wrappers.TimeFeatureWrapper']), ('gamma', 0.95), ('learning_rate', 0.001), ('learning_starts', 1000), ('n_timesteps', 25000.0), ('normalize', False), ('policy', 'MultiInputPolicy'), ('policy_kwargs', 'dict(net_arch=[64, 64])'), ('replay_buffer_class', 'HerReplayBuffer'), ('replay_buffer_kwargs', "dict( goal_selection_strategy='future', n_sampled_goal=4 )"), ('normalize_kwargs',{'norm_obs':False,'norm_reward':False})]),
46
47
  PPO: OrderedDict(
47
48
  [
48
49
  ("batch_size", 256),
@@ -68,6 +69,22 @@ NETWORK_SETUP = {
68
69
 
69
70
 
70
71
  class DeepRLAgent:
72
+ """
73
+ Deep Reinforcement Learning Agent, wrapping a SB3 agent and adding functionality,
74
+ needed for GR framework executions such as observation generation and video recording.
75
+ Supports SAC, PPO and TD3 algorithms.
76
+ Can be loaded from rl_zoo or trained from scratch.
77
+
78
+ Args:
79
+ domain_name (str): The domain name.
80
+ problem_name (str): The problem name.
81
+ num_timesteps (float): The number of timesteps for training.
82
+ env_prop (EnvProperty): The environment property.
83
+ algorithm (BaseAlgorithm, optional): The algorithm to use. Defaults to SAC.
84
+ reward_threshold (float, optional): The reward threshold. Defaults to 450.
85
+ exploration_rate (float, optional): The exploration rate. Defaults to None.
86
+ """
87
+
71
88
  def __init__(
72
89
  self,
73
90
  domain_name: str,
@@ -78,7 +95,18 @@ class DeepRLAgent:
78
95
  reward_threshold: float = 450,
79
96
  exploration_rate=None,
80
97
  ):
81
- # Need to change reward threshold to change according to which task the agent is training on, becuase it changes from task to task.
98
+ """
99
+ Initialize the DeepRLLearner object.
100
+
101
+ Args:
102
+ domain_name (str): The name of the domain.
103
+ problem_name (str): The name of the problem.
104
+ num_timesteps (float): The number of timesteps.
105
+ env_prop (EnvProperty): The environment property.
106
+ algorithm (BaseAlgorithm, optional): The algorithm to use. Defaults to SAC.
107
+ reward_threshold (float, optional): The reward threshold. Defaults to 450.
108
+ exploration_rate (float, optional): The exploration rate. Defaults to None.
109
+ """
82
110
  env_kwargs = {"id": problem_name, "render_mode": "rgb_array"}
83
111
  assert algorithm in [SAC, PPO, TD3]
84
112
 
@@ -110,7 +138,8 @@ class DeepRLAgent:
110
138
  "seed": 0,
111
139
  "buffer_size": 1,
112
140
  }
113
- # second support: models saved with SB3's model.save, which is saved as a formatted .pth file.
141
+ # second support: models saved with SB3's model.save, which is saved as a
142
+ # formatted .pth file.
114
143
  else:
115
144
  self.model_kwargs = {}
116
145
  self._model_file_path = os.path.join(
@@ -122,9 +151,17 @@ class DeepRLAgent:
122
151
  self.num_timesteps = num_timesteps
123
152
 
124
153
  def save_model(self):
154
+ """Save the model to a file."""
125
155
  self._model.save(self._model_file_path)
126
156
 
127
157
  def try_recording_video(self, video_path, desired=None):
158
+ """
159
+ Try recording a video of the agent's performance.
160
+
161
+ Args:
162
+ video_path (str): The path to save the video.
163
+ desired (optional): The desired goal. Defaults to None.
164
+ """
128
165
  num_tries = 0
129
166
  while True:
130
167
  if num_tries >= 10:
@@ -132,20 +169,26 @@ class DeepRLAgent:
132
169
  try:
133
170
  self.record_video(video_path, desired)
134
171
  break
135
- except Exception as e:
172
+ except Exception:
136
173
  num_tries += 1
137
174
  # print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
138
175
  print(f"generated sequence video at {video_path}.")
139
176
 
140
177
  def record_video(self, video_path, desired=None):
141
- """Record a video of the agent's performance."""
178
+ """
179
+ Record a video of the agent's performance.
180
+
181
+ Args:
182
+ video_path (str): The path to save the video.
183
+ desired (optional): The desired goal. Defaults to None.
184
+ """
142
185
  fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v")
143
186
  fps = 30.0
144
187
  # if is_gc:
145
- # assert goal_idx != None
188
+ # assert goal_idx is not None
146
189
  # self.reset_with_goal_idx(goal_idx)
147
190
  # else:
148
- # assert goal_idx == None
191
+ # assert goal_idx is None
149
192
  self.env.reset()
150
193
  frame_size = (
151
194
  self.env.render(mode="rgb_array").shape[1],
@@ -186,34 +229,32 @@ class DeepRLAgent:
186
229
  obs, desired, success_done
187
230
  )
188
231
  video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
189
- if general_done == False != success_done == True:
232
+ if general_done == False and success_done == True:
190
233
  assert (
191
234
  desired is not None
192
- ), f"general_done is false but success_done is true, and desired is None. This should never happen, since the \
193
- environment will say 'done' is false (general_done) while the observation will be close to the goal (success_done) \
194
- only in case we incorporated a 'desired' when generating the observation."
195
- elif general_done == True != success_done == False:
235
+ ), f"general_done is false but success_done is true, and desired is None. \
236
+ This should never happen, since the environment will say 'done' is false \
237
+ (general_done) while the observation will be close to the goal (success_done) \
238
+ only in case we incorporated a 'desired' when generating the observation."
239
+ elif general_done == True and success_done == False:
196
240
  raise Exception("general_done is true but success_done is false")
197
241
  self.env.close()
198
242
  video_writer.release()
199
243
 
200
244
  def load_model(self):
245
+ """Load the model from a file."""
201
246
  self._model = self.algorithm.load(
202
247
  self._model_file_path, env=self.env, device=device, **self.model_kwargs
203
248
  )
204
249
 
205
250
  def learn(self):
251
+ """Train the agent."""
206
252
  if os.path.exists(self._model_file_path):
207
253
  print(f"Loading pre-existing model in {self._model_file_path}")
208
254
  self.load_model()
209
255
  else:
210
- # Stop training when the model reaches the reward threshold
211
- # callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=self.reward_threshold, verbose=1)
212
- # eval_callback = EvalCallback(self.env, best_model_save_path="./logs/",
213
- # log_path="./logs/", eval_freq=500, callback_on_new_best=callback_on_best, verbose=1, render=True)
214
- # self._model.learn(total_timesteps=self.num_timesteps, progress_bar=True, callback=eval_callback)
215
256
  print(f"No existing model in {self._model_file_path}, starting learning")
216
- if self.exploration_rate != None:
257
+ if self.exploration_rate is not None:
217
258
  self._model = self.algorithm(
218
259
  "MultiInputPolicy",
219
260
  self.env,
@@ -228,15 +269,30 @@ class DeepRLAgent:
228
269
  self.save_model()
229
270
 
230
271
  def safe_env_reset(self):
272
+ """
273
+ Reset the environment safely.
274
+
275
+ Returns:
276
+ The initial observation.
277
+ """
231
278
  try:
232
279
  obs = self.env.reset()
233
- except Exception as e:
280
+ except Exception:
234
281
  kwargs = {"id": self.problem_name, "render_mode": "rgb_array"}
235
282
  self.env = self.env_prop.create_vec_env(kwargs)
236
283
  obs = self.env.reset()
237
284
  return obs
238
285
 
239
286
  def get_mean_and_std_dev(self, observation):
287
+ """
288
+ Get the mean and standard deviation of the action distribution.
289
+
290
+ Args:
291
+ observation: The observation.
292
+
293
+ Returns:
294
+ The mean and standard deviation of the action distribution.
295
+ """
240
296
  if self.algorithm == SAC:
241
297
  tensor_observation, _ = self._model.actor.obs_to_tensor(observation)
242
298
 
@@ -266,9 +322,20 @@ class DeepRLAgent:
266
322
  assert False
267
323
  return actor_means, log_std_dev
268
324
 
269
- # fits agents that generated observations in the form of: list of tuples, each tuple a single step\frame with size 2, comprised of obs and action.
270
- # the function squashes the 2d array of obs and action in a 1d array, concatenating their values together for training.
271
325
  def simplify_observation(self, observation):
326
+ """
327
+ Simplifies the given observation by concatenating the last dimension of each observation and action.
328
+ fits agents that generated observations in the form of: list of tuples, each tuple a single
329
+ step\frame with size 2, comprised of obs and action.
330
+ the function squashes the 2d array of obs and action in a 1d array, concatenating their
331
+ values together for training.
332
+
333
+ Args:
334
+ observation (list): List of tuples containing observation and action.
335
+
336
+ Returns:
337
+ list: List of simplified observations.
338
+ """
272
339
  return [
273
340
  np.concatenate(
274
341
  (
@@ -280,6 +347,17 @@ class DeepRLAgent:
280
347
  ]
281
348
 
282
349
  def add_random_optimalism(self, observations, action, constant_initial_action):
350
+ """
351
+ Adds random optimalism to the given action based on the length of observations.
352
+
353
+ Parameters:
354
+ observations (list): List of observations.
355
+ action (ndarray): Action to modify.
356
+ constant_initial_action (float): Initial action value.
357
+
358
+ Returns:
359
+ ndarray: Modified action.
360
+ """
283
361
  if len(observations) > 3:
284
362
  for i in range(0, len(action[0])):
285
363
  action[0][i] += random.uniform(
@@ -287,6 +365,7 @@ class DeepRLAgent:
287
365
  )
288
366
  else: # just walk in a specific random direction to enable diverse plans
289
367
  action = np.array(np.array([constant_initial_action]), None)
368
+ return action
290
369
 
291
370
  def generate_partial_observation(
292
371
  self,
@@ -297,6 +376,20 @@ class DeepRLAgent:
297
376
  fig_path=None,
298
377
  random_optimalism=True,
299
378
  ):
379
+ """
380
+ Generates a partial observation by selecting a subset of steps from a full observation.
381
+
382
+ Args:
383
+ action_selection_method (str): The method used for selecting actions.
384
+ percentage (float): The percentage of steps to include in the partial observation.
385
+ is_consecutive (bool): Whether the selected steps should be consecutive or not.
386
+ save_fig (bool, optional): Whether to save a figure of the observation. Defaults to False.
387
+ fig_path (str, optional): The path to save the figure. Defaults to None.
388
+ random_optimalism (bool, optional): Whether to apply random optimalism during observation generation. Defaults to True.
389
+
390
+ Returns:
391
+ list: A partial observation consisting of a subset of steps from the full observation.
392
+ """
300
393
  steps = self.generate_observation(
301
394
  action_selection_method,
302
395
  save_fig=save_fig,
@@ -315,25 +408,39 @@ class DeepRLAgent:
315
408
  fig_path=None,
316
409
  with_dict=False,
317
410
  desired=None,
318
- ) -> List[
319
- Tuple[np.ndarray, np.ndarray]
320
- ]: # TODO make sure to add a linter to alert when a method doesn't accept or return the type it should
321
- if save_fig == False:
411
+ ) -> list[tuple[np.ndarray, np.ndarray]]:
412
+ """
413
+ Generates observations by interacting with the environment.
414
+
415
+ Args:
416
+ action_selection_method (MethodType): The method used for action selection.
417
+ random_optimalism (bool): Flag indicating whether to add random optimalism to the actions.
418
+ save_fig (bool, optional): Flag indicating whether to save a figure. Defaults to False.
419
+ fig_path (str, optional): The path to save the figure. Required if save_fig is True. Defaults to None.
420
+ with_dict (bool, optional): Flag indicating whether to include the observation as a dictionary. Defaults to False.
421
+ desired (Any, optional): The desired goal for the observation. Defaults to None.
422
+
423
+ Returns:
424
+ list[tuple[np.ndarray, np.ndarray]]: A list of tuples containing the observation and the corresponding action.
425
+ """
426
+ if save_fig is False:
322
427
  assert (
323
- fig_path == None
428
+ fig_path is None
324
429
  ), "You can't specify a vid path when you don't even save the figure."
325
430
  else:
326
431
  assert (
327
- fig_path != None
432
+ fig_path is not None
328
433
  ), "You need to specify a vid path when you save the figure."
329
- # The try-except is a bug fix for the env not being reset properly in panda. If someone wants to check why and provide a robust solution they're welcome.
434
+ # The try-except is a bug fix for the env not being reset properly in panda.
435
+ # If someone wants to check why and provide a robust solution they're welcome.
330
436
  obs = self.safe_env_reset()
331
437
  self.env_prop.change_goal_to_specific_desired(obs, desired)
332
438
  observations = []
333
439
  is_successful_observation_made = False
334
440
  num_of_insuccessful_attempts = 0
335
441
  while not is_successful_observation_made:
336
- is_successful_observation_made = True # start as true, if this isn't the case (crash/death/truncation instead of success)
442
+ # start as true, if this isn't the case (crash/death/truncation instead of success)
443
+ is_successful_observation_made = True
337
444
  if random_optimalism:
338
445
  constant_initial_action = self.env.action_space.sample()
339
446
  while True:
@@ -343,9 +450,8 @@ class DeepRLAgent:
343
450
  action_selection_method != stochastic_amplified_selection
344
451
  )
345
452
  action, _states = self._model.predict(obs, deterministic=deterministic)
346
- if (
347
- random_optimalism
348
- ): # get the right direction and then start inserting noise to still get a relatively optimal plan
453
+ if random_optimalism:
454
+ # get the right direction and then start inserting noise to still get a relatively optimal plan
349
455
  self.add_random_optimalism(obs, action, constant_initial_action)
350
456
  if with_dict:
351
457
  observations.append((obs, action))
@@ -353,22 +459,31 @@ class DeepRLAgent:
353
459
  observations.append((obs["observation"], action))
354
460
  obs, reward, done, info = self.env.step(action)
355
461
  self.env_prop.change_goal_to_specific_desired(obs, desired)
356
- general_done = self.env_prop.is_done(done)
462
+ general_done = bool(self.env_prop.is_done(done))
357
463
  success_done = self.env_prop.is_success(info)
358
- success_done = self.env_prop.change_done_by_specific_desired(
359
- obs, desired, success_done
464
+ success_done = bool(
465
+ self.env_prop.change_done_by_specific_desired(
466
+ obs, desired, success_done
467
+ )
360
468
  )
361
- if general_done == True and success_done == False:
362
- # it could be that the stochasticity inserted into the actions made the agent die/crash. we don't want this observation: it's an insuccessful attempt.
469
+ if general_done is True and success_done is False:
470
+ # it could be that the stochasticity inserted into the actions made the agent die/crash.
471
+ # we don't want this observation: it's an insuccessful attempt.
363
472
  num_of_insuccessful_attempts += 1
364
- # print(f"for agent for problem {self.problem_name}, its done {len(observations)} steps, and got to a situation where general_done != success_done, for the {num_of_insuccessful_attempts} time.")
473
+ # print(f"for agent for problem {self.problem_name}, its done
474
+ # {len(observations)} steps, and got to a situation where
475
+ # general_done != success_done, for the {num_of_insuccessful_attempts} time.")
365
476
  if num_of_insuccessful_attempts > 50:
366
477
  # print(f"got more then 10 insuccessful attempts!")
367
478
  assert (
368
- general_done == success_done
369
- ), f"failed on goal: {obs['desired']}" # we want to make sure the episode is done only when the agent has actually succeeded with the task.
479
+ general_done
480
+ == success_done
481
+ # we want to make sure the episode is done only
482
+ # when the agent has actually succeeded with the task.
483
+ ), f"failed on goal: {obs['desired']}"
370
484
  else:
371
- # try again by breaking inner loop. everything is set up to be like the beginning of the function.
485
+ # try again by breaking inner loop.
486
+ # everything is set up to be like the beginning of the function.
372
487
  is_successful_observation_made = False
373
488
  obs = self.safe_env_reset()
374
489
  self.env_prop.change_goal_to_specific_desired(obs, desired)
@@ -376,20 +491,21 @@ class DeepRLAgent:
376
491
  []
377
492
  ) # we want to re-accumulate the observations from scratch, have another try
378
493
  break
379
- elif general_done == False and success_done == False:
494
+ elif general_done is False and success_done is False:
380
495
  continue
381
- elif general_done == True and success_done == True:
496
+ elif general_done is True and success_done is True:
382
497
  if num_of_insuccessful_attempts > 0:
383
498
  pass # print(f"after {num_of_insuccessful_attempts}, finally I succeeded!")
384
499
  break
385
- elif general_done == False and success_done == True:
386
- # The environment will say 'done' is false (general_done) while the observation will be close to the goal (success_done)
387
- # only in case we incorporated a 'desired' when generating the observation.
500
+ elif general_done is False and success_done is True:
501
+ # The environment will say 'done' is false (general_done) while the observation
502
+ # will be close to the goal (success_done) only in case we incorporated a 'desired'
503
+ # when generating the observation.
388
504
  assert (
389
505
  desired is not None
390
506
  ), f"general_done is false but success_done is true, and desired is None. This should never happen, since the \
391
- environment will say 'done' is false (general_done) while the observation will be close to the goal (success_done) \
392
- only in case we incorporated a 'desired' when generating the observation."
507
+ environment will say 'done' is false (general_done) while the observation will be close to the goal (success_done) \
508
+ only in case we incorporated a 'desired' when generating the observation."
393
509
  break
394
510
 
395
511
  if save_fig:
@@ -400,6 +516,23 @@ class DeepRLAgent:
400
516
 
401
517
 
402
518
  class GCDeepRLAgent(DeepRLAgent):
519
+ """
520
+ A class representing a Goal Conditioned Deep Reinforcement Learning Agent.
521
+
522
+ This agent extends the functionality of the base DeepRLAgent class by providing methods for generating partial observations and observations with goal-directed goals or problems.
523
+
524
+ Args:
525
+ DeepRLAgent (class): The base class for DeepRLAgent.
526
+
527
+ Attributes:
528
+ env (object): The environment in which the agent operates.
529
+ env_prop (object): The environment properties.
530
+
531
+ Methods:
532
+ generate_partial_observation: Generates a partial observation based on a given percentage of steps.
533
+ generate_observation: Generates an observation with optional goal-directed goals or problems.
534
+ """
535
+
403
536
  def generate_partial_observation(
404
537
  self,
405
538
  action_selection_method,
@@ -411,6 +544,22 @@ class GCDeepRLAgent(DeepRLAgent):
411
544
  fig_path=None,
412
545
  random_optimalism=True,
413
546
  ):
547
+ """
548
+ Generates a partial observation based on a given percentage of steps.
549
+
550
+ Args:
551
+ action_selection_method (MethodType): The method for selecting actions.
552
+ percentage (float): The percentage of steps to include in the partial observation.
553
+ is_consecutive (bool): Whether the steps should be consecutive or randomly selected.
554
+ goal_directed_problem (str, optional): The goal-directed problem. Defaults to None.
555
+ goal_directed_goal (object, optional): The goal-directed goal. Defaults to None.
556
+ save_fig (bool, optional): Whether to save a figure. Defaults to False.
557
+ fig_path (str, optional): The path to save the figure. Defaults to None.
558
+ random_optimalism (bool, optional): Whether to use random optimalism. Defaults to True.
559
+
560
+ Returns:
561
+ list: A random subset of steps from the full observation.
562
+ """
414
563
  steps = self.generate_observation(
415
564
  action_selection_method,
416
565
  save_fig=save_fig,
@@ -423,8 +572,6 @@ class GCDeepRLAgent(DeepRLAgent):
423
572
  steps, (int)(percentage * len(steps)), is_consecutive
424
573
  )
425
574
 
426
- # TODO move the goal_directed_goal and/or goal_directed_problem mechanism to be a property of the env_property, so deep_rl_learner doesn't depend on it and holds this logic so heavily.
427
- # Generate observation with goal_directed_goal or goal_directed_problem is only possible for a GC agent, otherwise - the agent can't act optimally to that new goal.
428
575
  def generate_observation(
429
576
  self,
430
577
  action_selection_method: MethodType,
@@ -435,16 +582,31 @@ class GCDeepRLAgent(DeepRLAgent):
435
582
  fig_path=None,
436
583
  with_dict=False,
437
584
  ):
585
+ """
586
+ Generates an observation with optional goal-directed goals or problems.
587
+
588
+ Args:
589
+ action_selection_method (MethodType): The method for selecting actions.
590
+ random_optimalism (bool): Whether to use random optimalism.
591
+ goal_directed_problem (str, optional): The goal-directed problem. Defaults to None.
592
+ goal_directed_goal (object, optional): The goal-directed goal. Defaults to None.
593
+ save_fig (bool, optional): Whether to save a figure. Defaults to False.
594
+ fig_path (str, optional): The path to save the figure. Defaults to None.
595
+ with_dict (bool, optional): Whether to include a dictionary in the observation. Defaults to False.
596
+
597
+ Returns:
598
+ list: The generated observation.
599
+ """
438
600
  if save_fig:
439
601
  assert (
440
- fig_path != None
602
+ fig_path is not None
441
603
  ), "You need to specify a vid path when you save the figure."
442
604
  else:
443
- assert fig_path == None
444
- # goal_directed_problem employs the GC agent in a new env with a static, predefined goal, and has him generate an observation sequence in it.
605
+ assert fig_path is None
606
+
445
607
  if goal_directed_problem:
446
608
  assert (
447
- goal_directed_goal == None
609
+ goal_directed_goal is None
448
610
  ), "can't give goal directed goal and also goal directed problem for the sake of sequence generation by a general agent"
449
611
  kwargs = {"id": goal_directed_problem, "render_mode": "rgb_array"}
450
612
  self.env = self.env_prop.create_vec_env(kwargs)
@@ -457,11 +619,9 @@ class GCDeepRLAgent(DeepRLAgent):
457
619
  with_dict=with_dict,
458
620
  )
459
621
  self.env = orig_env
460
- # goal_directed_goal employs the agent in the same env on which it trained - with goals that change with every episode sampled from the goal space.
461
- # but we manually change the 'desired' part of the observation to be the goal_directed_goal and edit the id_success and is_done accordingly.
462
622
  else:
463
623
  assert (
464
- goal_directed_problem == None
624
+ goal_directed_problem is None
465
625
  ), "can't give goal directed goal and also goal directed problem for the sake of sequence generation by a general agent"
466
626
  observations = super().generate_observation(
467
627
  action_selection_method=action_selection_method,
@@ -470,5 +630,5 @@ class GCDeepRLAgent(DeepRLAgent):
470
630
  fig_path=fig_path,
471
631
  with_dict=with_dict,
472
632
  desired=goal_directed_goal,
473
- ) # TODO tutorial on how to use the deepRLAgent for sequence generation and examination and plotting of the sequence
633
+ )
474
634
  return observations
@@ -1,2 +1 @@
1
- from gr_libs.ml.neural.utils.dictlist import DictList
2
- from gr_libs.ml.neural.utils.penv import ParallelEnv
1
+ """ utility functions for GR algorithms that use neural networks """
@@ -102,5 +102,5 @@ class Tree:
102
102
  def show(self):
103
103
  lines = ""
104
104
  for edge, node in self.iter(identifier=None, depth=0, last_node_flags=[]):
105
- lines += "{}{}\n".format(edge, node)
105
+ lines += f"{edge}{node}\n"
106
106
  print(lines)