gr-libs 0.1.6.post1__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. evaluation/analyze_results_cross_alg_cross_domain.py +236 -246
  2. evaluation/create_minigrid_map_image.py +10 -6
  3. evaluation/file_system.py +16 -5
  4. evaluation/generate_experiments_results.py +123 -74
  5. evaluation/generate_experiments_results_new_ver1.py +227 -243
  6. evaluation/generate_experiments_results_new_ver2.py +317 -317
  7. evaluation/generate_task_specific_statistics_plots.py +481 -253
  8. evaluation/get_plans_images.py +41 -26
  9. evaluation/increasing_and_decreasing_.py +97 -56
  10. gr_libs/__init__.py +6 -1
  11. gr_libs/_version.py +2 -2
  12. gr_libs/environment/__init__.py +17 -9
  13. gr_libs/environment/environment.py +167 -39
  14. gr_libs/environment/utils/utils.py +22 -12
  15. gr_libs/metrics/__init__.py +5 -0
  16. gr_libs/metrics/metrics.py +76 -34
  17. gr_libs/ml/__init__.py +2 -0
  18. gr_libs/ml/agent.py +21 -6
  19. gr_libs/ml/base/__init__.py +1 -1
  20. gr_libs/ml/base/rl_agent.py +13 -10
  21. gr_libs/ml/consts.py +1 -1
  22. gr_libs/ml/neural/deep_rl_learner.py +433 -352
  23. gr_libs/ml/neural/utils/__init__.py +1 -1
  24. gr_libs/ml/neural/utils/dictlist.py +3 -3
  25. gr_libs/ml/neural/utils/penv.py +5 -2
  26. gr_libs/ml/planner/mcts/mcts_model.py +524 -302
  27. gr_libs/ml/planner/mcts/utils/__init__.py +1 -1
  28. gr_libs/ml/planner/mcts/utils/node.py +11 -7
  29. gr_libs/ml/planner/mcts/utils/tree.py +14 -10
  30. gr_libs/ml/sequential/__init__.py +1 -1
  31. gr_libs/ml/sequential/lstm_model.py +256 -175
  32. gr_libs/ml/tabular/state.py +7 -7
  33. gr_libs/ml/tabular/tabular_q_learner.py +123 -73
  34. gr_libs/ml/tabular/tabular_rl_agent.py +20 -19
  35. gr_libs/ml/utils/__init__.py +8 -2
  36. gr_libs/ml/utils/format.py +78 -70
  37. gr_libs/ml/utils/math.py +2 -1
  38. gr_libs/ml/utils/other.py +1 -1
  39. gr_libs/ml/utils/storage.py +95 -28
  40. gr_libs/problems/consts.py +1549 -1227
  41. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +145 -80
  42. gr_libs/recognizer/graml/gr_dataset.py +209 -110
  43. gr_libs/recognizer/graml/graml_recognizer.py +431 -231
  44. gr_libs/recognizer/recognizer.py +38 -27
  45. gr_libs/recognizer/utils/__init__.py +1 -1
  46. gr_libs/recognizer/utils/format.py +8 -3
  47. {gr_libs-0.1.6.post1.dist-info → gr_libs-0.1.8.dist-info}/METADATA +1 -1
  48. gr_libs-0.1.8.dist-info/RECORD +70 -0
  49. {gr_libs-0.1.6.post1.dist-info → gr_libs-0.1.8.dist-info}/WHEEL +1 -1
  50. {gr_libs-0.1.6.post1.dist-info → gr_libs-0.1.8.dist-info}/top_level.txt +0 -1
  51. tests/test_gcdraco.py +10 -0
  52. tests/test_graml.py +8 -4
  53. tests/test_graql.py +2 -1
  54. tutorials/gcdraco_panda_tutorial.py +66 -0
  55. tutorials/gcdraco_parking_tutorial.py +61 -0
  56. tutorials/graml_minigrid_tutorial.py +42 -12
  57. tutorials/graml_panda_tutorial.py +35 -14
  58. tutorials/graml_parking_tutorial.py +37 -19
  59. tutorials/graml_point_maze_tutorial.py +33 -13
  60. tutorials/graql_minigrid_tutorial.py +31 -15
  61. CI/README.md +0 -12
  62. CI/docker_build_context/Dockerfile +0 -15
  63. gr_libs/recognizer/recognizer_doc.md +0 -61
  64. gr_libs-0.1.6.post1.dist-info/RECORD +0 -70
@@ -6,326 +6,326 @@ import os
6
6
  import dill
7
7
  from scipy.interpolate import make_interp_spline
8
8
  from scipy.ndimage import gaussian_filter1d
9
- from gr_libs.ml.utils.storage import get_experiment_results_path, set_global_storage_configs
9
+ from gr_libs.ml.utils.storage import (
10
+ get_experiment_results_path,
11
+ set_global_storage_configs,
12
+ )
10
13
  from scripts.generate_task_specific_statistics_plots import get_figures_dir_path
11
14
 
15
+
12
16
  def smooth_line(x, y, num_points=300):
13
- x_smooth = np.linspace(np.min(x), np.max(x), num_points)
14
- spline = make_interp_spline(x, y, k=3) # Cubic spline
15
- y_smooth = spline(x_smooth)
16
- return x_smooth, y_smooth
17
+ x_smooth = np.linspace(np.min(x), np.max(x), num_points)
18
+ spline = make_interp_spline(x, y, k=3) # Cubic spline
19
+ y_smooth = spline(x_smooth)
20
+ return x_smooth, y_smooth
21
+
17
22
 
18
23
  if __name__ == "__main__":
19
24
 
20
- fragmented_accuracies = {
21
- 'graml': {
22
- 'panda': {'gc_agent': {
23
- '0.3': [],
24
- '0.5': [],
25
- '0.7': [],
26
- '0.9': [],
27
- '1' : []
28
- }},
29
- 'minigrid': {'obstacles': {
30
- '0.3': [],
31
- '0.5': [],
32
- '0.7': [],
33
- '0.9': [],
34
- '1' : []
35
- },
36
- 'lava_crossing': {
37
- '0.3': [],
38
- '0.5': [],
39
- '0.7': [],
40
- '0.9': [],
41
- '1' : []
42
- }},
43
- 'point_maze': {'obstacles': {
44
- '0.3': [],
45
- '0.5': [],
46
- '0.7': [],
47
- '0.9': [],
48
- '1' : []
49
- },
50
- 'four_rooms': {
51
- '0.3': [],
52
- '0.5': [],
53
- '0.7': [],
54
- '0.9': [],
55
- '1' : []
56
- }},
57
- 'parking': {'gc_agent': {
58
- '0.3': [],
59
- '0.5': [],
60
- '0.7': [],
61
- '0.9': [],
62
- '1' : []
63
- },
64
- 'gd_agent': {
65
- '0.3': [],
66
- '0.5': [],
67
- '0.7': [],
68
- '0.9': [],
69
- '1' : []
70
- },
71
- },
72
- },
73
- 'graql': {
74
- 'panda': {'gc_agent': {
75
- '0.3': [],
76
- '0.5': [],
77
- '0.7': [],
78
- '0.9': [],
79
- '1' : []
80
- }},
81
- 'minigrid': {'obstacles': {
82
- '0.3': [],
83
- '0.5': [],
84
- '0.7': [],
85
- '0.9': [],
86
- '1' : []
87
- },
88
- 'lava_crossing': {
89
- '0.3': [],
90
- '0.5': [],
91
- '0.7': [],
92
- '0.9': [],
93
- '1' : []
94
- }},
95
- 'point_maze': {'obstacles': {
96
- '0.3': [],
97
- '0.5': [],
98
- '0.7': [],
99
- '0.9': [],
100
- '1' : []
101
- },
102
- 'four_rooms': {
103
- '0.3': [],
104
- '0.5': [],
105
- '0.7': [],
106
- '0.9': [],
107
- '1' : []
108
- }},
109
- 'parking': {'gc_agent': {
110
- '0.3': [],
111
- '0.5': [],
112
- '0.7': [],
113
- '0.9': [],
114
- '1' : []
115
- },
116
- 'gd_agent': {
117
- '0.3': [],
118
- '0.5': [],
119
- '0.7': [],
120
- '0.9': [],
121
- '1' : []
122
- },
123
- },
124
- }
125
- }
126
-
127
- continuing_accuracies = copy.deepcopy(fragmented_accuracies)
128
-
129
- #domains = ['panda', 'minigrid', 'point_maze', 'parking']
130
- domains = ['parking']
131
- tasks = ['L555', 'L444', 'L333', 'L222', 'L111']
132
- percentages = ['0.3', '0.5', '1']
133
-
134
- for partial_obs_type, accuracies, is_same_learn in zip(['fragmented', 'continuing'], [fragmented_accuracies, continuing_accuracies], [False, True]):
135
- for domain in domains:
136
- for env in accuracies['graml'][domain].keys():
137
- for task in tasks:
138
- set_global_storage_configs(recognizer_str='graml', is_fragmented=partial_obs_type,
139
- is_inference_same_length_sequences=True, is_learn_same_length_sequences=is_same_learn)
140
- graml_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
141
- set_global_storage_configs(recognizer_str='graql', is_fragmented=partial_obs_type)
142
- graql_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
143
- if os.path.exists(graml_res_file_path):
144
- with open(graml_res_file_path, 'rb') as results_file:
145
- results = dill.load(results_file)
146
- for percentage in accuracies['graml'][domain][env].keys():
147
- accuracies['graml'][domain][env][percentage].append(results[percentage]['accuracy'])
148
- else:
149
- assert(False, f"no file for {graml_res_file_path}")
150
- if os.path.exists(graql_res_file_path):
151
- with open(graql_res_file_path, 'rb') as results_file:
152
- results = dill.load(results_file)
153
- for percentage in accuracies['graml'][domain][env].keys():
154
- accuracies['graql'][domain][env][percentage].append(results[percentage]['accuracy'])
155
- else:
156
- assert(False, f"no file for {graql_res_file_path}")
157
-
158
- plot_styles = {
159
- ('graml', 'fragmented', 0.3): 'g--o', # Green dashed line with circle markers
160
- ('graml', 'fragmented', 0.5): 'g--s', # Green dashed line with square markers
161
- ('graml', 'fragmented', 0.7): 'g--^', # Green dashed line with triangle-up markers
162
- ('graml', 'fragmented', 0.9): 'g--d', # Green dashed line with diamond markers
163
- ('graml', 'fragmented', 1.0): 'g--*', # Green dashed line with star markers
164
-
165
- ('graml', 'continuing', 0.3): 'g-o', # Green solid line with circle markers
166
- ('graml', 'continuing', 0.5): 'g-s', # Green solid line with square markers
167
- ('graml', 'continuing', 0.7): 'g-^', # Green solid line with triangle-up markers
168
- ('graml', 'continuing', 0.9): 'g-d', # Green solid line with diamond markers
169
- ('graml', 'continuing', 1.0): 'g-*', # Green solid line with star markers
170
-
171
- ('graql', 'fragmented', 0.3): 'b--o', # Blue dashed line with circle markers
172
- ('graql', 'fragmented', 0.5): 'b--s', # Blue dashed line with square markers
173
- ('graql', 'fragmented', 0.7): 'b--^', # Blue dashed line with triangle-up markers
174
- ('graql', 'fragmented', 0.9): 'b--d', # Blue dashed line with diamond markers
175
- ('graql', 'fragmented', 1.0): 'b--*', # Blue dashed line with star markers
176
-
177
- ('graql', 'continuing', 0.3): 'b-o', # Blue solid line with circle markers
178
- ('graql', 'continuing', 0.5): 'b-s', # Blue solid line with square markers
179
- ('graql', 'continuing', 0.7): 'b-^', # Blue solid line with triangle-up markers
180
- ('graql', 'continuing', 0.9): 'b-d', # Blue solid line with diamond markers
181
- ('graql', 'continuing', 1.0): 'b-*', # Blue solid line with star markers
182
- }
183
-
184
- def average_accuracies(accuracies, domain):
185
- avg_acc = {algo: {perc: [] for perc in percentages}
186
- for algo in ['graml', 'graql']}
187
-
188
- for algo in avg_acc.keys():
189
- for perc in percentages:
190
- for env in accuracies[algo][domain].keys():
191
- env_acc = accuracies[algo][domain][env][perc] # list of 5, averages for L111 to L555.
192
- if env_acc:
193
- avg_acc[algo][perc].append(np.array(env_acc))
194
-
195
- for algo in avg_acc.keys():
196
- for perc in percentages:
197
- if avg_acc[algo][perc]:
198
- avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
199
-
200
- return avg_acc
201
-
202
- def plot_domain_accuracies(ax, fragmented_accuracies, continuing_accuracies, domain, sigma=1, line_width=1.5):
203
- fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
204
- continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
205
-
206
- x_vals = np.arange(1, 6) # Number of goals
207
-
208
- # Create "waves" (shaded regions) for each algorithm
209
- for algo in ['graml', 'graql']:
210
- fragmented_y_vals_by_percentage = []
211
- continuing_y_vals_by_percentage = []
212
-
213
-
214
- for perc in percentages:
215
- fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
216
- continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
217
-
218
- # Smooth the trends using Gaussian filtering
219
- fragmented_y_smoothed = gaussian_filter1d(fragmented_y_vals, sigma=sigma)
220
- continuing_y_smoothed = gaussian_filter1d(continuing_y_vals, sigma=sigma)
221
-
222
- fragmented_y_vals_by_percentage.append(fragmented_y_smoothed)
223
- continuing_y_vals_by_percentage.append(continuing_y_smoothed)
224
-
225
- ax.plot(
226
- x_vals, fragmented_y_smoothed,
227
- plot_styles[(algo, 'fragmented', float(perc))],
228
- label=f"{algo}, non-consecutive, {perc}",
229
- linewidth=0.5 # Control line thickness here
230
- )
231
- ax.plot(
232
- x_vals, continuing_y_smoothed,
233
- plot_styles[(algo, 'continuing', float(perc))],
234
- label=f"{algo}, consecutive, {perc}",
235
- linewidth=0.5 # Control line thickness here
236
- )
237
-
238
- # Fill between trends of the same type that differ only by percentage
239
- # for i in range(len(percentages) - 1):
240
- # ax.fill_between(
241
- # x_vals, fragmented_y_vals_by_percentage[i], fragmented_y_vals_by_percentage[i+1],
242
- # color='green', alpha=0.1 # Adjust the fill color and transparency (for graml)
243
- # )
244
- # ax.fill_between(
245
- # x_vals, continuing_y_vals_by_percentage[i], continuing_y_vals_by_percentage[i+1],
246
- # color='blue', alpha=0.1 # Adjust the fill color and transparency (for graql)
247
- # )
248
-
249
- ax.set_xticks(x_vals)
250
- ax.set_yticks(np.linspace(0, 1, 6))
251
- ax.set_ylim([0, 1])
252
- ax.set_title(f'{domain.capitalize()} Domain', fontsize=16)
253
- ax.grid(True)
254
-
255
- # COMMENT FROM HERE AND UNTIL NEXT FUNCTION FOR BG GC COMPARISON
256
-
257
- # fig, axes = plt.subplots(1, 4, figsize=(24, 6)) # Increase the figure size for better spacing (width 24, height 6)
258
-
259
- # # Generate each plot in a subplot, including both fragmented and continuing accuracies
260
- # for i, domain in enumerate(domains):
261
- # plot_domain_accuracies(axes[i], fragmented_accuracies, continuing_accuracies, domain)
262
-
263
- # # Set a single x-axis and y-axis label for the entire figure
264
- # fig.text(0.5, 0.04, 'Number of Goals', ha='center', fontsize=20) # Centered x-axis label
265
- # fig.text(0.04, 0.5, 'Accuracy', va='center', rotation='vertical', fontsize=20) # Reduced spacing for y-axis label
266
-
267
- # # Adjust subplot layout to avoid overlap
268
- # plt.subplots_adjust(left=0.09, right=0.91, top=0.79, bottom=0.21, wspace=0.3) # More space on top (top=0.82)
269
-
270
- # # Place the legend above the plots with more space between legend and plots
271
- # handles, labels = axes[0].get_legend_handles_labels()
272
- # fig.legend(handles, labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.05), fontsize=12) # Moved above with bbox_to_anchor
273
-
274
- # # Save the figure and show it
275
- # plt.savefig('accuracy_plots_smooth.png', dpi=300)
276
-
277
- # a specific comparison between bg-graml and gc-graml, aka gd_agent and gc_agent "envs":
278
- def plot_stick_figures(continuing_accuracies, fragmented_accuracies, title):
279
- fractions = ['0.3', '0.5', '1']
280
-
281
- def get_agent_data(data_dict, domain='graml', agent='gd_agent'):
282
- return [np.mean(data_dict[domain]['parking'][agent][fraction]) for fraction in fractions]
283
-
284
- # Continuing accuracies for gd_agent and gc_agent
285
- cont_gd = get_agent_data(continuing_accuracies, domain='graml', agent='gd_agent')
286
- cont_gc = get_agent_data(continuing_accuracies, domain='graml', agent='gc_agent')
287
-
288
- # Fragmented accuracies for gd_agent and gc_agent
289
- frag_gd = get_agent_data(fragmented_accuracies, domain='graml', agent='gd_agent')
290
- frag_gc = get_agent_data(fragmented_accuracies, domain='graml', agent='gc_agent')
291
-
292
- # Debugging: Print values to check if they're non-zero
293
- print("Continuing GD:", cont_gd)
294
- print("Continuing GC:", cont_gc)
295
- print("Fragmented GD:", frag_gd)
296
- print("Fragmented GC:", frag_gc)
297
-
298
- # Setting up figure
299
- x = np.arange(len(fractions)) # label locations
300
- width = 0.35 # width of the bars
301
-
302
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
303
-
304
- # Plot for continuing accuracies
305
- ax1.bar(x - width / 2, cont_gd, width, label='BG-GRAML')
306
- ax1.bar(x + width / 2, cont_gc, width, label='GC-GRAML')
307
- ax1.set_title('Consecutive Sequences', fontsize=20)
308
- ax1.set_xticks(x)
309
- ax1.set_xticklabels(fractions, fontsize=16)
310
- ax1.set_yticks(np.arange(0, 1.1, 0.2))
311
- ax1.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
312
- ax1.legend(fontsize=20)
313
-
314
- # Plot for fragmented accuracies
315
- ax2.bar(x - width / 2, frag_gd, width, label='BG-GRAML')
316
- ax2.bar(x + width / 2, frag_gc, width, label='GC-GRAML')
317
- ax2.set_title('Non-Consecutive Sequences', fontsize=20)
318
- ax2.set_xticks(x)
319
- ax2.set_xticklabels(fractions, fontsize=16)
320
- ax2.set_yticks(np.arange(0, 1.1, 0.2))
321
- ax2.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
322
- ax2.set_ylim(0, 1) # Ensure the y-axis is properly set
323
- ax2.legend(fontsize=20)
324
- # Common axis labels
325
- fig.text(0.5, 0.02, 'Observation Portion', ha='center', va='center', fontsize=24)
326
- fig.text(0.06, 0.5, 'Accuracy', ha='center', va='center', rotation='vertical', fontsize=24)
327
-
328
- plt.subplots_adjust(top=0.85)
329
- plt.savefig('gd_vs_gc_parking.png', dpi=300)
330
-
331
- plot_stick_figures(continuing_accuracies, fragmented_accuracies, "GC-GRAML compared with BG-GRAML")
25
+ fragmented_accuracies = {
26
+ "graml": {
27
+ "panda": {
28
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []}
29
+ },
30
+ "minigrid": {
31
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
32
+ "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
33
+ },
34
+ "point_maze": {
35
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
36
+ "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
37
+ },
38
+ "parking": {
39
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
40
+ "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
41
+ },
42
+ },
43
+ "graql": {
44
+ "panda": {
45
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []}
46
+ },
47
+ "minigrid": {
48
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
49
+ "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
50
+ },
51
+ "point_maze": {
52
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
53
+ "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
54
+ },
55
+ "parking": {
56
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
57
+ "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
58
+ },
59
+ },
60
+ }
61
+
62
+ continuing_accuracies = copy.deepcopy(fragmented_accuracies)
63
+
64
+ # domains = ['panda', 'minigrid', 'point_maze', 'parking']
65
+ domains = ["parking"]
66
+ tasks = ["L555", "L444", "L333", "L222", "L111"]
67
+ percentages = ["0.3", "0.5", "1"]
68
+
69
+ for partial_obs_type, accuracies, is_same_learn in zip(
70
+ ["fragmented", "continuing"],
71
+ [fragmented_accuracies, continuing_accuracies],
72
+ [False, True],
73
+ ):
74
+ for domain in domains:
75
+ for env in accuracies["graml"][domain].keys():
76
+ for task in tasks:
77
+ set_global_storage_configs(
78
+ recognizer_str="graml",
79
+ is_fragmented=partial_obs_type,
80
+ is_inference_same_length_sequences=True,
81
+ is_learn_same_length_sequences=is_same_learn,
82
+ )
83
+ graml_res_file_path = (
84
+ f"{get_experiment_results_path(domain, env, task)}.pkl"
85
+ )
86
+ set_global_storage_configs(
87
+ recognizer_str="graql", is_fragmented=partial_obs_type
88
+ )
89
+ graql_res_file_path = (
90
+ f"{get_experiment_results_path(domain, env, task)}.pkl"
91
+ )
92
+ if os.path.exists(graml_res_file_path):
93
+ with open(graml_res_file_path, "rb") as results_file:
94
+ results = dill.load(results_file)
95
+ for percentage in accuracies["graml"][domain][env].keys():
96
+ accuracies["graml"][domain][env][percentage].append(
97
+ results[percentage]["accuracy"]
98
+ )
99
+ else:
100
+ assert (False, f"no file for {graml_res_file_path}")
101
+ if os.path.exists(graql_res_file_path):
102
+ with open(graql_res_file_path, "rb") as results_file:
103
+ results = dill.load(results_file)
104
+ for percentage in accuracies["graml"][domain][env].keys():
105
+ accuracies["graql"][domain][env][percentage].append(
106
+ results[percentage]["accuracy"]
107
+ )
108
+ else:
109
+ assert (False, f"no file for {graql_res_file_path}")
110
+
111
+ plot_styles = {
112
+ ("graml", "fragmented", 0.3): "g--o", # Green dashed line with circle markers
113
+ ("graml", "fragmented", 0.5): "g--s", # Green dashed line with square markers
114
+ (
115
+ "graml",
116
+ "fragmented",
117
+ 0.7,
118
+ ): "g--^", # Green dashed line with triangle-up markers
119
+ ("graml", "fragmented", 0.9): "g--d", # Green dashed line with diamond markers
120
+ ("graml", "fragmented", 1.0): "g--*", # Green dashed line with star markers
121
+ ("graml", "continuing", 0.3): "g-o", # Green solid line with circle markers
122
+ ("graml", "continuing", 0.5): "g-s", # Green solid line with square markers
123
+ (
124
+ "graml",
125
+ "continuing",
126
+ 0.7,
127
+ ): "g-^", # Green solid line with triangle-up markers
128
+ ("graml", "continuing", 0.9): "g-d", # Green solid line with diamond markers
129
+ ("graml", "continuing", 1.0): "g-*", # Green solid line with star markers
130
+ ("graql", "fragmented", 0.3): "b--o", # Blue dashed line with circle markers
131
+ ("graql", "fragmented", 0.5): "b--s", # Blue dashed line with square markers
132
+ (
133
+ "graql",
134
+ "fragmented",
135
+ 0.7,
136
+ ): "b--^", # Blue dashed line with triangle-up markers
137
+ ("graql", "fragmented", 0.9): "b--d", # Blue dashed line with diamond markers
138
+ ("graql", "fragmented", 1.0): "b--*", # Blue dashed line with star markers
139
+ ("graql", "continuing", 0.3): "b-o", # Blue solid line with circle markers
140
+ ("graql", "continuing", 0.5): "b-s", # Blue solid line with square markers
141
+ ("graql", "continuing", 0.7): "b-^", # Blue solid line with triangle-up markers
142
+ ("graql", "continuing", 0.9): "b-d", # Blue solid line with diamond markers
143
+ ("graql", "continuing", 1.0): "b-*", # Blue solid line with star markers
144
+ }
145
+
146
+ def average_accuracies(accuracies, domain):
147
+ avg_acc = {
148
+ algo: {perc: [] for perc in percentages} for algo in ["graml", "graql"]
149
+ }
150
+
151
+ for algo in avg_acc.keys():
152
+ for perc in percentages:
153
+ for env in accuracies[algo][domain].keys():
154
+ env_acc = accuracies[algo][domain][env][
155
+ perc
156
+ ] # list of 5, averages for L111 to L555.
157
+ if env_acc:
158
+ avg_acc[algo][perc].append(np.array(env_acc))
159
+
160
+ for algo in avg_acc.keys():
161
+ for perc in percentages:
162
+ if avg_acc[algo][perc]:
163
+ avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
164
+
165
+ return avg_acc
166
+
167
+ def plot_domain_accuracies(
168
+ ax,
169
+ fragmented_accuracies,
170
+ continuing_accuracies,
171
+ domain,
172
+ sigma=1,
173
+ line_width=1.5,
174
+ ):
175
+ fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
176
+ continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
177
+
178
+ x_vals = np.arange(1, 6) # Number of goals
179
+
180
+ # Create "waves" (shaded regions) for each algorithm
181
+ for algo in ["graml", "graql"]:
182
+ fragmented_y_vals_by_percentage = []
183
+ continuing_y_vals_by_percentage = []
184
+
185
+ for perc in percentages:
186
+ fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
187
+ continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
188
+
189
+ # Smooth the trends using Gaussian filtering
190
+ fragmented_y_smoothed = gaussian_filter1d(
191
+ fragmented_y_vals, sigma=sigma
192
+ )
193
+ continuing_y_smoothed = gaussian_filter1d(
194
+ continuing_y_vals, sigma=sigma
195
+ )
196
+
197
+ fragmented_y_vals_by_percentage.append(fragmented_y_smoothed)
198
+ continuing_y_vals_by_percentage.append(continuing_y_smoothed)
199
+
200
+ ax.plot(
201
+ x_vals,
202
+ fragmented_y_smoothed,
203
+ plot_styles[(algo, "fragmented", float(perc))],
204
+ label=f"{algo}, non-consecutive, {perc}",
205
+ linewidth=0.5, # Control line thickness here
206
+ )
207
+ ax.plot(
208
+ x_vals,
209
+ continuing_y_smoothed,
210
+ plot_styles[(algo, "continuing", float(perc))],
211
+ label=f"{algo}, consecutive, {perc}",
212
+ linewidth=0.5, # Control line thickness here
213
+ )
214
+
215
+ # Fill between trends of the same type that differ only by percentage
216
+ # for i in range(len(percentages) - 1):
217
+ # ax.fill_between(
218
+ # x_vals, fragmented_y_vals_by_percentage[i], fragmented_y_vals_by_percentage[i+1],
219
+ # color='green', alpha=0.1 # Adjust the fill color and transparency (for graml)
220
+ # )
221
+ # ax.fill_between(
222
+ # x_vals, continuing_y_vals_by_percentage[i], continuing_y_vals_by_percentage[i+1],
223
+ # color='blue', alpha=0.1 # Adjust the fill color and transparency (for graql)
224
+ # )
225
+
226
+ ax.set_xticks(x_vals)
227
+ ax.set_yticks(np.linspace(0, 1, 6))
228
+ ax.set_ylim([0, 1])
229
+ ax.set_title(f"{domain.capitalize()} Domain", fontsize=16)
230
+ ax.grid(True)
231
+
232
+ # COMMENT FROM HERE AND UNTIL NEXT FUNCTION FOR BG GC COMPARISON
233
+
234
+ # fig, axes = plt.subplots(1, 4, figsize=(24, 6)) # Increase the figure size for better spacing (width 24, height 6)
235
+
236
+ # # Generate each plot in a subplot, including both fragmented and continuing accuracies
237
+ # for i, domain in enumerate(domains):
238
+ # plot_domain_accuracies(axes[i], fragmented_accuracies, continuing_accuracies, domain)
239
+
240
+ # # Set a single x-axis and y-axis label for the entire figure
241
+ # fig.text(0.5, 0.04, 'Number of Goals', ha='center', fontsize=20) # Centered x-axis label
242
+ # fig.text(0.04, 0.5, 'Accuracy', va='center', rotation='vertical', fontsize=20) # Reduced spacing for y-axis label
243
+
244
+ # # Adjust subplot layout to avoid overlap
245
+ # plt.subplots_adjust(left=0.09, right=0.91, top=0.79, bottom=0.21, wspace=0.3) # More space on top (top=0.82)
246
+
247
+ # # Place the legend above the plots with more space between legend and plots
248
+ # handles, labels = axes[0].get_legend_handles_labels()
249
+ # fig.legend(handles, labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.05), fontsize=12) # Moved above with bbox_to_anchor
250
+
251
+ # # Save the figure and show it
252
+ # plt.savefig('accuracy_plots_smooth.png', dpi=300)
253
+
254
+ # a specific comparison between bg-graml and gc-graml, aka gd_agent and gc_agent "envs":
255
+ def plot_stick_figures(continuing_accuracies, fragmented_accuracies, title):
256
+ fractions = ["0.3", "0.5", "1"]
257
+
258
+ def get_agent_data(data_dict, domain="graml", agent="gd_agent"):
259
+ return [
260
+ np.mean(data_dict[domain]["parking"][agent][fraction])
261
+ for fraction in fractions
262
+ ]
263
+
264
+ # Continuing accuracies for gd_agent and gc_agent
265
+ cont_gd = get_agent_data(
266
+ continuing_accuracies, domain="graml", agent="gd_agent"
267
+ )
268
+ cont_gc = get_agent_data(
269
+ continuing_accuracies, domain="graml", agent="gc_agent"
270
+ )
271
+
272
+ # Fragmented accuracies for gd_agent and gc_agent
273
+ frag_gd = get_agent_data(
274
+ fragmented_accuracies, domain="graml", agent="gd_agent"
275
+ )
276
+ frag_gc = get_agent_data(
277
+ fragmented_accuracies, domain="graml", agent="gc_agent"
278
+ )
279
+
280
+ # Debugging: Print values to check if they're non-zero
281
+ print("Continuing GD:", cont_gd)
282
+ print("Continuing GC:", cont_gc)
283
+ print("Fragmented GD:", frag_gd)
284
+ print("Fragmented GC:", frag_gc)
285
+
286
+ # Setting up figure
287
+ x = np.arange(len(fractions)) # label locations
288
+ width = 0.35 # width of the bars
289
+
290
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
291
+
292
+ # Plot for continuing accuracies
293
+ ax1.bar(x - width / 2, cont_gd, width, label="BG-GRAML")
294
+ ax1.bar(x + width / 2, cont_gc, width, label="GC-GRAML")
295
+ ax1.set_title("Consecutive Sequences", fontsize=20)
296
+ ax1.set_xticks(x)
297
+ ax1.set_xticklabels(fractions, fontsize=16)
298
+ ax1.set_yticks(np.arange(0, 1.1, 0.2))
299
+ ax1.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
300
+ ax1.legend(fontsize=20)
301
+
302
+ # Plot for fragmented accuracies
303
+ ax2.bar(x - width / 2, frag_gd, width, label="BG-GRAML")
304
+ ax2.bar(x + width / 2, frag_gc, width, label="GC-GRAML")
305
+ ax2.set_title("Non-Consecutive Sequences", fontsize=20)
306
+ ax2.set_xticks(x)
307
+ ax2.set_xticklabels(fractions, fontsize=16)
308
+ ax2.set_yticks(np.arange(0, 1.1, 0.2))
309
+ ax2.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
310
+ ax2.set_ylim(0, 1) # Ensure the y-axis is properly set
311
+ ax2.legend(fontsize=20)
312
+ # Common axis labels
313
+ fig.text(
314
+ 0.5, 0.02, "Observation Portion", ha="center", va="center", fontsize=24
315
+ )
316
+ fig.text(
317
+ 0.06,
318
+ 0.5,
319
+ "Accuracy",
320
+ ha="center",
321
+ va="center",
322
+ rotation="vertical",
323
+ fontsize=24,
324
+ )
325
+
326
+ plt.subplots_adjust(top=0.85)
327
+ plt.savefig("gd_vs_gc_parking.png", dpi=300)
328
+
329
+ plot_stick_figures(
330
+ continuing_accuracies, fragmented_accuracies, "GC-GRAML compared with BG-GRAML"
331
+ )