cogames-agents 0.0.0.7__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cogames_agents/__init__.py +0 -0
- cogames_agents/evals/__init__.py +5 -0
- cogames_agents/evals/planky_evals.py +415 -0
- cogames_agents/policy/__init__.py +0 -0
- cogames_agents/policy/evolution/__init__.py +0 -0
- cogames_agents/policy/evolution/cogsguard/__init__.py +0 -0
- cogames_agents/policy/evolution/cogsguard/evolution.py +695 -0
- cogames_agents/policy/evolution/cogsguard/evolutionary_coordinator.py +540 -0
- cogames_agents/policy/nim_agents/__init__.py +20 -0
- cogames_agents/policy/nim_agents/agents.py +98 -0
- cogames_agents/policy/nim_agents/bindings/generated/libnim_agents.dylib +0 -0
- cogames_agents/policy/nim_agents/bindings/generated/nim_agents.py +215 -0
- cogames_agents/policy/nim_agents/cogsguard_agents.nim +555 -0
- cogames_agents/policy/nim_agents/cogsguard_align_all_agents.nim +569 -0
- cogames_agents/policy/nim_agents/common.nim +1054 -0
- cogames_agents/policy/nim_agents/install.sh +1 -0
- cogames_agents/policy/nim_agents/ladybug_agent.nim +954 -0
- cogames_agents/policy/nim_agents/nim_agents.nim +68 -0
- cogames_agents/policy/nim_agents/nim_agents.nims +14 -0
- cogames_agents/policy/nim_agents/nimby.lock +3 -0
- cogames_agents/policy/nim_agents/racecar_agents.nim +844 -0
- cogames_agents/policy/nim_agents/random_agents.nim +68 -0
- cogames_agents/policy/nim_agents/test_agents.py +53 -0
- cogames_agents/policy/nim_agents/thinky_agents.nim +677 -0
- cogames_agents/policy/nim_agents/thinky_eval.py +230 -0
- cogames_agents/policy/scripted_agent/README.md +360 -0
- cogames_agents/policy/scripted_agent/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/baseline_agent.py +1031 -0
- cogames_agents/policy/scripted_agent/cogas/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/cogas/context.py +68 -0
- cogames_agents/policy/scripted_agent/cogas/entity_map.py +152 -0
- cogames_agents/policy/scripted_agent/cogas/goal.py +115 -0
- cogames_agents/policy/scripted_agent/cogas/goals/__init__.py +27 -0
- cogames_agents/policy/scripted_agent/cogas/goals/aligner.py +160 -0
- cogames_agents/policy/scripted_agent/cogas/goals/gear.py +197 -0
- cogames_agents/policy/scripted_agent/cogas/goals/miner.py +441 -0
- cogames_agents/policy/scripted_agent/cogas/goals/scout.py +40 -0
- cogames_agents/policy/scripted_agent/cogas/goals/scrambler.py +174 -0
- cogames_agents/policy/scripted_agent/cogas/goals/shared.py +160 -0
- cogames_agents/policy/scripted_agent/cogas/goals/stem.py +60 -0
- cogames_agents/policy/scripted_agent/cogas/goals/survive.py +100 -0
- cogames_agents/policy/scripted_agent/cogas/navigator.py +401 -0
- cogames_agents/policy/scripted_agent/cogas/obs_parser.py +238 -0
- cogames_agents/policy/scripted_agent/cogas/policy.py +525 -0
- cogames_agents/policy/scripted_agent/cogas/trace.py +69 -0
- cogames_agents/policy/scripted_agent/cogsguard/CLAUDE.md +517 -0
- cogames_agents/policy/scripted_agent/cogsguard/README.md +252 -0
- cogames_agents/policy/scripted_agent/cogsguard/__init__.py +74 -0
- cogames_agents/policy/scripted_agent/cogsguard/aligned_junction_held_investigation.md +152 -0
- cogames_agents/policy/scripted_agent/cogsguard/aligner.py +333 -0
- cogames_agents/policy/scripted_agent/cogsguard/behavior_hooks.py +44 -0
- cogames_agents/policy/scripted_agent/cogsguard/control_agent.py +323 -0
- cogames_agents/policy/scripted_agent/cogsguard/debug_agent.py +533 -0
- cogames_agents/policy/scripted_agent/cogsguard/miner.py +589 -0
- cogames_agents/policy/scripted_agent/cogsguard/options.py +67 -0
- cogames_agents/policy/scripted_agent/cogsguard/parity_metrics.py +36 -0
- cogames_agents/policy/scripted_agent/cogsguard/policy.py +1967 -0
- cogames_agents/policy/scripted_agent/cogsguard/prereq_trace.py +33 -0
- cogames_agents/policy/scripted_agent/cogsguard/role_trace.py +50 -0
- cogames_agents/policy/scripted_agent/cogsguard/roles.py +31 -0
- cogames_agents/policy/scripted_agent/cogsguard/rollout_trace.py +40 -0
- cogames_agents/policy/scripted_agent/cogsguard/scout.py +69 -0
- cogames_agents/policy/scripted_agent/cogsguard/scrambler.py +350 -0
- cogames_agents/policy/scripted_agent/cogsguard/targeted_agent.py +418 -0
- cogames_agents/policy/scripted_agent/cogsguard/teacher.py +224 -0
- cogames_agents/policy/scripted_agent/cogsguard/types.py +381 -0
- cogames_agents/policy/scripted_agent/cogsguard/v2_agent.py +49 -0
- cogames_agents/policy/scripted_agent/common/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/common/geometry.py +24 -0
- cogames_agents/policy/scripted_agent/common/roles.py +34 -0
- cogames_agents/policy/scripted_agent/common/tag_utils.py +48 -0
- cogames_agents/policy/scripted_agent/demo_policy.py +242 -0
- cogames_agents/policy/scripted_agent/pathfinding.py +126 -0
- cogames_agents/policy/scripted_agent/pinky/DESIGN.md +317 -0
- cogames_agents/policy/scripted_agent/pinky/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/__init__.py +17 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/aligner.py +400 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/base.py +119 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/miner.py +632 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/scout.py +138 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/scrambler.py +433 -0
- cogames_agents/policy/scripted_agent/pinky/policy.py +570 -0
- cogames_agents/policy/scripted_agent/pinky/services/__init__.py +7 -0
- cogames_agents/policy/scripted_agent/pinky/services/map_tracker.py +808 -0
- cogames_agents/policy/scripted_agent/pinky/services/navigator.py +864 -0
- cogames_agents/policy/scripted_agent/pinky/services/safety.py +189 -0
- cogames_agents/policy/scripted_agent/pinky/state.py +299 -0
- cogames_agents/policy/scripted_agent/pinky/types.py +138 -0
- cogames_agents/policy/scripted_agent/planky/CLAUDE.md +124 -0
- cogames_agents/policy/scripted_agent/planky/IMPROVEMENTS.md +160 -0
- cogames_agents/policy/scripted_agent/planky/NOTES.md +153 -0
- cogames_agents/policy/scripted_agent/planky/PLAN.md +254 -0
- cogames_agents/policy/scripted_agent/planky/README.md +214 -0
- cogames_agents/policy/scripted_agent/planky/STRATEGY.md +100 -0
- cogames_agents/policy/scripted_agent/planky/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/planky/context.py +68 -0
- cogames_agents/policy/scripted_agent/planky/entity_map.py +152 -0
- cogames_agents/policy/scripted_agent/planky/goal.py +107 -0
- cogames_agents/policy/scripted_agent/planky/goals/__init__.py +27 -0
- cogames_agents/policy/scripted_agent/planky/goals/aligner.py +168 -0
- cogames_agents/policy/scripted_agent/planky/goals/gear.py +179 -0
- cogames_agents/policy/scripted_agent/planky/goals/miner.py +416 -0
- cogames_agents/policy/scripted_agent/planky/goals/scout.py +40 -0
- cogames_agents/policy/scripted_agent/planky/goals/scrambler.py +174 -0
- cogames_agents/policy/scripted_agent/planky/goals/shared.py +160 -0
- cogames_agents/policy/scripted_agent/planky/goals/stem.py +49 -0
- cogames_agents/policy/scripted_agent/planky/goals/survive.py +96 -0
- cogames_agents/policy/scripted_agent/planky/navigator.py +388 -0
- cogames_agents/policy/scripted_agent/planky/obs_parser.py +238 -0
- cogames_agents/policy/scripted_agent/planky/policy.py +485 -0
- cogames_agents/policy/scripted_agent/planky/tests/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/planky/tests/conftest.py +66 -0
- cogames_agents/policy/scripted_agent/planky/tests/helpers.py +152 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_aligner.py +24 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_miner.py +30 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_scout.py +15 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_scrambler.py +29 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_stem.py +36 -0
- cogames_agents/policy/scripted_agent/planky/trace.py +69 -0
- cogames_agents/policy/scripted_agent/types.py +239 -0
- cogames_agents/policy/scripted_agent/unclipping_agent.py +461 -0
- cogames_agents/policy/scripted_agent/utils.py +381 -0
- cogames_agents/policy/scripted_registry.py +80 -0
- cogames_agents/py.typed +0 -0
- cogames_agents-0.0.0.7.dist-info/METADATA +98 -0
- cogames_agents-0.0.0.7.dist-info/RECORD +128 -0
- cogames_agents-0.0.0.7.dist-info/WHEEL +6 -0
- cogames_agents-0.0.0.7.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1967 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CoGsGuard Scripted Agent - Vibe-based multi-agent policy.
|
|
3
|
+
|
|
4
|
+
Agents use vibes to determine their behavior:
|
|
5
|
+
- default: do nothing (noop)
|
|
6
|
+
- gear: pick a role via smart coordinator, change vibe to that role
|
|
7
|
+
- miner/scout/aligner/scrambler: get gear if needed, then execute role behavior
|
|
8
|
+
- heart: do nothing (noop)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import random
|
|
14
|
+
from collections import deque
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from typing import TYPE_CHECKING, Optional
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from cogames_agents.policy.evolution.cogsguard.evolutionary_coordinator import (
|
|
21
|
+
EvolutionaryRoleCoordinator,
|
|
22
|
+
)
|
|
23
|
+
from cogames_agents.policy.scripted_agent.pathfinding import (
|
|
24
|
+
compute_goal_cells,
|
|
25
|
+
shortest_path,
|
|
26
|
+
)
|
|
27
|
+
from cogames_agents.policy.scripted_agent.pathfinding import (
|
|
28
|
+
is_traversable as path_is_traversable,
|
|
29
|
+
)
|
|
30
|
+
from cogames_agents.policy.scripted_agent.pathfinding import (
|
|
31
|
+
is_within_bounds as path_is_within_bounds,
|
|
32
|
+
)
|
|
33
|
+
from cogames_agents.policy.scripted_agent.types import CellType, ObjectState, ParsedObservation
|
|
34
|
+
from cogames_agents.policy.scripted_agent.utils import (
|
|
35
|
+
add_inventory_token,
|
|
36
|
+
change_vibe_action,
|
|
37
|
+
has_type_tag,
|
|
38
|
+
is_adjacent,
|
|
39
|
+
is_station,
|
|
40
|
+
is_wall,
|
|
41
|
+
)
|
|
42
|
+
from cogames_agents.policy.scripted_agent.utils import (
|
|
43
|
+
parse_observation as utils_parse_observation,
|
|
44
|
+
)
|
|
45
|
+
from mettagrid.config.mettagrid_config import CardinalDirection
|
|
46
|
+
from mettagrid.mettagrid_c import dtype_actions
|
|
47
|
+
from mettagrid.policy.policy import MultiAgentPolicy, StatefulAgentPolicy, StatefulPolicyImpl
|
|
48
|
+
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
|
|
49
|
+
from mettagrid.simulator import Action, AgentObservation, ObservationToken
|
|
50
|
+
|
|
51
|
+
from .types import (
|
|
52
|
+
ROLE_TO_GEAR,
|
|
53
|
+
ROLE_TO_STATION,
|
|
54
|
+
CogsguardAgentState,
|
|
55
|
+
CogsguardPhase,
|
|
56
|
+
Role,
|
|
57
|
+
StructureInfo,
|
|
58
|
+
StructureType,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Vibe names for role selection
|
|
62
|
+
ROLE_VIBES = ["scout", "miner", "aligner", "scrambler"]
|
|
63
|
+
VIBE_TO_ROLE = {
|
|
64
|
+
"miner": Role.MINER,
|
|
65
|
+
"scout": Role.SCOUT,
|
|
66
|
+
"aligner": Role.ALIGNER,
|
|
67
|
+
"scrambler": Role.SCRAMBLER,
|
|
68
|
+
}
|
|
69
|
+
SMART_ROLE_SWITCH_COOLDOWN = 40
|
|
70
|
+
SCRAMBLER_GEAR_PRIORITY_STEPS = 25
|
|
71
|
+
|
|
72
|
+
_GLOBAL_COORDINATORS: dict[int, "SmartRoleCoordinator"] = {}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _shared_coordinator(policy_env_info: PolicyEnvInterface) -> "SmartRoleCoordinator":
|
|
76
|
+
key = id(policy_env_info)
|
|
77
|
+
coordinator = _GLOBAL_COORDINATORS.get(key)
|
|
78
|
+
if coordinator is None or coordinator.num_agents != policy_env_info.num_agents:
|
|
79
|
+
coordinator = SmartRoleCoordinator(policy_env_info.num_agents)
|
|
80
|
+
_GLOBAL_COORDINATORS[key] = coordinator
|
|
81
|
+
return coordinator
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
if TYPE_CHECKING:
|
|
85
|
+
from mettagrid.simulator.interface import AgentObservation
|
|
86
|
+
|
|
87
|
+
# Debug flag - set to True to see detailed agent behavior
|
|
88
|
+
DEBUG = False
|
|
89
|
+
GEAR_SEARCH_OFFSETS = [
|
|
90
|
+
# BaseHub places stations ~4-5 rows below the hub, spaced by 2 columns.
|
|
91
|
+
# Search those slots first to capture early gear windows.
|
|
92
|
+
(4, -4),
|
|
93
|
+
(4, -2),
|
|
94
|
+
(4, 0),
|
|
95
|
+
(4, 2),
|
|
96
|
+
(4, 4),
|
|
97
|
+
(5, -4),
|
|
98
|
+
(5, -2),
|
|
99
|
+
(5, 0),
|
|
100
|
+
(5, 2),
|
|
101
|
+
(5, 4),
|
|
102
|
+
# Fallbacks for variations/tighter layouts.
|
|
103
|
+
(6, -4),
|
|
104
|
+
(6, 0),
|
|
105
|
+
(6, 4),
|
|
106
|
+
(0, 5),
|
|
107
|
+
(5, 0),
|
|
108
|
+
(0, -5),
|
|
109
|
+
(-5, 0),
|
|
110
|
+
]
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class SmartRoleAgentSnapshot:
|
|
115
|
+
"""Lightweight snapshot for smart-role coordination."""
|
|
116
|
+
|
|
117
|
+
step: int
|
|
118
|
+
role: Role
|
|
119
|
+
has_gear: bool
|
|
120
|
+
structures_known: tuple[str, ...]
|
|
121
|
+
structures_seen: int
|
|
122
|
+
heart_count: int
|
|
123
|
+
influence_count: int
|
|
124
|
+
junction_alignment_counts: dict[str, int]
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@dataclass
|
|
128
|
+
class SmartRoleCoordinator:
|
|
129
|
+
"""Shared coordinator for future smart-role selection."""
|
|
130
|
+
|
|
131
|
+
num_agents: int
|
|
132
|
+
agent_snapshots: dict[int, SmartRoleAgentSnapshot] = field(default_factory=dict)
|
|
133
|
+
junction_alignment_overrides: dict[tuple[int, int], Optional[str]] = field(default_factory=dict)
|
|
134
|
+
station_offsets: dict[str, tuple[int, int]] = field(default_factory=dict)
|
|
135
|
+
recent_scrambles: dict[tuple[int, int], int] = field(default_factory=dict)
|
|
136
|
+
|
|
137
|
+
def update_agent(self, s: CogsguardAgentState) -> None:
|
|
138
|
+
hub_pos = s.stations.get("hub")
|
|
139
|
+
if hub_pos is not None:
|
|
140
|
+
self._record_known_junctions(s, hub_pos)
|
|
141
|
+
self._record_known_stations(s, hub_pos)
|
|
142
|
+
self._apply_alignment_overrides(s, hub_pos)
|
|
143
|
+
self._apply_station_overrides(s, hub_pos)
|
|
144
|
+
junction_counts = {"cogs": 0, "clips": 0, "neutral": 0, "unknown": 0}
|
|
145
|
+
for struct in s.get_structures_by_type(StructureType.CHARGER):
|
|
146
|
+
bucket = self._normalize_alignment(struct.alignment)
|
|
147
|
+
junction_counts[bucket] += 1
|
|
148
|
+
|
|
149
|
+
known_structures = tuple(sorted({struct.structure_type.value for struct in s.structures.values()}))
|
|
150
|
+
self.agent_snapshots[s.agent_id] = SmartRoleAgentSnapshot(
|
|
151
|
+
step=s.step_count,
|
|
152
|
+
role=s.role,
|
|
153
|
+
has_gear=s.has_gear(),
|
|
154
|
+
structures_known=known_structures,
|
|
155
|
+
structures_seen=len(s.structures),
|
|
156
|
+
heart_count=s.heart,
|
|
157
|
+
influence_count=s.influence,
|
|
158
|
+
junction_alignment_counts=junction_counts,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def register_junction_alignment(
|
|
162
|
+
self,
|
|
163
|
+
pos: tuple[int, int],
|
|
164
|
+
alignment: Optional[str],
|
|
165
|
+
hub_pos: Optional[tuple[int, int]],
|
|
166
|
+
step: Optional[int] = None,
|
|
167
|
+
) -> None:
|
|
168
|
+
if hub_pos is None:
|
|
169
|
+
return
|
|
170
|
+
offset = (pos[0] - hub_pos[0], pos[1] - hub_pos[1])
|
|
171
|
+
self.junction_alignment_overrides[offset] = alignment
|
|
172
|
+
if step is None:
|
|
173
|
+
return
|
|
174
|
+
if alignment is None:
|
|
175
|
+
self.recent_scrambles[offset] = step
|
|
176
|
+
elif alignment == "cogs":
|
|
177
|
+
self.recent_scrambles.pop(offset, None)
|
|
178
|
+
|
|
179
|
+
def recent_scramble_targets(
|
|
180
|
+
self,
|
|
181
|
+
hub_pos: Optional[tuple[int, int]],
|
|
182
|
+
step: int,
|
|
183
|
+
*,
|
|
184
|
+
max_age: int = 200,
|
|
185
|
+
) -> list[tuple[int, int]]:
|
|
186
|
+
if hub_pos is None:
|
|
187
|
+
return []
|
|
188
|
+
targets: list[tuple[int, int]] = []
|
|
189
|
+
stale_offsets: list[tuple[int, int]] = []
|
|
190
|
+
for offset, scramble_step in self.recent_scrambles.items():
|
|
191
|
+
if step - scramble_step > max_age:
|
|
192
|
+
stale_offsets.append(offset)
|
|
193
|
+
continue
|
|
194
|
+
targets.append((hub_pos[0] + offset[0], hub_pos[1] + offset[1]))
|
|
195
|
+
for offset in stale_offsets:
|
|
196
|
+
self.recent_scrambles.pop(offset, None)
|
|
197
|
+
return targets
|
|
198
|
+
|
|
199
|
+
def _record_known_junctions(self, s: CogsguardAgentState, hub_pos: tuple[int, int]) -> None:
|
|
200
|
+
for junction in s.get_structures_by_type(StructureType.CHARGER):
|
|
201
|
+
offset = (junction.position[0] - hub_pos[0], junction.position[1] - hub_pos[1])
|
|
202
|
+
if offset not in self.junction_alignment_overrides and junction.alignment is not None:
|
|
203
|
+
self.junction_alignment_overrides[offset] = junction.alignment
|
|
204
|
+
|
|
205
|
+
def _record_known_stations(self, s: CogsguardAgentState, hub_pos: tuple[int, int]) -> None:
|
|
206
|
+
for name, pos in s.stations.items():
|
|
207
|
+
if pos is None or name in ("hub", "junction"):
|
|
208
|
+
continue
|
|
209
|
+
if name not in self.station_offsets:
|
|
210
|
+
self.station_offsets[name] = (pos[0] - hub_pos[0], pos[1] - hub_pos[1])
|
|
211
|
+
|
|
212
|
+
def _apply_station_overrides(self, s: CogsguardAgentState, hub_pos: tuple[int, int]) -> None:
|
|
213
|
+
if not self.station_offsets:
|
|
214
|
+
return
|
|
215
|
+
for name, offset in self.station_offsets.items():
|
|
216
|
+
if s.stations.get(name) is not None:
|
|
217
|
+
continue
|
|
218
|
+
pos = (hub_pos[0] + offset[0], hub_pos[1] + offset[1])
|
|
219
|
+
if not (0 <= pos[0] < s.map_height and 0 <= pos[1] < s.map_width):
|
|
220
|
+
continue
|
|
221
|
+
s.stations[name] = pos
|
|
222
|
+
if pos not in s.structures:
|
|
223
|
+
s.structures[pos] = StructureInfo(
|
|
224
|
+
position=pos,
|
|
225
|
+
structure_type=self._station_structure_type(name),
|
|
226
|
+
name=name,
|
|
227
|
+
last_seen_step=s.step_count,
|
|
228
|
+
)
|
|
229
|
+
s.occupancy[pos[0]][pos[1]] = CellType.OBSTACLE.value
|
|
230
|
+
|
|
231
|
+
def _apply_alignment_overrides(self, s: CogsguardAgentState, hub_pos: tuple[int, int]) -> None:
|
|
232
|
+
if not self.junction_alignment_overrides:
|
|
233
|
+
return
|
|
234
|
+
for offset, alignment in self.junction_alignment_overrides.items():
|
|
235
|
+
pos = (hub_pos[0] + offset[0], hub_pos[1] + offset[1])
|
|
236
|
+
if not (0 <= pos[0] < s.map_height and 0 <= pos[1] < s.map_width):
|
|
237
|
+
continue
|
|
238
|
+
struct = s.structures.get(pos)
|
|
239
|
+
if struct is None:
|
|
240
|
+
s.structures[pos] = StructureInfo(
|
|
241
|
+
position=pos,
|
|
242
|
+
structure_type=StructureType.CHARGER,
|
|
243
|
+
name="junction",
|
|
244
|
+
last_seen_step=s.step_count,
|
|
245
|
+
alignment=alignment,
|
|
246
|
+
)
|
|
247
|
+
s.occupancy[pos[0]][pos[1]] = CellType.OBSTACLE.value
|
|
248
|
+
elif struct.structure_type == StructureType.CHARGER:
|
|
249
|
+
if struct.last_seen_step == s.step_count:
|
|
250
|
+
continue
|
|
251
|
+
if struct.alignment != alignment:
|
|
252
|
+
struct.alignment = alignment
|
|
253
|
+
|
|
254
|
+
if s.supply_depots:
|
|
255
|
+
for idx, (pos, _alignment) in enumerate(s.supply_depots):
|
|
256
|
+
offset = (pos[0] - hub_pos[0], pos[1] - hub_pos[1])
|
|
257
|
+
if offset in self.junction_alignment_overrides:
|
|
258
|
+
s.supply_depots[idx] = (pos, self.junction_alignment_overrides[offset])
|
|
259
|
+
|
|
260
|
+
@staticmethod
|
|
261
|
+
def _station_structure_type(name: str) -> StructureType:
|
|
262
|
+
return {
|
|
263
|
+
"miner_station": StructureType.MINER_STATION,
|
|
264
|
+
"scout_station": StructureType.SCOUT_STATION,
|
|
265
|
+
"aligner_station": StructureType.ALIGNER_STATION,
|
|
266
|
+
"scrambler_station": StructureType.SCRAMBLER_STATION,
|
|
267
|
+
"chest": StructureType.CHEST,
|
|
268
|
+
}.get(name, StructureType.UNKNOWN)
|
|
269
|
+
|
|
270
|
+
@staticmethod
|
|
271
|
+
def _normalize_alignment(alignment: Optional[str]) -> str:
|
|
272
|
+
if alignment is None or alignment == "neutral":
|
|
273
|
+
return "neutral"
|
|
274
|
+
if alignment in ("cogs", "clips"):
|
|
275
|
+
return alignment
|
|
276
|
+
return "unknown"
|
|
277
|
+
|
|
278
|
+
def choose_role(self, agent_id: int) -> str:
|
|
279
|
+
"""Pick a role vibe based on aggregated snapshots."""
|
|
280
|
+
snapshot = self.agent_snapshots.get(agent_id)
|
|
281
|
+
if snapshot is None:
|
|
282
|
+
return random.choice(ROLE_VIBES)
|
|
283
|
+
|
|
284
|
+
structures_known = self._aggregate_structures()
|
|
285
|
+
if "hub" not in structures_known or "chest" not in structures_known:
|
|
286
|
+
return "scout"
|
|
287
|
+
|
|
288
|
+
role_counts = self._aggregate_role_counts()
|
|
289
|
+
if role_counts.get("scout", 0) == 0:
|
|
290
|
+
return "scout"
|
|
291
|
+
if role_counts.get("miner", 0) == 0:
|
|
292
|
+
return "miner"
|
|
293
|
+
|
|
294
|
+
junction_counts = self._aggregate_junction_counts()
|
|
295
|
+
known_junctions = sum(junction_counts.values()) - junction_counts["unknown"]
|
|
296
|
+
if known_junctions == 0:
|
|
297
|
+
return "scout"
|
|
298
|
+
|
|
299
|
+
if role_counts.get("scrambler", 0) == 0:
|
|
300
|
+
return "scrambler"
|
|
301
|
+
if role_counts.get("aligner", 0) == 0:
|
|
302
|
+
return "aligner"
|
|
303
|
+
|
|
304
|
+
if junction_counts["clips"] > 0 and role_counts.get("scrambler", 0) <= role_counts.get("aligner", 0):
|
|
305
|
+
return "scrambler"
|
|
306
|
+
if junction_counts["neutral"] > 0:
|
|
307
|
+
return "aligner"
|
|
308
|
+
|
|
309
|
+
if self._aggregate_structures_seen() < 10:
|
|
310
|
+
return "scout"
|
|
311
|
+
return "miner"
|
|
312
|
+
|
|
313
|
+
def _aggregate_junction_counts(self) -> dict[str, int]:
|
|
314
|
+
totals = {"cogs": 0, "clips": 0, "neutral": 0, "unknown": 0}
|
|
315
|
+
for snapshot in self.agent_snapshots.values():
|
|
316
|
+
for key in totals:
|
|
317
|
+
totals[key] = max(totals[key], snapshot.junction_alignment_counts.get(key, 0))
|
|
318
|
+
return totals
|
|
319
|
+
|
|
320
|
+
def aligned_junction_count(self) -> int:
|
|
321
|
+
return self._aggregate_junction_counts().get("cogs", 0)
|
|
322
|
+
|
|
323
|
+
def _aggregate_structures(self) -> set[str]:
|
|
324
|
+
structures: set[str] = set()
|
|
325
|
+
for snapshot in self.agent_snapshots.values():
|
|
326
|
+
structures.update(snapshot.structures_known)
|
|
327
|
+
return structures
|
|
328
|
+
|
|
329
|
+
def _aggregate_role_counts(self) -> dict[str, int]:
|
|
330
|
+
counts = {role: 0 for role in ROLE_VIBES}
|
|
331
|
+
for snapshot in self.agent_snapshots.values():
|
|
332
|
+
role_name = snapshot.role.value
|
|
333
|
+
if role_name in counts:
|
|
334
|
+
counts[role_name] += 1
|
|
335
|
+
return counts
|
|
336
|
+
|
|
337
|
+
def _aggregate_role_gear_counts(self) -> dict[str, int]:
|
|
338
|
+
counts = {role: 0 for role in ROLE_VIBES}
|
|
339
|
+
for snapshot in self.agent_snapshots.values():
|
|
340
|
+
role_name = snapshot.role.value
|
|
341
|
+
if role_name in counts and snapshot.has_gear:
|
|
342
|
+
counts[role_name] += 1
|
|
343
|
+
return counts
|
|
344
|
+
|
|
345
|
+
def get_role_gear_counts(self) -> dict[str, int]:
|
|
346
|
+
return self._aggregate_role_gear_counts()
|
|
347
|
+
|
|
348
|
+
def _aggregate_heart_count(self) -> int:
|
|
349
|
+
return sum(snapshot.heart_count for snapshot in self.agent_snapshots.values())
|
|
350
|
+
|
|
351
|
+
def _aggregate_influence_count(self) -> int:
|
|
352
|
+
return sum(snapshot.influence_count for snapshot in self.agent_snapshots.values())
|
|
353
|
+
|
|
354
|
+
def _aggregate_structures_seen(self) -> int:
|
|
355
|
+
return max((snap.structures_seen for snap in self.agent_snapshots.values()), default=0)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class CogsguardAgentPolicyImpl(StatefulPolicyImpl[CogsguardAgentState]):
|
|
359
|
+
"""Base policy implementation for CoGsGuard agents.
|
|
360
|
+
|
|
361
|
+
Handles common behavior like gear acquisition. Role-specific behavior
|
|
362
|
+
is implemented by overriding execute_role().
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
# Subclasses set this
|
|
366
|
+
ROLE: Role = Role.MINER
|
|
367
|
+
|
|
368
|
+
def __init__(
|
|
369
|
+
self,
|
|
370
|
+
policy_env_info: PolicyEnvInterface,
|
|
371
|
+
agent_id: int,
|
|
372
|
+
role: Role,
|
|
373
|
+
smart_role_coordinator: Optional[SmartRoleCoordinator] = None,
|
|
374
|
+
evolutionary_role_coordinator: Optional[EvolutionaryRoleCoordinator] = None,
|
|
375
|
+
use_evolutionary_roles: bool = False,
|
|
376
|
+
):
|
|
377
|
+
self._agent_id = agent_id
|
|
378
|
+
self._role = role
|
|
379
|
+
self._policy_env_info = policy_env_info
|
|
380
|
+
self._smart_role_coordinator = smart_role_coordinator
|
|
381
|
+
self._evolutionary_role_coordinator = evolutionary_role_coordinator
|
|
382
|
+
self._use_evolutionary_roles = use_evolutionary_roles
|
|
383
|
+
# Some env configs omit move_energy_cost; default to 1 to match simulator fallback.
|
|
384
|
+
self._move_energy_cost = getattr(policy_env_info, "move_energy_cost", 1)
|
|
385
|
+
|
|
386
|
+
# Observation grid half-ranges
|
|
387
|
+
self._obs_hr = policy_env_info.obs_height // 2
|
|
388
|
+
self._obs_wr = policy_env_info.obs_width // 2
|
|
389
|
+
|
|
390
|
+
# Action lookup
|
|
391
|
+
self._action_names = list(policy_env_info.action_names)
|
|
392
|
+
self._action_set = set(self._action_names)
|
|
393
|
+
self._vibe_names = [
|
|
394
|
+
name[len("change_vibe_") :] for name in self._action_names if name.startswith("change_vibe_")
|
|
395
|
+
]
|
|
396
|
+
self._move_deltas = {
|
|
397
|
+
"north": (-1, 0),
|
|
398
|
+
"south": (1, 0),
|
|
399
|
+
"east": (0, 1),
|
|
400
|
+
"west": (0, -1),
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
# Feature name sets for observation parsing
|
|
404
|
+
self._spatial_feature_names = {"tag", "cooldown_remaining", "clipped", "remaining_uses"}
|
|
405
|
+
self._agent_feature_key_by_name = {"agent:group": "agent_group", "agent:frozen": "agent_frozen"}
|
|
406
|
+
self._protocol_input_prefix = "protocol_input:"
|
|
407
|
+
self._protocol_output_prefix = "protocol_output:"
|
|
408
|
+
|
|
409
|
+
# Cache tag names on first use
|
|
410
|
+
self._tag_names: dict[int, str] = {}
|
|
411
|
+
|
|
412
|
+
def _noop(self) -> Action:
|
|
413
|
+
return Action(name="noop")
|
|
414
|
+
|
|
415
|
+
def _has_vibe(self, vibe_name: str) -> bool:
|
|
416
|
+
return vibe_name in self._vibe_names
|
|
417
|
+
|
|
418
|
+
def _choose_role_vibe(self, s: CogsguardAgentState) -> str:
|
|
419
|
+
if self._use_evolutionary_roles and self._evolutionary_role_coordinator is not None:
|
|
420
|
+
return self._evolutionary_role_coordinator.choose_vibe(s.agent_id, s.step_count)
|
|
421
|
+
if self._smart_role_coordinator is None:
|
|
422
|
+
return random.choice(ROLE_VIBES)
|
|
423
|
+
return self._smart_role_coordinator.choose_role(s.agent_id)
|
|
424
|
+
|
|
425
|
+
def _move(self, direction: str) -> Action:
|
|
426
|
+
action_name = f"move_{direction}"
|
|
427
|
+
if action_name in self._action_set:
|
|
428
|
+
return Action(name=action_name)
|
|
429
|
+
return self._noop()
|
|
430
|
+
|
|
431
|
+
def initial_agent_state(self) -> CogsguardAgentState:
|
|
432
|
+
"""Initialize state for this agent.
|
|
433
|
+
|
|
434
|
+
IMPORTANT: Positions are tracked RELATIVE to the agent's starting position.
|
|
435
|
+
- Agent starts at (0, 0) in internal coordinates
|
|
436
|
+
- All discovered object positions are relative to this origin
|
|
437
|
+
- The actual map size doesn't matter - we only use relative offsets
|
|
438
|
+
- Occupancy grid is centered at (grid_size/2, grid_size/2) to allow negative relative positions
|
|
439
|
+
"""
|
|
440
|
+
self._tag_names = self._policy_env_info.tag_id_to_name
|
|
441
|
+
|
|
442
|
+
# Use a grid large enough for typical exploration range
|
|
443
|
+
# Grid center is the agent's starting position (0, 0) in relative coords
|
|
444
|
+
# But stored at grid_center to allow negative relative positions
|
|
445
|
+
grid_size = 200
|
|
446
|
+
grid_center = grid_size // 2
|
|
447
|
+
|
|
448
|
+
state = CogsguardAgentState(
|
|
449
|
+
agent_id=self._agent_id,
|
|
450
|
+
role=self._role,
|
|
451
|
+
map_height=grid_size,
|
|
452
|
+
map_width=grid_size,
|
|
453
|
+
occupancy=[[CellType.FREE.value] * grid_size for _ in range(grid_size)],
|
|
454
|
+
explored=[[False] * grid_size for _ in range(grid_size)],
|
|
455
|
+
# Start at (0, 0) relative - stored at grid center for negative offset support
|
|
456
|
+
row=grid_center,
|
|
457
|
+
col=grid_center,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
if self._move_energy_cost is not None:
|
|
461
|
+
state.MOVE_ENERGY_COST = self._move_energy_cost
|
|
462
|
+
return state
|
|
463
|
+
|
|
464
|
+
def step_with_state(self, obs: AgentObservation, s: CogsguardAgentState) -> tuple[Action, CogsguardAgentState]:
|
|
465
|
+
"""Main step function."""
|
|
466
|
+
s.step_count += 1
|
|
467
|
+
s.current_obs = obs
|
|
468
|
+
s.agent_occupancy.clear()
|
|
469
|
+
|
|
470
|
+
# Read inventory
|
|
471
|
+
self._read_inventory(s, obs)
|
|
472
|
+
|
|
473
|
+
# Update position from last action
|
|
474
|
+
self._update_agent_position(s)
|
|
475
|
+
|
|
476
|
+
# Parse observation
|
|
477
|
+
parsed = self._parse_observation(s, obs)
|
|
478
|
+
|
|
479
|
+
# Update map knowledge
|
|
480
|
+
self._update_occupancy_and_discover(s, parsed)
|
|
481
|
+
|
|
482
|
+
if self._smart_role_coordinator is not None:
|
|
483
|
+
self._smart_role_coordinator.update_agent(s)
|
|
484
|
+
|
|
485
|
+
# Update phase
|
|
486
|
+
self._update_phase(s)
|
|
487
|
+
|
|
488
|
+
# Execute current phase
|
|
489
|
+
action = self._execute_phase(s)
|
|
490
|
+
|
|
491
|
+
# Debug logging
|
|
492
|
+
if DEBUG and s.step_count <= 50: # Only first 50 steps per agent
|
|
493
|
+
gear_status = "HAS_GEAR" if s.has_gear() else "NO_GEAR"
|
|
494
|
+
nexus_pos = s.get_structure_position(StructureType.HUB) or "NOT_FOUND"
|
|
495
|
+
print(
|
|
496
|
+
f"[A{s.agent_id}] Step {s.step_count}: vibe={s.current_vibe} role={s.role.value} | "
|
|
497
|
+
f"Phase={s.phase.value} | {gear_status} | "
|
|
498
|
+
f"Energy={s.energy} | "
|
|
499
|
+
f"Pos=({s.row},{s.col}) | "
|
|
500
|
+
f"Nexus@{nexus_pos} | "
|
|
501
|
+
f"Action={action.name}"
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
s.last_action = action
|
|
505
|
+
return action, s
|
|
506
|
+
|
|
507
|
+
def _read_inventory(self, s: CogsguardAgentState, obs: AgentObservation) -> None:
|
|
508
|
+
"""Read inventory, vibe, and last executed action from observation."""
|
|
509
|
+
inv = {}
|
|
510
|
+
vibe_id = 0 # Default vibe ID
|
|
511
|
+
last_action_id: Optional[int] = None
|
|
512
|
+
center_r, center_c = self._obs_hr, self._obs_wr
|
|
513
|
+
token_value_base = None
|
|
514
|
+
for tok in obs.tokens:
|
|
515
|
+
if tok.location == (center_r, center_c):
|
|
516
|
+
feature_name = tok.feature.name
|
|
517
|
+
if feature_name.startswith("inv:"):
|
|
518
|
+
if token_value_base is None:
|
|
519
|
+
token_value_base = int(tok.feature.normalization)
|
|
520
|
+
add_inventory_token(inv, feature_name, tok.value, token_value_base=token_value_base)
|
|
521
|
+
elif feature_name == "vibe":
|
|
522
|
+
vibe_id = tok.value
|
|
523
|
+
elif feature_name == "last_action":
|
|
524
|
+
last_action_id = tok.value
|
|
525
|
+
|
|
526
|
+
s.energy = inv.get("energy", 0)
|
|
527
|
+
s.carbon = inv.get("carbon", 0)
|
|
528
|
+
s.oxygen = inv.get("oxygen", 0)
|
|
529
|
+
s.germanium = inv.get("germanium", 0)
|
|
530
|
+
s.silicon = inv.get("silicon", 0)
|
|
531
|
+
s.heart = inv.get("heart", 0)
|
|
532
|
+
s.influence = inv.get("influence", 0)
|
|
533
|
+
s.hp = inv.get("hp", 100)
|
|
534
|
+
|
|
535
|
+
# Gear items
|
|
536
|
+
s.miner = inv.get("miner", 0)
|
|
537
|
+
s.scout = inv.get("scout", 0)
|
|
538
|
+
s.aligner = inv.get("aligner", 0)
|
|
539
|
+
s.scrambler = inv.get("scrambler", 0)
|
|
540
|
+
|
|
541
|
+
if s.heart != s._last_heart_count:
|
|
542
|
+
s._heart_wait_start = 0
|
|
543
|
+
s._last_heart_count = s.heart
|
|
544
|
+
|
|
545
|
+
# Read vibe name from vibe ID using policy_env_info
|
|
546
|
+
s.current_vibe = self._get_vibe_name(vibe_id)
|
|
547
|
+
|
|
548
|
+
# Read last executed action from observation
|
|
549
|
+
# This tells us what the simulator actually did, not what we intended
|
|
550
|
+
if last_action_id is not None:
|
|
551
|
+
action_names = self._policy_env_info.action_names
|
|
552
|
+
if 0 <= last_action_id < len(action_names):
|
|
553
|
+
s.last_action_executed = action_names[last_action_id]
|
|
554
|
+
else:
|
|
555
|
+
s.last_action_executed = None
|
|
556
|
+
else:
|
|
557
|
+
s.last_action_executed = None
|
|
558
|
+
|
|
559
|
+
def _get_vibe_name(self, vibe_id: int) -> str:
|
|
560
|
+
"""Convert vibe ID to vibe name."""
|
|
561
|
+
if 0 <= vibe_id < len(self._vibe_names):
|
|
562
|
+
return self._vibe_names[vibe_id]
|
|
563
|
+
return "default"
|
|
564
|
+
|
|
565
|
+
def _update_agent_position(self, s: CogsguardAgentState) -> None:
|
|
566
|
+
"""Update position based on last action that was ACTUALLY EXECUTED.
|
|
567
|
+
|
|
568
|
+
IMPORTANT: Position is updated from the executed action in observations.
|
|
569
|
+
This keeps internal position consistent with the simulator, even when
|
|
570
|
+
movement is delayed or overridden by another controller.
|
|
571
|
+
"""
|
|
572
|
+
# Use last_action_executed from observation, NOT last_action (our intent)
|
|
573
|
+
executed_action = s.last_action_executed
|
|
574
|
+
intended_action = s.last_action.name if s.last_action else None
|
|
575
|
+
|
|
576
|
+
# Debug: Log when intended != executed (action failed, delayed, or human control)
|
|
577
|
+
if DEBUG and s.step_count <= 100:
|
|
578
|
+
if intended_action and executed_action and intended_action != executed_action:
|
|
579
|
+
print(
|
|
580
|
+
f"[A{s.agent_id}] ACTION_MISMATCH: intended={intended_action}, "
|
|
581
|
+
f"executed={executed_action} (action failed/delayed or human control)"
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# ONLY update position when:
|
|
585
|
+
# 1. The executed action is a move
|
|
586
|
+
# 2. We're not interacting with an object this step
|
|
587
|
+
if executed_action and executed_action.startswith("move_") and not s.using_object_this_step:
|
|
588
|
+
direction = executed_action[5:] # Remove "move_" prefix
|
|
589
|
+
if direction in self._move_deltas:
|
|
590
|
+
dr, dc = self._move_deltas[direction]
|
|
591
|
+
s.row += dr
|
|
592
|
+
s.col += dc
|
|
593
|
+
|
|
594
|
+
s.using_object_this_step = False
|
|
595
|
+
|
|
596
|
+
# Track position history
|
|
597
|
+
current_pos = (s.row, s.col)
|
|
598
|
+
s.position_history.append(current_pos)
|
|
599
|
+
if len(s.position_history) > 30:
|
|
600
|
+
s.position_history.pop(0)
|
|
601
|
+
|
|
602
|
+
def _parse_observation(self, s: CogsguardAgentState, obs: AgentObservation) -> ParsedObservation:
|
|
603
|
+
"""Parse observation into structured format."""
|
|
604
|
+
return utils_parse_observation(
|
|
605
|
+
s, # type: ignore[arg-type] # CogsguardAgentState is compatible with SimpleAgentState
|
|
606
|
+
obs,
|
|
607
|
+
obs_hr=self._obs_hr,
|
|
608
|
+
obs_wr=self._obs_wr,
|
|
609
|
+
spatial_feature_names=self._spatial_feature_names,
|
|
610
|
+
agent_feature_key_by_name=self._agent_feature_key_by_name,
|
|
611
|
+
protocol_input_prefix=self._protocol_input_prefix,
|
|
612
|
+
protocol_output_prefix=self._protocol_output_prefix,
|
|
613
|
+
tag_names=self._tag_names,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
def _update_occupancy_and_discover(self, s: CogsguardAgentState, parsed: ParsedObservation) -> None:
|
|
617
|
+
"""Update occupancy map and discover objects."""
|
|
618
|
+
if s.row < 0:
|
|
619
|
+
return
|
|
620
|
+
|
|
621
|
+
# Mark all observed cells as FREE and explored
|
|
622
|
+
for obs_r in range(2 * self._obs_hr + 1):
|
|
623
|
+
for obs_c in range(2 * self._obs_wr + 1):
|
|
624
|
+
r, c = obs_r - self._obs_hr + s.row, obs_c - self._obs_wr + s.col
|
|
625
|
+
if 0 <= r < s.map_height and 0 <= c < s.map_width:
|
|
626
|
+
s.occupancy[r][c] = CellType.FREE.value
|
|
627
|
+
s.explored[r][c] = True
|
|
628
|
+
|
|
629
|
+
# Process discovered objects
|
|
630
|
+
if DEBUG and s.step_count == 1:
|
|
631
|
+
print(f"[A{s.agent_id}] Nearby objects: {[obj.name for obj in parsed.nearby_objects.values()]}")
|
|
632
|
+
|
|
633
|
+
for pos, obj_state in parsed.nearby_objects.items():
|
|
634
|
+
r, c = pos
|
|
635
|
+
obj_name = obj_state.name.lower()
|
|
636
|
+
obj_tags = [tag.lower() for tag in obj_state.tags]
|
|
637
|
+
|
|
638
|
+
# Walls are obstacles
|
|
639
|
+
if is_wall(obj_name):
|
|
640
|
+
s.occupancy[r][c] = CellType.OBSTACLE.value
|
|
641
|
+
self._update_structure(s, pos, obj_name, StructureType.WALL, obj_state)
|
|
642
|
+
continue
|
|
643
|
+
|
|
644
|
+
# Track other agents
|
|
645
|
+
if obj_name == "agent" and obj_state.agent_id != s.agent_id:
|
|
646
|
+
s.agent_occupancy.add((r, c))
|
|
647
|
+
continue
|
|
648
|
+
|
|
649
|
+
# Discover gear stations
|
|
650
|
+
for _role, station_name in ROLE_TO_STATION.items():
|
|
651
|
+
if (
|
|
652
|
+
is_station(obj_name, station_name)
|
|
653
|
+
or station_name in obj_name
|
|
654
|
+
or any(station_name in tag for tag in obj_tags)
|
|
655
|
+
):
|
|
656
|
+
s.occupancy[r][c] = CellType.OBSTACLE.value
|
|
657
|
+
struct_type = self._get_station_type(station_name)
|
|
658
|
+
self._update_structure(s, pos, obj_name, struct_type, obj_state)
|
|
659
|
+
break
|
|
660
|
+
|
|
661
|
+
# Discover supply depots (junction in cogsguard)
|
|
662
|
+
is_junction = has_type_tag(obj_tags, ("supply_depot", "junction", "junction"))
|
|
663
|
+
if is_junction:
|
|
664
|
+
s.occupancy[r][c] = CellType.OBSTACLE.value
|
|
665
|
+
self._update_structure(s, pos, obj_name, StructureType.CHARGER, obj_state)
|
|
666
|
+
|
|
667
|
+
# Discover hub (the main base / resource deposit point)
|
|
668
|
+
is_hub = (
|
|
669
|
+
"hub" in obj_name
|
|
670
|
+
or obj_name in {"main_nexus"}
|
|
671
|
+
or any("hub" in tag or "main_nexus" in tag or "nexus" in tag for tag in obj_tags)
|
|
672
|
+
)
|
|
673
|
+
if is_hub:
|
|
674
|
+
s.occupancy[r][c] = CellType.OBSTACLE.value
|
|
675
|
+
self._update_structure(s, pos, obj_name, StructureType.HUB, obj_state)
|
|
676
|
+
|
|
677
|
+
# Discover chest (for hearts) - exclude extractors which are ChestConfig-backed.
|
|
678
|
+
resources = ["carbon", "oxygen", "germanium", "silicon"]
|
|
679
|
+
has_extractor_tag = "extractor" in obj_name or any("extractor" in tag for tag in obj_tags)
|
|
680
|
+
is_resource_chest = any(f"{res}_" in obj_name or f"{res}chest" in obj_name for res in resources)
|
|
681
|
+
if (
|
|
682
|
+
not has_extractor_tag
|
|
683
|
+
and (obj_name == "chest" or ("chest" in obj_name and not is_resource_chest))
|
|
684
|
+
or any(tag == "chest" for tag in obj_tags)
|
|
685
|
+
):
|
|
686
|
+
s.occupancy[r][c] = CellType.OBSTACLE.value
|
|
687
|
+
self._update_structure(s, pos, obj_name, StructureType.CHEST, obj_state)
|
|
688
|
+
|
|
689
|
+
# Discover extractors (in cogsguard they're named {resource}_extractor).
|
|
690
|
+
for resource in ["carbon", "oxygen", "germanium", "silicon"]:
|
|
691
|
+
if f"{resource}_extractor" in obj_name or any(f"{resource}_extractor" in tag for tag in obj_tags):
|
|
692
|
+
s.occupancy[r][c] = CellType.OBSTACLE.value
|
|
693
|
+
self._update_structure(s, pos, obj_name, StructureType.EXTRACTOR, obj_state, resource_type=resource)
|
|
694
|
+
break
|
|
695
|
+
|
|
696
|
+
def _get_station_type(self, station_name: str) -> StructureType:
|
|
697
|
+
"""Convert station name to StructureType."""
|
|
698
|
+
mapping = {
|
|
699
|
+
"miner_station": StructureType.MINER_STATION,
|
|
700
|
+
"scout_station": StructureType.SCOUT_STATION,
|
|
701
|
+
"aligner_station": StructureType.ALIGNER_STATION,
|
|
702
|
+
"scrambler_station": StructureType.SCRAMBLER_STATION,
|
|
703
|
+
}
|
|
704
|
+
return mapping.get(station_name, StructureType.UNKNOWN)
|
|
705
|
+
|
|
706
|
+
def _update_structure(
|
|
707
|
+
self,
|
|
708
|
+
s: CogsguardAgentState,
|
|
709
|
+
pos: tuple[int, int],
|
|
710
|
+
obj_name: str,
|
|
711
|
+
structure_type: StructureType,
|
|
712
|
+
obj_state: ObjectState,
|
|
713
|
+
resource_type: Optional[str] = None,
|
|
714
|
+
) -> None:
|
|
715
|
+
"""Update or create a structure in the structures map."""
|
|
716
|
+
# Derive alignment from clipped field, object name, and structure type
|
|
717
|
+
clipped = obj_state.clipped > 0
|
|
718
|
+
alignment = self._derive_alignment(obj_name, clipped, structure_type, obj_state.tags)
|
|
719
|
+
if pos in s.alignment_overrides:
|
|
720
|
+
override = s.alignment_overrides[pos]
|
|
721
|
+
if alignment is None:
|
|
722
|
+
alignment = override
|
|
723
|
+
elif alignment != override:
|
|
724
|
+
s.alignment_overrides[pos] = alignment
|
|
725
|
+
|
|
726
|
+
# Calculate inventory amount for extractors
|
|
727
|
+
# Key insight: empty dict {} on FIRST observation = no info yet (assume full)
|
|
728
|
+
# Empty dict {} on SUBSEQUENT observation = depleted (0 resources)
|
|
729
|
+
is_new_structure = pos not in s.structures
|
|
730
|
+
|
|
731
|
+
if structure_type == StructureType.EXTRACTOR:
|
|
732
|
+
# For extractors, track resource counts carefully
|
|
733
|
+
if resource_type and resource_type in obj_state.inventory:
|
|
734
|
+
# We have actual inventory info for this resource type
|
|
735
|
+
inventory_amount = obj_state.inventory[resource_type]
|
|
736
|
+
elif obj_state.inventory:
|
|
737
|
+
# Sum all inventory if resource type not specified
|
|
738
|
+
inventory_amount = sum(obj_state.inventory.values())
|
|
739
|
+
elif is_new_structure:
|
|
740
|
+
# First time seeing this extractor with no inventory info
|
|
741
|
+
# Assume it has resources (we don't know yet)
|
|
742
|
+
inventory_amount = 999
|
|
743
|
+
else:
|
|
744
|
+
# Known extractor with empty inventory dict = depleted (0 resources)
|
|
745
|
+
inventory_amount = 0
|
|
746
|
+
if DEBUG and inventory_amount == 0:
|
|
747
|
+
print(f"[A{s.agent_id}] EXTRACTOR_EMPTY: {pos} resource={resource_type} inv={obj_state.inventory}")
|
|
748
|
+
elif obj_state.inventory:
|
|
749
|
+
# Non-extractors: use inventory sum if present
|
|
750
|
+
inventory_amount = sum(obj_state.inventory.values())
|
|
751
|
+
else:
|
|
752
|
+
inventory_amount = 999 # Default: unknown/full
|
|
753
|
+
|
|
754
|
+
if pos in s.structures:
|
|
755
|
+
# Update existing structure
|
|
756
|
+
struct = s.structures[pos]
|
|
757
|
+
struct.last_seen_step = s.step_count
|
|
758
|
+
struct.cooldown_remaining = obj_state.cooldown_remaining
|
|
759
|
+
struct.remaining_uses = obj_state.remaining_uses
|
|
760
|
+
struct.clipped = clipped
|
|
761
|
+
struct.alignment = alignment
|
|
762
|
+
struct.inventory_amount = inventory_amount
|
|
763
|
+
else:
|
|
764
|
+
# Create new structure
|
|
765
|
+
s.structures[pos] = StructureInfo(
|
|
766
|
+
position=pos,
|
|
767
|
+
structure_type=structure_type,
|
|
768
|
+
name=obj_name,
|
|
769
|
+
last_seen_step=s.step_count,
|
|
770
|
+
resource_type=resource_type,
|
|
771
|
+
cooldown_remaining=obj_state.cooldown_remaining,
|
|
772
|
+
remaining_uses=obj_state.remaining_uses,
|
|
773
|
+
clipped=clipped,
|
|
774
|
+
alignment=alignment,
|
|
775
|
+
inventory_amount=inventory_amount,
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
if structure_type in {
|
|
779
|
+
StructureType.HUB,
|
|
780
|
+
StructureType.CHEST,
|
|
781
|
+
StructureType.MINER_STATION,
|
|
782
|
+
StructureType.SCOUT_STATION,
|
|
783
|
+
StructureType.ALIGNER_STATION,
|
|
784
|
+
StructureType.SCRAMBLER_STATION,
|
|
785
|
+
}:
|
|
786
|
+
s.stations[structure_type.value] = pos
|
|
787
|
+
|
|
788
|
+
if structure_type == StructureType.CHARGER:
|
|
789
|
+
for idx, (depot_pos, _alignment) in enumerate(s.supply_depots):
|
|
790
|
+
if depot_pos == pos:
|
|
791
|
+
s.supply_depots[idx] = (pos, alignment)
|
|
792
|
+
break
|
|
793
|
+
else:
|
|
794
|
+
s.supply_depots.append((pos, alignment))
|
|
795
|
+
if DEBUG:
|
|
796
|
+
print(
|
|
797
|
+
f"[A{s.agent_id}] STRUCTURE: Added {structure_type.value} at {pos} "
|
|
798
|
+
f"(alignment={alignment}, inv={inventory_amount})"
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
def _derive_alignment(
|
|
802
|
+
self,
|
|
803
|
+
obj_name: str,
|
|
804
|
+
clipped: bool,
|
|
805
|
+
structure_type: Optional[StructureType] = None,
|
|
806
|
+
tags: Optional[list[str]] = None,
|
|
807
|
+
) -> Optional[str]:
|
|
808
|
+
"""Derive alignment from object name, tags, clipped status, and structure type.
|
|
809
|
+
|
|
810
|
+
In CoGsGuard:
|
|
811
|
+
- Hub/nexus = cogs-aligned
|
|
812
|
+
- Charger/supply_depot alignment comes from tags/collectives
|
|
813
|
+
"""
|
|
814
|
+
obj_lower = obj_name.lower()
|
|
815
|
+
tag_lowers = [tag.lower() for tag in tags or []]
|
|
816
|
+
# Check if name contains alignment info
|
|
817
|
+
if "cogs" in obj_lower or "cogs_" in obj_lower or any("cogs" in tag for tag in tag_lowers):
|
|
818
|
+
return "cogs"
|
|
819
|
+
if "clips" in obj_lower or "clips_" in obj_lower or any("clips" in tag for tag in tag_lowers):
|
|
820
|
+
return "clips"
|
|
821
|
+
# Clipped field indicates clips alignment
|
|
822
|
+
if clipped:
|
|
823
|
+
return "clips"
|
|
824
|
+
# Structure type defaults:
|
|
825
|
+
# - Hub/nexus defaults to cogs (main cogs building)
|
|
826
|
+
if structure_type == StructureType.HUB:
|
|
827
|
+
if (
|
|
828
|
+
"nexus" in obj_lower
|
|
829
|
+
or "hub" in obj_lower
|
|
830
|
+
or any("nexus" in tag for tag in tag_lowers)
|
|
831
|
+
or any("hub" in tag for tag in tag_lowers)
|
|
832
|
+
):
|
|
833
|
+
return "cogs"
|
|
834
|
+
return None # Unknown/neutral
|
|
835
|
+
|
|
836
|
+
def _update_phase(self, s: CogsguardAgentState) -> None:
|
|
837
|
+
"""Update agent phase based on current vibe.
|
|
838
|
+
|
|
839
|
+
Vibe-based state machine:
|
|
840
|
+
- default/heart: do nothing
|
|
841
|
+
- gear: pick role via smart coordinator, change vibe to that role
|
|
842
|
+
- role vibe (scout/miner/aligner/scrambler): get gear first, then execute role
|
|
843
|
+
"""
|
|
844
|
+
vibe = s.current_vibe
|
|
845
|
+
|
|
846
|
+
# Role vibes: scout, miner, aligner, scrambler
|
|
847
|
+
if vibe in VIBE_TO_ROLE:
|
|
848
|
+
# Update role based on vibe
|
|
849
|
+
s.role = VIBE_TO_ROLE[vibe]
|
|
850
|
+
# Always try to get gear first, then execute role
|
|
851
|
+
if s.has_gear():
|
|
852
|
+
s.phase = CogsguardPhase.EXECUTE_ROLE
|
|
853
|
+
elif s.step_count > 30 and s.role in (Role.MINER, Role.SCOUT):
|
|
854
|
+
# After 30 steps, miners/scouts can proceed without gear to bootstrap economy/exploration.
|
|
855
|
+
s.phase = CogsguardPhase.EXECUTE_ROLE
|
|
856
|
+
else:
|
|
857
|
+
s.phase = CogsguardPhase.GET_GEAR
|
|
858
|
+
else:
|
|
859
|
+
# For default, heart, gear vibes - handled in _execute_phase
|
|
860
|
+
s.phase = CogsguardPhase.GET_GEAR # Will be overridden
|
|
861
|
+
|
|
862
|
+
def _execute_phase(self, s: CogsguardAgentState) -> Action:
|
|
863
|
+
"""Execute action for current phase based on vibe.
|
|
864
|
+
|
|
865
|
+
Vibe-based behavior:
|
|
866
|
+
- default: do nothing (noop)
|
|
867
|
+
- gear: pick role via smart coordinator, change vibe to that role
|
|
868
|
+
- role vibe: get gear then execute role
|
|
869
|
+
- heart: do nothing (noop)
|
|
870
|
+
"""
|
|
871
|
+
vibe = s.current_vibe
|
|
872
|
+
|
|
873
|
+
# Default vibe: do nothing (wait for external vibe change)
|
|
874
|
+
if vibe == "default":
|
|
875
|
+
return self._noop()
|
|
876
|
+
|
|
877
|
+
# Heart vibe: do nothing
|
|
878
|
+
if vibe == "heart":
|
|
879
|
+
return self._noop()
|
|
880
|
+
|
|
881
|
+
# Gear vibe: pick a role and change vibe to it
|
|
882
|
+
if vibe == "gear":
|
|
883
|
+
selected_role = self._choose_role_vibe(s)
|
|
884
|
+
if DEBUG:
|
|
885
|
+
print(f"[A{s.agent_id}] GEAR_VIBE: Picking role vibe: {selected_role}")
|
|
886
|
+
return change_vibe_action(selected_role, action_names=self._action_names)
|
|
887
|
+
|
|
888
|
+
# Role vibes: execute the role behavior
|
|
889
|
+
if vibe in VIBE_TO_ROLE:
|
|
890
|
+
if s.phase == CogsguardPhase.GET_GEAR:
|
|
891
|
+
return self._do_get_gear(s)
|
|
892
|
+
elif s.phase == CogsguardPhase.EXECUTE_ROLE:
|
|
893
|
+
return self.execute_role(s)
|
|
894
|
+
|
|
895
|
+
return self._noop()
|
|
896
|
+
|
|
897
|
+
def _do_recharge(self, s: CogsguardAgentState) -> Action:
|
|
898
|
+
"""Recharge by standing near the main nexus (cogs-aligned, has energy AOE).
|
|
899
|
+
|
|
900
|
+
IMPORTANT: If energy is very low, we can't even move to the nexus!
|
|
901
|
+
In that case, just wait (noop) and hope AOE regeneration eventually helps,
|
|
902
|
+
or try a single step towards the nexus if we can afford it.
|
|
903
|
+
"""
|
|
904
|
+
# The main_nexus is cogs-aligned and has AOE that gives energy to cogs agents
|
|
905
|
+
# The supply_depot is clips-aligned and won't give energy to cogs agents
|
|
906
|
+
nexus_pos = s.get_structure_position(StructureType.HUB)
|
|
907
|
+
if nexus_pos is None:
|
|
908
|
+
if DEBUG:
|
|
909
|
+
print(f"[A{s.agent_id}] RECHARGE: No nexus found, exploring")
|
|
910
|
+
return self._explore(s)
|
|
911
|
+
|
|
912
|
+
# Just need to be near the nexus (within AOE range), not adjacent
|
|
913
|
+
dist = abs(s.row - nexus_pos[0]) + abs(s.col - nexus_pos[1])
|
|
914
|
+
aoe_range = 10 # AOE range from recipe
|
|
915
|
+
|
|
916
|
+
if dist <= aoe_range:
|
|
917
|
+
if DEBUG and s.step_count % 20 == 0:
|
|
918
|
+
print(f"[A{s.agent_id}] RECHARGE: Near nexus (dist={dist}), waiting for AOE (energy={s.energy})")
|
|
919
|
+
return self._noop()
|
|
920
|
+
|
|
921
|
+
# Check if we have enough energy to move at all
|
|
922
|
+
# If energy is too low, just wait and hope for passive regen or AOE
|
|
923
|
+
if s.energy < s.MOVE_ENERGY_COST:
|
|
924
|
+
if DEBUG and s.step_count % 20 == 0:
|
|
925
|
+
print(
|
|
926
|
+
f"[A{s.agent_id}] RECHARGE: Energy critically low ({s.energy}), "
|
|
927
|
+
f"can't move to nexus at dist={dist}, waiting for regen"
|
|
928
|
+
)
|
|
929
|
+
return self._noop()
|
|
930
|
+
|
|
931
|
+
# If we have some energy but not much, try to move one step at a time towards nexus
|
|
932
|
+
# This is more conservative - don't commit to a long path if we might not make it
|
|
933
|
+
if s.energy < s.MOVE_ENERGY_COST * 3:
|
|
934
|
+
if DEBUG and s.step_count % 10 == 0:
|
|
935
|
+
print(
|
|
936
|
+
f"[A{s.agent_id}] RECHARGE: Low energy ({s.energy}), "
|
|
937
|
+
f"taking single step towards nexus at {nexus_pos}"
|
|
938
|
+
)
|
|
939
|
+
# Simple single-step movement towards nexus
|
|
940
|
+
dr = nexus_pos[0] - s.row
|
|
941
|
+
dc = nexus_pos[1] - s.col
|
|
942
|
+
if abs(dr) >= abs(dc):
|
|
943
|
+
# Move vertically
|
|
944
|
+
if dr > 0:
|
|
945
|
+
return self._move("south")
|
|
946
|
+
else:
|
|
947
|
+
return self._move("north")
|
|
948
|
+
else:
|
|
949
|
+
# Move horizontally
|
|
950
|
+
if dc > 0:
|
|
951
|
+
return self._move("east")
|
|
952
|
+
else:
|
|
953
|
+
return self._move("west")
|
|
954
|
+
|
|
955
|
+
if DEBUG and s.step_count % 20 == 0:
|
|
956
|
+
print(f"[A{s.agent_id}] RECHARGE: Moving to nexus at {nexus_pos} from ({s.row},{s.col}), dist={dist}")
|
|
957
|
+
return self._move_towards(s, nexus_pos, reach_adjacent=True)
|
|
958
|
+
|
|
959
|
+
def _do_get_gear(self, s: CogsguardAgentState) -> Action:
|
|
960
|
+
"""Find gear station and equip gear."""
|
|
961
|
+
if (
|
|
962
|
+
self._smart_role_coordinator is not None
|
|
963
|
+
and s.role != Role.SCRAMBLER
|
|
964
|
+
and s.step_count <= SCRAMBLER_GEAR_PRIORITY_STEPS
|
|
965
|
+
):
|
|
966
|
+
gear_counts = self._smart_role_coordinator.get_role_gear_counts()
|
|
967
|
+
if gear_counts.get("scrambler", 0) == 0:
|
|
968
|
+
if DEBUG and s.step_count % 5 == 0:
|
|
969
|
+
print(f"[A{s.agent_id}] GET_GEAR: yielding to scrambler gear priority")
|
|
970
|
+
return self._explore(s)
|
|
971
|
+
station_name = s.get_gear_station_name()
|
|
972
|
+
station_pos = s.get_structure_position(s.get_gear_station_type())
|
|
973
|
+
hub_pos = s.get_structure_position(StructureType.HUB)
|
|
974
|
+
|
|
975
|
+
if DEBUG and s.step_count <= 10:
|
|
976
|
+
known_structures = sorted({struct.structure_type.value for struct in s.structures.values()})
|
|
977
|
+
print(f"[A{s.agent_id}] GET_GEAR: station={station_name} pos={station_pos} all={known_structures}")
|
|
978
|
+
|
|
979
|
+
# Bootstrap with scout gear for mobility when station is unknown.
|
|
980
|
+
if station_pos is None:
|
|
981
|
+
scout_station = s.get_structure_position(StructureType.SCOUT_STATION)
|
|
982
|
+
if scout_station is not None and s.scout == 0 and station_name != "scout_station":
|
|
983
|
+
if not is_adjacent((s.row, s.col), scout_station):
|
|
984
|
+
return self._move_towards(s, scout_station, reach_adjacent=True)
|
|
985
|
+
return self._use_object_at(s, scout_station)
|
|
986
|
+
|
|
987
|
+
if hub_pos is not None:
|
|
988
|
+
offset = GEAR_SEARCH_OFFSETS[(s.agent_id + s.step_count // 10) % len(GEAR_SEARCH_OFFSETS)]
|
|
989
|
+
target = (hub_pos[0] + offset[0], hub_pos[1] + offset[1])
|
|
990
|
+
return self._move_towards(s, target, reach_adjacent=True)
|
|
991
|
+
|
|
992
|
+
if DEBUG:
|
|
993
|
+
print(f"[A{s.agent_id}] GET_GEAR: No {station_name} found, exploring")
|
|
994
|
+
return self._explore(s)
|
|
995
|
+
|
|
996
|
+
# Navigate to station
|
|
997
|
+
adj = is_adjacent((s.row, s.col), station_pos)
|
|
998
|
+
if DEBUG and s.step_count <= 60:
|
|
999
|
+
print(f"[A{s.agent_id}] GET_GEAR: pos=({s.row},{s.col}), station={station_pos}, adjacent={adj}")
|
|
1000
|
+
if not adj:
|
|
1001
|
+
return self._move_towards(s, station_pos, reach_adjacent=True)
|
|
1002
|
+
|
|
1003
|
+
# Bump station to get gear
|
|
1004
|
+
if DEBUG:
|
|
1005
|
+
print(f"[A{s.agent_id}] GET_GEAR: Adjacent to {station_name}, bumping it!")
|
|
1006
|
+
return self._use_object_at(s, station_pos)
|
|
1007
|
+
|
|
1008
|
+
def execute_role(self, s: CogsguardAgentState) -> Action:
|
|
1009
|
+
"""Execute role-specific behavior. Override in subclasses."""
|
|
1010
|
+
if s.step_count <= 100:
|
|
1011
|
+
print(f"[A{s.agent_id}] BASE_EXECUTE_ROLE: impl={type(self).__name__}, role={s.role}")
|
|
1012
|
+
return self._explore(s)
|
|
1013
|
+
|
|
1014
|
+
# =========================================================================
|
|
1015
|
+
# Navigation utilities
|
|
1016
|
+
# =========================================================================
|
|
1017
|
+
|
|
1018
|
+
def _use_object_at(self, s: CogsguardAgentState, target_pos: tuple[int, int]) -> Action:
|
|
1019
|
+
"""Use an object by moving into its cell."""
|
|
1020
|
+
tr, tc = target_pos
|
|
1021
|
+
if s.row == tr and s.col == tc:
|
|
1022
|
+
return self._noop()
|
|
1023
|
+
|
|
1024
|
+
dr = tr - s.row
|
|
1025
|
+
dc = tc - s.col
|
|
1026
|
+
|
|
1027
|
+
# Check agent collision
|
|
1028
|
+
if (tr, tc) in s.agent_occupancy:
|
|
1029
|
+
return self._noop()
|
|
1030
|
+
|
|
1031
|
+
# Mark that we're using an object
|
|
1032
|
+
s.using_object_this_step = True
|
|
1033
|
+
|
|
1034
|
+
if dr == -1:
|
|
1035
|
+
return self._move("north")
|
|
1036
|
+
if dr == 1:
|
|
1037
|
+
return self._move("south")
|
|
1038
|
+
if dc == 1:
|
|
1039
|
+
return self._move("east")
|
|
1040
|
+
if dc == -1:
|
|
1041
|
+
return self._move("west")
|
|
1042
|
+
|
|
1043
|
+
return self._noop()
|
|
1044
|
+
|
|
1045
|
+
def _explore_frontier(self, s: CogsguardAgentState) -> Optional[Action]:
|
|
1046
|
+
"""Find and move toward the nearest unexplored frontier."""
|
|
1047
|
+
if not s.explored or len(s.explored) == 0:
|
|
1048
|
+
return None
|
|
1049
|
+
|
|
1050
|
+
start = (s.row, s.col)
|
|
1051
|
+
visited: set[tuple[int, int]] = {start}
|
|
1052
|
+
queue: deque[tuple[tuple[int, int], Optional[str]]] = deque()
|
|
1053
|
+
queue.append((start, None))
|
|
1054
|
+
|
|
1055
|
+
directions = [("north", -1, 0), ("south", 1, 0), ("east", 0, 1), ("west", 0, -1)]
|
|
1056
|
+
direction_deltas = {direction: (dr, dc) for direction, dr, dc in directions}
|
|
1057
|
+
|
|
1058
|
+
while queue:
|
|
1059
|
+
pos, first_step = queue.popleft()
|
|
1060
|
+
r, c = pos
|
|
1061
|
+
|
|
1062
|
+
for direction, dr, dc in directions:
|
|
1063
|
+
nr, nc = r + dr, c + dc
|
|
1064
|
+
if not (0 <= nr < s.map_height and 0 <= nc < s.map_width):
|
|
1065
|
+
continue
|
|
1066
|
+
if (nr, nc) in visited:
|
|
1067
|
+
continue
|
|
1068
|
+
|
|
1069
|
+
visited.add((nr, nc))
|
|
1070
|
+
|
|
1071
|
+
if not s.explored[nr][nc]:
|
|
1072
|
+
if first_step is None:
|
|
1073
|
+
if s.occupancy[nr][nc] == CellType.FREE.value and (nr, nc) not in s.agent_occupancy:
|
|
1074
|
+
if DEBUG and s.step_count <= 100:
|
|
1075
|
+
print(f"[A{s.agent_id}] FRONTIER: Moving {direction} to unexplored ({nr},{nc})")
|
|
1076
|
+
return self._move(direction)
|
|
1077
|
+
else:
|
|
1078
|
+
step_dr, step_dc = direction_deltas[first_step]
|
|
1079
|
+
step_r, step_c = s.row + step_dr, s.col + step_dc
|
|
1080
|
+
if not (0 <= step_r < s.map_height and 0 <= step_c < s.map_width):
|
|
1081
|
+
continue
|
|
1082
|
+
if s.occupancy[step_r][step_c] != CellType.FREE.value or (step_r, step_c) in s.agent_occupancy:
|
|
1083
|
+
continue
|
|
1084
|
+
if DEBUG and s.step_count <= 100:
|
|
1085
|
+
explored_count = sum(sum(row) for row in s.explored)
|
|
1086
|
+
total_cells = s.map_height * s.map_width
|
|
1087
|
+
print(
|
|
1088
|
+
f"[A{s.agent_id}] FRONTIER: Heading {first_step} towards "
|
|
1089
|
+
f"frontier at ({nr},{nc}), explored={explored_count}/{total_cells}"
|
|
1090
|
+
)
|
|
1091
|
+
return self._move(first_step)
|
|
1092
|
+
|
|
1093
|
+
if s.explored[nr][nc] and s.occupancy[nr][nc] == CellType.FREE.value:
|
|
1094
|
+
next_first_step = first_step
|
|
1095
|
+
if first_step is None and (r, c) == start:
|
|
1096
|
+
next_first_step = direction
|
|
1097
|
+
queue.append(((nr, nc), next_first_step))
|
|
1098
|
+
|
|
1099
|
+
if DEBUG and s.step_count % 50 == 0:
|
|
1100
|
+
explored_count = sum(sum(row) for row in s.explored)
|
|
1101
|
+
total_cells = s.map_height * s.map_width
|
|
1102
|
+
print(f"[A{s.agent_id}] FRONTIER: None found, explored={explored_count}/{total_cells}")
|
|
1103
|
+
return None
|
|
1104
|
+
|
|
1105
|
+
def _explore(self, s: CogsguardAgentState) -> Action:
|
|
1106
|
+
"""Explore systematically - cycle through cardinal directions."""
|
|
1107
|
+
# Check for location loop (agents blocking each other back and forth)
|
|
1108
|
+
if self._is_in_location_loop(s):
|
|
1109
|
+
action = self._break_location_loop(s)
|
|
1110
|
+
if action:
|
|
1111
|
+
return action
|
|
1112
|
+
# If can't break loop, fall through to normal exploration
|
|
1113
|
+
|
|
1114
|
+
# Start with east since gear stations are typically east of hub
|
|
1115
|
+
direction_cycle: list[CardinalDirection] = ["east", "south", "west", "north"]
|
|
1116
|
+
|
|
1117
|
+
if DEBUG and s.step_count <= 30:
|
|
1118
|
+
print(f"[A{s.agent_id}] EXPLORE: target={s.exploration_target}, step={s.step_count}")
|
|
1119
|
+
|
|
1120
|
+
if s.exploration_target is not None and isinstance(s.exploration_target, str):
|
|
1121
|
+
steps_in_direction = s.step_count - s.exploration_target_step
|
|
1122
|
+
if steps_in_direction < 8: # Explore 8 steps before turning (faster cycles)
|
|
1123
|
+
dr, dc = self._move_deltas.get(s.exploration_target, (0, 0))
|
|
1124
|
+
next_r, next_c = s.row + dr, s.col + dc
|
|
1125
|
+
if path_is_traversable(s, next_r, next_c, CellType): # type: ignore[arg-type]
|
|
1126
|
+
return self._move(s.exploration_target) # type: ignore[arg-type]
|
|
1127
|
+
|
|
1128
|
+
# Pick next direction in the cycle (don't randomize)
|
|
1129
|
+
current_dir = s.exploration_target
|
|
1130
|
+
if current_dir in direction_cycle:
|
|
1131
|
+
idx = direction_cycle.index(current_dir)
|
|
1132
|
+
next_idx = (idx + 1) % 4
|
|
1133
|
+
else:
|
|
1134
|
+
# Always start with east (index 0) since gear stations are east of hub
|
|
1135
|
+
next_idx = 0
|
|
1136
|
+
|
|
1137
|
+
# Try directions starting from next_idx
|
|
1138
|
+
for i in range(4):
|
|
1139
|
+
direction = direction_cycle[(next_idx + i) % 4]
|
|
1140
|
+
dr, dc = self._move_deltas[direction]
|
|
1141
|
+
next_r, next_c = s.row + dr, s.col + dc
|
|
1142
|
+
traversable = path_is_traversable(s, next_r, next_c, CellType) # type: ignore[arg-type]
|
|
1143
|
+
if DEBUG and s.step_count <= 10:
|
|
1144
|
+
in_bounds = 0 <= next_r < s.map_height and 0 <= next_c < s.map_width
|
|
1145
|
+
cell_val = s.occupancy[next_r][next_c] if in_bounds else -1
|
|
1146
|
+
agent_occ = (next_r, next_c) in s.agent_occupancy
|
|
1147
|
+
print(
|
|
1148
|
+
f"[A{s.agent_id}] EXPLORE_DIR: {direction} -> ({next_r},{next_c}) "
|
|
1149
|
+
f"trav={traversable} cell={cell_val} agent={agent_occ}"
|
|
1150
|
+
)
|
|
1151
|
+
if traversable:
|
|
1152
|
+
s.exploration_target = direction
|
|
1153
|
+
s.exploration_target_step = s.step_count
|
|
1154
|
+
return self._move(direction)
|
|
1155
|
+
|
|
1156
|
+
if DEBUG and s.step_count <= 10:
|
|
1157
|
+
print(f"[A{s.agent_id}] EXPLORE: All directions blocked, returning noop")
|
|
1158
|
+
return self._noop()
|
|
1159
|
+
|
|
1160
|
+
def _move_towards(
|
|
1161
|
+
self,
|
|
1162
|
+
s: CogsguardAgentState,
|
|
1163
|
+
target: tuple[int, int],
|
|
1164
|
+
*,
|
|
1165
|
+
reach_adjacent: bool = False,
|
|
1166
|
+
) -> Action:
|
|
1167
|
+
"""Pathfind toward a target."""
|
|
1168
|
+
# Check for location loop (agents blocking each other back and forth)
|
|
1169
|
+
if self._is_in_location_loop(s):
|
|
1170
|
+
action = self._break_location_loop(s)
|
|
1171
|
+
if action:
|
|
1172
|
+
return action
|
|
1173
|
+
# If can't break loop, fall through to normal pathfinding
|
|
1174
|
+
|
|
1175
|
+
start = (s.row, s.col)
|
|
1176
|
+
if start == target and not reach_adjacent:
|
|
1177
|
+
return self._noop()
|
|
1178
|
+
|
|
1179
|
+
goal_cells = compute_goal_cells(s, target, reach_adjacent, CellType) # type: ignore[arg-type]
|
|
1180
|
+
if not goal_cells:
|
|
1181
|
+
if DEBUG:
|
|
1182
|
+
print(f"[A{s.agent_id}] PATHFIND: No goal cells for {target}")
|
|
1183
|
+
return self._noop()
|
|
1184
|
+
|
|
1185
|
+
# Check cached path
|
|
1186
|
+
path = None
|
|
1187
|
+
if s.cached_path and s.cached_path_target == target and s.cached_path_reach_adjacent == reach_adjacent:
|
|
1188
|
+
next_pos = s.cached_path[0]
|
|
1189
|
+
if path_is_traversable(s, next_pos[0], next_pos[1], CellType): # type: ignore[arg-type]
|
|
1190
|
+
path = s.cached_path
|
|
1191
|
+
|
|
1192
|
+
# Compute new path if needed
|
|
1193
|
+
if path is None:
|
|
1194
|
+
path = shortest_path(s, start, goal_cells, False, CellType) # type: ignore[arg-type]
|
|
1195
|
+
s.cached_path = path.copy() if path else None
|
|
1196
|
+
s.cached_path_target = target
|
|
1197
|
+
s.cached_path_reach_adjacent = reach_adjacent
|
|
1198
|
+
|
|
1199
|
+
if not path:
|
|
1200
|
+
if DEBUG:
|
|
1201
|
+
print(f"[A{s.agent_id}] PATHFIND: No path to {target}, exploring instead")
|
|
1202
|
+
return self._explore(s)
|
|
1203
|
+
|
|
1204
|
+
next_pos = path[0]
|
|
1205
|
+
|
|
1206
|
+
# Convert to action
|
|
1207
|
+
dr = next_pos[0] - s.row
|
|
1208
|
+
dc = next_pos[1] - s.col
|
|
1209
|
+
|
|
1210
|
+
# Check agent collision
|
|
1211
|
+
if (next_pos[0], next_pos[1]) in s.agent_occupancy:
|
|
1212
|
+
action = self._try_random_direction(s)
|
|
1213
|
+
if action:
|
|
1214
|
+
s.cached_path = None
|
|
1215
|
+
s.cached_path_target = None
|
|
1216
|
+
return action
|
|
1217
|
+
return self._noop()
|
|
1218
|
+
|
|
1219
|
+
# Advance cached path only after taking a step
|
|
1220
|
+
if s.cached_path:
|
|
1221
|
+
s.cached_path = s.cached_path[1:]
|
|
1222
|
+
if not s.cached_path:
|
|
1223
|
+
s.cached_path = None
|
|
1224
|
+
s.cached_path_target = None
|
|
1225
|
+
|
|
1226
|
+
if dr == -1 and dc == 0:
|
|
1227
|
+
return self._move("north")
|
|
1228
|
+
elif dr == 1 and dc == 0:
|
|
1229
|
+
return self._move("south")
|
|
1230
|
+
elif dr == 0 and dc == 1:
|
|
1231
|
+
return self._move("east")
|
|
1232
|
+
elif dr == 0 and dc == -1:
|
|
1233
|
+
return self._move("west")
|
|
1234
|
+
|
|
1235
|
+
return self._noop()
|
|
1236
|
+
|
|
1237
|
+
def _try_random_direction(self, s: CogsguardAgentState) -> Optional[Action]:
|
|
1238
|
+
"""Try to move in a random free direction."""
|
|
1239
|
+
directions: list[CardinalDirection] = ["north", "south", "east", "west"]
|
|
1240
|
+
random.shuffle(directions)
|
|
1241
|
+
for direction in directions:
|
|
1242
|
+
dr, dc = self._move_deltas[direction]
|
|
1243
|
+
nr, nc = s.row + dr, s.col + dc
|
|
1244
|
+
if path_is_within_bounds(s, nr, nc) and s.occupancy[nr][nc] == CellType.FREE.value: # type: ignore[arg-type]
|
|
1245
|
+
if (nr, nc) not in s.agent_occupancy:
|
|
1246
|
+
return self._move(direction)
|
|
1247
|
+
return None
|
|
1248
|
+
|
|
1249
|
+
def _is_in_location_loop(self, s: CogsguardAgentState) -> bool:
|
|
1250
|
+
"""Detect if agent is stuck in a back-and-forth location loop.
|
|
1251
|
+
|
|
1252
|
+
Detects patterns like A→B→A→B→A (oscillating between 2 positions 3+ times).
|
|
1253
|
+
Returns True if such a loop is detected.
|
|
1254
|
+
"""
|
|
1255
|
+
history = s.position_history
|
|
1256
|
+
# Need at least 5 positions to detect A→B→A→B→A pattern
|
|
1257
|
+
if len(history) < 5:
|
|
1258
|
+
return False
|
|
1259
|
+
|
|
1260
|
+
# Check last 6 positions for oscillation pattern
|
|
1261
|
+
recent = history[-6:] if len(history) >= 6 else history
|
|
1262
|
+
|
|
1263
|
+
# Count unique positions in recent history
|
|
1264
|
+
unique_positions = set(recent)
|
|
1265
|
+
|
|
1266
|
+
# If only 2 unique positions in last 6 moves, we're oscillating
|
|
1267
|
+
if len(unique_positions) <= 2 and len(recent) >= 5:
|
|
1268
|
+
if DEBUG:
|
|
1269
|
+
print(f"[A{s.agent_id}] LOOP_DETECTED: Oscillating between {unique_positions}")
|
|
1270
|
+
return True
|
|
1271
|
+
|
|
1272
|
+
return False
|
|
1273
|
+
|
|
1274
|
+
def _break_location_loop(self, s: CogsguardAgentState) -> Optional[Action]:
|
|
1275
|
+
"""Try to break out of a location loop by moving in a random direction.
|
|
1276
|
+
|
|
1277
|
+
Clears cached path to force re-pathing after breaking the loop.
|
|
1278
|
+
"""
|
|
1279
|
+
if DEBUG:
|
|
1280
|
+
print(f"[A{s.agent_id}] BREAKING_LOOP: Attempting random move to escape")
|
|
1281
|
+
|
|
1282
|
+
# Clear cached path to force fresh pathfinding
|
|
1283
|
+
s.cached_path = None
|
|
1284
|
+
s.cached_path_target = None
|
|
1285
|
+
|
|
1286
|
+
# Clear position history to reset loop detection
|
|
1287
|
+
s.position_history.clear()
|
|
1288
|
+
|
|
1289
|
+
return self._try_random_direction(s)
|
|
1290
|
+
|
|
1291
|
+
|
|
1292
|
+
# =============================================================================
|
|
1293
|
+
# Policy wrapper
|
|
1294
|
+
# =============================================================================
|
|
1295
|
+
|
|
1296
|
+
|
|
1297
|
+
def _parse_vibe_list(raw: Optional[str]) -> list[str]:
|
|
1298
|
+
if not raw:
|
|
1299
|
+
return []
|
|
1300
|
+
return [entry.strip().lower() for entry in raw.split(",") if entry.strip()]
|
|
1301
|
+
|
|
1302
|
+
|
|
1303
|
+
class CogsguardPolicy(MultiAgentPolicy):
|
|
1304
|
+
"""Multi-agent policy for CoGsGuard with vibe-based role selection.
|
|
1305
|
+
|
|
1306
|
+
Agents use vibes to determine their behavior:
|
|
1307
|
+
- default: do nothing
|
|
1308
|
+
- gear: pick a role via smart or evolutionary coordinator, change vibe to that role
|
|
1309
|
+
- miner/scout/aligner/scrambler: get gear then execute role
|
|
1310
|
+
- heart: do nothing
|
|
1311
|
+
|
|
1312
|
+
Initial vibe counts can be specified via URI query parameters:
|
|
1313
|
+
metta://policy/role_py?miner=4&scrambler=2&gear=1
|
|
1314
|
+
You can also set a fixed role pattern with:
|
|
1315
|
+
metta://policy/role_py?role_cycle=aligner,miner,scrambler,scout
|
|
1316
|
+
metta://policy/role_py?role_order=aligner,miner,aligner,miner,scout
|
|
1317
|
+
|
|
1318
|
+
Vibes are assigned to agents in order. If counts don't sum to num_agents,
|
|
1319
|
+
remaining agents get "gear" vibe (which picks a role via the smart coordinator).
|
|
1320
|
+
"""
|
|
1321
|
+
|
|
1322
|
+
short_names = ["role_py"]
|
|
1323
|
+
|
|
1324
|
+
def __init__(
|
|
1325
|
+
self,
|
|
1326
|
+
policy_env_info: PolicyEnvInterface,
|
|
1327
|
+
device: str = "cpu",
|
|
1328
|
+
role_cycle: Optional[str] = None,
|
|
1329
|
+
role_order: Optional[str] = None,
|
|
1330
|
+
**vibe_counts: int,
|
|
1331
|
+
):
|
|
1332
|
+
super().__init__(policy_env_info, device=device)
|
|
1333
|
+
self._agent_policies: dict[int, StatefulAgentPolicy[CogsguardAgentState]] = {}
|
|
1334
|
+
self._smart_role_coordinator = _shared_coordinator(policy_env_info)
|
|
1335
|
+
self._feature_by_id = {feature.id: feature for feature in policy_env_info.obs_features}
|
|
1336
|
+
self._action_name_to_index = {name: idx for idx, name in enumerate(policy_env_info.action_names)}
|
|
1337
|
+
self._noop_action_value = dtype_actions.type(self._action_name_to_index.get("noop", 0))
|
|
1338
|
+
|
|
1339
|
+
def _parse_flag(value: object) -> bool:
|
|
1340
|
+
if isinstance(value, bool):
|
|
1341
|
+
return value
|
|
1342
|
+
if isinstance(value, int):
|
|
1343
|
+
return value != 0
|
|
1344
|
+
if isinstance(value, str):
|
|
1345
|
+
return value.strip().lower() in {"1", "true", "yes", "on"}
|
|
1346
|
+
return False
|
|
1347
|
+
|
|
1348
|
+
self._use_evolutionary_roles = (
|
|
1349
|
+
_parse_flag(vibe_counts.pop("evolution", None))
|
|
1350
|
+
or _parse_flag(vibe_counts.pop("evolutionary", None))
|
|
1351
|
+
or _parse_flag(vibe_counts.pop("evolve", None))
|
|
1352
|
+
)
|
|
1353
|
+
self._evolutionary_role_coordinator = (
|
|
1354
|
+
EvolutionaryRoleCoordinator(policy_env_info.num_agents) if self._use_evolutionary_roles else None
|
|
1355
|
+
)
|
|
1356
|
+
self._evolutionary_hooks_configured = False
|
|
1357
|
+
|
|
1358
|
+
available_vibes = {
|
|
1359
|
+
name[len("change_vibe_") :] for name in policy_env_info.action_names if name.startswith("change_vibe_")
|
|
1360
|
+
}
|
|
1361
|
+
role_vibes = [vibe for vibe in ["scrambler", "aligner", "miner", "scout"] if vibe in available_vibes]
|
|
1362
|
+
|
|
1363
|
+
self._initial_vibes: list[str] = []
|
|
1364
|
+
role_cycle_list = _parse_vibe_list(role_cycle)
|
|
1365
|
+
role_order_list = _parse_vibe_list(role_order)
|
|
1366
|
+
|
|
1367
|
+
if role_order_list:
|
|
1368
|
+
self._initial_vibes = []
|
|
1369
|
+
fallback_vibe = "default"
|
|
1370
|
+
for vibe in role_order_list:
|
|
1371
|
+
if vibe in available_vibes or vibe == "default":
|
|
1372
|
+
self._initial_vibes.append(vibe)
|
|
1373
|
+
else:
|
|
1374
|
+
if DEBUG:
|
|
1375
|
+
print(f"[CogsguardPolicy] Unknown role_order vibe '{vibe}', using '{fallback_vibe}'")
|
|
1376
|
+
self._initial_vibes.append(fallback_vibe)
|
|
1377
|
+
remaining = policy_env_info.num_agents - len(self._initial_vibes)
|
|
1378
|
+
if remaining > 0 and "gear" in available_vibes:
|
|
1379
|
+
self._initial_vibes.extend(["gear"] * remaining)
|
|
1380
|
+
elif role_cycle_list:
|
|
1381
|
+
cycle = [vibe for vibe in role_cycle_list if vibe in available_vibes]
|
|
1382
|
+
if cycle:
|
|
1383
|
+
while len(self._initial_vibes) < policy_env_info.num_agents:
|
|
1384
|
+
self._initial_vibes.extend(cycle)
|
|
1385
|
+
self._initial_vibes = self._initial_vibes[: policy_env_info.num_agents]
|
|
1386
|
+
|
|
1387
|
+
if not self._initial_vibes:
|
|
1388
|
+
# Build initial vibe assignment from URI params (e.g., ?scrambler=1&miner=4)
|
|
1389
|
+
counts = {k: v for k, v in vibe_counts.items() if isinstance(v, int)}
|
|
1390
|
+
if not counts and role_vibes:
|
|
1391
|
+
counts = {"scrambler": 1, "miner": 4}
|
|
1392
|
+
|
|
1393
|
+
if role_vibes:
|
|
1394
|
+
for vibe_name in role_vibes: # Role vibes first
|
|
1395
|
+
self._initial_vibes.extend([vibe_name] * counts.get(vibe_name, 0))
|
|
1396
|
+
# Add gear vibes (agents will pick a role)
|
|
1397
|
+
if "gear" in available_vibes:
|
|
1398
|
+
self._initial_vibes.extend(["gear"] * counts.get("gear", 0))
|
|
1399
|
+
remaining = policy_env_info.num_agents - len(self._initial_vibes)
|
|
1400
|
+
if remaining > 0 and "gear" in available_vibes:
|
|
1401
|
+
self._initial_vibes.extend(["gear"] * remaining)
|
|
1402
|
+
|
|
1403
|
+
if DEBUG:
|
|
1404
|
+
print(f"[CogsguardPolicy] Initial vibe assignment: {self._initial_vibes}")
|
|
1405
|
+
|
|
1406
|
+
def agent_policy(self, agent_id: int) -> StatefulAgentPolicy[CogsguardAgentState]:
|
|
1407
|
+
if agent_id not in self._agent_policies:
|
|
1408
|
+
# Create a multi-role implementation that can handle any role
|
|
1409
|
+
# The actual role is determined by vibe at runtime
|
|
1410
|
+
# Assign initial target vibe based on agent_id and configured counts
|
|
1411
|
+
target_vibe: Optional[str] = None
|
|
1412
|
+
if agent_id < len(self._initial_vibes):
|
|
1413
|
+
target_vibe = self._initial_vibes[agent_id]
|
|
1414
|
+
# Agents without assigned vibes stay on "default" (noop)
|
|
1415
|
+
|
|
1416
|
+
impl = CogsguardMultiRoleImpl(
|
|
1417
|
+
self._policy_env_info,
|
|
1418
|
+
agent_id,
|
|
1419
|
+
initial_target_vibe=target_vibe,
|
|
1420
|
+
smart_role_coordinator=self._smart_role_coordinator,
|
|
1421
|
+
evolutionary_role_coordinator=self._evolutionary_role_coordinator,
|
|
1422
|
+
use_evolutionary_roles=self._use_evolutionary_roles,
|
|
1423
|
+
)
|
|
1424
|
+
if self._evolutionary_role_coordinator is not None and not self._evolutionary_hooks_configured:
|
|
1425
|
+
from .behavior_hooks import build_cogsguard_behavior_hooks
|
|
1426
|
+
|
|
1427
|
+
self._evolutionary_role_coordinator.behavior_hooks.update(build_cogsguard_behavior_hooks(impl))
|
|
1428
|
+
self._evolutionary_hooks_configured = True
|
|
1429
|
+
self._agent_policies[agent_id] = StatefulAgentPolicy(impl, self._policy_env_info, agent_id=agent_id)
|
|
1430
|
+
|
|
1431
|
+
return self._agent_policies[agent_id]
|
|
1432
|
+
|
|
1433
|
+
def step_batch(self, raw_observations: np.ndarray, raw_actions: np.ndarray) -> None:
|
|
1434
|
+
num_agents = min(raw_observations.shape[0], self._policy_env_info.num_agents)
|
|
1435
|
+
for agent_id in range(num_agents):
|
|
1436
|
+
obs = self._raw_obs_to_agent_obs(agent_id, raw_observations[agent_id])
|
|
1437
|
+
action = self.agent_policy(agent_id).step(obs)
|
|
1438
|
+
action_index = self._action_name_to_index.get(action.name, self._noop_action_value)
|
|
1439
|
+
raw_actions[agent_id] = dtype_actions.type(action_index)
|
|
1440
|
+
|
|
1441
|
+
def _raw_obs_to_agent_obs(self, agent_id: int, raw_obs: np.ndarray) -> AgentObservation:
|
|
1442
|
+
tokens: list[ObservationToken] = []
|
|
1443
|
+
for token in raw_obs:
|
|
1444
|
+
feature_id = int(token[1])
|
|
1445
|
+
if feature_id == 0xFF:
|
|
1446
|
+
break
|
|
1447
|
+
feature = self._feature_by_id.get(feature_id)
|
|
1448
|
+
if feature is None:
|
|
1449
|
+
continue
|
|
1450
|
+
location_packed = int(token[0])
|
|
1451
|
+
value = int(token[2])
|
|
1452
|
+
tokens.append(
|
|
1453
|
+
ObservationToken(
|
|
1454
|
+
feature=feature,
|
|
1455
|
+
value=value,
|
|
1456
|
+
raw_token=(location_packed, feature_id, value),
|
|
1457
|
+
)
|
|
1458
|
+
)
|
|
1459
|
+
return AgentObservation(agent_id=agent_id, tokens=tokens)
|
|
1460
|
+
|
|
1461
|
+
|
|
1462
|
+
class CogsguardMultiRoleImpl(CogsguardAgentPolicyImpl):
|
|
1463
|
+
"""Multi-role implementation that delegates to role-specific behavior based on vibe."""
|
|
1464
|
+
|
|
1465
|
+
def __init__(
|
|
1466
|
+
self,
|
|
1467
|
+
policy_env_info: PolicyEnvInterface,
|
|
1468
|
+
agent_id: int,
|
|
1469
|
+
initial_target_vibe: Optional[str] = None,
|
|
1470
|
+
smart_role_coordinator: Optional[SmartRoleCoordinator] = None,
|
|
1471
|
+
evolutionary_role_coordinator: Optional[EvolutionaryRoleCoordinator] = None,
|
|
1472
|
+
use_evolutionary_roles: bool = False,
|
|
1473
|
+
):
|
|
1474
|
+
# Initialize with MINER as default, but role will be updated based on vibe
|
|
1475
|
+
super().__init__(
|
|
1476
|
+
policy_env_info,
|
|
1477
|
+
agent_id,
|
|
1478
|
+
Role.MINER,
|
|
1479
|
+
smart_role_coordinator=smart_role_coordinator,
|
|
1480
|
+
evolutionary_role_coordinator=evolutionary_role_coordinator,
|
|
1481
|
+
use_evolutionary_roles=use_evolutionary_roles,
|
|
1482
|
+
)
|
|
1483
|
+
|
|
1484
|
+
# Target vibe to switch to at start (if specified)
|
|
1485
|
+
self._initial_target_vibe = initial_target_vibe
|
|
1486
|
+
self._initial_vibe_set = False
|
|
1487
|
+
self._smart_role_enabled = initial_target_vibe == "gear"
|
|
1488
|
+
|
|
1489
|
+
# Lazy-load role implementations
|
|
1490
|
+
self._role_impls: dict[Role, CogsguardAgentPolicyImpl] = {}
|
|
1491
|
+
|
|
1492
|
+
def _execute_phase(self, s: CogsguardAgentState) -> Action:
|
|
1493
|
+
"""Execute action for current phase, handling initial vibe assignment.
|
|
1494
|
+
|
|
1495
|
+
Overrides base class to:
|
|
1496
|
+
1. Handle initial vibe assignment from URI params
|
|
1497
|
+
2. Skip the hardcoded "agent 0 = scrambler" logic when initial vibe is configured
|
|
1498
|
+
"""
|
|
1499
|
+
# If we have a target vibe and haven't switched yet, do it first
|
|
1500
|
+
if self._initial_target_vibe and not self._initial_vibe_set:
|
|
1501
|
+
if not self._has_vibe(self._initial_target_vibe):
|
|
1502
|
+
self._initial_vibe_set = True
|
|
1503
|
+
self._smart_role_enabled = False
|
|
1504
|
+
return self._noop()
|
|
1505
|
+
if s.current_vibe != self._initial_target_vibe:
|
|
1506
|
+
if DEBUG:
|
|
1507
|
+
print(
|
|
1508
|
+
f"[A{s.agent_id}] INITIAL_VIBE: Switching from {s.current_vibe} to {self._initial_target_vibe}"
|
|
1509
|
+
)
|
|
1510
|
+
return change_vibe_action(self._initial_target_vibe, action_names=self._action_names)
|
|
1511
|
+
self._initial_vibe_set = True
|
|
1512
|
+
|
|
1513
|
+
# If initial target vibe was configured, skip the hardcoded agent 0 scrambler logic
|
|
1514
|
+
# by directly handling the vibe-based behavior here
|
|
1515
|
+
if self._initial_target_vibe:
|
|
1516
|
+
return self._execute_vibe_behavior(s)
|
|
1517
|
+
|
|
1518
|
+
# Continue with normal phase execution (includes agent 0 scrambler logic)
|
|
1519
|
+
return super()._execute_phase(s)
|
|
1520
|
+
|
|
1521
|
+
def _execute_vibe_behavior(self, s: CogsguardAgentState) -> Action:
|
|
1522
|
+
"""Execute vibe-based behavior without the hardcoded agent 0 scrambler override."""
|
|
1523
|
+
vibe = s.current_vibe
|
|
1524
|
+
|
|
1525
|
+
# Default vibe: do nothing (wait for external vibe change)
|
|
1526
|
+
if vibe == "default":
|
|
1527
|
+
return self._noop()
|
|
1528
|
+
|
|
1529
|
+
# Heart vibe: do nothing
|
|
1530
|
+
if vibe == "heart":
|
|
1531
|
+
return self._noop()
|
|
1532
|
+
|
|
1533
|
+
# Gear vibe: pick a role and change vibe to it
|
|
1534
|
+
if vibe == "gear":
|
|
1535
|
+
selected_role = self._choose_role_vibe(s)
|
|
1536
|
+
if DEBUG:
|
|
1537
|
+
print(f"[A{s.agent_id}] GEAR_VIBE: Picking role vibe: {selected_role}")
|
|
1538
|
+
if not self._has_vibe(selected_role):
|
|
1539
|
+
return self._noop()
|
|
1540
|
+
return change_vibe_action(selected_role, action_names=self._action_names)
|
|
1541
|
+
|
|
1542
|
+
# Role vibes: execute the role behavior
|
|
1543
|
+
if vibe in VIBE_TO_ROLE:
|
|
1544
|
+
action = self._maybe_switch_smart_role(s)
|
|
1545
|
+
if action is not None:
|
|
1546
|
+
return action
|
|
1547
|
+
if s.phase == CogsguardPhase.GET_GEAR:
|
|
1548
|
+
return self._do_get_gear(s)
|
|
1549
|
+
elif s.phase == CogsguardPhase.EXECUTE_ROLE:
|
|
1550
|
+
return self.execute_role(s)
|
|
1551
|
+
|
|
1552
|
+
return self._noop()
|
|
1553
|
+
|
|
1554
|
+
def _maybe_switch_smart_role(self, s: CogsguardAgentState) -> Optional[Action]:
|
|
1555
|
+
if not self._smart_role_enabled or self._smart_role_coordinator is None:
|
|
1556
|
+
return None
|
|
1557
|
+
if s._pending_action_type is not None:
|
|
1558
|
+
return None
|
|
1559
|
+
if s.phase == CogsguardPhase.GET_GEAR:
|
|
1560
|
+
return None
|
|
1561
|
+
if s.step_count < s.role_lock_until_step:
|
|
1562
|
+
return None
|
|
1563
|
+
|
|
1564
|
+
gear_role = None
|
|
1565
|
+
if s.aligner > 0:
|
|
1566
|
+
gear_role = "aligner"
|
|
1567
|
+
elif s.scrambler > 0:
|
|
1568
|
+
gear_role = "scrambler"
|
|
1569
|
+
elif s.miner > 0:
|
|
1570
|
+
gear_role = "miner"
|
|
1571
|
+
elif s.scout > 0:
|
|
1572
|
+
gear_role = "scout"
|
|
1573
|
+
if gear_role and gear_role != s.current_vibe:
|
|
1574
|
+
if not self._has_vibe(gear_role):
|
|
1575
|
+
return None
|
|
1576
|
+
return change_vibe_action(gear_role, action_names=self._action_names)
|
|
1577
|
+
|
|
1578
|
+
selected_role = self._smart_role_coordinator.choose_role(s.agent_id)
|
|
1579
|
+
if selected_role == s.current_vibe:
|
|
1580
|
+
return None
|
|
1581
|
+
|
|
1582
|
+
s.last_role_switch_step = s.step_count
|
|
1583
|
+
s.role_lock_until_step = s.step_count + SMART_ROLE_SWITCH_COOLDOWN
|
|
1584
|
+
if DEBUG:
|
|
1585
|
+
print(f"[A{s.agent_id}] SMART_ROLE: Switching to {selected_role}")
|
|
1586
|
+
if not self._has_vibe(selected_role):
|
|
1587
|
+
return None
|
|
1588
|
+
return change_vibe_action(selected_role, action_names=self._action_names)
|
|
1589
|
+
|
|
1590
|
+
def _get_role_impl(self, role: Role) -> CogsguardAgentPolicyImpl:
|
|
1591
|
+
"""Get or create role-specific implementation."""
|
|
1592
|
+
if role not in self._role_impls:
|
|
1593
|
+
from .aligner import AlignerAgentPolicyImpl
|
|
1594
|
+
from .miner import MinerAgentPolicyImpl
|
|
1595
|
+
from .scout import ScoutAgentPolicyImpl
|
|
1596
|
+
from .scrambler import ScramblerAgentPolicyImpl
|
|
1597
|
+
|
|
1598
|
+
impl_class = {
|
|
1599
|
+
Role.MINER: MinerAgentPolicyImpl,
|
|
1600
|
+
Role.SCOUT: ScoutAgentPolicyImpl,
|
|
1601
|
+
Role.ALIGNER: AlignerAgentPolicyImpl,
|
|
1602
|
+
Role.SCRAMBLER: ScramblerAgentPolicyImpl,
|
|
1603
|
+
}[role]
|
|
1604
|
+
|
|
1605
|
+
self._role_impls[role] = impl_class(
|
|
1606
|
+
self._policy_env_info,
|
|
1607
|
+
self._agent_id,
|
|
1608
|
+
role,
|
|
1609
|
+
smart_role_coordinator=self._smart_role_coordinator,
|
|
1610
|
+
)
|
|
1611
|
+
|
|
1612
|
+
return self._role_impls[role]
|
|
1613
|
+
|
|
1614
|
+
def execute_role(self, s: CogsguardAgentState) -> Action:
|
|
1615
|
+
"""Delegate to role-specific implementation based on current role (set from vibe)."""
|
|
1616
|
+
role_impl = self._get_role_impl(s.role)
|
|
1617
|
+
return role_impl.execute_role(s)
|
|
1618
|
+
|
|
1619
|
+
|
|
1620
|
+
class CogsguardGeneralistImpl(CogsguardAgentPolicyImpl):
|
|
1621
|
+
"""Generalist agent that picks roles based on situational priorities."""
|
|
1622
|
+
|
|
1623
|
+
ROLE = Role.MINER
|
|
1624
|
+
ROLE_SWITCH_COOLDOWN = 120
|
|
1625
|
+
EARLY_SCOUT_STEPS = 80
|
|
1626
|
+
MIN_STRUCTURES_FOR_MIDGAME = 6
|
|
1627
|
+
MIN_ENERGY_BUFFER = 2
|
|
1628
|
+
|
|
1629
|
+
def __init__(
|
|
1630
|
+
self,
|
|
1631
|
+
policy_env_info: PolicyEnvInterface,
|
|
1632
|
+
agent_id: int,
|
|
1633
|
+
smart_role_coordinator: Optional[SmartRoleCoordinator] = None,
|
|
1634
|
+
):
|
|
1635
|
+
super().__init__(policy_env_info, agent_id, Role.MINER, smart_role_coordinator=smart_role_coordinator)
|
|
1636
|
+
self._role_impls: dict[Role, CogsguardAgentPolicyImpl] = {}
|
|
1637
|
+
|
|
1638
|
+
def _update_phase(self, s: CogsguardAgentState) -> None:
|
|
1639
|
+
desired_role = self._select_role(s)
|
|
1640
|
+
if self._should_switch_role(s, desired_role):
|
|
1641
|
+
if DEBUG:
|
|
1642
|
+
print(f"[A{s.agent_id}] GENERALIST: Switching role {s.role.value} -> {desired_role.value}")
|
|
1643
|
+
s.role = desired_role
|
|
1644
|
+
s.last_role_switch_step = s.step_count
|
|
1645
|
+
s.role_lock_until_step = s.step_count + self.ROLE_SWITCH_COOLDOWN
|
|
1646
|
+
|
|
1647
|
+
if s.has_gear() or s.step_count > 30:
|
|
1648
|
+
s.phase = CogsguardPhase.EXECUTE_ROLE
|
|
1649
|
+
else:
|
|
1650
|
+
s.phase = CogsguardPhase.GET_GEAR
|
|
1651
|
+
|
|
1652
|
+
def _execute_phase(self, s: CogsguardAgentState) -> Action:
|
|
1653
|
+
if self._should_recharge(s):
|
|
1654
|
+
return self._do_recharge(s)
|
|
1655
|
+
if s.phase == CogsguardPhase.GET_GEAR:
|
|
1656
|
+
return self._do_get_gear(s)
|
|
1657
|
+
if s.phase == CogsguardPhase.EXECUTE_ROLE:
|
|
1658
|
+
return self.execute_role(s)
|
|
1659
|
+
return self._noop()
|
|
1660
|
+
|
|
1661
|
+
def execute_role(self, s: CogsguardAgentState) -> Action:
|
|
1662
|
+
role_impl = self._get_role_impl(s.role)
|
|
1663
|
+
return role_impl.execute_role(s)
|
|
1664
|
+
|
|
1665
|
+
def _select_role(self, s: CogsguardAgentState) -> Role:
|
|
1666
|
+
if s._pending_action_type is not None:
|
|
1667
|
+
return s.role
|
|
1668
|
+
|
|
1669
|
+
hub_known = s.stations.get("hub") is not None
|
|
1670
|
+
chest_known = s.stations.get("chest") is not None
|
|
1671
|
+
|
|
1672
|
+
if not hub_known:
|
|
1673
|
+
return Role.SCOUT
|
|
1674
|
+
|
|
1675
|
+
if s._pending_alignment_target is not None:
|
|
1676
|
+
return Role.ALIGNER
|
|
1677
|
+
|
|
1678
|
+
if s.step_count < self.EARLY_SCOUT_STEPS and (
|
|
1679
|
+
not chest_known or len(s.structures) < self.MIN_STRUCTURES_FOR_MIDGAME
|
|
1680
|
+
):
|
|
1681
|
+
return Role.SCOUT
|
|
1682
|
+
|
|
1683
|
+
junctions = s.get_structures_by_type(StructureType.CHARGER)
|
|
1684
|
+
has_enemy_junctions = any(junction.alignment == "clips" for junction in junctions)
|
|
1685
|
+
has_neutral_junctions = any(junction.alignment in (None, "neutral") for junction in junctions)
|
|
1686
|
+
role_counts = self._role_counts()
|
|
1687
|
+
|
|
1688
|
+
if s.role == Role.SCRAMBLER:
|
|
1689
|
+
if has_neutral_junctions and self._role_is_ready(s, Role.ALIGNER):
|
|
1690
|
+
return Role.ALIGNER
|
|
1691
|
+
if has_enemy_junctions:
|
|
1692
|
+
return Role.SCRAMBLER
|
|
1693
|
+
if s.role == Role.ALIGNER:
|
|
1694
|
+
if (
|
|
1695
|
+
s._pending_alignment_target is None
|
|
1696
|
+
and not has_neutral_junctions
|
|
1697
|
+
and has_enemy_junctions
|
|
1698
|
+
and self._role_is_ready(s, Role.SCRAMBLER)
|
|
1699
|
+
):
|
|
1700
|
+
return Role.SCRAMBLER
|
|
1701
|
+
if has_enemy_junctions or has_neutral_junctions:
|
|
1702
|
+
return Role.ALIGNER
|
|
1703
|
+
|
|
1704
|
+
target_counts = self._target_role_counts(
|
|
1705
|
+
num_agents=self._policy_env_info.num_agents,
|
|
1706
|
+
has_enemy_junctions=has_enemy_junctions,
|
|
1707
|
+
has_neutral_junctions=has_neutral_junctions,
|
|
1708
|
+
hub_known=hub_known,
|
|
1709
|
+
step_count=s.step_count,
|
|
1710
|
+
)
|
|
1711
|
+
deficit_role = self._pick_deficit_role(s, role_counts, target_counts)
|
|
1712
|
+
if deficit_role is not None:
|
|
1713
|
+
return deficit_role
|
|
1714
|
+
|
|
1715
|
+
if has_enemy_junctions and self._role_is_ready(s, Role.SCRAMBLER):
|
|
1716
|
+
return Role.SCRAMBLER
|
|
1717
|
+
if has_neutral_junctions and self._role_is_ready(s, Role.ALIGNER):
|
|
1718
|
+
return Role.ALIGNER
|
|
1719
|
+
|
|
1720
|
+
if not s.get_usable_extractors():
|
|
1721
|
+
return Role.SCOUT
|
|
1722
|
+
|
|
1723
|
+
if s.total_cargo < s.cargo_capacity - 2:
|
|
1724
|
+
return Role.MINER
|
|
1725
|
+
|
|
1726
|
+
return self._pick_balanced_role(s, has_enemy_junctions, has_neutral_junctions)
|
|
1727
|
+
|
|
1728
|
+
def _should_switch_role(self, s: CogsguardAgentState, desired_role: Role) -> bool:
|
|
1729
|
+
if desired_role == s.role:
|
|
1730
|
+
return False
|
|
1731
|
+
if s._pending_action_type is not None:
|
|
1732
|
+
return False
|
|
1733
|
+
if desired_role == Role.ALIGNER and s._pending_alignment_target is not None:
|
|
1734
|
+
return True
|
|
1735
|
+
if desired_role == Role.ALIGNER and s.role == Role.SCRAMBLER:
|
|
1736
|
+
has_neutral = any(
|
|
1737
|
+
junction.alignment in (None, "neutral") for junction in s.get_structures_by_type(StructureType.CHARGER)
|
|
1738
|
+
)
|
|
1739
|
+
if has_neutral:
|
|
1740
|
+
return True
|
|
1741
|
+
if s.step_count < s.role_lock_until_step:
|
|
1742
|
+
return False
|
|
1743
|
+
return True
|
|
1744
|
+
|
|
1745
|
+
def _role_is_ready(self, s: CogsguardAgentState, role: Role) -> bool:
|
|
1746
|
+
if role in (Role.ALIGNER, Role.SCRAMBLER) and s.stations.get("hub") is None:
|
|
1747
|
+
return False
|
|
1748
|
+
if role in (Role.MINER, Role.SCOUT):
|
|
1749
|
+
return True
|
|
1750
|
+
return True
|
|
1751
|
+
|
|
1752
|
+
def _target_role_counts(
|
|
1753
|
+
self,
|
|
1754
|
+
num_agents: int,
|
|
1755
|
+
has_enemy_junctions: bool,
|
|
1756
|
+
has_neutral_junctions: bool,
|
|
1757
|
+
hub_known: bool,
|
|
1758
|
+
step_count: int,
|
|
1759
|
+
) -> dict[Role, int]:
|
|
1760
|
+
targets: dict[Role, int] = {}
|
|
1761
|
+
if step_count < self.EARLY_SCOUT_STEPS:
|
|
1762
|
+
targets[Role.SCOUT] = 2 if num_agents >= 8 else 1
|
|
1763
|
+
else:
|
|
1764
|
+
targets[Role.SCOUT] = 1
|
|
1765
|
+
targets[Role.MINER] = max(4, num_agents // 2)
|
|
1766
|
+
if hub_known:
|
|
1767
|
+
targets[Role.SCRAMBLER] = 2 if has_enemy_junctions else 1
|
|
1768
|
+
targets[Role.ALIGNER] = 2 if has_neutral_junctions else 1
|
|
1769
|
+
return targets
|
|
1770
|
+
|
|
1771
|
+
def _pick_deficit_role(
|
|
1772
|
+
self,
|
|
1773
|
+
s: CogsguardAgentState,
|
|
1774
|
+
role_counts: dict[Role, int],
|
|
1775
|
+
target_counts: dict[Role, int],
|
|
1776
|
+
) -> Role | None:
|
|
1777
|
+
if not target_counts:
|
|
1778
|
+
return None
|
|
1779
|
+
deficits: list[Role] = []
|
|
1780
|
+
ordered_roles = [Role.SCRAMBLER, Role.ALIGNER, Role.SCOUT, Role.MINER]
|
|
1781
|
+
for role in ordered_roles:
|
|
1782
|
+
target = target_counts.get(role, 0)
|
|
1783
|
+
if target <= 0:
|
|
1784
|
+
continue
|
|
1785
|
+
deficit = max(target - role_counts.get(role, 0), 0)
|
|
1786
|
+
deficits.extend([role] * deficit)
|
|
1787
|
+
if not deficits:
|
|
1788
|
+
return None
|
|
1789
|
+
role = deficits[s.agent_id % len(deficits)]
|
|
1790
|
+
if self._role_is_ready(s, role):
|
|
1791
|
+
return role
|
|
1792
|
+
return None
|
|
1793
|
+
|
|
1794
|
+
def _should_recharge(self, s: CogsguardAgentState) -> bool:
|
|
1795
|
+
if s.total_cargo > 0:
|
|
1796
|
+
return False
|
|
1797
|
+
if s.energy >= s.MOVE_ENERGY_COST * self.MIN_ENERGY_BUFFER:
|
|
1798
|
+
return False
|
|
1799
|
+
return s.stations.get("hub") is not None
|
|
1800
|
+
|
|
1801
|
+
def _role_has_gear(self, s: CogsguardAgentState, role: Role) -> bool:
|
|
1802
|
+
return getattr(s, ROLE_TO_GEAR[role], 0) > 0
|
|
1803
|
+
|
|
1804
|
+
def _role_station_known(self, s: CogsguardAgentState, role: Role) -> bool:
|
|
1805
|
+
return s.stations.get(ROLE_TO_STATION[role]) is not None
|
|
1806
|
+
|
|
1807
|
+
def _pick_balanced_role(
|
|
1808
|
+
self,
|
|
1809
|
+
s: CogsguardAgentState,
|
|
1810
|
+
has_enemy_junctions: bool,
|
|
1811
|
+
has_neutral_junctions: bool,
|
|
1812
|
+
) -> Role:
|
|
1813
|
+
candidates = [Role.MINER, Role.SCOUT]
|
|
1814
|
+
if has_enemy_junctions and self._role_is_ready(s, Role.SCRAMBLER):
|
|
1815
|
+
candidates.append(Role.SCRAMBLER)
|
|
1816
|
+
if has_neutral_junctions and self._role_is_ready(s, Role.ALIGNER):
|
|
1817
|
+
candidates.append(Role.ALIGNER)
|
|
1818
|
+
|
|
1819
|
+
role_counts = self._role_counts()
|
|
1820
|
+
best_role = s.role
|
|
1821
|
+
best_score = float("-inf")
|
|
1822
|
+
for role in candidates:
|
|
1823
|
+
score = 0
|
|
1824
|
+
if role == s.role:
|
|
1825
|
+
score += 2
|
|
1826
|
+
if self._role_has_gear(s, role):
|
|
1827
|
+
score += 3
|
|
1828
|
+
elif self._role_station_known(s, role):
|
|
1829
|
+
score += 1
|
|
1830
|
+
if role_counts:
|
|
1831
|
+
score += 2 - role_counts.get(role, 0)
|
|
1832
|
+
if score > best_score:
|
|
1833
|
+
best_score = score
|
|
1834
|
+
best_role = role
|
|
1835
|
+
return best_role
|
|
1836
|
+
|
|
1837
|
+
def _role_counts(self) -> dict[Role, int]:
|
|
1838
|
+
if self._smart_role_coordinator is None:
|
|
1839
|
+
return {}
|
|
1840
|
+
counts = {role: 0 for role in Role}
|
|
1841
|
+
for snapshot in self._smart_role_coordinator.agent_snapshots.values():
|
|
1842
|
+
counts[snapshot.role] += 1
|
|
1843
|
+
return counts
|
|
1844
|
+
|
|
1845
|
+
def _get_role_impl(self, role: Role) -> CogsguardAgentPolicyImpl:
|
|
1846
|
+
if role not in self._role_impls:
|
|
1847
|
+
from .aligner import AlignerAgentPolicyImpl
|
|
1848
|
+
from .miner import MinerAgentPolicyImpl
|
|
1849
|
+
from .scout import ScoutAgentPolicyImpl
|
|
1850
|
+
from .scrambler import ScramblerAgentPolicyImpl
|
|
1851
|
+
|
|
1852
|
+
impl_class = {
|
|
1853
|
+
Role.MINER: MinerAgentPolicyImpl,
|
|
1854
|
+
Role.SCOUT: ScoutAgentPolicyImpl,
|
|
1855
|
+
Role.ALIGNER: AlignerAgentPolicyImpl,
|
|
1856
|
+
Role.SCRAMBLER: ScramblerAgentPolicyImpl,
|
|
1857
|
+
}[role]
|
|
1858
|
+
|
|
1859
|
+
self._role_impls[role] = impl_class(
|
|
1860
|
+
self._policy_env_info,
|
|
1861
|
+
self._agent_id,
|
|
1862
|
+
role,
|
|
1863
|
+
smart_role_coordinator=self._smart_role_coordinator,
|
|
1864
|
+
)
|
|
1865
|
+
return self._role_impls[role]
|
|
1866
|
+
|
|
1867
|
+
|
|
1868
|
+
class CogsguardWomboImpl(CogsguardGeneralistImpl):
|
|
1869
|
+
"""Generalist agent that prioritizes aligning multiple junctions."""
|
|
1870
|
+
|
|
1871
|
+
TARGET_ALIGNED_JUNCTIONS = 2
|
|
1872
|
+
JUNCTION_PUSH_SCOUTS = 2
|
|
1873
|
+
JUNCTION_PUSH_ALIGNERS = 2
|
|
1874
|
+
JUNCTION_PUSH_SCRAMBLERS = 2
|
|
1875
|
+
MIN_MINERS = 4
|
|
1876
|
+
|
|
1877
|
+
def _select_role(self, s: CogsguardAgentState) -> Role:
|
|
1878
|
+
aligned_count = 0
|
|
1879
|
+
if self._smart_role_coordinator is not None:
|
|
1880
|
+
aligned_count = self._smart_role_coordinator.aligned_junction_count()
|
|
1881
|
+
if aligned_count < self.TARGET_ALIGNED_JUNCTIONS:
|
|
1882
|
+
if s._pending_action_type is not None:
|
|
1883
|
+
return s.role
|
|
1884
|
+
if s.stations.get("hub") is None:
|
|
1885
|
+
return Role.SCOUT
|
|
1886
|
+
if s.role in (Role.SCRAMBLER, Role.ALIGNER) and s.has_gear():
|
|
1887
|
+
return s.role
|
|
1888
|
+
if s.role == Role.SCRAMBLER and s._pending_alignment_target is not None:
|
|
1889
|
+
return Role.SCRAMBLER
|
|
1890
|
+
return super()._select_role(s)
|
|
1891
|
+
|
|
1892
|
+
def _should_recharge(self, s: CogsguardAgentState) -> bool:
|
|
1893
|
+
aligned_count = 0
|
|
1894
|
+
if self._smart_role_coordinator is not None:
|
|
1895
|
+
aligned_count = self._smart_role_coordinator.aligned_junction_count()
|
|
1896
|
+
if aligned_count < self.TARGET_ALIGNED_JUNCTIONS and s.role in (Role.SCRAMBLER, Role.ALIGNER):
|
|
1897
|
+
if s.total_cargo > 0:
|
|
1898
|
+
return False
|
|
1899
|
+
if s.energy >= s.MOVE_ENERGY_COST * self.MIN_ENERGY_BUFFER:
|
|
1900
|
+
return False
|
|
1901
|
+
return s.stations.get("hub") is not None
|
|
1902
|
+
return super()._should_recharge(s)
|
|
1903
|
+
|
|
1904
|
+
def _target_role_counts(
|
|
1905
|
+
self,
|
|
1906
|
+
num_agents: int,
|
|
1907
|
+
has_enemy_junctions: bool,
|
|
1908
|
+
has_neutral_junctions: bool,
|
|
1909
|
+
hub_known: bool,
|
|
1910
|
+
step_count: int,
|
|
1911
|
+
) -> dict[Role, int]:
|
|
1912
|
+
targets = super()._target_role_counts(
|
|
1913
|
+
num_agents=num_agents,
|
|
1914
|
+
has_enemy_junctions=has_enemy_junctions,
|
|
1915
|
+
has_neutral_junctions=has_neutral_junctions,
|
|
1916
|
+
hub_known=hub_known,
|
|
1917
|
+
step_count=step_count,
|
|
1918
|
+
)
|
|
1919
|
+
|
|
1920
|
+
aligned_count = 0
|
|
1921
|
+
if self._smart_role_coordinator is not None:
|
|
1922
|
+
aligned_count = self._smart_role_coordinator.aligned_junction_count()
|
|
1923
|
+
|
|
1924
|
+
if aligned_count < self.TARGET_ALIGNED_JUNCTIONS:
|
|
1925
|
+
targets[Role.SCOUT] = max(targets.get(Role.SCOUT, 0), self.JUNCTION_PUSH_SCOUTS)
|
|
1926
|
+
if hub_known:
|
|
1927
|
+
targets[Role.SCRAMBLER] = max(targets.get(Role.SCRAMBLER, 0), self.JUNCTION_PUSH_SCRAMBLERS)
|
|
1928
|
+
targets[Role.ALIGNER] = max(targets.get(Role.ALIGNER, 0), self.JUNCTION_PUSH_ALIGNERS)
|
|
1929
|
+
targets[Role.MINER] = max(self.MIN_MINERS, num_agents // 2)
|
|
1930
|
+
|
|
1931
|
+
return targets
|
|
1932
|
+
|
|
1933
|
+
|
|
1934
|
+
class CogsguardGeneralistPolicy(CogsguardPolicy):
|
|
1935
|
+
"""Generalist policy that adapts roles based on map and resource priorities."""
|
|
1936
|
+
|
|
1937
|
+
def __init__(self, policy_env_info: PolicyEnvInterface, device: str = "cpu", **_ignored: int):
|
|
1938
|
+
super().__init__(policy_env_info, device=device, **_ignored)
|
|
1939
|
+
|
|
1940
|
+
def agent_policy(self, agent_id: int) -> StatefulAgentPolicy[CogsguardAgentState]:
|
|
1941
|
+
if agent_id not in self._agent_policies:
|
|
1942
|
+
impl = CogsguardGeneralistImpl(
|
|
1943
|
+
self._policy_env_info,
|
|
1944
|
+
agent_id,
|
|
1945
|
+
smart_role_coordinator=self._smart_role_coordinator,
|
|
1946
|
+
)
|
|
1947
|
+
self._agent_policies[agent_id] = StatefulAgentPolicy(impl, self._policy_env_info, agent_id=agent_id)
|
|
1948
|
+
return self._agent_policies[agent_id]
|
|
1949
|
+
|
|
1950
|
+
|
|
1951
|
+
class CogsguardWomboPolicy(CogsguardPolicy):
|
|
1952
|
+
"""Generalist policy that prioritizes role rigs based on map conditions."""
|
|
1953
|
+
|
|
1954
|
+
short_names = ["wombo"]
|
|
1955
|
+
|
|
1956
|
+
def __init__(self, policy_env_info: PolicyEnvInterface, device: str = "cpu", **_ignored: int):
|
|
1957
|
+
super().__init__(policy_env_info, device=device, **_ignored)
|
|
1958
|
+
|
|
1959
|
+
def agent_policy(self, agent_id: int) -> StatefulAgentPolicy[CogsguardAgentState]:
|
|
1960
|
+
if agent_id not in self._agent_policies:
|
|
1961
|
+
impl = CogsguardWomboImpl(
|
|
1962
|
+
self._policy_env_info,
|
|
1963
|
+
agent_id,
|
|
1964
|
+
smart_role_coordinator=self._smart_role_coordinator,
|
|
1965
|
+
)
|
|
1966
|
+
self._agent_policies[agent_id] = StatefulAgentPolicy(impl, self._policy_env_info, agent_id=agent_id)
|
|
1967
|
+
return self._agent_policies[agent_id]
|