cogames 0.3.64__py3-none-any.whl → 0.3.68__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cogames/cli/client.py +0 -3
- cogames/cli/docsync/docsync.py +7 -1
- cogames/cli/mission.py +68 -53
- cogames/cli/policy.py +26 -10
- cogames/cli/submit.py +128 -142
- cogames/cli/utils.py +5 -0
- cogames/cogs_vs_clips/clip_difficulty.py +57 -0
- cogames/cogs_vs_clips/clips.py +103 -0
- cogames/cogs_vs_clips/cog.py +29 -11
- cogames/cogs_vs_clips/cogsguard_curriculum.py +122 -0
- cogames/cogs_vs_clips/cogsguard_tutorial.py +15 -16
- cogames/cogs_vs_clips/config.py +38 -0
- cogames/cogs_vs_clips/{cogs_vs_clips_mapgen.md → docs/cogs_vs_clips_mapgen.md} +8 -10
- cogames/cogs_vs_clips/evals/README.md +11 -35
- cogames/cogs_vs_clips/evals/cogsguard_evals.py +21 -6
- cogames/cogs_vs_clips/evals/diagnostic_evals.py +13 -101
- cogames/cogs_vs_clips/evals/difficulty_variants.py +16 -28
- cogames/cogs_vs_clips/evals/integrated_evals.py +8 -60
- cogames/cogs_vs_clips/evals/spanning_evals.py +48 -54
- cogames/cogs_vs_clips/mission.py +93 -277
- cogames/cogs_vs_clips/missions.py +17 -27
- cogames/cogs_vs_clips/{cogsguard_reward_variants.py → reward_variants.py} +22 -2
- cogames/cogs_vs_clips/sites.py +41 -30
- cogames/cogs_vs_clips/stations.py +39 -84
- cogames/cogs_vs_clips/team.py +46 -0
- cogames/cogs_vs_clips/{procedural.py → terrain.py} +14 -8
- cogames/cogs_vs_clips/variants.py +201 -107
- cogames/cogs_vs_clips/weather.py +52 -0
- cogames/core.py +87 -0
- cogames/docs/SCRIPTED_AGENT.md +3 -3
- cogames/evaluate.py +4 -2
- cogames/main.py +357 -51
- cogames/maps/canidate1_1000.map +1 -1
- cogames/maps/canidate1_1000_stations.map +2 -2
- cogames/maps/canidate1_500.map +1 -1
- cogames/maps/canidate1_500_stations.map +2 -2
- cogames/maps/canidate2_1000.map +1 -1
- cogames/maps/canidate2_1000_stations.map +2 -2
- cogames/maps/canidate2_500.map +1 -1
- cogames/maps/canidate2_500_stations.map +1 -1
- cogames/maps/canidate3_1000.map +1 -1
- cogames/maps/canidate3_1000_stations.map +2 -2
- cogames/maps/canidate3_500.map +1 -1
- cogames/maps/canidate3_500_stations.map +2 -2
- cogames/maps/canidate4_500.map +1 -1
- cogames/maps/canidate4_500_stations.map +2 -2
- cogames/maps/cave_base_50.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_agile.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_agile_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_charge_up.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_charge_up_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation1.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation1_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation2.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation2_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation3.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation3_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_near.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_search.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_search_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_extract_lab.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_extract_lab_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_memory.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_memory_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_radial.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_radial_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_resource_lab.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_unclip.map +6 -6
- cogames/maps/evals/eval_balanced_spread.map +6 -6
- cogames/maps/evals/eval_clip_oxygen.map +6 -6
- cogames/maps/evals/eval_collect_resources.map +6 -6
- cogames/maps/evals/eval_collect_resources_hard.map +6 -6
- cogames/maps/evals/eval_collect_resources_medium.map +6 -6
- cogames/maps/evals/eval_divide_and_conquer.map +6 -6
- cogames/maps/evals/eval_energy_starved.map +6 -6
- cogames/maps/evals/eval_multi_coordinated_collect_hard.map +6 -6
- cogames/maps/evals/eval_oxygen_bottleneck.map +6 -6
- cogames/maps/evals/eval_single_use_world.map +6 -6
- cogames/maps/evals/extractor_hub_100x100.map +6 -6
- cogames/maps/evals/extractor_hub_30x30.map +6 -6
- cogames/maps/evals/extractor_hub_50x50.map +6 -6
- cogames/maps/evals/extractor_hub_70x70.map +6 -6
- cogames/maps/evals/extractor_hub_80x80.map +6 -6
- cogames/maps/machina_100_stations.map +2 -2
- cogames/maps/machina_200_stations.map +2 -2
- cogames/maps/machina_200_stations_small.map +2 -2
- cogames/maps/machina_eval_exp01.map +2 -2
- cogames/maps/machina_eval_template_large.map +2 -2
- cogames/maps/machinatrainer4agents.map +2 -2
- cogames/maps/machinatrainer4agentsbase.map +2 -2
- cogames/maps/machinatrainerbig.map +2 -2
- cogames/maps/machinatrainersmall.map +2 -2
- cogames/maps/planky_evals/aligner_avoid_aoe.map +6 -6
- cogames/maps/planky_evals/aligner_full_cycle.map +6 -6
- cogames/maps/planky_evals/aligner_gear.map +6 -6
- cogames/maps/planky_evals/aligner_hearts.map +6 -6
- cogames/maps/planky_evals/aligner_junction.map +6 -6
- cogames/maps/planky_evals/exploration_distant.map +6 -6
- cogames/maps/planky_evals/maze.map +6 -6
- cogames/maps/planky_evals/miner_best_resource.map +6 -6
- cogames/maps/planky_evals/miner_deposit.map +6 -6
- cogames/maps/planky_evals/miner_extract.map +6 -6
- cogames/maps/planky_evals/miner_full_cycle.map +6 -6
- cogames/maps/planky_evals/miner_gear.map +6 -6
- cogames/maps/planky_evals/multi_role.map +6 -6
- cogames/maps/planky_evals/resource_chain.map +6 -6
- cogames/maps/planky_evals/scout_explore.map +6 -6
- cogames/maps/planky_evals/scout_gear.map +6 -6
- cogames/maps/planky_evals/scrambler_full_cycle.map +6 -6
- cogames/maps/planky_evals/scrambler_gear.map +6 -6
- cogames/maps/planky_evals/scrambler_target.map +6 -6
- cogames/maps/planky_evals/stuck_corridor.map +6 -6
- cogames/maps/planky_evals/survive_retreat.map +6 -6
- cogames/maps/training_facility_clipped.map +2 -2
- cogames/maps/training_facility_open_1.map +2 -2
- cogames/maps/training_facility_open_2.map +2 -2
- cogames/maps/training_facility_open_3.map +2 -2
- cogames/maps/training_facility_tight_4.map +2 -2
- cogames/maps/training_facility_tight_5.map +2 -2
- cogames/maps/vanilla_large.map +2 -2
- cogames/maps/vanilla_small.map +2 -2
- cogames/pickup.py +6 -5
- cogames/play.py +14 -16
- cogames/policy/nim_agents/__init__.py +0 -2
- cogames/policy/nim_agents/agents.py +0 -11
- cogames/policy/starter_agent.py +4 -1
- cogames/verbose.py +2 -2
- {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/METADATA +45 -29
- cogames-0.3.68.dist-info/RECORD +160 -0
- metta_alo/scoring.py +7 -7
- cogames/cogs_vs_clips/mission_utils.py +0 -19
- cogames/cogs_vs_clips/tutorial_missions.py +0 -25
- cogames-0.3.64.dist-info/RECORD +0 -159
- metta_alo/job_specs.py +0 -17
- metta_alo/policy.py +0 -16
- metta_alo/pure_single_episode_runner.py +0 -75
- metta_alo/rollout.py +0 -322
- {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/WHEEL +0 -0
- {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/entry_points.txt +0 -0
- {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/licenses/LICENSE +0 -0
- {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/top_level.txt +0 -0
cogames/cli/client.py
CHANGED
|
@@ -253,9 +253,6 @@ class TournamentServerClient:
|
|
|
253
253
|
json=payload,
|
|
254
254
|
)
|
|
255
255
|
|
|
256
|
-
def update_policy_version_tags(self, policy_version_id: uuid.UUID, tags: dict[str, str]) -> dict[str, Any]:
|
|
257
|
-
return self._put(f"/stats/policies/versions/{policy_version_id}/tags", json=tags)
|
|
258
|
-
|
|
259
256
|
def get_policy_memberships(self, policy_version_id: uuid.UUID) -> list[dict[str, Any]]:
|
|
260
257
|
return self._get(f"/tournament/policies/{policy_version_id}/memberships", list[dict[str, Any]])
|
|
261
258
|
|
cogames/cli/docsync/docsync.py
CHANGED
|
@@ -120,7 +120,13 @@ def check(
|
|
|
120
120
|
typer.echo("\nNotebook documentation is out of sync!", err=True)
|
|
121
121
|
for error in errors:
|
|
122
122
|
typer.echo(f" {error}", err=True)
|
|
123
|
-
typer.echo(
|
|
123
|
+
typer.echo(
|
|
124
|
+
"\nThis can happen if you modified cogames CLI flags (the README includes a command reference)"
|
|
125
|
+
"\nor changed code that affects notebook outputs."
|
|
126
|
+
"\n"
|
|
127
|
+
"\nTo fix: run 'cogames docsync all'",
|
|
128
|
+
err=True,
|
|
129
|
+
)
|
|
124
130
|
raise typer.Exit(1)
|
|
125
131
|
|
|
126
132
|
typer.echo("All notebooks are in sync!")
|
cogames/cli/mission.py
CHANGED
|
@@ -8,52 +8,42 @@ from rich import box
|
|
|
8
8
|
from rich.table import Table
|
|
9
9
|
|
|
10
10
|
from cogames.cli.base import console
|
|
11
|
-
from cogames.cogs_vs_clips.
|
|
12
|
-
|
|
13
|
-
AnyMission,
|
|
14
|
-
Mission,
|
|
15
|
-
MissionVariant,
|
|
16
|
-
NumCogsVariant,
|
|
17
|
-
Site,
|
|
18
|
-
)
|
|
19
|
-
from cogames.cogs_vs_clips.procedural import MachinaArena
|
|
11
|
+
from cogames.cogs_vs_clips.clip_difficulty import get_cogsguard_difficulty
|
|
12
|
+
from cogames.cogs_vs_clips.mission import CvCMission, NumCogsVariant
|
|
20
13
|
from cogames.cogs_vs_clips.sites import SITES
|
|
14
|
+
from cogames.cogs_vs_clips.terrain import MachinaArena
|
|
21
15
|
from cogames.cogs_vs_clips.variants import HIDDEN_VARIANTS, VARIANTS
|
|
16
|
+
from cogames.core import (
|
|
17
|
+
MAP_MISSION_DELIMITER,
|
|
18
|
+
CoGameMissionVariant,
|
|
19
|
+
CoGameSite,
|
|
20
|
+
)
|
|
22
21
|
from cogames.game import load_mission_config, load_mission_config_from_python
|
|
23
22
|
from mettagrid import MettaGridConfig
|
|
24
23
|
from mettagrid.mapgen.mapgen import MapGen
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
@lru_cache(maxsize=1)
|
|
28
|
-
def _get_core_missions() -> list[
|
|
29
|
-
from cogames.cogs_vs_clips.missions import get_core_missions
|
|
27
|
+
def _get_core_missions() -> list[CvCMission]:
|
|
28
|
+
from cogames.cogs_vs_clips.missions import get_core_missions # noqa: PLC0415
|
|
30
29
|
|
|
31
30
|
return get_core_missions()
|
|
32
31
|
|
|
33
32
|
|
|
34
33
|
@lru_cache(maxsize=1)
|
|
35
|
-
def
|
|
36
|
-
from cogames.cogs_vs_clips.
|
|
34
|
+
def _get_eval_missions_all() -> list[CvCMission]:
|
|
35
|
+
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
|
|
36
|
+
from cogames.cogs_vs_clips.evals.integrated_evals import EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS # noqa: PLC0415
|
|
37
|
+
from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS # noqa: PLC0415
|
|
37
38
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
@lru_cache(maxsize=1)
|
|
42
|
-
def _get_eval_missions_all() -> list[AnyMission]:
|
|
43
|
-
from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS
|
|
44
|
-
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
|
|
45
|
-
from cogames.cogs_vs_clips.evals.integrated_evals import EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS
|
|
46
|
-
from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS
|
|
47
|
-
|
|
48
|
-
missions: list[AnyMission] = []
|
|
49
|
-
missions.extend(COGSGUARD_EVAL_MISSIONS)
|
|
39
|
+
missions: list[CvCMission] = []
|
|
50
40
|
missions.extend(INTEGRATED_EVAL_MISSIONS)
|
|
51
41
|
missions.extend(SPANNING_EVAL_MISSIONS)
|
|
52
42
|
missions.extend(mission_cls() for mission_cls in DIAGNOSTIC_EVALS) # type: ignore[call-arg]
|
|
53
43
|
return missions
|
|
54
44
|
|
|
55
45
|
|
|
56
|
-
def load_mission_set(mission_set: str) -> list[
|
|
46
|
+
def load_mission_set(mission_set: str) -> list[CvCMission]:
|
|
57
47
|
"""Load a predefined set of evaluation missions.
|
|
58
48
|
|
|
59
49
|
Args:
|
|
@@ -70,7 +60,7 @@ def load_mission_set(mission_set: str) -> list[AnyMission]:
|
|
|
70
60
|
Raises:
|
|
71
61
|
ValueError: If mission_set name is unknown
|
|
72
62
|
"""
|
|
73
|
-
missions_list: list[
|
|
63
|
+
missions_list: list[CvCMission]
|
|
74
64
|
if mission_set == "all":
|
|
75
65
|
# All missions: eval missions + integrated + spanning + diagnostic + core missions
|
|
76
66
|
missions_list = list(_get_eval_missions_all())
|
|
@@ -82,19 +72,21 @@ def load_mission_set(mission_set: str) -> list[AnyMission]:
|
|
|
82
72
|
missions_list.append(mission)
|
|
83
73
|
|
|
84
74
|
elif mission_set == "diagnostic_evals":
|
|
85
|
-
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
|
|
75
|
+
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
|
|
86
76
|
|
|
87
77
|
missions_list = [mission_cls() for mission_cls in DIAGNOSTIC_EVALS] # type: ignore[call-arg]
|
|
88
78
|
elif mission_set == "cogsguard_evals":
|
|
89
|
-
from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS
|
|
79
|
+
from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS # noqa: PLC0415
|
|
90
80
|
|
|
91
81
|
missions_list = list(COGSGUARD_EVAL_MISSIONS)
|
|
92
82
|
elif mission_set == "integrated_evals":
|
|
93
|
-
from cogames.cogs_vs_clips.evals.integrated_evals import
|
|
83
|
+
from cogames.cogs_vs_clips.evals.integrated_evals import ( # noqa: PLC0415
|
|
84
|
+
EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS,
|
|
85
|
+
)
|
|
94
86
|
|
|
95
87
|
missions_list = list(INTEGRATED_EVAL_MISSIONS)
|
|
96
88
|
elif mission_set == "spanning_evals":
|
|
97
|
-
from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS
|
|
89
|
+
from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS # noqa: PLC0415
|
|
98
90
|
|
|
99
91
|
missions_list = list(SPANNING_EVAL_MISSIONS)
|
|
100
92
|
else:
|
|
@@ -104,14 +96,14 @@ def load_mission_set(mission_set: str) -> list[AnyMission]:
|
|
|
104
96
|
return missions_list
|
|
105
97
|
|
|
106
98
|
|
|
107
|
-
def parse_variants(variants_arg: Optional[list[str]]) -> list[
|
|
99
|
+
def parse_variants(variants_arg: Optional[list[str]]) -> list[CoGameMissionVariant]:
|
|
108
100
|
"""Parse variant specifications from command line.
|
|
109
101
|
|
|
110
102
|
Args:
|
|
111
103
|
variants_arg: List of variant names like ["solar_flare", "dark_side"]
|
|
112
104
|
|
|
113
105
|
Returns:
|
|
114
|
-
List of configured
|
|
106
|
+
List of configured CoGameMissionVariant instances
|
|
115
107
|
|
|
116
108
|
Raises:
|
|
117
109
|
ValueError: If variant name is unknown
|
|
@@ -119,11 +111,11 @@ def parse_variants(variants_arg: Optional[list[str]]) -> list[MissionVariant]:
|
|
|
119
111
|
if not variants_arg:
|
|
120
112
|
return []
|
|
121
113
|
|
|
122
|
-
variants: list[
|
|
114
|
+
variants: list[CoGameMissionVariant] = []
|
|
123
115
|
all_variants = [*VARIANTS, *HIDDEN_VARIANTS]
|
|
124
116
|
for name in variants_arg:
|
|
125
117
|
# Find matching variant class by instantiating and checking the name
|
|
126
|
-
variant:
|
|
118
|
+
variant: CoGameMissionVariant | None = None
|
|
127
119
|
for v in all_variants:
|
|
128
120
|
if v.name == name:
|
|
129
121
|
variant = v
|
|
@@ -140,6 +132,12 @@ def parse_variants(variants_arg: Optional[list[str]]) -> list[MissionVariant]:
|
|
|
140
132
|
return variants
|
|
141
133
|
|
|
142
134
|
|
|
135
|
+
def parse_difficulty(difficulty_arg: Optional[str]) -> Optional[CoGameMissionVariant]:
|
|
136
|
+
if difficulty_arg is None:
|
|
137
|
+
return None
|
|
138
|
+
return get_cogsguard_difficulty(difficulty_arg)
|
|
139
|
+
|
|
140
|
+
|
|
143
141
|
def get_all_missions() -> list[str]:
|
|
144
142
|
"""Get all core mission names in the format site.mission (excludes evals)."""
|
|
145
143
|
return [mission.full_name() for mission in _get_core_missions()]
|
|
@@ -150,7 +148,7 @@ def get_all_eval_missions() -> list[str]:
|
|
|
150
148
|
return [mission.full_name() for mission in _get_eval_missions_all()]
|
|
151
149
|
|
|
152
150
|
|
|
153
|
-
def get_site_by_name(site_name: str) ->
|
|
151
|
+
def get_site_by_name(site_name: str) -> CoGameSite:
|
|
154
152
|
"""Get a site by name.
|
|
155
153
|
|
|
156
154
|
Raises:
|
|
@@ -165,17 +163,25 @@ def get_site_by_name(site_name: str) -> Site:
|
|
|
165
163
|
|
|
166
164
|
|
|
167
165
|
def get_mission_name_and_config(
|
|
168
|
-
ctx: typer.Context,
|
|
169
|
-
|
|
166
|
+
ctx: typer.Context,
|
|
167
|
+
mission_arg: Optional[str],
|
|
168
|
+
variants_arg: Optional[list[str]] = None,
|
|
169
|
+
cogs: Optional[int] = None,
|
|
170
|
+
difficulty: Optional[str] = None,
|
|
171
|
+
) -> tuple[str, MettaGridConfig, Optional[CvCMission]]:
|
|
170
172
|
if not mission_arg:
|
|
171
173
|
console.print(ctx.get_help())
|
|
172
174
|
console.print("[yellow]Missing: --mission / -m[/yellow]\n")
|
|
173
175
|
else:
|
|
174
176
|
try:
|
|
175
|
-
return get_mission(mission_arg, variants_arg, cogs)
|
|
177
|
+
return get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty)
|
|
176
178
|
except ValueError as e:
|
|
177
|
-
|
|
178
|
-
|
|
179
|
+
error_msg = str(e)
|
|
180
|
+
if "variant" in error_msg.lower() or "difficulty" in error_msg.lower():
|
|
181
|
+
console.print(f"[red]{error_msg}[/red]")
|
|
182
|
+
else:
|
|
183
|
+
console.print(f"[red]Mission '{mission_arg}' not found.[/red]")
|
|
184
|
+
console.print("[dim]Run: cogames missions (or cogames missions <site>) to list options.[/dim]\n")
|
|
179
185
|
raise typer.Exit(1) from e
|
|
180
186
|
list_missions()
|
|
181
187
|
|
|
@@ -191,6 +197,7 @@ def get_mission_names_and_configs(
|
|
|
191
197
|
variants_arg: Optional[list[str]] = None,
|
|
192
198
|
cogs: Optional[int] = None,
|
|
193
199
|
steps: Optional[int] = None,
|
|
200
|
+
difficulty: Optional[str] = None,
|
|
194
201
|
) -> list[tuple[str, MettaGridConfig]]:
|
|
195
202
|
if not missions_arg:
|
|
196
203
|
console.print(ctx.get_help())
|
|
@@ -200,7 +207,7 @@ def get_mission_names_and_configs(
|
|
|
200
207
|
not_deduped = [
|
|
201
208
|
mission
|
|
202
209
|
for missions in missions_arg
|
|
203
|
-
for mission in _get_missions_by_possible_wildcard(missions, variants_arg, cogs)
|
|
210
|
+
for mission in _get_missions_by_possible_wildcard(missions, variants_arg, cogs, difficulty)
|
|
204
211
|
]
|
|
205
212
|
name_set: set[str] = set()
|
|
206
213
|
deduped = []
|
|
@@ -232,6 +239,7 @@ def _get_missions_by_possible_wildcard(
|
|
|
232
239
|
mission_arg: str,
|
|
233
240
|
variants_arg: Optional[list[str]],
|
|
234
241
|
cogs: Optional[int],
|
|
242
|
+
difficulty: Optional[str],
|
|
235
243
|
) -> list[tuple[str, MettaGridConfig]]:
|
|
236
244
|
if "*" in mission_arg:
|
|
237
245
|
# Convert shell-style wildcard to regex pattern
|
|
@@ -240,10 +248,12 @@ def _get_missions_by_possible_wildcard(
|
|
|
240
248
|
# Drop the Mission (3rd element) for wildcard results
|
|
241
249
|
return [
|
|
242
250
|
(name, env_cfg)
|
|
243
|
-
for name, env_cfg, _ in (
|
|
251
|
+
for name, env_cfg, _ in (
|
|
252
|
+
get_mission(m, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty) for m in missions
|
|
253
|
+
)
|
|
244
254
|
]
|
|
245
255
|
# Drop the Mission for single mission
|
|
246
|
-
name, env_cfg, _ = get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs)
|
|
256
|
+
name, env_cfg, _ = get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty)
|
|
247
257
|
return [(name, env_cfg)]
|
|
248
258
|
|
|
249
259
|
|
|
@@ -253,12 +263,10 @@ def find_mission(
|
|
|
253
263
|
*,
|
|
254
264
|
include_evals: bool = False,
|
|
255
265
|
include_legacy: bool = False,
|
|
256
|
-
) ->
|
|
257
|
-
missions: list[
|
|
266
|
+
) -> CvCMission:
|
|
267
|
+
missions: list[CvCMission] = list(_get_core_missions())
|
|
258
268
|
if include_evals:
|
|
259
269
|
missions = [*missions, *_get_eval_missions_all()]
|
|
260
|
-
if include_legacy:
|
|
261
|
-
missions = [*missions, *_get_legacy_missions()]
|
|
262
270
|
|
|
263
271
|
found_site = False
|
|
264
272
|
for mission in missions:
|
|
@@ -288,7 +296,8 @@ def get_mission(
|
|
|
288
296
|
variants_arg: Optional[list[str]] = None,
|
|
289
297
|
cogs: Optional[int] = None,
|
|
290
298
|
include_legacy: bool = False,
|
|
291
|
-
|
|
299
|
+
difficulty: Optional[str] = None,
|
|
300
|
+
) -> tuple[str, MettaGridConfig, Optional[CvCMission]]:
|
|
292
301
|
"""Get a specific mission configuration by name or file path.
|
|
293
302
|
|
|
294
303
|
Args:
|
|
@@ -296,15 +305,18 @@ def get_mission(
|
|
|
296
305
|
variants_arg: List of variant names like ["solar_flare", "dark_side"]
|
|
297
306
|
cogs: Number of cogs (agents) to use, overrides the default from the mission
|
|
298
307
|
include_legacy: Whether to include legacy (pre-CogsGuard) missions
|
|
308
|
+
difficulty: Difficulty name (easy, medium, hard) controlling clips events
|
|
299
309
|
|
|
300
310
|
Returns:
|
|
301
|
-
Tuple of (mission name, MettaGridConfig,
|
|
311
|
+
Tuple of (mission name, MettaGridConfig, CvCMission or None)
|
|
302
312
|
|
|
303
313
|
Raises:
|
|
304
314
|
ValueError: If mission not found or file cannot be loaded
|
|
305
315
|
"""
|
|
306
316
|
# Check if it's a file path
|
|
307
317
|
if any(mission_arg.endswith(ext) for ext in [".yaml", ".yml", ".json", ".py"]):
|
|
318
|
+
if difficulty is not None:
|
|
319
|
+
raise ValueError("Difficulty can only be used with mission names, not config files.")
|
|
308
320
|
path = Path(mission_arg)
|
|
309
321
|
if not path.exists():
|
|
310
322
|
raise ValueError(f"File not found: {mission_arg}")
|
|
@@ -321,6 +333,7 @@ def get_mission(
|
|
|
321
333
|
|
|
322
334
|
# Parse variants if provided
|
|
323
335
|
variants = parse_variants(variants_arg)
|
|
336
|
+
difficulty_variant = parse_difficulty(difficulty)
|
|
324
337
|
|
|
325
338
|
# Otherwise, treat it as a fully qualified mission name, or as a site name
|
|
326
339
|
if (delim_count := mission_arg.count(MAP_MISSION_DELIMITER)) == 0:
|
|
@@ -330,8 +343,10 @@ def get_mission(
|
|
|
330
343
|
else:
|
|
331
344
|
site_name, mission_name = mission_arg.split(MAP_MISSION_DELIMITER)
|
|
332
345
|
|
|
333
|
-
mission:
|
|
346
|
+
mission: CvCMission = find_mission(site_name, mission_name, include_evals=True, include_legacy=include_legacy)
|
|
334
347
|
|
|
348
|
+
if difficulty_variant is not None:
|
|
349
|
+
mission = mission.with_variants([difficulty_variant])
|
|
335
350
|
if variants:
|
|
336
351
|
mission = mission.with_variants(variants)
|
|
337
352
|
if cogs is not None:
|
|
@@ -454,7 +469,7 @@ def list_evals() -> None:
|
|
|
454
469
|
return
|
|
455
470
|
|
|
456
471
|
# Group missions by site
|
|
457
|
-
missions_by_site: dict[str, list[
|
|
472
|
+
missions_by_site: dict[str, list[CvCMission]] = {}
|
|
458
473
|
for m in evals:
|
|
459
474
|
missions_by_site.setdefault(m.site.name, []).append(m)
|
|
460
475
|
|
|
@@ -506,7 +521,7 @@ def list_evals() -> None:
|
|
|
506
521
|
console.print(" [bold]cogames play[/bold] --mission [blue]evals.divide_and_conquer[/blue]")
|
|
507
522
|
|
|
508
523
|
|
|
509
|
-
def describe_mission(mission_name: str, game_config: MettaGridConfig, mission_cfg:
|
|
524
|
+
def describe_mission(mission_name: str, game_config: MettaGridConfig, mission_cfg: CvCMission | None = None) -> None:
|
|
510
525
|
"""Print detailed information about a specific mission.
|
|
511
526
|
|
|
512
527
|
Args:
|
cogames/cli/policy.py
CHANGED
|
@@ -68,13 +68,13 @@ def _translate_error(e: Exception) -> str:
|
|
|
68
68
|
return translated
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str]) -> PolicySpec:
|
|
71
|
+
def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str], device: str | None = None) -> PolicySpec:
|
|
72
72
|
if policy_arg is None:
|
|
73
73
|
console.print(ctx.get_help())
|
|
74
74
|
console.print("[yellow]Missing: --policy / -p[/yellow]\n")
|
|
75
75
|
else:
|
|
76
76
|
try:
|
|
77
|
-
return parse_policy_spec(spec=policy_arg).to_policy_spec()
|
|
77
|
+
return parse_policy_spec(spec=policy_arg, device=device).to_policy_spec()
|
|
78
78
|
except (ValueError, ModuleNotFoundError) as e:
|
|
79
79
|
translated = _translate_error(e)
|
|
80
80
|
console.print(f"[yellow]Error parsing policy argument: {translated}[/yellow]\n")
|
|
@@ -90,14 +90,14 @@ def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str]) -> PolicySpec
|
|
|
90
90
|
|
|
91
91
|
|
|
92
92
|
def get_policy_specs_with_proportions(
|
|
93
|
-
ctx: typer.Context, policy_args: Optional[list[str]]
|
|
93
|
+
ctx: typer.Context, policy_args: Optional[list[str]], device: str | None = None
|
|
94
94
|
) -> list[PolicySpecWithProportion]:
|
|
95
95
|
if not policy_args:
|
|
96
96
|
console.print(ctx.get_help())
|
|
97
97
|
console.print("[yellow]Supply at least one: --policy / -p[/yellow]\n")
|
|
98
98
|
else:
|
|
99
99
|
try:
|
|
100
|
-
return [parse_policy_spec(spec=policy_arg) for policy_arg in policy_args]
|
|
100
|
+
return [parse_policy_spec(spec=policy_arg, device=device) for policy_arg in policy_args]
|
|
101
101
|
except (ValueError, ModuleNotFoundError) as e:
|
|
102
102
|
translated = _translate_error(e)
|
|
103
103
|
console.print(f"[yellow]Error parsing policy argument: {translated}[/yellow]")
|
|
@@ -112,7 +112,20 @@ def get_policy_specs_with_proportions(
|
|
|
112
112
|
raise typer.Exit(1)
|
|
113
113
|
|
|
114
114
|
|
|
115
|
-
def
|
|
115
|
+
def _apply_device_override(spec: PolicySpecWithProportion, device: str | None) -> PolicySpecWithProportion:
|
|
116
|
+
if device is None:
|
|
117
|
+
return spec
|
|
118
|
+
init_kwargs = dict(spec.init_kwargs or {})
|
|
119
|
+
init_kwargs["device"] = device
|
|
120
|
+
return PolicySpecWithProportion(
|
|
121
|
+
class_path=spec.class_path,
|
|
122
|
+
data_path=spec.data_path,
|
|
123
|
+
proportion=spec.proportion,
|
|
124
|
+
init_kwargs=init_kwargs,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def parse_policy_spec(spec: str, device: str | None = None) -> PolicySpecWithProportion:
|
|
116
129
|
"""Parse a policy CLI option into its components.
|
|
117
130
|
|
|
118
131
|
Supports two formats:
|
|
@@ -161,13 +174,14 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
|
|
|
161
174
|
uri_part = spec
|
|
162
175
|
fraction = 1.0
|
|
163
176
|
|
|
164
|
-
policy = policy_spec_from_uri(uri_part.strip())
|
|
165
|
-
|
|
177
|
+
policy = policy_spec_from_uri(uri_part.strip(), device=device or "cpu")
|
|
178
|
+
policy_spec = PolicySpecWithProportion(
|
|
166
179
|
class_path=policy.class_path,
|
|
167
180
|
data_path=policy.data_path,
|
|
168
181
|
proportion=fraction,
|
|
169
182
|
init_kwargs=policy.init_kwargs,
|
|
170
183
|
)
|
|
184
|
+
return _apply_device_override(policy_spec, device)
|
|
171
185
|
|
|
172
186
|
entries = [part.strip() for part in spec.split(",") if part.strip()]
|
|
173
187
|
if not entries:
|
|
@@ -176,19 +190,20 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
|
|
|
176
190
|
fraction = 1.0
|
|
177
191
|
first = entries[0]
|
|
178
192
|
if is_uri(first) or ("=" not in first and is_path_like(first)):
|
|
179
|
-
policy = policy_spec_from_uri(first)
|
|
193
|
+
policy = policy_spec_from_uri(first, device=device or "cpu")
|
|
180
194
|
for entry in entries[1:]:
|
|
181
195
|
key, value = parse_key_value(entry)
|
|
182
196
|
if key != "proportion":
|
|
183
197
|
raise ValueError("Only proportion is supported after a checkpoint URI.")
|
|
184
198
|
fraction = parse_proportion(value)
|
|
185
199
|
|
|
186
|
-
|
|
200
|
+
policy_spec = PolicySpecWithProportion(
|
|
187
201
|
class_path=policy.class_path,
|
|
188
202
|
data_path=policy.data_path,
|
|
189
203
|
proportion=fraction,
|
|
190
204
|
init_kwargs=policy.init_kwargs,
|
|
191
205
|
)
|
|
206
|
+
return _apply_device_override(policy_spec, device)
|
|
192
207
|
|
|
193
208
|
if "=" not in first:
|
|
194
209
|
if ":" in first:
|
|
@@ -230,9 +245,10 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
|
|
|
230
245
|
if class_path is None:
|
|
231
246
|
raise ValueError("Policy specification must include class= for key=value format.")
|
|
232
247
|
|
|
233
|
-
|
|
248
|
+
policy_spec = PolicySpecWithProportion(
|
|
234
249
|
class_path=class_path,
|
|
235
250
|
data_path=data_path,
|
|
236
251
|
proportion=fraction,
|
|
237
252
|
init_kwargs=init_kwargs,
|
|
238
253
|
)
|
|
254
|
+
return _apply_device_override(policy_spec, device)
|