gr-libs 0.1.8__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. gr_libs/__init__.py +3 -1
  2. gr_libs/_version.py +2 -2
  3. gr_libs/all_experiments.py +260 -0
  4. gr_libs/environment/__init__.py +14 -1
  5. gr_libs/environment/_utils/__init__.py +0 -0
  6. gr_libs/environment/{utils → _utils}/utils.py +1 -1
  7. gr_libs/environment/environment.py +278 -23
  8. gr_libs/evaluation/__init__.py +1 -0
  9. gr_libs/evaluation/generate_experiments_results.py +100 -0
  10. gr_libs/metrics/__init__.py +2 -0
  11. gr_libs/metrics/metrics.py +166 -31
  12. gr_libs/ml/__init__.py +1 -6
  13. gr_libs/ml/base/__init__.py +3 -1
  14. gr_libs/ml/base/rl_agent.py +68 -3
  15. gr_libs/ml/neural/__init__.py +1 -3
  16. gr_libs/ml/neural/deep_rl_learner.py +241 -84
  17. gr_libs/ml/neural/utils/__init__.py +1 -2
  18. gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +1 -1
  19. gr_libs/ml/planner/mcts/mcts_model.py +71 -34
  20. gr_libs/ml/sequential/__init__.py +0 -1
  21. gr_libs/ml/sequential/{lstm_model.py → _lstm_model.py} +11 -14
  22. gr_libs/ml/tabular/__init__.py +1 -3
  23. gr_libs/ml/tabular/tabular_q_learner.py +27 -9
  24. gr_libs/ml/tabular/tabular_rl_agent.py +22 -9
  25. gr_libs/ml/utils/__init__.py +2 -9
  26. gr_libs/ml/utils/format.py +13 -90
  27. gr_libs/ml/utils/math.py +3 -2
  28. gr_libs/ml/utils/other.py +2 -2
  29. gr_libs/ml/utils/storage.py +41 -94
  30. gr_libs/odgr_executor.py +263 -0
  31. gr_libs/problems/consts.py +570 -292
  32. gr_libs/recognizer/{utils → _utils}/format.py +2 -2
  33. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +127 -36
  34. gr_libs/recognizer/graml/{gr_dataset.py → _gr_dataset.py} +11 -11
  35. gr_libs/recognizer/graml/graml_recognizer.py +186 -35
  36. gr_libs/recognizer/recognizer.py +59 -10
  37. gr_libs/tutorials/draco_panda_tutorial.py +58 -0
  38. gr_libs/tutorials/draco_parking_tutorial.py +56 -0
  39. {tutorials → gr_libs/tutorials}/gcdraco_panda_tutorial.py +11 -11
  40. {tutorials → gr_libs/tutorials}/gcdraco_parking_tutorial.py +6 -8
  41. {tutorials → gr_libs/tutorials}/graml_minigrid_tutorial.py +18 -14
  42. {tutorials → gr_libs/tutorials}/graml_panda_tutorial.py +11 -12
  43. {tutorials → gr_libs/tutorials}/graml_parking_tutorial.py +8 -10
  44. {tutorials → gr_libs/tutorials}/graml_point_maze_tutorial.py +17 -3
  45. {tutorials → gr_libs/tutorials}/graql_minigrid_tutorial.py +2 -2
  46. {gr_libs-0.1.8.dist-info → gr_libs-0.2.5.dist-info}/METADATA +95 -29
  47. gr_libs-0.2.5.dist-info/RECORD +72 -0
  48. {gr_libs-0.1.8.dist-info → gr_libs-0.2.5.dist-info}/WHEEL +1 -1
  49. gr_libs-0.2.5.dist-info/top_level.txt +2 -0
  50. tests/test_draco.py +14 -0
  51. tests/test_gcdraco.py +2 -2
  52. tests/test_graml.py +4 -4
  53. tests/test_graql.py +1 -1
  54. tests/test_odgr_executor_expertbasedgraml.py +14 -0
  55. tests/test_odgr_executor_gcdraco.py +14 -0
  56. tests/test_odgr_executor_gcgraml.py +14 -0
  57. tests/test_odgr_executor_graql.py +14 -0
  58. evaluation/analyze_results_cross_alg_cross_domain.py +0 -267
  59. evaluation/create_minigrid_map_image.py +0 -38
  60. evaluation/file_system.py +0 -53
  61. evaluation/generate_experiments_results.py +0 -141
  62. evaluation/generate_experiments_results_new_ver1.py +0 -238
  63. evaluation/generate_experiments_results_new_ver2.py +0 -331
  64. evaluation/generate_task_specific_statistics_plots.py +0 -500
  65. evaluation/get_plans_images.py +0 -62
  66. evaluation/increasing_and_decreasing_.py +0 -104
  67. gr_libs/ml/neural/utils/penv.py +0 -60
  68. gr_libs-0.1.8.dist-info/RECORD +0 -70
  69. gr_libs-0.1.8.dist-info/top_level.txt +0 -4
  70. /gr_libs/{environment/utils/__init__.py → _evaluation/_generate_experiments_results.py} +0 -0
  71. /gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +0 -0
  72. /gr_libs/ml/planner/mcts/{utils → _utils}/node.py +0 -0
  73. /gr_libs/recognizer/{utils → _utils}/__init__.py +0 -0
@@ -1,14 +1,18 @@
1
+ """ environment.py """
2
+
3
+ import os
4
+ import sys
1
5
  from abc import abstractmethod
2
6
  from collections import namedtuple
3
- import os
7
+ from contextlib import contextmanager
4
8
 
5
9
  import gymnasium as gym
6
- from stable_baselines3.common.vec_env import DummyVecEnv
7
- from PIL import Image
8
10
  import numpy as np
9
11
  from gymnasium.envs.registration import register
10
- from minigrid.core.world_object import Wall, Lava
11
- from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
12
+ from minigrid.core.world_object import Lava, Wall
13
+ from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
14
+ from PIL import Image
15
+ from stable_baselines3.common.vec_env import DummyVecEnv
12
16
 
13
17
  MINIGRID, PANDA, PARKING, POINT_MAZE = "minigrid", "panda", "parking", "point_maze"
14
18
 
@@ -21,112 +25,226 @@ LSTMProperties = namedtuple(
21
25
  )
22
26
 
23
27
 
28
+ @contextmanager
29
+ def suppress_output():
30
+ """
31
+ Context manager to suppress stdout and stderr (including C/C++ prints).
32
+ """
33
+ with open(os.devnull, "w") as devnull:
34
+ old_stdout = sys.stdout
35
+ old_stderr = sys.stderr
36
+ sys.stdout = devnull
37
+ sys.stderr = devnull
38
+ try:
39
+ yield
40
+ finally:
41
+ sys.stdout = old_stdout
42
+ sys.stderr = old_stderr
43
+
44
+
24
45
  class EnvProperty:
46
+ """
47
+ Base class for environment properties.
48
+ """
49
+
25
50
  def __init__(self, name):
51
+ """
52
+ Initializes a new instance of the Environment class.
53
+
54
+ Args:
55
+ name (str): The name of the environment.
56
+ """
26
57
  self.name = name
27
58
 
28
59
  def __str__(self):
60
+ """
61
+ Returns a string representation of the object.
62
+ """
29
63
  return f"{self.name}"
30
64
 
31
65
  def __repr__(self):
66
+ """
67
+ Returns a string representation of the object.
68
+ """
32
69
  return f"{self.name}"
33
70
 
34
71
  def __eq__(self, other):
72
+ """
73
+ Check if this object is equal to another object.
74
+
75
+ Args:
76
+ other: The other object to compare with.
77
+
78
+ Returns:
79
+ True if the objects are equal, False otherwise.
80
+ """
35
81
  return self.name == other.name
36
82
 
37
83
  def __ne__(self, other):
84
+ """
85
+ Check if the current object is not equal to the other object.
86
+
87
+ Args:
88
+ other: The object to compare with.
89
+
90
+ Returns:
91
+ bool: True if the objects are not equal, False otherwise.
92
+ """
38
93
  return not self.__eq__(other)
39
94
 
40
95
  @abstractmethod
41
96
  def str_to_goal(self):
42
- pass
97
+ """
98
+ Convert a problem name to a goal.
99
+ """
43
100
 
44
101
  @abstractmethod
45
102
  def gc_adaptable(self):
46
- pass
103
+ """
104
+ Check if the environment is goal-conditioned adaptable.
105
+ """
47
106
 
48
107
  @abstractmethod
49
108
  def problem_list_to_str_tuple(self, problems):
50
- pass
109
+ """
110
+ Convert a list of problems to a string tuple.
111
+ """
51
112
 
52
113
  @abstractmethod
53
114
  def goal_to_problem_str(self, goal):
54
- pass
115
+ """
116
+ Convert a goal to a problem string.
117
+ """
55
118
 
56
119
  @abstractmethod
57
120
  def is_action_discrete(self):
58
- pass
121
+ """
122
+ Check if the action space is discrete.
123
+ """
59
124
 
60
125
  @abstractmethod
61
126
  def is_state_discrete(self):
62
- pass
127
+ """
128
+ Check if the state space is discrete.
129
+ """
63
130
 
64
131
  @abstractmethod
65
132
  def get_lstm_props(self):
66
- pass
133
+ """
134
+ Get the LSTM properties for the environment.
135
+ """
67
136
 
68
137
  @abstractmethod
69
138
  def change_done_by_specific_desired(self, obs, desired, old_success_done):
70
- pass
139
+ """
140
+ Change the 'done' flag based on a specific desired goal.
141
+ """
71
142
 
72
143
  @abstractmethod
73
144
  def is_done(self, done):
74
- pass
145
+ """
146
+ Check if the episode is done.
147
+ """
75
148
 
76
149
  @abstractmethod
77
150
  def is_success(self, info):
78
- pass
151
+ """
152
+ Check if the episode is successful.
153
+ """
79
154
 
80
155
  def create_vec_env(self, kwargs):
81
- env = gym.make(**kwargs)
156
+ """
157
+ Create a vectorized environment, suppressing prints from gym/pybullet/panda-gym.
158
+ """
159
+ with suppress_output():
160
+ env = gym.make(**kwargs)
82
161
  return DummyVecEnv([lambda: env])
83
162
 
84
163
  @abstractmethod
85
164
  def change_goal_to_specific_desired(self, obs, desired):
86
- pass
165
+ """
166
+ Change the goal to a specific desired goal.
167
+ """
87
168
 
88
169
 
89
170
  class GCEnvProperty(EnvProperty):
171
+ """
172
+ Base class for goal-conditioned environment properties.
173
+ """
174
+
90
175
  @abstractmethod
91
176
  def use_goal_directed_problem(self):
92
- pass
177
+ """
178
+ Check if the environment uses a goal-directed problem.
179
+ """
93
180
 
94
181
  def problem_list_to_str_tuple(self, problems):
182
+ """
183
+ Convert a list of problems to a string tuple.
184
+ """
95
185
  return "goal_conditioned"
96
186
 
97
187
 
98
188
  class MinigridProperty(EnvProperty):
189
+ """
190
+ Environment properties for the Minigrid domain.
191
+ """
192
+
99
193
  def __init__(self, name):
100
194
  super().__init__(name)
101
195
  self.domain_name = "minigrid"
102
196
 
103
197
  def goal_to_problem_str(self, goal):
198
+ """
199
+ Convert a goal to a problem string.
200
+ """
104
201
  return self.name + f"-DynamicGoal-{goal[0]}x{goal[1]}-v0"
105
202
 
106
203
  def str_to_goal(self, problem_name):
204
+ """
205
+ Convert a problem name to a goal.
206
+ """
107
207
  parts = problem_name.split("-")
108
208
  goal_part = [part for part in parts if "x" in part]
109
209
  width, height = goal_part[0].split("x")
110
210
  return (int(width), int(height))
111
211
 
112
212
  def gc_adaptable(self):
213
+ """
214
+ Check if the environment is goal-conditioned adaptable.
215
+ """
113
216
  return False
114
217
 
115
218
  def problem_list_to_str_tuple(self, problems):
219
+ """
220
+ Convert a list of problems to a string tuple.
221
+ """
116
222
  return "_".join([f"[{s.split('-')[-2]}]" for s in problems])
117
223
 
118
224
  def is_action_discrete(self):
225
+ """
226
+ Check if the action space is discrete.
227
+ """
119
228
  return True
120
229
 
121
230
  def is_state_discrete(self):
231
+ """
232
+ Check if the state space is discrete.
233
+ """
122
234
  return True
123
235
 
124
236
  def get_lstm_props(self):
237
+ """
238
+ Get the LSTM properties for the environment.
239
+ """
125
240
  return LSTMProperties(
126
241
  batch_size=16, input_size=4, hidden_size=8, num_samples=40000
127
242
  )
128
243
 
129
244
  def create_sequence_image(self, sequence, img_path, problem_name):
245
+ """
246
+ Create a sequence image for the environment.
247
+ """
130
248
  if not os.path.exists(os.path.dirname(img_path)):
131
249
  os.makedirs(os.path.dirname(img_path))
132
250
  env_id = (
@@ -134,7 +252,7 @@ class MinigridProperty(EnvProperty):
134
252
  + "-DynamicGoal-"
135
253
  + problem_name.split("-DynamicGoal-")[1]
136
254
  )
137
- result = register(
255
+ register(
138
256
  id=env_id,
139
257
  entry_point="gr_envs.minigrid_scripts.envs:CustomColorEnv",
140
258
  kwargs={
@@ -146,7 +264,6 @@ class MinigridProperty(EnvProperty):
146
264
  "plan": sequence,
147
265
  },
148
266
  )
149
- # print(result)
150
267
  env = gym.make(id=env_id)
151
268
  env = RGBImgPartialObsWrapper(env) # Get pixel observations
152
269
  env = ImgObsWrapper(env) # Get rid of the 'mission' field
@@ -156,34 +273,62 @@ class MinigridProperty(EnvProperty):
156
273
 
157
274
  ####### save image to file
158
275
  image_pil = Image.fromarray(np.uint8(img)).convert("RGB")
159
- image_pil.save(r"{}.png".format(img_path))
276
+ image_pil.save(r"{}.png".format(os.path.join(img_path, "plan_image")))
160
277
 
161
278
  def change_done_by_specific_desired(self, obs, desired, old_success_done):
279
+ """
280
+ Change the 'done' flag based on a specific desired goal.
281
+ """
162
282
  assert (
163
283
  desired is None
164
284
  ), "In MinigridProperty, giving a specific 'desired' is not supported."
165
285
  return old_success_done
166
286
 
167
287
  def is_done(self, done):
288
+ """
289
+ Check if the episode is done.
290
+ """
168
291
  assert isinstance(done, np.ndarray)
169
292
  return done[0]
170
293
 
171
- # Not used currently since TabularQLearner doesn't need is_success from the environment
172
294
  def is_success(self, info):
295
+ """
296
+ Check if the episode is successful.
297
+ """
173
298
  raise NotImplementedError("no other option for any of the environments.")
174
299
 
175
300
  def change_goal_to_specific_desired(self, obs, desired):
301
+ """
302
+ Change the goal to a specific desired goal.
303
+ """
176
304
  assert (
177
305
  desired is None
178
306
  ), "In MinigridProperty, giving a specific 'desired' is not supported."
179
307
 
180
308
 
181
309
  class PandaProperty(GCEnvProperty):
310
+ """
311
+ Environment properties for the Panda domain.
312
+ """
313
+
182
314
  def __init__(self, name):
315
+ """
316
+ Initialize a new instance of the Environment class.
317
+
318
+ Args:
319
+ name (str): The name of the environment.
320
+
321
+ Attributes:
322
+ domain_name (str): The domain name of the environment.
323
+
324
+ """
183
325
  super().__init__(name)
184
326
  self.domain_name = "panda"
185
327
 
186
328
  def str_to_goal(self, problem_name):
329
+ """
330
+ Convert a problem name to a goal.
331
+ """
187
332
  try:
188
333
  numeric_part = problem_name.split("PandaMyReachDenseX")[1]
189
334
  components = [
@@ -194,38 +339,62 @@ class PandaProperty(GCEnvProperty):
194
339
  for component in components:
195
340
  floats.append(float(component))
196
341
  return np.array([floats], dtype=np.float32)
197
- except Exception as e:
342
+ except Exception:
198
343
  return "general"
199
344
 
200
345
  def goal_to_problem_str(self, goal):
346
+ """
347
+ Convert a goal to a problem string.
348
+ """
201
349
  goal_str = "X".join(
202
350
  [str(float(g)).replace(".", "y").replace("-", "M") for g in goal[0]]
203
351
  )
204
352
  return f"PandaMyReachDenseX{goal_str}-v3"
205
353
 
206
354
  def gc_adaptable(self):
355
+ """
356
+ Check if the environment is goal-conditioned adaptable.
357
+ """
207
358
  return True
208
359
 
209
360
  def use_goal_directed_problem(self):
361
+ """
362
+ Check if the environment uses a goal-directed problem.
363
+ """
210
364
  return False
211
365
 
212
366
  def is_action_discrete(self):
367
+ """
368
+ Check if the action space is discrete.
369
+ """
213
370
  return False
214
371
 
215
372
  def is_state_discrete(self):
373
+ """
374
+ Check if the state space is discrete.
375
+ """
216
376
  return False
217
377
 
218
378
  def get_lstm_props(self):
379
+ """
380
+ Get the LSTM properties for the environment.
381
+ """
219
382
  return LSTMProperties(
220
383
  batch_size=32, input_size=9, hidden_size=8, num_samples=20000
221
384
  )
222
385
 
223
386
  def sample_goal():
387
+ """
388
+ Sample a random goal.
389
+ """
224
390
  goal_range_low = np.array([-0.40, -0.40, 0.10])
225
391
  goal_range_high = np.array([0.2, 0.2, 0.10])
226
392
  return np.random.uniform(goal_range_low, goal_range_high)
227
393
 
228
394
  def change_done_by_specific_desired(self, obs, desired, old_success_done):
395
+ """
396
+ Change the 'done' flag based on a specific desired goal.
397
+ """
229
398
  if desired is None:
230
399
  return old_success_done
231
400
  assert isinstance(
@@ -241,70 +410,134 @@ class PandaProperty(GCEnvProperty):
241
410
  return old_success_done
242
411
 
243
412
  def is_done(self, done):
413
+ """
414
+ Check if the episode is done.
415
+ """
244
416
  assert isinstance(done, np.ndarray)
245
417
  return done[0]
246
418
 
247
419
  def is_success(self, info):
420
+ """
421
+ Check if the episode is successful.
422
+ """
248
423
  assert "is_success" in info[0].keys()
249
424
  return info[0]["is_success"]
250
425
 
251
426
  def change_goal_to_specific_desired(self, obs, desired):
427
+ """
428
+ Change the goal to a specific desired goal.
429
+ """
252
430
  if desired is not None:
253
431
  obs["desired_goal"] = desired
254
432
 
255
433
 
256
434
  class ParkingProperty(GCEnvProperty):
435
+ """
436
+ Environment properties for the Parking domain.
437
+ """
257
438
 
258
439
  def __init__(self, name):
440
+ """
441
+ Initialize a new environment object.
442
+
443
+ Args:
444
+ name (str): The name of the environment.
445
+
446
+ Attributes:
447
+ domain_name (str): The domain name of the environment.
448
+
449
+ """
259
450
  super().__init__(name)
260
451
  self.domain_name = "parking"
261
452
 
262
453
  def goal_to_problem_str(self, goal):
454
+ """
455
+ Convert a goal to a problem string.
456
+ """
263
457
  return self.name.split("-v0")[0] + f"-GI-{goal}-v0"
264
458
 
265
459
  def gc_adaptable(self):
460
+ """
461
+ Check if the environment is goal-conditioned adaptable.
462
+ """
266
463
  return True
267
464
 
268
465
  def is_action_discrete(self):
466
+ """
467
+ Check if the action space is discrete.
468
+ """
269
469
  return False
270
470
 
271
471
  def is_state_discrete(self):
472
+ """
473
+ Check if the state space is discrete.
474
+ """
272
475
  return False
273
476
 
274
477
  def use_goal_directed_problem(self):
478
+ """
479
+ Check if the environment uses a goal-directed problem.
480
+ """
275
481
  return True
276
482
 
277
483
  def get_lstm_props(self):
484
+ """
485
+ Get the LSTM properties for the environment.
486
+ """
278
487
  return LSTMProperties(
279
488
  batch_size=32, input_size=8, hidden_size=8, num_samples=20000
280
489
  )
281
490
 
282
491
  def change_done_by_specific_desired(self, obs, desired, old_success_done):
492
+ """
493
+ Change the 'done' flag based on a specific desired goal.
494
+ """
283
495
  assert (
284
496
  desired is None
285
497
  ), "In ParkingProperty, giving a specific 'desired' is not supported."
286
498
  return old_success_done
287
499
 
288
500
  def is_done(self, done):
501
+ """
502
+ Check if the episode is done.
503
+ """
289
504
  assert isinstance(done, np.ndarray)
290
505
  return done[0]
291
506
 
292
507
  def is_success(self, info):
508
+ """
509
+ Check if the episode is successful.
510
+ """
293
511
  assert "is_success" in info[0].keys()
294
512
  return info[0]["is_success"]
295
513
 
296
514
  def change_goal_to_specific_desired(self, obs, desired):
515
+ """
516
+ Change the goal to a specific desired goal.
517
+ """
297
518
  assert (
298
519
  desired is None
299
520
  ), "In ParkingProperty, giving a specific 'desired' is not supported."
300
521
 
301
522
 
302
523
  class PointMazeProperty(EnvProperty):
524
+ """Environment properties for the Point Maze domain."""
525
+
303
526
  def __init__(self, name):
527
+ """
528
+ Initializes a new instance of the Environment class.
529
+
530
+ Args:
531
+ name (str): The name of the environment.
532
+
533
+ Attributes:
534
+ domain_name (str): The domain name of the environment.
535
+ """
304
536
  super().__init__(name)
305
537
  self.domain_name = "point_maze"
306
538
 
307
539
  def str_to_goal(self):
540
+ """Convert a problem name to a goal."""
308
541
  parts = self.name.split("-")
309
542
  # Find the part containing the goal size (usually after "DynamicGoal")
310
543
  sizes_parts = [part for part in parts if "x" in part]
@@ -314,40 +547,62 @@ class PointMazeProperty(EnvProperty):
314
547
  return (int(width), int(height))
315
548
 
316
549
  def gc_adaptable(self):
550
+ """Check if the environment is goal-conditioned adaptable."""
317
551
  return False
318
552
 
319
553
  def problem_list_to_str_tuple(self, problems):
554
+ """Convert a list of problems to a string tuple."""
320
555
  return "_".join([f"[{s.split('-')[-1]}]" for s in problems])
321
556
 
322
557
  def is_action_discrete(self):
558
+ """Check if the action space is discrete."""
323
559
  return False
324
560
 
325
561
  def is_state_discrete(self):
562
+ """Check if the state space is discrete."""
326
563
  return False
327
564
 
328
565
  def get_lstm_props(self):
566
+ """
567
+ Get the LSTM properties for the environment.
568
+ """
329
569
  return LSTMProperties(
330
570
  batch_size=32, input_size=6, hidden_size=8, num_samples=20000
331
571
  )
332
572
 
333
573
  def goal_to_problem_str(self, goal):
574
+ """
575
+ Convert a goal to a problem string.
576
+ """
334
577
  return self.name + f"-Goal-{goal[0]}x{goal[1]}"
335
578
 
336
579
  def change_done_by_specific_desired(self, obs, desired, old_success_done):
580
+ """
581
+ Change the 'done' flag based on a specific desired goal.
582
+ """
337
583
  assert (
338
584
  desired is None
339
585
  ), "In PointMazeProperty, giving a specific 'desired' is not supported."
340
586
  return old_success_done
341
587
 
342
588
  def is_done(self, done):
589
+ """
590
+ Check if the episode is done.
591
+ """
343
592
  assert isinstance(done, np.ndarray)
344
593
  return done[0]
345
594
 
346
595
  def is_success(self, info):
596
+ """
597
+ Check if the episode is successful.
598
+ """
347
599
  assert "success" in info[0].keys()
348
600
  return info[0]["success"]
349
601
 
350
602
  def change_goal_to_specific_desired(self, obs, desired):
603
+ """
604
+ Change the goal to a specific desired goal.
605
+ """
351
606
  assert (
352
607
  desired is None
353
608
  ), "In ParkingProperty, giving a specific 'desired' is not supported."
@@ -0,0 +1 @@
1
+ """ This is a directory that includes scripts for analysis of GR results. """
@@ -0,0 +1,100 @@
1
+ import argparse
2
+ import os
3
+
4
+ import dill
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+ from gr_libs.ml.utils.storage import get_experiment_results_path
9
+
10
+
11
+ def load_results(domain, env, task, recognizer, n_runs, percentage, cons_type):
12
+ # Collect accuracy for a single task and recognizer
13
+ accs = []
14
+ res_dir = get_experiment_results_path(domain, env, task, recognizer)
15
+ if not os.path.exists(res_dir):
16
+ return accs
17
+ for i in range(n_runs):
18
+ res_file = os.path.join(res_dir, f"res_{i}.pkl")
19
+ if not os.path.exists(res_file):
20
+ continue
21
+ with open(res_file, "rb") as f:
22
+ results = dill.load(f)
23
+ if percentage in results and cons_type in results[percentage]:
24
+ acc = results[percentage][cons_type].get("accuracy")
25
+ if acc is not None:
26
+ accs.append(acc)
27
+ return accs
28
+
29
+
30
+ def main():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--domain", required=True)
33
+ parser.add_argument("--env", required=True)
34
+ parser.add_argument("--tasks", nargs="+", required=True)
35
+ parser.add_argument("--recognizers", nargs="+", required=True)
36
+ parser.add_argument("--n_runs", type=int, default=5)
37
+ parser.add_argument("--percentage", required=True)
38
+ parser.add_argument(
39
+ "--cons_type", choices=["consecutive", "non_consecutive"], required=True
40
+ )
41
+ parser.add_argument("--graph_name", type=str, default="experiment_results")
42
+ args = parser.parse_args()
43
+
44
+ plt.figure(figsize=(7, 5))
45
+ has_data = False
46
+ missing_recognizers = []
47
+
48
+ for recognizer in args.recognizers:
49
+ x_vals = []
50
+ y_means = []
51
+ y_sems = []
52
+ for task in args.tasks:
53
+ accs = load_results(
54
+ args.domain,
55
+ args.env,
56
+ task,
57
+ recognizer,
58
+ args.n_runs,
59
+ args.percentage,
60
+ args.cons_type,
61
+ )
62
+ if accs:
63
+ x_vals.append(task)
64
+ y_means.append(np.mean(accs))
65
+ y_sems.append(np.std(accs) / np.sqrt(len(accs)))
66
+ if x_vals:
67
+ has_data = True
68
+ x_ticks = np.arange(len(x_vals))
69
+ plt.plot(x_ticks, y_means, marker="o", label=recognizer)
70
+ plt.fill_between(
71
+ x_ticks,
72
+ np.array(y_means) - np.array(y_sems),
73
+ np.array(y_means) + np.array(y_sems),
74
+ alpha=0.2,
75
+ )
76
+ plt.xticks(x_ticks, x_vals)
77
+ else:
78
+ print(
79
+ f"Warning: No data found for recognizer '{recognizer}' in {args.domain} / {args.env} / {args.percentage} / {args.cons_type}"
80
+ )
81
+ missing_recognizers.append(recognizer)
82
+
83
+ if not has_data:
84
+ raise RuntimeError(
85
+ f"No data found for any recognizer in {args.domain} / {args.env} / {args.percentage} / {args.cons_type}. "
86
+ f"Missing recognizers: {', '.join(missing_recognizers)}"
87
+ )
88
+
89
+ plt.xlabel("Task")
90
+ plt.ylabel("Accuracy")
91
+ plt.title(f"{args.domain} - {args.env} ({args.percentage}, {args.cons_type})")
92
+ plt.legend()
93
+ plt.grid(True)
94
+ fig_path = f"{args.graph_name}_{'_'.join(args.recognizers)}_{args.domain}_{args.env}_{args.percentage}_{args.cons_type}.png"
95
+ plt.savefig(fig_path)
96
+ print(f"Figure saved at: {fig_path}")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()