cogames 0.3.49__py3-none-any.whl → 0.3.64__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cogames/cli/client.py +60 -6
- cogames/cli/docsync/__init__.py +0 -0
- cogames/cli/docsync/_nb_md_directive_processing.py +180 -0
- cogames/cli/docsync/_nb_md_sync.py +103 -0
- cogames/cli/docsync/_nb_py_sync.py +122 -0
- cogames/cli/docsync/_three_way_sync.py +115 -0
- cogames/cli/docsync/_utils.py +76 -0
- cogames/cli/docsync/docsync.py +156 -0
- cogames/cli/leaderboard.py +112 -28
- cogames/cli/mission.py +64 -53
- cogames/cli/policy.py +46 -10
- cogames/cli/submit.py +268 -67
- cogames/cogs_vs_clips/cog.py +79 -0
- cogames/cogs_vs_clips/cogs_vs_clips_mapgen.md +19 -16
- cogames/cogs_vs_clips/cogsguard_reward_variants.py +153 -0
- cogames/cogs_vs_clips/cogsguard_tutorial.py +56 -0
- cogames/cogs_vs_clips/evals/README.md +10 -16
- cogames/cogs_vs_clips/evals/cogsguard_evals.py +81 -0
- cogames/cogs_vs_clips/evals/diagnostic_evals.py +49 -444
- cogames/cogs_vs_clips/evals/difficulty_variants.py +13 -326
- cogames/cogs_vs_clips/evals/integrated_evals.py +5 -45
- cogames/cogs_vs_clips/evals/spanning_evals.py +9 -180
- cogames/cogs_vs_clips/mission.py +187 -146
- cogames/cogs_vs_clips/missions.py +46 -137
- cogames/cogs_vs_clips/procedural.py +8 -8
- cogames/cogs_vs_clips/sites.py +107 -3
- cogames/cogs_vs_clips/stations.py +198 -186
- cogames/cogs_vs_clips/tutorial_missions.py +1 -1
- cogames/cogs_vs_clips/variants.py +25 -476
- cogames/device.py +13 -1
- cogames/{policy/scripted_agent/README.md → docs/SCRIPTED_AGENT.md} +82 -58
- cogames/evaluate.py +18 -30
- cogames/main.py +1434 -243
- cogames/maps/canidate1_1000.map +1 -1
- cogames/maps/canidate1_1000_stations.map +2 -2
- cogames/maps/canidate1_500.map +1 -1
- cogames/maps/canidate1_500_stations.map +2 -2
- cogames/maps/canidate2_1000.map +1 -1
- cogames/maps/canidate2_1000_stations.map +2 -2
- cogames/maps/canidate2_500.map +1 -1
- cogames/maps/canidate2_500_stations.map +2 -2
- cogames/maps/canidate3_1000.map +1 -1
- cogames/maps/canidate3_1000_stations.map +2 -2
- cogames/maps/canidate3_500.map +1 -1
- cogames/maps/canidate3_500_stations.map +2 -2
- cogames/maps/canidate4_500.map +1 -1
- cogames/maps/canidate4_500_stations.map +2 -2
- cogames/maps/cave_base_50.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_agile.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_agile_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_charge_up.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_charge_up_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation1.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation1_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation2.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation2_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation3.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_chest_navigation3_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_chest_near.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_chest_search.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_chest_search_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_extract_lab.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_extract_lab_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_memory.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_memory_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_radial.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_radial_hard.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_resource_lab.map +2 -2
- cogames/maps/diagnostic_evals/diagnostic_unclip.map +2 -2
- cogames/maps/evals/eval_balanced_spread.map +9 -5
- cogames/maps/evals/eval_clip_oxygen.map +9 -5
- cogames/maps/evals/eval_collect_resources.map +9 -5
- cogames/maps/evals/eval_collect_resources_hard.map +9 -5
- cogames/maps/evals/eval_collect_resources_medium.map +9 -5
- cogames/maps/evals/eval_divide_and_conquer.map +9 -5
- cogames/maps/evals/eval_energy_starved.map +9 -5
- cogames/maps/evals/eval_multi_coordinated_collect_hard.map +9 -5
- cogames/maps/evals/eval_oxygen_bottleneck.map +9 -5
- cogames/maps/evals/eval_single_use_world.map +9 -5
- cogames/maps/evals/extractor_hub_100x100.map +9 -5
- cogames/maps/evals/extractor_hub_30x30.map +9 -5
- cogames/maps/evals/extractor_hub_50x50.map +9 -5
- cogames/maps/evals/extractor_hub_70x70.map +9 -5
- cogames/maps/evals/extractor_hub_80x80.map +9 -5
- cogames/maps/machina_100_stations.map +2 -2
- cogames/maps/machina_200_stations.map +2 -2
- cogames/maps/machina_200_stations_small.map +2 -2
- cogames/maps/machina_eval_exp01.map +2 -2
- cogames/maps/machina_eval_template_large.map +2 -2
- cogames/maps/machinatrainer4agents.map +2 -2
- cogames/maps/machinatrainer4agentsbase.map +2 -2
- cogames/maps/machinatrainerbig.map +2 -2
- cogames/maps/machinatrainersmall.map +2 -2
- cogames/maps/planky_evals/aligner_avoid_aoe.map +28 -0
- cogames/maps/planky_evals/aligner_full_cycle.map +28 -0
- cogames/maps/planky_evals/aligner_gear.map +24 -0
- cogames/maps/planky_evals/aligner_hearts.map +24 -0
- cogames/maps/planky_evals/aligner_junction.map +26 -0
- cogames/maps/planky_evals/exploration_distant.map +28 -0
- cogames/maps/planky_evals/maze.map +32 -0
- cogames/maps/planky_evals/miner_best_resource.map +26 -0
- cogames/maps/planky_evals/miner_deposit.map +24 -0
- cogames/maps/planky_evals/miner_extract.map +26 -0
- cogames/maps/planky_evals/miner_full_cycle.map +28 -0
- cogames/maps/planky_evals/miner_gear.map +24 -0
- cogames/maps/planky_evals/multi_role.map +28 -0
- cogames/maps/planky_evals/resource_chain.map +30 -0
- cogames/maps/planky_evals/scout_explore.map +32 -0
- cogames/maps/planky_evals/scout_gear.map +24 -0
- cogames/maps/planky_evals/scrambler_full_cycle.map +28 -0
- cogames/maps/planky_evals/scrambler_gear.map +24 -0
- cogames/maps/planky_evals/scrambler_target.map +26 -0
- cogames/maps/planky_evals/stuck_corridor.map +32 -0
- cogames/maps/planky_evals/survive_retreat.map +26 -0
- cogames/maps/training_facility_clipped.map +2 -2
- cogames/maps/training_facility_open_1.map +2 -2
- cogames/maps/training_facility_open_2.map +2 -2
- cogames/maps/training_facility_open_3.map +2 -2
- cogames/maps/training_facility_tight_4.map +2 -2
- cogames/maps/training_facility_tight_5.map +2 -2
- cogames/maps/vanilla_large.map +2 -2
- cogames/maps/vanilla_small.map +2 -2
- cogames/pickup.py +183 -0
- cogames/play.py +166 -33
- cogames/policy/chaos_monkey.py +54 -0
- cogames/policy/nim_agents/__init__.py +27 -10
- cogames/policy/nim_agents/agents.py +121 -60
- cogames/policy/nim_agents/thinky_eval.py +35 -222
- cogames/policy/pufferlib_policy.py +67 -32
- cogames/policy/starter_agent.py +184 -0
- cogames/policy/trainable_policy_template.py +4 -1
- cogames/train.py +51 -13
- cogames/verbose.py +2 -2
- cogames-0.3.64.dist-info/METADATA +1842 -0
- cogames-0.3.64.dist-info/RECORD +159 -0
- cogames-0.3.64.dist-info/licenses/LICENSE +21 -0
- cogames-0.3.64.dist-info/top_level.txt +2 -0
- metta_alo/__init__.py +0 -0
- metta_alo/job_specs.py +17 -0
- metta_alo/policy.py +16 -0
- metta_alo/pure_single_episode_runner.py +75 -0
- metta_alo/py.typed +0 -0
- metta_alo/rollout.py +322 -0
- metta_alo/scoring.py +168 -0
- cogames/maps/diagnostic_evals/diagnostic_assembler_near.map +0 -49
- cogames/maps/diagnostic_evals/diagnostic_assembler_search.map +0 -49
- cogames/maps/diagnostic_evals/diagnostic_assembler_search_hard.map +0 -89
- cogames/policy/nim_agents/common.nim +0 -887
- cogames/policy/nim_agents/install.sh +0 -1
- cogames/policy/nim_agents/ladybug_agent.nim +0 -984
- cogames/policy/nim_agents/nim_agents.nim +0 -55
- cogames/policy/nim_agents/nim_agents.nims +0 -14
- cogames/policy/nim_agents/nimby.lock +0 -3
- cogames/policy/nim_agents/racecar_agents.nim +0 -884
- cogames/policy/nim_agents/random_agents.nim +0 -68
- cogames/policy/nim_agents/test_agents.py +0 -53
- cogames/policy/nim_agents/thinky_agents.nim +0 -717
- cogames/policy/scripted_agent/baseline_agent.py +0 -1049
- cogames/policy/scripted_agent/demo_policy.py +0 -244
- cogames/policy/scripted_agent/pathfinding.py +0 -126
- cogames/policy/scripted_agent/starter_agent.py +0 -136
- cogames/policy/scripted_agent/types.py +0 -235
- cogames/policy/scripted_agent/unclipping_agent.py +0 -476
- cogames/policy/scripted_agent/utils.py +0 -385
- cogames-0.3.49.dist-info/METADATA +0 -406
- cogames-0.3.49.dist-info/RECORD +0 -136
- cogames-0.3.49.dist-info/top_level.txt +0 -1
- {cogames-0.3.49.dist-info → cogames-0.3.64.dist-info}/WHEEL +0 -0
- {cogames-0.3.49.dist-info → cogames-0.3.64.dist-info}/entry_points.txt +0 -0
|
@@ -1,233 +1,46 @@
|
|
|
1
|
-
|
|
1
|
+
"""Legacy module path for Thinky eval helpers.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
import time
|
|
6
|
-
from typing import Dict, List, Tuple
|
|
7
|
-
|
|
8
|
-
import cogames.policy.nim_agents.agents as na
|
|
9
|
-
from cogames.cli.utils import suppress_noisy_logs
|
|
10
|
-
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
|
|
11
|
-
from cogames.cogs_vs_clips.mission import Mission, NumCogsVariant
|
|
12
|
-
from mettagrid.policy.loader import initialize_or_load_policy
|
|
13
|
-
from mettagrid.policy.policy import PolicySpec
|
|
14
|
-
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
|
|
15
|
-
from mettagrid.simulator.rollout import Rollout
|
|
16
|
-
|
|
17
|
-
# Agent to evaluate
|
|
18
|
-
AGENT_PATH = "cogames.policy.nim_agents.agents.ThinkyAgentsMultiPolicy"
|
|
19
|
-
|
|
20
|
-
# Defaults (keep simple)
|
|
21
|
-
NUM_COGS = 4
|
|
22
|
-
MAX_STEPS = 10000
|
|
23
|
-
SEED = 42
|
|
24
|
-
|
|
25
|
-
# Add/modify your evals here over time
|
|
26
|
-
EVALS: List[Tuple[str, str, int]] = [
|
|
27
|
-
# Regular evals
|
|
28
|
-
(
|
|
29
|
-
"energy_starved",
|
|
30
|
-
"buggy",
|
|
31
|
-
NUM_COGS,
|
|
32
|
-
), # E is very hard, max E is 256, but agents think its 100.
|
|
33
|
-
("oxygen_bottleneck", "", NUM_COGS),
|
|
34
|
-
("collect_resources_classic", "", NUM_COGS),
|
|
35
|
-
("collect_resources_spread", "", NUM_COGS),
|
|
36
|
-
("collect_far", "", NUM_COGS),
|
|
37
|
-
("divide_and_conquer", "", NUM_COGS),
|
|
38
|
-
("go_together", "", NUM_COGS),
|
|
39
|
-
("single_use_swarm", "flakey", NUM_COGS),
|
|
40
|
-
# Diagnostic evals
|
|
41
|
-
("diagnostic_chest_navigation1", "", 1),
|
|
42
|
-
("diagnostic_chest_navigation2", "", 1),
|
|
43
|
-
("diagnostic_chest_navigation3", "", 1),
|
|
44
|
-
("diagnostic_chest_deposit_near", "", 1),
|
|
45
|
-
("diagnostic_chest_deposit_search", "", 1),
|
|
46
|
-
("diagnostic_charge_up", "buggy", 1), # The cog needs to sacrifice itself to make hart.
|
|
47
|
-
("diagnostic_memory", "", 1),
|
|
48
|
-
("diagnostic_assemble_seeded_near", "", 1),
|
|
49
|
-
("diagnostic_assemble_seeded_search", "", 1),
|
|
50
|
-
("diagnostic_extract_missing_carbon", "", 1),
|
|
51
|
-
("diagnostic_extract_missing_oxygen", "", 1),
|
|
52
|
-
("diagnostic_extract_missing_germanium", "", 1),
|
|
53
|
-
("diagnostic_extract_missing_silicon", "", 1),
|
|
54
|
-
("diagnostic_unclip_craft", "", 1),
|
|
55
|
-
("diagnostic_unclip_preseed", "", 1),
|
|
56
|
-
("diagnostic_agile", "", 1),
|
|
57
|
-
("diagnostic_radial", "", 1),
|
|
58
|
-
# Hello World evals
|
|
59
|
-
("distant_resources", "buggy", NUM_COGS), # Not enough time for such distances.
|
|
60
|
-
("quadrant_buildings", "buggy", NUM_COGS), # Not enough charger for such distances.
|
|
61
|
-
("vibe_check", "", NUM_COGS),
|
|
62
|
-
("oxygen_bottleneck_easy", "", NUM_COGS),
|
|
63
|
-
("oxygen_bottleneck_standard", "", NUM_COGS),
|
|
64
|
-
("oxygen_bottleneck_hard", "buggy", NUM_COGS), # Not enough charger for such distances.
|
|
65
|
-
("energy_starved_easy", "", NUM_COGS),
|
|
66
|
-
("energy_starved_standard", "buggy", NUM_COGS), # E drain too high.
|
|
67
|
-
("energy_starved_hard", "buggy", NUM_COGS), # E drain too high.
|
|
68
|
-
("unclipping_easy", "n/a", NUM_COGS),
|
|
69
|
-
("unclipping_standard", "n/a", NUM_COGS),
|
|
70
|
-
("unclipping_hard", "n/a", NUM_COGS),
|
|
71
|
-
("distant_resources_easy", "", NUM_COGS),
|
|
72
|
-
("distant_resources_standard", "flakey", NUM_COGS), # Not enough time for such distances.
|
|
73
|
-
("distant_resources_hard", "buggy", NUM_COGS), # Not enough time for such distances.
|
|
74
|
-
("quadrant_buildings_easy", "", NUM_COGS),
|
|
75
|
-
("quadrant_buildings_standard", "buggy", NUM_COGS), # Not enough charger for such distances.
|
|
76
|
-
("quadrant_buildings_hard", "buggy", NUM_COGS), # Not enough charger for such distances.
|
|
77
|
-
("single_use_swarm_easy", "buggy", NUM_COGS),
|
|
78
|
-
("single_use_swarm_standard", "buggy", NUM_COGS), # Not enough time for such distances.
|
|
79
|
-
("single_use_swarm_hard", "buggy", NUM_COGS), # E drain too high.
|
|
80
|
-
("vibe_check_easy", "buggy", NUM_COGS), # No/invalid recipes available.
|
|
81
|
-
("vibe_check_standard", "", NUM_COGS),
|
|
82
|
-
("vibe_check_hard", "flakey", NUM_COGS), # Not enough time for such distances.
|
|
83
|
-
# Hearts evals
|
|
84
|
-
("easy_large_hearts", "slow", NUM_COGS),
|
|
85
|
-
("easy_medium_hearts", "", NUM_COGS),
|
|
86
|
-
("easy_small_hearts", "flakey", NUM_COGS),
|
|
87
|
-
# Missions from missions.py
|
|
88
|
-
("harvest", "", NUM_COGS),
|
|
89
|
-
("repair", "", 2), # repair uses 2 cogs
|
|
90
|
-
("hello_world_unclip", "", NUM_COGS),
|
|
91
|
-
]
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def _load_all_missions() -> Dict[str, Mission]:
|
|
95
|
-
# Minimal loader: merge all known mission sets
|
|
96
|
-
from importlib import import_module
|
|
97
|
-
|
|
98
|
-
missions: List[Mission] = []
|
|
99
|
-
for mod_name in (
|
|
100
|
-
"cogames.cogs_vs_clips.evals.eval_missions",
|
|
101
|
-
"cogames.cogs_vs_clips.evals.integrated_evals",
|
|
102
|
-
"cogames.cogs_vs_clips.evals.spanning_evals",
|
|
103
|
-
"cogames.cogs_vs_clips.missions",
|
|
104
|
-
):
|
|
105
|
-
try:
|
|
106
|
-
mod = import_module(mod_name)
|
|
107
|
-
# missions.py uses MISSIONS, others use EVAL_MISSIONS
|
|
108
|
-
eval_list = getattr(mod, "MISSIONS", getattr(mod, "EVAL_MISSIONS", []))
|
|
109
|
-
missions.extend(eval_list)
|
|
110
|
-
except Exception:
|
|
111
|
-
pass
|
|
3
|
+
The implementation moved to `cogames_agents.policy.nim_agents.thinky_eval`.
|
|
4
|
+
"""
|
|
112
5
|
|
|
113
|
-
|
|
114
|
-
try:
|
|
115
|
-
missions.extend([cls() for cls in DIAGNOSTIC_EVALS]) # type: ignore[misc]
|
|
116
|
-
except Exception:
|
|
117
|
-
pass
|
|
118
|
-
|
|
119
|
-
# Build name -> mission instance map
|
|
120
|
-
mission_map: Dict[str, Mission] = {}
|
|
121
|
-
for m in missions:
|
|
122
|
-
# Items in EVAL_MISSIONS may be classes or instances; normalize to instances
|
|
123
|
-
try:
|
|
124
|
-
mission: Mission = m() if isinstance(m, type) else m # type: ignore[call-arg,assignment]
|
|
125
|
-
except Exception:
|
|
126
|
-
continue
|
|
127
|
-
mission_map[mission.name] = mission
|
|
128
|
-
return mission_map
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def _ensure_vibe_supports_gear(env_cfg) -> None:
|
|
132
|
-
# Keep minimal and silent if anything fails
|
|
133
|
-
try:
|
|
134
|
-
assembler = env_cfg.game.objects.get("assembler")
|
|
135
|
-
uses_gear = False
|
|
136
|
-
if assembler is not None and hasattr(assembler, "protocols"):
|
|
137
|
-
for proto in assembler.protocols:
|
|
138
|
-
if any(v == "gear" for v in getattr(proto, "vibes", [])):
|
|
139
|
-
uses_gear = True
|
|
140
|
-
break
|
|
141
|
-
if uses_gear:
|
|
142
|
-
change_vibe = env_cfg.game.actions.change_vibe
|
|
143
|
-
has_gear = any(v.name == "gear" for v in change_vibe.vibes)
|
|
144
|
-
if not has_gear:
|
|
145
|
-
from mettagrid.config.vibes import VIBE_BY_NAME
|
|
146
|
-
|
|
147
|
-
change_vibe.vibes = list(change_vibe.vibes) + [VIBE_BY_NAME["gear"]]
|
|
148
|
-
except Exception:
|
|
149
|
-
pass
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
def run_eval(experiment_name: str, tag: str, mission_map: Dict[str, Mission], num_cogs: int, seed: int) -> float:
|
|
153
|
-
start = time.perf_counter()
|
|
154
|
-
try:
|
|
155
|
-
if experiment_name not in mission_map:
|
|
156
|
-
print(f"{tag:<6} {experiment_name:<40} {'MISSION NOT FOUND':>6}")
|
|
157
|
-
return 0.0
|
|
158
|
-
|
|
159
|
-
base_mission = mission_map[experiment_name]
|
|
160
|
-
mission = base_mission.with_variants([NumCogsVariant(num_cogs=num_cogs)])
|
|
6
|
+
from __future__ import annotations
|
|
161
7
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
env_cfg.game.max_steps = MAX_STEPS
|
|
8
|
+
import importlib
|
|
9
|
+
from typing import Any
|
|
165
10
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
11
|
+
try:
|
|
12
|
+
_thinky_eval = importlib.import_module("cogames_agents.policy.nim_agents.thinky_eval")
|
|
13
|
+
except ModuleNotFoundError as exc:
|
|
14
|
+
if exc.name and (exc.name == "cogames_agents" or exc.name.startswith("cogames_agents.")):
|
|
15
|
+
raise ModuleNotFoundError(
|
|
16
|
+
"Legacy import `cogames.policy.nim_agents.thinky_eval` requires optional dependency "
|
|
17
|
+
"`cogames-agents` (install `cogames[agents]`)."
|
|
18
|
+
) from exc
|
|
19
|
+
raise
|
|
173
20
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
21
|
+
AGENT_PATH = _thinky_eval.AGENT_PATH
|
|
22
|
+
EVALS = _thinky_eval.EVALS
|
|
23
|
+
MAX_STEPS = _thinky_eval.MAX_STEPS
|
|
24
|
+
NUM_COGS = _thinky_eval.NUM_COGS
|
|
25
|
+
SEED = _thinky_eval.SEED
|
|
26
|
+
main = _thinky_eval.main
|
|
27
|
+
run_eval = _thinky_eval.run_eval
|
|
181
28
|
|
|
182
|
-
total_reward = float(sum(rollout._sim.episode_rewards))
|
|
183
|
-
hearts_per_agent = total_reward / max(1, num_cogs)
|
|
184
|
-
elapsed = time.perf_counter() - start
|
|
185
29
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
tm = f"{elapsed:.2f}"
|
|
189
|
-
print(f"{tag:<6} {experiment_name:<40} {hpa:>6}h {tm:>6}s")
|
|
190
|
-
return hearts_per_agent
|
|
191
|
-
except Exception as e:
|
|
192
|
-
elapsed = time.perf_counter() - start
|
|
193
|
-
error_message = str(e)
|
|
194
|
-
print(f"{tag:<6} {experiment_name:<40} {error_message}")
|
|
195
|
-
return 0.0
|
|
30
|
+
def __getattr__(name: str) -> Any:
|
|
31
|
+
return getattr(_thinky_eval, name)
|
|
196
32
|
|
|
197
33
|
|
|
198
|
-
def
|
|
199
|
-
|
|
200
|
-
na.start_measure()
|
|
201
|
-
mission_map = _load_all_missions()
|
|
202
|
-
print(f"Loaded {len(mission_map)} missions")
|
|
203
|
-
print("tag .. map name ............................... harts/A .. time")
|
|
204
|
-
start = time.perf_counter()
|
|
205
|
-
total_hpa = 0.0
|
|
206
|
-
successful_evals = 0
|
|
207
|
-
num_evals = 0
|
|
208
|
-
for experiment_name, tag, num_cogs in EVALS:
|
|
209
|
-
num_evals += 1
|
|
210
|
-
if tag == "flakey":
|
|
211
|
-
for i in range(10):
|
|
212
|
-
hpa = run_eval(experiment_name, tag, mission_map, num_cogs, SEED + i)
|
|
213
|
-
if hpa > 0:
|
|
214
|
-
successful_evals += 1
|
|
215
|
-
total_hpa += hpa
|
|
216
|
-
break
|
|
217
|
-
else:
|
|
218
|
-
hpa = run_eval(experiment_name, tag, mission_map, num_cogs, SEED)
|
|
219
|
-
if hpa > 0:
|
|
220
|
-
successful_evals += 1
|
|
221
|
-
total_hpa += hpa
|
|
222
|
-
success_rate = successful_evals / num_evals
|
|
223
|
-
elapsed = time.perf_counter() - start
|
|
224
|
-
total_evals = f"{num_evals} evals {success_rate * 100:.1f}% successful"
|
|
225
|
-
hpa = f"{total_hpa:.2f}"
|
|
226
|
-
tm = f"{elapsed:.2f}"
|
|
227
|
-
tag = "total"
|
|
228
|
-
print(f"{tag:<6} {total_evals:<40} {hpa:>6}h {tm:>6}s")
|
|
229
|
-
na.end_measure()
|
|
34
|
+
def __dir__() -> list[str]:
|
|
35
|
+
return sorted(set(globals()).union(dir(_thinky_eval)))
|
|
230
36
|
|
|
231
37
|
|
|
232
|
-
|
|
233
|
-
|
|
38
|
+
__all__ = [
|
|
39
|
+
"AGENT_PATH",
|
|
40
|
+
"EVALS",
|
|
41
|
+
"MAX_STEPS",
|
|
42
|
+
"NUM_COGS",
|
|
43
|
+
"SEED",
|
|
44
|
+
"main",
|
|
45
|
+
"run_eval",
|
|
46
|
+
]
|
|
@@ -15,8 +15,9 @@ import torch
|
|
|
15
15
|
|
|
16
16
|
import pufferlib.models # type: ignore[import-untyped]
|
|
17
17
|
import pufferlib.pytorch # type: ignore[import-untyped]
|
|
18
|
-
from mettagrid.policy.policy import AgentPolicy, MultiAgentPolicy
|
|
18
|
+
from mettagrid.policy.policy import AgentPolicy, MultiAgentPolicy, StatefulAgentPolicy
|
|
19
19
|
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
|
|
20
|
+
from mettagrid.policy.pufferlib import PufferlibStatefulImpl
|
|
20
21
|
from mettagrid.simulator import Action, AgentObservation, Simulation
|
|
21
22
|
|
|
22
23
|
|
|
@@ -24,7 +25,7 @@ class PufferlibCogsPolicy(MultiAgentPolicy, AgentPolicy):
|
|
|
24
25
|
"""Loads and runs checkpoints trained with PufferLib's CoGames policy.
|
|
25
26
|
|
|
26
27
|
This policy serves as both the MultiAgentPolicy factory and AgentPolicy
|
|
27
|
-
implementation, returning
|
|
28
|
+
implementation, returning per-agent wrappers that track state.
|
|
28
29
|
"""
|
|
29
30
|
|
|
30
31
|
short_names = ["pufferlib_cogs"]
|
|
@@ -38,64 +39,98 @@ class PufferlibCogsPolicy(MultiAgentPolicy, AgentPolicy):
|
|
|
38
39
|
):
|
|
39
40
|
MultiAgentPolicy.__init__(self, policy_env_info, device=device)
|
|
40
41
|
AgentPolicy.__init__(self, policy_env_info)
|
|
41
|
-
|
|
42
|
+
self._hidden_size = hidden_size
|
|
43
|
+
self._device = torch.device(device)
|
|
44
|
+
self._shim_env = SimpleNamespace(
|
|
42
45
|
single_observation_space=policy_env_info.observation_space,
|
|
43
46
|
single_action_space=policy_env_info.action_space,
|
|
44
47
|
observation_space=policy_env_info.observation_space,
|
|
45
48
|
action_space=policy_env_info.action_space,
|
|
46
49
|
num_agents=policy_env_info.num_agents,
|
|
47
50
|
)
|
|
48
|
-
|
|
49
|
-
self._net = pufferlib.models.Default(
|
|
50
|
-
self._net = self._net.to(torch.device(device))
|
|
51
|
+
self._shim_env.env = self._shim_env
|
|
52
|
+
self._net = pufferlib.models.Default(self._shim_env, hidden_size=hidden_size).to(self._device) # type: ignore[arg-type]
|
|
51
53
|
self._action_names = policy_env_info.action_names
|
|
52
|
-
self.
|
|
53
|
-
self.
|
|
54
|
+
self._is_recurrent = False
|
|
55
|
+
self._stateful_impl = PufferlibStatefulImpl(
|
|
56
|
+
self._net,
|
|
57
|
+
policy_env_info,
|
|
58
|
+
self._device,
|
|
59
|
+
is_recurrent=self._is_recurrent,
|
|
60
|
+
)
|
|
61
|
+
self._agent_policies: dict[int, StatefulAgentPolicy[dict[str, torch.Tensor | None]]] = {}
|
|
62
|
+
self._state_initialized = False
|
|
63
|
+
self._state: dict[str, torch.Tensor | None] = {}
|
|
54
64
|
|
|
55
65
|
def network(self) -> torch.nn.Module: # type: ignore[override]
|
|
56
66
|
return self._net
|
|
57
67
|
|
|
58
68
|
def agent_policy(self, agent_id: int) -> AgentPolicy: # type: ignore[override]
|
|
59
|
-
|
|
69
|
+
if agent_id not in self._agent_policies:
|
|
70
|
+
self._agent_policies[agent_id] = StatefulAgentPolicy(
|
|
71
|
+
self._stateful_impl,
|
|
72
|
+
self._policy_env_info,
|
|
73
|
+
agent_id=agent_id,
|
|
74
|
+
)
|
|
75
|
+
return self._agent_policies[agent_id]
|
|
60
76
|
|
|
61
77
|
def is_recurrent(self) -> bool:
|
|
62
|
-
return
|
|
78
|
+
return self._is_recurrent
|
|
63
79
|
|
|
64
80
|
def reset(self, simulation: Optional[Simulation] = None) -> None: # type: ignore[override]
|
|
65
|
-
|
|
66
|
-
|
|
81
|
+
for policy in self._agent_policies.values():
|
|
82
|
+
policy.reset(simulation)
|
|
83
|
+
self._reset_state()
|
|
67
84
|
|
|
68
85
|
def load_policy_data(self, policy_data_path: str) -> None:
|
|
69
|
-
state = torch.load(policy_data_path, map_location=
|
|
70
|
-
|
|
71
|
-
|
|
86
|
+
state = torch.load(policy_data_path, map_location=self._device)
|
|
87
|
+
state = {k.replace("module.", ""): v for k, v in state.items()}
|
|
88
|
+
uses_rnn = any(key.startswith(("lstm.", "cell.")) for key in state)
|
|
89
|
+
base_net = pufferlib.models.Default(self._shim_env, hidden_size=self._hidden_size) # type: ignore[arg-type]
|
|
90
|
+
net = (
|
|
91
|
+
pufferlib.models.LSTMWrapper(
|
|
92
|
+
self._shim_env,
|
|
93
|
+
base_net,
|
|
94
|
+
input_size=base_net.hidden_size,
|
|
95
|
+
hidden_size=base_net.hidden_size,
|
|
96
|
+
)
|
|
97
|
+
if uses_rnn
|
|
98
|
+
else base_net
|
|
99
|
+
)
|
|
100
|
+
net.load_state_dict(state)
|
|
101
|
+
self._net = net.to(self._device)
|
|
102
|
+
self._is_recurrent = uses_rnn
|
|
103
|
+
self._stateful_impl = PufferlibStatefulImpl(
|
|
104
|
+
self._net,
|
|
105
|
+
self._policy_env_info,
|
|
106
|
+
self._device,
|
|
107
|
+
is_recurrent=self._is_recurrent,
|
|
108
|
+
)
|
|
109
|
+
self._agent_policies.clear()
|
|
110
|
+
self._state_initialized = False
|
|
111
|
+
self._state = {}
|
|
72
112
|
|
|
73
113
|
def save_policy_data(self, policy_data_path: str) -> None:
|
|
74
114
|
torch.save(self._net.state_dict(), policy_data_path)
|
|
75
115
|
|
|
76
116
|
def step(self, obs: Union[AgentObservation, torch.Tensor, Sequence[Any]]) -> Action: # type: ignore[override]
|
|
77
117
|
if isinstance(obs, AgentObservation):
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
for idx, token in enumerate(obs.tokens):
|
|
85
|
-
if idx >= self._num_tokens:
|
|
86
|
-
break
|
|
87
|
-
raw = torch.as_tensor(token.raw_token, device=self._device, dtype=obs_tensor.dtype)
|
|
88
|
-
obs_tensor[idx, : raw.numel()] = raw
|
|
89
|
-
else:
|
|
90
|
-
obs_tensor = torch.as_tensor(obs, device=self._device, dtype=torch.float32)
|
|
91
|
-
|
|
92
|
-
obs_tensor = obs_tensor * (1.0 / 255.0)
|
|
118
|
+
if not self._state_initialized:
|
|
119
|
+
self._reset_state()
|
|
120
|
+
with torch.no_grad():
|
|
121
|
+
action, self._state = self._stateful_impl.step_with_state(obs, self._state)
|
|
122
|
+
return action
|
|
123
|
+
obs_tensor = torch.as_tensor(obs, device=self._device, dtype=torch.float32)
|
|
93
124
|
if obs_tensor.ndim == 2:
|
|
94
125
|
obs_tensor = obs_tensor.unsqueeze(0)
|
|
95
|
-
|
|
96
126
|
with torch.no_grad():
|
|
97
127
|
self._net.eval()
|
|
98
|
-
logits, _ = self._net.forward_eval(obs_tensor)
|
|
128
|
+
logits, _ = self._net.forward_eval(obs_tensor, None)
|
|
99
129
|
sampled, _, _ = pufferlib.pytorch.sample_logits(logits)
|
|
100
130
|
action_idx = max(0, min(int(sampled.item()), len(self._action_names) - 1))
|
|
101
131
|
return Action(name=self._action_names[action_idx])
|
|
132
|
+
|
|
133
|
+
def _reset_state(self) -> None:
|
|
134
|
+
self._stateful_impl.reset()
|
|
135
|
+
self._state = self._stateful_impl.initial_agent_state()
|
|
136
|
+
self._state_initialized = True
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sample policy for the CoGames CogsGuard environment.
|
|
3
|
+
|
|
4
|
+
This starter policy uses simple heuristics:
|
|
5
|
+
- If the agent has no gear, head toward the nearest gear station.
|
|
6
|
+
- If the agent has aligner or scrambler gear, try to get hearts (and influence for aligner) then head to junctions.
|
|
7
|
+
- If the agent has miner gear, head to extractors.
|
|
8
|
+
- If the agent has scout gear, explore in a simple pattern.
|
|
9
|
+
|
|
10
|
+
Note to users of this policy:
|
|
11
|
+
We don't intend for scripted policies to be the final word on how policies are generated (e.g., we expect the
|
|
12
|
+
environment to be complicated enough that trained agents will be necessary). So we expect that scripting policies
|
|
13
|
+
is a good way to start, but don't want you to get stuck here. Feel free to prove us wrong!
|
|
14
|
+
|
|
15
|
+
Note to cogames developers:
|
|
16
|
+
This policy should be kept relatively minimalist, without dependencies on intricate algorithms.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from typing import Iterable, Optional
|
|
23
|
+
|
|
24
|
+
from mettagrid.policy.policy import MultiAgentPolicy, StatefulAgentPolicy, StatefulPolicyImpl
|
|
25
|
+
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
|
|
26
|
+
from mettagrid.simulator import Action
|
|
27
|
+
from mettagrid.simulator.interface import AgentObservation
|
|
28
|
+
|
|
29
|
+
GEAR = ("aligner", "scrambler", "miner", "scout")
|
|
30
|
+
ELEMENTS = ("carbon", "oxygen", "germanium", "silicon")
|
|
31
|
+
WANDER_DIRECTIONS = ("east", "south", "west", "north")
|
|
32
|
+
WANDER_STEPS = 8
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class StarterCogState:
|
|
37
|
+
wander_direction_index: int = 0
|
|
38
|
+
wander_steps_remaining: int = WANDER_STEPS
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class StarterCogPolicyImpl(StatefulPolicyImpl[StarterCogState]):
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
policy_env_info: PolicyEnvInterface,
|
|
45
|
+
agent_id: int,
|
|
46
|
+
):
|
|
47
|
+
self._agent_id = agent_id
|
|
48
|
+
self._policy_env_info = policy_env_info
|
|
49
|
+
|
|
50
|
+
self._action_names = policy_env_info.action_names
|
|
51
|
+
self._action_name_set = set(self._action_names)
|
|
52
|
+
self._fallback_action_name = "noop" if "noop" in self._action_name_set else self._action_names[0]
|
|
53
|
+
self._center = (policy_env_info.obs_height // 2, policy_env_info.obs_width // 2)
|
|
54
|
+
self._tag_name_to_id = {name: idx for idx, name in enumerate(policy_env_info.tags)}
|
|
55
|
+
self._gear_station_tags = self._resolve_tag_ids([f"{gear}_station" for gear in GEAR])
|
|
56
|
+
self._extractor_tags = self._resolve_tag_ids([f"{element}_extractor" for element in ELEMENTS])
|
|
57
|
+
self._junction_tags = self._resolve_tag_ids(["junction"])
|
|
58
|
+
self._chest_tags = self._resolve_tag_ids(["chest"])
|
|
59
|
+
self._hub_tags = self._resolve_tag_ids(["hub"])
|
|
60
|
+
|
|
61
|
+
def _resolve_tag_ids(self, names: Iterable[str]) -> set[int]:
|
|
62
|
+
tag_ids: set[int] = set()
|
|
63
|
+
for name in names:
|
|
64
|
+
if name in self._tag_name_to_id:
|
|
65
|
+
tag_ids.add(self._tag_name_to_id[name])
|
|
66
|
+
if name.startswith("type:"):
|
|
67
|
+
continue
|
|
68
|
+
type_name = f"type:{name}"
|
|
69
|
+
if type_name in self._tag_name_to_id:
|
|
70
|
+
tag_ids.add(self._tag_name_to_id[type_name])
|
|
71
|
+
return tag_ids
|
|
72
|
+
|
|
73
|
+
def _inventory_items(self, obs: AgentObservation) -> set[str]:
|
|
74
|
+
items: set[str] = set()
|
|
75
|
+
for token in obs.tokens:
|
|
76
|
+
if token.location != self._center:
|
|
77
|
+
continue
|
|
78
|
+
name = token.feature.name
|
|
79
|
+
if not name.startswith("inv:"):
|
|
80
|
+
continue
|
|
81
|
+
parts = name.split(":", 2)
|
|
82
|
+
if len(parts) >= 2:
|
|
83
|
+
items.add(parts[1])
|
|
84
|
+
return items
|
|
85
|
+
|
|
86
|
+
def _closest_tag_location(self, obs: AgentObservation, tag_ids: set[int]) -> Optional[tuple[int, int]]:
|
|
87
|
+
if not tag_ids:
|
|
88
|
+
return None
|
|
89
|
+
best_location: Optional[tuple[int, int]] = None
|
|
90
|
+
best_distance = 999
|
|
91
|
+
for token in obs.tokens:
|
|
92
|
+
if token.feature.name != "tag":
|
|
93
|
+
continue
|
|
94
|
+
if token.value not in tag_ids:
|
|
95
|
+
continue
|
|
96
|
+
distance = abs(token.location[0] - self._center[0]) + abs(token.location[1] - self._center[1])
|
|
97
|
+
if distance < best_distance:
|
|
98
|
+
best_distance = distance
|
|
99
|
+
best_location = token.location
|
|
100
|
+
return best_location
|
|
101
|
+
|
|
102
|
+
def _action(self, name: str) -> Action:
|
|
103
|
+
if name in self._action_name_set:
|
|
104
|
+
return Action(name=name)
|
|
105
|
+
return Action(name=self._fallback_action_name)
|
|
106
|
+
|
|
107
|
+
def _wander(self, state: StarterCogState) -> tuple[Action, StarterCogState]:
|
|
108
|
+
if state.wander_steps_remaining <= 0:
|
|
109
|
+
state.wander_direction_index = (state.wander_direction_index + 1) % len(WANDER_DIRECTIONS)
|
|
110
|
+
state.wander_steps_remaining = WANDER_STEPS
|
|
111
|
+
direction = WANDER_DIRECTIONS[state.wander_direction_index]
|
|
112
|
+
state.wander_steps_remaining -= 1
|
|
113
|
+
return self._action(f"move_{direction}"), state
|
|
114
|
+
|
|
115
|
+
def _move_toward(self, state: StarterCogState, target: Optional[tuple[int, int]]) -> tuple[Action, StarterCogState]:
|
|
116
|
+
if target is None:
|
|
117
|
+
return self._wander(state)
|
|
118
|
+
delta_row = target[0] - self._center[0]
|
|
119
|
+
delta_col = target[1] - self._center[1]
|
|
120
|
+
if delta_row == 0 and delta_col == 0:
|
|
121
|
+
return self._action(self._fallback_action_name), state
|
|
122
|
+
if abs(delta_row) >= abs(delta_col):
|
|
123
|
+
direction = "south" if delta_row > 0 else "north"
|
|
124
|
+
else:
|
|
125
|
+
direction = "east" if delta_col > 0 else "west"
|
|
126
|
+
return self._action(f"move_{direction}"), state
|
|
127
|
+
|
|
128
|
+
def _current_gear(self, items: set[str]) -> Optional[str]:
|
|
129
|
+
for gear in GEAR:
|
|
130
|
+
if gear in items:
|
|
131
|
+
return gear
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
def step_with_state(self, obs: AgentObservation, state: StarterCogState) -> tuple[Action, StarterCogState]:
|
|
135
|
+
"""Compute the action for this Cog."""
|
|
136
|
+
items = self._inventory_items(obs)
|
|
137
|
+
gear = self._current_gear(items)
|
|
138
|
+
has_heart = "heart" in items
|
|
139
|
+
has_influence = "influence" in items
|
|
140
|
+
|
|
141
|
+
if gear is None:
|
|
142
|
+
target_tags = self._gear_station_tags
|
|
143
|
+
elif gear == "aligner":
|
|
144
|
+
if has_heart and has_influence:
|
|
145
|
+
target_tags = self._junction_tags
|
|
146
|
+
elif not has_heart:
|
|
147
|
+
target_tags = self._chest_tags
|
|
148
|
+
else:
|
|
149
|
+
target_tags = self._hub_tags
|
|
150
|
+
elif gear == "scrambler":
|
|
151
|
+
target_tags = self._junction_tags if has_heart else self._chest_tags
|
|
152
|
+
elif gear == "miner":
|
|
153
|
+
target_tags = self._extractor_tags
|
|
154
|
+
else:
|
|
155
|
+
target_tags = set()
|
|
156
|
+
|
|
157
|
+
target_location = self._closest_tag_location(obs, target_tags) if target_tags else None
|
|
158
|
+
return self._move_toward(state, target_location)
|
|
159
|
+
|
|
160
|
+
def initial_agent_state(self) -> StarterCogState:
|
|
161
|
+
"""Get the initial state for a new agent."""
|
|
162
|
+
return StarterCogState()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# ============================================================================
|
|
166
|
+
# Policy Wrapper Classes
|
|
167
|
+
# ============================================================================
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class StarterPolicy(MultiAgentPolicy):
|
|
171
|
+
short_names = ["starter"]
|
|
172
|
+
|
|
173
|
+
def __init__(self, policy_env_info: PolicyEnvInterface, device: str = "cpu"):
|
|
174
|
+
super().__init__(policy_env_info, device=device)
|
|
175
|
+
self._agent_policies: dict[int, StatefulAgentPolicy[StarterCogState]] = {}
|
|
176
|
+
|
|
177
|
+
def agent_policy(self, agent_id: int) -> StatefulAgentPolicy[StarterCogState]:
|
|
178
|
+
if agent_id not in self._agent_policies:
|
|
179
|
+
self._agent_policies[agent_id] = StatefulAgentPolicy(
|
|
180
|
+
StarterCogPolicyImpl(self._policy_env_info, agent_id),
|
|
181
|
+
self._policy_env_info,
|
|
182
|
+
agent_id=agent_id,
|
|
183
|
+
)
|
|
184
|
+
return self._agent_policies[agent_id]
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Trainable Policy Template for the CoGames environment.
|
|
3
3
|
|
|
4
|
+
This template is compatible with CogsGuard missions. It uses only the observation and action
|
|
5
|
+
spaces provided by the environment and makes no game-specific assumptions.
|
|
6
|
+
|
|
4
7
|
This template provides a minimal trainable neural network policy that can be used with
|
|
5
8
|
`cogames tutorial train`. It demonstrates the key interfaces required for training:
|
|
6
9
|
|
|
@@ -14,7 +17,7 @@ clarity and without the pufferlib dependency.
|
|
|
14
17
|
|
|
15
18
|
To use this template:
|
|
16
19
|
1. Modify MyNetwork to implement your desired architecture
|
|
17
|
-
2. Run: cogames tutorial train -m
|
|
20
|
+
2. Run: cogames tutorial train -m cogsguard_machina_1.basic -p class=my_trainable_policy.MyTrainablePolicy
|
|
18
21
|
"""
|
|
19
22
|
|
|
20
23
|
from __future__ import annotations
|