gr-libs 0.1.8__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gr_libs/__init__.py +3 -1
- gr_libs/_version.py +2 -2
- gr_libs/all_experiments.py +260 -0
- gr_libs/environment/__init__.py +14 -1
- gr_libs/environment/_utils/__init__.py +0 -0
- gr_libs/environment/{utils → _utils}/utils.py +1 -1
- gr_libs/environment/environment.py +278 -23
- gr_libs/evaluation/__init__.py +1 -0
- gr_libs/evaluation/generate_experiments_results.py +100 -0
- gr_libs/metrics/__init__.py +2 -0
- gr_libs/metrics/metrics.py +166 -31
- gr_libs/ml/__init__.py +1 -6
- gr_libs/ml/base/__init__.py +3 -1
- gr_libs/ml/base/rl_agent.py +68 -3
- gr_libs/ml/neural/__init__.py +1 -3
- gr_libs/ml/neural/deep_rl_learner.py +241 -84
- gr_libs/ml/neural/utils/__init__.py +1 -2
- gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +1 -1
- gr_libs/ml/planner/mcts/mcts_model.py +71 -34
- gr_libs/ml/sequential/__init__.py +0 -1
- gr_libs/ml/sequential/{lstm_model.py → _lstm_model.py} +11 -14
- gr_libs/ml/tabular/__init__.py +1 -3
- gr_libs/ml/tabular/tabular_q_learner.py +27 -9
- gr_libs/ml/tabular/tabular_rl_agent.py +22 -9
- gr_libs/ml/utils/__init__.py +2 -9
- gr_libs/ml/utils/format.py +13 -90
- gr_libs/ml/utils/math.py +3 -2
- gr_libs/ml/utils/other.py +2 -2
- gr_libs/ml/utils/storage.py +41 -94
- gr_libs/odgr_executor.py +263 -0
- gr_libs/problems/consts.py +570 -292
- gr_libs/recognizer/{utils → _utils}/format.py +2 -2
- gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +127 -36
- gr_libs/recognizer/graml/{gr_dataset.py → _gr_dataset.py} +11 -11
- gr_libs/recognizer/graml/graml_recognizer.py +186 -35
- gr_libs/recognizer/recognizer.py +59 -10
- gr_libs/tutorials/draco_panda_tutorial.py +58 -0
- gr_libs/tutorials/draco_parking_tutorial.py +56 -0
- {tutorials → gr_libs/tutorials}/gcdraco_panda_tutorial.py +11 -11
- {tutorials → gr_libs/tutorials}/gcdraco_parking_tutorial.py +6 -8
- {tutorials → gr_libs/tutorials}/graml_minigrid_tutorial.py +18 -14
- {tutorials → gr_libs/tutorials}/graml_panda_tutorial.py +11 -12
- {tutorials → gr_libs/tutorials}/graml_parking_tutorial.py +8 -10
- {tutorials → gr_libs/tutorials}/graml_point_maze_tutorial.py +17 -3
- {tutorials → gr_libs/tutorials}/graql_minigrid_tutorial.py +2 -2
- {gr_libs-0.1.8.dist-info → gr_libs-0.2.5.dist-info}/METADATA +95 -29
- gr_libs-0.2.5.dist-info/RECORD +72 -0
- {gr_libs-0.1.8.dist-info → gr_libs-0.2.5.dist-info}/WHEEL +1 -1
- gr_libs-0.2.5.dist-info/top_level.txt +2 -0
- tests/test_draco.py +14 -0
- tests/test_gcdraco.py +2 -2
- tests/test_graml.py +4 -4
- tests/test_graql.py +1 -1
- tests/test_odgr_executor_expertbasedgraml.py +14 -0
- tests/test_odgr_executor_gcdraco.py +14 -0
- tests/test_odgr_executor_gcgraml.py +14 -0
- tests/test_odgr_executor_graql.py +14 -0
- evaluation/analyze_results_cross_alg_cross_domain.py +0 -267
- evaluation/create_minigrid_map_image.py +0 -38
- evaluation/file_system.py +0 -53
- evaluation/generate_experiments_results.py +0 -141
- evaluation/generate_experiments_results_new_ver1.py +0 -238
- evaluation/generate_experiments_results_new_ver2.py +0 -331
- evaluation/generate_task_specific_statistics_plots.py +0 -500
- evaluation/get_plans_images.py +0 -62
- evaluation/increasing_and_decreasing_.py +0 -104
- gr_libs/ml/neural/utils/penv.py +0 -60
- gr_libs-0.1.8.dist-info/RECORD +0 -70
- gr_libs-0.1.8.dist-info/top_level.txt +0 -4
- /gr_libs/{environment/utils/__init__.py → _evaluation/_generate_experiments_results.py} +0 -0
- /gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +0 -0
- /gr_libs/ml/planner/mcts/{utils → _utils}/node.py +0 -0
- /gr_libs/recognizer/{utils → _utils}/__init__.py +0 -0
@@ -1,14 +1,18 @@
|
|
1
|
+
""" environment.py """
|
2
|
+
|
3
|
+
import os
|
4
|
+
import sys
|
1
5
|
from abc import abstractmethod
|
2
6
|
from collections import namedtuple
|
3
|
-
import
|
7
|
+
from contextlib import contextmanager
|
4
8
|
|
5
9
|
import gymnasium as gym
|
6
|
-
from stable_baselines3.common.vec_env import DummyVecEnv
|
7
|
-
from PIL import Image
|
8
10
|
import numpy as np
|
9
11
|
from gymnasium.envs.registration import register
|
10
|
-
from minigrid.core.world_object import
|
11
|
-
from minigrid.wrappers import
|
12
|
+
from minigrid.core.world_object import Lava, Wall
|
13
|
+
from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
|
14
|
+
from PIL import Image
|
15
|
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
12
16
|
|
13
17
|
MINIGRID, PANDA, PARKING, POINT_MAZE = "minigrid", "panda", "parking", "point_maze"
|
14
18
|
|
@@ -21,112 +25,226 @@ LSTMProperties = namedtuple(
|
|
21
25
|
)
|
22
26
|
|
23
27
|
|
28
|
+
@contextmanager
|
29
|
+
def suppress_output():
|
30
|
+
"""
|
31
|
+
Context manager to suppress stdout and stderr (including C/C++ prints).
|
32
|
+
"""
|
33
|
+
with open(os.devnull, "w") as devnull:
|
34
|
+
old_stdout = sys.stdout
|
35
|
+
old_stderr = sys.stderr
|
36
|
+
sys.stdout = devnull
|
37
|
+
sys.stderr = devnull
|
38
|
+
try:
|
39
|
+
yield
|
40
|
+
finally:
|
41
|
+
sys.stdout = old_stdout
|
42
|
+
sys.stderr = old_stderr
|
43
|
+
|
44
|
+
|
24
45
|
class EnvProperty:
|
46
|
+
"""
|
47
|
+
Base class for environment properties.
|
48
|
+
"""
|
49
|
+
|
25
50
|
def __init__(self, name):
|
51
|
+
"""
|
52
|
+
Initializes a new instance of the Environment class.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
name (str): The name of the environment.
|
56
|
+
"""
|
26
57
|
self.name = name
|
27
58
|
|
28
59
|
def __str__(self):
|
60
|
+
"""
|
61
|
+
Returns a string representation of the object.
|
62
|
+
"""
|
29
63
|
return f"{self.name}"
|
30
64
|
|
31
65
|
def __repr__(self):
|
66
|
+
"""
|
67
|
+
Returns a string representation of the object.
|
68
|
+
"""
|
32
69
|
return f"{self.name}"
|
33
70
|
|
34
71
|
def __eq__(self, other):
|
72
|
+
"""
|
73
|
+
Check if this object is equal to another object.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
other: The other object to compare with.
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
True if the objects are equal, False otherwise.
|
80
|
+
"""
|
35
81
|
return self.name == other.name
|
36
82
|
|
37
83
|
def __ne__(self, other):
|
84
|
+
"""
|
85
|
+
Check if the current object is not equal to the other object.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
other: The object to compare with.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
bool: True if the objects are not equal, False otherwise.
|
92
|
+
"""
|
38
93
|
return not self.__eq__(other)
|
39
94
|
|
40
95
|
@abstractmethod
|
41
96
|
def str_to_goal(self):
|
42
|
-
|
97
|
+
"""
|
98
|
+
Convert a problem name to a goal.
|
99
|
+
"""
|
43
100
|
|
44
101
|
@abstractmethod
|
45
102
|
def gc_adaptable(self):
|
46
|
-
|
103
|
+
"""
|
104
|
+
Check if the environment is goal-conditioned adaptable.
|
105
|
+
"""
|
47
106
|
|
48
107
|
@abstractmethod
|
49
108
|
def problem_list_to_str_tuple(self, problems):
|
50
|
-
|
109
|
+
"""
|
110
|
+
Convert a list of problems to a string tuple.
|
111
|
+
"""
|
51
112
|
|
52
113
|
@abstractmethod
|
53
114
|
def goal_to_problem_str(self, goal):
|
54
|
-
|
115
|
+
"""
|
116
|
+
Convert a goal to a problem string.
|
117
|
+
"""
|
55
118
|
|
56
119
|
@abstractmethod
|
57
120
|
def is_action_discrete(self):
|
58
|
-
|
121
|
+
"""
|
122
|
+
Check if the action space is discrete.
|
123
|
+
"""
|
59
124
|
|
60
125
|
@abstractmethod
|
61
126
|
def is_state_discrete(self):
|
62
|
-
|
127
|
+
"""
|
128
|
+
Check if the state space is discrete.
|
129
|
+
"""
|
63
130
|
|
64
131
|
@abstractmethod
|
65
132
|
def get_lstm_props(self):
|
66
|
-
|
133
|
+
"""
|
134
|
+
Get the LSTM properties for the environment.
|
135
|
+
"""
|
67
136
|
|
68
137
|
@abstractmethod
|
69
138
|
def change_done_by_specific_desired(self, obs, desired, old_success_done):
|
70
|
-
|
139
|
+
"""
|
140
|
+
Change the 'done' flag based on a specific desired goal.
|
141
|
+
"""
|
71
142
|
|
72
143
|
@abstractmethod
|
73
144
|
def is_done(self, done):
|
74
|
-
|
145
|
+
"""
|
146
|
+
Check if the episode is done.
|
147
|
+
"""
|
75
148
|
|
76
149
|
@abstractmethod
|
77
150
|
def is_success(self, info):
|
78
|
-
|
151
|
+
"""
|
152
|
+
Check if the episode is successful.
|
153
|
+
"""
|
79
154
|
|
80
155
|
def create_vec_env(self, kwargs):
|
81
|
-
|
156
|
+
"""
|
157
|
+
Create a vectorized environment, suppressing prints from gym/pybullet/panda-gym.
|
158
|
+
"""
|
159
|
+
with suppress_output():
|
160
|
+
env = gym.make(**kwargs)
|
82
161
|
return DummyVecEnv([lambda: env])
|
83
162
|
|
84
163
|
@abstractmethod
|
85
164
|
def change_goal_to_specific_desired(self, obs, desired):
|
86
|
-
|
165
|
+
"""
|
166
|
+
Change the goal to a specific desired goal.
|
167
|
+
"""
|
87
168
|
|
88
169
|
|
89
170
|
class GCEnvProperty(EnvProperty):
|
171
|
+
"""
|
172
|
+
Base class for goal-conditioned environment properties.
|
173
|
+
"""
|
174
|
+
|
90
175
|
@abstractmethod
|
91
176
|
def use_goal_directed_problem(self):
|
92
|
-
|
177
|
+
"""
|
178
|
+
Check if the environment uses a goal-directed problem.
|
179
|
+
"""
|
93
180
|
|
94
181
|
def problem_list_to_str_tuple(self, problems):
|
182
|
+
"""
|
183
|
+
Convert a list of problems to a string tuple.
|
184
|
+
"""
|
95
185
|
return "goal_conditioned"
|
96
186
|
|
97
187
|
|
98
188
|
class MinigridProperty(EnvProperty):
|
189
|
+
"""
|
190
|
+
Environment properties for the Minigrid domain.
|
191
|
+
"""
|
192
|
+
|
99
193
|
def __init__(self, name):
|
100
194
|
super().__init__(name)
|
101
195
|
self.domain_name = "minigrid"
|
102
196
|
|
103
197
|
def goal_to_problem_str(self, goal):
|
198
|
+
"""
|
199
|
+
Convert a goal to a problem string.
|
200
|
+
"""
|
104
201
|
return self.name + f"-DynamicGoal-{goal[0]}x{goal[1]}-v0"
|
105
202
|
|
106
203
|
def str_to_goal(self, problem_name):
|
204
|
+
"""
|
205
|
+
Convert a problem name to a goal.
|
206
|
+
"""
|
107
207
|
parts = problem_name.split("-")
|
108
208
|
goal_part = [part for part in parts if "x" in part]
|
109
209
|
width, height = goal_part[0].split("x")
|
110
210
|
return (int(width), int(height))
|
111
211
|
|
112
212
|
def gc_adaptable(self):
|
213
|
+
"""
|
214
|
+
Check if the environment is goal-conditioned adaptable.
|
215
|
+
"""
|
113
216
|
return False
|
114
217
|
|
115
218
|
def problem_list_to_str_tuple(self, problems):
|
219
|
+
"""
|
220
|
+
Convert a list of problems to a string tuple.
|
221
|
+
"""
|
116
222
|
return "_".join([f"[{s.split('-')[-2]}]" for s in problems])
|
117
223
|
|
118
224
|
def is_action_discrete(self):
|
225
|
+
"""
|
226
|
+
Check if the action space is discrete.
|
227
|
+
"""
|
119
228
|
return True
|
120
229
|
|
121
230
|
def is_state_discrete(self):
|
231
|
+
"""
|
232
|
+
Check if the state space is discrete.
|
233
|
+
"""
|
122
234
|
return True
|
123
235
|
|
124
236
|
def get_lstm_props(self):
|
237
|
+
"""
|
238
|
+
Get the LSTM properties for the environment.
|
239
|
+
"""
|
125
240
|
return LSTMProperties(
|
126
241
|
batch_size=16, input_size=4, hidden_size=8, num_samples=40000
|
127
242
|
)
|
128
243
|
|
129
244
|
def create_sequence_image(self, sequence, img_path, problem_name):
|
245
|
+
"""
|
246
|
+
Create a sequence image for the environment.
|
247
|
+
"""
|
130
248
|
if not os.path.exists(os.path.dirname(img_path)):
|
131
249
|
os.makedirs(os.path.dirname(img_path))
|
132
250
|
env_id = (
|
@@ -134,7 +252,7 @@ class MinigridProperty(EnvProperty):
|
|
134
252
|
+ "-DynamicGoal-"
|
135
253
|
+ problem_name.split("-DynamicGoal-")[1]
|
136
254
|
)
|
137
|
-
|
255
|
+
register(
|
138
256
|
id=env_id,
|
139
257
|
entry_point="gr_envs.minigrid_scripts.envs:CustomColorEnv",
|
140
258
|
kwargs={
|
@@ -146,7 +264,6 @@ class MinigridProperty(EnvProperty):
|
|
146
264
|
"plan": sequence,
|
147
265
|
},
|
148
266
|
)
|
149
|
-
# print(result)
|
150
267
|
env = gym.make(id=env_id)
|
151
268
|
env = RGBImgPartialObsWrapper(env) # Get pixel observations
|
152
269
|
env = ImgObsWrapper(env) # Get rid of the 'mission' field
|
@@ -156,34 +273,62 @@ class MinigridProperty(EnvProperty):
|
|
156
273
|
|
157
274
|
####### save image to file
|
158
275
|
image_pil = Image.fromarray(np.uint8(img)).convert("RGB")
|
159
|
-
image_pil.save(r"{}.png".format(img_path))
|
276
|
+
image_pil.save(r"{}.png".format(os.path.join(img_path, "plan_image")))
|
160
277
|
|
161
278
|
def change_done_by_specific_desired(self, obs, desired, old_success_done):
|
279
|
+
"""
|
280
|
+
Change the 'done' flag based on a specific desired goal.
|
281
|
+
"""
|
162
282
|
assert (
|
163
283
|
desired is None
|
164
284
|
), "In MinigridProperty, giving a specific 'desired' is not supported."
|
165
285
|
return old_success_done
|
166
286
|
|
167
287
|
def is_done(self, done):
|
288
|
+
"""
|
289
|
+
Check if the episode is done.
|
290
|
+
"""
|
168
291
|
assert isinstance(done, np.ndarray)
|
169
292
|
return done[0]
|
170
293
|
|
171
|
-
# Not used currently since TabularQLearner doesn't need is_success from the environment
|
172
294
|
def is_success(self, info):
|
295
|
+
"""
|
296
|
+
Check if the episode is successful.
|
297
|
+
"""
|
173
298
|
raise NotImplementedError("no other option for any of the environments.")
|
174
299
|
|
175
300
|
def change_goal_to_specific_desired(self, obs, desired):
|
301
|
+
"""
|
302
|
+
Change the goal to a specific desired goal.
|
303
|
+
"""
|
176
304
|
assert (
|
177
305
|
desired is None
|
178
306
|
), "In MinigridProperty, giving a specific 'desired' is not supported."
|
179
307
|
|
180
308
|
|
181
309
|
class PandaProperty(GCEnvProperty):
|
310
|
+
"""
|
311
|
+
Environment properties for the Panda domain.
|
312
|
+
"""
|
313
|
+
|
182
314
|
def __init__(self, name):
|
315
|
+
"""
|
316
|
+
Initialize a new instance of the Environment class.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
name (str): The name of the environment.
|
320
|
+
|
321
|
+
Attributes:
|
322
|
+
domain_name (str): The domain name of the environment.
|
323
|
+
|
324
|
+
"""
|
183
325
|
super().__init__(name)
|
184
326
|
self.domain_name = "panda"
|
185
327
|
|
186
328
|
def str_to_goal(self, problem_name):
|
329
|
+
"""
|
330
|
+
Convert a problem name to a goal.
|
331
|
+
"""
|
187
332
|
try:
|
188
333
|
numeric_part = problem_name.split("PandaMyReachDenseX")[1]
|
189
334
|
components = [
|
@@ -194,38 +339,62 @@ class PandaProperty(GCEnvProperty):
|
|
194
339
|
for component in components:
|
195
340
|
floats.append(float(component))
|
196
341
|
return np.array([floats], dtype=np.float32)
|
197
|
-
except Exception
|
342
|
+
except Exception:
|
198
343
|
return "general"
|
199
344
|
|
200
345
|
def goal_to_problem_str(self, goal):
|
346
|
+
"""
|
347
|
+
Convert a goal to a problem string.
|
348
|
+
"""
|
201
349
|
goal_str = "X".join(
|
202
350
|
[str(float(g)).replace(".", "y").replace("-", "M") for g in goal[0]]
|
203
351
|
)
|
204
352
|
return f"PandaMyReachDenseX{goal_str}-v3"
|
205
353
|
|
206
354
|
def gc_adaptable(self):
|
355
|
+
"""
|
356
|
+
Check if the environment is goal-conditioned adaptable.
|
357
|
+
"""
|
207
358
|
return True
|
208
359
|
|
209
360
|
def use_goal_directed_problem(self):
|
361
|
+
"""
|
362
|
+
Check if the environment uses a goal-directed problem.
|
363
|
+
"""
|
210
364
|
return False
|
211
365
|
|
212
366
|
def is_action_discrete(self):
|
367
|
+
"""
|
368
|
+
Check if the action space is discrete.
|
369
|
+
"""
|
213
370
|
return False
|
214
371
|
|
215
372
|
def is_state_discrete(self):
|
373
|
+
"""
|
374
|
+
Check if the state space is discrete.
|
375
|
+
"""
|
216
376
|
return False
|
217
377
|
|
218
378
|
def get_lstm_props(self):
|
379
|
+
"""
|
380
|
+
Get the LSTM properties for the environment.
|
381
|
+
"""
|
219
382
|
return LSTMProperties(
|
220
383
|
batch_size=32, input_size=9, hidden_size=8, num_samples=20000
|
221
384
|
)
|
222
385
|
|
223
386
|
def sample_goal():
|
387
|
+
"""
|
388
|
+
Sample a random goal.
|
389
|
+
"""
|
224
390
|
goal_range_low = np.array([-0.40, -0.40, 0.10])
|
225
391
|
goal_range_high = np.array([0.2, 0.2, 0.10])
|
226
392
|
return np.random.uniform(goal_range_low, goal_range_high)
|
227
393
|
|
228
394
|
def change_done_by_specific_desired(self, obs, desired, old_success_done):
|
395
|
+
"""
|
396
|
+
Change the 'done' flag based on a specific desired goal.
|
397
|
+
"""
|
229
398
|
if desired is None:
|
230
399
|
return old_success_done
|
231
400
|
assert isinstance(
|
@@ -241,70 +410,134 @@ class PandaProperty(GCEnvProperty):
|
|
241
410
|
return old_success_done
|
242
411
|
|
243
412
|
def is_done(self, done):
|
413
|
+
"""
|
414
|
+
Check if the episode is done.
|
415
|
+
"""
|
244
416
|
assert isinstance(done, np.ndarray)
|
245
417
|
return done[0]
|
246
418
|
|
247
419
|
def is_success(self, info):
|
420
|
+
"""
|
421
|
+
Check if the episode is successful.
|
422
|
+
"""
|
248
423
|
assert "is_success" in info[0].keys()
|
249
424
|
return info[0]["is_success"]
|
250
425
|
|
251
426
|
def change_goal_to_specific_desired(self, obs, desired):
|
427
|
+
"""
|
428
|
+
Change the goal to a specific desired goal.
|
429
|
+
"""
|
252
430
|
if desired is not None:
|
253
431
|
obs["desired_goal"] = desired
|
254
432
|
|
255
433
|
|
256
434
|
class ParkingProperty(GCEnvProperty):
|
435
|
+
"""
|
436
|
+
Environment properties for the Parking domain.
|
437
|
+
"""
|
257
438
|
|
258
439
|
def __init__(self, name):
|
440
|
+
"""
|
441
|
+
Initialize a new environment object.
|
442
|
+
|
443
|
+
Args:
|
444
|
+
name (str): The name of the environment.
|
445
|
+
|
446
|
+
Attributes:
|
447
|
+
domain_name (str): The domain name of the environment.
|
448
|
+
|
449
|
+
"""
|
259
450
|
super().__init__(name)
|
260
451
|
self.domain_name = "parking"
|
261
452
|
|
262
453
|
def goal_to_problem_str(self, goal):
|
454
|
+
"""
|
455
|
+
Convert a goal to a problem string.
|
456
|
+
"""
|
263
457
|
return self.name.split("-v0")[0] + f"-GI-{goal}-v0"
|
264
458
|
|
265
459
|
def gc_adaptable(self):
|
460
|
+
"""
|
461
|
+
Check if the environment is goal-conditioned adaptable.
|
462
|
+
"""
|
266
463
|
return True
|
267
464
|
|
268
465
|
def is_action_discrete(self):
|
466
|
+
"""
|
467
|
+
Check if the action space is discrete.
|
468
|
+
"""
|
269
469
|
return False
|
270
470
|
|
271
471
|
def is_state_discrete(self):
|
472
|
+
"""
|
473
|
+
Check if the state space is discrete.
|
474
|
+
"""
|
272
475
|
return False
|
273
476
|
|
274
477
|
def use_goal_directed_problem(self):
|
478
|
+
"""
|
479
|
+
Check if the environment uses a goal-directed problem.
|
480
|
+
"""
|
275
481
|
return True
|
276
482
|
|
277
483
|
def get_lstm_props(self):
|
484
|
+
"""
|
485
|
+
Get the LSTM properties for the environment.
|
486
|
+
"""
|
278
487
|
return LSTMProperties(
|
279
488
|
batch_size=32, input_size=8, hidden_size=8, num_samples=20000
|
280
489
|
)
|
281
490
|
|
282
491
|
def change_done_by_specific_desired(self, obs, desired, old_success_done):
|
492
|
+
"""
|
493
|
+
Change the 'done' flag based on a specific desired goal.
|
494
|
+
"""
|
283
495
|
assert (
|
284
496
|
desired is None
|
285
497
|
), "In ParkingProperty, giving a specific 'desired' is not supported."
|
286
498
|
return old_success_done
|
287
499
|
|
288
500
|
def is_done(self, done):
|
501
|
+
"""
|
502
|
+
Check if the episode is done.
|
503
|
+
"""
|
289
504
|
assert isinstance(done, np.ndarray)
|
290
505
|
return done[0]
|
291
506
|
|
292
507
|
def is_success(self, info):
|
508
|
+
"""
|
509
|
+
Check if the episode is successful.
|
510
|
+
"""
|
293
511
|
assert "is_success" in info[0].keys()
|
294
512
|
return info[0]["is_success"]
|
295
513
|
|
296
514
|
def change_goal_to_specific_desired(self, obs, desired):
|
515
|
+
"""
|
516
|
+
Change the goal to a specific desired goal.
|
517
|
+
"""
|
297
518
|
assert (
|
298
519
|
desired is None
|
299
520
|
), "In ParkingProperty, giving a specific 'desired' is not supported."
|
300
521
|
|
301
522
|
|
302
523
|
class PointMazeProperty(EnvProperty):
|
524
|
+
"""Environment properties for the Point Maze domain."""
|
525
|
+
|
303
526
|
def __init__(self, name):
|
527
|
+
"""
|
528
|
+
Initializes a new instance of the Environment class.
|
529
|
+
|
530
|
+
Args:
|
531
|
+
name (str): The name of the environment.
|
532
|
+
|
533
|
+
Attributes:
|
534
|
+
domain_name (str): The domain name of the environment.
|
535
|
+
"""
|
304
536
|
super().__init__(name)
|
305
537
|
self.domain_name = "point_maze"
|
306
538
|
|
307
539
|
def str_to_goal(self):
|
540
|
+
"""Convert a problem name to a goal."""
|
308
541
|
parts = self.name.split("-")
|
309
542
|
# Find the part containing the goal size (usually after "DynamicGoal")
|
310
543
|
sizes_parts = [part for part in parts if "x" in part]
|
@@ -314,40 +547,62 @@ class PointMazeProperty(EnvProperty):
|
|
314
547
|
return (int(width), int(height))
|
315
548
|
|
316
549
|
def gc_adaptable(self):
|
550
|
+
"""Check if the environment is goal-conditioned adaptable."""
|
317
551
|
return False
|
318
552
|
|
319
553
|
def problem_list_to_str_tuple(self, problems):
|
554
|
+
"""Convert a list of problems to a string tuple."""
|
320
555
|
return "_".join([f"[{s.split('-')[-1]}]" for s in problems])
|
321
556
|
|
322
557
|
def is_action_discrete(self):
|
558
|
+
"""Check if the action space is discrete."""
|
323
559
|
return False
|
324
560
|
|
325
561
|
def is_state_discrete(self):
|
562
|
+
"""Check if the state space is discrete."""
|
326
563
|
return False
|
327
564
|
|
328
565
|
def get_lstm_props(self):
|
566
|
+
"""
|
567
|
+
Get the LSTM properties for the environment.
|
568
|
+
"""
|
329
569
|
return LSTMProperties(
|
330
570
|
batch_size=32, input_size=6, hidden_size=8, num_samples=20000
|
331
571
|
)
|
332
572
|
|
333
573
|
def goal_to_problem_str(self, goal):
|
574
|
+
"""
|
575
|
+
Convert a goal to a problem string.
|
576
|
+
"""
|
334
577
|
return self.name + f"-Goal-{goal[0]}x{goal[1]}"
|
335
578
|
|
336
579
|
def change_done_by_specific_desired(self, obs, desired, old_success_done):
|
580
|
+
"""
|
581
|
+
Change the 'done' flag based on a specific desired goal.
|
582
|
+
"""
|
337
583
|
assert (
|
338
584
|
desired is None
|
339
585
|
), "In PointMazeProperty, giving a specific 'desired' is not supported."
|
340
586
|
return old_success_done
|
341
587
|
|
342
588
|
def is_done(self, done):
|
589
|
+
"""
|
590
|
+
Check if the episode is done.
|
591
|
+
"""
|
343
592
|
assert isinstance(done, np.ndarray)
|
344
593
|
return done[0]
|
345
594
|
|
346
595
|
def is_success(self, info):
|
596
|
+
"""
|
597
|
+
Check if the episode is successful.
|
598
|
+
"""
|
347
599
|
assert "success" in info[0].keys()
|
348
600
|
return info[0]["success"]
|
349
601
|
|
350
602
|
def change_goal_to_specific_desired(self, obs, desired):
|
603
|
+
"""
|
604
|
+
Change the goal to a specific desired goal.
|
605
|
+
"""
|
351
606
|
assert (
|
352
607
|
desired is None
|
353
608
|
), "In ParkingProperty, giving a specific 'desired' is not supported."
|
@@ -0,0 +1 @@
|
|
1
|
+
""" This is a directory that includes scripts for analysis of GR results. """
|
@@ -0,0 +1,100 @@
|
|
1
|
+
import argparse
|
2
|
+
import os
|
3
|
+
|
4
|
+
import dill
|
5
|
+
import matplotlib.pyplot as plt
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from gr_libs.ml.utils.storage import get_experiment_results_path
|
9
|
+
|
10
|
+
|
11
|
+
def load_results(domain, env, task, recognizer, n_runs, percentage, cons_type):
|
12
|
+
# Collect accuracy for a single task and recognizer
|
13
|
+
accs = []
|
14
|
+
res_dir = get_experiment_results_path(domain, env, task, recognizer)
|
15
|
+
if not os.path.exists(res_dir):
|
16
|
+
return accs
|
17
|
+
for i in range(n_runs):
|
18
|
+
res_file = os.path.join(res_dir, f"res_{i}.pkl")
|
19
|
+
if not os.path.exists(res_file):
|
20
|
+
continue
|
21
|
+
with open(res_file, "rb") as f:
|
22
|
+
results = dill.load(f)
|
23
|
+
if percentage in results and cons_type in results[percentage]:
|
24
|
+
acc = results[percentage][cons_type].get("accuracy")
|
25
|
+
if acc is not None:
|
26
|
+
accs.append(acc)
|
27
|
+
return accs
|
28
|
+
|
29
|
+
|
30
|
+
def main():
|
31
|
+
parser = argparse.ArgumentParser()
|
32
|
+
parser.add_argument("--domain", required=True)
|
33
|
+
parser.add_argument("--env", required=True)
|
34
|
+
parser.add_argument("--tasks", nargs="+", required=True)
|
35
|
+
parser.add_argument("--recognizers", nargs="+", required=True)
|
36
|
+
parser.add_argument("--n_runs", type=int, default=5)
|
37
|
+
parser.add_argument("--percentage", required=True)
|
38
|
+
parser.add_argument(
|
39
|
+
"--cons_type", choices=["consecutive", "non_consecutive"], required=True
|
40
|
+
)
|
41
|
+
parser.add_argument("--graph_name", type=str, default="experiment_results")
|
42
|
+
args = parser.parse_args()
|
43
|
+
|
44
|
+
plt.figure(figsize=(7, 5))
|
45
|
+
has_data = False
|
46
|
+
missing_recognizers = []
|
47
|
+
|
48
|
+
for recognizer in args.recognizers:
|
49
|
+
x_vals = []
|
50
|
+
y_means = []
|
51
|
+
y_sems = []
|
52
|
+
for task in args.tasks:
|
53
|
+
accs = load_results(
|
54
|
+
args.domain,
|
55
|
+
args.env,
|
56
|
+
task,
|
57
|
+
recognizer,
|
58
|
+
args.n_runs,
|
59
|
+
args.percentage,
|
60
|
+
args.cons_type,
|
61
|
+
)
|
62
|
+
if accs:
|
63
|
+
x_vals.append(task)
|
64
|
+
y_means.append(np.mean(accs))
|
65
|
+
y_sems.append(np.std(accs) / np.sqrt(len(accs)))
|
66
|
+
if x_vals:
|
67
|
+
has_data = True
|
68
|
+
x_ticks = np.arange(len(x_vals))
|
69
|
+
plt.plot(x_ticks, y_means, marker="o", label=recognizer)
|
70
|
+
plt.fill_between(
|
71
|
+
x_ticks,
|
72
|
+
np.array(y_means) - np.array(y_sems),
|
73
|
+
np.array(y_means) + np.array(y_sems),
|
74
|
+
alpha=0.2,
|
75
|
+
)
|
76
|
+
plt.xticks(x_ticks, x_vals)
|
77
|
+
else:
|
78
|
+
print(
|
79
|
+
f"Warning: No data found for recognizer '{recognizer}' in {args.domain} / {args.env} / {args.percentage} / {args.cons_type}"
|
80
|
+
)
|
81
|
+
missing_recognizers.append(recognizer)
|
82
|
+
|
83
|
+
if not has_data:
|
84
|
+
raise RuntimeError(
|
85
|
+
f"No data found for any recognizer in {args.domain} / {args.env} / {args.percentage} / {args.cons_type}. "
|
86
|
+
f"Missing recognizers: {', '.join(missing_recognizers)}"
|
87
|
+
)
|
88
|
+
|
89
|
+
plt.xlabel("Task")
|
90
|
+
plt.ylabel("Accuracy")
|
91
|
+
plt.title(f"{args.domain} - {args.env} ({args.percentage}, {args.cons_type})")
|
92
|
+
plt.legend()
|
93
|
+
plt.grid(True)
|
94
|
+
fig_path = f"{args.graph_name}_{'_'.join(args.recognizers)}_{args.domain}_{args.env}_{args.percentage}_{args.cons_type}.png"
|
95
|
+
plt.savefig(fig_path)
|
96
|
+
print(f"Figure saved at: {fig_path}")
|
97
|
+
|
98
|
+
|
99
|
+
if __name__ == "__main__":
|
100
|
+
main()
|