cogames 0.3.64__py3-none-any.whl → 0.3.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. cogames/cli/client.py +0 -3
  2. cogames/cli/docsync/docsync.py +7 -1
  3. cogames/cli/mission.py +68 -53
  4. cogames/cli/policy.py +26 -10
  5. cogames/cli/submit.py +128 -142
  6. cogames/cli/utils.py +5 -0
  7. cogames/cogs_vs_clips/clip_difficulty.py +57 -0
  8. cogames/cogs_vs_clips/clips.py +103 -0
  9. cogames/cogs_vs_clips/cog.py +29 -11
  10. cogames/cogs_vs_clips/cogsguard_curriculum.py +122 -0
  11. cogames/cogs_vs_clips/cogsguard_tutorial.py +15 -16
  12. cogames/cogs_vs_clips/config.py +38 -0
  13. cogames/cogs_vs_clips/{cogs_vs_clips_mapgen.md → docs/cogs_vs_clips_mapgen.md} +8 -10
  14. cogames/cogs_vs_clips/evals/README.md +11 -35
  15. cogames/cogs_vs_clips/evals/cogsguard_evals.py +21 -6
  16. cogames/cogs_vs_clips/evals/diagnostic_evals.py +13 -101
  17. cogames/cogs_vs_clips/evals/difficulty_variants.py +16 -28
  18. cogames/cogs_vs_clips/evals/integrated_evals.py +8 -60
  19. cogames/cogs_vs_clips/evals/spanning_evals.py +48 -54
  20. cogames/cogs_vs_clips/mission.py +93 -277
  21. cogames/cogs_vs_clips/missions.py +17 -27
  22. cogames/cogs_vs_clips/{cogsguard_reward_variants.py → reward_variants.py} +22 -2
  23. cogames/cogs_vs_clips/sites.py +41 -30
  24. cogames/cogs_vs_clips/stations.py +39 -84
  25. cogames/cogs_vs_clips/team.py +46 -0
  26. cogames/cogs_vs_clips/{procedural.py → terrain.py} +14 -8
  27. cogames/cogs_vs_clips/variants.py +201 -107
  28. cogames/cogs_vs_clips/weather.py +52 -0
  29. cogames/core.py +87 -0
  30. cogames/docs/SCRIPTED_AGENT.md +3 -3
  31. cogames/evaluate.py +4 -2
  32. cogames/main.py +357 -51
  33. cogames/maps/canidate1_1000.map +1 -1
  34. cogames/maps/canidate1_1000_stations.map +2 -2
  35. cogames/maps/canidate1_500.map +1 -1
  36. cogames/maps/canidate1_500_stations.map +2 -2
  37. cogames/maps/canidate2_1000.map +1 -1
  38. cogames/maps/canidate2_1000_stations.map +2 -2
  39. cogames/maps/canidate2_500.map +1 -1
  40. cogames/maps/canidate2_500_stations.map +1 -1
  41. cogames/maps/canidate3_1000.map +1 -1
  42. cogames/maps/canidate3_1000_stations.map +2 -2
  43. cogames/maps/canidate3_500.map +1 -1
  44. cogames/maps/canidate3_500_stations.map +2 -2
  45. cogames/maps/canidate4_500.map +1 -1
  46. cogames/maps/canidate4_500_stations.map +2 -2
  47. cogames/maps/cave_base_50.map +2 -2
  48. cogames/maps/diagnostic_evals/diagnostic_agile.map +2 -2
  49. cogames/maps/diagnostic_evals/diagnostic_agile_hard.map +2 -2
  50. cogames/maps/diagnostic_evals/diagnostic_charge_up.map +6 -6
  51. cogames/maps/diagnostic_evals/diagnostic_charge_up_hard.map +6 -6
  52. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1.map +6 -6
  53. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1_hard.map +6 -6
  54. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2.map +6 -6
  55. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2_hard.map +6 -6
  56. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3.map +6 -6
  57. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3_hard.map +6 -6
  58. cogames/maps/diagnostic_evals/diagnostic_chest_near.map +6 -6
  59. cogames/maps/diagnostic_evals/diagnostic_chest_search.map +6 -6
  60. cogames/maps/diagnostic_evals/diagnostic_chest_search_hard.map +6 -6
  61. cogames/maps/diagnostic_evals/diagnostic_extract_lab.map +6 -6
  62. cogames/maps/diagnostic_evals/diagnostic_extract_lab_hard.map +6 -6
  63. cogames/maps/diagnostic_evals/diagnostic_memory.map +6 -6
  64. cogames/maps/diagnostic_evals/diagnostic_memory_hard.map +6 -6
  65. cogames/maps/diagnostic_evals/diagnostic_radial.map +2 -2
  66. cogames/maps/diagnostic_evals/diagnostic_radial_hard.map +2 -2
  67. cogames/maps/diagnostic_evals/diagnostic_resource_lab.map +6 -6
  68. cogames/maps/diagnostic_evals/diagnostic_unclip.map +6 -6
  69. cogames/maps/evals/eval_balanced_spread.map +6 -6
  70. cogames/maps/evals/eval_clip_oxygen.map +6 -6
  71. cogames/maps/evals/eval_collect_resources.map +6 -6
  72. cogames/maps/evals/eval_collect_resources_hard.map +6 -6
  73. cogames/maps/evals/eval_collect_resources_medium.map +6 -6
  74. cogames/maps/evals/eval_divide_and_conquer.map +6 -6
  75. cogames/maps/evals/eval_energy_starved.map +6 -6
  76. cogames/maps/evals/eval_multi_coordinated_collect_hard.map +6 -6
  77. cogames/maps/evals/eval_oxygen_bottleneck.map +6 -6
  78. cogames/maps/evals/eval_single_use_world.map +6 -6
  79. cogames/maps/evals/extractor_hub_100x100.map +6 -6
  80. cogames/maps/evals/extractor_hub_30x30.map +6 -6
  81. cogames/maps/evals/extractor_hub_50x50.map +6 -6
  82. cogames/maps/evals/extractor_hub_70x70.map +6 -6
  83. cogames/maps/evals/extractor_hub_80x80.map +6 -6
  84. cogames/maps/machina_100_stations.map +2 -2
  85. cogames/maps/machina_200_stations.map +2 -2
  86. cogames/maps/machina_200_stations_small.map +2 -2
  87. cogames/maps/machina_eval_exp01.map +2 -2
  88. cogames/maps/machina_eval_template_large.map +2 -2
  89. cogames/maps/machinatrainer4agents.map +2 -2
  90. cogames/maps/machinatrainer4agentsbase.map +2 -2
  91. cogames/maps/machinatrainerbig.map +2 -2
  92. cogames/maps/machinatrainersmall.map +2 -2
  93. cogames/maps/planky_evals/aligner_avoid_aoe.map +6 -6
  94. cogames/maps/planky_evals/aligner_full_cycle.map +6 -6
  95. cogames/maps/planky_evals/aligner_gear.map +6 -6
  96. cogames/maps/planky_evals/aligner_hearts.map +6 -6
  97. cogames/maps/planky_evals/aligner_junction.map +6 -6
  98. cogames/maps/planky_evals/exploration_distant.map +6 -6
  99. cogames/maps/planky_evals/maze.map +6 -6
  100. cogames/maps/planky_evals/miner_best_resource.map +6 -6
  101. cogames/maps/planky_evals/miner_deposit.map +6 -6
  102. cogames/maps/planky_evals/miner_extract.map +6 -6
  103. cogames/maps/planky_evals/miner_full_cycle.map +6 -6
  104. cogames/maps/planky_evals/miner_gear.map +6 -6
  105. cogames/maps/planky_evals/multi_role.map +6 -6
  106. cogames/maps/planky_evals/resource_chain.map +6 -6
  107. cogames/maps/planky_evals/scout_explore.map +6 -6
  108. cogames/maps/planky_evals/scout_gear.map +6 -6
  109. cogames/maps/planky_evals/scrambler_full_cycle.map +6 -6
  110. cogames/maps/planky_evals/scrambler_gear.map +6 -6
  111. cogames/maps/planky_evals/scrambler_target.map +6 -6
  112. cogames/maps/planky_evals/stuck_corridor.map +6 -6
  113. cogames/maps/planky_evals/survive_retreat.map +6 -6
  114. cogames/maps/training_facility_clipped.map +2 -2
  115. cogames/maps/training_facility_open_1.map +2 -2
  116. cogames/maps/training_facility_open_2.map +2 -2
  117. cogames/maps/training_facility_open_3.map +2 -2
  118. cogames/maps/training_facility_tight_4.map +2 -2
  119. cogames/maps/training_facility_tight_5.map +2 -2
  120. cogames/maps/vanilla_large.map +2 -2
  121. cogames/maps/vanilla_small.map +2 -2
  122. cogames/pickup.py +6 -5
  123. cogames/play.py +14 -16
  124. cogames/policy/nim_agents/__init__.py +0 -2
  125. cogames/policy/nim_agents/agents.py +0 -11
  126. cogames/policy/starter_agent.py +4 -1
  127. cogames/verbose.py +2 -2
  128. {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/METADATA +45 -29
  129. cogames-0.3.68.dist-info/RECORD +160 -0
  130. metta_alo/scoring.py +7 -7
  131. cogames/cogs_vs_clips/mission_utils.py +0 -19
  132. cogames/cogs_vs_clips/tutorial_missions.py +0 -25
  133. cogames-0.3.64.dist-info/RECORD +0 -159
  134. metta_alo/job_specs.py +0 -17
  135. metta_alo/policy.py +0 -16
  136. metta_alo/pure_single_episode_runner.py +0 -75
  137. metta_alo/rollout.py +0 -322
  138. {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/WHEEL +0 -0
  139. {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/entry_points.txt +0 -0
  140. {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/licenses/LICENSE +0 -0
  141. {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/top_level.txt +0 -0
cogames/cli/client.py CHANGED
@@ -253,9 +253,6 @@ class TournamentServerClient:
253
253
  json=payload,
254
254
  )
255
255
 
256
- def update_policy_version_tags(self, policy_version_id: uuid.UUID, tags: dict[str, str]) -> dict[str, Any]:
257
- return self._put(f"/stats/policies/versions/{policy_version_id}/tags", json=tags)
258
-
259
256
  def get_policy_memberships(self, policy_version_id: uuid.UUID) -> list[dict[str, Any]]:
260
257
  return self._get(f"/tournament/policies/{policy_version_id}/memberships", list[dict[str, Any]])
261
258
 
@@ -120,7 +120,13 @@ def check(
120
120
  typer.echo("\nNotebook documentation is out of sync!", err=True)
121
121
  for error in errors:
122
122
  typer.echo(f" {error}", err=True)
123
- typer.echo("\nTo fix: run 'cogames docsync all'", err=True)
123
+ typer.echo(
124
+ "\nThis can happen if you modified cogames CLI flags (the README includes a command reference)"
125
+ "\nor changed code that affects notebook outputs."
126
+ "\n"
127
+ "\nTo fix: run 'cogames docsync all'",
128
+ err=True,
129
+ )
124
130
  raise typer.Exit(1)
125
131
 
126
132
  typer.echo("All notebooks are in sync!")
cogames/cli/mission.py CHANGED
@@ -8,52 +8,42 @@ from rich import box
8
8
  from rich.table import Table
9
9
 
10
10
  from cogames.cli.base import console
11
- from cogames.cogs_vs_clips.mission import (
12
- MAP_MISSION_DELIMITER,
13
- AnyMission,
14
- Mission,
15
- MissionVariant,
16
- NumCogsVariant,
17
- Site,
18
- )
19
- from cogames.cogs_vs_clips.procedural import MachinaArena
11
+ from cogames.cogs_vs_clips.clip_difficulty import get_cogsguard_difficulty
12
+ from cogames.cogs_vs_clips.mission import CvCMission, NumCogsVariant
20
13
  from cogames.cogs_vs_clips.sites import SITES
14
+ from cogames.cogs_vs_clips.terrain import MachinaArena
21
15
  from cogames.cogs_vs_clips.variants import HIDDEN_VARIANTS, VARIANTS
16
+ from cogames.core import (
17
+ MAP_MISSION_DELIMITER,
18
+ CoGameMissionVariant,
19
+ CoGameSite,
20
+ )
22
21
  from cogames.game import load_mission_config, load_mission_config_from_python
23
22
  from mettagrid import MettaGridConfig
24
23
  from mettagrid.mapgen.mapgen import MapGen
25
24
 
26
25
 
27
26
  @lru_cache(maxsize=1)
28
- def _get_core_missions() -> list[AnyMission]:
29
- from cogames.cogs_vs_clips.missions import get_core_missions
27
+ def _get_core_missions() -> list[CvCMission]:
28
+ from cogames.cogs_vs_clips.missions import get_core_missions # noqa: PLC0415
30
29
 
31
30
  return get_core_missions()
32
31
 
33
32
 
34
33
  @lru_cache(maxsize=1)
35
- def _get_legacy_missions() -> list[Mission]:
36
- from cogames.cogs_vs_clips.missions import get_legacy_missions
34
+ def _get_eval_missions_all() -> list[CvCMission]:
35
+ from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
36
+ from cogames.cogs_vs_clips.evals.integrated_evals import EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS # noqa: PLC0415
37
+ from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS # noqa: PLC0415
37
38
 
38
- return get_legacy_missions()
39
-
40
-
41
- @lru_cache(maxsize=1)
42
- def _get_eval_missions_all() -> list[AnyMission]:
43
- from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS
44
- from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
45
- from cogames.cogs_vs_clips.evals.integrated_evals import EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS
46
- from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS
47
-
48
- missions: list[AnyMission] = []
49
- missions.extend(COGSGUARD_EVAL_MISSIONS)
39
+ missions: list[CvCMission] = []
50
40
  missions.extend(INTEGRATED_EVAL_MISSIONS)
51
41
  missions.extend(SPANNING_EVAL_MISSIONS)
52
42
  missions.extend(mission_cls() for mission_cls in DIAGNOSTIC_EVALS) # type: ignore[call-arg]
53
43
  return missions
54
44
 
55
45
 
56
- def load_mission_set(mission_set: str) -> list[AnyMission]:
46
+ def load_mission_set(mission_set: str) -> list[CvCMission]:
57
47
  """Load a predefined set of evaluation missions.
58
48
 
59
49
  Args:
@@ -70,7 +60,7 @@ def load_mission_set(mission_set: str) -> list[AnyMission]:
70
60
  Raises:
71
61
  ValueError: If mission_set name is unknown
72
62
  """
73
- missions_list: list[AnyMission]
63
+ missions_list: list[CvCMission]
74
64
  if mission_set == "all":
75
65
  # All missions: eval missions + integrated + spanning + diagnostic + core missions
76
66
  missions_list = list(_get_eval_missions_all())
@@ -82,19 +72,21 @@ def load_mission_set(mission_set: str) -> list[AnyMission]:
82
72
  missions_list.append(mission)
83
73
 
84
74
  elif mission_set == "diagnostic_evals":
85
- from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
75
+ from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
86
76
 
87
77
  missions_list = [mission_cls() for mission_cls in DIAGNOSTIC_EVALS] # type: ignore[call-arg]
88
78
  elif mission_set == "cogsguard_evals":
89
- from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS
79
+ from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS # noqa: PLC0415
90
80
 
91
81
  missions_list = list(COGSGUARD_EVAL_MISSIONS)
92
82
  elif mission_set == "integrated_evals":
93
- from cogames.cogs_vs_clips.evals.integrated_evals import EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS
83
+ from cogames.cogs_vs_clips.evals.integrated_evals import ( # noqa: PLC0415
84
+ EVAL_MISSIONS as INTEGRATED_EVAL_MISSIONS,
85
+ )
94
86
 
95
87
  missions_list = list(INTEGRATED_EVAL_MISSIONS)
96
88
  elif mission_set == "spanning_evals":
97
- from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS
89
+ from cogames.cogs_vs_clips.evals.spanning_evals import EVAL_MISSIONS as SPANNING_EVAL_MISSIONS # noqa: PLC0415
98
90
 
99
91
  missions_list = list(SPANNING_EVAL_MISSIONS)
100
92
  else:
@@ -104,14 +96,14 @@ def load_mission_set(mission_set: str) -> list[AnyMission]:
104
96
  return missions_list
105
97
 
106
98
 
107
- def parse_variants(variants_arg: Optional[list[str]]) -> list[MissionVariant]:
99
+ def parse_variants(variants_arg: Optional[list[str]]) -> list[CoGameMissionVariant]:
108
100
  """Parse variant specifications from command line.
109
101
 
110
102
  Args:
111
103
  variants_arg: List of variant names like ["solar_flare", "dark_side"]
112
104
 
113
105
  Returns:
114
- List of configured MissionVariant instances
106
+ List of configured CoGameMissionVariant instances
115
107
 
116
108
  Raises:
117
109
  ValueError: If variant name is unknown
@@ -119,11 +111,11 @@ def parse_variants(variants_arg: Optional[list[str]]) -> list[MissionVariant]:
119
111
  if not variants_arg:
120
112
  return []
121
113
 
122
- variants: list[MissionVariant] = []
114
+ variants: list[CoGameMissionVariant] = []
123
115
  all_variants = [*VARIANTS, *HIDDEN_VARIANTS]
124
116
  for name in variants_arg:
125
117
  # Find matching variant class by instantiating and checking the name
126
- variant: MissionVariant | None = None
118
+ variant: CoGameMissionVariant | None = None
127
119
  for v in all_variants:
128
120
  if v.name == name:
129
121
  variant = v
@@ -140,6 +132,12 @@ def parse_variants(variants_arg: Optional[list[str]]) -> list[MissionVariant]:
140
132
  return variants
141
133
 
142
134
 
135
+ def parse_difficulty(difficulty_arg: Optional[str]) -> Optional[CoGameMissionVariant]:
136
+ if difficulty_arg is None:
137
+ return None
138
+ return get_cogsguard_difficulty(difficulty_arg)
139
+
140
+
143
141
  def get_all_missions() -> list[str]:
144
142
  """Get all core mission names in the format site.mission (excludes evals)."""
145
143
  return [mission.full_name() for mission in _get_core_missions()]
@@ -150,7 +148,7 @@ def get_all_eval_missions() -> list[str]:
150
148
  return [mission.full_name() for mission in _get_eval_missions_all()]
151
149
 
152
150
 
153
- def get_site_by_name(site_name: str) -> Site:
151
+ def get_site_by_name(site_name: str) -> CoGameSite:
154
152
  """Get a site by name.
155
153
 
156
154
  Raises:
@@ -165,17 +163,25 @@ def get_site_by_name(site_name: str) -> Site:
165
163
 
166
164
 
167
165
  def get_mission_name_and_config(
168
- ctx: typer.Context, mission_arg: Optional[str], variants_arg: Optional[list[str]] = None, cogs: Optional[int] = None
169
- ) -> tuple[str, MettaGridConfig, Optional[AnyMission]]:
166
+ ctx: typer.Context,
167
+ mission_arg: Optional[str],
168
+ variants_arg: Optional[list[str]] = None,
169
+ cogs: Optional[int] = None,
170
+ difficulty: Optional[str] = None,
171
+ ) -> tuple[str, MettaGridConfig, Optional[CvCMission]]:
170
172
  if not mission_arg:
171
173
  console.print(ctx.get_help())
172
174
  console.print("[yellow]Missing: --mission / -m[/yellow]\n")
173
175
  else:
174
176
  try:
175
- return get_mission(mission_arg, variants_arg, cogs)
177
+ return get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty)
176
178
  except ValueError as e:
177
- console.print(f"[red]Mission '{mission_arg}' not found.[/red]")
178
- console.print("[dim]Run: cogames missions (or cogames missions <site>) to list options.[/dim]\n")
179
+ error_msg = str(e)
180
+ if "variant" in error_msg.lower() or "difficulty" in error_msg.lower():
181
+ console.print(f"[red]{error_msg}[/red]")
182
+ else:
183
+ console.print(f"[red]Mission '{mission_arg}' not found.[/red]")
184
+ console.print("[dim]Run: cogames missions (or cogames missions <site>) to list options.[/dim]\n")
179
185
  raise typer.Exit(1) from e
180
186
  list_missions()
181
187
 
@@ -191,6 +197,7 @@ def get_mission_names_and_configs(
191
197
  variants_arg: Optional[list[str]] = None,
192
198
  cogs: Optional[int] = None,
193
199
  steps: Optional[int] = None,
200
+ difficulty: Optional[str] = None,
194
201
  ) -> list[tuple[str, MettaGridConfig]]:
195
202
  if not missions_arg:
196
203
  console.print(ctx.get_help())
@@ -200,7 +207,7 @@ def get_mission_names_and_configs(
200
207
  not_deduped = [
201
208
  mission
202
209
  for missions in missions_arg
203
- for mission in _get_missions_by_possible_wildcard(missions, variants_arg, cogs)
210
+ for mission in _get_missions_by_possible_wildcard(missions, variants_arg, cogs, difficulty)
204
211
  ]
205
212
  name_set: set[str] = set()
206
213
  deduped = []
@@ -232,6 +239,7 @@ def _get_missions_by_possible_wildcard(
232
239
  mission_arg: str,
233
240
  variants_arg: Optional[list[str]],
234
241
  cogs: Optional[int],
242
+ difficulty: Optional[str],
235
243
  ) -> list[tuple[str, MettaGridConfig]]:
236
244
  if "*" in mission_arg:
237
245
  # Convert shell-style wildcard to regex pattern
@@ -240,10 +248,12 @@ def _get_missions_by_possible_wildcard(
240
248
  # Drop the Mission (3rd element) for wildcard results
241
249
  return [
242
250
  (name, env_cfg)
243
- for name, env_cfg, _ in (get_mission(m, variants_arg=variants_arg, cogs=cogs) for m in missions)
251
+ for name, env_cfg, _ in (
252
+ get_mission(m, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty) for m in missions
253
+ )
244
254
  ]
245
255
  # Drop the Mission for single mission
246
- name, env_cfg, _ = get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs)
256
+ name, env_cfg, _ = get_mission(mission_arg, variants_arg=variants_arg, cogs=cogs, difficulty=difficulty)
247
257
  return [(name, env_cfg)]
248
258
 
249
259
 
@@ -253,12 +263,10 @@ def find_mission(
253
263
  *,
254
264
  include_evals: bool = False,
255
265
  include_legacy: bool = False,
256
- ) -> AnyMission:
257
- missions: list[AnyMission] = list(_get_core_missions())
266
+ ) -> CvCMission:
267
+ missions: list[CvCMission] = list(_get_core_missions())
258
268
  if include_evals:
259
269
  missions = [*missions, *_get_eval_missions_all()]
260
- if include_legacy:
261
- missions = [*missions, *_get_legacy_missions()]
262
270
 
263
271
  found_site = False
264
272
  for mission in missions:
@@ -288,7 +296,8 @@ def get_mission(
288
296
  variants_arg: Optional[list[str]] = None,
289
297
  cogs: Optional[int] = None,
290
298
  include_legacy: bool = False,
291
- ) -> tuple[str, MettaGridConfig, Optional[AnyMission]]:
299
+ difficulty: Optional[str] = None,
300
+ ) -> tuple[str, MettaGridConfig, Optional[CvCMission]]:
292
301
  """Get a specific mission configuration by name or file path.
293
302
 
294
303
  Args:
@@ -296,15 +305,18 @@ def get_mission(
296
305
  variants_arg: List of variant names like ["solar_flare", "dark_side"]
297
306
  cogs: Number of cogs (agents) to use, overrides the default from the mission
298
307
  include_legacy: Whether to include legacy (pre-CogsGuard) missions
308
+ difficulty: Difficulty name (easy, medium, hard) controlling clips events
299
309
 
300
310
  Returns:
301
- Tuple of (mission name, MettaGridConfig, Mission or None)
311
+ Tuple of (mission name, MettaGridConfig, CvCMission or None)
302
312
 
303
313
  Raises:
304
314
  ValueError: If mission not found or file cannot be loaded
305
315
  """
306
316
  # Check if it's a file path
307
317
  if any(mission_arg.endswith(ext) for ext in [".yaml", ".yml", ".json", ".py"]):
318
+ if difficulty is not None:
319
+ raise ValueError("Difficulty can only be used with mission names, not config files.")
308
320
  path = Path(mission_arg)
309
321
  if not path.exists():
310
322
  raise ValueError(f"File not found: {mission_arg}")
@@ -321,6 +333,7 @@ def get_mission(
321
333
 
322
334
  # Parse variants if provided
323
335
  variants = parse_variants(variants_arg)
336
+ difficulty_variant = parse_difficulty(difficulty)
324
337
 
325
338
  # Otherwise, treat it as a fully qualified mission name, or as a site name
326
339
  if (delim_count := mission_arg.count(MAP_MISSION_DELIMITER)) == 0:
@@ -330,8 +343,10 @@ def get_mission(
330
343
  else:
331
344
  site_name, mission_name = mission_arg.split(MAP_MISSION_DELIMITER)
332
345
 
333
- mission: AnyMission = find_mission(site_name, mission_name, include_evals=True, include_legacy=include_legacy)
346
+ mission: CvCMission = find_mission(site_name, mission_name, include_evals=True, include_legacy=include_legacy)
334
347
 
348
+ if difficulty_variant is not None:
349
+ mission = mission.with_variants([difficulty_variant])
335
350
  if variants:
336
351
  mission = mission.with_variants(variants)
337
352
  if cogs is not None:
@@ -454,7 +469,7 @@ def list_evals() -> None:
454
469
  return
455
470
 
456
471
  # Group missions by site
457
- missions_by_site: dict[str, list[AnyMission]] = {}
472
+ missions_by_site: dict[str, list[CvCMission]] = {}
458
473
  for m in evals:
459
474
  missions_by_site.setdefault(m.site.name, []).append(m)
460
475
 
@@ -506,7 +521,7 @@ def list_evals() -> None:
506
521
  console.print(" [bold]cogames play[/bold] --mission [blue]evals.divide_and_conquer[/blue]")
507
522
 
508
523
 
509
- def describe_mission(mission_name: str, game_config: MettaGridConfig, mission_cfg: AnyMission | None = None) -> None:
524
+ def describe_mission(mission_name: str, game_config: MettaGridConfig, mission_cfg: CvCMission | None = None) -> None:
510
525
  """Print detailed information about a specific mission.
511
526
 
512
527
  Args:
cogames/cli/policy.py CHANGED
@@ -68,13 +68,13 @@ def _translate_error(e: Exception) -> str:
68
68
  return translated
69
69
 
70
70
 
71
- def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str]) -> PolicySpec:
71
+ def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str], device: str | None = None) -> PolicySpec:
72
72
  if policy_arg is None:
73
73
  console.print(ctx.get_help())
74
74
  console.print("[yellow]Missing: --policy / -p[/yellow]\n")
75
75
  else:
76
76
  try:
77
- return parse_policy_spec(spec=policy_arg).to_policy_spec()
77
+ return parse_policy_spec(spec=policy_arg, device=device).to_policy_spec()
78
78
  except (ValueError, ModuleNotFoundError) as e:
79
79
  translated = _translate_error(e)
80
80
  console.print(f"[yellow]Error parsing policy argument: {translated}[/yellow]\n")
@@ -90,14 +90,14 @@ def get_policy_spec(ctx: typer.Context, policy_arg: Optional[str]) -> PolicySpec
90
90
 
91
91
 
92
92
  def get_policy_specs_with_proportions(
93
- ctx: typer.Context, policy_args: Optional[list[str]]
93
+ ctx: typer.Context, policy_args: Optional[list[str]], device: str | None = None
94
94
  ) -> list[PolicySpecWithProportion]:
95
95
  if not policy_args:
96
96
  console.print(ctx.get_help())
97
97
  console.print("[yellow]Supply at least one: --policy / -p[/yellow]\n")
98
98
  else:
99
99
  try:
100
- return [parse_policy_spec(spec=policy_arg) for policy_arg in policy_args]
100
+ return [parse_policy_spec(spec=policy_arg, device=device) for policy_arg in policy_args]
101
101
  except (ValueError, ModuleNotFoundError) as e:
102
102
  translated = _translate_error(e)
103
103
  console.print(f"[yellow]Error parsing policy argument: {translated}[/yellow]")
@@ -112,7 +112,20 @@ def get_policy_specs_with_proportions(
112
112
  raise typer.Exit(1)
113
113
 
114
114
 
115
- def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
115
+ def _apply_device_override(spec: PolicySpecWithProportion, device: str | None) -> PolicySpecWithProportion:
116
+ if device is None:
117
+ return spec
118
+ init_kwargs = dict(spec.init_kwargs or {})
119
+ init_kwargs["device"] = device
120
+ return PolicySpecWithProportion(
121
+ class_path=spec.class_path,
122
+ data_path=spec.data_path,
123
+ proportion=spec.proportion,
124
+ init_kwargs=init_kwargs,
125
+ )
126
+
127
+
128
+ def parse_policy_spec(spec: str, device: str | None = None) -> PolicySpecWithProportion:
116
129
  """Parse a policy CLI option into its components.
117
130
 
118
131
  Supports two formats:
@@ -161,13 +174,14 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
161
174
  uri_part = spec
162
175
  fraction = 1.0
163
176
 
164
- policy = policy_spec_from_uri(uri_part.strip())
165
- return PolicySpecWithProportion(
177
+ policy = policy_spec_from_uri(uri_part.strip(), device=device or "cpu")
178
+ policy_spec = PolicySpecWithProportion(
166
179
  class_path=policy.class_path,
167
180
  data_path=policy.data_path,
168
181
  proportion=fraction,
169
182
  init_kwargs=policy.init_kwargs,
170
183
  )
184
+ return _apply_device_override(policy_spec, device)
171
185
 
172
186
  entries = [part.strip() for part in spec.split(",") if part.strip()]
173
187
  if not entries:
@@ -176,19 +190,20 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
176
190
  fraction = 1.0
177
191
  first = entries[0]
178
192
  if is_uri(first) or ("=" not in first and is_path_like(first)):
179
- policy = policy_spec_from_uri(first)
193
+ policy = policy_spec_from_uri(first, device=device or "cpu")
180
194
  for entry in entries[1:]:
181
195
  key, value = parse_key_value(entry)
182
196
  if key != "proportion":
183
197
  raise ValueError("Only proportion is supported after a checkpoint URI.")
184
198
  fraction = parse_proportion(value)
185
199
 
186
- return PolicySpecWithProportion(
200
+ policy_spec = PolicySpecWithProportion(
187
201
  class_path=policy.class_path,
188
202
  data_path=policy.data_path,
189
203
  proportion=fraction,
190
204
  init_kwargs=policy.init_kwargs,
191
205
  )
206
+ return _apply_device_override(policy_spec, device)
192
207
 
193
208
  if "=" not in first:
194
209
  if ":" in first:
@@ -230,9 +245,10 @@ def parse_policy_spec(spec: str) -> PolicySpecWithProportion:
230
245
  if class_path is None:
231
246
  raise ValueError("Policy specification must include class= for key=value format.")
232
247
 
233
- return PolicySpecWithProportion(
248
+ policy_spec = PolicySpecWithProportion(
234
249
  class_path=class_path,
235
250
  data_path=data_path,
236
251
  proportion=fraction,
237
252
  init_kwargs=init_kwargs,
238
253
  )
254
+ return _apply_device_override(policy_spec, device)