gr-libs 0.1.8__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. gr_libs/__init__.py +3 -1
  2. gr_libs/_version.py +2 -2
  3. gr_libs/all_experiments.py +260 -0
  4. gr_libs/environment/__init__.py +14 -1
  5. gr_libs/environment/_utils/__init__.py +0 -0
  6. gr_libs/environment/{utils → _utils}/utils.py +1 -1
  7. gr_libs/environment/environment.py +278 -23
  8. gr_libs/evaluation/__init__.py +1 -0
  9. gr_libs/evaluation/generate_experiments_results.py +100 -0
  10. gr_libs/metrics/__init__.py +2 -0
  11. gr_libs/metrics/metrics.py +166 -31
  12. gr_libs/ml/__init__.py +1 -6
  13. gr_libs/ml/base/__init__.py +3 -1
  14. gr_libs/ml/base/rl_agent.py +68 -3
  15. gr_libs/ml/neural/__init__.py +1 -3
  16. gr_libs/ml/neural/deep_rl_learner.py +241 -84
  17. gr_libs/ml/neural/utils/__init__.py +1 -2
  18. gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +1 -1
  19. gr_libs/ml/planner/mcts/mcts_model.py +71 -34
  20. gr_libs/ml/sequential/__init__.py +0 -1
  21. gr_libs/ml/sequential/{lstm_model.py → _lstm_model.py} +11 -14
  22. gr_libs/ml/tabular/__init__.py +1 -3
  23. gr_libs/ml/tabular/tabular_q_learner.py +27 -9
  24. gr_libs/ml/tabular/tabular_rl_agent.py +22 -9
  25. gr_libs/ml/utils/__init__.py +2 -9
  26. gr_libs/ml/utils/format.py +13 -90
  27. gr_libs/ml/utils/math.py +3 -2
  28. gr_libs/ml/utils/other.py +2 -2
  29. gr_libs/ml/utils/storage.py +41 -94
  30. gr_libs/odgr_executor.py +263 -0
  31. gr_libs/problems/consts.py +570 -292
  32. gr_libs/recognizer/{utils → _utils}/format.py +2 -2
  33. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +127 -36
  34. gr_libs/recognizer/graml/{gr_dataset.py → _gr_dataset.py} +11 -11
  35. gr_libs/recognizer/graml/graml_recognizer.py +186 -35
  36. gr_libs/recognizer/recognizer.py +59 -10
  37. gr_libs/tutorials/draco_panda_tutorial.py +58 -0
  38. gr_libs/tutorials/draco_parking_tutorial.py +56 -0
  39. {tutorials → gr_libs/tutorials}/gcdraco_panda_tutorial.py +11 -11
  40. {tutorials → gr_libs/tutorials}/gcdraco_parking_tutorial.py +6 -8
  41. {tutorials → gr_libs/tutorials}/graml_minigrid_tutorial.py +18 -14
  42. {tutorials → gr_libs/tutorials}/graml_panda_tutorial.py +11 -12
  43. {tutorials → gr_libs/tutorials}/graml_parking_tutorial.py +8 -10
  44. {tutorials → gr_libs/tutorials}/graml_point_maze_tutorial.py +17 -3
  45. {tutorials → gr_libs/tutorials}/graql_minigrid_tutorial.py +2 -2
  46. {gr_libs-0.1.8.dist-info → gr_libs-0.2.5.dist-info}/METADATA +95 -29
  47. gr_libs-0.2.5.dist-info/RECORD +72 -0
  48. {gr_libs-0.1.8.dist-info → gr_libs-0.2.5.dist-info}/WHEEL +1 -1
  49. gr_libs-0.2.5.dist-info/top_level.txt +2 -0
  50. tests/test_draco.py +14 -0
  51. tests/test_gcdraco.py +2 -2
  52. tests/test_graml.py +4 -4
  53. tests/test_graql.py +1 -1
  54. tests/test_odgr_executor_expertbasedgraml.py +14 -0
  55. tests/test_odgr_executor_gcdraco.py +14 -0
  56. tests/test_odgr_executor_gcgraml.py +14 -0
  57. tests/test_odgr_executor_graql.py +14 -0
  58. evaluation/analyze_results_cross_alg_cross_domain.py +0 -267
  59. evaluation/create_minigrid_map_image.py +0 -38
  60. evaluation/file_system.py +0 -53
  61. evaluation/generate_experiments_results.py +0 -141
  62. evaluation/generate_experiments_results_new_ver1.py +0 -238
  63. evaluation/generate_experiments_results_new_ver2.py +0 -331
  64. evaluation/generate_task_specific_statistics_plots.py +0 -500
  65. evaluation/get_plans_images.py +0 -62
  66. evaluation/increasing_and_decreasing_.py +0 -104
  67. gr_libs/ml/neural/utils/penv.py +0 -60
  68. gr_libs-0.1.8.dist-info/RECORD +0 -70
  69. gr_libs-0.1.8.dist-info/top_level.txt +0 -4
  70. /gr_libs/{environment/utils/__init__.py → _evaluation/_generate_experiments_results.py} +0 -0
  71. /gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +0 -0
  72. /gr_libs/ml/planner/mcts/{utils → _utils}/node.py +0 -0
  73. /gr_libs/recognizer/{utils → _utils}/__init__.py +0 -0
@@ -1,331 +0,0 @@
1
- import copy
2
- import sys
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
- import os
6
- import dill
7
- from scipy.interpolate import make_interp_spline
8
- from scipy.ndimage import gaussian_filter1d
9
- from gr_libs.ml.utils.storage import (
10
- get_experiment_results_path,
11
- set_global_storage_configs,
12
- )
13
- from scripts.generate_task_specific_statistics_plots import get_figures_dir_path
14
-
15
-
16
- def smooth_line(x, y, num_points=300):
17
- x_smooth = np.linspace(np.min(x), np.max(x), num_points)
18
- spline = make_interp_spline(x, y, k=3) # Cubic spline
19
- y_smooth = spline(x_smooth)
20
- return x_smooth, y_smooth
21
-
22
-
23
- if __name__ == "__main__":
24
-
25
- fragmented_accuracies = {
26
- "graml": {
27
- "panda": {
28
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []}
29
- },
30
- "minigrid": {
31
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
32
- "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
33
- },
34
- "point_maze": {
35
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
36
- "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
37
- },
38
- "parking": {
39
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
40
- "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
41
- },
42
- },
43
- "graql": {
44
- "panda": {
45
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []}
46
- },
47
- "minigrid": {
48
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
49
- "lava_crossing": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
50
- },
51
- "point_maze": {
52
- "obstacles": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
53
- "four_rooms": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
54
- },
55
- "parking": {
56
- "gc_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
57
- "gd_agent": {"0.3": [], "0.5": [], "0.7": [], "0.9": [], "1": []},
58
- },
59
- },
60
- }
61
-
62
- continuing_accuracies = copy.deepcopy(fragmented_accuracies)
63
-
64
- # domains = ['panda', 'minigrid', 'point_maze', 'parking']
65
- domains = ["parking"]
66
- tasks = ["L555", "L444", "L333", "L222", "L111"]
67
- percentages = ["0.3", "0.5", "1"]
68
-
69
- for partial_obs_type, accuracies, is_same_learn in zip(
70
- ["fragmented", "continuing"],
71
- [fragmented_accuracies, continuing_accuracies],
72
- [False, True],
73
- ):
74
- for domain in domains:
75
- for env in accuracies["graml"][domain].keys():
76
- for task in tasks:
77
- set_global_storage_configs(
78
- recognizer_str="graml",
79
- is_fragmented=partial_obs_type,
80
- is_inference_same_length_sequences=True,
81
- is_learn_same_length_sequences=is_same_learn,
82
- )
83
- graml_res_file_path = (
84
- f"{get_experiment_results_path(domain, env, task)}.pkl"
85
- )
86
- set_global_storage_configs(
87
- recognizer_str="graql", is_fragmented=partial_obs_type
88
- )
89
- graql_res_file_path = (
90
- f"{get_experiment_results_path(domain, env, task)}.pkl"
91
- )
92
- if os.path.exists(graml_res_file_path):
93
- with open(graml_res_file_path, "rb") as results_file:
94
- results = dill.load(results_file)
95
- for percentage in accuracies["graml"][domain][env].keys():
96
- accuracies["graml"][domain][env][percentage].append(
97
- results[percentage]["accuracy"]
98
- )
99
- else:
100
- assert (False, f"no file for {graml_res_file_path}")
101
- if os.path.exists(graql_res_file_path):
102
- with open(graql_res_file_path, "rb") as results_file:
103
- results = dill.load(results_file)
104
- for percentage in accuracies["graml"][domain][env].keys():
105
- accuracies["graql"][domain][env][percentage].append(
106
- results[percentage]["accuracy"]
107
- )
108
- else:
109
- assert (False, f"no file for {graql_res_file_path}")
110
-
111
- plot_styles = {
112
- ("graml", "fragmented", 0.3): "g--o", # Green dashed line with circle markers
113
- ("graml", "fragmented", 0.5): "g--s", # Green dashed line with square markers
114
- (
115
- "graml",
116
- "fragmented",
117
- 0.7,
118
- ): "g--^", # Green dashed line with triangle-up markers
119
- ("graml", "fragmented", 0.9): "g--d", # Green dashed line with diamond markers
120
- ("graml", "fragmented", 1.0): "g--*", # Green dashed line with star markers
121
- ("graml", "continuing", 0.3): "g-o", # Green solid line with circle markers
122
- ("graml", "continuing", 0.5): "g-s", # Green solid line with square markers
123
- (
124
- "graml",
125
- "continuing",
126
- 0.7,
127
- ): "g-^", # Green solid line with triangle-up markers
128
- ("graml", "continuing", 0.9): "g-d", # Green solid line with diamond markers
129
- ("graml", "continuing", 1.0): "g-*", # Green solid line with star markers
130
- ("graql", "fragmented", 0.3): "b--o", # Blue dashed line with circle markers
131
- ("graql", "fragmented", 0.5): "b--s", # Blue dashed line with square markers
132
- (
133
- "graql",
134
- "fragmented",
135
- 0.7,
136
- ): "b--^", # Blue dashed line with triangle-up markers
137
- ("graql", "fragmented", 0.9): "b--d", # Blue dashed line with diamond markers
138
- ("graql", "fragmented", 1.0): "b--*", # Blue dashed line with star markers
139
- ("graql", "continuing", 0.3): "b-o", # Blue solid line with circle markers
140
- ("graql", "continuing", 0.5): "b-s", # Blue solid line with square markers
141
- ("graql", "continuing", 0.7): "b-^", # Blue solid line with triangle-up markers
142
- ("graql", "continuing", 0.9): "b-d", # Blue solid line with diamond markers
143
- ("graql", "continuing", 1.0): "b-*", # Blue solid line with star markers
144
- }
145
-
146
- def average_accuracies(accuracies, domain):
147
- avg_acc = {
148
- algo: {perc: [] for perc in percentages} for algo in ["graml", "graql"]
149
- }
150
-
151
- for algo in avg_acc.keys():
152
- for perc in percentages:
153
- for env in accuracies[algo][domain].keys():
154
- env_acc = accuracies[algo][domain][env][
155
- perc
156
- ] # list of 5, averages for L111 to L555.
157
- if env_acc:
158
- avg_acc[algo][perc].append(np.array(env_acc))
159
-
160
- for algo in avg_acc.keys():
161
- for perc in percentages:
162
- if avg_acc[algo][perc]:
163
- avg_acc[algo][perc] = np.mean(np.array(avg_acc[algo][perc]), axis=0)
164
-
165
- return avg_acc
166
-
167
- def plot_domain_accuracies(
168
- ax,
169
- fragmented_accuracies,
170
- continuing_accuracies,
171
- domain,
172
- sigma=1,
173
- line_width=1.5,
174
- ):
175
- fragmented_avg_acc = average_accuracies(fragmented_accuracies, domain)
176
- continuing_avg_acc = average_accuracies(continuing_accuracies, domain)
177
-
178
- x_vals = np.arange(1, 6) # Number of goals
179
-
180
- # Create "waves" (shaded regions) for each algorithm
181
- for algo in ["graml", "graql"]:
182
- fragmented_y_vals_by_percentage = []
183
- continuing_y_vals_by_percentage = []
184
-
185
- for perc in percentages:
186
- fragmented_y_vals = np.array(fragmented_avg_acc[algo][perc])
187
- continuing_y_vals = np.array(continuing_avg_acc[algo][perc])
188
-
189
- # Smooth the trends using Gaussian filtering
190
- fragmented_y_smoothed = gaussian_filter1d(
191
- fragmented_y_vals, sigma=sigma
192
- )
193
- continuing_y_smoothed = gaussian_filter1d(
194
- continuing_y_vals, sigma=sigma
195
- )
196
-
197
- fragmented_y_vals_by_percentage.append(fragmented_y_smoothed)
198
- continuing_y_vals_by_percentage.append(continuing_y_smoothed)
199
-
200
- ax.plot(
201
- x_vals,
202
- fragmented_y_smoothed,
203
- plot_styles[(algo, "fragmented", float(perc))],
204
- label=f"{algo}, non-consecutive, {perc}",
205
- linewidth=0.5, # Control line thickness here
206
- )
207
- ax.plot(
208
- x_vals,
209
- continuing_y_smoothed,
210
- plot_styles[(algo, "continuing", float(perc))],
211
- label=f"{algo}, consecutive, {perc}",
212
- linewidth=0.5, # Control line thickness here
213
- )
214
-
215
- # Fill between trends of the same type that differ only by percentage
216
- # for i in range(len(percentages) - 1):
217
- # ax.fill_between(
218
- # x_vals, fragmented_y_vals_by_percentage[i], fragmented_y_vals_by_percentage[i+1],
219
- # color='green', alpha=0.1 # Adjust the fill color and transparency (for graml)
220
- # )
221
- # ax.fill_between(
222
- # x_vals, continuing_y_vals_by_percentage[i], continuing_y_vals_by_percentage[i+1],
223
- # color='blue', alpha=0.1 # Adjust the fill color and transparency (for graql)
224
- # )
225
-
226
- ax.set_xticks(x_vals)
227
- ax.set_yticks(np.linspace(0, 1, 6))
228
- ax.set_ylim([0, 1])
229
- ax.set_title(f"{domain.capitalize()} Domain", fontsize=16)
230
- ax.grid(True)
231
-
232
- # COMMENT FROM HERE AND UNTIL NEXT FUNCTION FOR BG GC COMPARISON
233
-
234
- # fig, axes = plt.subplots(1, 4, figsize=(24, 6)) # Increase the figure size for better spacing (width 24, height 6)
235
-
236
- # # Generate each plot in a subplot, including both fragmented and continuing accuracies
237
- # for i, domain in enumerate(domains):
238
- # plot_domain_accuracies(axes[i], fragmented_accuracies, continuing_accuracies, domain)
239
-
240
- # # Set a single x-axis and y-axis label for the entire figure
241
- # fig.text(0.5, 0.04, 'Number of Goals', ha='center', fontsize=20) # Centered x-axis label
242
- # fig.text(0.04, 0.5, 'Accuracy', va='center', rotation='vertical', fontsize=20) # Reduced spacing for y-axis label
243
-
244
- # # Adjust subplot layout to avoid overlap
245
- # plt.subplots_adjust(left=0.09, right=0.91, top=0.79, bottom=0.21, wspace=0.3) # More space on top (top=0.82)
246
-
247
- # # Place the legend above the plots with more space between legend and plots
248
- # handles, labels = axes[0].get_legend_handles_labels()
249
- # fig.legend(handles, labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.05), fontsize=12) # Moved above with bbox_to_anchor
250
-
251
- # # Save the figure and show it
252
- # plt.savefig('accuracy_plots_smooth.png', dpi=300)
253
-
254
- # a specific comparison between bg-graml and gc-graml, aka gd_agent and gc_agent "envs":
255
- def plot_stick_figures(continuing_accuracies, fragmented_accuracies, title):
256
- fractions = ["0.3", "0.5", "1"]
257
-
258
- def get_agent_data(data_dict, domain="graml", agent="gd_agent"):
259
- return [
260
- np.mean(data_dict[domain]["parking"][agent][fraction])
261
- for fraction in fractions
262
- ]
263
-
264
- # Continuing accuracies for gd_agent and gc_agent
265
- cont_gd = get_agent_data(
266
- continuing_accuracies, domain="graml", agent="gd_agent"
267
- )
268
- cont_gc = get_agent_data(
269
- continuing_accuracies, domain="graml", agent="gc_agent"
270
- )
271
-
272
- # Fragmented accuracies for gd_agent and gc_agent
273
- frag_gd = get_agent_data(
274
- fragmented_accuracies, domain="graml", agent="gd_agent"
275
- )
276
- frag_gc = get_agent_data(
277
- fragmented_accuracies, domain="graml", agent="gc_agent"
278
- )
279
-
280
- # Debugging: Print values to check if they're non-zero
281
- print("Continuing GD:", cont_gd)
282
- print("Continuing GC:", cont_gc)
283
- print("Fragmented GD:", frag_gd)
284
- print("Fragmented GC:", frag_gc)
285
-
286
- # Setting up figure
287
- x = np.arange(len(fractions)) # label locations
288
- width = 0.35 # width of the bars
289
-
290
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
291
-
292
- # Plot for continuing accuracies
293
- ax1.bar(x - width / 2, cont_gd, width, label="BG-GRAML")
294
- ax1.bar(x + width / 2, cont_gc, width, label="GC-GRAML")
295
- ax1.set_title("Consecutive Sequences", fontsize=20)
296
- ax1.set_xticks(x)
297
- ax1.set_xticklabels(fractions, fontsize=16)
298
- ax1.set_yticks(np.arange(0, 1.1, 0.2))
299
- ax1.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
300
- ax1.legend(fontsize=20)
301
-
302
- # Plot for fragmented accuracies
303
- ax2.bar(x - width / 2, frag_gd, width, label="BG-GRAML")
304
- ax2.bar(x + width / 2, frag_gc, width, label="GC-GRAML")
305
- ax2.set_title("Non-Consecutive Sequences", fontsize=20)
306
- ax2.set_xticks(x)
307
- ax2.set_xticklabels(fractions, fontsize=16)
308
- ax2.set_yticks(np.arange(0, 1.1, 0.2))
309
- ax2.set_yticklabels(np.round(np.arange(0, 1.1, 0.2), 1), fontsize=16)
310
- ax2.set_ylim(0, 1) # Ensure the y-axis is properly set
311
- ax2.legend(fontsize=20)
312
- # Common axis labels
313
- fig.text(
314
- 0.5, 0.02, "Observation Portion", ha="center", va="center", fontsize=24
315
- )
316
- fig.text(
317
- 0.06,
318
- 0.5,
319
- "Accuracy",
320
- ha="center",
321
- va="center",
322
- rotation="vertical",
323
- fontsize=24,
324
- )
325
-
326
- plt.subplots_adjust(top=0.85)
327
- plt.savefig("gd_vs_gc_parking.png", dpi=300)
328
-
329
- plot_stick_figures(
330
- continuing_accuracies, fragmented_accuracies, "GC-GRAML compared with BG-GRAML"
331
- )