gr-libs 0.1.6.post1__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. evaluation/analyze_results_cross_alg_cross_domain.py +236 -246
  2. evaluation/create_minigrid_map_image.py +10 -6
  3. evaluation/file_system.py +16 -5
  4. evaluation/generate_experiments_results.py +123 -74
  5. evaluation/generate_experiments_results_new_ver1.py +227 -243
  6. evaluation/generate_experiments_results_new_ver2.py +317 -317
  7. evaluation/generate_task_specific_statistics_plots.py +481 -253
  8. evaluation/get_plans_images.py +41 -26
  9. evaluation/increasing_and_decreasing_.py +97 -56
  10. gr_libs/__init__.py +6 -1
  11. gr_libs/_version.py +2 -2
  12. gr_libs/environment/__init__.py +17 -9
  13. gr_libs/environment/environment.py +167 -39
  14. gr_libs/environment/utils/utils.py +22 -12
  15. gr_libs/metrics/__init__.py +5 -0
  16. gr_libs/metrics/metrics.py +76 -34
  17. gr_libs/ml/__init__.py +2 -0
  18. gr_libs/ml/agent.py +21 -6
  19. gr_libs/ml/base/__init__.py +1 -1
  20. gr_libs/ml/base/rl_agent.py +13 -10
  21. gr_libs/ml/consts.py +1 -1
  22. gr_libs/ml/neural/deep_rl_learner.py +433 -352
  23. gr_libs/ml/neural/utils/__init__.py +1 -1
  24. gr_libs/ml/neural/utils/dictlist.py +3 -3
  25. gr_libs/ml/neural/utils/penv.py +5 -2
  26. gr_libs/ml/planner/mcts/mcts_model.py +524 -302
  27. gr_libs/ml/planner/mcts/utils/__init__.py +1 -1
  28. gr_libs/ml/planner/mcts/utils/node.py +11 -7
  29. gr_libs/ml/planner/mcts/utils/tree.py +14 -10
  30. gr_libs/ml/sequential/__init__.py +1 -1
  31. gr_libs/ml/sequential/lstm_model.py +256 -175
  32. gr_libs/ml/tabular/state.py +7 -7
  33. gr_libs/ml/tabular/tabular_q_learner.py +123 -73
  34. gr_libs/ml/tabular/tabular_rl_agent.py +20 -19
  35. gr_libs/ml/utils/__init__.py +8 -2
  36. gr_libs/ml/utils/format.py +78 -70
  37. gr_libs/ml/utils/math.py +2 -1
  38. gr_libs/ml/utils/other.py +1 -1
  39. gr_libs/ml/utils/storage.py +95 -28
  40. gr_libs/problems/consts.py +1549 -1227
  41. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +145 -80
  42. gr_libs/recognizer/graml/gr_dataset.py +209 -110
  43. gr_libs/recognizer/graml/graml_recognizer.py +431 -231
  44. gr_libs/recognizer/recognizer.py +38 -27
  45. gr_libs/recognizer/utils/__init__.py +1 -1
  46. gr_libs/recognizer/utils/format.py +8 -3
  47. {gr_libs-0.1.6.post1.dist-info → gr_libs-0.1.8.dist-info}/METADATA +1 -1
  48. gr_libs-0.1.8.dist-info/RECORD +70 -0
  49. {gr_libs-0.1.6.post1.dist-info → gr_libs-0.1.8.dist-info}/WHEEL +1 -1
  50. {gr_libs-0.1.6.post1.dist-info → gr_libs-0.1.8.dist-info}/top_level.txt +0 -1
  51. tests/test_gcdraco.py +10 -0
  52. tests/test_graml.py +8 -4
  53. tests/test_graql.py +2 -1
  54. tutorials/gcdraco_panda_tutorial.py +66 -0
  55. tutorials/gcdraco_parking_tutorial.py +61 -0
  56. tutorials/graml_minigrid_tutorial.py +42 -12
  57. tutorials/graml_panda_tutorial.py +35 -14
  58. tutorials/graml_parking_tutorial.py +37 -19
  59. tutorials/graml_point_maze_tutorial.py +33 -13
  60. tutorials/graql_minigrid_tutorial.py +31 -15
  61. CI/README.md +0 -12
  62. CI/docker_build_context/Dockerfile +0 -15
  63. gr_libs/recognizer/recognizer_doc.md +0 -61
  64. gr_libs-0.1.6.post1.dist-info/RECORD +0 -70
@@ -6,272 +6,262 @@ import os
6
6
  import dill
7
7
  from scipy.interpolate import make_interp_spline
8
8
  from scipy.ndimage import gaussian_filter1d
9
- from gr_libs.ml.utils.storage import get_experiment_results_path, set_global_storage_configs
9
+ from gr_libs.ml.utils.storage import (
10
+ get_experiment_results_path,
11
+ set_global_storage_configs,
12
+ )
10
13
  from scripts.generate_task_specific_statistics_plots import get_figures_dir_path
11
14
 
15
+
12
16
  def smooth_line(x, y, num_points=300):
13
17
  x_smooth = np.linspace(np.min(x), np.max(x), num_points)
14
18
  spline = make_interp_spline(x, y, k=3) # Cubic spline
15
19
  y_smooth = spline(x_smooth)
16
20
  return x_smooth, y_smooth
17
21
 
22
+
18
23
  if __name__ == "__main__":
19
24
 
20
- fragmented_accuracies = {
21
- 'graml': {
22
- 'panda': {'gd_agent': {
23
- '0.3': [], # every list here should have number of tasks accuracies in it, since we done experiments for L111-L555. remember each accuracy is an average of #goals different tasks.
24
- '0.5': [],
25
- '0.7': [],
26
- '0.9': [],
27
- '1' : []
28
- },
29
- 'gc_agent': {
30
- '0.3': [],
31
- '0.5': [],
32
- '0.7': [],
33
- '0.9': [],
34
- '1' : []
35
- }},
36
- 'minigrid': {'obstacles': {
37
- '0.3': [],
38
- '0.5': [],
39
- '0.7': [],
40
- '0.9': [],
41
- '1' : []
42
- },
43
- 'lava_crossing': {
44
- '0.3': [],
45
- '0.5': [],
46
- '0.7': [],
47
- '0.9': [],
48
- '1' : []
49
- }},
50
- 'point_maze': {'obstacles': {
51
- '0.3': [],
52
- '0.5': [],
53
- '0.7': [],
54
- '0.9': [],
55
- '1' : []
56
- },
57
- 'four_rooms': {
58
- '0.3': [],
59
- '0.5': [],
60
- '0.7': [],
61
- '0.9': [],
62
- '1' : []
63
- }},
64
- 'parking': {'gd_agent': {
65
- '0.3': [],
66
- '0.5': [],
67
- '0.7': [],
68
- '0.9': [],
69
- '1' : []
70
- },
71
- 'gc_agent': {
72
- '0.3': [],
73
- '0.5': [],
74
- '0.7': [],
75
- '0.9': [],
76
- '1' : []
77
- }},
78
- },
79
- 'graql': {
80
- 'panda': {'gd_agent': {
81
- '0.3': [],
82
- '0.5': [],
83
- '0.7': [],
84
- '0.9': [],
85
- '1' : []
86
- },
87
- 'gc_agent': {
88
- '0.3': [],
89
- '0.5': [],
90
- '0.7': [],
91
- '0.9': [],
92
- '1' : []
93
- }},
94
- 'minigrid': {'obstacles': {
95
- '0.3': [],
96
- '0.5': [],
97
- '0.7': [],
98
- '0.9': [],
99
- '1' : []
100
- },
101
- 'lava_crossing': {
102
- '0.3': [],
103
- '0.5': [],
104
- '0.7': [],
105
- '0.9': [],
106
- '1' : []
107
- }},
108
- 'point_maze': {'obstacles': {
109
- '0.3': [],
110
- '0.5': [],
111
- '0.7': [],
112
- '0.9': [],
113
- '1' : []
114
- },
115
- 'four_rooms': {
116
- '0.3': [],
117
- '0.5': [],
118
- '0.7': [],
119
- '0.9': [],
120
- '1' : []
121
- }},
122
- 'parking': {'gd_agent': {
123
- '0.3': [],
124
- '0.5': [],
125
- '0.7': [],
126
- '0.9': [],
127
- '1' : []
128
- },
129
- 'gc_agent': {
130
- '0.3': [],
131
- '0.5': [],
132
- '0.7': [],
133
- '0.9': [],
134
- '1' : []
135
- }},
136
- }
137
- }
25
+ fragmented_accuracies = {
26
+ "graml": {
27
+ "panda": {
28
+ "gd_agent": {
29
+ "0.3": [], # every list here should have number of tasks accuracies in it, since we done experiments for L111-L555. remember each accuracy is an average of #goals different tasks.
30
+ "0.5": [],
31
+ "0.7": [],
32
+ "0.9": [],
33
+ "1": [],
34
+ },
35
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
36
+ },
37
+ "minigrid": {
38
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
39
+ "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
40
+ },
41
+ "point_maze": {
42
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
43
+ "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
44
+ },
45
+ "parking": {
46
+ "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
47
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
48
+ },
49
+ },
50
+ "graql": {
51
+ "panda": {
52
+ "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
53
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
54
+ },
55
+ "minigrid": {
56
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
57
+ "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
58
+ },
59
+ "point_maze": {
60
+ "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
61
+ "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
62
+ },
63
+ "parking": {
64
+ "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
65
+ "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
66
+ },
67
+ },
68
+ }
69
+
70
+ continuing_accuracies = copy.deepcopy(fragmented_accuracies)
71
+
72
+ # domains = ['panda', 'minigrid', 'point_maze', 'parking']
73
+ domains = ["minigrid", "point_maze", "parking"]
74
+ tasks = ["L111", "L222", "L333", "L444", "L555"]
75
+ percentages = ["0.3", "0.5", "1"]
76
+
77
+ for partial_obs_type, accuracies, is_same_learn in zip(
78
+ ["fragmented", "continuing"],
79
+ [fragmented_accuracies, continuing_accuracies],
80
+ [False, True],
81
+ ):
82
+ for domain in domains:
83
+ for env in accuracies["graml"][domain].keys():
84
+ for task in tasks:
85
+ set_global_storage_configs(
86
+ recognizer_str="graml",
87
+ is_fragmented=partial_obs_type,
88
+ is_inference_same_length_sequences=True,
89
+ is_learn_same_length_sequences=is_same_learn,
90
+ )
91
+ graml_res_file_path = (
92
+ f"{get_experiment_results_path(domain, env, task)}.pkl"
93
+ )
94
+ set_global_storage_configs(
95
+ recognizer_str="graql", is_fragmented=partial_obs_type
96
+ )
97
+ graql_res_file_path = (
98
+ f"{get_experiment_results_path(domain, env, task)}.pkl"
99
+ )
100
+ if os.path.exists(graml_res_file_path):
101
+ with open(graml_res_file_path, "rb") as results_file:
102
+ results = dill.load(results_file)
103
+ for percentage in accuracies["graml"][domain][env].keys():
104
+ accuracies["graml"][domain][env][percentage].append(
105
+ results[percentage]["accuracy"]
106
+ )
107
+ else:
108
+ assert (False, f"no file for {graml_res_file_path}")
109
+ if os.path.exists(graql_res_file_path):
110
+ with open(graql_res_file_path, "rb") as results_file:
111
+ results = dill.load(results_file)
112
+ for percentage in accuracies["graml"][domain][env].keys():
113
+ accuracies["graql"][domain][env][percentage].append(
114
+ results[percentage]["accuracy"]
115
+ )
116
+ else:
117
+ assert (False, f"no file for {graql_res_file_path}")
118
+
119
+ plot_styles = {
120
+ ("graml", "fragmented", 0.3): "g--o", # Green dashed line with circle markers
121
+ ("graml", "fragmented", 0.5): "g--s", # Green dashed line with square markers
122
+ (
123
+ "graml",
124
+ "fragmented",
125
+ 0.7,
126
+ ): "g--^", # Green dashed line with triangle-up markers
127
+ ("graml", "fragmented", 0.9): "g--d", # Green dashed line with diamond markers
128
+ ("graml", "fragmented", 1.0): "g--*", # Green dashed line with star markers
129
+ ("graml", "continuing", 0.3): "g-o", # Green solid line with circle markers
130
+ ("graml", "continuing", 0.5): "g-s", # Green solid line with square markers
131
+ (
132
+ "graml",
133
+ "continuing",
134
+ 0.7,
135
+ ): "g-^", # Green solid line with triangle-up markers
136
+ ("graml", "continuing", 0.9): "g-d", # Green solid line with diamond markers
137
+ ("graml", "continuing", 1.0): "g-*", # Green solid line with star markers
138
+ ("graql", "fragmented", 0.3): "b--o", # Blue dashed line with circle markers
139
+ ("graql", "fragmented", 0.5): "b--s", # Blue dashed line with square markers
140
+ (
141
+ "graql",
142
+ "fragmented",
143
+ 0.7,
144
+ ): "b--^", # Blue dashed line with triangle-up markers
145
+ ("graql", "fragmented", 0.9): "b--d", # Blue dashed line with diamond markers
146
+ ("graql", "fragmented", 1.0): "b--*", # Blue dashed line with star markers
147
+ ("graql", "continuing", 0.3): "b-o", # Blue solid line with circle markers
148
+ ("graql", "continuing", 0.5): "b-s", # Blue solid line with square markers
149
+ ("graql", "continuing", 0.7): "b-^", # Blue solid line with triangle-up markers
150
+ ("graql", "continuing", 0.9): "b-d", # Blue solid line with diamond markers
151
+ ("graql", "continuing", 1.0): "b-*", # Blue solid line with star markers
152
+ }
153
+
154
+ def average_accuracies(accuracies, domain):
155
+ avg_acc = {
156
+ algo: {perc: [] for perc in percentages} for algo in ["graml", "graql"]
157
+ }
158
+
159
+ for algo in avg_acc.keys():
160
+ for perc in percentages:
161
+ for env in accuracies[algo][domain].keys():
162
+ env_acc = accuracies[algo][domain][env][
163
+ perc
164
+ ] # list of 5, averages for L111 to L555.
165
+ if env_acc:
166
+ avg_acc[algo][perc].append(np.array(env_acc))
167
+
168
+ for algo in avg_acc.keys():
169
+ for perc in percentages:
170
+ if avg_acc[algo][perc]:
171
+ avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
172
+
173
+ return avg_acc
174
+
175
+ def plot_domain_accuracies(
176
+ ax,
177
+ fragmented_accuracies,
178
+ continuing_accuracies,
179
+ domain,
180
+ sigma=1,
181
+ line_width=1.5,
182
+ ):
183
+ fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
184
+ continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
138
185
 
139
- continuing_accuracies = copy.deepcopy(fragmented_accuracies)
140
-
141
- #domains = ['panda', 'minigrid', 'point_maze', 'parking']
142
- domains = ['minigrid', 'point_maze', 'parking']
143
- tasks = ['L111', 'L222', 'L333', 'L444', 'L555']
144
- percentages = ['0.3', '0.5', '1']
186
+ x_vals = np.arange(1, 6) # Number of goals
145
187
 
146
- for partial_obs_type, accuracies, is_same_learn in zip(['fragmented', 'continuing'], [fragmented_accuracies, continuing_accuracies], [False, True]):
147
- for domain in domains:
148
- for env in accuracies['graml'][domain].keys():
149
- for task in tasks:
150
- set_global_storage_configs(recognizer_str='graml', is_fragmented=partial_obs_type,
151
- is_inference_same_length_sequences=True, is_learn_same_length_sequences=is_same_learn)
152
- graml_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
153
- set_global_storage_configs(recognizer_str='graql', is_fragmented=partial_obs_type)
154
- graql_res_file_path = f'{get_experiment_results_path(domain, env, task)}.pkl'
155
- if os.path.exists(graml_res_file_path):
156
- with open(graml_res_file_path, 'rb') as results_file:
157
- results = dill.load(results_file)
158
- for percentage in accuracies['graml'][domain][env].keys():
159
- accuracies['graml'][domain][env][percentage].append(results[percentage]['accuracy'])
160
- else:
161
- assert(False, f"no file for {graml_res_file_path}")
162
- if os.path.exists(graql_res_file_path):
163
- with open(graql_res_file_path, 'rb') as results_file:
164
- results = dill.load(results_file)
165
- for percentage in accuracies['graml'][domain][env].keys():
166
- accuracies['graql'][domain][env][percentage].append(results[percentage]['accuracy'])
167
- else:
168
- assert(False, f"no file for {graql_res_file_path}")
188
+ # Create "waves" (shaded regions) for each algorithm
189
+ for algo in ["graml", "graql"]:
190
+ fragmented_y_vals_by_percentage = []
191
+ continuing_y_vals_by_percentage = []
169
192
 
170
- plot_styles = {
171
- ('graml', 'fragmented', 0.3): 'g--o', # Green dashed line with circle markers
172
- ('graml', 'fragmented', 0.5): 'g--s', # Green dashed line with square markers
173
- ('graml', 'fragmented', 0.7): 'g--^', # Green dashed line with triangle-up markers
174
- ('graml', 'fragmented', 0.9): 'g--d', # Green dashed line with diamond markers
175
- ('graml', 'fragmented', 1.0): 'g--*', # Green dashed line with star markers
176
-
177
- ('graml', 'continuing', 0.3): 'g-o', # Green solid line with circle markers
178
- ('graml', 'continuing', 0.5): 'g-s', # Green solid line with square markers
179
- ('graml', 'continuing', 0.7): 'g-^', # Green solid line with triangle-up markers
180
- ('graml', 'continuing', 0.9): 'g-d', # Green solid line with diamond markers
181
- ('graml', 'continuing', 1.0): 'g-*', # Green solid line with star markers
182
-
183
- ('graql', 'fragmented', 0.3): 'b--o', # Blue dashed line with circle markers
184
- ('graql', 'fragmented', 0.5): 'b--s', # Blue dashed line with square markers
185
- ('graql', 'fragmented', 0.7): 'b--^', # Blue dashed line with triangle-up markers
186
- ('graql', 'fragmented', 0.9): 'b--d', # Blue dashed line with diamond markers
187
- ('graql', 'fragmented', 1.0): 'b--*', # Blue dashed line with star markers
188
-
189
- ('graql', 'continuing', 0.3): 'b-o', # Blue solid line with circle markers
190
- ('graql', 'continuing', 0.5): 'b-s', # Blue solid line with square markers
191
- ('graql', 'continuing', 0.7): 'b-^', # Blue solid line with triangle-up markers
192
- ('graql', 'continuing', 0.9): 'b-d', # Blue solid line with diamond markers
193
- ('graql', 'continuing', 1.0): 'b-*', # Blue solid line with star markers
194
- }
193
+ for perc in percentages:
194
+ fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
195
+ continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
195
196
 
196
- def average_accuracies(accuracies, domain):
197
- avg_acc = {algo: {perc: [] for perc in percentages}
198
- for algo in ['graml', 'graql']}
199
-
200
- for algo in avg_acc.keys():
201
- for perc in percentages:
202
- for env in accuracies[algo][domain].keys():
203
- env_acc = accuracies[algo][domain][env][perc] # list of 5, averages for L111 to L555.
204
- if env_acc:
205
- avg_acc[algo][perc].append(np.array(env_acc))
206
-
207
- for algo in avg_acc.keys():
208
- for perc in percentages:
209
- if avg_acc[algo][perc]:
210
- avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
211
-
212
- return avg_acc
197
+ # Smooth the trends using Gaussian filtering
198
+ fragmented_y_smoothed = gaussian_filter1d(
199
+ fragmented_y_vals, sigma=sigma
200
+ )
201
+ continuing_y_smoothed = gaussian_filter1d(
202
+ continuing_y_vals, sigma=sigma
203
+ )
213
204
 
214
- def plot_domain_accuracies(ax, fragmented_accuracies, continuing_accuracies, domain, sigma=1, line_width=1.5):
215
- fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
216
- continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
217
-
218
- x_vals = np.arange(1, 6) # Number of goals
219
-
220
- # Create "waves" (shaded regions) for each algorithm
221
- for algo in ['graml', 'graql']:
222
- fragmented_y_vals_by_percentage = []
223
- continuing_y_vals_by_percentage = []
205
+ fragmented_y_vals_by_percentage.append(fragmented_y_smoothed)
206
+ continuing_y_vals_by_percentage.append(continuing_y_smoothed)
224
207
 
225
- for perc in percentages:
226
- fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
227
- continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
208
+ ax.plot(
209
+ x_vals,
210
+ fragmented_y_smoothed,
211
+ plot_styles[(algo, "fragmented", float(perc))],
212
+ label=f"{algo}, non-consecutive, {perc}",
213
+ linewidth=0.5, # Control line thickness here
214
+ )
215
+ ax.plot(
216
+ x_vals,
217
+ continuing_y_smoothed,
218
+ plot_styles[(algo, "continuing", float(perc))],
219
+ label=f"{algo}, consecutive, {perc}",
220
+ linewidth=0.5, # Control line thickness here
221
+ )
228
222
 
229
- # Smooth the trends using Gaussian filtering
230
- fragmented_y_smoothed = gaussian_filter1d(fragmented_y_vals, sigma=sigma)
231
- continuing_y_smoothed = gaussian_filter1d(continuing_y_vals, sigma=sigma)
223
+ ax.set_xticks(x_vals)
224
+ ax.set_yticks(np.linspace(0, 1, 6))
225
+ ax.set_ylim([0, 1])
226
+ ax.set_title(f"{domain.capitalize()} Domain", fontsize=16)
227
+ ax.grid(True)
232
228
 
233
- fragmented_y_vals_by_percentage.append(fragmented_y_smoothed)
234
- continuing_y_vals_by_percentage.append(continuing_y_smoothed)
229
+ fig, axes = plt.subplots(
230
+ 1, 4, figsize=(24, 6)
231
+ ) # Increase the figure size for better spacing (width 24, height 6)
235
232
 
236
- ax.plot(
237
- x_vals, fragmented_y_smoothed,
238
- plot_styles[(algo, 'fragmented', float(perc))],
239
- label=f"{algo}, non-consecutive, {perc}",
240
- linewidth=0.5 # Control line thickness here
241
- )
242
- ax.plot(
243
- x_vals, continuing_y_smoothed,
244
- plot_styles[(algo, 'continuing', float(perc))],
245
- label=f"{algo}, consecutive, {perc}",
246
- linewidth=0.5 # Control line thickness here
247
- )
248
-
249
- ax.set_xticks(x_vals)
250
- ax.set_yticks(np.linspace(0, 1, 6))
251
- ax.set_ylim([0, 1])
252
- ax.set_title(f'{domain.capitalize()} Domain', fontsize=16)
253
- ax.grid(True)
233
+ # Generate each plot in a subplot, including both fragmented and continuing accuracies
234
+ for i, domain in enumerate(domains):
235
+ plot_domain_accuracies(
236
+ axes[i], fragmented_accuracies, continuing_accuracies, domain
237
+ )
254
238
 
255
- fig, axes = plt.subplots(1, 4, figsize=(24, 6)) # Increase the figure size for better spacing (width 24, height 6)
256
-
257
- # Generate each plot in a subplot, including both fragmented and continuing accuracies
258
- for i, domain in enumerate(domains):
259
- plot_domain_accuracies(axes[i], fragmented_accuracies, continuing_accuracies, domain)
239
+ # Set a single x-axis and y-axis label for the entire figure
240
+ fig.text(
241
+ 0.5, 0.04, "Number of Goals", ha="center", fontsize=20
242
+ ) # Centered x-axis label
243
+ fig.text(
244
+ 0.04, 0.5, "Accuracy", va="center", rotation="vertical", fontsize=20
245
+ ) # Reduced spacing for y-axis label
260
246
 
261
- # Set a single x-axis and y-axis label for the entire figure
262
- fig.text(0.5, 0.04, 'Number of Goals', ha='center', fontsize=20) # Centered x-axis label
263
- fig.text(0.04, 0.5, 'Accuracy', va='center', rotation='vertical', fontsize=20) # Reduced spacing for y-axis label
247
+ # Adjust subplot layout to avoid overlap
248
+ plt.subplots_adjust(
249
+ left=0.09, right=0.91, top=0.79, bottom=0.21, wspace=0.3
250
+ ) # More space on top (top=0.82)
264
251
 
265
- # Adjust subplot layout to avoid overlap
266
- plt.subplots_adjust(left=0.09, right=0.91, top=0.79, bottom=0.21, wspace=0.3) # More space on top (top=0.82)
267
-
268
- # Place the legend above the plots with more space between legend and plots
269
- handles, labels = axes[0].get_legend_handles_labels()
270
- fig.legend(handles, labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.05), fontsize=12) # Moved above with bbox_to_anchor
252
+ # Place the legend above the plots with more space between legend and plots
253
+ handles, labels = axes[0].get_legend_handles_labels()
254
+ fig.legend(
255
+ handles,
256
+ labels,
257
+ loc="upper center",
258
+ ncol=4,
259
+ bbox_to_anchor=(0.5, 1.05),
260
+ fontsize=12,
261
+ ) # Moved above with bbox_to_anchor
271
262
 
272
- # Save the figure and show it
273
- save_dir = os.path.join('figures', 'all_domains_accuracy_plots')
274
- if not os.path.exists(save_dir):
275
- os.makedirs(save_dir)
276
- plt.savefig(os.path.join(save_dir, 'accuracy_plots_smooth.png'), dpi=300)
277
-
263
+ # Save the figure and show it
264
+ save_dir = os.path.join("figures", "all_domains_accuracy_plots")
265
+ if not os.path.exists(save_dir):
266
+ os.makedirs(save_dir)
267
+ plt.savefig(os.path.join(save_dir, "accuracy_plots_smooth.png"), dpi=300)
@@ -2,21 +2,25 @@ from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
2
2
  import numpy as np
3
3
  import gr_libs.ml as ml
4
4
  from minigrid.core.world_object import Wall
5
- #from q_table_plot import save_q_table_plot_image
5
+
6
+ # from q_table_plot import save_q_table_plot_image
6
7
  from gymnasium.envs.registration import register
7
8
 
8
9
  env_name = "MiniGrid-SimpleCrossingS13N4-DynamicGoal-5x9-v0"
9
10
  # create an agent and train it (if it is already trained, it will get q-table from cache)
10
- agent = ml.TabularQLearner(env_name='MiniGrid-Walls-13x13-v0',problem_name = "MiniGrid-SimpleCrossingS13N4-DynamicGoal-5x9-v0")
11
+ agent = ml.TabularQLearner(
12
+ env_name="MiniGrid-Walls-13x13-v0",
13
+ problem_name="MiniGrid-SimpleCrossingS13N4-DynamicGoal-5x9-v0",
14
+ )
11
15
  # agent.learn()
12
16
 
13
17
  # save_q_table_plot_image(agent.q_table, 15, 15, (10,7))
14
18
 
15
19
  # add to the steps list the step the trained agent would take on the env in every state according to the q_table
16
20
  env = agent.env
17
- env = RGBImgPartialObsWrapper(env) # Get pixel observations
18
- env = ImgObsWrapper(env) # Get rid of the 'mission' field
19
- obs, _ = env.reset() # This now produces an RGB tensor only
21
+ env = RGBImgPartialObsWrapper(env) # Get pixel observations
22
+ env = ImgObsWrapper(env) # Get rid of the 'mission' field
23
+ obs, _ = env.reset() # This now produces an RGB tensor only
20
24
 
21
25
  img = env.get_frame()
22
26
 
@@ -24,7 +28,7 @@ img = env.get_frame()
24
28
  from PIL import Image
25
29
  import numpy as np
26
30
 
27
- image_pil = Image.fromarray(np.uint8(img)).convert('RGB')
31
+ image_pil = Image.fromarray(np.uint8(img)).convert("RGB")
28
32
  image_pil.save(r"{}.png".format(env_name))
29
33
 
30
34
  # ####### show image
evaluation/file_system.py CHANGED
@@ -4,26 +4,36 @@ import random
4
4
  import hashlib
5
5
  from typing import List
6
6
 
7
+
7
8
  def get_observations_path(env_name: str):
8
9
  return f"dataset/{env_name}/observations"
9
10
 
11
+
10
12
  def get_observations_paths(path: str):
11
13
  return [os.path.join(path, file_name) for file_name in os.listdir(path)]
12
14
 
15
+
13
16
  def create_partial_observabilities_files(env_name: str, observabilities: List[float]):
14
- with open(r"dataset/{env_name}/observations/obs1.0.pkl".format(env_name=env_name), "rb") as f:
17
+ with open(
18
+ r"dataset/{env_name}/observations/obs1.0.pkl".format(env_name=env_name), "rb"
19
+ ) as f:
15
20
  step_1_0 = dill.load(f)
16
21
 
17
- number_of_items_to_randomize = [int(observability * len(step_1_0)) for observability in observabilities]
22
+ number_of_items_to_randomize = [
23
+ int(observability * len(step_1_0)) for observability in observabilities
24
+ ]
18
25
  obs = []
19
26
  for items_to_randomize in number_of_items_to_randomize:
20
27
  obs.append(random.sample(step_1_0, items_to_randomize))
21
28
  for index, observability in enumerate(observabilities):
22
29
  partial_steps = obs[index]
23
- file_path = r"dataset/{env_name}/observations/obs{obs}.pkl".format(env_name=env_name, obs=observability)
30
+ file_path = r"dataset/{env_name}/observations/obs{obs}.pkl".format(
31
+ env_name=env_name, obs=observability
32
+ )
24
33
  with open(file_path, "wb+") as f:
25
34
  dill.dump(partial_steps, f)
26
-
35
+
36
+
27
37
  def md5(file_path: str):
28
38
  hash_md5 = hashlib.md5()
29
39
  with open(file_path, "rb") as f:
@@ -31,6 +41,7 @@ def md5(file_path: str):
31
41
  hash_md5.update(chunk)
32
42
  return hash_md5.hexdigest()
33
43
 
44
+
34
45
  def get_md5(file_path_list: List[str]):
35
46
  return [(file_path, md5(file_path=file_path)) for file_path in file_path_list]
36
47
 
@@ -39,4 +50,4 @@ def print_md5(file_path_list: List[str]):
39
50
  md5_of_observations = get_md5(file_path_list=file_path_list)
40
51
  for file_name, file_md5 in md5_of_observations:
41
52
  print(f"{file_name}:{file_md5}")
42
- print("")
53
+ print("")