gr-libs 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gr_libs/__init__.py +6 -1
- gr_libs/_version.py +2 -2
- gr_libs/environment/environment.py +104 -15
- gr_libs/ml/consts.py +1 -0
- gr_libs/ml/neural/deep_rl_learner.py +101 -14
- gr_libs/odgr_executor.py +7 -2
- gr_libs/recognizer/_utils/format.py +7 -1
- gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +146 -1
- gr_libs/recognizer/graml/graml_recognizer.py +4 -4
- gr_libs/recognizer/recognizer.py +4 -4
- gr_libs/tutorials/gcaura_panda_tutorial.py +168 -0
- gr_libs/tutorials/gcaura_parking_tutorial.py +167 -0
- gr_libs/tutorials/gcaura_point_maze_tutorial.py +169 -0
- {gr_libs-0.2.5.dist-info → gr_libs-0.2.6.dist-info}/METADATA +16 -11
- {gr_libs-0.2.5.dist-info → gr_libs-0.2.6.dist-info}/RECORD +19 -14
- tests/test_gcaura.py +15 -0
- tests/test_odgr_executor_gcaura.py +14 -0
- {gr_libs-0.2.5.dist-info → gr_libs-0.2.6.dist-info}/WHEEL +0 -0
- {gr_libs-0.2.5.dist-info → gr_libs-0.2.6.dist-info}/top_level.txt +0 -0
gr_libs/__init__.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
"""gr_libs: Baselines for goal recognition executions on gym environments."""
|
2
2
|
|
3
|
-
from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import
|
3
|
+
from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import (
|
4
|
+
Draco,
|
5
|
+
GCDraco,
|
6
|
+
Graql,
|
7
|
+
GCAura,
|
8
|
+
)
|
4
9
|
from gr_libs.recognizer.graml.graml_recognizer import ExpertBasedGraml, GCGraml
|
5
10
|
|
6
11
|
try:
|
gr_libs/_version.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
"""
|
1
|
+
"""environment.py"""
|
2
2
|
|
3
3
|
import os
|
4
4
|
import sys
|
@@ -14,6 +14,8 @@ from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
|
|
14
14
|
from PIL import Image
|
15
15
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
16
16
|
|
17
|
+
from gr_envs.wrappers.goal_wrapper import GoalRecognitionWrapper
|
18
|
+
|
17
19
|
MINIGRID, PANDA, PARKING, POINT_MAZE = "minigrid", "panda", "parking", "point_maze"
|
18
20
|
|
19
21
|
QLEARNING = "QLEARNING"
|
@@ -110,6 +112,12 @@ class EnvProperty:
|
|
110
112
|
Convert a list of problems to a string tuple.
|
111
113
|
"""
|
112
114
|
|
115
|
+
@abstractmethod
|
116
|
+
def goal_to_str(self, goal):
|
117
|
+
"""
|
118
|
+
Convert a goal to a string representation.
|
119
|
+
"""
|
120
|
+
|
113
121
|
@abstractmethod
|
114
122
|
def goal_to_problem_str(self, goal):
|
115
123
|
"""
|
@@ -166,6 +174,29 @@ class EnvProperty:
|
|
166
174
|
Change the goal to a specific desired goal.
|
167
175
|
"""
|
168
176
|
|
177
|
+
def is_goal_in_subspace(self, goal):
|
178
|
+
"""
|
179
|
+
Check if a goal is within the specified goal subspace.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
goal: The goal to check
|
183
|
+
goal_subspace: The goal subspace to check against
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
bool: True if the goal is within the subspace, False otherwise
|
187
|
+
"""
|
188
|
+
env = gym.make(id=self.name)
|
189
|
+
while env is not None and hasattr(env, "env"):
|
190
|
+
if isinstance(env, GoalRecognitionWrapper) and hasattr(
|
191
|
+
env, "is_goal_in_subspace"
|
192
|
+
):
|
193
|
+
# If the environment has a goal recognition wrapper, use its method
|
194
|
+
return env.is_goal_in_subspace(goal)
|
195
|
+
# Traverse through wrappers to find the base environment
|
196
|
+
env = env.env
|
197
|
+
|
198
|
+
return True
|
199
|
+
|
169
200
|
|
170
201
|
class GCEnvProperty(EnvProperty):
|
171
202
|
"""
|
@@ -194,16 +225,25 @@ class MinigridProperty(EnvProperty):
|
|
194
225
|
super().__init__(name)
|
195
226
|
self.domain_name = "minigrid"
|
196
227
|
|
228
|
+
def goal_to_str(self, goal):
|
229
|
+
"""
|
230
|
+
Convert a goal to a string representation.
|
231
|
+
"""
|
232
|
+
return f"{goal[0]}x{goal[1]}"
|
233
|
+
|
197
234
|
def goal_to_problem_str(self, goal):
|
198
235
|
"""
|
199
236
|
Convert a goal to a problem string.
|
200
237
|
"""
|
201
|
-
return self.name + f"-DynamicGoal-{goal
|
238
|
+
return self.name + f"-DynamicGoal-{self.goal_to_str(goal)}-v0"
|
202
239
|
|
203
|
-
def str_to_goal(self, problem_name):
|
240
|
+
def str_to_goal(self, problem_name=None):
|
204
241
|
"""
|
205
242
|
Convert a problem name to a goal.
|
206
243
|
"""
|
244
|
+
if problem_name is None:
|
245
|
+
problem_name = self.name
|
246
|
+
|
207
247
|
parts = problem_name.split("-")
|
208
248
|
goal_part = [part for part in parts if "x" in part]
|
209
249
|
width, height = goal_part[0].split("x")
|
@@ -325,30 +365,36 @@ class PandaProperty(GCEnvProperty):
|
|
325
365
|
super().__init__(name)
|
326
366
|
self.domain_name = "panda"
|
327
367
|
|
328
|
-
def str_to_goal(self, problem_name):
|
368
|
+
def str_to_goal(self, problem_name=None):
|
329
369
|
"""
|
330
370
|
Convert a problem name to a goal.
|
331
371
|
"""
|
372
|
+
if problem_name is None:
|
373
|
+
return "general"
|
332
374
|
try:
|
333
375
|
numeric_part = problem_name.split("PandaMyReachDenseX")[1]
|
334
376
|
components = [
|
335
377
|
component.replace("-v3", "").replace("y", ".").replace("M", "-")
|
336
378
|
for component in numeric_part.split("X")
|
337
379
|
]
|
338
|
-
floats = []
|
339
|
-
|
340
|
-
floats.append(float(component))
|
341
|
-
return np.array([floats], dtype=np.float32)
|
380
|
+
floats = [float(component) for component in components]
|
381
|
+
return np.array([floats])
|
342
382
|
except Exception:
|
343
383
|
return "general"
|
344
384
|
|
345
|
-
def
|
385
|
+
def goal_to_str(self, goal):
|
346
386
|
"""
|
347
|
-
Convert a goal to a
|
387
|
+
Convert a goal to a string representation.
|
348
388
|
"""
|
349
|
-
|
389
|
+
return "X".join(
|
350
390
|
[str(float(g)).replace(".", "y").replace("-", "M") for g in goal[0]]
|
351
391
|
)
|
392
|
+
|
393
|
+
def goal_to_problem_str(self, goal):
|
394
|
+
"""
|
395
|
+
Convert a goal to a problem string.
|
396
|
+
"""
|
397
|
+
goal_str = self.goal_to_str(goal)
|
352
398
|
return f"PandaMyReachDenseX{goal_str}-v3"
|
353
399
|
|
354
400
|
def gc_adaptable(self):
|
@@ -450,10 +496,34 @@ class ParkingProperty(GCEnvProperty):
|
|
450
496
|
super().__init__(name)
|
451
497
|
self.domain_name = "parking"
|
452
498
|
|
499
|
+
def str_to_goal(self, problem_name=None):
|
500
|
+
"""
|
501
|
+
Convert a problem name to a goal.
|
502
|
+
"""
|
503
|
+
if not problem_name:
|
504
|
+
problem_name = self.name
|
505
|
+
# Extract the goal from the part
|
506
|
+
return int(problem_name.split("GI-")[1].split("-v0")[0])
|
507
|
+
|
508
|
+
def goal_to_str(self, goal):
|
509
|
+
"""
|
510
|
+
Convert a goal to a string representation.
|
511
|
+
"""
|
512
|
+
if isinstance(goal, int):
|
513
|
+
return str(goal)
|
514
|
+
elif isinstance(goal, str):
|
515
|
+
return goal
|
516
|
+
else:
|
517
|
+
raise ValueError(
|
518
|
+
f"Unsupported goal type: {type(goal)}. Expected int or str."
|
519
|
+
)
|
520
|
+
|
453
521
|
def goal_to_problem_str(self, goal):
|
454
522
|
"""
|
455
523
|
Convert a goal to a problem string.
|
456
524
|
"""
|
525
|
+
if "-GI-" in self.name:
|
526
|
+
return self.name.split("-GI-")[0] + f"-GI-{goal}-v0"
|
457
527
|
return self.name.split("-v0")[0] + f"-GI-{goal}-v0"
|
458
528
|
|
459
529
|
def gc_adaptable(self):
|
@@ -536,9 +606,11 @@ class PointMazeProperty(EnvProperty):
|
|
536
606
|
super().__init__(name)
|
537
607
|
self.domain_name = "point_maze"
|
538
608
|
|
539
|
-
def str_to_goal(self):
|
609
|
+
def str_to_goal(self, problem_name=None):
|
540
610
|
"""Convert a problem name to a goal."""
|
541
|
-
|
611
|
+
if not problem_name:
|
612
|
+
problem_name = self.name
|
613
|
+
parts = problem_name.split("-")
|
542
614
|
# Find the part containing the goal size (usually after "DynamicGoal")
|
543
615
|
sizes_parts = [part for part in parts if "x" in part]
|
544
616
|
goal_part = sizes_parts[1]
|
@@ -546,9 +618,15 @@ class PointMazeProperty(EnvProperty):
|
|
546
618
|
width, height = goal_part.split("x")
|
547
619
|
return (int(width), int(height))
|
548
620
|
|
621
|
+
def goal_to_str(self, goal):
|
622
|
+
"""
|
623
|
+
Convert a goal to a string representation.
|
624
|
+
"""
|
625
|
+
return f"{goal[0]}x{goal[1]}"
|
626
|
+
|
549
627
|
def gc_adaptable(self):
|
550
628
|
"""Check if the environment is goal-conditioned adaptable."""
|
551
|
-
return
|
629
|
+
return True
|
552
630
|
|
553
631
|
def problem_list_to_str_tuple(self, problems):
|
554
632
|
"""Convert a list of problems to a string tuple."""
|
@@ -574,7 +652,12 @@ class PointMazeProperty(EnvProperty):
|
|
574
652
|
"""
|
575
653
|
Convert a goal to a problem string.
|
576
654
|
"""
|
577
|
-
|
655
|
+
possible_suffixes = ["-Goals-", "-Goal-", "-MultiGoals-", "-GoalConditioned-"]
|
656
|
+
for suffix in possible_suffixes:
|
657
|
+
if suffix in self.name:
|
658
|
+
return self.name.split(suffix)[0] + f"-Goal-{self.goal_to_str(goal)}"
|
659
|
+
|
660
|
+
return self.name + f"-Goal-{self.goal_to_str(goal)}"
|
578
661
|
|
579
662
|
def change_done_by_specific_desired(self, obs, desired, old_success_done):
|
580
663
|
"""
|
@@ -592,6 +675,12 @@ class PointMazeProperty(EnvProperty):
|
|
592
675
|
assert isinstance(done, np.ndarray)
|
593
676
|
return done[0]
|
594
677
|
|
678
|
+
def use_goal_directed_problem(self):
|
679
|
+
"""
|
680
|
+
Check if the environment uses a goal-directed problem.
|
681
|
+
"""
|
682
|
+
return True
|
683
|
+
|
595
684
|
def is_success(self, info):
|
596
685
|
"""
|
597
686
|
Check if the episode is successful.
|
gr_libs/ml/consts.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import gc
|
2
2
|
from collections import OrderedDict
|
3
3
|
from types import MethodType
|
4
|
+
from typing import Any
|
4
5
|
|
5
6
|
import cv2
|
6
7
|
import numpy as np
|
@@ -22,6 +23,10 @@ from stable_baselines3.common.base_class import BaseAlgorithm
|
|
22
23
|
|
23
24
|
from gr_libs.ml.utils import device
|
24
25
|
|
26
|
+
from gr_libs.ml.consts import (
|
27
|
+
FINETUNE_TIMESTEPS,
|
28
|
+
)
|
29
|
+
|
25
30
|
# TODO do we need this?
|
26
31
|
NETWORK_SETUP = {
|
27
32
|
SAC: OrderedDict(
|
@@ -236,27 +241,46 @@ class DeepRLAgent:
|
|
236
241
|
self._model_file_path, env=self.env, device=device, **self.model_kwargs
|
237
242
|
)
|
238
243
|
|
239
|
-
def learn(self):
|
244
|
+
def learn(self, goal=None, total_timesteps=None):
|
240
245
|
"""Train the agent."""
|
241
|
-
|
242
|
-
|
246
|
+
model_file_path = self._model_file_path
|
247
|
+
old_model_file_path = model_file_path
|
248
|
+
if goal is not None:
|
249
|
+
model_file_path = self._model_file_path.replace(
|
250
|
+
".pth", f"_{goal}.pth"
|
251
|
+
).replace(".zip", f"_{goal}.zip")
|
252
|
+
if total_timesteps is not None:
|
253
|
+
model_file_path = model_file_path.replace(
|
254
|
+
".pth", f"_{total_timesteps}.pth"
|
255
|
+
).replace(".zip", f"_{total_timesteps}.zip")
|
256
|
+
|
257
|
+
self._model_file_path = model_file_path
|
258
|
+
|
259
|
+
if os.path.exists(model_file_path):
|
260
|
+
print(f"Loading pre-existing model in {model_file_path}")
|
243
261
|
self.load_model()
|
244
262
|
else:
|
245
|
-
print(f"No existing model in {
|
246
|
-
if
|
247
|
-
|
248
|
-
|
249
|
-
self.
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
263
|
+
print(f"No existing model in {model_file_path}, starting learning")
|
264
|
+
if total_timesteps is None:
|
265
|
+
total_timesteps = self.num_timesteps
|
266
|
+
if self.exploration_rate is not None:
|
267
|
+
self._model = self.algorithm(
|
268
|
+
"MultiInputPolicy",
|
269
|
+
self.env,
|
270
|
+
ent_coef=self.exploration_rate,
|
271
|
+
verbose=1,
|
272
|
+
)
|
273
|
+
else:
|
274
|
+
self._model = self.algorithm(
|
275
|
+
"MultiInputPolicy", self.env, verbose=1
|
276
|
+
)
|
255
277
|
self._model.learn(
|
256
|
-
total_timesteps=
|
278
|
+
total_timesteps=total_timesteps, progress_bar=True
|
257
279
|
) # comment this in a normal env
|
258
280
|
self.save_model()
|
259
281
|
|
282
|
+
self._model_file_path = old_model_file_path
|
283
|
+
|
260
284
|
def safe_env_reset(self):
|
261
285
|
"""
|
262
286
|
Reset the environment safely, suppressing output.
|
@@ -503,6 +527,69 @@ class DeepRLAgent:
|
|
503
527
|
self.env.close()
|
504
528
|
return observations
|
505
529
|
|
530
|
+
def fine_tune(
|
531
|
+
self,
|
532
|
+
goal: Any,
|
533
|
+
num_timesteps: int = FINETUNE_TIMESTEPS,
|
534
|
+
) -> None:
|
535
|
+
"""
|
536
|
+
Fine-tune this goal-conditioned agent on a single specified goal.
|
537
|
+
Overrides optimizer LR if provided, resets the env to the goal, and continues training.
|
538
|
+
|
539
|
+
Args:
|
540
|
+
goal: The specific goal to fine-tune on. Type depends on the environment.
|
541
|
+
num_timesteps: Number of timesteps for fine-tuning. Defaults to FINETUNE_TIMESTEPS.
|
542
|
+
learning_rate: Learning rate for fine-tuning. Defaults to FINETUNE_LR.
|
543
|
+
"""
|
544
|
+
# Store original environment and problem
|
545
|
+
original_env = self.env
|
546
|
+
original_problem = self.problem_name
|
547
|
+
created_new_env = False
|
548
|
+
|
549
|
+
try:
|
550
|
+
# Try to create a goal-specific environment
|
551
|
+
if hasattr(self.env_prop, "goal_to_problem_str") and callable(
|
552
|
+
self.env_prop.goal_to_problem_str
|
553
|
+
):
|
554
|
+
try:
|
555
|
+
goal_problem = self.env_prop.goal_to_problem_str(goal)
|
556
|
+
|
557
|
+
# Create the goal-specific environment
|
558
|
+
env_kwargs = {"id": goal_problem, "render_mode": "rgb_array"}
|
559
|
+
new_env = self.env_prop.create_vec_env(env_kwargs)
|
560
|
+
|
561
|
+
# Update the model's environment
|
562
|
+
self._model.set_env(new_env)
|
563
|
+
self.env = new_env
|
564
|
+
self.problem_name = goal_problem
|
565
|
+
created_new_env = True
|
566
|
+
print(f"Created a new environment for fine-tuning: {goal_problem}")
|
567
|
+
except Exception as e:
|
568
|
+
print(f"Warning: Could not create goal-specific environment: {e}")
|
569
|
+
|
570
|
+
if not created_new_env:
|
571
|
+
print(
|
572
|
+
(
|
573
|
+
"Fine-tuning requires a goal-specific environment."
|
574
|
+
"Please ensure that the environment with the specified goal exists."
|
575
|
+
)
|
576
|
+
)
|
577
|
+
|
578
|
+
print(f"Fine-tuning for {num_timesteps} timesteps...")
|
579
|
+
self.learn(
|
580
|
+
goal=self.env_prop.goal_to_str(goal), total_timesteps=num_timesteps
|
581
|
+
)
|
582
|
+
print("Fine-tuning complete. Model saved.")
|
583
|
+
|
584
|
+
finally:
|
585
|
+
# Restore original environment if needed
|
586
|
+
if created_new_env:
|
587
|
+
self.env.close()
|
588
|
+
self._model.set_env(original_env)
|
589
|
+
self.env = original_env
|
590
|
+
self.problem_name = original_problem
|
591
|
+
print("Restored original environment.")
|
592
|
+
|
506
593
|
|
507
594
|
class GCDeepRLAgent(DeepRLAgent):
|
508
595
|
"""
|
gr_libs/odgr_executor.py
CHANGED
@@ -15,7 +15,7 @@ from gr_libs.ml.utils.storage import (
|
|
15
15
|
)
|
16
16
|
from gr_libs.problems.consts import PROBLEMS
|
17
17
|
from gr_libs.recognizer._utils import recognizer_str_to_obj
|
18
|
-
from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco
|
18
|
+
from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco, GCAura
|
19
19
|
from gr_libs.recognizer.graml.graml_recognizer import Graml
|
20
20
|
from gr_libs.recognizer.recognizer import GaAgentTrainerRecognizer, LearningRecognizer
|
21
21
|
|
@@ -102,7 +102,11 @@ def run_odgr_problem(args):
|
|
102
102
|
}
|
103
103
|
|
104
104
|
# need to dump the whole plan for draco because it needs it for inference phase for checking likelihood.
|
105
|
-
if (
|
105
|
+
if (
|
106
|
+
recognizer_type == Draco
|
107
|
+
or recognizer_type == GCDraco
|
108
|
+
or recognizer_type == GCAura
|
109
|
+
) and issubclass(
|
106
110
|
rl_agent_type, DeepRLAgent
|
107
111
|
): # TODO remove this condition, remove the assumption.
|
108
112
|
generate_obs_kwargs["with_dict"] = True
|
@@ -224,6 +228,7 @@ def parse_args():
|
|
224
228
|
"Graql",
|
225
229
|
"Draco",
|
226
230
|
"GCDraco",
|
231
|
+
"GCAura",
|
227
232
|
],
|
228
233
|
required=True,
|
229
234
|
help="Recognizer type. Follow readme.md and recognizer folder for more information and rules.",
|
@@ -1,4 +1,9 @@
|
|
1
|
-
from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import
|
1
|
+
from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import (
|
2
|
+
Draco,
|
3
|
+
GCDraco,
|
4
|
+
Graql,
|
5
|
+
GCAura,
|
6
|
+
)
|
2
7
|
from gr_libs.recognizer.graml.graml_recognizer import (
|
3
8
|
ExpertBasedGraml,
|
4
9
|
GCGraml,
|
@@ -14,5 +19,6 @@ def recognizer_str_to_obj(recognizer_str: str):
|
|
14
19
|
"Graql": Graql,
|
15
20
|
"Draco": Draco,
|
16
21
|
"GCDraco": GCDraco,
|
22
|
+
"GCAura": GCAura,
|
17
23
|
}
|
18
24
|
return recognizer_map.get(recognizer_str)
|
@@ -8,12 +8,14 @@ from gr_libs.ml.base import RLAgent
|
|
8
8
|
from gr_libs.ml.neural.deep_rl_learner import DeepRLAgent, GCDeepRLAgent
|
9
9
|
from gr_libs.ml.tabular.tabular_q_learner import TabularQLearner
|
10
10
|
from gr_libs.ml.utils.storage import get_gr_as_rl_experiment_confidence_path
|
11
|
+
from gymnasium.envs.registration import register, registry
|
11
12
|
from gr_libs.recognizer.recognizer import (
|
12
13
|
GaAdaptingRecognizer,
|
13
14
|
GaAgentTrainerRecognizer,
|
14
15
|
LearningRecognizer,
|
15
16
|
Recognizer,
|
16
17
|
)
|
18
|
+
from gr_libs.ml.consts import FINETUNE_TIMESTEPS
|
17
19
|
|
18
20
|
|
19
21
|
class GRAsRL(Recognizer):
|
@@ -234,7 +236,7 @@ class GCDraco(GRAsRL, LearningRecognizer, GaAdaptingRecognizer):
|
|
234
236
|
base = problems["gc"]
|
235
237
|
base_goals = base["goals"]
|
236
238
|
train_configs = base["train_configs"]
|
237
|
-
super().domain_learning_phase(
|
239
|
+
super().domain_learning_phase(train_configs, base_goals)
|
238
240
|
agent_kwargs = {
|
239
241
|
"domain_name": self.env_prop.domain_name,
|
240
242
|
"problem_name": self.env_prop.name,
|
@@ -256,3 +258,146 @@ class GCDraco(GRAsRL, LearningRecognizer, GaAdaptingRecognizer):
|
|
256
258
|
|
257
259
|
def choose_agent(self, problem_name: str) -> RLAgent:
|
258
260
|
return next(iter(self.agents.values()))
|
261
|
+
|
262
|
+
|
263
|
+
class GCAura(GRAsRL, LearningRecognizer, GaAdaptingRecognizer):
|
264
|
+
"""
|
265
|
+
GCAura uses goal-conditioned reinforcement learning with adaptive fine-tuning.
|
266
|
+
|
267
|
+
It trains a base goal-conditioned policy over a goal subspace in the domain learning phase.
|
268
|
+
During the goal adaptation phase, it checks if new goals are within the original goal subspace:
|
269
|
+
- If a goal is within the subspace, it uses the original trained model
|
270
|
+
- If a goal is outside the subspace, it fine-tunes the model for that specific goal
|
271
|
+
|
272
|
+
This approach combines the efficiency of goal-conditioned RL with the precision of
|
273
|
+
goal-specific fine-tuning when needed.
|
274
|
+
"""
|
275
|
+
|
276
|
+
def __init__(self, *args, **kwargs):
|
277
|
+
super().__init__(*args, **kwargs)
|
278
|
+
assert (
|
279
|
+
self.env_prop.gc_adaptable()
|
280
|
+
and not self.env_prop.is_state_discrete()
|
281
|
+
and not self.env_prop.is_action_discrete()
|
282
|
+
)
|
283
|
+
if self.rl_agent_type is None:
|
284
|
+
self.rl_agent_type = GCDeepRLAgent
|
285
|
+
self.evaluation_function = kwargs.get("evaluation_function")
|
286
|
+
if self.evaluation_function is None:
|
287
|
+
from gr_libs.metrics.metrics import mean_wasserstein_distance
|
288
|
+
|
289
|
+
self.evaluation_function = mean_wasserstein_distance
|
290
|
+
assert callable(
|
291
|
+
self.evaluation_function
|
292
|
+
), "Evaluation function must be a callable function."
|
293
|
+
|
294
|
+
# Store fine-tuning parameters
|
295
|
+
self.finetune_timesteps = kwargs.get("finetune_timesteps", FINETUNE_TIMESTEPS)
|
296
|
+
|
297
|
+
# Dictionary to store fine-tuned agents for specific goals
|
298
|
+
self.fine_tuned_agents = {}
|
299
|
+
|
300
|
+
def domain_learning_phase(self, problems):
|
301
|
+
base = problems["gc"]
|
302
|
+
train_configs = base["train_configs"]
|
303
|
+
|
304
|
+
# Store the goal subspace for later checks
|
305
|
+
self.original_train_configs = train_configs
|
306
|
+
|
307
|
+
super().domain_learning_phase(train_configs)
|
308
|
+
|
309
|
+
agent_kwargs = {
|
310
|
+
"domain_name": self.env_prop.domain_name,
|
311
|
+
"problem_name": self.env_prop.name,
|
312
|
+
"algorithm": train_configs[0][0],
|
313
|
+
"num_timesteps": train_configs[0][1],
|
314
|
+
"env_prop": self.env_prop,
|
315
|
+
}
|
316
|
+
|
317
|
+
agent = self.rl_agent_type(**agent_kwargs)
|
318
|
+
agent.learn()
|
319
|
+
self.agents[self.env_prop.name] = agent
|
320
|
+
self.action_space = agent.env.action_space
|
321
|
+
|
322
|
+
def _is_goal_in_subspace(self, goal):
|
323
|
+
"""
|
324
|
+
Check if a goal is within the original training subspace.
|
325
|
+
|
326
|
+
Delegates to the environment property's implementation.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
goal: The goal to check
|
330
|
+
|
331
|
+
Returns:
|
332
|
+
bool: True if the goal is within the training subspace
|
333
|
+
"""
|
334
|
+
# Use the environment property's implementation
|
335
|
+
return self.env_prop.is_goal_in_subspace(goal)
|
336
|
+
|
337
|
+
def goals_adaptation_phase(self, dynamic_goals):
|
338
|
+
"""
|
339
|
+
Adapt to new goals, fine-tuning if necessary.
|
340
|
+
|
341
|
+
For goals outside the original training subspace, fine-tune the model.
|
342
|
+
|
343
|
+
Args:
|
344
|
+
dynamic_goals: List of goals to adapt to
|
345
|
+
"""
|
346
|
+
self.active_goals = dynamic_goals
|
347
|
+
self.active_problems = [
|
348
|
+
self.env_prop.goal_to_problem_str(goal) for goal in dynamic_goals
|
349
|
+
]
|
350
|
+
|
351
|
+
# Check each goal and fine-tune if needed
|
352
|
+
for goal in dynamic_goals:
|
353
|
+
if not self._is_goal_in_subspace(goal):
|
354
|
+
print(f"Goal {goal} is outside the training subspace. Fine-tuning...")
|
355
|
+
|
356
|
+
# Create a new agent for this goal
|
357
|
+
agent_kwargs = {
|
358
|
+
"domain_name": self.env_prop.domain_name,
|
359
|
+
"problem_name": self.env_prop.name,
|
360
|
+
"algorithm": self.original_train_configs[0][0],
|
361
|
+
"num_timesteps": self.original_train_configs[0][1],
|
362
|
+
"env_prop": self.env_prop,
|
363
|
+
}
|
364
|
+
|
365
|
+
# Create new agent with base model
|
366
|
+
fine_tuned_agent = self.rl_agent_type(**agent_kwargs)
|
367
|
+
fine_tuned_agent.learn() # This loads the existing model
|
368
|
+
|
369
|
+
# Fine-tune for this specific goal
|
370
|
+
fine_tuned_agent.fine_tune(
|
371
|
+
goal=goal,
|
372
|
+
num_timesteps=self.finetune_timesteps,
|
373
|
+
)
|
374
|
+
|
375
|
+
# Store the fine-tuned agent
|
376
|
+
self.fine_tuned_agents[
|
377
|
+
f"{self.env_prop.goal_to_str(goal)}_{self.finetune_timesteps}"
|
378
|
+
] = fine_tuned_agent
|
379
|
+
else:
|
380
|
+
print(f"Goal {goal} is within the training subspace. Using base agent.")
|
381
|
+
|
382
|
+
def choose_agent(self, problem_name: str) -> RLAgent:
|
383
|
+
"""
|
384
|
+
Return the appropriate agent for the given problem.
|
385
|
+
|
386
|
+
If the goal has a fine-tuned agent, return that; otherwise return the base agent.
|
387
|
+
|
388
|
+
Args:
|
389
|
+
problem_name: The problem name to get agent for
|
390
|
+
|
391
|
+
Returns:
|
392
|
+
The appropriate agent (base or fine-tuned)
|
393
|
+
"""
|
394
|
+
# Extract the goal from the problem name
|
395
|
+
goal = self.env_prop.str_to_goal(problem_name)
|
396
|
+
agent_name = f"{self.env_prop.goal_to_str(goal)}_{self.finetune_timesteps}"
|
397
|
+
|
398
|
+
# Check if we have a fine-tuned agent for this goal
|
399
|
+
if agent_name in self.fine_tuned_agents:
|
400
|
+
return self.fine_tuned_agents[agent_name]
|
401
|
+
|
402
|
+
# Otherwise return the base agent
|
403
|
+
return self.agents[self.env_prop.name]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
"""
|
1
|
+
"""Collection of recognizers that use GRAML methods: metric learning for ODGR."""
|
2
2
|
|
3
3
|
import os
|
4
4
|
from abc import abstractmethod
|
@@ -124,7 +124,7 @@ class Graml(LearningRecognizer):
|
|
124
124
|
pass
|
125
125
|
|
126
126
|
def domain_learning_phase(self, base_goals: list[str], train_configs: list):
|
127
|
-
super().domain_learning_phase(
|
127
|
+
super().domain_learning_phase(train_configs, base_goals)
|
128
128
|
self.train_agents_on_base_goals(base_goals, train_configs)
|
129
129
|
# train the network so it will find a metric for the observations of the base agents such that traces of agents to different goals are far from one another
|
130
130
|
self.model_directory = get_lstm_model_dir(
|
@@ -343,7 +343,7 @@ class BGGraml(Graml):
|
|
343
343
|
assert len(base_goals) == len(
|
344
344
|
train_configs
|
345
345
|
), "base_goals and train_configs should have the same length"
|
346
|
-
super().domain_learning_phase(
|
346
|
+
super().domain_learning_phase(train_configs=train_configs, base_goals=base_goals)
|
347
347
|
|
348
348
|
# In case we need goal-directed agent for every goal
|
349
349
|
def train_agents_on_base_goals(self, base_goals: list[str], train_configs: list):
|
@@ -556,7 +556,7 @@ class GCGraml(Graml, GaAdaptingRecognizer):
|
|
556
556
|
assert (
|
557
557
|
len(train_configs) == 1
|
558
558
|
), "GCGraml should only have one train config for the base goals, it uses a single agent"
|
559
|
-
super().domain_learning_phase(
|
559
|
+
super().domain_learning_phase(train_configs=train_configs, base_goals=base_goals)
|
560
560
|
|
561
561
|
# In case we need goal-directed agent for every goal
|
562
562
|
def train_agents_on_base_goals(self, base_goals: list[str], train_configs: list):
|
gr_libs/recognizer/recognizer.py
CHANGED
@@ -36,7 +36,7 @@ class LearningRecognizer(Recognizer):
|
|
36
36
|
def __init__(self, *args, **kwargs):
|
37
37
|
super().__init__(*args, **kwargs)
|
38
38
|
|
39
|
-
def domain_learning_phase(self, base_goals: list[str]
|
39
|
+
def domain_learning_phase(self, train_configs: list, base_goals: list[str] = None):
|
40
40
|
"""
|
41
41
|
Perform the domain learning phase.
|
42
42
|
|
@@ -70,18 +70,18 @@ class GaAgentTrainerRecognizer(Recognizer):
|
|
70
70
|
None
|
71
71
|
"""
|
72
72
|
|
73
|
-
def domain_learning_phase(self, base_goals: list[str]
|
73
|
+
def domain_learning_phase(self, train_configs: list, base_goals: list[str] = None):
|
74
74
|
"""
|
75
75
|
Perform the domain learning phase.
|
76
76
|
|
77
77
|
Args:
|
78
|
-
base_goals (List[str]): List of base goals.
|
79
78
|
train_configs (List): List of training configurations.
|
79
|
+
base_goals (List[str]): List of base goals for the learning phase.
|
80
80
|
|
81
81
|
Returns:
|
82
82
|
None
|
83
83
|
"""
|
84
|
-
super().domain_learning_phase(
|
84
|
+
super().domain_learning_phase(train_configs, base_goals)
|
85
85
|
|
86
86
|
|
87
87
|
class GaAdaptingRecognizer(Recognizer):
|