gr-libs 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gr_libs/__init__.py CHANGED
@@ -1,6 +1,11 @@
1
1
  """gr_libs: Baselines for goal recognition executions on gym environments."""
2
2
 
3
- from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco, Graql
3
+ from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import (
4
+ Draco,
5
+ GCDraco,
6
+ Graql,
7
+ GCAura,
8
+ )
4
9
  from gr_libs.recognizer.graml.graml_recognizer import ExpertBasedGraml, GCGraml
5
10
 
6
11
  try:
gr_libs/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.5'
21
- __version_tuple__ = version_tuple = (0, 2, 5)
20
+ __version__ = version = '0.2.6'
21
+ __version_tuple__ = version_tuple = (0, 2, 6)
@@ -1,4 +1,4 @@
1
- """ environment.py """
1
+ """environment.py"""
2
2
 
3
3
  import os
4
4
  import sys
@@ -14,6 +14,8 @@ from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
14
14
  from PIL import Image
15
15
  from stable_baselines3.common.vec_env import DummyVecEnv
16
16
 
17
+ from gr_envs.wrappers.goal_wrapper import GoalRecognitionWrapper
18
+
17
19
  MINIGRID, PANDA, PARKING, POINT_MAZE = "minigrid", "panda", "parking", "point_maze"
18
20
 
19
21
  QLEARNING = "QLEARNING"
@@ -110,6 +112,12 @@ class EnvProperty:
110
112
  Convert a list of problems to a string tuple.
111
113
  """
112
114
 
115
+ @abstractmethod
116
+ def goal_to_str(self, goal):
117
+ """
118
+ Convert a goal to a string representation.
119
+ """
120
+
113
121
  @abstractmethod
114
122
  def goal_to_problem_str(self, goal):
115
123
  """
@@ -166,6 +174,29 @@ class EnvProperty:
166
174
  Change the goal to a specific desired goal.
167
175
  """
168
176
 
177
+ def is_goal_in_subspace(self, goal):
178
+ """
179
+ Check if a goal is within the specified goal subspace.
180
+
181
+ Args:
182
+ goal: The goal to check
183
+ goal_subspace: The goal subspace to check against
184
+
185
+ Returns:
186
+ bool: True if the goal is within the subspace, False otherwise
187
+ """
188
+ env = gym.make(id=self.name)
189
+ while env is not None and hasattr(env, "env"):
190
+ if isinstance(env, GoalRecognitionWrapper) and hasattr(
191
+ env, "is_goal_in_subspace"
192
+ ):
193
+ # If the environment has a goal recognition wrapper, use its method
194
+ return env.is_goal_in_subspace(goal)
195
+ # Traverse through wrappers to find the base environment
196
+ env = env.env
197
+
198
+ return True
199
+
169
200
 
170
201
  class GCEnvProperty(EnvProperty):
171
202
  """
@@ -194,16 +225,25 @@ class MinigridProperty(EnvProperty):
194
225
  super().__init__(name)
195
226
  self.domain_name = "minigrid"
196
227
 
228
+ def goal_to_str(self, goal):
229
+ """
230
+ Convert a goal to a string representation.
231
+ """
232
+ return f"{goal[0]}x{goal[1]}"
233
+
197
234
  def goal_to_problem_str(self, goal):
198
235
  """
199
236
  Convert a goal to a problem string.
200
237
  """
201
- return self.name + f"-DynamicGoal-{goal[0]}x{goal[1]}-v0"
238
+ return self.name + f"-DynamicGoal-{self.goal_to_str(goal)}-v0"
202
239
 
203
- def str_to_goal(self, problem_name):
240
+ def str_to_goal(self, problem_name=None):
204
241
  """
205
242
  Convert a problem name to a goal.
206
243
  """
244
+ if problem_name is None:
245
+ problem_name = self.name
246
+
207
247
  parts = problem_name.split("-")
208
248
  goal_part = [part for part in parts if "x" in part]
209
249
  width, height = goal_part[0].split("x")
@@ -325,30 +365,36 @@ class PandaProperty(GCEnvProperty):
325
365
  super().__init__(name)
326
366
  self.domain_name = "panda"
327
367
 
328
- def str_to_goal(self, problem_name):
368
+ def str_to_goal(self, problem_name=None):
329
369
  """
330
370
  Convert a problem name to a goal.
331
371
  """
372
+ if problem_name is None:
373
+ return "general"
332
374
  try:
333
375
  numeric_part = problem_name.split("PandaMyReachDenseX")[1]
334
376
  components = [
335
377
  component.replace("-v3", "").replace("y", ".").replace("M", "-")
336
378
  for component in numeric_part.split("X")
337
379
  ]
338
- floats = []
339
- for component in components:
340
- floats.append(float(component))
341
- return np.array([floats], dtype=np.float32)
380
+ floats = [float(component) for component in components]
381
+ return np.array([floats])
342
382
  except Exception:
343
383
  return "general"
344
384
 
345
- def goal_to_problem_str(self, goal):
385
+ def goal_to_str(self, goal):
346
386
  """
347
- Convert a goal to a problem string.
387
+ Convert a goal to a string representation.
348
388
  """
349
- goal_str = "X".join(
389
+ return "X".join(
350
390
  [str(float(g)).replace(".", "y").replace("-", "M") for g in goal[0]]
351
391
  )
392
+
393
+ def goal_to_problem_str(self, goal):
394
+ """
395
+ Convert a goal to a problem string.
396
+ """
397
+ goal_str = self.goal_to_str(goal)
352
398
  return f"PandaMyReachDenseX{goal_str}-v3"
353
399
 
354
400
  def gc_adaptable(self):
@@ -450,10 +496,34 @@ class ParkingProperty(GCEnvProperty):
450
496
  super().__init__(name)
451
497
  self.domain_name = "parking"
452
498
 
499
+ def str_to_goal(self, problem_name=None):
500
+ """
501
+ Convert a problem name to a goal.
502
+ """
503
+ if not problem_name:
504
+ problem_name = self.name
505
+ # Extract the goal from the part
506
+ return int(problem_name.split("GI-")[1].split("-v0")[0])
507
+
508
+ def goal_to_str(self, goal):
509
+ """
510
+ Convert a goal to a string representation.
511
+ """
512
+ if isinstance(goal, int):
513
+ return str(goal)
514
+ elif isinstance(goal, str):
515
+ return goal
516
+ else:
517
+ raise ValueError(
518
+ f"Unsupported goal type: {type(goal)}. Expected int or str."
519
+ )
520
+
453
521
  def goal_to_problem_str(self, goal):
454
522
  """
455
523
  Convert a goal to a problem string.
456
524
  """
525
+ if "-GI-" in self.name:
526
+ return self.name.split("-GI-")[0] + f"-GI-{goal}-v0"
457
527
  return self.name.split("-v0")[0] + f"-GI-{goal}-v0"
458
528
 
459
529
  def gc_adaptable(self):
@@ -536,9 +606,11 @@ class PointMazeProperty(EnvProperty):
536
606
  super().__init__(name)
537
607
  self.domain_name = "point_maze"
538
608
 
539
- def str_to_goal(self):
609
+ def str_to_goal(self, problem_name=None):
540
610
  """Convert a problem name to a goal."""
541
- parts = self.name.split("-")
611
+ if not problem_name:
612
+ problem_name = self.name
613
+ parts = problem_name.split("-")
542
614
  # Find the part containing the goal size (usually after "DynamicGoal")
543
615
  sizes_parts = [part for part in parts if "x" in part]
544
616
  goal_part = sizes_parts[1]
@@ -546,9 +618,15 @@ class PointMazeProperty(EnvProperty):
546
618
  width, height = goal_part.split("x")
547
619
  return (int(width), int(height))
548
620
 
621
+ def goal_to_str(self, goal):
622
+ """
623
+ Convert a goal to a string representation.
624
+ """
625
+ return f"{goal[0]}x{goal[1]}"
626
+
549
627
  def gc_adaptable(self):
550
628
  """Check if the environment is goal-conditioned adaptable."""
551
- return False
629
+ return True
552
630
 
553
631
  def problem_list_to_str_tuple(self, problems):
554
632
  """Convert a list of problems to a string tuple."""
@@ -574,7 +652,12 @@ class PointMazeProperty(EnvProperty):
574
652
  """
575
653
  Convert a goal to a problem string.
576
654
  """
577
- return self.name + f"-Goal-{goal[0]}x{goal[1]}"
655
+ possible_suffixes = ["-Goals-", "-Goal-", "-MultiGoals-", "-GoalConditioned-"]
656
+ for suffix in possible_suffixes:
657
+ if suffix in self.name:
658
+ return self.name.split(suffix)[0] + f"-Goal-{self.goal_to_str(goal)}"
659
+
660
+ return self.name + f"-Goal-{self.goal_to_str(goal)}"
578
661
 
579
662
  def change_done_by_specific_desired(self, obs, desired, old_success_done):
580
663
  """
@@ -592,6 +675,12 @@ class PointMazeProperty(EnvProperty):
592
675
  assert isinstance(done, np.ndarray)
593
676
  return done[0]
594
677
 
678
+ def use_goal_directed_problem(self):
679
+ """
680
+ Check if the environment uses a goal-directed problem.
681
+ """
682
+ return True
683
+
595
684
  def is_success(self, info):
596
685
  """
597
686
  Check if the episode is successful.
gr_libs/ml/consts.py CHANGED
@@ -20,3 +20,4 @@ OPTIM_ALPHA = 0.99
20
20
  CLIP_EPS = 0.2
21
21
  RECURRENCE = 1
22
22
  TEXT = False
23
+ FINETUNE_TIMESTEPS = 100000 # for GCAura fine-tuning
@@ -1,6 +1,7 @@
1
1
  import gc
2
2
  from collections import OrderedDict
3
3
  from types import MethodType
4
+ from typing import Any
4
5
 
5
6
  import cv2
6
7
  import numpy as np
@@ -22,6 +23,10 @@ from stable_baselines3.common.base_class import BaseAlgorithm
22
23
 
23
24
  from gr_libs.ml.utils import device
24
25
 
26
+ from gr_libs.ml.consts import (
27
+ FINETUNE_TIMESTEPS,
28
+ )
29
+
25
30
  # TODO do we need this?
26
31
  NETWORK_SETUP = {
27
32
  SAC: OrderedDict(
@@ -236,27 +241,46 @@ class DeepRLAgent:
236
241
  self._model_file_path, env=self.env, device=device, **self.model_kwargs
237
242
  )
238
243
 
239
- def learn(self):
244
+ def learn(self, goal=None, total_timesteps=None):
240
245
  """Train the agent."""
241
- if os.path.exists(self._model_file_path):
242
- print(f"Loading pre-existing model in {self._model_file_path}")
246
+ model_file_path = self._model_file_path
247
+ old_model_file_path = model_file_path
248
+ if goal is not None:
249
+ model_file_path = self._model_file_path.replace(
250
+ ".pth", f"_{goal}.pth"
251
+ ).replace(".zip", f"_{goal}.zip")
252
+ if total_timesteps is not None:
253
+ model_file_path = model_file_path.replace(
254
+ ".pth", f"_{total_timesteps}.pth"
255
+ ).replace(".zip", f"_{total_timesteps}.zip")
256
+
257
+ self._model_file_path = model_file_path
258
+
259
+ if os.path.exists(model_file_path):
260
+ print(f"Loading pre-existing model in {model_file_path}")
243
261
  self.load_model()
244
262
  else:
245
- print(f"No existing model in {self._model_file_path}, starting learning")
246
- if self.exploration_rate is not None:
247
- self._model = self.algorithm(
248
- "MultiInputPolicy",
249
- self.env,
250
- ent_coef=self.exploration_rate,
251
- verbose=1,
252
- )
253
- else:
254
- self._model = self.algorithm("MultiInputPolicy", self.env, verbose=1)
263
+ print(f"No existing model in {model_file_path}, starting learning")
264
+ if total_timesteps is None:
265
+ total_timesteps = self.num_timesteps
266
+ if self.exploration_rate is not None:
267
+ self._model = self.algorithm(
268
+ "MultiInputPolicy",
269
+ self.env,
270
+ ent_coef=self.exploration_rate,
271
+ verbose=1,
272
+ )
273
+ else:
274
+ self._model = self.algorithm(
275
+ "MultiInputPolicy", self.env, verbose=1
276
+ )
255
277
  self._model.learn(
256
- total_timesteps=self.num_timesteps, progress_bar=True
278
+ total_timesteps=total_timesteps, progress_bar=True
257
279
  ) # comment this in a normal env
258
280
  self.save_model()
259
281
 
282
+ self._model_file_path = old_model_file_path
283
+
260
284
  def safe_env_reset(self):
261
285
  """
262
286
  Reset the environment safely, suppressing output.
@@ -503,6 +527,69 @@ class DeepRLAgent:
503
527
  self.env.close()
504
528
  return observations
505
529
 
530
+ def fine_tune(
531
+ self,
532
+ goal: Any,
533
+ num_timesteps: int = FINETUNE_TIMESTEPS,
534
+ ) -> None:
535
+ """
536
+ Fine-tune this goal-conditioned agent on a single specified goal.
537
+ Overrides optimizer LR if provided, resets the env to the goal, and continues training.
538
+
539
+ Args:
540
+ goal: The specific goal to fine-tune on. Type depends on the environment.
541
+ num_timesteps: Number of timesteps for fine-tuning. Defaults to FINETUNE_TIMESTEPS.
542
+ learning_rate: Learning rate for fine-tuning. Defaults to FINETUNE_LR.
543
+ """
544
+ # Store original environment and problem
545
+ original_env = self.env
546
+ original_problem = self.problem_name
547
+ created_new_env = False
548
+
549
+ try:
550
+ # Try to create a goal-specific environment
551
+ if hasattr(self.env_prop, "goal_to_problem_str") and callable(
552
+ self.env_prop.goal_to_problem_str
553
+ ):
554
+ try:
555
+ goal_problem = self.env_prop.goal_to_problem_str(goal)
556
+
557
+ # Create the goal-specific environment
558
+ env_kwargs = {"id": goal_problem, "render_mode": "rgb_array"}
559
+ new_env = self.env_prop.create_vec_env(env_kwargs)
560
+
561
+ # Update the model's environment
562
+ self._model.set_env(new_env)
563
+ self.env = new_env
564
+ self.problem_name = goal_problem
565
+ created_new_env = True
566
+ print(f"Created a new environment for fine-tuning: {goal_problem}")
567
+ except Exception as e:
568
+ print(f"Warning: Could not create goal-specific environment: {e}")
569
+
570
+ if not created_new_env:
571
+ print(
572
+ (
573
+ "Fine-tuning requires a goal-specific environment."
574
+ "Please ensure that the environment with the specified goal exists."
575
+ )
576
+ )
577
+
578
+ print(f"Fine-tuning for {num_timesteps} timesteps...")
579
+ self.learn(
580
+ goal=self.env_prop.goal_to_str(goal), total_timesteps=num_timesteps
581
+ )
582
+ print("Fine-tuning complete. Model saved.")
583
+
584
+ finally:
585
+ # Restore original environment if needed
586
+ if created_new_env:
587
+ self.env.close()
588
+ self._model.set_env(original_env)
589
+ self.env = original_env
590
+ self.problem_name = original_problem
591
+ print("Restored original environment.")
592
+
506
593
 
507
594
  class GCDeepRLAgent(DeepRLAgent):
508
595
  """
gr_libs/odgr_executor.py CHANGED
@@ -15,7 +15,7 @@ from gr_libs.ml.utils.storage import (
15
15
  )
16
16
  from gr_libs.problems.consts import PROBLEMS
17
17
  from gr_libs.recognizer._utils import recognizer_str_to_obj
18
- from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco
18
+ from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco, GCAura
19
19
  from gr_libs.recognizer.graml.graml_recognizer import Graml
20
20
  from gr_libs.recognizer.recognizer import GaAgentTrainerRecognizer, LearningRecognizer
21
21
 
@@ -102,7 +102,11 @@ def run_odgr_problem(args):
102
102
  }
103
103
 
104
104
  # need to dump the whole plan for draco because it needs it for inference phase for checking likelihood.
105
- if (recognizer_type == Draco or recognizer_type == GCDraco) and issubclass(
105
+ if (
106
+ recognizer_type == Draco
107
+ or recognizer_type == GCDraco
108
+ or recognizer_type == GCAura
109
+ ) and issubclass(
106
110
  rl_agent_type, DeepRLAgent
107
111
  ): # TODO remove this condition, remove the assumption.
108
112
  generate_obs_kwargs["with_dict"] = True
@@ -224,6 +228,7 @@ def parse_args():
224
228
  "Graql",
225
229
  "Draco",
226
230
  "GCDraco",
231
+ "GCAura",
227
232
  ],
228
233
  required=True,
229
234
  help="Recognizer type. Follow readme.md and recognizer folder for more information and rules.",
@@ -1,4 +1,9 @@
1
- from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import Draco, GCDraco, Graql
1
+ from gr_libs.recognizer.gr_as_rl.gr_as_rl_recognizer import (
2
+ Draco,
3
+ GCDraco,
4
+ Graql,
5
+ GCAura,
6
+ )
2
7
  from gr_libs.recognizer.graml.graml_recognizer import (
3
8
  ExpertBasedGraml,
4
9
  GCGraml,
@@ -14,5 +19,6 @@ def recognizer_str_to_obj(recognizer_str: str):
14
19
  "Graql": Graql,
15
20
  "Draco": Draco,
16
21
  "GCDraco": GCDraco,
22
+ "GCAura": GCAura,
17
23
  }
18
24
  return recognizer_map.get(recognizer_str)
@@ -8,12 +8,14 @@ from gr_libs.ml.base import RLAgent
8
8
  from gr_libs.ml.neural.deep_rl_learner import DeepRLAgent, GCDeepRLAgent
9
9
  from gr_libs.ml.tabular.tabular_q_learner import TabularQLearner
10
10
  from gr_libs.ml.utils.storage import get_gr_as_rl_experiment_confidence_path
11
+ from gymnasium.envs.registration import register, registry
11
12
  from gr_libs.recognizer.recognizer import (
12
13
  GaAdaptingRecognizer,
13
14
  GaAgentTrainerRecognizer,
14
15
  LearningRecognizer,
15
16
  Recognizer,
16
17
  )
18
+ from gr_libs.ml.consts import FINETUNE_TIMESTEPS
17
19
 
18
20
 
19
21
  class GRAsRL(Recognizer):
@@ -234,7 +236,7 @@ class GCDraco(GRAsRL, LearningRecognizer, GaAdaptingRecognizer):
234
236
  base = problems["gc"]
235
237
  base_goals = base["goals"]
236
238
  train_configs = base["train_configs"]
237
- super().domain_learning_phase(base_goals, train_configs)
239
+ super().domain_learning_phase(train_configs, base_goals)
238
240
  agent_kwargs = {
239
241
  "domain_name": self.env_prop.domain_name,
240
242
  "problem_name": self.env_prop.name,
@@ -256,3 +258,146 @@ class GCDraco(GRAsRL, LearningRecognizer, GaAdaptingRecognizer):
256
258
 
257
259
  def choose_agent(self, problem_name: str) -> RLAgent:
258
260
  return next(iter(self.agents.values()))
261
+
262
+
263
+ class GCAura(GRAsRL, LearningRecognizer, GaAdaptingRecognizer):
264
+ """
265
+ GCAura uses goal-conditioned reinforcement learning with adaptive fine-tuning.
266
+
267
+ It trains a base goal-conditioned policy over a goal subspace in the domain learning phase.
268
+ During the goal adaptation phase, it checks if new goals are within the original goal subspace:
269
+ - If a goal is within the subspace, it uses the original trained model
270
+ - If a goal is outside the subspace, it fine-tunes the model for that specific goal
271
+
272
+ This approach combines the efficiency of goal-conditioned RL with the precision of
273
+ goal-specific fine-tuning when needed.
274
+ """
275
+
276
+ def __init__(self, *args, **kwargs):
277
+ super().__init__(*args, **kwargs)
278
+ assert (
279
+ self.env_prop.gc_adaptable()
280
+ and not self.env_prop.is_state_discrete()
281
+ and not self.env_prop.is_action_discrete()
282
+ )
283
+ if self.rl_agent_type is None:
284
+ self.rl_agent_type = GCDeepRLAgent
285
+ self.evaluation_function = kwargs.get("evaluation_function")
286
+ if self.evaluation_function is None:
287
+ from gr_libs.metrics.metrics import mean_wasserstein_distance
288
+
289
+ self.evaluation_function = mean_wasserstein_distance
290
+ assert callable(
291
+ self.evaluation_function
292
+ ), "Evaluation function must be a callable function."
293
+
294
+ # Store fine-tuning parameters
295
+ self.finetune_timesteps = kwargs.get("finetune_timesteps", FINETUNE_TIMESTEPS)
296
+
297
+ # Dictionary to store fine-tuned agents for specific goals
298
+ self.fine_tuned_agents = {}
299
+
300
+ def domain_learning_phase(self, problems):
301
+ base = problems["gc"]
302
+ train_configs = base["train_configs"]
303
+
304
+ # Store the goal subspace for later checks
305
+ self.original_train_configs = train_configs
306
+
307
+ super().domain_learning_phase(train_configs)
308
+
309
+ agent_kwargs = {
310
+ "domain_name": self.env_prop.domain_name,
311
+ "problem_name": self.env_prop.name,
312
+ "algorithm": train_configs[0][0],
313
+ "num_timesteps": train_configs[0][1],
314
+ "env_prop": self.env_prop,
315
+ }
316
+
317
+ agent = self.rl_agent_type(**agent_kwargs)
318
+ agent.learn()
319
+ self.agents[self.env_prop.name] = agent
320
+ self.action_space = agent.env.action_space
321
+
322
+ def _is_goal_in_subspace(self, goal):
323
+ """
324
+ Check if a goal is within the original training subspace.
325
+
326
+ Delegates to the environment property's implementation.
327
+
328
+ Args:
329
+ goal: The goal to check
330
+
331
+ Returns:
332
+ bool: True if the goal is within the training subspace
333
+ """
334
+ # Use the environment property's implementation
335
+ return self.env_prop.is_goal_in_subspace(goal)
336
+
337
+ def goals_adaptation_phase(self, dynamic_goals):
338
+ """
339
+ Adapt to new goals, fine-tuning if necessary.
340
+
341
+ For goals outside the original training subspace, fine-tune the model.
342
+
343
+ Args:
344
+ dynamic_goals: List of goals to adapt to
345
+ """
346
+ self.active_goals = dynamic_goals
347
+ self.active_problems = [
348
+ self.env_prop.goal_to_problem_str(goal) for goal in dynamic_goals
349
+ ]
350
+
351
+ # Check each goal and fine-tune if needed
352
+ for goal in dynamic_goals:
353
+ if not self._is_goal_in_subspace(goal):
354
+ print(f"Goal {goal} is outside the training subspace. Fine-tuning...")
355
+
356
+ # Create a new agent for this goal
357
+ agent_kwargs = {
358
+ "domain_name": self.env_prop.domain_name,
359
+ "problem_name": self.env_prop.name,
360
+ "algorithm": self.original_train_configs[0][0],
361
+ "num_timesteps": self.original_train_configs[0][1],
362
+ "env_prop": self.env_prop,
363
+ }
364
+
365
+ # Create new agent with base model
366
+ fine_tuned_agent = self.rl_agent_type(**agent_kwargs)
367
+ fine_tuned_agent.learn() # This loads the existing model
368
+
369
+ # Fine-tune for this specific goal
370
+ fine_tuned_agent.fine_tune(
371
+ goal=goal,
372
+ num_timesteps=self.finetune_timesteps,
373
+ )
374
+
375
+ # Store the fine-tuned agent
376
+ self.fine_tuned_agents[
377
+ f"{self.env_prop.goal_to_str(goal)}_{self.finetune_timesteps}"
378
+ ] = fine_tuned_agent
379
+ else:
380
+ print(f"Goal {goal} is within the training subspace. Using base agent.")
381
+
382
+ def choose_agent(self, problem_name: str) -> RLAgent:
383
+ """
384
+ Return the appropriate agent for the given problem.
385
+
386
+ If the goal has a fine-tuned agent, return that; otherwise return the base agent.
387
+
388
+ Args:
389
+ problem_name: The problem name to get agent for
390
+
391
+ Returns:
392
+ The appropriate agent (base or fine-tuned)
393
+ """
394
+ # Extract the goal from the problem name
395
+ goal = self.env_prop.str_to_goal(problem_name)
396
+ agent_name = f"{self.env_prop.goal_to_str(goal)}_{self.finetune_timesteps}"
397
+
398
+ # Check if we have a fine-tuned agent for this goal
399
+ if agent_name in self.fine_tuned_agents:
400
+ return self.fine_tuned_agents[agent_name]
401
+
402
+ # Otherwise return the base agent
403
+ return self.agents[self.env_prop.name]
@@ -1,4 +1,4 @@
1
- """ Collection of recognizers that use GRAML methods: metric learning for ODGR. """
1
+ """Collection of recognizers that use GRAML methods: metric learning for ODGR."""
2
2
 
3
3
  import os
4
4
  from abc import abstractmethod
@@ -124,7 +124,7 @@ class Graml(LearningRecognizer):
124
124
  pass
125
125
 
126
126
  def domain_learning_phase(self, base_goals: list[str], train_configs: list):
127
- super().domain_learning_phase(base_goals, train_configs)
127
+ super().domain_learning_phase(train_configs, base_goals)
128
128
  self.train_agents_on_base_goals(base_goals, train_configs)
129
129
  # train the network so it will find a metric for the observations of the base agents such that traces of agents to different goals are far from one another
130
130
  self.model_directory = get_lstm_model_dir(
@@ -343,7 +343,7 @@ class BGGraml(Graml):
343
343
  assert len(base_goals) == len(
344
344
  train_configs
345
345
  ), "base_goals and train_configs should have the same length"
346
- super().domain_learning_phase(base_goals, train_configs)
346
+ super().domain_learning_phase(train_configs=train_configs, base_goals=base_goals)
347
347
 
348
348
  # In case we need goal-directed agent for every goal
349
349
  def train_agents_on_base_goals(self, base_goals: list[str], train_configs: list):
@@ -556,7 +556,7 @@ class GCGraml(Graml, GaAdaptingRecognizer):
556
556
  assert (
557
557
  len(train_configs) == 1
558
558
  ), "GCGraml should only have one train config for the base goals, it uses a single agent"
559
- super().domain_learning_phase(base_goals, train_configs)
559
+ super().domain_learning_phase(train_configs=train_configs, base_goals=base_goals)
560
560
 
561
561
  # In case we need goal-directed agent for every goal
562
562
  def train_agents_on_base_goals(self, base_goals: list[str], train_configs: list):
@@ -36,7 +36,7 @@ class LearningRecognizer(Recognizer):
36
36
  def __init__(self, *args, **kwargs):
37
37
  super().__init__(*args, **kwargs)
38
38
 
39
- def domain_learning_phase(self, base_goals: list[str], train_configs: list):
39
+ def domain_learning_phase(self, train_configs: list, base_goals: list[str] = None):
40
40
  """
41
41
  Perform the domain learning phase.
42
42
 
@@ -70,18 +70,18 @@ class GaAgentTrainerRecognizer(Recognizer):
70
70
  None
71
71
  """
72
72
 
73
- def domain_learning_phase(self, base_goals: list[str], train_configs: list):
73
+ def domain_learning_phase(self, train_configs: list, base_goals: list[str] = None):
74
74
  """
75
75
  Perform the domain learning phase.
76
76
 
77
77
  Args:
78
- base_goals (List[str]): List of base goals.
79
78
  train_configs (List): List of training configurations.
79
+ base_goals (List[str]): List of base goals for the learning phase.
80
80
 
81
81
  Returns:
82
82
  None
83
83
  """
84
- super().domain_learning_phase(base_goals, train_configs)
84
+ super().domain_learning_phase(train_configs, base_goals)
85
85
 
86
86
 
87
87
  class GaAdaptingRecognizer(Recognizer):