gr-libs 0.1.8__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. gr_libs/__init__.py +3 -1
  2. gr_libs/_version.py +2 -2
  3. gr_libs/all_experiments.py +260 -0
  4. gr_libs/environment/__init__.py +14 -1
  5. gr_libs/environment/_utils/__init__.py +0 -0
  6. gr_libs/environment/{utils → _utils}/utils.py +1 -1
  7. gr_libs/environment/environment.py +278 -23
  8. gr_libs/evaluation/__init__.py +1 -0
  9. gr_libs/evaluation/generate_experiments_results.py +100 -0
  10. gr_libs/metrics/__init__.py +2 -0
  11. gr_libs/metrics/metrics.py +166 -31
  12. gr_libs/ml/__init__.py +1 -6
  13. gr_libs/ml/base/__init__.py +3 -1
  14. gr_libs/ml/base/rl_agent.py +68 -3
  15. gr_libs/ml/neural/__init__.py +1 -3
  16. gr_libs/ml/neural/deep_rl_learner.py +241 -84
  17. gr_libs/ml/neural/utils/__init__.py +1 -2
  18. gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +1 -1
  19. gr_libs/ml/planner/mcts/mcts_model.py +71 -34
  20. gr_libs/ml/sequential/__init__.py +0 -1
  21. gr_libs/ml/sequential/{lstm_model.py → _lstm_model.py} +11 -14
  22. gr_libs/ml/tabular/__init__.py +1 -3
  23. gr_libs/ml/tabular/tabular_q_learner.py +27 -9
  24. gr_libs/ml/tabular/tabular_rl_agent.py +22 -9
  25. gr_libs/ml/utils/__init__.py +2 -9
  26. gr_libs/ml/utils/format.py +13 -90
  27. gr_libs/ml/utils/math.py +3 -2
  28. gr_libs/ml/utils/other.py +2 -2
  29. gr_libs/ml/utils/storage.py +41 -94
  30. gr_libs/odgr_executor.py +263 -0
  31. gr_libs/problems/consts.py +570 -292
  32. gr_libs/recognizer/{utils → _utils}/format.py +2 -2
  33. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +127 -36
  34. gr_libs/recognizer/graml/{gr_dataset.py → _gr_dataset.py} +11 -11
  35. gr_libs/recognizer/graml/graml_recognizer.py +186 -35
  36. gr_libs/recognizer/recognizer.py +59 -10
  37. gr_libs/tutorials/draco_panda_tutorial.py +58 -0
  38. gr_libs/tutorials/draco_parking_tutorial.py +56 -0
  39. {tutorials → gr_libs/tutorials}/gcdraco_panda_tutorial.py +11 -11
  40. {tutorials → gr_libs/tutorials}/gcdraco_parking_tutorial.py +6 -8
  41. {tutorials → gr_libs/tutorials}/graml_minigrid_tutorial.py +18 -14
  42. {tutorials → gr_libs/tutorials}/graml_panda_tutorial.py +11 -12
  43. {tutorials → gr_libs/tutorials}/graml_parking_tutorial.py +8 -10
  44. {tutorials → gr_libs/tutorials}/graml_point_maze_tutorial.py +17 -3
  45. {tutorials → gr_libs/tutorials}/graql_minigrid_tutorial.py +2 -2
  46. {gr_libs-0.1.8.dist-info → gr_libs-0.2.5.dist-info}/METADATA +95 -29
  47. gr_libs-0.2.5.dist-info/RECORD +72 -0
  48. {gr_libs-0.1.8.dist-info → gr_libs-0.2.5.dist-info}/WHEEL +1 -1
  49. gr_libs-0.2.5.dist-info/top_level.txt +2 -0
  50. tests/test_draco.py +14 -0
  51. tests/test_gcdraco.py +2 -2
  52. tests/test_graml.py +4 -4
  53. tests/test_graql.py +1 -1
  54. tests/test_odgr_executor_expertbasedgraml.py +14 -0
  55. tests/test_odgr_executor_gcdraco.py +14 -0
  56. tests/test_odgr_executor_gcgraml.py +14 -0
  57. tests/test_odgr_executor_graql.py +14 -0
  58. evaluation/analyze_results_cross_alg_cross_domain.py +0 -267
  59. evaluation/create_minigrid_map_image.py +0 -38
  60. evaluation/file_system.py +0 -53
  61. evaluation/generate_experiments_results.py +0 -141
  62. evaluation/generate_experiments_results_new_ver1.py +0 -238
  63. evaluation/generate_experiments_results_new_ver2.py +0 -331
  64. evaluation/generate_task_specific_statistics_plots.py +0 -500
  65. evaluation/get_plans_images.py +0 -62
  66. evaluation/increasing_and_decreasing_.py +0 -104
  67. gr_libs/ml/neural/utils/penv.py +0 -60
  68. gr_libs-0.1.8.dist-info/RECORD +0 -70
  69. gr_libs-0.1.8.dist-info/top_level.txt +0 -4
  70. /gr_libs/{environment/utils/__init__.py → _evaluation/_generate_experiments_results.py} +0 -0
  71. /gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +0 -0
  72. /gr_libs/ml/planner/mcts/{utils → _utils}/node.py +0 -0
  73. /gr_libs/recognizer/{utils → _utils}/__init__.py +0 -0
@@ -1,10 +1,4 @@
1
- import csv
2
1
  import os
3
- import torch
4
- import logging
5
- import sys
6
-
7
- from .other import device
8
2
 
9
3
 
10
4
  def create_folders_if_necessary(path):
@@ -12,23 +6,34 @@ def create_folders_if_necessary(path):
12
6
  os.makedirs(path)
13
7
 
14
8
 
15
- def get_storage_framework_dir(recognizer: str):
16
- return os.path.join(get_storage_dir(), recognizer)
9
+ def get_outputs_dir():
10
+ return "outputs"
11
+
12
+
13
+ def get_recognizer_outputs_dir(recognizer: str):
14
+ return os.path.join(get_outputs_dir(), recognizer)
17
15
 
18
16
 
19
- def get_storage_dir():
17
+ def get_gr_cache_dir():
20
18
  # Prefer local directory if it exists (e.g., in GitHub workspace)
21
- if os.path.exists("dataset"):
22
- return "dataset"
19
+ if os.path.exists("gr_cache"):
20
+ return "gr_cache"
23
21
  # Fall back to pre-mounted directory (e.g., in Docker container)
24
- if os.path.exists("/preloaded_data"):
25
- return "/preloaded_data"
22
+ if os.path.exists("/gr_cache"):
23
+ return "/gr_cache"
26
24
  # Default to "dataset" even if it doesn't exist (e.g., will be created)
27
- return "dataset"
25
+ return "gr_cache"
28
26
 
29
27
 
30
- def _get_models_directory_name():
31
- return "models"
28
+ def get_trained_agents_dir():
29
+ # Prefer local directory if it exists (e.g., in GitHub workspace)
30
+ if os.path.exists("trained_agents"):
31
+ return "trained_agents"
32
+ # Fall back to pre-mounted directory (e.g., in Docker container)
33
+ if os.path.exists("/trained_agents"):
34
+ return "/trained_agents"
35
+ # Default to "dataset" even if it doesn't exist (e.g., will be created)
36
+ return "trained_agents"
32
37
 
33
38
 
34
39
  def _get_siamese_datasets_directory_name():
@@ -43,25 +48,26 @@ def get_observation_file_name(observability_percentage: float):
43
48
  return "obs" + str(observability_percentage) + ".pkl"
44
49
 
45
50
 
46
- def get_domain_dir(domain_name, recognizer: str):
47
- return os.path.join(get_storage_framework_dir(recognizer), domain_name)
51
+ def get_domain_outputs_dir(domain_name, recognizer: str):
52
+ return os.path.join(get_recognizer_outputs_dir(recognizer), domain_name)
48
53
 
49
54
 
50
- def get_env_dir(domain_name, env_name, recognizer: str):
51
- return os.path.join(get_domain_dir(domain_name, recognizer), env_name)
55
+ def get_env_outputs_dir(domain_name, env_name, recognizer: str):
56
+ return os.path.join(get_domain_outputs_dir(domain_name, recognizer), env_name)
52
57
 
53
58
 
54
59
  def get_observations_dir(domain_name, env_name, recognizer: str):
55
60
  return os.path.join(
56
- get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer),
61
+ get_env_outputs_dir(
62
+ domain_name=domain_name, env_name=env_name, recognizer=recognizer
63
+ ),
57
64
  _get_observations_directory_name(),
58
65
  )
59
66
 
60
67
 
61
68
  def get_agent_model_dir(domain_name, model_name, class_name):
62
69
  return os.path.join(
63
- get_storage_dir(),
64
- _get_models_directory_name(),
70
+ get_trained_agents_dir(),
65
71
  domain_name,
66
72
  model_name,
67
73
  class_name,
@@ -70,15 +76,7 @@ def get_agent_model_dir(domain_name, model_name, class_name):
70
76
 
71
77
  def get_lstm_model_dir(domain_name, env_name, model_name, recognizer: str):
72
78
  return os.path.join(
73
- get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer),
74
- model_name,
75
- )
76
-
77
-
78
- def get_models_dir(domain_name, env_name, recognizer: str):
79
- return os.path.join(
80
- get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer),
81
- _get_models_directory_name(),
79
+ get_gr_cache_dir(), recognizer, domain_name, env_name, model_name
82
80
  )
83
81
 
84
82
 
@@ -94,14 +92,7 @@ def get_siamese_dataset_path(domain_name, env_name, model_name, recognizer: str)
94
92
 
95
93
  def get_embeddings_result_path(domain_name, env_name, recognizer: str):
96
94
  return os.path.join(
97
- get_env_dir(domain_name, env_name=env_name, recognizer=recognizer),
98
- "goal_embeddings",
99
- )
100
-
101
-
102
- def get_embeddings_result_path(domain_name, env_name, recognizer: str):
103
- return os.path.join(
104
- get_env_dir(domain_name, env_name=env_name, recognizer=recognizer),
95
+ get_env_outputs_dir(domain_name, env_name=env_name, recognizer=recognizer),
105
96
  "goal_embeddings",
106
97
  )
107
98
 
@@ -113,9 +104,7 @@ def get_and_create(path):
113
104
 
114
105
  def get_experiment_results_path(domain, env_name, task, recognizer: str):
115
106
  return os.path.join(
116
- get_env_dir(domain, env_name=env_name, recognizer=recognizer),
117
- "experiment_results",
118
- env_name,
107
+ get_env_outputs_dir(domain, env_name=env_name, recognizer=recognizer),
119
108
  task,
120
109
  "experiment_results",
121
110
  )
@@ -123,72 +112,30 @@ def get_experiment_results_path(domain, env_name, task, recognizer: str):
123
112
 
124
113
  def get_plans_result_path(domain_name, env_name, recognizer: str):
125
114
  return os.path.join(
126
- get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "plans"
115
+ get_env_outputs_dir(domain_name, env_name=env_name, recognizer=recognizer),
116
+ "plans",
127
117
  )
128
118
 
129
119
 
130
120
  def get_policy_sequences_result_path(domain_name, env_name, recognizer: str):
131
121
  return os.path.join(
132
- get_env_dir(domain_name, env_name, recognizer=recognizer), "policy_sequences"
122
+ get_env_outputs_dir(domain_name, env_name, recognizer=recognizer),
123
+ "policy_sequences",
133
124
  )
134
125
 
135
126
 
136
127
  ### END GRAML PATHS ###
137
- ""
128
+
138
129
  ### GRAQL PATHS ###
139
130
 
140
131
 
141
132
  def get_gr_as_rl_experiment_confidence_path(domain_name, env_name, recognizer: str):
142
133
  return os.path.join(
143
- get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer),
144
- "experiments",
134
+ get_env_outputs_dir(
135
+ domain_name=domain_name, env_name=env_name, recognizer=recognizer
136
+ ),
137
+ "confidence",
145
138
  )
146
139
 
147
140
 
148
141
  ### GRAQL PATHS ###
149
-
150
-
151
- def get_status_path(model_dir):
152
- return os.path.join(model_dir, "status.pt")
153
-
154
-
155
- def get_status(model_dir):
156
- path = get_status_path(model_dir)
157
- return torch.load(path, map_location=device)
158
-
159
-
160
- def save_status(status, model_dir):
161
- path = get_status_path(model_dir)
162
- utils.create_folders_if_necessary(path)
163
- torch.save(status, path)
164
-
165
-
166
- def get_vocab(model_dir):
167
- return get_status(model_dir)["vocab"]
168
-
169
-
170
- def get_model_state(model_dir):
171
- return get_status(model_dir)["model_state"]
172
-
173
-
174
- def get_txt_logger(model_dir):
175
- path = os.path.join(model_dir, "log.txt")
176
- utils.create_folders_if_necessary(path)
177
-
178
- logging.basicConfig(
179
- level=logging.INFO,
180
- format="%(message)s",
181
- handlers=[
182
- logging.FileHandler(filename=path),
183
- logging.StreamHandler(sys.stdout),
184
- ],
185
- )
186
-
187
- return logging.getLogger()
188
-
189
-
190
- def get_csv_logger(model_dir):
191
- csv_path = os.path.join(model_dir, "log.csv")
192
- utils.create_folders_if_necessary(csv_path)
193
- csv_file = open(csv_path, "a")
194
- return csv_file, csv.writer(csv_file)
@@ -0,0 +1,263 @@
1
+ import argparse
2
+ import os
3
+ import time
4
+
5
+ import dill
6
+
7
+ from gr_libs.environment._utils.utils import domain_to_env_property
8
+ from gr_libs.metrics.metrics import stochastic_amplified_selection
9
+ from gr_libs.ml.neural.deep_rl_learner import DeepRLAgent
10
+ from gr_libs.ml.utils.format import random_subset_with_order
11
+ from gr_libs.ml.utils.storage import (
12
+ get_and_create,
13
+ get_experiment_results_path,
14
+ get_policy_sequences_result_path,
15
+ )
16
+ from gr_libs.problems.consts import PROBLEMS
17
+ from gr_libs.recognizer._utils import recognizer_str_to_obj
18
+ from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco
19
+ from gr_libs.recognizer.graml.graml_recognizer import Graml
20
+ from gr_libs.recognizer.recognizer import GaAgentTrainerRecognizer, LearningRecognizer
21
+
22
+
23
+ def validate(args, recognizer_type, task_inputs):
24
+ if "base" in task_inputs.keys():
25
+ # assert issubclass(recognizer_type, LearningRecognizer), f"base is in the task_inputs for the recognizer {args.recognizer}, which doesn't have a domain learning phase (is not a learning recognizer)."
26
+ assert (
27
+ list(task_inputs.keys())[0] == "base"
28
+ ), "In case of LearningRecognizer, base should be the first element in the task_inputs dict in consts.py"
29
+ assert (
30
+ "base" not in list(task_inputs.keys())[1:]
31
+ ), "In case of LearningRecognizer, base should be only in the first element in the task_inputs dict in consts.py"
32
+ # else:
33
+ # assert not issubclass(recognizer_type, LearningRecognizer), f"base is not in the task_inputs for the recognizer {args.recognizer}, which has a domain learning phase (is a learning recognizer). Remove it from the task_inputs dict in consts.py."
34
+
35
+
36
+ def run_odgr_problem(args):
37
+ recognizer_type = recognizer_str_to_obj(args.recognizer)
38
+ env_inputs = PROBLEMS[args.domain]
39
+ assert (
40
+ args.env_name in env_inputs.keys()
41
+ ), f"env_name {args.env_name} is not in the list of available environments for the domain {args.domain}. Add it to PROBLEMS dict in consts.py"
42
+ task_inputs = env_inputs[args.env_name][args.task]
43
+ recognizer = recognizer_type(
44
+ domain_name=args.domain,
45
+ env_name=args.env_name,
46
+ collect_statistics=args.collect_stats,
47
+ )
48
+ validate(args, recognizer_type, task_inputs)
49
+ ga_times, results = [], {}
50
+ for key, value in task_inputs.items():
51
+ if key == "base":
52
+ dlp_time = 0
53
+ if issubclass(recognizer_type, LearningRecognizer):
54
+ start_dlp_time = time.time()
55
+ recognizer.domain_learning_phase(value)
56
+ dlp_time = time.time() - start_dlp_time
57
+ elif key.startswith("G_"):
58
+ start_ga_time = time.time()
59
+ kwargs = {"dynamic_goals": value["goals"]}
60
+ if issubclass(recognizer_type, GaAgentTrainerRecognizer):
61
+ kwargs["dynamic_train_configs"] = value["train_configs"]
62
+ recognizer.goals_adaptation_phase(**kwargs)
63
+ ga_times.append(time.time() - start_ga_time)
64
+ elif key.startswith("I_"):
65
+ goal, train_config, consecutive, consecutive_str, percentage = (
66
+ value["goal"],
67
+ value["train_config"],
68
+ value["consecutive"],
69
+ "consecutive" if value["consecutive"] == True else "non_consecutive",
70
+ value["percentage"],
71
+ )
72
+ results.setdefault(str(percentage), {})
73
+ results[str(percentage)].setdefault(
74
+ consecutive_str,
75
+ {
76
+ "correct": 0,
77
+ "num_of_tasks": 0,
78
+ "accuracy": 0,
79
+ "average_inference_time": 0,
80
+ },
81
+ )
82
+ property_type = domain_to_env_property(args.domain)
83
+ env_property = property_type(args.env_name)
84
+ problem_name = env_property.goal_to_problem_str(goal)
85
+ rl_agent_type = recognizer.rl_agent_type
86
+ agent = rl_agent_type(
87
+ domain_name=args.domain,
88
+ problem_name=problem_name,
89
+ algorithm=train_config[0],
90
+ num_timesteps=train_config[1],
91
+ env_prop=env_property,
92
+ )
93
+ agent.learn()
94
+ fig_path = get_and_create(
95
+ f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=args.domain, env_name=args.env_name, recognizer=args.recognizer), problem_name))}_inference_seq"
96
+ )
97
+ generate_obs_kwargs = {
98
+ "action_selection_method": stochastic_amplified_selection,
99
+ "save_fig": args.collect_stats,
100
+ "random_optimalism": True,
101
+ "fig_path": fig_path if args.collect_stats else None,
102
+ }
103
+
104
+ # need to dump the whole plan for draco because it needs it for inference phase for checking likelihood.
105
+ if (recognizer_type == Draco or recognizer_type == GCDraco) and issubclass(
106
+ rl_agent_type, DeepRLAgent
107
+ ): # TODO remove this condition, remove the assumption.
108
+ generate_obs_kwargs["with_dict"] = True
109
+ sequence = agent.generate_observation(**generate_obs_kwargs)
110
+ if issubclass(
111
+ recognizer_type, Graml
112
+ ): # need to dump the plans to compute offline plan similarity only in graml's case for evaluation.
113
+ recognizer.dump_plans(
114
+ true_sequence=sequence, true_goal=goal, percentage=percentage
115
+ )
116
+ partial_sequence = random_subset_with_order(
117
+ sequence, (int)(percentage * len(sequence)), is_consecutive=consecutive
118
+ )
119
+ # add evaluation_function to kwargs if this is graql. move everything to kwargs...
120
+ start_inf_time = time.time()
121
+ closest_goal = recognizer.inference_phase(
122
+ partial_sequence, goal, percentage
123
+ )
124
+ results[str(percentage)][consecutive_str]["average_inference_time"] += (
125
+ time.time() - start_inf_time
126
+ )
127
+ # print(f'real goal {goal}, closest goal is: {closest_goal}')
128
+ if all(a == b for a, b in zip(str(goal), closest_goal)):
129
+ results[str(percentage)][consecutive_str]["correct"] += 1
130
+ results[str(percentage)][consecutive_str]["num_of_tasks"] += 1
131
+
132
+ for percentage in results.keys():
133
+ for consecutive_str in results[str(percentage)].keys():
134
+ results[str(percentage)][consecutive_str]["average_inference_time"] /= len(
135
+ results[str(percentage)][consecutive_str]
136
+ )
137
+ results[str(percentage)][consecutive_str]["accuracy"] = (
138
+ results[str(percentage)][consecutive_str]["correct"]
139
+ / results[str(percentage)][consecutive_str]["num_of_tasks"]
140
+ )
141
+
142
+ # aggregate
143
+ total_correct = sum(
144
+ [
145
+ result["correct"]
146
+ for cons_result in results.values()
147
+ for result in cons_result.values()
148
+ ]
149
+ )
150
+ total_tasks = sum(
151
+ [
152
+ result["num_of_tasks"]
153
+ for cons_result in results.values()
154
+ for result in cons_result.values()
155
+ ]
156
+ )
157
+ total_average_inference_time = (
158
+ sum(
159
+ [
160
+ result["average_inference_time"]
161
+ for cons_result in results.values()
162
+ for result in cons_result.values()
163
+ ]
164
+ )
165
+ / total_tasks
166
+ )
167
+
168
+ results["total"] = {
169
+ "total_correct": total_correct,
170
+ "total_tasks": total_tasks,
171
+ "total_accuracy": total_correct / total_tasks,
172
+ "total_average_inference_time": total_average_inference_time,
173
+ "goals_adaptation_time": sum(ga_times) / len(ga_times),
174
+ "domain_learning_time": dlp_time,
175
+ }
176
+ print(str(results))
177
+ res_file_path = get_and_create(
178
+ get_experiment_results_path(
179
+ domain=args.domain,
180
+ env_name=args.env_name,
181
+ task=args.task,
182
+ recognizer=args.recognizer,
183
+ )
184
+ )
185
+ if args.experiment_num is not None:
186
+ res_txt = os.path.join(res_file_path, f"res_{args.experiment_num}.txt")
187
+ res_pkl = os.path.join(res_file_path, f"res_{args.experiment_num}.pkl")
188
+ else:
189
+ res_txt = os.path.join(res_file_path, "res.txt")
190
+ res_pkl = os.path.join(res_file_path, "res.pkl")
191
+
192
+ print(f"generating results into {res_txt} and {res_pkl}")
193
+ with open(res_pkl, "wb") as results_file:
194
+ dill.dump(results, results_file)
195
+ with open(res_txt, "w") as results_file:
196
+ results_file.write(str(results))
197
+
198
+
199
+ def parse_args():
200
+ parser = argparse.ArgumentParser(
201
+ description="Parse command-line arguments for the RL experiment.",
202
+ formatter_class=argparse.RawTextHelpFormatter,
203
+ )
204
+
205
+ # Required arguments
206
+ required_group = parser.add_argument_group("Required arguments")
207
+ required_group.add_argument(
208
+ "--domain",
209
+ choices=["point_maze", "minigrid", "parking", "panda"],
210
+ required=True,
211
+ help="Domain name (point_maze, minigrid, parking, or panda)",
212
+ )
213
+ required_group.add_argument(
214
+ "--env_name",
215
+ required=True,
216
+ help="Env name (point_maze, minigrid, parking, or panda). For example, Parking-S-14-PC--v0",
217
+ )
218
+ required_group.add_argument(
219
+ "--recognizer",
220
+ choices=[
221
+ "MCTSBasedGraml",
222
+ "ExpertBasedGraml",
223
+ "GCGraml",
224
+ "Graql",
225
+ "Draco",
226
+ "GCDraco",
227
+ ],
228
+ required=True,
229
+ help="Recognizer type. Follow readme.md and recognizer folder for more information and rules.",
230
+ )
231
+ required_group.add_argument(
232
+ "--task",
233
+ choices=["L1", "L2", "L3", "L4", "L5"],
234
+ required=True,
235
+ help="Task identifier (e.g., L1, L2,...,L5)",
236
+ )
237
+
238
+ # Optional arguments
239
+ optional_group = parser.add_argument_group("Optional arguments")
240
+ optional_group.add_argument(
241
+ "--collect_stats", action="store_true", help="Whether to collect statistics"
242
+ )
243
+ optional_group.add_argument(
244
+ "--experiment_num",
245
+ type=int,
246
+ default=None,
247
+ help="Experiment number for parallel runs",
248
+ )
249
+ args = parser.parse_args()
250
+
251
+ ### VALIDATE INPUTS ###
252
+ # Assert that all required arguments are provided
253
+ assert (
254
+ args.domain is not None
255
+ and args.recognizer is not None
256
+ and args.task is not None
257
+ ), "Missing required arguments: domain, recognizer, or task"
258
+ return args
259
+
260
+
261
+ if __name__ == "__main__":
262
+ args = parse_args()
263
+ run_odgr_problem(args)