gr-libs 0.1.8__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gr_libs/__init__.py +3 -1
- gr_libs/_evaluation/__init__.py +1 -0
- evaluation/analyze_results_cross_alg_cross_domain.py → gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +81 -88
- evaluation/generate_experiments_results.py → gr_libs/_evaluation/_generate_experiments_results.py +6 -6
- evaluation/generate_task_specific_statistics_plots.py → gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +11 -14
- evaluation/get_plans_images.py → gr_libs/_evaluation/_get_plans_images.py +3 -4
- evaluation/increasing_and_decreasing_.py → gr_libs/_evaluation/_increasing_and_decreasing_.py +3 -1
- gr_libs/_version.py +2 -2
- gr_libs/all_experiments.py +294 -0
- gr_libs/environment/__init__.py +14 -1
- gr_libs/environment/{utils → _utils}/utils.py +1 -1
- gr_libs/environment/environment.py +257 -22
- gr_libs/metrics/__init__.py +2 -0
- gr_libs/metrics/metrics.py +166 -31
- gr_libs/ml/__init__.py +1 -6
- gr_libs/ml/base/__init__.py +3 -1
- gr_libs/ml/base/rl_agent.py +68 -3
- gr_libs/ml/neural/__init__.py +1 -3
- gr_libs/ml/neural/deep_rl_learner.py +227 -67
- gr_libs/ml/neural/utils/__init__.py +1 -2
- gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +1 -1
- gr_libs/ml/planner/mcts/mcts_model.py +71 -34
- gr_libs/ml/sequential/__init__.py +0 -1
- gr_libs/ml/sequential/{lstm_model.py → _lstm_model.py} +11 -14
- gr_libs/ml/tabular/__init__.py +1 -3
- gr_libs/ml/tabular/tabular_q_learner.py +27 -9
- gr_libs/ml/tabular/tabular_rl_agent.py +22 -9
- gr_libs/ml/utils/__init__.py +2 -9
- gr_libs/ml/utils/format.py +13 -90
- gr_libs/ml/utils/math.py +3 -2
- gr_libs/ml/utils/other.py +2 -2
- gr_libs/ml/utils/storage.py +41 -94
- gr_libs/odgr_executor.py +268 -0
- gr_libs/problems/consts.py +2 -2
- gr_libs/recognizer/_utils/__init__.py +0 -0
- gr_libs/recognizer/{utils → _utils}/format.py +2 -2
- gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +116 -36
- gr_libs/recognizer/graml/{gr_dataset.py → _gr_dataset.py} +11 -11
- gr_libs/recognizer/graml/graml_recognizer.py +172 -29
- gr_libs/recognizer/recognizer.py +59 -10
- gr_libs/tutorials/draco_panda_tutorial.py +58 -0
- gr_libs/tutorials/draco_parking_tutorial.py +56 -0
- {tutorials → gr_libs/tutorials}/gcdraco_panda_tutorial.py +5 -9
- {tutorials → gr_libs/tutorials}/gcdraco_parking_tutorial.py +3 -7
- {tutorials → gr_libs/tutorials}/graml_minigrid_tutorial.py +2 -2
- {tutorials → gr_libs/tutorials}/graml_panda_tutorial.py +5 -10
- {tutorials → gr_libs/tutorials}/graml_parking_tutorial.py +5 -9
- {tutorials → gr_libs/tutorials}/graml_point_maze_tutorial.py +2 -1
- {tutorials → gr_libs/tutorials}/graql_minigrid_tutorial.py +2 -2
- {gr_libs-0.1.8.dist-info → gr_libs-0.2.2.dist-info}/METADATA +84 -29
- gr_libs-0.2.2.dist-info/RECORD +71 -0
- {gr_libs-0.1.8.dist-info → gr_libs-0.2.2.dist-info}/WHEEL +1 -1
- gr_libs-0.2.2.dist-info/top_level.txt +2 -0
- tests/test_draco.py +14 -0
- tests/test_gcdraco.py +2 -2
- tests/test_graml.py +4 -4
- tests/test_graql.py +1 -1
- evaluation/create_minigrid_map_image.py +0 -38
- evaluation/file_system.py +0 -53
- evaluation/generate_experiments_results_new_ver1.py +0 -238
- evaluation/generate_experiments_results_new_ver2.py +0 -331
- gr_libs/ml/neural/utils/penv.py +0 -60
- gr_libs/recognizer/utils/__init__.py +0 -1
- gr_libs-0.1.8.dist-info/RECORD +0 -70
- gr_libs-0.1.8.dist-info/top_level.txt +0 -4
- /gr_libs/environment/{utils → _utils}/__init__.py +0 -0
- /gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +0 -0
- /gr_libs/ml/planner/mcts/{utils → _utils}/node.py +0 -0
gr_libs/ml/utils/storage.py
CHANGED
@@ -1,10 +1,4 @@
|
|
1
|
-
import csv
|
2
1
|
import os
|
3
|
-
import torch
|
4
|
-
import logging
|
5
|
-
import sys
|
6
|
-
|
7
|
-
from .other import device
|
8
2
|
|
9
3
|
|
10
4
|
def create_folders_if_necessary(path):
|
@@ -12,23 +6,34 @@ def create_folders_if_necessary(path):
|
|
12
6
|
os.makedirs(path)
|
13
7
|
|
14
8
|
|
15
|
-
def
|
16
|
-
return
|
9
|
+
def get_outputs_dir():
|
10
|
+
return "outputs"
|
11
|
+
|
12
|
+
|
13
|
+
def get_recognizer_outputs_dir(recognizer: str):
|
14
|
+
return os.path.join(get_outputs_dir(), recognizer)
|
17
15
|
|
18
16
|
|
19
|
-
def
|
17
|
+
def get_gr_cache_dir():
|
20
18
|
# Prefer local directory if it exists (e.g., in GitHub workspace)
|
21
|
-
if os.path.exists("
|
22
|
-
return "
|
19
|
+
if os.path.exists("gr_cache"):
|
20
|
+
return "gr_cache"
|
23
21
|
# Fall back to pre-mounted directory (e.g., in Docker container)
|
24
|
-
if os.path.exists("/
|
25
|
-
return "/
|
22
|
+
if os.path.exists("/gr_cache"):
|
23
|
+
return "/gr_cache"
|
26
24
|
# Default to "dataset" even if it doesn't exist (e.g., will be created)
|
27
|
-
return "
|
25
|
+
return "gr_cache"
|
28
26
|
|
29
27
|
|
30
|
-
def
|
31
|
-
|
28
|
+
def get_trained_agents_dir():
|
29
|
+
# Prefer local directory if it exists (e.g., in GitHub workspace)
|
30
|
+
if os.path.exists("trained_agents"):
|
31
|
+
return "trained_agents"
|
32
|
+
# Fall back to pre-mounted directory (e.g., in Docker container)
|
33
|
+
if os.path.exists("/trained_agents"):
|
34
|
+
return "/trained_agents"
|
35
|
+
# Default to "dataset" even if it doesn't exist (e.g., will be created)
|
36
|
+
return "trained_agents"
|
32
37
|
|
33
38
|
|
34
39
|
def _get_siamese_datasets_directory_name():
|
@@ -43,25 +48,26 @@ def get_observation_file_name(observability_percentage: float):
|
|
43
48
|
return "obs" + str(observability_percentage) + ".pkl"
|
44
49
|
|
45
50
|
|
46
|
-
def
|
47
|
-
return os.path.join(
|
51
|
+
def get_domain_outputs_dir(domain_name, recognizer: str):
|
52
|
+
return os.path.join(get_recognizer_outputs_dir(recognizer), domain_name)
|
48
53
|
|
49
54
|
|
50
|
-
def
|
51
|
-
return os.path.join(
|
55
|
+
def get_env_outputs_dir(domain_name, env_name, recognizer: str):
|
56
|
+
return os.path.join(get_domain_outputs_dir(domain_name, recognizer), env_name)
|
52
57
|
|
53
58
|
|
54
59
|
def get_observations_dir(domain_name, env_name, recognizer: str):
|
55
60
|
return os.path.join(
|
56
|
-
|
61
|
+
get_env_outputs_dir(
|
62
|
+
domain_name=domain_name, env_name=env_name, recognizer=recognizer
|
63
|
+
),
|
57
64
|
_get_observations_directory_name(),
|
58
65
|
)
|
59
66
|
|
60
67
|
|
61
68
|
def get_agent_model_dir(domain_name, model_name, class_name):
|
62
69
|
return os.path.join(
|
63
|
-
|
64
|
-
_get_models_directory_name(),
|
70
|
+
get_trained_agents_dir(),
|
65
71
|
domain_name,
|
66
72
|
model_name,
|
67
73
|
class_name,
|
@@ -70,15 +76,7 @@ def get_agent_model_dir(domain_name, model_name, class_name):
|
|
70
76
|
|
71
77
|
def get_lstm_model_dir(domain_name, env_name, model_name, recognizer: str):
|
72
78
|
return os.path.join(
|
73
|
-
|
74
|
-
model_name,
|
75
|
-
)
|
76
|
-
|
77
|
-
|
78
|
-
def get_models_dir(domain_name, env_name, recognizer: str):
|
79
|
-
return os.path.join(
|
80
|
-
get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer),
|
81
|
-
_get_models_directory_name(),
|
79
|
+
get_gr_cache_dir(), recognizer, domain_name, env_name, model_name
|
82
80
|
)
|
83
81
|
|
84
82
|
|
@@ -94,14 +92,7 @@ def get_siamese_dataset_path(domain_name, env_name, model_name, recognizer: str)
|
|
94
92
|
|
95
93
|
def get_embeddings_result_path(domain_name, env_name, recognizer: str):
|
96
94
|
return os.path.join(
|
97
|
-
|
98
|
-
"goal_embeddings",
|
99
|
-
)
|
100
|
-
|
101
|
-
|
102
|
-
def get_embeddings_result_path(domain_name, env_name, recognizer: str):
|
103
|
-
return os.path.join(
|
104
|
-
get_env_dir(domain_name, env_name=env_name, recognizer=recognizer),
|
95
|
+
get_env_outputs_dir(domain_name, env_name=env_name, recognizer=recognizer),
|
105
96
|
"goal_embeddings",
|
106
97
|
)
|
107
98
|
|
@@ -113,9 +104,7 @@ def get_and_create(path):
|
|
113
104
|
|
114
105
|
def get_experiment_results_path(domain, env_name, task, recognizer: str):
|
115
106
|
return os.path.join(
|
116
|
-
|
117
|
-
"experiment_results",
|
118
|
-
env_name,
|
107
|
+
get_env_outputs_dir(domain, env_name=env_name, recognizer=recognizer),
|
119
108
|
task,
|
120
109
|
"experiment_results",
|
121
110
|
)
|
@@ -123,72 +112,30 @@ def get_experiment_results_path(domain, env_name, task, recognizer: str):
|
|
123
112
|
|
124
113
|
def get_plans_result_path(domain_name, env_name, recognizer: str):
|
125
114
|
return os.path.join(
|
126
|
-
|
115
|
+
get_env_outputs_dir(domain_name, env_name=env_name, recognizer=recognizer),
|
116
|
+
"plans",
|
127
117
|
)
|
128
118
|
|
129
119
|
|
130
120
|
def get_policy_sequences_result_path(domain_name, env_name, recognizer: str):
|
131
121
|
return os.path.join(
|
132
|
-
|
122
|
+
get_env_outputs_dir(domain_name, env_name, recognizer=recognizer),
|
123
|
+
"policy_sequences",
|
133
124
|
)
|
134
125
|
|
135
126
|
|
136
127
|
### END GRAML PATHS ###
|
137
|
-
|
128
|
+
|
138
129
|
### GRAQL PATHS ###
|
139
130
|
|
140
131
|
|
141
132
|
def get_gr_as_rl_experiment_confidence_path(domain_name, env_name, recognizer: str):
|
142
133
|
return os.path.join(
|
143
|
-
|
144
|
-
|
134
|
+
get_env_outputs_dir(
|
135
|
+
domain_name=domain_name, env_name=env_name, recognizer=recognizer
|
136
|
+
),
|
137
|
+
"confidence",
|
145
138
|
)
|
146
139
|
|
147
140
|
|
148
141
|
### GRAQL PATHS ###
|
149
|
-
|
150
|
-
|
151
|
-
def get_status_path(model_dir):
|
152
|
-
return os.path.join(model_dir, "status.pt")
|
153
|
-
|
154
|
-
|
155
|
-
def get_status(model_dir):
|
156
|
-
path = get_status_path(model_dir)
|
157
|
-
return torch.load(path, map_location=device)
|
158
|
-
|
159
|
-
|
160
|
-
def save_status(status, model_dir):
|
161
|
-
path = get_status_path(model_dir)
|
162
|
-
utils.create_folders_if_necessary(path)
|
163
|
-
torch.save(status, path)
|
164
|
-
|
165
|
-
|
166
|
-
def get_vocab(model_dir):
|
167
|
-
return get_status(model_dir)["vocab"]
|
168
|
-
|
169
|
-
|
170
|
-
def get_model_state(model_dir):
|
171
|
-
return get_status(model_dir)["model_state"]
|
172
|
-
|
173
|
-
|
174
|
-
def get_txt_logger(model_dir):
|
175
|
-
path = os.path.join(model_dir, "log.txt")
|
176
|
-
utils.create_folders_if_necessary(path)
|
177
|
-
|
178
|
-
logging.basicConfig(
|
179
|
-
level=logging.INFO,
|
180
|
-
format="%(message)s",
|
181
|
-
handlers=[
|
182
|
-
logging.FileHandler(filename=path),
|
183
|
-
logging.StreamHandler(sys.stdout),
|
184
|
-
],
|
185
|
-
)
|
186
|
-
|
187
|
-
return logging.getLogger()
|
188
|
-
|
189
|
-
|
190
|
-
def get_csv_logger(model_dir):
|
191
|
-
csv_path = os.path.join(model_dir, "log.csv")
|
192
|
-
utils.create_folders_if_necessary(csv_path)
|
193
|
-
csv_file = open(csv_path, "a")
|
194
|
-
return csv_file, csv.writer(csv_file)
|
gr_libs/odgr_executor.py
ADDED
@@ -0,0 +1,268 @@
|
|
1
|
+
import argparse
|
2
|
+
import os
|
3
|
+
import time
|
4
|
+
|
5
|
+
import dill
|
6
|
+
|
7
|
+
from gr_libs.environment.utils.utils import domain_to_env_property
|
8
|
+
from gr_libs.metrics.metrics import stochastic_amplified_selection
|
9
|
+
from gr_libs.ml.neural.deep_rl_learner import DeepRLAgent
|
10
|
+
from gr_libs.ml.utils.format import random_subset_with_order
|
11
|
+
from gr_libs.ml.utils.storage import (
|
12
|
+
get_and_create,
|
13
|
+
get_experiment_results_path,
|
14
|
+
get_policy_sequences_result_path,
|
15
|
+
)
|
16
|
+
from gr_libs.problems.consts import PROBLEMS
|
17
|
+
from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco
|
18
|
+
from gr_libs.recognizer.graml.graml_recognizer import Graml
|
19
|
+
from gr_libs.recognizer.recognizer import GaAgentTrainerRecognizer, LearningRecognizer
|
20
|
+
from gr_libs.recognizer.utils import recognizer_str_to_obj
|
21
|
+
|
22
|
+
|
23
|
+
def validate(args, recognizer_type, task_inputs):
|
24
|
+
if "base" in task_inputs.keys():
|
25
|
+
# assert issubclass(recognizer_type, LearningRecognizer), f"base is in the task_inputs for the recognizer {args.recognizer}, which doesn't have a domain learning phase (is not a learning recognizer)."
|
26
|
+
assert (
|
27
|
+
list(task_inputs.keys())[0] == "base"
|
28
|
+
), "In case of LearningRecognizer, base should be the first element in the task_inputs dict in consts.py"
|
29
|
+
assert (
|
30
|
+
"base" not in list(task_inputs.keys())[1:]
|
31
|
+
), "In case of LearningRecognizer, base should be only in the first element in the task_inputs dict in consts.py"
|
32
|
+
# else:
|
33
|
+
# assert not issubclass(recognizer_type, LearningRecognizer), f"base is not in the task_inputs for the recognizer {args.recognizer}, which has a domain learning phase (is a learning recognizer). Remove it from the task_inputs dict in consts.py."
|
34
|
+
|
35
|
+
|
36
|
+
def run_odgr_problem(args):
|
37
|
+
recognizer_type = recognizer_str_to_obj(args.recognizer)
|
38
|
+
env_inputs = PROBLEMS[args.domain]
|
39
|
+
assert (
|
40
|
+
args.env_name in env_inputs.keys()
|
41
|
+
), f"env_name {args.env_name} is not in the list of available environments for the domain {args.domain}. Add it to PROBLEMS dict in consts.py"
|
42
|
+
task_inputs = env_inputs[args.env_name][args.task]
|
43
|
+
recognizer = recognizer_type(
|
44
|
+
domain_name=args.domain,
|
45
|
+
env_name=args.env_name,
|
46
|
+
collect_statistics=args.collect_stats,
|
47
|
+
)
|
48
|
+
validate(args, recognizer_type, task_inputs)
|
49
|
+
ga_times, results = [], {}
|
50
|
+
for key, value in task_inputs.items():
|
51
|
+
if key == "base":
|
52
|
+
dlp_time = 0
|
53
|
+
if issubclass(recognizer_type, LearningRecognizer):
|
54
|
+
start_dlp_time = time.time()
|
55
|
+
recognizer.domain_learning_phase(
|
56
|
+
base_goals=value["goals"], train_configs=value["train_configs"]
|
57
|
+
)
|
58
|
+
dlp_time = time.time() - start_dlp_time
|
59
|
+
elif key.startswith("G_"):
|
60
|
+
start_ga_time = time.time()
|
61
|
+
kwargs = {"dynamic_goals": value["goals"]}
|
62
|
+
if issubclass(recognizer_type, GaAgentTrainerRecognizer):
|
63
|
+
kwargs["dynamic_train_configs"] = value["train_configs"]
|
64
|
+
recognizer.goals_adaptation_phase(**kwargs)
|
65
|
+
ga_times.append(time.time() - start_ga_time)
|
66
|
+
elif key.startswith("I_"):
|
67
|
+
goal, train_config, consecutive, consecutive_str, percentage = (
|
68
|
+
value["goal"],
|
69
|
+
value["train_config"],
|
70
|
+
value["consecutive"],
|
71
|
+
"consecutive" if value["consecutive"] == True else "non_consecutive",
|
72
|
+
value["percentage"],
|
73
|
+
)
|
74
|
+
results.setdefault(str(percentage), {})
|
75
|
+
results[str(percentage)].setdefault(
|
76
|
+
consecutive_str,
|
77
|
+
{
|
78
|
+
"correct": 0,
|
79
|
+
"num_of_tasks": 0,
|
80
|
+
"accuracy": 0,
|
81
|
+
"average_inference_time": 0,
|
82
|
+
},
|
83
|
+
)
|
84
|
+
property_type = domain_to_env_property(args.domain)
|
85
|
+
env_property = property_type(args.env_name)
|
86
|
+
problem_name = env_property.goal_to_problem_str(goal)
|
87
|
+
rl_agent_type = recognizer.rl_agent_type
|
88
|
+
agent = rl_agent_type(
|
89
|
+
domain_name=args.domain,
|
90
|
+
problem_name=problem_name,
|
91
|
+
algorithm=train_config[0],
|
92
|
+
num_timesteps=train_config[1],
|
93
|
+
env_prop=env_property,
|
94
|
+
)
|
95
|
+
agent.learn()
|
96
|
+
fig_path = get_and_create(
|
97
|
+
f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=args.domain, env_name=args.env_name, recognizer=args.recognizer), problem_name))}_inference_seq"
|
98
|
+
)
|
99
|
+
generate_obs_kwargs = {
|
100
|
+
"action_selection_method": stochastic_amplified_selection,
|
101
|
+
"save_fig": args.collect_stats,
|
102
|
+
"random_optimalism": True,
|
103
|
+
"fig_path": fig_path if args.collect_stats else None,
|
104
|
+
}
|
105
|
+
|
106
|
+
# need to dump the whole plan for draco because it needs it for inference phase for checking likelihood.
|
107
|
+
if (recognizer_type == Draco or recognizer_type == GCDraco) and issubclass(
|
108
|
+
rl_agent_type, DeepRLAgent
|
109
|
+
): # TODO remove this condition, remove the assumption.
|
110
|
+
generate_obs_kwargs["with_dict"] = True
|
111
|
+
sequence = agent.generate_observation(**generate_obs_kwargs)
|
112
|
+
if issubclass(
|
113
|
+
recognizer_type, Graml
|
114
|
+
): # need to dump the plans to compute offline plan similarity only in graml's case for evaluation.
|
115
|
+
recognizer.dump_plans(
|
116
|
+
true_sequence=sequence, true_goal=goal, percentage=percentage
|
117
|
+
)
|
118
|
+
partial_sequence = random_subset_with_order(
|
119
|
+
sequence, (int)(percentage * len(sequence)), is_consecutive=consecutive
|
120
|
+
)
|
121
|
+
# add evaluation_function to kwargs if this is graql. move everything to kwargs...
|
122
|
+
start_inf_time = time.time()
|
123
|
+
closest_goal = recognizer.inference_phase(
|
124
|
+
partial_sequence, goal, percentage
|
125
|
+
)
|
126
|
+
results[str(percentage)][consecutive_str]["average_inference_time"] += (
|
127
|
+
time.time() - start_inf_time
|
128
|
+
)
|
129
|
+
# print(f'real goal {goal}, closest goal is: {closest_goal}')
|
130
|
+
if all(a == b for a, b in zip(str(goal), closest_goal)):
|
131
|
+
results[str(percentage)][consecutive_str]["correct"] += 1
|
132
|
+
results[str(percentage)][consecutive_str]["num_of_tasks"] += 1
|
133
|
+
|
134
|
+
for percentage in results.keys():
|
135
|
+
for consecutive_str in results[str(percentage)].keys():
|
136
|
+
results[str(percentage)][consecutive_str]["average_inference_time"] /= len(
|
137
|
+
results[str(percentage)][consecutive_str]
|
138
|
+
)
|
139
|
+
results[str(percentage)][consecutive_str]["accuracy"] = (
|
140
|
+
results[str(percentage)][consecutive_str]["correct"]
|
141
|
+
/ results[str(percentage)][consecutive_str]["num_of_tasks"]
|
142
|
+
)
|
143
|
+
|
144
|
+
# aggregate
|
145
|
+
total_correct = sum(
|
146
|
+
[
|
147
|
+
result["correct"]
|
148
|
+
for cons_result in results.values()
|
149
|
+
for result in cons_result.values()
|
150
|
+
]
|
151
|
+
)
|
152
|
+
total_tasks = sum(
|
153
|
+
[
|
154
|
+
result["num_of_tasks"]
|
155
|
+
for cons_result in results.values()
|
156
|
+
for result in cons_result.values()
|
157
|
+
]
|
158
|
+
)
|
159
|
+
total_average_inference_time = (
|
160
|
+
sum(
|
161
|
+
[
|
162
|
+
result["average_inference_time"]
|
163
|
+
for cons_result in results.values()
|
164
|
+
for result in cons_result.values()
|
165
|
+
]
|
166
|
+
)
|
167
|
+
/ total_tasks
|
168
|
+
)
|
169
|
+
|
170
|
+
results["total"] = {
|
171
|
+
"total_correct": total_correct,
|
172
|
+
"total_tasks": total_tasks,
|
173
|
+
"total_accuracy": total_correct / total_tasks,
|
174
|
+
"total_average_inference_time": total_average_inference_time,
|
175
|
+
"goals_adaptation_time": sum(ga_times) / len(ga_times),
|
176
|
+
"domain_learning_time": dlp_time,
|
177
|
+
}
|
178
|
+
print(str(results))
|
179
|
+
res_file_path = get_and_create(
|
180
|
+
get_experiment_results_path(
|
181
|
+
domain=args.domain,
|
182
|
+
env_name=args.env_name,
|
183
|
+
task=args.task,
|
184
|
+
recognizer=args.recognizer,
|
185
|
+
)
|
186
|
+
)
|
187
|
+
print(f"generating results into {res_file_path}")
|
188
|
+
with open(os.path.join(res_file_path, "res.pkl"), "wb") as results_file:
|
189
|
+
dill.dump(results, results_file)
|
190
|
+
with open(os.path.join(res_file_path, "res.txt"), "w") as results_file:
|
191
|
+
results_file.write(str(results))
|
192
|
+
|
193
|
+
|
194
|
+
def parse_args():
|
195
|
+
parser = argparse.ArgumentParser(
|
196
|
+
description="Parse command-line arguments for the RL experiment.",
|
197
|
+
formatter_class=argparse.RawTextHelpFormatter,
|
198
|
+
)
|
199
|
+
|
200
|
+
# Required arguments
|
201
|
+
required_group = parser.add_argument_group("Required arguments")
|
202
|
+
required_group.add_argument(
|
203
|
+
"--domain",
|
204
|
+
choices=["point_maze", "minigrid", "parking", "panda"],
|
205
|
+
required=True,
|
206
|
+
help="Domain name (point_maze, minigrid, parking, or panda)",
|
207
|
+
)
|
208
|
+
required_group.add_argument(
|
209
|
+
"--env_name",
|
210
|
+
required=True,
|
211
|
+
help="Env name (point_maze, minigrid, parking, or panda). For example, Parking-S-14-PC--v0",
|
212
|
+
)
|
213
|
+
required_group.add_argument(
|
214
|
+
"--recognizer",
|
215
|
+
choices=[
|
216
|
+
"MCTSBasedGraml",
|
217
|
+
"ExpertBasedGraml",
|
218
|
+
"GCGraml",
|
219
|
+
"Graql",
|
220
|
+
"Draco",
|
221
|
+
"GCDraco",
|
222
|
+
],
|
223
|
+
required=True,
|
224
|
+
help="Recognizer type. Follow readme.md and recognizer folder for more information and rules.",
|
225
|
+
)
|
226
|
+
required_group.add_argument(
|
227
|
+
"--task",
|
228
|
+
choices=[
|
229
|
+
"L1",
|
230
|
+
"L2",
|
231
|
+
"L3",
|
232
|
+
"L4",
|
233
|
+
"L5",
|
234
|
+
"L11",
|
235
|
+
"L22",
|
236
|
+
"L33",
|
237
|
+
"L44",
|
238
|
+
"L55",
|
239
|
+
"L111",
|
240
|
+
"L222",
|
241
|
+
"L333",
|
242
|
+
"L444",
|
243
|
+
"L555",
|
244
|
+
],
|
245
|
+
required=True,
|
246
|
+
help="Task identifier (e.g., L1, L2,...,L5)",
|
247
|
+
)
|
248
|
+
|
249
|
+
# Optional arguments
|
250
|
+
optional_group = parser.add_argument_group("Optional arguments")
|
251
|
+
optional_group.add_argument(
|
252
|
+
"--collect_stats", action="store_true", help="Whether to collect statistics"
|
253
|
+
)
|
254
|
+
args = parser.parse_args()
|
255
|
+
|
256
|
+
### VALIDATE INPUTS ###
|
257
|
+
# Assert that all required arguments are provided
|
258
|
+
assert (
|
259
|
+
args.domain is not None
|
260
|
+
and args.recognizer is not None
|
261
|
+
and args.task is not None
|
262
|
+
), "Missing required arguments: domain, recognizer, or task"
|
263
|
+
return args
|
264
|
+
|
265
|
+
|
266
|
+
if __name__ == "__main__":
|
267
|
+
args = parse_args()
|
268
|
+
run_odgr_problem(args)
|
gr_libs/problems/consts.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
import numpy as np
|
2
2
|
from stable_baselines3 import PPO, SAC, TD3
|
3
|
+
|
3
4
|
from gr_libs.environment.environment import (
|
4
5
|
MINIGRID,
|
5
|
-
PARKING,
|
6
6
|
PANDA,
|
7
|
+
PARKING,
|
7
8
|
POINT_MAZE,
|
8
9
|
QLEARNING,
|
9
10
|
PandaProperty,
|
10
11
|
)
|
11
12
|
|
12
|
-
|
13
13
|
PROBLEMS = {
|
14
14
|
PARKING: {
|
15
15
|
"Parking-S-14-PC-": {
|
File without changes
|
@@ -1,9 +1,9 @@
|
|
1
|
+
from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco, Graql
|
1
2
|
from gr_libs.recognizer.graml.graml_recognizer import (
|
2
|
-
GCGraml,
|
3
3
|
ExpertBasedGraml,
|
4
|
+
GCGraml,
|
4
5
|
MCTSBasedGraml,
|
5
6
|
)
|
6
|
-
from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Graql, Draco, GCDraco
|
7
7
|
|
8
8
|
|
9
9
|
def recognizer_str_to_obj(recognizer_str: str):
|