gr-libs 0.1.7.post0__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. evaluation/analyze_results_cross_alg_cross_domain.py +236 -246
  2. evaluation/create_minigrid_map_image.py +10 -6
  3. evaluation/file_system.py +16 -5
  4. evaluation/generate_experiments_results.py +123 -74
  5. evaluation/generate_experiments_results_new_ver1.py +227 -243
  6. evaluation/generate_experiments_results_new_ver2.py +317 -317
  7. evaluation/generate_task_specific_statistics_plots.py +481 -253
  8. evaluation/get_plans_images.py +41 -26
  9. evaluation/increasing_and_decreasing_.py +97 -56
  10. gr_libs/__init__.py +2 -1
  11. gr_libs/_version.py +2 -2
  12. gr_libs/environment/__init__.py +16 -8
  13. gr_libs/environment/environment.py +167 -39
  14. gr_libs/environment/utils/utils.py +22 -12
  15. gr_libs/metrics/__init__.py +5 -0
  16. gr_libs/metrics/metrics.py +76 -34
  17. gr_libs/ml/__init__.py +2 -0
  18. gr_libs/ml/agent.py +21 -6
  19. gr_libs/ml/base/__init__.py +1 -1
  20. gr_libs/ml/base/rl_agent.py +13 -10
  21. gr_libs/ml/consts.py +1 -1
  22. gr_libs/ml/neural/deep_rl_learner.py +433 -352
  23. gr_libs/ml/neural/utils/__init__.py +1 -1
  24. gr_libs/ml/neural/utils/dictlist.py +3 -3
  25. gr_libs/ml/neural/utils/penv.py +5 -2
  26. gr_libs/ml/planner/mcts/mcts_model.py +524 -302
  27. gr_libs/ml/planner/mcts/utils/__init__.py +1 -1
  28. gr_libs/ml/planner/mcts/utils/node.py +11 -7
  29. gr_libs/ml/planner/mcts/utils/tree.py +14 -10
  30. gr_libs/ml/sequential/__init__.py +1 -1
  31. gr_libs/ml/sequential/lstm_model.py +256 -175
  32. gr_libs/ml/tabular/state.py +7 -7
  33. gr_libs/ml/tabular/tabular_q_learner.py +123 -73
  34. gr_libs/ml/tabular/tabular_rl_agent.py +20 -19
  35. gr_libs/ml/utils/__init__.py +8 -2
  36. gr_libs/ml/utils/format.py +78 -70
  37. gr_libs/ml/utils/math.py +2 -1
  38. gr_libs/ml/utils/other.py +1 -1
  39. gr_libs/ml/utils/storage.py +88 -28
  40. gr_libs/problems/consts.py +1549 -1227
  41. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +145 -80
  42. gr_libs/recognizer/graml/gr_dataset.py +209 -110
  43. gr_libs/recognizer/graml/graml_recognizer.py +431 -240
  44. gr_libs/recognizer/recognizer.py +38 -27
  45. gr_libs/recognizer/utils/__init__.py +1 -1
  46. gr_libs/recognizer/utils/format.py +8 -3
  47. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/METADATA +1 -1
  48. gr_libs-0.1.8.dist-info/RECORD +70 -0
  49. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/WHEEL +1 -1
  50. tests/test_gcdraco.py +10 -0
  51. tests/test_graml.py +8 -4
  52. tests/test_graql.py +2 -1
  53. tutorials/gcdraco_panda_tutorial.py +66 -0
  54. tutorials/gcdraco_parking_tutorial.py +61 -0
  55. tutorials/graml_minigrid_tutorial.py +42 -12
  56. tutorials/graml_panda_tutorial.py +35 -14
  57. tutorials/graml_parking_tutorial.py +37 -20
  58. tutorials/graml_point_maze_tutorial.py +33 -13
  59. tutorials/graql_minigrid_tutorial.py +31 -15
  60. gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
  61. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/top_level.txt +0 -0
@@ -10,38 +10,53 @@ GRAML_includer = os.path.dirname(os.path.dirname(currentdir))
10
10
  sys.path.insert(0, GRAML_includer)
11
11
  sys.path.insert(0, GRAML_itself)
12
12
 
13
+
13
14
  def get_plans_result_path(env_name):
14
- return os.path.join("dataset", (env_name), "plans")
15
+ return os.path.join("dataset", (env_name), "plans")
16
+
15
17
 
16
18
  def get_policy_sequences_result_path(env_name):
17
- return os.path.join("dataset", (env_name), "policy_sequences")
19
+ return os.path.join("dataset", (env_name), "policy_sequences")
18
20
 
19
21
 
20
22
  # TODO: instead of loading the model and having it produce the sequence again, just save the sequence from the framework run, and have this script accept the whole path (including is_fragmented etc.)
21
23
  def analyze_and_produce_images(env_name):
22
- models_dir = get_models_dir(env_name=env_name)
23
- for dirname in os.listdir(models_dir):
24
- if dirname.startswith('MiniGrid'):
25
- model_dir = get_model_dir(env_name=env_name, model_name=dirname, class_name="MCTS")
26
- model_file_path = os.path.join(model_dir, "mcts_model.pth")
27
- try:
28
- with open(model_file_path, 'rb') as file: # Load the pre-existing model
29
- monteCarloTreeSearch = pickle.load(file)
30
- full_plan = monteCarloTreeSearch.generate_full_policy_sequence()
31
- plan = [pos for ((state, pos), action) in full_plan]
32
- plans_result_path = get_plans_result_path(env_name)
33
- if not os.path.exists(plans_result_path): os.makedirs(plans_result_path)
34
- img_path = os.path.join(get_plans_result_path(env_name), dirname)
35
- print(f"plan to {dirname} is:\n\t{plan}\ngenerating image at {img_path}.")
36
- create_sequence_image(plan, img_path, dirname)
37
-
38
- except FileNotFoundError as e:
39
- print(f"Warning: {e.filename} doesn't exist. It's probably a base goal, not generating policy sequence for it.")
24
+ models_dir = get_models_dir(env_name=env_name)
25
+ for dirname in os.listdir(models_dir):
26
+ if dirname.startswith("MiniGrid"):
27
+ model_dir = get_model_dir(
28
+ env_name=env_name, model_name=dirname, class_name="MCTS"
29
+ )
30
+ model_file_path = os.path.join(model_dir, "mcts_model.pth")
31
+ try:
32
+ with open(model_file_path, "rb") as file: # Load the pre-existing model
33
+ monteCarloTreeSearch = pickle.load(file)
34
+ full_plan = monteCarloTreeSearch.generate_full_policy_sequence()
35
+ plan = [pos for ((state, pos), action) in full_plan]
36
+ plans_result_path = get_plans_result_path(env_name)
37
+ if not os.path.exists(plans_result_path):
38
+ os.makedirs(plans_result_path)
39
+ img_path = os.path.join(get_plans_result_path(env_name), dirname)
40
+ print(
41
+ f"plan to {dirname} is:\n\t{plan}\ngenerating image at {img_path}."
42
+ )
43
+ create_sequence_image(plan, img_path, dirname)
44
+
45
+ except FileNotFoundError as e:
46
+ print(
47
+ f"Warning: {e.filename} doesn't exist. It's probably a base goal, not generating policy sequence for it."
48
+ )
49
+
40
50
 
41
51
  if __name__ == "__main__":
42
- # preventing circular imports. only needed for running this as main anyway.
43
- from gr_libs.ml.utils.storage import get_models_dir, get_model_dir
44
- # checks:
45
- assert len(sys.argv) == 2, f"Assertion failed: len(sys.argv) is {len(sys.argv)} while it needs to be 2.\n Example: \n\t /usr/bin/python scripts/get_plans_images.py MiniGrid-Walls-13x13-v0"
46
- assert os.path.exists(get_models_dir(sys.argv[1])), "plans weren't made for this environment, run graml_main.py with this environment first."
47
- analyze_and_produce_images(sys.argv[1])
52
+ # preventing circular imports. only needed for running this as main anyway.
53
+ from gr_libs.ml.utils.storage import get_models_dir, get_model_dir
54
+
55
+ # checks:
56
+ assert (
57
+ len(sys.argv) == 2
58
+ ), f"Assertion failed: len(sys.argv) is {len(sys.argv)} while it needs to be 2.\n Example: \n\t /usr/bin/python scripts/get_plans_images.py MiniGrid-Walls-13x13-v0"
59
+ assert os.path.exists(
60
+ get_models_dir(sys.argv[1])
61
+ ), "plans weren't made for this environment, run graml_main.py with this environment first."
62
+ analyze_and_produce_images(sys.argv[1])
@@ -2,62 +2,103 @@ import os
2
2
  import dill
3
3
  import numpy as np
4
4
  import matplotlib.pyplot as plt
5
- from gr_libs.ml.utils.storage import get_experiment_results_path, set_global_storage_configs
5
+ from gr_libs.ml.utils.storage import (
6
+ get_experiment_results_path,
7
+ set_global_storage_configs,
8
+ )
6
9
 
7
10
  if __name__ == "__main__":
8
11
 
9
- # Define the tasks and percentages
10
- increasing_base_goals = ['L1', 'L2', 'L3', 'L4', 'L5']
11
- increasing_dynamic_goals = ['L111', 'L222', 'L555', 'L333', 'L444']
12
- percentages = ['0.3', '0.5', '0.7', '0.9', '1']
13
-
14
- # Prepare a dictionary to hold accuracy data
15
- accuracies = {task: {perc: [] for perc in percentages} for task in increasing_base_goals + increasing_dynamic_goals}
16
-
17
- # Collect data for both sets of goals
18
- for task in increasing_base_goals + increasing_dynamic_goals:
19
- set_global_storage_configs(recognizer_str='graml', is_fragmented='fragmented',
20
- is_inference_same_length_sequences=True, is_learn_same_length_sequences=False)
21
- res_file_path = f'{get_experiment_results_path("parking", "gd_agent", task)}.pkl'
22
-
23
- if os.path.exists(res_file_path):
24
- with open(res_file_path, 'rb') as results_file:
25
- results = dill.load(results_file)
26
- for percentage in percentages:
27
- accuracies[task][percentage].append(results[percentage]['accuracy'])
28
- else:
29
- print(f"Warning: no file for {res_file_path}")
30
-
31
- # Create the figure with two subplots
32
- fig, axes = plt.subplots(1, 2, figsize=(12, 6))
33
-
34
- # Bar plot function
35
- def plot_accuracies(ax, task_set, title, type):
36
- """Plot accuracies for a given set of tasks on the provided axis."""
37
- x_vals = np.arange(len(task_set)) # X-axis positions for the number of goals
38
- bar_width = 0.15 # Width of each bar
39
- for i, perc in enumerate(['0.3', '0.5', '1']):
40
- if perc == '1': y_vals = [max([accuracies[task]['0.5'][0], accuracies[task]['0.7'][0], accuracies[task]['0.9'][0], accuracies[task]['1'][0]]) for task in task_set] # Get mean accuracies
41
- else: y_vals = [accuracies[task][perc][0] for task in task_set] # Get mean accuracies
42
- if type != 'base': ax.bar(x_vals + i * bar_width, y_vals, width=bar_width, label=f'Percentage {perc}')
43
- else: ax.bar(x_vals + i * bar_width, y_vals, width=bar_width)
44
- ax.set_xticks(x_vals + bar_width) # Center x-ticks
45
- ax.set_xticklabels([i+3 for i in range(len(task_set))], fontsize=16) # Set custom x-tick labels
46
- ax.set_yticks(np.linspace(0, 1, 6))
47
- ax.set_ylim([0, 1])
48
- ax.set_title(title, fontsize=20)
49
- ax.set_xlabel(f'Number of {type} Goals', fontsize=20)
50
- if type == 'base':
51
- ax.set_ylabel('Accuracy', fontsize=22)
52
- ax.legend()
53
-
54
- # Plot for increasing base goals
55
- plot_accuracies(axes[0], increasing_base_goals, 'Increasing Base Goals', "base")
56
-
57
- # Plot for increasing dynamic goals
58
- plot_accuracies(axes[1], increasing_dynamic_goals, 'Increasing Active Goals', "active")
59
- plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1, wspace=0.3, hspace=0.3)
60
- # Adjust layout and save the figure
61
- plt.tight_layout()
62
- plt.savefig('increasing_goals_plot_bars.png', dpi=300) # Save the figure as a PNG file
63
- print('Figure saved at: increasing_goals_plot_bars.png')
12
+ # Define the tasks and percentages
13
+ increasing_base_goals = ["L1", "L2", "L3", "L4", "L5"]
14
+ increasing_dynamic_goals = ["L111", "L222", "L555", "L333", "L444"]
15
+ percentages = ["0.3", "0.5", "0.7", "0.9", "1"]
16
+
17
+ # Prepare a dictionary to hold accuracy data
18
+ accuracies = {
19
+ task: {perc: [] for perc in percentages}
20
+ for task in increasing_base_goals + increasing_dynamic_goals
21
+ }
22
+
23
+ # Collect data for both sets of goals
24
+ for task in increasing_base_goals + increasing_dynamic_goals:
25
+ set_global_storage_configs(
26
+ recognizer_str="graml",
27
+ is_fragmented="fragmented",
28
+ is_inference_same_length_sequences=True,
29
+ is_learn_same_length_sequences=False,
30
+ )
31
+ res_file_path = (
32
+ f'{get_experiment_results_path("parking", "gd_agent", task)}.pkl'
33
+ )
34
+
35
+ if os.path.exists(res_file_path):
36
+ with open(res_file_path, "rb") as results_file:
37
+ results = dill.load(results_file)
38
+ for percentage in percentages:
39
+ accuracies[task][percentage].append(results[percentage]["accuracy"])
40
+ else:
41
+ print(f"Warning: no file for {res_file_path}")
42
+
43
+ # Create the figure with two subplots
44
+ fig, axes = plt.subplots(1, 2, figsize=(12, 6))
45
+
46
+ # Bar plot function
47
+ def plot_accuracies(ax, task_set, title, type):
48
+ """Plot accuracies for a given set of tasks on the provided axis."""
49
+ x_vals = np.arange(len(task_set)) # X-axis positions for the number of goals
50
+ bar_width = 0.15 # Width of each bar
51
+ for i, perc in enumerate(["0.3", "0.5", "1"]):
52
+ if perc == "1":
53
+ y_vals = [
54
+ max(
55
+ [
56
+ accuracies[task]["0.5"][0],
57
+ accuracies[task]["0.7"][0],
58
+ accuracies[task]["0.9"][0],
59
+ accuracies[task]["1"][0],
60
+ ]
61
+ )
62
+ for task in task_set
63
+ ] # Get mean accuracies
64
+ else:
65
+ y_vals = [
66
+ accuracies[task][perc][0] for task in task_set
67
+ ] # Get mean accuracies
68
+ if type != "base":
69
+ ax.bar(
70
+ x_vals + i * bar_width,
71
+ y_vals,
72
+ width=bar_width,
73
+ label=f"Percentage {perc}",
74
+ )
75
+ else:
76
+ ax.bar(x_vals + i * bar_width, y_vals, width=bar_width)
77
+ ax.set_xticks(x_vals + bar_width) # Center x-ticks
78
+ ax.set_xticklabels(
79
+ [i + 3 for i in range(len(task_set))], fontsize=16
80
+ ) # Set custom x-tick labels
81
+ ax.set_yticks(np.linspace(0, 1, 6))
82
+ ax.set_ylim([0, 1])
83
+ ax.set_title(title, fontsize=20)
84
+ ax.set_xlabel(f"Number of {type} Goals", fontsize=20)
85
+ if type == "base":
86
+ ax.set_ylabel("Accuracy", fontsize=22)
87
+ ax.legend()
88
+
89
+ # Plot for increasing base goals
90
+ plot_accuracies(axes[0], increasing_base_goals, "Increasing Base Goals", "base")
91
+
92
+ # Plot for increasing dynamic goals
93
+ plot_accuracies(
94
+ axes[1], increasing_dynamic_goals, "Increasing Active Goals", "active"
95
+ )
96
+ plt.subplots_adjust(
97
+ left=0.1, right=0.9, top=0.9, bottom=0.1, wspace=0.3, hspace=0.3
98
+ )
99
+ # Adjust layout and save the figure
100
+ plt.tight_layout()
101
+ plt.savefig(
102
+ "increasing_goals_plot_bars.png", dpi=300
103
+ ) # Save the figure as a PNG file
104
+ print("Figure saved at: increasing_goals_plot_bars.png")
gr_libs/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from gr_libs.recognizer.graml.graml_recognizer import ExpertBasedGraml, GCGraml
2
- from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Graql
2
+ from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Graql, Draco, GCDraco
3
+
3
4
  try:
4
5
  from ._version import version as __version__
5
6
  except ImportError:
gr_libs/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.7.post0'
21
- __version_tuple__ = version_tuple = (0, 1, 7, 'post0')
20
+ __version__ = version = '0.1.8'
21
+ __version_tuple__ = version_tuple = (0, 1, 8)
@@ -1,22 +1,30 @@
1
1
  import importlib.metadata
2
2
  import warnings
3
3
 
4
+
4
5
  def is_extra_installed(package: str, extra: str) -> bool:
5
6
  """Check if an extra was installed for a given package."""
6
7
  try:
7
8
  # Get metadata for the installed package
8
9
  dist = importlib.metadata.metadata(package)
9
- requires = dist.get_all("Requires-Dist", []) # Dependencies listed in the package metadata
10
+ requires = dist.get_all(
11
+ "Requires-Dist", []
12
+ ) # Dependencies listed in the package metadata
10
13
  return any(extra in req for req in requires)
11
14
  except importlib.metadata.PackageNotFoundError:
12
15
  return False # The package is not installed
13
16
 
17
+
14
18
  # Check if `gr_libs[minigrid]` was installed
15
19
  for env in ["minigrid", "panda", "highway", "maze"]:
16
- if is_extra_installed("gr_libs", f"gr_envs[{env}]"):
17
- try:
18
- importlib.import_module(f"gr_envs.{env}_scripts.envs")
19
- except ImportError:
20
- raise ImportError(f"gr_envs[{env}] was not installed, but gr_libs[{env}] requires it! if you messed with gr_envs installation, you can reinstall gr_libs.")
21
- else:
22
- warnings.warn(f"gr_libs[{env}] was not installed, skipping {env} imports.", RuntimeWarning)
20
+ if is_extra_installed("gr_libs", f"gr_envs[{env}]"):
21
+ try:
22
+ importlib.import_module(f"gr_envs.{env}_scripts.envs")
23
+ except ImportError:
24
+ raise ImportError(
25
+ f"gr_envs[{env}] was not installed, but gr_libs[{env}] requires it! if you messed with gr_envs installation, you can reinstall gr_libs."
26
+ )
27
+ else:
28
+ warnings.warn(
29
+ f"gr_libs[{env}] was not installed, skipping {env} imports.", RuntimeWarning
30
+ )
@@ -2,7 +2,8 @@ from abc import abstractmethod
2
2
  from collections import namedtuple
3
3
  import os
4
4
 
5
- import gymnasium
5
+ import gymnasium as gym
6
+ from stable_baselines3.common.vec_env import DummyVecEnv
6
7
  from PIL import Image
7
8
  import numpy as np
8
9
  from gymnasium.envs.registration import register
@@ -15,8 +16,9 @@ QLEARNING = "QLEARNING"
15
16
 
16
17
  SUPPORTED_DOMAINS = [MINIGRID, PANDA, PARKING, POINT_MAZE]
17
18
 
18
- LSTMProperties = namedtuple('LSTMProperties', ['input_size', 'hidden_size', 'batch_size', 'num_samples'])
19
-
19
+ LSTMProperties = namedtuple(
20
+ "LSTMProperties", ["input_size", "hidden_size", "batch_size", "num_samples"]
21
+ )
20
22
 
21
23
 
22
24
  class EnvProperty:
@@ -34,7 +36,7 @@ class EnvProperty:
34
36
 
35
37
  def __ne__(self, other):
36
38
  return not self.__eq__(other)
37
-
39
+
38
40
  @abstractmethod
39
41
  def str_to_goal(self):
40
42
  pass
@@ -63,6 +65,27 @@ class EnvProperty:
63
65
  def get_lstm_props(self):
64
66
  pass
65
67
 
68
+ @abstractmethod
69
+ def change_done_by_specific_desired(self, obs, desired, old_success_done):
70
+ pass
71
+
72
+ @abstractmethod
73
+ def is_done(self, done):
74
+ pass
75
+
76
+ @abstractmethod
77
+ def is_success(self, info):
78
+ pass
79
+
80
+ def create_vec_env(self, kwargs):
81
+ env = gym.make(**kwargs)
82
+ return DummyVecEnv([lambda: env])
83
+
84
+ @abstractmethod
85
+ def change_goal_to_specific_desired(self, obs, desired):
86
+ pass
87
+
88
+
66
89
  class GCEnvProperty(EnvProperty):
67
90
  @abstractmethod
68
91
  def use_goal_directed_problem(self):
@@ -71,6 +94,7 @@ class GCEnvProperty(EnvProperty):
71
94
  def problem_list_to_str_tuple(self, problems):
72
95
  return "goal_conditioned"
73
96
 
97
+
74
98
  class MinigridProperty(EnvProperty):
75
99
  def __init__(self, name):
76
100
  super().__init__(name)
@@ -87,10 +111,10 @@ class MinigridProperty(EnvProperty):
87
111
 
88
112
  def gc_adaptable(self):
89
113
  return False
90
-
114
+
91
115
  def problem_list_to_str_tuple(self, problems):
92
116
  return "_".join([f"[{s.split('-')[-2]}]" for s in problems])
93
-
117
+
94
118
  def is_action_discrete(self):
95
119
  return True
96
120
 
@@ -98,34 +122,62 @@ class MinigridProperty(EnvProperty):
98
122
  return True
99
123
 
100
124
  def get_lstm_props(self):
101
- return LSTMProperties(batch_size=16, input_size=4, hidden_size=8, num_samples=40000)
102
-
125
+ return LSTMProperties(
126
+ batch_size=16, input_size=4, hidden_size=8, num_samples=40000
127
+ )
128
+
103
129
  def create_sequence_image(self, sequence, img_path, problem_name):
104
- if not os.path.exists(os.path.dirname(img_path)): os.makedirs(os.path.dirname(img_path))
105
- env_id = problem_name.split("-DynamicGoal-")[0] + "-DynamicGoal-" + problem_name.split("-DynamicGoal-")[1]
130
+ if not os.path.exists(os.path.dirname(img_path)):
131
+ os.makedirs(os.path.dirname(img_path))
132
+ env_id = (
133
+ problem_name.split("-DynamicGoal-")[0]
134
+ + "-DynamicGoal-"
135
+ + problem_name.split("-DynamicGoal-")[1]
136
+ )
106
137
  result = register(
107
138
  id=env_id,
108
139
  entry_point="gr_envs.minigrid_scripts.envs:CustomColorEnv",
109
- kwargs={"size": 13 if 'Simple' in problem_name else 9,
110
- "num_crossings": 4 if 'Simple' in problem_name else 3,
111
- "goal_pos": self.str_to_goal(problem_name),
112
- "obstacle_type": Wall if 'Simple' in problem_name else Lava,
113
- "start_pos": (1, 1) if 'Simple' in problem_name else (3, 1),
114
- "plan": sequence},
140
+ kwargs={
141
+ "size": 13 if "Simple" in problem_name else 9,
142
+ "num_crossings": 4 if "Simple" in problem_name else 3,
143
+ "goal_pos": self.str_to_goal(problem_name),
144
+ "obstacle_type": Wall if "Simple" in problem_name else Lava,
145
+ "start_pos": (1, 1) if "Simple" in problem_name else (3, 1),
146
+ "plan": sequence,
147
+ },
115
148
  )
116
- #print(result)
117
- env = gymnasium.make(id=env_id)
118
- env = RGBImgPartialObsWrapper(env) # Get pixel observations
119
- env = ImgObsWrapper(env) # Get rid of the 'mission' field
120
- obs, _ = env.reset() # This now produces an RGB tensor only
149
+ # print(result)
150
+ env = gym.make(id=env_id)
151
+ env = RGBImgPartialObsWrapper(env) # Get pixel observations
152
+ env = ImgObsWrapper(env) # Get rid of the 'mission' field
153
+ obs, _ = env.reset() # This now produces an RGB tensor only
121
154
 
122
155
  img = env.unwrapped.get_frame()
123
156
 
124
157
  ####### save image to file
125
- image_pil = Image.fromarray(np.uint8(img)).convert('RGB')
158
+ image_pil = Image.fromarray(np.uint8(img)).convert("RGB")
126
159
  image_pil.save(r"{}.png".format(img_path))
127
160
 
128
-
161
+ def change_done_by_specific_desired(self, obs, desired, old_success_done):
162
+ assert (
163
+ desired is None
164
+ ), "In MinigridProperty, giving a specific 'desired' is not supported."
165
+ return old_success_done
166
+
167
+ def is_done(self, done):
168
+ assert isinstance(done, np.ndarray)
169
+ return done[0]
170
+
171
+ # Not used currently since TabularQLearner doesn't need is_success from the environment
172
+ def is_success(self, info):
173
+ raise NotImplementedError("no other option for any of the environments.")
174
+
175
+ def change_goal_to_specific_desired(self, obs, desired):
176
+ assert (
177
+ desired is None
178
+ ), "In MinigridProperty, giving a specific 'desired' is not supported."
179
+
180
+
129
181
  class PandaProperty(GCEnvProperty):
130
182
  def __init__(self, name):
131
183
  super().__init__(name)
@@ -133,25 +185,30 @@ class PandaProperty(GCEnvProperty):
133
185
 
134
186
  def str_to_goal(self, problem_name):
135
187
  try:
136
- numeric_part = problem_name.split('PandaMyReachDenseX')[1]
137
- components = [component.replace('-v3', '').replace('y', '.').replace('M', '-') for component in numeric_part.split('X')]
188
+ numeric_part = problem_name.split("PandaMyReachDenseX")[1]
189
+ components = [
190
+ component.replace("-v3", "").replace("y", ".").replace("M", "-")
191
+ for component in numeric_part.split("X")
192
+ ]
138
193
  floats = []
139
194
  for component in components:
140
195
  floats.append(float(component))
141
196
  return np.array([floats], dtype=np.float32)
142
197
  except Exception as e:
143
198
  return "general"
144
-
199
+
145
200
  def goal_to_problem_str(self, goal):
146
- goal_str = 'X'.join([str(float(g)).replace(".", "y").replace("-","M") for g in goal[0]])
201
+ goal_str = "X".join(
202
+ [str(float(g)).replace(".", "y").replace("-", "M") for g in goal[0]]
203
+ )
147
204
  return f"PandaMyReachDenseX{goal_str}-v3"
148
205
 
149
206
  def gc_adaptable(self):
150
207
  return True
151
-
208
+
152
209
  def use_goal_directed_problem(self):
153
210
  return False
154
-
211
+
155
212
  def is_action_discrete(self):
156
213
  return False
157
214
 
@@ -159,14 +216,43 @@ class PandaProperty(GCEnvProperty):
159
216
  return False
160
217
 
161
218
  def get_lstm_props(self):
162
- return LSTMProperties(batch_size=32, input_size=9, hidden_size=8, num_samples=20000)
163
-
219
+ return LSTMProperties(
220
+ batch_size=32, input_size=9, hidden_size=8, num_samples=20000
221
+ )
222
+
164
223
  def sample_goal():
165
224
  goal_range_low = np.array([-0.40, -0.40, 0.10])
166
225
  goal_range_high = np.array([0.2, 0.2, 0.10])
167
226
  return np.random.uniform(goal_range_low, goal_range_high)
168
227
 
169
-
228
+ def change_done_by_specific_desired(self, obs, desired, old_success_done):
229
+ if desired is None:
230
+ return old_success_done
231
+ assert isinstance(
232
+ desired, np.ndarray
233
+ ), f"Unsupported type for desired: {type(desired)}"
234
+ if desired.size > 0 and not np.isnan(desired).all():
235
+ assert (
236
+ obs["achieved_goal"].shape == desired.shape
237
+ ), f"Shape mismatch: {obs['achieved_goal'].shape} vs {desired.shape}"
238
+ d = np.linalg.norm(obs["achieved_goal"] - desired, axis=-1)
239
+ return (d < 0.04)[0]
240
+ else:
241
+ return old_success_done
242
+
243
+ def is_done(self, done):
244
+ assert isinstance(done, np.ndarray)
245
+ return done[0]
246
+
247
+ def is_success(self, info):
248
+ assert "is_success" in info[0].keys()
249
+ return info[0]["is_success"]
250
+
251
+ def change_goal_to_specific_desired(self, obs, desired):
252
+ if desired is not None:
253
+ obs["desired_goal"] = desired
254
+
255
+
170
256
  class ParkingProperty(GCEnvProperty):
171
257
 
172
258
  def __init__(self, name):
@@ -178,18 +264,39 @@ class ParkingProperty(GCEnvProperty):
178
264
 
179
265
  def gc_adaptable(self):
180
266
  return True
181
-
267
+
182
268
  def is_action_discrete(self):
183
269
  return False
184
270
 
185
271
  def is_state_discrete(self):
186
272
  return False
187
-
273
+
188
274
  def use_goal_directed_problem(self):
189
275
  return True
190
-
276
+
191
277
  def get_lstm_props(self):
192
- return LSTMProperties(batch_size=32, input_size=8, hidden_size=8, num_samples=20000)
278
+ return LSTMProperties(
279
+ batch_size=32, input_size=8, hidden_size=8, num_samples=20000
280
+ )
281
+
282
+ def change_done_by_specific_desired(self, obs, desired, old_success_done):
283
+ assert (
284
+ desired is None
285
+ ), "In ParkingProperty, giving a specific 'desired' is not supported."
286
+ return old_success_done
287
+
288
+ def is_done(self, done):
289
+ assert isinstance(done, np.ndarray)
290
+ return done[0]
291
+
292
+ def is_success(self, info):
293
+ assert "is_success" in info[0].keys()
294
+ return info[0]["is_success"]
295
+
296
+ def change_goal_to_specific_desired(self, obs, desired):
297
+ assert (
298
+ desired is None
299
+ ), "In ParkingProperty, giving a specific 'desired' is not supported."
193
300
 
194
301
 
195
302
  class PointMazeProperty(EnvProperty):
@@ -205,7 +312,7 @@ class PointMazeProperty(EnvProperty):
205
312
  # Extract width and height from the goal part
206
313
  width, height = goal_part.split("x")
207
314
  return (int(width), int(height))
208
-
315
+
209
316
  def gc_adaptable(self):
210
317
  return False
211
318
 
@@ -217,9 +324,30 @@ class PointMazeProperty(EnvProperty):
217
324
 
218
325
  def is_state_discrete(self):
219
326
  return False
220
-
327
+
221
328
  def get_lstm_props(self):
222
- return LSTMProperties(batch_size=32, input_size=6, hidden_size=8, num_samples=20000)
329
+ return LSTMProperties(
330
+ batch_size=32, input_size=6, hidden_size=8, num_samples=20000
331
+ )
223
332
 
224
333
  def goal_to_problem_str(self, goal):
225
334
  return self.name + f"-Goal-{goal[0]}x{goal[1]}"
335
+
336
+ def change_done_by_specific_desired(self, obs, desired, old_success_done):
337
+ assert (
338
+ desired is None
339
+ ), "In PointMazeProperty, giving a specific 'desired' is not supported."
340
+ return old_success_done
341
+
342
+ def is_done(self, done):
343
+ assert isinstance(done, np.ndarray)
344
+ return done[0]
345
+
346
+ def is_success(self, info):
347
+ assert "success" in info[0].keys()
348
+ return info[0]["success"]
349
+
350
+ def change_goal_to_specific_desired(self, obs, desired):
351
+ assert (
352
+ desired is None
353
+ ), "In ParkingProperty, giving a specific 'desired' is not supported."