gr-libs 0.1.6.post1__tar.gz → 0.1.7.post0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/PKG-INFO +1 -1
  2. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/__init__.py +5 -1
  3. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/_version.py +2 -2
  4. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/environment/__init__.py +1 -1
  5. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/tabular/tabular_q_learner.py +1 -1
  6. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/utils/storage.py +7 -0
  7. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/recognizer/graml/graml_recognizer.py +21 -12
  8. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs.egg-info/PKG-INFO +1 -1
  9. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs.egg-info/SOURCES.txt +0 -10
  10. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs.egg-info/top_level.txt +1 -0
  11. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/tutorials/graml_parking_tutorial.py +1 -0
  12. gr_libs-0.1.6.post1/.github/workflows/common_test_steps.yml +0 -26
  13. gr_libs-0.1.6.post1/.github/workflows/pr_flow.yml +0 -10
  14. gr_libs-0.1.6.post1/.github/workflows/release.yml +0 -33
  15. gr_libs-0.1.6.post1/.gitignore +0 -160
  16. gr_libs-0.1.6.post1/CI/README.md +0 -12
  17. gr_libs-0.1.6.post1/CI/docker_build_context/Dockerfile +0 -15
  18. gr_libs-0.1.6.post1/all_experiments.py +0 -194
  19. gr_libs-0.1.6.post1/download_dataset.py +0 -19
  20. gr_libs-0.1.6.post1/gr_libs/recognizer/recognizer_doc.md +0 -61
  21. gr_libs-0.1.6.post1/odgr_executor.py +0 -125
  22. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/README.md +0 -0
  23. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/evaluation/analyze_results_cross_alg_cross_domain.py +0 -0
  24. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/evaluation/create_minigrid_map_image.py +0 -0
  25. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/evaluation/file_system.py +0 -0
  26. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/evaluation/generate_experiments_results.py +0 -0
  27. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/evaluation/generate_experiments_results_new_ver1.py +0 -0
  28. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/evaluation/generate_experiments_results_new_ver2.py +0 -0
  29. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/evaluation/generate_task_specific_statistics_plots.py +0 -0
  30. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/evaluation/get_plans_images.py +0 -0
  31. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/evaluation/increasing_and_decreasing_.py +0 -0
  32. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/environment/environment.py +0 -0
  33. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/environment/utils/__init__.py +0 -0
  34. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/environment/utils/utils.py +0 -0
  35. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/metrics/__init__.py +0 -0
  36. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/metrics/metrics.py +0 -0
  37. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/__init__.py +0 -0
  38. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/agent.py +0 -0
  39. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/base/__init__.py +0 -0
  40. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/base/rl_agent.py +0 -0
  41. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/consts.py +0 -0
  42. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/neural/__init__.py +0 -0
  43. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/neural/deep_rl_learner.py +0 -0
  44. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/neural/utils/__init__.py +0 -0
  45. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/neural/utils/dictlist.py +0 -0
  46. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/neural/utils/penv.py +0 -0
  47. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/planner/__init__.py +0 -0
  48. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/planner/mcts/__init__.py +0 -0
  49. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/planner/mcts/mcts_model.py +0 -0
  50. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/planner/mcts/utils/__init__.py +0 -0
  51. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/planner/mcts/utils/node.py +0 -0
  52. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/planner/mcts/utils/tree.py +0 -0
  53. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/sequential/__init__.py +0 -0
  54. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/sequential/lstm_model.py +0 -0
  55. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/tabular/__init__.py +0 -0
  56. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/tabular/state.py +0 -0
  57. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/tabular/tabular_rl_agent.py +0 -0
  58. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/utils/__init__.py +0 -0
  59. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/utils/env.py +0 -0
  60. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/utils/format.py +0 -0
  61. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/utils/math.py +0 -0
  62. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/ml/utils/other.py +0 -0
  63. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/problems/__init__.py +0 -0
  64. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/problems/consts.py +0 -0
  65. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/recognizer/__init__.py +0 -0
  66. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/recognizer/gr_as_rl/__init__.py +0 -0
  67. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +0 -0
  68. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/recognizer/graml/__init__.py +0 -0
  69. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/recognizer/graml/gr_dataset.py +0 -0
  70. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/recognizer/recognizer.py +0 -0
  71. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/recognizer/utils/__init__.py +0 -0
  72. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs/recognizer/utils/format.py +0 -0
  73. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs.egg-info/dependency_links.txt +0 -0
  74. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/gr_libs.egg-info/requires.txt +0 -0
  75. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/pyproject.toml +0 -0
  76. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/setup.cfg +0 -0
  77. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/tests/test_graml.py +0 -0
  78. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/tests/test_graql.py +0 -0
  79. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/tutorials/graml_minigrid_tutorial.py +0 -0
  80. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/tutorials/graml_panda_tutorial.py +0 -0
  81. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/tutorials/graml_point_maze_tutorial.py +0 -0
  82. {gr_libs-0.1.6.post1 → gr_libs-0.1.7.post0}/tutorials/graql_minigrid_tutorial.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gr_libs
3
- Version: 0.1.6.post1
3
+ Version: 0.1.7.post0
4
4
  Summary: Package with goal recognition frameworks baselines
5
5
  Author: Ben Nageris
6
6
  Author-email: Matan Shamir <matan.shamir@live.biu.ac.il>, Osher Elhadad <osher.elhadad@live.biu.ac.il>
@@ -1,2 +1,6 @@
1
1
  from gr_libs.recognizer.graml.graml_recognizer import ExpertBasedGraml, GCGraml
2
- from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Graql
2
+ from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Graql
3
+ try:
4
+ from ._version import version as __version__
5
+ except ImportError:
6
+ __version__ = "0.0.0" # fallback if file isn't present
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.6.post1'
21
- __version_tuple__ = version_tuple = (0, 1, 6)
20
+ __version__ = version = '0.1.7.post0'
21
+ __version_tuple__ = version_tuple = (0, 1, 7, 'post0')
@@ -12,7 +12,7 @@ def is_extra_installed(package: str, extra: str) -> bool:
12
12
  return False # The package is not installed
13
13
 
14
14
  # Check if `gr_libs[minigrid]` was installed
15
- for env in ["minigrid", "panda", "highway", "point_maze"]:
15
+ for env in ["minigrid", "panda", "highway", "maze"]:
16
16
  if is_extra_installed("gr_libs", f"gr_envs[{env}]"):
17
17
  try:
18
18
  importlib.import_module(f"gr_envs.{env}_scripts.envs")
@@ -351,7 +351,7 @@ class TabularQLearner(TabularRLAgent):
351
351
  def simplify_observation(self, observation):
352
352
  return [(obs['direction'], agent_pos_x, agent_pos_y, action) for ((obs, (agent_pos_x, agent_pos_y)), action) in observation] # list of tuples, each tuple the sample
353
353
 
354
- def generate_observation(self, action_selection_method: MethodType, random_optimalism, save_fig = False, fig_path: str=None, env_prop=None):
354
+ def generate_observation(self, action_selection_method: MethodType, random_optimalism, save_fig=False, fig_path: str=None, env_prop=None):
355
355
  """
356
356
  Generate a single observation given a list of agents
357
357
 
@@ -15,6 +15,13 @@ def get_storage_framework_dir(recognizer: str):
15
15
  return os.path.join(get_storage_dir(),recognizer)
16
16
 
17
17
  def get_storage_dir():
18
+ # Prefer local directory if it exists (e.g., in GitHub workspace)
19
+ if os.path.exists("dataset"):
20
+ return "dataset"
21
+ # Fall back to pre-mounted directory (e.g., in Docker container)
22
+ if os.path.exists("/preloaded_data"):
23
+ return "/preloaded_data"
24
+ # Default to "dataset" even if it doesn't exist (e.g., will be created)
18
25
  return "dataset"
19
26
 
20
27
  def _get_models_directory_name():
@@ -82,14 +82,14 @@ class Graml(LearningRecognizer):
82
82
  dev_loader=DataLoader(dev_dataset, batch_size=self.env_prop.get_lstm_props().batch_size, shuffle=False, collate_fn=self.collate_func))
83
83
  save_weights(model=self.model, path=self.model_file_path)
84
84
 
85
- def goals_adaptation_phase(self, dynamic_goals: List[EnvProperty]):
85
+ def goals_adaptation_phase(self, dynamic_goals: List[EnvProperty], save_fig=False):
86
86
  self.is_first_inf_since_new_goals = True
87
87
  self.current_goals = dynamic_goals
88
88
  # start by training each rl agent on the base goal set
89
89
  self.embeddings_dict = {} # relevant if the embedding of the plan occurs during the goals adaptation phase
90
90
  self.plans_dict = {} # relevant if the embedding of the plan occurs during the inference phase
91
91
  for goal in self.current_goals:
92
- obss = self.generate_sequences_library(goal)
92
+ obss = self.generate_sequences_library(goal, save_fig=save_fig)
93
93
  self.plans_dict[str(goal)] = obss
94
94
 
95
95
  def get_goal_plan(self, goal):
@@ -150,7 +150,7 @@ class Graml(LearningRecognizer):
150
150
  return closest_goal
151
151
 
152
152
  @abstractmethod
153
- def generate_sequences_library(self, goal: str) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
153
+ def generate_sequences_library(self, goal: str, save_fig=False) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
154
154
  pass
155
155
 
156
156
  # this function duplicates every sequence and creates a consecutive and non-consecutive version of it
@@ -192,10 +192,10 @@ class MCTSBasedGraml(BGGraml, GaAdaptingRecognizer):
192
192
  super().__init__(*args, **kwargs)
193
193
  if self.rl_agent_type==None: self.rl_agent_type = TabularQLearner
194
194
 
195
- def generate_sequences_library(self, goal: str) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
195
+ def generate_sequences_library(self, goal: str, save_fig=False) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
196
196
  problem_name = self.env_prop.goal_to_problem_str(goal)
197
197
  img_path = os.path.join(get_policy_sequences_result_path(self.env_prop.domain_name, recognizer=self.__class__.__name__), problem_name + "_MCTS")
198
- return mcts_model.plan(self.env_prop.name, problem_name, goal, save_fig=True, img_path=img_path, env_prop=self.env_prop)
198
+ return mcts_model.plan(self.env_prop.name, problem_name, goal, save_fig=save_fig, img_path=img_path, env_prop=self.env_prop)
199
199
 
200
200
  class ExpertBasedGraml(BGGraml, GaAgentTrainerRecognizer):
201
201
  def __init__(self, *args, **kwargs):
@@ -206,15 +206,23 @@ class ExpertBasedGraml(BGGraml, GaAgentTrainerRecognizer):
206
206
  else:
207
207
  self.rl_agent_type = DeepRLAgent
208
208
 
209
- def generate_sequences_library(self, goal: str) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
209
+ def generate_sequences_library(self, goal: str, save_fig=False) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
210
210
  problem_name = self.env_prop.goal_to_problem_str(goal)
211
211
  kwargs = {"domain_name":self.domain_name, "problem_name":problem_name}
212
212
  if self.dynamic_train_configs_dict[problem_name][0] != None: kwargs["algorithm"] = self.dynamic_train_configs_dict[problem_name][0]
213
213
  if self.dynamic_train_configs_dict[problem_name][1] != None: kwargs["num_timesteps"] = self.dynamic_train_configs_dict[problem_name][1]
214
214
  agent = self.rl_agent_type(**kwargs)
215
215
  agent.learn()
216
- fig_path = get_and_create(f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name, recognizer=self.__class__.__name__), problem_name))}_bg_sequence")
217
- return [agent.generate_observation(action_selection_method=metrics.greedy_selection, random_optimalism=False, save_fig=True, fig_path=fig_path, env_prop=self.env_prop)]
216
+ agent_kwargs = {
217
+ "action_selection_method": metrics.greedy_selection,
218
+ "random_optimalism": False,
219
+ "save_fig": save_fig,
220
+ "env_prop": self.env_prop
221
+ }
222
+ if save_fig:
223
+ fig_path = get_and_create(f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name, recognizer=self.__class__.__name__), problem_name))}_bg_sequence")
224
+ agent_kwargs["fig_path"] = fig_path
225
+ return [agent.generate_observation(**agent_kwargs)]
218
226
 
219
227
  def goals_adaptation_phase(self, dynamic_goals: List[str], dynamic_train_configs):
220
228
  self.dynamic_goals_problems = [self.env_prop.goal_to_problem_str(g) for g in dynamic_goals]
@@ -244,20 +252,21 @@ class GCGraml(Graml, GaAdaptingRecognizer):
244
252
  gc_agent.learn()
245
253
  self.agents.append(ContextualAgent(problem_name=self.env_prop.name, problem_goal="general", agent=gc_agent))
246
254
 
247
- def generate_sequences_library(self, goal: str) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
255
+ def generate_sequences_library(self, goal: str, save_fig=False) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
248
256
  problem_name = self.env_prop.goal_to_problem_str(goal)
249
257
  kwargs = {"domain_name":self.domain_name, "problem_name":self.env_prop.name} # problem name is env name in gc case
250
258
  if self.original_train_configs[0][0] != None: kwargs["algorithm"] = self.original_train_configs[0][0]
251
259
  if self.original_train_configs[0][1] != None: kwargs["num_timesteps"] = self.original_train_configs[0][1]
252
260
  agent = self.rl_agent_type(**kwargs)
253
261
  agent.learn()
254
- fig_path = get_and_create(f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name, recognizer=self.__class__.__name__), problem_name))}_gc_sequence")
255
262
  agent_kwargs = {
256
263
  "action_selection_method": metrics.stochastic_amplified_selection,
257
264
  "random_optimalism": True,
258
- "save_fig": True,
259
- "fig_path": fig_path
265
+ "save_fig": save_fig
260
266
  }
267
+ if save_fig:
268
+ fig_path = get_and_create(f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name, recognizer=self.__class__.__name__), problem_name))}_gc_sequence")
269
+ agent_kwargs["fig_path"] = fig_path
261
270
  if self.env_prop.use_goal_directed_problem(): agent_kwargs["goal_directed_problem"] = problem_name
262
271
  else: agent_kwargs["goal_directed_goal"] = goal
263
272
  obss = []
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gr_libs
3
- Version: 0.1.6.post1
3
+ Version: 0.1.7.post0
4
4
  Summary: Package with goal recognition frameworks baselines
5
5
  Author: Ben Nageris
6
6
  Author-email: Matan Shamir <matan.shamir@live.biu.ac.il>, Osher Elhadad <osher.elhadad@live.biu.ac.il>
@@ -1,14 +1,5 @@
1
- .gitignore
2
1
  README.md
3
- all_experiments.py
4
- download_dataset.py
5
- odgr_executor.py
6
2
  pyproject.toml
7
- .github/workflows/common_test_steps.yml
8
- .github/workflows/pr_flow.yml
9
- .github/workflows/release.yml
10
- CI/README.md
11
- CI/docker_build_context/Dockerfile
12
3
  evaluation/analyze_results_cross_alg_cross_domain.py
13
4
  evaluation/create_minigrid_map_image.py
14
5
  evaluation/file_system.py
@@ -63,7 +54,6 @@ gr_libs/problems/__init__.py
63
54
  gr_libs/problems/consts.py
64
55
  gr_libs/recognizer/__init__.py
65
56
  gr_libs/recognizer/recognizer.py
66
- gr_libs/recognizer/recognizer_doc.md
67
57
  gr_libs/recognizer/gr_as_rl/__init__.py
68
58
  gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py
69
59
  gr_libs/recognizer/graml/__init__.py
@@ -1,4 +1,5 @@
1
1
  CI
2
+ build
2
3
  dist
3
4
  evaluation
4
5
  gr_libs
@@ -5,6 +5,7 @@ from gr_libs.metrics.metrics import stochastic_amplified_selection
5
5
  from gr_libs.ml.neural.deep_rl_learner import DeepRLAgent, GCDeepRLAgent
6
6
  from gr_libs.ml.utils.format import random_subset_with_order
7
7
  from gr_libs.recognizer.graml.graml_recognizer import ExpertBasedGraml, GCGraml
8
+ import gr_libs.environment.environment
8
9
 
9
10
  def run_graml_parking_tutorial():
10
11
  recognizer = GCGraml(
@@ -1,26 +0,0 @@
1
- name: Common Test Steps
2
-
3
- on:
4
- workflow_call:
5
-
6
- jobs:
7
- test_steps:
8
- runs-on: ubuntu-latest
9
- container:
10
- image: ghcr.io/matanshamir1/gr_test_base_slim:latest
11
- steps:
12
- - name: Check out the repository
13
- uses: actions/checkout@v4
14
-
15
- - name: Install gr_libs with all extras and test tools
16
- env:
17
- SETUPTOOLS_SCM_PRETEND_VERSION_FOR_GR_LIBS: "0.0.0"
18
- run: |
19
- python -m pip install --upgrade pip
20
- pip install setuptools_scm
21
- pip install gr_envs[minigrid,panda,parking,maze]
22
- pip install .[minigrid,panda,parking,maze]
23
- pip install pytest
24
-
25
- - name: Run tests
26
- run: pytest tests/
@@ -1,10 +0,0 @@
1
- name: PR Test Flow
2
-
3
- on:
4
- pull_request:
5
- branches:
6
- - main # or whichever branch you're targeting for PRs
7
-
8
- jobs:
9
- run_tests:
10
- uses: ./.github/workflows/common_test_steps.yml
@@ -1,33 +0,0 @@
1
- name: Publish to PyPI
2
-
3
- on:
4
- push:
5
- tags:
6
- - "v*"
7
-
8
- jobs:
9
- release:
10
- runs-on: ubuntu-latest
11
- steps:
12
- # from here to remov when returning uses: ./.github/workflows/common_test_steps.yml
13
- - name: Checkout code
14
- uses: actions/checkout@v4
15
-
16
- - name: Set up Python
17
- uses: actions/setup-python@v5
18
- with:
19
- python-version: "3.11"
20
-
21
- - name: Install build tools
22
- run: |
23
- python -m pip install --upgrade pip
24
- pip install build twine
25
- # until here!
26
- - name: Build the package
27
- run: python -m build
28
-
29
- - name: Publish to PyPI
30
- env:
31
- TWINE_USERNAME: __token__
32
- TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
33
- run: python -m twine upload dist/*
@@ -1,160 +0,0 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # poetry
98
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
- #poetry.lock
103
-
104
- # pdm
105
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
- #pdm.lock
107
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
- # in version control.
109
- # https://pdm.fming.dev/#use-with-ide
110
- .pdm.toml
111
-
112
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
- __pypackages__/
114
-
115
- # Celery stuff
116
- celerybeat-schedule
117
- celerybeat.pid
118
-
119
- # SageMath parsed files
120
- *.sage.py
121
-
122
- # Environments
123
- .env
124
- .venv
125
- env/
126
- venv/
127
- ENV/
128
- env.bak/
129
- venv.bak/
130
-
131
- # Spyder project settings
132
- .spyderproject
133
- .spyproject
134
-
135
- # Rope project settings
136
- .ropeproject
137
-
138
- # mkdocs documentation
139
- /site
140
-
141
- # mypy
142
- .mypy_cache/
143
- .dmypy.json
144
- dmypy.json
145
-
146
- # Pyre type checker
147
- .pyre/
148
-
149
- # pytype static type analyzer
150
- .pytype/
151
-
152
- # Cython debug symbols
153
- cython_debug/
154
-
155
- # PyCharm
156
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
- # and can be added to the global gitignore or merged into this file. For a more nuclear
159
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
- #.idea/
@@ -1,12 +0,0 @@
1
- ## How to build a new docker image including new trained agents:
2
- 1. Install docker
3
- 2. Make sure you have a dataset.zip at your repo root
4
- 3. Make sure you have a classic token in github: https://github.com/settings/tokens . If you don't, create one with package write, read and delete permissions and copy it somewhere safe.
5
- 4. Authenticate to ghcr with docker by running:
6
- ```sh
7
- echo ghp_REST_OF_TOKEN | docker login ghcr.io -u MatanShamir1 --password-stdin
8
- ```
9
- 3. docker build -t ghcr.io/<your-username>/gr_test_base:latest -f CI/Dockerfile .
10
- (the -f Dockerfile tells docker which Dockerfile to use and the '.' tells docker what's the build context, or where the dataset.zip should live)
11
- 4. docker push ghcr.io/<your-username>/gr_test_base:latest
12
- docker push ghcr.io/MatanShamir1/gr_test_base:latest
@@ -1,15 +0,0 @@
1
- FROM python:3.11-slim
2
-
3
- # Set workdir
4
- WORKDIR /app
5
-
6
- # Install unzip
7
- RUN apt-get update && apt-get install -y unzip && rm -rf /var/lib/apt/lists/*
8
-
9
- # Copy and unzip the dataset
10
- COPY dataset.zip .
11
- RUN unzip dataset.zip && rm dataset.zip
12
- RUN mv dataset_new dataset
13
-
14
- # Just start with bash by default
15
- CMD [ "bash" ]
@@ -1,194 +0,0 @@
1
- import os
2
- import sys
3
- import threading
4
- import dill
5
- import subprocess
6
- import concurrent.futures
7
- import numpy as np
8
-
9
- from gr_libs.ml.utils.storage import get_experiment_results_path
10
-
11
- # Define the lists
12
- # domains = ['minigrid', 'point_maze', 'parking', 'panda']
13
- # envs = {
14
- # 'minigrid': ['obstacles', 'lava_crossing'],
15
- # 'point_maze': ['four_rooms', 'lava_crossing'],
16
- # 'parking': ['gc_agent', 'gd_agent'],
17
- # 'panda': ['gc_agent', 'gd_agent']
18
- # }
19
- # tasks = {
20
- # 'minigrid': ['L111', 'L222', 'L333', 'L444', 'L555'],
21
- # 'point_maze': ['L111', 'L222', 'L333', 'L444', 'L555'],
22
- # 'parking': ['L111', 'L222', 'L333', 'L444', 'L555'],
23
- # 'panda': ['L111', 'L222', 'L333', 'L444', 'L555']
24
- # }
25
- configs = {
26
- 'minigrid': {
27
- 'MiniGrid-SimpleCrossingS13N4': ['L1', 'L2', 'L3', 'L4', 'L5'],
28
- 'MiniGrid-LavaCrossingS9N2': ['L1', 'L2', 'L3', 'L4', 'L5']
29
- }
30
- # 'point_maze': {
31
- # 'PointMaze-FourRoomsEnvDense-11x11': ['L1', 'L2', 'L3', 'L4', 'L5'],
32
- # 'PointMaze-ObstaclesEnvDense-11x11': ['L1', 'L2', 'L3', 'L4', 'L5']
33
- # }
34
- # 'parking': {
35
- # 'Parking-S-14-PC-': ['L1', 'L2', 'L3', 'L4', 'L5'],
36
- # 'Parking-S-14-PC-': ['L1', 'L2', 'L3', 'L4', 'L5']
37
- # }
38
- # 'panda': {
39
- # 'PandaMyReachDense': ['L1', 'L2', 'L3', 'L4', 'L5'],
40
- # 'PandaMyReachDense': ['L1', 'L2', 'L3', 'L4', 'L5']
41
- # }
42
- }
43
- # for minigrid: #TODO assert these instead i the beggingning of the code before beginning with the actual threading
44
- recognizers = ['ExpertBasedGraml']
45
- # recognizers = ['Graql']
46
-
47
- # for point_maze:
48
- # recognizers = ['ExpertBasedGraml']
49
- # recognizers = ['Draco']
50
-
51
- # for parking:
52
- # recognizers = ['GCGraml']
53
- # recognizers = ['GCDraco']
54
-
55
- # for panda:
56
- # recognizers = ['GCGraml']
57
- # recognizers = ['GCDraco']
58
-
59
- n = 5 # Number of times to execute each task
60
-
61
- # Function to read results from the result file
62
- def read_results(res_file_path):
63
- with open(res_file_path, 'rb') as f:
64
- results = dill.load(f)
65
- return results
66
-
67
- # Every thread worker executes this function.
68
- def run_experiment(domain, env, task, recognizer, i, generate_new=False):
69
- cmd = f"python odgr_executor.py --domain {domain} --recognizer {recognizer} --env_name {env} --task {task} --collect_stats"
70
- print(f"Starting execution: {cmd}")
71
- try:
72
- res_file_path = get_experiment_results_path(domain, env, task, recognizer)
73
- res_file_path_txt = res_file_path + '.txt'
74
- i_res_file_path_txt = res_file_path + f'_{i}.txt'
75
- res_file_path_pkl = res_file_path + '.pkl'
76
- i_res_file_path_pkl = res_file_path + f'_{i}.pkl'
77
- if generate_new or (not os.path.exists(i_res_file_path_txt) or not os.path.exists(i_res_file_path_pkl)):
78
- if os.path.exists(i_res_file_path_txt) or os.path.exists(i_res_file_path_pkl):
79
- i_res_file_path_txt = i_res_file_path_txt.replace(f'_{i}', f'_{i}_new')
80
- i_res_file_path_pkl = i_res_file_path_pkl.replace(f'_{i}', f'_{i}_new')
81
- # every thread in the current process starts a new process which executes the command. the current thread waits for the subprocess to finish.
82
- process = subprocess.Popen(cmd, shell=True)
83
- process.wait()
84
- # result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
85
- if process.returncode != 0:
86
- print(f"Execution failed: {cmd}")
87
- print(f"Error: {result.stderr}")
88
- return None
89
- else:
90
- print(f"Finished execution successfully: {cmd}")
91
- file_lock = threading.Lock()
92
- with file_lock:
93
- os.rename(res_file_path_pkl, i_res_file_path_pkl)
94
- os.rename(res_file_path_txt, i_res_file_path_txt)
95
- else:
96
- print(f"File {i_res_file_path_txt} already exists. Skipping execution of {cmd}")
97
- return ((domain, env, task, recognizer), read_results(i_res_file_path_pkl))
98
- except Exception as e:
99
- print(f"Exception occurred while running experiment: {e}")
100
- return None
101
-
102
- # Collect results
103
- results = {}
104
-
105
- # create an executor that manages a pool of threads.
106
- # Note that any failure in the threads will not stop the main thread from continuing and vice versa, nor will the debugger view the failure if in debug mode.
107
- # Use prints and if any thread's printing stops suspect failure. If failure happened, use breakpoints before failure and use the watch to see the failure by pasting the problematic piece of code.
108
- with concurrent.futures.ThreadPoolExecutor() as executor:
109
- futures = []
110
- for domain, envs in configs.items():
111
- for env, tasks in envs.items():
112
- for task in tasks:
113
- for recognizer in recognizers:
114
- for i in range(n):
115
- # submit returns a future object that represents the execution of the function by some thread in the pool. the method is added to the list and the executor will run it as soon as it can.
116
- futures.append(executor.submit(run_experiment, domain, env, task, recognizer, i, generate_new=True if len(sys.argv) > 1 and sys.argv[1] == '--generate_new' else False))
117
-
118
- # main thread continues execution after the loop above. Here, it waits for all threads. every time a thread finishes, the main thread reads the results from the future object.
119
- # If debugging main thread, note the executor will stop creating new threads and running them since it exists and runs in the main thread.
120
- # probably, the main thread gets interrupted by the return of a future and knows to start execution of new threads and then continue main thread execution
121
- for future in concurrent.futures.as_completed(futures):
122
- # the objects returned by the 'result' func are tuples with key being the args inserted to 'submit'.
123
- if future.result() is None:
124
- print(f"for future {future}, future.result() is None. Continuing to next future.")
125
- continue
126
- key, result = future.result()
127
- print(f"main thread reading results from future {key}")
128
- # list because every experiment is executed n times.
129
- if key not in results:
130
- results[key] = []
131
- results[key].append(result)
132
-
133
- # Calculate average accuracy and standard deviation for each percentage
134
- detailed_summary = {}
135
- compiled_accuracies = {}
136
- for key, result_list in results.items():
137
- domain, env, task, recognizer = key
138
- percentages = result_list[0].keys()
139
- detailed_summary[key] = {}
140
- if (domain, recognizer) not in compiled_accuracies:
141
- compiled_accuracies[(domain, recognizer)] = {}
142
- for percentage in percentages:
143
- if percentage == 'total':
144
- continue
145
- if percentage not in compiled_accuracies[(domain, recognizer)].keys(): compiled_accuracies[(domain, recognizer)][percentage] = {}
146
- if percentage not in detailed_summary[key].keys(): detailed_summary[key][percentage] = {}
147
- consecutive_accuracies = [result[percentage]['consecutive']['accuracy'] for result in result_list] # accuracies in all different n executions
148
- non_consecutive_accuracies = [result[percentage]['non_consecutive']['accuracy'] for result in result_list]
149
- if 'consecutive' in compiled_accuracies[(domain, recognizer)][percentage].keys():
150
- compiled_accuracies[(domain, recognizer)][percentage]['consecutive'].extend(consecutive_accuracies)
151
- else:
152
- compiled_accuracies[(domain, recognizer)][percentage]['consecutive'] = consecutive_accuracies
153
- if 'non_consecutive' in compiled_accuracies[(domain, recognizer)][percentage].keys():
154
- compiled_accuracies[(domain, recognizer)][percentage]['non_consecutive'].extend(non_consecutive_accuracies)
155
- else:
156
- compiled_accuracies[(domain, recognizer)][percentage]['non_consecutive'] = non_consecutive_accuracies
157
- avg_consecutive_accuracy = np.mean(consecutive_accuracies)
158
- consecutive_std_dev = np.std(consecutive_accuracies)
159
- detailed_summary[key][percentage]['consecutive'] = (avg_consecutive_accuracy, consecutive_std_dev)
160
- avg_non_consecutive_accuracy = np.mean(non_consecutive_accuracies)
161
- non_consecutive_std_dev = np.std(non_consecutive_accuracies)
162
- detailed_summary[key][percentage]['non_consecutive'] = (avg_non_consecutive_accuracy, non_consecutive_std_dev)
163
-
164
- compiled_summary = {}
165
- for key, percentage_dict in compiled_accuracies.items():
166
- compiled_summary[key] = {}
167
- for percentage, cons_accuracies in percentage_dict.items():
168
- compiled_summary[key][percentage] = {}
169
- for is_cons, accuracies in cons_accuracies.items():
170
- avg_accuracy = np.mean(accuracies)
171
- std_dev = np.std(accuracies)
172
- compiled_summary[key][percentage][is_cons] = (avg_accuracy, std_dev)
173
-
174
- # Write different summary results to different files
175
- detailed_summary_file_path = os.path.join('summaries',f"detailed_summary_{''.join(configs.keys())}_{recognizers[0]}.txt")
176
- compiled_summary_file_path = os.path.join('summaries',f"compiled_summary_{''.join(configs.keys())}_{recognizers[0]}.txt")
177
- with open(detailed_summary_file_path, 'w') as f:
178
- for key, percentage_dict in detailed_summary.items():
179
- domain, env, task, recognizer = key
180
- f.write(f"{domain}\t{env}\t{task}\t{recognizer}\n")
181
- for percentage, cons_info in percentage_dict.items():
182
- for is_cons, (avg_accuracy, std_dev) in cons_info.items():
183
- f.write(f"\t\t{percentage}\t{is_cons}\t{avg_accuracy:.4f}\t{std_dev:.4f}\n")
184
-
185
- with open(compiled_summary_file_path, 'w') as f:
186
- for key, percentage_dict in compiled_summary.items():
187
- domain, recognizer = key
188
- f.write(f"{domain}\t{recognizer}\n")
189
- for percentage, cons_info in percentage_dict.items():
190
- for is_cons, (avg_accuracy, std_dev) in cons_info.items():
191
- f.write(f"\t\t{percentage}\t{is_cons}\t{avg_accuracy:.4f}\t{std_dev:.4f}\n")
192
-
193
- print(f"Detailed summary results written to {detailed_summary_file_path}")
194
- print(f"Compiled summary results written to {compiled_summary_file_path}")
@@ -1,19 +0,0 @@
1
- import requests
2
- import zipfile
3
- import os
4
-
5
- def download_and_extract_dataset(google_drive_url, extract_to):
6
- os.makedirs(extract_to, exist_ok=True)
7
- download_url = google_drive_url + "&export=download"
8
- response = requests.get(download_url)
9
- response.raise_for_status()
10
- with open('dataset.zip', 'wb') as f:
11
- f.write(response.content)
12
- with zipfile.ZipFile('dataset.zip', 'r') as zip_ref:
13
- zip_ref.extractall(extract_to)
14
- os.remove('dataset.zip')
15
-
16
- if __name__ == "__main__":
17
- google_drive_url = "https://drive.google.com/file/d/1PK1iZONTyiQZBgLErUO88p1YWdL4B9Xn/view?usp=sharing"
18
- extract_to = "dataset"
19
- download_and_extract_dataset(google_drive_url, extract_to)
@@ -1,61 +0,0 @@
1
- # Recognizer Module Documentation
2
-
3
- This document provides an overview of the recognizer module, including its class hierarchy and instructions for adding a new class of recognizer.
4
-
5
- ## Class Hierarchy
6
-
7
- The recognizer module consists of an abstract base class `Recognizer` and several derived classes, each implementing specific behaviors. The main classes are:
8
-
9
- 1. **Recognizer (Abstract Base Class)**
10
- - `inference_phase()` (abstract method)
11
-
12
- 2. **LearningRecognizer (Extends Recognizer)**
13
- - `domain_learning_phase()`
14
-
15
- 3. **GaAgentTrainerRecognizer (Extends Recognizer)**
16
- - `goals_adaptation_phase()` (abstract method)
17
- - `domain_learning_phase()`
18
-
19
- 4. **GaAdaptingRecognizer (Extends Recognizer)**
20
- - `goals_adaptation_phase()` (abstract method)
21
-
22
- 5. **GRAsRL (Extends Recognizer)**
23
- - Implements `goals_adaptation_phase()`
24
- - Implements `inference_phase()`
25
-
26
- 6. **Specific Implementations:**
27
- - `Graql (Extends GRAsRL, GaAgentTrainerRecognizer)`
28
- - `Draco (Extends GRAsRL, GaAgentTrainerRecognizer)`
29
- - `GCDraco (Extends GRAsRL, LearningRecognizer, GaAdaptingRecognizer)`
30
- - `Graml (Extends LearningRecognizer)`
31
-
32
- ## How to Add a New Recognizer Class
33
-
34
- To add a new class of recognizer, follow these steps:
35
-
36
- 1. **Determine the Type of Recognizer:**
37
- - Will it require learning? Extend `LearningRecognizer`.
38
- - Will it adapt goals dynamically? Extend `GaAdaptingRecognizer`.
39
- - Will it train agents for new goals? Extend `GaAgentTrainerRecognizer`.
40
- - Will it involve RL-based recognition? Extend `GRAsRL`.
41
-
42
- 2. **Define the Class:**
43
- - Create a new class that extends the appropriate base class(es).
44
- - Implement the required abstract methods (`inference_phase()`, `goals_adaptation_phase()`, etc.).
45
-
46
- 3. **Initialize the Recognizer:**
47
- - Ensure proper initialization by calling `super().__init__(*args, **kwargs)`.
48
- - Set up any necessary agent storage or evaluation functions.
49
-
50
- 4. **Implement Core Methods:**
51
- - Define how the recognizer processes inference sequences.
52
- - Implement learning or goal adaptation logic if applicable.
53
-
54
- 5. **Register the Recognizer:**
55
- - Ensure it integrates properly with the existing system by using the correct `domain_to_env_property()`.
56
-
57
- 6. **Test the New Recognizer:**
58
- - Run experiments to validate its behavior.
59
- - Compare results against existing recognizers to ensure correctness.
60
-
61
- By following these steps, you can seamlessly integrate a new recognizer into the framework while maintaining compatibility with the existing structure.
@@ -1,125 +0,0 @@
1
- import argparse
2
- import os
3
- import time
4
- import dill
5
-
6
- from gr_libs.environment.utils.utils import domain_to_env_property
7
- from gr_libs.ml.neural.deep_rl_learner import DeepRLAgent
8
- from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco, Graql
9
- from gr_libs.recognizer import Graml
10
- from gr_libs.metrics.metrics import stochastic_amplified_selection
11
- from gr_libs.ml.utils.format import random_subset_with_order
12
- from gr_libs.recognizer.recognizer import GaAgentTrainerRecognizer, LearningRecognizer
13
- from gr_libs.recognizer.utils import recognizer_str_to_obj
14
- from gr_libs.ml.utils.storage import create_folders_if_necessary, get_and_create, get_experiment_results_path, get_policy_sequences_result_path
15
-
16
- from gr_libs.problems.consts import PROBLEMS
17
-
18
- def validate(args, recognizer_type, task_inputs):
19
- if "base" in task_inputs.keys():
20
- #assert issubclass(recognizer_type, LearningRecognizer), f"base is in the task_inputs for the recognizer {args.recognizer}, which doesn't have a domain learning phase (is not a learning recognizer)."
21
- assert list(task_inputs.keys())[0] == "base", "In case of LearningRecognizer, base should be the first element in the task_inputs dict in consts.py"
22
- assert "base" not in list(task_inputs.keys())[1:], "In case of LearningRecognizer, base should be only in the first element in the task_inputs dict in consts.py"
23
- #else:
24
- #assert not issubclass(recognizer_type, LearningRecognizer), f"base is not in the task_inputs for the recognizer {args.recognizer}, which has a domain learning phase (is a learning recognizer). Remove it from the task_inputs dict in consts.py."
25
-
26
- def run_odgr_problem(args):
27
- recognizer_type = recognizer_str_to_obj(args.recognizer)
28
- env_inputs = PROBLEMS[args.domain]
29
- assert args.env_name in env_inputs.keys(), f"env_name {args.env_name} is not in the list of available environments for the domain {args.domain}. Add it to PROBLEMS dict in consts.py"
30
- task_inputs = env_inputs[args.env_name][args.task]
31
- recognizer = recognizer_type(domain_name=args.domain, env_name=args.env_name, collect_statistics=args.collect_stats)
32
- validate(args, recognizer_type, task_inputs)
33
- ga_times, results = [], {}
34
- for key, value in task_inputs.items():
35
- if key == "base":
36
- dlp_time = 0
37
- if issubclass(recognizer_type, LearningRecognizer):
38
- start_dlp_time = time.time()
39
- recognizer.domain_learning_phase(base_goals=value["goals"], train_configs=value["train_configs"])
40
- dlp_time = time.time() - start_dlp_time
41
- elif key.startswith("G_"):
42
- start_ga_time = time.time()
43
- kwargs = {"dynamic_goals": value["goals"]}
44
- if issubclass(recognizer_type, GaAgentTrainerRecognizer): kwargs["dynamic_train_configs"] = value["train_configs"]
45
- recognizer.goals_adaptation_phase(**kwargs)
46
- ga_times.append(time.time() - start_ga_time)
47
- elif key.startswith("I_"):
48
- goal, train_config, consecutive, consecutive_str, percentage = value["goal"], value["train_config"], value["consecutive"], "consecutive" if value["consecutive"] == True else "non_consecutive", value["percentage"]
49
- results.setdefault(str(percentage), {})
50
- results[str(percentage)].setdefault(consecutive_str, {'correct': 0, 'num_of_tasks': 0, 'accuracy': 0, 'average_inference_time': 0})
51
- property_type = domain_to_env_property(args.domain)
52
- env_property = property_type(args.env_name)
53
- problem_name = env_property.goal_to_problem_str(goal)
54
- rl_agent_type = recognizer.rl_agent_type
55
- agent = rl_agent_type(domain_name=args.domain, problem_name=problem_name, algorithm=train_config[0], num_timesteps=train_config[1])
56
- agent.learn()
57
- fig_path = get_and_create(f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=args.domain, env_name=args.env_name, recognizer=args.recognizer), problem_name))}_inference_seq")
58
- generate_obs_kwargs = {"action_selection_method": stochastic_amplified_selection,
59
- "save_fig": True,
60
- "random_optimalism": True,
61
- "fig_path": fig_path,
62
- "env_prop": env_property}
63
-
64
- # need to dump the whole plan for draco because it needs it for inference phase for checking likelihood.
65
- if (recognizer_type == Draco or recognizer_type == GCDraco) and issubclass(rl_agent_type, DeepRLAgent): # TODO remove this condition, remove the assumption.
66
- generate_obs_kwargs["with_dict"] = True
67
- sequence = agent.generate_observation(**generate_obs_kwargs)
68
- if issubclass(recognizer_type, Graml): # need to dump the plans to compute offline plan similarity only in graml's case for evaluation.
69
- recognizer.dump_plans(true_sequence=sequence, true_goal=goal, percentage=percentage)
70
- partial_sequence = random_subset_with_order(sequence, (int)(percentage * len(sequence)), is_consecutive=consecutive)
71
- # add evaluation_function to kwargs if this is graql. move everything to kwargs...
72
- start_inf_time = time.time()
73
- closest_goal = recognizer.inference_phase(partial_sequence, goal, percentage)
74
- results[str(percentage)][consecutive_str]['average_inference_time'] += time.time() - start_inf_time
75
- # print(f'real goal {goal}, closest goal is: {closest_goal}')
76
- if all(a == b for a, b in zip(str(goal), closest_goal)):
77
- results[str(percentage)][consecutive_str]['correct'] += 1
78
- results[str(percentage)][consecutive_str]['num_of_tasks'] += 1
79
-
80
- for percentage in results.keys():
81
- for consecutive_str in results[str(percentage)].keys():
82
- results[str(percentage)][consecutive_str]['average_inference_time'] /= len(results[str(percentage)][consecutive_str])
83
- results[str(percentage)][consecutive_str]['accuracy'] = results[str(percentage)][consecutive_str]['correct'] / results[str(percentage)][consecutive_str]['num_of_tasks']
84
-
85
- # aggregate
86
- total_correct = sum([result['correct'] for cons_result in results.values() for result in cons_result.values()])
87
- total_tasks = sum([result['num_of_tasks'] for cons_result in results.values() for result in cons_result.values()])
88
- total_average_inference_time = sum([result['average_inference_time'] for cons_result in results.values() for result in cons_result.values()]) / total_tasks
89
-
90
- results['total'] = {'total_correct': total_correct, 'total_tasks': total_tasks, "total_accuracy": total_correct/total_tasks, 'total_average_inference_time': total_average_inference_time, 'goals_adaptation_time': sum(ga_times)/len(ga_times), 'domain_learning_time': dlp_time}
91
- print(str(results))
92
- res_file_path = get_and_create(get_experiment_results_path(domain=args.domain, env_name=args.env_name, task=args.task, recognizer=args.recognizer))
93
- print(f"generating results into {res_file_path}")
94
- with open(f'{res_file_path}.pkl', 'wb') as results_file:
95
- dill.dump(results, results_file)
96
- with open(f'{res_file_path}.txt', 'w') as results_file:
97
- results_file.write(str(results))
98
-
99
-
100
- def parse_args():
101
- parser = argparse.ArgumentParser(
102
- description="Parse command-line arguments for the RL experiment.",
103
- formatter_class=argparse.RawTextHelpFormatter
104
- )
105
-
106
- # Required arguments
107
- required_group = parser.add_argument_group("Required arguments")
108
- required_group.add_argument("--domain", choices=["point_maze", "minigrid", "parking", "panda"], required=True, help="Domain name (point_maze, minigrid, parking, or panda)")
109
- required_group.add_argument("--env_name", required=True, help="Env name (point_maze, minigrid, parking, or panda). For example, Parking-S-14-PC--v0")
110
- required_group.add_argument("--recognizer", choices=["MCTSBasedGraml", "ExpertBasedGraml", "GCGraml", "Graql", "Draco", "GCDraco"], required=True, help="Recognizer type. Follow readme.md and recognizer folder for more information and rules.")
111
- required_group.add_argument("--task", choices=["L1", "L2", "L3", "L4", "L5", "L11", "L22", "L33", "L44", "L55", "L111", "L222", "L333", "L444", "L555"], required=True, help="Task identifier (e.g., L1, L2,...,L5)")
112
-
113
- # Optional arguments
114
- optional_group = parser.add_argument_group("Optional arguments")
115
- optional_group.add_argument("--collect_stats", action="store_true", help="Whether to collect statistics")
116
- args = parser.parse_args()
117
-
118
- ### VALIDATE INPUTS ###
119
- # Assert that all required arguments are provided
120
- assert args.domain is not None and args.recognizer is not None and args.task is not None, "Missing required arguments: domain, recognizer, or task"
121
- return args
122
-
123
- if __name__ == "__main__":
124
- args = parse_args()
125
- run_odgr_problem(args)
File without changes
File without changes