gr-libs 0.2.2__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. gr_libs/_evaluation/_generate_experiments_results.py +0 -141
  2. gr_libs/_version.py +2 -2
  3. gr_libs/all_experiments.py +73 -107
  4. gr_libs/environment/environment.py +22 -2
  5. gr_libs/evaluation/generate_experiments_results.py +100 -0
  6. gr_libs/ml/neural/deep_rl_learner.py +17 -20
  7. gr_libs/odgr_executor.py +20 -25
  8. gr_libs/problems/consts.py +568 -290
  9. gr_libs/recognizer/_utils/__init__.py +1 -0
  10. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +12 -1
  11. gr_libs/recognizer/graml/graml_recognizer.py +16 -8
  12. gr_libs/tutorials/gcdraco_panda_tutorial.py +6 -2
  13. gr_libs/tutorials/gcdraco_parking_tutorial.py +3 -1
  14. gr_libs/tutorials/graml_minigrid_tutorial.py +16 -12
  15. gr_libs/tutorials/graml_panda_tutorial.py +6 -2
  16. gr_libs/tutorials/graml_parking_tutorial.py +3 -1
  17. gr_libs/tutorials/graml_point_maze_tutorial.py +15 -2
  18. {gr_libs-0.2.2.dist-info → gr_libs-0.2.5.dist-info}/METADATA +27 -16
  19. {gr_libs-0.2.2.dist-info → gr_libs-0.2.5.dist-info}/RECORD +26 -25
  20. {gr_libs-0.2.2.dist-info → gr_libs-0.2.5.dist-info}/WHEEL +1 -1
  21. tests/test_odgr_executor_expertbasedgraml.py +14 -0
  22. tests/test_odgr_executor_gcdraco.py +14 -0
  23. tests/test_odgr_executor_gcgraml.py +14 -0
  24. tests/test_odgr_executor_graql.py +14 -0
  25. gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +0 -260
  26. gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +0 -497
  27. gr_libs/_evaluation/_get_plans_images.py +0 -61
  28. gr_libs/_evaluation/_increasing_and_decreasing_.py +0 -106
  29. /gr_libs/{_evaluation → evaluation}/__init__.py +0 -0
  30. {gr_libs-0.2.2.dist-info → gr_libs-0.2.5.dist-info}/top_level.txt +0 -0
@@ -1,141 +0,0 @@
1
- import copy
2
- import os
3
-
4
- import dill
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
-
8
- from gr_libs.ml.utils.storage import (
9
- get_experiment_results_path,
10
- set_global_storage_configs,
11
- )
12
-
13
-
14
- def gen_graph(
15
- graph_name,
16
- x_label_str,
17
- tasks,
18
- panda_env,
19
- minigrid_env,
20
- parking_env,
21
- maze_env,
22
- percentage,
23
- ):
24
-
25
- fragmented_accuracies = {
26
- "graml": {
27
- #'panda': [],
28
- #'minigrid': [],
29
- #'point_maze': [],
30
- "parking": []
31
- },
32
- "graql": {
33
- #'panda': [],
34
- #'minigrid': [],
35
- #'point_maze': [],
36
- "parking": []
37
- },
38
- }
39
-
40
- continuing_accuracies = copy.deepcopy(fragmented_accuracies)
41
-
42
- # domains_envs = [('minigrid', minigrid_env), ('point_maze', maze_env), ('parking', parking_env)]
43
- domains_envs = [("parking", parking_env)]
44
-
45
- for partial_obs_type, accuracies, is_same_learn in zip(
46
- ["fragmented", "continuing"],
47
- [fragmented_accuracies, continuing_accuracies],
48
- [False, True],
49
- ):
50
- for domain, env in domains_envs:
51
- for task in tasks:
52
- set_global_storage_configs(
53
- recognizer_str="graml",
54
- is_fragmented=partial_obs_type,
55
- is_inference_same_length_sequences=True,
56
- is_learn_same_length_sequences=is_same_learn,
57
- )
58
- graml_res_file_path = (
59
- f"{get_experiment_results_path(domain, env, task)}.pkl"
60
- )
61
- set_global_storage_configs(
62
- recognizer_str="graql", is_fragmented=partial_obs_type
63
- )
64
- graql_res_file_path = (
65
- f"{get_experiment_results_path(domain, env, task)}.pkl"
66
- )
67
- if os.path.exists(graml_res_file_path):
68
- with open(graml_res_file_path, "rb") as results_file:
69
- results = dill.load(results_file)
70
- accuracies["graml"][domain].append(
71
- results[percentage]["accuracy"]
72
- )
73
- else:
74
- assert False, f"no file for {graml_res_file_path}"
75
- if os.path.exists(graql_res_file_path):
76
- with open(graql_res_file_path, "rb") as results_file:
77
- results = dill.load(results_file)
78
- accuracies["graql"][domain].append(
79
- results[percentage]["accuracy"]
80
- )
81
- else:
82
- assert False, f"no file for {graql_res_file_path}"
83
-
84
- def plot_accuracies(accuracies, partial_obs_type):
85
- plt.figure(figsize=(10, 6))
86
- colors = plt.cm.get_cmap(
87
- "tab10", len(accuracies["graml"]) * len(accuracies["graml"]["parking"])
88
- )
89
-
90
- # Define different line styles for each algorithm
91
- line_styles = {"graml": "-", "graql": "--"}
92
- x_vals = np.arange(3, 8)
93
- plt.xticks(x_vals)
94
- plt.yticks(np.linspace(0, 1, 6))
95
- plt.ylim([0, 1])
96
- # Plot each domain-env pair's accuracies with different line styles for each algorithm
97
- for alg in ["graml", "graql"]:
98
- for idx, (domain, acc_values) in enumerate(accuracies[alg].items()):
99
- if acc_values and len(acc_values) > 0: # Only plot if there are values
100
- x_values = np.arange(3, len(acc_values) + 3)
101
- plt.plot(
102
- x_values,
103
- acc_values,
104
- marker="o",
105
- linestyle=line_styles[alg],
106
- color=colors(idx),
107
- label=f"{alg}-{domain}-{partial_obs_type}-{percentage}",
108
- )
109
-
110
- # Set labels, title, and grid
111
- plt.xlabel(x_label_str)
112
- plt.ylabel("Accuracy")
113
- plt.grid(True)
114
-
115
- # Add legend to differentiate between domain-env pairs
116
- plt.legend()
117
-
118
- # Save the figure
119
- fig_path = os.path.join(f"{graph_name}_{partial_obs_type}.png")
120
- plt.savefig(fig_path)
121
- print(f"Accuracies figure saved at: {fig_path}")
122
-
123
- print(f"fragmented_accuracies: {fragmented_accuracies}")
124
- plot_accuracies(fragmented_accuracies, "fragmented")
125
- print(f"continuing_accuracies: {continuing_accuracies}")
126
- plot_accuracies(continuing_accuracies, "continuing")
127
-
128
-
129
- if __name__ == "__main__":
130
- # gen_graph("increasing_base_goals", "Number of base goals", ['L1', 'L2', 'L3', 'L4', 'L5'], panda_env='gd_agent', minigrid_env='obstacles', parking_env='gd_agent', maze_env='obstacles')
131
- # gen_graph("increasing_dynamic_goals", "Number of dynamic goals", ['L1', 'L2', 'L3', 'L4', 'L5'], panda_env='gc_agent', minigrid_env='lava_crossing', parking_env='gc_agent', maze_env='four_rooms')
132
- gen_graph(
133
- "base_problems",
134
- "Number of goals",
135
- ["L111", "L222", "L333", "L444", "L555"],
136
- panda_env="gd_agent",
137
- minigrid_env="obstacles",
138
- parking_env="gc_agent",
139
- maze_env="obstacles",
140
- percentage="0.7",
141
- )
gr_libs/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.2'
21
- __version_tuple__ = version_tuple = (0, 2, 2)
20
+ __version__ = version = '0.2.5'
21
+ __version_tuple__ = version_tuple = (0, 2, 5)
@@ -1,67 +1,43 @@
1
1
  """ executes odgr_executor parallely on a set of problems defined in consts.py """
2
2
 
3
+ import argparse
3
4
  import concurrent.futures
4
5
  import os
5
6
  import subprocess
6
7
  import sys
7
- import threading
8
8
 
9
9
  import dill
10
10
  import numpy as np
11
11
 
12
12
  from gr_libs.ml.utils.storage import get_experiment_results_path
13
13
 
14
- # Define the lists
15
- # domains = ['minigrid', 'point_maze', 'parking', 'panda']
16
- # envs = {
17
- # 'minigrid': ['obstacles', 'lava_crossing'],
18
- # 'point_maze': ['four_rooms', 'lava_crossing'],
19
- # 'parking': ['gc_agent', 'gd_agent'],
20
- # 'panda': ['gc_agent', 'gd_agent']
21
- # }
22
- # tasks = {
23
- # 'minigrid': ['L111', 'L222', 'L333', 'L444', 'L555'],
24
- # 'point_maze': ['L111', 'L222', 'L333', 'L444', 'L555'],
25
- # 'parking': ['L111', 'L222', 'L333', 'L444', 'L555'],
26
- # 'panda': ['L111', 'L222', 'L333', 'L444', 'L555']
27
- # }
28
- configs = {
29
- "minigrid": {
30
- "MiniGrid-SimpleCrossingS13N4": ["L1", "L2", "L3", "L4", "L5"],
31
- "MiniGrid-LavaCrossingS9N2": ["L1", "L2", "L3", "L4", "L5"],
32
- }
33
- # 'point_maze': {
34
- # 'PointMaze-FourRoomsEnvDense-11x11': ['L1', 'L2', 'L3', 'L4', 'L5'],
35
- # 'PointMaze-ObstaclesEnvDense-11x11': ['L1', 'L2', 'L3', 'L4', 'L5']
36
- # }
37
- # 'parking': {
38
- # 'Parking-S-14-PC-': ['L1', 'L2', 'L3', 'L4', 'L5'],
39
- # 'Parking-S-14-PC-': ['L1', 'L2', 'L3', 'L4', 'L5']
40
- # }
41
- # 'panda': {
42
- # 'PandaMyReachDense': ['L1', 'L2', 'L3', 'L4', 'L5'],
43
- # 'PandaMyReachDense': ['L1', 'L2', 'L3', 'L4', 'L5']
44
- # }
45
- }
46
- # for minigrid:
47
- # TODO assert these instead i the beggingning of the code before beginning
48
- # with the actual threading
49
- recognizers = ["ExpertBasedGraml", "Graql"]
50
- # recognizers = ['Graql']
51
-
52
- # for point_maze:
53
- # recognizers = ['ExpertBasedGraml']
54
- # recognizers = ['Draco']
55
-
56
- # for parking:
57
- # recognizers = ['GCGraml']
58
- # recognizers = ['GCDraco']
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument("--domains", nargs="+", required=True, help="List of domains")
16
+ parser.add_argument(
17
+ "--envs",
18
+ nargs="+",
19
+ required=True,
20
+ help="List of environments (same order as domains)",
21
+ )
22
+ parser.add_argument(
23
+ "--tasks", nargs="+", required=True, help="List of tasks (e.g. L1 L2 L3 L4 L5)"
24
+ )
25
+ parser.add_argument(
26
+ "--recognizers", nargs="+", required=True, help="List of recognizers"
27
+ )
28
+ parser.add_argument(
29
+ "--n", type=int, default=5, help="Number of times to execute each task"
30
+ )
31
+ args = parser.parse_args()
59
32
 
60
- # for panda:
61
- # recognizers = ['GCGraml']
62
- # recognizers = ['GCDraco']
33
+ # Build configs dynamically
34
+ configs = {}
35
+ for domain, env in zip(args.domains, args.envs):
36
+ configs.setdefault(domain, {})
37
+ configs[domain][env] = args.tasks
63
38
 
64
- n = 5 # Number of times to execute each task
39
+ recognizers = args.recognizers
40
+ n = args.n
65
41
 
66
42
 
67
43
  # Function to read results from the result file
@@ -97,40 +73,31 @@ def run_experiment(domain, env, task, recognizer, i, generate_new=False):
97
73
  Returns:
98
74
  tuple: A tuple containing the experiment details and the results.
99
75
  """
100
- cmd = f"python gr_libs/odgr_executor.py --domain {domain} --recognizer \
101
- {recognizer} --env_name {env} --task {task} --collect_stats"
102
- print(f"Starting execution: {cmd}")
76
+ cmd = f"python gr_libs/odgr_executor.py --domain {domain} --recognizer {recognizer} --env_name {env} --task {task} --collect_stats --experiment_num {i}"
103
77
  try:
104
78
  res_file_path = get_experiment_results_path(domain, env, task, recognizer)
105
- res_file_path_txt = os.path.join(res_file_path, "res.txt")
106
- i_res_file_path_txt = os.path.join(res_file_path, f"res_{i}.txt")
107
- res_file_path_pkl = os.path.join(res_file_path, "res.pkl")
108
79
  i_res_file_path_pkl = os.path.join(res_file_path, f"res_{i}.pkl")
80
+ i_res_file_path_txt = os.path.join(res_file_path, f"res_{i}.txt")
109
81
  if generate_new or (
110
82
  not os.path.exists(i_res_file_path_txt)
111
83
  or not os.path.exists(i_res_file_path_pkl)
112
84
  ):
113
- if os.path.exists(i_res_file_path_txt) or os.path.exists(
114
- i_res_file_path_pkl
115
- ):
116
- i_res_file_path_txt = i_res_file_path_txt.replace(f"_{i}", f"_{i}_new")
117
- i_res_file_path_pkl = i_res_file_path_pkl.replace(f"_{i}", f"_{i}_new")
118
- process = subprocess.Popen(cmd, shell=True)
119
- process.wait()
85
+ process = subprocess.Popen(
86
+ cmd,
87
+ shell=True,
88
+ stdout=subprocess.PIPE,
89
+ stderr=subprocess.PIPE,
90
+ text=True,
91
+ )
92
+ stdout, stderr = process.communicate()
120
93
  if process.returncode != 0:
121
- print(f"Execution failed: {cmd}")
122
- print(f"Error: {result.stderr}")
94
+ print(f"Execution failed: {cmd}\nSTDOUT:\n{stdout}\nSTDERR:\n{stderr}")
123
95
  return None
124
96
  else:
125
97
  print(f"Finished execution successfully: {cmd}")
126
- file_lock = threading.Lock()
127
- with file_lock:
128
- os.rename(res_file_path_pkl, i_res_file_path_pkl)
129
- os.rename(res_file_path_txt, i_res_file_path_txt)
130
98
  else:
131
99
  print(
132
- f"File {i_res_file_path_txt} already exists. Skipping execution \
133
- of {cmd}"
100
+ f"File {i_res_file_path_txt} already exists. Skipping execution of {cmd}"
134
101
  )
135
102
  return ((domain, env, task, recognizer), read_results(i_res_file_path_pkl))
136
103
  except Exception as e:
@@ -252,43 +219,42 @@ for key, percentage_dict in compiled_accuracies.items():
252
219
  std_dev = np.std(accuracies)
253
220
  compiled_summary[key][percentage][is_cons] = (avg_accuracy, std_dev)
254
221
 
255
- # Write different summary results to different files
222
+ # Write different summary results to different files, one per recognizer
256
223
  if not os.path.exists(os.path.join("outputs", "summaries")):
257
224
  os.makedirs(os.path.join("outputs", "summaries"))
258
- detailed_summary_file_path = os.path.join(
259
- "outputs",
260
- "summaries",
261
- f"detailed_summary_{''.join(configs.keys())}_{recognizers[0]}.txt",
262
- )
263
- compiled_summary_file_path = os.path.join(
264
- "outputs",
265
- "summaries",
266
- f"compiled_summary_{''.join(configs.keys())}_{recognizers[0]}.txt",
267
- )
268
- with open(detailed_summary_file_path, "w") as f:
269
- for key, percentage_dict in detailed_summary.items():
270
- domain, env, task, recognizer = key
271
- f.write(f"{domain}\t{env}\t{task}\t{recognizer}\n")
272
- for percentage, cons_info in percentage_dict.items():
273
- for is_cons, (avg_accuracy, std_dev) in cons_info.items():
274
- f.write(
275
- f"\t\t{percentage}\t{is_cons}\t{avg_accuracy:.4f}\t{std_dev:.4f}\n"
276
- )
277
225
 
278
- with open(compiled_summary_file_path, "w") as f:
279
- for key, percentage_dict in compiled_summary.items():
280
- for percentage, cons_info in percentage_dict.items():
281
- for is_cons, (avg_accuracy, std_dev) in cons_info.items():
282
- f.write(
283
- f"{key[0]}\t{key[1]}\t{percentage}\t{is_cons}\t{avg_accuracy:.4f}\t{std_dev:.4f}\n"
284
- )
285
- domain, recognizer = key
286
- f.write(f"{domain}\t{recognizer}\n")
287
- for percentage, cons_info in percentage_dict.items():
288
- for is_cons, (avg_accuracy, std_dev) in cons_info.items():
289
- f.write(
290
- f"\t\t{percentage}\t{is_cons}\t{avg_accuracy:.4f}\t{std_dev:.4f}\n"
291
- )
226
+ for recognizer in recognizers:
227
+ compiled_summary_file_path = os.path.join(
228
+ "outputs",
229
+ "summaries",
230
+ f"compiled_summary_{''.join(configs.keys())}_{recognizer}.txt",
231
+ )
232
+ with open(compiled_summary_file_path, "w") as f:
233
+ for key, percentage_dict in compiled_summary.items():
234
+ domain, recog = key
235
+ if recog != recognizer:
236
+ continue # Only write results for this recognizer
237
+ for percentage, cons_info in percentage_dict.items():
238
+ for is_cons, (avg_accuracy, std_dev) in cons_info.items():
239
+ f.write(
240
+ f"{domain}\t{recog}\t{percentage}\t{is_cons}\t{avg_accuracy:.4f}\t{std_dev:.4f}\n"
241
+ )
242
+ print(f"Compiled summary results written to {compiled_summary_file_path}")
292
243
 
293
- print(f"Detailed summary results written to {detailed_summary_file_path}")
294
- print(f"Compiled summary results written to {compiled_summary_file_path}")
244
+ detailed_summary_file_path = os.path.join(
245
+ "outputs",
246
+ "summaries",
247
+ f"detailed_summary_{''.join(configs.keys())}_{recognizer}.txt",
248
+ )
249
+ with open(detailed_summary_file_path, "w") as f:
250
+ for key, percentage_dict in detailed_summary.items():
251
+ domain, env, task, recog = key
252
+ if recog != recognizer:
253
+ continue # Only write results for this recognizer
254
+ f.write(f"{domain}\t{env}\t{task}\t{recog}\n")
255
+ for percentage, cons_info in percentage_dict.items():
256
+ for is_cons, (avg_accuracy, std_dev) in cons_info.items():
257
+ f.write(
258
+ f"\t\t{percentage}\t{is_cons}\t{avg_accuracy:.4f}\t{std_dev:.4f}\n"
259
+ )
260
+ print(f"Detailed summary results written to {detailed_summary_file_path}")
@@ -1,8 +1,10 @@
1
1
  """ environment.py """
2
2
 
3
3
  import os
4
+ import sys
4
5
  from abc import abstractmethod
5
6
  from collections import namedtuple
7
+ from contextlib import contextmanager
6
8
 
7
9
  import gymnasium as gym
8
10
  import numpy as np
@@ -23,6 +25,23 @@ LSTMProperties = namedtuple(
23
25
  )
24
26
 
25
27
 
28
+ @contextmanager
29
+ def suppress_output():
30
+ """
31
+ Context manager to suppress stdout and stderr (including C/C++ prints).
32
+ """
33
+ with open(os.devnull, "w") as devnull:
34
+ old_stdout = sys.stdout
35
+ old_stderr = sys.stderr
36
+ sys.stdout = devnull
37
+ sys.stderr = devnull
38
+ try:
39
+ yield
40
+ finally:
41
+ sys.stdout = old_stdout
42
+ sys.stderr = old_stderr
43
+
44
+
26
45
  class EnvProperty:
27
46
  """
28
47
  Base class for environment properties.
@@ -135,9 +154,10 @@ class EnvProperty:
135
154
 
136
155
  def create_vec_env(self, kwargs):
137
156
  """
138
- Create a vectorized environment.
157
+ Create a vectorized environment, suppressing prints from gym/pybullet/panda-gym.
139
158
  """
140
- env = gym.make(**kwargs)
159
+ with suppress_output():
160
+ env = gym.make(**kwargs)
141
161
  return DummyVecEnv([lambda: env])
142
162
 
143
163
  @abstractmethod
@@ -0,0 +1,100 @@
1
+ import argparse
2
+ import os
3
+
4
+ import dill
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+ from gr_libs.ml.utils.storage import get_experiment_results_path
9
+
10
+
11
+ def load_results(domain, env, task, recognizer, n_runs, percentage, cons_type):
12
+ # Collect accuracy for a single task and recognizer
13
+ accs = []
14
+ res_dir = get_experiment_results_path(domain, env, task, recognizer)
15
+ if not os.path.exists(res_dir):
16
+ return accs
17
+ for i in range(n_runs):
18
+ res_file = os.path.join(res_dir, f"res_{i}.pkl")
19
+ if not os.path.exists(res_file):
20
+ continue
21
+ with open(res_file, "rb") as f:
22
+ results = dill.load(f)
23
+ if percentage in results and cons_type in results[percentage]:
24
+ acc = results[percentage][cons_type].get("accuracy")
25
+ if acc is not None:
26
+ accs.append(acc)
27
+ return accs
28
+
29
+
30
+ def main():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--domain", required=True)
33
+ parser.add_argument("--env", required=True)
34
+ parser.add_argument("--tasks", nargs="+", required=True)
35
+ parser.add_argument("--recognizers", nargs="+", required=True)
36
+ parser.add_argument("--n_runs", type=int, default=5)
37
+ parser.add_argument("--percentage", required=True)
38
+ parser.add_argument(
39
+ "--cons_type", choices=["consecutive", "non_consecutive"], required=True
40
+ )
41
+ parser.add_argument("--graph_name", type=str, default="experiment_results")
42
+ args = parser.parse_args()
43
+
44
+ plt.figure(figsize=(7, 5))
45
+ has_data = False
46
+ missing_recognizers = []
47
+
48
+ for recognizer in args.recognizers:
49
+ x_vals = []
50
+ y_means = []
51
+ y_sems = []
52
+ for task in args.tasks:
53
+ accs = load_results(
54
+ args.domain,
55
+ args.env,
56
+ task,
57
+ recognizer,
58
+ args.n_runs,
59
+ args.percentage,
60
+ args.cons_type,
61
+ )
62
+ if accs:
63
+ x_vals.append(task)
64
+ y_means.append(np.mean(accs))
65
+ y_sems.append(np.std(accs) / np.sqrt(len(accs)))
66
+ if x_vals:
67
+ has_data = True
68
+ x_ticks = np.arange(len(x_vals))
69
+ plt.plot(x_ticks, y_means, marker="o", label=recognizer)
70
+ plt.fill_between(
71
+ x_ticks,
72
+ np.array(y_means) - np.array(y_sems),
73
+ np.array(y_means) + np.array(y_sems),
74
+ alpha=0.2,
75
+ )
76
+ plt.xticks(x_ticks, x_vals)
77
+ else:
78
+ print(
79
+ f"Warning: No data found for recognizer '{recognizer}' in {args.domain} / {args.env} / {args.percentage} / {args.cons_type}"
80
+ )
81
+ missing_recognizers.append(recognizer)
82
+
83
+ if not has_data:
84
+ raise RuntimeError(
85
+ f"No data found for any recognizer in {args.domain} / {args.env} / {args.percentage} / {args.cons_type}. "
86
+ f"Missing recognizers: {', '.join(missing_recognizers)}"
87
+ )
88
+
89
+ plt.xlabel("Task")
90
+ plt.ylabel("Accuracy")
91
+ plt.title(f"{args.domain} - {args.env} ({args.percentage}, {args.cons_type})")
92
+ plt.legend()
93
+ plt.grid(True)
94
+ fig_path = f"{args.graph_name}_{'_'.join(args.recognizers)}_{args.domain}_{args.env}_{args.percentage}_{args.cons_type}.png"
95
+ plt.savefig(fig_path)
96
+ print(f"Figure saved at: {fig_path}")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
@@ -5,7 +5,7 @@ from types import MethodType
5
5
  import cv2
6
6
  import numpy as np
7
7
 
8
- from gr_libs.environment.environment import EnvProperty
8
+ from gr_libs.environment.environment import EnvProperty, suppress_output
9
9
 
10
10
  if __name__ != "__main__":
11
11
  from gr_libs.ml.utils.storage import get_agent_model_dir
@@ -184,12 +184,7 @@ class DeepRLAgent:
184
184
  """
185
185
  fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v")
186
186
  fps = 30.0
187
- # if is_gc:
188
- # assert goal_idx is not None
189
- # self.reset_with_goal_idx(goal_idx)
190
- # else:
191
- # assert goal_idx is None
192
- self.env.reset()
187
+ self.safe_env_reset()
193
188
  frame_size = (
194
189
  self.env.render(mode="rgb_array").shape[1],
195
190
  self.env.render(mode="rgb_array").shape[0],
@@ -198,7 +193,7 @@ class DeepRLAgent:
198
193
  video_writer = cv2.VideoWriter(video_path, fourcc, fps, frame_size)
199
194
  general_done, success_done = False, False
200
195
  gc.collect()
201
- obs = self.env.reset()
196
+ obs = self.safe_env_reset()
202
197
  self.env_prop.change_goal_to_specific_desired(obs, desired)
203
198
  counter = 0
204
199
  while not (general_done or success_done):
@@ -209,17 +204,11 @@ class DeepRLAgent:
209
204
  general_done = general_done[0]
210
205
  self.env_prop.change_goal_to_specific_desired(obs, desired)
211
206
  if "success" in info[0].keys():
212
- success_done = info[0][
213
- "success"
214
- ] # make sure the agent actually reached the goal within the max time
207
+ success_done = info[0]["success"]
215
208
  elif "is_success" in info[0].keys():
216
- success_done = info[0][
217
- "is_success"
218
- ] # make sure the agent actually reached the goal within the max time
209
+ success_done = info[0]["is_success"]
219
210
  elif "step_task_completions" in info[0].keys():
220
- success_done = (
221
- len(info[0]["step_task_completions"]) == 1
222
- ) # bug of dummyVecEnv, it removes the episode_task_completions from the info dict.
211
+ success_done = len(info[0]["step_task_completions"]) == 1
223
212
  else:
224
213
  raise NotImplementedError(
225
214
  "no other option for any of the environments."
@@ -270,17 +259,17 @@ class DeepRLAgent:
270
259
 
271
260
  def safe_env_reset(self):
272
261
  """
273
- Reset the environment safely.
262
+ Reset the environment safely, suppressing output.
274
263
 
275
264
  Returns:
276
265
  The initial observation.
277
266
  """
278
267
  try:
279
- obs = self.env.reset()
268
+ obs = suppress_env_reset(self.env)
280
269
  except Exception:
281
270
  kwargs = {"id": self.problem_name, "render_mode": "rgb_array"}
282
271
  self.env = self.env_prop.create_vec_env(kwargs)
283
- obs = self.env.reset()
272
+ obs = suppress_env_reset(self.env)
284
273
  return obs
285
274
 
286
275
  def get_mean_and_std_dev(self, observation):
@@ -632,3 +621,11 @@ class GCDeepRLAgent(DeepRLAgent):
632
621
  desired=goal_directed_goal,
633
622
  )
634
623
  return observations
624
+
625
+
626
+ def suppress_env_reset(env):
627
+ """
628
+ Utility function to suppress prints during env.reset().
629
+ """
630
+ with suppress_output():
631
+ return env.reset()