gr-libs 0.1.7.post0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gr_libs/__init__.py +4 -1
- gr_libs/_evaluation/__init__.py +1 -0
- gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +260 -0
- gr_libs/_evaluation/_generate_experiments_results.py +141 -0
- gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +497 -0
- gr_libs/_evaluation/_get_plans_images.py +61 -0
- gr_libs/_evaluation/_increasing_and_decreasing_.py +106 -0
- gr_libs/_version.py +2 -2
- gr_libs/all_experiments.py +294 -0
- gr_libs/environment/__init__.py +30 -9
- gr_libs/environment/_utils/utils.py +27 -0
- gr_libs/environment/environment.py +417 -54
- gr_libs/metrics/__init__.py +7 -0
- gr_libs/metrics/metrics.py +231 -54
- gr_libs/ml/__init__.py +2 -5
- gr_libs/ml/agent.py +21 -6
- gr_libs/ml/base/__init__.py +3 -1
- gr_libs/ml/base/rl_agent.py +81 -13
- gr_libs/ml/consts.py +1 -1
- gr_libs/ml/neural/__init__.py +1 -3
- gr_libs/ml/neural/deep_rl_learner.py +619 -378
- gr_libs/ml/neural/utils/__init__.py +1 -2
- gr_libs/ml/neural/utils/dictlist.py +3 -3
- gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +1 -1
- gr_libs/ml/planner/mcts/{utils → _utils}/node.py +11 -7
- gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +15 -11
- gr_libs/ml/planner/mcts/mcts_model.py +571 -312
- gr_libs/ml/sequential/__init__.py +0 -1
- gr_libs/ml/sequential/_lstm_model.py +270 -0
- gr_libs/ml/tabular/__init__.py +1 -3
- gr_libs/ml/tabular/state.py +7 -7
- gr_libs/ml/tabular/tabular_q_learner.py +150 -82
- gr_libs/ml/tabular/tabular_rl_agent.py +42 -28
- gr_libs/ml/utils/__init__.py +2 -3
- gr_libs/ml/utils/format.py +28 -97
- gr_libs/ml/utils/math.py +5 -3
- gr_libs/ml/utils/other.py +3 -3
- gr_libs/ml/utils/storage.py +88 -81
- gr_libs/odgr_executor.py +268 -0
- gr_libs/problems/consts.py +1549 -1227
- gr_libs/recognizer/_utils/__init__.py +0 -0
- gr_libs/recognizer/_utils/format.py +18 -0
- gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +233 -88
- gr_libs/recognizer/graml/_gr_dataset.py +233 -0
- gr_libs/recognizer/graml/graml_recognizer.py +586 -252
- gr_libs/recognizer/recognizer.py +90 -30
- gr_libs/tutorials/draco_panda_tutorial.py +58 -0
- gr_libs/tutorials/draco_parking_tutorial.py +56 -0
- gr_libs/tutorials/gcdraco_panda_tutorial.py +62 -0
- gr_libs/tutorials/gcdraco_parking_tutorial.py +57 -0
- gr_libs/tutorials/graml_minigrid_tutorial.py +64 -0
- gr_libs/tutorials/graml_panda_tutorial.py +57 -0
- gr_libs/tutorials/graml_parking_tutorial.py +52 -0
- gr_libs/tutorials/graml_point_maze_tutorial.py +60 -0
- gr_libs/tutorials/graql_minigrid_tutorial.py +50 -0
- {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/METADATA +84 -29
- gr_libs-0.2.2.dist-info/RECORD +71 -0
- {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/WHEEL +1 -1
- gr_libs-0.2.2.dist-info/top_level.txt +2 -0
- tests/test_draco.py +14 -0
- tests/test_gcdraco.py +10 -0
- tests/test_graml.py +12 -8
- tests/test_graql.py +3 -2
- evaluation/analyze_results_cross_alg_cross_domain.py +0 -277
- evaluation/create_minigrid_map_image.py +0 -34
- evaluation/file_system.py +0 -42
- evaluation/generate_experiments_results.py +0 -92
- evaluation/generate_experiments_results_new_ver1.py +0 -254
- evaluation/generate_experiments_results_new_ver2.py +0 -331
- evaluation/generate_task_specific_statistics_plots.py +0 -272
- evaluation/get_plans_images.py +0 -47
- evaluation/increasing_and_decreasing_.py +0 -63
- gr_libs/environment/utils/utils.py +0 -17
- gr_libs/ml/neural/utils/penv.py +0 -57
- gr_libs/ml/sequential/lstm_model.py +0 -192
- gr_libs/recognizer/graml/gr_dataset.py +0 -134
- gr_libs/recognizer/utils/__init__.py +0 -1
- gr_libs/recognizer/utils/format.py +0 -13
- gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
- gr_libs-0.1.7.post0.dist-info/top_level.txt +0 -4
- tutorials/graml_minigrid_tutorial.py +0 -34
- tutorials/graml_panda_tutorial.py +0 -41
- tutorials/graml_parking_tutorial.py +0 -39
- tutorials/graml_point_maze_tutorial.py +0 -39
- tutorials/graql_minigrid_tutorial.py +0 -34
- /gr_libs/environment/{utils → _utils}/__init__.py +0 -0
@@ -1,272 +0,0 @@
|
|
1
|
-
import argparse
|
2
|
-
import sys
|
3
|
-
import matplotlib.pyplot as plt
|
4
|
-
import numpy as np
|
5
|
-
import os
|
6
|
-
import ast
|
7
|
-
import inspect
|
8
|
-
import torch
|
9
|
-
import dill
|
10
|
-
|
11
|
-
from gr_libs.ml.utils import get_embeddings_result_path
|
12
|
-
from gr_libs.ml.utils.storage import get_experiment_results_path, set_global_storage_configs, get_graql_experiment_confidence_path
|
13
|
-
from gr_libs.metrics.metrics import measure_average_sequence_distance
|
14
|
-
|
15
|
-
def get_tasks_embeddings_dir_path(env_name):
|
16
|
-
return os.path.join("../gr_libs", get_embeddings_result_path(env_name))
|
17
|
-
|
18
|
-
def get_figures_dir_path(domain_name, env_name):
|
19
|
-
return os.path.join("../gr_libs", "figures", domain_name, env_name)
|
20
|
-
|
21
|
-
def similarities_vector_to_std_deviation_units_vector(ref_dict: dict, relative_to_largest):
|
22
|
-
"""
|
23
|
-
Calculate the number of standard deviation units every other element is
|
24
|
-
from the largest/smallest element in the vector.
|
25
|
-
|
26
|
-
Parameters:
|
27
|
-
- vector: list or numpy array of numbers.
|
28
|
-
- relative_to_largest: boolean, if True, measure in relation to the largest element,
|
29
|
-
if False, measure in relation to the smallest element.
|
30
|
-
|
31
|
-
Returns:
|
32
|
-
- List of number of standard deviation units for each element in the vector.
|
33
|
-
"""
|
34
|
-
vector = np.array(list(ref_dict.values()))
|
35
|
-
mean = np.mean(vector) # for the future maybe another method for measurement
|
36
|
-
std_dev = np.std(vector)
|
37
|
-
|
38
|
-
# Determine the reference element (largest or smallest)
|
39
|
-
if relative_to_largest:
|
40
|
-
reference_value = np.max(vector)
|
41
|
-
else:
|
42
|
-
reference_value = np.min(vector)
|
43
|
-
for goal, value in ref_dict.items():
|
44
|
-
ref_dict[goal] = abs(value - reference_value) / std_dev
|
45
|
-
return ref_dict
|
46
|
-
|
47
|
-
def analyze_and_produce_plots(recognizer_type: str, domain_name: str, env_name: str, fragmented_status: str, inf_same_length_status: str, learn_same_length_status: str):
|
48
|
-
if recognizer_type == "graml":
|
49
|
-
assert os.path.exists(get_embeddings_result_path(domain_name)), "Embeddings weren't made for this environment, run graml_main.py with this environment first."
|
50
|
-
tasks_embedding_dicts = {}
|
51
|
-
tasks_plans_dict = {}
|
52
|
-
goals_similarity_dict = {}
|
53
|
-
plans_similarity_dict = {}
|
54
|
-
|
55
|
-
embeddings_dir_path = get_tasks_embeddings_dir_path(domain_name)
|
56
|
-
for embeddings_file_name in [filename for filename in os.listdir(embeddings_dir_path) if 'embeddings' in filename]:
|
57
|
-
with open(os.path.join(embeddings_dir_path, embeddings_file_name), 'rb') as emb_file:
|
58
|
-
splitted_name = embeddings_file_name.split('_')
|
59
|
-
goal, percentage = splitted_name[0], splitted_name[1]
|
60
|
-
with open(os.path.join(embeddings_dir_path, f'{goal}_{percentage}_plans_dict.pkl'), 'rb') as plan_file:
|
61
|
-
tasks_plans_dict[f"{goal}_{percentage}"] = dill.load(plan_file)
|
62
|
-
tasks_embedding_dicts[f"{goal}_{percentage}"] = dill.load(emb_file)
|
63
|
-
|
64
|
-
for goal_percentage, embedding_dict in tasks_embedding_dicts.items():
|
65
|
-
goal, percentage = goal_percentage.split('_')
|
66
|
-
similarities = {dynamic_goal: [] for dynamic_goal in embedding_dict.keys() if 'true' not in dynamic_goal}
|
67
|
-
real_goal_embedding = embedding_dict[f"{goal}_true"]
|
68
|
-
for dynamic_goal, goal_embedding in embedding_dict.items():
|
69
|
-
if 'true' in dynamic_goal: continue
|
70
|
-
curr_similarity = torch.exp(-torch.sum(torch.abs(goal_embedding-real_goal_embedding)))
|
71
|
-
similarities[dynamic_goal] = curr_similarity.item()
|
72
|
-
if goal not in goals_similarity_dict.keys(): goals_similarity_dict[goal] = {}
|
73
|
-
goals_similarity_dict[goal][percentage] = similarities_vector_to_std_deviation_units_vector(ref_dict=similarities, relative_to_largest=True)
|
74
|
-
|
75
|
-
for goal_percentage, plans_dict in tasks_plans_dict.items():
|
76
|
-
goal, percentage = goal_percentage.split('_')
|
77
|
-
real_plan = plans_dict[f"{goal}_true"]
|
78
|
-
sequence_similarities = {d_goal:measure_average_sequence_distance(real_plan, plan) for d_goal,plan in plans_dict.items() if 'true' not in d_goal} # aps = agent plan sequence?
|
79
|
-
if goal not in plans_similarity_dict.keys(): plans_similarity_dict[goal] = {}
|
80
|
-
plans_similarity_dict[goal][percentage] = similarities_vector_to_std_deviation_units_vector(ref_dict=sequence_similarities, relative_to_largest=False)
|
81
|
-
|
82
|
-
goals = list(goals_similarity_dict.keys())
|
83
|
-
percentages = sorted(set(percentage for similarities in goals_similarity_dict.values() for percentage in similarities.keys()))
|
84
|
-
num_percentages = len(percentages)
|
85
|
-
fig_string = f"{recognizer_type}_{domain_name}_{env_name}_{fragmented_status}_{inf_same_length_status}_{learn_same_length_status}"
|
86
|
-
|
87
|
-
else: # algorithm = "graql"
|
88
|
-
assert os.path.exists(get_graql_experiment_confidence_path(domain_name)), "Embeddings weren't made for this environment, run graml_main.py with this environment first."
|
89
|
-
tasks_scores_dict = {}
|
90
|
-
goals_similarity_dict = {}
|
91
|
-
experiments_dir_path = get_graql_experiment_confidence_path(domain_name)
|
92
|
-
for experiments_file_name in os.listdir(experiments_dir_path):
|
93
|
-
with open(os.path.join(experiments_dir_path, experiments_file_name), 'rb') as exp_file:
|
94
|
-
splitted_name = experiments_file_name.split('_')
|
95
|
-
goal, percentage = splitted_name[1], splitted_name[2]
|
96
|
-
tasks_scores_dict[f"{goal}_{percentage}"] = dill.load(exp_file)
|
97
|
-
|
98
|
-
for goal_percentage, scores_list in tasks_scores_dict.items():
|
99
|
-
goal, percentage = goal_percentage.split('_')
|
100
|
-
similarities = {dynamic_goal: score for (dynamic_goal, score) in scores_list}
|
101
|
-
if goal not in goals_similarity_dict.keys(): goals_similarity_dict[goal] = {}
|
102
|
-
goals_similarity_dict[goal][percentage] = similarities_vector_to_std_deviation_units_vector(ref_dict=similarities, relative_to_largest=False)
|
103
|
-
|
104
|
-
goals = list(goals_similarity_dict.keys())
|
105
|
-
percentages = sorted(set(percentage for similarities in goals_similarity_dict.values() for percentage in similarities.keys()))
|
106
|
-
num_percentages = len(percentages)
|
107
|
-
fig_string = f"{recognizer_type}_{domain_name}_{env_name}_{fragmented_status}"
|
108
|
-
|
109
|
-
# -------------------- Start of Confusion Matrix Code --------------------
|
110
|
-
# Initialize matrices of size len(goals) x len(goals)
|
111
|
-
confusion_matrix_goals, confusion_matrix_plans = np.zeros((len(goals), len(goals))), np.zeros((len(goals), len(goals)))
|
112
|
-
|
113
|
-
# if domain_name == 'point_maze' and args.task == 'L555':
|
114
|
-
# if env_name == 'obstacles':
|
115
|
-
# goals = ['(4, 7)', '(3, 6)', '(5, 5)', '(8, 8)', '(6, 3)', '(7, 4)']
|
116
|
-
# else: # if env_name is 'four_rooms'
|
117
|
-
# goals = ['(2, 8)', '(3, 7)', '(3, 4)', '(4, 4)', '(4, 3)', '(7, 3)', '(8, 2)']
|
118
|
-
|
119
|
-
# Populate confusion matrix with similarity values for goals
|
120
|
-
for i, true_goal in enumerate(goals):
|
121
|
-
for j, dynamic_goal in enumerate(goals):
|
122
|
-
percentage = percentages[-3]
|
123
|
-
confusion_matrix_goals[i, j] = goals_similarity_dict[true_goal][percentage].get(dynamic_goal, 0)
|
124
|
-
|
125
|
-
if plans_similarity_dict:
|
126
|
-
# Populate confusion matrix with similarity values for plans
|
127
|
-
for i, true_goal in enumerate(goals):
|
128
|
-
for j, dynamic_goal in enumerate(goals):
|
129
|
-
percentage = percentages[-1]
|
130
|
-
confusion_matrix_plans[i, j] = plans_similarity_dict[true_goal][percentage].get(dynamic_goal, 0)
|
131
|
-
|
132
|
-
# Create the figure and subplots for the unified display
|
133
|
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6), sharex=True)
|
134
|
-
|
135
|
-
# Plot for goal similarities
|
136
|
-
im1 = ax1.imshow(confusion_matrix_goals, cmap='Blues', interpolation='nearest')
|
137
|
-
cbar1 = fig.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)
|
138
|
-
cbar1.set_label('St. dev from most probable goal', fontsize=18)
|
139
|
-
ax1.set_title('Embeddings', fontsize=22, pad=20)
|
140
|
-
ax1.set_xticks(np.arange(len(goals)))
|
141
|
-
ax1.set_xticklabels(goals, rotation=45, ha="right", fontsize=16)
|
142
|
-
ax1.set_yticks(np.arange(len(goals)))
|
143
|
-
ax1.set_yticklabels(goals, fontsize=16) # y-tick labels for ax1
|
144
|
-
|
145
|
-
# Plot for plan similarities
|
146
|
-
im2 = ax2.imshow(confusion_matrix_plans, cmap='Greens', interpolation='nearest')
|
147
|
-
cbar2 = fig.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)
|
148
|
-
cbar2.set_label('Distance between plans', fontsize=18)
|
149
|
-
ax2.set_title('Sequences', fontsize=22, pad=20)
|
150
|
-
ax2.set_xticks(np.arange(len(goals)))
|
151
|
-
ax2.set_xticklabels(goals, rotation=45, ha="right", fontsize=16)
|
152
|
-
ax2.set_yticks(np.arange(len(goals))) # y-ticks for ax2 explicitly
|
153
|
-
ax2.set_yticklabels(goals, fontsize=16) # y-tick labels for ax2
|
154
|
-
|
155
|
-
# Adjust the figure layout to reduce overlap
|
156
|
-
plt.subplots_adjust(left=0.15, right=0.9, bottom=0.25, top=0.85, wspace=0.1)
|
157
|
-
|
158
|
-
# Unified axis labels, placed closer to the left
|
159
|
-
fig.text(0.57, 0.07, 'Goals Adaptation Phase', ha='center', fontsize=22)
|
160
|
-
fig.text(0.12, 0.5, 'Inference Phase', va='center', rotation='vertical', fontsize=22)
|
161
|
-
|
162
|
-
# Save the combined plot
|
163
|
-
fig_dir = get_figures_dir_path(domain_name=domain_name, env_name=env_name)
|
164
|
-
if not os.path.exists(fig_dir):
|
165
|
-
os.makedirs(fig_dir)
|
166
|
-
confusion_matrix_combined_path = os.path.join(fig_dir, f"{fig_string}_combined_conf_mat.png")
|
167
|
-
plt.savefig(confusion_matrix_combined_path, dpi=300)
|
168
|
-
print(f"Combined confusion matrix figure saved at: {confusion_matrix_combined_path}")
|
169
|
-
|
170
|
-
# -------------------- End of Confusion Matrix Code --------------------
|
171
|
-
fig, axes = plt.subplots(nrows=num_percentages, ncols=1, figsize=(10, 6 * num_percentages))
|
172
|
-
|
173
|
-
if num_percentages == 1:
|
174
|
-
axes = [axes]
|
175
|
-
|
176
|
-
for i, percentage in enumerate(percentages):
|
177
|
-
correct_tasks, tasks_num = 0, 0
|
178
|
-
ax = axes[i]
|
179
|
-
dynamic_goals = list(next(iter(goals_similarity_dict.values()))[percentage].keys())
|
180
|
-
num_goals = len(goals)
|
181
|
-
num_dynamic_goals = len(dynamic_goals)
|
182
|
-
bar_width = 0.8 / num_dynamic_goals
|
183
|
-
bar_positions = np.arange(num_goals)
|
184
|
-
|
185
|
-
if recognizer_type == "graml":
|
186
|
-
for j, dynamic_goal in enumerate(dynamic_goals):
|
187
|
-
goal_similarities = [goals_similarity_dict[goal][percentage][dynamic_goal] + 0.04 for goal in goals]
|
188
|
-
plan_similarities = [plans_similarity_dict[goal][percentage][dynamic_goal] + 0.04 for goal in goals]
|
189
|
-
ax.bar(bar_positions + j * bar_width, goal_similarities, bar_width/2, label=f"embedding of {dynamic_goal}")
|
190
|
-
ax.bar(bar_positions + j * bar_width + bar_width/2, plan_similarities, bar_width/2, label=f"plan to {dynamic_goal}")
|
191
|
-
else:
|
192
|
-
for j, dynamic_goal in enumerate(dynamic_goals):
|
193
|
-
goal_similarities = [goals_similarity_dict[goal][percentage][dynamic_goal] + 0.04 for goal in goals]
|
194
|
-
ax.bar(bar_positions + j * bar_width, goal_similarities, bar_width, label=f"policy to {dynamic_goal}")
|
195
|
-
|
196
|
-
x_labels = []
|
197
|
-
for true_goal in goals:
|
198
|
-
guessed_goal = min(goals_similarity_dict[true_goal][percentage], key=goals_similarity_dict[true_goal][percentage].get)
|
199
|
-
tasks_num += 1
|
200
|
-
if true_goal == guessed_goal: correct_tasks += 1
|
201
|
-
second_lowest_value = sorted(goals_similarity_dict[true_goal][percentage].values())[1]
|
202
|
-
confidence_level = abs(goals_similarity_dict[true_goal][percentage][guessed_goal] - second_lowest_value)
|
203
|
-
label = f"True: {true_goal}\nGuessed: {guessed_goal}\nConfidence: {confidence_level:.2f}"
|
204
|
-
x_labels.append(label)
|
205
|
-
|
206
|
-
ax.set_ylabel('Distance (units in st. deviations)', fontsize=10)
|
207
|
-
ax.set_title(f'Confidence level for {domain_name}, {env_name}, {fragmented_status}. Accuracy: {correct_tasks / tasks_num}', fontsize=12)
|
208
|
-
ax.set_xticks(bar_positions + bar_width * (num_dynamic_goals - 1) / 2)
|
209
|
-
ax.set_xticklabels(x_labels, fontsize=8)
|
210
|
-
ax.legend()
|
211
|
-
|
212
|
-
fig_path = os.path.join(fig_dir, f"{fig_string}_stats.png")
|
213
|
-
fig.savefig(fig_path)
|
214
|
-
print(f"general figure saved at: {fig_path}")
|
215
|
-
|
216
|
-
|
217
|
-
def parse_args():
|
218
|
-
parser = argparse.ArgumentParser(
|
219
|
-
description="Parse command-line arguments for the RL experiment.",
|
220
|
-
formatter_class=argparse.RawTextHelpFormatter
|
221
|
-
)
|
222
|
-
|
223
|
-
# Required arguments
|
224
|
-
required_group = parser.add_argument_group("Required arguments")
|
225
|
-
required_group.add_argument("--domain", choices=["point_maze", "minigrid", "parking", "franka_kitchen", "panda"], required=True, help="Domain type (point_maze, minigrid, parking, or franka_kitchen)")
|
226
|
-
required_group.add_argument("--recognizer", choices=["graml", "graql", "draco"], required=True, help="Recognizer type (graml, graql, draco). graql only for discrete domains.")
|
227
|
-
required_group.add_argument("--task", choices=["L1", "L2", "L3", "L4", "L5", "L11", "L22", "L33", "L44", "L55", "L111", "L222", "L333", "L444", "L555"], required=True, help="Task identifier (e.g., L1, L2,...,L5)")
|
228
|
-
required_group.add_argument("--partial_obs_type", required=True, choices=["fragmented", "continuing"], help="Give fragmented or continuing partial observations for inference phase inputs.")
|
229
|
-
|
230
|
-
# Optional arguments
|
231
|
-
optional_group = parser.add_argument_group("Optional arguments")
|
232
|
-
optional_group.add_argument("--minigrid_env", choices=["four_rooms", "obstacles"], help="Minigrid environment (four_rooms or obstacles)")
|
233
|
-
optional_group.add_argument("--parking_env", choices=["gd_agent", "gc_agent"], help="Parking environment (agent or gc_agent)")
|
234
|
-
optional_group.add_argument("--point_maze_env", choices=["obstacles", "four_rooms"], help="Parking environment (agent or gc_agent)")
|
235
|
-
optional_group.add_argument("--franka_env", choices=["comb1", "comb2"], help="Franka Kitchen environment (comb1 or comb2)")
|
236
|
-
optional_group.add_argument("--panda_env", choices=["gc_agent", "gd_agent"], help="Panda Robotics environment (gc_agent or gd_agent)")
|
237
|
-
optional_group.add_argument("--learn_same_seq_len", action="store_true", help="Learn with the same sequence length")
|
238
|
-
optional_group.add_argument("--inference_same_seq_len", action="store_true", help="Infer with the same sequence length")
|
239
|
-
|
240
|
-
args = parser.parse_args()
|
241
|
-
|
242
|
-
### VALIDATE INPUTS ###
|
243
|
-
# Assert that all required arguments are provided
|
244
|
-
assert args.domain is not None and args.recognizer is not None and args.task is not None, "Missing required arguments: domain, recognizer, or task"
|
245
|
-
|
246
|
-
# Validate the combination of domain and environment
|
247
|
-
if args.domain == "minigrid" and args.minigrid_env is None:
|
248
|
-
parser.error("Missing required argument: --minigrid_env must be provided when --domain is minigrid")
|
249
|
-
elif args.domain == "parking" and args.parking_env is None:
|
250
|
-
parser.error("Missing required argument: --parking_env must be provided when --domain is parking")
|
251
|
-
elif args.domain == "point_maze" and args.point_maze_env is None:
|
252
|
-
parser.error("Missing required argument: --point_maze_env must be provided when --domain is point_maze")
|
253
|
-
elif args.domain == "franka_kitchen" and args.franka_env is None:
|
254
|
-
parser.error("Missing required argument: --franka_env must be provided when --domain is franka_kitchen")
|
255
|
-
|
256
|
-
if args.recognizer != "graml":
|
257
|
-
if args.learn_same_seq_len == True: parser.error("learn_same_seq_len is only relevant for graml.")
|
258
|
-
if args.inference_same_seq_len == True: parser.error("inference_same_seq_len is only relevant for graml.")
|
259
|
-
|
260
|
-
return args
|
261
|
-
|
262
|
-
if __name__ == "__main__":
|
263
|
-
args = parse_args()
|
264
|
-
set_global_storage_configs(recognizer_str=args.recognizer, is_fragmented=args.partial_obs_type,
|
265
|
-
is_inference_same_length_sequences=args.inference_same_seq_len, is_learn_same_length_sequences=args.learn_same_seq_len)
|
266
|
-
env_name, = [x for x in [args.minigrid_env, args.parking_env, args.point_maze_env, args.franka_env] if isinstance(x, str)]
|
267
|
-
if args.inference_same_seq_len: inference_same_seq_len = "inference_same_seq_len"
|
268
|
-
else: inference_same_seq_len = "inference_diff_seq_len"
|
269
|
-
if args.learn_same_seq_len: learn_same_seq_len = "learn_same_seq_len"
|
270
|
-
else: learn_same_seq_len = "learn_diff_seq_len"
|
271
|
-
analyze_and_produce_plots(args.recognizer, domain_name=args.domain, env_name=env_name, fragmented_status=args.partial_obs_type,
|
272
|
-
inf_same_length_status=inference_same_seq_len, learn_same_length_status=learn_same_seq_len)
|
evaluation/get_plans_images.py
DELETED
@@ -1,47 +0,0 @@
|
|
1
|
-
import sys
|
2
|
-
import os
|
3
|
-
import pickle
|
4
|
-
import inspect
|
5
|
-
|
6
|
-
|
7
|
-
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
8
|
-
GRAML_itself = os.path.dirname(currentdir)
|
9
|
-
GRAML_includer = os.path.dirname(os.path.dirname(currentdir))
|
10
|
-
sys.path.insert(0, GRAML_includer)
|
11
|
-
sys.path.insert(0, GRAML_itself)
|
12
|
-
|
13
|
-
def get_plans_result_path(env_name):
|
14
|
-
return os.path.join("dataset", (env_name), "plans")
|
15
|
-
|
16
|
-
def get_policy_sequences_result_path(env_name):
|
17
|
-
return os.path.join("dataset", (env_name), "policy_sequences")
|
18
|
-
|
19
|
-
|
20
|
-
# TODO: instead of loading the model and having it produce the sequence again, just save the sequence from the framework run, and have this script accept the whole path (including is_fragmented etc.)
|
21
|
-
def analyze_and_produce_images(env_name):
|
22
|
-
models_dir = get_models_dir(env_name=env_name)
|
23
|
-
for dirname in os.listdir(models_dir):
|
24
|
-
if dirname.startswith('MiniGrid'):
|
25
|
-
model_dir = get_model_dir(env_name=env_name, model_name=dirname, class_name="MCTS")
|
26
|
-
model_file_path = os.path.join(model_dir, "mcts_model.pth")
|
27
|
-
try:
|
28
|
-
with open(model_file_path, 'rb') as file: # Load the pre-existing model
|
29
|
-
monteCarloTreeSearch = pickle.load(file)
|
30
|
-
full_plan = monteCarloTreeSearch.generate_full_policy_sequence()
|
31
|
-
plan = [pos for ((state, pos), action) in full_plan]
|
32
|
-
plans_result_path = get_plans_result_path(env_name)
|
33
|
-
if not os.path.exists(plans_result_path): os.makedirs(plans_result_path)
|
34
|
-
img_path = os.path.join(get_plans_result_path(env_name), dirname)
|
35
|
-
print(f"plan to {dirname} is:\n\t{plan}\ngenerating image at {img_path}.")
|
36
|
-
create_sequence_image(plan, img_path, dirname)
|
37
|
-
|
38
|
-
except FileNotFoundError as e:
|
39
|
-
print(f"Warning: {e.filename} doesn't exist. It's probably a base goal, not generating policy sequence for it.")
|
40
|
-
|
41
|
-
if __name__ == "__main__":
|
42
|
-
# preventing circular imports. only needed for running this as main anyway.
|
43
|
-
from gr_libs.ml.utils.storage import get_models_dir, get_model_dir
|
44
|
-
# checks:
|
45
|
-
assert len(sys.argv) == 2, f"Assertion failed: len(sys.argv) is {len(sys.argv)} while it needs to be 2.\n Example: \n\t /usr/bin/python scripts/get_plans_images.py MiniGrid-Walls-13x13-v0"
|
46
|
-
assert os.path.exists(get_models_dir(sys.argv[1])), "plans weren't made for this environment, run graml_main.py with this environment first."
|
47
|
-
analyze_and_produce_images(sys.argv[1])
|
@@ -1,63 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import dill
|
3
|
-
import numpy as np
|
4
|
-
import matplotlib.pyplot as plt
|
5
|
-
from gr_libs.ml.utils.storage import get_experiment_results_path, set_global_storage_configs
|
6
|
-
|
7
|
-
if __name__ == "__main__":
|
8
|
-
|
9
|
-
# Define the tasks and percentages
|
10
|
-
increasing_base_goals = ['L1', 'L2', 'L3', 'L4', 'L5']
|
11
|
-
increasing_dynamic_goals = ['L111', 'L222', 'L555', 'L333', 'L444']
|
12
|
-
percentages = ['0.3', '0.5', '0.7', '0.9', '1']
|
13
|
-
|
14
|
-
# Prepare a dictionary to hold accuracy data
|
15
|
-
accuracies = {task: {perc: [] for perc in percentages} for task in increasing_base_goals + increasing_dynamic_goals}
|
16
|
-
|
17
|
-
# Collect data for both sets of goals
|
18
|
-
for task in increasing_base_goals + increasing_dynamic_goals:
|
19
|
-
set_global_storage_configs(recognizer_str='graml', is_fragmented='fragmented',
|
20
|
-
is_inference_same_length_sequences=True, is_learn_same_length_sequences=False)
|
21
|
-
res_file_path = f'{get_experiment_results_path("parking", "gd_agent", task)}.pkl'
|
22
|
-
|
23
|
-
if os.path.exists(res_file_path):
|
24
|
-
with open(res_file_path, 'rb') as results_file:
|
25
|
-
results = dill.load(results_file)
|
26
|
-
for percentage in percentages:
|
27
|
-
accuracies[task][percentage].append(results[percentage]['accuracy'])
|
28
|
-
else:
|
29
|
-
print(f"Warning: no file for {res_file_path}")
|
30
|
-
|
31
|
-
# Create the figure with two subplots
|
32
|
-
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
|
33
|
-
|
34
|
-
# Bar plot function
|
35
|
-
def plot_accuracies(ax, task_set, title, type):
|
36
|
-
"""Plot accuracies for a given set of tasks on the provided axis."""
|
37
|
-
x_vals = np.arange(len(task_set)) # X-axis positions for the number of goals
|
38
|
-
bar_width = 0.15 # Width of each bar
|
39
|
-
for i, perc in enumerate(['0.3', '0.5', '1']):
|
40
|
-
if perc == '1': y_vals = [max([accuracies[task]['0.5'][0], accuracies[task]['0.7'][0], accuracies[task]['0.9'][0], accuracies[task]['1'][0]]) for task in task_set] # Get mean accuracies
|
41
|
-
else: y_vals = [accuracies[task][perc][0] for task in task_set] # Get mean accuracies
|
42
|
-
if type != 'base': ax.bar(x_vals + i * bar_width, y_vals, width=bar_width, label=f'Percentage {perc}')
|
43
|
-
else: ax.bar(x_vals + i * bar_width, y_vals, width=bar_width)
|
44
|
-
ax.set_xticks(x_vals + bar_width) # Center x-ticks
|
45
|
-
ax.set_xticklabels([i+3 for i in range(len(task_set))], fontsize=16) # Set custom x-tick labels
|
46
|
-
ax.set_yticks(np.linspace(0, 1, 6))
|
47
|
-
ax.set_ylim([0, 1])
|
48
|
-
ax.set_title(title, fontsize=20)
|
49
|
-
ax.set_xlabel(f'Number of {type} Goals', fontsize=20)
|
50
|
-
if type == 'base':
|
51
|
-
ax.set_ylabel('Accuracy', fontsize=22)
|
52
|
-
ax.legend()
|
53
|
-
|
54
|
-
# Plot for increasing base goals
|
55
|
-
plot_accuracies(axes[0], increasing_base_goals, 'Increasing Base Goals', "base")
|
56
|
-
|
57
|
-
# Plot for increasing dynamic goals
|
58
|
-
plot_accuracies(axes[1], increasing_dynamic_goals, 'Increasing Active Goals', "active")
|
59
|
-
plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1, wspace=0.3, hspace=0.3)
|
60
|
-
# Adjust layout and save the figure
|
61
|
-
plt.tight_layout()
|
62
|
-
plt.savefig('increasing_goals_plot_bars.png', dpi=300) # Save the figure as a PNG file
|
63
|
-
print('Figure saved at: increasing_goals_plot_bars.png')
|
@@ -1,17 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
import sys
|
3
|
-
from gr_libs.environment.environment import MINIGRID, PANDA, PARKING, POINT_MAZE, EnvProperty, MinigridProperty, PandaProperty, ParkingProperty, PointMazeProperty
|
4
|
-
|
5
|
-
|
6
|
-
def domain_to_env_property(domain_name: str):
|
7
|
-
if domain_name == MINIGRID:
|
8
|
-
return MinigridProperty
|
9
|
-
elif domain_name == PARKING:
|
10
|
-
return ParkingProperty
|
11
|
-
elif domain_name == PANDA:
|
12
|
-
return PandaProperty
|
13
|
-
elif domain_name == POINT_MAZE:
|
14
|
-
return PointMazeProperty
|
15
|
-
else:
|
16
|
-
logging.error(f"Domain {domain_name} is not supported.")
|
17
|
-
sys.exit(1)
|
gr_libs/ml/neural/utils/penv.py
DELETED
@@ -1,57 +0,0 @@
|
|
1
|
-
import multiprocessing
|
2
|
-
import gymnasium as gym
|
3
|
-
|
4
|
-
#multiprocessing.set_start_method("fork")
|
5
|
-
|
6
|
-
|
7
|
-
def worker(conn, env):
|
8
|
-
while True:
|
9
|
-
cmd, data = conn.recv()
|
10
|
-
if cmd == "step":
|
11
|
-
obs, reward, terminated, truncated, info = env.step(data)
|
12
|
-
if terminated or truncated:
|
13
|
-
obs, _ = env.reset()
|
14
|
-
conn.send((obs, reward, terminated, truncated, info))
|
15
|
-
elif cmd == "reset":
|
16
|
-
obs, _ = env.reset()
|
17
|
-
conn.send(obs)
|
18
|
-
else:
|
19
|
-
raise NotImplementedError
|
20
|
-
|
21
|
-
|
22
|
-
class ParallelEnv(gym.Env):
|
23
|
-
"""A concurrent execution of environments in multiple processes."""
|
24
|
-
|
25
|
-
def __init__(self, envs):
|
26
|
-
assert len(envs) >= 1, "No environment given."
|
27
|
-
|
28
|
-
self.envs = envs
|
29
|
-
self.observation_space = self.envs[0].observation_space
|
30
|
-
self.action_space = self.envs[0].action_space
|
31
|
-
|
32
|
-
self.locals = []
|
33
|
-
for env in self.envs[1:]:
|
34
|
-
local, remote = multiprocessing.Pipe()
|
35
|
-
self.locals.append(local)
|
36
|
-
p = multiprocessing.Process(target=worker, args=(remote, env))
|
37
|
-
p.daemon = True
|
38
|
-
p.start()
|
39
|
-
remote.close()
|
40
|
-
|
41
|
-
def reset(self):
|
42
|
-
for local in self.locals:
|
43
|
-
local.send(("reset", None))
|
44
|
-
results = [self.envs[0].reset()[0]] + [local.recv() for local in self.locals]
|
45
|
-
return results
|
46
|
-
|
47
|
-
def step(self, actions):
|
48
|
-
for local, action in zip(self.locals, actions[1:]):
|
49
|
-
local.send(("step", action))
|
50
|
-
obs, reward, terminated, truncated, info = self.envs[0].step(actions[0])
|
51
|
-
if terminated or truncated:
|
52
|
-
obs, _ = self.envs[0].reset()
|
53
|
-
results = zip(*[(obs, reward, terminated, truncated, info)] + [local.recv() for local in self.locals])
|
54
|
-
return results
|
55
|
-
|
56
|
-
def render(self):
|
57
|
-
raise NotImplementedError
|
@@ -1,192 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import torch
|
3
|
-
import torch.nn as nn
|
4
|
-
import torch.nn.functional as F
|
5
|
-
import torch.optim as optim
|
6
|
-
from types import MethodType
|
7
|
-
import numpy as np
|
8
|
-
from gr_libs.ml.utils import device
|
9
|
-
from torch.nn.utils.rnn import pack_padded_sequence
|
10
|
-
|
11
|
-
|
12
|
-
def accuracy_per_epoch(model, data_loader):
|
13
|
-
model.eval()
|
14
|
-
correct = total = 0.0
|
15
|
-
sum_loss = 0.0
|
16
|
-
with torch.no_grad():
|
17
|
-
for (first_traces, second_traces, is_same_goals, first_traces_lengths, second_traces_lengths) in data_loader:
|
18
|
-
y_pred = model.forward_tab(first_traces, second_traces, first_traces_lengths, second_traces_lengths)
|
19
|
-
loss = F.binary_cross_entropy(y_pred, is_same_goals)
|
20
|
-
sum_loss += loss.item()
|
21
|
-
y_pred = (y_pred >= 0.5)
|
22
|
-
correct += torch.sum(y_pred == is_same_goals)
|
23
|
-
total += len(is_same_goals)
|
24
|
-
return correct / total, sum_loss / 32
|
25
|
-
|
26
|
-
def accuracy_per_epoch_cont(model, data_loader):
|
27
|
-
model.eval()
|
28
|
-
correct = total = 0.0
|
29
|
-
sum_loss = 0.0
|
30
|
-
with torch.no_grad():
|
31
|
-
for (first_traces_images, first_traces_texts, second_traces_images, second_traces_texts, is_same_goals, first_traces_lengths, second_traces_lengths) in data_loader:
|
32
|
-
y_pred = model.forward_cont(first_traces_images, first_traces_texts, second_traces_images, second_traces_texts, first_traces_lengths, second_traces_lengths)
|
33
|
-
loss = F.binary_cross_entropy(y_pred, is_same_goals)
|
34
|
-
sum_loss += loss.item()
|
35
|
-
y_pred = (y_pred >= 0.5)
|
36
|
-
correct += torch.sum(y_pred == is_same_goals)
|
37
|
-
total += len(is_same_goals)
|
38
|
-
return correct / total, sum_loss / 32
|
39
|
-
|
40
|
-
# class CNNImageEmbeddor(nn.Module):
|
41
|
-
# def __init__(self, obs_space, action_space, use_text=False):
|
42
|
-
# super().__init__()
|
43
|
-
# self.use_text = use_text
|
44
|
-
# self.image_conv = nn.Sequential(
|
45
|
-
# nn.Conv2d(3, 4, kernel_size=(3, 3), padding=1), # Reduced filters, added padding
|
46
|
-
# nn.ReLU(),
|
47
|
-
# nn.MaxPool2d((2, 2)),
|
48
|
-
# nn.Conv2d(4, 4, (3, 3), padding=1), # Reduced filters, added padding
|
49
|
-
# nn.ReLU(),
|
50
|
-
# nn.MaxPool2d((2, 2)), # Added additional pooling to reduce size
|
51
|
-
# nn.Conv2d(4, 8, (3, 3), padding=1), # Reduced filters, added padding
|
52
|
-
# nn.ReLU(),
|
53
|
-
# nn.BatchNorm2d(8)
|
54
|
-
# )
|
55
|
-
# n = obs_space["image"][0]
|
56
|
-
# m = obs_space["image"][1]
|
57
|
-
# self.image_embedding_size = ((n - 4) // 4 - 3) * ((m - 4) // 4 - 3) * 8
|
58
|
-
# if self.use_text:
|
59
|
-
# self.word_embedding_size = 32
|
60
|
-
# self.word_embedding = nn.Embedding(obs_space["text"], self.word_embedding_size)
|
61
|
-
# self.text_embedding_size = 128
|
62
|
-
# self.text_rnn = nn.GRU(self.word_embedding_size, self.text_embedding_size, batch_first=True)
|
63
|
-
|
64
|
-
def forward(self, images, texts):
|
65
|
-
# images shape: batch_size X max_sequence_len X sample_size. same for text.
|
66
|
-
# need to reshape image to num_channels X height X width, like nn.Conv expects it to be.
|
67
|
-
x = images.transpose(2, 4).transpose(3, 4)
|
68
|
-
orig_shape = x.shape
|
69
|
-
# combine batch and sequence to 1 dimension so conv could handle it
|
70
|
-
x = x.view(orig_shape[0]*orig_shape[1], orig_shape[2], orig_shape[3], orig_shape[4]) # x shape: batch_size * max_sequence_len X sample_size
|
71
|
-
x = self.image_conv(x) # x shape: batch_size * max_sequence_len X last_conv_size X 1 X 1
|
72
|
-
# reshape x back to divide batches from sequences
|
73
|
-
x = x.view(orig_shape[0], orig_shape[1], x.shape[1]) # x shape: batch_size X max_sequence_len X last_conv_size. last 2 dimensions (1,1) are collapsed to last conv.
|
74
|
-
embedding = x
|
75
|
-
|
76
|
-
if self.use_text:
|
77
|
-
embed_text = self._get_embed_text(texts)
|
78
|
-
embedding = torch.cat((embedding, embed_text), dim=1)
|
79
|
-
|
80
|
-
return embedding
|
81
|
-
|
82
|
-
def _get_embed_text(self, text):
|
83
|
-
_, hidden = self.text_rnn(self.word_embedding(text))
|
84
|
-
return hidden[-1]
|
85
|
-
|
86
|
-
class LstmObservations(nn.Module):
|
87
|
-
|
88
|
-
def __init__(self, input_size, hidden_size): # TODO make sure the right cuda is used!
|
89
|
-
super(LstmObservations,self).__init__()
|
90
|
-
#self.embeddor = CNNImageEmbeddor(obs_space, action_space)
|
91
|
-
# check if the traces are a bunch of images
|
92
|
-
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
|
93
|
-
self.dropout = nn.Dropout(0.5) # Added dropout layer
|
94
|
-
# Initialize weights
|
95
|
-
for name, param in self.lstm.named_parameters():
|
96
|
-
if 'weight' in name:
|
97
|
-
nn.init.xavier_uniform_(param)
|
98
|
-
elif 'bias' in name:
|
99
|
-
nn.init.zeros_(param)
|
100
|
-
|
101
|
-
|
102
|
-
# tabular
|
103
|
-
def forward_tab(self, traces1, traces2, lengths1, lengths2):
|
104
|
-
out1, (ht1, ct1) = self.lstm(pack_padded_sequence(traces1, lengths1, batch_first=True, enforce_sorted=False), None) # traces1 & traces 2 shapes: batch_size X max sequence_length X embedding_size
|
105
|
-
out2, (ht2, ct2) = self.lstm(pack_padded_sequence(traces2, lengths2, batch_first=True, enforce_sorted=False), None)
|
106
|
-
# out1, _ = pad_packed_sequence(out1, batch_first=True, total_length=max(lengths1))
|
107
|
-
# out2, _ = pad_packed_sequence(out2, batch_first=True, total_length=max(lengths2))
|
108
|
-
manhattan_dis = torch.exp(-torch.sum(torch.abs(ht1[-1]-ht2[-1]),dim=1,keepdim=True))
|
109
|
-
return manhattan_dis.squeeze()
|
110
|
-
|
111
|
-
# continuous
|
112
|
-
# def forward_cont(self, traces1_images, traces1_texts, traces2_images, traces2_texts, lengths1, lengths2):
|
113
|
-
# # we also embed '0' images, but we take them out of the equation in the lstm (it knows to not treat them when batching)
|
114
|
-
# traces1 = self.embeddor(traces1_images, traces1_texts)
|
115
|
-
# traces2 = self.embeddor(traces2_images, traces2_texts) # traces1 & traces 2 shapes: batch_size X max_sequence_length X embedding_size
|
116
|
-
# out1, (ht1, ct1) = self.lstm(pack_padded_sequence(traces1, lengths1, batch_first=True, enforce_sorted=False), None)
|
117
|
-
# out2, (ht2, ct2) = self.lstm(pack_padded_sequence(traces2, lengths2, batch_first=True, enforce_sorted=False), None)
|
118
|
-
# manhattan_dis = torch.exp(-torch.sum(torch.abs(ht1[-1]-ht2[-1]),dim=1,keepdim=True))
|
119
|
-
# return manhattan_dis.squeeze()
|
120
|
-
|
121
|
-
def embed_sequence(self, trace):
|
122
|
-
trace = torch.stack([torch.tensor(observation, dtype=torch.float32) for observation in trace]).to(device)
|
123
|
-
out, (ht, ct) = self.lstm(trace, None)
|
124
|
-
return ht[-1]
|
125
|
-
|
126
|
-
# def embed_sequence_cont(self, sequence, preprocess_obss):
|
127
|
-
# sequence = [preprocess_obss([obs])[0] for ((obs, (_, _)), _) in sequence]
|
128
|
-
# trace_images = torch.tensor(np.expand_dims(torch.stack([step.image for step in sequence]), axis=0)).to(device)
|
129
|
-
# trace_texts = torch.tensor(np.expand_dims(torch.stack([step.text for step in sequence]), axis=0)).to(device)
|
130
|
-
# embedded_trace = self.embeddor(trace_images, trace_texts)
|
131
|
-
# out, (ht, ct) = self.lstm(embedded_trace)
|
132
|
-
# return ht[-1]
|
133
|
-
|
134
|
-
def train_metric_model(model, train_loader, dev_loader, nepochs=5, patience = 2):
|
135
|
-
devAccuracy = []
|
136
|
-
best_dev_accuracy = 0.0
|
137
|
-
no_improvement_count = 0
|
138
|
-
optimizer = torch.optim.Adadelta(model.parameters(), weight_decay=0.1)
|
139
|
-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
|
140
|
-
for epoch in range(nepochs):
|
141
|
-
sum_loss, denominator = 0.0, 0.0
|
142
|
-
model.train()
|
143
|
-
for (first_traces, second_traces, is_same_goals, first_traces_lengths, second_traces_lengths) in train_loader:
|
144
|
-
model.zero_grad()
|
145
|
-
y_pred = model.forward_tab(first_traces, second_traces, first_traces_lengths, second_traces_lengths)
|
146
|
-
if len(is_same_goals) == 1: is_same_goals = torch.squeeze(is_same_goals) # for the case of batches in size 1...
|
147
|
-
loss = F.binary_cross_entropy(y_pred, is_same_goals)
|
148
|
-
sum_loss += loss.item()
|
149
|
-
denominator += 1
|
150
|
-
loss.backward()
|
151
|
-
optimizer.step()
|
152
|
-
|
153
|
-
dev_accuracy, dev_loss = accuracy_per_epoch(model, dev_loader)
|
154
|
-
devAccuracy.append(dev_accuracy)
|
155
|
-
if dev_accuracy > best_dev_accuracy:
|
156
|
-
best_dev_accuracy = dev_accuracy
|
157
|
-
no_improvement_count = 0
|
158
|
-
else:
|
159
|
-
no_improvement_count = 1
|
160
|
-
|
161
|
-
print("epoch - {}/{}...".format(epoch + 1, nepochs),
|
162
|
-
"train loss - {:.6f}...".format(sum_loss / denominator),
|
163
|
-
"dev loss - {:.6f}...".format(dev_loss),
|
164
|
-
"dev accuracy - {:.6f}".format(dev_accuracy))
|
165
|
-
|
166
|
-
if no_improvement_count >= patience:
|
167
|
-
print(f"Early stopping after {epoch + 1} epochs with no improvement.")
|
168
|
-
break
|
169
|
-
|
170
|
-
def train_metric_model_cont(model, train_loader, dev_loader, nepochs=5):
|
171
|
-
devAccuracy = []
|
172
|
-
optimizer = torch.optim.Adadelta(model.parameters(),weight_decay=1.25)
|
173
|
-
for epoch in range(nepochs):
|
174
|
-
sum_loss, denominator = 0.0, 0.0
|
175
|
-
model.train()
|
176
|
-
for (first_traces_images, first_traces_texts, second_traces_images, second_traces_texts, is_same_goals, first_traces_lengths, second_traces_lengths) in train_loader:
|
177
|
-
model.zero_grad()
|
178
|
-
y_pred = model.forward_cont(first_traces_images, first_traces_texts, second_traces_images, second_traces_texts, first_traces_lengths, second_traces_lengths)
|
179
|
-
loss = F.binary_cross_entropy(y_pred, is_same_goals)
|
180
|
-
sum_loss += loss.item()
|
181
|
-
denominator += 1
|
182
|
-
loss.backward()
|
183
|
-
optimizer.step()
|
184
|
-
|
185
|
-
dev_accuracy, dev_loss = accuracy_per_epoch_cont(model, dev_loader)
|
186
|
-
devAccuracy.append(dev_accuracy)
|
187
|
-
|
188
|
-
print("epoch - {}/{}...".format(epoch + 1, nepochs),
|
189
|
-
"train loss - {:.6f}...".format(sum_loss / denominator),
|
190
|
-
"dev loss - {:.6f}...".format(dev_loss),
|
191
|
-
"dev accuracy - {:.6f}".format(dev_accuracy))
|
192
|
-
|