gr-libs 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. evaluation/analyze_results_cross_alg_cross_domain.py +277 -0
  2. evaluation/create_minigrid_map_image.py +34 -0
  3. evaluation/file_system.py +42 -0
  4. evaluation/generate_experiments_results.py +92 -0
  5. evaluation/generate_experiments_results_new_ver1.py +254 -0
  6. evaluation/generate_experiments_results_new_ver2.py +331 -0
  7. evaluation/generate_task_specific_statistics_plots.py +272 -0
  8. evaluation/get_plans_images.py +47 -0
  9. evaluation/increasing_and_decreasing_.py +63 -0
  10. gr_libs/__init__.py +2 -0
  11. gr_libs/environment/__init__.py +0 -0
  12. gr_libs/environment/environment.py +227 -0
  13. gr_libs/environment/utils/__init__.py +0 -0
  14. gr_libs/environment/utils/utils.py +17 -0
  15. gr_libs/metrics/__init__.py +0 -0
  16. gr_libs/metrics/metrics.py +224 -0
  17. gr_libs/ml/__init__.py +6 -0
  18. gr_libs/ml/agent.py +56 -0
  19. gr_libs/ml/base/__init__.py +1 -0
  20. gr_libs/ml/base/rl_agent.py +54 -0
  21. gr_libs/ml/consts.py +22 -0
  22. gr_libs/ml/neural/__init__.py +3 -0
  23. gr_libs/ml/neural/deep_rl_learner.py +395 -0
  24. gr_libs/ml/neural/utils/__init__.py +2 -0
  25. gr_libs/ml/neural/utils/dictlist.py +33 -0
  26. gr_libs/ml/neural/utils/penv.py +57 -0
  27. gr_libs/ml/planner/__init__.py +0 -0
  28. gr_libs/ml/planner/mcts/__init__.py +0 -0
  29. gr_libs/ml/planner/mcts/mcts_model.py +330 -0
  30. gr_libs/ml/planner/mcts/utils/__init__.py +2 -0
  31. gr_libs/ml/planner/mcts/utils/node.py +33 -0
  32. gr_libs/ml/planner/mcts/utils/tree.py +102 -0
  33. gr_libs/ml/sequential/__init__.py +1 -0
  34. gr_libs/ml/sequential/lstm_model.py +192 -0
  35. gr_libs/ml/tabular/__init__.py +3 -0
  36. gr_libs/ml/tabular/state.py +21 -0
  37. gr_libs/ml/tabular/tabular_q_learner.py +453 -0
  38. gr_libs/ml/tabular/tabular_rl_agent.py +126 -0
  39. gr_libs/ml/utils/__init__.py +6 -0
  40. gr_libs/ml/utils/env.py +7 -0
  41. gr_libs/ml/utils/format.py +100 -0
  42. gr_libs/ml/utils/math.py +13 -0
  43. gr_libs/ml/utils/other.py +24 -0
  44. gr_libs/ml/utils/storage.py +127 -0
  45. gr_libs/recognizer/__init__.py +0 -0
  46. gr_libs/recognizer/gr_as_rl/__init__.py +0 -0
  47. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +102 -0
  48. gr_libs/recognizer/graml/__init__.py +0 -0
  49. gr_libs/recognizer/graml/gr_dataset.py +134 -0
  50. gr_libs/recognizer/graml/graml_recognizer.py +266 -0
  51. gr_libs/recognizer/recognizer.py +46 -0
  52. gr_libs/recognizer/utils/__init__.py +1 -0
  53. gr_libs/recognizer/utils/format.py +13 -0
  54. gr_libs-0.1.3.dist-info/METADATA +197 -0
  55. gr_libs-0.1.3.dist-info/RECORD +62 -0
  56. gr_libs-0.1.3.dist-info/WHEEL +5 -0
  57. gr_libs-0.1.3.dist-info/top_level.txt +3 -0
  58. tutorials/graml_minigrid_tutorial.py +30 -0
  59. tutorials/graml_panda_tutorial.py +32 -0
  60. tutorials/graml_parking_tutorial.py +38 -0
  61. tutorials/graml_point_maze_tutorial.py +43 -0
  62. tutorials/graql_minigrid_tutorial.py +29 -0
@@ -0,0 +1,254 @@
1
+ import copy
2
+ import sys
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import os
6
+ import dill
7
+
8
+ from gr_libs.ml.utils.storage import get_experiment_results_path, set_global_storage_configs
9
+ from scripts.generate_task_specific_statistics_plots import get_figures_dir_path
10
+
11
+ if __name__ == "__main__":
12
+
13
+ fragmented_accuracies = {
14
+ 'graml': {
15
+ 'panda': {'gd_agent': {
16
+ '0.3': [], # every list here should have number of tasks accuracies in it, since we done experiments for L111-L555. remember each accuracy is an average of #goals different tasks.
17
+ '0.5': [],
18
+ '0.7': [],
19
+ '0.9': [],
20
+ '1' : []
21
+ },
22
+ 'gc_agent': {
23
+ '0.3': [],
24
+ '0.5': [],
25
+ '0.7': [],
26
+ '0.9': [],
27
+ '1' : []
28
+ }},
29
+ 'minigrid': {'obstacles': {
30
+ '0.3': [],
31
+ '0.5': [],
32
+ '0.7': [],
33
+ '0.9': [],
34
+ '1' : []
35
+ },
36
+ 'lava_crossing': {
37
+ '0.3': [],
38
+ '0.5': [],
39
+ '0.7': [],
40
+ '0.9': [],
41
+ '1' : []
42
+ }},
43
+ 'point_maze': {'obstacles': {
44
+ '0.3': [],
45
+ '0.5': [],
46
+ '0.7': [],
47
+ '0.9': [],
48
+ '1' : []
49
+ },
50
+ 'four_rooms': {
51
+ '0.3': [],
52
+ '0.5': [],
53
+ '0.7': [],
54
+ '0.9': [],
55
+ '1' : []
56
+ }},
57
+ 'parking': {'gd_agent': {
58
+ '0.3': [],
59
+ '0.5': [],
60
+ '0.7': [],
61
+ '0.9': [],
62
+ '1' : []
63
+ },
64
+ 'gc_agent': {
65
+ '0.3': [],
66
+ '0.5': [],
67
+ '0.7': [],
68
+ '0.9': [],
69
+ '1' : []
70
+ }},
71
+ },
72
+ 'graql': {
73
+ 'panda': {'gd_agent': {
74
+ '0.3': [],
75
+ '0.5': [],
76
+ '0.7': [],
77
+ '0.9': [],
78
+ '1' : []
79
+ },
80
+ 'gc_agent': {
81
+ '0.3': [],
82
+ '0.5': [],
83
+ '0.7': [],
84
+ '0.9': [],
85
+ '1' : []
86
+ }},
87
+ 'minigrid': {'obstacles': {
88
+ '0.3': [],
89
+ '0.5': [],
90
+ '0.7': [],
91
+ '0.9': [],
92
+ '1' : []
93
+ },
94
+ 'lava_crossing': {
95
+ '0.3': [],
96
+ '0.5': [],
97
+ '0.7': [],
98
+ '0.9': [],
99
+ '1' : []
100
+ }},
101
+ 'point_maze': {'obstacles': {
102
+ '0.3': [],
103
+ '0.5': [],
104
+ '0.7': [],
105
+ '0.9': [],
106
+ '1' : []
107
+ },
108
+ 'four_rooms': {
109
+ '0.3': [],
110
+ '0.5': [],
111
+ '0.7': [],
112
+ '0.9': [],
113
+ '1' : []
114
+ }},
115
+ 'parking': {'gd_agent': {
116
+ '0.3': [],
117
+ '0.5': [],
118
+ '0.7': [],
119
+ '0.9': [],
120
+ '1' : []
121
+ },
122
+ 'gc_agent': {
123
+ '0.3': [],
124
+ '0.5': [],
125
+ '0.7': [],
126
+ '0.9': [],
127
+ '1' : []
128
+ }},
129
+ }
130
+ }
131
+
132
+ continuing_accuracies = copy.deepcopy(fragmented_accuracies)
133
+
134
+ #domains = ['panda', 'minigrid', 'point_maze', 'parking']
135
+ domains = ['minigrid', 'point_maze', 'parking']
136
+ tasks = ['L111', 'L222', 'L333', 'L444', 'L555']
137
+ percentages = ['0.3', '0.5', '0.7', '0.9', '1']
138
+
139
+ for partial_obs_type, accuracies, is_same_learn in zip(['fragmented', 'continuing'], [fragmented_accuracies, continuing_accuracies], [False, True]):
140
+ for domain in domains:
141
+ for env in accuracies['graml'][domain].keys():
142
+ for task in tasks:
143
+ set_global_storage_configs(recognizer_str='graml', is_fragmented=partial_obs_type,
144
+ is_inference_same_length_sequences=True, is_learn_same_length_sequences=is_same_learn)
145
+ graml_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
146
+ set_global_storage_configs(recognizer_str='graql', is_fragmented=partial_obs_type)
147
+ graql_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
148
+ if os.path.exists(graml_res_file_path):
149
+ with open(graml_res_file_path, 'rb') as results_file:
150
+ results = dill.load(results_file)
151
+ for percentage in accuracies['graml'][domain][env].keys():
152
+ accuracies['graml'][domain][env][percentage].append(results[percentage]['accuracy'])
153
+ else:
154
+ assert(False, f"no file for {graml_res_file_path}")
155
+ if os.path.exists(graql_res_file_path):
156
+ with open(graql_res_file_path, 'rb') as results_file:
157
+ results = dill.load(results_file)
158
+ for percentage in accuracies['graml'][domain][env].keys():
159
+ accuracies['graql'][domain][env][percentage].append(results[percentage]['accuracy'])
160
+ else:
161
+ assert(False, f"no file for {graql_res_file_path}")
162
+
163
+ plot_styles = {
164
+ ('graml', 'fragmented', 0.3): 'g--o', # Green dashed line with circle markers
165
+ ('graml', 'fragmented', 0.5): 'g--s', # Green dashed line with square markers
166
+ ('graml', 'fragmented', 0.7): 'g--^', # Green dashed line with triangle-up markers
167
+ ('graml', 'fragmented', 0.9): 'g--d', # Green dashed line with diamond markers
168
+ ('graml', 'fragmented', 1.0): 'g--*', # Green dashed line with star markers
169
+
170
+ ('graml', 'continuing', 0.3): 'g-o', # Green solid line with circle markers
171
+ ('graml', 'continuing', 0.5): 'g-s', # Green solid line with square markers
172
+ ('graml', 'continuing', 0.7): 'g-^', # Green solid line with triangle-up markers
173
+ ('graml', 'continuing', 0.9): 'g-d', # Green solid line with diamond markers
174
+ ('graml', 'continuing', 1.0): 'g-*', # Green solid line with star markers
175
+
176
+ ('graql', 'fragmented', 0.3): 'b--o', # Blue dashed line with circle markers
177
+ ('graql', 'fragmented', 0.5): 'b--s', # Blue dashed line with square markers
178
+ ('graql', 'fragmented', 0.7): 'b--^', # Blue dashed line with triangle-up markers
179
+ ('graql', 'fragmented', 0.9): 'b--d', # Blue dashed line with diamond markers
180
+ ('graql', 'fragmented', 1.0): 'b--*', # Blue dashed line with star markers
181
+
182
+ ('graql', 'continuing', 0.3): 'b-o', # Blue solid line with circle markers
183
+ ('graql', 'continuing', 0.5): 'b-s', # Blue solid line with square markers
184
+ ('graql', 'continuing', 0.7): 'b-^', # Blue solid line with triangle-up markers
185
+ ('graql', 'continuing', 0.9): 'b-d', # Blue solid line with diamond markers
186
+ ('graql', 'continuing', 1.0): 'b-*', # Blue solid line with star markers
187
+ }
188
+
189
+ def average_accuracies(accuracies, domain):
190
+ avg_acc = {algo: {perc: [] for perc in percentages}
191
+ for algo in ['graml', 'graql']}
192
+
193
+ for algo in avg_acc.keys():
194
+ for perc in percentages:
195
+ for env in accuracies[algo][domain].keys():
196
+ env_acc = accuracies[algo][domain][env][perc] # list of 5, averages for L111 to L555.
197
+ if env_acc:
198
+ avg_acc[algo][perc].append(np.array(env_acc))
199
+
200
+ for algo in avg_acc.keys():
201
+ for perc in percentages:
202
+ if avg_acc[algo][perc]:
203
+ avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
204
+
205
+ return avg_acc
206
+
207
+ def plot_domain_accuracies(ax, fragmented_accuracies, continuing_accuracies, domain):
208
+ fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
209
+ continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
210
+
211
+ x_vals = np.arange(1, 6) # Number of goals
212
+
213
+ # Create "waves" (shaded regions) for each algorithm
214
+ for algo in ['graml', 'graql']:
215
+ for perc in percentages:
216
+ fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
217
+ continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
218
+
219
+ ax.plot(
220
+ x_vals, fragmented_y_vals,
221
+ plot_styles[(algo, 'fragmented', float(perc))], # Use the updated plot_styles dictionary with percentage
222
+ label=f"{algo}, non-consecutive, {perc}"
223
+ )
224
+ ax.plot(
225
+ x_vals, continuing_y_vals,
226
+ plot_styles[(algo, 'continuing', float(perc))], # Use the updated plot_styles dictionary with percentage
227
+ label=f"{algo}, consecutive, {perc}"
228
+ )
229
+
230
+ ax.set_xticks(x_vals)
231
+ ax.set_yticks(np.linspace(0, 1, 6))
232
+ ax.set_ylim([0, 1])
233
+ ax.set_title(f'{domain.capitalize()} Domain', fontsize=16)
234
+ ax.grid(True)
235
+
236
+ fig, axes = plt.subplots(1, 4, figsize=(24, 6)) # Increase the figure size for better spacing (width 24, height 6)
237
+
238
+ # Generate each plot in a subplot, including both fragmented and continuing accuracies
239
+ for i, domain in enumerate(domains):
240
+ plot_domain_accuracies(axes[i], fragmented_accuracies, continuing_accuracies, domain)
241
+
242
+ # Set a single x-axis and y-axis label for the entire figure
243
+ fig.text(0.5, 0.04, 'Number of Goals', ha='center', fontsize=20) # Centered x-axis label
244
+ fig.text(0.04, 0.5, 'Accuracy', va='center', rotation='vertical', fontsize=20) # Reduced spacing for y-axis label
245
+
246
+ # Adjust subplot layout to avoid overlap
247
+ plt.subplots_adjust(left=0.09, right=0.91, top=0.76, bottom=0.24, wspace=0.3) # More space on top (top=0.82)
248
+
249
+ # Place the legend above the plots with more space between legend and plots
250
+ handles, labels = axes[0].get_legend_handles_labels()
251
+ fig.legend(handles, labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.05), fontsize=12) # Moved above with bbox_to_anchor
252
+
253
+ # Save the figure and show it
254
+ plt.savefig('accuracy_plots.png', dpi=300)
@@ -0,0 +1,331 @@
1
+ import copy
2
+ import sys
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import os
6
+ import dill
7
+ from scipy.interpolate import make_interp_spline
8
+ from scipy.ndimage import gaussian_filter1d
9
+ from gr_libs.ml.utils.storage import get_experiment_results_path, set_global_storage_configs
10
+ from scripts.generate_task_specific_statistics_plots import get_figures_dir_path
11
+
12
+ def smooth_line(x, y, num_points=300):
13
+ x_smooth = np.linspace(np.min(x), np.max(x), num_points)
14
+ spline = make_interp_spline(x, y, k=3) # Cubic spline
15
+ y_smooth = spline(x_smooth)
16
+ return x_smooth, y_smooth
17
+
18
+ if __name__ == "__main__":
19
+
20
+ fragmented_accuracies = {
21
+ 'graml': {
22
+ 'panda': {'gc_agent': {
23
+ '0.3': [],
24
+ '0.5': [],
25
+ '0.7': [],
26
+ '0.9': [],
27
+ '1' : []
28
+ }},
29
+ 'minigrid': {'obstacles': {
30
+ '0.3': [],
31
+ '0.5': [],
32
+ '0.7': [],
33
+ '0.9': [],
34
+ '1' : []
35
+ },
36
+ 'lava_crossing': {
37
+ '0.3': [],
38
+ '0.5': [],
39
+ '0.7': [],
40
+ '0.9': [],
41
+ '1' : []
42
+ }},
43
+ 'point_maze': {'obstacles': {
44
+ '0.3': [],
45
+ '0.5': [],
46
+ '0.7': [],
47
+ '0.9': [],
48
+ '1' : []
49
+ },
50
+ 'four_rooms': {
51
+ '0.3': [],
52
+ '0.5': [],
53
+ '0.7': [],
54
+ '0.9': [],
55
+ '1' : []
56
+ }},
57
+ 'parking': {'gc_agent': {
58
+ '0.3': [],
59
+ '0.5': [],
60
+ '0.7': [],
61
+ '0.9': [],
62
+ '1' : []
63
+ },
64
+ 'gd_agent': {
65
+ '0.3': [],
66
+ '0.5': [],
67
+ '0.7': [],
68
+ '0.9': [],
69
+ '1' : []
70
+ },
71
+ },
72
+ },
73
+ 'graql': {
74
+ 'panda': {'gc_agent': {
75
+ '0.3': [],
76
+ '0.5': [],
77
+ '0.7': [],
78
+ '0.9': [],
79
+ '1' : []
80
+ }},
81
+ 'minigrid': {'obstacles': {
82
+ '0.3': [],
83
+ '0.5': [],
84
+ '0.7': [],
85
+ '0.9': [],
86
+ '1' : []
87
+ },
88
+ 'lava_crossing': {
89
+ '0.3': [],
90
+ '0.5': [],
91
+ '0.7': [],
92
+ '0.9': [],
93
+ '1' : []
94
+ }},
95
+ 'point_maze': {'obstacles': {
96
+ '0.3': [],
97
+ '0.5': [],
98
+ '0.7': [],
99
+ '0.9': [],
100
+ '1' : []
101
+ },
102
+ 'four_rooms': {
103
+ '0.3': [],
104
+ '0.5': [],
105
+ '0.7': [],
106
+ '0.9': [],
107
+ '1' : []
108
+ }},
109
+ 'parking': {'gc_agent': {
110
+ '0.3': [],
111
+ '0.5': [],
112
+ '0.7': [],
113
+ '0.9': [],
114
+ '1' : []
115
+ },
116
+ 'gd_agent': {
117
+ '0.3': [],
118
+ '0.5': [],
119
+ '0.7': [],
120
+ '0.9': [],
121
+ '1' : []
122
+ },
123
+ },
124
+ }
125
+ }
126
+
127
+ continuing_accuracies = copy.deepcopy(fragmented_accuracies)
128
+
129
+ #domains = ['panda', 'minigrid', 'point_maze', 'parking']
130
+ domains = ['parking']
131
+ tasks = ['L555', 'L444', 'L333', 'L222', 'L111']
132
+ percentages = ['0.3', '0.5', '1']
133
+
134
+ for partial_obs_type, accuracies, is_same_learn in zip(['fragmented', 'continuing'], [fragmented_accuracies, continuing_accuracies], [False, True]):
135
+ for domain in domains:
136
+ for env in accuracies['graml'][domain].keys():
137
+ for task in tasks:
138
+ set_global_storage_configs(recognizer_str='graml', is_fragmented=partial_obs_type,
139
+ is_inference_same_length_sequences=True, is_learn_same_length_sequences=is_same_learn)
140
+ graml_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
141
+ set_global_storage_configs(recognizer_str='graql', is_fragmented=partial_obs_type)
142
+ graql_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
143
+ if os.path.exists(graml_res_file_path):
144
+ with open(graml_res_file_path, 'rb') as results_file:
145
+ results = dill.load(results_file)
146
+ for percentage in accuracies['graml'][domain][env].keys():
147
+ accuracies['graml'][domain][env][percentage].append(results[percentage]['accuracy'])
148
+ else:
149
+ assert(False, f"no file for {graml_res_file_path}")
150
+ if os.path.exists(graql_res_file_path):
151
+ with open(graql_res_file_path, 'rb') as results_file:
152
+ results = dill.load(results_file)
153
+ for percentage in accuracies['graml'][domain][env].keys():
154
+ accuracies['graql'][domain][env][percentage].append(results[percentage]['accuracy'])
155
+ else:
156
+ assert(False, f"no file for {graql_res_file_path}")
157
+
158
+ plot_styles = {
159
+ ('graml', 'fragmented', 0.3): 'g--o', # Green dashed line with circle markers
160
+ ('graml', 'fragmented', 0.5): 'g--s', # Green dashed line with square markers
161
+ ('graml', 'fragmented', 0.7): 'g--^', # Green dashed line with triangle-up markers
162
+ ('graml', 'fragmented', 0.9): 'g--d', # Green dashed line with diamond markers
163
+ ('graml', 'fragmented', 1.0): 'g--*', # Green dashed line with star markers
164
+
165
+ ('graml', 'continuing', 0.3): 'g-o', # Green solid line with circle markers
166
+ ('graml', 'continuing', 0.5): 'g-s', # Green solid line with square markers
167
+ ('graml', 'continuing', 0.7): 'g-^', # Green solid line with triangle-up markers
168
+ ('graml', 'continuing', 0.9): 'g-d', # Green solid line with diamond markers
169
+ ('graml', 'continuing', 1.0): 'g-*', # Green solid line with star markers
170
+
171
+ ('graql', 'fragmented', 0.3): 'b--o', # Blue dashed line with circle markers
172
+ ('graql', 'fragmented', 0.5): 'b--s', # Blue dashed line with square markers
173
+ ('graql', 'fragmented', 0.7): 'b--^', # Blue dashed line with triangle-up markers
174
+ ('graql', 'fragmented', 0.9): 'b--d', # Blue dashed line with diamond markers
175
+ ('graql', 'fragmented', 1.0): 'b--*', # Blue dashed line with star markers
176
+
177
+ ('graql', 'continuing', 0.3): 'b-o', # Blue solid line with circle markers
178
+ ('graql', 'continuing', 0.5): 'b-s', # Blue solid line with square markers
179
+ ('graql', 'continuing', 0.7): 'b-^', # Blue solid line with triangle-up markers
180
+ ('graql', 'continuing', 0.9): 'b-d', # Blue solid line with diamond markers
181
+ ('graql', 'continuing', 1.0): 'b-*', # Blue solid line with star markers
182
+ }
183
+
184
+ def average_accuracies(accuracies, domain):
185
+ avg_acc = {algo: {perc: [] for perc in percentages}
186
+ for algo in ['graml', 'graql']}
187
+
188
+ for algo in avg_acc.keys():
189
+ for perc in percentages:
190
+ for env in accuracies[algo][domain].keys():
191
+ env_acc = accuracies[algo][domain][env][perc] # list of 5, averages for L111 to L555.
192
+ if env_acc:
193
+ avg_acc[algo][perc].append(np.array(env_acc))
194
+
195
+ for algo in avg_acc.keys():
196
+ for perc in percentages:
197
+ if avg_acc[algo][perc]:
198
+ avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
199
+
200
+ return avg_acc
201
+
202
+ def plot_domain_accuracies(ax, fragmented_accuracies, continuing_accuracies, domain, sigma=1, line_width=1.5):
203
+ fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
204
+ continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
205
+
206
+ x_vals = np.arange(1, 6) # Number of goals
207
+
208
+ # Create "waves" (shaded regions) for each algorithm
209
+ for algo in ['graml', 'graql']:
210
+ fragmented_y_vals_by_percentage = []
211
+ continuing_y_vals_by_percentage = []
212
+
213
+
214
+ for perc in percentages:
215
+ fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
216
+ continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
217
+
218
+ # Smooth the trends using Gaussian filtering
219
+ fragmented_y_smoothed = gaussian_filter1d(fragmented_y_vals, sigma=sigma)
220
+ continuing_y_smoothed = gaussian_filter1d(continuing_y_vals, sigma=sigma)
221
+
222
+ fragmented_y_vals_by_percentage.append(fragmented_y_smoothed)
223
+ continuing_y_vals_by_percentage.append(continuing_y_smoothed)
224
+
225
+ ax.plot(
226
+ x_vals, fragmented_y_smoothed,
227
+ plot_styles[(algo, 'fragmented', float(perc))],
228
+ label=f"{algo}, non-consecutive, {perc}",
229
+ linewidth=0.5 # Control line thickness here
230
+ )
231
+ ax.plot(
232
+ x_vals, continuing_y_smoothed,
233
+ plot_styles[(algo, 'continuing', float(perc))],
234
+ label=f"{algo}, consecutive, {perc}",
235
+ linewidth=0.5 # Control line thickness here
236
+ )
237
+
238
+ # Fill between trends of the same type that differ only by percentage
239
+ # for i in range(len(percentages) - 1):
240
+ # ax.fill_between(
241
+ # x_vals, fragmented_y_vals_by_percentage[i], fragmented_y_vals_by_percentage[i+1],
242
+ # color='green', alpha=0.1 # Adjust the fill color and transparency (for graml)
243
+ # )
244
+ # ax.fill_between(
245
+ # x_vals, continuing_y_vals_by_percentage[i], continuing_y_vals_by_percentage[i+1],
246
+ # color='blue', alpha=0.1 # Adjust the fill color and transparency (for graql)
247
+ # )
248
+
249
+ ax.set_xticks(x_vals)
250
+ ax.set_yticks(np.linspace(0, 1, 6))
251
+ ax.set_ylim([0, 1])
252
+ ax.set_title(f'{domain.capitalize()} Domain', fontsize=16)
253
+ ax.grid(True)
254
+
255
+ # COMMENT FROM HERE AND UNTIL NEXT FUNCTION FOR BG GC COMPARISON
256
+
257
+ # fig, axes = plt.subplots(1, 4, figsize=(24, 6)) # Increase the figure size for better spacing (width 24, height 6)
258
+
259
+ # # Generate each plot in a subplot, including both fragmented and continuing accuracies
260
+ # for i, domain in enumerate(domains):
261
+ # plot_domain_accuracies(axes[i], fragmented_accuracies, continuing_accuracies, domain)
262
+
263
+ # # Set a single x-axis and y-axis label for the entire figure
264
+ # fig.text(0.5, 0.04, 'Number of Goals', ha='center', fontsize=20) # Centered x-axis label
265
+ # fig.text(0.04, 0.5, 'Accuracy', va='center', rotation='vertical', fontsize=20) # Reduced spacing for y-axis label
266
+
267
+ # # Adjust subplot layout to avoid overlap
268
+ # plt.subplots_adjust(left=0.09, right=0.91, top=0.79, bottom=0.21, wspace=0.3) # More space on top (top=0.82)
269
+
270
+ # # Place the legend above the plots with more space between legend and plots
271
+ # handles, labels = axes[0].get_legend_handles_labels()
272
+ # fig.legend(handles, labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.05), fontsize=12) # Moved above with bbox_to_anchor
273
+
274
+ # # Save the figure and show it
275
+ # plt.savefig('accuracy_plots_smooth.png', dpi=300)
276
+
277
+ # a specific comparison between bg-graml and gc-graml, aka gd_agent and gc_agent "envs":
278
+ def plot_stick_figures(continuing_accuracies, fragmented_accuracies, title):
279
+ fractions = ['0.3', '0.5', '1']
280
+
281
+ def get_agent_data(data_dict, domain='graml', agent='gd_agent'):
282
+ return [np.mean(data_dict[domain]['parking'][agent][fraction]) for fraction in fractions]
283
+
284
+ # Continuing accuracies for gd_agent and gc_agent
285
+ cont_gd = get_agent_data(continuing_accuracies, domain='graml', agent='gd_agent')
286
+ cont_gc = get_agent_data(continuing_accuracies, domain='graml', agent='gc_agent')
287
+
288
+ # Fragmented accuracies for gd_agent and gc_agent
289
+ frag_gd = get_agent_data(fragmented_accuracies, domain='graml', agent='gd_agent')
290
+ frag_gc = get_agent_data(fragmented_accuracies, domain='graml', agent='gc_agent')
291
+
292
+ # Debugging: Print values to check if they're non-zero
293
+ print("Continuing GD:", cont_gd)
294
+ print("Continuing GC:", cont_gc)
295
+ print("Fragmented GD:", frag_gd)
296
+ print("Fragmented GC:", frag_gc)
297
+
298
+ # Setting up figure
299
+ x = np.arange(len(fractions)) # label locations
300
+ width = 0.35 # width of the bars
301
+
302
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
303
+
304
+ # Plot for continuing accuracies
305
+ ax1.bar(x - width / 2, cont_gd, width, label='BG-GRAML')
306
+ ax1.bar(x + width / 2, cont_gc, width, label='GC-GRAML')
307
+ ax1.set_title('Consecutive Sequences', fontsize=20)
308
+ ax1.set_xticks(x)
309
+ ax1.set_xticklabels(fractions, fontsize=16)
310
+ ax1.set_yticks(np.arange(0, 1.1, 0.2))
311
+ ax1.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
312
+ ax1.legend(fontsize=20)
313
+
314
+ # Plot for fragmented accuracies
315
+ ax2.bar(x - width / 2, frag_gd, width, label='BG-GRAML')
316
+ ax2.bar(x + width / 2, frag_gc, width, label='GC-GRAML')
317
+ ax2.set_title('Non-Consecutive Sequences', fontsize=20)
318
+ ax2.set_xticks(x)
319
+ ax2.set_xticklabels(fractions, fontsize=16)
320
+ ax2.set_yticks(np.arange(0, 1.1, 0.2))
321
+ ax2.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
322
+ ax2.set_ylim(0, 1) # Ensure the y-axis is properly set
323
+ ax2.legend(fontsize=20)
324
+ # Common axis labels
325
+ fig.text(0.5, 0.02, 'Observation Portion', ha='center', va='center', fontsize=24)
326
+ fig.text(0.06, 0.5, 'Accuracy', ha='center', va='center', rotation='vertical', fontsize=24)
327
+
328
+ plt.subplots_adjust(top=0.85)
329
+ plt.savefig('gd_vs_gc_parking.png', dpi=300)
330
+
331
+ plot_stick_figures(continuing_accuracies, fragmented_accuracies, "GC-GRAML compared with BG-GRAML")