cogames 0.3.49__py3-none-any.whl → 0.3.64__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. cogames/cli/client.py +60 -6
  2. cogames/cli/docsync/__init__.py +0 -0
  3. cogames/cli/docsync/_nb_md_directive_processing.py +180 -0
  4. cogames/cli/docsync/_nb_md_sync.py +103 -0
  5. cogames/cli/docsync/_nb_py_sync.py +122 -0
  6. cogames/cli/docsync/_three_way_sync.py +115 -0
  7. cogames/cli/docsync/_utils.py +76 -0
  8. cogames/cli/docsync/docsync.py +156 -0
  9. cogames/cli/leaderboard.py +112 -28
  10. cogames/cli/mission.py +64 -53
  11. cogames/cli/policy.py +46 -10
  12. cogames/cli/submit.py +268 -67
  13. cogames/cogs_vs_clips/cog.py +79 -0
  14. cogames/cogs_vs_clips/cogs_vs_clips_mapgen.md +19 -16
  15. cogames/cogs_vs_clips/cogsguard_reward_variants.py +153 -0
  16. cogames/cogs_vs_clips/cogsguard_tutorial.py +56 -0
  17. cogames/cogs_vs_clips/evals/README.md +10 -16
  18. cogames/cogs_vs_clips/evals/cogsguard_evals.py +81 -0
  19. cogames/cogs_vs_clips/evals/diagnostic_evals.py +49 -444
  20. cogames/cogs_vs_clips/evals/difficulty_variants.py +13 -326
  21. cogames/cogs_vs_clips/evals/integrated_evals.py +5 -45
  22. cogames/cogs_vs_clips/evals/spanning_evals.py +9 -180
  23. cogames/cogs_vs_clips/mission.py +187 -146
  24. cogames/cogs_vs_clips/missions.py +46 -137
  25. cogames/cogs_vs_clips/procedural.py +8 -8
  26. cogames/cogs_vs_clips/sites.py +107 -3
  27. cogames/cogs_vs_clips/stations.py +198 -186
  28. cogames/cogs_vs_clips/tutorial_missions.py +1 -1
  29. cogames/cogs_vs_clips/variants.py +25 -476
  30. cogames/device.py +13 -1
  31. cogames/{policy/scripted_agent/README.md → docs/SCRIPTED_AGENT.md} +82 -58
  32. cogames/evaluate.py +18 -30
  33. cogames/main.py +1434 -243
  34. cogames/maps/canidate1_1000.map +1 -1
  35. cogames/maps/canidate1_1000_stations.map +2 -2
  36. cogames/maps/canidate1_500.map +1 -1
  37. cogames/maps/canidate1_500_stations.map +2 -2
  38. cogames/maps/canidate2_1000.map +1 -1
  39. cogames/maps/canidate2_1000_stations.map +2 -2
  40. cogames/maps/canidate2_500.map +1 -1
  41. cogames/maps/canidate2_500_stations.map +2 -2
  42. cogames/maps/canidate3_1000.map +1 -1
  43. cogames/maps/canidate3_1000_stations.map +2 -2
  44. cogames/maps/canidate3_500.map +1 -1
  45. cogames/maps/canidate3_500_stations.map +2 -2
  46. cogames/maps/canidate4_500.map +1 -1
  47. cogames/maps/canidate4_500_stations.map +2 -2
  48. cogames/maps/cave_base_50.map +2 -2
  49. cogames/maps/diagnostic_evals/diagnostic_agile.map +2 -2
  50. cogames/maps/diagnostic_evals/diagnostic_agile_hard.map +2 -2
  51. cogames/maps/diagnostic_evals/diagnostic_charge_up.map +2 -2
  52. cogames/maps/diagnostic_evals/diagnostic_charge_up_hard.map +2 -2
  53. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1.map +2 -2
  54. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1_hard.map +2 -2
  55. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2.map +2 -2
  56. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2_hard.map +2 -2
  57. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3.map +2 -2
  58. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3_hard.map +2 -2
  59. cogames/maps/diagnostic_evals/diagnostic_chest_near.map +2 -2
  60. cogames/maps/diagnostic_evals/diagnostic_chest_search.map +2 -2
  61. cogames/maps/diagnostic_evals/diagnostic_chest_search_hard.map +2 -2
  62. cogames/maps/diagnostic_evals/diagnostic_extract_lab.map +2 -2
  63. cogames/maps/diagnostic_evals/diagnostic_extract_lab_hard.map +2 -2
  64. cogames/maps/diagnostic_evals/diagnostic_memory.map +2 -2
  65. cogames/maps/diagnostic_evals/diagnostic_memory_hard.map +2 -2
  66. cogames/maps/diagnostic_evals/diagnostic_radial.map +2 -2
  67. cogames/maps/diagnostic_evals/diagnostic_radial_hard.map +2 -2
  68. cogames/maps/diagnostic_evals/diagnostic_resource_lab.map +2 -2
  69. cogames/maps/diagnostic_evals/diagnostic_unclip.map +2 -2
  70. cogames/maps/evals/eval_balanced_spread.map +9 -5
  71. cogames/maps/evals/eval_clip_oxygen.map +9 -5
  72. cogames/maps/evals/eval_collect_resources.map +9 -5
  73. cogames/maps/evals/eval_collect_resources_hard.map +9 -5
  74. cogames/maps/evals/eval_collect_resources_medium.map +9 -5
  75. cogames/maps/evals/eval_divide_and_conquer.map +9 -5
  76. cogames/maps/evals/eval_energy_starved.map +9 -5
  77. cogames/maps/evals/eval_multi_coordinated_collect_hard.map +9 -5
  78. cogames/maps/evals/eval_oxygen_bottleneck.map +9 -5
  79. cogames/maps/evals/eval_single_use_world.map +9 -5
  80. cogames/maps/evals/extractor_hub_100x100.map +9 -5
  81. cogames/maps/evals/extractor_hub_30x30.map +9 -5
  82. cogames/maps/evals/extractor_hub_50x50.map +9 -5
  83. cogames/maps/evals/extractor_hub_70x70.map +9 -5
  84. cogames/maps/evals/extractor_hub_80x80.map +9 -5
  85. cogames/maps/machina_100_stations.map +2 -2
  86. cogames/maps/machina_200_stations.map +2 -2
  87. cogames/maps/machina_200_stations_small.map +2 -2
  88. cogames/maps/machina_eval_exp01.map +2 -2
  89. cogames/maps/machina_eval_template_large.map +2 -2
  90. cogames/maps/machinatrainer4agents.map +2 -2
  91. cogames/maps/machinatrainer4agentsbase.map +2 -2
  92. cogames/maps/machinatrainerbig.map +2 -2
  93. cogames/maps/machinatrainersmall.map +2 -2
  94. cogames/maps/planky_evals/aligner_avoid_aoe.map +28 -0
  95. cogames/maps/planky_evals/aligner_full_cycle.map +28 -0
  96. cogames/maps/planky_evals/aligner_gear.map +24 -0
  97. cogames/maps/planky_evals/aligner_hearts.map +24 -0
  98. cogames/maps/planky_evals/aligner_junction.map +26 -0
  99. cogames/maps/planky_evals/exploration_distant.map +28 -0
  100. cogames/maps/planky_evals/maze.map +32 -0
  101. cogames/maps/planky_evals/miner_best_resource.map +26 -0
  102. cogames/maps/planky_evals/miner_deposit.map +24 -0
  103. cogames/maps/planky_evals/miner_extract.map +26 -0
  104. cogames/maps/planky_evals/miner_full_cycle.map +28 -0
  105. cogames/maps/planky_evals/miner_gear.map +24 -0
  106. cogames/maps/planky_evals/multi_role.map +28 -0
  107. cogames/maps/planky_evals/resource_chain.map +30 -0
  108. cogames/maps/planky_evals/scout_explore.map +32 -0
  109. cogames/maps/planky_evals/scout_gear.map +24 -0
  110. cogames/maps/planky_evals/scrambler_full_cycle.map +28 -0
  111. cogames/maps/planky_evals/scrambler_gear.map +24 -0
  112. cogames/maps/planky_evals/scrambler_target.map +26 -0
  113. cogames/maps/planky_evals/stuck_corridor.map +32 -0
  114. cogames/maps/planky_evals/survive_retreat.map +26 -0
  115. cogames/maps/training_facility_clipped.map +2 -2
  116. cogames/maps/training_facility_open_1.map +2 -2
  117. cogames/maps/training_facility_open_2.map +2 -2
  118. cogames/maps/training_facility_open_3.map +2 -2
  119. cogames/maps/training_facility_tight_4.map +2 -2
  120. cogames/maps/training_facility_tight_5.map +2 -2
  121. cogames/maps/vanilla_large.map +2 -2
  122. cogames/maps/vanilla_small.map +2 -2
  123. cogames/pickup.py +183 -0
  124. cogames/play.py +166 -33
  125. cogames/policy/chaos_monkey.py +54 -0
  126. cogames/policy/nim_agents/__init__.py +27 -10
  127. cogames/policy/nim_agents/agents.py +121 -60
  128. cogames/policy/nim_agents/thinky_eval.py +35 -222
  129. cogames/policy/pufferlib_policy.py +67 -32
  130. cogames/policy/starter_agent.py +184 -0
  131. cogames/policy/trainable_policy_template.py +4 -1
  132. cogames/train.py +51 -13
  133. cogames/verbose.py +2 -2
  134. cogames-0.3.64.dist-info/METADATA +1842 -0
  135. cogames-0.3.64.dist-info/RECORD +159 -0
  136. cogames-0.3.64.dist-info/licenses/LICENSE +21 -0
  137. cogames-0.3.64.dist-info/top_level.txt +2 -0
  138. metta_alo/__init__.py +0 -0
  139. metta_alo/job_specs.py +17 -0
  140. metta_alo/policy.py +16 -0
  141. metta_alo/pure_single_episode_runner.py +75 -0
  142. metta_alo/py.typed +0 -0
  143. metta_alo/rollout.py +322 -0
  144. metta_alo/scoring.py +168 -0
  145. cogames/maps/diagnostic_evals/diagnostic_assembler_near.map +0 -49
  146. cogames/maps/diagnostic_evals/diagnostic_assembler_search.map +0 -49
  147. cogames/maps/diagnostic_evals/diagnostic_assembler_search_hard.map +0 -89
  148. cogames/policy/nim_agents/common.nim +0 -887
  149. cogames/policy/nim_agents/install.sh +0 -1
  150. cogames/policy/nim_agents/ladybug_agent.nim +0 -984
  151. cogames/policy/nim_agents/nim_agents.nim +0 -55
  152. cogames/policy/nim_agents/nim_agents.nims +0 -14
  153. cogames/policy/nim_agents/nimby.lock +0 -3
  154. cogames/policy/nim_agents/racecar_agents.nim +0 -884
  155. cogames/policy/nim_agents/random_agents.nim +0 -68
  156. cogames/policy/nim_agents/test_agents.py +0 -53
  157. cogames/policy/nim_agents/thinky_agents.nim +0 -717
  158. cogames/policy/scripted_agent/baseline_agent.py +0 -1049
  159. cogames/policy/scripted_agent/demo_policy.py +0 -244
  160. cogames/policy/scripted_agent/pathfinding.py +0 -126
  161. cogames/policy/scripted_agent/starter_agent.py +0 -136
  162. cogames/policy/scripted_agent/types.py +0 -235
  163. cogames/policy/scripted_agent/unclipping_agent.py +0 -476
  164. cogames/policy/scripted_agent/utils.py +0 -385
  165. cogames-0.3.49.dist-info/METADATA +0 -406
  166. cogames-0.3.49.dist-info/RECORD +0 -136
  167. cogames-0.3.49.dist-info/top_level.txt +0 -1
  168. {cogames-0.3.49.dist-info → cogames-0.3.64.dist-info}/WHEEL +0 -0
  169. {cogames-0.3.49.dist-info → cogames-0.3.64.dist-info}/entry_points.txt +0 -0
cogames/play.py CHANGED
@@ -1,21 +1,160 @@
1
1
  """Game playing functionality for CoGames."""
2
2
 
3
3
  import logging
4
+ import uuid
4
5
  from pathlib import Path
5
6
  from typing import Optional
6
7
 
8
+ from rich import box
7
9
  from rich.console import Console
10
+ from rich.table import Table
8
11
 
12
+ from metta_alo.pure_single_episode_runner import PureSingleEpisodeResult
13
+ from metta_alo.rollout import run_single_episode
9
14
  from mettagrid import MettaGridConfig
10
- from mettagrid.policy.loader import initialize_or_load_policy
11
15
  from mettagrid.policy.policy import PolicySpec
12
- from mettagrid.policy.policy_env_interface import PolicyEnvInterface
13
16
  from mettagrid.renderer.renderer import RenderMode
14
- from mettagrid.simulator.replay_log_writer import ReplayLogWriter
15
- from mettagrid.simulator.rollout import Rollout
16
17
 
17
18
  logger = logging.getLogger("cogames.play")
18
19
 
20
+ # Resources and gear types for CogsGuard
21
+ ELEMENTS = ["carbon", "oxygen", "germanium", "silicon"]
22
+ GEAR = ["miner", "aligner", "scrambler", "scout"]
23
+
24
+
25
+ def _print_episode_stats(console: Console, results: PureSingleEpisodeResult) -> None:
26
+ """Print episode statistics in a formatted table."""
27
+ stats = results.stats
28
+ total_reward = sum(results.rewards)
29
+
30
+ # Aggregate agent stats
31
+ agent_stats = stats.get("agent", [])
32
+ totals: dict[str, float] = {}
33
+ for agent in agent_stats:
34
+ for key, value in agent.items():
35
+ totals[key] = totals.get(key, 0) + value
36
+
37
+ # Check if this is a CogsGuard mission (has collective stats)
38
+ collective_stats = stats.get("collective", {})
39
+ cogs_stats = collective_stats.get("cogs", {})
40
+ clips_stats = collective_stats.get("clips", {})
41
+
42
+ if cogs_stats or clips_stats:
43
+ # CogsGuard mission - show relevant stats
44
+ _print_cogsguard_stats(console, totals, cogs_stats, clips_stats, total_reward)
45
+ else:
46
+ # Standard mission - show basic stats
47
+ _print_standard_stats(console, totals, total_reward)
48
+
49
+
50
+ def _print_cogsguard_stats(
51
+ console: Console,
52
+ agent_totals: dict[str, float],
53
+ cogs_stats: dict[str, float],
54
+ clips_stats: dict[str, float],
55
+ total_reward: float,
56
+ ) -> None:
57
+ """Print CogsGuard-specific statistics."""
58
+ table = Table(title="Episode Stats", box=box.ROUNDED, show_header=True, header_style="bold cyan")
59
+ table.add_column("Category", style="yellow")
60
+ table.add_column("Stat", style="white")
61
+ table.add_column("Gained", style="green", justify="right")
62
+ table.add_column("Lost", style="red", justify="right")
63
+ table.add_column("Final", style="cyan", justify="right")
64
+
65
+ sections_added = 0
66
+
67
+ # Junctions: gained (aligned) | lost | final (current count)
68
+ junctions_added = False
69
+ for team, team_stats, color in [("cogs", cogs_stats, "green"), ("clips", clips_stats, "red")]:
70
+ gained = int(team_stats.get("junction.gained", 0))
71
+ lost = int(team_stats.get("junction.lost", 0))
72
+ final = int(team_stats.get("junction", 0))
73
+ if gained > 0 or lost > 0 or final > 0:
74
+ if not junctions_added:
75
+ table.add_row("[bold]Junctions[/bold]", "", "", "", "")
76
+ junctions_added = True
77
+ sections_added += 1
78
+ table.add_row(
79
+ "",
80
+ f"[{color}]{team}[/{color}]",
81
+ f"[{color}]{gained}[/{color}]",
82
+ f"[{color}]{lost}[/{color}]",
83
+ f"[{color}]{final}[/{color}]",
84
+ )
85
+
86
+ # Gear: gained | lost | final (net = gained - lost)
87
+ gear_added = False
88
+ for gear in GEAR:
89
+ gained = int(agent_totals.get(f"{gear}.gained", 0))
90
+ lost = int(agent_totals.get(f"{gear}.lost", 0))
91
+ final = gained - lost
92
+ if gained > 0 or lost > 0:
93
+ if not gear_added:
94
+ if sections_added > 0:
95
+ table.add_section()
96
+ table.add_row("[bold]Gear[/bold]", "", "", "", "")
97
+ gear_added = True
98
+ sections_added += 1
99
+ table.add_row("", gear, str(gained), str(lost), str(final))
100
+
101
+ # Hearts (in Gear section)
102
+ hearts_gained = int(agent_totals.get("heart.gained", 0))
103
+ hearts_lost = int(agent_totals.get("heart.lost", 0))
104
+ if hearts_gained > 0 or hearts_lost > 0:
105
+ if not gear_added:
106
+ if sections_added > 0:
107
+ table.add_section()
108
+ table.add_row("[bold]Gear[/bold]", "", "", "", "")
109
+ gear_added = True
110
+ sections_added += 1
111
+ table.add_row("", "hearts", str(hearts_gained), str(hearts_lost), "")
112
+
113
+ # Resources: gained (deposited) | lost (withdrawn) | final (current amount)
114
+ resources_added = False
115
+ for resource in ELEMENTS:
116
+ gained = int(cogs_stats.get(f"collective.{resource}.deposited", 0))
117
+ lost = int(cogs_stats.get(f"collective.{resource}.withdrawn", 0))
118
+ final = int(cogs_stats.get(f"collective.{resource}.amount", 0))
119
+ if gained > 0 or lost > 0 or final > 0:
120
+ if not resources_added:
121
+ if sections_added > 0:
122
+ table.add_section()
123
+ table.add_row("[bold]Resources[/bold]", "", "", "", "")
124
+ resources_added = True
125
+ sections_added += 1
126
+ table.add_row("", resource, str(gained), str(lost), str(final))
127
+
128
+ # Total reward at bottom
129
+ if sections_added > 0:
130
+ table.add_section()
131
+ table.add_row("[bold]Reward[/bold]", "total", f"{total_reward:.2f}", "", "")
132
+
133
+ console.print(table)
134
+
135
+
136
+ def _print_standard_stats(console: Console, agent_totals: dict[str, float], total_reward: float) -> None:
137
+ """Print standard statistics for non-CogsGuard missions."""
138
+ # Filter for interesting stats
139
+ interesting = {}
140
+ for key, value in agent_totals.items():
141
+ if value != 0 and any(pattern in key for pattern in [".gained", ".lost", ".deposited", ".withdrawn", "heart"]):
142
+ interesting[key] = value
143
+
144
+ table = Table(title="Episode Stats", box=box.ROUNDED, show_header=True, header_style="bold cyan")
145
+ table.add_column("Stat", style="white")
146
+ table.add_column("Value", style="green", justify="right")
147
+
148
+ for key in sorted(interesting.keys()):
149
+ table.add_row(key, f"{int(interesting[key])}")
150
+
151
+ # Total reward at bottom
152
+ if interesting:
153
+ table.add_section()
154
+ table.add_row("[bold]Reward (total)[/bold]", f"{total_reward:.2f}")
155
+
156
+ console.print(table)
157
+
19
158
 
20
159
  def play(
21
160
  console: Console,
@@ -34,48 +173,42 @@ def play(
34
173
  policy_spec: Policy specification (class path and optional data path)
35
174
  game_name: Human-readable name of the game (used for logging/metadata)
36
175
  seed: Random seed
37
- render_mode: Render mode - "gui", "unicode", or "none"
176
+ render_mode: Render mode - "gui", "vibescope", "unicode", or "none"
38
177
  save_replay: Optional directory path to save replay. Directory will be created if it doesn't exist.
39
178
  Replay will be saved with a unique UUID-based filename.
40
179
  """
41
180
 
42
181
  logger.debug("Starting play session", extra={"game_name": game_name})
43
182
 
44
- policy_env_info = PolicyEnvInterface.from_mg_cfg(env_cfg)
45
- policy = initialize_or_load_policy(policy_env_info, policy_spec)
46
- agent_policies = [policy.agent_policy(agent_id) for agent_id in range(env_cfg.game.num_agents)]
47
-
48
- # Set up replay writer if requested
49
- event_handlers = []
50
- replay_writer = None
183
+ replay_path = None
51
184
  if save_replay:
52
- replay_writer = ReplayLogWriter(str(save_replay))
53
- event_handlers.append(replay_writer)
54
-
55
- # Create simulator and renderer
56
- rollout = Rollout(
57
- env_cfg,
58
- agent_policies,
59
- render_mode=render_mode,
60
- seed=seed,
61
- event_handlers=event_handlers,
62
- )
185
+ save_replay.mkdir(parents=True, exist_ok=True)
186
+ replay_path = save_replay / f"{uuid.uuid4()}.json.z"
187
+
63
188
  try:
64
- rollout.run_until_done()
189
+ results, _replay = run_single_episode(
190
+ policy_specs=[policy_spec],
191
+ assignments=[0] * env_cfg.game.num_agents,
192
+ env=env_cfg,
193
+ results_uri=None,
194
+ replay_uri=str(replay_path) if replay_path else None,
195
+ seed=seed,
196
+ device="cpu",
197
+ render_mode=render_mode,
198
+ )
65
199
  except KeyboardInterrupt:
66
200
  logger.info("Interrupted; ending episode early.")
67
- rollout._sim.end_episode()
68
201
  return
69
202
 
70
203
  # Print summary
71
204
  console.print("\n[bold green]Episode Complete![/bold green]")
72
- console.print(f"Steps: {rollout._sim.current_step}")
73
- console.print(f"Total Rewards: {rollout._sim.episode_rewards}")
74
- console.print(f"Final Reward Sum: {float(sum(rollout._sim.episode_rewards)):.2f}")
205
+ console.print(f"Steps: {results.steps}")
206
+
207
+ # Print episode stats
208
+ _print_episode_stats(console, results)
75
209
 
76
210
  # Print replay command if replay was saved
77
- if replay_writer:
78
- for replay_path in replay_writer.get_written_replay_paths():
79
- console.print("\n[bold cyan]Replay saved![/bold cyan]")
80
- console.print("To watch the replay, run:")
81
- console.print(f"[bold green]cogames replay {replay_path}[/bold green]")
211
+ if replay_path:
212
+ console.print("\n[bold cyan]Replay saved![/bold cyan]")
213
+ console.print("To watch the replay, run:")
214
+ console.print(f"[bold green]cogames replay {replay_path}[/bold green]")
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from typing import Optional
5
+
6
+ from mettagrid.policy.policy import AgentPolicy, MultiAgentPolicy
7
+ from mettagrid.policy.policy_env_interface import PolicyEnvInterface
8
+ from mettagrid.simulator import Action, AgentObservation, Simulation
9
+
10
+
11
+ class _ChaosMonkeyAgent(AgentPolicy):
12
+ def __init__(self, policy_env_info: PolicyEnvInterface, fail_step: int, fail_probability: float):
13
+ super().__init__(policy_env_info)
14
+ self._fail_step = fail_step
15
+ self._fail_probability = fail_probability
16
+ self._step = 0
17
+ self._failed = False
18
+ self._noop = Action(name="noop")
19
+
20
+ def reset(self, simulation: Optional[Simulation] = None) -> None:
21
+ self._step = 0
22
+ self._failed = False
23
+
24
+ def step(self, obs: AgentObservation) -> Action:
25
+ if self._failed:
26
+ return self._noop
27
+
28
+ if self._step >= self._fail_step and random.random() < self._fail_probability:
29
+ self._failed = True
30
+ raise RuntimeError(f"Chaos monkey triggered at step {self._step}")
31
+
32
+ self._step += 1
33
+ return self._noop
34
+
35
+
36
+ class ChaosMonkeyPolicy(MultiAgentPolicy):
37
+ """A scripted policy that intentionally fails mid-episode to test robustness."""
38
+
39
+ short_names = ["chaos-monkey"]
40
+
41
+ def __init__(
42
+ self,
43
+ policy_env_info: PolicyEnvInterface,
44
+ fail_step: int | str = 10,
45
+ fail_probability: float | int | str = 1.0,
46
+ device: str = "cpu",
47
+ **_: object,
48
+ ):
49
+ super().__init__(policy_env_info, device=device)
50
+ self._fail_step = int(fail_step)
51
+ self._fail_probability = float(fail_probability)
52
+
53
+ def agent_policy(self, agent_id: int) -> AgentPolicy:
54
+ return _ChaosMonkeyAgent(self.policy_env_info, self._fail_step, self._fail_probability)
@@ -1,18 +1,35 @@
1
- """Nim-based agent policies for CoGames."""
1
+ """Legacy import shim for Nim-based agents.
2
2
 
3
- from cogames.policy.nim_agents import agents # noqa: F401
3
+ Historically, Nim agents lived at `cogames.policy.nim_agents`. They were moved
4
+ to the optional `cogames-agents` package under `cogames_agents.policy.nim_agents`.
5
+
6
+ Old policy bundles may still import the legacy path; this module preserves
7
+ backwards compatibility by re-exporting the new implementation when available.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ try:
13
+ import cogames.policy.nim_agents.agents as agents # noqa: F401
14
+ from cogames.policy.nim_agents.agents import ( # noqa: F401
15
+ CogsguardAlignAllAgentsMultiPolicy,
16
+ LadyBugAgentsMultiPolicy,
17
+ RaceCarAgentsMultiPolicy,
18
+ RandomAgentsMultiPolicy,
19
+ ThinkyAgentsMultiPolicy,
20
+ )
21
+ except ModuleNotFoundError as exc:
22
+ if exc.name and (exc.name == "cogames_agents" or exc.name.startswith("cogames_agents.")):
23
+ raise ModuleNotFoundError(
24
+ "Legacy import `cogames.policy.nim_agents` requires optional dependency "
25
+ "`cogames-agents` (install `cogames[agents]`)."
26
+ ) from exc
27
+ raise
4
28
 
5
29
  __all__ = [
6
30
  "RandomAgentsMultiPolicy",
7
31
  "ThinkyAgentsMultiPolicy",
8
32
  "RaceCarAgentsMultiPolicy",
9
33
  "LadyBugAgentsMultiPolicy",
34
+ "CogsguardAlignAllAgentsMultiPolicy",
10
35
  ]
11
-
12
- # Re-export the policy classes for convenience
13
- from cogames.policy.nim_agents.agents import ( # noqa: F401
14
- LadyBugAgentsMultiPolicy,
15
- RaceCarAgentsMultiPolicy,
16
- RandomAgentsMultiPolicy,
17
- ThinkyAgentsMultiPolicy,
18
- )
@@ -1,66 +1,127 @@
1
- import importlib
2
- import os
3
- import sys
4
- from typing import Sequence
5
-
6
- from mettagrid.policy.policy import NimMultiAgentPolicy
7
- from mettagrid.policy.policy_env_interface import PolicyEnvInterface
8
-
9
- current_dir = os.path.dirname(os.path.abspath(__file__))
10
- bindings_dir = os.path.join(current_dir, "bindings/generated")
11
- if bindings_dir not in sys.path:
12
- sys.path.append(bindings_dir)
13
-
14
- na = importlib.import_module("nim_agents")
15
-
16
-
17
- def start_measure():
18
- na.start_measure()
19
-
20
-
21
- def end_measure():
22
- na.end_measure()
1
+ """Legacy module path for Nim-based agent policies.
23
2
 
3
+ The implementation moved to `cogames_agents.policy.nim_agents.agents`. Keep this
4
+ wrapper so older checkpoints and policy bundles that reference the old import
5
+ path can still load.
6
+ """
24
7
 
25
- class ThinkyAgentsMultiPolicy(NimMultiAgentPolicy):
26
- short_names = ["thinky"]
8
+ from __future__ import annotations
27
9
 
28
- def __init__(self, policy_env_info: PolicyEnvInterface, agent_ids: Sequence[int] | None = None):
29
- super().__init__(
30
- policy_env_info,
31
- nim_policy_factory=na.ThinkyPolicy,
32
- agent_ids=agent_ids,
33
- )
34
-
35
-
36
- class RandomAgentsMultiPolicy(NimMultiAgentPolicy):
37
- short_names = ["nim_random"]
38
-
39
- def __init__(self, policy_env_info: PolicyEnvInterface, agent_ids: Sequence[int] | None = None):
40
- super().__init__(
41
- policy_env_info,
42
- nim_policy_factory=na.RandomPolicy,
43
- agent_ids=agent_ids,
44
- )
45
-
46
-
47
- class RaceCarAgentsMultiPolicy(NimMultiAgentPolicy):
48
- short_names = ["race_car"]
49
-
50
- def __init__(self, policy_env_info: PolicyEnvInterface, agent_ids: Sequence[int] | None = None):
51
- super().__init__(
52
- policy_env_info,
53
- nim_policy_factory=na.RaceCarPolicy,
54
- agent_ids=agent_ids,
55
- )
10
+ from collections.abc import Callable
56
11
 
12
+ from mettagrid.policy.policy import MultiAgentPolicy
13
+ from mettagrid.policy.policy_env_interface import PolicyEnvInterface
57
14
 
58
- class LadyBugAgentsMultiPolicy(NimMultiAgentPolicy):
59
- short_names = ["nim_ladybug"]
60
15
 
61
- def __init__(self, policy_env_info: PolicyEnvInterface, agent_ids: Sequence[int] | None = None):
62
- super().__init__(
63
- policy_env_info,
64
- nim_policy_factory=na.LadybugPolicy,
65
- agent_ids=agent_ids,
66
- )
16
+ def _raise_missing_nim_agents() -> None:
17
+ raise ModuleNotFoundError(
18
+ "Nim scripted agents are not available (missing `nim_agents` bindings). "
19
+ "If you're developing locally, build them with: "
20
+ "`cd packages/cogames-agents/src/cogames_agents/policy/nim_agents && nim c nim_agents.nim`."
21
+ )
22
+
23
+
24
+ start_measure: Callable[[], None]
25
+ end_measure: Callable[[], None]
26
+ ThinkyAgentsMultiPolicy: type[MultiAgentPolicy]
27
+ RandomAgentsMultiPolicy: type[MultiAgentPolicy]
28
+ RaceCarAgentsMultiPolicy: type[MultiAgentPolicy]
29
+ LadyBugAgentsMultiPolicy: type[MultiAgentPolicy]
30
+ CogsguardAgentsMultiPolicy: type[MultiAgentPolicy]
31
+ CogsguardAlignAllAgentsMultiPolicy: type[MultiAgentPolicy]
32
+
33
+ try:
34
+ from cogames_agents.policy.nim_agents.agents import (
35
+ CogsguardAgentsMultiPolicy as _CogsguardAgentsMultiPolicy,
36
+ )
37
+ from cogames_agents.policy.nim_agents.agents import (
38
+ CogsguardAlignAllAgentsMultiPolicy as _CogsguardAlignAllAgentsMultiPolicy,
39
+ )
40
+ from cogames_agents.policy.nim_agents.agents import (
41
+ LadyBugAgentsMultiPolicy as _LadyBugAgentsMultiPolicy,
42
+ )
43
+ from cogames_agents.policy.nim_agents.agents import (
44
+ RaceCarAgentsMultiPolicy as _RaceCarAgentsMultiPolicy,
45
+ )
46
+ from cogames_agents.policy.nim_agents.agents import (
47
+ RandomAgentsMultiPolicy as _RandomAgentsMultiPolicy,
48
+ )
49
+ from cogames_agents.policy.nim_agents.agents import (
50
+ ThinkyAgentsMultiPolicy as _ThinkyAgentsMultiPolicy,
51
+ )
52
+ from cogames_agents.policy.nim_agents.agents import (
53
+ end_measure as _end_measure,
54
+ )
55
+ from cogames_agents.policy.nim_agents.agents import (
56
+ start_measure as _start_measure,
57
+ )
58
+ except (ModuleNotFoundError, OSError) as exc:
59
+ if (
60
+ isinstance(exc, ModuleNotFoundError)
61
+ and exc.name
62
+ and (exc.name == "cogames_agents" or exc.name.startswith("cogames_agents."))
63
+ ):
64
+ raise ModuleNotFoundError(
65
+ "Legacy import `cogames.policy.nim_agents.agents` requires optional dependency "
66
+ "`cogames-agents` (install `cogames[agents]`)."
67
+ ) from exc
68
+ # Fall back to stubs if the optional Nim bindings are missing.
69
+
70
+ def _missing_start_measure() -> None:
71
+ _raise_missing_nim_agents()
72
+
73
+ def _missing_end_measure() -> None:
74
+ _raise_missing_nim_agents()
75
+
76
+ class _MissingThinkyAgentsMultiPolicy(MultiAgentPolicy):
77
+ def __init__(self, policy_env_info: PolicyEnvInterface, **_: object):
78
+ _raise_missing_nim_agents()
79
+
80
+ class _MissingRandomAgentsMultiPolicy(MultiAgentPolicy):
81
+ def __init__(self, policy_env_info: PolicyEnvInterface, **_: object):
82
+ _raise_missing_nim_agents()
83
+
84
+ class _MissingRaceCarAgentsMultiPolicy(MultiAgentPolicy):
85
+ def __init__(self, policy_env_info: PolicyEnvInterface, **_: object):
86
+ _raise_missing_nim_agents()
87
+
88
+ class _MissingLadyBugAgentsMultiPolicy(MultiAgentPolicy):
89
+ def __init__(self, policy_env_info: PolicyEnvInterface, **_: object):
90
+ _raise_missing_nim_agents()
91
+
92
+ class _MissingCogsguardAgentsMultiPolicy(MultiAgentPolicy):
93
+ def __init__(self, policy_env_info: PolicyEnvInterface, **_: object):
94
+ _raise_missing_nim_agents()
95
+
96
+ class _MissingCogsguardAlignAllAgentsMultiPolicy(MultiAgentPolicy):
97
+ def __init__(self, policy_env_info: PolicyEnvInterface, **_: object):
98
+ _raise_missing_nim_agents()
99
+
100
+ start_measure = _missing_start_measure
101
+ end_measure = _missing_end_measure
102
+ ThinkyAgentsMultiPolicy = _MissingThinkyAgentsMultiPolicy
103
+ RandomAgentsMultiPolicy = _MissingRandomAgentsMultiPolicy
104
+ RaceCarAgentsMultiPolicy = _MissingRaceCarAgentsMultiPolicy
105
+ LadyBugAgentsMultiPolicy = _MissingLadyBugAgentsMultiPolicy
106
+ CogsguardAgentsMultiPolicy = _MissingCogsguardAgentsMultiPolicy
107
+ CogsguardAlignAllAgentsMultiPolicy = _MissingCogsguardAlignAllAgentsMultiPolicy
108
+ else:
109
+ start_measure = _start_measure
110
+ end_measure = _end_measure
111
+ ThinkyAgentsMultiPolicy = _ThinkyAgentsMultiPolicy
112
+ RandomAgentsMultiPolicy = _RandomAgentsMultiPolicy
113
+ RaceCarAgentsMultiPolicy = _RaceCarAgentsMultiPolicy
114
+ LadyBugAgentsMultiPolicy = _LadyBugAgentsMultiPolicy
115
+ CogsguardAgentsMultiPolicy = _CogsguardAgentsMultiPolicy
116
+ CogsguardAlignAllAgentsMultiPolicy = _CogsguardAlignAllAgentsMultiPolicy
117
+
118
+ __all__ = [
119
+ "start_measure",
120
+ "end_measure",
121
+ "ThinkyAgentsMultiPolicy",
122
+ "RandomAgentsMultiPolicy",
123
+ "RaceCarAgentsMultiPolicy",
124
+ "LadyBugAgentsMultiPolicy",
125
+ "CogsguardAgentsMultiPolicy",
126
+ "CogsguardAlignAllAgentsMultiPolicy",
127
+ ]