gr-libs 0.1.7.post0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gr_libs/__init__.py +4 -1
- gr_libs/_evaluation/__init__.py +1 -0
- gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +260 -0
- gr_libs/_evaluation/_generate_experiments_results.py +141 -0
- gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +497 -0
- gr_libs/_evaluation/_get_plans_images.py +61 -0
- gr_libs/_evaluation/_increasing_and_decreasing_.py +106 -0
- gr_libs/_version.py +2 -2
- gr_libs/all_experiments.py +294 -0
- gr_libs/environment/__init__.py +30 -9
- gr_libs/environment/_utils/utils.py +27 -0
- gr_libs/environment/environment.py +417 -54
- gr_libs/metrics/__init__.py +7 -0
- gr_libs/metrics/metrics.py +231 -54
- gr_libs/ml/__init__.py +2 -5
- gr_libs/ml/agent.py +21 -6
- gr_libs/ml/base/__init__.py +3 -1
- gr_libs/ml/base/rl_agent.py +81 -13
- gr_libs/ml/consts.py +1 -1
- gr_libs/ml/neural/__init__.py +1 -3
- gr_libs/ml/neural/deep_rl_learner.py +619 -378
- gr_libs/ml/neural/utils/__init__.py +1 -2
- gr_libs/ml/neural/utils/dictlist.py +3 -3
- gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +1 -1
- gr_libs/ml/planner/mcts/{utils → _utils}/node.py +11 -7
- gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +15 -11
- gr_libs/ml/planner/mcts/mcts_model.py +571 -312
- gr_libs/ml/sequential/__init__.py +0 -1
- gr_libs/ml/sequential/_lstm_model.py +270 -0
- gr_libs/ml/tabular/__init__.py +1 -3
- gr_libs/ml/tabular/state.py +7 -7
- gr_libs/ml/tabular/tabular_q_learner.py +150 -82
- gr_libs/ml/tabular/tabular_rl_agent.py +42 -28
- gr_libs/ml/utils/__init__.py +2 -3
- gr_libs/ml/utils/format.py +28 -97
- gr_libs/ml/utils/math.py +5 -3
- gr_libs/ml/utils/other.py +3 -3
- gr_libs/ml/utils/storage.py +88 -81
- gr_libs/odgr_executor.py +268 -0
- gr_libs/problems/consts.py +1549 -1227
- gr_libs/recognizer/_utils/__init__.py +0 -0
- gr_libs/recognizer/_utils/format.py +18 -0
- gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +233 -88
- gr_libs/recognizer/graml/_gr_dataset.py +233 -0
- gr_libs/recognizer/graml/graml_recognizer.py +586 -252
- gr_libs/recognizer/recognizer.py +90 -30
- gr_libs/tutorials/draco_panda_tutorial.py +58 -0
- gr_libs/tutorials/draco_parking_tutorial.py +56 -0
- gr_libs/tutorials/gcdraco_panda_tutorial.py +62 -0
- gr_libs/tutorials/gcdraco_parking_tutorial.py +57 -0
- gr_libs/tutorials/graml_minigrid_tutorial.py +64 -0
- gr_libs/tutorials/graml_panda_tutorial.py +57 -0
- gr_libs/tutorials/graml_parking_tutorial.py +52 -0
- gr_libs/tutorials/graml_point_maze_tutorial.py +60 -0
- gr_libs/tutorials/graql_minigrid_tutorial.py +50 -0
- {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/METADATA +84 -29
- gr_libs-0.2.2.dist-info/RECORD +71 -0
- {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/WHEEL +1 -1
- gr_libs-0.2.2.dist-info/top_level.txt +2 -0
- tests/test_draco.py +14 -0
- tests/test_gcdraco.py +10 -0
- tests/test_graml.py +12 -8
- tests/test_graql.py +3 -2
- evaluation/analyze_results_cross_alg_cross_domain.py +0 -277
- evaluation/create_minigrid_map_image.py +0 -34
- evaluation/file_system.py +0 -42
- evaluation/generate_experiments_results.py +0 -92
- evaluation/generate_experiments_results_new_ver1.py +0 -254
- evaluation/generate_experiments_results_new_ver2.py +0 -331
- evaluation/generate_task_specific_statistics_plots.py +0 -272
- evaluation/get_plans_images.py +0 -47
- evaluation/increasing_and_decreasing_.py +0 -63
- gr_libs/environment/utils/utils.py +0 -17
- gr_libs/ml/neural/utils/penv.py +0 -57
- gr_libs/ml/sequential/lstm_model.py +0 -192
- gr_libs/recognizer/graml/gr_dataset.py +0 -134
- gr_libs/recognizer/utils/__init__.py +0 -1
- gr_libs/recognizer/utils/format.py +0 -13
- gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
- gr_libs-0.1.7.post0.dist-info/top_level.txt +0 -4
- tutorials/graml_minigrid_tutorial.py +0 -34
- tutorials/graml_panda_tutorial.py +0 -41
- tutorials/graml_parking_tutorial.py +0 -39
- tutorials/graml_point_maze_tutorial.py +0 -39
- tutorials/graql_minigrid_tutorial.py +0 -34
- /gr_libs/environment/{utils → _utils}/__init__.py +0 -0
gr_libs/ml/utils/format.py
CHANGED
@@ -1,100 +1,31 @@
|
|
1
|
-
|
2
|
-
import re
|
3
|
-
import torch
|
4
|
-
import gr_libs.ml
|
5
|
-
import gymnasium as gym
|
6
|
-
import random
|
7
|
-
|
8
|
-
def get_obss_preprocessor(obs_space):
|
9
|
-
# Check if obs_space is an image space
|
10
|
-
if isinstance(obs_space, gym.spaces.Box):
|
11
|
-
obs_space = {"image": obs_space.shape}
|
12
|
-
|
13
|
-
def preprocess_obss(obss, device=None):
|
14
|
-
return ml.DictList({
|
15
|
-
"image": preprocess_images(obss, device=device)
|
16
|
-
})
|
17
|
-
|
18
|
-
# Check if it is a MiniGrid observation space
|
19
|
-
elif isinstance(obs_space, gym.spaces.Dict) and "image" in obs_space.spaces.keys():
|
20
|
-
obs_space = {"image": obs_space.spaces["image"].shape, "text": 100}
|
21
|
-
|
22
|
-
vocab = Vocabulary(obs_space["text"])
|
23
|
-
|
24
|
-
def preprocess_obss(obss, device=None):
|
25
|
-
return ml.DictList({
|
26
|
-
"image": preprocess_images([obs["image"] for obs in obss], device=device),
|
27
|
-
"text": preprocess_texts([obs["mission"] for obs in obss], vocab, device=device)
|
28
|
-
})
|
29
|
-
|
30
|
-
preprocess_obss.vocab = vocab
|
31
|
-
|
32
|
-
# Check if it is a MiniGrid observation space
|
33
|
-
elif isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces.keys():
|
34
|
-
obs_space = {"observation": obs_space.spaces["observation"].shape}
|
35
|
-
|
36
|
-
def preprocess_obss(obss, device=None):
|
37
|
-
return ml.DictList({
|
38
|
-
"observation": preprocess_images(obss, device=device)
|
39
|
-
})
|
40
|
-
|
41
|
-
|
42
|
-
else:
|
43
|
-
raise ValueError("Unknown observation space: " + str(obs_space))
|
44
|
-
|
45
|
-
return obs_space, preprocess_obss
|
46
|
-
|
1
|
+
""" formatting-related utilities """
|
47
2
|
|
48
|
-
|
49
|
-
# Bug of Pytorch: very slow if not first converted to numpy array
|
50
|
-
images = numpy.array(images)
|
51
|
-
return torch.tensor(images, device=device, dtype=torch.float)
|
52
|
-
|
53
|
-
|
54
|
-
def random_subset_with_order(sequence, subset_size, is_consecutive = True):
|
55
|
-
if subset_size >= len(sequence):
|
56
|
-
return sequence
|
57
|
-
else:
|
58
|
-
if is_consecutive:
|
59
|
-
indices_to_select = [i for i in range(subset_size)]
|
60
|
-
else:
|
61
|
-
indices_to_select = sorted(random.sample(range(len(sequence)), subset_size)) # Randomly select indices to keep
|
62
|
-
return [sequence[i] for i in indices_to_select] # Return the elements corresponding to the selected indices
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
def preprocess_texts(texts, vocab, device=None):
|
67
|
-
var_indexed_texts = []
|
68
|
-
max_text_len = 0
|
69
|
-
|
70
|
-
for text in texts:
|
71
|
-
tokens = re.findall("([a-z]+)", text.lower())
|
72
|
-
var_indexed_text = numpy.array([vocab[token] for token in tokens])
|
73
|
-
var_indexed_texts.append(var_indexed_text)
|
74
|
-
max_text_len = max(len(var_indexed_text), max_text_len)
|
75
|
-
|
76
|
-
indexed_texts = numpy.zeros((len(texts), max_text_len))
|
77
|
-
|
78
|
-
for i, indexed_text in enumerate(var_indexed_texts):
|
79
|
-
indexed_texts[i, :len(indexed_text)] = indexed_text
|
80
|
-
|
81
|
-
return torch.tensor(indexed_texts, device=device, dtype=torch.long)
|
82
|
-
|
83
|
-
|
84
|
-
class Vocabulary:
|
85
|
-
"""A mapping from tokens to ids with a capacity of `max_size` words.
|
86
|
-
It can be saved in a `vocab.json` file."""
|
87
|
-
|
88
|
-
def __init__(self, max_size):
|
89
|
-
self.max_size = max_size
|
90
|
-
self.vocab = {}
|
3
|
+
import random
|
91
4
|
|
92
|
-
def load_vocab(self, vocab):
|
93
|
-
self.vocab = vocab
|
94
5
|
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
6
|
+
def random_subset_with_order(sequence, subset_size, is_consecutive=True):
|
7
|
+
"""
|
8
|
+
Returns a random subset of elements from the given sequence with a specified subset size.
|
9
|
+
|
10
|
+
Args:
|
11
|
+
sequence (list): The sequence of elements to select from.
|
12
|
+
subset_size (int): The size of the desired subset.
|
13
|
+
is_consecutive (bool, optional): Whether the selected subset should be consecutive elements from the sequence.
|
14
|
+
Defaults to True.
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
list: A random subset of elements from the sequence.
|
18
|
+
|
19
|
+
"""
|
20
|
+
if subset_size >= len(sequence):
|
21
|
+
return sequence
|
22
|
+
else:
|
23
|
+
if is_consecutive:
|
24
|
+
indices_to_select = [i for i in range(subset_size)]
|
25
|
+
else:
|
26
|
+
indices_to_select = sorted(
|
27
|
+
random.sample(range(len(sequence)), subset_size)
|
28
|
+
) # Randomly select indices to keep
|
29
|
+
return [
|
30
|
+
sequence[i] for i in indices_to_select
|
31
|
+
] # Return the elements corresponding to the selected indices
|
gr_libs/ml/utils/math.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
|
+
""" math-related functions """
|
2
|
+
|
1
3
|
import math
|
2
|
-
from typing import Callable, Generator, List
|
3
4
|
|
4
|
-
|
5
|
+
|
6
|
+
def softmax(values: list[float]) -> list[float]:
|
5
7
|
"""Computes softmax probabilities for an array of values
|
6
8
|
TODO We should probably use numpy arrays here
|
7
9
|
Args:
|
@@ -10,4 +12,4 @@ def softmax(values: List[float]) -> List[float]:
|
|
10
12
|
Returns:
|
11
13
|
np.array: softmax probabilities
|
12
14
|
"""
|
13
|
-
return [(math.exp(q)) / sum([math.exp(_q) for _q in values]) for q in values]
|
15
|
+
return [(math.exp(q)) / sum([math.exp(_q) for _q in values]) for q in values]
|
gr_libs/ml/utils/other.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
|
+
import collections
|
1
2
|
import random
|
3
|
+
|
2
4
|
import numpy
|
3
5
|
import torch
|
4
|
-
import collections
|
5
|
-
|
6
6
|
|
7
7
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8
8
|
|
@@ -21,4 +21,4 @@ def synthesize(array):
|
|
21
21
|
d["std"] = numpy.std(array)
|
22
22
|
d["min"] = numpy.amin(array)
|
23
23
|
d["max"] = numpy.amax(array)
|
24
|
-
return d
|
24
|
+
return d
|
gr_libs/ml/utils/storage.py
CHANGED
@@ -1,134 +1,141 @@
|
|
1
|
-
import csv
|
2
1
|
import os
|
3
|
-
import torch
|
4
|
-
import logging
|
5
|
-
import sys
|
6
2
|
|
7
|
-
from .other import device
|
8
3
|
|
9
4
|
def create_folders_if_necessary(path):
|
10
5
|
if not os.path.exists(path):
|
11
6
|
os.makedirs(path)
|
12
7
|
|
13
8
|
|
14
|
-
def
|
15
|
-
return
|
9
|
+
def get_outputs_dir():
|
10
|
+
return "outputs"
|
16
11
|
|
17
|
-
|
12
|
+
|
13
|
+
def get_recognizer_outputs_dir(recognizer: str):
|
14
|
+
return os.path.join(get_outputs_dir(), recognizer)
|
15
|
+
|
16
|
+
|
17
|
+
def get_gr_cache_dir():
|
18
18
|
# Prefer local directory if it exists (e.g., in GitHub workspace)
|
19
|
-
if os.path.exists("
|
20
|
-
return "
|
19
|
+
if os.path.exists("gr_cache"):
|
20
|
+
return "gr_cache"
|
21
21
|
# Fall back to pre-mounted directory (e.g., in Docker container)
|
22
|
-
if os.path.exists("/
|
23
|
-
return "/
|
22
|
+
if os.path.exists("/gr_cache"):
|
23
|
+
return "/gr_cache"
|
24
24
|
# Default to "dataset" even if it doesn't exist (e.g., will be created)
|
25
|
-
return "
|
25
|
+
return "gr_cache"
|
26
|
+
|
27
|
+
|
28
|
+
def get_trained_agents_dir():
|
29
|
+
# Prefer local directory if it exists (e.g., in GitHub workspace)
|
30
|
+
if os.path.exists("trained_agents"):
|
31
|
+
return "trained_agents"
|
32
|
+
# Fall back to pre-mounted directory (e.g., in Docker container)
|
33
|
+
if os.path.exists("/trained_agents"):
|
34
|
+
return "/trained_agents"
|
35
|
+
# Default to "dataset" even if it doesn't exist (e.g., will be created)
|
36
|
+
return "trained_agents"
|
26
37
|
|
27
|
-
def _get_models_directory_name():
|
28
|
-
return "models"
|
29
38
|
|
30
39
|
def _get_siamese_datasets_directory_name():
|
31
40
|
return "siamese_datasets"
|
32
41
|
|
42
|
+
|
33
43
|
def _get_observations_directory_name():
|
34
44
|
return "observations"
|
35
45
|
|
36
|
-
def get_observation_file_name(observability_percentage: float):
|
37
|
-
return 'obs' + str(observability_percentage) + '.pkl'
|
38
46
|
|
47
|
+
def get_observation_file_name(observability_percentage: float):
|
48
|
+
return "obs" + str(observability_percentage) + ".pkl"
|
39
49
|
|
40
|
-
def get_domain_dir(domain_name, recognizer:str):
|
41
|
-
return os.path.join(get_storage_framework_dir(recognizer), domain_name)
|
42
50
|
|
43
|
-
def
|
44
|
-
return os.path.join(
|
51
|
+
def get_domain_outputs_dir(domain_name, recognizer: str):
|
52
|
+
return os.path.join(get_recognizer_outputs_dir(recognizer), domain_name)
|
45
53
|
|
46
|
-
def get_observations_dir(domain_name, env_name, recognizer:str):
|
47
|
-
return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), _get_observations_directory_name())
|
48
54
|
|
49
|
-
def
|
50
|
-
return os.path.join(
|
55
|
+
def get_env_outputs_dir(domain_name, env_name, recognizer: str):
|
56
|
+
return os.path.join(get_domain_outputs_dir(domain_name, recognizer), env_name)
|
51
57
|
|
52
|
-
def get_lstm_model_dir(domain_name, env_name, model_name, recognizer:str):
|
53
|
-
return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), model_name)
|
54
58
|
|
55
|
-
def
|
56
|
-
return os.path.join(
|
59
|
+
def get_observations_dir(domain_name, env_name, recognizer: str):
|
60
|
+
return os.path.join(
|
61
|
+
get_env_outputs_dir(
|
62
|
+
domain_name=domain_name, env_name=env_name, recognizer=recognizer
|
63
|
+
),
|
64
|
+
_get_observations_directory_name(),
|
65
|
+
)
|
57
66
|
|
58
|
-
### GRAML PATHS ###
|
59
67
|
|
60
|
-
def
|
61
|
-
return os.path.join(
|
68
|
+
def get_agent_model_dir(domain_name, model_name, class_name):
|
69
|
+
return os.path.join(
|
70
|
+
get_trained_agents_dir(),
|
71
|
+
domain_name,
|
72
|
+
model_name,
|
73
|
+
class_name,
|
74
|
+
)
|
62
75
|
|
63
|
-
def get_embeddings_result_path(domain_name, env_name, recognizer:str):
|
64
|
-
return os.path.join(get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "goal_embeddings")
|
65
76
|
|
66
|
-
def
|
67
|
-
return os.path.join(
|
77
|
+
def get_lstm_model_dir(domain_name, env_name, model_name, recognizer: str):
|
78
|
+
return os.path.join(
|
79
|
+
get_gr_cache_dir(), recognizer, domain_name, env_name, model_name
|
80
|
+
)
|
68
81
|
|
69
|
-
def get_and_create(path):
|
70
|
-
create_folders_if_necessary(path)
|
71
|
-
return path
|
72
82
|
|
73
|
-
|
74
|
-
return os.path.join(get_env_dir(domain, env_name=env_name, recognizer=recognizer), "experiment_results", env_name, task, "experiment_results")
|
83
|
+
### GRAML PATHS ###
|
75
84
|
|
76
|
-
def get_plans_result_path(domain_name, env_name, recognizer:str):
|
77
|
-
return os.path.join(get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "plans")
|
78
85
|
|
79
|
-
def
|
80
|
-
return os.path.join(
|
86
|
+
def get_siamese_dataset_path(domain_name, env_name, model_name, recognizer: str):
|
87
|
+
return os.path.join(
|
88
|
+
get_lstm_model_dir(domain_name, env_name, model_name, recognizer),
|
89
|
+
_get_siamese_datasets_directory_name(),
|
90
|
+
)
|
81
91
|
|
82
|
-
### END GRAML PATHS ###
|
83
|
-
''
|
84
|
-
### GRAQL PATHS ###
|
85
92
|
|
86
|
-
def
|
87
|
-
return os.path.join(
|
93
|
+
def get_embeddings_result_path(domain_name, env_name, recognizer: str):
|
94
|
+
return os.path.join(
|
95
|
+
get_env_outputs_dir(domain_name, env_name=env_name, recognizer=recognizer),
|
96
|
+
"goal_embeddings",
|
97
|
+
)
|
88
98
|
|
89
|
-
### GRAQL PATHS ###
|
90
99
|
|
91
|
-
def
|
92
|
-
|
100
|
+
def get_and_create(path):
|
101
|
+
create_folders_if_necessary(path)
|
102
|
+
return path
|
93
103
|
|
94
104
|
|
95
|
-
def
|
96
|
-
path
|
97
|
-
|
105
|
+
def get_experiment_results_path(domain, env_name, task, recognizer: str):
|
106
|
+
return os.path.join(
|
107
|
+
get_env_outputs_dir(domain, env_name=env_name, recognizer=recognizer),
|
108
|
+
task,
|
109
|
+
"experiment_results",
|
110
|
+
)
|
98
111
|
|
99
112
|
|
100
|
-
def
|
101
|
-
path
|
102
|
-
|
103
|
-
|
113
|
+
def get_plans_result_path(domain_name, env_name, recognizer: str):
|
114
|
+
return os.path.join(
|
115
|
+
get_env_outputs_dir(domain_name, env_name=env_name, recognizer=recognizer),
|
116
|
+
"plans",
|
117
|
+
)
|
104
118
|
|
105
119
|
|
106
|
-
def
|
107
|
-
return
|
120
|
+
def get_policy_sequences_result_path(domain_name, env_name, recognizer: str):
|
121
|
+
return os.path.join(
|
122
|
+
get_env_outputs_dir(domain_name, env_name, recognizer=recognizer),
|
123
|
+
"policy_sequences",
|
124
|
+
)
|
108
125
|
|
109
126
|
|
110
|
-
|
111
|
-
return get_status(model_dir)["model_state"]
|
127
|
+
### END GRAML PATHS ###
|
112
128
|
|
129
|
+
### GRAQL PATHS ###
|
113
130
|
|
114
|
-
def get_txt_logger(model_dir):
|
115
|
-
path = os.path.join(model_dir, "log.txt")
|
116
|
-
utils.create_folders_if_necessary(path)
|
117
131
|
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
]
|
132
|
+
def get_gr_as_rl_experiment_confidence_path(domain_name, env_name, recognizer: str):
|
133
|
+
return os.path.join(
|
134
|
+
get_env_outputs_dir(
|
135
|
+
domain_name=domain_name, env_name=env_name, recognizer=recognizer
|
136
|
+
),
|
137
|
+
"confidence",
|
125
138
|
)
|
126
139
|
|
127
|
-
return logging.getLogger()
|
128
140
|
|
129
|
-
|
130
|
-
def get_csv_logger(model_dir):
|
131
|
-
csv_path = os.path.join(model_dir, "log.csv")
|
132
|
-
utils.create_folders_if_necessary(csv_path)
|
133
|
-
csv_file = open(csv_path, "a")
|
134
|
-
return csv_file, csv.writer(csv_file)
|
141
|
+
### GRAQL PATHS ###
|
gr_libs/odgr_executor.py
ADDED
@@ -0,0 +1,268 @@
|
|
1
|
+
import argparse
|
2
|
+
import os
|
3
|
+
import time
|
4
|
+
|
5
|
+
import dill
|
6
|
+
|
7
|
+
from gr_libs.environment.utils.utils import domain_to_env_property
|
8
|
+
from gr_libs.metrics.metrics import stochastic_amplified_selection
|
9
|
+
from gr_libs.ml.neural.deep_rl_learner import DeepRLAgent
|
10
|
+
from gr_libs.ml.utils.format import random_subset_with_order
|
11
|
+
from gr_libs.ml.utils.storage import (
|
12
|
+
get_and_create,
|
13
|
+
get_experiment_results_path,
|
14
|
+
get_policy_sequences_result_path,
|
15
|
+
)
|
16
|
+
from gr_libs.problems.consts import PROBLEMS
|
17
|
+
from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco
|
18
|
+
from gr_libs.recognizer.graml.graml_recognizer import Graml
|
19
|
+
from gr_libs.recognizer.recognizer import GaAgentTrainerRecognizer, LearningRecognizer
|
20
|
+
from gr_libs.recognizer.utils import recognizer_str_to_obj
|
21
|
+
|
22
|
+
|
23
|
+
def validate(args, recognizer_type, task_inputs):
|
24
|
+
if "base" in task_inputs.keys():
|
25
|
+
# assert issubclass(recognizer_type, LearningRecognizer), f"base is in the task_inputs for the recognizer {args.recognizer}, which doesn't have a domain learning phase (is not a learning recognizer)."
|
26
|
+
assert (
|
27
|
+
list(task_inputs.keys())[0] == "base"
|
28
|
+
), "In case of LearningRecognizer, base should be the first element in the task_inputs dict in consts.py"
|
29
|
+
assert (
|
30
|
+
"base" not in list(task_inputs.keys())[1:]
|
31
|
+
), "In case of LearningRecognizer, base should be only in the first element in the task_inputs dict in consts.py"
|
32
|
+
# else:
|
33
|
+
# assert not issubclass(recognizer_type, LearningRecognizer), f"base is not in the task_inputs for the recognizer {args.recognizer}, which has a domain learning phase (is a learning recognizer). Remove it from the task_inputs dict in consts.py."
|
34
|
+
|
35
|
+
|
36
|
+
def run_odgr_problem(args):
|
37
|
+
recognizer_type = recognizer_str_to_obj(args.recognizer)
|
38
|
+
env_inputs = PROBLEMS[args.domain]
|
39
|
+
assert (
|
40
|
+
args.env_name in env_inputs.keys()
|
41
|
+
), f"env_name {args.env_name} is not in the list of available environments for the domain {args.domain}. Add it to PROBLEMS dict in consts.py"
|
42
|
+
task_inputs = env_inputs[args.env_name][args.task]
|
43
|
+
recognizer = recognizer_type(
|
44
|
+
domain_name=args.domain,
|
45
|
+
env_name=args.env_name,
|
46
|
+
collect_statistics=args.collect_stats,
|
47
|
+
)
|
48
|
+
validate(args, recognizer_type, task_inputs)
|
49
|
+
ga_times, results = [], {}
|
50
|
+
for key, value in task_inputs.items():
|
51
|
+
if key == "base":
|
52
|
+
dlp_time = 0
|
53
|
+
if issubclass(recognizer_type, LearningRecognizer):
|
54
|
+
start_dlp_time = time.time()
|
55
|
+
recognizer.domain_learning_phase(
|
56
|
+
base_goals=value["goals"], train_configs=value["train_configs"]
|
57
|
+
)
|
58
|
+
dlp_time = time.time() - start_dlp_time
|
59
|
+
elif key.startswith("G_"):
|
60
|
+
start_ga_time = time.time()
|
61
|
+
kwargs = {"dynamic_goals": value["goals"]}
|
62
|
+
if issubclass(recognizer_type, GaAgentTrainerRecognizer):
|
63
|
+
kwargs["dynamic_train_configs"] = value["train_configs"]
|
64
|
+
recognizer.goals_adaptation_phase(**kwargs)
|
65
|
+
ga_times.append(time.time() - start_ga_time)
|
66
|
+
elif key.startswith("I_"):
|
67
|
+
goal, train_config, consecutive, consecutive_str, percentage = (
|
68
|
+
value["goal"],
|
69
|
+
value["train_config"],
|
70
|
+
value["consecutive"],
|
71
|
+
"consecutive" if value["consecutive"] == True else "non_consecutive",
|
72
|
+
value["percentage"],
|
73
|
+
)
|
74
|
+
results.setdefault(str(percentage), {})
|
75
|
+
results[str(percentage)].setdefault(
|
76
|
+
consecutive_str,
|
77
|
+
{
|
78
|
+
"correct": 0,
|
79
|
+
"num_of_tasks": 0,
|
80
|
+
"accuracy": 0,
|
81
|
+
"average_inference_time": 0,
|
82
|
+
},
|
83
|
+
)
|
84
|
+
property_type = domain_to_env_property(args.domain)
|
85
|
+
env_property = property_type(args.env_name)
|
86
|
+
problem_name = env_property.goal_to_problem_str(goal)
|
87
|
+
rl_agent_type = recognizer.rl_agent_type
|
88
|
+
agent = rl_agent_type(
|
89
|
+
domain_name=args.domain,
|
90
|
+
problem_name=problem_name,
|
91
|
+
algorithm=train_config[0],
|
92
|
+
num_timesteps=train_config[1],
|
93
|
+
env_prop=env_property,
|
94
|
+
)
|
95
|
+
agent.learn()
|
96
|
+
fig_path = get_and_create(
|
97
|
+
f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=args.domain, env_name=args.env_name, recognizer=args.recognizer), problem_name))}_inference_seq"
|
98
|
+
)
|
99
|
+
generate_obs_kwargs = {
|
100
|
+
"action_selection_method": stochastic_amplified_selection,
|
101
|
+
"save_fig": args.collect_stats,
|
102
|
+
"random_optimalism": True,
|
103
|
+
"fig_path": fig_path if args.collect_stats else None,
|
104
|
+
}
|
105
|
+
|
106
|
+
# need to dump the whole plan for draco because it needs it for inference phase for checking likelihood.
|
107
|
+
if (recognizer_type == Draco or recognizer_type == GCDraco) and issubclass(
|
108
|
+
rl_agent_type, DeepRLAgent
|
109
|
+
): # TODO remove this condition, remove the assumption.
|
110
|
+
generate_obs_kwargs["with_dict"] = True
|
111
|
+
sequence = agent.generate_observation(**generate_obs_kwargs)
|
112
|
+
if issubclass(
|
113
|
+
recognizer_type, Graml
|
114
|
+
): # need to dump the plans to compute offline plan similarity only in graml's case for evaluation.
|
115
|
+
recognizer.dump_plans(
|
116
|
+
true_sequence=sequence, true_goal=goal, percentage=percentage
|
117
|
+
)
|
118
|
+
partial_sequence = random_subset_with_order(
|
119
|
+
sequence, (int)(percentage * len(sequence)), is_consecutive=consecutive
|
120
|
+
)
|
121
|
+
# add evaluation_function to kwargs if this is graql. move everything to kwargs...
|
122
|
+
start_inf_time = time.time()
|
123
|
+
closest_goal = recognizer.inference_phase(
|
124
|
+
partial_sequence, goal, percentage
|
125
|
+
)
|
126
|
+
results[str(percentage)][consecutive_str]["average_inference_time"] += (
|
127
|
+
time.time() - start_inf_time
|
128
|
+
)
|
129
|
+
# print(f'real goal {goal}, closest goal is: {closest_goal}')
|
130
|
+
if all(a == b for a, b in zip(str(goal), closest_goal)):
|
131
|
+
results[str(percentage)][consecutive_str]["correct"] += 1
|
132
|
+
results[str(percentage)][consecutive_str]["num_of_tasks"] += 1
|
133
|
+
|
134
|
+
for percentage in results.keys():
|
135
|
+
for consecutive_str in results[str(percentage)].keys():
|
136
|
+
results[str(percentage)][consecutive_str]["average_inference_time"] /= len(
|
137
|
+
results[str(percentage)][consecutive_str]
|
138
|
+
)
|
139
|
+
results[str(percentage)][consecutive_str]["accuracy"] = (
|
140
|
+
results[str(percentage)][consecutive_str]["correct"]
|
141
|
+
/ results[str(percentage)][consecutive_str]["num_of_tasks"]
|
142
|
+
)
|
143
|
+
|
144
|
+
# aggregate
|
145
|
+
total_correct = sum(
|
146
|
+
[
|
147
|
+
result["correct"]
|
148
|
+
for cons_result in results.values()
|
149
|
+
for result in cons_result.values()
|
150
|
+
]
|
151
|
+
)
|
152
|
+
total_tasks = sum(
|
153
|
+
[
|
154
|
+
result["num_of_tasks"]
|
155
|
+
for cons_result in results.values()
|
156
|
+
for result in cons_result.values()
|
157
|
+
]
|
158
|
+
)
|
159
|
+
total_average_inference_time = (
|
160
|
+
sum(
|
161
|
+
[
|
162
|
+
result["average_inference_time"]
|
163
|
+
for cons_result in results.values()
|
164
|
+
for result in cons_result.values()
|
165
|
+
]
|
166
|
+
)
|
167
|
+
/ total_tasks
|
168
|
+
)
|
169
|
+
|
170
|
+
results["total"] = {
|
171
|
+
"total_correct": total_correct,
|
172
|
+
"total_tasks": total_tasks,
|
173
|
+
"total_accuracy": total_correct / total_tasks,
|
174
|
+
"total_average_inference_time": total_average_inference_time,
|
175
|
+
"goals_adaptation_time": sum(ga_times) / len(ga_times),
|
176
|
+
"domain_learning_time": dlp_time,
|
177
|
+
}
|
178
|
+
print(str(results))
|
179
|
+
res_file_path = get_and_create(
|
180
|
+
get_experiment_results_path(
|
181
|
+
domain=args.domain,
|
182
|
+
env_name=args.env_name,
|
183
|
+
task=args.task,
|
184
|
+
recognizer=args.recognizer,
|
185
|
+
)
|
186
|
+
)
|
187
|
+
print(f"generating results into {res_file_path}")
|
188
|
+
with open(os.path.join(res_file_path, "res.pkl"), "wb") as results_file:
|
189
|
+
dill.dump(results, results_file)
|
190
|
+
with open(os.path.join(res_file_path, "res.txt"), "w") as results_file:
|
191
|
+
results_file.write(str(results))
|
192
|
+
|
193
|
+
|
194
|
+
def parse_args():
|
195
|
+
parser = argparse.ArgumentParser(
|
196
|
+
description="Parse command-line arguments for the RL experiment.",
|
197
|
+
formatter_class=argparse.RawTextHelpFormatter,
|
198
|
+
)
|
199
|
+
|
200
|
+
# Required arguments
|
201
|
+
required_group = parser.add_argument_group("Required arguments")
|
202
|
+
required_group.add_argument(
|
203
|
+
"--domain",
|
204
|
+
choices=["point_maze", "minigrid", "parking", "panda"],
|
205
|
+
required=True,
|
206
|
+
help="Domain name (point_maze, minigrid, parking, or panda)",
|
207
|
+
)
|
208
|
+
required_group.add_argument(
|
209
|
+
"--env_name",
|
210
|
+
required=True,
|
211
|
+
help="Env name (point_maze, minigrid, parking, or panda). For example, Parking-S-14-PC--v0",
|
212
|
+
)
|
213
|
+
required_group.add_argument(
|
214
|
+
"--recognizer",
|
215
|
+
choices=[
|
216
|
+
"MCTSBasedGraml",
|
217
|
+
"ExpertBasedGraml",
|
218
|
+
"GCGraml",
|
219
|
+
"Graql",
|
220
|
+
"Draco",
|
221
|
+
"GCDraco",
|
222
|
+
],
|
223
|
+
required=True,
|
224
|
+
help="Recognizer type. Follow readme.md and recognizer folder for more information and rules.",
|
225
|
+
)
|
226
|
+
required_group.add_argument(
|
227
|
+
"--task",
|
228
|
+
choices=[
|
229
|
+
"L1",
|
230
|
+
"L2",
|
231
|
+
"L3",
|
232
|
+
"L4",
|
233
|
+
"L5",
|
234
|
+
"L11",
|
235
|
+
"L22",
|
236
|
+
"L33",
|
237
|
+
"L44",
|
238
|
+
"L55",
|
239
|
+
"L111",
|
240
|
+
"L222",
|
241
|
+
"L333",
|
242
|
+
"L444",
|
243
|
+
"L555",
|
244
|
+
],
|
245
|
+
required=True,
|
246
|
+
help="Task identifier (e.g., L1, L2,...,L5)",
|
247
|
+
)
|
248
|
+
|
249
|
+
# Optional arguments
|
250
|
+
optional_group = parser.add_argument_group("Optional arguments")
|
251
|
+
optional_group.add_argument(
|
252
|
+
"--collect_stats", action="store_true", help="Whether to collect statistics"
|
253
|
+
)
|
254
|
+
args = parser.parse_args()
|
255
|
+
|
256
|
+
### VALIDATE INPUTS ###
|
257
|
+
# Assert that all required arguments are provided
|
258
|
+
assert (
|
259
|
+
args.domain is not None
|
260
|
+
and args.recognizer is not None
|
261
|
+
and args.task is not None
|
262
|
+
), "Missing required arguments: domain, recognizer, or task"
|
263
|
+
return args
|
264
|
+
|
265
|
+
|
266
|
+
if __name__ == "__main__":
|
267
|
+
args = parse_args()
|
268
|
+
run_odgr_problem(args)
|