gr-libs 0.1.7.post0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. gr_libs/__init__.py +4 -1
  2. gr_libs/_evaluation/__init__.py +1 -0
  3. gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +260 -0
  4. gr_libs/_evaluation/_generate_experiments_results.py +141 -0
  5. gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +497 -0
  6. gr_libs/_evaluation/_get_plans_images.py +61 -0
  7. gr_libs/_evaluation/_increasing_and_decreasing_.py +106 -0
  8. gr_libs/_version.py +2 -2
  9. gr_libs/all_experiments.py +294 -0
  10. gr_libs/environment/__init__.py +30 -9
  11. gr_libs/environment/_utils/utils.py +27 -0
  12. gr_libs/environment/environment.py +417 -54
  13. gr_libs/metrics/__init__.py +7 -0
  14. gr_libs/metrics/metrics.py +231 -54
  15. gr_libs/ml/__init__.py +2 -5
  16. gr_libs/ml/agent.py +21 -6
  17. gr_libs/ml/base/__init__.py +3 -1
  18. gr_libs/ml/base/rl_agent.py +81 -13
  19. gr_libs/ml/consts.py +1 -1
  20. gr_libs/ml/neural/__init__.py +1 -3
  21. gr_libs/ml/neural/deep_rl_learner.py +619 -378
  22. gr_libs/ml/neural/utils/__init__.py +1 -2
  23. gr_libs/ml/neural/utils/dictlist.py +3 -3
  24. gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +1 -1
  25. gr_libs/ml/planner/mcts/{utils → _utils}/node.py +11 -7
  26. gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +15 -11
  27. gr_libs/ml/planner/mcts/mcts_model.py +571 -312
  28. gr_libs/ml/sequential/__init__.py +0 -1
  29. gr_libs/ml/sequential/_lstm_model.py +270 -0
  30. gr_libs/ml/tabular/__init__.py +1 -3
  31. gr_libs/ml/tabular/state.py +7 -7
  32. gr_libs/ml/tabular/tabular_q_learner.py +150 -82
  33. gr_libs/ml/tabular/tabular_rl_agent.py +42 -28
  34. gr_libs/ml/utils/__init__.py +2 -3
  35. gr_libs/ml/utils/format.py +28 -97
  36. gr_libs/ml/utils/math.py +5 -3
  37. gr_libs/ml/utils/other.py +3 -3
  38. gr_libs/ml/utils/storage.py +88 -81
  39. gr_libs/odgr_executor.py +268 -0
  40. gr_libs/problems/consts.py +1549 -1227
  41. gr_libs/recognizer/_utils/__init__.py +0 -0
  42. gr_libs/recognizer/_utils/format.py +18 -0
  43. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +233 -88
  44. gr_libs/recognizer/graml/_gr_dataset.py +233 -0
  45. gr_libs/recognizer/graml/graml_recognizer.py +586 -252
  46. gr_libs/recognizer/recognizer.py +90 -30
  47. gr_libs/tutorials/draco_panda_tutorial.py +58 -0
  48. gr_libs/tutorials/draco_parking_tutorial.py +56 -0
  49. gr_libs/tutorials/gcdraco_panda_tutorial.py +62 -0
  50. gr_libs/tutorials/gcdraco_parking_tutorial.py +57 -0
  51. gr_libs/tutorials/graml_minigrid_tutorial.py +64 -0
  52. gr_libs/tutorials/graml_panda_tutorial.py +57 -0
  53. gr_libs/tutorials/graml_parking_tutorial.py +52 -0
  54. gr_libs/tutorials/graml_point_maze_tutorial.py +60 -0
  55. gr_libs/tutorials/graql_minigrid_tutorial.py +50 -0
  56. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/METADATA +84 -29
  57. gr_libs-0.2.2.dist-info/RECORD +71 -0
  58. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/WHEEL +1 -1
  59. gr_libs-0.2.2.dist-info/top_level.txt +2 -0
  60. tests/test_draco.py +14 -0
  61. tests/test_gcdraco.py +10 -0
  62. tests/test_graml.py +12 -8
  63. tests/test_graql.py +3 -2
  64. evaluation/analyze_results_cross_alg_cross_domain.py +0 -277
  65. evaluation/create_minigrid_map_image.py +0 -34
  66. evaluation/file_system.py +0 -42
  67. evaluation/generate_experiments_results.py +0 -92
  68. evaluation/generate_experiments_results_new_ver1.py +0 -254
  69. evaluation/generate_experiments_results_new_ver2.py +0 -331
  70. evaluation/generate_task_specific_statistics_plots.py +0 -272
  71. evaluation/get_plans_images.py +0 -47
  72. evaluation/increasing_and_decreasing_.py +0 -63
  73. gr_libs/environment/utils/utils.py +0 -17
  74. gr_libs/ml/neural/utils/penv.py +0 -57
  75. gr_libs/ml/sequential/lstm_model.py +0 -192
  76. gr_libs/recognizer/graml/gr_dataset.py +0 -134
  77. gr_libs/recognizer/utils/__init__.py +0 -1
  78. gr_libs/recognizer/utils/format.py +0 -13
  79. gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
  80. gr_libs-0.1.7.post0.dist-info/top_level.txt +0 -4
  81. tutorials/graml_minigrid_tutorial.py +0 -34
  82. tutorials/graml_panda_tutorial.py +0 -41
  83. tutorials/graml_parking_tutorial.py +0 -39
  84. tutorials/graml_point_maze_tutorial.py +0 -39
  85. tutorials/graql_minigrid_tutorial.py +0 -34
  86. /gr_libs/environment/{utils → _utils}/__init__.py +0 -0
@@ -1,274 +1,608 @@
1
- from abc import ABC, abstractmethod
2
- from collections import namedtuple
1
+ """ Collection of recognizers that use GRAML methods: metric learning for ODGR. """
2
+
3
3
  import os
4
- from gr_libs.environment.environment import EnvProperty, GCEnvProperty, LSTMProperties
5
- from gr_libs.ml import utils
6
- from gr_libs.ml.base import ContextualAgent
7
- from typing import List, Tuple
4
+ from abc import abstractmethod
5
+
6
+ import dill
8
7
  import numpy as np
9
- from torch.utils.data import DataLoader
10
- from torch.nn.utils.rnn import pad_sequence
11
8
  import torch
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from torch.utils.data import DataLoader
11
+
12
+ from gr_libs.environment.environment import EnvProperty
13
+ from gr_libs.metrics import metrics
14
+ from gr_libs.ml import utils
15
+ from gr_libs.ml.base import ContextualAgent
12
16
  from gr_libs.ml.neural.deep_rl_learner import DeepRLAgent, GCDeepRLAgent
13
17
  from gr_libs.ml.planner.mcts import mcts_model
14
- import dill
18
+ from gr_libs.ml.sequential._lstm_model import LstmObservations, train_metric_model
15
19
  from gr_libs.ml.tabular.tabular_q_learner import TabularQLearner
16
- from gr_libs.recognizer.graml.gr_dataset import GRDataset, generate_datasets
17
- from gr_libs.ml.sequential.lstm_model import LstmObservations, train_metric_model
18
20
  from gr_libs.ml.utils.format import random_subset_with_order
19
- from gr_libs.ml.utils.storage import get_and_create, get_lstm_model_dir, get_embeddings_result_path, get_policy_sequences_result_path
20
- from gr_libs.metrics import metrics
21
- from gr_libs.recognizer.recognizer import GaAdaptingRecognizer, GaAgentTrainerRecognizer, LearningRecognizer, Recognizer # import first, very dependent
21
+ from gr_libs.ml.utils.storage import (
22
+ get_and_create,
23
+ get_embeddings_result_path,
24
+ get_lstm_model_dir,
25
+ get_policy_sequences_result_path,
26
+ )
27
+ from gr_libs.recognizer.graml._gr_dataset import GRDataset, generate_datasets
28
+ from gr_libs.recognizer.recognizer import (
29
+ GaAdaptingRecognizer,
30
+ GaAgentTrainerRecognizer,
31
+ LearningRecognizer,
32
+ )
22
33
 
23
34
  ### TODO IMPLEMENT MORE SELECTION METHODS, MAKE SURE action_probs IS AS IT SEEMS: list of action-probability 'es ###
24
35
 
36
+
25
37
  def collate_fn(batch):
26
- first_traces, second_traces, is_same_goals = zip(*batch)
27
- # torch.stack takes tensor tuples (fixed size) and stacks them up in a matrix
28
- first_traces_padded = pad_sequence([torch.stack(sequence) for sequence in first_traces], batch_first=True)
29
- second_traces_padded = pad_sequence([torch.stack(sequence) for sequence in second_traces], batch_first=True)
30
- first_traces_lengths = [len(trace) for trace in first_traces]
31
- second_traces_lengths = [len(trace) for trace in second_traces]
32
- return first_traces_padded.to(utils.device), second_traces_padded.to(utils.device), torch.stack(is_same_goals).to(utils.device), first_traces_lengths, second_traces_lengths
33
-
34
- def load_weights(loaded_model : LstmObservations, path):
35
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
- loaded_model.load_state_dict(torch.load(path, map_location=utils.device))
37
- loaded_model.to(utils.device) # Ensure model is on the right device
38
- return loaded_model
39
-
40
- def save_weights(model : LstmObservations, path):
41
- directory = os.path.dirname(path)
42
- if not os.path.exists(directory):
43
- os.makedirs(directory)
44
- torch.save(model.state_dict(), path)
38
+ """
39
+ Collates a batch of data for training or evaluation.
40
+
41
+ Args:
42
+ batch (list): A list of tuples, where each tuple contains the first traces, second traces, and the label indicating whether the goals are the same.
43
+
44
+ Returns:
45
+ tuple: A tuple containing the padded first traces, padded second traces, labels, lengths of first traces, and lengths of second traces.
46
+ """
47
+ first_traces, second_traces, is_same_goals = zip(*batch)
48
+ # torch.stack takes tensor tuples (fixed size) and stacks them up in a matrix
49
+ first_traces_padded = pad_sequence(
50
+ [torch.stack(sequence) for sequence in first_traces], batch_first=True
51
+ )
52
+ second_traces_padded = pad_sequence(
53
+ [torch.stack(sequence) for sequence in second_traces], batch_first=True
54
+ )
55
+ first_traces_lengths = [len(trace) for trace in first_traces]
56
+ second_traces_lengths = [len(trace) for trace in second_traces]
57
+ return (
58
+ first_traces_padded.to(utils.device),
59
+ second_traces_padded.to(utils.device),
60
+ torch.stack(is_same_goals).to(utils.device),
61
+ first_traces_lengths,
62
+ second_traces_lengths,
63
+ )
64
+
65
+
66
+ def load_weights(loaded_model: LstmObservations, path):
67
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
68
+ loaded_model.load_state_dict(torch.load(path, map_location=utils.device))
69
+ loaded_model.to(utils.device) # Ensure model is on the right device
70
+ return loaded_model
71
+
72
+
73
+ def save_weights(model: LstmObservations, path):
74
+ directory = os.path.dirname(path)
75
+ if not os.path.exists(directory):
76
+ os.makedirs(directory)
77
+ torch.save(model.state_dict(), path)
78
+
45
79
 
46
80
  class Graml(LearningRecognizer):
47
- def __init__(self, *args, **kwargs):
48
- super().__init__(*args, **kwargs)
49
- self.agents: List[ContextualAgent] = []
50
- self.train_func = train_metric_model; self.collate_func = collate_fn
51
-
52
- @abstractmethod
53
- def train_agents_on_base_goals(self, base_goals: List[str], train_configs: List):
54
- pass
55
-
56
- def domain_learning_phase(self, base_goals: List[str], train_configs: List):
57
- super().domain_learning_phase(base_goals, train_configs)
58
- self.train_agents_on_base_goals(base_goals, train_configs)
59
- # train the network so it will find a metric for the observations of the base agents such that traces of agents to different goals are far from one another
60
- self.model_directory = get_lstm_model_dir(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name , model_name=self.env_prop.problem_list_to_str_tuple(self.original_problems), recognizer=self.__class__.__name__)
61
- last_path = r"lstm_model.pth"
62
- self.model_file_path = os.path.join(self.model_directory, last_path)
63
- self.model = LstmObservations(input_size=self.env_prop.get_lstm_props().input_size, hidden_size=self.env_prop.get_lstm_props().hidden_size)
64
- self.model.to(utils.device)
65
-
66
- if os.path.exists(self.model_file_path):
67
- print(f"Loading pre-existing lstm model in {self.model_file_path}")
68
- load_weights(loaded_model=self.model, path=self.model_file_path)
69
- else:
70
- print(f"{self.model_file_path} doesn't exist, training the model")
71
- train_samples, dev_samples = generate_datasets(num_samples=self.env_prop.get_lstm_props().num_samples,
72
- agents=self.agents,
73
- observation_creation_method=metrics.stochastic_amplified_selection,
74
- problems=self.original_problems,
75
- env_prop=self.env_prop,
76
- gc_goal_set=self.gc_goal_set if hasattr(self, 'gc_goal_set') else None,
77
- recognizer_name=self.__class__.__name__)
78
-
79
- train_dataset = GRDataset(len(train_samples), train_samples)
80
- dev_dataset = GRDataset(len(dev_samples), dev_samples)
81
- self.train_func(self.model, train_loader=DataLoader(train_dataset, batch_size=self.env_prop.get_lstm_props().batch_size, shuffle=False, collate_fn=self.collate_func),
82
- dev_loader=DataLoader(dev_dataset, batch_size=self.env_prop.get_lstm_props().batch_size, shuffle=False, collate_fn=self.collate_func))
83
- save_weights(model=self.model, path=self.model_file_path)
84
-
85
- def goals_adaptation_phase(self, dynamic_goals: List[EnvProperty], save_fig=False):
86
- self.is_first_inf_since_new_goals = True
87
- self.current_goals = dynamic_goals
88
- # start by training each rl agent on the base goal set
89
- self.embeddings_dict = {} # relevant if the embedding of the plan occurs during the goals adaptation phase
90
- self.plans_dict = {} # relevant if the embedding of the plan occurs during the inference phase
91
- for goal in self.current_goals:
92
- obss = self.generate_sequences_library(goal, save_fig=save_fig)
93
- self.plans_dict[str(goal)] = obss
94
-
95
- def get_goal_plan(self, goal):
96
- assert self.plans_dict, "plans_dict wasn't created during goals_adaptation_phase and now inference phase can't return the plans. when inference_same_length, keep the plans and not their embeddings during goals_adaptation_phase."
97
- return self.plans_dict[goal]
98
-
99
- def dump_plans(self, true_sequence, true_goal, percentage):
100
- assert self.plans_dict, "plans_dict wasn't created during goals_adaptation_phase and now inference phase can't return the plans. when inference_same_length, keep the plans and not their embeddings during goals_adaptation_phase."
101
- # Arrange storage
102
- embeddings_path = get_and_create(get_embeddings_result_path(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name, recognizer=self.__class__.__name__))
103
- self.plans_dict[f"{true_goal}_true"] = true_sequence
104
-
105
- with open(embeddings_path + f'/{true_goal}_{percentage}_plans_dict.pkl', 'wb') as plans_file:
106
- to_dump = {}
107
- for goal, obss in self.plans_dict.items():
108
- if goal == f"{true_goal}_true":
109
- to_dump[goal] = self.agents[0].agent.simplify_observation(obss)
110
- else:
111
- to_dump[goal] = []
112
- for obs in obss:
113
- addition = self.agents[0].agent.simplify_observation(obs) if self.is_first_inf_since_new_goals else obs
114
- to_dump[goal].append(addition)
115
- dill.dump(to_dump, plans_file)
116
- self.plans_dict.pop(f"{true_goal}_true")
117
-
118
- def create_embeddings_dict(self):
119
- for goal, obss in self.plans_dict.items():
120
- self.embeddings_dict[goal] = []
121
- for (cons_seq, non_cons_seq) in obss:
122
- self.embeddings_dict[goal].append((self.model.embed_sequence(cons_seq), self.model.embed_sequence(non_cons_seq)))
123
-
124
- def inference_phase(self, inf_sequence, true_goal, percentage) -> str:
125
- embeddings_path = get_and_create(get_embeddings_result_path(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name, recognizer=self.__class__.__name__))
126
- simplified_inf_sequence = self.agents[0].agent.simplify_observation(inf_sequence)
127
- new_embedding = self.model.embed_sequence(simplified_inf_sequence)
128
- assert self.plans_dict, "plans_dict wasn't created during goals_adaptation_phase and now inference phase can't embed the plans. when inference_same_length, keep the plans and not their embeddings during goals_adaptation_phase."
129
- if self.is_first_inf_since_new_goals:
130
- self.is_first_inf_since_new_goals = False
131
- self.update_sequences_library_inference_phase(inf_sequence)
132
- self.create_embeddings_dict()
133
-
134
- closest_goal, greatest_similarity = None, 0
135
- for (goal, embeddings) in self.embeddings_dict.items():
136
- sum_curr_similarities = 0
137
- for cons_embedding, non_cons_embedding in embeddings:
138
- sum_curr_similarities += max(torch.exp(-torch.sum(torch.abs(cons_embedding-new_embedding))), torch.exp(-torch.sum(torch.abs(non_cons_embedding-new_embedding))))
139
- mean_similarity = sum_curr_similarities/len(embeddings)
140
- if mean_similarity > greatest_similarity:
141
- closest_goal = goal
142
- greatest_similarity = mean_similarity
143
-
144
- self.embeddings_dict[f"{true_goal}_true"] = new_embedding
145
- if self.collect_statistics:
146
- with open(os.path.join(embeddings_path, f'{true_goal}_{percentage}_embeddings_dict.pkl'), 'wb') as embeddings_file:
147
- dill.dump(self.embeddings_dict, embeddings_file)
148
- self.embeddings_dict.pop(f"{true_goal}_true")
149
-
150
- return closest_goal
151
-
152
- @abstractmethod
153
- def generate_sequences_library(self, goal: str, save_fig=False) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
154
- pass
155
-
156
- # this function duplicates every sequence and creates a consecutive and non-consecutive version of it
157
- def update_sequences_library_inference_phase(self, inf_sequence) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
158
- new_plans_dict = {}
159
- for goal, obss in self.plans_dict.items():
160
- new_obss = []
161
- for obs in obss:
162
- consecutive_partial_obs = random_subset_with_order(obs, len(inf_sequence), is_consecutive=True)
163
- non_consecutive_partial_obs = random_subset_with_order(obs, len(inf_sequence), is_consecutive=False)
164
- simplified_consecutive_partial_obs = self.agents[0].agent.simplify_observation(consecutive_partial_obs)
165
- simplified_non_consecutive_partial_obs = self.agents[0].agent.simplify_observation(non_consecutive_partial_obs)
166
- new_obss.append((simplified_consecutive_partial_obs, simplified_non_consecutive_partial_obs))
167
- new_plans_dict[goal] = new_obss # override old full observations with new partial observations with consecutive and non-consecutive versions.
168
- self.plans_dict = new_plans_dict
81
+ """
82
+ The Graml class is a subclass of LearningRecognizer and represents a recognizer that uses the Graml algorithm for goal recognition.
83
+ Graml learns a metric over observation sequences, over time: using a GC or a collection of agents, it creates a dataset and learns
84
+ the metric on it during the domain learning phase. During the goals adaptation phase, it creates or receives a library of sequences for each goal,
85
+ and maintains embeddings of them for the inference phase. The inference phase uses the learned metric to find the closest goal to a given sequence.
86
+
87
+ Attributes:
88
+ agents (list[ContextualAgent]): A list of contextual agents associated with the recognizer.
89
+ train_func: The function used for training the metric model.
90
+ collate_func: The function used for collating data in the training process.
91
+
92
+ Methods:
93
+ train_agents_on_base_goals(base_goals: list[str], train_configs: list): Trains the agents on the given base goals and train configurations.
94
+ domain_learning_phase(base_goals: list[str], train_configs: list): Performs the domain learning phase of the Graml algorithm.
95
+ goals_adaptation_phase(dynamic_goals: list[EnvProperty], save_fig=False): Performs the goals adaptation phase of the Graml algorithm.
96
+ get_goal_plan(goal): Retrieves the plan associated with the given goal.
97
+ dump_plans(true_sequence, true_goal, percentage): Dumps the plans to a file.
98
+ create_embeddings_dict(): Creates the embeddings dictionary for the plans.
99
+ inference_phase(inf_sequence, true_goal, percentage) -> str: Performs the inference phase of the Graml algorithm and returns the closest goal.
100
+ generate_sequences_library(goal: str, save_fig=False) -> list[list[tuple[np.ndarray, np.ndarray]]]: Generates the sequences library for the given goal.
101
+ update_sequences_library_inference_phase(inf_sequence) -> list[list[tuple[np.ndarray, np.ndarray]]]: Updates the sequences library during the inference phase.
102
+ """
103
+
104
+ def __init__(self, *args, **kwargs):
105
+ """
106
+ Initialize the GramlRecognizer object.
107
+
108
+ Args:
109
+ *args: Variable length argument list.
110
+ **kwargs: Arbitrary keyword arguments.
111
+
112
+ Attributes:
113
+ agents (list[ContextualAgent]): List of contextual agents.
114
+ train_func: Training function for the metric model.
115
+ collate_func: Collate function for data batching.
116
+ """
117
+ super().__init__(*args, **kwargs)
118
+ self.agents: list[ContextualAgent] = []
119
+ self.train_func = train_metric_model
120
+ self.collate_func = collate_fn
121
+
122
+ @abstractmethod
123
+ def train_agents_on_base_goals(self, base_goals: list[str], train_configs: list):
124
+ pass
125
+
126
+ def domain_learning_phase(self, base_goals: list[str], train_configs: list):
127
+ super().domain_learning_phase(base_goals, train_configs)
128
+ self.train_agents_on_base_goals(base_goals, train_configs)
129
+ # train the network so it will find a metric for the observations of the base agents such that traces of agents to different goals are far from one another
130
+ self.model_directory = get_lstm_model_dir(
131
+ domain_name=self.env_prop.domain_name,
132
+ env_name=self.env_prop.name,
133
+ model_name=self.env_prop.problem_list_to_str_tuple(self.original_problems),
134
+ recognizer=self.__class__.__name__,
135
+ )
136
+ last_path = r"lstm_model.pth"
137
+ self.model_file_path = os.path.join(self.model_directory, last_path)
138
+ self.model = LstmObservations(
139
+ input_size=self.env_prop.get_lstm_props().input_size,
140
+ hidden_size=self.env_prop.get_lstm_props().hidden_size,
141
+ )
142
+ self.model.to(utils.device)
143
+
144
+ if os.path.exists(self.model_file_path):
145
+ print(f"Loading pre-existing lstm model in {self.model_file_path}")
146
+ load_weights(loaded_model=self.model, path=self.model_file_path)
147
+ else:
148
+ print(f"{self.model_file_path} doesn't exist, training the model")
149
+ train_samples, dev_samples = generate_datasets(
150
+ num_samples=self.env_prop.get_lstm_props().num_samples,
151
+ agents=self.agents,
152
+ observation_creation_method=metrics.stochastic_amplified_selection,
153
+ problems=self.original_problems,
154
+ env_prop=self.env_prop,
155
+ gc_goal_set=self.gc_goal_set if hasattr(self, "gc_goal_set") else None,
156
+ recognizer_name=self.__class__.__name__,
157
+ )
158
+
159
+ train_dataset = GRDataset(len(train_samples), train_samples)
160
+ dev_dataset = GRDataset(len(dev_samples), dev_samples)
161
+ self.train_func(
162
+ self.model,
163
+ train_loader=DataLoader(
164
+ train_dataset,
165
+ batch_size=self.env_prop.get_lstm_props().batch_size,
166
+ shuffle=False,
167
+ collate_fn=self.collate_func,
168
+ ),
169
+ dev_loader=DataLoader(
170
+ dev_dataset,
171
+ batch_size=self.env_prop.get_lstm_props().batch_size,
172
+ shuffle=False,
173
+ collate_fn=self.collate_func,
174
+ ),
175
+ )
176
+ save_weights(model=self.model, path=self.model_file_path)
177
+
178
+ def goals_adaptation_phase(self, dynamic_goals: list[EnvProperty], save_fig=False):
179
+ self.is_first_inf_since_new_goals = True
180
+ self.current_goals = dynamic_goals
181
+ # start by training each rl agent on the base goal set
182
+ self.embeddings_dict = (
183
+ {}
184
+ ) # relevant if the embedding of the plan occurs during the goals adaptation phase
185
+ self.plans_dict = (
186
+ {}
187
+ ) # relevant if the embedding of the plan occurs during the inference phase
188
+ for goal in self.current_goals:
189
+ obss = self.generate_sequences_library(goal, save_fig=save_fig)
190
+ self.plans_dict[str(goal)] = obss
191
+
192
+ def get_goal_plan(self, goal):
193
+ assert (
194
+ self.plans_dict
195
+ ), "plans_dict wasn't created during goals_adaptation_phase and now inference phase can't return the plans. when inference_same_length, keep the plans and not their embeddings during goals_adaptation_phase."
196
+ return self.plans_dict[goal]
197
+
198
+ def dump_plans(self, true_sequence, true_goal, percentage):
199
+ assert (
200
+ self.plans_dict
201
+ ), "plans_dict wasn't created during goals_adaptation_phase and now inference phase can't return the plans. when inference_same_length, keep the plans and not their embeddings during goals_adaptation_phase."
202
+ # Arrange storage
203
+ embeddings_path = get_and_create(
204
+ get_embeddings_result_path(
205
+ domain_name=self.env_prop.domain_name,
206
+ env_name=self.env_prop.name,
207
+ recognizer=self.__class__.__name__,
208
+ )
209
+ )
210
+ self.plans_dict[f"{true_goal}_true"] = true_sequence
211
+
212
+ with open(
213
+ embeddings_path + f"/{true_goal}_{percentage}_plans_dict.pkl", "wb"
214
+ ) as plans_file:
215
+ to_dump = {}
216
+ for goal, obss in self.plans_dict.items():
217
+ if goal == f"{true_goal}_true":
218
+ to_dump[goal] = self.agents[0].agent.simplify_observation(obss)
219
+ else:
220
+ to_dump[goal] = []
221
+ for obs in obss:
222
+ addition = (
223
+ self.agents[0].agent.simplify_observation(obs)
224
+ if self.is_first_inf_since_new_goals
225
+ else obs
226
+ )
227
+ to_dump[goal].append(addition)
228
+ dill.dump(to_dump, plans_file)
229
+ self.plans_dict.pop(f"{true_goal}_true")
230
+
231
+ def create_embeddings_dict(self):
232
+ for goal, obss in self.plans_dict.items():
233
+ self.embeddings_dict[goal] = []
234
+ for cons_seq, non_cons_seq in obss:
235
+ self.embeddings_dict[goal].append(
236
+ (
237
+ self.model.embed_sequence(cons_seq),
238
+ self.model.embed_sequence(non_cons_seq),
239
+ )
240
+ )
241
+
242
+ def inference_phase(self, inf_sequence, true_goal, percentage) -> str:
243
+ embeddings_path = get_and_create(
244
+ get_embeddings_result_path(
245
+ domain_name=self.env_prop.domain_name,
246
+ env_name=self.env_prop.name,
247
+ recognizer=self.__class__.__name__,
248
+ )
249
+ )
250
+ simplified_inf_sequence = self.agents[0].agent.simplify_observation(
251
+ inf_sequence
252
+ )
253
+ new_embedding = self.model.embed_sequence(simplified_inf_sequence)
254
+ assert (
255
+ self.plans_dict
256
+ ), "plans_dict wasn't created during goals_adaptation_phase and now inference phase can't embed the plans. when inference_same_length, keep the plans and not their embeddings during goals_adaptation_phase."
257
+ if self.is_first_inf_since_new_goals:
258
+ self.is_first_inf_since_new_goals = False
259
+ self.update_sequences_library_inference_phase(inf_sequence)
260
+ self.create_embeddings_dict()
261
+
262
+ closest_goal, greatest_similarity = None, 0
263
+ for goal, embeddings in self.embeddings_dict.items():
264
+ sum_curr_similarities = 0
265
+ for cons_embedding, non_cons_embedding in embeddings:
266
+ sum_curr_similarities += max(
267
+ torch.exp(-torch.sum(torch.abs(cons_embedding - new_embedding))),
268
+ torch.exp(
269
+ -torch.sum(torch.abs(non_cons_embedding - new_embedding))
270
+ ),
271
+ )
272
+ mean_similarity = sum_curr_similarities / len(embeddings)
273
+ if mean_similarity > greatest_similarity:
274
+ closest_goal = goal
275
+ greatest_similarity = mean_similarity
276
+
277
+ self.embeddings_dict[f"{true_goal}_true"] = new_embedding
278
+ if self.collect_statistics:
279
+ with open(
280
+ os.path.join(
281
+ embeddings_path, f"{true_goal}_{percentage}_embeddings_dict.pkl"
282
+ ),
283
+ "wb",
284
+ ) as embeddings_file:
285
+ dill.dump(self.embeddings_dict, embeddings_file)
286
+ self.embeddings_dict.pop(f"{true_goal}_true")
287
+
288
+ return closest_goal
289
+
290
+ @abstractmethod
291
+ def generate_sequences_library(
292
+ self, goal: str, save_fig=False
293
+ ) -> list[list[tuple[np.ndarray, np.ndarray]]]:
294
+ pass
295
+
296
+ # this function duplicates every sequence and creates a consecutive and non-consecutive version of it
297
+ def update_sequences_library_inference_phase(
298
+ self, inf_sequence
299
+ ) -> list[list[tuple[np.ndarray, np.ndarray]]]:
300
+ new_plans_dict = {}
301
+ for goal, obss in self.plans_dict.items():
302
+ new_obss = []
303
+ for obs in obss:
304
+ consecutive_partial_obs = random_subset_with_order(
305
+ obs, len(inf_sequence), is_consecutive=True
306
+ )
307
+ non_consecutive_partial_obs = random_subset_with_order(
308
+ obs, len(inf_sequence), is_consecutive=False
309
+ )
310
+ simplified_consecutive_partial_obs = self.agents[
311
+ 0
312
+ ].agent.simplify_observation(consecutive_partial_obs)
313
+ simplified_non_consecutive_partial_obs = self.agents[
314
+ 0
315
+ ].agent.simplify_observation(non_consecutive_partial_obs)
316
+ new_obss.append(
317
+ (
318
+ simplified_consecutive_partial_obs,
319
+ simplified_non_consecutive_partial_obs,
320
+ )
321
+ )
322
+ new_plans_dict[goal] = (
323
+ new_obss # override old full observations with new partial observations with consecutive and non-consecutive versions.
324
+ )
325
+ self.plans_dict = new_plans_dict
326
+
169
327
 
170
328
  class BGGraml(Graml):
171
- def __init__(self, *args, **kwargs):
172
- super().__init__(*args, **kwargs)
173
-
174
- def domain_learning_phase(self, base_goals: List[str], train_configs: List):
175
- assert len(train_configs) == len(base_goals), "There should be train configs for every goal in BGGraml."
176
- return super().domain_learning_phase(base_goals, train_configs)
177
-
178
- # In case we need goal-directed agent for every goal
179
- def train_agents_on_base_goals(self, base_goals: List[str], train_configs: List):
180
- self.original_problems = [self.env_prop.goal_to_problem_str(g) for g in base_goals]
181
- # start by training each rl agent on the base goal set
182
- for (problem, goal), (algorithm, num_timesteps) in zip(zip(self.original_problems, base_goals), train_configs):
183
- kwargs = {"domain_name":self.domain_name, "problem_name":problem}
184
- if algorithm != None: kwargs["algorithm"] = algorithm
185
- if num_timesteps != None: kwargs["num_timesteps"] = num_timesteps
186
- agent = self.rl_agent_type(**kwargs)
187
- agent.learn()
188
- self.agents.append(ContextualAgent(problem_name=problem, problem_goal=goal, agent=agent))
329
+ """
330
+ BGGraml class represents a goal-directed agent for the BGGraml algorithm.
331
+
332
+ It extends the Graml class and provides additional methods for training agents on base goals.
333
+ """
334
+
335
+ def __init__(self, *args, **kwargs):
336
+ super().__init__(*args, **kwargs)
337
+
338
+ def domain_learning_phase(self, base_goals: list[str], train_configs: list):
339
+ assert len(train_configs) == len(
340
+ base_goals
341
+ ), "There should be train configs for every goal in BGGraml."
342
+ return super().domain_learning_phase(base_goals, train_configs)
343
+
344
+ # In case we need goal-directed agent for every goal
345
+ def train_agents_on_base_goals(self, base_goals: list[str], train_configs: list):
346
+ self.original_problems = [
347
+ self.env_prop.goal_to_problem_str(g) for g in base_goals
348
+ ]
349
+ # start by training each rl agent on the base goal set
350
+ for (problem, goal), (algorithm, num_timesteps) in zip(
351
+ zip(self.original_problems, base_goals), train_configs
352
+ ):
353
+ kwargs = {
354
+ "domain_name": self.domain_name,
355
+ "problem_name": problem,
356
+ "env_prop": self.env_prop,
357
+ }
358
+ if algorithm != None:
359
+ kwargs["algorithm"] = algorithm
360
+ if num_timesteps != None:
361
+ kwargs["num_timesteps"] = num_timesteps
362
+ agent = self.rl_agent_type(**kwargs)
363
+ agent.learn()
364
+ self.agents.append(
365
+ ContextualAgent(problem_name=problem, problem_goal=goal, agent=agent)
366
+ )
367
+
189
368
 
190
369
  class MCTSBasedGraml(BGGraml, GaAdaptingRecognizer):
191
- def __init__(self, *args, **kwargs):
192
- super().__init__(*args, **kwargs)
193
- if self.rl_agent_type==None: self.rl_agent_type = TabularQLearner
370
+ """
371
+ MCTSBasedGraml is a class that represents a recognizer based on the MCTS algorithm.
372
+ It inherits from BGGraml and GaAdaptingRecognizer classes.
373
+
374
+ Attributes:
375
+ rl_agent_type (type): The type of reinforcement learning agent used.
376
+ """
377
+
378
+ def __init__(self, *args, **kwargs):
379
+ """
380
+ Initialize the GramlRecognizer object.
381
+
382
+ Args:
383
+ *args: Variable length argument list.
384
+ **kwargs: Arbitrary keyword arguments.
385
+
386
+ """
387
+ super().__init__(*args, **kwargs)
388
+ if self.rl_agent_type == None:
389
+ self.rl_agent_type = TabularQLearner
390
+
391
+ def generate_sequences_library(
392
+ self, goal: str, save_fig=False
393
+ ) -> list[list[tuple[np.ndarray, np.ndarray]]]:
394
+ """
395
+ Generates a library of sequences for a given goal.
396
+
397
+ Args:
398
+ goal (str): The goal for which to generate sequences.
399
+ save_fig (bool, optional): Whether to save the generated figure. Defaults to False.
400
+
401
+ Returns:
402
+ list[list[tuple[np.ndarray, np.ndarray]]]: The generated sequences library.
403
+ """
404
+ problem_name = self.env_prop.goal_to_problem_str(goal)
405
+ img_path = os.path.join(
406
+ get_policy_sequences_result_path(
407
+ self.env_prop.domain_name, recognizer=self.__class__.__name__
408
+ ),
409
+ problem_name + "_MCTS",
410
+ )
411
+ return mcts_model.plan(
412
+ self.env_prop.name,
413
+ problem_name,
414
+ goal,
415
+ save_fig=save_fig,
416
+ img_path=img_path,
417
+ env_prop=self.env_prop,
418
+ )
194
419
 
195
- def generate_sequences_library(self, goal: str, save_fig=False) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
196
- problem_name = self.env_prop.goal_to_problem_str(goal)
197
- img_path = os.path.join(get_policy_sequences_result_path(self.env_prop.domain_name, recognizer=self.__class__.__name__), problem_name + "_MCTS")
198
- return mcts_model.plan(self.env_prop.name, problem_name, goal, save_fig=save_fig, img_path=img_path, env_prop=self.env_prop)
199
420
 
200
421
  class ExpertBasedGraml(BGGraml, GaAgentTrainerRecognizer):
201
- def __init__(self, *args, **kwargs):
202
- super().__init__(*args, **kwargs)
203
- if self.rl_agent_type==None:
204
- if self.env_prop.is_state_discrete() and self.env_prop.is_action_discrete():
205
- self.rl_agent_type = TabularQLearner
206
- else:
207
- self.rl_agent_type = DeepRLAgent
208
-
209
- def generate_sequences_library(self, goal: str, save_fig=False) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
210
- problem_name = self.env_prop.goal_to_problem_str(goal)
211
- kwargs = {"domain_name":self.domain_name, "problem_name":problem_name}
212
- if self.dynamic_train_configs_dict[problem_name][0] != None: kwargs["algorithm"] = self.dynamic_train_configs_dict[problem_name][0]
213
- if self.dynamic_train_configs_dict[problem_name][1] != None: kwargs["num_timesteps"] = self.dynamic_train_configs_dict[problem_name][1]
214
- agent = self.rl_agent_type(**kwargs)
215
- agent.learn()
216
- agent_kwargs = {
217
- "action_selection_method": metrics.greedy_selection,
218
- "random_optimalism": False,
219
- "save_fig": save_fig,
220
- "env_prop": self.env_prop
221
- }
222
- if save_fig:
223
- fig_path = get_and_create(f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name, recognizer=self.__class__.__name__), problem_name))}_bg_sequence")
224
- agent_kwargs["fig_path"] = fig_path
225
- return [agent.generate_observation(**agent_kwargs)]
226
-
227
- def goals_adaptation_phase(self, dynamic_goals: List[str], dynamic_train_configs):
228
- self.dynamic_goals_problems = [self.env_prop.goal_to_problem_str(g) for g in dynamic_goals]
229
- self.dynamic_train_configs_dict = {problem:config for problem, config in zip(self.dynamic_goals_problems,dynamic_train_configs)}
230
- return super().goals_adaptation_phase(dynamic_goals)
422
+ """
423
+ ExpertBasedGraml class represents a Graml recognizer that uses expert knowledge to generate sequences library and adapt goals.
424
+
425
+ Args:
426
+ *args: Variable length argument list.
427
+ **kwargs: Arbitrary keyword arguments.
428
+
429
+ Attributes:
430
+ rl_agent_type (type): The type of reinforcement learning agent used.
431
+ env_prop (EnvironmentProperties): The environment properties.
432
+ dynamic_train_configs_dict (dict): The dynamic training configurations for each problem.
433
+
434
+ """
435
+
436
+ def __init__(self, *args, **kwargs):
437
+ """
438
+ Initialize the GRAML Recognizer.
439
+
440
+ Args:
441
+ *args: Variable length argument list.
442
+ **kwargs: Arbitrary keyword arguments.
443
+
444
+ """
445
+ super().__init__(*args, **kwargs)
446
+ if self.rl_agent_type == None:
447
+ if self.env_prop.is_state_discrete() and self.env_prop.is_action_discrete():
448
+ self.rl_agent_type = TabularQLearner
449
+ else:
450
+ self.rl_agent_type = DeepRLAgent
451
+
452
+ def generate_sequences_library(
453
+ self, goal: str, save_fig=False
454
+ ) -> list[list[tuple[np.ndarray, np.ndarray]]]:
455
+ """
456
+ Generates a sequences library for a given goal.
457
+
458
+ Args:
459
+ goal (str): The goal for which to generate the sequences library.
460
+ save_fig (bool, optional): Whether to save the figure. Defaults to False.
461
+
462
+ Returns:
463
+ list[list[tuple[np.ndarray, np.ndarray]]]: The generated sequences library.
464
+
465
+ """
466
+ problem_name = self.env_prop.goal_to_problem_str(goal)
467
+ kwargs = {
468
+ "domain_name": self.domain_name,
469
+ "problem_name": problem_name,
470
+ "env_prop": self.env_prop,
471
+ }
472
+ if self.dynamic_train_configs_dict[problem_name][0] != None:
473
+ kwargs["algorithm"] = self.dynamic_train_configs_dict[problem_name][0]
474
+ if self.dynamic_train_configs_dict[problem_name][1] != None:
475
+ kwargs["num_timesteps"] = self.dynamic_train_configs_dict[problem_name][1]
476
+ agent = self.rl_agent_type(**kwargs)
477
+ agent.learn()
478
+ agent_kwargs = {
479
+ "action_selection_method": metrics.greedy_selection,
480
+ "random_optimalism": False,
481
+ "save_fig": save_fig,
482
+ }
483
+ if save_fig:
484
+ fig_path = get_and_create(
485
+ f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name, recognizer=self.__class__.__name__), problem_name))}_bg_sequence"
486
+ )
487
+ agent_kwargs["fig_path"] = fig_path
488
+ return [agent.generate_observation(**agent_kwargs)]
489
+
490
+ def goals_adaptation_phase(self, dynamic_goals: list[str], dynamic_train_configs):
491
+ """
492
+ Performs the goals adaptation phase.
493
+
494
+ Args:
495
+ dynamic_goals (list[str]): The dynamic goals.
496
+ dynamic_train_configs: The dynamic training configurations.
497
+
498
+ Returns:
499
+ The result of the goals adaptation phase.
500
+
501
+ """
502
+ self.dynamic_goals_problems = [
503
+ self.env_prop.goal_to_problem_str(g) for g in dynamic_goals
504
+ ]
505
+ self.dynamic_train_configs_dict = {
506
+ problem: config
507
+ for problem, config in zip(
508
+ self.dynamic_goals_problems, dynamic_train_configs
509
+ )
510
+ }
511
+ return super().goals_adaptation_phase(dynamic_goals)
512
+
231
513
 
232
514
  class GCGraml(Graml, GaAdaptingRecognizer):
233
- def __init__(self, *args, **kwargs):
234
- super().__init__(*args, **kwargs)
235
- if self.rl_agent_type==None: self.rl_agent_type = GCDeepRLAgent
236
- assert self.env_prop.gc_adaptable() and not self.env_prop.is_state_discrete() and not self.env_prop.is_action_discrete()
237
-
238
- def domain_learning_phase(self, base_goals: List[str], train_configs: List):
239
- assert len(train_configs) == 1, "There should be one train config for the sole gc agent in GCGraml."
240
- return super().domain_learning_phase(base_goals, train_configs)
241
-
242
- # In case we need goal-directed agent for every goal
243
- def train_agents_on_base_goals(self, base_goals: List[str], train_configs: List):
244
- self.gc_goal_set = base_goals
245
- self.original_problems = self.env_prop.name # needed for gr_dataset
246
- # start by training each rl agent on the base goal set
247
- kwargs = {"domain_name":self.domain_name, "problem_name":self.env_prop.name}
248
- algorithm, num_timesteps = train_configs[0] # should only be one, was asserted
249
- if algorithm != None: kwargs["algorithm"] = algorithm
250
- if num_timesteps != None: kwargs["num_timesteps"] = num_timesteps
251
- gc_agent = self.rl_agent_type(**kwargs)
252
- gc_agent.learn()
253
- self.agents.append(ContextualAgent(problem_name=self.env_prop.name, problem_goal="general", agent=gc_agent))
254
-
255
- def generate_sequences_library(self, goal: str, save_fig=False) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
256
- problem_name = self.env_prop.goal_to_problem_str(goal)
257
- kwargs = {"domain_name":self.domain_name, "problem_name":self.env_prop.name} # problem name is env name in gc case
258
- if self.original_train_configs[0][0] != None: kwargs["algorithm"] = self.original_train_configs[0][0]
259
- if self.original_train_configs[0][1] != None: kwargs["num_timesteps"] = self.original_train_configs[0][1]
260
- agent = self.rl_agent_type(**kwargs)
261
- agent.learn()
262
- agent_kwargs = {
263
- "action_selection_method": metrics.stochastic_amplified_selection,
264
- "random_optimalism": True,
265
- "save_fig": save_fig
266
- }
267
- if save_fig:
268
- fig_path = get_and_create(f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name, recognizer=self.__class__.__name__), problem_name))}_gc_sequence")
269
- agent_kwargs["fig_path"] = fig_path
270
- if self.env_prop.use_goal_directed_problem(): agent_kwargs["goal_directed_problem"] = problem_name
271
- else: agent_kwargs["goal_directed_goal"] = goal
272
- obss = []
273
- for _ in range(5): obss.append(agent.generate_observation(**agent_kwargs))
274
- return obss
515
+ """
516
+ GCGraml class represents a recognizer that uses the GCDeepRLAgent for domain learning and sequence generation.
517
+ It makes its adaptation phase quicker and require less assumptions, but the assumption of a GC agent is still needed and may result
518
+ in less optimal policies that generate the observations in the synthetic dataset, which could eventually lead to a less optimal metric.
519
+
520
+ Args:
521
+ Graml (class): Base class for Graml recognizers.
522
+ GaAdaptingRecognizer (class): Base class for GA adapting recognizers.
523
+
524
+ Attributes:
525
+ rl_agent_type (class): The type of RL agent to be used for learning and generation.
526
+ env_prop (object): The environment properties.
527
+ agents (list): List of contextual agents.
528
+
529
+ Methods:
530
+ __init__: Initializes the GCGraml recognizer.
531
+ domain_learning_phase: Performs the domain learning phase.
532
+ train_agents_on_base_goals: Trains the RL agents on the base goals.
533
+ generate_sequences_library: Generates sequences library for a specific goal.
534
+
535
+ """
536
+
537
+ def __init__(self, *args, **kwargs):
538
+ super().__init__(*args, **kwargs)
539
+ if self.rl_agent_type == None:
540
+ self.rl_agent_type = GCDeepRLAgent
541
+ assert (
542
+ self.env_prop.gc_adaptable()
543
+ and not self.env_prop.is_state_discrete()
544
+ and not self.env_prop.is_action_discrete()
545
+ )
546
+
547
+ def domain_learning_phase(self, base_goals: list[str], train_configs: list):
548
+ assert (
549
+ len(train_configs) == 1
550
+ ), "There should be one train config for the sole gc agent in GCGraml."
551
+ return super().domain_learning_phase(base_goals, train_configs)
552
+
553
+ # In case we need goal-directed agent for every goal
554
+ def train_agents_on_base_goals(self, base_goals: list[str], train_configs: list):
555
+ self.gc_goal_set = base_goals
556
+ self.original_problems = self.env_prop.name # needed for gr_dataset
557
+ # start by training each rl agent on the base goal set
558
+ kwargs = {
559
+ "domain_name": self.domain_name,
560
+ "problem_name": self.env_prop.name,
561
+ "env_prop": self.env_prop,
562
+ }
563
+ algorithm, num_timesteps = train_configs[0] # should only be one, was asserted
564
+ if algorithm != None:
565
+ kwargs["algorithm"] = algorithm
566
+ if num_timesteps != None:
567
+ kwargs["num_timesteps"] = num_timesteps
568
+ gc_agent = self.rl_agent_type(**kwargs)
569
+ gc_agent.learn()
570
+ self.agents.append(
571
+ ContextualAgent(
572
+ problem_name=self.env_prop.name, problem_goal="general", agent=gc_agent
573
+ )
574
+ )
575
+
576
+ def generate_sequences_library(
577
+ self, goal: str, save_fig=False
578
+ ) -> list[list[tuple[np.ndarray, np.ndarray]]]:
579
+ problem_name = self.env_prop.goal_to_problem_str(goal)
580
+ kwargs = {
581
+ "domain_name": self.domain_name,
582
+ "problem_name": self.env_prop.name,
583
+ "env_prop": self.env_prop,
584
+ } # problem name is env name in gc case
585
+ if self.original_train_configs[0][0] != None:
586
+ kwargs["algorithm"] = self.original_train_configs[0][0]
587
+ if self.original_train_configs[0][1] != None:
588
+ kwargs["num_timesteps"] = self.original_train_configs[0][1]
589
+ agent = self.rl_agent_type(**kwargs)
590
+ agent.learn()
591
+ agent_kwargs = {
592
+ "action_selection_method": metrics.stochastic_amplified_selection,
593
+ "random_optimalism": True,
594
+ "save_fig": save_fig,
595
+ }
596
+ if save_fig:
597
+ fig_path = get_and_create(
598
+ f"{os.path.abspath(os.path.join(get_policy_sequences_result_path(domain_name=self.env_prop.domain_name, env_name=self.env_prop.name, recognizer=self.__class__.__name__), problem_name))}_gc_sequence"
599
+ )
600
+ agent_kwargs["fig_path"] = fig_path
601
+ if self.env_prop.use_goal_directed_problem():
602
+ agent_kwargs["goal_directed_problem"] = problem_name
603
+ else:
604
+ agent_kwargs["goal_directed_goal"] = goal
605
+ obss = []
606
+ for _ in range(5):
607
+ obss.append(agent.generate_observation(**agent_kwargs))
608
+ return obss