gr-libs 0.1.7.post0__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. evaluation/analyze_results_cross_alg_cross_domain.py +236 -246
  2. evaluation/create_minigrid_map_image.py +10 -6
  3. evaluation/file_system.py +16 -5
  4. evaluation/generate_experiments_results.py +123 -74
  5. evaluation/generate_experiments_results_new_ver1.py +227 -243
  6. evaluation/generate_experiments_results_new_ver2.py +317 -317
  7. evaluation/generate_task_specific_statistics_plots.py +481 -253
  8. evaluation/get_plans_images.py +41 -26
  9. evaluation/increasing_and_decreasing_.py +97 -56
  10. gr_libs/__init__.py +2 -1
  11. gr_libs/_version.py +2 -2
  12. gr_libs/environment/__init__.py +16 -8
  13. gr_libs/environment/environment.py +167 -39
  14. gr_libs/environment/utils/utils.py +22 -12
  15. gr_libs/metrics/__init__.py +5 -0
  16. gr_libs/metrics/metrics.py +76 -34
  17. gr_libs/ml/__init__.py +2 -0
  18. gr_libs/ml/agent.py +21 -6
  19. gr_libs/ml/base/__init__.py +1 -1
  20. gr_libs/ml/base/rl_agent.py +13 -10
  21. gr_libs/ml/consts.py +1 -1
  22. gr_libs/ml/neural/deep_rl_learner.py +433 -352
  23. gr_libs/ml/neural/utils/__init__.py +1 -1
  24. gr_libs/ml/neural/utils/dictlist.py +3 -3
  25. gr_libs/ml/neural/utils/penv.py +5 -2
  26. gr_libs/ml/planner/mcts/mcts_model.py +524 -302
  27. gr_libs/ml/planner/mcts/utils/__init__.py +1 -1
  28. gr_libs/ml/planner/mcts/utils/node.py +11 -7
  29. gr_libs/ml/planner/mcts/utils/tree.py +14 -10
  30. gr_libs/ml/sequential/__init__.py +1 -1
  31. gr_libs/ml/sequential/lstm_model.py +256 -175
  32. gr_libs/ml/tabular/state.py +7 -7
  33. gr_libs/ml/tabular/tabular_q_learner.py +123 -73
  34. gr_libs/ml/tabular/tabular_rl_agent.py +20 -19
  35. gr_libs/ml/utils/__init__.py +8 -2
  36. gr_libs/ml/utils/format.py +78 -70
  37. gr_libs/ml/utils/math.py +2 -1
  38. gr_libs/ml/utils/other.py +1 -1
  39. gr_libs/ml/utils/storage.py +88 -28
  40. gr_libs/problems/consts.py +1549 -1227
  41. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +145 -80
  42. gr_libs/recognizer/graml/gr_dataset.py +209 -110
  43. gr_libs/recognizer/graml/graml_recognizer.py +431 -240
  44. gr_libs/recognizer/recognizer.py +38 -27
  45. gr_libs/recognizer/utils/__init__.py +1 -1
  46. gr_libs/recognizer/utils/format.py +8 -3
  47. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/METADATA +1 -1
  48. gr_libs-0.1.8.dist-info/RECORD +70 -0
  49. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/WHEEL +1 -1
  50. tests/test_gcdraco.py +10 -0
  51. tests/test_graml.py +8 -4
  52. tests/test_graql.py +2 -1
  53. tutorials/gcdraco_panda_tutorial.py +66 -0
  54. tutorials/gcdraco_parking_tutorial.py +61 -0
  55. tutorials/graml_minigrid_tutorial.py +42 -12
  56. tutorials/graml_panda_tutorial.py +35 -14
  57. tutorials/graml_parking_tutorial.py +37 -20
  58. tutorials/graml_point_maze_tutorial.py +33 -13
  59. tutorials/graql_minigrid_tutorial.py +31 -15
  60. gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
  61. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/top_level.txt +0 -0
@@ -5,250 +5,234 @@ import numpy as np
5
5
  import os
6
6
  import dill
7
7
 
8
- from gr_libs.ml.utils.storage import get_experiment_results_path, set_global_storage_configs
8
+ from gr_libs.ml.utils.storage import (
9
+ get_experiment_results_path,
10
+ set_global_storage_configs,
11
+ )
9
12
  from scripts.generate_task_specific_statistics_plots import get_figures_dir_path
10
13
 
11
14
  if __name__ == "__main__":
12
15
 
13
- fragmented_accuracies = {
14
- 'graml': {
15
- 'panda': {'gd_agent': {
16
- '0.3': [], # every list here should have number of tasks accuracies in it, since we done experiments for L111-L555. remember each accuracy is an average of #goals different tasks.
17
- '0.5': [],
18
- '0.7': [],
19
- '0.9': [],
20
- '1' : []
21
- },
22
- 'gc_agent': {
23
- '0.3': [],
24
- '0.5': [],
25
- '0.7': [],
26
- '0.9': [],
27
- '1' : []
28
- }},
29
- 'minigrid': {'obstacles': {
30
- '0.3': [],
31
- '0.5': [],
32
- '0.7': [],
33
- '0.9': [],
34
- '1' : []
35
- },
36
- 'lava_crossing': {
37
- '0.3': [],
38
- '0.5': [],
39
- '0.7': [],
40
- '0.9': [],
41
- '1' : []
42
- }},
43
- 'point_maze': {'obstacles': {
44
- '0.3': [],
45
- '0.5': [],
46
- '0.7': [],
47
- '0.9': [],
48
- '1' : []
49
- },
50
- 'four_rooms': {
51
- '0.3': [],
52
- '0.5': [],
53
- '0.7': [],
54
- '0.9': [],
55
- '1' : []
56
- }},
57
- 'parking': {'gd_agent': {
58
- '0.3': [],
59
- '0.5': [],
60
- '0.7': [],
61
- '0.9': [],
62
- '1' : []
63
- },
64
- 'gc_agent': {
65
- '0.3': [],
66
- '0.5': [],
67
- '0.7': [],
68
- '0.9': [],
69
- '1' : []
70
- }},
71
- },
72
- 'graql': {
73
- 'panda': {'gd_agent': {
74
- '0.3': [],
75
- '0.5': [],
76
- '0.7': [],
77
- '0.9': [],
78
- '1' : []
79
- },
80
- 'gc_agent': {
81
- '0.3': [],
82
- '0.5': [],
83
- '0.7': [],
84
- '0.9': [],
85
- '1' : []
86
- }},
87
- 'minigrid': {'obstacles': {
88
- '0.3': [],
89
- '0.5': [],
90
- '0.7': [],
91
- '0.9': [],
92
- '1' : []
93
- },
94
- 'lava_crossing': {
95
- '0.3': [],
96
- '0.5': [],
97
- '0.7': [],
98
- '0.9': [],
99
- '1' : []
100
- }},
101
- 'point_maze': {'obstacles': {
102
- '0.3': [],
103
- '0.5': [],
104
- '0.7': [],
105
- '0.9': [],
106
- '1' : []
107
- },
108
- 'four_rooms': {
109
- '0.3': [],
110
- '0.5': [],
111
- '0.7': [],
112
- '0.9': [],
113
- '1' : []
114
- }},
115
- 'parking': {'gd_agent': {
116
- '0.3': [],
117
- '0.5': [],
118
- '0.7': [],
119
- '0.9': [],
120
- '1' : []
121
- },
122
- 'gc_agent': {
123
- '0.3': [],
124
- '0.5': [],
125
- '0.7': [],
126
- '0.9': [],
127
- '1' : []
128
- }},
129
- }
130
- }
131
-
132
- continuing_accuracies = copy.deepcopy(fragmented_accuracies)
133
-
134
- #domains = ['panda', 'minigrid', 'point_maze', 'parking']
135
- domains = ['minigrid', 'point_maze', 'parking']
136
- tasks = ['L111', 'L222', 'L333', 'L444', 'L555']
137
- percentages = ['0.3', '0.5', '0.7', '0.9', '1']
138
-
139
- for partial_obs_type, accuracies, is_same_learn in zip(['fragmented', 'continuing'], [fragmented_accuracies, continuing_accuracies], [False, True]):
140
- for domain in domains:
141
- for env in accuracies['graml'][domain].keys():
142
- for task in tasks:
143
- set_global_storage_configs(recognizer_str='graml', is_fragmented=partial_obs_type,
144
- is_inference_same_length_sequences=True, is_learn_same_length_sequences=is_same_learn)
145
- graml_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
146
- set_global_storage_configs(recognizer_str='graql', is_fragmented=partial_obs_type)
147
- graql_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
148
- if os.path.exists(graml_res_file_path):
149
- with open(graml_res_file_path, 'rb') as results_file:
150
- results = dill.load(results_file)
151
- for percentage in accuracies['graml'][domain][env].keys():
152
- accuracies['graml'][domain][env][percentage].append(results[percentage]['accuracy'])
153
- else:
154
- assert(False, f"no file for {graml_res_file_path}")
155
- if os.path.exists(graql_res_file_path):
156
- with open(graql_res_file_path, 'rb') as results_file:
157
- results = dill.load(results_file)
158
- for percentage in accuracies['graml'][domain][env].keys():
159
- accuracies['graql'][domain][env][percentage].append(results[percentage]['accuracy'])
160
- else:
161
- assert(False, f"no file for {graql_res_file_path}")
162
-
163
- plot_styles = {
164
- ('graml', 'fragmented', 0.3): 'g--o', # Green dashed line with circle markers
165
- ('graml', 'fragmented', 0.5): 'g--s', # Green dashed line with square markers
166
- ('graml', 'fragmented', 0.7): 'g--^', # Green dashed line with triangle-up markers
167
- ('graml', 'fragmented', 0.9): 'g--d', # Green dashed line with diamond markers
168
- ('graml', 'fragmented', 1.0): 'g--*', # Green dashed line with star markers
169
-
170
- ('graml', 'continuing', 0.3): 'g-o', # Green solid line with circle markers
171
- ('graml', 'continuing', 0.5): 'g-s', # Green solid line with square markers
172
- ('graml', 'continuing', 0.7): 'g-^', # Green solid line with triangle-up markers
173
- ('graml', 'continuing', 0.9): 'g-d', # Green solid line with diamond markers
174
- ('graml', 'continuing', 1.0): 'g-*', # Green solid line with star markers
175
-
176
- ('graql', 'fragmented', 0.3): 'b--o', # Blue dashed line with circle markers
177
- ('graql', 'fragmented', 0.5): 'b--s', # Blue dashed line with square markers
178
- ('graql', 'fragmented', 0.7): 'b--^', # Blue dashed line with triangle-up markers
179
- ('graql', 'fragmented', 0.9): 'b--d', # Blue dashed line with diamond markers
180
- ('graql', 'fragmented', 1.0): 'b--*', # Blue dashed line with star markers
181
-
182
- ('graql', 'continuing', 0.3): 'b-o', # Blue solid line with circle markers
183
- ('graql', 'continuing', 0.5): 'b-s', # Blue solid line with square markers
184
- ('graql', 'continuing', 0.7): 'b-^', # Blue solid line with triangle-up markers
185
- ('graql', 'continuing', 0.9): 'b-d', # Blue solid line with diamond markers
186
- ('graql', 'continuing', 1.0): 'b-*', # Blue solid line with star markers
187
- }
188
-
189
- def average_accuracies(accuracies, domain):
190
- avg_acc = {algo: {perc: [] for perc in percentages}
191
- for algo in ['graml', 'graql']}
192
-
193
- for algo in avg_acc.keys():
194
- for perc in percentages:
195
- for env in accuracies[algo][domain].keys():
196
- env_acc = accuracies[algo][domain][env][perc] # list of 5, averages for L111 to L555.
197
- if env_acc:
198
- avg_acc[algo][perc].append(np.array(env_acc))
199
-
200
- for algo in avg_acc.keys():
201
- for perc in percentages:
202
- if avg_acc[algo][perc]:
203
- avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
204
-
205
- return avg_acc
206
-
207
- def plot_domain_accuracies(ax, fragmented_accuracies, continuing_accuracies, domain):
208
- fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
209
- continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
210
-
211
- x_vals = np.arange(1, 6) # Number of goals
212
-
213
- # Create "waves" (shaded regions) for each algorithm
214
- for algo in ['graml', 'graql']:
215
- for perc in percentages:
216
- fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
217
- continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
218
-
219
- ax.plot(
220
- x_vals, fragmented_y_vals,
221
- plot_styles[(algo, 'fragmented', float(perc))], # Use the updated plot_styles dictionary with percentage
222
- label=f"{algo}, non-consecutive, {perc}"
223
- )
224
- ax.plot(
225
- x_vals, continuing_y_vals,
226
- plot_styles[(algo, 'continuing', float(perc))], # Use the updated plot_styles dictionary with percentage
227
- label=f"{algo}, consecutive, {perc}"
228
- )
229
-
230
- ax.set_xticks(x_vals)
231
- ax.set_yticks(np.linspace(0, 1, 6))
232
- ax.set_ylim([0, 1])
233
- ax.set_title(f'{domain.capitalize()} Domain', fontsize=16)
234
- ax.grid(True)
235
-
236
- fig, axes = plt.subplots(1, 4, figsize=(24, 6)) # Increase the figure size for better spacing (width 24, height 6)
237
-
238
- # Generate each plot in a subplot, including both fragmented and continuing accuracies
239
- for i, domain in enumerate(domains):
240
- plot_domain_accuracies(axes[i], fragmented_accuracies, continuing_accuracies, domain)
241
-
242
- # Set a single x-axis and y-axis label for the entire figure
243
- fig.text(0.5, 0.04, 'Number of Goals', ha='center', fontsize=20) # Centered x-axis label
244
- fig.text(0.04, 0.5, 'Accuracy', va='center', rotation='vertical', fontsize=20) # Reduced spacing for y-axis label
245
-
246
- # Adjust subplot layout to avoid overlap
247
- plt.subplots_adjust(left=0.09, right=0.91, top=0.76, bottom=0.24, wspace=0.3) # More space on top (top=0.82)
248
-
249
- # Place the legend above the plots with more space between legend and plots
250
- handles, labels = axes[0].get_legend_handles_labels()
251
- fig.legend(handles, labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.05), fontsize=12) # Moved above with bbox_to_anchor
252
-
253
- # Save the figure and show it
254
- plt.savefig('accuracy_plots.png', dpi=300)
16
+ fragmented_accuracies = {
17
+ "graml": {
18
+ "panda": {
19
+ "gd_agent": {
20
+ "0.3": [], # every list here should have number of tasks accuracies in it, since we done experiments for L111-L555. remember each accuracy is an average of #goals different tasks.
21
+ "0.5": [],
22
+ "0.7": [],
23
+ "0.9": [],
24
+ "1": [],
25
+ },
26
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
27
+ },
28
+ "minigrid": {
29
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
30
+ "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
31
+ },
32
+ "point_maze": {
33
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
34
+ "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
35
+ },
36
+ "parking": {
37
+ "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
38
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
39
+ },
40
+ },
41
+ "graql": {
42
+ "panda": {
43
+ "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
44
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
45
+ },
46
+ "minigrid": {
47
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
48
+ "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
49
+ },
50
+ "point_maze": {
51
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
52
+ "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
53
+ },
54
+ "parking": {
55
+ "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
56
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
57
+ },
58
+ },
59
+ }
60
+
61
+ continuing_accuracies = copy.deepcopy(fragmented_accuracies)
62
+
63
+ # domains = ['panda', 'minigrid', 'point_maze', 'parking']
64
+ domains = ["minigrid", "point_maze", "parking"]
65
+ tasks = ["L111", "L222", "L333", "L444", "L555"]
66
+ percentages = ["0.3", "0.5", "0.7", "0.9", "1"]
67
+
68
+ for partial_obs_type, accuracies, is_same_learn in zip(
69
+ ["fragmented", "continuing"],
70
+ [fragmented_accuracies, continuing_accuracies],
71
+ [False, True],
72
+ ):
73
+ for domain in domains:
74
+ for env in accuracies["graml"][domain].keys():
75
+ for task in tasks:
76
+ set_global_storage_configs(
77
+ recognizer_str="graml",
78
+ is_fragmented=partial_obs_type,
79
+ is_inference_same_length_sequences=True,
80
+ is_learn_same_length_sequences=is_same_learn,
81
+ )
82
+ graml_res_file_path = (
83
+ f"{get_experiment_results_path(domain, env, task)}.pkl"
84
+ )
85
+ set_global_storage_configs(
86
+ recognizer_str="graql", is_fragmented=partial_obs_type
87
+ )
88
+ graql_res_file_path = (
89
+ f"{get_experiment_results_path(domain, env, task)}.pkl"
90
+ )
91
+ if os.path.exists(graml_res_file_path):
92
+ with open(graml_res_file_path, "rb") as results_file:
93
+ results = dill.load(results_file)
94
+ for percentage in accuracies["graml"][domain][env].keys():
95
+ accuracies["graml"][domain][env][percentage].append(
96
+ results[percentage]["accuracy"]
97
+ )
98
+ else:
99
+ assert (False, f"no file for {graml_res_file_path}")
100
+ if os.path.exists(graql_res_file_path):
101
+ with open(graql_res_file_path, "rb") as results_file:
102
+ results = dill.load(results_file)
103
+ for percentage in accuracies["graml"][domain][env].keys():
104
+ accuracies["graql"][domain][env][percentage].append(
105
+ results[percentage]["accuracy"]
106
+ )
107
+ else:
108
+ assert (False, f"no file for {graql_res_file_path}")
109
+
110
+ plot_styles = {
111
+ ("graml", "fragmented", 0.3): "g--o", # Green dashed line with circle markers
112
+ ("graml", "fragmented", 0.5): "g--s", # Green dashed line with square markers
113
+ (
114
+ "graml",
115
+ "fragmented",
116
+ 0.7,
117
+ ): "g--^", # Green dashed line with triangle-up markers
118
+ ("graml", "fragmented", 0.9): "g--d", # Green dashed line with diamond markers
119
+ ("graml", "fragmented", 1.0): "g--*", # Green dashed line with star markers
120
+ ("graml", "continuing", 0.3): "g-o", # Green solid line with circle markers
121
+ ("graml", "continuing", 0.5): "g-s", # Green solid line with square markers
122
+ (
123
+ "graml",
124
+ "continuing",
125
+ 0.7,
126
+ ): "g-^", # Green solid line with triangle-up markers
127
+ ("graml", "continuing", 0.9): "g-d", # Green solid line with diamond markers
128
+ ("graml", "continuing", 1.0): "g-*", # Green solid line with star markers
129
+ ("graql", "fragmented", 0.3): "b--o", # Blue dashed line with circle markers
130
+ ("graql", "fragmented", 0.5): "b--s", # Blue dashed line with square markers
131
+ (
132
+ "graql",
133
+ "fragmented",
134
+ 0.7,
135
+ ): "b--^", # Blue dashed line with triangle-up markers
136
+ ("graql", "fragmented", 0.9): "b--d", # Blue dashed line with diamond markers
137
+ ("graql", "fragmented", 1.0): "b--*", # Blue dashed line with star markers
138
+ ("graql", "continuing", 0.3): "b-o", # Blue solid line with circle markers
139
+ ("graql", "continuing", 0.5): "b-s", # Blue solid line with square markers
140
+ ("graql", "continuing", 0.7): "b-^", # Blue solid line with triangle-up markers
141
+ ("graql", "continuing", 0.9): "b-d", # Blue solid line with diamond markers
142
+ ("graql", "continuing", 1.0): "b-*", # Blue solid line with star markers
143
+ }
144
+
145
+ def average_accuracies(accuracies, domain):
146
+ avg_acc = {
147
+ algo: {perc: [] for perc in percentages} for algo in ["graml", "graql"]
148
+ }
149
+
150
+ for algo in avg_acc.keys():
151
+ for perc in percentages:
152
+ for env in accuracies[algo][domain].keys():
153
+ env_acc = accuracies[algo][domain][env][
154
+ perc
155
+ ] # list of 5, averages for L111 to L555.
156
+ if env_acc:
157
+ avg_acc[algo][perc].append(np.array(env_acc))
158
+
159
+ for algo in avg_acc.keys():
160
+ for perc in percentages:
161
+ if avg_acc[algo][perc]:
162
+ avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
163
+
164
+ return avg_acc
165
+
166
+ def plot_domain_accuracies(
167
+ ax, fragmented_accuracies, continuing_accuracies, domain
168
+ ):
169
+ fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
170
+ continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
171
+
172
+ x_vals = np.arange(1, 6) # Number of goals
173
+
174
+ # Create "waves" (shaded regions) for each algorithm
175
+ for algo in ["graml", "graql"]:
176
+ for perc in percentages:
177
+ fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
178
+ continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
179
+
180
+ ax.plot(
181
+ x_vals,
182
+ fragmented_y_vals,
183
+ plot_styles[
184
+ (algo, "fragmented", float(perc))
185
+ ], # Use the updated plot_styles dictionary with percentage
186
+ label=f"{algo}, non-consecutive, {perc}",
187
+ )
188
+ ax.plot(
189
+ x_vals,
190
+ continuing_y_vals,
191
+ plot_styles[
192
+ (algo, "continuing", float(perc))
193
+ ], # Use the updated plot_styles dictionary with percentage
194
+ label=f"{algo}, consecutive, {perc}",
195
+ )
196
+
197
+ ax.set_xticks(x_vals)
198
+ ax.set_yticks(np.linspace(0, 1, 6))
199
+ ax.set_ylim([0, 1])
200
+ ax.set_title(f"{domain.capitalize()} Domain", fontsize=16)
201
+ ax.grid(True)
202
+
203
+ fig, axes = plt.subplots(
204
+ 1, 4, figsize=(24, 6)
205
+ ) # Increase the figure size for better spacing (width 24, height 6)
206
+
207
+ # Generate each plot in a subplot, including both fragmented and continuing accuracies
208
+ for i, domain in enumerate(domains):
209
+ plot_domain_accuracies(
210
+ axes[i], fragmented_accuracies, continuing_accuracies, domain
211
+ )
212
+
213
+ # Set a single x-axis and y-axis label for the entire figure
214
+ fig.text(
215
+ 0.5, 0.04, "Number of Goals", ha="center", fontsize=20
216
+ ) # Centered x-axis label
217
+ fig.text(
218
+ 0.04, 0.5, "Accuracy", va="center", rotation="vertical", fontsize=20
219
+ ) # Reduced spacing for y-axis label
220
+
221
+ # Adjust subplot layout to avoid overlap
222
+ plt.subplots_adjust(
223
+ left=0.09, right=0.91, top=0.76, bottom=0.24, wspace=0.3
224
+ ) # More space on top (top=0.82)
225
+
226
+ # Place the legend above the plots with more space between legend and plots
227
+ handles, labels = axes[0].get_legend_handles_labels()
228
+ fig.legend(
229
+ handles,
230
+ labels,
231
+ loc="upper center",
232
+ ncol=4,
233
+ bbox_to_anchor=(0.5, 1.05),
234
+ fontsize=12,
235
+ ) # Moved above with bbox_to_anchor
236
+
237
+ # Save the figure and show it
238
+ plt.savefig("accuracy_plots.png", dpi=300)