gr-libs 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evaluation/analyze_results_cross_alg_cross_domain.py +277 -0
- evaluation/create_minigrid_map_image.py +34 -0
- evaluation/file_system.py +42 -0
- evaluation/generate_experiments_results.py +92 -0
- evaluation/generate_experiments_results_new_ver1.py +254 -0
- evaluation/generate_experiments_results_new_ver2.py +331 -0
- evaluation/generate_task_specific_statistics_plots.py +272 -0
- evaluation/get_plans_images.py +47 -0
- evaluation/increasing_and_decreasing_.py +63 -0
- gr_libs/__init__.py +2 -0
- gr_libs/environment/__init__.py +0 -0
- gr_libs/environment/environment.py +227 -0
- gr_libs/environment/utils/__init__.py +0 -0
- gr_libs/environment/utils/utils.py +17 -0
- gr_libs/metrics/__init__.py +0 -0
- gr_libs/metrics/metrics.py +224 -0
- gr_libs/ml/__init__.py +6 -0
- gr_libs/ml/agent.py +56 -0
- gr_libs/ml/base/__init__.py +1 -0
- gr_libs/ml/base/rl_agent.py +54 -0
- gr_libs/ml/consts.py +22 -0
- gr_libs/ml/neural/__init__.py +3 -0
- gr_libs/ml/neural/deep_rl_learner.py +395 -0
- gr_libs/ml/neural/utils/__init__.py +2 -0
- gr_libs/ml/neural/utils/dictlist.py +33 -0
- gr_libs/ml/neural/utils/penv.py +57 -0
- gr_libs/ml/planner/__init__.py +0 -0
- gr_libs/ml/planner/mcts/__init__.py +0 -0
- gr_libs/ml/planner/mcts/mcts_model.py +330 -0
- gr_libs/ml/planner/mcts/utils/__init__.py +2 -0
- gr_libs/ml/planner/mcts/utils/node.py +33 -0
- gr_libs/ml/planner/mcts/utils/tree.py +102 -0
- gr_libs/ml/sequential/__init__.py +1 -0
- gr_libs/ml/sequential/lstm_model.py +192 -0
- gr_libs/ml/tabular/__init__.py +3 -0
- gr_libs/ml/tabular/state.py +21 -0
- gr_libs/ml/tabular/tabular_q_learner.py +453 -0
- gr_libs/ml/tabular/tabular_rl_agent.py +126 -0
- gr_libs/ml/utils/__init__.py +6 -0
- gr_libs/ml/utils/env.py +7 -0
- gr_libs/ml/utils/format.py +100 -0
- gr_libs/ml/utils/math.py +13 -0
- gr_libs/ml/utils/other.py +24 -0
- gr_libs/ml/utils/storage.py +127 -0
- gr_libs/recognizer/__init__.py +0 -0
- gr_libs/recognizer/gr_as_rl/__init__.py +0 -0
- gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +102 -0
- gr_libs/recognizer/graml/__init__.py +0 -0
- gr_libs/recognizer/graml/gr_dataset.py +134 -0
- gr_libs/recognizer/graml/graml_recognizer.py +266 -0
- gr_libs/recognizer/recognizer.py +46 -0
- gr_libs/recognizer/utils/__init__.py +1 -0
- gr_libs/recognizer/utils/format.py +13 -0
- gr_libs-0.1.3.dist-info/METADATA +197 -0
- gr_libs-0.1.3.dist-info/RECORD +62 -0
- gr_libs-0.1.3.dist-info/WHEEL +5 -0
- gr_libs-0.1.3.dist-info/top_level.txt +3 -0
- tutorials/graml_minigrid_tutorial.py +30 -0
- tutorials/graml_panda_tutorial.py +32 -0
- tutorials/graml_parking_tutorial.py +38 -0
- tutorials/graml_point_maze_tutorial.py +43 -0
- tutorials/graql_minigrid_tutorial.py +29 -0
@@ -0,0 +1,254 @@
|
|
1
|
+
import copy
|
2
|
+
import sys
|
3
|
+
import matplotlib.pyplot as plt
|
4
|
+
import numpy as np
|
5
|
+
import os
|
6
|
+
import dill
|
7
|
+
|
8
|
+
from gr_libs.ml.utils.storage import get_experiment_results_path, set_global_storage_configs
|
9
|
+
from scripts.generate_task_specific_statistics_plots import get_figures_dir_path
|
10
|
+
|
11
|
+
if __name__ == "__main__":
|
12
|
+
|
13
|
+
fragmented_accuracies = {
|
14
|
+
'graml': {
|
15
|
+
'panda': {'gd_agent': {
|
16
|
+
'0.3': [], # every list here should have number of tasks accuracies in it, since we done experiments for L111-L555. remember each accuracy is an average of #goals different tasks.
|
17
|
+
'0.5': [],
|
18
|
+
'0.7': [],
|
19
|
+
'0.9': [],
|
20
|
+
'1' : []
|
21
|
+
},
|
22
|
+
'gc_agent': {
|
23
|
+
'0.3': [],
|
24
|
+
'0.5': [],
|
25
|
+
'0.7': [],
|
26
|
+
'0.9': [],
|
27
|
+
'1' : []
|
28
|
+
}},
|
29
|
+
'minigrid': {'obstacles': {
|
30
|
+
'0.3': [],
|
31
|
+
'0.5': [],
|
32
|
+
'0.7': [],
|
33
|
+
'0.9': [],
|
34
|
+
'1' : []
|
35
|
+
},
|
36
|
+
'lava_crossing': {
|
37
|
+
'0.3': [],
|
38
|
+
'0.5': [],
|
39
|
+
'0.7': [],
|
40
|
+
'0.9': [],
|
41
|
+
'1' : []
|
42
|
+
}},
|
43
|
+
'point_maze': {'obstacles': {
|
44
|
+
'0.3': [],
|
45
|
+
'0.5': [],
|
46
|
+
'0.7': [],
|
47
|
+
'0.9': [],
|
48
|
+
'1' : []
|
49
|
+
},
|
50
|
+
'four_rooms': {
|
51
|
+
'0.3': [],
|
52
|
+
'0.5': [],
|
53
|
+
'0.7': [],
|
54
|
+
'0.9': [],
|
55
|
+
'1' : []
|
56
|
+
}},
|
57
|
+
'parking': {'gd_agent': {
|
58
|
+
'0.3': [],
|
59
|
+
'0.5': [],
|
60
|
+
'0.7': [],
|
61
|
+
'0.9': [],
|
62
|
+
'1' : []
|
63
|
+
},
|
64
|
+
'gc_agent': {
|
65
|
+
'0.3': [],
|
66
|
+
'0.5': [],
|
67
|
+
'0.7': [],
|
68
|
+
'0.9': [],
|
69
|
+
'1' : []
|
70
|
+
}},
|
71
|
+
},
|
72
|
+
'graql': {
|
73
|
+
'panda': {'gd_agent': {
|
74
|
+
'0.3': [],
|
75
|
+
'0.5': [],
|
76
|
+
'0.7': [],
|
77
|
+
'0.9': [],
|
78
|
+
'1' : []
|
79
|
+
},
|
80
|
+
'gc_agent': {
|
81
|
+
'0.3': [],
|
82
|
+
'0.5': [],
|
83
|
+
'0.7': [],
|
84
|
+
'0.9': [],
|
85
|
+
'1' : []
|
86
|
+
}},
|
87
|
+
'minigrid': {'obstacles': {
|
88
|
+
'0.3': [],
|
89
|
+
'0.5': [],
|
90
|
+
'0.7': [],
|
91
|
+
'0.9': [],
|
92
|
+
'1' : []
|
93
|
+
},
|
94
|
+
'lava_crossing': {
|
95
|
+
'0.3': [],
|
96
|
+
'0.5': [],
|
97
|
+
'0.7': [],
|
98
|
+
'0.9': [],
|
99
|
+
'1' : []
|
100
|
+
}},
|
101
|
+
'point_maze': {'obstacles': {
|
102
|
+
'0.3': [],
|
103
|
+
'0.5': [],
|
104
|
+
'0.7': [],
|
105
|
+
'0.9': [],
|
106
|
+
'1' : []
|
107
|
+
},
|
108
|
+
'four_rooms': {
|
109
|
+
'0.3': [],
|
110
|
+
'0.5': [],
|
111
|
+
'0.7': [],
|
112
|
+
'0.9': [],
|
113
|
+
'1' : []
|
114
|
+
}},
|
115
|
+
'parking': {'gd_agent': {
|
116
|
+
'0.3': [],
|
117
|
+
'0.5': [],
|
118
|
+
'0.7': [],
|
119
|
+
'0.9': [],
|
120
|
+
'1' : []
|
121
|
+
},
|
122
|
+
'gc_agent': {
|
123
|
+
'0.3': [],
|
124
|
+
'0.5': [],
|
125
|
+
'0.7': [],
|
126
|
+
'0.9': [],
|
127
|
+
'1' : []
|
128
|
+
}},
|
129
|
+
}
|
130
|
+
}
|
131
|
+
|
132
|
+
continuing_accuracies = copy.deepcopy(fragmented_accuracies)
|
133
|
+
|
134
|
+
#domains = ['panda', 'minigrid', 'point_maze', 'parking']
|
135
|
+
domains = ['minigrid', 'point_maze', 'parking']
|
136
|
+
tasks = ['L111', 'L222', 'L333', 'L444', 'L555']
|
137
|
+
percentages = ['0.3', '0.5', '0.7', '0.9', '1']
|
138
|
+
|
139
|
+
for partial_obs_type, accuracies, is_same_learn in zip(['fragmented', 'continuing'], [fragmented_accuracies, continuing_accuracies], [False, True]):
|
140
|
+
for domain in domains:
|
141
|
+
for env in accuracies['graml'][domain].keys():
|
142
|
+
for task in tasks:
|
143
|
+
set_global_storage_configs(recognizer_str='graml', is_fragmented=partial_obs_type,
|
144
|
+
is_inference_same_length_sequences=True, is_learn_same_length_sequences=is_same_learn)
|
145
|
+
graml_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
|
146
|
+
set_global_storage_configs(recognizer_str='graql', is_fragmented=partial_obs_type)
|
147
|
+
graql_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
|
148
|
+
if os.path.exists(graml_res_file_path):
|
149
|
+
with open(graml_res_file_path, 'rb') as results_file:
|
150
|
+
results = dill.load(results_file)
|
151
|
+
for percentage in accuracies['graml'][domain][env].keys():
|
152
|
+
accuracies['graml'][domain][env][percentage].append(results[percentage]['accuracy'])
|
153
|
+
else:
|
154
|
+
assert(False, f"no file for {graml_res_file_path}")
|
155
|
+
if os.path.exists(graql_res_file_path):
|
156
|
+
with open(graql_res_file_path, 'rb') as results_file:
|
157
|
+
results = dill.load(results_file)
|
158
|
+
for percentage in accuracies['graml'][domain][env].keys():
|
159
|
+
accuracies['graql'][domain][env][percentage].append(results[percentage]['accuracy'])
|
160
|
+
else:
|
161
|
+
assert(False, f"no file for {graql_res_file_path}")
|
162
|
+
|
163
|
+
plot_styles = {
|
164
|
+
('graml', 'fragmented', 0.3): 'g--o', # Green dashed line with circle markers
|
165
|
+
('graml', 'fragmented', 0.5): 'g--s', # Green dashed line with square markers
|
166
|
+
('graml', 'fragmented', 0.7): 'g--^', # Green dashed line with triangle-up markers
|
167
|
+
('graml', 'fragmented', 0.9): 'g--d', # Green dashed line with diamond markers
|
168
|
+
('graml', 'fragmented', 1.0): 'g--*', # Green dashed line with star markers
|
169
|
+
|
170
|
+
('graml', 'continuing', 0.3): 'g-o', # Green solid line with circle markers
|
171
|
+
('graml', 'continuing', 0.5): 'g-s', # Green solid line with square markers
|
172
|
+
('graml', 'continuing', 0.7): 'g-^', # Green solid line with triangle-up markers
|
173
|
+
('graml', 'continuing', 0.9): 'g-d', # Green solid line with diamond markers
|
174
|
+
('graml', 'continuing', 1.0): 'g-*', # Green solid line with star markers
|
175
|
+
|
176
|
+
('graql', 'fragmented', 0.3): 'b--o', # Blue dashed line with circle markers
|
177
|
+
('graql', 'fragmented', 0.5): 'b--s', # Blue dashed line with square markers
|
178
|
+
('graql', 'fragmented', 0.7): 'b--^', # Blue dashed line with triangle-up markers
|
179
|
+
('graql', 'fragmented', 0.9): 'b--d', # Blue dashed line with diamond markers
|
180
|
+
('graql', 'fragmented', 1.0): 'b--*', # Blue dashed line with star markers
|
181
|
+
|
182
|
+
('graql', 'continuing', 0.3): 'b-o', # Blue solid line with circle markers
|
183
|
+
('graql', 'continuing', 0.5): 'b-s', # Blue solid line with square markers
|
184
|
+
('graql', 'continuing', 0.7): 'b-^', # Blue solid line with triangle-up markers
|
185
|
+
('graql', 'continuing', 0.9): 'b-d', # Blue solid line with diamond markers
|
186
|
+
('graql', 'continuing', 1.0): 'b-*', # Blue solid line with star markers
|
187
|
+
}
|
188
|
+
|
189
|
+
def average_accuracies(accuracies, domain):
|
190
|
+
avg_acc = {algo: {perc: [] for perc in percentages}
|
191
|
+
for algo in ['graml', 'graql']}
|
192
|
+
|
193
|
+
for algo in avg_acc.keys():
|
194
|
+
for perc in percentages:
|
195
|
+
for env in accuracies[algo][domain].keys():
|
196
|
+
env_acc = accuracies[algo][domain][env][perc] # list of 5, averages for L111 to L555.
|
197
|
+
if env_acc:
|
198
|
+
avg_acc[algo][perc].append(np.array(env_acc))
|
199
|
+
|
200
|
+
for algo in avg_acc.keys():
|
201
|
+
for perc in percentages:
|
202
|
+
if avg_acc[algo][perc]:
|
203
|
+
avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
|
204
|
+
|
205
|
+
return avg_acc
|
206
|
+
|
207
|
+
def plot_domain_accuracies(ax, fragmented_accuracies, continuing_accuracies, domain):
|
208
|
+
fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
|
209
|
+
continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
|
210
|
+
|
211
|
+
x_vals = np.arange(1, 6) # Number of goals
|
212
|
+
|
213
|
+
# Create "waves" (shaded regions) for each algorithm
|
214
|
+
for algo in ['graml', 'graql']:
|
215
|
+
for perc in percentages:
|
216
|
+
fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
|
217
|
+
continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
|
218
|
+
|
219
|
+
ax.plot(
|
220
|
+
x_vals, fragmented_y_vals,
|
221
|
+
plot_styles[(algo, 'fragmented', float(perc))], # Use the updated plot_styles dictionary with percentage
|
222
|
+
label=f"{algo}, non-consecutive, {perc}"
|
223
|
+
)
|
224
|
+
ax.plot(
|
225
|
+
x_vals, continuing_y_vals,
|
226
|
+
plot_styles[(algo, 'continuing', float(perc))], # Use the updated plot_styles dictionary with percentage
|
227
|
+
label=f"{algo}, consecutive, {perc}"
|
228
|
+
)
|
229
|
+
|
230
|
+
ax.set_xticks(x_vals)
|
231
|
+
ax.set_yticks(np.linspace(0, 1, 6))
|
232
|
+
ax.set_ylim([0, 1])
|
233
|
+
ax.set_title(f'{domain.capitalize()} Domain', fontsize=16)
|
234
|
+
ax.grid(True)
|
235
|
+
|
236
|
+
fig, axes = plt.subplots(1, 4, figsize=(24, 6)) # Increase the figure size for better spacing (width 24, height 6)
|
237
|
+
|
238
|
+
# Generate each plot in a subplot, including both fragmented and continuing accuracies
|
239
|
+
for i, domain in enumerate(domains):
|
240
|
+
plot_domain_accuracies(axes[i], fragmented_accuracies, continuing_accuracies, domain)
|
241
|
+
|
242
|
+
# Set a single x-axis and y-axis label for the entire figure
|
243
|
+
fig.text(0.5, 0.04, 'Number of Goals', ha='center', fontsize=20) # Centered x-axis label
|
244
|
+
fig.text(0.04, 0.5, 'Accuracy', va='center', rotation='vertical', fontsize=20) # Reduced spacing for y-axis label
|
245
|
+
|
246
|
+
# Adjust subplot layout to avoid overlap
|
247
|
+
plt.subplots_adjust(left=0.09, right=0.91, top=0.76, bottom=0.24, wspace=0.3) # More space on top (top=0.82)
|
248
|
+
|
249
|
+
# Place the legend above the plots with more space between legend and plots
|
250
|
+
handles, labels = axes[0].get_legend_handles_labels()
|
251
|
+
fig.legend(handles, labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.05), fontsize=12) # Moved above with bbox_to_anchor
|
252
|
+
|
253
|
+
# Save the figure and show it
|
254
|
+
plt.savefig('accuracy_plots.png', dpi=300)
|
@@ -0,0 +1,331 @@
|
|
1
|
+
import copy
|
2
|
+
import sys
|
3
|
+
import matplotlib.pyplot as plt
|
4
|
+
import numpy as np
|
5
|
+
import os
|
6
|
+
import dill
|
7
|
+
from scipy.interpolate import make_interp_spline
|
8
|
+
from scipy.ndimage import gaussian_filter1d
|
9
|
+
from gr_libs.ml.utils.storage import get_experiment_results_path, set_global_storage_configs
|
10
|
+
from scripts.generate_task_specific_statistics_plots import get_figures_dir_path
|
11
|
+
|
12
|
+
def smooth_line(x, y, num_points=300):
|
13
|
+
x_smooth = np.linspace(np.min(x), np.max(x), num_points)
|
14
|
+
spline = make_interp_spline(x, y, k=3) # Cubic spline
|
15
|
+
y_smooth = spline(x_smooth)
|
16
|
+
return x_smooth, y_smooth
|
17
|
+
|
18
|
+
if __name__ == "__main__":
|
19
|
+
|
20
|
+
fragmented_accuracies = {
|
21
|
+
'graml': {
|
22
|
+
'panda': {'gc_agent': {
|
23
|
+
'0.3': [],
|
24
|
+
'0.5': [],
|
25
|
+
'0.7': [],
|
26
|
+
'0.9': [],
|
27
|
+
'1' : []
|
28
|
+
}},
|
29
|
+
'minigrid': {'obstacles': {
|
30
|
+
'0.3': [],
|
31
|
+
'0.5': [],
|
32
|
+
'0.7': [],
|
33
|
+
'0.9': [],
|
34
|
+
'1' : []
|
35
|
+
},
|
36
|
+
'lava_crossing': {
|
37
|
+
'0.3': [],
|
38
|
+
'0.5': [],
|
39
|
+
'0.7': [],
|
40
|
+
'0.9': [],
|
41
|
+
'1' : []
|
42
|
+
}},
|
43
|
+
'point_maze': {'obstacles': {
|
44
|
+
'0.3': [],
|
45
|
+
'0.5': [],
|
46
|
+
'0.7': [],
|
47
|
+
'0.9': [],
|
48
|
+
'1' : []
|
49
|
+
},
|
50
|
+
'four_rooms': {
|
51
|
+
'0.3': [],
|
52
|
+
'0.5': [],
|
53
|
+
'0.7': [],
|
54
|
+
'0.9': [],
|
55
|
+
'1' : []
|
56
|
+
}},
|
57
|
+
'parking': {'gc_agent': {
|
58
|
+
'0.3': [],
|
59
|
+
'0.5': [],
|
60
|
+
'0.7': [],
|
61
|
+
'0.9': [],
|
62
|
+
'1' : []
|
63
|
+
},
|
64
|
+
'gd_agent': {
|
65
|
+
'0.3': [],
|
66
|
+
'0.5': [],
|
67
|
+
'0.7': [],
|
68
|
+
'0.9': [],
|
69
|
+
'1' : []
|
70
|
+
},
|
71
|
+
},
|
72
|
+
},
|
73
|
+
'graql': {
|
74
|
+
'panda': {'gc_agent': {
|
75
|
+
'0.3': [],
|
76
|
+
'0.5': [],
|
77
|
+
'0.7': [],
|
78
|
+
'0.9': [],
|
79
|
+
'1' : []
|
80
|
+
}},
|
81
|
+
'minigrid': {'obstacles': {
|
82
|
+
'0.3': [],
|
83
|
+
'0.5': [],
|
84
|
+
'0.7': [],
|
85
|
+
'0.9': [],
|
86
|
+
'1' : []
|
87
|
+
},
|
88
|
+
'lava_crossing': {
|
89
|
+
'0.3': [],
|
90
|
+
'0.5': [],
|
91
|
+
'0.7': [],
|
92
|
+
'0.9': [],
|
93
|
+
'1' : []
|
94
|
+
}},
|
95
|
+
'point_maze': {'obstacles': {
|
96
|
+
'0.3': [],
|
97
|
+
'0.5': [],
|
98
|
+
'0.7': [],
|
99
|
+
'0.9': [],
|
100
|
+
'1' : []
|
101
|
+
},
|
102
|
+
'four_rooms': {
|
103
|
+
'0.3': [],
|
104
|
+
'0.5': [],
|
105
|
+
'0.7': [],
|
106
|
+
'0.9': [],
|
107
|
+
'1' : []
|
108
|
+
}},
|
109
|
+
'parking': {'gc_agent': {
|
110
|
+
'0.3': [],
|
111
|
+
'0.5': [],
|
112
|
+
'0.7': [],
|
113
|
+
'0.9': [],
|
114
|
+
'1' : []
|
115
|
+
},
|
116
|
+
'gd_agent': {
|
117
|
+
'0.3': [],
|
118
|
+
'0.5': [],
|
119
|
+
'0.7': [],
|
120
|
+
'0.9': [],
|
121
|
+
'1' : []
|
122
|
+
},
|
123
|
+
},
|
124
|
+
}
|
125
|
+
}
|
126
|
+
|
127
|
+
continuing_accuracies = copy.deepcopy(fragmented_accuracies)
|
128
|
+
|
129
|
+
#domains = ['panda', 'minigrid', 'point_maze', 'parking']
|
130
|
+
domains = ['parking']
|
131
|
+
tasks = ['L555', 'L444', 'L333', 'L222', 'L111']
|
132
|
+
percentages = ['0.3', '0.5', '1']
|
133
|
+
|
134
|
+
for partial_obs_type, accuracies, is_same_learn in zip(['fragmented', 'continuing'], [fragmented_accuracies, continuing_accuracies], [False, True]):
|
135
|
+
for domain in domains:
|
136
|
+
for env in accuracies['graml'][domain].keys():
|
137
|
+
for task in tasks:
|
138
|
+
set_global_storage_configs(recognizer_str='graml', is_fragmented=partial_obs_type,
|
139
|
+
is_inference_same_length_sequences=True, is_learn_same_length_sequences=is_same_learn)
|
140
|
+
graml_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
|
141
|
+
set_global_storage_configs(recognizer_str='graql', is_fragmented=partial_obs_type)
|
142
|
+
graql_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
|
143
|
+
if os.path.exists(graml_res_file_path):
|
144
|
+
with open(graml_res_file_path, 'rb') as results_file:
|
145
|
+
results = dill.load(results_file)
|
146
|
+
for percentage in accuracies['graml'][domain][env].keys():
|
147
|
+
accuracies['graml'][domain][env][percentage].append(results[percentage]['accuracy'])
|
148
|
+
else:
|
149
|
+
assert(False, f"no file for {graml_res_file_path}")
|
150
|
+
if os.path.exists(graql_res_file_path):
|
151
|
+
with open(graql_res_file_path, 'rb') as results_file:
|
152
|
+
results = dill.load(results_file)
|
153
|
+
for percentage in accuracies['graml'][domain][env].keys():
|
154
|
+
accuracies['graql'][domain][env][percentage].append(results[percentage]['accuracy'])
|
155
|
+
else:
|
156
|
+
assert(False, f"no file for {graql_res_file_path}")
|
157
|
+
|
158
|
+
plot_styles = {
|
159
|
+
('graml', 'fragmented', 0.3): 'g--o', # Green dashed line with circle markers
|
160
|
+
('graml', 'fragmented', 0.5): 'g--s', # Green dashed line with square markers
|
161
|
+
('graml', 'fragmented', 0.7): 'g--^', # Green dashed line with triangle-up markers
|
162
|
+
('graml', 'fragmented', 0.9): 'g--d', # Green dashed line with diamond markers
|
163
|
+
('graml', 'fragmented', 1.0): 'g--*', # Green dashed line with star markers
|
164
|
+
|
165
|
+
('graml', 'continuing', 0.3): 'g-o', # Green solid line with circle markers
|
166
|
+
('graml', 'continuing', 0.5): 'g-s', # Green solid line with square markers
|
167
|
+
('graml', 'continuing', 0.7): 'g-^', # Green solid line with triangle-up markers
|
168
|
+
('graml', 'continuing', 0.9): 'g-d', # Green solid line with diamond markers
|
169
|
+
('graml', 'continuing', 1.0): 'g-*', # Green solid line with star markers
|
170
|
+
|
171
|
+
('graql', 'fragmented', 0.3): 'b--o', # Blue dashed line with circle markers
|
172
|
+
('graql', 'fragmented', 0.5): 'b--s', # Blue dashed line with square markers
|
173
|
+
('graql', 'fragmented', 0.7): 'b--^', # Blue dashed line with triangle-up markers
|
174
|
+
('graql', 'fragmented', 0.9): 'b--d', # Blue dashed line with diamond markers
|
175
|
+
('graql', 'fragmented', 1.0): 'b--*', # Blue dashed line with star markers
|
176
|
+
|
177
|
+
('graql', 'continuing', 0.3): 'b-o', # Blue solid line with circle markers
|
178
|
+
('graql', 'continuing', 0.5): 'b-s', # Blue solid line with square markers
|
179
|
+
('graql', 'continuing', 0.7): 'b-^', # Blue solid line with triangle-up markers
|
180
|
+
('graql', 'continuing', 0.9): 'b-d', # Blue solid line with diamond markers
|
181
|
+
('graql', 'continuing', 1.0): 'b-*', # Blue solid line with star markers
|
182
|
+
}
|
183
|
+
|
184
|
+
def average_accuracies(accuracies, domain):
|
185
|
+
avg_acc = {algo: {perc: [] for perc in percentages}
|
186
|
+
for algo in ['graml', 'graql']}
|
187
|
+
|
188
|
+
for algo in avg_acc.keys():
|
189
|
+
for perc in percentages:
|
190
|
+
for env in accuracies[algo][domain].keys():
|
191
|
+
env_acc = accuracies[algo][domain][env][perc] # list of 5, averages for L111 to L555.
|
192
|
+
if env_acc:
|
193
|
+
avg_acc[algo][perc].append(np.array(env_acc))
|
194
|
+
|
195
|
+
for algo in avg_acc.keys():
|
196
|
+
for perc in percentages:
|
197
|
+
if avg_acc[algo][perc]:
|
198
|
+
avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
|
199
|
+
|
200
|
+
return avg_acc
|
201
|
+
|
202
|
+
def plot_domain_accuracies(ax, fragmented_accuracies, continuing_accuracies, domain, sigma=1, line_width=1.5):
|
203
|
+
fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
|
204
|
+
continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
|
205
|
+
|
206
|
+
x_vals = np.arange(1, 6) # Number of goals
|
207
|
+
|
208
|
+
# Create "waves" (shaded regions) for each algorithm
|
209
|
+
for algo in ['graml', 'graql']:
|
210
|
+
fragmented_y_vals_by_percentage = []
|
211
|
+
continuing_y_vals_by_percentage = []
|
212
|
+
|
213
|
+
|
214
|
+
for perc in percentages:
|
215
|
+
fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
|
216
|
+
continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
|
217
|
+
|
218
|
+
# Smooth the trends using Gaussian filtering
|
219
|
+
fragmented_y_smoothed = gaussian_filter1d(fragmented_y_vals, sigma=sigma)
|
220
|
+
continuing_y_smoothed = gaussian_filter1d(continuing_y_vals, sigma=sigma)
|
221
|
+
|
222
|
+
fragmented_y_vals_by_percentage.append(fragmented_y_smoothed)
|
223
|
+
continuing_y_vals_by_percentage.append(continuing_y_smoothed)
|
224
|
+
|
225
|
+
ax.plot(
|
226
|
+
x_vals, fragmented_y_smoothed,
|
227
|
+
plot_styles[(algo, 'fragmented', float(perc))],
|
228
|
+
label=f"{algo}, non-consecutive, {perc}",
|
229
|
+
linewidth=0.5 # Control line thickness here
|
230
|
+
)
|
231
|
+
ax.plot(
|
232
|
+
x_vals, continuing_y_smoothed,
|
233
|
+
plot_styles[(algo, 'continuing', float(perc))],
|
234
|
+
label=f"{algo}, consecutive, {perc}",
|
235
|
+
linewidth=0.5 # Control line thickness here
|
236
|
+
)
|
237
|
+
|
238
|
+
# Fill between trends of the same type that differ only by percentage
|
239
|
+
# for i in range(len(percentages) - 1):
|
240
|
+
# ax.fill_between(
|
241
|
+
# x_vals, fragmented_y_vals_by_percentage[i], fragmented_y_vals_by_percentage[i+1],
|
242
|
+
# color='green', alpha=0.1 # Adjust the fill color and transparency (for graml)
|
243
|
+
# )
|
244
|
+
# ax.fill_between(
|
245
|
+
# x_vals, continuing_y_vals_by_percentage[i], continuing_y_vals_by_percentage[i+1],
|
246
|
+
# color='blue', alpha=0.1 # Adjust the fill color and transparency (for graql)
|
247
|
+
# )
|
248
|
+
|
249
|
+
ax.set_xticks(x_vals)
|
250
|
+
ax.set_yticks(np.linspace(0, 1, 6))
|
251
|
+
ax.set_ylim([0, 1])
|
252
|
+
ax.set_title(f'{domain.capitalize()} Domain', fontsize=16)
|
253
|
+
ax.grid(True)
|
254
|
+
|
255
|
+
# COMMENT FROM HERE AND UNTIL NEXT FUNCTION FOR BG GC COMPARISON
|
256
|
+
|
257
|
+
# fig, axes = plt.subplots(1, 4, figsize=(24, 6)) # Increase the figure size for better spacing (width 24, height 6)
|
258
|
+
|
259
|
+
# # Generate each plot in a subplot, including both fragmented and continuing accuracies
|
260
|
+
# for i, domain in enumerate(domains):
|
261
|
+
# plot_domain_accuracies(axes[i], fragmented_accuracies, continuing_accuracies, domain)
|
262
|
+
|
263
|
+
# # Set a single x-axis and y-axis label for the entire figure
|
264
|
+
# fig.text(0.5, 0.04, 'Number of Goals', ha='center', fontsize=20) # Centered x-axis label
|
265
|
+
# fig.text(0.04, 0.5, 'Accuracy', va='center', rotation='vertical', fontsize=20) # Reduced spacing for y-axis label
|
266
|
+
|
267
|
+
# # Adjust subplot layout to avoid overlap
|
268
|
+
# plt.subplots_adjust(left=0.09, right=0.91, top=0.79, bottom=0.21, wspace=0.3) # More space on top (top=0.82)
|
269
|
+
|
270
|
+
# # Place the legend above the plots with more space between legend and plots
|
271
|
+
# handles, labels = axes[0].get_legend_handles_labels()
|
272
|
+
# fig.legend(handles, labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.05), fontsize=12) # Moved above with bbox_to_anchor
|
273
|
+
|
274
|
+
# # Save the figure and show it
|
275
|
+
# plt.savefig('accuracy_plots_smooth.png', dpi=300)
|
276
|
+
|
277
|
+
# a specific comparison between bg-graml and gc-graml, aka gd_agent and gc_agent "envs":
|
278
|
+
def plot_stick_figures(continuing_accuracies, fragmented_accuracies, title):
|
279
|
+
fractions = ['0.3', '0.5', '1']
|
280
|
+
|
281
|
+
def get_agent_data(data_dict, domain='graml', agent='gd_agent'):
|
282
|
+
return [np.mean(data_dict[domain]['parking'][agent][fraction]) for fraction in fractions]
|
283
|
+
|
284
|
+
# Continuing accuracies for gd_agent and gc_agent
|
285
|
+
cont_gd = get_agent_data(continuing_accuracies, domain='graml', agent='gd_agent')
|
286
|
+
cont_gc = get_agent_data(continuing_accuracies, domain='graml', agent='gc_agent')
|
287
|
+
|
288
|
+
# Fragmented accuracies for gd_agent and gc_agent
|
289
|
+
frag_gd = get_agent_data(fragmented_accuracies, domain='graml', agent='gd_agent')
|
290
|
+
frag_gc = get_agent_data(fragmented_accuracies, domain='graml', agent='gc_agent')
|
291
|
+
|
292
|
+
# Debugging: Print values to check if they're non-zero
|
293
|
+
print("Continuing GD:", cont_gd)
|
294
|
+
print("Continuing GC:", cont_gc)
|
295
|
+
print("Fragmented GD:", frag_gd)
|
296
|
+
print("Fragmented GC:", frag_gc)
|
297
|
+
|
298
|
+
# Setting up figure
|
299
|
+
x = np.arange(len(fractions)) # label locations
|
300
|
+
width = 0.35 # width of the bars
|
301
|
+
|
302
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
|
303
|
+
|
304
|
+
# Plot for continuing accuracies
|
305
|
+
ax1.bar(x - width / 2, cont_gd, width, label='BG-GRAML')
|
306
|
+
ax1.bar(x + width / 2, cont_gc, width, label='GC-GRAML')
|
307
|
+
ax1.set_title('Consecutive Sequences', fontsize=20)
|
308
|
+
ax1.set_xticks(x)
|
309
|
+
ax1.set_xticklabels(fractions, fontsize=16)
|
310
|
+
ax1.set_yticks(np.arange(0, 1.1, 0.2))
|
311
|
+
ax1.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
|
312
|
+
ax1.legend(fontsize=20)
|
313
|
+
|
314
|
+
# Plot for fragmented accuracies
|
315
|
+
ax2.bar(x - width / 2, frag_gd, width, label='BG-GRAML')
|
316
|
+
ax2.bar(x + width / 2, frag_gc, width, label='GC-GRAML')
|
317
|
+
ax2.set_title('Non-Consecutive Sequences', fontsize=20)
|
318
|
+
ax2.set_xticks(x)
|
319
|
+
ax2.set_xticklabels(fractions, fontsize=16)
|
320
|
+
ax2.set_yticks(np.arange(0, 1.1, 0.2))
|
321
|
+
ax2.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
|
322
|
+
ax2.set_ylim(0, 1) # Ensure the y-axis is properly set
|
323
|
+
ax2.legend(fontsize=20)
|
324
|
+
# Common axis labels
|
325
|
+
fig.text(0.5, 0.02, 'Observation Portion', ha='center', va='center', fontsize=24)
|
326
|
+
fig.text(0.06, 0.5, 'Accuracy', ha='center', va='center', rotation='vertical', fontsize=24)
|
327
|
+
|
328
|
+
plt.subplots_adjust(top=0.85)
|
329
|
+
plt.savefig('gd_vs_gc_parking.png', dpi=300)
|
330
|
+
|
331
|
+
plot_stick_figures(continuing_accuracies, fragmented_accuracies, "GC-GRAML compared with BG-GRAML")
|