cogames-agents 0.0.0.7__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cogames_agents/__init__.py +0 -0
- cogames_agents/evals/__init__.py +5 -0
- cogames_agents/evals/planky_evals.py +415 -0
- cogames_agents/policy/__init__.py +0 -0
- cogames_agents/policy/evolution/__init__.py +0 -0
- cogames_agents/policy/evolution/cogsguard/__init__.py +0 -0
- cogames_agents/policy/evolution/cogsguard/evolution.py +695 -0
- cogames_agents/policy/evolution/cogsguard/evolutionary_coordinator.py +540 -0
- cogames_agents/policy/nim_agents/__init__.py +20 -0
- cogames_agents/policy/nim_agents/agents.py +98 -0
- cogames_agents/policy/nim_agents/bindings/generated/libnim_agents.dylib +0 -0
- cogames_agents/policy/nim_agents/bindings/generated/nim_agents.py +215 -0
- cogames_agents/policy/nim_agents/cogsguard_agents.nim +555 -0
- cogames_agents/policy/nim_agents/cogsguard_align_all_agents.nim +569 -0
- cogames_agents/policy/nim_agents/common.nim +1054 -0
- cogames_agents/policy/nim_agents/install.sh +1 -0
- cogames_agents/policy/nim_agents/ladybug_agent.nim +954 -0
- cogames_agents/policy/nim_agents/nim_agents.nim +68 -0
- cogames_agents/policy/nim_agents/nim_agents.nims +14 -0
- cogames_agents/policy/nim_agents/nimby.lock +3 -0
- cogames_agents/policy/nim_agents/racecar_agents.nim +844 -0
- cogames_agents/policy/nim_agents/random_agents.nim +68 -0
- cogames_agents/policy/nim_agents/test_agents.py +53 -0
- cogames_agents/policy/nim_agents/thinky_agents.nim +677 -0
- cogames_agents/policy/nim_agents/thinky_eval.py +230 -0
- cogames_agents/policy/scripted_agent/README.md +360 -0
- cogames_agents/policy/scripted_agent/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/baseline_agent.py +1031 -0
- cogames_agents/policy/scripted_agent/cogas/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/cogas/context.py +68 -0
- cogames_agents/policy/scripted_agent/cogas/entity_map.py +152 -0
- cogames_agents/policy/scripted_agent/cogas/goal.py +115 -0
- cogames_agents/policy/scripted_agent/cogas/goals/__init__.py +27 -0
- cogames_agents/policy/scripted_agent/cogas/goals/aligner.py +160 -0
- cogames_agents/policy/scripted_agent/cogas/goals/gear.py +197 -0
- cogames_agents/policy/scripted_agent/cogas/goals/miner.py +441 -0
- cogames_agents/policy/scripted_agent/cogas/goals/scout.py +40 -0
- cogames_agents/policy/scripted_agent/cogas/goals/scrambler.py +174 -0
- cogames_agents/policy/scripted_agent/cogas/goals/shared.py +160 -0
- cogames_agents/policy/scripted_agent/cogas/goals/stem.py +60 -0
- cogames_agents/policy/scripted_agent/cogas/goals/survive.py +100 -0
- cogames_agents/policy/scripted_agent/cogas/navigator.py +401 -0
- cogames_agents/policy/scripted_agent/cogas/obs_parser.py +238 -0
- cogames_agents/policy/scripted_agent/cogas/policy.py +525 -0
- cogames_agents/policy/scripted_agent/cogas/trace.py +69 -0
- cogames_agents/policy/scripted_agent/cogsguard/CLAUDE.md +517 -0
- cogames_agents/policy/scripted_agent/cogsguard/README.md +252 -0
- cogames_agents/policy/scripted_agent/cogsguard/__init__.py +74 -0
- cogames_agents/policy/scripted_agent/cogsguard/aligned_junction_held_investigation.md +152 -0
- cogames_agents/policy/scripted_agent/cogsguard/aligner.py +333 -0
- cogames_agents/policy/scripted_agent/cogsguard/behavior_hooks.py +44 -0
- cogames_agents/policy/scripted_agent/cogsguard/control_agent.py +323 -0
- cogames_agents/policy/scripted_agent/cogsguard/debug_agent.py +533 -0
- cogames_agents/policy/scripted_agent/cogsguard/miner.py +589 -0
- cogames_agents/policy/scripted_agent/cogsguard/options.py +67 -0
- cogames_agents/policy/scripted_agent/cogsguard/parity_metrics.py +36 -0
- cogames_agents/policy/scripted_agent/cogsguard/policy.py +1967 -0
- cogames_agents/policy/scripted_agent/cogsguard/prereq_trace.py +33 -0
- cogames_agents/policy/scripted_agent/cogsguard/role_trace.py +50 -0
- cogames_agents/policy/scripted_agent/cogsguard/roles.py +31 -0
- cogames_agents/policy/scripted_agent/cogsguard/rollout_trace.py +40 -0
- cogames_agents/policy/scripted_agent/cogsguard/scout.py +69 -0
- cogames_agents/policy/scripted_agent/cogsguard/scrambler.py +350 -0
- cogames_agents/policy/scripted_agent/cogsguard/targeted_agent.py +418 -0
- cogames_agents/policy/scripted_agent/cogsguard/teacher.py +224 -0
- cogames_agents/policy/scripted_agent/cogsguard/types.py +381 -0
- cogames_agents/policy/scripted_agent/cogsguard/v2_agent.py +49 -0
- cogames_agents/policy/scripted_agent/common/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/common/geometry.py +24 -0
- cogames_agents/policy/scripted_agent/common/roles.py +34 -0
- cogames_agents/policy/scripted_agent/common/tag_utils.py +48 -0
- cogames_agents/policy/scripted_agent/demo_policy.py +242 -0
- cogames_agents/policy/scripted_agent/pathfinding.py +126 -0
- cogames_agents/policy/scripted_agent/pinky/DESIGN.md +317 -0
- cogames_agents/policy/scripted_agent/pinky/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/__init__.py +17 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/aligner.py +400 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/base.py +119 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/miner.py +632 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/scout.py +138 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/scrambler.py +433 -0
- cogames_agents/policy/scripted_agent/pinky/policy.py +570 -0
- cogames_agents/policy/scripted_agent/pinky/services/__init__.py +7 -0
- cogames_agents/policy/scripted_agent/pinky/services/map_tracker.py +808 -0
- cogames_agents/policy/scripted_agent/pinky/services/navigator.py +864 -0
- cogames_agents/policy/scripted_agent/pinky/services/safety.py +189 -0
- cogames_agents/policy/scripted_agent/pinky/state.py +299 -0
- cogames_agents/policy/scripted_agent/pinky/types.py +138 -0
- cogames_agents/policy/scripted_agent/planky/CLAUDE.md +124 -0
- cogames_agents/policy/scripted_agent/planky/IMPROVEMENTS.md +160 -0
- cogames_agents/policy/scripted_agent/planky/NOTES.md +153 -0
- cogames_agents/policy/scripted_agent/planky/PLAN.md +254 -0
- cogames_agents/policy/scripted_agent/planky/README.md +214 -0
- cogames_agents/policy/scripted_agent/planky/STRATEGY.md +100 -0
- cogames_agents/policy/scripted_agent/planky/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/planky/context.py +68 -0
- cogames_agents/policy/scripted_agent/planky/entity_map.py +152 -0
- cogames_agents/policy/scripted_agent/planky/goal.py +107 -0
- cogames_agents/policy/scripted_agent/planky/goals/__init__.py +27 -0
- cogames_agents/policy/scripted_agent/planky/goals/aligner.py +168 -0
- cogames_agents/policy/scripted_agent/planky/goals/gear.py +179 -0
- cogames_agents/policy/scripted_agent/planky/goals/miner.py +416 -0
- cogames_agents/policy/scripted_agent/planky/goals/scout.py +40 -0
- cogames_agents/policy/scripted_agent/planky/goals/scrambler.py +174 -0
- cogames_agents/policy/scripted_agent/planky/goals/shared.py +160 -0
- cogames_agents/policy/scripted_agent/planky/goals/stem.py +49 -0
- cogames_agents/policy/scripted_agent/planky/goals/survive.py +96 -0
- cogames_agents/policy/scripted_agent/planky/navigator.py +388 -0
- cogames_agents/policy/scripted_agent/planky/obs_parser.py +238 -0
- cogames_agents/policy/scripted_agent/planky/policy.py +485 -0
- cogames_agents/policy/scripted_agent/planky/tests/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/planky/tests/conftest.py +66 -0
- cogames_agents/policy/scripted_agent/planky/tests/helpers.py +152 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_aligner.py +24 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_miner.py +30 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_scout.py +15 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_scrambler.py +29 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_stem.py +36 -0
- cogames_agents/policy/scripted_agent/planky/trace.py +69 -0
- cogames_agents/policy/scripted_agent/types.py +239 -0
- cogames_agents/policy/scripted_agent/unclipping_agent.py +461 -0
- cogames_agents/policy/scripted_agent/utils.py +381 -0
- cogames_agents/policy/scripted_registry.py +80 -0
- cogames_agents/py.typed +0 -0
- cogames_agents-0.0.0.7.dist-info/METADATA +98 -0
- cogames_agents-0.0.0.7.dist-info/RECORD +128 -0
- cogames_agents-0.0.0.7.dist-info/WHEEL +6 -0
- cogames_agents-0.0.0.7.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
"""CoGsGuard scripted policy with a phased leader coordinator."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Optional
|
|
7
|
+
|
|
8
|
+
from cogames_agents.policy.scripted_agent.utils import change_vibe_action
|
|
9
|
+
from mettagrid.policy.policy import StatefulAgentPolicy
|
|
10
|
+
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
|
|
11
|
+
from mettagrid.simulator import Action
|
|
12
|
+
|
|
13
|
+
from .aligner import AlignerAgentPolicyImpl
|
|
14
|
+
from .miner import MinerAgentPolicyImpl
|
|
15
|
+
from .policy import DEBUG, CogsguardAgentPolicyImpl, CogsguardMultiRoleImpl, CogsguardPolicy
|
|
16
|
+
from .scout import ScoutAgentPolicyImpl
|
|
17
|
+
from .scrambler import ScramblerAgentPolicyImpl
|
|
18
|
+
from .types import CogsguardAgentState, Role, StructureType
|
|
19
|
+
|
|
20
|
+
PLAN_INTERVAL_STEPS = 40
|
|
21
|
+
PHASE_EXPLORE_END = 60
|
|
22
|
+
PHASE_CONTROL_END = 220
|
|
23
|
+
CHEST_LOW_THRESHOLD = 60
|
|
24
|
+
CONTROL_VIBES = {"scrambler", "aligner"}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _default_role_counts(num_agents: int) -> dict[str, int]:
|
|
28
|
+
if num_agents <= 1:
|
|
29
|
+
return {"miner": 1}
|
|
30
|
+
if num_agents == 2:
|
|
31
|
+
return {"scrambler": 1, "miner": 1}
|
|
32
|
+
if num_agents == 3:
|
|
33
|
+
return {"scrambler": 1, "miner": 1, "scout": 1}
|
|
34
|
+
if num_agents <= 7:
|
|
35
|
+
scramblers = 1
|
|
36
|
+
aligners = 1
|
|
37
|
+
scouts = 1
|
|
38
|
+
else:
|
|
39
|
+
scramblers = max(2, num_agents // 6)
|
|
40
|
+
aligners = max(2, num_agents // 6)
|
|
41
|
+
scouts = 1
|
|
42
|
+
miners = max(1, num_agents - scramblers - scouts - aligners)
|
|
43
|
+
return {
|
|
44
|
+
"scrambler": scramblers,
|
|
45
|
+
"aligner": aligners,
|
|
46
|
+
"miner": miners,
|
|
47
|
+
"scout": scouts,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _normalize_counts(num_agents: int, counts: dict[str, int]) -> dict[str, int]:
|
|
52
|
+
normalized = {k: v for k, v in counts.items() if isinstance(v, int)}
|
|
53
|
+
total = sum(normalized.values())
|
|
54
|
+
if total < num_agents:
|
|
55
|
+
normalized["miner"] = normalized.get("miner", 0) + (num_agents - total)
|
|
56
|
+
elif total > num_agents:
|
|
57
|
+
overflow = total - num_agents
|
|
58
|
+
miners = normalized.get("miner", 0)
|
|
59
|
+
normalized["miner"] = max(0, miners - overflow)
|
|
60
|
+
return normalized
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _build_role_plan(num_agents: int, counts: dict[str, int]) -> list[str]:
|
|
64
|
+
ordered: list[str] = []
|
|
65
|
+
for role_name in ["scrambler", "aligner", "miner", "scout"]:
|
|
66
|
+
ordered.extend([role_name] * counts.get(role_name, 0))
|
|
67
|
+
if len(ordered) < num_agents:
|
|
68
|
+
ordered.extend(["miner"] * (num_agents - len(ordered)))
|
|
69
|
+
return ordered[:num_agents]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class CommanderPlannerState:
|
|
74
|
+
num_agents: int
|
|
75
|
+
desired_vibes: list[str] = field(default_factory=list)
|
|
76
|
+
last_plan_step: int = 0
|
|
77
|
+
known_junctions: int = 0
|
|
78
|
+
aligned_junctions: int = 0
|
|
79
|
+
chest_resources: int = 0
|
|
80
|
+
junction_map: dict[tuple[int, int], Optional[str]] = field(default_factory=dict)
|
|
81
|
+
assigned_targets: dict[int, tuple[int, int]] = field(default_factory=dict)
|
|
82
|
+
|
|
83
|
+
def update_from_agent(self, s: CogsguardAgentState) -> None:
|
|
84
|
+
junctions = s.get_structures_by_type(StructureType.CHARGER)
|
|
85
|
+
aligned = [c for c in junctions if c.alignment == "cogs"]
|
|
86
|
+
self.known_junctions = max(self.known_junctions, len(junctions))
|
|
87
|
+
self.aligned_junctions = max(self.aligned_junctions, len(aligned))
|
|
88
|
+
for junction in junctions:
|
|
89
|
+
self.junction_map[junction.position] = junction.alignment
|
|
90
|
+
|
|
91
|
+
chest_resources = 0
|
|
92
|
+
for struct in s.get_structures_by_type(StructureType.CHEST):
|
|
93
|
+
chest_resources = max(chest_resources, struct.inventory_amount)
|
|
94
|
+
if chest_resources > 0:
|
|
95
|
+
self.chest_resources = max(self.chest_resources, chest_resources)
|
|
96
|
+
|
|
97
|
+
def maybe_plan(self, step_count: int) -> None:
|
|
98
|
+
if step_count - self.last_plan_step < PLAN_INTERVAL_STEPS:
|
|
99
|
+
return
|
|
100
|
+
self.last_plan_step = step_count
|
|
101
|
+
|
|
102
|
+
counts = self._choose_counts(step_count)
|
|
103
|
+
self.desired_vibes = _build_role_plan(self.num_agents, counts)
|
|
104
|
+
self._assign_targets()
|
|
105
|
+
|
|
106
|
+
if DEBUG:
|
|
107
|
+
print(
|
|
108
|
+
f"[COMMANDER] plan@{step_count}: junctions={self.known_junctions} "
|
|
109
|
+
f"aligned={self.aligned_junctions} chest={self.chest_resources} "
|
|
110
|
+
f"roles={counts}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def _choose_counts(self, step_count: int) -> dict[str, int]:
|
|
114
|
+
if step_count < PHASE_EXPLORE_END or self.known_junctions == 0:
|
|
115
|
+
scouts = 2 if self.num_agents >= 5 else 1
|
|
116
|
+
return {
|
|
117
|
+
"scrambler": 1,
|
|
118
|
+
"aligner": 0,
|
|
119
|
+
"scout": scouts,
|
|
120
|
+
"miner": max(1, self.num_agents - (1 + scouts)),
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
if step_count < PHASE_CONTROL_END and self.aligned_junctions < max(1, self.known_junctions // 3):
|
|
124
|
+
scramblers = 2 if self.num_agents >= 6 else 1
|
|
125
|
+
aligners = 2 if self.num_agents >= 6 else 1
|
|
126
|
+
return {
|
|
127
|
+
"scrambler": scramblers,
|
|
128
|
+
"aligner": aligners,
|
|
129
|
+
"scout": 1,
|
|
130
|
+
"miner": max(1, self.num_agents - (scramblers + aligners + 1)),
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
if 0 < self.chest_resources < CHEST_LOW_THRESHOLD:
|
|
134
|
+
return {
|
|
135
|
+
"scrambler": 1,
|
|
136
|
+
"aligner": 1,
|
|
137
|
+
"scout": 1,
|
|
138
|
+
"miner": max(1, self.num_agents - 3),
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
return {
|
|
142
|
+
"scrambler": 1,
|
|
143
|
+
"aligner": 1,
|
|
144
|
+
"scout": 1,
|
|
145
|
+
"miner": max(1, self.num_agents - 3),
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
def _assign_targets(self) -> None:
|
|
149
|
+
targets = [pos for pos, alignment in self.junction_map.items() if alignment != "cogs"]
|
|
150
|
+
targets.sort()
|
|
151
|
+
|
|
152
|
+
self.assigned_targets.clear()
|
|
153
|
+
if not targets or not self.desired_vibes:
|
|
154
|
+
return
|
|
155
|
+
|
|
156
|
+
target_index = 0
|
|
157
|
+
for agent_id, vibe in enumerate(self.desired_vibes):
|
|
158
|
+
if vibe not in CONTROL_VIBES:
|
|
159
|
+
continue
|
|
160
|
+
self.assigned_targets[agent_id] = targets[target_index]
|
|
161
|
+
target_index = (target_index + 1) % len(targets)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class CogsguardCommanderMultiRoleImpl(CogsguardMultiRoleImpl):
|
|
165
|
+
def __init__(
|
|
166
|
+
self,
|
|
167
|
+
policy_env_info: PolicyEnvInterface,
|
|
168
|
+
agent_id: int,
|
|
169
|
+
initial_target_vibe: Optional[str],
|
|
170
|
+
shared_state: CommanderPlannerState,
|
|
171
|
+
):
|
|
172
|
+
super().__init__(policy_env_info, agent_id, initial_target_vibe=initial_target_vibe)
|
|
173
|
+
self._shared_state = shared_state
|
|
174
|
+
|
|
175
|
+
def _execute_phase(self, s: CogsguardAgentState) -> Action:
|
|
176
|
+
self._shared_state.update_from_agent(s)
|
|
177
|
+
if s.agent_id == 0:
|
|
178
|
+
self._shared_state.maybe_plan(s.step_count)
|
|
179
|
+
|
|
180
|
+
if self._shared_state.desired_vibes:
|
|
181
|
+
desired = self._shared_state.desired_vibes[s.agent_id]
|
|
182
|
+
if desired != s.current_vibe:
|
|
183
|
+
return change_vibe_action(desired, action_names=self._action_names)
|
|
184
|
+
|
|
185
|
+
return super()._execute_phase(s)
|
|
186
|
+
|
|
187
|
+
def execute_role(self, s: CogsguardAgentState) -> Action:
|
|
188
|
+
target = self._shared_state.assigned_targets.get(s.agent_id)
|
|
189
|
+
if target and s.current_vibe in CONTROL_VIBES and s.has_gear() and s.heart >= 1:
|
|
190
|
+
struct = s.get_structure_at(target)
|
|
191
|
+
if struct and struct.alignment != "cogs":
|
|
192
|
+
if abs(target[0] - s.row) + abs(target[1] - s.col) > 1:
|
|
193
|
+
return self._move_towards(s, target, reach_adjacent=True)
|
|
194
|
+
return self._use_object_at(s, target)
|
|
195
|
+
return super().execute_role(s)
|
|
196
|
+
|
|
197
|
+
def _get_role_impl(self, role: Role) -> CogsguardAgentPolicyImpl:
|
|
198
|
+
if role not in self._role_impls:
|
|
199
|
+
impl_class = {
|
|
200
|
+
Role.MINER: MinerAgentPolicyImpl,
|
|
201
|
+
Role.SCOUT: ScoutAgentPolicyImpl,
|
|
202
|
+
Role.ALIGNER: CommanderAlignerAgentPolicyImpl,
|
|
203
|
+
Role.SCRAMBLER: CommanderScramblerAgentPolicyImpl,
|
|
204
|
+
}[role]
|
|
205
|
+
self._role_impls[role] = impl_class(self._policy_env_info, self._agent_id, role)
|
|
206
|
+
return self._role_impls[role]
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class CommanderScramblerAgentPolicyImpl(ScramblerAgentPolicyImpl):
|
|
210
|
+
def _find_best_target(self, s: CogsguardAgentState) -> Optional[tuple[int, int]]:
|
|
211
|
+
junctions = s.get_structures_by_type(StructureType.CHARGER)
|
|
212
|
+
cooldown = 20 if len(junctions) <= 4 else 50
|
|
213
|
+
|
|
214
|
+
enemy_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
215
|
+
neutral_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
216
|
+
any_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
217
|
+
|
|
218
|
+
for junction in junctions:
|
|
219
|
+
pos = junction.position
|
|
220
|
+
dist = abs(pos[0] - s.row) + abs(pos[1] - s.col)
|
|
221
|
+
|
|
222
|
+
last_worked = s.worked_junctions.get(pos, 0)
|
|
223
|
+
if last_worked > 0 and s.step_count - last_worked < cooldown:
|
|
224
|
+
continue
|
|
225
|
+
|
|
226
|
+
if junction.alignment == "cogs":
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
if junction.alignment == "clips" or junction.clipped:
|
|
230
|
+
enemy_junctions.append((dist, pos))
|
|
231
|
+
elif junction.alignment is None or junction.alignment == "neutral":
|
|
232
|
+
neutral_junctions.append((dist, pos))
|
|
233
|
+
else:
|
|
234
|
+
any_junctions.append((dist, pos))
|
|
235
|
+
|
|
236
|
+
if enemy_junctions:
|
|
237
|
+
enemy_junctions.sort()
|
|
238
|
+
return enemy_junctions[0][1]
|
|
239
|
+
if neutral_junctions:
|
|
240
|
+
neutral_junctions.sort()
|
|
241
|
+
return neutral_junctions[0][1]
|
|
242
|
+
if any_junctions:
|
|
243
|
+
any_junctions.sort()
|
|
244
|
+
return any_junctions[0][1]
|
|
245
|
+
|
|
246
|
+
return super()._find_best_target(s)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class CommanderAlignerAgentPolicyImpl(AlignerAgentPolicyImpl):
|
|
250
|
+
def _find_best_target(self, s: CogsguardAgentState) -> Optional[tuple[int, int]]:
|
|
251
|
+
junctions = s.get_structures_by_type(StructureType.CHARGER)
|
|
252
|
+
cooldown = 20 if len(junctions) <= 4 else 50
|
|
253
|
+
|
|
254
|
+
neutral_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
255
|
+
clips_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
256
|
+
other_junctions: list[tuple[int, tuple[int, int]]] = []
|
|
257
|
+
|
|
258
|
+
for junction in junctions:
|
|
259
|
+
pos = junction.position
|
|
260
|
+
dist = abs(pos[0] - s.row) + abs(pos[1] - s.col)
|
|
261
|
+
|
|
262
|
+
last_worked = s.worked_junctions.get(pos, 0)
|
|
263
|
+
if last_worked > 0 and s.step_count - last_worked < cooldown:
|
|
264
|
+
continue
|
|
265
|
+
|
|
266
|
+
if junction.alignment == "cogs":
|
|
267
|
+
continue
|
|
268
|
+
|
|
269
|
+
if junction.alignment is None or junction.alignment == "neutral":
|
|
270
|
+
neutral_junctions.append((dist, pos))
|
|
271
|
+
elif junction.alignment == "clips" or junction.clipped:
|
|
272
|
+
clips_junctions.append((dist, pos))
|
|
273
|
+
else:
|
|
274
|
+
other_junctions.append((dist, pos))
|
|
275
|
+
|
|
276
|
+
if neutral_junctions:
|
|
277
|
+
neutral_junctions.sort()
|
|
278
|
+
return neutral_junctions[0][1]
|
|
279
|
+
if clips_junctions:
|
|
280
|
+
clips_junctions.sort()
|
|
281
|
+
return clips_junctions[0][1]
|
|
282
|
+
if other_junctions:
|
|
283
|
+
other_junctions.sort()
|
|
284
|
+
return other_junctions[0][1]
|
|
285
|
+
|
|
286
|
+
return super()._find_best_target(s)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class CogsguardControlAgent(CogsguardPolicy):
|
|
290
|
+
"""CoGsGuard policy with a phased coordinator that overrides roles."""
|
|
291
|
+
|
|
292
|
+
short_names = ["cogsguard_control"]
|
|
293
|
+
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
policy_env_info: PolicyEnvInterface,
|
|
297
|
+
device: str = "cpu",
|
|
298
|
+
**vibe_counts: Any,
|
|
299
|
+
):
|
|
300
|
+
has_explicit_counts = any(isinstance(v, int) for v in vibe_counts.values())
|
|
301
|
+
if has_explicit_counts:
|
|
302
|
+
counts = _normalize_counts(policy_env_info.num_agents, vibe_counts)
|
|
303
|
+
else:
|
|
304
|
+
counts = _default_role_counts(policy_env_info.num_agents)
|
|
305
|
+
super().__init__(policy_env_info, device=device, **counts)
|
|
306
|
+
self._shared_state = CommanderPlannerState(policy_env_info.num_agents)
|
|
307
|
+
self._shared_state.desired_vibes = _build_role_plan(policy_env_info.num_agents, counts)
|
|
308
|
+
|
|
309
|
+
def agent_policy(self, agent_id: int) -> StatefulAgentPolicy[CogsguardAgentState]:
|
|
310
|
+
if agent_id not in self._agent_policies:
|
|
311
|
+
target_vibe = None
|
|
312
|
+
if agent_id < len(self._initial_vibes):
|
|
313
|
+
target_vibe = self._initial_vibes[agent_id]
|
|
314
|
+
|
|
315
|
+
impl = CogsguardCommanderMultiRoleImpl(
|
|
316
|
+
self._policy_env_info,
|
|
317
|
+
agent_id,
|
|
318
|
+
initial_target_vibe=target_vibe,
|
|
319
|
+
shared_state=self._shared_state,
|
|
320
|
+
)
|
|
321
|
+
self._agent_policies[agent_id] = StatefulAgentPolicy(impl, self._policy_env_info, agent_id=agent_id)
|
|
322
|
+
|
|
323
|
+
return self._agent_policies[agent_id]
|