gr-libs 0.2.2__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. gr_libs/_evaluation/_generate_experiments_results.py +0 -141
  2. gr_libs/_version.py +2 -2
  3. gr_libs/all_experiments.py +73 -107
  4. gr_libs/environment/environment.py +22 -2
  5. gr_libs/evaluation/generate_experiments_results.py +100 -0
  6. gr_libs/ml/neural/deep_rl_learner.py +17 -20
  7. gr_libs/odgr_executor.py +20 -25
  8. gr_libs/problems/consts.py +568 -290
  9. gr_libs/recognizer/_utils/__init__.py +1 -0
  10. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +12 -1
  11. gr_libs/recognizer/graml/graml_recognizer.py +16 -8
  12. gr_libs/tutorials/gcdraco_panda_tutorial.py +6 -2
  13. gr_libs/tutorials/gcdraco_parking_tutorial.py +3 -1
  14. gr_libs/tutorials/graml_minigrid_tutorial.py +16 -12
  15. gr_libs/tutorials/graml_panda_tutorial.py +6 -2
  16. gr_libs/tutorials/graml_parking_tutorial.py +3 -1
  17. gr_libs/tutorials/graml_point_maze_tutorial.py +15 -2
  18. {gr_libs-0.2.2.dist-info → gr_libs-0.2.5.dist-info}/METADATA +27 -16
  19. {gr_libs-0.2.2.dist-info → gr_libs-0.2.5.dist-info}/RECORD +26 -25
  20. {gr_libs-0.2.2.dist-info → gr_libs-0.2.5.dist-info}/WHEEL +1 -1
  21. tests/test_odgr_executor_expertbasedgraml.py +14 -0
  22. tests/test_odgr_executor_gcdraco.py +14 -0
  23. tests/test_odgr_executor_gcgraml.py +14 -0
  24. tests/test_odgr_executor_graql.py +14 -0
  25. gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +0 -260
  26. gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +0 -497
  27. gr_libs/_evaluation/_get_plans_images.py +0 -61
  28. gr_libs/_evaluation/_increasing_and_decreasing_.py +0 -106
  29. /gr_libs/{_evaluation → evaluation}/__init__.py +0 -0
  30. {gr_libs-0.2.2.dist-info → gr_libs-0.2.5.dist-info}/top_level.txt +0 -0
@@ -1,497 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- import dill
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
- import torch
8
-
9
- from gr_libs.metrics.metrics import measure_average_sequence_distance
10
- from gr_libs.ml.utils import get_embeddings_result_path
11
- from gr_libs.ml.utils.storage import (
12
- get_graql_experiment_confidence_path,
13
- set_global_storage_configs,
14
- )
15
-
16
-
17
- def get_tasks_embeddings_dir_path(env_name):
18
- return os.path.join("../gr_libs", get_embeddings_result_path(env_name))
19
-
20
-
21
- def get_figures_dir_path(domain_name, env_name):
22
- return os.path.join("../gr_libs", "figures", domain_name, env_name)
23
-
24
-
25
- def similarities_vector_to_std_deviation_units_vector(
26
- ref_dict: dict, relative_to_largest
27
- ):
28
- """
29
- Calculate the number of standard deviation units every other element is
30
- from the largest/smallest element in the vector.
31
-
32
- Parameters:
33
- - vector: list or numpy array of numbers.
34
- - relative_to_largest: boolean, if True, measure in relation to the largest element,
35
- if False, measure in relation to the smallest element.
36
-
37
- Returns:
38
- - List of number of standard deviation units for each element in the vector.
39
- """
40
- vector = np.array(list(ref_dict.values()))
41
- mean = np.mean(vector) # for the future maybe another method for measurement
42
- std_dev = np.std(vector)
43
-
44
- # Determine the reference element (largest or smallest)
45
- if relative_to_largest:
46
- reference_value = np.max(vector)
47
- else:
48
- reference_value = np.min(vector)
49
- for goal, value in ref_dict.items():
50
- ref_dict[goal] = abs(value - reference_value) / std_dev
51
- return ref_dict
52
-
53
-
54
- def analyze_and_produce_plots(
55
- recognizer_type: str,
56
- domain_name: str,
57
- env_name: str,
58
- fragmented_status: str,
59
- inf_same_length_status: str,
60
- learn_same_length_status: str,
61
- ):
62
- if recognizer_type == "graml":
63
- assert os.path.exists(
64
- get_embeddings_result_path(domain_name)
65
- ), "Embeddings weren't made for this environment, run graml_main.py with this environment first."
66
- tasks_embedding_dicts = {}
67
- tasks_plans_dict = {}
68
- goals_similarity_dict = {}
69
- plans_similarity_dict = {}
70
-
71
- embeddings_dir_path = get_tasks_embeddings_dir_path(domain_name)
72
- for embeddings_file_name in [
73
- filename
74
- for filename in os.listdir(embeddings_dir_path)
75
- if "embeddings" in filename
76
- ]:
77
- with open(
78
- os.path.join(embeddings_dir_path, embeddings_file_name), "rb"
79
- ) as emb_file:
80
- splitted_name = embeddings_file_name.split("_")
81
- goal, percentage = splitted_name[0], splitted_name[1]
82
- with open(
83
- os.path.join(
84
- embeddings_dir_path, f"{goal}_{percentage}_plans_dict.pkl"
85
- ),
86
- "rb",
87
- ) as plan_file:
88
- tasks_plans_dict[f"{goal}_{percentage}"] = dill.load(plan_file)
89
- tasks_embedding_dicts[f"{goal}_{percentage}"] = dill.load(emb_file)
90
-
91
- for goal_percentage, embedding_dict in tasks_embedding_dicts.items():
92
- goal, percentage = goal_percentage.split("_")
93
- similarities = {
94
- dynamic_goal: []
95
- for dynamic_goal in embedding_dict.keys()
96
- if "true" not in dynamic_goal
97
- }
98
- real_goal_embedding = embedding_dict[f"{goal}_true"]
99
- for dynamic_goal, goal_embedding in embedding_dict.items():
100
- if "true" in dynamic_goal:
101
- continue
102
- curr_similarity = torch.exp(
103
- -torch.sum(torch.abs(goal_embedding - real_goal_embedding))
104
- )
105
- similarities[dynamic_goal] = curr_similarity.item()
106
- if goal not in goals_similarity_dict.keys():
107
- goals_similarity_dict[goal] = {}
108
- goals_similarity_dict[goal][percentage] = (
109
- similarities_vector_to_std_deviation_units_vector(
110
- ref_dict=similarities, relative_to_largest=True
111
- )
112
- )
113
-
114
- for goal_percentage, plans_dict in tasks_plans_dict.items():
115
- goal, percentage = goal_percentage.split("_")
116
- real_plan = plans_dict[f"{goal}_true"]
117
- sequence_similarities = {
118
- d_goal: measure_average_sequence_distance(real_plan, plan)
119
- for d_goal, plan in plans_dict.items()
120
- if "true" not in d_goal
121
- } # aps = agent plan sequence?
122
- if goal not in plans_similarity_dict.keys():
123
- plans_similarity_dict[goal] = {}
124
- plans_similarity_dict[goal][percentage] = (
125
- similarities_vector_to_std_deviation_units_vector(
126
- ref_dict=sequence_similarities, relative_to_largest=False
127
- )
128
- )
129
-
130
- goals = list(goals_similarity_dict.keys())
131
- percentages = sorted(
132
- {
133
- percentage
134
- for similarities in goals_similarity_dict.values()
135
- for percentage in similarities.keys()
136
- }
137
- )
138
- num_percentages = len(percentages)
139
- fig_string = f"{recognizer_type}_{domain_name}_{env_name}_{fragmented_status}_{inf_same_length_status}_{learn_same_length_status}"
140
-
141
- else: # algorithm = "graql"
142
- assert os.path.exists(
143
- get_graql_experiment_confidence_path(domain_name)
144
- ), "Embeddings weren't made for this environment, run graml_main.py with this environment first."
145
- tasks_scores_dict = {}
146
- goals_similarity_dict = {}
147
- experiments_dir_path = get_graql_experiment_confidence_path(domain_name)
148
- for experiments_file_name in os.listdir(experiments_dir_path):
149
- with open(
150
- os.path.join(experiments_dir_path, experiments_file_name), "rb"
151
- ) as exp_file:
152
- splitted_name = experiments_file_name.split("_")
153
- goal, percentage = splitted_name[1], splitted_name[2]
154
- tasks_scores_dict[f"{goal}_{percentage}"] = dill.load(exp_file)
155
-
156
- for goal_percentage, scores_list in tasks_scores_dict.items():
157
- goal, percentage = goal_percentage.split("_")
158
- similarities = {
159
- dynamic_goal: score for (dynamic_goal, score) in scores_list
160
- }
161
- if goal not in goals_similarity_dict.keys():
162
- goals_similarity_dict[goal] = {}
163
- goals_similarity_dict[goal][percentage] = (
164
- similarities_vector_to_std_deviation_units_vector(
165
- ref_dict=similarities, relative_to_largest=False
166
- )
167
- )
168
-
169
- goals = list(goals_similarity_dict.keys())
170
- percentages = sorted(
171
- {
172
- percentage
173
- for similarities in goals_similarity_dict.values()
174
- for percentage in similarities.keys()
175
- }
176
- )
177
- num_percentages = len(percentages)
178
- fig_string = f"{recognizer_type}_{domain_name}_{env_name}_{fragmented_status}"
179
-
180
- # -------------------- Start of Confusion Matrix Code --------------------
181
- # Initialize matrices of size len(goals) x len(goals)
182
- confusion_matrix_goals, confusion_matrix_plans = np.zeros(
183
- (len(goals), len(goals))
184
- ), np.zeros((len(goals), len(goals)))
185
-
186
- # if domain_name == 'point_maze' and args.task == 'L555':
187
- # if env_name == 'obstacles':
188
- # goals = ['(4, 7)', '(3, 6)', '(5, 5)', '(8, 8)', '(6, 3)', '(7, 4)']
189
- # else: # if env_name is 'four_rooms'
190
- # goals = ['(2, 8)', '(3, 7)', '(3, 4)', '(4, 4)', '(4, 3)', '(7, 3)', '(8, 2)']
191
-
192
- # Populate confusion matrix with similarity values for goals
193
- for i, true_goal in enumerate(goals):
194
- for j, dynamic_goal in enumerate(goals):
195
- percentage = percentages[-3]
196
- confusion_matrix_goals[i, j] = goals_similarity_dict[true_goal][
197
- percentage
198
- ].get(dynamic_goal, 0)
199
-
200
- if plans_similarity_dict:
201
- # Populate confusion matrix with similarity values for plans
202
- for i, true_goal in enumerate(goals):
203
- for j, dynamic_goal in enumerate(goals):
204
- percentage = percentages[-1]
205
- confusion_matrix_plans[i, j] = plans_similarity_dict[true_goal][
206
- percentage
207
- ].get(dynamic_goal, 0)
208
-
209
- # Create the figure and subplots for the unified display
210
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6), sharex=True)
211
-
212
- # Plot for goal similarities
213
- im1 = ax1.imshow(confusion_matrix_goals, cmap="Blues", interpolation="nearest")
214
- cbar1 = fig.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)
215
- cbar1.set_label("St. dev from most probable goal", fontsize=18)
216
- ax1.set_title("Embeddings", fontsize=22, pad=20)
217
- ax1.set_xticks(np.arange(len(goals)))
218
- ax1.set_xticklabels(goals, rotation=45, ha="right", fontsize=16)
219
- ax1.set_yticks(np.arange(len(goals)))
220
- ax1.set_yticklabels(goals, fontsize=16) # y-tick labels for ax1
221
-
222
- # Plot for plan similarities
223
- im2 = ax2.imshow(confusion_matrix_plans, cmap="Greens", interpolation="nearest")
224
- cbar2 = fig.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)
225
- cbar2.set_label("Distance between plans", fontsize=18)
226
- ax2.set_title("Sequences", fontsize=22, pad=20)
227
- ax2.set_xticks(np.arange(len(goals)))
228
- ax2.set_xticklabels(goals, rotation=45, ha="right", fontsize=16)
229
- ax2.set_yticks(np.arange(len(goals))) # y-ticks for ax2 explicitly
230
- ax2.set_yticklabels(goals, fontsize=16) # y-tick labels for ax2
231
-
232
- # Adjust the figure layout to reduce overlap
233
- plt.subplots_adjust(left=0.15, right=0.9, bottom=0.25, top=0.85, wspace=0.1)
234
-
235
- # Unified axis labels, placed closer to the left
236
- fig.text(0.57, 0.07, "Goals Adaptation Phase", ha="center", fontsize=22)
237
- fig.text(
238
- 0.12, 0.5, "Inference Phase", va="center", rotation="vertical", fontsize=22
239
- )
240
-
241
- # Save the combined plot
242
- fig_dir = get_figures_dir_path(domain_name=domain_name, env_name=env_name)
243
- if not os.path.exists(fig_dir):
244
- os.makedirs(fig_dir)
245
- confusion_matrix_combined_path = os.path.join(
246
- fig_dir, f"{fig_string}_combined_conf_mat.png"
247
- )
248
- plt.savefig(confusion_matrix_combined_path, dpi=300)
249
- print(
250
- f"Combined confusion matrix figure saved at: {confusion_matrix_combined_path}"
251
- )
252
-
253
- # -------------------- End of Confusion Matrix Code --------------------
254
- fig, axes = plt.subplots(
255
- nrows=num_percentages, ncols=1, figsize=(10, 6 * num_percentages)
256
- )
257
-
258
- if num_percentages == 1:
259
- axes = [axes]
260
-
261
- for i, percentage in enumerate(percentages):
262
- correct_tasks, tasks_num = 0, 0
263
- ax = axes[i]
264
- dynamic_goals = list(
265
- next(iter(goals_similarity_dict.values()))[percentage].keys()
266
- )
267
- num_goals = len(goals)
268
- num_dynamic_goals = len(dynamic_goals)
269
- bar_width = 0.8 / num_dynamic_goals
270
- bar_positions = np.arange(num_goals)
271
-
272
- if recognizer_type == "graml":
273
- for j, dynamic_goal in enumerate(dynamic_goals):
274
- goal_similarities = [
275
- goals_similarity_dict[goal][percentage][dynamic_goal] + 0.04
276
- for goal in goals
277
- ]
278
- plan_similarities = [
279
- plans_similarity_dict[goal][percentage][dynamic_goal] + 0.04
280
- for goal in goals
281
- ]
282
- ax.bar(
283
- bar_positions + j * bar_width,
284
- goal_similarities,
285
- bar_width / 2,
286
- label=f"embedding of {dynamic_goal}",
287
- )
288
- ax.bar(
289
- bar_positions + j * bar_width + bar_width / 2,
290
- plan_similarities,
291
- bar_width / 2,
292
- label=f"plan to {dynamic_goal}",
293
- )
294
- else:
295
- for j, dynamic_goal in enumerate(dynamic_goals):
296
- goal_similarities = [
297
- goals_similarity_dict[goal][percentage][dynamic_goal] + 0.04
298
- for goal in goals
299
- ]
300
- ax.bar(
301
- bar_positions + j * bar_width,
302
- goal_similarities,
303
- bar_width,
304
- label=f"policy to {dynamic_goal}",
305
- )
306
-
307
- x_labels = []
308
- for true_goal in goals:
309
- guessed_goal = min(
310
- goals_similarity_dict[true_goal][percentage],
311
- key=goals_similarity_dict[true_goal][percentage].get,
312
- )
313
- tasks_num += 1
314
- if true_goal == guessed_goal:
315
- correct_tasks += 1
316
- second_lowest_value = sorted(
317
- goals_similarity_dict[true_goal][percentage].values()
318
- )[1]
319
- confidence_level = abs(
320
- goals_similarity_dict[true_goal][percentage][guessed_goal]
321
- - second_lowest_value
322
- )
323
- label = f"True: {true_goal}\nGuessed: {guessed_goal}\nConfidence: {confidence_level:.2f}"
324
- x_labels.append(label)
325
-
326
- ax.set_ylabel("Distance (units in st. deviations)", fontsize=10)
327
- ax.set_title(
328
- f"Confidence level for {domain_name}, {env_name}, {fragmented_status}. Accuracy: {correct_tasks / tasks_num}",
329
- fontsize=12,
330
- )
331
- ax.set_xticks(bar_positions + bar_width * (num_dynamic_goals - 1) / 2)
332
- ax.set_xticklabels(x_labels, fontsize=8)
333
- ax.legend()
334
-
335
- fig_path = os.path.join(fig_dir, f"{fig_string}_stats.png")
336
- fig.savefig(fig_path)
337
- print(f"general figure saved at: {fig_path}")
338
-
339
-
340
- def parse_args():
341
- parser = argparse.ArgumentParser(
342
- description="Parse command-line arguments for the RL experiment.",
343
- formatter_class=argparse.RawTextHelpFormatter,
344
- )
345
-
346
- # Required arguments
347
- required_group = parser.add_argument_group("Required arguments")
348
- required_group.add_argument(
349
- "--domain",
350
- choices=["point_maze", "minigrid", "parking", "franka_kitchen", "panda"],
351
- required=True,
352
- help="Domain type (point_maze, minigrid, parking, or franka_kitchen)",
353
- )
354
- required_group.add_argument(
355
- "--recognizer",
356
- choices=["graml", "graql", "draco"],
357
- required=True,
358
- help="Recognizer type (graml, graql, draco). graql only for discrete domains.",
359
- )
360
- required_group.add_argument(
361
- "--task",
362
- choices=[
363
- "L1",
364
- "L2",
365
- "L3",
366
- "L4",
367
- "L5",
368
- "L11",
369
- "L22",
370
- "L33",
371
- "L44",
372
- "L55",
373
- "L111",
374
- "L222",
375
- "L333",
376
- "L444",
377
- "L555",
378
- ],
379
- required=True,
380
- help="Task identifier (e.g., L1, L2,...,L5)",
381
- )
382
- required_group.add_argument(
383
- "--partial_obs_type",
384
- required=True,
385
- choices=["fragmented", "continuing"],
386
- help="Give fragmented or continuing partial observations for inference phase inputs.",
387
- )
388
-
389
- # Optional arguments
390
- optional_group = parser.add_argument_group("Optional arguments")
391
- optional_group.add_argument(
392
- "--minigrid_env",
393
- choices=["four_rooms", "obstacles"],
394
- help="Minigrid environment (four_rooms or obstacles)",
395
- )
396
- optional_group.add_argument(
397
- "--parking_env",
398
- choices=["gd_agent", "gc_agent"],
399
- help="Parking environment (agent or gc_agent)",
400
- )
401
- optional_group.add_argument(
402
- "--point_maze_env",
403
- choices=["obstacles", "four_rooms"],
404
- help="Parking environment (agent or gc_agent)",
405
- )
406
- optional_group.add_argument(
407
- "--franka_env",
408
- choices=["comb1", "comb2"],
409
- help="Franka Kitchen environment (comb1 or comb2)",
410
- )
411
- optional_group.add_argument(
412
- "--panda_env",
413
- choices=["gc_agent", "gd_agent"],
414
- help="Panda Robotics environment (gc_agent or gd_agent)",
415
- )
416
- optional_group.add_argument(
417
- "--learn_same_seq_len",
418
- action="store_true",
419
- help="Learn with the same sequence length",
420
- )
421
- optional_group.add_argument(
422
- "--inference_same_seq_len",
423
- action="store_true",
424
- help="Infer with the same sequence length",
425
- )
426
-
427
- args = parser.parse_args()
428
-
429
- ### VALIDATE INPUTS ###
430
- # Assert that all required arguments are provided
431
- assert (
432
- args.domain is not None
433
- and args.recognizer is not None
434
- and args.task is not None
435
- ), "Missing required arguments: domain, recognizer, or task"
436
-
437
- # Validate the combination of domain and environment
438
- if args.domain == "minigrid" and args.minigrid_env is None:
439
- parser.error(
440
- "Missing required argument: --minigrid_env must be provided when --domain is minigrid"
441
- )
442
- elif args.domain == "parking" and args.parking_env is None:
443
- parser.error(
444
- "Missing required argument: --parking_env must be provided when --domain is parking"
445
- )
446
- elif args.domain == "point_maze" and args.point_maze_env is None:
447
- parser.error(
448
- "Missing required argument: --point_maze_env must be provided when --domain is point_maze"
449
- )
450
- elif args.domain == "franka_kitchen" and args.franka_env is None:
451
- parser.error(
452
- "Missing required argument: --franka_env must be provided when --domain is franka_kitchen"
453
- )
454
-
455
- if args.recognizer != "graml":
456
- if args.learn_same_seq_len == True:
457
- parser.error("learn_same_seq_len is only relevant for graml.")
458
- if args.inference_same_seq_len == True:
459
- parser.error("inference_same_seq_len is only relevant for graml.")
460
-
461
- return args
462
-
463
-
464
- if __name__ == "__main__":
465
- args = parse_args()
466
- set_global_storage_configs(
467
- recognizer_str=args.recognizer,
468
- is_fragmented=args.partial_obs_type,
469
- is_inference_same_length_sequences=args.inference_same_seq_len,
470
- is_learn_same_length_sequences=args.learn_same_seq_len,
471
- )
472
- (env_name,) = (
473
- x
474
- for x in [
475
- args.minigrid_env,
476
- args.parking_env,
477
- args.point_maze_env,
478
- args.franka_env,
479
- ]
480
- if isinstance(x, str)
481
- )
482
- if args.inference_same_seq_len:
483
- inference_same_seq_len = "inference_same_seq_len"
484
- else:
485
- inference_same_seq_len = "inference_diff_seq_len"
486
- if args.learn_same_seq_len:
487
- learn_same_seq_len = "learn_same_seq_len"
488
- else:
489
- learn_same_seq_len = "learn_diff_seq_len"
490
- analyze_and_produce_plots(
491
- args.recognizer,
492
- domain_name=args.domain,
493
- env_name=env_name,
494
- fragmented_status=args.partial_obs_type,
495
- inf_same_length_status=inference_same_seq_len,
496
- learn_same_length_status=learn_same_seq_len,
497
- )
@@ -1,61 +0,0 @@
1
- import inspect
2
- import os
3
- import pickle
4
- import sys
5
-
6
- currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
7
- GRAML_itself = os.path.dirname(currentdir)
8
- GRAML_includer = os.path.dirname(os.path.dirname(currentdir))
9
- sys.path.insert(0, GRAML_includer)
10
- sys.path.insert(0, GRAML_itself)
11
-
12
-
13
- def get_plans_result_path(env_name):
14
- return os.path.join("dataset", (env_name), "plans")
15
-
16
-
17
- def get_policy_sequences_result_path(env_name):
18
- return os.path.join("dataset", (env_name), "policy_sequences")
19
-
20
-
21
- # TODO: instead of loading the model and having it produce the sequence again, just save the sequence from the framework run, and have this script accept the whole path (including is_fragmented etc.)
22
- def analyze_and_produce_images(env_name):
23
- models_dir = get_models_dir(env_name=env_name)
24
- for dirname in os.listdir(models_dir):
25
- if dirname.startswith("MiniGrid"):
26
- model_dir = get_model_dir(
27
- env_name=env_name, model_name=dirname, class_name="MCTS"
28
- )
29
- model_file_path = os.path.join(model_dir, "mcts_model.pth")
30
- try:
31
- with open(model_file_path, "rb") as file: # Load the pre-existing model
32
- monteCarloTreeSearch = pickle.load(file)
33
- full_plan = monteCarloTreeSearch.generate_full_policy_sequence()
34
- plan = [pos for ((state, pos), action) in full_plan]
35
- plans_result_path = get_plans_result_path(env_name)
36
- if not os.path.exists(plans_result_path):
37
- os.makedirs(plans_result_path)
38
- img_path = os.path.join(get_plans_result_path(env_name), dirname)
39
- print(
40
- f"plan to {dirname} is:\n\t{plan}\ngenerating image at {img_path}."
41
- )
42
- create_sequence_image(plan, img_path, dirname)
43
-
44
- except FileNotFoundError as e:
45
- print(
46
- f"Warning: {e.filename} doesn't exist. It's probably a base goal, not generating policy sequence for it."
47
- )
48
-
49
-
50
- if __name__ == "__main__":
51
- # preventing circular imports. only needed for running this as main anyway.
52
- from gr_libs.ml.utils.storage import get_model_dir, get_models_dir
53
-
54
- # checks:
55
- assert (
56
- len(sys.argv) == 2
57
- ), f"Assertion failed: len(sys.argv) is {len(sys.argv)} while it needs to be 2.\n Example: \n\t /usr/bin/python scripts/get_plans_images.py MiniGrid-Walls-13x13-v0"
58
- assert os.path.exists(
59
- get_models_dir(sys.argv[1])
60
- ), "plans weren't made for this environment, run graml_main.py with this environment first."
61
- analyze_and_produce_images(sys.argv[1])
@@ -1,106 +0,0 @@
1
- import os
2
-
3
- import dill
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
-
7
- from gr_libs.ml.utils.storage import (
8
- get_experiment_results_path,
9
- set_global_storage_configs,
10
- )
11
-
12
- if __name__ == "__main__":
13
-
14
- # Define the tasks and percentages
15
- increasing_base_goals = ["L1", "L2", "L3", "L4", "L5"]
16
- increasing_dynamic_goals = ["L111", "L222", "L555", "L333", "L444"]
17
- percentages = ["0.3", "0.5", "0.7", "0.9", "1"]
18
-
19
- # Prepare a dictionary to hold accuracy data
20
- accuracies = {
21
- task: {perc: [] for perc in percentages}
22
- for task in increasing_base_goals + increasing_dynamic_goals
23
- }
24
-
25
- # Collect data for both sets of goals
26
- for task in increasing_base_goals + increasing_dynamic_goals:
27
- set_global_storage_configs(
28
- recognizer_str="graml",
29
- is_fragmented="fragmented",
30
- is_inference_same_length_sequences=True,
31
- is_learn_same_length_sequences=False,
32
- )
33
- res_file_path = (
34
- f'{get_experiment_results_path("parking", "gd_agent", task)}.pkl'
35
- )
36
-
37
- if os.path.exists(res_file_path):
38
- with open(res_file_path, "rb") as results_file:
39
- results = dill.load(results_file)
40
- for percentage in percentages:
41
- accuracies[task][percentage].append(results[percentage]["accuracy"])
42
- else:
43
- print(f"Warning: no file for {res_file_path}")
44
-
45
- # Create the figure with two subplots
46
- fig, axes = plt.subplots(1, 2, figsize=(12, 6))
47
-
48
- # Bar plot function
49
- def plot_accuracies(ax, task_set, title, type):
50
- """Plot accuracies for a given set of tasks on the provided axis."""
51
- x_vals = np.arange(len(task_set)) # X-axis positions for the number of goals
52
- bar_width = 0.15 # Width of each bar
53
- for i, perc in enumerate(["0.3", "0.5", "1"]):
54
- if perc == "1":
55
- y_vals = [
56
- max(
57
- [
58
- accuracies[task]["0.5"][0],
59
- accuracies[task]["0.7"][0],
60
- accuracies[task]["0.9"][0],
61
- accuracies[task]["1"][0],
62
- ]
63
- )
64
- for task in task_set
65
- ] # Get mean accuracies
66
- else:
67
- y_vals = [
68
- accuracies[task][perc][0] for task in task_set
69
- ] # Get mean accuracies
70
- if type != "base":
71
- ax.bar(
72
- x_vals + i * bar_width,
73
- y_vals,
74
- width=bar_width,
75
- label=f"Percentage {perc}",
76
- )
77
- else:
78
- ax.bar(x_vals + i * bar_width, y_vals, width=bar_width)
79
- ax.set_xticks(x_vals + bar_width) # Center x-ticks
80
- ax.set_xticklabels(
81
- [i + 3 for i in range(len(task_set))], fontsize=16
82
- ) # Set custom x-tick labels
83
- ax.set_yticks(np.linspace(0, 1, 6))
84
- ax.set_ylim([0, 1])
85
- ax.set_title(title, fontsize=20)
86
- ax.set_xlabel(f"Number of {type} Goals", fontsize=20)
87
- if type == "base":
88
- ax.set_ylabel("Accuracy", fontsize=22)
89
- ax.legend()
90
-
91
- # Plot for increasing base goals
92
- plot_accuracies(axes[0], increasing_base_goals, "Increasing Base Goals", "base")
93
-
94
- # Plot for increasing dynamic goals
95
- plot_accuracies(
96
- axes[1], increasing_dynamic_goals, "Increasing Active Goals", "active"
97
- )
98
- plt.subplots_adjust(
99
- left=0.1, right=0.9, top=0.9, bottom=0.1, wspace=0.3, hspace=0.3
100
- )
101
- # Adjust layout and save the figure
102
- plt.tight_layout()
103
- plt.savefig(
104
- "increasing_goals_plot_bars.png", dpi=300
105
- ) # Save the figure as a PNG file
106
- print("Figure saved at: increasing_goals_plot_bars.png")
File without changes