gr-libs 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. evaluation/analyze_results_cross_alg_cross_domain.py +277 -0
  2. evaluation/create_minigrid_map_image.py +34 -0
  3. evaluation/file_system.py +42 -0
  4. evaluation/generate_experiments_results.py +92 -0
  5. evaluation/generate_experiments_results_new_ver1.py +254 -0
  6. evaluation/generate_experiments_results_new_ver2.py +331 -0
  7. evaluation/generate_task_specific_statistics_plots.py +272 -0
  8. evaluation/get_plans_images.py +47 -0
  9. evaluation/increasing_and_decreasing_.py +63 -0
  10. gr_libs/__init__.py +2 -0
  11. gr_libs/environment/__init__.py +0 -0
  12. gr_libs/environment/environment.py +227 -0
  13. gr_libs/environment/utils/__init__.py +0 -0
  14. gr_libs/environment/utils/utils.py +17 -0
  15. gr_libs/metrics/__init__.py +0 -0
  16. gr_libs/metrics/metrics.py +224 -0
  17. gr_libs/ml/__init__.py +6 -0
  18. gr_libs/ml/agent.py +56 -0
  19. gr_libs/ml/base/__init__.py +1 -0
  20. gr_libs/ml/base/rl_agent.py +54 -0
  21. gr_libs/ml/consts.py +22 -0
  22. gr_libs/ml/neural/__init__.py +3 -0
  23. gr_libs/ml/neural/deep_rl_learner.py +395 -0
  24. gr_libs/ml/neural/utils/__init__.py +2 -0
  25. gr_libs/ml/neural/utils/dictlist.py +33 -0
  26. gr_libs/ml/neural/utils/penv.py +57 -0
  27. gr_libs/ml/planner/__init__.py +0 -0
  28. gr_libs/ml/planner/mcts/__init__.py +0 -0
  29. gr_libs/ml/planner/mcts/mcts_model.py +330 -0
  30. gr_libs/ml/planner/mcts/utils/__init__.py +2 -0
  31. gr_libs/ml/planner/mcts/utils/node.py +33 -0
  32. gr_libs/ml/planner/mcts/utils/tree.py +102 -0
  33. gr_libs/ml/sequential/__init__.py +1 -0
  34. gr_libs/ml/sequential/lstm_model.py +192 -0
  35. gr_libs/ml/tabular/__init__.py +3 -0
  36. gr_libs/ml/tabular/state.py +21 -0
  37. gr_libs/ml/tabular/tabular_q_learner.py +453 -0
  38. gr_libs/ml/tabular/tabular_rl_agent.py +126 -0
  39. gr_libs/ml/utils/__init__.py +6 -0
  40. gr_libs/ml/utils/env.py +7 -0
  41. gr_libs/ml/utils/format.py +100 -0
  42. gr_libs/ml/utils/math.py +13 -0
  43. gr_libs/ml/utils/other.py +24 -0
  44. gr_libs/ml/utils/storage.py +127 -0
  45. gr_libs/recognizer/__init__.py +0 -0
  46. gr_libs/recognizer/gr_as_rl/__init__.py +0 -0
  47. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +102 -0
  48. gr_libs/recognizer/graml/__init__.py +0 -0
  49. gr_libs/recognizer/graml/gr_dataset.py +134 -0
  50. gr_libs/recognizer/graml/graml_recognizer.py +266 -0
  51. gr_libs/recognizer/recognizer.py +46 -0
  52. gr_libs/recognizer/utils/__init__.py +1 -0
  53. gr_libs/recognizer/utils/format.py +13 -0
  54. gr_libs-0.1.3.dist-info/METADATA +197 -0
  55. gr_libs-0.1.3.dist-info/RECORD +62 -0
  56. gr_libs-0.1.3.dist-info/WHEEL +5 -0
  57. gr_libs-0.1.3.dist-info/top_level.txt +3 -0
  58. tutorials/graml_minigrid_tutorial.py +30 -0
  59. tutorials/graml_panda_tutorial.py +32 -0
  60. tutorials/graml_parking_tutorial.py +38 -0
  61. tutorials/graml_point_maze_tutorial.py +43 -0
  62. tutorials/graql_minigrid_tutorial.py +29 -0
@@ -0,0 +1,100 @@
1
+ import numpy
2
+ import re
3
+ import torch
4
+ import gr_libs.ml
5
+ import gymnasium as gym
6
+ import random
7
+
8
+ def get_obss_preprocessor(obs_space):
9
+ # Check if obs_space is an image space
10
+ if isinstance(obs_space, gym.spaces.Box):
11
+ obs_space = {"image": obs_space.shape}
12
+
13
+ def preprocess_obss(obss, device=None):
14
+ return ml.DictList({
15
+ "image": preprocess_images(obss, device=device)
16
+ })
17
+
18
+ # Check if it is a MiniGrid observation space
19
+ elif isinstance(obs_space, gym.spaces.Dict) and "image" in obs_space.spaces.keys():
20
+ obs_space = {"image": obs_space.spaces["image"].shape, "text": 100}
21
+
22
+ vocab = Vocabulary(obs_space["text"])
23
+
24
+ def preprocess_obss(obss, device=None):
25
+ return ml.DictList({
26
+ "image": preprocess_images([obs["image"] for obs in obss], device=device),
27
+ "text": preprocess_texts([obs["mission"] for obs in obss], vocab, device=device)
28
+ })
29
+
30
+ preprocess_obss.vocab = vocab
31
+
32
+ # Check if it is a MiniGrid observation space
33
+ elif isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces.keys():
34
+ obs_space = {"observation": obs_space.spaces["observation"].shape}
35
+
36
+ def preprocess_obss(obss, device=None):
37
+ return ml.DictList({
38
+ "observation": preprocess_images(obss, device=device)
39
+ })
40
+
41
+
42
+ else:
43
+ raise ValueError("Unknown observation space: " + str(obs_space))
44
+
45
+ return obs_space, preprocess_obss
46
+
47
+
48
+ def preprocess_images(images, device=None):
49
+ # Bug of Pytorch: very slow if not first converted to numpy array
50
+ images = numpy.array(images)
51
+ return torch.tensor(images, device=device, dtype=torch.float)
52
+
53
+
54
+ def random_subset_with_order(sequence, subset_size, is_consecutive = True):
55
+ if subset_size >= len(sequence):
56
+ return sequence
57
+ else:
58
+ if is_consecutive:
59
+ indices_to_select = [i for i in range(subset_size)]
60
+ else:
61
+ indices_to_select = sorted(random.sample(range(len(sequence)), subset_size)) # Randomly select indices to keep
62
+ return [sequence[i] for i in indices_to_select] # Return the elements corresponding to the selected indices
63
+
64
+
65
+
66
+ def preprocess_texts(texts, vocab, device=None):
67
+ var_indexed_texts = []
68
+ max_text_len = 0
69
+
70
+ for text in texts:
71
+ tokens = re.findall("([a-z]+)", text.lower())
72
+ var_indexed_text = numpy.array([vocab[token] for token in tokens])
73
+ var_indexed_texts.append(var_indexed_text)
74
+ max_text_len = max(len(var_indexed_text), max_text_len)
75
+
76
+ indexed_texts = numpy.zeros((len(texts), max_text_len))
77
+
78
+ for i, indexed_text in enumerate(var_indexed_texts):
79
+ indexed_texts[i, :len(indexed_text)] = indexed_text
80
+
81
+ return torch.tensor(indexed_texts, device=device, dtype=torch.long)
82
+
83
+
84
+ class Vocabulary:
85
+ """A mapping from tokens to ids with a capacity of `max_size` words.
86
+ It can be saved in a `vocab.json` file."""
87
+
88
+ def __init__(self, max_size):
89
+ self.max_size = max_size
90
+ self.vocab = {}
91
+
92
+ def load_vocab(self, vocab):
93
+ self.vocab = vocab
94
+
95
+ def __getitem__(self, token):
96
+ if not token in self.vocab.keys():
97
+ if len(self.vocab) >= self.max_size:
98
+ raise ValueError("Maximum vocabulary capacity reached")
99
+ self.vocab[token] = len(self.vocab) + 1
100
+ return self.vocab[token]
@@ -0,0 +1,13 @@
1
+ import math
2
+ from typing import Callable, Generator, List
3
+
4
+ def softmax(values: List[float]) -> List[float]:
5
+ """Computes softmax probabilities for an array of values
6
+ TODO We should probably use numpy arrays here
7
+ Args:
8
+ values (np.array): Input values for which to compute softmax
9
+
10
+ Returns:
11
+ np.array: softmax probabilities
12
+ """
13
+ return [(math.exp(q)) / sum([math.exp(_q) for _q in values]) for q in values]
@@ -0,0 +1,24 @@
1
+ import random
2
+ import numpy
3
+ import torch
4
+ import collections
5
+
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+
10
+ def seed(seed):
11
+ random.seed(seed)
12
+ numpy.random.seed(seed)
13
+ torch.manual_seed(seed)
14
+ if torch.cuda.is_available():
15
+ torch.cuda.manual_seed_all(seed)
16
+
17
+
18
+ def synthesize(array):
19
+ d = collections.OrderedDict()
20
+ d["mean"] = numpy.mean(array)
21
+ d["std"] = numpy.std(array)
22
+ d["min"] = numpy.amin(array)
23
+ d["max"] = numpy.amax(array)
24
+ return d
@@ -0,0 +1,127 @@
1
+ import csv
2
+ import os
3
+ import torch
4
+ import logging
5
+ import sys
6
+
7
+ from .other import device
8
+
9
+ def create_folders_if_necessary(path):
10
+ if not os.path.exists(path):
11
+ os.makedirs(path)
12
+
13
+
14
+ def get_storage_framework_dir(recognizer: str):
15
+ return os.path.join(get_storage_dir(),recognizer)
16
+
17
+ def get_storage_dir():
18
+ return "dataset"
19
+
20
+ def _get_models_directory_name():
21
+ return "models"
22
+
23
+ def _get_siamese_datasets_directory_name():
24
+ return "siamese_datasets"
25
+
26
+ def _get_observations_directory_name():
27
+ return "observations"
28
+
29
+ def get_observation_file_name(observability_percentage: float):
30
+ return 'obs' + str(observability_percentage) + '.pkl'
31
+
32
+
33
+ def get_domain_dir(domain_name, recognizer:str):
34
+ return os.path.join(get_storage_framework_dir(recognizer), domain_name)
35
+
36
+ def get_env_dir(domain_name, env_name, recognizer:str):
37
+ return os.path.join(get_domain_dir(domain_name, recognizer), env_name)
38
+
39
+ def get_observations_dir(domain_name, env_name, recognizer:str):
40
+ return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), _get_observations_directory_name())
41
+
42
+ def get_agent_model_dir(domain_name, model_name, class_name):
43
+ return os.path.join(get_storage_dir(), _get_models_directory_name(), domain_name, model_name, class_name)
44
+
45
+ def get_lstm_model_dir(domain_name, env_name, model_name, recognizer:str):
46
+ return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), model_name)
47
+
48
+ def get_models_dir(domain_name, env_name, recognizer:str):
49
+ return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), _get_models_directory_name())
50
+
51
+ ### GRAML PATHS ###
52
+
53
+ def get_siamese_dataset_path(domain_name, env_name, model_name, recognizer:str):
54
+ return os.path.join(get_lstm_model_dir(domain_name, env_name, model_name, recognizer), _get_siamese_datasets_directory_name())
55
+
56
+ def get_embeddings_result_path(domain_name, env_name, recognizer:str):
57
+ return os.path.join(get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "goal_embeddings")
58
+
59
+ def get_embeddings_result_path(domain_name, env_name, recognizer:str):
60
+ return os.path.join(get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "goal_embeddings")
61
+
62
+ def get_and_create(path):
63
+ create_folders_if_necessary(path)
64
+ return path
65
+
66
+ def get_experiment_results_path(domain, env_name, task, recognizer:str):
67
+ return os.path.join(get_env_dir(domain, env_name=env_name, recognizer=recognizer), "experiment_results", env_name, task, "experiment_results")
68
+
69
+ def get_plans_result_path(domain_name, env_name, recognizer:str):
70
+ return os.path.join(get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "plans")
71
+
72
+ def get_policy_sequences_result_path(domain_name, env_name, recognizer:str):
73
+ return os.path.join(get_env_dir(domain_name, env_name, recognizer=recognizer), "policy_sequences")
74
+
75
+ ### END GRAML PATHS ###
76
+ ''
77
+ ### GRAQL PATHS ###
78
+
79
+ def get_gr_as_rl_experiment_confidence_path(domain_name, env_name, recognizer:str):
80
+ return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), "experiments")
81
+
82
+ ### GRAQL PATHS ###
83
+
84
+ def get_status_path(model_dir):
85
+ return os.path.join(model_dir, "status.pt")
86
+
87
+
88
+ def get_status(model_dir):
89
+ path = get_status_path(model_dir)
90
+ return torch.load(path, map_location=device)
91
+
92
+
93
+ def save_status(status, model_dir):
94
+ path = get_status_path(model_dir)
95
+ utils.create_folders_if_necessary(path)
96
+ torch.save(status, path)
97
+
98
+
99
+ def get_vocab(model_dir):
100
+ return get_status(model_dir)["vocab"]
101
+
102
+
103
+ def get_model_state(model_dir):
104
+ return get_status(model_dir)["model_state"]
105
+
106
+
107
+ def get_txt_logger(model_dir):
108
+ path = os.path.join(model_dir, "log.txt")
109
+ utils.create_folders_if_necessary(path)
110
+
111
+ logging.basicConfig(
112
+ level=logging.INFO,
113
+ format="%(message)s",
114
+ handlers=[
115
+ logging.FileHandler(filename=path),
116
+ logging.StreamHandler(sys.stdout)
117
+ ]
118
+ )
119
+
120
+ return logging.getLogger()
121
+
122
+
123
+ def get_csv_logger(model_dir):
124
+ csv_path = os.path.join(model_dir, "log.csv")
125
+ utils.create_folders_if_necessary(csv_path)
126
+ csv_file = open(csv_path, "a")
127
+ return csv_file, csv.writer(csv_file)
File without changes
File without changes
@@ -0,0 +1,102 @@
1
+ from abc import abstractmethod
2
+ import os
3
+ import dill
4
+ from typing import List, Type
5
+ import numpy as np
6
+ from gr_libs.environment.environment import EnvProperty, GCEnvProperty
7
+ from gr_libs.environment.utils.utils import domain_to_env_property
8
+ from gr_libs.metrics.metrics import kl_divergence_norm_softmax, mean_wasserstein_distance
9
+ from gr_libs.ml.base import RLAgent
10
+ from gr_libs.ml.neural.deep_rl_learner import DeepRLAgent, GCDeepRLAgent
11
+ from gr_libs.ml.tabular.tabular_q_learner import TabularQLearner
12
+ from gr_libs.ml.utils.storage import get_gr_as_rl_experiment_confidence_path
13
+ from gr_libs.recognizer.recognizer import GaAdaptingRecognizer, GaAgentTrainerRecognizer, LearningRecognizer, Recognizer
14
+
15
+ class GRAsRL(Recognizer):
16
+ def __init__(self, *args, **kwargs):
17
+ super().__init__(*args, **kwargs)
18
+ self.agents = {} # consider changing to ContextualAgent
19
+
20
+ def goals_adaptation_phase(self, dynamic_goals: List[str], dynamic_train_configs):
21
+ super().goals_adaptation_phase(dynamic_goals, dynamic_train_configs)
22
+ dynamic_goals_problems = [self.env_prop.goal_to_problem_str(goal) for goal in dynamic_goals]
23
+ self.active_goals = dynamic_goals
24
+ self.active_problems = dynamic_goals_problems
25
+ for problem_name, config in zip(dynamic_goals_problems, dynamic_train_configs):
26
+ agent_kwargs = {"domain_name": self.env_prop.domain_name,
27
+ "problem_name": problem_name}
28
+ if config[0]: agent_kwargs["algorithm"] = config[0]
29
+ if config[1]: agent_kwargs["num_timesteps"] = config[1]
30
+ agent = self.rl_agent_type(**agent_kwargs)
31
+ agent.learn()
32
+ self.agents[problem_name] = agent
33
+ self.action_space = next(iter(self.agents.values())).env.action_space
34
+
35
+ def inference_phase(self, inf_sequence, true_goal, percentage) -> str:
36
+ scores = []
37
+ for problem_name in self.active_problems:
38
+ agent = self.choose_agent(problem_name)
39
+ if self.env_prop.gc_adaptable():
40
+ assert self.__class__.__name__ == "GCDraco", "This recognizer is not compatible with goal conditioned problems."
41
+ inf_sequence = self.prepare_inf_sequence(problem_name, inf_sequence)
42
+ score = self.evaluation_function(inf_sequence, agent, self.action_space)
43
+ scores.append(score)
44
+ #scores = metrics.softmin(np.array(scores))
45
+ if self.collect_statistics:
46
+ results_path = get_gr_as_rl_experiment_confidence_path(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name, recognizer=self.__class__.__name__)
47
+ if not os.path.exists(results_path): os.makedirs(results_path)
48
+ with open(results_path + f'/true_{true_goal}_{percentage}_scores.pkl', 'wb') as scores_file:
49
+ dill.dump([(str(goal), score) for (goal, score) in zip(self.active_goals, scores)], scores_file)
50
+ div, true_goal_index = min((div, goal) for (goal, div) in enumerate(scores))
51
+ return str(self.active_goals[true_goal_index])
52
+
53
+ def choose_agent(self, problem_name:str) -> RLAgent:
54
+ return self.agents[problem_name]
55
+
56
+
57
+ class Graql(GRAsRL, GaAgentTrainerRecognizer):
58
+ def __init__(self, *args, **kwargs):
59
+ super().__init__(*args, **kwargs)
60
+ assert not self.env_prop.gc_adaptable() and self.env_prop.is_state_discrete() and self.env_prop.is_action_discrete()
61
+ if self.rl_agent_type==None: self.rl_agent_type = TabularQLearner
62
+ self.evaluation_function = kl_divergence_norm_softmax
63
+
64
+ class Draco(GRAsRL, GaAgentTrainerRecognizer):
65
+ def __init__(self, *args, **kwargs):
66
+ super().__init__(*args, **kwargs)
67
+ assert not self.env_prop.is_state_discrete() and not self.env_prop.is_action_discrete()
68
+ if self.rl_agent_type==None: self.rl_agent_type = DeepRLAgent
69
+ self.evaluation_function = mean_wasserstein_distance
70
+
71
+ class GCDraco(GRAsRL, LearningRecognizer, GaAdaptingRecognizer): # TODO problem: it gets 2 goal_adaptation phase from parents, one with configs and one without.
72
+ def __init__(self, *args, **kwargs):
73
+ super().__init__(*args, **kwargs)
74
+ assert self.env_prop.gc_adaptable() and not self.env_prop.is_state_discrete() and not self.env_prop.is_action_discrete()
75
+ self.evaluation_function = mean_wasserstein_distance
76
+ if self.rl_agent_type==None: self.rl_agent_type = GCDeepRLAgent
77
+
78
+ def domain_learning_phase(self, base_goals: List[str], train_configs):
79
+ super().domain_learning_phase(base_goals, train_configs)
80
+ agent_kwargs = {"domain_name": self.env_prop.domain_name,
81
+ "problem_name": self.env_prop.name,
82
+ "algorithm": self.original_train_configs[0][0],
83
+ "num_timesteps": self.original_train_configs[0][1]}
84
+ agent = self.rl_agent_type(**agent_kwargs)
85
+ agent.learn()
86
+ self.agents[self.env_prop.name] = agent
87
+ self.action_space = agent.env.action_space
88
+
89
+ # this method currently does nothing but optimizations can be made here.
90
+ def goals_adaptation_phase(self, dynamic_goals):
91
+ self.active_goals = dynamic_goals
92
+ self.active_problems = [self.env_prop.goal_to_problem_str(goal) for goal in dynamic_goals]
93
+
94
+ def choose_agent(self, problem_name:str) -> RLAgent:
95
+ return next(iter(self.agents.values()))
96
+
97
+ def prepare_inf_sequence(self, problem_name: str, inf_sequence):
98
+ if not self.env_prop.use_goal_directed_problem():
99
+ for obs in inf_sequence:
100
+ obs[0]['desired_goal'] = np.array([self.env_prop.str_to_goal(problem_name)], dtype=obs[0]['desired_goal'].dtype)
101
+ return inf_sequence
102
+ return inf_sequence
File without changes
@@ -0,0 +1,134 @@
1
+ import numpy as np
2
+ from torch.utils.data import Dataset
3
+ import random
4
+ from types import MethodType
5
+ from typing import List
6
+ from gr_libs.environment.environment import EnvProperty
7
+ from gr_libs.metrics.metrics import measure_average_sequence_distance
8
+ from gr_libs.ml.base.rl_agent import ContextualAgent
9
+ from gr_libs.ml.utils import get_siamese_dataset_path
10
+ from gr_libs.ml.base import RLAgent
11
+ import os
12
+ import dill
13
+ import torch
14
+
15
+ class GRDataset(Dataset):
16
+ def __init__(self, num_samples, samples):
17
+ self.num_samples = num_samples
18
+ self.samples = samples
19
+
20
+ def __len__(self):
21
+ return self.num_samples
22
+
23
+ def __getitem__(self, idx):
24
+ return self.samples[idx] # returns a tuple - as appended in 'generate_dataset' last line
25
+
26
+ def check_diff_goals(first_agent_goal, second_agent_goal):
27
+ try:
28
+ assert first_agent_goal != second_agent_goal
29
+ except Exception as e:
30
+ try:
31
+ assert any(first_agent_goal != second_agent_goal)
32
+ except Exception as e:
33
+ for arr1, arr2 in zip(first_agent_goal, second_agent_goal):
34
+ assert any(elm1!=elm2 for elm1, elm2 in zip(arr1, arr2))
35
+
36
+ def generate_datasets(num_samples, agents: List[ContextualAgent], observation_creation_method : MethodType, problems: List[str], env_prop:EnvProperty, recognizer_name:str, gc_goal_set=None):
37
+ if gc_goal_set: model_name = env_prop.name
38
+ else: model_name = env_prop.problem_list_to_str_tuple(problems)
39
+ dataset_directory = get_siamese_dataset_path(domain_name=env_prop.domain_name, env_name=env_prop.name, model_name=model_name, recognizer=recognizer_name)
40
+ dataset_train_path, dataset_dev_path = os.path.join(dataset_directory, 'train.pkl'), os.path.join(dataset_directory, 'dev.pkl')
41
+ if os.path.exists(dataset_train_path) and os.path.exists(dataset_dev_path):
42
+ print(f"Loading pre-existing datasets in {dataset_directory}")
43
+ with open(dataset_train_path, 'rb') as train_file:
44
+ train_samples = dill.load(train_file)
45
+ with open(dataset_dev_path, 'rb') as dev_file:
46
+ dev_samples = dill.load(dev_file)
47
+ else:
48
+ print(f"{dataset_directory} doesn't exist, generating datasets")
49
+ if not os.path.exists(dataset_directory):
50
+ os.makedirs(dataset_directory)
51
+ all_samples = []
52
+ for i in range(num_samples):
53
+ if gc_goal_set != None: # TODO change to having one flow for both cases and injecting according to gc_goal_set or not
54
+ assert env_prop.gc_adaptable() == True, "shouldn't specify a goal directed representation if not generating datasets with a general agent."
55
+ is_same_goal = (np.random.choice([1, 0], 1, p=[1/max(len(gc_goal_set), 6), 1 - 1/max(len(gc_goal_set), 6)]))[0]
56
+ first_is_consecutive = np.random.choice([True, False], 1, p=[0.5, 0.5])[0]
57
+ first_random_index = np.random.randint(0, len(gc_goal_set)) # works for lists of every object type, while np.choice only works for 1d arrays
58
+ first_agent_goal = gc_goal_set[first_random_index] # could be either a real goal or a goal-directed problem name
59
+ #first_agent_goal = np.random.choice(gc_goal_set)
60
+ first_trace_percentage = random.choice([0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
61
+ first_observation = []
62
+ first_agent_kwargs = {
63
+ "action_selection_method": observation_creation_method,
64
+ "percentage": first_trace_percentage,
65
+ "is_consecutive": first_is_consecutive,
66
+ "save_fig": False
67
+ }
68
+ while first_observation == []:
69
+ # needs to be different than agents[0] problem_name, it should be from the gc_goal_set.
70
+ # but the problem is with the panda because it
71
+ if env_prop.use_goal_directed_problem(): first_agent_kwargs["goal_directed_problem"] = first_agent_goal
72
+ else: first_agent_kwargs["goal_directed_goal"] = first_agent_goal
73
+ first_observation = agents[0].agent.generate_partial_observation(**first_agent_kwargs)
74
+ first_observation = agents[0].agent.simplify_observation(first_observation)
75
+
76
+ second_is_consecutive = np.random.choice([True, False], 1, p=[0.5, 0.5])[0]
77
+ second_agent_goal = first_agent_goal
78
+ second_random_index = first_random_index
79
+ if not is_same_goal:
80
+ second_random_index = np.random.choice([i for i in range(len(gc_goal_set)) if i != first_random_index])
81
+ assert first_random_index != second_random_index
82
+ second_agent_goal = gc_goal_set[second_random_index]
83
+ if not is_same_goal: check_diff_goals(first_agent_goal, second_agent_goal)
84
+ second_trace_percentage = first_trace_percentage
85
+ second_observation = []
86
+ second_agent_kwargs = {
87
+ "action_selection_method": observation_creation_method,
88
+ "percentage": second_trace_percentage,
89
+ "is_consecutive": second_is_consecutive,
90
+ "save_fig": False
91
+ }
92
+ while second_observation == []:
93
+ if env_prop.use_goal_directed_problem() == True: second_agent_kwargs["goal_directed_problem"] = second_agent_goal
94
+ else: second_agent_kwargs["goal_directed_goal"] = second_agent_goal
95
+ second_observation = agents[0].agent.generate_partial_observation(**second_agent_kwargs)
96
+ second_observation = agents[0].agent.simplify_observation(second_observation)
97
+ else:
98
+ is_same_goal = (np.random.choice([1, 0], 1, p=[1/max(len(agents), 6), 1 - 1/max(len(agents), 6)]))[0]
99
+ first_is_consecutive = np.random.choice([True, False], 1, p=[0.5, 0.5])[0]
100
+ first_agent = np.random.choice(agents)
101
+ first_trace_percentage = random.choice([0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
102
+ first_observation = first_agent.agent.generate_partial_observation(action_selection_method=observation_creation_method, percentage=first_trace_percentage, is_consecutive=first_is_consecutive, save_fig=False, random_optimalism=True)
103
+ first_observation = first_agent.agent.simplify_observation(first_observation)
104
+
105
+ second_agent = first_agent
106
+ if not is_same_goal:
107
+ second_agent = np.random.choice([agent for agent in agents if agent != first_agent])
108
+ assert second_agent != first_agent
109
+ second_is_consecutive = np.random.choice([True, False], 1, p=[0.5, 0.5])[0]
110
+ second_trace_percentage = first_trace_percentage
111
+ second_observation = second_agent.agent.generate_partial_observation(action_selection_method=observation_creation_method, percentage=second_trace_percentage, is_consecutive=second_is_consecutive, save_fig=False, random_optimalism=True)
112
+ second_observation = second_agent.agent.simplify_observation(second_observation)
113
+ if is_same_goal:
114
+ observations_distance = measure_average_sequence_distance(first_observation, second_observation) # for debugging mate
115
+ all_samples.append((
116
+ [torch.tensor(observation, dtype=torch.float32) for observation in first_observation],
117
+ [torch.tensor(observation, dtype=torch.float32) for observation in second_observation],
118
+ torch.tensor(is_same_goal, dtype=torch.float32)))
119
+ # all_samples.append((first_observation, second_observation, torch.tensor(is_same_goal, dtype=torch.float32)))
120
+ if i % 1000 == 0:
121
+ print(f'generated {i} samples')
122
+
123
+ total_samples = len(all_samples)
124
+ train_size = int(0.8 * total_samples)
125
+ train_samples = all_samples[:train_size]
126
+ dev_samples = all_samples[train_size:]
127
+ with open(dataset_train_path, 'wb') as train_file:
128
+ dill.dump(train_samples, train_file)
129
+ with open(dataset_dev_path, 'wb') as dev_file:
130
+ dill.dump(dev_samples, dev_file)
131
+
132
+ return train_samples, dev_samples
133
+
134
+