cogames 0.3.65__py3-none-any.whl → 0.3.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. cogames/cli/client.py +0 -3
  2. cogames/cli/docsync/docsync.py +7 -1
  3. cogames/cli/mission.py +44 -19
  4. cogames/cli/policy.py +26 -10
  5. cogames/cli/submit.py +127 -141
  6. cogames/cli/utils.py +5 -0
  7. cogames/cogs_vs_clips/clip_difficulty.py +57 -0
  8. cogames/cogs_vs_clips/clips.py +23 -6
  9. cogames/cogs_vs_clips/cog.py +16 -5
  10. cogames/cogs_vs_clips/cogsguard_curriculum.py +122 -0
  11. cogames/cogs_vs_clips/cogsguard_tutorial.py +5 -5
  12. cogames/cogs_vs_clips/config.py +1 -1
  13. cogames/cogs_vs_clips/docs/cogs_vs_clips_mapgen.md +2 -3
  14. cogames/cogs_vs_clips/evals/README.md +8 -32
  15. cogames/cogs_vs_clips/evals/diagnostic_evals.py +0 -1
  16. cogames/cogs_vs_clips/evals/difficulty_variants.py +7 -10
  17. cogames/cogs_vs_clips/mission.py +38 -10
  18. cogames/cogs_vs_clips/missions.py +1 -1
  19. cogames/cogs_vs_clips/reward_variants.py +173 -0
  20. cogames/cogs_vs_clips/sites.py +6 -5
  21. cogames/cogs_vs_clips/stations.py +13 -9
  22. cogames/cogs_vs_clips/team.py +3 -1
  23. cogames/cogs_vs_clips/terrain.py +2 -2
  24. cogames/cogs_vs_clips/variants.py +175 -4
  25. cogames/cogs_vs_clips/weather.py +52 -0
  26. cogames/docs/SCRIPTED_AGENT.md +3 -3
  27. cogames/evaluate.py +4 -2
  28. cogames/main.py +357 -51
  29. cogames/maps/canidate1_1000.map +1 -1
  30. cogames/maps/canidate1_1000_stations.map +2 -2
  31. cogames/maps/canidate1_500.map +1 -1
  32. cogames/maps/canidate1_500_stations.map +2 -2
  33. cogames/maps/canidate2_1000.map +1 -1
  34. cogames/maps/canidate2_1000_stations.map +2 -2
  35. cogames/maps/canidate2_500.map +1 -1
  36. cogames/maps/canidate2_500_stations.map +1 -1
  37. cogames/maps/canidate3_1000.map +1 -1
  38. cogames/maps/canidate3_1000_stations.map +2 -2
  39. cogames/maps/canidate3_500.map +1 -1
  40. cogames/maps/canidate3_500_stations.map +2 -2
  41. cogames/maps/canidate4_500.map +1 -1
  42. cogames/maps/canidate4_500_stations.map +2 -2
  43. cogames/maps/cave_base_50.map +2 -2
  44. cogames/maps/diagnostic_evals/diagnostic_agile.map +2 -2
  45. cogames/maps/diagnostic_evals/diagnostic_agile_hard.map +2 -2
  46. cogames/maps/diagnostic_evals/diagnostic_charge_up.map +6 -6
  47. cogames/maps/diagnostic_evals/diagnostic_charge_up_hard.map +6 -6
  48. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1.map +6 -6
  49. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1_hard.map +6 -6
  50. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2.map +6 -6
  51. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2_hard.map +6 -6
  52. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3.map +6 -6
  53. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3_hard.map +6 -6
  54. cogames/maps/diagnostic_evals/diagnostic_chest_near.map +6 -6
  55. cogames/maps/diagnostic_evals/diagnostic_chest_search.map +6 -6
  56. cogames/maps/diagnostic_evals/diagnostic_chest_search_hard.map +6 -6
  57. cogames/maps/diagnostic_evals/diagnostic_extract_lab.map +6 -6
  58. cogames/maps/diagnostic_evals/diagnostic_extract_lab_hard.map +6 -6
  59. cogames/maps/diagnostic_evals/diagnostic_memory.map +6 -6
  60. cogames/maps/diagnostic_evals/diagnostic_memory_hard.map +6 -6
  61. cogames/maps/diagnostic_evals/diagnostic_radial.map +2 -2
  62. cogames/maps/diagnostic_evals/diagnostic_radial_hard.map +2 -2
  63. cogames/maps/diagnostic_evals/diagnostic_resource_lab.map +6 -6
  64. cogames/maps/diagnostic_evals/diagnostic_unclip.map +6 -6
  65. cogames/maps/evals/eval_balanced_spread.map +6 -6
  66. cogames/maps/evals/eval_clip_oxygen.map +6 -6
  67. cogames/maps/evals/eval_collect_resources.map +6 -6
  68. cogames/maps/evals/eval_collect_resources_hard.map +6 -6
  69. cogames/maps/evals/eval_collect_resources_medium.map +6 -6
  70. cogames/maps/evals/eval_divide_and_conquer.map +6 -6
  71. cogames/maps/evals/eval_energy_starved.map +6 -6
  72. cogames/maps/evals/eval_multi_coordinated_collect_hard.map +6 -6
  73. cogames/maps/evals/eval_oxygen_bottleneck.map +6 -6
  74. cogames/maps/evals/eval_single_use_world.map +6 -6
  75. cogames/maps/evals/extractor_hub_100x100.map +6 -6
  76. cogames/maps/evals/extractor_hub_30x30.map +6 -6
  77. cogames/maps/evals/extractor_hub_50x50.map +6 -6
  78. cogames/maps/evals/extractor_hub_70x70.map +6 -6
  79. cogames/maps/evals/extractor_hub_80x80.map +6 -6
  80. cogames/maps/machina_100_stations.map +2 -2
  81. cogames/maps/machina_200_stations.map +2 -2
  82. cogames/maps/machina_200_stations_small.map +2 -2
  83. cogames/maps/machina_eval_exp01.map +2 -2
  84. cogames/maps/machina_eval_template_large.map +2 -2
  85. cogames/maps/machinatrainer4agents.map +2 -2
  86. cogames/maps/machinatrainer4agentsbase.map +2 -2
  87. cogames/maps/machinatrainerbig.map +2 -2
  88. cogames/maps/machinatrainersmall.map +2 -2
  89. cogames/maps/planky_evals/aligner_avoid_aoe.map +6 -6
  90. cogames/maps/planky_evals/aligner_full_cycle.map +6 -6
  91. cogames/maps/planky_evals/aligner_gear.map +6 -6
  92. cogames/maps/planky_evals/aligner_hearts.map +6 -6
  93. cogames/maps/planky_evals/aligner_junction.map +6 -6
  94. cogames/maps/planky_evals/exploration_distant.map +6 -6
  95. cogames/maps/planky_evals/maze.map +6 -6
  96. cogames/maps/planky_evals/miner_best_resource.map +6 -6
  97. cogames/maps/planky_evals/miner_deposit.map +6 -6
  98. cogames/maps/planky_evals/miner_extract.map +6 -6
  99. cogames/maps/planky_evals/miner_full_cycle.map +6 -6
  100. cogames/maps/planky_evals/miner_gear.map +6 -6
  101. cogames/maps/planky_evals/multi_role.map +6 -6
  102. cogames/maps/planky_evals/resource_chain.map +6 -6
  103. cogames/maps/planky_evals/scout_explore.map +6 -6
  104. cogames/maps/planky_evals/scout_gear.map +6 -6
  105. cogames/maps/planky_evals/scrambler_full_cycle.map +6 -6
  106. cogames/maps/planky_evals/scrambler_gear.map +6 -6
  107. cogames/maps/planky_evals/scrambler_target.map +6 -6
  108. cogames/maps/planky_evals/stuck_corridor.map +6 -6
  109. cogames/maps/planky_evals/survive_retreat.map +6 -6
  110. cogames/maps/training_facility_clipped.map +2 -2
  111. cogames/maps/training_facility_open_1.map +2 -2
  112. cogames/maps/training_facility_open_2.map +2 -2
  113. cogames/maps/training_facility_open_3.map +2 -2
  114. cogames/maps/training_facility_tight_4.map +2 -2
  115. cogames/maps/training_facility_tight_5.map +2 -2
  116. cogames/maps/vanilla_large.map +2 -2
  117. cogames/maps/vanilla_small.map +2 -2
  118. cogames/pickup.py +6 -5
  119. cogames/play.py +14 -16
  120. cogames/policy/nim_agents/__init__.py +0 -2
  121. cogames/policy/nim_agents/agents.py +0 -11
  122. cogames/policy/starter_agent.py +4 -1
  123. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/METADATA +45 -29
  124. cogames-0.3.68.dist-info/RECORD +160 -0
  125. metta_alo/scoring.py +7 -7
  126. cogames-0.3.65.dist-info/RECORD +0 -160
  127. metta_alo/job_specs.py +0 -17
  128. metta_alo/policy.py +0 -16
  129. metta_alo/pure_single_episode_runner.py +0 -75
  130. metta_alo/rollout.py +0 -322
  131. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/WHEEL +0 -0
  132. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/entry_points.txt +0 -0
  133. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/licenses/LICENSE +0 -0
  134. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/top_level.txt +0 -0
cogames/cli/client.py CHANGED
@@ -253,9 +253,6 @@ class TournamentServerClient:
253
253
  json=payload,
254
254
  )
255
255
 
256
- def update_policy_version_tags(self, policy_version_id: uuid.UUID, tags: dict[str, str]) -> dict[str, Any]:
257
- return self._put(f"/stats/policies/versions/{policy_version_id}/tags", json=tags)
258
-
259
256
  def get_policy_memberships(self, policy_version_id: uuid.UUID) -> list[dict[str, Any]]:
260
257
  return self._get(f"/tournament/policies/{policy_version_id}/memberships", list[dict[str, Any]])
261
258
 
@@ -120,7 +120,13 @@ def check(
120
120
  typer.echo("\nNotebook documentation is out of sync!", err=True)
121
121
  for error in errors:
122
122
  typer.echo(f" {error}", err=True)
123
- typer.echo("\nTo fix: run 'cogames docsync all'", err=True)
123
+ typer.echo(
124
+ "\nThis can happen if you modified cogames CLI flags (the README includes a command reference)"
125
+ "\nor changed code that affects notebook outputs."
126
+ "\n"
127
+ "\nTo fix: run 'cogames docsync all'",
128
+ err=True,
129
+ )
124
130
  raise typer.Exit(1)
125
131
 
126
132
  typer.echo("All notebooks are in sync!")
cogames/cli/mission.py CHANGED
@@ -8,10 +8,8 @@ from rich import box
8
8
  from rich.table import Table
9
9
 
10
10
  from cogames.cli.base import console
11
- from cogames.cogs_vs_clips.mission import (
12
- CvCMission,
13
- NumCogsVariant,
14
- )
11
+ from cogames.cogs_vs_clips.clip_difficulty import get_cogsguard_difficulty
12
+ from cogames.cogs_vs_clips.mission import CvCMission, NumCogsVariant
15
13
  from cogames.cogs_vs_clips.sites import SITES
16
14
  from cogames.cogs_vs_clips.terrain import MachinaArena
17
15
  from cogames.cogs_vs_clips.variants import HIDDEN_VARIANTS, VARIANTS
@@ -27,16 +25,16 @@ from mettagrid.mapgen.mapgen import MapGen
27
25
 
28
26
  @lru_cache(maxsize=1)
29
27
  def _get_core_missions() -> list[CvCMission]:
30
- from cogames.cogs_vs_clips.missions import get_core_missions
28
+ from cogames.cogs_vs_clips.missions import get_core_missions # noqa: PLC0415
31
29
 
32
30
  return get_core_missions()
33
31
 
34
32
 
35
33
  @lru_cache(maxsize=1)
36
34
  def _get_eval_missions_all() -> list[CvCMission]:
37
- from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
38
- from cogames.cogs_vs_clips.evals.integrated_evals import EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS
39
- from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS
35
+ from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
36
+ from cogames.cogs_vs_clips.evals.integrated_evals import EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS # noqa: PLC0415
37
+ from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS # noqa: PLC0415
40
38
 
41
39
  missions: list[CvCMission] = []
42
40
  missions.extend(INTEGRATED_EVAL_MISSIONS)
@@ -74,19 +72,21 @@ def load_mission_set(mission_set: str) -> list[CvCMission]:
74
72
  missions_list.append(mission)
75
73
 
76
74
  elif mission_set == "diagnostic_evals":
77
- from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
75
+ from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
78
76
 
79
77
  missions_list = [mission_cls() for mission_cls in DIAGNOSTIC_EVALS] # type: ignore[call-arg]
80
78
  elif mission_set == "cogsguard_evals":
81
- from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS
79
+ from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS # noqa: PLC0415
82
80
 
83
81
  missions_list = list(COGSGUARD_EVAL_MISSIONS)
84
82
  elif mission_set == "integrated_evals":
85
- from cogames.cogs_vs_clips.evals.integrated_evals import EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS
83
+ from cogames.cogs_vs_clips.evals.integrated_evals import ( # noqa: PLC0415
84
+ EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS,
85
+ )
86
86
 
87
87
  missions_list = list(INTEGRATED_EVAL_MISSIONS)
88
88
  elif mission_set == "spanning_evals":
89
- from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS
89
+ from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS # noqa: PLC0415
90
90
 
91
91
  missions_list = list(SPANNING_EVAL_MISSIONS)
92
92
  else:
@@ -132,6 +132,12 @@ def parse_variants(variants_arg: Optional[list[str]]) -> list[CoGameMissionVaria
132
132
  return variants
133
133
 
134
134
 
135
+ def parse_difficulty(difficulty_arg: Optional[str]) -> Optional[CoGameMissionVariant]:
136
+ if difficulty_arg is None:
137
+ return None
138
+ return get_cogsguard_difficulty(difficulty_arg)
139
+
140
+
135
141
  def get_all_missions() -> list[str]:
136
142
  """Get all core mission names in the format site.mission (excludes evals)."""
137
143
  return [mission.full_name() for mission in _get_core_missions()]
@@ -157,17 +163,25 @@ def get_site_by_name(site_name: str) -> CoGameSite:
157
163
 
158
164
 
159
165
  def get_mission_name_and_config(
160
- ctx: typer.Context, mission_arg: Optional[str], variants_arg: Optional[list[str]] = None, cogs: Optional[int] = None
166
+ ctx: typer.Context,
167
+ mission_arg: Optional[str],
168
+ variants_arg: Optional[list[str]] = None,
169
+ cogs: Optional[int] = None,
170
+ difficulty: Optional[str] = None,
161
171
  ) -> tuple[str, MettaGridConfig, Optional[CvCMission]]:
162
172
  if not mission_arg:
163
173
  console.print(ctx.get_help())
164
174
  console.print("[yellow]Missing: --mission / -m[/yellow]\n")
165
175
  else:
166
176
  try:
167
- return get_mission(mission_arg, variants_arg, cogs)
177
+ return get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty)
168
178
  except ValueError as e:
169
- console.print(f"[red]Mission '{mission_arg}' not found.[/red]")
170
- console.print("[dim]Run: cogames missions (or cogames missions <site>) to list options.[/dim]\n")
179
+ error_msg = str(e)
180
+ if "variant" in error_msg.lower() or "difficulty" in error_msg.lower():
181
+ console.print(f"[red]{error_msg}[/red]")
182
+ else:
183
+ console.print(f"[red]Mission '{mission_arg}' not found.[/red]")
184
+ console.print("[dim]Run: cogames missions (or cogames missions <site>) to list options.[/dim]\n")
171
185
  raise typer.Exit(1) from e
172
186
  list_missions()
173
187
 
@@ -183,6 +197,7 @@ def get_mission_names_and_configs(
183
197
  variants_arg: Optional[list[str]] = None,
184
198
  cogs: Optional[int] = None,
185
199
  steps: Optional[int] = None,
200
+ difficulty: Optional[str] = None,
186
201
  ) -> list[tuple[str, MettaGridConfig]]:
187
202
  if not missions_arg:
188
203
  console.print(ctx.get_help())
@@ -192,7 +207,7 @@ def get_mission_names_and_configs(
192
207
  not_deduped = [
193
208
  mission
194
209
  for missions in missions_arg
195
- for mission in _get_missions_by_possible_wildcard(missions, variants_arg, cogs)
210
+ for mission in _get_missions_by_possible_wildcard(missions, variants_arg, cogs, difficulty)
196
211
  ]
197
212
  name_set: set[str] = set()
198
213
  deduped = []
@@ -224,6 +239,7 @@ def _get_missions_by_possible_wildcard(
224
239
  mission_arg: str,
225
240
  variants_arg: Optional[list[str]],
226
241
  cogs: Optional[int],
242
+ difficulty: Optional[str],
227
243
  ) -> list[tuple[str, MettaGridConfig]]:
228
244
  if "*" in mission_arg:
229
245
  # Convert shell-style wildcard to regex pattern
@@ -232,10 +248,12 @@ def _get_missions_by_possible_wildcard(
232
248
  # Drop the Mission (3rd element) for wildcard results
233
249
  return [
234
250
  (name, env_cfg)
235
- for name, env_cfg, _ in (get_mission(m, variants_arg=variants_arg, cogs=cogs) for m in missions)
251
+ for name, env_cfg, _ in (
252
+ get_mission(m, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty) for m in missions
253
+ )
236
254
  ]
237
255
  # Drop the Mission for single mission
238
- name, env_cfg, _ = get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs)
256
+ name, env_cfg, _ = get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty)
239
257
  return [(name, env_cfg)]
240
258
 
241
259
 
@@ -278,6 +296,7 @@ def get_mission(
278
296
  variants_arg: Optional[list[str]] = None,
279
297
  cogs: Optional[int] = None,
280
298
  include_legacy: bool = False,
299
+ difficulty: Optional[str] = None,
281
300
  ) -> tuple[str, MettaGridConfig, Optional[CvCMission]]:
282
301
  """Get a specific mission configuration by name or file path.
283
302
 
@@ -286,6 +305,7 @@ def get_mission(
286
305
  variants_arg: List of variant names like ["solar_flare", "dark_side"]
287
306
  cogs: Number of cogs (agents) to use, overrides the default from the mission
288
307
  include_legacy: Whether to include legacy (pre-CogsGuard) missions
308
+ difficulty: Difficulty name (easy, medium, hard) controlling clips events
289
309
 
290
310
  Returns:
291
311
  Tuple of (mission name, MettaGridConfig, CvCMission or None)
@@ -295,6 +315,8 @@ def get_mission(
295
315
  """
296
316
  # Check if it's a file path
297
317
  if any(mission_arg.endswith(ext) for ext in [".yaml", ".yml", ".json", ".py"]):
318
+ if difficulty is not None:
319
+ raise ValueError("Difficulty can only be used with mission names, not config files.")
298
320
  path = Path(mission_arg)
299
321
  if not path.exists():
300
322
  raise ValueError(f"File not found: {mission_arg}")
@@ -311,6 +333,7 @@ def get_mission(
311
333
 
312
334
  # Parse variants if provided
313
335
  variants = parse_variants(variants_arg)
336
+ difficulty_variant = parse_difficulty(difficulty)
314
337
 
315
338
  # Otherwise, treat it as a fully qualified mission name, or as a site name
316
339
  if (delim_count := mission_arg.count(MAP_MISSION_DELIMITER)) == 0:
@@ -322,6 +345,8 @@ def get_mission(
322
345
 
323
346
  mission: CvCMission = find_mission(site_name, mission_name, include_evals=True, include_legacy=include_legacy)
324
347
 
348
+ if difficulty_variant is not None:
349
+ mission = mission.with_variants([difficulty_variant])
325
350
  if variants:
326
351
  mission = mission.with_variants(variants)
327
352
  if cogs is not None:
cogames/cli/policy.py CHANGED
@@ -68,13 +68,13 @@ def _translate_error(e: Exception) -> str:
68
68
  return translated
69
69
 
70
70
 
71
- def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str]) -> PolicySpec:
71
+ def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str], device: str | None = None) -> PolicySpec:
72
72
  if policy_arg is None:
73
73
  console.print(ctx.get_help())
74
74
  console.print("[yellow]Missing: --policy / -p[/yellow]\n")
75
75
  else:
76
76
  try:
77
- return parse_policy_spec(spec=policy_arg).to_policy_spec()
77
+ return parse_policy_spec(spec=policy_arg, device=device).to_policy_spec()
78
78
  except (ValueError, ModuleNotFoundError) as e:
79
79
  translated = _translate_error(e)
80
80
  console.print(f"[yellow]Error parsing policy argument: {translated}[/yellow]\n")
@@ -90,14 +90,14 @@ def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str]) -> PolicySpec
90
90
 
91
91
 
92
92
  def get_policy_specs_with_proportions(
93
- ctx: typer.Context, policy_args: Optional[list[str]]
93
+ ctx: typer.Context, policy_args: Optional[list[str]], device: str | None = None
94
94
  ) -> list[PolicySpecWithProportion]:
95
95
  if not policy_args:
96
96
  console.print(ctx.get_help())
97
97
  console.print("[yellow]Supply at least one: --policy / -p[/yellow]\n")
98
98
  else:
99
99
  try:
100
- return [parse_policy_spec(spec=policy_arg) for policy_arg in policy_args]
100
+ return [parse_policy_spec(spec=policy_arg, device=device) for policy_arg in policy_args]
101
101
  except (ValueError, ModuleNotFoundError) as e:
102
102
  translated = _translate_error(e)
103
103
  console.print(f"[yellow]Error parsing policy argument: {translated}[/yellow]")
@@ -112,7 +112,20 @@ def get_policy_specs_with_proportions(
112
112
  raise typer.Exit(1)
113
113
 
114
114
 
115
- def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
115
+ def _apply_device_override(spec: PolicySpecWithProportion, device: str | None) -> PolicySpecWithProportion:
116
+ if device is None:
117
+ return spec
118
+ init_kwargs = dict(spec.init_kwargs or {})
119
+ init_kwargs["device"] = device
120
+ return PolicySpecWithProportion(
121
+ class_path=spec.class_path,
122
+ data_path=spec.data_path,
123
+ proportion=spec.proportion,
124
+ init_kwargs=init_kwargs,
125
+ )
126
+
127
+
128
+ def parse_policy_spec(spec: str, device: str | None = None) -> PolicySpecWithProportion:
116
129
  """Parse a policy CLI option into its components.
117
130
 
118
131
  Supports two formats:
@@ -161,13 +174,14 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
161
174
  uri_part = spec
162
175
  fraction = 1.0
163
176
 
164
- policy = policy_spec_from_uri(uri_part.strip())
165
- return PolicySpecWithProportion(
177
+ policy = policy_spec_from_uri(uri_part.strip(), device=device or "cpu")
178
+ policy_spec = PolicySpecWithProportion(
166
179
  class_path=policy.class_path,
167
180
  data_path=policy.data_path,
168
181
  proportion=fraction,
169
182
  init_kwargs=policy.init_kwargs,
170
183
  )
184
+ return _apply_device_override(policy_spec, device)
171
185
 
172
186
  entries = [part.strip() for part in spec.split(",") if part.strip()]
173
187
  if not entries:
@@ -176,19 +190,20 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
176
190
  fraction = 1.0
177
191
  first = entries[0]
178
192
  if is_uri(first) or ("=" not in first and is_path_like(first)):
179
- policy = policy_spec_from_uri(first)
193
+ policy = policy_spec_from_uri(first, device=device or "cpu")
180
194
  for entry in entries[1:]:
181
195
  key, value = parse_key_value(entry)
182
196
  if key != "proportion":
183
197
  raise ValueError("Only proportion is supported after a checkpoint URI.")
184
198
  fraction = parse_proportion(value)
185
199
 
186
- return PolicySpecWithProportion(
200
+ policy_spec = PolicySpecWithProportion(
187
201
  class_path=policy.class_path,
188
202
  data_path=policy.data_path,
189
203
  proportion=fraction,
190
204
  init_kwargs=policy.init_kwargs,
191
205
  )
206
+ return _apply_device_override(policy_spec, device)
192
207
 
193
208
  if "=" not in first:
194
209
  if ":" in first:
@@ -230,9 +245,10 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
230
245
  if class_path is None:
231
246
  raise ValueError("Policy specification must include class= for key=value format.")
232
247
 
233
- return PolicySpecWithProportion(
248
+ policy_spec = PolicySpecWithProportion(
234
249
  class_path=class_path,
235
250
  data_path=data_path,
236
251
  proportion=fraction,
237
252
  init_kwargs=init_kwargs,
238
253
  )
254
+ return _apply_device_override(policy_spec, device)