cogames 0.3.49__py3-none-any.whl → 0.3.64__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. cogames/cli/client.py +60 -6
  2. cogames/cli/docsync/__init__.py +0 -0
  3. cogames/cli/docsync/_nb_md_directive_processing.py +180 -0
  4. cogames/cli/docsync/_nb_md_sync.py +103 -0
  5. cogames/cli/docsync/_nb_py_sync.py +122 -0
  6. cogames/cli/docsync/_three_way_sync.py +115 -0
  7. cogames/cli/docsync/_utils.py +76 -0
  8. cogames/cli/docsync/docsync.py +156 -0
  9. cogames/cli/leaderboard.py +112 -28
  10. cogames/cli/mission.py +64 -53
  11. cogames/cli/policy.py +46 -10
  12. cogames/cli/submit.py +268 -67
  13. cogames/cogs_vs_clips/cog.py +79 -0
  14. cogames/cogs_vs_clips/cogs_vs_clips_mapgen.md +19 -16
  15. cogames/cogs_vs_clips/cogsguard_reward_variants.py +153 -0
  16. cogames/cogs_vs_clips/cogsguard_tutorial.py +56 -0
  17. cogames/cogs_vs_clips/evals/README.md +10 -16
  18. cogames/cogs_vs_clips/evals/cogsguard_evals.py +81 -0
  19. cogames/cogs_vs_clips/evals/diagnostic_evals.py +49 -444
  20. cogames/cogs_vs_clips/evals/difficulty_variants.py +13 -326
  21. cogames/cogs_vs_clips/evals/integrated_evals.py +5 -45
  22. cogames/cogs_vs_clips/evals/spanning_evals.py +9 -180
  23. cogames/cogs_vs_clips/mission.py +187 -146
  24. cogames/cogs_vs_clips/missions.py +46 -137
  25. cogames/cogs_vs_clips/procedural.py +8 -8
  26. cogames/cogs_vs_clips/sites.py +107 -3
  27. cogames/cogs_vs_clips/stations.py +198 -186
  28. cogames/cogs_vs_clips/tutorial_missions.py +1 -1
  29. cogames/cogs_vs_clips/variants.py +25 -476
  30. cogames/device.py +13 -1
  31. cogames/{policy/scripted_agent/README.md → docs/SCRIPTED_AGENT.md} +82 -58
  32. cogames/evaluate.py +18 -30
  33. cogames/main.py +1434 -243
  34. cogames/maps/canidate1_1000.map +1 -1
  35. cogames/maps/canidate1_1000_stations.map +2 -2
  36. cogames/maps/canidate1_500.map +1 -1
  37. cogames/maps/canidate1_500_stations.map +2 -2
  38. cogames/maps/canidate2_1000.map +1 -1
  39. cogames/maps/canidate2_1000_stations.map +2 -2
  40. cogames/maps/canidate2_500.map +1 -1
  41. cogames/maps/canidate2_500_stations.map +2 -2
  42. cogames/maps/canidate3_1000.map +1 -1
  43. cogames/maps/canidate3_1000_stations.map +2 -2
  44. cogames/maps/canidate3_500.map +1 -1
  45. cogames/maps/canidate3_500_stations.map +2 -2
  46. cogames/maps/canidate4_500.map +1 -1
  47. cogames/maps/canidate4_500_stations.map +2 -2
  48. cogames/maps/cave_base_50.map +2 -2
  49. cogames/maps/diagnostic_evals/diagnostic_agile.map +2 -2
  50. cogames/maps/diagnostic_evals/diagnostic_agile_hard.map +2 -2
  51. cogames/maps/diagnostic_evals/diagnostic_charge_up.map +2 -2
  52. cogames/maps/diagnostic_evals/diagnostic_charge_up_hard.map +2 -2
  53. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1.map +2 -2
  54. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1_hard.map +2 -2
  55. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2.map +2 -2
  56. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2_hard.map +2 -2
  57. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3.map +2 -2
  58. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3_hard.map +2 -2
  59. cogames/maps/diagnostic_evals/diagnostic_chest_near.map +2 -2
  60. cogames/maps/diagnostic_evals/diagnostic_chest_search.map +2 -2
  61. cogames/maps/diagnostic_evals/diagnostic_chest_search_hard.map +2 -2
  62. cogames/maps/diagnostic_evals/diagnostic_extract_lab.map +2 -2
  63. cogames/maps/diagnostic_evals/diagnostic_extract_lab_hard.map +2 -2
  64. cogames/maps/diagnostic_evals/diagnostic_memory.map +2 -2
  65. cogames/maps/diagnostic_evals/diagnostic_memory_hard.map +2 -2
  66. cogames/maps/diagnostic_evals/diagnostic_radial.map +2 -2
  67. cogames/maps/diagnostic_evals/diagnostic_radial_hard.map +2 -2
  68. cogames/maps/diagnostic_evals/diagnostic_resource_lab.map +2 -2
  69. cogames/maps/diagnostic_evals/diagnostic_unclip.map +2 -2
  70. cogames/maps/evals/eval_balanced_spread.map +9 -5
  71. cogames/maps/evals/eval_clip_oxygen.map +9 -5
  72. cogames/maps/evals/eval_collect_resources.map +9 -5
  73. cogames/maps/evals/eval_collect_resources_hard.map +9 -5
  74. cogames/maps/evals/eval_collect_resources_medium.map +9 -5
  75. cogames/maps/evals/eval_divide_and_conquer.map +9 -5
  76. cogames/maps/evals/eval_energy_starved.map +9 -5
  77. cogames/maps/evals/eval_multi_coordinated_collect_hard.map +9 -5
  78. cogames/maps/evals/eval_oxygen_bottleneck.map +9 -5
  79. cogames/maps/evals/eval_single_use_world.map +9 -5
  80. cogames/maps/evals/extractor_hub_100x100.map +9 -5
  81. cogames/maps/evals/extractor_hub_30x30.map +9 -5
  82. cogames/maps/evals/extractor_hub_50x50.map +9 -5
  83. cogames/maps/evals/extractor_hub_70x70.map +9 -5
  84. cogames/maps/evals/extractor_hub_80x80.map +9 -5
  85. cogames/maps/machina_100_stations.map +2 -2
  86. cogames/maps/machina_200_stations.map +2 -2
  87. cogames/maps/machina_200_stations_small.map +2 -2
  88. cogames/maps/machina_eval_exp01.map +2 -2
  89. cogames/maps/machina_eval_template_large.map +2 -2
  90. cogames/maps/machinatrainer4agents.map +2 -2
  91. cogames/maps/machinatrainer4agentsbase.map +2 -2
  92. cogames/maps/machinatrainerbig.map +2 -2
  93. cogames/maps/machinatrainersmall.map +2 -2
  94. cogames/maps/planky_evals/aligner_avoid_aoe.map +28 -0
  95. cogames/maps/planky_evals/aligner_full_cycle.map +28 -0
  96. cogames/maps/planky_evals/aligner_gear.map +24 -0
  97. cogames/maps/planky_evals/aligner_hearts.map +24 -0
  98. cogames/maps/planky_evals/aligner_junction.map +26 -0
  99. cogames/maps/planky_evals/exploration_distant.map +28 -0
  100. cogames/maps/planky_evals/maze.map +32 -0
  101. cogames/maps/planky_evals/miner_best_resource.map +26 -0
  102. cogames/maps/planky_evals/miner_deposit.map +24 -0
  103. cogames/maps/planky_evals/miner_extract.map +26 -0
  104. cogames/maps/planky_evals/miner_full_cycle.map +28 -0
  105. cogames/maps/planky_evals/miner_gear.map +24 -0
  106. cogames/maps/planky_evals/multi_role.map +28 -0
  107. cogames/maps/planky_evals/resource_chain.map +30 -0
  108. cogames/maps/planky_evals/scout_explore.map +32 -0
  109. cogames/maps/planky_evals/scout_gear.map +24 -0
  110. cogames/maps/planky_evals/scrambler_full_cycle.map +28 -0
  111. cogames/maps/planky_evals/scrambler_gear.map +24 -0
  112. cogames/maps/planky_evals/scrambler_target.map +26 -0
  113. cogames/maps/planky_evals/stuck_corridor.map +32 -0
  114. cogames/maps/planky_evals/survive_retreat.map +26 -0
  115. cogames/maps/training_facility_clipped.map +2 -2
  116. cogames/maps/training_facility_open_1.map +2 -2
  117. cogames/maps/training_facility_open_2.map +2 -2
  118. cogames/maps/training_facility_open_3.map +2 -2
  119. cogames/maps/training_facility_tight_4.map +2 -2
  120. cogames/maps/training_facility_tight_5.map +2 -2
  121. cogames/maps/vanilla_large.map +2 -2
  122. cogames/maps/vanilla_small.map +2 -2
  123. cogames/pickup.py +183 -0
  124. cogames/play.py +166 -33
  125. cogames/policy/chaos_monkey.py +54 -0
  126. cogames/policy/nim_agents/__init__.py +27 -10
  127. cogames/policy/nim_agents/agents.py +121 -60
  128. cogames/policy/nim_agents/thinky_eval.py +35 -222
  129. cogames/policy/pufferlib_policy.py +67 -32
  130. cogames/policy/starter_agent.py +184 -0
  131. cogames/policy/trainable_policy_template.py +4 -1
  132. cogames/train.py +51 -13
  133. cogames/verbose.py +2 -2
  134. cogames-0.3.64.dist-info/METADATA +1842 -0
  135. cogames-0.3.64.dist-info/RECORD +159 -0
  136. cogames-0.3.64.dist-info/licenses/LICENSE +21 -0
  137. cogames-0.3.64.dist-info/top_level.txt +2 -0
  138. metta_alo/__init__.py +0 -0
  139. metta_alo/job_specs.py +17 -0
  140. metta_alo/policy.py +16 -0
  141. metta_alo/pure_single_episode_runner.py +75 -0
  142. metta_alo/py.typed +0 -0
  143. metta_alo/rollout.py +322 -0
  144. metta_alo/scoring.py +168 -0
  145. cogames/maps/diagnostic_evals/diagnostic_assembler_near.map +0 -49
  146. cogames/maps/diagnostic_evals/diagnostic_assembler_search.map +0 -49
  147. cogames/maps/diagnostic_evals/diagnostic_assembler_search_hard.map +0 -89
  148. cogames/policy/nim_agents/common.nim +0 -887
  149. cogames/policy/nim_agents/install.sh +0 -1
  150. cogames/policy/nim_agents/ladybug_agent.nim +0 -984
  151. cogames/policy/nim_agents/nim_agents.nim +0 -55
  152. cogames/policy/nim_agents/nim_agents.nims +0 -14
  153. cogames/policy/nim_agents/nimby.lock +0 -3
  154. cogames/policy/nim_agents/racecar_agents.nim +0 -884
  155. cogames/policy/nim_agents/random_agents.nim +0 -68
  156. cogames/policy/nim_agents/test_agents.py +0 -53
  157. cogames/policy/nim_agents/thinky_agents.nim +0 -717
  158. cogames/policy/scripted_agent/baseline_agent.py +0 -1049
  159. cogames/policy/scripted_agent/demo_policy.py +0 -244
  160. cogames/policy/scripted_agent/pathfinding.py +0 -126
  161. cogames/policy/scripted_agent/starter_agent.py +0 -136
  162. cogames/policy/scripted_agent/types.py +0 -235
  163. cogames/policy/scripted_agent/unclipping_agent.py +0 -476
  164. cogames/policy/scripted_agent/utils.py +0 -385
  165. cogames-0.3.49.dist-info/METADATA +0 -406
  166. cogames-0.3.49.dist-info/RECORD +0 -136
  167. cogames-0.3.49.dist-info/top_level.txt +0 -1
  168. {cogames-0.3.49.dist-info → cogames-0.3.64.dist-info}/WHEEL +0 -0
  169. {cogames-0.3.49.dist-info → cogames-0.3.64.dist-info}/entry_points.txt +0 -0
@@ -1,233 +1,46 @@
1
- # much simpler evaluator for thinky agents.
1
+ """Legacy module path for Thinky eval helpers.
2
2
 
3
- from __future__ import annotations
4
-
5
- import time
6
- from typing import Dict, List, Tuple
7
-
8
- import cogames.policy.nim_agents.agents as na
9
- from cogames.cli.utils import suppress_noisy_logs
10
- from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
11
- from cogames.cogs_vs_clips.mission import Mission, NumCogsVariant
12
- from mettagrid.policy.loader import initialize_or_load_policy
13
- from mettagrid.policy.policy import PolicySpec
14
- from mettagrid.policy.policy_env_interface import PolicyEnvInterface
15
- from mettagrid.simulator.rollout import Rollout
16
-
17
- # Agent to evaluate
18
- AGENT_PATH = "cogames.policy.nim_agents.agents.ThinkyAgentsMultiPolicy"
19
-
20
- # Defaults (keep simple)
21
- NUM_COGS = 4
22
- MAX_STEPS = 10000
23
- SEED = 42
24
-
25
- # Add/modify your evals here over time
26
- EVALS: List[Tuple[str, str, int]] = [
27
- # Regular evals
28
- (
29
- "energy_starved",
30
- "buggy",
31
- NUM_COGS,
32
- ), # E is very hard, max E is 256, but agents think its 100.
33
- ("oxygen_bottleneck", "", NUM_COGS),
34
- ("collect_resources_classic", "", NUM_COGS),
35
- ("collect_resources_spread", "", NUM_COGS),
36
- ("collect_far", "", NUM_COGS),
37
- ("divide_and_conquer", "", NUM_COGS),
38
- ("go_together", "", NUM_COGS),
39
- ("single_use_swarm", "flakey", NUM_COGS),
40
- # Diagnostic evals
41
- ("diagnostic_chest_navigation1", "", 1),
42
- ("diagnostic_chest_navigation2", "", 1),
43
- ("diagnostic_chest_navigation3", "", 1),
44
- ("diagnostic_chest_deposit_near", "", 1),
45
- ("diagnostic_chest_deposit_search", "", 1),
46
- ("diagnostic_charge_up", "buggy", 1), # The cog needs to sacrifice itself to make hart.
47
- ("diagnostic_memory", "", 1),
48
- ("diagnostic_assemble_seeded_near", "", 1),
49
- ("diagnostic_assemble_seeded_search", "", 1),
50
- ("diagnostic_extract_missing_carbon", "", 1),
51
- ("diagnostic_extract_missing_oxygen", "", 1),
52
- ("diagnostic_extract_missing_germanium", "", 1),
53
- ("diagnostic_extract_missing_silicon", "", 1),
54
- ("diagnostic_unclip_craft", "", 1),
55
- ("diagnostic_unclip_preseed", "", 1),
56
- ("diagnostic_agile", "", 1),
57
- ("diagnostic_radial", "", 1),
58
- # Hello World evals
59
- ("distant_resources", "buggy", NUM_COGS), # Not enough time for such distances.
60
- ("quadrant_buildings", "buggy", NUM_COGS), # Not enough charger for such distances.
61
- ("vibe_check", "", NUM_COGS),
62
- ("oxygen_bottleneck_easy", "", NUM_COGS),
63
- ("oxygen_bottleneck_standard", "", NUM_COGS),
64
- ("oxygen_bottleneck_hard", "buggy", NUM_COGS), # Not enough charger for such distances.
65
- ("energy_starved_easy", "", NUM_COGS),
66
- ("energy_starved_standard", "buggy", NUM_COGS), # E drain too high.
67
- ("energy_starved_hard", "buggy", NUM_COGS), # E drain too high.
68
- ("unclipping_easy", "n/a", NUM_COGS),
69
- ("unclipping_standard", "n/a", NUM_COGS),
70
- ("unclipping_hard", "n/a", NUM_COGS),
71
- ("distant_resources_easy", "", NUM_COGS),
72
- ("distant_resources_standard", "flakey", NUM_COGS), # Not enough time for such distances.
73
- ("distant_resources_hard", "buggy", NUM_COGS), # Not enough time for such distances.
74
- ("quadrant_buildings_easy", "", NUM_COGS),
75
- ("quadrant_buildings_standard", "buggy", NUM_COGS), # Not enough charger for such distances.
76
- ("quadrant_buildings_hard", "buggy", NUM_COGS), # Not enough charger for such distances.
77
- ("single_use_swarm_easy", "buggy", NUM_COGS),
78
- ("single_use_swarm_standard", "buggy", NUM_COGS), # Not enough time for such distances.
79
- ("single_use_swarm_hard", "buggy", NUM_COGS), # E drain too high.
80
- ("vibe_check_easy", "buggy", NUM_COGS), # No/invalid recipes available.
81
- ("vibe_check_standard", "", NUM_COGS),
82
- ("vibe_check_hard", "flakey", NUM_COGS), # Not enough time for such distances.
83
- # Hearts evals
84
- ("easy_large_hearts", "slow", NUM_COGS),
85
- ("easy_medium_hearts", "", NUM_COGS),
86
- ("easy_small_hearts", "flakey", NUM_COGS),
87
- # Missions from missions.py
88
- ("harvest", "", NUM_COGS),
89
- ("repair", "", 2), # repair uses 2 cogs
90
- ("hello_world_unclip", "", NUM_COGS),
91
- ]
92
-
93
-
94
- def _load_all_missions() -> Dict[str, Mission]:
95
- # Minimal loader: merge all known mission sets
96
- from importlib import import_module
97
-
98
- missions: List[Mission] = []
99
- for mod_name in (
100
- "cogames.cogs_vs_clips.evals.eval_missions",
101
- "cogames.cogs_vs_clips.evals.integrated_evals",
102
- "cogames.cogs_vs_clips.evals.spanning_evals",
103
- "cogames.cogs_vs_clips.missions",
104
- ):
105
- try:
106
- mod = import_module(mod_name)
107
- # missions.py uses MISSIONS, others use EVAL_MISSIONS
108
- eval_list = getattr(mod, "MISSIONS", getattr(mod, "EVAL_MISSIONS", []))
109
- missions.extend(eval_list)
110
- except Exception:
111
- pass
3
+ The implementation moved to `cogames_agents.policy.nim_agents.thinky_eval`.
4
+ """
112
5
 
113
- # Diagnostic evals are a list of classes; instantiate them
114
- try:
115
- missions.extend([cls() for cls in DIAGNOSTIC_EVALS]) # type: ignore[misc]
116
- except Exception:
117
- pass
118
-
119
- # Build name -> mission instance map
120
- mission_map: Dict[str, Mission] = {}
121
- for m in missions:
122
- # Items in EVAL_MISSIONS may be classes or instances; normalize to instances
123
- try:
124
- mission: Mission = m() if isinstance(m, type) else m # type: ignore[call-arg,assignment]
125
- except Exception:
126
- continue
127
- mission_map[mission.name] = mission
128
- return mission_map
129
-
130
-
131
- def _ensure_vibe_supports_gear(env_cfg) -> None:
132
- # Keep minimal and silent if anything fails
133
- try:
134
- assembler = env_cfg.game.objects.get("assembler")
135
- uses_gear = False
136
- if assembler is not None and hasattr(assembler, "protocols"):
137
- for proto in assembler.protocols:
138
- if any(v == "gear" for v in getattr(proto, "vibes", [])):
139
- uses_gear = True
140
- break
141
- if uses_gear:
142
- change_vibe = env_cfg.game.actions.change_vibe
143
- has_gear = any(v.name == "gear" for v in change_vibe.vibes)
144
- if not has_gear:
145
- from mettagrid.config.vibes import VIBE_BY_NAME
146
-
147
- change_vibe.vibes = list(change_vibe.vibes) + [VIBE_BY_NAME["gear"]]
148
- except Exception:
149
- pass
150
-
151
-
152
- def run_eval(experiment_name: str, tag: str, mission_map: Dict[str, Mission], num_cogs: int, seed: int) -> float:
153
- start = time.perf_counter()
154
- try:
155
- if experiment_name not in mission_map:
156
- print(f"{tag:<6} {experiment_name:<40} {'MISSION NOT FOUND':>6}")
157
- return 0.0
158
-
159
- base_mission = mission_map[experiment_name]
160
- mission = base_mission.with_variants([NumCogsVariant(num_cogs=num_cogs)])
6
+ from __future__ import annotations
161
7
 
162
- env_cfg = mission.make_env()
163
- _ensure_vibe_supports_gear(env_cfg)
164
- env_cfg.game.max_steps = MAX_STEPS
8
+ import importlib
9
+ from typing import Any
165
10
 
166
- # Create policy and rollout
167
- pei = PolicyEnvInterface.from_mg_cfg(env_cfg)
168
- policy = initialize_or_load_policy(
169
- pei,
170
- PolicySpec(class_path=AGENT_PATH, data_path=None),
171
- )
172
- agent_policies = [policy.agent_policy(i) for i in range(num_cogs)]
11
+ try:
12
+ _thinky_eval = importlib.import_module("cogames_agents.policy.nim_agents.thinky_eval")
13
+ except ModuleNotFoundError as exc:
14
+ if exc.name and (exc.name == "cogames_agents" or exc.name.startswith("cogames_agents.")):
15
+ raise ModuleNotFoundError(
16
+ "Legacy import `cogames.policy.nim_agents.thinky_eval` requires optional dependency "
17
+ "`cogames-agents` (install `cogames[agents]`)."
18
+ ) from exc
19
+ raise
173
20
 
174
- rollout = Rollout(
175
- env_cfg,
176
- agent_policies,
177
- render_mode="none",
178
- seed=seed,
179
- )
180
- rollout.run_until_done()
21
+ AGENT_PATH = _thinky_eval.AGENT_PATH
22
+ EVALS = _thinky_eval.EVALS
23
+ MAX_STEPS = _thinky_eval.MAX_STEPS
24
+ NUM_COGS = _thinky_eval.NUM_COGS
25
+ SEED = _thinky_eval.SEED
26
+ main = _thinky_eval.main
27
+ run_eval = _thinky_eval.run_eval
181
28
 
182
- total_reward = float(sum(rollout._sim.episode_rewards))
183
- hearts_per_agent = total_reward / max(1, num_cogs)
184
- elapsed = time.perf_counter() - start
185
29
 
186
- # One simple line per eval
187
- hpa = f"{hearts_per_agent:.2f}"
188
- tm = f"{elapsed:.2f}"
189
- print(f"{tag:<6} {experiment_name:<40} {hpa:>6}h {tm:>6}s")
190
- return hearts_per_agent
191
- except Exception as e:
192
- elapsed = time.perf_counter() - start
193
- error_message = str(e)
194
- print(f"{tag:<6} {experiment_name:<40} {error_message}")
195
- return 0.0
30
+ def __getattr__(name: str) -> Any:
31
+ return getattr(_thinky_eval, name)
196
32
 
197
33
 
198
- def main() -> None:
199
- suppress_noisy_logs()
200
- na.start_measure()
201
- mission_map = _load_all_missions()
202
- print(f"Loaded {len(mission_map)} missions")
203
- print("tag .. map name ............................... harts/A .. time")
204
- start = time.perf_counter()
205
- total_hpa = 0.0
206
- successful_evals = 0
207
- num_evals = 0
208
- for experiment_name, tag, num_cogs in EVALS:
209
- num_evals += 1
210
- if tag == "flakey":
211
- for i in range(10):
212
- hpa = run_eval(experiment_name, tag, mission_map, num_cogs, SEED + i)
213
- if hpa > 0:
214
- successful_evals += 1
215
- total_hpa += hpa
216
- break
217
- else:
218
- hpa = run_eval(experiment_name, tag, mission_map, num_cogs, SEED)
219
- if hpa > 0:
220
- successful_evals += 1
221
- total_hpa += hpa
222
- success_rate = successful_evals / num_evals
223
- elapsed = time.perf_counter() - start
224
- total_evals = f"{num_evals} evals {success_rate * 100:.1f}% successful"
225
- hpa = f"{total_hpa:.2f}"
226
- tm = f"{elapsed:.2f}"
227
- tag = "total"
228
- print(f"{tag:<6} {total_evals:<40} {hpa:>6}h {tm:>6}s")
229
- na.end_measure()
34
+ def __dir__() -> list[str]:
35
+ return sorted(set(globals()).union(dir(_thinky_eval)))
230
36
 
231
37
 
232
- if __name__ == "__main__":
233
- main()
38
+ __all__ = [
39
+ "AGENT_PATH",
40
+ "EVALS",
41
+ "MAX_STEPS",
42
+ "NUM_COGS",
43
+ "SEED",
44
+ "main",
45
+ "run_eval",
46
+ ]
@@ -15,8 +15,9 @@ import torch
15
15
 
16
16
  import pufferlib.models # type: ignore[import-untyped]
17
17
  import pufferlib.pytorch # type: ignore[import-untyped]
18
- from mettagrid.policy.policy import AgentPolicy, MultiAgentPolicy
18
+ from mettagrid.policy.policy import AgentPolicy, MultiAgentPolicy, StatefulAgentPolicy
19
19
  from mettagrid.policy.policy_env_interface import PolicyEnvInterface
20
+ from mettagrid.policy.pufferlib import PufferlibStatefulImpl
20
21
  from mettagrid.simulator import Action, AgentObservation, Simulation
21
22
 
22
23
 
@@ -24,7 +25,7 @@ class PufferlibCogsPolicy(MultiAgentPolicy, AgentPolicy):
24
25
  """Loads and runs checkpoints trained with PufferLib's CoGames policy.
25
26
 
26
27
  This policy serves as both the MultiAgentPolicy factory and AgentPolicy
27
- implementation, returning itself from agent_policy().
28
+ implementation, returning per-agent wrappers that track state.
28
29
  """
29
30
 
30
31
  short_names = ["pufferlib_cogs"]
@@ -38,64 +39,98 @@ class PufferlibCogsPolicy(MultiAgentPolicy, AgentPolicy):
38
39
  ):
39
40
  MultiAgentPolicy.__init__(self, policy_env_info, device=device)
40
41
  AgentPolicy.__init__(self, policy_env_info)
41
- shim_env = SimpleNamespace(
42
+ self._hidden_size = hidden_size
43
+ self._device = torch.device(device)
44
+ self._shim_env = SimpleNamespace(
42
45
  single_observation_space=policy_env_info.observation_space,
43
46
  single_action_space=policy_env_info.action_space,
44
47
  observation_space=policy_env_info.observation_space,
45
48
  action_space=policy_env_info.action_space,
46
49
  num_agents=policy_env_info.num_agents,
47
50
  )
48
- shim_env.env = shim_env
49
- self._net = pufferlib.models.Default(shim_env, hidden_size=hidden_size) # type: ignore[arg-type]
50
- self._net = self._net.to(torch.device(device))
51
+ self._shim_env.env = self._shim_env
52
+ self._net = pufferlib.models.Default(self._shim_env, hidden_size=hidden_size).to(self._device) # type: ignore[arg-type]
51
53
  self._action_names = policy_env_info.action_names
52
- self._num_tokens, self._token_dim = policy_env_info.observation_space.shape
53
- self._device = next(self._net.parameters()).device
54
+ self._is_recurrent = False
55
+ self._stateful_impl = PufferlibStatefulImpl(
56
+ self._net,
57
+ policy_env_info,
58
+ self._device,
59
+ is_recurrent=self._is_recurrent,
60
+ )
61
+ self._agent_policies: dict[int, StatefulAgentPolicy[dict[str, torch.Tensor | None]]] = {}
62
+ self._state_initialized = False
63
+ self._state: dict[str, torch.Tensor | None] = {}
54
64
 
55
65
  def network(self) -> torch.nn.Module: # type: ignore[override]
56
66
  return self._net
57
67
 
58
68
  def agent_policy(self, agent_id: int) -> AgentPolicy: # type: ignore[override]
59
- return self
69
+ if agent_id not in self._agent_policies:
70
+ self._agent_policies[agent_id] = StatefulAgentPolicy(
71
+ self._stateful_impl,
72
+ self._policy_env_info,
73
+ agent_id=agent_id,
74
+ )
75
+ return self._agent_policies[agent_id]
60
76
 
61
77
  def is_recurrent(self) -> bool:
62
- return False
78
+ return self._is_recurrent
63
79
 
64
80
  def reset(self, simulation: Optional[Simulation] = None) -> None: # type: ignore[override]
65
- # No internal state to reset; signature satisfies AgentPolicy and MultiAgentPolicy
66
- return None
81
+ for policy in self._agent_policies.values():
82
+ policy.reset(simulation)
83
+ self._reset_state()
67
84
 
68
85
  def load_policy_data(self, policy_data_path: str) -> None:
69
- state = torch.load(policy_data_path, map_location=next(self._net.parameters()).device)
70
- self._net.load_state_dict(state)
71
- self._net = self._net.to(next(self._net.parameters()).device)
86
+ state = torch.load(policy_data_path, map_location=self._device)
87
+ state = {k.replace("module.", ""): v for k, v in state.items()}
88
+ uses_rnn = any(key.startswith(("lstm.", "cell.")) for key in state)
89
+ base_net = pufferlib.models.Default(self._shim_env, hidden_size=self._hidden_size) # type: ignore[arg-type]
90
+ net = (
91
+ pufferlib.models.LSTMWrapper(
92
+ self._shim_env,
93
+ base_net,
94
+ input_size=base_net.hidden_size,
95
+ hidden_size=base_net.hidden_size,
96
+ )
97
+ if uses_rnn
98
+ else base_net
99
+ )
100
+ net.load_state_dict(state)
101
+ self._net = net.to(self._device)
102
+ self._is_recurrent = uses_rnn
103
+ self._stateful_impl = PufferlibStatefulImpl(
104
+ self._net,
105
+ self._policy_env_info,
106
+ self._device,
107
+ is_recurrent=self._is_recurrent,
108
+ )
109
+ self._agent_policies.clear()
110
+ self._state_initialized = False
111
+ self._state = {}
72
112
 
73
113
  def save_policy_data(self, policy_data_path: str) -> None:
74
114
  torch.save(self._net.state_dict(), policy_data_path)
75
115
 
76
116
  def step(self, obs: Union[AgentObservation, torch.Tensor, Sequence[Any]]) -> Action: # type: ignore[override]
77
117
  if isinstance(obs, AgentObservation):
78
- obs_tensor = torch.full(
79
- (self._num_tokens, self._token_dim),
80
- fill_value=255.0,
81
- device=self._device,
82
- dtype=torch.float32,
83
- )
84
- for idx, token in enumerate(obs.tokens):
85
- if idx >= self._num_tokens:
86
- break
87
- raw = torch.as_tensor(token.raw_token, device=self._device, dtype=obs_tensor.dtype)
88
- obs_tensor[idx, : raw.numel()] = raw
89
- else:
90
- obs_tensor = torch.as_tensor(obs, device=self._device, dtype=torch.float32)
91
-
92
- obs_tensor = obs_tensor * (1.0 / 255.0)
118
+ if not self._state_initialized:
119
+ self._reset_state()
120
+ with torch.no_grad():
121
+ action, self._state = self._stateful_impl.step_with_state(obs, self._state)
122
+ return action
123
+ obs_tensor = torch.as_tensor(obs, device=self._device, dtype=torch.float32)
93
124
  if obs_tensor.ndim == 2:
94
125
  obs_tensor = obs_tensor.unsqueeze(0)
95
-
96
126
  with torch.no_grad():
97
127
  self._net.eval()
98
- logits, _ = self._net.forward_eval(obs_tensor)
128
+ logits, _ = self._net.forward_eval(obs_tensor, None)
99
129
  sampled, _, _ = pufferlib.pytorch.sample_logits(logits)
100
130
  action_idx = max(0, min(int(sampled.item()), len(self._action_names) - 1))
101
131
  return Action(name=self._action_names[action_idx])
132
+
133
+ def _reset_state(self) -> None:
134
+ self._stateful_impl.reset()
135
+ self._state = self._stateful_impl.initial_agent_state()
136
+ self._state_initialized = True
@@ -0,0 +1,184 @@
1
+ """
2
+ Sample policy for the CoGames CogsGuard environment.
3
+
4
+ This starter policy uses simple heuristics:
5
+ - If the agent has no gear, head toward the nearest gear station.
6
+ - If the agent has aligner or scrambler gear, try to get hearts (and influence for aligner) then head to junctions.
7
+ - If the agent has miner gear, head to extractors.
8
+ - If the agent has scout gear, explore in a simple pattern.
9
+
10
+ Note to users of this policy:
11
+ We don't intend for scripted policies to be the final word on how policies are generated (e.g., we expect the
12
+ environment to be complicated enough that trained agents will be necessary). So we expect that scripting policies
13
+ is a good way to start, but don't want you to get stuck here. Feel free to prove us wrong!
14
+
15
+ Note to cogames developers:
16
+ This policy should be kept relatively minimalist, without dependencies on intricate algorithms.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from dataclasses import dataclass
22
+ from typing import Iterable, Optional
23
+
24
+ from mettagrid.policy.policy import MultiAgentPolicy, StatefulAgentPolicy, StatefulPolicyImpl
25
+ from mettagrid.policy.policy_env_interface import PolicyEnvInterface
26
+ from mettagrid.simulator import Action
27
+ from mettagrid.simulator.interface import AgentObservation
28
+
29
+ GEAR = ("aligner", "scrambler", "miner", "scout")
30
+ ELEMENTS = ("carbon", "oxygen", "germanium", "silicon")
31
+ WANDER_DIRECTIONS = ("east", "south", "west", "north")
32
+ WANDER_STEPS = 8
33
+
34
+
35
+ @dataclass
36
+ class StarterCogState:
37
+ wander_direction_index: int = 0
38
+ wander_steps_remaining: int = WANDER_STEPS
39
+
40
+
41
+ class StarterCogPolicyImpl(StatefulPolicyImpl[StarterCogState]):
42
+ def __init__(
43
+ self,
44
+ policy_env_info: PolicyEnvInterface,
45
+ agent_id: int,
46
+ ):
47
+ self._agent_id = agent_id
48
+ self._policy_env_info = policy_env_info
49
+
50
+ self._action_names = policy_env_info.action_names
51
+ self._action_name_set = set(self._action_names)
52
+ self._fallback_action_name = "noop" if "noop" in self._action_name_set else self._action_names[0]
53
+ self._center = (policy_env_info.obs_height // 2, policy_env_info.obs_width // 2)
54
+ self._tag_name_to_id = {name: idx for idx, name in enumerate(policy_env_info.tags)}
55
+ self._gear_station_tags = self._resolve_tag_ids([f"{gear}_station" for gear in GEAR])
56
+ self._extractor_tags = self._resolve_tag_ids([f"{element}_extractor" for element in ELEMENTS])
57
+ self._junction_tags = self._resolve_tag_ids(["junction"])
58
+ self._chest_tags = self._resolve_tag_ids(["chest"])
59
+ self._hub_tags = self._resolve_tag_ids(["hub"])
60
+
61
+ def _resolve_tag_ids(self, names: Iterable[str]) -> set[int]:
62
+ tag_ids: set[int] = set()
63
+ for name in names:
64
+ if name in self._tag_name_to_id:
65
+ tag_ids.add(self._tag_name_to_id[name])
66
+ if name.startswith("type:"):
67
+ continue
68
+ type_name = f"type:{name}"
69
+ if type_name in self._tag_name_to_id:
70
+ tag_ids.add(self._tag_name_to_id[type_name])
71
+ return tag_ids
72
+
73
+ def _inventory_items(self, obs: AgentObservation) -> set[str]:
74
+ items: set[str] = set()
75
+ for token in obs.tokens:
76
+ if token.location != self._center:
77
+ continue
78
+ name = token.feature.name
79
+ if not name.startswith("inv:"):
80
+ continue
81
+ parts = name.split(":", 2)
82
+ if len(parts) >= 2:
83
+ items.add(parts[1])
84
+ return items
85
+
86
+ def _closest_tag_location(self, obs: AgentObservation, tag_ids: set[int]) -> Optional[tuple[int, int]]:
87
+ if not tag_ids:
88
+ return None
89
+ best_location: Optional[tuple[int, int]] = None
90
+ best_distance = 999
91
+ for token in obs.tokens:
92
+ if token.feature.name != "tag":
93
+ continue
94
+ if token.value not in tag_ids:
95
+ continue
96
+ distance = abs(token.location[0] - self._center[0]) + abs(token.location[1] - self._center[1])
97
+ if distance < best_distance:
98
+ best_distance = distance
99
+ best_location = token.location
100
+ return best_location
101
+
102
+ def _action(self, name: str) -> Action:
103
+ if name in self._action_name_set:
104
+ return Action(name=name)
105
+ return Action(name=self._fallback_action_name)
106
+
107
+ def _wander(self, state: StarterCogState) -> tuple[Action, StarterCogState]:
108
+ if state.wander_steps_remaining <= 0:
109
+ state.wander_direction_index = (state.wander_direction_index + 1) % len(WANDER_DIRECTIONS)
110
+ state.wander_steps_remaining = WANDER_STEPS
111
+ direction = WANDER_DIRECTIONS[state.wander_direction_index]
112
+ state.wander_steps_remaining -= 1
113
+ return self._action(f"move_{direction}"), state
114
+
115
+ def _move_toward(self, state: StarterCogState, target: Optional[tuple[int, int]]) -> tuple[Action, StarterCogState]:
116
+ if target is None:
117
+ return self._wander(state)
118
+ delta_row = target[0] - self._center[0]
119
+ delta_col = target[1] - self._center[1]
120
+ if delta_row == 0 and delta_col == 0:
121
+ return self._action(self._fallback_action_name), state
122
+ if abs(delta_row) >= abs(delta_col):
123
+ direction = "south" if delta_row > 0 else "north"
124
+ else:
125
+ direction = "east" if delta_col > 0 else "west"
126
+ return self._action(f"move_{direction}"), state
127
+
128
+ def _current_gear(self, items: set[str]) -> Optional[str]:
129
+ for gear in GEAR:
130
+ if gear in items:
131
+ return gear
132
+ return None
133
+
134
+ def step_with_state(self, obs: AgentObservation, state: StarterCogState) -> tuple[Action, StarterCogState]:
135
+ """Compute the action for this Cog."""
136
+ items = self._inventory_items(obs)
137
+ gear = self._current_gear(items)
138
+ has_heart = "heart" in items
139
+ has_influence = "influence" in items
140
+
141
+ if gear is None:
142
+ target_tags = self._gear_station_tags
143
+ elif gear == "aligner":
144
+ if has_heart and has_influence:
145
+ target_tags = self._junction_tags
146
+ elif not has_heart:
147
+ target_tags = self._chest_tags
148
+ else:
149
+ target_tags = self._hub_tags
150
+ elif gear == "scrambler":
151
+ target_tags = self._junction_tags if has_heart else self._chest_tags
152
+ elif gear == "miner":
153
+ target_tags = self._extractor_tags
154
+ else:
155
+ target_tags = set()
156
+
157
+ target_location = self._closest_tag_location(obs, target_tags) if target_tags else None
158
+ return self._move_toward(state, target_location)
159
+
160
+ def initial_agent_state(self) -> StarterCogState:
161
+ """Get the initial state for a new agent."""
162
+ return StarterCogState()
163
+
164
+
165
+ # ============================================================================
166
+ # Policy Wrapper Classes
167
+ # ============================================================================
168
+
169
+
170
+ class StarterPolicy(MultiAgentPolicy):
171
+ short_names = ["starter"]
172
+
173
+ def __init__(self, policy_env_info: PolicyEnvInterface, device: str = "cpu"):
174
+ super().__init__(policy_env_info, device=device)
175
+ self._agent_policies: dict[int, StatefulAgentPolicy[StarterCogState]] = {}
176
+
177
+ def agent_policy(self, agent_id: int) -> StatefulAgentPolicy[StarterCogState]:
178
+ if agent_id not in self._agent_policies:
179
+ self._agent_policies[agent_id] = StatefulAgentPolicy(
180
+ StarterCogPolicyImpl(self._policy_env_info, agent_id),
181
+ self._policy_env_info,
182
+ agent_id=agent_id,
183
+ )
184
+ return self._agent_policies[agent_id]
@@ -1,6 +1,9 @@
1
1
  """
2
2
  Trainable Policy Template for the CoGames environment.
3
3
 
4
+ This template is compatible with CogsGuard missions. It uses only the observation and action
5
+ spaces provided by the environment and makes no game-specific assumptions.
6
+
4
7
  This template provides a minimal trainable neural network policy that can be used with
5
8
  `cogames tutorial train`. It demonstrates the key interfaces required for training:
6
9
 
@@ -14,7 +17,7 @@ clarity and without the pufferlib dependency.
14
17
 
15
18
  To use this template:
16
19
  1. Modify MyNetwork to implement your desired architecture
17
- 2. Run: cogames tutorial train -m training_facility.harvest -p class=my_trainable_policy.MyTrainablePolicy
20
+ 2. Run: cogames tutorial train -m cogsguard_machina_1.basic -p class=my_trainable_policy.MyTrainablePolicy
18
21
  """
19
22
 
20
23
  from __future__ import annotations