cogames 0.3.64__py3-none-any.whl → 0.3.68__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cogames/cli/client.py +0 -3
- cogames/cli/docsync/docsync.py +7 -1
- cogames/cli/mission.py +68 -53
- cogames/cli/policy.py +26 -10
- cogames/cli/submit.py +128 -142
- cogames/cli/utils.py +5 -0
- cogames/cogs_vs_clips/clip_difficulty.py +57 -0
- cogames/cogs_vs_clips/clips.py +103 -0
- cogames/cogs_vs_clips/cog.py +29 -11
- cogames/cogs_vs_clips/cogsguard_curriculum.py +122 -0
- cogames/cogs_vs_clips/cogsguard_tutorial.py +15 -16
- cogames/cogs_vs_clips/config.py +38 -0
- cogames/cogs_vs_clips/{cogs_vs_clips_mapgen.md → docs/cogs_vs_clips_mapgen.md} +8 -10
- cogames/cogs_vs_clips/evals/README.md +11 -35
- cogames/cogs_vs_clips/evals/cogsguard_evals.py +21 -6
- cogames/cogs_vs_clips/evals/diagnostic_evals.py +13 -101
- cogames/cogs_vs_clips/evals/difficulty_variants.py +16 -28
- cogames/cogs_vs_clips/evals/integrated_evals.py +8 -60
- cogames/cogs_vs_clips/evals/spanning_evals.py +48 -54
- cogames/cogs_vs_clips/mission.py +93 -277
- cogames/cogs_vs_clips/missions.py +17 -27
- cogames/cogs_vs_clips/{cogsguard_reward_variants.py → reward_variants.py} +22 -2
- cogames/cogs_vs_clips/sites.py +41 -30
- cogames/cogs_vs_clips/stations.py +39 -84
- cogames/cogs_vs_clips/team.py +46 -0
- cogames/cogs_vs_clips/{procedural.py → terrain.py} +14 -8
- cogames/cogs_vs_clips/variants.py +201 -107
- cogames/cogs_vs_clips/weather.py +52 -0
- cogames/core.py +87 -0
- cogames/docs/SCRIPTED_AGENT.md +3 -3
- cogames/evaluate.py +4 -2
- cogames/main.py +357 -51
- cogames/maps/canidate1_1000.map +1 -1
- cogames/maps/canidate1_1000_stations.map +2 -2
- cogames/maps/canidate1_500.map +1 -1
- cogames/maps/canidate1_500_stations.map +2 -2
- cogames/maps/canidate2_1000.map +1 -1
- cogames/maps/canidate2_1000_stations.map +2 -2
- cogames/maps/canidate2_500.map +1 -1
- cogames/maps/canidate2_500_stations.map +1 -1
- cogames/maps/canidate3_1000.map +1 -1
- cogames/maps/canidate3_1000_stations.map +2 -2
- cogames/maps/canidate3_500.map +1 -1
- cogames/maps/canidate3_500_stations.map +2 -2
- cogames/maps/canidate4_500.map +1 -1
- cogames/maps/canidate4_500_stations.map +2 -2
- cogames/maps/cave_base_50.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_agile.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_agile_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_charge_up.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_charge_up_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation1.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation1_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation2.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation2_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation3.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation3_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_near.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_search.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_chest_search_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_extract_lab.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_extract_lab_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_memory.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_memory_hard.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_radial.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_radial_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_resource_lab.map +6 -6
- cogames/maps/diagnostic_evals/diagnostic_unclip.map +6 -6
- cogames/maps/evals/eval_balanced_spread.map +6 -6
- cogames/maps/evals/eval_clip_oxygen.map +6 -6
- cogames/maps/evals/eval_collect_resources.map +6 -6
- cogames/maps/evals/eval_collect_resources_hard.map +6 -6
- cogames/maps/evals/eval_collect_resources_medium.map +6 -6
- cogames/maps/evals/eval_divide_and_conquer.map +6 -6
- cogames/maps/evals/eval_energy_starved.map +6 -6
- cogames/maps/evals/eval_multi_coordinated_collect_hard.map +6 -6
- cogames/maps/evals/eval_oxygen_bottleneck.map +6 -6
- cogames/maps/evals/eval_single_use_world.map +6 -6
- cogames/maps/evals/extractor_hub_100x100.map +6 -6
- cogames/maps/evals/extractor_hub_30x30.map +6 -6
- cogames/maps/evals/extractor_hub_50x50.map +6 -6
- cogames/maps/evals/extractor_hub_70x70.map +6 -6
- cogames/maps/evals/extractor_hub_80x80.map +6 -6
- cogames/maps/machina_100_stations.map +2 -2
- cogames/maps/machina_200_stations.map +2 -2
- cogames/maps/machina_200_stations_small.map +2 -2
- cogames/maps/machina_eval_exp01.map +2 -2
- cogames/maps/machina_eval_template_large.map +2 -2
- cogames/maps/machinatrainer4agents.map +2 -2
- cogames/maps/machinatrainer4agentsbase.map +2 -2
- cogames/maps/machinatrainerbig.map +2 -2
- cogames/maps/machinatrainersmall.map +2 -2
- cogames/maps/planky_evals/aligner_avoid_aoe.map +6 -6
- cogames/maps/planky_evals/aligner_full_cycle.map +6 -6
- cogames/maps/planky_evals/aligner_gear.map +6 -6
- cogames/maps/planky_evals/aligner_hearts.map +6 -6
- cogames/maps/planky_evals/aligner_junction.map +6 -6
- cogames/maps/planky_evals/exploration_distant.map +6 -6
- cogames/maps/planky_evals/maze.map +6 -6
- cogames/maps/planky_evals/miner_best_resource.map +6 -6
- cogames/maps/planky_evals/miner_deposit.map +6 -6
- cogames/maps/planky_evals/miner_extract.map +6 -6
- cogames/maps/planky_evals/miner_full_cycle.map +6 -6
- cogames/maps/planky_evals/miner_gear.map +6 -6
- cogames/maps/planky_evals/multi_role.map +6 -6
- cogames/maps/planky_evals/resource_chain.map +6 -6
- cogames/maps/planky_evals/scout_explore.map +6 -6
- cogames/maps/planky_evals/scout_gear.map +6 -6
- cogames/maps/planky_evals/scrambler_full_cycle.map +6 -6
- cogames/maps/planky_evals/scrambler_gear.map +6 -6
- cogames/maps/planky_evals/scrambler_target.map +6 -6
- cogames/maps/planky_evals/stuck_corridor.map +6 -6
- cogames/maps/planky_evals/survive_retreat.map +6 -6
- cogames/maps/training_facility_clipped.map +2 -2
- cogames/maps/training_facility_open_1.map +2 -2
- cogames/maps/training_facility_open_2.map +2 -2
- cogames/maps/training_facility_open_3.map +2 -2
- cogames/maps/training_facility_tight_4.map +2 -2
- cogames/maps/training_facility_tight_5.map +2 -2
- cogames/maps/vanilla_large.map +2 -2
- cogames/maps/vanilla_small.map +2 -2
- cogames/pickup.py +6 -5
- cogames/play.py +14 -16
- cogames/policy/nim_agents/__init__.py +0 -2
- cogames/policy/nim_agents/agents.py +0 -11
- cogames/policy/starter_agent.py +4 -1
- cogames/verbose.py +2 -2
- {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/METADATA +45 -29
- cogames-0.3.68.dist-info/RECORD +160 -0
- metta_alo/scoring.py +7 -7
- cogames/cogs_vs_clips/mission_utils.py +0 -19
- cogames/cogs_vs_clips/tutorial_missions.py +0 -25
- cogames-0.3.64.dist-info/RECORD +0 -159
- metta_alo/job_specs.py +0 -17
- metta_alo/policy.py +0 -16
- metta_alo/pure_single_episode_runner.py +0 -75
- metta_alo/rollout.py +0 -322
- {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/WHEEL +0 -0
- {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/entry_points.txt +0 -0
- {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/licenses/LICENSE +0 -0
- {cogames-0.3.64.dist-info → cogames-0.3.68.dist-info}/top_level.txt +0 -0
cogames/main.py
CHANGED
|
@@ -8,6 +8,7 @@ from cogames.cli.utils import suppress_noisy_logs
|
|
|
8
8
|
|
|
9
9
|
suppress_noisy_logs()
|
|
10
10
|
|
|
11
|
+
import importlib
|
|
11
12
|
import importlib.metadata
|
|
12
13
|
import importlib.util
|
|
13
14
|
import json
|
|
@@ -17,6 +18,7 @@ import subprocess
|
|
|
17
18
|
import sys
|
|
18
19
|
import threading
|
|
19
20
|
import time
|
|
21
|
+
from dataclasses import dataclass
|
|
20
22
|
from pathlib import Path
|
|
21
23
|
from typing import Literal, Optional, TypeVar
|
|
22
24
|
|
|
@@ -38,7 +40,6 @@ from cogames import play as play_module
|
|
|
38
40
|
from cogames import train as train_module
|
|
39
41
|
from cogames.cli.base import console
|
|
40
42
|
from cogames.cli.client import SeasonInfo, TournamentServerClient, fetch_default_season, fetch_season_info
|
|
41
|
-
from cogames.cli.docsync import docsync
|
|
42
43
|
from cogames.cli.leaderboard import (
|
|
43
44
|
leaderboard_cmd,
|
|
44
45
|
parse_policy_identifier,
|
|
@@ -63,6 +64,7 @@ from cogames.cli.policy import (
|
|
|
63
64
|
policy_arg_w_proportion_example,
|
|
64
65
|
)
|
|
65
66
|
from cogames.cli.submit import DEFAULT_SUBMIT_SERVER, results_url_for_season, upload_policy, validate_policy_spec
|
|
67
|
+
from cogames.cogs_vs_clips.mission import CvCMission, NumCogsVariant
|
|
66
68
|
from cogames.curricula import make_rotation
|
|
67
69
|
from cogames.device import resolve_training_device
|
|
68
70
|
from mettagrid.config.mettagrid_config import MettaGridConfig
|
|
@@ -87,6 +89,158 @@ logger = logging.getLogger("cogames.main")
|
|
|
87
89
|
T = TypeVar("T")
|
|
88
90
|
|
|
89
91
|
|
|
92
|
+
@dataclass(frozen=True)
|
|
93
|
+
class DiagnoseCase:
|
|
94
|
+
name: str
|
|
95
|
+
env_cfg: MettaGridConfig
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _load_eval_missions(module_path: str) -> list[CvCMission]:
|
|
99
|
+
module = importlib.import_module(module_path)
|
|
100
|
+
missions = getattr(module, "EVAL_MISSIONS", None)
|
|
101
|
+
if missions is None:
|
|
102
|
+
raise AttributeError(f"Module '{module_path}' does not define EVAL_MISSIONS")
|
|
103
|
+
return list(missions)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _load_diagnose_missions(mission_set: str) -> list[CvCMission]:
|
|
107
|
+
if mission_set == "thinky_evals":
|
|
108
|
+
return []
|
|
109
|
+
|
|
110
|
+
if mission_set == "all":
|
|
111
|
+
from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS # noqa: PLC0415
|
|
112
|
+
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
|
|
113
|
+
from cogames.cogs_vs_clips.missions import MISSIONS as ALL_MISSIONS # noqa: PLC0415
|
|
114
|
+
|
|
115
|
+
missions_list: list[CvCMission] = []
|
|
116
|
+
missions_list.extend(COGSGUARD_EVAL_MISSIONS)
|
|
117
|
+
missions_list.extend(_load_eval_missions("cogames.cogs_vs_clips.evals.integrated_evals"))
|
|
118
|
+
missions_list.extend(_load_eval_missions("cogames.cogs_vs_clips.evals.spanning_evals"))
|
|
119
|
+
missions_list.extend([mission_cls() for mission_cls in DIAGNOSTIC_EVALS]) # type: ignore[call-arg]
|
|
120
|
+
eval_mission_names = {mission.name for mission in missions_list}
|
|
121
|
+
for mission in ALL_MISSIONS:
|
|
122
|
+
if mission.name not in eval_mission_names:
|
|
123
|
+
missions_list.append(mission)
|
|
124
|
+
return missions_list
|
|
125
|
+
|
|
126
|
+
if mission_set == "cogsguard_evals":
|
|
127
|
+
from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS # noqa: PLC0415
|
|
128
|
+
|
|
129
|
+
return list(COGSGUARD_EVAL_MISSIONS)
|
|
130
|
+
|
|
131
|
+
if mission_set == "diagnostic_evals":
|
|
132
|
+
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
|
|
133
|
+
|
|
134
|
+
return [mission_cls() for mission_cls in DIAGNOSTIC_EVALS] # type: ignore[call-arg]
|
|
135
|
+
|
|
136
|
+
if mission_set == "tournament":
|
|
137
|
+
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
|
|
138
|
+
|
|
139
|
+
missions_list = []
|
|
140
|
+
missions_list.extend(_load_eval_missions("cogames.cogs_vs_clips.evals.integrated_evals"))
|
|
141
|
+
missions_list.extend([mission_cls() for mission_cls in DIAGNOSTIC_EVALS]) # type: ignore[call-arg]
|
|
142
|
+
return missions_list
|
|
143
|
+
|
|
144
|
+
if mission_set == "integrated_evals":
|
|
145
|
+
return _load_eval_missions("cogames.cogs_vs_clips.evals.integrated_evals")
|
|
146
|
+
|
|
147
|
+
if mission_set == "spanning_evals":
|
|
148
|
+
return _load_eval_missions("cogames.cogs_vs_clips.evals.spanning_evals")
|
|
149
|
+
|
|
150
|
+
raise ValueError(f"Unknown mission set: {mission_set}")
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _build_thinky_mission_map() -> dict[str, CvCMission]:
|
|
154
|
+
from cogames.cogs_vs_clips.evals.cogsguard_evals import COGSGUARD_EVAL_MISSIONS # noqa: PLC0415
|
|
155
|
+
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS # noqa: PLC0415
|
|
156
|
+
from cogames.cogs_vs_clips.missions import MISSIONS as ALL_MISSIONS # noqa: PLC0415
|
|
157
|
+
|
|
158
|
+
missions: list[CvCMission] = []
|
|
159
|
+
missions.extend(_load_eval_missions("cogames.cogs_vs_clips.evals.integrated_evals"))
|
|
160
|
+
missions.extend(_load_eval_missions("cogames.cogs_vs_clips.evals.spanning_evals"))
|
|
161
|
+
missions.extend([mission_cls() for mission_cls in DIAGNOSTIC_EVALS]) # type: ignore[call-arg]
|
|
162
|
+
missions.extend(COGSGUARD_EVAL_MISSIONS)
|
|
163
|
+
missions.extend(ALL_MISSIONS)
|
|
164
|
+
|
|
165
|
+
mission_map: dict[str, CvCMission] = {}
|
|
166
|
+
for mission in missions:
|
|
167
|
+
mission_map.setdefault(mission.name, mission)
|
|
168
|
+
return mission_map
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _matches_experiment(mission_name: str, experiment_filters: set[str]) -> bool:
|
|
172
|
+
if not experiment_filters:
|
|
173
|
+
return True
|
|
174
|
+
if mission_name in experiment_filters:
|
|
175
|
+
return True
|
|
176
|
+
suffix = f".{mission_name}"
|
|
177
|
+
return any(name.endswith(suffix) for name in experiment_filters)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _cogs_for_mission(mission: CvCMission, cogs_list: list[int], respect_cogs_list: bool) -> list[int]:
|
|
181
|
+
fixed_cogs = getattr(mission, "num_cogs", None)
|
|
182
|
+
if fixed_cogs is not None:
|
|
183
|
+
if respect_cogs_list and fixed_cogs not in cogs_list:
|
|
184
|
+
return []
|
|
185
|
+
return [fixed_cogs]
|
|
186
|
+
site = getattr(mission, "site", None)
|
|
187
|
+
if site is None:
|
|
188
|
+
return list(cogs_list)
|
|
189
|
+
min_cogs = getattr(site, "min_cogs", None)
|
|
190
|
+
max_cogs = getattr(site, "max_cogs", None)
|
|
191
|
+
return [
|
|
192
|
+
num_cogs
|
|
193
|
+
for num_cogs in cogs_list
|
|
194
|
+
if (min_cogs is None or num_cogs >= min_cogs) and (max_cogs is None or num_cogs <= max_cogs)
|
|
195
|
+
]
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _build_diagnose_case(mission: CvCMission, num_cogs: int, steps: int) -> DiagnoseCase:
|
|
199
|
+
mission_with_cogs = mission.with_variants([NumCogsVariant(num_cogs=num_cogs)])
|
|
200
|
+
env_cfg = mission_with_cogs.make_env()
|
|
201
|
+
env_cfg.game.max_steps = steps
|
|
202
|
+
name = f"{mission.full_name()} (cogs={num_cogs})"
|
|
203
|
+
return DiagnoseCase(name=name, env_cfg=env_cfg)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _build_diagnose_cases(
|
|
207
|
+
*,
|
|
208
|
+
mission_set: str,
|
|
209
|
+
experiments: Optional[list[str]],
|
|
210
|
+
cogs: Optional[list[int]],
|
|
211
|
+
steps: int,
|
|
212
|
+
) -> list[DiagnoseCase]:
|
|
213
|
+
experiment_filters = set(experiments or [])
|
|
214
|
+
cogs_list = cogs if cogs else [1, 2, 4]
|
|
215
|
+
respect_cogs_list = cogs is not None
|
|
216
|
+
cases: list[DiagnoseCase] = []
|
|
217
|
+
|
|
218
|
+
if mission_set == "thinky_evals":
|
|
219
|
+
from cogames_agents.policy.nim_agents.thinky_eval import EVALS as THINKY_EVALS # noqa: PLC0415
|
|
220
|
+
|
|
221
|
+
mission_map = _build_thinky_mission_map()
|
|
222
|
+
for exp_name, _tag, num_cogs in THINKY_EVALS:
|
|
223
|
+
if not _matches_experiment(exp_name, experiment_filters):
|
|
224
|
+
continue
|
|
225
|
+
if respect_cogs_list and num_cogs not in cogs_list:
|
|
226
|
+
continue
|
|
227
|
+
base_mission = mission_map.get(exp_name)
|
|
228
|
+
if base_mission is None:
|
|
229
|
+
logger.warning("Thinky eval mission '%s' not found; skipping.", exp_name)
|
|
230
|
+
continue
|
|
231
|
+
cases.append(_build_diagnose_case(base_mission, num_cogs, steps))
|
|
232
|
+
return cases
|
|
233
|
+
|
|
234
|
+
missions = _load_diagnose_missions(mission_set)
|
|
235
|
+
for mission in missions:
|
|
236
|
+
if not _matches_experiment(mission.name, experiment_filters):
|
|
237
|
+
continue
|
|
238
|
+
for num_cogs in _cogs_for_mission(mission, cogs_list, respect_cogs_list):
|
|
239
|
+
cases.append(_build_diagnose_case(mission, num_cogs, steps))
|
|
240
|
+
|
|
241
|
+
return cases
|
|
242
|
+
|
|
243
|
+
|
|
90
244
|
def _resolve_mettascope_script() -> Path:
|
|
91
245
|
spec = importlib.util.find_spec("mettagrid")
|
|
92
246
|
if spec is None or spec.origin is None:
|
|
@@ -132,7 +286,18 @@ tutorial_app = typer.Typer(
|
|
|
132
286
|
if register_tribal_cli is not None:
|
|
133
287
|
register_tribal_cli(app)
|
|
134
288
|
|
|
135
|
-
|
|
289
|
+
|
|
290
|
+
@app.command(
|
|
291
|
+
name="docsync",
|
|
292
|
+
hidden=True,
|
|
293
|
+
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
|
294
|
+
add_help_option=False,
|
|
295
|
+
)
|
|
296
|
+
def docsync_cmd(ctx: typer.Context) -> None:
|
|
297
|
+
"""Sync cogames docs between .ipynb, .py, and .md formats (dev-only)."""
|
|
298
|
+
from cogames.cli.docsync import docsync # noqa: PLC0415
|
|
299
|
+
|
|
300
|
+
docsync.app(prog_name="cogames docsync", standalone_mode=False, args=list(ctx.args))
|
|
136
301
|
|
|
137
302
|
|
|
138
303
|
@tutorial_app.command(
|
|
@@ -160,7 +325,7 @@ def tutorial_cmd(
|
|
|
160
325
|
console.print("[dim]Initializing Mettascope...[/dim]")
|
|
161
326
|
|
|
162
327
|
# Load tutorial mission (CogsGuard)
|
|
163
|
-
from cogames.cogs_vs_clips.missions import make_cogsguard_mission
|
|
328
|
+
from cogames.cogs_vs_clips.missions import make_cogsguard_mission # noqa: PLC0415
|
|
164
329
|
|
|
165
330
|
# Create environment config
|
|
166
331
|
env_cfg = make_cogsguard_mission(num_agents=1, max_steps=1000).make_env()
|
|
@@ -310,7 +475,7 @@ def cogsguard_tutorial_cmd(
|
|
|
310
475
|
console.print("[dim]Initializing Mettascope...[/dim]")
|
|
311
476
|
|
|
312
477
|
# Load CogsGuard tutorial mission
|
|
313
|
-
from cogames.cogs_vs_clips.cogsguard_tutorial import CogsGuardTutorialMission
|
|
478
|
+
from cogames.cogs_vs_clips.cogsguard_tutorial import CogsGuardTutorialMission # noqa: PLC0415
|
|
314
479
|
|
|
315
480
|
# Create environment config
|
|
316
481
|
env_cfg = CogsGuardTutorialMission.make_env()
|
|
@@ -526,6 +691,13 @@ def games_cmd(
|
|
|
526
691
|
help="Apply variant (requires -m, repeatable)",
|
|
527
692
|
rich_help_panel="Describe",
|
|
528
693
|
),
|
|
694
|
+
difficulty: Optional[str] = typer.Option(
|
|
695
|
+
None,
|
|
696
|
+
"--difficulty",
|
|
697
|
+
metavar="LEVEL",
|
|
698
|
+
help="Difficulty (easy, medium, hard) controlling clips events (requires -m)",
|
|
699
|
+
rich_help_panel="Describe",
|
|
700
|
+
),
|
|
529
701
|
format_: Optional[Literal["yaml", "json"]] = typer.Option(
|
|
530
702
|
None,
|
|
531
703
|
"--format",
|
|
@@ -569,7 +741,13 @@ def games_cmd(
|
|
|
569
741
|
return
|
|
570
742
|
|
|
571
743
|
try:
|
|
572
|
-
resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(
|
|
744
|
+
resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(
|
|
745
|
+
ctx,
|
|
746
|
+
mission,
|
|
747
|
+
variants_arg=variant,
|
|
748
|
+
cogs=cogs,
|
|
749
|
+
difficulty=difficulty,
|
|
750
|
+
)
|
|
573
751
|
except typer.Exit as exc:
|
|
574
752
|
if exc.exit_code != 1:
|
|
575
753
|
raise
|
|
@@ -653,6 +831,13 @@ def describe_cmd(
|
|
|
653
831
|
help="Apply variant (repeatable)",
|
|
654
832
|
rich_help_panel="Configuration",
|
|
655
833
|
),
|
|
834
|
+
difficulty: Optional[str] = typer.Option(
|
|
835
|
+
None,
|
|
836
|
+
"--difficulty",
|
|
837
|
+
metavar="LEVEL",
|
|
838
|
+
help="Difficulty (easy, medium, hard) controlling clips events",
|
|
839
|
+
rich_help_panel="Configuration",
|
|
840
|
+
),
|
|
656
841
|
_help: bool = typer.Option(
|
|
657
842
|
False,
|
|
658
843
|
"--help",
|
|
@@ -663,7 +848,13 @@ def describe_cmd(
|
|
|
663
848
|
rich_help_panel="Other",
|
|
664
849
|
),
|
|
665
850
|
) -> None:
|
|
666
|
-
resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(
|
|
851
|
+
resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(
|
|
852
|
+
ctx,
|
|
853
|
+
mission,
|
|
854
|
+
variants_arg=variant,
|
|
855
|
+
cogs=cogs,
|
|
856
|
+
difficulty=difficulty,
|
|
857
|
+
)
|
|
667
858
|
describe_mission(resolved_mission, env_cfg, mission_cfg)
|
|
668
859
|
|
|
669
860
|
|
|
@@ -711,6 +902,13 @@ def play_cmd(
|
|
|
711
902
|
help="Apply variant modifier (repeatable)",
|
|
712
903
|
rich_help_panel="Game Setup",
|
|
713
904
|
),
|
|
905
|
+
difficulty: Optional[str] = typer.Option(
|
|
906
|
+
None,
|
|
907
|
+
"--difficulty",
|
|
908
|
+
metavar="LEVEL",
|
|
909
|
+
help="Difficulty (easy, medium, hard) controlling clips events",
|
|
910
|
+
rich_help_panel="Game Setup",
|
|
911
|
+
),
|
|
714
912
|
cogs: Optional[int] = typer.Option(
|
|
715
913
|
None,
|
|
716
914
|
"--cogs",
|
|
@@ -729,6 +927,13 @@ def play_cmd(
|
|
|
729
927
|
help="Policy controlling cogs ([bold]noop[/bold], [bold]random[/bold], [bold]lstm[/bold], or path)",
|
|
730
928
|
rich_help_panel="Policy",
|
|
731
929
|
),
|
|
930
|
+
device: str = typer.Option(
|
|
931
|
+
"auto",
|
|
932
|
+
"--device",
|
|
933
|
+
metavar="DEVICE",
|
|
934
|
+
help="Policy device (auto, cpu, cuda, cuda:0, etc.)",
|
|
935
|
+
rich_help_panel="Policy",
|
|
936
|
+
),
|
|
732
937
|
# --- Simulation ---
|
|
733
938
|
steps: int = typer.Option(
|
|
734
939
|
1000,
|
|
@@ -762,6 +967,12 @@ def play_cmd(
|
|
|
762
967
|
show_default="same as --seed",
|
|
763
968
|
rich_help_panel="Simulation",
|
|
764
969
|
),
|
|
970
|
+
autostart: bool = typer.Option(
|
|
971
|
+
False,
|
|
972
|
+
"--autostart",
|
|
973
|
+
help="Start simulation immediately without waiting for user input",
|
|
974
|
+
rich_help_panel="Simulation",
|
|
975
|
+
),
|
|
765
976
|
# --- Output ---
|
|
766
977
|
save_replay_dir: Optional[Path] = typer.Option( # noqa: B008
|
|
767
978
|
None,
|
|
@@ -796,7 +1007,13 @@ def play_cmd(
|
|
|
796
1007
|
rich_help_panel="Other",
|
|
797
1008
|
),
|
|
798
1009
|
) -> None:
|
|
799
|
-
resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(
|
|
1010
|
+
resolved_mission, env_cfg, mission_cfg = get_mission_name_and_config(
|
|
1011
|
+
ctx,
|
|
1012
|
+
mission,
|
|
1013
|
+
variants_arg=variant,
|
|
1014
|
+
cogs=cogs,
|
|
1015
|
+
difficulty=difficulty,
|
|
1016
|
+
)
|
|
800
1017
|
|
|
801
1018
|
if print_cvc_config or print_mg_config:
|
|
802
1019
|
try:
|
|
@@ -811,9 +1028,8 @@ def play_cmd(
|
|
|
811
1028
|
if isinstance(map_builder, MapGen.Config):
|
|
812
1029
|
map_builder.seed = map_seed
|
|
813
1030
|
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
console.print(f"Max Steps: {steps}, Render: {render}")
|
|
1031
|
+
resolved_device = resolve_training_device(console, device)
|
|
1032
|
+
policy_spec = get_policy_spec(ctx, policy, device=str(resolved_device))
|
|
817
1033
|
|
|
818
1034
|
if ctx.get_parameter_source("steps") in (
|
|
819
1035
|
ParameterSource.COMMANDLINE,
|
|
@@ -822,14 +1038,19 @@ def play_cmd(
|
|
|
822
1038
|
):
|
|
823
1039
|
env_cfg.game.max_steps = steps
|
|
824
1040
|
|
|
1041
|
+
console.print(f"[cyan]Playing {resolved_mission}[/cyan]")
|
|
1042
|
+
console.print(f"Max Steps: {env_cfg.game.max_steps}, Render: {render}")
|
|
1043
|
+
|
|
825
1044
|
play_module.play(
|
|
826
1045
|
console,
|
|
827
1046
|
env_cfg=env_cfg,
|
|
828
1047
|
policy_spec=policy_spec,
|
|
829
1048
|
seed=seed,
|
|
1049
|
+
device=str(resolved_device),
|
|
830
1050
|
render_mode=render,
|
|
831
1051
|
game_name=resolved_mission,
|
|
832
1052
|
save_replay=save_replay_dir,
|
|
1053
|
+
autostart=autostart,
|
|
833
1054
|
)
|
|
834
1055
|
|
|
835
1056
|
|
|
@@ -1151,6 +1372,13 @@ def train_cmd(
|
|
|
1151
1372
|
help="Mission variant (repeatable)",
|
|
1152
1373
|
rich_help_panel="Mission Setup",
|
|
1153
1374
|
),
|
|
1375
|
+
difficulty: Optional[str] = typer.Option(
|
|
1376
|
+
None,
|
|
1377
|
+
"--difficulty",
|
|
1378
|
+
metavar="LEVEL",
|
|
1379
|
+
help="Difficulty (easy, medium, hard) controlling clips events",
|
|
1380
|
+
rich_help_panel="Mission Setup",
|
|
1381
|
+
),
|
|
1154
1382
|
# --- Policy ---
|
|
1155
1383
|
policy: str = typer.Option(
|
|
1156
1384
|
"class=lstm",
|
|
@@ -1261,7 +1489,13 @@ def train_cmd(
|
|
|
1261
1489
|
rich_help_panel="Other",
|
|
1262
1490
|
),
|
|
1263
1491
|
) -> None:
|
|
1264
|
-
selected_missions = get_mission_names_and_configs(
|
|
1492
|
+
selected_missions = get_mission_names_and_configs(
|
|
1493
|
+
ctx,
|
|
1494
|
+
missions,
|
|
1495
|
+
variants_arg=variant,
|
|
1496
|
+
cogs=cogs,
|
|
1497
|
+
difficulty=difficulty,
|
|
1498
|
+
)
|
|
1265
1499
|
if len(selected_missions) == 1:
|
|
1266
1500
|
mission_name, env_cfg = selected_missions[0]
|
|
1267
1501
|
supplier = None
|
|
@@ -1380,6 +1614,13 @@ def run_cmd(
|
|
|
1380
1614
|
help="Mission variant (repeatable)",
|
|
1381
1615
|
rich_help_panel="Mission",
|
|
1382
1616
|
),
|
|
1617
|
+
difficulty: Optional[str] = typer.Option(
|
|
1618
|
+
None,
|
|
1619
|
+
"--difficulty",
|
|
1620
|
+
metavar="LEVEL",
|
|
1621
|
+
help="Difficulty (easy, medium, hard) controlling clips events",
|
|
1622
|
+
rich_help_panel="Mission",
|
|
1623
|
+
),
|
|
1383
1624
|
# --- Policy ---
|
|
1384
1625
|
policies: Optional[list[str]] = typer.Option( # noqa: B008
|
|
1385
1626
|
None,
|
|
@@ -1389,6 +1630,13 @@ def run_cmd(
|
|
|
1389
1630
|
help=f"Policies to evaluate: ({policy_arg_w_proportion_example}...)",
|
|
1390
1631
|
rich_help_panel="Policy",
|
|
1391
1632
|
),
|
|
1633
|
+
device: str = typer.Option(
|
|
1634
|
+
"auto",
|
|
1635
|
+
"--device",
|
|
1636
|
+
metavar="DEVICE",
|
|
1637
|
+
help="Policy device (auto, cpu, cuda, cuda:0, etc.)",
|
|
1638
|
+
rich_help_panel="Policy",
|
|
1639
|
+
),
|
|
1392
1640
|
# --- Simulation ---
|
|
1393
1641
|
episodes: int = typer.Option(
|
|
1394
1642
|
10,
|
|
@@ -1400,12 +1648,13 @@ def run_cmd(
|
|
|
1400
1648
|
rich_help_panel="Simulation",
|
|
1401
1649
|
),
|
|
1402
1650
|
steps: Optional[int] = typer.Option(
|
|
1403
|
-
|
|
1651
|
+
None,
|
|
1404
1652
|
"--steps",
|
|
1405
1653
|
"-s",
|
|
1406
1654
|
metavar="N",
|
|
1407
1655
|
help="Max steps per episode",
|
|
1408
1656
|
min=1,
|
|
1657
|
+
show_default="from mission",
|
|
1409
1658
|
rich_help_panel="Simulation",
|
|
1410
1659
|
),
|
|
1411
1660
|
seed: int = typer.Option(
|
|
@@ -1465,7 +1714,7 @@ def run_cmd(
|
|
|
1465
1714
|
raise typer.Exit(1)
|
|
1466
1715
|
|
|
1467
1716
|
if mission_set:
|
|
1468
|
-
from cogames.cli.mission import load_mission_set
|
|
1717
|
+
from cogames.cli.mission import load_mission_set # noqa: PLC0415
|
|
1469
1718
|
|
|
1470
1719
|
try:
|
|
1471
1720
|
mission_objs = load_mission_set(mission_set)
|
|
@@ -1479,7 +1728,14 @@ def run_cmd(
|
|
|
1479
1728
|
if cogs is None:
|
|
1480
1729
|
cogs = 4
|
|
1481
1730
|
|
|
1482
|
-
selected_missions = get_mission_names_and_configs(
|
|
1731
|
+
selected_missions = get_mission_names_and_configs(
|
|
1732
|
+
ctx,
|
|
1733
|
+
missions,
|
|
1734
|
+
variants_arg=variant,
|
|
1735
|
+
cogs=cogs,
|
|
1736
|
+
steps=steps,
|
|
1737
|
+
difficulty=difficulty,
|
|
1738
|
+
)
|
|
1483
1739
|
|
|
1484
1740
|
# Optional MapGen seed override for procedural maps.
|
|
1485
1741
|
if map_seed is not None:
|
|
@@ -1488,7 +1744,8 @@ def run_cmd(
|
|
|
1488
1744
|
if isinstance(map_builder, MapGen.Config):
|
|
1489
1745
|
map_builder.seed = map_seed
|
|
1490
1746
|
|
|
1491
|
-
|
|
1747
|
+
resolved_device = resolve_training_device(console, device)
|
|
1748
|
+
policy_specs = get_policy_specs_with_proportions(ctx, policies, device=str(resolved_device))
|
|
1492
1749
|
|
|
1493
1750
|
if ctx.info_name == "scrimmage":
|
|
1494
1751
|
if len(policy_specs) != 1:
|
|
@@ -1510,6 +1767,7 @@ def run_cmd(
|
|
|
1510
1767
|
action_timeout_ms=action_timeout_ms,
|
|
1511
1768
|
episodes=episodes,
|
|
1512
1769
|
seed=seed,
|
|
1770
|
+
device=str(resolved_device),
|
|
1513
1771
|
output_format=format_,
|
|
1514
1772
|
save_replay=str(save_replay_dir) if save_replay_dir else None,
|
|
1515
1773
|
)
|
|
@@ -1552,6 +1810,13 @@ def pickup_cmd(
|
|
|
1552
1810
|
help="Mission variant (repeatable)",
|
|
1553
1811
|
rich_help_panel="Mission",
|
|
1554
1812
|
),
|
|
1813
|
+
difficulty: Optional[str] = typer.Option(
|
|
1814
|
+
None,
|
|
1815
|
+
"--difficulty",
|
|
1816
|
+
metavar="LEVEL",
|
|
1817
|
+
help="Difficulty (easy, medium, hard) controlling clips events",
|
|
1818
|
+
rich_help_panel="Mission",
|
|
1819
|
+
),
|
|
1555
1820
|
# --- Policy ---
|
|
1556
1821
|
policy: Optional[str] = typer.Option(
|
|
1557
1822
|
None,
|
|
@@ -1568,6 +1833,13 @@ def pickup_cmd(
|
|
|
1568
1833
|
help="Pool policy (repeatable)",
|
|
1569
1834
|
rich_help_panel="Policy",
|
|
1570
1835
|
),
|
|
1836
|
+
device: str = typer.Option(
|
|
1837
|
+
"auto",
|
|
1838
|
+
"--device",
|
|
1839
|
+
metavar="DEVICE",
|
|
1840
|
+
help="Policy device (auto, cpu, cuda, cuda:0, etc.)",
|
|
1841
|
+
rich_help_panel="Policy",
|
|
1842
|
+
),
|
|
1571
1843
|
# --- Simulation ---
|
|
1572
1844
|
episodes: int = typer.Option(
|
|
1573
1845
|
1,
|
|
@@ -1631,7 +1903,7 @@ def pickup_cmd(
|
|
|
1631
1903
|
rich_help_panel="Other",
|
|
1632
1904
|
),
|
|
1633
1905
|
) -> None:
|
|
1634
|
-
import httpx
|
|
1906
|
+
import httpx # noqa: PLC0415
|
|
1635
1907
|
|
|
1636
1908
|
if policy is None:
|
|
1637
1909
|
console.print(ctx.get_help())
|
|
@@ -1644,15 +1916,22 @@ def pickup_cmd(
|
|
|
1644
1916
|
raise typer.Exit(1)
|
|
1645
1917
|
|
|
1646
1918
|
# Resolve mission
|
|
1647
|
-
resolved_mission, env_cfg, _ = get_mission_name_and_config(
|
|
1919
|
+
resolved_mission, env_cfg, _ = get_mission_name_and_config(
|
|
1920
|
+
ctx,
|
|
1921
|
+
mission,
|
|
1922
|
+
variants_arg=variant,
|
|
1923
|
+
cogs=cogs,
|
|
1924
|
+
difficulty=difficulty,
|
|
1925
|
+
)
|
|
1648
1926
|
if steps is not None:
|
|
1649
1927
|
env_cfg.game.max_steps = steps
|
|
1650
1928
|
|
|
1651
1929
|
candidate_label = policy
|
|
1652
1930
|
pool_labels = pool
|
|
1653
|
-
|
|
1931
|
+
resolved_device = resolve_training_device(console, device)
|
|
1932
|
+
candidate_spec = get_policy_spec(ctx, policy, device=str(resolved_device))
|
|
1654
1933
|
try:
|
|
1655
|
-
pool_specs = [parse_policy_spec(spec).to_policy_spec() for spec in pool]
|
|
1934
|
+
pool_specs = [parse_policy_spec(spec, device=str(resolved_device)).to_policy_spec() for spec in pool]
|
|
1656
1935
|
except (ValueError, ModuleNotFoundError, httpx.HTTPError) as exc:
|
|
1657
1936
|
translated = _translate_error(exc)
|
|
1658
1937
|
console.print(f"[yellow]Error parsing pool policy: {translated}[/yellow]\n")
|
|
@@ -1669,6 +1948,7 @@ def pickup_cmd(
|
|
|
1669
1948
|
map_seed=map_seed,
|
|
1670
1949
|
action_timeout_ms=action_timeout_ms,
|
|
1671
1950
|
save_replay_dir=save_replay_dir,
|
|
1951
|
+
device=str(resolved_device),
|
|
1672
1952
|
candidate_label=candidate_label,
|
|
1673
1953
|
pool_labels=pool_labels,
|
|
1674
1954
|
)
|
|
@@ -1762,10 +2042,10 @@ def login_cmd(
|
|
|
1762
2042
|
rich_help_panel="Other",
|
|
1763
2043
|
),
|
|
1764
2044
|
) -> None:
|
|
1765
|
-
from urllib.parse import urlparse
|
|
2045
|
+
from urllib.parse import urlparse # noqa: PLC0415
|
|
1766
2046
|
|
|
1767
2047
|
# Check if we already have a token
|
|
1768
|
-
from cogames.auth import BaseCLIAuthenticator
|
|
2048
|
+
from cogames.auth import BaseCLIAuthenticator # noqa: PLC0415
|
|
1769
2049
|
|
|
1770
2050
|
temp_auth = BaseCLIAuthenticator(
|
|
1771
2051
|
token_file_name="cogames.yaml",
|
|
@@ -1823,7 +2103,9 @@ app.command(
|
|
|
1823
2103
|
rich_help_panel="Evaluate",
|
|
1824
2104
|
epilog="""[dim]Examples:[/dim]
|
|
1825
2105
|
|
|
1826
|
-
[cyan]cogames diagnose ./train_dir/my_run[/cyan] Default
|
|
2106
|
+
[cyan]cogames diagnose ./train_dir/my_run[/cyan] Default CogsGuard evals
|
|
2107
|
+
|
|
2108
|
+
[cyan]cogames diagnose lstm -S diagnostic_evals[/cyan] Diagnostic evals (non-CogsGuard)
|
|
1827
2109
|
|
|
1828
2110
|
[cyan]cogames diagnose lstm -S tournament[/cyan] Tournament suite
|
|
1829
2111
|
|
|
@@ -1831,6 +2113,7 @@ app.command(
|
|
|
1831
2113
|
add_help_option=False,
|
|
1832
2114
|
)
|
|
1833
2115
|
def diagnose_cmd(
|
|
2116
|
+
ctx: typer.Context,
|
|
1834
2117
|
policy: str = typer.Argument(
|
|
1835
2118
|
...,
|
|
1836
2119
|
metavar="POLICY",
|
|
@@ -1838,6 +2121,7 @@ def diagnose_cmd(
|
|
|
1838
2121
|
),
|
|
1839
2122
|
# --- Evaluation ---
|
|
1840
2123
|
mission_set: Literal[
|
|
2124
|
+
"cogsguard_evals",
|
|
1841
2125
|
"diagnostic_evals",
|
|
1842
2126
|
"integrated_evals",
|
|
1843
2127
|
"spanning_evals",
|
|
@@ -1845,7 +2129,7 @@ def diagnose_cmd(
|
|
|
1845
2129
|
"tournament",
|
|
1846
2130
|
"all",
|
|
1847
2131
|
] = typer.Option(
|
|
1848
|
-
"
|
|
2132
|
+
"cogsguard_evals",
|
|
1849
2133
|
"--mission-set",
|
|
1850
2134
|
"-S",
|
|
1851
2135
|
metavar="SET",
|
|
@@ -1867,6 +2151,13 @@ def diagnose_cmd(
|
|
|
1867
2151
|
help="Agent counts to test (repeatable)",
|
|
1868
2152
|
rich_help_panel="Evaluation",
|
|
1869
2153
|
),
|
|
2154
|
+
device: str = typer.Option(
|
|
2155
|
+
"auto",
|
|
2156
|
+
"--device",
|
|
2157
|
+
metavar="DEVICE",
|
|
2158
|
+
help="Policy device (auto, cpu, cuda, cuda:0, etc.)",
|
|
2159
|
+
rich_help_panel="Evaluation",
|
|
2160
|
+
),
|
|
1870
2161
|
# --- Simulation ---
|
|
1871
2162
|
steps: int = typer.Option(
|
|
1872
2163
|
1000,
|
|
@@ -1895,28 +2186,30 @@ def diagnose_cmd(
|
|
|
1895
2186
|
rich_help_panel="Other",
|
|
1896
2187
|
),
|
|
1897
2188
|
) -> None:
|
|
1898
|
-
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
if
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
cmd.extend(["--steps", str(steps)])
|
|
1912
|
-
cmd.extend(["--repeats", str(episodes)])
|
|
1913
|
-
cmd.append("--no-plots")
|
|
1914
|
-
|
|
1915
|
-
cmd.extend(["--policy", policy])
|
|
2189
|
+
resolved_device = resolve_training_device(console, device)
|
|
2190
|
+
policy_spec = get_policy_spec(ctx, policy, device=str(resolved_device))
|
|
2191
|
+
|
|
2192
|
+
cases = _build_diagnose_cases(
|
|
2193
|
+
mission_set=mission_set,
|
|
2194
|
+
experiments=experiments,
|
|
2195
|
+
cogs=cogs,
|
|
2196
|
+
steps=steps,
|
|
2197
|
+
)
|
|
2198
|
+
if not cases:
|
|
2199
|
+
console.print("[red]No evaluation cases matched your filters.[/red]")
|
|
2200
|
+
raise typer.Exit(1)
|
|
1916
2201
|
|
|
1917
|
-
console.print("[cyan]Running diagnostic evaluation...[/cyan]")
|
|
1918
|
-
|
|
1919
|
-
|
|
2202
|
+
console.print(f"[cyan]Running diagnostic evaluation ({len(cases)} cases)...[/cyan]")
|
|
2203
|
+
evaluate_module.evaluate(
|
|
2204
|
+
console,
|
|
2205
|
+
missions=[(case.name, case.env_cfg) for case in cases],
|
|
2206
|
+
policy_specs=[policy_spec],
|
|
2207
|
+
proportions=[1.0],
|
|
2208
|
+
action_timeout_ms=10000,
|
|
2209
|
+
episodes=episodes,
|
|
2210
|
+
seed=42,
|
|
2211
|
+
device=str(resolved_device),
|
|
2212
|
+
)
|
|
1920
2213
|
|
|
1921
2214
|
|
|
1922
2215
|
def _resolve_season(server: str, season_name: str | None = None) -> SeasonInfo:
|
|
@@ -1950,6 +2243,13 @@ def validate_policy_cmd(
|
|
|
1950
2243
|
help=f"Policy specification: {policy_arg_example}",
|
|
1951
2244
|
rich_help_panel="Policy",
|
|
1952
2245
|
),
|
|
2246
|
+
device: str = typer.Option(
|
|
2247
|
+
"auto",
|
|
2248
|
+
"--device",
|
|
2249
|
+
metavar="DEVICE",
|
|
2250
|
+
help="Policy device (auto, cpu, cuda, cuda:0, etc.)",
|
|
2251
|
+
rich_help_panel="Policy",
|
|
2252
|
+
),
|
|
1953
2253
|
setup_script: Optional[str] = typer.Option(
|
|
1954
2254
|
None,
|
|
1955
2255
|
"--setup-script",
|
|
@@ -1991,9 +2291,9 @@ def validate_policy_cmd(
|
|
|
1991
2291
|
env_cfg = MettaGridConfig.model_validate(config_data)
|
|
1992
2292
|
|
|
1993
2293
|
if setup_script:
|
|
1994
|
-
import subprocess
|
|
1995
|
-
import sys
|
|
1996
|
-
from pathlib import Path
|
|
2294
|
+
import subprocess # noqa: PLC0415
|
|
2295
|
+
import sys # noqa: PLC0415
|
|
2296
|
+
from pathlib import Path # noqa: PLC0415
|
|
1997
2297
|
|
|
1998
2298
|
script_path = Path(setup_script)
|
|
1999
2299
|
if not script_path.exists():
|
|
@@ -2012,8 +2312,14 @@ def validate_policy_cmd(
|
|
|
2012
2312
|
raise typer.Exit(1)
|
|
2013
2313
|
console.print("[green]Setup script completed[/green]")
|
|
2014
2314
|
|
|
2015
|
-
|
|
2016
|
-
|
|
2315
|
+
resolved_device = resolve_training_device(console, device)
|
|
2316
|
+
policy_spec = get_policy_spec(ctx, policy, device=str(resolved_device))
|
|
2317
|
+
validate_policy_spec(
|
|
2318
|
+
policy_spec,
|
|
2319
|
+
env_cfg,
|
|
2320
|
+
device=str(resolved_device),
|
|
2321
|
+
season=season_info.name,
|
|
2322
|
+
)
|
|
2017
2323
|
console.print("[green]Policy validated successfully[/green]")
|
|
2018
2324
|
raise typer.Exit(0)
|
|
2019
2325
|
|
|
@@ -2224,7 +2530,7 @@ def submit_cmd(
|
|
|
2224
2530
|
rich_help_panel="Other",
|
|
2225
2531
|
),
|
|
2226
2532
|
) -> None:
|
|
2227
|
-
import httpx
|
|
2533
|
+
import httpx # noqa: PLC0415
|
|
2228
2534
|
|
|
2229
2535
|
season_info = _resolve_season(server, season)
|
|
2230
2536
|
season_name = season_info.name
|
|
@@ -2322,7 +2628,7 @@ def docs_cmd(
|
|
|
2322
2628
|
|
|
2323
2629
|
# If no argument provided, show available documents
|
|
2324
2630
|
if doc_name is None:
|
|
2325
|
-
from rich.table import Table
|
|
2631
|
+
from rich.table import Table # noqa: PLC0415
|
|
2326
2632
|
|
|
2327
2633
|
console.print("\n[bold cyan]Available Documents:[/bold cyan]\n")
|
|
2328
2634
|
table = Table(show_header=True, header_style="bold magenta", box=box.ROUNDED, padding=(0, 1))
|