cogames 0.3.65__py3-none-any.whl → 0.3.69__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cogames/cli/client.py +0 -3
- cogames/cli/docsync/docsync.py +7 -1
- cogames/cli/mission.py +44 -19
- cogames/cli/policy.py +26 -10
- cogames/cli/submit.py +201 -495
- cogames/cli/utils.py +5 -0
- cogames/cogs_vs_clips/clip_difficulty.py +57 -0
- cogames/cogs_vs_clips/clips.py +23 -6
- cogames/cogs_vs_clips/cog.py +16 -5
- cogames/cogs_vs_clips/cogsguard_curriculum.py +122 -0
- cogames/cogs_vs_clips/cogsguard_tutorial.py +5 -5
- cogames/cogs_vs_clips/config.py +1 -1
- cogames/cogs_vs_clips/docs/cogs_vs_clips_mapgen.md +2 -3
- cogames/cogs_vs_clips/evals/README.md +8 -32
- cogames/cogs_vs_clips/evals/diagnostic_evals.py +0 -1
- cogames/cogs_vs_clips/evals/difficulty_variants.py +7 -10
- cogames/cogs_vs_clips/mission.py +38 -10
- cogames/cogs_vs_clips/missions.py +1 -1
- cogames/cogs_vs_clips/reward_variants.py +173 -0
- cogames/cogs_vs_clips/sites.py +6 -5
- cogames/cogs_vs_clips/stations.py +13 -9
- cogames/cogs_vs_clips/team.py +3 -1
- cogames/cogs_vs_clips/terrain.py +2 -2
- cogames/cogs_vs_clips/variants.py +175 -4
- cogames/cogs_vs_clips/weather.py +52 -0
- cogames/docs/SCRIPTED_AGENT.md +3 -3
- cogames/evaluate.py +4 -2
- cogames/main.py +420 -84
- cogames/maps/canidate1_1000.map +1 -1
- cogames/maps/canidate1_1000_stations.map +2 -2
- cogames/maps/canidate1_500.map +1 -1
- cogames/maps/canidate1_500_stations.map +2 -2
- cogames/maps/canidate2_1000.map +1 -1
- cogames/maps/canidate2_1000_stations.map +2 -2
- cogames/maps/canidate2_500.map +1 -1
- cogames/maps/canidate2_500_stations.map +1 -1
- cogames/maps/canidate3_1000.map +1 -1
- cogames/maps/canidate3_1000_stations.map +2 -2
- cogames/maps/canidate3_500.map +1 -1
- cogames/maps/canidate3_500_stations.map +2 -2
- cogames/maps/canidate4_500.map +1 -1
- cogames/maps/canidate4_500_stations.map +2 -2
- cogames/maps/cave_base_50.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_agile.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_agile_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_charge_up.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_charge_up_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation1.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation1_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation2.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation2_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation3.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation3_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_near.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_search.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_search_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_extract_lab.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_extract_lab_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_memory.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_memory_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_radial.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_radial_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_resource_lab.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_unclip.map +6 -6
- cogames/maps/evals/eval_balanced_spread.map +6 -6
- cogames/maps/evals/eval_clip_oxygen.map +6 -6
- cogames/maps/evals/eval_collect_resources.map +6 -6
- cogames/maps/evals/eval_collect_resources_hard.map +6 -6
- cogames/maps/evals/eval_collect_resources_medium.map +6 -6
- cogames/maps/evals/eval_divide_and_conquer.map +6 -6
- cogames/maps/evals/eval_energy_starved.map +6 -6
- cogames/maps/evals/eval_multi_coordinated_collect_hard.map +6 -6
- cogames/maps/evals/eval_oxygen_bottleneck.map +6 -6
- cogames/maps/evals/eval_single_use_world.map +6 -6
- cogames/maps/evals/extractor_hub_100x100.map +6 -6
- cogames/maps/evals/extractor_hub_30x30.map +6 -6
- cogames/maps/evals/extractor_hub_50x50.map +6 -6
- cogames/maps/evals/extractor_hub_70x70.map +6 -6
- cogames/maps/evals/extractor_hub_80x80.map +6 -6
- cogames/maps/machina_100_stations.map +2 -2
- cogames/maps/machina_200_stations.map +2 -2
- cogames/maps/machina_200_stations_small.map +2 -2
- cogames/maps/machina_eval_exp01.map +2 -2
- cogames/maps/machina_eval_template_large.map +2 -2
- cogames/maps/machinatrainer4agents.map +2 -2
- cogames/maps/machinatrainer4agentsbase.map +2 -2
- cogames/maps/machinatrainerbig.map +2 -2
- cogames/maps/machinatrainersmall.map +2 -2
- cogames/maps/planky_evals/aligner_avoid_aoe.map +6 -6
- cogames/maps/planky_evals/aligner_full_cycle.map +6 -6
- cogames/maps/planky_evals/aligner_gear.map +6 -6
- cogames/maps/planky_evals/aligner_hearts.map +6 -6
- cogames/maps/planky_evals/aligner_junction.map +6 -6
- cogames/maps/planky_evals/exploration_distant.map +6 -6
- cogames/maps/planky_evals/maze.map +6 -6
- cogames/maps/planky_evals/miner_best_resource.map +6 -6
- cogames/maps/planky_evals/miner_deposit.map +6 -6
- cogames/maps/planky_evals/miner_extract.map +6 -6
- cogames/maps/planky_evals/miner_full_cycle.map +6 -6
- cogames/maps/planky_evals/miner_gear.map +6 -6
- cogames/maps/planky_evals/multi_role.map +6 -6
- cogames/maps/planky_evals/resource_chain.map +6 -6
- cogames/maps/planky_evals/scout_explore.map +6 -6
- cogames/maps/planky_evals/scout_gear.map +6 -6
- cogames/maps/planky_evals/scrambler_full_cycle.map +6 -6
- cogames/maps/planky_evals/scrambler_gear.map +6 -6
- cogames/maps/planky_evals/scrambler_target.map +6 -6
- cogames/maps/planky_evals/stuck_corridor.map +6 -6
- cogames/maps/planky_evals/survive_retreat.map +6 -6
- cogames/maps/training_facility_clipped.map +2 -2
- cogames/maps/training_facility_open_1.map +2 -2
- cogames/maps/training_facility_open_2.map +2 -2
- cogames/maps/training_facility_open_3.map +2 -2
- cogames/maps/training_facility_tight_4.map +2 -2
- cogames/maps/training_facility_tight_5.map +2 -2
- cogames/maps/vanilla_large.map +2 -2
- cogames/maps/vanilla_small.map +2 -2
- cogames/pickup.py +6 -5
- cogames/play.py +14 -16
- cogames/policy/nim_agents/__init__.py +0 -2
- cogames/policy/nim_agents/agents.py +0 -11
- cogames/policy/starter_agent.py +4 -1
- {cogames-0.3.65.dist-info → cogames-0.3.69.dist-info}/METADATA +45 -29
- cogames-0.3.69.dist-info/RECORD +160 -0
- metta_alo/scoring.py +7 -7
- cogames-0.3.65.dist-info/RECORD +0 -160
- metta_alo/job_specs.py +0 -17
- metta_alo/policy.py +0 -16
- metta_alo/pure_single_episode_runner.py +0 -75
- metta_alo/rollout.py +0 -322
- {cogames-0.3.65.dist-info → cogames-0.3.69.dist-info}/WHEEL +0 -0
- {cogames-0.3.65.dist-info → cogames-0.3.69.dist-info}/entry_points.txt +0 -0
- {cogames-0.3.65.dist-info → cogames-0.3.69.dist-info}/licenses/LICENSE +0 -0
- {cogames-0.3.65.dist-info → cogames-0.3.69.dist-info}/top_level.txt +0 -0
cogames/cli/client.py
CHANGED
|
@@ -253,9 +253,6 @@ class TournamentServerClient:
|
|
|
253
253
|
json=payload,
|
|
254
254
|
)
|
|
255
255
|
|
|
256
|
-
def update_policy_version_tags(self, policy_version_id: uuid.UUID, tags: dict[str, str]) -> dict[str, Any]:
|
|
257
|
-
return self._put(f"/stats/policies/versions/{policy_version_id}/tags", json=tags)
|
|
258
|
-
|
|
259
256
|
def get_policy_memberships(self, policy_version_id: uuid.UUID) -> list[dict[str, Any]]:
|
|
260
257
|
return self._get(f"/tournament/policies/{policy_version_id}/memberships", list[dict[str, Any]])
|
|
261
258
|
|
cogames/cli/docsync/docsync.py
CHANGED
|
@@ -120,7 +120,13 @@ def check(
|
|
|
120
120
|
typer.echo("\nNotebook documentation is out of sync!", err=True)
|
|
121
121
|
for error in errors:
|
|
122
122
|
typer.echo(f" {error}", err=True)
|
|
123
|
-
typer.echo(
|
|
123
|
+
typer.echo(
|
|
124
|
+
"\nThis can happen if you modified cogames CLI flags (the README includes a command reference)"
|
|
125
|
+
"\nor changed code that affects notebook outputs."
|
|
126
|
+
"\n"
|
|
127
|
+
"\nTo fix: run 'cogames docsync all'",
|
|
128
|
+
err=True,
|
|
129
|
+
)
|
|
124
130
|
raise typer.Exit(1)
|
|
125
131
|
|
|
126
132
|
typer.echo("All notebooks are in sync!")
|
cogames/cli/mission.py
CHANGED
|
@@ -8,10 +8,8 @@ from rich import box
|
|
|
8
8
|
from rich.table import Table
|
|
9
9
|
|
|
10
10
|
from cogames.cli.base import console
|
|
11
|
-
from cogames.cogs_vs_clips.
|
|
12
|
-
|
|
13
|
-
NumCogsVariant,
|
|
14
|
-
)
|
|
11
|
+
from cogames.cogs_vs_clips.clip_difficulty import get_cogsguard_difficulty
|
|
12
|
+
from cogames.cogs_vs_clips.mission import CvCMission, NumCogsVariant
|
|
15
13
|
from cogames.cogs_vs_clips.sites import SITES
|
|
16
14
|
from cogames.cogs_vs_clips.terrain import MachinaArena
|
|
17
15
|
from cogames.cogs_vs_clips.variants import HIDDEN_VARIANTS, VARIANTS
|
|
@@ -27,16 +25,16 @@ from mettagrid.mapgen.mapgen import MapGen
|
|
|
27
25
|
|
|
28
26
|
@lru_cache(maxsize=1)
|
|
29
27
|
def _get_core_missions() -> list[CvCMission]:
|
|
30
|
-
from cogames.cogs_vs_clips.missions import get_core_missions
|
|
28
|
+
from cogames.cogs_vs_clips.missions import get_core_missions # noqa: PLC0415
|
|
31
29
|
|
|
32
30
|
return get_core_missions()
|
|
33
31
|
|
|
34
32
|
|
|
35
33
|
@lru_cache(maxsize=1)
|
|
36
34
|
def _get_eval_missions_all() -> list[CvCMission]:
|
|
37
|
-
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
|
|
38
|
-
from cogames.cogs_vs_clips.evals.integrated_evals import EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS
|
|
39
|
-
from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS
|
|
35
|
+
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
|
|
36
|
+
from cogames.cogs_vs_clips.evals.integrated_evals import EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS # noqa: PLC0415
|
|
37
|
+
from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS # noqa: PLC0415
|
|
40
38
|
|
|
41
39
|
missions: list[CvCMission] = []
|
|
42
40
|
missions.extend(INTEGRATED_EVAL_MISSIONS)
|
|
@@ -74,19 +72,21 @@ def load_mission_set(mission_set: str) -> list[CvCMission]:
|
|
|
74
72
|
missions_list.append(mission)
|
|
75
73
|
|
|
76
74
|
elif mission_set == "diagnostic_evals":
|
|
77
|
-
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
|
|
75
|
+
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
|
|
78
76
|
|
|
79
77
|
missions_list = [mission_cls() for mission_cls in DIAGNOSTIC_EVALS] # type: ignore[call-arg]
|
|
80
78
|
elif mission_set == "cogsguard_evals":
|
|
81
|
-
from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS
|
|
79
|
+
from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS # noqa: PLC0415
|
|
82
80
|
|
|
83
81
|
missions_list = list(COGSGUARD_EVAL_MISSIONS)
|
|
84
82
|
elif mission_set == "integrated_evals":
|
|
85
|
-
from cogames.cogs_vs_clips.evals.integrated_evals import
|
|
83
|
+
from cogames.cogs_vs_clips.evals.integrated_evals import ( # noqa: PLC0415
|
|
84
|
+
EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS,
|
|
85
|
+
)
|
|
86
86
|
|
|
87
87
|
missions_list = list(INTEGRATED_EVAL_MISSIONS)
|
|
88
88
|
elif mission_set == "spanning_evals":
|
|
89
|
-
from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS
|
|
89
|
+
from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS # noqa: PLC0415
|
|
90
90
|
|
|
91
91
|
missions_list = list(SPANNING_EVAL_MISSIONS)
|
|
92
92
|
else:
|
|
@@ -132,6 +132,12 @@ def parse_variants(variants_arg: Optional[list[str]]) -> list[CoGameMissionVaria
|
|
|
132
132
|
return variants
|
|
133
133
|
|
|
134
134
|
|
|
135
|
+
def parse_difficulty(difficulty_arg: Optional[str]) -> Optional[CoGameMissionVariant]:
|
|
136
|
+
if difficulty_arg is None:
|
|
137
|
+
return None
|
|
138
|
+
return get_cogsguard_difficulty(difficulty_arg)
|
|
139
|
+
|
|
140
|
+
|
|
135
141
|
def get_all_missions() -> list[str]:
|
|
136
142
|
"""Get all core mission names in the format site.mission (excludes evals)."""
|
|
137
143
|
return [mission.full_name() for mission in _get_core_missions()]
|
|
@@ -157,17 +163,25 @@ def get_site_by_name(site_name: str) -> CoGameSite:
|
|
|
157
163
|
|
|
158
164
|
|
|
159
165
|
def get_mission_name_and_config(
|
|
160
|
-
ctx: typer.Context,
|
|
166
|
+
ctx: typer.Context,
|
|
167
|
+
mission_arg: Optional[str],
|
|
168
|
+
variants_arg: Optional[list[str]] = None,
|
|
169
|
+
cogs: Optional[int] = None,
|
|
170
|
+
difficulty: Optional[str] = None,
|
|
161
171
|
) -> tuple[str, MettaGridConfig, Optional[CvCMission]]:
|
|
162
172
|
if not mission_arg:
|
|
163
173
|
console.print(ctx.get_help())
|
|
164
174
|
console.print("[yellow]Missing: --mission / -m[/yellow]\n")
|
|
165
175
|
else:
|
|
166
176
|
try:
|
|
167
|
-
return get_mission(mission_arg, variants_arg, cogs)
|
|
177
|
+
return get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty)
|
|
168
178
|
except ValueError as e:
|
|
169
|
-
|
|
170
|
-
|
|
179
|
+
error_msg = str(e)
|
|
180
|
+
if "variant" in error_msg.lower() or "difficulty" in error_msg.lower():
|
|
181
|
+
console.print(f"[red]{error_msg}[/red]")
|
|
182
|
+
else:
|
|
183
|
+
console.print(f"[red]Mission '{mission_arg}' not found.[/red]")
|
|
184
|
+
console.print("[dim]Run: cogames missions (or cogames missions <site>) to list options.[/dim]\n")
|
|
171
185
|
raise typer.Exit(1) from e
|
|
172
186
|
list_missions()
|
|
173
187
|
|
|
@@ -183,6 +197,7 @@ def get_mission_names_and_configs(
|
|
|
183
197
|
variants_arg: Optional[list[str]] = None,
|
|
184
198
|
cogs: Optional[int] = None,
|
|
185
199
|
steps: Optional[int] = None,
|
|
200
|
+
difficulty: Optional[str] = None,
|
|
186
201
|
) -> list[tuple[str, MettaGridConfig]]:
|
|
187
202
|
if not missions_arg:
|
|
188
203
|
console.print(ctx.get_help())
|
|
@@ -192,7 +207,7 @@ def get_mission_names_and_configs(
|
|
|
192
207
|
not_deduped = [
|
|
193
208
|
mission
|
|
194
209
|
for missions in missions_arg
|
|
195
|
-
for mission in _get_missions_by_possible_wildcard(missions, variants_arg, cogs)
|
|
210
|
+
for mission in _get_missions_by_possible_wildcard(missions, variants_arg, cogs, difficulty)
|
|
196
211
|
]
|
|
197
212
|
name_set: set[str] = set()
|
|
198
213
|
deduped = []
|
|
@@ -224,6 +239,7 @@ def _get_missions_by_possible_wildcard(
|
|
|
224
239
|
mission_arg: str,
|
|
225
240
|
variants_arg: Optional[list[str]],
|
|
226
241
|
cogs: Optional[int],
|
|
242
|
+
difficulty: Optional[str],
|
|
227
243
|
) -> list[tuple[str, MettaGridConfig]]:
|
|
228
244
|
if "*" in mission_arg:
|
|
229
245
|
# Convert shell-style wildcard to regex pattern
|
|
@@ -232,10 +248,12 @@ def _get_missions_by_possible_wildcard(
|
|
|
232
248
|
# Drop the Mission (3rd element) for wildcard results
|
|
233
249
|
return [
|
|
234
250
|
(name, env_cfg)
|
|
235
|
-
for name, env_cfg, _ in (
|
|
251
|
+
for name, env_cfg, _ in (
|
|
252
|
+
get_mission(m, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty) for m in missions
|
|
253
|
+
)
|
|
236
254
|
]
|
|
237
255
|
# Drop the Mission for single mission
|
|
238
|
-
name, env_cfg, _ = get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs)
|
|
256
|
+
name, env_cfg, _ = get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty)
|
|
239
257
|
return [(name, env_cfg)]
|
|
240
258
|
|
|
241
259
|
|
|
@@ -278,6 +296,7 @@ def get_mission(
|
|
|
278
296
|
variants_arg: Optional[list[str]] = None,
|
|
279
297
|
cogs: Optional[int] = None,
|
|
280
298
|
include_legacy: bool = False,
|
|
299
|
+
difficulty: Optional[str] = None,
|
|
281
300
|
) -> tuple[str, MettaGridConfig, Optional[CvCMission]]:
|
|
282
301
|
"""Get a specific mission configuration by name or file path.
|
|
283
302
|
|
|
@@ -286,6 +305,7 @@ def get_mission(
|
|
|
286
305
|
variants_arg: List of variant names like ["solar_flare", "dark_side"]
|
|
287
306
|
cogs: Number of cogs (agents) to use, overrides the default from the mission
|
|
288
307
|
include_legacy: Whether to include legacy (pre-CogsGuard) missions
|
|
308
|
+
difficulty: Difficulty name (easy, medium, hard) controlling clips events
|
|
289
309
|
|
|
290
310
|
Returns:
|
|
291
311
|
Tuple of (mission name, MettaGridConfig, CvCMission or None)
|
|
@@ -295,6 +315,8 @@ def get_mission(
|
|
|
295
315
|
"""
|
|
296
316
|
# Check if it's a file path
|
|
297
317
|
if any(mission_arg.endswith(ext) for ext in [".yaml", ".yml", ".json", ".py"]):
|
|
318
|
+
if difficulty is not None:
|
|
319
|
+
raise ValueError("Difficulty can only be used with mission names, not config files.")
|
|
298
320
|
path = Path(mission_arg)
|
|
299
321
|
if not path.exists():
|
|
300
322
|
raise ValueError(f"File not found: {mission_arg}")
|
|
@@ -311,6 +333,7 @@ def get_mission(
|
|
|
311
333
|
|
|
312
334
|
# Parse variants if provided
|
|
313
335
|
variants = parse_variants(variants_arg)
|
|
336
|
+
difficulty_variant = parse_difficulty(difficulty)
|
|
314
337
|
|
|
315
338
|
# Otherwise, treat it as a fully qualified mission name, or as a site name
|
|
316
339
|
if (delim_count := mission_arg.count(MAP_MISSION_DELIMITER)) == 0:
|
|
@@ -322,6 +345,8 @@ def get_mission(
|
|
|
322
345
|
|
|
323
346
|
mission: CvCMission = find_mission(site_name, mission_name, include_evals=True, include_legacy=include_legacy)
|
|
324
347
|
|
|
348
|
+
if difficulty_variant is not None:
|
|
349
|
+
mission = mission.with_variants([difficulty_variant])
|
|
325
350
|
if variants:
|
|
326
351
|
mission = mission.with_variants(variants)
|
|
327
352
|
if cogs is not None:
|
cogames/cli/policy.py
CHANGED
|
@@ -68,13 +68,13 @@ def _translate_error(e: Exception) -> str:
|
|
|
68
68
|
return translated
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str]) -> PolicySpec:
|
|
71
|
+
def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str], device: str | None = None) -> PolicySpec:
|
|
72
72
|
if policy_arg is None:
|
|
73
73
|
console.print(ctx.get_help())
|
|
74
74
|
console.print("[yellow]Missing: --policy / -p[/yellow]\n")
|
|
75
75
|
else:
|
|
76
76
|
try:
|
|
77
|
-
return parse_policy_spec(spec=policy_arg).to_policy_spec()
|
|
77
|
+
return parse_policy_spec(spec=policy_arg, device=device).to_policy_spec()
|
|
78
78
|
except (ValueError, ModuleNotFoundError) as e:
|
|
79
79
|
translated = _translate_error(e)
|
|
80
80
|
console.print(f"[yellow]Error parsing policy argument: {translated}[/yellow]\n")
|
|
@@ -90,14 +90,14 @@ def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str]) -> PolicySpec
|
|
|
90
90
|
|
|
91
91
|
|
|
92
92
|
def get_policy_specs_with_proportions(
|
|
93
|
-
ctx: typer.Context, policy_args: Optional[list[str]]
|
|
93
|
+
ctx: typer.Context, policy_args: Optional[list[str]], device: str | None = None
|
|
94
94
|
) -> list[PolicySpecWithProportion]:
|
|
95
95
|
if not policy_args:
|
|
96
96
|
console.print(ctx.get_help())
|
|
97
97
|
console.print("[yellow]Supply at least one: --policy / -p[/yellow]\n")
|
|
98
98
|
else:
|
|
99
99
|
try:
|
|
100
|
-
return [parse_policy_spec(spec=policy_arg) for policy_arg in policy_args]
|
|
100
|
+
return [parse_policy_spec(spec=policy_arg, device=device) for policy_arg in policy_args]
|
|
101
101
|
except (ValueError, ModuleNotFoundError) as e:
|
|
102
102
|
translated = _translate_error(e)
|
|
103
103
|
console.print(f"[yellow]Error parsing policy argument: {translated}[/yellow]")
|
|
@@ -112,7 +112,20 @@ def get_policy_specs_with_proportions(
|
|
|
112
112
|
raise typer.Exit(1)
|
|
113
113
|
|
|
114
114
|
|
|
115
|
-
def
|
|
115
|
+
def _apply_device_override(spec: PolicySpecWithProportion, device: str | None) -> PolicySpecWithProportion:
|
|
116
|
+
if device is None:
|
|
117
|
+
return spec
|
|
118
|
+
init_kwargs = dict(spec.init_kwargs or {})
|
|
119
|
+
init_kwargs["device"] = device
|
|
120
|
+
return PolicySpecWithProportion(
|
|
121
|
+
class_path=spec.class_path,
|
|
122
|
+
data_path=spec.data_path,
|
|
123
|
+
proportion=spec.proportion,
|
|
124
|
+
init_kwargs=init_kwargs,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def parse_policy_spec(spec: str, device: str | None = None) -> PolicySpecWithProportion:
|
|
116
129
|
"""Parse a policy CLI option into its components.
|
|
117
130
|
|
|
118
131
|
Supports two formats:
|
|
@@ -161,13 +174,14 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
|
|
|
161
174
|
uri_part = spec
|
|
162
175
|
fraction = 1.0
|
|
163
176
|
|
|
164
|
-
policy = policy_spec_from_uri(uri_part.strip())
|
|
165
|
-
|
|
177
|
+
policy = policy_spec_from_uri(uri_part.strip(), device=device or "cpu")
|
|
178
|
+
policy_spec = PolicySpecWithProportion(
|
|
166
179
|
class_path=policy.class_path,
|
|
167
180
|
data_path=policy.data_path,
|
|
168
181
|
proportion=fraction,
|
|
169
182
|
init_kwargs=policy.init_kwargs,
|
|
170
183
|
)
|
|
184
|
+
return _apply_device_override(policy_spec, device)
|
|
171
185
|
|
|
172
186
|
entries = [part.strip() for part in spec.split(",") if part.strip()]
|
|
173
187
|
if not entries:
|
|
@@ -176,19 +190,20 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
|
|
|
176
190
|
fraction = 1.0
|
|
177
191
|
first = entries[0]
|
|
178
192
|
if is_uri(first) or ("=" not in first and is_path_like(first)):
|
|
179
|
-
policy = policy_spec_from_uri(first)
|
|
193
|
+
policy = policy_spec_from_uri(first, device=device or "cpu")
|
|
180
194
|
for entry in entries[1:]:
|
|
181
195
|
key, value = parse_key_value(entry)
|
|
182
196
|
if key != "proportion":
|
|
183
197
|
raise ValueError("Only proportion is supported after a checkpoint URI.")
|
|
184
198
|
fraction = parse_proportion(value)
|
|
185
199
|
|
|
186
|
-
|
|
200
|
+
policy_spec = PolicySpecWithProportion(
|
|
187
201
|
class_path=policy.class_path,
|
|
188
202
|
data_path=policy.data_path,
|
|
189
203
|
proportion=fraction,
|
|
190
204
|
init_kwargs=policy.init_kwargs,
|
|
191
205
|
)
|
|
206
|
+
return _apply_device_override(policy_spec, device)
|
|
192
207
|
|
|
193
208
|
if "=" not in first:
|
|
194
209
|
if ":" in first:
|
|
@@ -230,9 +245,10 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
|
|
|
230
245
|
if class_path is None:
|
|
231
246
|
raise ValueError("Policy specification must include class= for key=value format.")
|
|
232
247
|
|
|
233
|
-
|
|
248
|
+
policy_spec = PolicySpecWithProportion(
|
|
234
249
|
class_path=class_path,
|
|
235
250
|
data_path=data_path,
|
|
236
251
|
proportion=fraction,
|
|
237
252
|
init_kwargs=init_kwargs,
|
|
238
253
|
)
|
|
254
|
+
return _apply_device_override(policy_spec, device)
|