gr-libs 0.1.8__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. gr_libs/__init__.py +3 -1
  2. gr_libs/_evaluation/__init__.py +1 -0
  3. evaluation/analyze_results_cross_alg_cross_domain.py → gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +81 -88
  4. evaluation/generate_experiments_results.py → gr_libs/_evaluation/_generate_experiments_results.py +6 -6
  5. evaluation/generate_task_specific_statistics_plots.py → gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +11 -14
  6. evaluation/get_plans_images.py → gr_libs/_evaluation/_get_plans_images.py +3 -4
  7. evaluation/increasing_and_decreasing_.py → gr_libs/_evaluation/_increasing_and_decreasing_.py +3 -1
  8. gr_libs/_version.py +2 -2
  9. gr_libs/all_experiments.py +294 -0
  10. gr_libs/environment/__init__.py +14 -1
  11. gr_libs/environment/{utils → _utils}/utils.py +1 -1
  12. gr_libs/environment/environment.py +257 -22
  13. gr_libs/metrics/__init__.py +2 -0
  14. gr_libs/metrics/metrics.py +166 -31
  15. gr_libs/ml/__init__.py +1 -6
  16. gr_libs/ml/base/__init__.py +3 -1
  17. gr_libs/ml/base/rl_agent.py +68 -3
  18. gr_libs/ml/neural/__init__.py +1 -3
  19. gr_libs/ml/neural/deep_rl_learner.py +227 -67
  20. gr_libs/ml/neural/utils/__init__.py +1 -2
  21. gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +1 -1
  22. gr_libs/ml/planner/mcts/mcts_model.py +71 -34
  23. gr_libs/ml/sequential/__init__.py +0 -1
  24. gr_libs/ml/sequential/{lstm_model.py → _lstm_model.py} +11 -14
  25. gr_libs/ml/tabular/__init__.py +1 -3
  26. gr_libs/ml/tabular/tabular_q_learner.py +27 -9
  27. gr_libs/ml/tabular/tabular_rl_agent.py +22 -9
  28. gr_libs/ml/utils/__init__.py +2 -9
  29. gr_libs/ml/utils/format.py +13 -90
  30. gr_libs/ml/utils/math.py +3 -2
  31. gr_libs/ml/utils/other.py +2 -2
  32. gr_libs/ml/utils/storage.py +41 -94
  33. gr_libs/odgr_executor.py +268 -0
  34. gr_libs/problems/consts.py +2 -2
  35. gr_libs/recognizer/_utils/__init__.py +0 -0
  36. gr_libs/recognizer/{utils → _utils}/format.py +2 -2
  37. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +116 -36
  38. gr_libs/recognizer/graml/{gr_dataset.py → _gr_dataset.py} +11 -11
  39. gr_libs/recognizer/graml/graml_recognizer.py +172 -29
  40. gr_libs/recognizer/recognizer.py +59 -10
  41. gr_libs/tutorials/draco_panda_tutorial.py +58 -0
  42. gr_libs/tutorials/draco_parking_tutorial.py +56 -0
  43. {tutorials → gr_libs/tutorials}/gcdraco_panda_tutorial.py +5 -9
  44. {tutorials → gr_libs/tutorials}/gcdraco_parking_tutorial.py +3 -7
  45. {tutorials → gr_libs/tutorials}/graml_minigrid_tutorial.py +2 -2
  46. {tutorials → gr_libs/tutorials}/graml_panda_tutorial.py +5 -10
  47. {tutorials → gr_libs/tutorials}/graml_parking_tutorial.py +5 -9
  48. {tutorials → gr_libs/tutorials}/graml_point_maze_tutorial.py +2 -1
  49. {tutorials → gr_libs/tutorials}/graql_minigrid_tutorial.py +2 -2
  50. {gr_libs-0.1.8.dist-info → gr_libs-0.2.2.dist-info}/METADATA +84 -29
  51. gr_libs-0.2.2.dist-info/RECORD +71 -0
  52. {gr_libs-0.1.8.dist-info → gr_libs-0.2.2.dist-info}/WHEEL +1 -1
  53. gr_libs-0.2.2.dist-info/top_level.txt +2 -0
  54. tests/test_draco.py +14 -0
  55. tests/test_gcdraco.py +2 -2
  56. tests/test_graml.py +4 -4
  57. tests/test_graql.py +1 -1
  58. evaluation/create_minigrid_map_image.py +0 -38
  59. evaluation/file_system.py +0 -53
  60. evaluation/generate_experiments_results_new_ver1.py +0 -238
  61. evaluation/generate_experiments_results_new_ver2.py +0 -331
  62. gr_libs/ml/neural/utils/penv.py +0 -60
  63. gr_libs/recognizer/utils/__init__.py +0 -1
  64. gr_libs-0.1.8.dist-info/RECORD +0 -70
  65. gr_libs-0.1.8.dist-info/top_level.txt +0 -4
  66. /gr_libs/environment/{utils → _utils}/__init__.py +0 -0
  67. /gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +0 -0
  68. /gr_libs/ml/planner/mcts/{utils → _utils}/node.py +0 -0
@@ -1,38 +0,0 @@
1
- from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
2
- import numpy as np
3
- import gr_libs.ml as ml
4
- from minigrid.core.world_object import Wall
5
-
6
- # from q_table_plot import save_q_table_plot_image
7
- from gymnasium.envs.registration import register
8
-
9
- env_name = "MiniGrid-SimpleCrossingS13N4-DynamicGoal-5x9-v0"
10
- # create an agent and train it (if it is already trained, it will get q-table from cache)
11
- agent = ml.TabularQLearner(
12
- env_name="MiniGrid-Walls-13x13-v0",
13
- problem_name="MiniGrid-SimpleCrossingS13N4-DynamicGoal-5x9-v0",
14
- )
15
- # agent.learn()
16
-
17
- # save_q_table_plot_image(agent.q_table, 15, 15, (10,7))
18
-
19
- # add to the steps list the step the trained agent would take on the env in every state according to the q_table
20
- env = agent.env
21
- env = RGBImgPartialObsWrapper(env) # Get pixel observations
22
- env = ImgObsWrapper(env) # Get rid of the 'mission' field
23
- obs, _ = env.reset() # This now produces an RGB tensor only
24
-
25
- img = env.get_frame()
26
-
27
- ####### save image to file
28
- from PIL import Image
29
- import numpy as np
30
-
31
- image_pil = Image.fromarray(np.uint8(img)).convert("RGB")
32
- image_pil.save(r"{}.png".format(env_name))
33
-
34
- # ####### show image
35
- # from gym_minigrid.window import Window
36
- # window = Window(r"z")
37
- # window.show_img(img=img)
38
- # window.close()
evaluation/file_system.py DELETED
@@ -1,53 +0,0 @@
1
- import os
2
- import dill
3
- import random
4
- import hashlib
5
- from typing import List
6
-
7
-
8
- def get_observations_path(env_name: str):
9
- return f"dataset/{env_name}/observations"
10
-
11
-
12
- def get_observations_paths(path: str):
13
- return [os.path.join(path, file_name) for file_name in os.listdir(path)]
14
-
15
-
16
- def create_partial_observabilities_files(env_name: str, observabilities: List[float]):
17
- with open(
18
- r"dataset/{env_name}/observations/obs1.0.pkl".format(env_name=env_name), "rb"
19
- ) as f:
20
- step_1_0 = dill.load(f)
21
-
22
- number_of_items_to_randomize = [
23
- int(observability * len(step_1_0)) for observability in observabilities
24
- ]
25
- obs = []
26
- for items_to_randomize in number_of_items_to_randomize:
27
- obs.append(random.sample(step_1_0, items_to_randomize))
28
- for index, observability in enumerate(observabilities):
29
- partial_steps = obs[index]
30
- file_path = r"dataset/{env_name}/observations/obs{obs}.pkl".format(
31
- env_name=env_name, obs=observability
32
- )
33
- with open(file_path, "wb+") as f:
34
- dill.dump(partial_steps, f)
35
-
36
-
37
- def md5(file_path: str):
38
- hash_md5 = hashlib.md5()
39
- with open(file_path, "rb") as f:
40
- for chunk in iter(lambda: f.read(4096), b""):
41
- hash_md5.update(chunk)
42
- return hash_md5.hexdigest()
43
-
44
-
45
- def get_md5(file_path_list: List[str]):
46
- return [(file_path, md5(file_path=file_path)) for file_path in file_path_list]
47
-
48
-
49
- def print_md5(file_path_list: List[str]):
50
- md5_of_observations = get_md5(file_path_list=file_path_list)
51
- for file_name, file_md5 in md5_of_observations:
52
- print(f"{file_name}:{file_md5}")
53
- print("")
@@ -1,238 +0,0 @@
1
- import copy
2
- import sys
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
- import os
6
- import dill
7
-
8
- from gr_libs.ml.utils.storage import (
9
- get_experiment_results_path,
10
- set_global_storage_configs,
11
- )
12
- from scripts.generate_task_specific_statistics_plots import get_figures_dir_path
13
-
14
- if __name__ == "__main__":
15
-
16
- fragmented_accuracies = {
17
- "graml": {
18
- "panda": {
19
- "gd_agent": {
20
- "0.3": [], # every list here should have number of tasks accuracies in it, since we done experiments for L111-L555. remember each accuracy is an average of #goals different tasks.
21
- "0.5": [],
22
- "0.7": [],
23
- "0.9": [],
24
- "1": [],
25
- },
26
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
27
- },
28
- "minigrid": {
29
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
30
- "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
31
- },
32
- "point_maze": {
33
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
34
- "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
35
- },
36
- "parking": {
37
- "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
38
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
39
- },
40
- },
41
- "graql": {
42
- "panda": {
43
- "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
44
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
45
- },
46
- "minigrid": {
47
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
48
- "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
49
- },
50
- "point_maze": {
51
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
52
- "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
53
- },
54
- "parking": {
55
- "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
56
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
57
- },
58
- },
59
- }
60
-
61
- continuing_accuracies = copy.deepcopy(fragmented_accuracies)
62
-
63
- # domains = ['panda', 'minigrid', 'point_maze', 'parking']
64
- domains = ["minigrid", "point_maze", "parking"]
65
- tasks = ["L111", "L222", "L333", "L444", "L555"]
66
- percentages = ["0.3", "0.5", "0.7", "0.9", "1"]
67
-
68
- for partial_obs_type, accuracies, is_same_learn in zip(
69
- ["fragmented", "continuing"],
70
- [fragmented_accuracies, continuing_accuracies],
71
- [False, True],
72
- ):
73
- for domain in domains:
74
- for env in accuracies["graml"][domain].keys():
75
- for task in tasks:
76
- set_global_storage_configs(
77
- recognizer_str="graml",
78
- is_fragmented=partial_obs_type,
79
- is_inference_same_length_sequences=True,
80
- is_learn_same_length_sequences=is_same_learn,
81
- )
82
- graml_res_file_path = (
83
- f"{get_experiment_results_path(domain, env, task)}.pkl"
84
- )
85
- set_global_storage_configs(
86
- recognizer_str="graql", is_fragmented=partial_obs_type
87
- )
88
- graql_res_file_path = (
89
- f"{get_experiment_results_path(domain, env, task)}.pkl"
90
- )
91
- if os.path.exists(graml_res_file_path):
92
- with open(graml_res_file_path, "rb") as results_file:
93
- results = dill.load(results_file)
94
- for percentage in accuracies["graml"][domain][env].keys():
95
- accuracies["graml"][domain][env][percentage].append(
96
- results[percentage]["accuracy"]
97
- )
98
- else:
99
- assert (False, f"no file for {graml_res_file_path}")
100
- if os.path.exists(graql_res_file_path):
101
- with open(graql_res_file_path, "rb") as results_file:
102
- results = dill.load(results_file)
103
- for percentage in accuracies["graml"][domain][env].keys():
104
- accuracies["graql"][domain][env][percentage].append(
105
- results[percentage]["accuracy"]
106
- )
107
- else:
108
- assert (False, f"no file for {graql_res_file_path}")
109
-
110
- plot_styles = {
111
- ("graml", "fragmented", 0.3): "g--o", # Green dashed line with circle markers
112
- ("graml", "fragmented", 0.5): "g--s", # Green dashed line with square markers
113
- (
114
- "graml",
115
- "fragmented",
116
- 0.7,
117
- ): "g--^", # Green dashed line with triangle-up markers
118
- ("graml", "fragmented", 0.9): "g--d", # Green dashed line with diamond markers
119
- ("graml", "fragmented", 1.0): "g--*", # Green dashed line with star markers
120
- ("graml", "continuing", 0.3): "g-o", # Green solid line with circle markers
121
- ("graml", "continuing", 0.5): "g-s", # Green solid line with square markers
122
- (
123
- "graml",
124
- "continuing",
125
- 0.7,
126
- ): "g-^", # Green solid line with triangle-up markers
127
- ("graml", "continuing", 0.9): "g-d", # Green solid line with diamond markers
128
- ("graml", "continuing", 1.0): "g-*", # Green solid line with star markers
129
- ("graql", "fragmented", 0.3): "b--o", # Blue dashed line with circle markers
130
- ("graql", "fragmented", 0.5): "b--s", # Blue dashed line with square markers
131
- (
132
- "graql",
133
- "fragmented",
134
- 0.7,
135
- ): "b--^", # Blue dashed line with triangle-up markers
136
- ("graql", "fragmented", 0.9): "b--d", # Blue dashed line with diamond markers
137
- ("graql", "fragmented", 1.0): "b--*", # Blue dashed line with star markers
138
- ("graql", "continuing", 0.3): "b-o", # Blue solid line with circle markers
139
- ("graql", "continuing", 0.5): "b-s", # Blue solid line with square markers
140
- ("graql", "continuing", 0.7): "b-^", # Blue solid line with triangle-up markers
141
- ("graql", "continuing", 0.9): "b-d", # Blue solid line with diamond markers
142
- ("graql", "continuing", 1.0): "b-*", # Blue solid line with star markers
143
- }
144
-
145
- def average_accuracies(accuracies, domain):
146
- avg_acc = {
147
- algo: {perc: [] for perc in percentages} for algo in ["graml", "graql"]
148
- }
149
-
150
- for algo in avg_acc.keys():
151
- for perc in percentages:
152
- for env in accuracies[algo][domain].keys():
153
- env_acc = accuracies[algo][domain][env][
154
- perc
155
- ] # list of 5, averages for L111 to L555.
156
- if env_acc:
157
- avg_acc[algo][perc].append(np.array(env_acc))
158
-
159
- for algo in avg_acc.keys():
160
- for perc in percentages:
161
- if avg_acc[algo][perc]:
162
- avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
163
-
164
- return avg_acc
165
-
166
- def plot_domain_accuracies(
167
- ax, fragmented_accuracies, continuing_accuracies, domain
168
- ):
169
- fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
170
- continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
171
-
172
- x_vals = np.arange(1, 6) # Number of goals
173
-
174
- # Create "waves" (shaded regions) for each algorithm
175
- for algo in ["graml", "graql"]:
176
- for perc in percentages:
177
- fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
178
- continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
179
-
180
- ax.plot(
181
- x_vals,
182
- fragmented_y_vals,
183
- plot_styles[
184
- (algo, "fragmented", float(perc))
185
- ], # Use the updated plot_styles dictionary with percentage
186
- label=f"{algo}, non-consecutive, {perc}",
187
- )
188
- ax.plot(
189
- x_vals,
190
- continuing_y_vals,
191
- plot_styles[
192
- (algo, "continuing", float(perc))
193
- ], # Use the updated plot_styles dictionary with percentage
194
- label=f"{algo}, consecutive, {perc}",
195
- )
196
-
197
- ax.set_xticks(x_vals)
198
- ax.set_yticks(np.linspace(0, 1, 6))
199
- ax.set_ylim([0, 1])
200
- ax.set_title(f"{domain.capitalize()} Domain", fontsize=16)
201
- ax.grid(True)
202
-
203
- fig, axes = plt.subplots(
204
- 1, 4, figsize=(24, 6)
205
- ) # Increase the figure size for better spacing (width 24, height 6)
206
-
207
- # Generate each plot in a subplot, including both fragmented and continuing accuracies
208
- for i, domain in enumerate(domains):
209
- plot_domain_accuracies(
210
- axes[i], fragmented_accuracies, continuing_accuracies, domain
211
- )
212
-
213
- # Set a single x-axis and y-axis label for the entire figure
214
- fig.text(
215
- 0.5, 0.04, "Number of Goals", ha="center", fontsize=20
216
- ) # Centered x-axis label
217
- fig.text(
218
- 0.04, 0.5, "Accuracy", va="center", rotation="vertical", fontsize=20
219
- ) # Reduced spacing for y-axis label
220
-
221
- # Adjust subplot layout to avoid overlap
222
- plt.subplots_adjust(
223
- left=0.09, right=0.91, top=0.76, bottom=0.24, wspace=0.3
224
- ) # More space on top (top=0.82)
225
-
226
- # Place the legend above the plots with more space between legend and plots
227
- handles, labels = axes[0].get_legend_handles_labels()
228
- fig.legend(
229
- handles,
230
- labels,
231
- loc="upper center",
232
- ncol=4,
233
- bbox_to_anchor=(0.5, 1.05),
234
- fontsize=12,
235
- ) # Moved above with bbox_to_anchor
236
-
237
- # Save the figure and show it
238
- plt.savefig("accuracy_plots.png", dpi=300)
@@ -1,331 +0,0 @@
1
- import copy
2
- import sys
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
- import os
6
- import dill
7
- from scipy.interpolate import make_interp_spline
8
- from scipy.ndimage import gaussian_filter1d
9
- from gr_libs.ml.utils.storage import (
10
- get_experiment_results_path,
11
- set_global_storage_configs,
12
- )
13
- from scripts.generate_task_specific_statistics_plots import get_figures_dir_path
14
-
15
-
16
- def smooth_line(x, y, num_points=300):
17
- x_smooth = np.linspace(np.min(x), np.max(x), num_points)
18
- spline = make_interp_spline(x, y, k=3) # Cubic spline
19
- y_smooth = spline(x_smooth)
20
- return x_smooth, y_smooth
21
-
22
-
23
- if __name__ == "__main__":
24
-
25
- fragmented_accuracies = {
26
- "graml": {
27
- "panda": {
28
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []}
29
- },
30
- "minigrid": {
31
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
32
- "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
33
- },
34
- "point_maze": {
35
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
36
- "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
37
- },
38
- "parking": {
39
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
40
- "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
41
- },
42
- },
43
- "graql": {
44
- "panda": {
45
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []}
46
- },
47
- "minigrid": {
48
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
49
- "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
50
- },
51
- "point_maze": {
52
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
53
- "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
54
- },
55
- "parking": {
56
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
57
- "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
58
- },
59
- },
60
- }
61
-
62
- continuing_accuracies = copy.deepcopy(fragmented_accuracies)
63
-
64
- # domains = ['panda', 'minigrid', 'point_maze', 'parking']
65
- domains = ["parking"]
66
- tasks = ["L555", "L444", "L333", "L222", "L111"]
67
- percentages = ["0.3", "0.5", "1"]
68
-
69
- for partial_obs_type, accuracies, is_same_learn in zip(
70
- ["fragmented", "continuing"],
71
- [fragmented_accuracies, continuing_accuracies],
72
- [False, True],
73
- ):
74
- for domain in domains:
75
- for env in accuracies["graml"][domain].keys():
76
- for task in tasks:
77
- set_global_storage_configs(
78
- recognizer_str="graml",
79
- is_fragmented=partial_obs_type,
80
- is_inference_same_length_sequences=True,
81
- is_learn_same_length_sequences=is_same_learn,
82
- )
83
- graml_res_file_path = (
84
- f"{get_experiment_results_path(domain, env, task)}.pkl"
85
- )
86
- set_global_storage_configs(
87
- recognizer_str="graql", is_fragmented=partial_obs_type
88
- )
89
- graql_res_file_path = (
90
- f"{get_experiment_results_path(domain, env, task)}.pkl"
91
- )
92
- if os.path.exists(graml_res_file_path):
93
- with open(graml_res_file_path, "rb") as results_file:
94
- results = dill.load(results_file)
95
- for percentage in accuracies["graml"][domain][env].keys():
96
- accuracies["graml"][domain][env][percentage].append(
97
- results[percentage]["accuracy"]
98
- )
99
- else:
100
- assert (False, f"no file for {graml_res_file_path}")
101
- if os.path.exists(graql_res_file_path):
102
- with open(graql_res_file_path, "rb") as results_file:
103
- results = dill.load(results_file)
104
- for percentage in accuracies["graml"][domain][env].keys():
105
- accuracies["graql"][domain][env][percentage].append(
106
- results[percentage]["accuracy"]
107
- )
108
- else:
109
- assert (False, f"no file for {graql_res_file_path}")
110
-
111
- plot_styles = {
112
- ("graml", "fragmented", 0.3): "g--o", # Green dashed line with circle markers
113
- ("graml", "fragmented", 0.5): "g--s", # Green dashed line with square markers
114
- (
115
- "graml",
116
- "fragmented",
117
- 0.7,
118
- ): "g--^", # Green dashed line with triangle-up markers
119
- ("graml", "fragmented", 0.9): "g--d", # Green dashed line with diamond markers
120
- ("graml", "fragmented", 1.0): "g--*", # Green dashed line with star markers
121
- ("graml", "continuing", 0.3): "g-o", # Green solid line with circle markers
122
- ("graml", "continuing", 0.5): "g-s", # Green solid line with square markers
123
- (
124
- "graml",
125
- "continuing",
126
- 0.7,
127
- ): "g-^", # Green solid line with triangle-up markers
128
- ("graml", "continuing", 0.9): "g-d", # Green solid line with diamond markers
129
- ("graml", "continuing", 1.0): "g-*", # Green solid line with star markers
130
- ("graql", "fragmented", 0.3): "b--o", # Blue dashed line with circle markers
131
- ("graql", "fragmented", 0.5): "b--s", # Blue dashed line with square markers
132
- (
133
- "graql",
134
- "fragmented",
135
- 0.7,
136
- ): "b--^", # Blue dashed line with triangle-up markers
137
- ("graql", "fragmented", 0.9): "b--d", # Blue dashed line with diamond markers
138
- ("graql", "fragmented", 1.0): "b--*", # Blue dashed line with star markers
139
- ("graql", "continuing", 0.3): "b-o", # Blue solid line with circle markers
140
- ("graql", "continuing", 0.5): "b-s", # Blue solid line with square markers
141
- ("graql", "continuing", 0.7): "b-^", # Blue solid line with triangle-up markers
142
- ("graql", "continuing", 0.9): "b-d", # Blue solid line with diamond markers
143
- ("graql", "continuing", 1.0): "b-*", # Blue solid line with star markers
144
- }
145
-
146
- def average_accuracies(accuracies, domain):
147
- avg_acc = {
148
- algo: {perc: [] for perc in percentages} for algo in ["graml", "graql"]
149
- }
150
-
151
- for algo in avg_acc.keys():
152
- for perc in percentages:
153
- for env in accuracies[algo][domain].keys():
154
- env_acc = accuracies[algo][domain][env][
155
- perc
156
- ] # list of 5, averages for L111 to L555.
157
- if env_acc:
158
- avg_acc[algo][perc].append(np.array(env_acc))
159
-
160
- for algo in avg_acc.keys():
161
- for perc in percentages:
162
- if avg_acc[algo][perc]:
163
- avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
164
-
165
- return avg_acc
166
-
167
- def plot_domain_accuracies(
168
- ax,
169
- fragmented_accuracies,
170
- continuing_accuracies,
171
- domain,
172
- sigma=1,
173
- line_width=1.5,
174
- ):
175
- fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
176
- continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
177
-
178
- x_vals = np.arange(1, 6) # Number of goals
179
-
180
- # Create "waves" (shaded regions) for each algorithm
181
- for algo in ["graml", "graql"]:
182
- fragmented_y_vals_by_percentage = []
183
- continuing_y_vals_by_percentage = []
184
-
185
- for perc in percentages:
186
- fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
187
- continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
188
-
189
- # Smooth the trends using Gaussian filtering
190
- fragmented_y_smoothed = gaussian_filter1d(
191
- fragmented_y_vals, sigma=sigma
192
- )
193
- continuing_y_smoothed = gaussian_filter1d(
194
- continuing_y_vals, sigma=sigma
195
- )
196
-
197
- fragmented_y_vals_by_percentage.append(fragmented_y_smoothed)
198
- continuing_y_vals_by_percentage.append(continuing_y_smoothed)
199
-
200
- ax.plot(
201
- x_vals,
202
- fragmented_y_smoothed,
203
- plot_styles[(algo, "fragmented", float(perc))],
204
- label=f"{algo}, non-consecutive, {perc}",
205
- linewidth=0.5, # Control line thickness here
206
- )
207
- ax.plot(
208
- x_vals,
209
- continuing_y_smoothed,
210
- plot_styles[(algo, "continuing", float(perc))],
211
- label=f"{algo}, consecutive, {perc}",
212
- linewidth=0.5, # Control line thickness here
213
- )
214
-
215
- # Fill between trends of the same type that differ only by percentage
216
- # for i in range(len(percentages) - 1):
217
- # ax.fill_between(
218
- # x_vals, fragmented_y_vals_by_percentage[i], fragmented_y_vals_by_percentage[i+1],
219
- # color='green', alpha=0.1 # Adjust the fill color and transparency (for graml)
220
- # )
221
- # ax.fill_between(
222
- # x_vals, continuing_y_vals_by_percentage[i], continuing_y_vals_by_percentage[i+1],
223
- # color='blue', alpha=0.1 # Adjust the fill color and transparency (for graql)
224
- # )
225
-
226
- ax.set_xticks(x_vals)
227
- ax.set_yticks(np.linspace(0, 1, 6))
228
- ax.set_ylim([0, 1])
229
- ax.set_title(f"{domain.capitalize()} Domain", fontsize=16)
230
- ax.grid(True)
231
-
232
- # COMMENT FROM HERE AND UNTIL NEXT FUNCTION FOR BG GC COMPARISON
233
-
234
- # fig, axes = plt.subplots(1, 4, figsize=(24, 6)) # Increase the figure size for better spacing (width 24, height 6)
235
-
236
- # # Generate each plot in a subplot, including both fragmented and continuing accuracies
237
- # for i, domain in enumerate(domains):
238
- # plot_domain_accuracies(axes[i], fragmented_accuracies, continuing_accuracies, domain)
239
-
240
- # # Set a single x-axis and y-axis label for the entire figure
241
- # fig.text(0.5, 0.04, 'Number of Goals', ha='center', fontsize=20) # Centered x-axis label
242
- # fig.text(0.04, 0.5, 'Accuracy', va='center', rotation='vertical', fontsize=20) # Reduced spacing for y-axis label
243
-
244
- # # Adjust subplot layout to avoid overlap
245
- # plt.subplots_adjust(left=0.09, right=0.91, top=0.79, bottom=0.21, wspace=0.3) # More space on top (top=0.82)
246
-
247
- # # Place the legend above the plots with more space between legend and plots
248
- # handles, labels = axes[0].get_legend_handles_labels()
249
- # fig.legend(handles, labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.05), fontsize=12) # Moved above with bbox_to_anchor
250
-
251
- # # Save the figure and show it
252
- # plt.savefig('accuracy_plots_smooth.png', dpi=300)
253
-
254
- # a specific comparison between bg-graml and gc-graml, aka gd_agent and gc_agent "envs":
255
- def plot_stick_figures(continuing_accuracies, fragmented_accuracies, title):
256
- fractions = ["0.3", "0.5", "1"]
257
-
258
- def get_agent_data(data_dict, domain="graml", agent="gd_agent"):
259
- return [
260
- np.mean(data_dict[domain]["parking"][agent][fraction])
261
- for fraction in fractions
262
- ]
263
-
264
- # Continuing accuracies for gd_agent and gc_agent
265
- cont_gd = get_agent_data(
266
- continuing_accuracies, domain="graml", agent="gd_agent"
267
- )
268
- cont_gc = get_agent_data(
269
- continuing_accuracies, domain="graml", agent="gc_agent"
270
- )
271
-
272
- # Fragmented accuracies for gd_agent and gc_agent
273
- frag_gd = get_agent_data(
274
- fragmented_accuracies, domain="graml", agent="gd_agent"
275
- )
276
- frag_gc = get_agent_data(
277
- fragmented_accuracies, domain="graml", agent="gc_agent"
278
- )
279
-
280
- # Debugging: Print values to check if they're non-zero
281
- print("Continuing GD:", cont_gd)
282
- print("Continuing GC:", cont_gc)
283
- print("Fragmented GD:", frag_gd)
284
- print("Fragmented GC:", frag_gc)
285
-
286
- # Setting up figure
287
- x = np.arange(len(fractions)) # label locations
288
- width = 0.35 # width of the bars
289
-
290
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
291
-
292
- # Plot for continuing accuracies
293
- ax1.bar(x - width / 2, cont_gd, width, label="BG-GRAML")
294
- ax1.bar(x + width / 2, cont_gc, width, label="GC-GRAML")
295
- ax1.set_title("Consecutive Sequences", fontsize=20)
296
- ax1.set_xticks(x)
297
- ax1.set_xticklabels(fractions, fontsize=16)
298
- ax1.set_yticks(np.arange(0, 1.1, 0.2))
299
- ax1.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
300
- ax1.legend(fontsize=20)
301
-
302
- # Plot for fragmented accuracies
303
- ax2.bar(x - width / 2, frag_gd, width, label="BG-GRAML")
304
- ax2.bar(x + width / 2, frag_gc, width, label="GC-GRAML")
305
- ax2.set_title("Non-Consecutive Sequences", fontsize=20)
306
- ax2.set_xticks(x)
307
- ax2.set_xticklabels(fractions, fontsize=16)
308
- ax2.set_yticks(np.arange(0, 1.1, 0.2))
309
- ax2.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
310
- ax2.set_ylim(0, 1) # Ensure the y-axis is properly set
311
- ax2.legend(fontsize=20)
312
- # Common axis labels
313
- fig.text(
314
- 0.5, 0.02, "Observation Portion", ha="center", va="center", fontsize=24
315
- )
316
- fig.text(
317
- 0.06,
318
- 0.5,
319
- "Accuracy",
320
- ha="center",
321
- va="center",
322
- rotation="vertical",
323
- fontsize=24,
324
- )
325
-
326
- plt.subplots_adjust(top=0.85)
327
- plt.savefig("gd_vs_gc_parking.png", dpi=300)
328
-
329
- plot_stick_figures(
330
- continuing_accuracies, fragmented_accuracies, "GC-GRAML compared with BG-GRAML"
331
- )