cogames 0.3.65__py3-none-any.whl → 0.3.69__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. cogames/cli/client.py +0 -3
  2. cogames/cli/docsync/docsync.py +7 -1
  3. cogames/cli/mission.py +44 -19
  4. cogames/cli/policy.py +26 -10
  5. cogames/cli/submit.py +201 -495
  6. cogames/cli/utils.py +5 -0
  7. cogames/cogs_vs_clips/clip_difficulty.py +57 -0
  8. cogames/cogs_vs_clips/clips.py +23 -6
  9. cogames/cogs_vs_clips/cog.py +16 -5
  10. cogames/cogs_vs_clips/cogsguard_curriculum.py +122 -0
  11. cogames/cogs_vs_clips/cogsguard_tutorial.py +5 -5
  12. cogames/cogs_vs_clips/config.py +1 -1
  13. cogames/cogs_vs_clips/docs/cogs_vs_clips_mapgen.md +2 -3
  14. cogames/cogs_vs_clips/evals/README.md +8 -32
  15. cogames/cogs_vs_clips/evals/diagnostic_evals.py +0 -1
  16. cogames/cogs_vs_clips/evals/difficulty_variants.py +7 -10
  17. cogames/cogs_vs_clips/mission.py +38 -10
  18. cogames/cogs_vs_clips/missions.py +1 -1
  19. cogames/cogs_vs_clips/reward_variants.py +173 -0
  20. cogames/cogs_vs_clips/sites.py +6 -5
  21. cogames/cogs_vs_clips/stations.py +13 -9
  22. cogames/cogs_vs_clips/team.py +3 -1
  23. cogames/cogs_vs_clips/terrain.py +2 -2
  24. cogames/cogs_vs_clips/variants.py +175 -4
  25. cogames/cogs_vs_clips/weather.py +52 -0
  26. cogames/docs/SCRIPTED_AGENT.md +3 -3
  27. cogames/evaluate.py +4 -2
  28. cogames/main.py +420 -84
  29. cogames/maps/canidate1_1000.map +1 -1
  30. cogames/maps/canidate1_1000_stations.map +2 -2
  31. cogames/maps/canidate1_500.map +1 -1
  32. cogames/maps/canidate1_500_stations.map +2 -2
  33. cogames/maps/canidate2_1000.map +1 -1
  34. cogames/maps/canidate2_1000_stations.map +2 -2
  35. cogames/maps/canidate2_500.map +1 -1
  36. cogames/maps/canidate2_500_stations.map +1 -1
  37. cogames/maps/canidate3_1000.map +1 -1
  38. cogames/maps/canidate3_1000_stations.map +2 -2
  39. cogames/maps/canidate3_500.map +1 -1
  40. cogames/maps/canidate3_500_stations.map +2 -2
  41. cogames/maps/canidate4_500.map +1 -1
  42. cogames/maps/canidate4_500_stations.map +2 -2
  43. cogames/maps/cave_base_50.map +2 -2
  44. cogames/maps/diagnostic_evals/diagnostic_agile.map +2 -2
  45. cogames/maps/diagnostic_evals/diagnostic_agile_hard.map +2 -2
  46. cogames/maps/diagnostic_evals/diagnostic_charge_up.map +6 -6
  47. cogames/maps/diagnostic_evals/diagnostic_charge_up_hard.map +6 -6
  48. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1.map +6 -6
  49. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1_hard.map +6 -6
  50. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2.map +6 -6
  51. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2_hard.map +6 -6
  52. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3.map +6 -6
  53. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3_hard.map +6 -6
  54. cogames/maps/diagnostic_evals/diagnostic_chest_near.map +6 -6
  55. cogames/maps/diagnostic_evals/diagnostic_chest_search.map +6 -6
  56. cogames/maps/diagnostic_evals/diagnostic_chest_search_hard.map +6 -6
  57. cogames/maps/diagnostic_evals/diagnostic_extract_lab.map +6 -6
  58. cogames/maps/diagnostic_evals/diagnostic_extract_lab_hard.map +6 -6
  59. cogames/maps/diagnostic_evals/diagnostic_memory.map +6 -6
  60. cogames/maps/diagnostic_evals/diagnostic_memory_hard.map +6 -6
  61. cogames/maps/diagnostic_evals/diagnostic_radial.map +2 -2
  62. cogames/maps/diagnostic_evals/diagnostic_radial_hard.map +2 -2
  63. cogames/maps/diagnostic_evals/diagnostic_resource_lab.map +6 -6
  64. cogames/maps/diagnostic_evals/diagnostic_unclip.map +6 -6
  65. cogames/maps/evals/eval_balanced_spread.map +6 -6
  66. cogames/maps/evals/eval_clip_oxygen.map +6 -6
  67. cogames/maps/evals/eval_collect_resources.map +6 -6
  68. cogames/maps/evals/eval_collect_resources_hard.map +6 -6
  69. cogames/maps/evals/eval_collect_resources_medium.map +6 -6
  70. cogames/maps/evals/eval_divide_and_conquer.map +6 -6
  71. cogames/maps/evals/eval_energy_starved.map +6 -6
  72. cogames/maps/evals/eval_multi_coordinated_collect_hard.map +6 -6
  73. cogames/maps/evals/eval_oxygen_bottleneck.map +6 -6
  74. cogames/maps/evals/eval_single_use_world.map +6 -6
  75. cogames/maps/evals/extractor_hub_100x100.map +6 -6
  76. cogames/maps/evals/extractor_hub_30x30.map +6 -6
  77. cogames/maps/evals/extractor_hub_50x50.map +6 -6
  78. cogames/maps/evals/extractor_hub_70x70.map +6 -6
  79. cogames/maps/evals/extractor_hub_80x80.map +6 -6
  80. cogames/maps/machina_100_stations.map +2 -2
  81. cogames/maps/machina_200_stations.map +2 -2
  82. cogames/maps/machina_200_stations_small.map +2 -2
  83. cogames/maps/machina_eval_exp01.map +2 -2
  84. cogames/maps/machina_eval_template_large.map +2 -2
  85. cogames/maps/machinatrainer4agents.map +2 -2
  86. cogames/maps/machinatrainer4agentsbase.map +2 -2
  87. cogames/maps/machinatrainerbig.map +2 -2
  88. cogames/maps/machinatrainersmall.map +2 -2
  89. cogames/maps/planky_evals/aligner_avoid_aoe.map +6 -6
  90. cogames/maps/planky_evals/aligner_full_cycle.map +6 -6
  91. cogames/maps/planky_evals/aligner_gear.map +6 -6
  92. cogames/maps/planky_evals/aligner_hearts.map +6 -6
  93. cogames/maps/planky_evals/aligner_junction.map +6 -6
  94. cogames/maps/planky_evals/exploration_distant.map +6 -6
  95. cogames/maps/planky_evals/maze.map +6 -6
  96. cogames/maps/planky_evals/miner_best_resource.map +6 -6
  97. cogames/maps/planky_evals/miner_deposit.map +6 -6
  98. cogames/maps/planky_evals/miner_extract.map +6 -6
  99. cogames/maps/planky_evals/miner_full_cycle.map +6 -6
  100. cogames/maps/planky_evals/miner_gear.map +6 -6
  101. cogames/maps/planky_evals/multi_role.map +6 -6
  102. cogames/maps/planky_evals/resource_chain.map +6 -6
  103. cogames/maps/planky_evals/scout_explore.map +6 -6
  104. cogames/maps/planky_evals/scout_gear.map +6 -6
  105. cogames/maps/planky_evals/scrambler_full_cycle.map +6 -6
  106. cogames/maps/planky_evals/scrambler_gear.map +6 -6
  107. cogames/maps/planky_evals/scrambler_target.map +6 -6
  108. cogames/maps/planky_evals/stuck_corridor.map +6 -6
  109. cogames/maps/planky_evals/survive_retreat.map +6 -6
  110. cogames/maps/training_facility_clipped.map +2 -2
  111. cogames/maps/training_facility_open_1.map +2 -2
  112. cogames/maps/training_facility_open_2.map +2 -2
  113. cogames/maps/training_facility_open_3.map +2 -2
  114. cogames/maps/training_facility_tight_4.map +2 -2
  115. cogames/maps/training_facility_tight_5.map +2 -2
  116. cogames/maps/vanilla_large.map +2 -2
  117. cogames/maps/vanilla_small.map +2 -2
  118. cogames/pickup.py +6 -5
  119. cogames/play.py +14 -16
  120. cogames/policy/nim_agents/__init__.py +0 -2
  121. cogames/policy/nim_agents/agents.py +0 -11
  122. cogames/policy/starter_agent.py +4 -1
  123. {cogames-0.3.65.dist-info → cogames-0.3.69.dist-info}/METADATA +45 -29
  124. cogames-0.3.69.dist-info/RECORD +160 -0
  125. metta_alo/scoring.py +7 -7
  126. cogames-0.3.65.dist-info/RECORD +0 -160
  127. metta_alo/job_specs.py +0 -17
  128. metta_alo/policy.py +0 -16
  129. metta_alo/pure_single_episode_runner.py +0 -75
  130. metta_alo/rollout.py +0 -322
  131. {cogames-0.3.65.dist-info → cogames-0.3.69.dist-info}/WHEEL +0 -0
  132. {cogames-0.3.65.dist-info → cogames-0.3.69.dist-info}/entry_points.txt +0 -0
  133. {cogames-0.3.65.dist-info → cogames-0.3.69.dist-info}/licenses/LICENSE +0 -0
  134. {cogames-0.3.65.dist-info → cogames-0.3.69.dist-info}/top_level.txt +0 -0
cogames/cli/submit.py CHANGED
@@ -5,33 +5,28 @@ from __future__ import annotations
5
5
  import os
6
6
  import shutil
7
7
  import subprocess
8
+ import sys
8
9
  import tempfile
9
10
  import uuid
10
11
  import zipfile
11
12
  from dataclasses import dataclass
12
13
  from pathlib import Path
13
- from typing import TYPE_CHECKING
14
14
 
15
15
  import httpx
16
16
  import typer
17
- from rich.console import Console
18
17
 
19
18
  from cogames.cli.base import console
19
+ from cogames.cli.client import TournamentServerClient
20
20
  from cogames.cli.login import DEFAULT_COGAMES_SERVER
21
21
  from cogames.cli.policy import PolicySpec, get_policy_spec
22
- from metta_alo.rollout import run_single_episode
23
-
24
- if TYPE_CHECKING:
25
- from cogames.cli.client import TournamentServerClient
26
-
27
22
  from mettagrid.config.mettagrid_config import MettaGridConfig
28
23
  from mettagrid.policy.submission import POLICY_SPEC_FILENAME, SubmissionPolicySpec
24
+ from mettagrid.runner.episode_runner import run_episode_isolated
25
+ from mettagrid.runner.types import EpisodeSpec
26
+ from mettagrid.util.uri_resolvers.schemes import localize_uri, parse_uri
29
27
 
30
28
  DEFAULT_SUBMIT_SERVER = "https://api.observatory.softmax-research.net"
31
-
32
-
33
- def results_url_for_season(_server: str, _season: str) -> str:
34
- return "https://www.softmax.com/alignmentleague"
29
+ RESULTS_URL = "https://www.softmax.com/alignmentleague"
35
30
 
36
31
 
37
32
  @dataclass
@@ -42,277 +37,47 @@ class UploadResult:
42
37
  pools: list[str] | None = None
43
38
 
44
39
 
45
- def validate_paths(paths: list[str], console: Console) -> list[Path]:
46
- """Validate paths are within CWD and return them as relative paths."""
40
+ def _resolve_path_within_cwd(path_str: str, cwd: Path) -> Path:
41
+ """Resolve a path and return it relative to CWD. Raises if path escapes CWD."""
42
+ raw_path = Path(path_str).expanduser()
43
+ resolved = raw_path.resolve() if raw_path.is_absolute() else (cwd / raw_path).resolve()
44
+ if not resolved.is_relative_to(cwd):
45
+ console.print(f"[red]Error:[/red] Path must be within the current directory: {path_str}")
46
+ raise ValueError(f"Path escapes CWD: {path_str}")
47
+ return resolved.relative_to(cwd)
48
+
49
+
50
+ def validate_paths(paths: list[str]) -> list[Path]:
51
+ """Validate paths exist and are within CWD, return them as relative paths."""
47
52
  cwd = Path.cwd().resolve()
48
53
  validated_paths = []
49
-
50
54
  for path_str in paths:
51
- raw_path = Path(path_str).expanduser()
52
-
53
- # Resolve the path and check it's within CWD
54
- try:
55
- resolved = raw_path.resolve() if raw_path.is_absolute() else (cwd / raw_path).resolve()
56
- relative = resolved.relative_to(cwd)
57
- except ValueError:
58
- console.print(f"[red]Error:[/red] Path must be within the current directory: {path_str}")
59
- raise ValueError(f"Path escapes CWD: {path_str}") from None
60
-
61
- # Check if path exists
55
+ relative = _resolve_path_within_cwd(path_str, cwd)
56
+ resolved = cwd / relative
62
57
  if not resolved.exists():
63
58
  console.print(f"[red]Error:[/red] Path does not exist: {path_str}")
64
59
  raise FileNotFoundError(f"Path not found: {path_str}")
65
-
66
60
  validated_paths.append(relative)
67
-
68
61
  return validated_paths
69
62
 
70
63
 
71
- def _maybe_resolve_checkpoint_bundle_uri(policy: str) -> tuple[Path, bool] | None:
72
- """Return (local_zip_path, cleanup) if policy points to a checkpoint bundle URI."""
73
- from mettagrid.policy.prepare_policy_spec import download_policy_spec_from_s3_as_zip
74
- from mettagrid.util.uri_resolvers.schemes import parse_uri, resolve_uri
75
-
76
- first = policy.split(",", 1)[0].strip()
77
- parsed = parse_uri(first, allow_none=True, default_scheme=None)
78
- if parsed is None or parsed.scheme not in {"file", "s3"}:
79
- return None
80
-
81
- resolved = resolve_uri(first)
82
- if resolved.scheme == "s3":
83
- if not resolved.canonical.endswith(".zip"):
84
- raise ValueError("S3 policy specs must be .zip bundles.")
85
- return download_policy_spec_from_s3_as_zip(resolved.canonical), False
86
-
87
- local_path = resolved.local_path
88
- if local_path is None:
89
- raise ValueError(f"Cannot resolve local path for URI: {policy}")
90
- if not local_path.exists():
91
- raise FileNotFoundError(f"Bundle path not found: {local_path}")
92
- if local_path.is_dir():
93
- return _zip_directory_bundle(local_path), True
94
- if local_path.suffix == ".zip":
95
- return local_path, False
96
- raise ValueError("Checkpoint bundle must be a directory or .zip file.")
97
-
98
-
99
- def _zip_directory_bundle(bundle_dir: Path) -> Path:
100
- zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="cogames_bundle_")
101
- os.close(zip_fd)
102
-
103
- with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
104
- for file_path in bundle_dir.rglob("*"):
64
+ def _zip_directory_to(src: Path, dest: Path) -> None:
65
+ with zipfile.ZipFile(dest, "w", zipfile.ZIP_DEFLATED) as zipf:
66
+ for file_path in src.rglob("*"):
105
67
  if file_path.is_file():
106
- zipf.write(file_path, arcname=file_path.relative_to(bundle_dir))
107
-
108
- return Path(zip_path)
109
-
110
-
111
- def validate_bundle_in_isolation(policy_zip: Path, console: Console, *, season: str, server: str) -> bool:
112
- console.print("[dim]Testing policy bundle can run 10 steps...[/dim]")
68
+ zipf.write(file_path, arcname=file_path.relative_to(src))
113
69
 
114
- temp_dir = None
115
- try:
116
- temp_dir = create_temp_validation_env()
117
- bundle_name = policy_zip.name
118
- shutil.copy2(policy_zip, temp_dir / bundle_name)
119
70
 
120
- env = os.environ.copy()
121
- env["UV_NO_CACHE"] = "1"
122
-
123
- res = subprocess.run(
124
- [
125
- "uv",
126
- "run",
127
- "cogames",
128
- "validate-policy",
129
- "--season",
130
- season,
131
- "--server",
132
- server,
133
- "--policy",
134
- f"file://./{bundle_name}",
135
- ],
136
- cwd=temp_dir,
137
- capture_output=True,
138
- text=True,
139
- timeout=300,
140
- env=env,
141
- )
142
- if res.returncode != 0:
143
- console.print("[red]Validation failed[/red]")
144
- console.print(f"\n[red]Error:[/red]\n{res.stderr}")
145
- if res.stdout:
146
- console.print(f"\n[dim]Output:[/dim]\n{res.stdout}")
147
- return False
148
-
149
- console.print("[green]Validation passed[/green]")
150
- return True
151
- except subprocess.TimeoutExpired:
152
- console.print("[red]Validation timed out after 5 minutes[/red]")
153
- return False
154
- finally:
155
- if temp_dir and temp_dir.exists():
156
- shutil.rmtree(temp_dir)
157
-
158
-
159
- def get_latest_pypi_version(package: str) -> str:
160
- """Get the latest published version of a package from PyPI."""
161
- response = httpx.get(f"https://pypi.org/pypi/{package}/json")
162
- response.raise_for_status()
163
- return response.json()["info"]["version"]
164
-
165
-
166
- def get_pypi_requires_python(package: str) -> str:
167
- """Get the requires_python constraint from PyPI."""
168
- response = httpx.get(f"https://pypi.org/pypi/{package}/json")
169
- response.raise_for_status()
170
- return response.json()["info"]["requires_python"]
171
-
172
-
173
- def create_temp_validation_env() -> Path:
174
- """Create a temporary directory with a minimal pyproject.toml.
175
-
176
- The pyproject.toml depends on the latest published cogames and cogames-agents packages.
177
- Python version is constrained to match mettagrid's published wheels.
178
- """
179
- temp_dir = Path(tempfile.mkdtemp(prefix="cogames_submit_"))
180
-
181
- latest_cogames_version = get_latest_pypi_version("cogames")
182
- latest_agents_version = get_latest_pypi_version("cogames-agents")
183
- mettagrid_requires_python = get_pypi_requires_python("mettagrid")
184
-
185
- pyproject_content = f"""[project]
186
- name = "cogames-submission-validator"
187
- version = "0.1.0"
188
- requires-python = "{mettagrid_requires_python}"
189
- dependencies = ["cogames=={latest_cogames_version}", "cogames-agents=={latest_agents_version}"]
190
-
191
- [build-system]
192
- requires = ["setuptools>=42"]
193
- build-backend = "setuptools.build_meta"
194
- """
195
-
196
- pyproject_path = temp_dir / "pyproject.toml"
197
- pyproject_path.write_text(pyproject_content)
198
-
199
- return temp_dir
200
-
201
-
202
- def copy_files_maintaining_structure(files: list[Path], dest_dir: Path) -> None:
203
- """Copy files to destination, maintaining directory structure.
204
-
205
- If a file is 'train_dir/model.pt', it will be copied to 'dest_dir/train_dir/model.pt'.
206
- """
207
- for file_path in files:
208
- dest_path = dest_dir / file_path
209
-
210
- if file_path.is_dir():
211
- shutil.copytree(file_path, dest_path, dirs_exist_ok=True)
212
- else:
213
- dest_path.parent.mkdir(parents=True, exist_ok=True)
214
- shutil.copy2(file_path, dest_path)
215
-
216
-
217
- def validate_policy_spec(policy_spec: PolicySpec, env_cfg: MettaGridConfig) -> None:
218
- env_cfg = env_cfg.model_copy()
219
- env_cfg.game.max_steps = 10
220
- n = env_cfg.game.num_agents
221
- n_submitted = min(2, n)
222
- noop_spec = PolicySpec(class_path="mettagrid.policy.noop.NoopPolicy")
223
- if n_submitted < n:
224
- policy_specs = [policy_spec, noop_spec]
225
- assignments = [0] * n_submitted + [1] * (n - n_submitted)
226
- else:
227
- policy_specs = [policy_spec]
228
- assignments = [0] * n
229
- run_single_episode(
230
- policy_specs=policy_specs,
231
- assignments=assignments,
232
- env=env_cfg,
233
- results_uri=None,
234
- replay_uri=None,
235
- seed=42,
236
- max_action_time_ms=10000,
237
- device="cpu",
238
- )
239
-
240
-
241
- def validate_policy_in_isolation(
242
- policy_spec: PolicySpec,
243
- include_files: list[Path],
244
- console: Console,
245
- setup_script: str | None = None,
246
- *,
247
- season: str,
248
- server: str,
249
- ) -> bool:
250
- def _format_policy_arg(spec: PolicySpec) -> str:
251
- parts = [f"class={spec.class_path}"]
252
- if spec.data_path:
253
- parts.append(f"data={spec.data_path}")
254
- for key, value in spec.init_kwargs.items():
255
- parts.append(f"kw.{key}={value}")
256
- return ",".join(parts)
257
-
258
- console.print("[dim]Testing policy can run 10 steps...[/dim]")
259
-
260
- temp_dir = None
261
- try:
262
- temp_dir = create_temp_validation_env()
263
- copy_files_maintaining_structure(include_files, temp_dir)
264
-
265
- policy_arg = _format_policy_arg(policy_spec)
266
-
267
- def _run_from_tmp_dir(cmd: list[str]) -> subprocess.CompletedProcess[str]:
268
- env = os.environ.copy()
269
- env["UV_NO_CACHE"] = "1"
270
- res = subprocess.run(
271
- cmd,
272
- cwd=temp_dir,
273
- capture_output=True,
274
- text=True,
275
- timeout=300,
276
- env=env,
277
- )
278
- if not res.returncode == 0:
279
- console.print("[red]Validation failed[/red]")
280
- console.print(f"\n[red]Error:[/red]\n{res.stderr}")
281
- if res.stdout:
282
- console.print(f"\n[dim]Output:[/dim]\n{res.stdout}")
283
- raise Exception("Validation failed")
284
- return res
285
-
286
- _run_from_tmp_dir(["uv", "run", "cogames", "version"])
287
-
288
- validate_cmd = [
289
- "uv",
290
- "run",
291
- "cogames",
292
- "validate-policy",
293
- "--policy",
294
- policy_arg,
295
- "--season",
296
- season,
297
- "--server",
298
- server,
299
- ]
300
- if setup_script:
301
- validate_cmd.extend(["--setup-script", setup_script])
302
-
303
- _run_from_tmp_dir(validate_cmd)
304
-
305
- console.print("[green]Validation passed[/green]")
306
- return True
307
-
308
- except subprocess.TimeoutExpired:
309
- console.print("[red]Validation timed out after 5 minutes[/red]")
310
- return False
311
- except Exception:
312
- return False
313
- finally:
314
- if temp_dir and temp_dir.exists():
315
- shutil.rmtree(temp_dir)
71
+ def _collect_ancestor_init_files(include_files: list[Path]) -> list[Path]:
72
+ found: set[Path] = set()
73
+ for path in include_files:
74
+ parent = path.parent
75
+ while parent != Path(".") and parent != parent.parent:
76
+ init = parent / "__init__.py"
77
+ if init.is_file():
78
+ found.add(init)
79
+ parent = parent.parent
80
+ return sorted(found)
316
81
 
317
82
 
318
83
  def create_submission_zip(
@@ -335,156 +100,169 @@ def create_submission_zip(
335
100
  setup_script=setup_script,
336
101
  )
337
102
 
103
+ all_files: dict[str, Path] = {}
104
+ for init_path in _collect_ancestor_init_files(include_files):
105
+ all_files[str(init_path)] = init_path
106
+ for file_path in include_files:
107
+ if file_path.is_dir():
108
+ for root, _, files in os.walk(file_path):
109
+ for file in files:
110
+ full = Path(root) / file
111
+ all_files[str(full)] = full
112
+ else:
113
+ all_files[str(file_path)] = file_path
114
+
338
115
  with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
339
116
  zipf.writestr(data=submission_spec.model_dump_json(), zinfo_or_arcname=POLICY_SPEC_FILENAME)
340
-
341
- for file_path in include_files:
342
- if file_path.is_dir():
343
- for root, _, files in os.walk(file_path):
344
- for file in files:
345
- file_full_path = Path(root) / file
346
- zipf.write(file_full_path, arcname=file_full_path)
347
- else:
348
- zipf.write(file_path, arcname=file_path)
117
+ for arcname, path in all_files.items():
118
+ zipf.write(path, arcname=arcname)
349
119
 
350
120
  return Path(zip_path)
351
121
 
352
122
 
123
+ def create_bundle(
124
+ ctx: typer.Context,
125
+ policy: str,
126
+ output: Path,
127
+ include_files: list[str] | None = None,
128
+ init_kwargs: dict[str, str] | None = None,
129
+ setup_script: str | None = None,
130
+ ) -> Path:
131
+ # TODO: Unify the two paths below. For URI inputs, extract the PolicySpec from the
132
+ # bundle so we can apply init_kwargs/include_files/setup_script, then re-zip.
133
+ local = localize_uri(policy) if parse_uri(policy, allow_none=True, default_scheme=None) else None
134
+ if local is not None:
135
+ if init_kwargs or include_files or setup_script:
136
+ console.print("[red]Error:[/red] Extra files/kwargs are not supported with bundle URIs.")
137
+ raise typer.Exit(1)
138
+ console.print(f"[dim]Packaging existing bundle: {local}[/dim]")
139
+ if local.is_dir():
140
+ _zip_directory_to(local, output)
141
+ else:
142
+ shutil.copy2(local, output)
143
+ console.print(f"[dim]Bundle size: {output.stat().st_size / 1024:.0f} KB[/dim]")
144
+ return output
145
+
146
+ policy_spec = get_policy_spec(ctx, policy)
147
+ console.print(f"[dim]Policy class: {policy_spec.class_path}[/dim]")
148
+
149
+ if init_kwargs:
150
+ merged_kwargs = {**policy_spec.init_kwargs, **init_kwargs}
151
+ policy_spec = PolicySpec(
152
+ class_path=policy_spec.class_path,
153
+ data_path=policy_spec.data_path,
154
+ init_kwargs=merged_kwargs,
155
+ )
156
+
157
+ if policy_spec.init_kwargs:
158
+ console.print(f"[dim]Init kwargs: {policy_spec.init_kwargs}[/dim]")
159
+
160
+ cwd = Path.cwd().resolve()
161
+ if policy_spec.data_path:
162
+ data_rel = str(_resolve_path_within_cwd(policy_spec.data_path, cwd))
163
+ policy_spec = PolicySpec(
164
+ class_path=policy_spec.class_path,
165
+ data_path=data_rel,
166
+ init_kwargs=policy_spec.init_kwargs,
167
+ )
168
+ console.print(f"[dim]Data path: {data_rel}[/dim]")
169
+
170
+ setup_script_rel: str | None = None
171
+ if setup_script:
172
+ setup_script_rel = str(_resolve_path_within_cwd(setup_script, cwd))
173
+ console.print(f"[dim]Setup script: {setup_script_rel}[/dim]")
174
+
175
+ files_to_include = []
176
+ if policy_spec.data_path:
177
+ files_to_include.append(policy_spec.data_path)
178
+ if setup_script_rel:
179
+ files_to_include.append(setup_script_rel)
180
+ if include_files:
181
+ files_to_include.extend(include_files)
182
+
183
+ validated_paths: list[Path] = []
184
+ if files_to_include:
185
+ validated_paths = validate_paths(files_to_include)
186
+ console.print(f"[dim]Including {len(validated_paths)} file(s)[/dim]")
187
+
188
+ tmp_zip = create_submission_zip(validated_paths, policy_spec, setup_script=setup_script_rel)
189
+ shutil.move(str(tmp_zip), str(output))
190
+ console.print(f"[dim]Bundle size: {output.stat().st_size / 1024:.0f} KB[/dim]")
191
+ return output
192
+
193
+
194
+ def validate_bundle(policy_uri: str, env_cfg: MettaGridConfig) -> None:
195
+ """Validate a policy bundle by running a short episode in process isolation."""
196
+ env_cfg.game.max_steps = 10
197
+
198
+ spec = EpisodeSpec(
199
+ policy_uris=[policy_uri],
200
+ assignments=[0] * env_cfg.game.num_agents,
201
+ env=env_cfg,
202
+ seed=42,
203
+ max_action_time_ms=10000,
204
+ )
205
+
206
+ with tempfile.NamedTemporaryFile(suffix=".json", delete=True) as results_file:
207
+ res = run_episode_isolated(spec, Path(results_file.name))
208
+ console.print(f"[dim]Ran for {res.steps} steps[/dim]")
209
+
210
+ non_noop_actions = sum(
211
+ v for k, v in res.stats["agent"][0].items() if k.startswith("action.") and ".noop." not in k
212
+ )
213
+ if non_noop_actions == 0:
214
+ console.print("[yellow]Warning: Policy took no actions (all no-ops)[/yellow]")
215
+ raise typer.Exit(1)
216
+
217
+
353
218
  def upload_submission(
354
219
  client: TournamentServerClient,
355
220
  zip_path: Path,
356
221
  submission_name: str,
357
- console: Console,
358
222
  season: str | None = None,
359
223
  ) -> UploadResult | None:
360
224
  """Upload submission to CoGames backend using a presigned S3 URL."""
361
225
  console.print("[bold]Uploading[/bold]")
362
226
 
363
- try:
364
- presigned_data = client.get_presigned_upload_url()
365
- upload_url = presigned_data.get("upload_url")
366
- upload_id = presigned_data.get("upload_id")
227
+ presigned_data = client.get_presigned_upload_url()
228
+ upload_url = presigned_data.get("upload_url")
229
+ upload_id = presigned_data.get("upload_id")
367
230
 
368
- if not upload_url or not upload_id:
369
- console.print("[red]Upload URL missing from response[/red]")
370
- return None
371
- except httpx.TimeoutException:
372
- console.print("[red]Timed out while requesting upload URL[/red]")
373
- return None
374
- except httpx.HTTPStatusError as exc:
375
- console.print(f"[red]Failed to get upload URL ({exc.response.status_code})[/red]")
376
- console.print(f"[dim]{exc.response.text}[/dim]")
377
- return None
378
- except Exception as e:
379
- console.print(f"[red]Error requesting upload URL: {e}[/red]")
380
- return None
231
+ if not upload_url or not upload_id:
232
+ raise ValueError("Upload URL missing from response")
381
233
 
382
234
  console.print("[dim]Uploading to storage...[/dim]")
383
235
 
384
- try:
385
- with open(zip_path, "rb") as f:
386
- upload_response = httpx.put(
387
- upload_url,
388
- content=f,
389
- headers={"Content-Type": "application/zip"},
390
- timeout=600.0,
391
- )
392
- upload_response.raise_for_status()
393
- except httpx.TimeoutException:
394
- console.print("[red]Upload timed out after 10 minutes[/red]")
395
- return None
396
- except httpx.HTTPStatusError as exc:
397
- console.print(f"[red]Upload failed with status {exc.response.status_code}[/red]")
398
- console.print(f"[dim]{exc.response.text}[/dim]")
399
- return None
400
- except Exception as e:
401
- console.print(f"[red]Upload error: {e}[/red]")
402
- return None
236
+ with open(zip_path, "rb") as f:
237
+ upload_response = httpx.put(
238
+ upload_url,
239
+ content=f,
240
+ headers={"Content-Type": "application/zip"},
241
+ timeout=600.0,
242
+ )
243
+ upload_response.raise_for_status()
403
244
 
404
245
  if not season:
405
246
  console.print("[dim]Uploading policy...[/dim]")
406
247
  else:
407
248
  console.print(f"[dim]Uploading policy and submitting to season {season}...[/dim]")
408
249
 
250
+ result = client.complete_policy_upload(upload_id, submission_name, season=season)
251
+ submission_id = result.get("id")
252
+ name = result.get("name")
253
+ version = result.get("version")
254
+ pools = result.get("pools")
255
+ if submission_id is None or name is None or version is None:
256
+ raise ValueError("Missing fields in response")
409
257
  try:
410
- result = client.complete_policy_upload(upload_id, submission_name, season=season)
411
- submission_id = result.get("id")
412
- name = result.get("name")
413
- version = result.get("version")
414
- pools = result.get("pools")
415
- if submission_id is not None and name is not None and version is not None:
416
- try:
417
- return UploadResult(
418
- policy_version_id=uuid.UUID(str(submission_id)),
419
- name=name,
420
- version=version,
421
- pools=pools,
422
- )
423
- except ValueError:
424
- console.print(f"[red]Invalid submission ID returned: {submission_id}[/red]")
425
- return None
426
-
427
- console.print("[red]Missing fields in response[/red]")
428
- return None
429
- except httpx.TimeoutException:
430
- console.print("[red]Registration timed out[/red]")
431
- return None
432
- except httpx.HTTPStatusError as exc:
433
- console.print(f"[red]Registration failed with status {exc.response.status_code}[/red]")
434
- console.print(f"[dim]{exc.response.text}[/dim]")
435
- return None
436
- except Exception as e:
437
- console.print(f"[red]Registration error: {e}[/red]")
438
- return None
439
-
440
-
441
- def _upload_policy_bundle(
442
- bundle_result: tuple[Path, bool],
443
- *,
444
- client: TournamentServerClient,
445
- name: str,
446
- console: Console,
447
- init_kwargs: dict[str, str] | None,
448
- include_files: list[str] | None,
449
- setup_script: str | None,
450
- skip_validation: bool,
451
- dry_run: bool,
452
- validation_season: str,
453
- server: str,
454
- season: str | None = None,
455
- ) -> UploadResult | None:
456
- bundle_zip, cleanup_bundle_zip = bundle_result
457
-
458
- try:
459
- if init_kwargs or include_files or setup_script:
460
- console.print("[red]Error:[/red] Extra files/kwargs are not supported when uploading a bundle URI.")
461
- console.print(
462
- "Upload the bundle as-is, or use a short policy name or URI with local files to "
463
- "build a new submission zip."
464
- )
465
- return None
466
-
467
- if not skip_validation:
468
- if not validate_bundle_in_isolation(bundle_zip, console, season=validation_season, server=server):
469
- console.print("\n[red]Upload aborted due to validation failure.[/red]")
470
- return None
471
- else:
472
- console.print("[dim]Skipping validation[/dim]")
473
-
474
- if dry_run:
475
- console.print("[green]Dry run complete[/green]")
476
- return None
477
-
478
- with client:
479
- result = upload_submission(client, bundle_zip, name, console, season=season)
480
- if not result:
481
- console.print("\n[red]Upload failed.[/red]")
482
- return None
483
-
484
- return result
485
- finally:
486
- if cleanup_bundle_zip and bundle_zip.exists():
487
- bundle_zip.unlink()
258
+ return UploadResult(
259
+ policy_version_id=uuid.UUID(str(submission_id)),
260
+ name=name,
261
+ version=version,
262
+ pools=pools,
263
+ )
264
+ except ValueError as exc:
265
+ raise ValueError(f"Invalid submission ID returned: {submission_id}") from exc
488
266
 
489
267
 
490
268
  def upload_policy(
@@ -498,11 +276,8 @@ def upload_policy(
498
276
  dry_run: bool = False,
499
277
  skip_validation: bool = False,
500
278
  setup_script: str | None = None,
501
- validation_season: str = "",
502
279
  season: str | None = None,
503
280
  ) -> UploadResult | None:
504
- from cogames.cli.client import TournamentServerClient
505
-
506
281
  if dry_run:
507
282
  console.print("[dim]Dry run mode - no upload[/dim]\n")
508
283
 
@@ -510,116 +285,47 @@ def upload_policy(
510
285
  if not client:
511
286
  return None
512
287
 
513
- try:
514
- bundle_result = _maybe_resolve_checkpoint_bundle_uri(policy)
515
- except Exception as e:
516
- console.print(f"[red]Error resolving checkpoint bundle:[/red] {e}")
517
- return None
288
+ with tempfile.TemporaryDirectory(prefix="cogames_bundle_") as tmp_dir:
289
+ zip_path = Path(tmp_dir) / "bundle.zip"
518
290
 
519
- if bundle_result is not None:
520
- return _upload_policy_bundle(
521
- bundle_result,
522
- client=client,
523
- name=name,
524
- console=console,
525
- init_kwargs=init_kwargs,
291
+ create_bundle(
292
+ ctx=ctx,
293
+ policy=policy,
294
+ output=zip_path,
526
295
  include_files=include_files,
296
+ init_kwargs=init_kwargs,
527
297
  setup_script=setup_script,
528
- skip_validation=skip_validation,
529
- dry_run=dry_run,
530
- validation_season=validation_season,
531
- server=server,
532
- season=season,
533
298
  )
534
299
 
535
- try:
536
- policy_spec = get_policy_spec(ctx, policy)
537
- except Exception as e:
538
- console.print(f"[red]Error parsing policy:[/red] {e}")
539
- return None
540
-
541
- if init_kwargs:
542
- merged_kwargs = {**policy_spec.init_kwargs, **init_kwargs}
543
- policy_spec = PolicySpec(
544
- class_path=policy_spec.class_path,
545
- data_path=policy_spec.data_path,
546
- init_kwargs=merged_kwargs,
547
- )
548
-
549
- cwd = Path.cwd().resolve()
550
- if policy_spec.data_path:
551
- try:
552
- resolved = Path(policy_spec.data_path).expanduser().resolve()
553
- data_rel = str(resolved.relative_to(cwd))
554
- except ValueError:
555
- console.print("[red]Error:[/red] Policy weights path must be within the current directory.")
556
- console.print(f"[dim]{policy_spec.data_path}[/dim]")
557
- return None
558
- policy_spec = PolicySpec(
559
- class_path=policy_spec.class_path,
560
- data_path=data_rel,
561
- init_kwargs=policy_spec.init_kwargs,
562
- )
563
-
564
- setup_script_rel: str | None = None
565
- if setup_script:
566
- try:
567
- resolved = Path(setup_script).expanduser().resolve()
568
- setup_script_rel = str(resolved.relative_to(cwd))
569
- except ValueError:
570
- console.print("[red]Error:[/red] Setup script path must be within the current directory.")
571
- console.print(f"[dim]{setup_script}[/dim]")
572
- return None
573
-
574
- files_to_include = []
575
- if policy_spec.data_path:
576
- files_to_include.append(policy_spec.data_path)
577
- if setup_script_rel:
578
- files_to_include.append(setup_script_rel)
579
- if include_files:
580
- files_to_include.extend(include_files)
581
-
582
- validated_paths: list[Path] = []
583
- if files_to_include:
584
- try:
585
- validated_paths = validate_paths(files_to_include, console)
586
- except (ValueError, FileNotFoundError):
587
- return None
300
+ if not skip_validation:
301
+ cmd = [
302
+ sys.executable,
303
+ "-m",
304
+ "cogames",
305
+ "validate-bundle",
306
+ "--policy",
307
+ zip_path.as_uri(),
308
+ "--server",
309
+ server,
310
+ ]
311
+ if season:
312
+ cmd.extend(["--season", season])
313
+ result = subprocess.run(cmd, text=True, timeout=300)
314
+ if result.returncode != 0:
315
+ console.print("[red]Validation failed[/red]")
316
+ return None
317
+ console.print("[green]Validation passed[/green]")
318
+ else:
319
+ console.print("[dim]Skipping validation[/dim]")
588
320
 
589
- if not skip_validation:
590
- if not validate_policy_in_isolation(
591
- policy_spec,
592
- validated_paths,
593
- console,
594
- setup_script=setup_script_rel,
595
- season=validation_season,
596
- server=server,
597
- ):
598
- console.print("\n[red]Upload aborted due to validation failure.[/red]")
321
+ if dry_run:
322
+ console.print("[green]Dry run complete[/green]")
599
323
  return None
600
- else:
601
- console.print("[dim]Skipping validation[/dim]")
602
324
 
603
- try:
604
- zip_path = create_submission_zip(validated_paths, policy_spec, setup_script=setup_script_rel)
605
- except Exception as e:
606
- console.print(f"[red]Error creating zip:[/red] {e}")
607
- return None
608
-
609
- if dry_run:
610
- console.print("[green]Dry run complete[/green]")
611
- if zip_path.exists():
612
- zip_path.unlink()
613
- return None
614
-
615
- try:
616
325
  with client:
617
- result = upload_submission(client, zip_path, name, console, season=season)
326
+ result = upload_submission(client, zip_path, name, season=season)
618
327
  if not result:
619
328
  console.print("\n[red]Upload failed.[/red]")
620
329
  return None
621
330
 
622
331
  return result
623
- finally:
624
- if zip_path.exists():
625
- zip_path.unlink()