cogames-agents 0.0.0.7__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cogames_agents/__init__.py +0 -0
- cogames_agents/evals/__init__.py +5 -0
- cogames_agents/evals/planky_evals.py +415 -0
- cogames_agents/policy/__init__.py +0 -0
- cogames_agents/policy/evolution/__init__.py +0 -0
- cogames_agents/policy/evolution/cogsguard/__init__.py +0 -0
- cogames_agents/policy/evolution/cogsguard/evolution.py +695 -0
- cogames_agents/policy/evolution/cogsguard/evolutionary_coordinator.py +540 -0
- cogames_agents/policy/nim_agents/__init__.py +20 -0
- cogames_agents/policy/nim_agents/agents.py +98 -0
- cogames_agents/policy/nim_agents/bindings/generated/libnim_agents.dylib +0 -0
- cogames_agents/policy/nim_agents/bindings/generated/nim_agents.py +215 -0
- cogames_agents/policy/nim_agents/cogsguard_agents.nim +555 -0
- cogames_agents/policy/nim_agents/cogsguard_align_all_agents.nim +569 -0
- cogames_agents/policy/nim_agents/common.nim +1054 -0
- cogames_agents/policy/nim_agents/install.sh +1 -0
- cogames_agents/policy/nim_agents/ladybug_agent.nim +954 -0
- cogames_agents/policy/nim_agents/nim_agents.nim +68 -0
- cogames_agents/policy/nim_agents/nim_agents.nims +14 -0
- cogames_agents/policy/nim_agents/nimby.lock +3 -0
- cogames_agents/policy/nim_agents/racecar_agents.nim +844 -0
- cogames_agents/policy/nim_agents/random_agents.nim +68 -0
- cogames_agents/policy/nim_agents/test_agents.py +53 -0
- cogames_agents/policy/nim_agents/thinky_agents.nim +677 -0
- cogames_agents/policy/nim_agents/thinky_eval.py +230 -0
- cogames_agents/policy/scripted_agent/README.md +360 -0
- cogames_agents/policy/scripted_agent/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/baseline_agent.py +1031 -0
- cogames_agents/policy/scripted_agent/cogas/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/cogas/context.py +68 -0
- cogames_agents/policy/scripted_agent/cogas/entity_map.py +152 -0
- cogames_agents/policy/scripted_agent/cogas/goal.py +115 -0
- cogames_agents/policy/scripted_agent/cogas/goals/__init__.py +27 -0
- cogames_agents/policy/scripted_agent/cogas/goals/aligner.py +160 -0
- cogames_agents/policy/scripted_agent/cogas/goals/gear.py +197 -0
- cogames_agents/policy/scripted_agent/cogas/goals/miner.py +441 -0
- cogames_agents/policy/scripted_agent/cogas/goals/scout.py +40 -0
- cogames_agents/policy/scripted_agent/cogas/goals/scrambler.py +174 -0
- cogames_agents/policy/scripted_agent/cogas/goals/shared.py +160 -0
- cogames_agents/policy/scripted_agent/cogas/goals/stem.py +60 -0
- cogames_agents/policy/scripted_agent/cogas/goals/survive.py +100 -0
- cogames_agents/policy/scripted_agent/cogas/navigator.py +401 -0
- cogames_agents/policy/scripted_agent/cogas/obs_parser.py +238 -0
- cogames_agents/policy/scripted_agent/cogas/policy.py +525 -0
- cogames_agents/policy/scripted_agent/cogas/trace.py +69 -0
- cogames_agents/policy/scripted_agent/cogsguard/CLAUDE.md +517 -0
- cogames_agents/policy/scripted_agent/cogsguard/README.md +252 -0
- cogames_agents/policy/scripted_agent/cogsguard/__init__.py +74 -0
- cogames_agents/policy/scripted_agent/cogsguard/aligned_junction_held_investigation.md +152 -0
- cogames_agents/policy/scripted_agent/cogsguard/aligner.py +333 -0
- cogames_agents/policy/scripted_agent/cogsguard/behavior_hooks.py +44 -0
- cogames_agents/policy/scripted_agent/cogsguard/control_agent.py +323 -0
- cogames_agents/policy/scripted_agent/cogsguard/debug_agent.py +533 -0
- cogames_agents/policy/scripted_agent/cogsguard/miner.py +589 -0
- cogames_agents/policy/scripted_agent/cogsguard/options.py +67 -0
- cogames_agents/policy/scripted_agent/cogsguard/parity_metrics.py +36 -0
- cogames_agents/policy/scripted_agent/cogsguard/policy.py +1967 -0
- cogames_agents/policy/scripted_agent/cogsguard/prereq_trace.py +33 -0
- cogames_agents/policy/scripted_agent/cogsguard/role_trace.py +50 -0
- cogames_agents/policy/scripted_agent/cogsguard/roles.py +31 -0
- cogames_agents/policy/scripted_agent/cogsguard/rollout_trace.py +40 -0
- cogames_agents/policy/scripted_agent/cogsguard/scout.py +69 -0
- cogames_agents/policy/scripted_agent/cogsguard/scrambler.py +350 -0
- cogames_agents/policy/scripted_agent/cogsguard/targeted_agent.py +418 -0
- cogames_agents/policy/scripted_agent/cogsguard/teacher.py +224 -0
- cogames_agents/policy/scripted_agent/cogsguard/types.py +381 -0
- cogames_agents/policy/scripted_agent/cogsguard/v2_agent.py +49 -0
- cogames_agents/policy/scripted_agent/common/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/common/geometry.py +24 -0
- cogames_agents/policy/scripted_agent/common/roles.py +34 -0
- cogames_agents/policy/scripted_agent/common/tag_utils.py +48 -0
- cogames_agents/policy/scripted_agent/demo_policy.py +242 -0
- cogames_agents/policy/scripted_agent/pathfinding.py +126 -0
- cogames_agents/policy/scripted_agent/pinky/DESIGN.md +317 -0
- cogames_agents/policy/scripted_agent/pinky/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/__init__.py +17 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/aligner.py +400 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/base.py +119 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/miner.py +632 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/scout.py +138 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/scrambler.py +433 -0
- cogames_agents/policy/scripted_agent/pinky/policy.py +570 -0
- cogames_agents/policy/scripted_agent/pinky/services/__init__.py +7 -0
- cogames_agents/policy/scripted_agent/pinky/services/map_tracker.py +808 -0
- cogames_agents/policy/scripted_agent/pinky/services/navigator.py +864 -0
- cogames_agents/policy/scripted_agent/pinky/services/safety.py +189 -0
- cogames_agents/policy/scripted_agent/pinky/state.py +299 -0
- cogames_agents/policy/scripted_agent/pinky/types.py +138 -0
- cogames_agents/policy/scripted_agent/planky/CLAUDE.md +124 -0
- cogames_agents/policy/scripted_agent/planky/IMPROVEMENTS.md +160 -0
- cogames_agents/policy/scripted_agent/planky/NOTES.md +153 -0
- cogames_agents/policy/scripted_agent/planky/PLAN.md +254 -0
- cogames_agents/policy/scripted_agent/planky/README.md +214 -0
- cogames_agents/policy/scripted_agent/planky/STRATEGY.md +100 -0
- cogames_agents/policy/scripted_agent/planky/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/planky/context.py +68 -0
- cogames_agents/policy/scripted_agent/planky/entity_map.py +152 -0
- cogames_agents/policy/scripted_agent/planky/goal.py +107 -0
- cogames_agents/policy/scripted_agent/planky/goals/__init__.py +27 -0
- cogames_agents/policy/scripted_agent/planky/goals/aligner.py +168 -0
- cogames_agents/policy/scripted_agent/planky/goals/gear.py +179 -0
- cogames_agents/policy/scripted_agent/planky/goals/miner.py +416 -0
- cogames_agents/policy/scripted_agent/planky/goals/scout.py +40 -0
- cogames_agents/policy/scripted_agent/planky/goals/scrambler.py +174 -0
- cogames_agents/policy/scripted_agent/planky/goals/shared.py +160 -0
- cogames_agents/policy/scripted_agent/planky/goals/stem.py +49 -0
- cogames_agents/policy/scripted_agent/planky/goals/survive.py +96 -0
- cogames_agents/policy/scripted_agent/planky/navigator.py +388 -0
- cogames_agents/policy/scripted_agent/planky/obs_parser.py +238 -0
- cogames_agents/policy/scripted_agent/planky/policy.py +485 -0
- cogames_agents/policy/scripted_agent/planky/tests/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/planky/tests/conftest.py +66 -0
- cogames_agents/policy/scripted_agent/planky/tests/helpers.py +152 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_aligner.py +24 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_miner.py +30 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_scout.py +15 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_scrambler.py +29 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_stem.py +36 -0
- cogames_agents/policy/scripted_agent/planky/trace.py +69 -0
- cogames_agents/policy/scripted_agent/types.py +239 -0
- cogames_agents/policy/scripted_agent/unclipping_agent.py +461 -0
- cogames_agents/policy/scripted_agent/utils.py +381 -0
- cogames_agents/policy/scripted_registry.py +80 -0
- cogames_agents/py.typed +0 -0
- cogames_agents-0.0.0.7.dist-info/METADATA +98 -0
- cogames_agents-0.0.0.7.dist-info/RECORD +128 -0
- cogames_agents-0.0.0.7.dist-info/WHEEL +6 -0
- cogames_agents-0.0.0.7.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
"""CoGsGuard scripted policy with targeted role assignments."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Optional
|
|
7
|
+
|
|
8
|
+
from cogames_agents.policy.scripted_agent.utils import change_vibe_action
|
|
9
|
+
from mettagrid.policy.policy import StatefulAgentPolicy
|
|
10
|
+
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
|
|
11
|
+
from mettagrid.simulator import Action
|
|
12
|
+
|
|
13
|
+
from .aligner import AlignerAgentPolicyImpl
|
|
14
|
+
from .miner import HEALING_AOE_RANGE, MinerAgentPolicyImpl
|
|
15
|
+
from .policy import DEBUG, CogsguardAgentPolicyImpl, CogsguardMultiRoleImpl, CogsguardPolicy
|
|
16
|
+
from .scout import ScoutAgentPolicyImpl
|
|
17
|
+
from .scrambler import ScramblerAgentPolicyImpl
|
|
18
|
+
from .types import CogsguardAgentState, Role, StructureInfo, StructureType
|
|
19
|
+
|
|
20
|
+
PLAN_INTERVAL_STEPS = 25
|
|
21
|
+
PHASE_EXPLORE_END = 80
|
|
22
|
+
PHASE_CONTROL_END = 260
|
|
23
|
+
CHEST_LOW_THRESHOLD = 60
|
|
24
|
+
CONTROL_VIBES = {"scrambler", "aligner"}
|
|
25
|
+
RESOURCE_CYCLE = ["carbon", "oxygen", "germanium", "silicon"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _default_role_counts(num_agents: int) -> dict[str, int]:
|
|
29
|
+
if num_agents <= 1:
|
|
30
|
+
return {"miner": 1}
|
|
31
|
+
if num_agents == 2:
|
|
32
|
+
return {"scrambler": 1, "miner": 1}
|
|
33
|
+
if num_agents == 3:
|
|
34
|
+
return {"scrambler": 1, "miner": 1, "scout": 1}
|
|
35
|
+
if num_agents <= 7:
|
|
36
|
+
scramblers = 1
|
|
37
|
+
aligners = 1
|
|
38
|
+
scouts = 1
|
|
39
|
+
else:
|
|
40
|
+
scramblers = max(2, num_agents // 6)
|
|
41
|
+
aligners = max(2, num_agents // 6)
|
|
42
|
+
scouts = 1
|
|
43
|
+
miners = max(1, num_agents - scramblers - scouts - aligners)
|
|
44
|
+
return {
|
|
45
|
+
"scrambler": scramblers,
|
|
46
|
+
"aligner": aligners,
|
|
47
|
+
"miner": miners,
|
|
48
|
+
"scout": scouts,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _normalize_counts(num_agents: int, counts: dict[str, int]) -> dict[str, int]:
|
|
53
|
+
normalized = {k: v for k, v in counts.items() if isinstance(v, int)}
|
|
54
|
+
total = sum(normalized.values())
|
|
55
|
+
if total < num_agents:
|
|
56
|
+
normalized["miner"] = normalized.get("miner", 0) + (num_agents - total)
|
|
57
|
+
elif total > num_agents:
|
|
58
|
+
overflow = total - num_agents
|
|
59
|
+
miners = normalized.get("miner", 0)
|
|
60
|
+
normalized["miner"] = max(0, miners - overflow)
|
|
61
|
+
return normalized
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _build_role_plan(num_agents: int, counts: dict[str, int]) -> list[str]:
|
|
65
|
+
ordered: list[str] = []
|
|
66
|
+
for role_name in ["scrambler", "aligner", "miner", "scout"]:
|
|
67
|
+
ordered.extend([role_name] * counts.get(role_name, 0))
|
|
68
|
+
if len(ordered) < num_agents:
|
|
69
|
+
ordered.extend(["miner"] * (num_agents - len(ordered)))
|
|
70
|
+
return ordered[:num_agents]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class TargetedPlannerState:
|
|
75
|
+
num_agents: int
|
|
76
|
+
desired_vibes: list[str] = field(default_factory=list)
|
|
77
|
+
last_plan_step: int = 0
|
|
78
|
+
known_junctions: int = 0
|
|
79
|
+
aligned_junctions: int = 0
|
|
80
|
+
chest_resources: int = 0
|
|
81
|
+
junction_map: dict[tuple[int, int], Optional[str]] = field(default_factory=dict)
|
|
82
|
+
extractor_map: dict[tuple[int, int], Optional[str]] = field(default_factory=dict)
|
|
83
|
+
assigned_junctions: dict[int, tuple[int, int]] = field(default_factory=dict)
|
|
84
|
+
assigned_extractors: dict[int, tuple[int, int]] = field(default_factory=dict)
|
|
85
|
+
|
|
86
|
+
def update_from_agent(self, s: CogsguardAgentState) -> None:
|
|
87
|
+
junctions = s.get_structures_by_type(StructureType.CHARGER)
|
|
88
|
+
aligned = [c for c in junctions if c.alignment == "cogs"]
|
|
89
|
+
self.known_junctions = max(self.known_junctions, len(junctions))
|
|
90
|
+
self.aligned_junctions = max(self.aligned_junctions, len(aligned))
|
|
91
|
+
for junction in junctions:
|
|
92
|
+
self.junction_map[junction.position] = junction.alignment
|
|
93
|
+
|
|
94
|
+
for extractor in s.get_usable_extractors():
|
|
95
|
+
self.extractor_map[extractor.position] = extractor.resource_type
|
|
96
|
+
|
|
97
|
+
chest_resources = 0
|
|
98
|
+
for struct in s.get_structures_by_type(StructureType.CHEST):
|
|
99
|
+
chest_resources = max(chest_resources, struct.inventory_amount)
|
|
100
|
+
if chest_resources > 0:
|
|
101
|
+
self.chest_resources = max(self.chest_resources, chest_resources)
|
|
102
|
+
|
|
103
|
+
def maybe_plan(self, step_count: int) -> None:
|
|
104
|
+
if step_count - self.last_plan_step < PLAN_INTERVAL_STEPS:
|
|
105
|
+
return
|
|
106
|
+
self.last_plan_step = step_count
|
|
107
|
+
|
|
108
|
+
counts = self._choose_counts(step_count)
|
|
109
|
+
self.desired_vibes = _build_role_plan(self.num_agents, counts)
|
|
110
|
+
self._assign_targets()
|
|
111
|
+
|
|
112
|
+
if DEBUG:
|
|
113
|
+
print(
|
|
114
|
+
f"[TARGETED] plan@{step_count}: junctions={self.known_junctions} "
|
|
115
|
+
f"aligned={self.aligned_junctions} chest={self.chest_resources} "
|
|
116
|
+
f"roles={counts}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def _choose_counts(self, step_count: int) -> dict[str, int]:
|
|
120
|
+
if step_count < PHASE_EXPLORE_END or self.known_junctions == 0:
|
|
121
|
+
scouts = 3 if self.num_agents >= 8 else 2 if self.num_agents >= 5 else 1
|
|
122
|
+
return {
|
|
123
|
+
"scrambler": 0,
|
|
124
|
+
"aligner": 0,
|
|
125
|
+
"scout": scouts,
|
|
126
|
+
"miner": max(1, self.num_agents - scouts),
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
if 0 < self.chest_resources < CHEST_LOW_THRESHOLD:
|
|
130
|
+
scramblers = 1
|
|
131
|
+
aligners = 1
|
|
132
|
+
scouts = 1
|
|
133
|
+
return {
|
|
134
|
+
"scrambler": scramblers,
|
|
135
|
+
"aligner": aligners,
|
|
136
|
+
"scout": scouts,
|
|
137
|
+
"miner": max(1, self.num_agents - (scramblers + aligners + scouts)),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
if step_count < PHASE_CONTROL_END and self.aligned_junctions < max(1, self.known_junctions // 2):
|
|
141
|
+
if self.num_agents >= 8:
|
|
142
|
+
scramblers = 2
|
|
143
|
+
aligners = 3
|
|
144
|
+
elif self.num_agents >= 6:
|
|
145
|
+
scramblers = 1
|
|
146
|
+
aligners = 2
|
|
147
|
+
else:
|
|
148
|
+
scramblers = 1
|
|
149
|
+
aligners = 1
|
|
150
|
+
return {
|
|
151
|
+
"scrambler": scramblers,
|
|
152
|
+
"aligner": aligners,
|
|
153
|
+
"scout": 1,
|
|
154
|
+
"miner": max(1, self.num_agents - (scramblers + aligners + 1)),
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
return {
|
|
158
|
+
"scrambler": 1,
|
|
159
|
+
"aligner": 2 if self.num_agents >= 6 else 1,
|
|
160
|
+
"scout": 1,
|
|
161
|
+
"miner": max(1, self.num_agents - (2 if self.num_agents >= 6 else 1) - 2),
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
def _assign_targets(self) -> None:
|
|
165
|
+
junctions = [pos for pos, alignment in self.junction_map.items() if alignment != "cogs"]
|
|
166
|
+
junctions.sort()
|
|
167
|
+
extractors_by_resource: dict[str, list[tuple[int, int]]] = {res: [] for res in RESOURCE_CYCLE}
|
|
168
|
+
for pos, resource in self.extractor_map.items():
|
|
169
|
+
if resource in extractors_by_resource:
|
|
170
|
+
extractors_by_resource[resource].append(pos)
|
|
171
|
+
for positions in extractors_by_resource.values():
|
|
172
|
+
positions.sort()
|
|
173
|
+
all_extractors = sorted(self.extractor_map.keys())
|
|
174
|
+
|
|
175
|
+
self.assigned_junctions.clear()
|
|
176
|
+
self.assigned_extractors.clear()
|
|
177
|
+
if not self.desired_vibes:
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
junction_index = 0
|
|
181
|
+
extractor_index = 0
|
|
182
|
+
for agent_id, vibe in enumerate(self.desired_vibes):
|
|
183
|
+
if vibe in CONTROL_VIBES and junctions:
|
|
184
|
+
self.assigned_junctions[agent_id] = junctions[junction_index]
|
|
185
|
+
junction_index = (junction_index + 1) % len(junctions)
|
|
186
|
+
elif vibe == "miner" and self.extractor_map:
|
|
187
|
+
preferred = RESOURCE_CYCLE[agent_id % len(RESOURCE_CYCLE)]
|
|
188
|
+
preferred_list = extractors_by_resource.get(preferred, [])
|
|
189
|
+
if preferred_list:
|
|
190
|
+
self.assigned_extractors[agent_id] = preferred_list[extractor_index % len(preferred_list)]
|
|
191
|
+
elif all_extractors:
|
|
192
|
+
self.assigned_extractors[agent_id] = all_extractors[extractor_index % len(all_extractors)]
|
|
193
|
+
extractor_index += 1
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class TargetedMultiRoleImpl(CogsguardMultiRoleImpl):
|
|
197
|
+
def __init__(
|
|
198
|
+
self,
|
|
199
|
+
policy_env_info: PolicyEnvInterface,
|
|
200
|
+
agent_id: int,
|
|
201
|
+
initial_target_vibe: Optional[str],
|
|
202
|
+
shared_state: TargetedPlannerState,
|
|
203
|
+
):
|
|
204
|
+
super().__init__(policy_env_info, agent_id, initial_target_vibe=initial_target_vibe)
|
|
205
|
+
self._shared_state = shared_state
|
|
206
|
+
|
|
207
|
+
def _execute_phase(self, s: CogsguardAgentState) -> Action:
|
|
208
|
+
self._shared_state.update_from_agent(s)
|
|
209
|
+
if s.agent_id == 0:
|
|
210
|
+
self._shared_state.maybe_plan(s.step_count)
|
|
211
|
+
|
|
212
|
+
if self._shared_state.desired_vibes:
|
|
213
|
+
desired = self._shared_state.desired_vibes[s.agent_id]
|
|
214
|
+
if desired != s.current_vibe:
|
|
215
|
+
return change_vibe_action(desired, action_names=self._action_names)
|
|
216
|
+
|
|
217
|
+
return super()._execute_phase(s)
|
|
218
|
+
|
|
219
|
+
def execute_role(self, s: CogsguardAgentState) -> Action:
|
|
220
|
+
target = self._shared_state.assigned_junctions.get(s.agent_id)
|
|
221
|
+
if target and s.current_vibe in CONTROL_VIBES and s.has_gear() and s.heart >= 1:
|
|
222
|
+
struct = s.get_structure_at(target)
|
|
223
|
+
if struct and struct.alignment != "cogs":
|
|
224
|
+
if abs(target[0] - s.row) + abs(target[1] - s.col) > 1:
|
|
225
|
+
return self._move_towards(s, target, reach_adjacent=True)
|
|
226
|
+
return self._use_object_at(s, target)
|
|
227
|
+
return super().execute_role(s)
|
|
228
|
+
|
|
229
|
+
def _get_role_impl(self, role: Role) -> CogsguardAgentPolicyImpl:
|
|
230
|
+
if role not in self._role_impls:
|
|
231
|
+
impl_class = {
|
|
232
|
+
Role.MINER: TargetedMinerAgentPolicyImpl,
|
|
233
|
+
Role.SCOUT: ScoutAgentPolicyImpl,
|
|
234
|
+
Role.ALIGNER: TargetedAlignerAgentPolicyImpl,
|
|
235
|
+
Role.SCRAMBLER: TargetedScramblerAgentPolicyImpl,
|
|
236
|
+
}[role]
|
|
237
|
+
if role == Role.MINER:
|
|
238
|
+
self._role_impls[role] = impl_class(self._policy_env_info, self._agent_id, role, self._shared_state)
|
|
239
|
+
else:
|
|
240
|
+
self._role_impls[role] = impl_class(self._policy_env_info, self._agent_id, role)
|
|
241
|
+
return self._role_impls[role]
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class TargetedScramblerAgentPolicyImpl(ScramblerAgentPolicyImpl):
|
|
245
|
+
def _find_best_target(self, s: CogsguardAgentState) -> Optional[tuple[int, int]]:
|
|
246
|
+
junctions = s.get_structures_by_type(StructureType.CHARGER)
|
|
247
|
+
cooldown = 20 if len(junctions) <= 4 else 50
|
|
248
|
+
|
|
249
|
+
enemy_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
250
|
+
neutral_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
251
|
+
any_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
252
|
+
|
|
253
|
+
for junction in junctions:
|
|
254
|
+
pos = junction.position
|
|
255
|
+
dist = abs(pos[0] - s.row) + abs(pos[1] - s.col)
|
|
256
|
+
|
|
257
|
+
last_worked = s.worked_junctions.get(pos, 0)
|
|
258
|
+
if last_worked > 0 and s.step_count - last_worked < cooldown:
|
|
259
|
+
continue
|
|
260
|
+
|
|
261
|
+
if junction.alignment == "cogs":
|
|
262
|
+
continue
|
|
263
|
+
|
|
264
|
+
if junction.alignment == "clips" or junction.clipped:
|
|
265
|
+
enemy_junctions.append((dist, pos))
|
|
266
|
+
elif junction.alignment is None or junction.alignment == "neutral":
|
|
267
|
+
neutral_junctions.append((dist, pos))
|
|
268
|
+
else:
|
|
269
|
+
any_junctions.append((dist, pos))
|
|
270
|
+
|
|
271
|
+
if enemy_junctions:
|
|
272
|
+
enemy_junctions.sort()
|
|
273
|
+
return enemy_junctions[0][1]
|
|
274
|
+
if neutral_junctions:
|
|
275
|
+
neutral_junctions.sort()
|
|
276
|
+
return neutral_junctions[0][1]
|
|
277
|
+
if any_junctions:
|
|
278
|
+
any_junctions.sort()
|
|
279
|
+
return any_junctions[0][1]
|
|
280
|
+
|
|
281
|
+
return super()._find_best_target(s)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class TargetedAlignerAgentPolicyImpl(AlignerAgentPolicyImpl):
|
|
285
|
+
def _find_best_target(self, s: CogsguardAgentState) -> Optional[tuple[int, int]]:
|
|
286
|
+
junctions = s.get_structures_by_type(StructureType.CHARGER)
|
|
287
|
+
cooldown = 20 if len(junctions) <= 4 else 50
|
|
288
|
+
|
|
289
|
+
neutral_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
290
|
+
clips_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
291
|
+
other_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
292
|
+
|
|
293
|
+
for junction in junctions:
|
|
294
|
+
pos = junction.position
|
|
295
|
+
dist = abs(pos[0] - s.row) + abs(pos[1] - s.col)
|
|
296
|
+
|
|
297
|
+
last_worked = s.worked_junctions.get(pos, 0)
|
|
298
|
+
if last_worked > 0 and s.step_count - last_worked < cooldown:
|
|
299
|
+
continue
|
|
300
|
+
|
|
301
|
+
if junction.alignment == "cogs":
|
|
302
|
+
continue
|
|
303
|
+
|
|
304
|
+
if junction.alignment is None or junction.alignment == "neutral":
|
|
305
|
+
neutral_junctions.append((dist, pos))
|
|
306
|
+
elif junction.alignment == "clips" or junction.clipped:
|
|
307
|
+
clips_junctions.append((dist, pos))
|
|
308
|
+
else:
|
|
309
|
+
other_junctions.append((dist, pos))
|
|
310
|
+
|
|
311
|
+
if neutral_junctions:
|
|
312
|
+
neutral_junctions.sort()
|
|
313
|
+
return neutral_junctions[0][1]
|
|
314
|
+
if clips_junctions:
|
|
315
|
+
clips_junctions.sort()
|
|
316
|
+
return clips_junctions[0][1]
|
|
317
|
+
if other_junctions:
|
|
318
|
+
other_junctions.sort()
|
|
319
|
+
return other_junctions[0][1]
|
|
320
|
+
|
|
321
|
+
return super()._find_best_target(s)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class TargetedMinerAgentPolicyImpl(MinerAgentPolicyImpl):
|
|
325
|
+
def __init__(
|
|
326
|
+
self,
|
|
327
|
+
policy_env_info: PolicyEnvInterface,
|
|
328
|
+
agent_id: int,
|
|
329
|
+
role: Role,
|
|
330
|
+
shared_state: TargetedPlannerState,
|
|
331
|
+
):
|
|
332
|
+
super().__init__(policy_env_info, agent_id, role)
|
|
333
|
+
self._shared_state = shared_state
|
|
334
|
+
self._preferred_resource = RESOURCE_CYCLE[agent_id % len(RESOURCE_CYCLE)]
|
|
335
|
+
|
|
336
|
+
def _get_safe_extractor(
|
|
337
|
+
self,
|
|
338
|
+
s: CogsguardAgentState,
|
|
339
|
+
preferred_resource: str | None = None,
|
|
340
|
+
) -> Optional[StructureInfo]:
|
|
341
|
+
target = self._shared_state.assigned_extractors.get(s.agent_id)
|
|
342
|
+
if target:
|
|
343
|
+
current = s.get_structure_at(target)
|
|
344
|
+
if current and current.is_usable_extractor():
|
|
345
|
+
max_safe_dist = self._get_max_safe_distance(s)
|
|
346
|
+
dist_to_ext = abs(target[0] - s.row) + abs(target[1] - s.col)
|
|
347
|
+
nearest_depot = self._get_nearest_aligned_depot(s)
|
|
348
|
+
if nearest_depot:
|
|
349
|
+
dist_ext_to_depot = abs(target[0] - nearest_depot[0]) + abs(target[1] - nearest_depot[1])
|
|
350
|
+
round_trip = dist_to_ext + max(0, dist_ext_to_depot - HEALING_AOE_RANGE)
|
|
351
|
+
else:
|
|
352
|
+
round_trip = dist_to_ext * 2
|
|
353
|
+
if round_trip <= max_safe_dist:
|
|
354
|
+
return current
|
|
355
|
+
resource = preferred_resource or self._preferred_resource
|
|
356
|
+
preferred = [ext for ext in s.get_usable_extractors() if ext.resource_type == resource]
|
|
357
|
+
if preferred:
|
|
358
|
+
nearest_depot = self._get_nearest_aligned_depot(s)
|
|
359
|
+
max_safe_dist = self._get_max_safe_distance(s)
|
|
360
|
+
candidates: list[tuple[int, int, int, StructureInfo]] = []
|
|
361
|
+
for ext in preferred:
|
|
362
|
+
dist_to_ext = abs(ext.position[0] - s.row) + abs(ext.position[1] - s.col)
|
|
363
|
+
if nearest_depot:
|
|
364
|
+
dist_ext_to_depot = abs(ext.position[0] - nearest_depot[0]) + abs(
|
|
365
|
+
ext.position[1] - nearest_depot[1]
|
|
366
|
+
)
|
|
367
|
+
round_trip = dist_to_ext + max(0, dist_ext_to_depot - HEALING_AOE_RANGE)
|
|
368
|
+
else:
|
|
369
|
+
round_trip = dist_to_ext * 2
|
|
370
|
+
if round_trip <= max_safe_dist:
|
|
371
|
+
dist_ext_to_depot = (
|
|
372
|
+
abs(ext.position[0] - nearest_depot[0]) + abs(ext.position[1] - nearest_depot[1])
|
|
373
|
+
if nearest_depot
|
|
374
|
+
else 100
|
|
375
|
+
)
|
|
376
|
+
candidates.append((ext.inventory_amount, dist_ext_to_depot, dist_to_ext, ext))
|
|
377
|
+
if candidates:
|
|
378
|
+
candidates.sort(key=lambda x: (-x[0], x[1], x[2]))
|
|
379
|
+
return candidates[0][3]
|
|
380
|
+
|
|
381
|
+
return super()._get_safe_extractor(s, preferred_resource=preferred_resource)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
class CogsguardTargetedAgent(CogsguardPolicy):
|
|
385
|
+
"""CoGsGuard policy with coordinated role and target assignment."""
|
|
386
|
+
|
|
387
|
+
short_names = ["cogsguard_targeted"]
|
|
388
|
+
|
|
389
|
+
def __init__(
|
|
390
|
+
self,
|
|
391
|
+
policy_env_info: PolicyEnvInterface,
|
|
392
|
+
device: str = "cpu",
|
|
393
|
+
**vibe_counts: Any,
|
|
394
|
+
):
|
|
395
|
+
has_explicit_counts = any(isinstance(v, int) for v in vibe_counts.values())
|
|
396
|
+
if has_explicit_counts:
|
|
397
|
+
counts = _normalize_counts(policy_env_info.num_agents, vibe_counts)
|
|
398
|
+
else:
|
|
399
|
+
counts = _default_role_counts(policy_env_info.num_agents)
|
|
400
|
+
super().__init__(policy_env_info, device=device, **counts)
|
|
401
|
+
self._shared_state = TargetedPlannerState(policy_env_info.num_agents)
|
|
402
|
+
self._shared_state.desired_vibes = _build_role_plan(policy_env_info.num_agents, counts)
|
|
403
|
+
|
|
404
|
+
def agent_policy(self, agent_id: int) -> StatefulAgentPolicy[CogsguardAgentState]:
|
|
405
|
+
if agent_id not in self._agent_policies:
|
|
406
|
+
target_vibe = None
|
|
407
|
+
if agent_id < len(self._initial_vibes):
|
|
408
|
+
target_vibe = self._initial_vibes[agent_id]
|
|
409
|
+
|
|
410
|
+
impl = TargetedMultiRoleImpl(
|
|
411
|
+
self._policy_env_info,
|
|
412
|
+
agent_id,
|
|
413
|
+
initial_target_vibe=target_vibe,
|
|
414
|
+
shared_state=self._shared_state,
|
|
415
|
+
)
|
|
416
|
+
self._agent_policies[agent_id] = StatefulAgentPolicy(impl, self._policy_env_info, agent_id=agent_id)
|
|
417
|
+
|
|
418
|
+
return self._agent_policies[agent_id]
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from cogames_agents.policy.nim_agents.agents import CogsguardAgentsMultiPolicy
|
|
8
|
+
from cogames_agents.policy.scripted_agent.cogsguard.types import Role as CogsguardRole
|
|
9
|
+
from cogames_agents.policy.scripted_agent.common.roles import ROLE_VIBES
|
|
10
|
+
from mettagrid.policy.policy import AgentPolicy, MultiAgentPolicy
|
|
11
|
+
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
|
|
12
|
+
from mettagrid.simulator import Action, AgentObservation
|
|
13
|
+
|
|
14
|
+
DEFAULT_ROLE_VIBES = tuple(ROLE_VIBES)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CogsguardTeacherPolicy(MultiAgentPolicy):
|
|
18
|
+
"""Teacher wrapper that forces an initial vibe, then delegates to the Nim policy."""
|
|
19
|
+
|
|
20
|
+
short_names = ["teacher"]
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
policy_env_info: PolicyEnvInterface,
|
|
25
|
+
device: str = "cpu",
|
|
26
|
+
role_vibes: Optional[Sequence[str | CogsguardRole] | str] = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
super().__init__(policy_env_info, device=device)
|
|
29
|
+
self._delegate = CogsguardAgentsMultiPolicy(policy_env_info)
|
|
30
|
+
self._num_agents = policy_env_info.num_agents
|
|
31
|
+
self._action_names = list(policy_env_info.action_names)
|
|
32
|
+
self._action_name_to_index = {name: idx for idx, name in enumerate(self._action_names)}
|
|
33
|
+
self._delegate_agents = [self._delegate.agent_policy(i) for i in range(self._num_agents)]
|
|
34
|
+
|
|
35
|
+
self._episode_feature_id = self._find_feature_id("episode_completion_pct")
|
|
36
|
+
self._last_action_feature_id = self._find_feature_id("last_action")
|
|
37
|
+
|
|
38
|
+
self._role_action_ids = self._resolve_role_actions(role_vibes)
|
|
39
|
+
self._reset_episode_state()
|
|
40
|
+
|
|
41
|
+
def agent_policy(self, agent_id: int) -> AgentPolicy:
|
|
42
|
+
return _CogsguardTeacherAgentPolicy(self, agent_id)
|
|
43
|
+
|
|
44
|
+
def reset(self) -> None:
|
|
45
|
+
self._delegate.reset()
|
|
46
|
+
self._reset_episode_state()
|
|
47
|
+
|
|
48
|
+
def step_batch(self, raw_observations: np.ndarray, raw_actions: np.ndarray) -> None:
|
|
49
|
+
self._delegate.step_batch(raw_observations, raw_actions)
|
|
50
|
+
if not self._role_action_ids:
|
|
51
|
+
return
|
|
52
|
+
if raw_observations.shape[0] != self._num_agents:
|
|
53
|
+
return
|
|
54
|
+
for agent_id in range(self._num_agents):
|
|
55
|
+
episode_pct = self._extract_episode_pct_raw(raw_observations[agent_id])
|
|
56
|
+
last_action = self._extract_last_action_raw(raw_observations[agent_id])
|
|
57
|
+
forced_action = self._maybe_force_action(agent_id, episode_pct, last_action)
|
|
58
|
+
if forced_action is not None:
|
|
59
|
+
raw_actions[agent_id] = forced_action
|
|
60
|
+
|
|
61
|
+
def _step_single(self, agent_id: int, obs: AgentObservation) -> Action:
|
|
62
|
+
base_action = self._delegate_agents[agent_id].step(obs)
|
|
63
|
+
if not self._role_action_ids:
|
|
64
|
+
return base_action
|
|
65
|
+
episode_pct = self._extract_episode_pct_obs(obs)
|
|
66
|
+
last_action = self._extract_last_action_obs(obs)
|
|
67
|
+
forced_action = self._maybe_force_action(agent_id, episode_pct, last_action)
|
|
68
|
+
if forced_action is None:
|
|
69
|
+
return base_action
|
|
70
|
+
action_name = self._action_names[forced_action]
|
|
71
|
+
return Action(name=action_name)
|
|
72
|
+
|
|
73
|
+
def _extract_episode_pct_raw(self, raw_obs: np.ndarray) -> Optional[int]:
|
|
74
|
+
if self._episode_feature_id is None:
|
|
75
|
+
return None
|
|
76
|
+
for token in raw_obs:
|
|
77
|
+
if token[0] == 255 and token[1] == 255 and token[2] == 255:
|
|
78
|
+
break
|
|
79
|
+
if token[1] == self._episode_feature_id:
|
|
80
|
+
return int(token[2])
|
|
81
|
+
return 0
|
|
82
|
+
|
|
83
|
+
def _extract_episode_pct_obs(self, obs: AgentObservation) -> Optional[int]:
|
|
84
|
+
if self._episode_feature_id is None:
|
|
85
|
+
return None
|
|
86
|
+
for token in obs.tokens:
|
|
87
|
+
if token.feature.name == "episode_completion_pct":
|
|
88
|
+
return token.value
|
|
89
|
+
return 0
|
|
90
|
+
|
|
91
|
+
def _extract_last_action_raw(self, raw_obs: np.ndarray) -> Optional[int]:
|
|
92
|
+
if self._last_action_feature_id is None:
|
|
93
|
+
return None
|
|
94
|
+
for token in raw_obs:
|
|
95
|
+
if token[0] == 255 and token[1] == 255 and token[2] == 255:
|
|
96
|
+
break
|
|
97
|
+
if token[1] == self._last_action_feature_id:
|
|
98
|
+
return int(token[2])
|
|
99
|
+
return 0
|
|
100
|
+
|
|
101
|
+
def _extract_last_action_obs(self, obs: AgentObservation) -> Optional[int]:
|
|
102
|
+
if self._last_action_feature_id is None:
|
|
103
|
+
return None
|
|
104
|
+
for token in obs.tokens:
|
|
105
|
+
if token.feature.name == "last_action":
|
|
106
|
+
return token.value
|
|
107
|
+
return 0
|
|
108
|
+
|
|
109
|
+
def _find_feature_id(self, feature_name: str) -> Optional[int]:
|
|
110
|
+
for feature in self.policy_env_info.obs_features:
|
|
111
|
+
if feature.name == feature_name:
|
|
112
|
+
return feature.id
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
def _resolve_role_actions(self, role_vibes: Optional[Sequence[str | CogsguardRole] | str]) -> list[int]:
|
|
116
|
+
change_vibe_actions = [name for name in self._action_names if name.startswith("change_vibe_")]
|
|
117
|
+
if len(change_vibe_actions) <= 1:
|
|
118
|
+
return []
|
|
119
|
+
|
|
120
|
+
available_vibes = [name[len("change_vibe_") :] for name in change_vibe_actions]
|
|
121
|
+
if role_vibes is None:
|
|
122
|
+
role_vibes = [vibe for vibe in DEFAULT_ROLE_VIBES if vibe in available_vibes]
|
|
123
|
+
if not role_vibes:
|
|
124
|
+
role_vibes = [vibe for vibe in available_vibes if vibe != "default"]
|
|
125
|
+
if not role_vibes:
|
|
126
|
+
role_vibes = available_vibes
|
|
127
|
+
else:
|
|
128
|
+
if isinstance(role_vibes, str):
|
|
129
|
+
normalized_vibes = [vibe.strip() for vibe in role_vibes.split(",") if vibe.strip()]
|
|
130
|
+
else:
|
|
131
|
+
normalized_vibes = [vibe.value if isinstance(vibe, CogsguardRole) else str(vibe) for vibe in role_vibes]
|
|
132
|
+
role_vibes = [vibe for vibe in normalized_vibes if vibe in available_vibes]
|
|
133
|
+
if not role_vibes:
|
|
134
|
+
role_vibes = available_vibes
|
|
135
|
+
|
|
136
|
+
role_action_ids = []
|
|
137
|
+
for vibe_name in role_vibes:
|
|
138
|
+
action_name = f"change_vibe_{vibe_name}"
|
|
139
|
+
action_id = self._action_name_to_index.get(action_name)
|
|
140
|
+
if action_id is not None:
|
|
141
|
+
role_action_ids.append(action_id)
|
|
142
|
+
return role_action_ids
|
|
143
|
+
|
|
144
|
+
def _reset_episode_state(self) -> None:
|
|
145
|
+
self._episode_index = [0] * self._num_agents
|
|
146
|
+
self._forced_vibe = [False] * self._num_agents
|
|
147
|
+
self._last_episode_pct = [-1] * self._num_agents
|
|
148
|
+
self._step_in_episode = [0] * self._num_agents
|
|
149
|
+
self._last_action_value: list[Optional[int]] = [None] * self._num_agents
|
|
150
|
+
|
|
151
|
+
def _maybe_force_action(
|
|
152
|
+
self,
|
|
153
|
+
agent_id: int,
|
|
154
|
+
episode_pct: Optional[int],
|
|
155
|
+
last_action: Optional[int],
|
|
156
|
+
) -> Optional[int]:
|
|
157
|
+
self._update_episode_state(agent_id, episode_pct, last_action)
|
|
158
|
+
if self._forced_vibe[agent_id] or self._step_in_episode[agent_id] != 0:
|
|
159
|
+
return None
|
|
160
|
+
self._forced_vibe[agent_id] = True
|
|
161
|
+
role_index = (self._episode_index[agent_id] + agent_id) % len(self._role_action_ids)
|
|
162
|
+
return self._role_action_ids[role_index]
|
|
163
|
+
|
|
164
|
+
def _update_episode_state(
|
|
165
|
+
self,
|
|
166
|
+
agent_id: int,
|
|
167
|
+
episode_pct: Optional[int],
|
|
168
|
+
last_action: Optional[int],
|
|
169
|
+
) -> None:
|
|
170
|
+
last_pct = self._last_episode_pct[agent_id]
|
|
171
|
+
if episode_pct is None:
|
|
172
|
+
last_action_seen = self._last_action_value[agent_id]
|
|
173
|
+
if (
|
|
174
|
+
last_action is not None
|
|
175
|
+
and last_action_seen is not None
|
|
176
|
+
and last_action == 0
|
|
177
|
+
and last_action_seen != 0
|
|
178
|
+
and self._step_in_episode[agent_id] > 0
|
|
179
|
+
):
|
|
180
|
+
self._episode_index[agent_id] += 1
|
|
181
|
+
self._step_in_episode[agent_id] = 0
|
|
182
|
+
self._forced_vibe[agent_id] = False
|
|
183
|
+
self._last_episode_pct[agent_id] = 0
|
|
184
|
+
self._last_action_value[agent_id] = last_action
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
if last_pct == -1:
|
|
188
|
+
self._step_in_episode[agent_id] = 0
|
|
189
|
+
else:
|
|
190
|
+
self._step_in_episode[agent_id] += 1
|
|
191
|
+
self._last_episode_pct[agent_id] = 0
|
|
192
|
+
if last_action is not None:
|
|
193
|
+
self._last_action_value[agent_id] = last_action
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
new_episode = False
|
|
197
|
+
if last_pct == -1:
|
|
198
|
+
new_episode = True
|
|
199
|
+
elif episode_pct < last_pct:
|
|
200
|
+
new_episode = True
|
|
201
|
+
elif last_pct > 0 and episode_pct == 0:
|
|
202
|
+
new_episode = True
|
|
203
|
+
|
|
204
|
+
if new_episode:
|
|
205
|
+
if last_pct != -1:
|
|
206
|
+
self._episode_index[agent_id] += 1
|
|
207
|
+
self._step_in_episode[agent_id] = 0
|
|
208
|
+
self._forced_vibe[agent_id] = False
|
|
209
|
+
else:
|
|
210
|
+
self._step_in_episode[agent_id] += 1
|
|
211
|
+
|
|
212
|
+
self._last_episode_pct[agent_id] = episode_pct
|
|
213
|
+
if last_action is not None:
|
|
214
|
+
self._last_action_value[agent_id] = last_action
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class _CogsguardTeacherAgentPolicy(AgentPolicy):
|
|
218
|
+
def __init__(self, parent: CogsguardTeacherPolicy, agent_id: int) -> None:
|
|
219
|
+
super().__init__(parent.policy_env_info)
|
|
220
|
+
self._parent = parent
|
|
221
|
+
self._agent_id = agent_id
|
|
222
|
+
|
|
223
|
+
def step(self, obs: AgentObservation) -> Action:
|
|
224
|
+
return self._parent._step_single(self._agent_id, obs)
|