cogames 0.3.68__py3-none-any.whl → 0.3.69__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cogames/cli/submit.py CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
5
5
  import os
6
6
  import shutil
7
7
  import subprocess
8
+ import sys
8
9
  import tempfile
9
10
  import uuid
10
11
  import zipfile
@@ -13,24 +14,19 @@ from pathlib import Path
13
14
 
14
15
  import httpx
15
16
  import typer
16
- from rich.console import Console
17
17
 
18
18
  from cogames.cli.base import console
19
19
  from cogames.cli.client import TournamentServerClient
20
20
  from cogames.cli.login import DEFAULT_COGAMES_SERVER
21
- from cogames.cli.mission import get_mission
22
- from cogames.cli.policy import PolicySpec, get_policy_spec, parse_policy_spec
21
+ from cogames.cli.policy import PolicySpec, get_policy_spec
23
22
  from mettagrid.config.mettagrid_config import MettaGridConfig
24
- from mettagrid.policy.prepare_policy_spec import download_policy_spec_from_s3_as_zip
25
23
  from mettagrid.policy.submission import POLICY_SPEC_FILENAME, SubmissionPolicySpec
26
- from mettagrid.runner.rollout import run_episode_local
27
- from mettagrid.util.uri_resolvers.schemes import parse_uri, resolve_uri
24
+ from mettagrid.runner.episode_runner import run_episode_isolated
25
+ from mettagrid.runner.types import EpisodeSpec
26
+ from mettagrid.util.uri_resolvers.schemes import localize_uri, parse_uri
28
27
 
29
28
  DEFAULT_SUBMIT_SERVER = "https://api.observatory.softmax-research.net"
30
-
31
-
32
- def results_url_for_season(_server: str, _season: str) -> str:
33
- return "https://www.softmax.com/alignmentleague"
29
+ RESULTS_URL = "https://www.softmax.com/alignmentleague"
34
30
 
35
31
 
36
32
  @dataclass
@@ -41,303 +37,35 @@ class UploadResult:
41
37
  pools: list[str] | None = None
42
38
 
43
39
 
44
- def validate_paths(paths: list[str], console: Console) -> list[Path]:
45
- """Validate paths are within CWD and return them as relative paths."""
40
+ def _resolve_path_within_cwd(path_str: str, cwd: Path) -> Path:
41
+ """Resolve a path and return it relative to CWD. Raises if path escapes CWD."""
42
+ raw_path = Path(path_str).expanduser()
43
+ resolved = raw_path.resolve() if raw_path.is_absolute() else (cwd / raw_path).resolve()
44
+ if not resolved.is_relative_to(cwd):
45
+ console.print(f"[red]Error:[/red] Path must be within the current directory: {path_str}")
46
+ raise ValueError(f"Path escapes CWD: {path_str}")
47
+ return resolved.relative_to(cwd)
48
+
49
+
50
+ def validate_paths(paths: list[str]) -> list[Path]:
51
+ """Validate paths exist and are within CWD, return them as relative paths."""
46
52
  cwd = Path.cwd().resolve()
47
53
  validated_paths = []
48
-
49
54
  for path_str in paths:
50
- raw_path = Path(path_str).expanduser()
51
-
52
- # Resolve the path and check it's within CWD
53
- resolved = raw_path.resolve() if raw_path.is_absolute() else (cwd / raw_path).resolve()
54
- if not resolved.is_relative_to(cwd):
55
- console.print(f"[red]Error:[/red] Path must be within the current directory: {path_str}")
56
- raise ValueError(f"Path escapes CWD: {path_str}")
57
- relative = resolved.relative_to(cwd)
58
-
59
- # Check if path exists
55
+ relative = _resolve_path_within_cwd(path_str, cwd)
56
+ resolved = cwd / relative
60
57
  if not resolved.exists():
61
58
  console.print(f"[red]Error:[/red] Path does not exist: {path_str}")
62
59
  raise FileNotFoundError(f"Path not found: {path_str}")
63
-
64
60
  validated_paths.append(relative)
65
-
66
61
  return validated_paths
67
62
 
68
63
 
69
- def _maybe_resolve_checkpoint_bundle_uri(policy: str) -> tuple[Path, bool] | None:
70
- """Return (local_zip_path, cleanup) if policy points to a checkpoint bundle URI."""
71
- first = policy.split(",", 1)[0].strip()
72
- parsed = parse_uri(first, allow_none=True, default_scheme=None)
73
- if parsed is None or parsed.scheme not in {"file", "s3"}:
74
- return None
75
-
76
- resolved = resolve_uri(first)
77
- if resolved.scheme == "s3":
78
- if not resolved.canonical.endswith(".zip"):
79
- raise ValueError("S3 policy specs must be .zip bundles.")
80
- return download_policy_spec_from_s3_as_zip(resolved.canonical), False
81
-
82
- local_path = resolved.local_path
83
- if local_path is None:
84
- raise ValueError(f"Cannot resolve local path for URI: {policy}")
85
- if not local_path.exists():
86
- raise FileNotFoundError(f"Bundle path not found: {local_path}")
87
- if local_path.is_dir():
88
- return _zip_directory_bundle(local_path), True
89
- if local_path.suffix == ".zip":
90
- return local_path, False
91
- raise ValueError("Checkpoint bundle must be a directory or .zip file.")
92
-
93
-
94
- def _zip_directory_bundle(bundle_dir: Path) -> Path:
95
- zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="cogames_bundle_")
96
- os.close(zip_fd)
97
-
98
- with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
99
- for file_path in bundle_dir.rglob("*"):
64
+ def _zip_directory_to(src: Path, dest: Path) -> None:
65
+ with zipfile.ZipFile(dest, "w", zipfile.ZIP_DEFLATED) as zipf:
66
+ for file_path in src.rglob("*"):
100
67
  if file_path.is_file():
101
- zipf.write(file_path, arcname=file_path.relative_to(bundle_dir))
102
-
103
- return Path(zip_path)
104
-
105
-
106
- def validate_bundle_in_isolation(policy_zip: Path, console: Console, *, season: str, server: str) -> bool:
107
- console.print("[dim]Testing policy bundle can run 10 steps...[/dim]")
108
-
109
- temp_dir = create_temp_validation_env()
110
- try:
111
- bundle_name = policy_zip.name
112
- shutil.copy2(policy_zip, temp_dir / bundle_name)
113
-
114
- env = os.environ.copy()
115
- env["UV_NO_CACHE"] = "1"
116
-
117
- res = subprocess.run(
118
- [
119
- "uv",
120
- "run",
121
- "cogames",
122
- "validate-policy",
123
- "--season",
124
- season,
125
- "--server",
126
- server,
127
- "--policy",
128
- f"file://./{bundle_name}",
129
- ],
130
- cwd=temp_dir,
131
- capture_output=True,
132
- text=True,
133
- timeout=300,
134
- env=env,
135
- )
136
- if res.returncode != 0:
137
- console.print("[red]Validation failed[/red]")
138
- console.print(f"\n[red]Error:[/red]\n{res.stderr}")
139
- if res.stdout:
140
- console.print(f"\n[dim]Output:[/dim]\n{res.stdout}")
141
- return False
142
-
143
- console.print("[green]Validation passed[/green]")
144
- return True
145
- finally:
146
- if temp_dir.exists():
147
- shutil.rmtree(temp_dir)
148
-
149
-
150
- def get_latest_pypi_version(package: str) -> str:
151
- """Get the latest published version of a package from PyPI."""
152
- response = httpx.get(f"https://pypi.org/pypi/{package}/json")
153
- response.raise_for_status()
154
- return response.json()["info"]["version"]
155
-
156
-
157
- def get_pypi_requires_python(package: str) -> str:
158
- """Get the requires_python constraint from PyPI."""
159
- response = httpx.get(f"https://pypi.org/pypi/{package}/json")
160
- response.raise_for_status()
161
- return response.json()["info"]["requires_python"]
162
-
163
-
164
- def create_temp_validation_env() -> Path:
165
- """Create a temporary directory with a minimal pyproject.toml.
166
-
167
- The pyproject.toml depends on the latest published cogames and cogames-agents packages.
168
- Python version is constrained to match mettagrid's published wheels.
169
- """
170
- temp_dir = Path(tempfile.mkdtemp(prefix="cogames_submit_"))
171
-
172
- latest_cogames_version = get_latest_pypi_version("cogames")
173
- latest_agents_version = get_latest_pypi_version("cogames-agents")
174
- mettagrid_requires_python = get_pypi_requires_python("mettagrid")
175
-
176
- pyproject_content = f"""[project]
177
- name = "cogames-submission-validator"
178
- version = "0.1.0"
179
- requires-python = "{mettagrid_requires_python}"
180
- dependencies = ["cogames=={latest_cogames_version}", "cogames-agents=={latest_agents_version}"]
181
-
182
- [build-system]
183
- requires = ["setuptools>=42"]
184
- build-backend = "setuptools.build_meta"
185
- """
186
-
187
- pyproject_path = temp_dir / "pyproject.toml"
188
- pyproject_path.write_text(pyproject_content)
189
-
190
- return temp_dir
191
-
192
-
193
- def copy_files_maintaining_structure(files: list[Path], dest_dir: Path) -> None:
194
- """Copy files to destination, maintaining directory structure.
195
-
196
- If a file is 'train_dir/model.pt', it will be copied to 'dest_dir/train_dir/model.pt'.
197
- """
198
- for file_path in files:
199
- dest_path = dest_dir / file_path
200
-
201
- if file_path.is_dir():
202
- shutil.copytree(file_path, dest_path, dirs_exist_ok=True)
203
- else:
204
- dest_path.parent.mkdir(parents=True, exist_ok=True)
205
- shutil.copy2(file_path, dest_path)
206
-
207
-
208
- _SEASON_VALIDATION_MISSIONS: dict[str, str] = {
209
- "beta": "training_facility.harvest",
210
- "beta-cogsguard": "cogsguard_arena.basic",
211
- }
212
- _DEFAULT_VALIDATION_MISSION = "cogsguard_arena.basic"
213
-
214
-
215
- def get_validation_mission_for_season(season: str | None = None) -> str:
216
- """Get the appropriate mission for validating policies in a given season."""
217
- if season is None:
218
- return _DEFAULT_VALIDATION_MISSION
219
- return _SEASON_VALIDATION_MISSIONS.get(season, _DEFAULT_VALIDATION_MISSION)
220
-
221
-
222
- def validate_policy_spec(
223
- policy_spec: PolicySpec,
224
- env_cfg: MettaGridConfig | None = None,
225
- *,
226
- device: str = "cpu",
227
- season: str | None = None,
228
- ) -> None:
229
- """Validate policy works.
230
-
231
- Runs a single episode (up to 10 steps) using the same alo rollout flow as `cogames eval`.
232
-
233
- Args:
234
- policy_spec: The policy to validate.
235
- env_cfg: Optional environment config to validate against.
236
- device: Target device for policy evaluation (cpu/cuda/auto).
237
- season: Optional season name to determine which game to validate against.
238
- """
239
- if env_cfg is None:
240
- mission_name = get_validation_mission_for_season(season)
241
- # Legacy seasons (e.g., "beta") use legacy missions that require include_legacy=True
242
- include_legacy = season in _SEASON_VALIDATION_MISSIONS
243
- _, env_cfg, _ = get_mission(mission_name, include_legacy=include_legacy)
244
- else:
245
- env_cfg = env_cfg.model_copy()
246
-
247
- # Run 1 episode for up to 10 steps to validate the policy works
248
- env_cfg.game.max_steps = 10
249
- n = env_cfg.game.num_agents
250
- n_submitted = min(2, n)
251
- noop_spec = PolicySpec(class_path="mettagrid.policy.noop.NoopPolicy")
252
- if n_submitted < n:
253
- policy_specs = [policy_spec, noop_spec]
254
- assignments = [0] * n_submitted + [1] * (n - n_submitted)
255
- else:
256
- policy_specs = [policy_spec]
257
- assignments = [0] * n
258
- run_episode_local(
259
- policy_specs=policy_specs,
260
- assignments=assignments,
261
- env=env_cfg,
262
- seed=42,
263
- max_action_time_ms=10000,
264
- device=device,
265
- )
266
-
267
-
268
- def validate_policy_uri(policy_uri: str, env_cfg: MettaGridConfig, *, device: str = "cpu") -> None:
269
- policy_spec = parse_policy_spec(policy_uri, device=device).to_policy_spec()
270
- validate_policy_spec(policy_spec, env_cfg, device=device)
271
-
272
-
273
- def validate_policy_in_isolation(
274
- policy_spec: PolicySpec,
275
- include_files: list[Path],
276
- console: Console,
277
- setup_script: str | None = None,
278
- *,
279
- season: str,
280
- server: str,
281
- ) -> bool:
282
- def _format_policy_arg(spec: PolicySpec) -> str:
283
- parts = [f"class={spec.class_path}"]
284
- if spec.data_path:
285
- parts.append(f"data={spec.data_path}")
286
- for key, value in spec.init_kwargs.items():
287
- parts.append(f"kw.{key}={value}")
288
- return ",".join(parts)
289
-
290
- console.print("[dim]Testing policy can run 10 steps...[/dim]")
291
-
292
- temp_dir = create_temp_validation_env()
293
- try:
294
- copy_files_maintaining_structure(include_files, temp_dir)
295
-
296
- policy_arg = _format_policy_arg(policy_spec)
297
-
298
- def _run_from_tmp_dir(cmd: list[str]) -> subprocess.CompletedProcess[str]:
299
- env = os.environ.copy()
300
- env["UV_NO_CACHE"] = "1"
301
- res = subprocess.run(
302
- cmd,
303
- cwd=temp_dir,
304
- capture_output=True,
305
- text=True,
306
- timeout=300,
307
- env=env,
308
- )
309
- if not res.returncode == 0:
310
- console.print("[red]Validation failed[/red]")
311
- console.print(f"\n[red]Error:[/red]\n{res.stderr}")
312
- if res.stdout:
313
- console.print(f"\n[dim]Output:[/dim]\n{res.stdout}")
314
- raise Exception("Validation failed")
315
- return res
316
-
317
- _run_from_tmp_dir(["uv", "run", "cogames", "version"])
318
-
319
- validate_cmd = [
320
- "uv",
321
- "run",
322
- "cogames",
323
- "validate-policy",
324
- "--policy",
325
- policy_arg,
326
- "--season",
327
- season,
328
- "--server",
329
- server,
330
- ]
331
- if setup_script:
332
- validate_cmd.extend(["--setup-script", setup_script])
333
-
334
- _run_from_tmp_dir(validate_cmd)
335
-
336
- console.print("[green]Validation passed[/green]")
337
- return True
338
- finally:
339
- if temp_dir.exists():
340
- shutil.rmtree(temp_dir)
68
+ zipf.write(file_path, arcname=file_path.relative_to(src))
341
69
 
342
70
 
343
71
  def _collect_ancestor_init_files(include_files: list[Path]) -> list[Path]:
@@ -392,11 +120,105 @@ def create_submission_zip(
392
120
  return Path(zip_path)
393
121
 
394
122
 
123
+ def create_bundle(
124
+ ctx: typer.Context,
125
+ policy: str,
126
+ output: Path,
127
+ include_files: list[str] | None = None,
128
+ init_kwargs: dict[str, str] | None = None,
129
+ setup_script: str | None = None,
130
+ ) -> Path:
131
+ # TODO: Unify the two paths below. For URI inputs, extract the PolicySpec from the
132
+ # bundle so we can apply init_kwargs/include_files/setup_script, then re-zip.
133
+ local = localize_uri(policy) if parse_uri(policy, allow_none=True, default_scheme=None) else None
134
+ if local is not None:
135
+ if init_kwargs or include_files or setup_script:
136
+ console.print("[red]Error:[/red] Extra files/kwargs are not supported with bundle URIs.")
137
+ raise typer.Exit(1)
138
+ console.print(f"[dim]Packaging existing bundle: {local}[/dim]")
139
+ if local.is_dir():
140
+ _zip_directory_to(local, output)
141
+ else:
142
+ shutil.copy2(local, output)
143
+ console.print(f"[dim]Bundle size: {output.stat().st_size / 1024:.0f} KB[/dim]")
144
+ return output
145
+
146
+ policy_spec = get_policy_spec(ctx, policy)
147
+ console.print(f"[dim]Policy class: {policy_spec.class_path}[/dim]")
148
+
149
+ if init_kwargs:
150
+ merged_kwargs = {**policy_spec.init_kwargs, **init_kwargs}
151
+ policy_spec = PolicySpec(
152
+ class_path=policy_spec.class_path,
153
+ data_path=policy_spec.data_path,
154
+ init_kwargs=merged_kwargs,
155
+ )
156
+
157
+ if policy_spec.init_kwargs:
158
+ console.print(f"[dim]Init kwargs: {policy_spec.init_kwargs}[/dim]")
159
+
160
+ cwd = Path.cwd().resolve()
161
+ if policy_spec.data_path:
162
+ data_rel = str(_resolve_path_within_cwd(policy_spec.data_path, cwd))
163
+ policy_spec = PolicySpec(
164
+ class_path=policy_spec.class_path,
165
+ data_path=data_rel,
166
+ init_kwargs=policy_spec.init_kwargs,
167
+ )
168
+ console.print(f"[dim]Data path: {data_rel}[/dim]")
169
+
170
+ setup_script_rel: str | None = None
171
+ if setup_script:
172
+ setup_script_rel = str(_resolve_path_within_cwd(setup_script, cwd))
173
+ console.print(f"[dim]Setup script: {setup_script_rel}[/dim]")
174
+
175
+ files_to_include = []
176
+ if policy_spec.data_path:
177
+ files_to_include.append(policy_spec.data_path)
178
+ if setup_script_rel:
179
+ files_to_include.append(setup_script_rel)
180
+ if include_files:
181
+ files_to_include.extend(include_files)
182
+
183
+ validated_paths: list[Path] = []
184
+ if files_to_include:
185
+ validated_paths = validate_paths(files_to_include)
186
+ console.print(f"[dim]Including {len(validated_paths)} file(s)[/dim]")
187
+
188
+ tmp_zip = create_submission_zip(validated_paths, policy_spec, setup_script=setup_script_rel)
189
+ shutil.move(str(tmp_zip), str(output))
190
+ console.print(f"[dim]Bundle size: {output.stat().st_size / 1024:.0f} KB[/dim]")
191
+ return output
192
+
193
+
194
+ def validate_bundle(policy_uri: str, env_cfg: MettaGridConfig) -> None:
195
+ """Validate a policy bundle by running a short episode in process isolation."""
196
+ env_cfg.game.max_steps = 10
197
+
198
+ spec = EpisodeSpec(
199
+ policy_uris=[policy_uri],
200
+ assignments=[0] * env_cfg.game.num_agents,
201
+ env=env_cfg,
202
+ seed=42,
203
+ max_action_time_ms=10000,
204
+ )
205
+
206
+ with tempfile.NamedTemporaryFile(suffix=".json", delete=True) as results_file:
207
+ res = run_episode_isolated(spec, Path(results_file.name))
208
+ console.print(f"[dim]Ran for {res.steps} steps[/dim]")
209
+
210
+ non_noop_actions = sum(
211
+ v for k, v in res.stats["agent"][0].items() if k.startswith("action.") and ".noop." not in k
212
+ )
213
+ if non_noop_actions == 0:
214
+ console.print("[yellow]Warning: Policy took no actions (all no-ops)[/yellow]")
215
+ raise typer.Exit(1)
216
+
217
+
395
218
  def upload_submission(
396
219
  client: TournamentServerClient,
397
220
  zip_path: Path,
398
221
  submission_name: str,
399
- console: Console,
400
222
  season: str | None = None,
401
223
  ) -> UploadResult | None:
402
224
  """Upload submission to CoGames backend using a presigned S3 URL."""
@@ -443,55 +265,6 @@ def upload_submission(
443
265
  raise ValueError(f"Invalid submission ID returned: {submission_id}") from exc
444
266
 
445
267
 
446
- def _upload_policy_bundle(
447
- bundle_result: tuple[Path, bool],
448
- *,
449
- client: TournamentServerClient,
450
- name: str,
451
- console: Console,
452
- init_kwargs: dict[str, str] | None,
453
- include_files: list[str] | None,
454
- setup_script: str | None,
455
- skip_validation: bool,
456
- dry_run: bool,
457
- validation_season: str,
458
- server: str,
459
- season: str | None = None,
460
- ) -> UploadResult | None:
461
- bundle_zip, cleanup_bundle_zip = bundle_result
462
-
463
- try:
464
- if init_kwargs or include_files or setup_script:
465
- console.print("[red]Error:[/red] Extra files/kwargs are not supported when uploading a bundle URI.")
466
- console.print(
467
- "Upload the bundle as-is, or use a short policy name or URI with local files to "
468
- "build a new submission zip."
469
- )
470
- return None
471
-
472
- if not skip_validation:
473
- if not validate_bundle_in_isolation(bundle_zip, console, season=validation_season, server=server):
474
- console.print("\n[red]Upload aborted due to validation failure.[/red]")
475
- return None
476
- else:
477
- console.print("[dim]Skipping validation[/dim]")
478
-
479
- if dry_run:
480
- console.print("[green]Dry run complete[/green]")
481
- return None
482
-
483
- with client:
484
- result = upload_submission(client, bundle_zip, name, console, season=season)
485
- if not result:
486
- console.print("\n[red]Upload failed.[/red]")
487
- return None
488
-
489
- return result
490
- finally:
491
- if cleanup_bundle_zip and bundle_zip.exists():
492
- bundle_zip.unlink()
493
-
494
-
495
268
  def upload_policy(
496
269
  ctx: typer.Context,
497
270
  policy: str,
@@ -503,7 +276,6 @@ def upload_policy(
503
276
  dry_run: bool = False,
504
277
  skip_validation: bool = False,
505
278
  setup_script: str | None = None,
506
- validation_season: str = "",
507
279
  season: str | None = None,
508
280
  ) -> UploadResult | None:
509
281
  if dry_run:
@@ -513,99 +285,47 @@ def upload_policy(
513
285
  if not client:
514
286
  return None
515
287
 
516
- bundle_result = _maybe_resolve_checkpoint_bundle_uri(policy)
288
+ with tempfile.TemporaryDirectory(prefix="cogames_bundle_") as tmp_dir:
289
+ zip_path = Path(tmp_dir) / "bundle.zip"
517
290
 
518
- if bundle_result is not None:
519
- return _upload_policy_bundle(
520
- bundle_result,
521
- client=client,
522
- name=name,
523
- console=console,
524
- init_kwargs=init_kwargs,
291
+ create_bundle(
292
+ ctx=ctx,
293
+ policy=policy,
294
+ output=zip_path,
525
295
  include_files=include_files,
296
+ init_kwargs=init_kwargs,
526
297
  setup_script=setup_script,
527
- skip_validation=skip_validation,
528
- dry_run=dry_run,
529
- validation_season=validation_season,
530
- server=server,
531
- season=season,
532
- )
533
-
534
- policy_spec = get_policy_spec(ctx, policy)
535
-
536
- if init_kwargs:
537
- merged_kwargs = {**policy_spec.init_kwargs, **init_kwargs}
538
- policy_spec = PolicySpec(
539
- class_path=policy_spec.class_path,
540
- data_path=policy_spec.data_path,
541
- init_kwargs=merged_kwargs,
542
298
  )
543
299
 
544
- cwd = Path.cwd().resolve()
545
- if policy_spec.data_path:
546
- resolved = Path(policy_spec.data_path).expanduser().resolve()
547
- if not resolved.is_relative_to(cwd):
548
- console.print("[red]Error:[/red] Policy weights path must be within the current directory.")
549
- console.print(f"[dim]{policy_spec.data_path}[/dim]")
550
- raise ValueError("Policy weights path must be within the current directory.")
551
- data_rel = str(resolved.relative_to(cwd))
552
- policy_spec = PolicySpec(
553
- class_path=policy_spec.class_path,
554
- data_path=data_rel,
555
- init_kwargs=policy_spec.init_kwargs,
556
- )
557
-
558
- setup_script_rel: str | None = None
559
- if setup_script:
560
- resolved = Path(setup_script).expanduser().resolve()
561
- if not resolved.is_relative_to(cwd):
562
- console.print("[red]Error:[/red] Setup script path must be within the current directory.")
563
- console.print(f"[dim]{setup_script}[/dim]")
564
- raise ValueError("Setup script path must be within the current directory.")
565
- setup_script_rel = str(resolved.relative_to(cwd))
566
-
567
- files_to_include = []
568
- if policy_spec.data_path:
569
- files_to_include.append(policy_spec.data_path)
570
- if setup_script_rel:
571
- files_to_include.append(setup_script_rel)
572
- if include_files:
573
- files_to_include.extend(include_files)
300
+ if not skip_validation:
301
+ cmd = [
302
+ sys.executable,
303
+ "-m",
304
+ "cogames",
305
+ "validate-bundle",
306
+ "--policy",
307
+ zip_path.as_uri(),
308
+ "--server",
309
+ server,
310
+ ]
311
+ if season:
312
+ cmd.extend(["--season", season])
313
+ result = subprocess.run(cmd, text=True, timeout=300)
314
+ if result.returncode != 0:
315
+ console.print("[red]Validation failed[/red]")
316
+ return None
317
+ console.print("[green]Validation passed[/green]")
318
+ else:
319
+ console.print("[dim]Skipping validation[/dim]")
574
320
 
575
- validated_paths: list[Path] = []
576
- if files_to_include:
577
- validated_paths = validate_paths(files_to_include, console)
578
-
579
- if not skip_validation:
580
- if not validate_policy_in_isolation(
581
- policy_spec,
582
- validated_paths,
583
- console,
584
- setup_script=setup_script_rel,
585
- season=validation_season,
586
- server=server,
587
- ):
588
- console.print("\n[red]Upload aborted due to validation failure.[/red]")
321
+ if dry_run:
322
+ console.print("[green]Dry run complete[/green]")
589
323
  return None
590
- else:
591
- console.print("[dim]Skipping validation[/dim]")
592
-
593
- zip_path = create_submission_zip(validated_paths, policy_spec, setup_script=setup_script_rel)
594
-
595
- if dry_run:
596
- console.print("[green]Dry run complete[/green]")
597
- if zip_path.exists():
598
- zip_path.unlink()
599
- return None
600
324
 
601
- try:
602
325
  with client:
603
- result = upload_submission(client, zip_path, name, console, season=season)
326
+ result = upload_submission(client, zip_path, name, season=season)
604
327
  if not result:
605
328
  console.print("\n[red]Upload failed.[/red]")
606
329
  return None
607
330
 
608
331
  return result
609
- finally:
610
- if zip_path.exists():
611
- zip_path.unlink()
cogames/main.py CHANGED
@@ -63,7 +63,13 @@ from cogames.cli.policy import (
63
63
  policy_arg_example,
64
64
  policy_arg_w_proportion_example,
65
65
  )
66
- from cogames.cli.submit import DEFAULT_SUBMIT_SERVER, results_url_for_season, upload_policy, validate_policy_spec
66
+ from cogames.cli.submit import (
67
+ DEFAULT_SUBMIT_SERVER,
68
+ RESULTS_URL,
69
+ create_bundle,
70
+ upload_policy,
71
+ validate_bundle,
72
+ )
67
73
  from cogames.cogs_vs_clips.mission import CvCMission, NumCogsVariant
68
74
  from cogames.curricula import make_rotation
69
75
  from cogames.device import resolve_training_device
@@ -2228,12 +2234,12 @@ def _resolve_season(server: str, season_name: str | None = None) -> SeasonInfo:
2228
2234
 
2229
2235
 
2230
2236
  @app.command(
2231
- name="validate-policy",
2232
- help="Validate the policy loads and runs for at least a single step",
2237
+ name="create-bundle",
2238
+ help="Create a submission bundle zip from a policy",
2233
2239
  rich_help_panel="Policies",
2234
2240
  add_help_option=False,
2235
2241
  )
2236
- def validate_policy_cmd(
2242
+ def create_bundle_cmd(
2237
2243
  ctx: typer.Context,
2238
2244
  policy: str = typer.Option(
2239
2245
  ...,
@@ -2243,18 +2249,77 @@ def validate_policy_cmd(
2243
2249
  help=f"Policy specification: {policy_arg_example}",
2244
2250
  rich_help_panel="Policy",
2245
2251
  ),
2246
- device: str = typer.Option(
2247
- "auto",
2248
- "--device",
2249
- metavar="DEVICE",
2250
- help="Policy device (auto, cpu, cuda, cuda:0, etc.)",
2252
+ output: Path = typer.Option( # noqa: B008
2253
+ Path("submission.zip"),
2254
+ "--output",
2255
+ "-o",
2256
+ metavar="PATH",
2257
+ help="Output path for the bundle zip",
2258
+ rich_help_panel="Output",
2259
+ ),
2260
+ init_kwarg: Optional[list[str]] = typer.Option( # noqa: B008
2261
+ None,
2262
+ "--init-kwarg",
2263
+ "-k",
2264
+ metavar="KEY=VAL",
2265
+ help="Policy init kwargs (can be repeated)",
2251
2266
  rich_help_panel="Policy",
2252
2267
  ),
2268
+ include_files: Optional[list[str]] = typer.Option( # noqa: B008
2269
+ None,
2270
+ "--include-files",
2271
+ "-f",
2272
+ metavar="PATH",
2273
+ help="Files or directories to include (can be repeated)",
2274
+ rich_help_panel="Files",
2275
+ ),
2253
2276
  setup_script: Optional[str] = typer.Option(
2254
2277
  None,
2255
2278
  "--setup-script",
2256
- help="Path to a Python setup script to run before loading the policy",
2257
- rich_help_panel="Policy",
2279
+ metavar="PATH",
2280
+ help="Python setup script to include in the bundle",
2281
+ rich_help_panel="Files",
2282
+ ),
2283
+ _help: bool = typer.Option(
2284
+ False,
2285
+ "--help",
2286
+ "-h",
2287
+ help="Show this message and exit",
2288
+ is_eager=True,
2289
+ callback=_help_callback,
2290
+ rich_help_panel="Other",
2291
+ ),
2292
+ ) -> None:
2293
+ init_kwargs: dict[str, str] = {}
2294
+ if init_kwarg:
2295
+ for kv in init_kwarg:
2296
+ key, val = _parse_init_kwarg(kv)
2297
+ init_kwargs[key] = val
2298
+
2299
+ result_path = create_bundle(
2300
+ ctx=ctx,
2301
+ policy=policy,
2302
+ output=output.resolve(),
2303
+ include_files=include_files,
2304
+ init_kwargs=init_kwargs if init_kwargs else None,
2305
+ setup_script=setup_script,
2306
+ )
2307
+ console.print(f"[green]Bundle created:[/green] {result_path}")
2308
+
2309
+
2310
+ @app.command(
2311
+ name="validate-bundle",
2312
+ help="Validate a policy bundle runs correctly in process isolation",
2313
+ rich_help_panel="Policies",
2314
+ add_help_option=False,
2315
+ )
2316
+ def validate_bundle_cmd(
2317
+ policy: str = typer.Option(
2318
+ ...,
2319
+ "--policy",
2320
+ "-p",
2321
+ metavar="URI",
2322
+ help="Bundle URI (file://, s3://, or local path to .zip or directory)",
2258
2323
  ),
2259
2324
  season: Optional[str] = typer.Option(
2260
2325
  None,
@@ -2289,37 +2354,8 @@ def validate_policy_cmd(
2289
2354
  with TournamentServerClient(server_url=server) as client:
2290
2355
  config_data = client.get_config(entry_pool_info.config_id)
2291
2356
  env_cfg = MettaGridConfig.model_validate(config_data)
2357
+ validate_bundle(policy, env_cfg)
2292
2358
 
2293
- if setup_script:
2294
- import subprocess # noqa: PLC0415
2295
- import sys # noqa: PLC0415
2296
- from pathlib import Path # noqa: PLC0415
2297
-
2298
- script_path = Path(setup_script)
2299
- if not script_path.exists():
2300
- console.print(f"[red]Setup script not found: {setup_script}[/red]")
2301
- raise typer.Exit(1)
2302
- console.print(f"[yellow]Running setup script: {setup_script}[/yellow]")
2303
- result = subprocess.run(
2304
- [sys.executable, str(script_path)],
2305
- cwd=Path.cwd(),
2306
- capture_output=True,
2307
- text=True,
2308
- timeout=300,
2309
- )
2310
- if result.returncode != 0:
2311
- console.print(f"[red]Setup script failed:[/red]\n{result.stderr}")
2312
- raise typer.Exit(1)
2313
- console.print("[green]Setup script completed[/green]")
2314
-
2315
- resolved_device = resolve_training_device(console, device)
2316
- policy_spec = get_policy_spec(ctx, policy, device=str(resolved_device))
2317
- validate_policy_spec(
2318
- policy_spec,
2319
- env_cfg,
2320
- device=str(resolved_device),
2321
- season=season_info.name,
2322
- )
2323
2359
  console.print("[green]Policy validated successfully[/green]")
2324
2360
  raise typer.Exit(0)
2325
2361
 
@@ -2446,11 +2482,6 @@ def upload_cmd(
2446
2482
  ) -> None:
2447
2483
  season_info = _resolve_season(server, season)
2448
2484
 
2449
- has_entry_config = any(p.config_id for p in season_info.pools if p.name == season_info.entry_pool)
2450
- if not has_entry_config and not skip_validation:
2451
- console.print("[yellow]Warning: No entry config found for season. Skipping validation.[/yellow]")
2452
- skip_validation = True
2453
-
2454
2485
  init_kwargs: dict[str, str] = {}
2455
2486
  if init_kwarg:
2456
2487
  for kv in init_kwarg:
@@ -2468,7 +2499,6 @@ def upload_cmd(
2468
2499
  skip_validation=skip_validation,
2469
2500
  init_kwargs=init_kwargs if init_kwargs else None,
2470
2501
  setup_script=setup_script,
2471
- validation_season=season_info.name,
2472
2502
  season=season_info.name if not no_submit else None,
2473
2503
  )
2474
2504
 
@@ -2476,7 +2506,7 @@ def upload_cmd(
2476
2506
  console.print(f"[green]Upload complete: {result.name}:v{result.version}[/green]")
2477
2507
  if result.pools:
2478
2508
  console.print(f"[dim]Added to pools: {', '.join(result.pools)}[/dim]")
2479
- console.print(f"[dim]Results:[/dim] {results_url_for_season(server, season_info.name)}")
2509
+ console.print(f"[dim]Results:[/dim] {RESULTS_URL}")
2480
2510
  elif no_submit:
2481
2511
  console.print(f"\nTo submit to a tournament: cogames submit {result.name}:v{result.version}")
2482
2512
 
@@ -2574,7 +2604,7 @@ def submit_cmd(
2574
2604
  console.print(f"\n[bold green]Submitted to season '{season_name}'[/bold green]")
2575
2605
  if result.pools:
2576
2606
  console.print(f"[dim]Added to pools: {', '.join(result.pools)}[/dim]")
2577
- console.print(f"[dim]Results:[/dim] {results_url_for_season(server, season_name)}")
2607
+ console.print(f"[dim]Results:[/dim] {RESULTS_URL}")
2578
2608
  console.print(f"[dim]CLI:[/dim] cogames leaderboard --season {season_name}")
2579
2609
 
2580
2610
 
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cogames
3
- Version: 0.3.68
3
+ Version: 0.3.69
4
4
  Summary: Multi-agent cooperative games
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Programming Language :: Python :: 3.12
7
7
  Requires-Python: <3.13,>=3.12
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
- Requires-Dist: mettagrid==0.2.0.82
10
+ Requires-Dist: mettagrid==0.2.0.83
11
11
  Requires-Dist: packaging>=24.0.0
12
12
  Requires-Dist: pufferlib-core
13
13
  Requires-Dist: pydantic>=2.11.5
@@ -5,7 +5,7 @@ cogames/curricula.py,sha256=a7Nd-av4epAjmoqlefyU4q5eAitpYjCEGDCF_tS_iaY,825
5
5
  cogames/device.py,sha256=GVC7g4tNVySn_rSbHJB0jGKvpYBL8-VmeYEQXWXtvy0,1680
6
6
  cogames/evaluate.py,sha256=Qc3KopiDJs4iF2qpBajQYh2-7HnFWpOk1enexVhEZPo,9358
7
7
  cogames/game.py,sha256=Qpyd4UOv97S-PCcEb422XFVbuXk46qhEx7_FlAr8ayA,3331
8
- cogames/main.py,sha256=yBhcgEl8__0ogMybctDwY4tVGy5lXy8WYJqzLN11pu4,90307
8
+ cogames/main.py,sha256=Z3xRk6JCPzVLFUr1SxiNIU4oLk1V16vVLkQi4ea7_D8,90453
9
9
  cogames/pickup.py,sha256=kZUR19_2HzUK3ltc3GyglEv0kgUCPvm8feeCzC9EKDI,6578
10
10
  cogames/play.py,sha256=waSx8523tOWX8_KrcjlgV50USU90wGhYmpSCPchdJkg,7882
11
11
  cogames/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -17,7 +17,7 @@ cogames/cli/leaderboard.py,sha256=eIbiSBOkvoi5IL2weiKBOJAjWGHquDk1PWoHKA8uu3g,13
17
17
  cogames/cli/login.py,sha256=_i1Hdbp_wAMsX0NLbCTSr7GmOFbzSqwyNP79OqLOg40,1237
18
18
  cogames/cli/mission.py,sha256=X8vAdWz2qBKdrB4gh9xuqoRb_L39bYzdXUf9l-eGc28,22624
19
19
  cogames/cli/policy.py,sha256=v9MyddJ56lkTHJ195HyWswVqVXLyxuCI064QSerOCWg,9513
20
- cogames/cli/submit.py,sha256=xEGeJA00eber_SHlAB2b21VaVPwvcLCbjDABPbkAKFM,20665
20
+ cogames/cli/submit.py,sha256=xJyBH0LS-y2q9V8l-iupnVei7CvXo4gW43VFIEu_ldk,11536
21
21
  cogames/cli/utils.py,sha256=VLTFNIAVNFPL_dUtJqLT7umE_0rrJSFdbo0kntmGmKE,1266
22
22
  cogames/cli/docsync/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  cogames/cli/docsync/_nb_md_directive_processing.py,sha256=A920oS21mw9ZU2vp-cILzq5Sa-AMtmwMSayZKcQwX9M,6365
@@ -149,12 +149,12 @@ cogames/policy/trainable_policy_template.py,sha256=GbB4vi0wsIiuh_APtfjnavvh9ZVUx
149
149
  cogames/policy/nim_agents/__init__.py,sha256=ZMJgUljlvBoa3swAupfOyS6jF2jIwGPX-nL_T9vFmAw,1181
150
150
  cogames/policy/nim_agents/agents.py,sha256=Do9WIgeodBi99H5UfCN0MlCjosb8gxaSl2cojIbiK9o,4528
151
151
  cogames/policy/nim_agents/thinky_eval.py,sha256=NqIZng30MpyZD5hF03uoje4potd873hWyHOb3AZdeZk,1154
152
- cogames-0.3.68.dist-info/licenses/LICENSE,sha256=NG4hf0NHdGZhkabVCK1MKX8RAJmWaEm4eaGcUsONZ_E,1065
152
+ cogames-0.3.69.dist-info/licenses/LICENSE,sha256=NG4hf0NHdGZhkabVCK1MKX8RAJmWaEm4eaGcUsONZ_E,1065
153
153
  metta_alo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
154
154
  metta_alo/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
155
155
  metta_alo/scoring.py,sha256=PcGVUXmxh4H4BoEnkMQaxrwxf06BGTqRZVSnOTPzG9M,5875
156
- cogames-0.3.68.dist-info/METADATA,sha256=RD6b9CR38DP0EhBiI7TQNkncvgrqJlpASQlO04h2nO8,242101
157
- cogames-0.3.68.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
158
- cogames-0.3.68.dist-info/entry_points.txt,sha256=GTHdgj-RC2hQYmjUMSp9RHX8kbru19k0LS2lAj8DnLE,45
159
- cogames-0.3.68.dist-info/top_level.txt,sha256=YErBkYWJd3-eksLpbgbMrETni1MPBNL4mqEyhZUa0UE,18
160
- cogames-0.3.68.dist-info/RECORD,,
156
+ cogames-0.3.69.dist-info/METADATA,sha256=mIF-Ot6rJWtbI39HxTGlFqG7J74EUxOh9pYQJYKsGJo,242101
157
+ cogames-0.3.69.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
158
+ cogames-0.3.69.dist-info/entry_points.txt,sha256=GTHdgj-RC2hQYmjUMSp9RHX8kbru19k0LS2lAj8DnLE,45
159
+ cogames-0.3.69.dist-info/top_level.txt,sha256=YErBkYWJd3-eksLpbgbMrETni1MPBNL4mqEyhZUa0UE,18
160
+ cogames-0.3.69.dist-info/RECORD,,