cogames 0.3.65__py3-none-any.whl → 0.3.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. cogames/cli/client.py +0 -3
  2. cogames/cli/docsync/docsync.py +7 -1
  3. cogames/cli/mission.py +44 -19
  4. cogames/cli/policy.py +26 -10
  5. cogames/cli/submit.py +127 -141
  6. cogames/cli/utils.py +5 -0
  7. cogames/cogs_vs_clips/clip_difficulty.py +57 -0
  8. cogames/cogs_vs_clips/clips.py +23 -6
  9. cogames/cogs_vs_clips/cog.py +16 -5
  10. cogames/cogs_vs_clips/cogsguard_curriculum.py +122 -0
  11. cogames/cogs_vs_clips/cogsguard_tutorial.py +5 -5
  12. cogames/cogs_vs_clips/config.py +1 -1
  13. cogames/cogs_vs_clips/docs/cogs_vs_clips_mapgen.md +2 -3
  14. cogames/cogs_vs_clips/evals/README.md +8 -32
  15. cogames/cogs_vs_clips/evals/diagnostic_evals.py +0 -1
  16. cogames/cogs_vs_clips/evals/difficulty_variants.py +7 -10
  17. cogames/cogs_vs_clips/mission.py +38 -10
  18. cogames/cogs_vs_clips/missions.py +1 -1
  19. cogames/cogs_vs_clips/reward_variants.py +173 -0
  20. cogames/cogs_vs_clips/sites.py +6 -5
  21. cogames/cogs_vs_clips/stations.py +13 -9
  22. cogames/cogs_vs_clips/team.py +3 -1
  23. cogames/cogs_vs_clips/terrain.py +2 -2
  24. cogames/cogs_vs_clips/variants.py +175 -4
  25. cogames/cogs_vs_clips/weather.py +52 -0
  26. cogames/docs/SCRIPTED_AGENT.md +3 -3
  27. cogames/evaluate.py +4 -2
  28. cogames/main.py +357 -51
  29. cogames/maps/canidate1_1000.map +1 -1
  30. cogames/maps/canidate1_1000_stations.map +2 -2
  31. cogames/maps/canidate1_500.map +1 -1
  32. cogames/maps/canidate1_500_stations.map +2 -2
  33. cogames/maps/canidate2_1000.map +1 -1
  34. cogames/maps/canidate2_1000_stations.map +2 -2
  35. cogames/maps/canidate2_500.map +1 -1
  36. cogames/maps/canidate2_500_stations.map +1 -1
  37. cogames/maps/canidate3_1000.map +1 -1
  38. cogames/maps/canidate3_1000_stations.map +2 -2
  39. cogames/maps/canidate3_500.map +1 -1
  40. cogames/maps/canidate3_500_stations.map +2 -2
  41. cogames/maps/canidate4_500.map +1 -1
  42. cogames/maps/canidate4_500_stations.map +2 -2
  43. cogames/maps/cave_base_50.map +2 -2
  44. cogames/maps/diagnostic_evals/diagnostic_agile.map +2 -2
  45. cogames/maps/diagnostic_evals/diagnostic_agile_hard.map +2 -2
  46. cogames/maps/diagnostic_evals/diagnostic_charge_up.map +6 -6
  47. cogames/maps/diagnostic_evals/diagnostic_charge_up_hard.map +6 -6
  48. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1.map +6 -6
  49. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1_hard.map +6 -6
  50. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2.map +6 -6
  51. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2_hard.map +6 -6
  52. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3.map +6 -6
  53. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3_hard.map +6 -6
  54. cogames/maps/diagnostic_evals/diagnostic_chest_near.map +6 -6
  55. cogames/maps/diagnostic_evals/diagnostic_chest_search.map +6 -6
  56. cogames/maps/diagnostic_evals/diagnostic_chest_search_hard.map +6 -6
  57. cogames/maps/diagnostic_evals/diagnostic_extract_lab.map +6 -6
  58. cogames/maps/diagnostic_evals/diagnostic_extract_lab_hard.map +6 -6
  59. cogames/maps/diagnostic_evals/diagnostic_memory.map +6 -6
  60. cogames/maps/diagnostic_evals/diagnostic_memory_hard.map +6 -6
  61. cogames/maps/diagnostic_evals/diagnostic_radial.map +2 -2
  62. cogames/maps/diagnostic_evals/diagnostic_radial_hard.map +2 -2
  63. cogames/maps/diagnostic_evals/diagnostic_resource_lab.map +6 -6
  64. cogames/maps/diagnostic_evals/diagnostic_unclip.map +6 -6
  65. cogames/maps/evals/eval_balanced_spread.map +6 -6
  66. cogames/maps/evals/eval_clip_oxygen.map +6 -6
  67. cogames/maps/evals/eval_collect_resources.map +6 -6
  68. cogames/maps/evals/eval_collect_resources_hard.map +6 -6
  69. cogames/maps/evals/eval_collect_resources_medium.map +6 -6
  70. cogames/maps/evals/eval_divide_and_conquer.map +6 -6
  71. cogames/maps/evals/eval_energy_starved.map +6 -6
  72. cogames/maps/evals/eval_multi_coordinated_collect_hard.map +6 -6
  73. cogames/maps/evals/eval_oxygen_bottleneck.map +6 -6
  74. cogames/maps/evals/eval_single_use_world.map +6 -6
  75. cogames/maps/evals/extractor_hub_100x100.map +6 -6
  76. cogames/maps/evals/extractor_hub_30x30.map +6 -6
  77. cogames/maps/evals/extractor_hub_50x50.map +6 -6
  78. cogames/maps/evals/extractor_hub_70x70.map +6 -6
  79. cogames/maps/evals/extractor_hub_80x80.map +6 -6
  80. cogames/maps/machina_100_stations.map +2 -2
  81. cogames/maps/machina_200_stations.map +2 -2
  82. cogames/maps/machina_200_stations_small.map +2 -2
  83. cogames/maps/machina_eval_exp01.map +2 -2
  84. cogames/maps/machina_eval_template_large.map +2 -2
  85. cogames/maps/machinatrainer4agents.map +2 -2
  86. cogames/maps/machinatrainer4agentsbase.map +2 -2
  87. cogames/maps/machinatrainerbig.map +2 -2
  88. cogames/maps/machinatrainersmall.map +2 -2
  89. cogames/maps/planky_evals/aligner_avoid_aoe.map +6 -6
  90. cogames/maps/planky_evals/aligner_full_cycle.map +6 -6
  91. cogames/maps/planky_evals/aligner_gear.map +6 -6
  92. cogames/maps/planky_evals/aligner_hearts.map +6 -6
  93. cogames/maps/planky_evals/aligner_junction.map +6 -6
  94. cogames/maps/planky_evals/exploration_distant.map +6 -6
  95. cogames/maps/planky_evals/maze.map +6 -6
  96. cogames/maps/planky_evals/miner_best_resource.map +6 -6
  97. cogames/maps/planky_evals/miner_deposit.map +6 -6
  98. cogames/maps/planky_evals/miner_extract.map +6 -6
  99. cogames/maps/planky_evals/miner_full_cycle.map +6 -6
  100. cogames/maps/planky_evals/miner_gear.map +6 -6
  101. cogames/maps/planky_evals/multi_role.map +6 -6
  102. cogames/maps/planky_evals/resource_chain.map +6 -6
  103. cogames/maps/planky_evals/scout_explore.map +6 -6
  104. cogames/maps/planky_evals/scout_gear.map +6 -6
  105. cogames/maps/planky_evals/scrambler_full_cycle.map +6 -6
  106. cogames/maps/planky_evals/scrambler_gear.map +6 -6
  107. cogames/maps/planky_evals/scrambler_target.map +6 -6
  108. cogames/maps/planky_evals/stuck_corridor.map +6 -6
  109. cogames/maps/planky_evals/survive_retreat.map +6 -6
  110. cogames/maps/training_facility_clipped.map +2 -2
  111. cogames/maps/training_facility_open_1.map +2 -2
  112. cogames/maps/training_facility_open_2.map +2 -2
  113. cogames/maps/training_facility_open_3.map +2 -2
  114. cogames/maps/training_facility_tight_4.map +2 -2
  115. cogames/maps/training_facility_tight_5.map +2 -2
  116. cogames/maps/vanilla_large.map +2 -2
  117. cogames/maps/vanilla_small.map +2 -2
  118. cogames/pickup.py +6 -5
  119. cogames/play.py +14 -16
  120. cogames/policy/nim_agents/__init__.py +0 -2
  121. cogames/policy/nim_agents/agents.py +0 -11
  122. cogames/policy/starter_agent.py +4 -1
  123. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/METADATA +45 -29
  124. cogames-0.3.68.dist-info/RECORD +160 -0
  125. metta_alo/scoring.py +7 -7
  126. cogames-0.3.65.dist-info/RECORD +0 -160
  127. metta_alo/job_specs.py +0 -17
  128. metta_alo/policy.py +0 -16
  129. metta_alo/pure_single_episode_runner.py +0 -75
  130. metta_alo/rollout.py +0 -322
  131. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/WHEEL +0 -0
  132. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/entry_points.txt +0 -0
  133. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/licenses/LICENSE +0 -0
  134. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/top_level.txt +0 -0
cogames/main.py CHANGED
@@ -8,6 +8,7 @@ from cogames.cli.utils import suppress_noisy_logs
8
8
 
9
9
  suppress_noisy_logs()
10
10
 
11
+ import importlib
11
12
  import importlib.metadata
12
13
  import importlib.util
13
14
  import json
@@ -17,6 +18,7 @@ import subprocess
17
18
  import sys
18
19
  import threading
19
20
  import time
21
+ from dataclasses import dataclass
20
22
  from pathlib import Path
21
23
  from typing import Literal, Optional, TypeVar
22
24
 
@@ -38,7 +40,6 @@ from cogames import play as play_module
38
40
  from cogames import train as train_module
39
41
  from cogames.cli.base import console
40
42
  from cogames.cli.client import SeasonInfo, TournamentServerClient, fetch_default_season, fetch_season_info
41
- from cogames.cli.docsync import docsync
42
43
  from cogames.cli.leaderboard import (
43
44
  leaderboard_cmd,
44
45
  parse_policy_identifier,
@@ -63,6 +64,7 @@ from cogames.cli.policy import (
63
64
  policy_arg_w_proportion_example,
64
65
  )
65
66
  from cogames.cli.submit import DEFAULT_SUBMIT_SERVER, results_url_for_season, upload_policy, validate_policy_spec
67
+ from cogames.cogs_vs_clips.mission import CvCMission, NumCogsVariant
66
68
  from cogames.curricula import make_rotation
67
69
  from cogames.device import resolve_training_device
68
70
  from mettagrid.config.mettagrid_config import MettaGridConfig
@@ -87,6 +89,158 @@ logger = logging.getLogger("cogames.main")
87
89
  T = TypeVar("T")
88
90
 
89
91
 
92
+ @dataclass(frozen=True)
93
+ class DiagnoseCase:
94
+ name: str
95
+ env_cfg: MettaGridConfig
96
+
97
+
98
+ def _load_eval_missions(module_path: str) -> list[CvCMission]:
99
+ module = importlib.import_module(module_path)
100
+ missions = getattr(module, "EVAL_MISSIONS", None)
101
+ if missions is None:
102
+ raise AttributeError(f"Module '{module_path}' does not define EVAL_MISSIONS")
103
+ return list(missions)
104
+
105
+
106
+ def _load_diagnose_missions(mission_set: str) -> list[CvCMission]:
107
+ if mission_set == "thinky_evals":
108
+ return []
109
+
110
+ if mission_set == "all":
111
+ from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS # noqa: PLC0415
112
+ from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
113
+ from cogames.cogs_vs_clips.missions import MISSIONS as ALL_MISSIONS # noqa: PLC0415
114
+
115
+ missions_list: list[CvCMission] = []
116
+ missions_list.extend(COGSGUARD_EVAL_MISSIONS)
117
+ missions_list.extend(_load_eval_missions("cogames.cogs_vs_clips.evals.integrated_evals"))
118
+ missions_list.extend(_load_eval_missions("cogames.cogs_vs_clips.evals.spanning_evals"))
119
+ missions_list.extend([mission_cls() for mission_cls in DIAGNOSTIC_EVALS]) # type: ignore[call-arg]
120
+ eval_mission_names = {mission.name for mission in missions_list}
121
+ for mission in ALL_MISSIONS:
122
+ if mission.name not in eval_mission_names:
123
+ missions_list.append(mission)
124
+ return missions_list
125
+
126
+ if mission_set == "cogsguard_evals":
127
+ from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS # noqa: PLC0415
128
+
129
+ return list(COGSGUARD_EVAL_MISSIONS)
130
+
131
+ if mission_set == "diagnostic_evals":
132
+ from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
133
+
134
+ return [mission_cls() for mission_cls in DIAGNOSTIC_EVALS] # type: ignore[call-arg]
135
+
136
+ if mission_set == "tournament":
137
+ from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
138
+
139
+ missions_list = []
140
+ missions_list.extend(_load_eval_missions("cogames.cogs_vs_clips.evals.integrated_evals"))
141
+ missions_list.extend([mission_cls() for mission_cls in DIAGNOSTIC_EVALS]) # type: ignore[call-arg]
142
+ return missions_list
143
+
144
+ if mission_set == "integrated_evals":
145
+ return _load_eval_missions("cogames.cogs_vs_clips.evals.integrated_evals")
146
+
147
+ if mission_set == "spanning_evals":
148
+ return _load_eval_missions("cogames.cogs_vs_clips.evals.spanning_evals")
149
+
150
+ raise ValueError(f"Unknown mission set: {mission_set}")
151
+
152
+
153
+ def _build_thinky_mission_map() -> dict[str, CvCMission]:
154
+ from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS # noqa: PLC0415
155
+ from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
156
+ from cogames.cogs_vs_clips.missions import MISSIONS as ALL_MISSIONS # noqa: PLC0415
157
+
158
+ missions: list[CvCMission] = []
159
+ missions.extend(_load_eval_missions("cogames.cogs_vs_clips.evals.integrated_evals"))
160
+ missions.extend(_load_eval_missions("cogames.cogs_vs_clips.evals.spanning_evals"))
161
+ missions.extend([mission_cls() for mission_cls in DIAGNOSTIC_EVALS]) # type: ignore[call-arg]
162
+ missions.extend(COGSGUARD_EVAL_MISSIONS)
163
+ missions.extend(ALL_MISSIONS)
164
+
165
+ mission_map: dict[str, CvCMission] = {}
166
+ for mission in missions:
167
+ mission_map.setdefault(mission.name, mission)
168
+ return mission_map
169
+
170
+
171
+ def _matches_experiment(mission_name: str, experiment_filters: set[str]) -> bool:
172
+ if not experiment_filters:
173
+ return True
174
+ if mission_name in experiment_filters:
175
+ return True
176
+ suffix = f".{mission_name}"
177
+ return any(name.endswith(suffix) for name in experiment_filters)
178
+
179
+
180
+ def _cogs_for_mission(mission: CvCMission, cogs_list: list[int], respect_cogs_list: bool) -> list[int]:
181
+ fixed_cogs = getattr(mission, "num_cogs", None)
182
+ if fixed_cogs is not None:
183
+ if respect_cogs_list and fixed_cogs not in cogs_list:
184
+ return []
185
+ return [fixed_cogs]
186
+ site = getattr(mission, "site", None)
187
+ if site is None:
188
+ return list(cogs_list)
189
+ min_cogs = getattr(site, "min_cogs", None)
190
+ max_cogs = getattr(site, "max_cogs", None)
191
+ return [
192
+ num_cogs
193
+ for num_cogs in cogs_list
194
+ if (min_cogs is None or num_cogs >= min_cogs) and (max_cogs is None or num_cogs <= max_cogs)
195
+ ]
196
+
197
+
198
+ def _build_diagnose_case(mission: CvCMission, num_cogs: int, steps: int) -> DiagnoseCase:
199
+ mission_with_cogs = mission.with_variants([NumCogsVariant(num_cogs=num_cogs)])
200
+ env_cfg = mission_with_cogs.make_env()
201
+ env_cfg.game.max_steps = steps
202
+ name = f"{mission.full_name()} (cogs={num_cogs})"
203
+ return DiagnoseCase(name=name, env_cfg=env_cfg)
204
+
205
+
206
+ def _build_diagnose_cases(
207
+ *,
208
+ mission_set: str,
209
+ experiments: Optional[list[str]],
210
+ cogs: Optional[list[int]],
211
+ steps: int,
212
+ ) -> list[DiagnoseCase]:
213
+ experiment_filters = set(experiments or [])
214
+ cogs_list = cogs if cogs else [1, 2, 4]
215
+ respect_cogs_list = cogs is not None
216
+ cases: list[DiagnoseCase] = []
217
+
218
+ if mission_set == "thinky_evals":
219
+ from cogames_agents.policy.nim_agents.thinky_eval import EVALS as THINKY_EVALS # noqa: PLC0415
220
+
221
+ mission_map = _build_thinky_mission_map()
222
+ for exp_name, _tag, num_cogs in THINKY_EVALS:
223
+ if not _matches_experiment(exp_name, experiment_filters):
224
+ continue
225
+ if respect_cogs_list and num_cogs not in cogs_list:
226
+ continue
227
+ base_mission = mission_map.get(exp_name)
228
+ if base_mission is None:
229
+ logger.warning("Thinky eval mission '%s' not found; skipping.", exp_name)
230
+ continue
231
+ cases.append(_build_diagnose_case(base_mission, num_cogs, steps))
232
+ return cases
233
+
234
+ missions = _load_diagnose_missions(mission_set)
235
+ for mission in missions:
236
+ if not _matches_experiment(mission.name, experiment_filters):
237
+ continue
238
+ for num_cogs in _cogs_for_mission(mission, cogs_list, respect_cogs_list):
239
+ cases.append(_build_diagnose_case(mission, num_cogs, steps))
240
+
241
+ return cases
242
+
243
+
90
244
  def _resolve_mettascope_script() -> Path:
91
245
  spec = importlib.util.find_spec("mettagrid")
92
246
  if spec is None or spec.origin is None:
@@ -132,7 +286,18 @@ tutorial_app = typer.Typer(
132
286
  if register_tribal_cli is not None:
133
287
  register_tribal_cli(app)
134
288
 
135
- app.add_typer(docsync.app, name="docsync", hidden=True)
289
+
290
+ @app.command(
291
+ name="docsync",
292
+ hidden=True,
293
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
294
+ add_help_option=False,
295
+ )
296
+ def docsync_cmd(ctx: typer.Context) -> None:
297
+ """Sync cogames docs between .ipynb, .py, and .md formats (dev-only)."""
298
+ from cogames.cli.docsync import docsync # noqa: PLC0415
299
+
300
+ docsync.app(prog_name="cogames docsync", standalone_mode=False, args=list(ctx.args))
136
301
 
137
302
 
138
303
  @tutorial_app.command(
@@ -160,7 +325,7 @@ def tutorial_cmd(
160
325
  console.print("[dim]Initializing Mettascope...[/dim]")
161
326
 
162
327
  # Load tutorial mission (CogsGuard)
163
- from cogames.cogs_vs_clips.missions import make_cogsguard_mission
328
+ from cogames.cogs_vs_clips.missions import make_cogsguard_mission # noqa: PLC0415
164
329
 
165
330
  # Create environment config
166
331
  env_cfg = make_cogsguard_mission(num_agents=1, max_steps=1000).make_env()
@@ -310,7 +475,7 @@ def cogsguard_tutorial_cmd(
310
475
  console.print("[dim]Initializing Mettascope...[/dim]")
311
476
 
312
477
  # Load CogsGuard tutorial mission
313
- from cogames.cogs_vs_clips.cogsguard_tutorial import CogsGuardTutorialMission
478
+ from cogames.cogs_vs_clips.cogsguard_tutorial import CogsGuardTutorialMission # noqa: PLC0415
314
479
 
315
480
  # Create environment config
316
481
  env_cfg = CogsGuardTutorialMission.make_env()
@@ -526,6 +691,13 @@ def games_cmd(
526
691
  help="Apply variant (requires -m, repeatable)",
527
692
  rich_help_panel="Describe",
528
693
  ),
694
+ difficulty: Optional[str] = typer.Option(
695
+ None,
696
+ "--difficulty",
697
+ metavar="LEVEL",
698
+ help="Difficulty (easy, medium, hard) controlling clips events (requires -m)",
699
+ rich_help_panel="Describe",
700
+ ),
529
701
  format_: Optional[Literal["yaml", "json"]] = typer.Option(
530
702
  None,
531
703
  "--format",
@@ -569,7 +741,13 @@ def games_cmd(
569
741
  return
570
742
 
571
743
  try:
572
- resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(ctx, mission, variant, cogs)
744
+ resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(
745
+ ctx,
746
+ mission,
747
+ variants_arg=variant,
748
+ cogs=cogs,
749
+ difficulty=difficulty,
750
+ )
573
751
  except typer.Exit as exc:
574
752
  if exc.exit_code != 1:
575
753
  raise
@@ -653,6 +831,13 @@ def describe_cmd(
653
831
  help="Apply variant (repeatable)",
654
832
  rich_help_panel="Configuration",
655
833
  ),
834
+ difficulty: Optional[str] = typer.Option(
835
+ None,
836
+ "--difficulty",
837
+ metavar="LEVEL",
838
+ help="Difficulty (easy, medium, hard) controlling clips events",
839
+ rich_help_panel="Configuration",
840
+ ),
656
841
  _help: bool = typer.Option(
657
842
  False,
658
843
  "--help",
@@ -663,7 +848,13 @@ def describe_cmd(
663
848
  rich_help_panel="Other",
664
849
  ),
665
850
  ) -> None:
666
- resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(ctx, mission, variant, cogs)
851
+ resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(
852
+ ctx,
853
+ mission,
854
+ variants_arg=variant,
855
+ cogs=cogs,
856
+ difficulty=difficulty,
857
+ )
667
858
  describe_mission(resolved_mission, env_cfg, mission_cfg)
668
859
 
669
860
 
@@ -711,6 +902,13 @@ def play_cmd(
711
902
  help="Apply variant modifier (repeatable)",
712
903
  rich_help_panel="Game Setup",
713
904
  ),
905
+ difficulty: Optional[str] = typer.Option(
906
+ None,
907
+ "--difficulty",
908
+ metavar="LEVEL",
909
+ help="Difficulty (easy, medium, hard) controlling clips events",
910
+ rich_help_panel="Game Setup",
911
+ ),
714
912
  cogs: Optional[int] = typer.Option(
715
913
  None,
716
914
  "--cogs",
@@ -729,6 +927,13 @@ def play_cmd(
729
927
  help="Policy controlling cogs ([bold]noop[/bold], [bold]random[/bold], [bold]lstm[/bold], or path)",
730
928
  rich_help_panel="Policy",
731
929
  ),
930
+ device: str = typer.Option(
931
+ "auto",
932
+ "--device",
933
+ metavar="DEVICE",
934
+ help="Policy device (auto, cpu, cuda, cuda:0, etc.)",
935
+ rich_help_panel="Policy",
936
+ ),
732
937
  # --- Simulation ---
733
938
  steps: int = typer.Option(
734
939
  1000,
@@ -762,6 +967,12 @@ def play_cmd(
762
967
  show_default="same as --seed",
763
968
  rich_help_panel="Simulation",
764
969
  ),
970
+ autostart: bool = typer.Option(
971
+ False,
972
+ "--autostart",
973
+ help="Start simulation immediately without waiting for user input",
974
+ rich_help_panel="Simulation",
975
+ ),
765
976
  # --- Output ---
766
977
  save_replay_dir: Optional[Path] = typer.Option( # noqa: B008
767
978
  None,
@@ -796,7 +1007,13 @@ def play_cmd(
796
1007
  rich_help_panel="Other",
797
1008
  ),
798
1009
  ) -> None:
799
- resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(ctx, mission, variant, cogs)
1010
+ resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(
1011
+ ctx,
1012
+ mission,
1013
+ variants_arg=variant,
1014
+ cogs=cogs,
1015
+ difficulty=difficulty,
1016
+ )
800
1017
 
801
1018
  if print_cvc_config or print_mg_config:
802
1019
  try:
@@ -811,9 +1028,8 @@ def play_cmd(
811
1028
  if isinstance(map_builder, MapGen.Config):
812
1029
  map_builder.seed = map_seed
813
1030
 
814
- policy_spec = get_policy_spec(ctx, policy)
815
- console.print(f"[cyan]Playing {resolved_mission}[/cyan]")
816
- console.print(f"Max Steps: {steps}, Render: {render}")
1031
+ resolved_device = resolve_training_device(console, device)
1032
+ policy_spec = get_policy_spec(ctx, policy, device=str(resolved_device))
817
1033
 
818
1034
  if ctx.get_parameter_source("steps") in (
819
1035
  ParameterSource.COMMANDLINE,
@@ -822,14 +1038,19 @@ def play_cmd(
822
1038
  ):
823
1039
  env_cfg.game.max_steps = steps
824
1040
 
1041
+ console.print(f"[cyan]Playing {resolved_mission}[/cyan]")
1042
+ console.print(f"Max Steps: {env_cfg.game.max_steps}, Render: {render}")
1043
+
825
1044
  play_module.play(
826
1045
  console,
827
1046
  env_cfg=env_cfg,
828
1047
  policy_spec=policy_spec,
829
1048
  seed=seed,
1049
+ device=str(resolved_device),
830
1050
  render_mode=render,
831
1051
  game_name=resolved_mission,
832
1052
  save_replay=save_replay_dir,
1053
+ autostart=autostart,
833
1054
  )
834
1055
 
835
1056
 
@@ -1151,6 +1372,13 @@ def train_cmd(
1151
1372
  help="Mission variant (repeatable)",
1152
1373
  rich_help_panel="Mission Setup",
1153
1374
  ),
1375
+ difficulty: Optional[str] = typer.Option(
1376
+ None,
1377
+ "--difficulty",
1378
+ metavar="LEVEL",
1379
+ help="Difficulty (easy, medium, hard) controlling clips events",
1380
+ rich_help_panel="Mission Setup",
1381
+ ),
1154
1382
  # --- Policy ---
1155
1383
  policy: str = typer.Option(
1156
1384
  "class=lstm",
@@ -1261,7 +1489,13 @@ def train_cmd(
1261
1489
  rich_help_panel="Other",
1262
1490
  ),
1263
1491
  ) -> None:
1264
- selected_missions = get_mission_names_and_configs(ctx, missions, variants_arg=variant, cogs=cogs)
1492
+ selected_missions = get_mission_names_and_configs(
1493
+ ctx,
1494
+ missions,
1495
+ variants_arg=variant,
1496
+ cogs=cogs,
1497
+ difficulty=difficulty,
1498
+ )
1265
1499
  if len(selected_missions) == 1:
1266
1500
  mission_name, env_cfg = selected_missions[0]
1267
1501
  supplier = None
@@ -1380,6 +1614,13 @@ def run_cmd(
1380
1614
  help="Mission variant (repeatable)",
1381
1615
  rich_help_panel="Mission",
1382
1616
  ),
1617
+ difficulty: Optional[str] = typer.Option(
1618
+ None,
1619
+ "--difficulty",
1620
+ metavar="LEVEL",
1621
+ help="Difficulty (easy, medium, hard) controlling clips events",
1622
+ rich_help_panel="Mission",
1623
+ ),
1383
1624
  # --- Policy ---
1384
1625
  policies: Optional[list[str]] = typer.Option( # noqa: B008
1385
1626
  None,
@@ -1389,6 +1630,13 @@ def run_cmd(
1389
1630
  help=f"Policies to evaluate: ({policy_arg_w_proportion_example}...)",
1390
1631
  rich_help_panel="Policy",
1391
1632
  ),
1633
+ device: str = typer.Option(
1634
+ "auto",
1635
+ "--device",
1636
+ metavar="DEVICE",
1637
+ help="Policy device (auto, cpu, cuda, cuda:0, etc.)",
1638
+ rich_help_panel="Policy",
1639
+ ),
1392
1640
  # --- Simulation ---
1393
1641
  episodes: int = typer.Option(
1394
1642
  10,
@@ -1400,12 +1648,13 @@ def run_cmd(
1400
1648
  rich_help_panel="Simulation",
1401
1649
  ),
1402
1650
  steps: Optional[int] = typer.Option(
1403
- 1000,
1651
+ None,
1404
1652
  "--steps",
1405
1653
  "-s",
1406
1654
  metavar="N",
1407
1655
  help="Max steps per episode",
1408
1656
  min=1,
1657
+ show_default="from mission",
1409
1658
  rich_help_panel="Simulation",
1410
1659
  ),
1411
1660
  seed: int = typer.Option(
@@ -1465,7 +1714,7 @@ def run_cmd(
1465
1714
  raise typer.Exit(1)
1466
1715
 
1467
1716
  if mission_set:
1468
- from cogames.cli.mission import load_mission_set
1717
+ from cogames.cli.mission import load_mission_set # noqa: PLC0415
1469
1718
 
1470
1719
  try:
1471
1720
  mission_objs = load_mission_set(mission_set)
@@ -1479,7 +1728,14 @@ def run_cmd(
1479
1728
  if cogs is None:
1480
1729
  cogs = 4
1481
1730
 
1482
- selected_missions = get_mission_names_and_configs(ctx, missions, variants_arg=variant, cogs=cogs, steps=steps)
1731
+ selected_missions = get_mission_names_and_configs(
1732
+ ctx,
1733
+ missions,
1734
+ variants_arg=variant,
1735
+ cogs=cogs,
1736
+ steps=steps,
1737
+ difficulty=difficulty,
1738
+ )
1483
1739
 
1484
1740
  # Optional MapGen seed override for procedural maps.
1485
1741
  if map_seed is not None:
@@ -1488,7 +1744,8 @@ def run_cmd(
1488
1744
  if isinstance(map_builder, MapGen.Config):
1489
1745
  map_builder.seed = map_seed
1490
1746
 
1491
- policy_specs = get_policy_specs_with_proportions(ctx, policies)
1747
+ resolved_device = resolve_training_device(console, device)
1748
+ policy_specs = get_policy_specs_with_proportions(ctx, policies, device=str(resolved_device))
1492
1749
 
1493
1750
  if ctx.info_name == "scrimmage":
1494
1751
  if len(policy_specs) != 1:
@@ -1510,6 +1767,7 @@ def run_cmd(
1510
1767
  action_timeout_ms=action_timeout_ms,
1511
1768
  episodes=episodes,
1512
1769
  seed=seed,
1770
+ device=str(resolved_device),
1513
1771
  output_format=format_,
1514
1772
  save_replay=str(save_replay_dir) if save_replay_dir else None,
1515
1773
  )
@@ -1552,6 +1810,13 @@ def pickup_cmd(
1552
1810
  help="Mission variant (repeatable)",
1553
1811
  rich_help_panel="Mission",
1554
1812
  ),
1813
+ difficulty: Optional[str] = typer.Option(
1814
+ None,
1815
+ "--difficulty",
1816
+ metavar="LEVEL",
1817
+ help="Difficulty (easy, medium, hard) controlling clips events",
1818
+ rich_help_panel="Mission",
1819
+ ),
1555
1820
  # --- Policy ---
1556
1821
  policy: Optional[str] = typer.Option(
1557
1822
  None,
@@ -1568,6 +1833,13 @@ def pickup_cmd(
1568
1833
  help="Pool policy (repeatable)",
1569
1834
  rich_help_panel="Policy",
1570
1835
  ),
1836
+ device: str = typer.Option(
1837
+ "auto",
1838
+ "--device",
1839
+ metavar="DEVICE",
1840
+ help="Policy device (auto, cpu, cuda, cuda:0, etc.)",
1841
+ rich_help_panel="Policy",
1842
+ ),
1571
1843
  # --- Simulation ---
1572
1844
  episodes: int = typer.Option(
1573
1845
  1,
@@ -1631,7 +1903,7 @@ def pickup_cmd(
1631
1903
  rich_help_panel="Other",
1632
1904
  ),
1633
1905
  ) -> None:
1634
- import httpx
1906
+ import httpx # noqa: PLC0415
1635
1907
 
1636
1908
  if policy is None:
1637
1909
  console.print(ctx.get_help())
@@ -1644,15 +1916,22 @@ def pickup_cmd(
1644
1916
  raise typer.Exit(1)
1645
1917
 
1646
1918
  # Resolve mission
1647
- resolved_mission, env_cfg, _ = get_mission_name_and_config(ctx, mission, variants_arg=variant, cogs=cogs)
1919
+ resolved_mission, env_cfg, _ = get_mission_name_and_config(
1920
+ ctx,
1921
+ mission,
1922
+ variants_arg=variant,
1923
+ cogs=cogs,
1924
+ difficulty=difficulty,
1925
+ )
1648
1926
  if steps is not None:
1649
1927
  env_cfg.game.max_steps = steps
1650
1928
 
1651
1929
  candidate_label = policy
1652
1930
  pool_labels = pool
1653
- candidate_spec = get_policy_spec(ctx, policy)
1931
+ resolved_device = resolve_training_device(console, device)
1932
+ candidate_spec = get_policy_spec(ctx, policy, device=str(resolved_device))
1654
1933
  try:
1655
- pool_specs = [parse_policy_spec(spec).to_policy_spec() for spec in pool]
1934
+ pool_specs = [parse_policy_spec(spec, device=str(resolved_device)).to_policy_spec() for spec in pool]
1656
1935
  except (ValueError, ModuleNotFoundError, httpx.HTTPError) as exc:
1657
1936
  translated = _translate_error(exc)
1658
1937
  console.print(f"[yellow]Error parsing pool policy: {translated}[/yellow]\n")
@@ -1669,6 +1948,7 @@ def pickup_cmd(
1669
1948
  map_seed=map_seed,
1670
1949
  action_timeout_ms=action_timeout_ms,
1671
1950
  save_replay_dir=save_replay_dir,
1951
+ device=str(resolved_device),
1672
1952
  candidate_label=candidate_label,
1673
1953
  pool_labels=pool_labels,
1674
1954
  )
@@ -1762,10 +2042,10 @@ def login_cmd(
1762
2042
  rich_help_panel="Other",
1763
2043
  ),
1764
2044
  ) -> None:
1765
- from urllib.parse import urlparse
2045
+ from urllib.parse import urlparse # noqa: PLC0415
1766
2046
 
1767
2047
  # Check if we already have a token
1768
- from cogames.auth import BaseCLIAuthenticator
2048
+ from cogames.auth import BaseCLIAuthenticator # noqa: PLC0415
1769
2049
 
1770
2050
  temp_auth = BaseCLIAuthenticator(
1771
2051
  token_file_name="cogames.yaml",
@@ -1823,7 +2103,9 @@ app.command(
1823
2103
  rich_help_panel="Evaluate",
1824
2104
  epilog="""[dim]Examples:[/dim]
1825
2105
 
1826
- [cyan]cogames diagnose ./train_dir/my_run[/cyan] Default diagnostics
2106
+ [cyan]cogames diagnose ./train_dir/my_run[/cyan] Default CogsGuard evals
2107
+
2108
+ [cyan]cogames diagnose lstm -S diagnostic_evals[/cyan] Diagnostic evals (non-CogsGuard)
1827
2109
 
1828
2110
  [cyan]cogames diagnose lstm -S tournament[/cyan] Tournament suite
1829
2111
 
@@ -1831,6 +2113,7 @@ app.command(
1831
2113
  add_help_option=False,
1832
2114
  )
1833
2115
  def diagnose_cmd(
2116
+ ctx: typer.Context,
1834
2117
  policy: str = typer.Argument(
1835
2118
  ...,
1836
2119
  metavar="POLICY",
@@ -1838,6 +2121,7 @@ def diagnose_cmd(
1838
2121
  ),
1839
2122
  # --- Evaluation ---
1840
2123
  mission_set: Literal[
2124
+ "cogsguard_evals",
1841
2125
  "diagnostic_evals",
1842
2126
  "integrated_evals",
1843
2127
  "spanning_evals",
@@ -1845,7 +2129,7 @@ def diagnose_cmd(
1845
2129
  "tournament",
1846
2130
  "all",
1847
2131
  ] = typer.Option(
1848
- "diagnostic_evals",
2132
+ "cogsguard_evals",
1849
2133
  "--mission-set",
1850
2134
  "-S",
1851
2135
  metavar="SET",
@@ -1867,6 +2151,13 @@ def diagnose_cmd(
1867
2151
  help="Agent counts to test (repeatable)",
1868
2152
  rich_help_panel="Evaluation",
1869
2153
  ),
2154
+ device: str = typer.Option(
2155
+ "auto",
2156
+ "--device",
2157
+ metavar="DEVICE",
2158
+ help="Policy device (auto, cpu, cuda, cuda:0, etc.)",
2159
+ rich_help_panel="Evaluation",
2160
+ ),
1870
2161
  # --- Simulation ---
1871
2162
  steps: int = typer.Option(
1872
2163
  1000,
@@ -1895,28 +2186,30 @@ def diagnose_cmd(
1895
2186
  rich_help_panel="Other",
1896
2187
  ),
1897
2188
  ) -> None:
1898
- script_path = Path(__file__).resolve().parents[2] / "scripts" / "run_evaluation.py"
1899
-
1900
- cmd = [sys.executable, str(script_path)]
1901
- cmd.extend(["--mission-set", mission_set])
1902
-
1903
- if experiments:
1904
- cmd.append("--experiments")
1905
- cmd.extend(experiments)
1906
-
1907
- if cogs:
1908
- cmd.append("--cogs")
1909
- cmd.extend(str(c) for c in cogs)
1910
-
1911
- cmd.extend(["--steps", str(steps)])
1912
- cmd.extend(["--repeats", str(episodes)])
1913
- cmd.append("--no-plots")
1914
-
1915
- cmd.extend(["--policy", policy])
2189
+ resolved_device = resolve_training_device(console, device)
2190
+ policy_spec = get_policy_spec(ctx, policy, device=str(resolved_device))
2191
+
2192
+ cases = _build_diagnose_cases(
2193
+ mission_set=mission_set,
2194
+ experiments=experiments,
2195
+ cogs=cogs,
2196
+ steps=steps,
2197
+ )
2198
+ if not cases:
2199
+ console.print("[red]No evaluation cases matched your filters.[/red]")
2200
+ raise typer.Exit(1)
1916
2201
 
1917
- console.print("[cyan]Running diagnostic evaluation...[/cyan]")
1918
- console.print(f"[dim]{' '.join(cmd)}[/dim]")
1919
- subprocess.run(cmd, check=True)
2202
+ console.print(f"[cyan]Running diagnostic evaluation ({len(cases)} cases)...[/cyan]")
2203
+ evaluate_module.evaluate(
2204
+ console,
2205
+ missions=[(case.name, case.env_cfg) for case in cases],
2206
+ policy_specs=[policy_spec],
2207
+ proportions=[1.0],
2208
+ action_timeout_ms=10000,
2209
+ episodes=episodes,
2210
+ seed=42,
2211
+ device=str(resolved_device),
2212
+ )
1920
2213
 
1921
2214
 
1922
2215
  def _resolve_season(server: str, season_name: str | None = None) -> SeasonInfo:
@@ -1950,6 +2243,13 @@ def validate_policy_cmd(
1950
2243
  help=f"Policy specification: {policy_arg_example}",
1951
2244
  rich_help_panel="Policy",
1952
2245
  ),
2246
+ device: str = typer.Option(
2247
+ "auto",
2248
+ "--device",
2249
+ metavar="DEVICE",
2250
+ help="Policy device (auto, cpu, cuda, cuda:0, etc.)",
2251
+ rich_help_panel="Policy",
2252
+ ),
1953
2253
  setup_script: Optional[str] = typer.Option(
1954
2254
  None,
1955
2255
  "--setup-script",
@@ -1991,9 +2291,9 @@ def validate_policy_cmd(
1991
2291
  env_cfg = MettaGridConfig.model_validate(config_data)
1992
2292
 
1993
2293
  if setup_script:
1994
- import subprocess
1995
- import sys
1996
- from pathlib import Path
2294
+ import subprocess # noqa: PLC0415
2295
+ import sys # noqa: PLC0415
2296
+ from pathlib import Path # noqa: PLC0415
1997
2297
 
1998
2298
  script_path = Path(setup_script)
1999
2299
  if not script_path.exists():
@@ -2012,8 +2312,14 @@ def validate_policy_cmd(
2012
2312
  raise typer.Exit(1)
2013
2313
  console.print("[green]Setup script completed[/green]")
2014
2314
 
2015
- policy_spec = get_policy_spec(ctx, policy)
2016
- validate_policy_spec(policy_spec, env_cfg)
2315
+ resolved_device = resolve_training_device(console, device)
2316
+ policy_spec = get_policy_spec(ctx, policy, device=str(resolved_device))
2317
+ validate_policy_spec(
2318
+ policy_spec,
2319
+ env_cfg,
2320
+ device=str(resolved_device),
2321
+ season=season_info.name,
2322
+ )
2017
2323
  console.print("[green]Policy validated successfully[/green]")
2018
2324
  raise typer.Exit(0)
2019
2325
 
@@ -2224,7 +2530,7 @@ def submit_cmd(
2224
2530
  rich_help_panel="Other",
2225
2531
  ),
2226
2532
  ) -> None:
2227
- import httpx
2533
+ import httpx # noqa: PLC0415
2228
2534
 
2229
2535
  season_info = _resolve_season(server, season)
2230
2536
  season_name = season_info.name
@@ -2322,7 +2628,7 @@ def docs_cmd(
2322
2628
 
2323
2629
  # If no argument provided, show available documents
2324
2630
  if doc_name is None:
2325
- from rich.table import Table
2631
+ from rich.table import Table # noqa: PLC0415
2326
2632
 
2327
2633
  console.print("\n[bold cyan]Available Documents:[/bold cyan]\n")
2328
2634
  table = Table(show_header=True, header_style="bold magenta", box=box.ROUNDED, padding=(0, 1))