gr-libs 0.1.7.post0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. gr_libs/__init__.py +4 -1
  2. gr_libs/_evaluation/__init__.py +1 -0
  3. gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +260 -0
  4. gr_libs/_evaluation/_generate_experiments_results.py +141 -0
  5. gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +497 -0
  6. gr_libs/_evaluation/_get_plans_images.py +61 -0
  7. gr_libs/_evaluation/_increasing_and_decreasing_.py +106 -0
  8. gr_libs/_version.py +2 -2
  9. gr_libs/all_experiments.py +294 -0
  10. gr_libs/environment/__init__.py +30 -9
  11. gr_libs/environment/_utils/utils.py +27 -0
  12. gr_libs/environment/environment.py +417 -54
  13. gr_libs/metrics/__init__.py +7 -0
  14. gr_libs/metrics/metrics.py +231 -54
  15. gr_libs/ml/__init__.py +2 -5
  16. gr_libs/ml/agent.py +21 -6
  17. gr_libs/ml/base/__init__.py +3 -1
  18. gr_libs/ml/base/rl_agent.py +81 -13
  19. gr_libs/ml/consts.py +1 -1
  20. gr_libs/ml/neural/__init__.py +1 -3
  21. gr_libs/ml/neural/deep_rl_learner.py +619 -378
  22. gr_libs/ml/neural/utils/__init__.py +1 -2
  23. gr_libs/ml/neural/utils/dictlist.py +3 -3
  24. gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +1 -1
  25. gr_libs/ml/planner/mcts/{utils → _utils}/node.py +11 -7
  26. gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +15 -11
  27. gr_libs/ml/planner/mcts/mcts_model.py +571 -312
  28. gr_libs/ml/sequential/__init__.py +0 -1
  29. gr_libs/ml/sequential/_lstm_model.py +270 -0
  30. gr_libs/ml/tabular/__init__.py +1 -3
  31. gr_libs/ml/tabular/state.py +7 -7
  32. gr_libs/ml/tabular/tabular_q_learner.py +150 -82
  33. gr_libs/ml/tabular/tabular_rl_agent.py +42 -28
  34. gr_libs/ml/utils/__init__.py +2 -3
  35. gr_libs/ml/utils/format.py +28 -97
  36. gr_libs/ml/utils/math.py +5 -3
  37. gr_libs/ml/utils/other.py +3 -3
  38. gr_libs/ml/utils/storage.py +88 -81
  39. gr_libs/odgr_executor.py +268 -0
  40. gr_libs/problems/consts.py +1549 -1227
  41. gr_libs/recognizer/_utils/__init__.py +0 -0
  42. gr_libs/recognizer/_utils/format.py +18 -0
  43. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +233 -88
  44. gr_libs/recognizer/graml/_gr_dataset.py +233 -0
  45. gr_libs/recognizer/graml/graml_recognizer.py +586 -252
  46. gr_libs/recognizer/recognizer.py +90 -30
  47. gr_libs/tutorials/draco_panda_tutorial.py +58 -0
  48. gr_libs/tutorials/draco_parking_tutorial.py +56 -0
  49. gr_libs/tutorials/gcdraco_panda_tutorial.py +62 -0
  50. gr_libs/tutorials/gcdraco_parking_tutorial.py +57 -0
  51. gr_libs/tutorials/graml_minigrid_tutorial.py +64 -0
  52. gr_libs/tutorials/graml_panda_tutorial.py +57 -0
  53. gr_libs/tutorials/graml_parking_tutorial.py +52 -0
  54. gr_libs/tutorials/graml_point_maze_tutorial.py +60 -0
  55. gr_libs/tutorials/graql_minigrid_tutorial.py +50 -0
  56. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/METADATA +84 -29
  57. gr_libs-0.2.2.dist-info/RECORD +71 -0
  58. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/WHEEL +1 -1
  59. gr_libs-0.2.2.dist-info/top_level.txt +2 -0
  60. tests/test_draco.py +14 -0
  61. tests/test_gcdraco.py +10 -0
  62. tests/test_graml.py +12 -8
  63. tests/test_graql.py +3 -2
  64. evaluation/analyze_results_cross_alg_cross_domain.py +0 -277
  65. evaluation/create_minigrid_map_image.py +0 -34
  66. evaluation/file_system.py +0 -42
  67. evaluation/generate_experiments_results.py +0 -92
  68. evaluation/generate_experiments_results_new_ver1.py +0 -254
  69. evaluation/generate_experiments_results_new_ver2.py +0 -331
  70. evaluation/generate_task_specific_statistics_plots.py +0 -272
  71. evaluation/get_plans_images.py +0 -47
  72. evaluation/increasing_and_decreasing_.py +0 -63
  73. gr_libs/environment/utils/utils.py +0 -17
  74. gr_libs/ml/neural/utils/penv.py +0 -57
  75. gr_libs/ml/sequential/lstm_model.py +0 -192
  76. gr_libs/recognizer/graml/gr_dataset.py +0 -134
  77. gr_libs/recognizer/utils/__init__.py +0 -1
  78. gr_libs/recognizer/utils/format.py +0 -13
  79. gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
  80. gr_libs-0.1.7.post0.dist-info/top_level.txt +0 -4
  81. tutorials/graml_minigrid_tutorial.py +0 -34
  82. tutorials/graml_panda_tutorial.py +0 -41
  83. tutorials/graml_parking_tutorial.py +0 -39
  84. tutorials/graml_point_maze_tutorial.py +0 -39
  85. tutorials/graql_minigrid_tutorial.py +0 -34
  86. /gr_libs/environment/{utils → _utils}/__init__.py +0 -0
@@ -1,100 +1,31 @@
1
- import numpy
2
- import re
3
- import torch
4
- import gr_libs.ml
5
- import gymnasium as gym
6
- import random
7
-
8
- def get_obss_preprocessor(obs_space):
9
- # Check if obs_space is an image space
10
- if isinstance(obs_space, gym.spaces.Box):
11
- obs_space = {"image": obs_space.shape}
12
-
13
- def preprocess_obss(obss, device=None):
14
- return ml.DictList({
15
- "image": preprocess_images(obss, device=device)
16
- })
17
-
18
- # Check if it is a MiniGrid observation space
19
- elif isinstance(obs_space, gym.spaces.Dict) and "image" in obs_space.spaces.keys():
20
- obs_space = {"image": obs_space.spaces["image"].shape, "text": 100}
21
-
22
- vocab = Vocabulary(obs_space["text"])
23
-
24
- def preprocess_obss(obss, device=None):
25
- return ml.DictList({
26
- "image": preprocess_images([obs["image"] for obs in obss], device=device),
27
- "text": preprocess_texts([obs["mission"] for obs in obss], vocab, device=device)
28
- })
29
-
30
- preprocess_obss.vocab = vocab
31
-
32
- # Check if it is a MiniGrid observation space
33
- elif isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces.keys():
34
- obs_space = {"observation": obs_space.spaces["observation"].shape}
35
-
36
- def preprocess_obss(obss, device=None):
37
- return ml.DictList({
38
- "observation": preprocess_images(obss, device=device)
39
- })
40
-
41
-
42
- else:
43
- raise ValueError("Unknown observation space: " + str(obs_space))
44
-
45
- return obs_space, preprocess_obss
46
-
1
+ """ formatting-related utilities """
47
2
 
48
- def preprocess_images(images, device=None):
49
- # Bug of Pytorch: very slow if not first converted to numpy array
50
- images = numpy.array(images)
51
- return torch.tensor(images, device=device, dtype=torch.float)
52
-
53
-
54
- def random_subset_with_order(sequence, subset_size, is_consecutive = True):
55
- if subset_size >= len(sequence):
56
- return sequence
57
- else:
58
- if is_consecutive:
59
- indices_to_select = [i for i in range(subset_size)]
60
- else:
61
- indices_to_select = sorted(random.sample(range(len(sequence)), subset_size)) # Randomly select indices to keep
62
- return [sequence[i] for i in indices_to_select] # Return the elements corresponding to the selected indices
63
-
64
-
65
-
66
- def preprocess_texts(texts, vocab, device=None):
67
- var_indexed_texts = []
68
- max_text_len = 0
69
-
70
- for text in texts:
71
- tokens = re.findall("([a-z]+)", text.lower())
72
- var_indexed_text = numpy.array([vocab[token] for token in tokens])
73
- var_indexed_texts.append(var_indexed_text)
74
- max_text_len = max(len(var_indexed_text), max_text_len)
75
-
76
- indexed_texts = numpy.zeros((len(texts), max_text_len))
77
-
78
- for i, indexed_text in enumerate(var_indexed_texts):
79
- indexed_texts[i, :len(indexed_text)] = indexed_text
80
-
81
- return torch.tensor(indexed_texts, device=device, dtype=torch.long)
82
-
83
-
84
- class Vocabulary:
85
- """A mapping from tokens to ids with a capacity of `max_size` words.
86
- It can be saved in a `vocab.json` file."""
87
-
88
- def __init__(self, max_size):
89
- self.max_size = max_size
90
- self.vocab = {}
3
+ import random
91
4
 
92
- def load_vocab(self, vocab):
93
- self.vocab = vocab
94
5
 
95
- def __getitem__(self, token):
96
- if not token in self.vocab.keys():
97
- if len(self.vocab) >= self.max_size:
98
- raise ValueError("Maximum vocabulary capacity reached")
99
- self.vocab[token] = len(self.vocab) + 1
100
- return self.vocab[token]
6
+ def random_subset_with_order(sequence, subset_size, is_consecutive=True):
7
+ """
8
+ Returns a random subset of elements from the given sequence with a specified subset size.
9
+
10
+ Args:
11
+ sequence (list): The sequence of elements to select from.
12
+ subset_size (int): The size of the desired subset.
13
+ is_consecutive (bool, optional): Whether the selected subset should be consecutive elements from the sequence.
14
+ Defaults to True.
15
+
16
+ Returns:
17
+ list: A random subset of elements from the sequence.
18
+
19
+ """
20
+ if subset_size >= len(sequence):
21
+ return sequence
22
+ else:
23
+ if is_consecutive:
24
+ indices_to_select = [i for i in range(subset_size)]
25
+ else:
26
+ indices_to_select = sorted(
27
+ random.sample(range(len(sequence)), subset_size)
28
+ ) # Randomly select indices to keep
29
+ return [
30
+ sequence[i] for i in indices_to_select
31
+ ] # Return the elements corresponding to the selected indices
gr_libs/ml/utils/math.py CHANGED
@@ -1,7 +1,9 @@
1
+ """ math-related functions """
2
+
1
3
  import math
2
- from typing import Callable, Generator, List
3
4
 
4
- def softmax(values: List[float]) -> List[float]:
5
+
6
+ def softmax(values: list[float]) -> list[float]:
5
7
  """Computes softmax probabilities for an array of values
6
8
  TODO We should probably use numpy arrays here
7
9
  Args:
@@ -10,4 +12,4 @@ def softmax(values: List[float]) -> List[float]:
10
12
  Returns:
11
13
  np.array: softmax probabilities
12
14
  """
13
- return [(math.exp(q)) / sum([math.exp(_q) for _q in values]) for q in values]
15
+ return [(math.exp(q)) / sum([math.exp(_q) for _q in values]) for q in values]
gr_libs/ml/utils/other.py CHANGED
@@ -1,8 +1,8 @@
1
+ import collections
1
2
  import random
3
+
2
4
  import numpy
3
5
  import torch
4
- import collections
5
-
6
6
 
7
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
8
 
@@ -21,4 +21,4 @@ def synthesize(array):
21
21
  d["std"] = numpy.std(array)
22
22
  d["min"] = numpy.amin(array)
23
23
  d["max"] = numpy.amax(array)
24
- return d
24
+ return d
@@ -1,134 +1,141 @@
1
- import csv
2
1
  import os
3
- import torch
4
- import logging
5
- import sys
6
2
 
7
- from .other import device
8
3
 
9
4
  def create_folders_if_necessary(path):
10
5
  if not os.path.exists(path):
11
6
  os.makedirs(path)
12
7
 
13
8
 
14
- def get_storage_framework_dir(recognizer: str):
15
- return os.path.join(get_storage_dir(),recognizer)
9
+ def get_outputs_dir():
10
+ return "outputs"
16
11
 
17
- def get_storage_dir():
12
+
13
+ def get_recognizer_outputs_dir(recognizer: str):
14
+ return os.path.join(get_outputs_dir(), recognizer)
15
+
16
+
17
+ def get_gr_cache_dir():
18
18
  # Prefer local directory if it exists (e.g., in GitHub workspace)
19
- if os.path.exists("dataset"):
20
- return "dataset"
19
+ if os.path.exists("gr_cache"):
20
+ return "gr_cache"
21
21
  # Fall back to pre-mounted directory (e.g., in Docker container)
22
- if os.path.exists("/preloaded_data"):
23
- return "/preloaded_data"
22
+ if os.path.exists("/gr_cache"):
23
+ return "/gr_cache"
24
24
  # Default to "dataset" even if it doesn't exist (e.g., will be created)
25
- return "dataset"
25
+ return "gr_cache"
26
+
27
+
28
+ def get_trained_agents_dir():
29
+ # Prefer local directory if it exists (e.g., in GitHub workspace)
30
+ if os.path.exists("trained_agents"):
31
+ return "trained_agents"
32
+ # Fall back to pre-mounted directory (e.g., in Docker container)
33
+ if os.path.exists("/trained_agents"):
34
+ return "/trained_agents"
35
+ # Default to "dataset" even if it doesn't exist (e.g., will be created)
36
+ return "trained_agents"
26
37
 
27
- def _get_models_directory_name():
28
- return "models"
29
38
 
30
39
  def _get_siamese_datasets_directory_name():
31
40
  return "siamese_datasets"
32
41
 
42
+
33
43
  def _get_observations_directory_name():
34
44
  return "observations"
35
45
 
36
- def get_observation_file_name(observability_percentage: float):
37
- return 'obs' + str(observability_percentage) + '.pkl'
38
46
 
47
+ def get_observation_file_name(observability_percentage: float):
48
+ return "obs" + str(observability_percentage) + ".pkl"
39
49
 
40
- def get_domain_dir(domain_name, recognizer:str):
41
- return os.path.join(get_storage_framework_dir(recognizer), domain_name)
42
50
 
43
- def get_env_dir(domain_name, env_name, recognizer:str):
44
- return os.path.join(get_domain_dir(domain_name, recognizer), env_name)
51
+ def get_domain_outputs_dir(domain_name, recognizer: str):
52
+ return os.path.join(get_recognizer_outputs_dir(recognizer), domain_name)
45
53
 
46
- def get_observations_dir(domain_name, env_name, recognizer:str):
47
- return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), _get_observations_directory_name())
48
54
 
49
- def get_agent_model_dir(domain_name, model_name, class_name):
50
- return os.path.join(get_storage_dir(), _get_models_directory_name(), domain_name, model_name, class_name)
55
+ def get_env_outputs_dir(domain_name, env_name, recognizer: str):
56
+ return os.path.join(get_domain_outputs_dir(domain_name, recognizer), env_name)
51
57
 
52
- def get_lstm_model_dir(domain_name, env_name, model_name, recognizer:str):
53
- return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), model_name)
54
58
 
55
- def get_models_dir(domain_name, env_name, recognizer:str):
56
- return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), _get_models_directory_name())
59
+ def get_observations_dir(domain_name, env_name, recognizer: str):
60
+ return os.path.join(
61
+ get_env_outputs_dir(
62
+ domain_name=domain_name, env_name=env_name, recognizer=recognizer
63
+ ),
64
+ _get_observations_directory_name(),
65
+ )
57
66
 
58
- ### GRAML PATHS ###
59
67
 
60
- def get_siamese_dataset_path(domain_name, env_name, model_name, recognizer:str):
61
- return os.path.join(get_lstm_model_dir(domain_name, env_name, model_name, recognizer), _get_siamese_datasets_directory_name())
68
+ def get_agent_model_dir(domain_name, model_name, class_name):
69
+ return os.path.join(
70
+ get_trained_agents_dir(),
71
+ domain_name,
72
+ model_name,
73
+ class_name,
74
+ )
62
75
 
63
- def get_embeddings_result_path(domain_name, env_name, recognizer:str):
64
- return os.path.join(get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "goal_embeddings")
65
76
 
66
- def get_embeddings_result_path(domain_name, env_name, recognizer:str):
67
- return os.path.join(get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "goal_embeddings")
77
+ def get_lstm_model_dir(domain_name, env_name, model_name, recognizer: str):
78
+ return os.path.join(
79
+ get_gr_cache_dir(), recognizer, domain_name, env_name, model_name
80
+ )
68
81
 
69
- def get_and_create(path):
70
- create_folders_if_necessary(path)
71
- return path
72
82
 
73
- def get_experiment_results_path(domain, env_name, task, recognizer:str):
74
- return os.path.join(get_env_dir(domain, env_name=env_name, recognizer=recognizer), "experiment_results", env_name, task, "experiment_results")
83
+ ### GRAML PATHS ###
75
84
 
76
- def get_plans_result_path(domain_name, env_name, recognizer:str):
77
- return os.path.join(get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "plans")
78
85
 
79
- def get_policy_sequences_result_path(domain_name, env_name, recognizer:str):
80
- return os.path.join(get_env_dir(domain_name, env_name, recognizer=recognizer), "policy_sequences")
86
+ def get_siamese_dataset_path(domain_name, env_name, model_name, recognizer: str):
87
+ return os.path.join(
88
+ get_lstm_model_dir(domain_name, env_name, model_name, recognizer),
89
+ _get_siamese_datasets_directory_name(),
90
+ )
81
91
 
82
- ### END GRAML PATHS ###
83
- ''
84
- ### GRAQL PATHS ###
85
92
 
86
- def get_gr_as_rl_experiment_confidence_path(domain_name, env_name, recognizer:str):
87
- return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), "experiments")
93
+ def get_embeddings_result_path(domain_name, env_name, recognizer: str):
94
+ return os.path.join(
95
+ get_env_outputs_dir(domain_name, env_name=env_name, recognizer=recognizer),
96
+ "goal_embeddings",
97
+ )
88
98
 
89
- ### GRAQL PATHS ###
90
99
 
91
- def get_status_path(model_dir):
92
- return os.path.join(model_dir, "status.pt")
100
+ def get_and_create(path):
101
+ create_folders_if_necessary(path)
102
+ return path
93
103
 
94
104
 
95
- def get_status(model_dir):
96
- path = get_status_path(model_dir)
97
- return torch.load(path, map_location=device)
105
+ def get_experiment_results_path(domain, env_name, task, recognizer: str):
106
+ return os.path.join(
107
+ get_env_outputs_dir(domain, env_name=env_name, recognizer=recognizer),
108
+ task,
109
+ "experiment_results",
110
+ )
98
111
 
99
112
 
100
- def save_status(status, model_dir):
101
- path = get_status_path(model_dir)
102
- utils.create_folders_if_necessary(path)
103
- torch.save(status, path)
113
+ def get_plans_result_path(domain_name, env_name, recognizer: str):
114
+ return os.path.join(
115
+ get_env_outputs_dir(domain_name, env_name=env_name, recognizer=recognizer),
116
+ "plans",
117
+ )
104
118
 
105
119
 
106
- def get_vocab(model_dir):
107
- return get_status(model_dir)["vocab"]
120
+ def get_policy_sequences_result_path(domain_name, env_name, recognizer: str):
121
+ return os.path.join(
122
+ get_env_outputs_dir(domain_name, env_name, recognizer=recognizer),
123
+ "policy_sequences",
124
+ )
108
125
 
109
126
 
110
- def get_model_state(model_dir):
111
- return get_status(model_dir)["model_state"]
127
+ ### END GRAML PATHS ###
112
128
 
129
+ ### GRAQL PATHS ###
113
130
 
114
- def get_txt_logger(model_dir):
115
- path = os.path.join(model_dir, "log.txt")
116
- utils.create_folders_if_necessary(path)
117
131
 
118
- logging.basicConfig(
119
- level=logging.INFO,
120
- format="%(message)s",
121
- handlers=[
122
- logging.FileHandler(filename=path),
123
- logging.StreamHandler(sys.stdout)
124
- ]
132
+ def get_gr_as_rl_experiment_confidence_path(domain_name, env_name, recognizer: str):
133
+ return os.path.join(
134
+ get_env_outputs_dir(
135
+ domain_name=domain_name, env_name=env_name, recognizer=recognizer
136
+ ),
137
+ "confidence",
125
138
  )
126
139
 
127
- return logging.getLogger()
128
140
 
129
-
130
- def get_csv_logger(model_dir):
131
- csv_path = os.path.join(model_dir, "log.csv")
132
- utils.create_folders_if_necessary(csv_path)
133
- csv_file = open(csv_path, "a")
134
- return csv_file, csv.writer(csv_file)
141
+ ### GRAQL PATHS ###
@@ -0,0 +1,268 @@
1
+ import argparse
2
+ import os
3
+ import time
4
+
5
+ import dill
6
+
7
+ from gr_libs.environment.utils.utils import domain_to_env_property
8
+ from gr_libs.metrics.metrics import stochastic_amplified_selection
9
+ from gr_libs.ml.neural.deep_rl_learner import DeepRLAgent
10
+ from gr_libs.ml.utils.format import random_subset_with_order
11
+ from gr_libs.ml.utils.storage import (
12
+ get_and_create,
13
+ get_experiment_results_path,
14
+ get_policy_sequences_result_path,
15
+ )
16
+ from gr_libs.problems.consts import PROBLEMS
17
+ from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco
18
+ from gr_libs.recognizer.graml.graml_recognizer import Graml
19
+ from gr_libs.recognizer.recognizer import GaAgentTrainerRecognizer, LearningRecognizer
20
+ from gr_libs.recognizer.utils import recognizer_str_to_obj
21
+
22
+
23
+ def validate(args, recognizer_type, task_inputs):
24
+ if "base" in task_inputs.keys():
25
+ # assert issubclass(recognizer_type, LearningRecognizer), f"base is in the task_inputs for the recognizer {args.recognizer}, which doesn't have a domain learning phase (is not a learning recognizer)."
26
+ assert (
27
+ list(task_inputs.keys())[0] == "base"
28
+ ), "In case of LearningRecognizer, base should be the first element in the task_inputs dict in consts.py"
29
+ assert (
30
+ "base" not in list(task_inputs.keys())[1:]
31
+ ), "In case of LearningRecognizer, base should be only in the first element in the task_inputs dict in consts.py"
32
+ # else:
33
+ # assert not issubclass(recognizer_type, LearningRecognizer), f"base is not in the task_inputs for the recognizer {args.recognizer}, which has a domain learning phase (is a learning recognizer). Remove it from the task_inputs dict in consts.py."
34
+
35
+
36
+ def run_odgr_problem(args):
37
+ recognizer_type = recognizer_str_to_obj(args.recognizer)
38
+ env_inputs = PROBLEMS[args.domain]
39
+ assert (
40
+ args.env_name in env_inputs.keys()
41
+ ), f"env_name {args.env_name} is not in the list of available environments for the domain {args.domain}. Add it to PROBLEMS dict in consts.py"
42
+ task_inputs = env_inputs[args.env_name][args.task]
43
+ recognizer = recognizer_type(
44
+ domain_name=args.domain,
45
+ env_name=args.env_name,
46
+ collect_statistics=args.collect_stats,
47
+ )
48
+ validate(args, recognizer_type, task_inputs)
49
+ ga_times, results = [], {}
50
+ for key, value in task_inputs.items():
51
+ if key == "base":
52
+ dlp_time = 0
53
+ if issubclass(recognizer_type, LearningRecognizer):
54
+ start_dlp_time = time.time()
55
+ recognizer.domain_learning_phase(
56
+ base_goals=value["goals"], train_configs=value["train_configs"]
57
+ )
58
+ dlp_time = time.time() - start_dlp_time
59
+ elif key.startswith("G_"):
60
+ start_ga_time = time.time()
61
+ kwargs = {"dynamic_goals": value["goals"]}
62
+ if issubclass(recognizer_type, GaAgentTrainerRecognizer):
63
+ kwargs["dynamic_train_configs"] = value["train_configs"]
64
+ recognizer.goals_adaptation_phase(**kwargs)
65
+ ga_times.append(time.time() - start_ga_time)
66
+ elif key.startswith("I_"):
67
+ goal, train_config, consecutive, consecutive_str, percentage = (
68
+ value["goal"],
69
+ value["train_config"],
70
+ value["consecutive"],
71
+ "consecutive" if value["consecutive"] == True else "non_consecutive",
72
+ value["percentage"],
73
+ )
74
+ results.setdefault(str(percentage), {})
75
+ results[str(percentage)].setdefault(
76
+ consecutive_str,
77
+ {
78
+ "correct": 0,
79
+ "num_of_tasks": 0,
80
+ "accuracy": 0,
81
+ "average_inference_time": 0,
82
+ },
83
+ )
84
+ property_type = domain_to_env_property(args.domain)
85
+ env_property = property_type(args.env_name)
86
+ problem_name = env_property.goal_to_problem_str(goal)
87
+ rl_agent_type = recognizer.rl_agent_type
88
+ agent = rl_agent_type(
89
+ domain_name=args.domain,
90
+ problem_name=problem_name,
91
+ algorithm=train_config[0],
92
+ num_timesteps=train_config[1],
93
+ env_prop=env_property,
94
+ )
95
+ agent.learn()
96
+ fig_path = get_and_create(
97
+ f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=args.domain, env_name=args.env_name, recognizer=args.recognizer), problem_name))}_inference_seq"
98
+ )
99
+ generate_obs_kwargs = {
100
+ "action_selection_method": stochastic_amplified_selection,
101
+ "save_fig": args.collect_stats,
102
+ "random_optimalism": True,
103
+ "fig_path": fig_path if args.collect_stats else None,
104
+ }
105
+
106
+ # need to dump the whole plan for draco because it needs it for inference phase for checking likelihood.
107
+ if (recognizer_type == Draco or recognizer_type == GCDraco) and issubclass(
108
+ rl_agent_type, DeepRLAgent
109
+ ): # TODO remove this condition, remove the assumption.
110
+ generate_obs_kwargs["with_dict"] = True
111
+ sequence = agent.generate_observation(**generate_obs_kwargs)
112
+ if issubclass(
113
+ recognizer_type, Graml
114
+ ): # need to dump the plans to compute offline plan similarity only in graml's case for evaluation.
115
+ recognizer.dump_plans(
116
+ true_sequence=sequence, true_goal=goal, percentage=percentage
117
+ )
118
+ partial_sequence = random_subset_with_order(
119
+ sequence, (int)(percentage * len(sequence)), is_consecutive=consecutive
120
+ )
121
+ # add evaluation_function to kwargs if this is graql. move everything to kwargs...
122
+ start_inf_time = time.time()
123
+ closest_goal = recognizer.inference_phase(
124
+ partial_sequence, goal, percentage
125
+ )
126
+ results[str(percentage)][consecutive_str]["average_inference_time"] += (
127
+ time.time() - start_inf_time
128
+ )
129
+ # print(f'real goal {goal}, closest goal is: {closest_goal}')
130
+ if all(a == b for a, b in zip(str(goal), closest_goal)):
131
+ results[str(percentage)][consecutive_str]["correct"] += 1
132
+ results[str(percentage)][consecutive_str]["num_of_tasks"] += 1
133
+
134
+ for percentage in results.keys():
135
+ for consecutive_str in results[str(percentage)].keys():
136
+ results[str(percentage)][consecutive_str]["average_inference_time"] /= len(
137
+ results[str(percentage)][consecutive_str]
138
+ )
139
+ results[str(percentage)][consecutive_str]["accuracy"] = (
140
+ results[str(percentage)][consecutive_str]["correct"]
141
+ / results[str(percentage)][consecutive_str]["num_of_tasks"]
142
+ )
143
+
144
+ # aggregate
145
+ total_correct = sum(
146
+ [
147
+ result["correct"]
148
+ for cons_result in results.values()
149
+ for result in cons_result.values()
150
+ ]
151
+ )
152
+ total_tasks = sum(
153
+ [
154
+ result["num_of_tasks"]
155
+ for cons_result in results.values()
156
+ for result in cons_result.values()
157
+ ]
158
+ )
159
+ total_average_inference_time = (
160
+ sum(
161
+ [
162
+ result["average_inference_time"]
163
+ for cons_result in results.values()
164
+ for result in cons_result.values()
165
+ ]
166
+ )
167
+ / total_tasks
168
+ )
169
+
170
+ results["total"] = {
171
+ "total_correct": total_correct,
172
+ "total_tasks": total_tasks,
173
+ "total_accuracy": total_correct / total_tasks,
174
+ "total_average_inference_time": total_average_inference_time,
175
+ "goals_adaptation_time": sum(ga_times) / len(ga_times),
176
+ "domain_learning_time": dlp_time,
177
+ }
178
+ print(str(results))
179
+ res_file_path = get_and_create(
180
+ get_experiment_results_path(
181
+ domain=args.domain,
182
+ env_name=args.env_name,
183
+ task=args.task,
184
+ recognizer=args.recognizer,
185
+ )
186
+ )
187
+ print(f"generating results into {res_file_path}")
188
+ with open(os.path.join(res_file_path, "res.pkl"), "wb") as results_file:
189
+ dill.dump(results, results_file)
190
+ with open(os.path.join(res_file_path, "res.txt"), "w") as results_file:
191
+ results_file.write(str(results))
192
+
193
+
194
+ def parse_args():
195
+ parser = argparse.ArgumentParser(
196
+ description="Parse command-line arguments for the RL experiment.",
197
+ formatter_class=argparse.RawTextHelpFormatter,
198
+ )
199
+
200
+ # Required arguments
201
+ required_group = parser.add_argument_group("Required arguments")
202
+ required_group.add_argument(
203
+ "--domain",
204
+ choices=["point_maze", "minigrid", "parking", "panda"],
205
+ required=True,
206
+ help="Domain name (point_maze, minigrid, parking, or panda)",
207
+ )
208
+ required_group.add_argument(
209
+ "--env_name",
210
+ required=True,
211
+ help="Env name (point_maze, minigrid, parking, or panda). For example, Parking-S-14-PC--v0",
212
+ )
213
+ required_group.add_argument(
214
+ "--recognizer",
215
+ choices=[
216
+ "MCTSBasedGraml",
217
+ "ExpertBasedGraml",
218
+ "GCGraml",
219
+ "Graql",
220
+ "Draco",
221
+ "GCDraco",
222
+ ],
223
+ required=True,
224
+ help="Recognizer type. Follow readme.md and recognizer folder for more information and rules.",
225
+ )
226
+ required_group.add_argument(
227
+ "--task",
228
+ choices=[
229
+ "L1",
230
+ "L2",
231
+ "L3",
232
+ "L4",
233
+ "L5",
234
+ "L11",
235
+ "L22",
236
+ "L33",
237
+ "L44",
238
+ "L55",
239
+ "L111",
240
+ "L222",
241
+ "L333",
242
+ "L444",
243
+ "L555",
244
+ ],
245
+ required=True,
246
+ help="Task identifier (e.g., L1, L2,...,L5)",
247
+ )
248
+
249
+ # Optional arguments
250
+ optional_group = parser.add_argument_group("Optional arguments")
251
+ optional_group.add_argument(
252
+ "--collect_stats", action="store_true", help="Whether to collect statistics"
253
+ )
254
+ args = parser.parse_args()
255
+
256
+ ### VALIDATE INPUTS ###
257
+ # Assert that all required arguments are provided
258
+ assert (
259
+ args.domain is not None
260
+ and args.recognizer is not None
261
+ and args.task is not None
262
+ ), "Missing required arguments: domain, recognizer, or task"
263
+ return args
264
+
265
+
266
+ if __name__ == "__main__":
267
+ args = parse_args()
268
+ run_odgr_problem(args)