gr-libs 0.1.7.post0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. gr_libs/__init__.py +4 -1
  2. gr_libs/_evaluation/__init__.py +1 -0
  3. gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +260 -0
  4. gr_libs/_evaluation/_generate_experiments_results.py +141 -0
  5. gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +497 -0
  6. gr_libs/_evaluation/_get_plans_images.py +61 -0
  7. gr_libs/_evaluation/_increasing_and_decreasing_.py +106 -0
  8. gr_libs/_version.py +2 -2
  9. gr_libs/all_experiments.py +294 -0
  10. gr_libs/environment/__init__.py +30 -9
  11. gr_libs/environment/_utils/utils.py +27 -0
  12. gr_libs/environment/environment.py +417 -54
  13. gr_libs/metrics/__init__.py +7 -0
  14. gr_libs/metrics/metrics.py +231 -54
  15. gr_libs/ml/__init__.py +2 -5
  16. gr_libs/ml/agent.py +21 -6
  17. gr_libs/ml/base/__init__.py +3 -1
  18. gr_libs/ml/base/rl_agent.py +81 -13
  19. gr_libs/ml/consts.py +1 -1
  20. gr_libs/ml/neural/__init__.py +1 -3
  21. gr_libs/ml/neural/deep_rl_learner.py +619 -378
  22. gr_libs/ml/neural/utils/__init__.py +1 -2
  23. gr_libs/ml/neural/utils/dictlist.py +3 -3
  24. gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +1 -1
  25. gr_libs/ml/planner/mcts/{utils → _utils}/node.py +11 -7
  26. gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +15 -11
  27. gr_libs/ml/planner/mcts/mcts_model.py +571 -312
  28. gr_libs/ml/sequential/__init__.py +0 -1
  29. gr_libs/ml/sequential/_lstm_model.py +270 -0
  30. gr_libs/ml/tabular/__init__.py +1 -3
  31. gr_libs/ml/tabular/state.py +7 -7
  32. gr_libs/ml/tabular/tabular_q_learner.py +150 -82
  33. gr_libs/ml/tabular/tabular_rl_agent.py +42 -28
  34. gr_libs/ml/utils/__init__.py +2 -3
  35. gr_libs/ml/utils/format.py +28 -97
  36. gr_libs/ml/utils/math.py +5 -3
  37. gr_libs/ml/utils/other.py +3 -3
  38. gr_libs/ml/utils/storage.py +88 -81
  39. gr_libs/odgr_executor.py +268 -0
  40. gr_libs/problems/consts.py +1549 -1227
  41. gr_libs/recognizer/_utils/__init__.py +0 -0
  42. gr_libs/recognizer/_utils/format.py +18 -0
  43. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +233 -88
  44. gr_libs/recognizer/graml/_gr_dataset.py +233 -0
  45. gr_libs/recognizer/graml/graml_recognizer.py +586 -252
  46. gr_libs/recognizer/recognizer.py +90 -30
  47. gr_libs/tutorials/draco_panda_tutorial.py +58 -0
  48. gr_libs/tutorials/draco_parking_tutorial.py +56 -0
  49. gr_libs/tutorials/gcdraco_panda_tutorial.py +62 -0
  50. gr_libs/tutorials/gcdraco_parking_tutorial.py +57 -0
  51. gr_libs/tutorials/graml_minigrid_tutorial.py +64 -0
  52. gr_libs/tutorials/graml_panda_tutorial.py +57 -0
  53. gr_libs/tutorials/graml_parking_tutorial.py +52 -0
  54. gr_libs/tutorials/graml_point_maze_tutorial.py +60 -0
  55. gr_libs/tutorials/graql_minigrid_tutorial.py +50 -0
  56. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/METADATA +84 -29
  57. gr_libs-0.2.2.dist-info/RECORD +71 -0
  58. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/WHEEL +1 -1
  59. gr_libs-0.2.2.dist-info/top_level.txt +2 -0
  60. tests/test_draco.py +14 -0
  61. tests/test_gcdraco.py +10 -0
  62. tests/test_graml.py +12 -8
  63. tests/test_graql.py +3 -2
  64. evaluation/analyze_results_cross_alg_cross_domain.py +0 -277
  65. evaluation/create_minigrid_map_image.py +0 -34
  66. evaluation/file_system.py +0 -42
  67. evaluation/generate_experiments_results.py +0 -92
  68. evaluation/generate_experiments_results_new_ver1.py +0 -254
  69. evaluation/generate_experiments_results_new_ver2.py +0 -331
  70. evaluation/generate_task_specific_statistics_plots.py +0 -272
  71. evaluation/get_plans_images.py +0 -47
  72. evaluation/increasing_and_decreasing_.py +0 -63
  73. gr_libs/environment/utils/utils.py +0 -17
  74. gr_libs/ml/neural/utils/penv.py +0 -57
  75. gr_libs/ml/sequential/lstm_model.py +0 -192
  76. gr_libs/recognizer/graml/gr_dataset.py +0 -134
  77. gr_libs/recognizer/utils/__init__.py +0 -1
  78. gr_libs/recognizer/utils/format.py +0 -13
  79. gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
  80. gr_libs-0.1.7.post0.dist-info/top_level.txt +0 -4
  81. tutorials/graml_minigrid_tutorial.py +0 -34
  82. tutorials/graml_panda_tutorial.py +0 -41
  83. tutorials/graml_parking_tutorial.py +0 -39
  84. tutorials/graml_point_maze_tutorial.py +0 -39
  85. tutorials/graql_minigrid_tutorial.py +0 -34
  86. /gr_libs/environment/{utils → _utils}/__init__.py +0 -0
@@ -1,13 +1,16 @@
1
+ """ environment.py """
2
+
3
+ import os
1
4
  from abc import abstractmethod
2
5
  from collections import namedtuple
3
- import os
4
6
 
5
- import gymnasium
6
- from PIL import Image
7
+ import gymnasium as gym
7
8
  import numpy as np
8
9
  from gymnasium.envs.registration import register
9
- from minigrid.core.world_object import Wall, Lava
10
- from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
10
+ from minigrid.core.world_object import Lava, Wall
11
+ from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
12
+ from PIL import Image
13
+ from stable_baselines3.common.vec_env import DummyVecEnv
11
14
 
12
15
  MINIGRID, PANDA, PARKING, POINT_MAZE = "minigrid", "panda", "parking", "point_maze"
13
16
 
@@ -15,189 +18,506 @@ QLEARNING = "QLEARNING"
15
18
 
16
19
  SUPPORTED_DOMAINS = [MINIGRID, PANDA, PARKING, POINT_MAZE]
17
20
 
18
- LSTMProperties = namedtuple('LSTMProperties', ['input_size', 'hidden_size', 'batch_size', 'num_samples'])
19
-
21
+ LSTMProperties = namedtuple(
22
+ "LSTMProperties", ["input_size", "hidden_size", "batch_size", "num_samples"]
23
+ )
20
24
 
21
25
 
22
26
  class EnvProperty:
27
+ """
28
+ Base class for environment properties.
29
+ """
30
+
23
31
  def __init__(self, name):
32
+ """
33
+ Initializes a new instance of the Environment class.
34
+
35
+ Args:
36
+ name (str): The name of the environment.
37
+ """
24
38
  self.name = name
25
39
 
26
40
  def __str__(self):
41
+ """
42
+ Returns a string representation of the object.
43
+ """
27
44
  return f"{self.name}"
28
45
 
29
46
  def __repr__(self):
47
+ """
48
+ Returns a string representation of the object.
49
+ """
30
50
  return f"{self.name}"
31
51
 
32
52
  def __eq__(self, other):
53
+ """
54
+ Check if this object is equal to another object.
55
+
56
+ Args:
57
+ other: The other object to compare with.
58
+
59
+ Returns:
60
+ True if the objects are equal, False otherwise.
61
+ """
33
62
  return self.name == other.name
34
63
 
35
64
  def __ne__(self, other):
65
+ """
66
+ Check if the current object is not equal to the other object.
67
+
68
+ Args:
69
+ other: The object to compare with.
70
+
71
+ Returns:
72
+ bool: True if the objects are not equal, False otherwise.
73
+ """
36
74
  return not self.__eq__(other)
37
-
75
+
38
76
  @abstractmethod
39
77
  def str_to_goal(self):
40
- pass
78
+ """
79
+ Convert a problem name to a goal.
80
+ """
41
81
 
42
82
  @abstractmethod
43
83
  def gc_adaptable(self):
44
- pass
84
+ """
85
+ Check if the environment is goal-conditioned adaptable.
86
+ """
45
87
 
46
88
  @abstractmethod
47
89
  def problem_list_to_str_tuple(self, problems):
48
- pass
90
+ """
91
+ Convert a list of problems to a string tuple.
92
+ """
49
93
 
50
94
  @abstractmethod
51
95
  def goal_to_problem_str(self, goal):
52
- pass
96
+ """
97
+ Convert a goal to a problem string.
98
+ """
53
99
 
54
100
  @abstractmethod
55
101
  def is_action_discrete(self):
56
- pass
102
+ """
103
+ Check if the action space is discrete.
104
+ """
57
105
 
58
106
  @abstractmethod
59
107
  def is_state_discrete(self):
60
- pass
108
+ """
109
+ Check if the state space is discrete.
110
+ """
61
111
 
62
112
  @abstractmethod
63
113
  def get_lstm_props(self):
64
- pass
114
+ """
115
+ Get the LSTM properties for the environment.
116
+ """
117
+
118
+ @abstractmethod
119
+ def change_done_by_specific_desired(self, obs, desired, old_success_done):
120
+ """
121
+ Change the 'done' flag based on a specific desired goal.
122
+ """
123
+
124
+ @abstractmethod
125
+ def is_done(self, done):
126
+ """
127
+ Check if the episode is done.
128
+ """
129
+
130
+ @abstractmethod
131
+ def is_success(self, info):
132
+ """
133
+ Check if the episode is successful.
134
+ """
135
+
136
+ def create_vec_env(self, kwargs):
137
+ """
138
+ Create a vectorized environment.
139
+ """
140
+ env = gym.make(**kwargs)
141
+ return DummyVecEnv([lambda: env])
142
+
143
+ @abstractmethod
144
+ def change_goal_to_specific_desired(self, obs, desired):
145
+ """
146
+ Change the goal to a specific desired goal.
147
+ """
148
+
65
149
 
66
150
  class GCEnvProperty(EnvProperty):
151
+ """
152
+ Base class for goal-conditioned environment properties.
153
+ """
154
+
67
155
  @abstractmethod
68
156
  def use_goal_directed_problem(self):
69
- pass
157
+ """
158
+ Check if the environment uses a goal-directed problem.
159
+ """
70
160
 
71
161
  def problem_list_to_str_tuple(self, problems):
162
+ """
163
+ Convert a list of problems to a string tuple.
164
+ """
72
165
  return "goal_conditioned"
73
166
 
167
+
74
168
  class MinigridProperty(EnvProperty):
169
+ """
170
+ Environment properties for the Minigrid domain.
171
+ """
172
+
75
173
  def __init__(self, name):
76
174
  super().__init__(name)
77
175
  self.domain_name = "minigrid"
78
176
 
79
177
  def goal_to_problem_str(self, goal):
178
+ """
179
+ Convert a goal to a problem string.
180
+ """
80
181
  return self.name + f"-DynamicGoal-{goal[0]}x{goal[1]}-v0"
81
182
 
82
183
  def str_to_goal(self, problem_name):
184
+ """
185
+ Convert a problem name to a goal.
186
+ """
83
187
  parts = problem_name.split("-")
84
188
  goal_part = [part for part in parts if "x" in part]
85
189
  width, height = goal_part[0].split("x")
86
190
  return (int(width), int(height))
87
191
 
88
192
  def gc_adaptable(self):
193
+ """
194
+ Check if the environment is goal-conditioned adaptable.
195
+ """
89
196
  return False
90
-
197
+
91
198
  def problem_list_to_str_tuple(self, problems):
199
+ """
200
+ Convert a list of problems to a string tuple.
201
+ """
92
202
  return "_".join([f"[{s.split('-')[-2]}]" for s in problems])
93
-
203
+
94
204
  def is_action_discrete(self):
205
+ """
206
+ Check if the action space is discrete.
207
+ """
95
208
  return True
96
209
 
97
210
  def is_state_discrete(self):
211
+ """
212
+ Check if the state space is discrete.
213
+ """
98
214
  return True
99
215
 
100
216
  def get_lstm_props(self):
101
- return LSTMProperties(batch_size=16, input_size=4, hidden_size=8, num_samples=40000)
102
-
217
+ """
218
+ Get the LSTM properties for the environment.
219
+ """
220
+ return LSTMProperties(
221
+ batch_size=16, input_size=4, hidden_size=8, num_samples=40000
222
+ )
223
+
103
224
  def create_sequence_image(self, sequence, img_path, problem_name):
104
- if not os.path.exists(os.path.dirname(img_path)): os.makedirs(os.path.dirname(img_path))
105
- env_id = problem_name.split("-DynamicGoal-")[0] + "-DynamicGoal-" + problem_name.split("-DynamicGoal-")[1]
106
- result = register(
225
+ """
226
+ Create a sequence image for the environment.
227
+ """
228
+ if not os.path.exists(os.path.dirname(img_path)):
229
+ os.makedirs(os.path.dirname(img_path))
230
+ env_id = (
231
+ problem_name.split("-DynamicGoal-")[0]
232
+ + "-DynamicGoal-"
233
+ + problem_name.split("-DynamicGoal-")[1]
234
+ )
235
+ register(
107
236
  id=env_id,
108
237
  entry_point="gr_envs.minigrid_scripts.envs:CustomColorEnv",
109
- kwargs={"size": 13 if 'Simple' in problem_name else 9,
110
- "num_crossings": 4 if 'Simple' in problem_name else 3,
111
- "goal_pos": self.str_to_goal(problem_name),
112
- "obstacle_type": Wall if 'Simple' in problem_name else Lava,
113
- "start_pos": (1, 1) if 'Simple' in problem_name else (3, 1),
114
- "plan": sequence},
238
+ kwargs={
239
+ "size": 13 if "Simple" in problem_name else 9,
240
+ "num_crossings": 4 if "Simple" in problem_name else 3,
241
+ "goal_pos": self.str_to_goal(problem_name),
242
+ "obstacle_type": Wall if "Simple" in problem_name else Lava,
243
+ "start_pos": (1, 1) if "Simple" in problem_name else (3, 1),
244
+ "plan": sequence,
245
+ },
115
246
  )
116
- #print(result)
117
- env = gymnasium.make(id=env_id)
118
- env = RGBImgPartialObsWrapper(env) # Get pixel observations
119
- env = ImgObsWrapper(env) # Get rid of the 'mission' field
120
- obs, _ = env.reset() # This now produces an RGB tensor only
247
+ env = gym.make(id=env_id)
248
+ env = RGBImgPartialObsWrapper(env) # Get pixel observations
249
+ env = ImgObsWrapper(env) # Get rid of the 'mission' field
250
+ obs, _ = env.reset() # This now produces an RGB tensor only
121
251
 
122
252
  img = env.unwrapped.get_frame()
123
253
 
124
254
  ####### save image to file
125
- image_pil = Image.fromarray(np.uint8(img)).convert('RGB')
126
- image_pil.save(r"{}.png".format(img_path))
255
+ image_pil = Image.fromarray(np.uint8(img)).convert("RGB")
256
+ image_pil.save(r"{}.png".format(os.path.join(img_path, "plan_image")))
257
+
258
+ def change_done_by_specific_desired(self, obs, desired, old_success_done):
259
+ """
260
+ Change the 'done' flag based on a specific desired goal.
261
+ """
262
+ assert (
263
+ desired is None
264
+ ), "In MinigridProperty, giving a specific 'desired' is not supported."
265
+ return old_success_done
266
+
267
+ def is_done(self, done):
268
+ """
269
+ Check if the episode is done.
270
+ """
271
+ assert isinstance(done, np.ndarray)
272
+ return done[0]
273
+
274
+ def is_success(self, info):
275
+ """
276
+ Check if the episode is successful.
277
+ """
278
+ raise NotImplementedError("no other option for any of the environments.")
279
+
280
+ def change_goal_to_specific_desired(self, obs, desired):
281
+ """
282
+ Change the goal to a specific desired goal.
283
+ """
284
+ assert (
285
+ desired is None
286
+ ), "In MinigridProperty, giving a specific 'desired' is not supported."
287
+
127
288
 
128
-
129
289
  class PandaProperty(GCEnvProperty):
290
+ """
291
+ Environment properties for the Panda domain.
292
+ """
293
+
130
294
  def __init__(self, name):
295
+ """
296
+ Initialize a new instance of the Environment class.
297
+
298
+ Args:
299
+ name (str): The name of the environment.
300
+
301
+ Attributes:
302
+ domain_name (str): The domain name of the environment.
303
+
304
+ """
131
305
  super().__init__(name)
132
306
  self.domain_name = "panda"
133
307
 
134
308
  def str_to_goal(self, problem_name):
309
+ """
310
+ Convert a problem name to a goal.
311
+ """
135
312
  try:
136
- numeric_part = problem_name.split('PandaMyReachDenseX')[1]
137
- components = [component.replace('-v3', '').replace('y', '.').replace('M', '-') for component in numeric_part.split('X')]
313
+ numeric_part = problem_name.split("PandaMyReachDenseX")[1]
314
+ components = [
315
+ component.replace("-v3", "").replace("y", ".").replace("M", "-")
316
+ for component in numeric_part.split("X")
317
+ ]
138
318
  floats = []
139
319
  for component in components:
140
320
  floats.append(float(component))
141
321
  return np.array([floats], dtype=np.float32)
142
- except Exception as e:
322
+ except Exception:
143
323
  return "general"
144
-
324
+
145
325
  def goal_to_problem_str(self, goal):
146
- goal_str = 'X'.join([str(float(g)).replace(".", "y").replace("-","M") for g in goal[0]])
326
+ """
327
+ Convert a goal to a problem string.
328
+ """
329
+ goal_str = "X".join(
330
+ [str(float(g)).replace(".", "y").replace("-", "M") for g in goal[0]]
331
+ )
147
332
  return f"PandaMyReachDenseX{goal_str}-v3"
148
333
 
149
334
  def gc_adaptable(self):
335
+ """
336
+ Check if the environment is goal-conditioned adaptable.
337
+ """
150
338
  return True
151
-
339
+
152
340
  def use_goal_directed_problem(self):
341
+ """
342
+ Check if the environment uses a goal-directed problem.
343
+ """
153
344
  return False
154
-
345
+
155
346
  def is_action_discrete(self):
347
+ """
348
+ Check if the action space is discrete.
349
+ """
156
350
  return False
157
351
 
158
352
  def is_state_discrete(self):
353
+ """
354
+ Check if the state space is discrete.
355
+ """
159
356
  return False
160
357
 
161
358
  def get_lstm_props(self):
162
- return LSTMProperties(batch_size=32, input_size=9, hidden_size=8, num_samples=20000)
163
-
359
+ """
360
+ Get the LSTM properties for the environment.
361
+ """
362
+ return LSTMProperties(
363
+ batch_size=32, input_size=9, hidden_size=8, num_samples=20000
364
+ )
365
+
164
366
  def sample_goal():
367
+ """
368
+ Sample a random goal.
369
+ """
165
370
  goal_range_low = np.array([-0.40, -0.40, 0.10])
166
371
  goal_range_high = np.array([0.2, 0.2, 0.10])
167
372
  return np.random.uniform(goal_range_low, goal_range_high)
168
373
 
169
-
374
+ def change_done_by_specific_desired(self, obs, desired, old_success_done):
375
+ """
376
+ Change the 'done' flag based on a specific desired goal.
377
+ """
378
+ if desired is None:
379
+ return old_success_done
380
+ assert isinstance(
381
+ desired, np.ndarray
382
+ ), f"Unsupported type for desired: {type(desired)}"
383
+ if desired.size > 0 and not np.isnan(desired).all():
384
+ assert (
385
+ obs["achieved_goal"].shape == desired.shape
386
+ ), f"Shape mismatch: {obs['achieved_goal'].shape} vs {desired.shape}"
387
+ d = np.linalg.norm(obs["achieved_goal"] - desired, axis=-1)
388
+ return (d < 0.04)[0]
389
+ else:
390
+ return old_success_done
391
+
392
+ def is_done(self, done):
393
+ """
394
+ Check if the episode is done.
395
+ """
396
+ assert isinstance(done, np.ndarray)
397
+ return done[0]
398
+
399
+ def is_success(self, info):
400
+ """
401
+ Check if the episode is successful.
402
+ """
403
+ assert "is_success" in info[0].keys()
404
+ return info[0]["is_success"]
405
+
406
+ def change_goal_to_specific_desired(self, obs, desired):
407
+ """
408
+ Change the goal to a specific desired goal.
409
+ """
410
+ if desired is not None:
411
+ obs["desired_goal"] = desired
412
+
413
+
170
414
  class ParkingProperty(GCEnvProperty):
415
+ """
416
+ Environment properties for the Parking domain.
417
+ """
171
418
 
172
419
  def __init__(self, name):
420
+ """
421
+ Initialize a new environment object.
422
+
423
+ Args:
424
+ name (str): The name of the environment.
425
+
426
+ Attributes:
427
+ domain_name (str): The domain name of the environment.
428
+
429
+ """
173
430
  super().__init__(name)
174
431
  self.domain_name = "parking"
175
432
 
176
433
  def goal_to_problem_str(self, goal):
434
+ """
435
+ Convert a goal to a problem string.
436
+ """
177
437
  return self.name.split("-v0")[0] + f"-GI-{goal}-v0"
178
438
 
179
439
  def gc_adaptable(self):
440
+ """
441
+ Check if the environment is goal-conditioned adaptable.
442
+ """
180
443
  return True
181
-
444
+
182
445
  def is_action_discrete(self):
446
+ """
447
+ Check if the action space is discrete.
448
+ """
183
449
  return False
184
450
 
185
451
  def is_state_discrete(self):
452
+ """
453
+ Check if the state space is discrete.
454
+ """
186
455
  return False
187
-
456
+
188
457
  def use_goal_directed_problem(self):
458
+ """
459
+ Check if the environment uses a goal-directed problem.
460
+ """
189
461
  return True
190
-
462
+
191
463
  def get_lstm_props(self):
192
- return LSTMProperties(batch_size=32, input_size=8, hidden_size=8, num_samples=20000)
464
+ """
465
+ Get the LSTM properties for the environment.
466
+ """
467
+ return LSTMProperties(
468
+ batch_size=32, input_size=8, hidden_size=8, num_samples=20000
469
+ )
470
+
471
+ def change_done_by_specific_desired(self, obs, desired, old_success_done):
472
+ """
473
+ Change the 'done' flag based on a specific desired goal.
474
+ """
475
+ assert (
476
+ desired is None
477
+ ), "In ParkingProperty, giving a specific 'desired' is not supported."
478
+ return old_success_done
479
+
480
+ def is_done(self, done):
481
+ """
482
+ Check if the episode is done.
483
+ """
484
+ assert isinstance(done, np.ndarray)
485
+ return done[0]
486
+
487
+ def is_success(self, info):
488
+ """
489
+ Check if the episode is successful.
490
+ """
491
+ assert "is_success" in info[0].keys()
492
+ return info[0]["is_success"]
493
+
494
+ def change_goal_to_specific_desired(self, obs, desired):
495
+ """
496
+ Change the goal to a specific desired goal.
497
+ """
498
+ assert (
499
+ desired is None
500
+ ), "In ParkingProperty, giving a specific 'desired' is not supported."
193
501
 
194
502
 
195
503
  class PointMazeProperty(EnvProperty):
504
+ """Environment properties for the Point Maze domain."""
505
+
196
506
  def __init__(self, name):
507
+ """
508
+ Initializes a new instance of the Environment class.
509
+
510
+ Args:
511
+ name (str): The name of the environment.
512
+
513
+ Attributes:
514
+ domain_name (str): The domain name of the environment.
515
+ """
197
516
  super().__init__(name)
198
517
  self.domain_name = "point_maze"
199
518
 
200
519
  def str_to_goal(self):
520
+ """Convert a problem name to a goal."""
201
521
  parts = self.name.split("-")
202
522
  # Find the part containing the goal size (usually after "DynamicGoal")
203
523
  sizes_parts = [part for part in parts if "x" in part]
@@ -205,21 +525,64 @@ class PointMazeProperty(EnvProperty):
205
525
  # Extract width and height from the goal part
206
526
  width, height = goal_part.split("x")
207
527
  return (int(width), int(height))
208
-
528
+
209
529
  def gc_adaptable(self):
530
+ """Check if the environment is goal-conditioned adaptable."""
210
531
  return False
211
532
 
212
533
  def problem_list_to_str_tuple(self, problems):
534
+ """Convert a list of problems to a string tuple."""
213
535
  return "_".join([f"[{s.split('-')[-1]}]" for s in problems])
214
536
 
215
537
  def is_action_discrete(self):
538
+ """Check if the action space is discrete."""
216
539
  return False
217
540
 
218
541
  def is_state_discrete(self):
542
+ """Check if the state space is discrete."""
219
543
  return False
220
-
544
+
221
545
  def get_lstm_props(self):
222
- return LSTMProperties(batch_size=32, input_size=6, hidden_size=8, num_samples=20000)
546
+ """
547
+ Get the LSTM properties for the environment.
548
+ """
549
+ return LSTMProperties(
550
+ batch_size=32, input_size=6, hidden_size=8, num_samples=20000
551
+ )
223
552
 
224
553
  def goal_to_problem_str(self, goal):
554
+ """
555
+ Convert a goal to a problem string.
556
+ """
225
557
  return self.name + f"-Goal-{goal[0]}x{goal[1]}"
558
+
559
+ def change_done_by_specific_desired(self, obs, desired, old_success_done):
560
+ """
561
+ Change the 'done' flag based on a specific desired goal.
562
+ """
563
+ assert (
564
+ desired is None
565
+ ), "In PointMazeProperty, giving a specific 'desired' is not supported."
566
+ return old_success_done
567
+
568
+ def is_done(self, done):
569
+ """
570
+ Check if the episode is done.
571
+ """
572
+ assert isinstance(done, np.ndarray)
573
+ return done[0]
574
+
575
+ def is_success(self, info):
576
+ """
577
+ Check if the episode is successful.
578
+ """
579
+ assert "success" in info[0].keys()
580
+ return info[0]["success"]
581
+
582
+ def change_goal_to_specific_desired(self, obs, desired):
583
+ """
584
+ Change the goal to a specific desired goal.
585
+ """
586
+ assert (
587
+ desired is None
588
+ ), "In ParkingProperty, giving a specific 'desired' is not supported."
@@ -0,0 +1,7 @@
1
+ """ metrics for GR algorithms """
2
+
3
+ from .metrics import (
4
+ mean_p_value,
5
+ mean_wasserstein_distance,
6
+ stochastic_amplified_selection,
7
+ )