cogames-agents 0.0.0.7__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. cogames_agents/__init__.py +0 -0
  2. cogames_agents/evals/__init__.py +5 -0
  3. cogames_agents/evals/planky_evals.py +415 -0
  4. cogames_agents/policy/__init__.py +0 -0
  5. cogames_agents/policy/evolution/__init__.py +0 -0
  6. cogames_agents/policy/evolution/cogsguard/__init__.py +0 -0
  7. cogames_agents/policy/evolution/cogsguard/evolution.py +695 -0
  8. cogames_agents/policy/evolution/cogsguard/evolutionary_coordinator.py +540 -0
  9. cogames_agents/policy/nim_agents/__init__.py +20 -0
  10. cogames_agents/policy/nim_agents/agents.py +98 -0
  11. cogames_agents/policy/nim_agents/bindings/generated/libnim_agents.dylib +0 -0
  12. cogames_agents/policy/nim_agents/bindings/generated/nim_agents.py +215 -0
  13. cogames_agents/policy/nim_agents/cogsguard_agents.nim +555 -0
  14. cogames_agents/policy/nim_agents/cogsguard_align_all_agents.nim +569 -0
  15. cogames_agents/policy/nim_agents/common.nim +1054 -0
  16. cogames_agents/policy/nim_agents/install.sh +1 -0
  17. cogames_agents/policy/nim_agents/ladybug_agent.nim +954 -0
  18. cogames_agents/policy/nim_agents/nim_agents.nim +68 -0
  19. cogames_agents/policy/nim_agents/nim_agents.nims +14 -0
  20. cogames_agents/policy/nim_agents/nimby.lock +3 -0
  21. cogames_agents/policy/nim_agents/racecar_agents.nim +844 -0
  22. cogames_agents/policy/nim_agents/random_agents.nim +68 -0
  23. cogames_agents/policy/nim_agents/test_agents.py +53 -0
  24. cogames_agents/policy/nim_agents/thinky_agents.nim +677 -0
  25. cogames_agents/policy/nim_agents/thinky_eval.py +230 -0
  26. cogames_agents/policy/scripted_agent/README.md +360 -0
  27. cogames_agents/policy/scripted_agent/__init__.py +0 -0
  28. cogames_agents/policy/scripted_agent/baseline_agent.py +1031 -0
  29. cogames_agents/policy/scripted_agent/cogas/__init__.py +5 -0
  30. cogames_agents/policy/scripted_agent/cogas/context.py +68 -0
  31. cogames_agents/policy/scripted_agent/cogas/entity_map.py +152 -0
  32. cogames_agents/policy/scripted_agent/cogas/goal.py +115 -0
  33. cogames_agents/policy/scripted_agent/cogas/goals/__init__.py +27 -0
  34. cogames_agents/policy/scripted_agent/cogas/goals/aligner.py +160 -0
  35. cogames_agents/policy/scripted_agent/cogas/goals/gear.py +197 -0
  36. cogames_agents/policy/scripted_agent/cogas/goals/miner.py +441 -0
  37. cogames_agents/policy/scripted_agent/cogas/goals/scout.py +40 -0
  38. cogames_agents/policy/scripted_agent/cogas/goals/scrambler.py +174 -0
  39. cogames_agents/policy/scripted_agent/cogas/goals/shared.py +160 -0
  40. cogames_agents/policy/scripted_agent/cogas/goals/stem.py +60 -0
  41. cogames_agents/policy/scripted_agent/cogas/goals/survive.py +100 -0
  42. cogames_agents/policy/scripted_agent/cogas/navigator.py +401 -0
  43. cogames_agents/policy/scripted_agent/cogas/obs_parser.py +238 -0
  44. cogames_agents/policy/scripted_agent/cogas/policy.py +525 -0
  45. cogames_agents/policy/scripted_agent/cogas/trace.py +69 -0
  46. cogames_agents/policy/scripted_agent/cogsguard/CLAUDE.md +517 -0
  47. cogames_agents/policy/scripted_agent/cogsguard/README.md +252 -0
  48. cogames_agents/policy/scripted_agent/cogsguard/__init__.py +74 -0
  49. cogames_agents/policy/scripted_agent/cogsguard/aligned_junction_held_investigation.md +152 -0
  50. cogames_agents/policy/scripted_agent/cogsguard/aligner.py +333 -0
  51. cogames_agents/policy/scripted_agent/cogsguard/behavior_hooks.py +44 -0
  52. cogames_agents/policy/scripted_agent/cogsguard/control_agent.py +323 -0
  53. cogames_agents/policy/scripted_agent/cogsguard/debug_agent.py +533 -0
  54. cogames_agents/policy/scripted_agent/cogsguard/miner.py +589 -0
  55. cogames_agents/policy/scripted_agent/cogsguard/options.py +67 -0
  56. cogames_agents/policy/scripted_agent/cogsguard/parity_metrics.py +36 -0
  57. cogames_agents/policy/scripted_agent/cogsguard/policy.py +1967 -0
  58. cogames_agents/policy/scripted_agent/cogsguard/prereq_trace.py +33 -0
  59. cogames_agents/policy/scripted_agent/cogsguard/role_trace.py +50 -0
  60. cogames_agents/policy/scripted_agent/cogsguard/roles.py +31 -0
  61. cogames_agents/policy/scripted_agent/cogsguard/rollout_trace.py +40 -0
  62. cogames_agents/policy/scripted_agent/cogsguard/scout.py +69 -0
  63. cogames_agents/policy/scripted_agent/cogsguard/scrambler.py +350 -0
  64. cogames_agents/policy/scripted_agent/cogsguard/targeted_agent.py +418 -0
  65. cogames_agents/policy/scripted_agent/cogsguard/teacher.py +224 -0
  66. cogames_agents/policy/scripted_agent/cogsguard/types.py +381 -0
  67. cogames_agents/policy/scripted_agent/cogsguard/v2_agent.py +49 -0
  68. cogames_agents/policy/scripted_agent/common/__init__.py +0 -0
  69. cogames_agents/policy/scripted_agent/common/geometry.py +24 -0
  70. cogames_agents/policy/scripted_agent/common/roles.py +34 -0
  71. cogames_agents/policy/scripted_agent/common/tag_utils.py +48 -0
  72. cogames_agents/policy/scripted_agent/demo_policy.py +242 -0
  73. cogames_agents/policy/scripted_agent/pathfinding.py +126 -0
  74. cogames_agents/policy/scripted_agent/pinky/DESIGN.md +317 -0
  75. cogames_agents/policy/scripted_agent/pinky/__init__.py +5 -0
  76. cogames_agents/policy/scripted_agent/pinky/behaviors/__init__.py +17 -0
  77. cogames_agents/policy/scripted_agent/pinky/behaviors/aligner.py +400 -0
  78. cogames_agents/policy/scripted_agent/pinky/behaviors/base.py +119 -0
  79. cogames_agents/policy/scripted_agent/pinky/behaviors/miner.py +632 -0
  80. cogames_agents/policy/scripted_agent/pinky/behaviors/scout.py +138 -0
  81. cogames_agents/policy/scripted_agent/pinky/behaviors/scrambler.py +433 -0
  82. cogames_agents/policy/scripted_agent/pinky/policy.py +570 -0
  83. cogames_agents/policy/scripted_agent/pinky/services/__init__.py +7 -0
  84. cogames_agents/policy/scripted_agent/pinky/services/map_tracker.py +808 -0
  85. cogames_agents/policy/scripted_agent/pinky/services/navigator.py +864 -0
  86. cogames_agents/policy/scripted_agent/pinky/services/safety.py +189 -0
  87. cogames_agents/policy/scripted_agent/pinky/state.py +299 -0
  88. cogames_agents/policy/scripted_agent/pinky/types.py +138 -0
  89. cogames_agents/policy/scripted_agent/planky/CLAUDE.md +124 -0
  90. cogames_agents/policy/scripted_agent/planky/IMPROVEMENTS.md +160 -0
  91. cogames_agents/policy/scripted_agent/planky/NOTES.md +153 -0
  92. cogames_agents/policy/scripted_agent/planky/PLAN.md +254 -0
  93. cogames_agents/policy/scripted_agent/planky/README.md +214 -0
  94. cogames_agents/policy/scripted_agent/planky/STRATEGY.md +100 -0
  95. cogames_agents/policy/scripted_agent/planky/__init__.py +5 -0
  96. cogames_agents/policy/scripted_agent/planky/context.py +68 -0
  97. cogames_agents/policy/scripted_agent/planky/entity_map.py +152 -0
  98. cogames_agents/policy/scripted_agent/planky/goal.py +107 -0
  99. cogames_agents/policy/scripted_agent/planky/goals/__init__.py +27 -0
  100. cogames_agents/policy/scripted_agent/planky/goals/aligner.py +168 -0
  101. cogames_agents/policy/scripted_agent/planky/goals/gear.py +179 -0
  102. cogames_agents/policy/scripted_agent/planky/goals/miner.py +416 -0
  103. cogames_agents/policy/scripted_agent/planky/goals/scout.py +40 -0
  104. cogames_agents/policy/scripted_agent/planky/goals/scrambler.py +174 -0
  105. cogames_agents/policy/scripted_agent/planky/goals/shared.py +160 -0
  106. cogames_agents/policy/scripted_agent/planky/goals/stem.py +49 -0
  107. cogames_agents/policy/scripted_agent/planky/goals/survive.py +96 -0
  108. cogames_agents/policy/scripted_agent/planky/navigator.py +388 -0
  109. cogames_agents/policy/scripted_agent/planky/obs_parser.py +238 -0
  110. cogames_agents/policy/scripted_agent/planky/policy.py +485 -0
  111. cogames_agents/policy/scripted_agent/planky/tests/__init__.py +0 -0
  112. cogames_agents/policy/scripted_agent/planky/tests/conftest.py +66 -0
  113. cogames_agents/policy/scripted_agent/planky/tests/helpers.py +152 -0
  114. cogames_agents/policy/scripted_agent/planky/tests/test_aligner.py +24 -0
  115. cogames_agents/policy/scripted_agent/planky/tests/test_miner.py +30 -0
  116. cogames_agents/policy/scripted_agent/planky/tests/test_scout.py +15 -0
  117. cogames_agents/policy/scripted_agent/planky/tests/test_scrambler.py +29 -0
  118. cogames_agents/policy/scripted_agent/planky/tests/test_stem.py +36 -0
  119. cogames_agents/policy/scripted_agent/planky/trace.py +69 -0
  120. cogames_agents/policy/scripted_agent/types.py +239 -0
  121. cogames_agents/policy/scripted_agent/unclipping_agent.py +461 -0
  122. cogames_agents/policy/scripted_agent/utils.py +381 -0
  123. cogames_agents/policy/scripted_registry.py +80 -0
  124. cogames_agents/py.typed +0 -0
  125. cogames_agents-0.0.0.7.dist-info/METADATA +98 -0
  126. cogames_agents-0.0.0.7.dist-info/RECORD +128 -0
  127. cogames_agents-0.0.0.7.dist-info/WHEEL +6 -0
  128. cogames_agents-0.0.0.7.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1967 @@
1
+ """
2
+ CoGsGuard Scripted Agent - Vibe-based multi-agent policy.
3
+
4
+ Agents use vibes to determine their behavior:
5
+ - default: do nothing (noop)
6
+ - gear: pick a role via smart coordinator, change vibe to that role
7
+ - miner/scout/aligner/scrambler: get gear if needed, then execute role behavior
8
+ - heart: do nothing (noop)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import random
14
+ from collections import deque
15
+ from dataclasses import dataclass, field
16
+ from typing import TYPE_CHECKING, Optional
17
+
18
+ import numpy as np
19
+
20
+ from cogames_agents.policy.evolution.cogsguard.evolutionary_coordinator import (
21
+ EvolutionaryRoleCoordinator,
22
+ )
23
+ from cogames_agents.policy.scripted_agent.pathfinding import (
24
+ compute_goal_cells,
25
+ shortest_path,
26
+ )
27
+ from cogames_agents.policy.scripted_agent.pathfinding import (
28
+ is_traversable as path_is_traversable,
29
+ )
30
+ from cogames_agents.policy.scripted_agent.pathfinding import (
31
+ is_within_bounds as path_is_within_bounds,
32
+ )
33
+ from cogames_agents.policy.scripted_agent.types import CellType, ObjectState, ParsedObservation
34
+ from cogames_agents.policy.scripted_agent.utils import (
35
+ add_inventory_token,
36
+ change_vibe_action,
37
+ has_type_tag,
38
+ is_adjacent,
39
+ is_station,
40
+ is_wall,
41
+ )
42
+ from cogames_agents.policy.scripted_agent.utils import (
43
+ parse_observation as utils_parse_observation,
44
+ )
45
+ from mettagrid.config.mettagrid_config import CardinalDirection
46
+ from mettagrid.mettagrid_c import dtype_actions
47
+ from mettagrid.policy.policy import MultiAgentPolicy, StatefulAgentPolicy, StatefulPolicyImpl
48
+ from mettagrid.policy.policy_env_interface import PolicyEnvInterface
49
+ from mettagrid.simulator import Action, AgentObservation, ObservationToken
50
+
51
+ from .types import (
52
+ ROLE_TO_GEAR,
53
+ ROLE_TO_STATION,
54
+ CogsguardAgentState,
55
+ CogsguardPhase,
56
+ Role,
57
+ StructureInfo,
58
+ StructureType,
59
+ )
60
+
61
+ # Vibe names for role selection
62
+ ROLE_VIBES = ["scout", "miner", "aligner", "scrambler"]
63
+ VIBE_TO_ROLE = {
64
+ "miner": Role.MINER,
65
+ "scout": Role.SCOUT,
66
+ "aligner": Role.ALIGNER,
67
+ "scrambler": Role.SCRAMBLER,
68
+ }
69
+ SMART_ROLE_SWITCH_COOLDOWN = 40
70
+ SCRAMBLER_GEAR_PRIORITY_STEPS = 25
71
+
72
+ _GLOBAL_COORDINATORS: dict[int, "SmartRoleCoordinator"] = {}
73
+
74
+
75
+ def _shared_coordinator(policy_env_info: PolicyEnvInterface) -> "SmartRoleCoordinator":
76
+ key = id(policy_env_info)
77
+ coordinator = _GLOBAL_COORDINATORS.get(key)
78
+ if coordinator is None or coordinator.num_agents != policy_env_info.num_agents:
79
+ coordinator = SmartRoleCoordinator(policy_env_info.num_agents)
80
+ _GLOBAL_COORDINATORS[key] = coordinator
81
+ return coordinator
82
+
83
+
84
+ if TYPE_CHECKING:
85
+ from mettagrid.simulator.interface import AgentObservation
86
+
87
+ # Debug flag - set to True to see detailed agent behavior
88
+ DEBUG = False
89
+ GEAR_SEARCH_OFFSETS = [
90
+ # BaseHub places stations ~4-5 rows below the hub, spaced by 2 columns.
91
+ # Search those slots first to capture early gear windows.
92
+ (4, -4),
93
+ (4, -2),
94
+ (4, 0),
95
+ (4, 2),
96
+ (4, 4),
97
+ (5, -4),
98
+ (5, -2),
99
+ (5, 0),
100
+ (5, 2),
101
+ (5, 4),
102
+ # Fallbacks for variations/tighter layouts.
103
+ (6, -4),
104
+ (6, 0),
105
+ (6, 4),
106
+ (0, 5),
107
+ (5, 0),
108
+ (0, -5),
109
+ (-5, 0),
110
+ ]
111
+
112
+
113
+ @dataclass
114
+ class SmartRoleAgentSnapshot:
115
+ """Lightweight snapshot for smart-role coordination."""
116
+
117
+ step: int
118
+ role: Role
119
+ has_gear: bool
120
+ structures_known: tuple[str, ...]
121
+ structures_seen: int
122
+ heart_count: int
123
+ influence_count: int
124
+ junction_alignment_counts: dict[str, int]
125
+
126
+
127
+ @dataclass
128
+ class SmartRoleCoordinator:
129
+ """Shared coordinator for future smart-role selection."""
130
+
131
+ num_agents: int
132
+ agent_snapshots: dict[int, SmartRoleAgentSnapshot] = field(default_factory=dict)
133
+ junction_alignment_overrides: dict[tuple[int, int], Optional[str]] = field(default_factory=dict)
134
+ station_offsets: dict[str, tuple[int, int]] = field(default_factory=dict)
135
+ recent_scrambles: dict[tuple[int, int], int] = field(default_factory=dict)
136
+
137
+ def update_agent(self, s: CogsguardAgentState) -> None:
138
+ hub_pos = s.stations.get("hub")
139
+ if hub_pos is not None:
140
+ self._record_known_junctions(s, hub_pos)
141
+ self._record_known_stations(s, hub_pos)
142
+ self._apply_alignment_overrides(s, hub_pos)
143
+ self._apply_station_overrides(s, hub_pos)
144
+ junction_counts = {"cogs": 0, "clips": 0, "neutral": 0, "unknown": 0}
145
+ for struct in s.get_structures_by_type(StructureType.CHARGER):
146
+ bucket = self._normalize_alignment(struct.alignment)
147
+ junction_counts[bucket] += 1
148
+
149
+ known_structures = tuple(sorted({struct.structure_type.value for struct in s.structures.values()}))
150
+ self.agent_snapshots[s.agent_id] = SmartRoleAgentSnapshot(
151
+ step=s.step_count,
152
+ role=s.role,
153
+ has_gear=s.has_gear(),
154
+ structures_known=known_structures,
155
+ structures_seen=len(s.structures),
156
+ heart_count=s.heart,
157
+ influence_count=s.influence,
158
+ junction_alignment_counts=junction_counts,
159
+ )
160
+
161
+ def register_junction_alignment(
162
+ self,
163
+ pos: tuple[int, int],
164
+ alignment: Optional[str],
165
+ hub_pos: Optional[tuple[int, int]],
166
+ step: Optional[int] = None,
167
+ ) -> None:
168
+ if hub_pos is None:
169
+ return
170
+ offset = (pos[0] - hub_pos[0], pos[1] - hub_pos[1])
171
+ self.junction_alignment_overrides[offset] = alignment
172
+ if step is None:
173
+ return
174
+ if alignment is None:
175
+ self.recent_scrambles[offset] = step
176
+ elif alignment == "cogs":
177
+ self.recent_scrambles.pop(offset, None)
178
+
179
+ def recent_scramble_targets(
180
+ self,
181
+ hub_pos: Optional[tuple[int, int]],
182
+ step: int,
183
+ *,
184
+ max_age: int = 200,
185
+ ) -> list[tuple[int, int]]:
186
+ if hub_pos is None:
187
+ return []
188
+ targets: list[tuple[int, int]] = []
189
+ stale_offsets: list[tuple[int, int]] = []
190
+ for offset, scramble_step in self.recent_scrambles.items():
191
+ if step - scramble_step > max_age:
192
+ stale_offsets.append(offset)
193
+ continue
194
+ targets.append((hub_pos[0] + offset[0], hub_pos[1] + offset[1]))
195
+ for offset in stale_offsets:
196
+ self.recent_scrambles.pop(offset, None)
197
+ return targets
198
+
199
+ def _record_known_junctions(self, s: CogsguardAgentState, hub_pos: tuple[int, int]) -> None:
200
+ for junction in s.get_structures_by_type(StructureType.CHARGER):
201
+ offset = (junction.position[0] - hub_pos[0], junction.position[1] - hub_pos[1])
202
+ if offset not in self.junction_alignment_overrides and junction.alignment is not None:
203
+ self.junction_alignment_overrides[offset] = junction.alignment
204
+
205
+ def _record_known_stations(self, s: CogsguardAgentState, hub_pos: tuple[int, int]) -> None:
206
+ for name, pos in s.stations.items():
207
+ if pos is None or name in ("hub", "junction"):
208
+ continue
209
+ if name not in self.station_offsets:
210
+ self.station_offsets[name] = (pos[0] - hub_pos[0], pos[1] - hub_pos[1])
211
+
212
+ def _apply_station_overrides(self, s: CogsguardAgentState, hub_pos: tuple[int, int]) -> None:
213
+ if not self.station_offsets:
214
+ return
215
+ for name, offset in self.station_offsets.items():
216
+ if s.stations.get(name) is not None:
217
+ continue
218
+ pos = (hub_pos[0] + offset[0], hub_pos[1] + offset[1])
219
+ if not (0 <= pos[0] < s.map_height and 0 <= pos[1] < s.map_width):
220
+ continue
221
+ s.stations[name] = pos
222
+ if pos not in s.structures:
223
+ s.structures[pos] = StructureInfo(
224
+ position=pos,
225
+ structure_type=self._station_structure_type(name),
226
+ name=name,
227
+ last_seen_step=s.step_count,
228
+ )
229
+ s.occupancy[pos[0]][pos[1]] = CellType.OBSTACLE.value
230
+
231
+ def _apply_alignment_overrides(self, s: CogsguardAgentState, hub_pos: tuple[int, int]) -> None:
232
+ if not self.junction_alignment_overrides:
233
+ return
234
+ for offset, alignment in self.junction_alignment_overrides.items():
235
+ pos = (hub_pos[0] + offset[0], hub_pos[1] + offset[1])
236
+ if not (0 <= pos[0] < s.map_height and 0 <= pos[1] < s.map_width):
237
+ continue
238
+ struct = s.structures.get(pos)
239
+ if struct is None:
240
+ s.structures[pos] = StructureInfo(
241
+ position=pos,
242
+ structure_type=StructureType.CHARGER,
243
+ name="junction",
244
+ last_seen_step=s.step_count,
245
+ alignment=alignment,
246
+ )
247
+ s.occupancy[pos[0]][pos[1]] = CellType.OBSTACLE.value
248
+ elif struct.structure_type == StructureType.CHARGER:
249
+ if struct.last_seen_step == s.step_count:
250
+ continue
251
+ if struct.alignment != alignment:
252
+ struct.alignment = alignment
253
+
254
+ if s.supply_depots:
255
+ for idx, (pos, _alignment) in enumerate(s.supply_depots):
256
+ offset = (pos[0] - hub_pos[0], pos[1] - hub_pos[1])
257
+ if offset in self.junction_alignment_overrides:
258
+ s.supply_depots[idx] = (pos, self.junction_alignment_overrides[offset])
259
+
260
+ @staticmethod
261
+ def _station_structure_type(name: str) -> StructureType:
262
+ return {
263
+ "miner_station": StructureType.MINER_STATION,
264
+ "scout_station": StructureType.SCOUT_STATION,
265
+ "aligner_station": StructureType.ALIGNER_STATION,
266
+ "scrambler_station": StructureType.SCRAMBLER_STATION,
267
+ "chest": StructureType.CHEST,
268
+ }.get(name, StructureType.UNKNOWN)
269
+
270
+ @staticmethod
271
+ def _normalize_alignment(alignment: Optional[str]) -> str:
272
+ if alignment is None or alignment == "neutral":
273
+ return "neutral"
274
+ if alignment in ("cogs", "clips"):
275
+ return alignment
276
+ return "unknown"
277
+
278
+ def choose_role(self, agent_id: int) -> str:
279
+ """Pick a role vibe based on aggregated snapshots."""
280
+ snapshot = self.agent_snapshots.get(agent_id)
281
+ if snapshot is None:
282
+ return random.choice(ROLE_VIBES)
283
+
284
+ structures_known = self._aggregate_structures()
285
+ if "hub" not in structures_known or "chest" not in structures_known:
286
+ return "scout"
287
+
288
+ role_counts = self._aggregate_role_counts()
289
+ if role_counts.get("scout", 0) == 0:
290
+ return "scout"
291
+ if role_counts.get("miner", 0) == 0:
292
+ return "miner"
293
+
294
+ junction_counts = self._aggregate_junction_counts()
295
+ known_junctions = sum(junction_counts.values()) - junction_counts["unknown"]
296
+ if known_junctions == 0:
297
+ return "scout"
298
+
299
+ if role_counts.get("scrambler", 0) == 0:
300
+ return "scrambler"
301
+ if role_counts.get("aligner", 0) == 0:
302
+ return "aligner"
303
+
304
+ if junction_counts["clips"] > 0 and role_counts.get("scrambler", 0) <= role_counts.get("aligner", 0):
305
+ return "scrambler"
306
+ if junction_counts["neutral"] > 0:
307
+ return "aligner"
308
+
309
+ if self._aggregate_structures_seen() < 10:
310
+ return "scout"
311
+ return "miner"
312
+
313
+ def _aggregate_junction_counts(self) -> dict[str, int]:
314
+ totals = {"cogs": 0, "clips": 0, "neutral": 0, "unknown": 0}
315
+ for snapshot in self.agent_snapshots.values():
316
+ for key in totals:
317
+ totals[key] = max(totals[key], snapshot.junction_alignment_counts.get(key, 0))
318
+ return totals
319
+
320
+ def aligned_junction_count(self) -> int:
321
+ return self._aggregate_junction_counts().get("cogs", 0)
322
+
323
+ def _aggregate_structures(self) -> set[str]:
324
+ structures: set[str] = set()
325
+ for snapshot in self.agent_snapshots.values():
326
+ structures.update(snapshot.structures_known)
327
+ return structures
328
+
329
+ def _aggregate_role_counts(self) -> dict[str, int]:
330
+ counts = {role: 0 for role in ROLE_VIBES}
331
+ for snapshot in self.agent_snapshots.values():
332
+ role_name = snapshot.role.value
333
+ if role_name in counts:
334
+ counts[role_name] += 1
335
+ return counts
336
+
337
+ def _aggregate_role_gear_counts(self) -> dict[str, int]:
338
+ counts = {role: 0 for role in ROLE_VIBES}
339
+ for snapshot in self.agent_snapshots.values():
340
+ role_name = snapshot.role.value
341
+ if role_name in counts and snapshot.has_gear:
342
+ counts[role_name] += 1
343
+ return counts
344
+
345
+ def get_role_gear_counts(self) -> dict[str, int]:
346
+ return self._aggregate_role_gear_counts()
347
+
348
+ def _aggregate_heart_count(self) -> int:
349
+ return sum(snapshot.heart_count for snapshot in self.agent_snapshots.values())
350
+
351
+ def _aggregate_influence_count(self) -> int:
352
+ return sum(snapshot.influence_count for snapshot in self.agent_snapshots.values())
353
+
354
+ def _aggregate_structures_seen(self) -> int:
355
+ return max((snap.structures_seen for snap in self.agent_snapshots.values()), default=0)
356
+
357
+
358
+ class CogsguardAgentPolicyImpl(StatefulPolicyImpl[CogsguardAgentState]):
359
+ """Base policy implementation for CoGsGuard agents.
360
+
361
+ Handles common behavior like gear acquisition. Role-specific behavior
362
+ is implemented by overriding execute_role().
363
+ """
364
+
365
+ # Subclasses set this
366
+ ROLE: Role = Role.MINER
367
+
368
+ def __init__(
369
+ self,
370
+ policy_env_info: PolicyEnvInterface,
371
+ agent_id: int,
372
+ role: Role,
373
+ smart_role_coordinator: Optional[SmartRoleCoordinator] = None,
374
+ evolutionary_role_coordinator: Optional[EvolutionaryRoleCoordinator] = None,
375
+ use_evolutionary_roles: bool = False,
376
+ ):
377
+ self._agent_id = agent_id
378
+ self._role = role
379
+ self._policy_env_info = policy_env_info
380
+ self._smart_role_coordinator = smart_role_coordinator
381
+ self._evolutionary_role_coordinator = evolutionary_role_coordinator
382
+ self._use_evolutionary_roles = use_evolutionary_roles
383
+ # Some env configs omit move_energy_cost; default to 1 to match simulator fallback.
384
+ self._move_energy_cost = getattr(policy_env_info, "move_energy_cost", 1)
385
+
386
+ # Observation grid half-ranges
387
+ self._obs_hr = policy_env_info.obs_height // 2
388
+ self._obs_wr = policy_env_info.obs_width // 2
389
+
390
+ # Action lookup
391
+ self._action_names = list(policy_env_info.action_names)
392
+ self._action_set = set(self._action_names)
393
+ self._vibe_names = [
394
+ name[len("change_vibe_") :] for name in self._action_names if name.startswith("change_vibe_")
395
+ ]
396
+ self._move_deltas = {
397
+ "north": (-1, 0),
398
+ "south": (1, 0),
399
+ "east": (0, 1),
400
+ "west": (0, -1),
401
+ }
402
+
403
+ # Feature name sets for observation parsing
404
+ self._spatial_feature_names = {"tag", "cooldown_remaining", "clipped", "remaining_uses"}
405
+ self._agent_feature_key_by_name = {"agent:group": "agent_group", "agent:frozen": "agent_frozen"}
406
+ self._protocol_input_prefix = "protocol_input:"
407
+ self._protocol_output_prefix = "protocol_output:"
408
+
409
+ # Cache tag names on first use
410
+ self._tag_names: dict[int, str] = {}
411
+
412
+ def _noop(self) -> Action:
413
+ return Action(name="noop")
414
+
415
+ def _has_vibe(self, vibe_name: str) -> bool:
416
+ return vibe_name in self._vibe_names
417
+
418
+ def _choose_role_vibe(self, s: CogsguardAgentState) -> str:
419
+ if self._use_evolutionary_roles and self._evolutionary_role_coordinator is not None:
420
+ return self._evolutionary_role_coordinator.choose_vibe(s.agent_id, s.step_count)
421
+ if self._smart_role_coordinator is None:
422
+ return random.choice(ROLE_VIBES)
423
+ return self._smart_role_coordinator.choose_role(s.agent_id)
424
+
425
+ def _move(self, direction: str) -> Action:
426
+ action_name = f"move_{direction}"
427
+ if action_name in self._action_set:
428
+ return Action(name=action_name)
429
+ return self._noop()
430
+
431
+ def initial_agent_state(self) -> CogsguardAgentState:
432
+ """Initialize state for this agent.
433
+
434
+ IMPORTANT: Positions are tracked RELATIVE to the agent's starting position.
435
+ - Agent starts at (0, 0) in internal coordinates
436
+ - All discovered object positions are relative to this origin
437
+ - The actual map size doesn't matter - we only use relative offsets
438
+ - Occupancy grid is centered at (grid_size/2, grid_size/2) to allow negative relative positions
439
+ """
440
+ self._tag_names = self._policy_env_info.tag_id_to_name
441
+
442
+ # Use a grid large enough for typical exploration range
443
+ # Grid center is the agent's starting position (0, 0) in relative coords
444
+ # But stored at grid_center to allow negative relative positions
445
+ grid_size = 200
446
+ grid_center = grid_size // 2
447
+
448
+ state = CogsguardAgentState(
449
+ agent_id=self._agent_id,
450
+ role=self._role,
451
+ map_height=grid_size,
452
+ map_width=grid_size,
453
+ occupancy=[[CellType.FREE.value] * grid_size for _ in range(grid_size)],
454
+ explored=[[False] * grid_size for _ in range(grid_size)],
455
+ # Start at (0, 0) relative - stored at grid center for negative offset support
456
+ row=grid_center,
457
+ col=grid_center,
458
+ )
459
+
460
+ if self._move_energy_cost is not None:
461
+ state.MOVE_ENERGY_COST = self._move_energy_cost
462
+ return state
463
+
464
+ def step_with_state(self, obs: AgentObservation, s: CogsguardAgentState) -> tuple[Action, CogsguardAgentState]:
465
+ """Main step function."""
466
+ s.step_count += 1
467
+ s.current_obs = obs
468
+ s.agent_occupancy.clear()
469
+
470
+ # Read inventory
471
+ self._read_inventory(s, obs)
472
+
473
+ # Update position from last action
474
+ self._update_agent_position(s)
475
+
476
+ # Parse observation
477
+ parsed = self._parse_observation(s, obs)
478
+
479
+ # Update map knowledge
480
+ self._update_occupancy_and_discover(s, parsed)
481
+
482
+ if self._smart_role_coordinator is not None:
483
+ self._smart_role_coordinator.update_agent(s)
484
+
485
+ # Update phase
486
+ self._update_phase(s)
487
+
488
+ # Execute current phase
489
+ action = self._execute_phase(s)
490
+
491
+ # Debug logging
492
+ if DEBUG and s.step_count <= 50: # Only first 50 steps per agent
493
+ gear_status = "HAS_GEAR" if s.has_gear() else "NO_GEAR"
494
+ nexus_pos = s.get_structure_position(StructureType.HUB) or "NOT_FOUND"
495
+ print(
496
+ f"[A{s.agent_id}] Step {s.step_count}: vibe={s.current_vibe} role={s.role.value} | "
497
+ f"Phase={s.phase.value} | {gear_status} | "
498
+ f"Energy={s.energy} | "
499
+ f"Pos=({s.row},{s.col}) | "
500
+ f"Nexus@{nexus_pos} | "
501
+ f"Action={action.name}"
502
+ )
503
+
504
+ s.last_action = action
505
+ return action, s
506
+
507
+ def _read_inventory(self, s: CogsguardAgentState, obs: AgentObservation) -> None:
508
+ """Read inventory, vibe, and last executed action from observation."""
509
+ inv = {}
510
+ vibe_id = 0 # Default vibe ID
511
+ last_action_id: Optional[int] = None
512
+ center_r, center_c = self._obs_hr, self._obs_wr
513
+ token_value_base = None
514
+ for tok in obs.tokens:
515
+ if tok.location == (center_r, center_c):
516
+ feature_name = tok.feature.name
517
+ if feature_name.startswith("inv:"):
518
+ if token_value_base is None:
519
+ token_value_base = int(tok.feature.normalization)
520
+ add_inventory_token(inv, feature_name, tok.value, token_value_base=token_value_base)
521
+ elif feature_name == "vibe":
522
+ vibe_id = tok.value
523
+ elif feature_name == "last_action":
524
+ last_action_id = tok.value
525
+
526
+ s.energy = inv.get("energy", 0)
527
+ s.carbon = inv.get("carbon", 0)
528
+ s.oxygen = inv.get("oxygen", 0)
529
+ s.germanium = inv.get("germanium", 0)
530
+ s.silicon = inv.get("silicon", 0)
531
+ s.heart = inv.get("heart", 0)
532
+ s.influence = inv.get("influence", 0)
533
+ s.hp = inv.get("hp", 100)
534
+
535
+ # Gear items
536
+ s.miner = inv.get("miner", 0)
537
+ s.scout = inv.get("scout", 0)
538
+ s.aligner = inv.get("aligner", 0)
539
+ s.scrambler = inv.get("scrambler", 0)
540
+
541
+ if s.heart != s._last_heart_count:
542
+ s._heart_wait_start = 0
543
+ s._last_heart_count = s.heart
544
+
545
+ # Read vibe name from vibe ID using policy_env_info
546
+ s.current_vibe = self._get_vibe_name(vibe_id)
547
+
548
+ # Read last executed action from observation
549
+ # This tells us what the simulator actually did, not what we intended
550
+ if last_action_id is not None:
551
+ action_names = self._policy_env_info.action_names
552
+ if 0 <= last_action_id < len(action_names):
553
+ s.last_action_executed = action_names[last_action_id]
554
+ else:
555
+ s.last_action_executed = None
556
+ else:
557
+ s.last_action_executed = None
558
+
559
+ def _get_vibe_name(self, vibe_id: int) -> str:
560
+ """Convert vibe ID to vibe name."""
561
+ if 0 <= vibe_id < len(self._vibe_names):
562
+ return self._vibe_names[vibe_id]
563
+ return "default"
564
+
565
+ def _update_agent_position(self, s: CogsguardAgentState) -> None:
566
+ """Update position based on last action that was ACTUALLY EXECUTED.
567
+
568
+ IMPORTANT: Position is updated from the executed action in observations.
569
+ This keeps internal position consistent with the simulator, even when
570
+ movement is delayed or overridden by another controller.
571
+ """
572
+ # Use last_action_executed from observation, NOT last_action (our intent)
573
+ executed_action = s.last_action_executed
574
+ intended_action = s.last_action.name if s.last_action else None
575
+
576
+ # Debug: Log when intended != executed (action failed, delayed, or human control)
577
+ if DEBUG and s.step_count <= 100:
578
+ if intended_action and executed_action and intended_action != executed_action:
579
+ print(
580
+ f"[A{s.agent_id}] ACTION_MISMATCH: intended={intended_action}, "
581
+ f"executed={executed_action} (action failed/delayed or human control)"
582
+ )
583
+
584
+ # ONLY update position when:
585
+ # 1. The executed action is a move
586
+ # 2. We're not interacting with an object this step
587
+ if executed_action and executed_action.startswith("move_") and not s.using_object_this_step:
588
+ direction = executed_action[5:] # Remove "move_" prefix
589
+ if direction in self._move_deltas:
590
+ dr, dc = self._move_deltas[direction]
591
+ s.row += dr
592
+ s.col += dc
593
+
594
+ s.using_object_this_step = False
595
+
596
+ # Track position history
597
+ current_pos = (s.row, s.col)
598
+ s.position_history.append(current_pos)
599
+ if len(s.position_history) > 30:
600
+ s.position_history.pop(0)
601
+
602
+ def _parse_observation(self, s: CogsguardAgentState, obs: AgentObservation) -> ParsedObservation:
603
+ """Parse observation into structured format."""
604
+ return utils_parse_observation(
605
+ s, # type: ignore[arg-type] # CogsguardAgentState is compatible with SimpleAgentState
606
+ obs,
607
+ obs_hr=self._obs_hr,
608
+ obs_wr=self._obs_wr,
609
+ spatial_feature_names=self._spatial_feature_names,
610
+ agent_feature_key_by_name=self._agent_feature_key_by_name,
611
+ protocol_input_prefix=self._protocol_input_prefix,
612
+ protocol_output_prefix=self._protocol_output_prefix,
613
+ tag_names=self._tag_names,
614
+ )
615
+
616
+ def _update_occupancy_and_discover(self, s: CogsguardAgentState, parsed: ParsedObservation) -> None:
617
+ """Update occupancy map and discover objects."""
618
+ if s.row < 0:
619
+ return
620
+
621
+ # Mark all observed cells as FREE and explored
622
+ for obs_r in range(2 * self._obs_hr + 1):
623
+ for obs_c in range(2 * self._obs_wr + 1):
624
+ r, c = obs_r - self._obs_hr + s.row, obs_c - self._obs_wr + s.col
625
+ if 0 <= r < s.map_height and 0 <= c < s.map_width:
626
+ s.occupancy[r][c] = CellType.FREE.value
627
+ s.explored[r][c] = True
628
+
629
+ # Process discovered objects
630
+ if DEBUG and s.step_count == 1:
631
+ print(f"[A{s.agent_id}] Nearby objects: {[obj.name for obj in parsed.nearby_objects.values()]}")
632
+
633
+ for pos, obj_state in parsed.nearby_objects.items():
634
+ r, c = pos
635
+ obj_name = obj_state.name.lower()
636
+ obj_tags = [tag.lower() for tag in obj_state.tags]
637
+
638
+ # Walls are obstacles
639
+ if is_wall(obj_name):
640
+ s.occupancy[r][c] = CellType.OBSTACLE.value
641
+ self._update_structure(s, pos, obj_name, StructureType.WALL, obj_state)
642
+ continue
643
+
644
+ # Track other agents
645
+ if obj_name == "agent" and obj_state.agent_id != s.agent_id:
646
+ s.agent_occupancy.add((r, c))
647
+ continue
648
+
649
+ # Discover gear stations
650
+ for _role, station_name in ROLE_TO_STATION.items():
651
+ if (
652
+ is_station(obj_name, station_name)
653
+ or station_name in obj_name
654
+ or any(station_name in tag for tag in obj_tags)
655
+ ):
656
+ s.occupancy[r][c] = CellType.OBSTACLE.value
657
+ struct_type = self._get_station_type(station_name)
658
+ self._update_structure(s, pos, obj_name, struct_type, obj_state)
659
+ break
660
+
661
+ # Discover supply depots (junction in cogsguard)
662
+ is_junction = has_type_tag(obj_tags, ("supply_depot", "junction", "junction"))
663
+ if is_junction:
664
+ s.occupancy[r][c] = CellType.OBSTACLE.value
665
+ self._update_structure(s, pos, obj_name, StructureType.CHARGER, obj_state)
666
+
667
+ # Discover hub (the main base / resource deposit point)
668
+ is_hub = (
669
+ "hub" in obj_name
670
+ or obj_name in {"main_nexus"}
671
+ or any("hub" in tag or "main_nexus" in tag or "nexus" in tag for tag in obj_tags)
672
+ )
673
+ if is_hub:
674
+ s.occupancy[r][c] = CellType.OBSTACLE.value
675
+ self._update_structure(s, pos, obj_name, StructureType.HUB, obj_state)
676
+
677
+ # Discover chest (for hearts) - exclude extractors which are ChestConfig-backed.
678
+ resources = ["carbon", "oxygen", "germanium", "silicon"]
679
+ has_extractor_tag = "extractor" in obj_name or any("extractor" in tag for tag in obj_tags)
680
+ is_resource_chest = any(f"{res}_" in obj_name or f"{res}chest" in obj_name for res in resources)
681
+ if (
682
+ not has_extractor_tag
683
+ and (obj_name == "chest" or ("chest" in obj_name and not is_resource_chest))
684
+ or any(tag == "chest" for tag in obj_tags)
685
+ ):
686
+ s.occupancy[r][c] = CellType.OBSTACLE.value
687
+ self._update_structure(s, pos, obj_name, StructureType.CHEST, obj_state)
688
+
689
+ # Discover extractors (in cogsguard they're named {resource}_extractor).
690
+ for resource in ["carbon", "oxygen", "germanium", "silicon"]:
691
+ if f"{resource}_extractor" in obj_name or any(f"{resource}_extractor" in tag for tag in obj_tags):
692
+ s.occupancy[r][c] = CellType.OBSTACLE.value
693
+ self._update_structure(s, pos, obj_name, StructureType.EXTRACTOR, obj_state, resource_type=resource)
694
+ break
695
+
696
+ def _get_station_type(self, station_name: str) -> StructureType:
697
+ """Convert station name to StructureType."""
698
+ mapping = {
699
+ "miner_station": StructureType.MINER_STATION,
700
+ "scout_station": StructureType.SCOUT_STATION,
701
+ "aligner_station": StructureType.ALIGNER_STATION,
702
+ "scrambler_station": StructureType.SCRAMBLER_STATION,
703
+ }
704
+ return mapping.get(station_name, StructureType.UNKNOWN)
705
+
706
+ def _update_structure(
707
+ self,
708
+ s: CogsguardAgentState,
709
+ pos: tuple[int, int],
710
+ obj_name: str,
711
+ structure_type: StructureType,
712
+ obj_state: ObjectState,
713
+ resource_type: Optional[str] = None,
714
+ ) -> None:
715
+ """Update or create a structure in the structures map."""
716
+ # Derive alignment from clipped field, object name, and structure type
717
+ clipped = obj_state.clipped > 0
718
+ alignment = self._derive_alignment(obj_name, clipped, structure_type, obj_state.tags)
719
+ if pos in s.alignment_overrides:
720
+ override = s.alignment_overrides[pos]
721
+ if alignment is None:
722
+ alignment = override
723
+ elif alignment != override:
724
+ s.alignment_overrides[pos] = alignment
725
+
726
+ # Calculate inventory amount for extractors
727
+ # Key insight: empty dict {} on FIRST observation = no info yet (assume full)
728
+ # Empty dict {} on SUBSEQUENT observation = depleted (0 resources)
729
+ is_new_structure = pos not in s.structures
730
+
731
+ if structure_type == StructureType.EXTRACTOR:
732
+ # For extractors, track resource counts carefully
733
+ if resource_type and resource_type in obj_state.inventory:
734
+ # We have actual inventory info for this resource type
735
+ inventory_amount = obj_state.inventory[resource_type]
736
+ elif obj_state.inventory:
737
+ # Sum all inventory if resource type not specified
738
+ inventory_amount = sum(obj_state.inventory.values())
739
+ elif is_new_structure:
740
+ # First time seeing this extractor with no inventory info
741
+ # Assume it has resources (we don't know yet)
742
+ inventory_amount = 999
743
+ else:
744
+ # Known extractor with empty inventory dict = depleted (0 resources)
745
+ inventory_amount = 0
746
+ if DEBUG and inventory_amount == 0:
747
+ print(f"[A{s.agent_id}] EXTRACTOR_EMPTY: {pos} resource={resource_type} inv={obj_state.inventory}")
748
+ elif obj_state.inventory:
749
+ # Non-extractors: use inventory sum if present
750
+ inventory_amount = sum(obj_state.inventory.values())
751
+ else:
752
+ inventory_amount = 999 # Default: unknown/full
753
+
754
+ if pos in s.structures:
755
+ # Update existing structure
756
+ struct = s.structures[pos]
757
+ struct.last_seen_step = s.step_count
758
+ struct.cooldown_remaining = obj_state.cooldown_remaining
759
+ struct.remaining_uses = obj_state.remaining_uses
760
+ struct.clipped = clipped
761
+ struct.alignment = alignment
762
+ struct.inventory_amount = inventory_amount
763
+ else:
764
+ # Create new structure
765
+ s.structures[pos] = StructureInfo(
766
+ position=pos,
767
+ structure_type=structure_type,
768
+ name=obj_name,
769
+ last_seen_step=s.step_count,
770
+ resource_type=resource_type,
771
+ cooldown_remaining=obj_state.cooldown_remaining,
772
+ remaining_uses=obj_state.remaining_uses,
773
+ clipped=clipped,
774
+ alignment=alignment,
775
+ inventory_amount=inventory_amount,
776
+ )
777
+
778
+ if structure_type in {
779
+ StructureType.HUB,
780
+ StructureType.CHEST,
781
+ StructureType.MINER_STATION,
782
+ StructureType.SCOUT_STATION,
783
+ StructureType.ALIGNER_STATION,
784
+ StructureType.SCRAMBLER_STATION,
785
+ }:
786
+ s.stations[structure_type.value] = pos
787
+
788
+ if structure_type == StructureType.CHARGER:
789
+ for idx, (depot_pos, _alignment) in enumerate(s.supply_depots):
790
+ if depot_pos == pos:
791
+ s.supply_depots[idx] = (pos, alignment)
792
+ break
793
+ else:
794
+ s.supply_depots.append((pos, alignment))
795
+ if DEBUG:
796
+ print(
797
+ f"[A{s.agent_id}] STRUCTURE: Added {structure_type.value} at {pos} "
798
+ f"(alignment={alignment}, inv={inventory_amount})"
799
+ )
800
+
801
+ def _derive_alignment(
802
+ self,
803
+ obj_name: str,
804
+ clipped: bool,
805
+ structure_type: Optional[StructureType] = None,
806
+ tags: Optional[list[str]] = None,
807
+ ) -> Optional[str]:
808
+ """Derive alignment from object name, tags, clipped status, and structure type.
809
+
810
+ In CoGsGuard:
811
+ - Hub/nexus = cogs-aligned
812
+ - Charger/supply_depot alignment comes from tags/collectives
813
+ """
814
+ obj_lower = obj_name.lower()
815
+ tag_lowers = [tag.lower() for tag in tags or []]
816
+ # Check if name contains alignment info
817
+ if "cogs" in obj_lower or "cogs_" in obj_lower or any("cogs" in tag for tag in tag_lowers):
818
+ return "cogs"
819
+ if "clips" in obj_lower or "clips_" in obj_lower or any("clips" in tag for tag in tag_lowers):
820
+ return "clips"
821
+ # Clipped field indicates clips alignment
822
+ if clipped:
823
+ return "clips"
824
+ # Structure type defaults:
825
+ # - Hub/nexus defaults to cogs (main cogs building)
826
+ if structure_type == StructureType.HUB:
827
+ if (
828
+ "nexus" in obj_lower
829
+ or "hub" in obj_lower
830
+ or any("nexus" in tag for tag in tag_lowers)
831
+ or any("hub" in tag for tag in tag_lowers)
832
+ ):
833
+ return "cogs"
834
+ return None # Unknown/neutral
835
+
836
+ def _update_phase(self, s: CogsguardAgentState) -> None:
837
+ """Update agent phase based on current vibe.
838
+
839
+ Vibe-based state machine:
840
+ - default/heart: do nothing
841
+ - gear: pick role via smart coordinator, change vibe to that role
842
+ - role vibe (scout/miner/aligner/scrambler): get gear first, then execute role
843
+ """
844
+ vibe = s.current_vibe
845
+
846
+ # Role vibes: scout, miner, aligner, scrambler
847
+ if vibe in VIBE_TO_ROLE:
848
+ # Update role based on vibe
849
+ s.role = VIBE_TO_ROLE[vibe]
850
+ # Always try to get gear first, then execute role
851
+ if s.has_gear():
852
+ s.phase = CogsguardPhase.EXECUTE_ROLE
853
+ elif s.step_count > 30 and s.role in (Role.MINER, Role.SCOUT):
854
+ # After 30 steps, miners/scouts can proceed without gear to bootstrap economy/exploration.
855
+ s.phase = CogsguardPhase.EXECUTE_ROLE
856
+ else:
857
+ s.phase = CogsguardPhase.GET_GEAR
858
+ else:
859
+ # For default, heart, gear vibes - handled in _execute_phase
860
+ s.phase = CogsguardPhase.GET_GEAR # Will be overridden
861
+
862
+ def _execute_phase(self, s: CogsguardAgentState) -> Action:
863
+ """Execute action for current phase based on vibe.
864
+
865
+ Vibe-based behavior:
866
+ - default: do nothing (noop)
867
+ - gear: pick role via smart coordinator, change vibe to that role
868
+ - role vibe: get gear then execute role
869
+ - heart: do nothing (noop)
870
+ """
871
+ vibe = s.current_vibe
872
+
873
+ # Default vibe: do nothing (wait for external vibe change)
874
+ if vibe == "default":
875
+ return self._noop()
876
+
877
+ # Heart vibe: do nothing
878
+ if vibe == "heart":
879
+ return self._noop()
880
+
881
+ # Gear vibe: pick a role and change vibe to it
882
+ if vibe == "gear":
883
+ selected_role = self._choose_role_vibe(s)
884
+ if DEBUG:
885
+ print(f"[A{s.agent_id}] GEAR_VIBE: Picking role vibe: {selected_role}")
886
+ return change_vibe_action(selected_role, action_names=self._action_names)
887
+
888
+ # Role vibes: execute the role behavior
889
+ if vibe in VIBE_TO_ROLE:
890
+ if s.phase == CogsguardPhase.GET_GEAR:
891
+ return self._do_get_gear(s)
892
+ elif s.phase == CogsguardPhase.EXECUTE_ROLE:
893
+ return self.execute_role(s)
894
+
895
+ return self._noop()
896
+
897
+ def _do_recharge(self, s: CogsguardAgentState) -> Action:
898
+ """Recharge by standing near the main nexus (cogs-aligned, has energy AOE).
899
+
900
+ IMPORTANT: If energy is very low, we can't even move to the nexus!
901
+ In that case, just wait (noop) and hope AOE regeneration eventually helps,
902
+ or try a single step towards the nexus if we can afford it.
903
+ """
904
+ # The main_nexus is cogs-aligned and has AOE that gives energy to cogs agents
905
+ # The supply_depot is clips-aligned and won't give energy to cogs agents
906
+ nexus_pos = s.get_structure_position(StructureType.HUB)
907
+ if nexus_pos is None:
908
+ if DEBUG:
909
+ print(f"[A{s.agent_id}] RECHARGE: No nexus found, exploring")
910
+ return self._explore(s)
911
+
912
+ # Just need to be near the nexus (within AOE range), not adjacent
913
+ dist = abs(s.row - nexus_pos[0]) + abs(s.col - nexus_pos[1])
914
+ aoe_range = 10 # AOE range from recipe
915
+
916
+ if dist <= aoe_range:
917
+ if DEBUG and s.step_count % 20 == 0:
918
+ print(f"[A{s.agent_id}] RECHARGE: Near nexus (dist={dist}), waiting for AOE (energy={s.energy})")
919
+ return self._noop()
920
+
921
+ # Check if we have enough energy to move at all
922
+ # If energy is too low, just wait and hope for passive regen or AOE
923
+ if s.energy < s.MOVE_ENERGY_COST:
924
+ if DEBUG and s.step_count % 20 == 0:
925
+ print(
926
+ f"[A{s.agent_id}] RECHARGE: Energy critically low ({s.energy}), "
927
+ f"can't move to nexus at dist={dist}, waiting for regen"
928
+ )
929
+ return self._noop()
930
+
931
+ # If we have some energy but not much, try to move one step at a time towards nexus
932
+ # This is more conservative - don't commit to a long path if we might not make it
933
+ if s.energy < s.MOVE_ENERGY_COST * 3:
934
+ if DEBUG and s.step_count % 10 == 0:
935
+ print(
936
+ f"[A{s.agent_id}] RECHARGE: Low energy ({s.energy}), "
937
+ f"taking single step towards nexus at {nexus_pos}"
938
+ )
939
+ # Simple single-step movement towards nexus
940
+ dr = nexus_pos[0] - s.row
941
+ dc = nexus_pos[1] - s.col
942
+ if abs(dr) >= abs(dc):
943
+ # Move vertically
944
+ if dr > 0:
945
+ return self._move("south")
946
+ else:
947
+ return self._move("north")
948
+ else:
949
+ # Move horizontally
950
+ if dc > 0:
951
+ return self._move("east")
952
+ else:
953
+ return self._move("west")
954
+
955
+ if DEBUG and s.step_count % 20 == 0:
956
+ print(f"[A{s.agent_id}] RECHARGE: Moving to nexus at {nexus_pos} from ({s.row},{s.col}), dist={dist}")
957
+ return self._move_towards(s, nexus_pos, reach_adjacent=True)
958
+
959
+ def _do_get_gear(self, s: CogsguardAgentState) -> Action:
960
+ """Find gear station and equip gear."""
961
+ if (
962
+ self._smart_role_coordinator is not None
963
+ and s.role != Role.SCRAMBLER
964
+ and s.step_count <= SCRAMBLER_GEAR_PRIORITY_STEPS
965
+ ):
966
+ gear_counts = self._smart_role_coordinator.get_role_gear_counts()
967
+ if gear_counts.get("scrambler", 0) == 0:
968
+ if DEBUG and s.step_count % 5 == 0:
969
+ print(f"[A{s.agent_id}] GET_GEAR: yielding to scrambler gear priority")
970
+ return self._explore(s)
971
+ station_name = s.get_gear_station_name()
972
+ station_pos = s.get_structure_position(s.get_gear_station_type())
973
+ hub_pos = s.get_structure_position(StructureType.HUB)
974
+
975
+ if DEBUG and s.step_count <= 10:
976
+ known_structures = sorted({struct.structure_type.value for struct in s.structures.values()})
977
+ print(f"[A{s.agent_id}] GET_GEAR: station={station_name} pos={station_pos} all={known_structures}")
978
+
979
+ # Bootstrap with scout gear for mobility when station is unknown.
980
+ if station_pos is None:
981
+ scout_station = s.get_structure_position(StructureType.SCOUT_STATION)
982
+ if scout_station is not None and s.scout == 0 and station_name != "scout_station":
983
+ if not is_adjacent((s.row, s.col), scout_station):
984
+ return self._move_towards(s, scout_station, reach_adjacent=True)
985
+ return self._use_object_at(s, scout_station)
986
+
987
+ if hub_pos is not None:
988
+ offset = GEAR_SEARCH_OFFSETS[(s.agent_id + s.step_count // 10) % len(GEAR_SEARCH_OFFSETS)]
989
+ target = (hub_pos[0] + offset[0], hub_pos[1] + offset[1])
990
+ return self._move_towards(s, target, reach_adjacent=True)
991
+
992
+ if DEBUG:
993
+ print(f"[A{s.agent_id}] GET_GEAR: No {station_name} found, exploring")
994
+ return self._explore(s)
995
+
996
+ # Navigate to station
997
+ adj = is_adjacent((s.row, s.col), station_pos)
998
+ if DEBUG and s.step_count <= 60:
999
+ print(f"[A{s.agent_id}] GET_GEAR: pos=({s.row},{s.col}), station={station_pos}, adjacent={adj}")
1000
+ if not adj:
1001
+ return self._move_towards(s, station_pos, reach_adjacent=True)
1002
+
1003
+ # Bump station to get gear
1004
+ if DEBUG:
1005
+ print(f"[A{s.agent_id}] GET_GEAR: Adjacent to {station_name}, bumping it!")
1006
+ return self._use_object_at(s, station_pos)
1007
+
1008
+ def execute_role(self, s: CogsguardAgentState) -> Action:
1009
+ """Execute role-specific behavior. Override in subclasses."""
1010
+ if s.step_count <= 100:
1011
+ print(f"[A{s.agent_id}] BASE_EXECUTE_ROLE: impl={type(self).__name__}, role={s.role}")
1012
+ return self._explore(s)
1013
+
1014
+ # =========================================================================
1015
+ # Navigation utilities
1016
+ # =========================================================================
1017
+
1018
+ def _use_object_at(self, s: CogsguardAgentState, target_pos: tuple[int, int]) -> Action:
1019
+ """Use an object by moving into its cell."""
1020
+ tr, tc = target_pos
1021
+ if s.row == tr and s.col == tc:
1022
+ return self._noop()
1023
+
1024
+ dr = tr - s.row
1025
+ dc = tc - s.col
1026
+
1027
+ # Check agent collision
1028
+ if (tr, tc) in s.agent_occupancy:
1029
+ return self._noop()
1030
+
1031
+ # Mark that we're using an object
1032
+ s.using_object_this_step = True
1033
+
1034
+ if dr == -1:
1035
+ return self._move("north")
1036
+ if dr == 1:
1037
+ return self._move("south")
1038
+ if dc == 1:
1039
+ return self._move("east")
1040
+ if dc == -1:
1041
+ return self._move("west")
1042
+
1043
+ return self._noop()
1044
+
1045
+ def _explore_frontier(self, s: CogsguardAgentState) -> Optional[Action]:
1046
+ """Find and move toward the nearest unexplored frontier."""
1047
+ if not s.explored or len(s.explored) == 0:
1048
+ return None
1049
+
1050
+ start = (s.row, s.col)
1051
+ visited: set[tuple[int, int]] = {start}
1052
+ queue: deque[tuple[tuple[int, int], Optional[str]]] = deque()
1053
+ queue.append((start, None))
1054
+
1055
+ directions = [("north", -1, 0), ("south", 1, 0), ("east", 0, 1), ("west", 0, -1)]
1056
+ direction_deltas = {direction: (dr, dc) for direction, dr, dc in directions}
1057
+
1058
+ while queue:
1059
+ pos, first_step = queue.popleft()
1060
+ r, c = pos
1061
+
1062
+ for direction, dr, dc in directions:
1063
+ nr, nc = r + dr, c + dc
1064
+ if not (0 <= nr < s.map_height and 0 <= nc < s.map_width):
1065
+ continue
1066
+ if (nr, nc) in visited:
1067
+ continue
1068
+
1069
+ visited.add((nr, nc))
1070
+
1071
+ if not s.explored[nr][nc]:
1072
+ if first_step is None:
1073
+ if s.occupancy[nr][nc] == CellType.FREE.value and (nr, nc) not in s.agent_occupancy:
1074
+ if DEBUG and s.step_count <= 100:
1075
+ print(f"[A{s.agent_id}] FRONTIER: Moving {direction} to unexplored ({nr},{nc})")
1076
+ return self._move(direction)
1077
+ else:
1078
+ step_dr, step_dc = direction_deltas[first_step]
1079
+ step_r, step_c = s.row + step_dr, s.col + step_dc
1080
+ if not (0 <= step_r < s.map_height and 0 <= step_c < s.map_width):
1081
+ continue
1082
+ if s.occupancy[step_r][step_c] != CellType.FREE.value or (step_r, step_c) in s.agent_occupancy:
1083
+ continue
1084
+ if DEBUG and s.step_count <= 100:
1085
+ explored_count = sum(sum(row) for row in s.explored)
1086
+ total_cells = s.map_height * s.map_width
1087
+ print(
1088
+ f"[A{s.agent_id}] FRONTIER: Heading {first_step} towards "
1089
+ f"frontier at ({nr},{nc}), explored={explored_count}/{total_cells}"
1090
+ )
1091
+ return self._move(first_step)
1092
+
1093
+ if s.explored[nr][nc] and s.occupancy[nr][nc] == CellType.FREE.value:
1094
+ next_first_step = first_step
1095
+ if first_step is None and (r, c) == start:
1096
+ next_first_step = direction
1097
+ queue.append(((nr, nc), next_first_step))
1098
+
1099
+ if DEBUG and s.step_count % 50 == 0:
1100
+ explored_count = sum(sum(row) for row in s.explored)
1101
+ total_cells = s.map_height * s.map_width
1102
+ print(f"[A{s.agent_id}] FRONTIER: None found, explored={explored_count}/{total_cells}")
1103
+ return None
1104
+
1105
+ def _explore(self, s: CogsguardAgentState) -> Action:
1106
+ """Explore systematically - cycle through cardinal directions."""
1107
+ # Check for location loop (agents blocking each other back and forth)
1108
+ if self._is_in_location_loop(s):
1109
+ action = self._break_location_loop(s)
1110
+ if action:
1111
+ return action
1112
+ # If can't break loop, fall through to normal exploration
1113
+
1114
+ # Start with east since gear stations are typically east of hub
1115
+ direction_cycle: list[CardinalDirection] = ["east", "south", "west", "north"]
1116
+
1117
+ if DEBUG and s.step_count <= 30:
1118
+ print(f"[A{s.agent_id}] EXPLORE: target={s.exploration_target}, step={s.step_count}")
1119
+
1120
+ if s.exploration_target is not None and isinstance(s.exploration_target, str):
1121
+ steps_in_direction = s.step_count - s.exploration_target_step
1122
+ if steps_in_direction < 8: # Explore 8 steps before turning (faster cycles)
1123
+ dr, dc = self._move_deltas.get(s.exploration_target, (0, 0))
1124
+ next_r, next_c = s.row + dr, s.col + dc
1125
+ if path_is_traversable(s, next_r, next_c, CellType): # type: ignore[arg-type]
1126
+ return self._move(s.exploration_target) # type: ignore[arg-type]
1127
+
1128
+ # Pick next direction in the cycle (don't randomize)
1129
+ current_dir = s.exploration_target
1130
+ if current_dir in direction_cycle:
1131
+ idx = direction_cycle.index(current_dir)
1132
+ next_idx = (idx + 1) % 4
1133
+ else:
1134
+ # Always start with east (index 0) since gear stations are east of hub
1135
+ next_idx = 0
1136
+
1137
+ # Try directions starting from next_idx
1138
+ for i in range(4):
1139
+ direction = direction_cycle[(next_idx + i) % 4]
1140
+ dr, dc = self._move_deltas[direction]
1141
+ next_r, next_c = s.row + dr, s.col + dc
1142
+ traversable = path_is_traversable(s, next_r, next_c, CellType) # type: ignore[arg-type]
1143
+ if DEBUG and s.step_count <= 10:
1144
+ in_bounds = 0 <= next_r < s.map_height and 0 <= next_c < s.map_width
1145
+ cell_val = s.occupancy[next_r][next_c] if in_bounds else -1
1146
+ agent_occ = (next_r, next_c) in s.agent_occupancy
1147
+ print(
1148
+ f"[A{s.agent_id}] EXPLORE_DIR: {direction} -> ({next_r},{next_c}) "
1149
+ f"trav={traversable} cell={cell_val} agent={agent_occ}"
1150
+ )
1151
+ if traversable:
1152
+ s.exploration_target = direction
1153
+ s.exploration_target_step = s.step_count
1154
+ return self._move(direction)
1155
+
1156
+ if DEBUG and s.step_count <= 10:
1157
+ print(f"[A{s.agent_id}] EXPLORE: All directions blocked, returning noop")
1158
+ return self._noop()
1159
+
1160
+ def _move_towards(
1161
+ self,
1162
+ s: CogsguardAgentState,
1163
+ target: tuple[int, int],
1164
+ *,
1165
+ reach_adjacent: bool = False,
1166
+ ) -> Action:
1167
+ """Pathfind toward a target."""
1168
+ # Check for location loop (agents blocking each other back and forth)
1169
+ if self._is_in_location_loop(s):
1170
+ action = self._break_location_loop(s)
1171
+ if action:
1172
+ return action
1173
+ # If can't break loop, fall through to normal pathfinding
1174
+
1175
+ start = (s.row, s.col)
1176
+ if start == target and not reach_adjacent:
1177
+ return self._noop()
1178
+
1179
+ goal_cells = compute_goal_cells(s, target, reach_adjacent, CellType) # type: ignore[arg-type]
1180
+ if not goal_cells:
1181
+ if DEBUG:
1182
+ print(f"[A{s.agent_id}] PATHFIND: No goal cells for {target}")
1183
+ return self._noop()
1184
+
1185
+ # Check cached path
1186
+ path = None
1187
+ if s.cached_path and s.cached_path_target == target and s.cached_path_reach_adjacent == reach_adjacent:
1188
+ next_pos = s.cached_path[0]
1189
+ if path_is_traversable(s, next_pos[0], next_pos[1], CellType): # type: ignore[arg-type]
1190
+ path = s.cached_path
1191
+
1192
+ # Compute new path if needed
1193
+ if path is None:
1194
+ path = shortest_path(s, start, goal_cells, False, CellType) # type: ignore[arg-type]
1195
+ s.cached_path = path.copy() if path else None
1196
+ s.cached_path_target = target
1197
+ s.cached_path_reach_adjacent = reach_adjacent
1198
+
1199
+ if not path:
1200
+ if DEBUG:
1201
+ print(f"[A{s.agent_id}] PATHFIND: No path to {target}, exploring instead")
1202
+ return self._explore(s)
1203
+
1204
+ next_pos = path[0]
1205
+
1206
+ # Convert to action
1207
+ dr = next_pos[0] - s.row
1208
+ dc = next_pos[1] - s.col
1209
+
1210
+ # Check agent collision
1211
+ if (next_pos[0], next_pos[1]) in s.agent_occupancy:
1212
+ action = self._try_random_direction(s)
1213
+ if action:
1214
+ s.cached_path = None
1215
+ s.cached_path_target = None
1216
+ return action
1217
+ return self._noop()
1218
+
1219
+ # Advance cached path only after taking a step
1220
+ if s.cached_path:
1221
+ s.cached_path = s.cached_path[1:]
1222
+ if not s.cached_path:
1223
+ s.cached_path = None
1224
+ s.cached_path_target = None
1225
+
1226
+ if dr == -1 and dc == 0:
1227
+ return self._move("north")
1228
+ elif dr == 1 and dc == 0:
1229
+ return self._move("south")
1230
+ elif dr == 0 and dc == 1:
1231
+ return self._move("east")
1232
+ elif dr == 0 and dc == -1:
1233
+ return self._move("west")
1234
+
1235
+ return self._noop()
1236
+
1237
+ def _try_random_direction(self, s: CogsguardAgentState) -> Optional[Action]:
1238
+ """Try to move in a random free direction."""
1239
+ directions: list[CardinalDirection] = ["north", "south", "east", "west"]
1240
+ random.shuffle(directions)
1241
+ for direction in directions:
1242
+ dr, dc = self._move_deltas[direction]
1243
+ nr, nc = s.row + dr, s.col + dc
1244
+ if path_is_within_bounds(s, nr, nc) and s.occupancy[nr][nc] == CellType.FREE.value: # type: ignore[arg-type]
1245
+ if (nr, nc) not in s.agent_occupancy:
1246
+ return self._move(direction)
1247
+ return None
1248
+
1249
+ def _is_in_location_loop(self, s: CogsguardAgentState) -> bool:
1250
+ """Detect if agent is stuck in a back-and-forth location loop.
1251
+
1252
+ Detects patterns like A→B→A→B→A (oscillating between 2 positions 3+ times).
1253
+ Returns True if such a loop is detected.
1254
+ """
1255
+ history = s.position_history
1256
+ # Need at least 5 positions to detect A→B→A→B→A pattern
1257
+ if len(history) < 5:
1258
+ return False
1259
+
1260
+ # Check last 6 positions for oscillation pattern
1261
+ recent = history[-6:] if len(history) >= 6 else history
1262
+
1263
+ # Count unique positions in recent history
1264
+ unique_positions = set(recent)
1265
+
1266
+ # If only 2 unique positions in last 6 moves, we're oscillating
1267
+ if len(unique_positions) <= 2 and len(recent) >= 5:
1268
+ if DEBUG:
1269
+ print(f"[A{s.agent_id}] LOOP_DETECTED: Oscillating between {unique_positions}")
1270
+ return True
1271
+
1272
+ return False
1273
+
1274
+ def _break_location_loop(self, s: CogsguardAgentState) -> Optional[Action]:
1275
+ """Try to break out of a location loop by moving in a random direction.
1276
+
1277
+ Clears cached path to force re-pathing after breaking the loop.
1278
+ """
1279
+ if DEBUG:
1280
+ print(f"[A{s.agent_id}] BREAKING_LOOP: Attempting random move to escape")
1281
+
1282
+ # Clear cached path to force fresh pathfinding
1283
+ s.cached_path = None
1284
+ s.cached_path_target = None
1285
+
1286
+ # Clear position history to reset loop detection
1287
+ s.position_history.clear()
1288
+
1289
+ return self._try_random_direction(s)
1290
+
1291
+
1292
+ # =============================================================================
1293
+ # Policy wrapper
1294
+ # =============================================================================
1295
+
1296
+
1297
+ def _parse_vibe_list(raw: Optional[str]) -> list[str]:
1298
+ if not raw:
1299
+ return []
1300
+ return [entry.strip().lower() for entry in raw.split(",") if entry.strip()]
1301
+
1302
+
1303
+ class CogsguardPolicy(MultiAgentPolicy):
1304
+ """Multi-agent policy for CoGsGuard with vibe-based role selection.
1305
+
1306
+ Agents use vibes to determine their behavior:
1307
+ - default: do nothing
1308
+ - gear: pick a role via smart or evolutionary coordinator, change vibe to that role
1309
+ - miner/scout/aligner/scrambler: get gear then execute role
1310
+ - heart: do nothing
1311
+
1312
+ Initial vibe counts can be specified via URI query parameters:
1313
+ metta://policy/role_py?miner=4&scrambler=2&gear=1
1314
+ You can also set a fixed role pattern with:
1315
+ metta://policy/role_py?role_cycle=aligner,miner,scrambler,scout
1316
+ metta://policy/role_py?role_order=aligner,miner,aligner,miner,scout
1317
+
1318
+ Vibes are assigned to agents in order. If counts don't sum to num_agents,
1319
+ remaining agents get "gear" vibe (which picks a role via the smart coordinator).
1320
+ """
1321
+
1322
+ short_names = ["role_py"]
1323
+
1324
+ def __init__(
1325
+ self,
1326
+ policy_env_info: PolicyEnvInterface,
1327
+ device: str = "cpu",
1328
+ role_cycle: Optional[str] = None,
1329
+ role_order: Optional[str] = None,
1330
+ **vibe_counts: int,
1331
+ ):
1332
+ super().__init__(policy_env_info, device=device)
1333
+ self._agent_policies: dict[int, StatefulAgentPolicy[CogsguardAgentState]] = {}
1334
+ self._smart_role_coordinator = _shared_coordinator(policy_env_info)
1335
+ self._feature_by_id = {feature.id: feature for feature in policy_env_info.obs_features}
1336
+ self._action_name_to_index = {name: idx for idx, name in enumerate(policy_env_info.action_names)}
1337
+ self._noop_action_value = dtype_actions.type(self._action_name_to_index.get("noop", 0))
1338
+
1339
+ def _parse_flag(value: object) -> bool:
1340
+ if isinstance(value, bool):
1341
+ return value
1342
+ if isinstance(value, int):
1343
+ return value != 0
1344
+ if isinstance(value, str):
1345
+ return value.strip().lower() in {"1", "true", "yes", "on"}
1346
+ return False
1347
+
1348
+ self._use_evolutionary_roles = (
1349
+ _parse_flag(vibe_counts.pop("evolution", None))
1350
+ or _parse_flag(vibe_counts.pop("evolutionary", None))
1351
+ or _parse_flag(vibe_counts.pop("evolve", None))
1352
+ )
1353
+ self._evolutionary_role_coordinator = (
1354
+ EvolutionaryRoleCoordinator(policy_env_info.num_agents) if self._use_evolutionary_roles else None
1355
+ )
1356
+ self._evolutionary_hooks_configured = False
1357
+
1358
+ available_vibes = {
1359
+ name[len("change_vibe_") :] for name in policy_env_info.action_names if name.startswith("change_vibe_")
1360
+ }
1361
+ role_vibes = [vibe for vibe in ["scrambler", "aligner", "miner", "scout"] if vibe in available_vibes]
1362
+
1363
+ self._initial_vibes: list[str] = []
1364
+ role_cycle_list = _parse_vibe_list(role_cycle)
1365
+ role_order_list = _parse_vibe_list(role_order)
1366
+
1367
+ if role_order_list:
1368
+ self._initial_vibes = []
1369
+ fallback_vibe = "default"
1370
+ for vibe in role_order_list:
1371
+ if vibe in available_vibes or vibe == "default":
1372
+ self._initial_vibes.append(vibe)
1373
+ else:
1374
+ if DEBUG:
1375
+ print(f"[CogsguardPolicy] Unknown role_order vibe '{vibe}', using '{fallback_vibe}'")
1376
+ self._initial_vibes.append(fallback_vibe)
1377
+ remaining = policy_env_info.num_agents - len(self._initial_vibes)
1378
+ if remaining > 0 and "gear" in available_vibes:
1379
+ self._initial_vibes.extend(["gear"] * remaining)
1380
+ elif role_cycle_list:
1381
+ cycle = [vibe for vibe in role_cycle_list if vibe in available_vibes]
1382
+ if cycle:
1383
+ while len(self._initial_vibes) < policy_env_info.num_agents:
1384
+ self._initial_vibes.extend(cycle)
1385
+ self._initial_vibes = self._initial_vibes[: policy_env_info.num_agents]
1386
+
1387
+ if not self._initial_vibes:
1388
+ # Build initial vibe assignment from URI params (e.g., ?scrambler=1&miner=4)
1389
+ counts = {k: v for k, v in vibe_counts.items() if isinstance(v, int)}
1390
+ if not counts and role_vibes:
1391
+ counts = {"scrambler": 1, "miner": 4}
1392
+
1393
+ if role_vibes:
1394
+ for vibe_name in role_vibes: # Role vibes first
1395
+ self._initial_vibes.extend([vibe_name] * counts.get(vibe_name, 0))
1396
+ # Add gear vibes (agents will pick a role)
1397
+ if "gear" in available_vibes:
1398
+ self._initial_vibes.extend(["gear"] * counts.get("gear", 0))
1399
+ remaining = policy_env_info.num_agents - len(self._initial_vibes)
1400
+ if remaining > 0 and "gear" in available_vibes:
1401
+ self._initial_vibes.extend(["gear"] * remaining)
1402
+
1403
+ if DEBUG:
1404
+ print(f"[CogsguardPolicy] Initial vibe assignment: {self._initial_vibes}")
1405
+
1406
+ def agent_policy(self, agent_id: int) -> StatefulAgentPolicy[CogsguardAgentState]:
1407
+ if agent_id not in self._agent_policies:
1408
+ # Create a multi-role implementation that can handle any role
1409
+ # The actual role is determined by vibe at runtime
1410
+ # Assign initial target vibe based on agent_id and configured counts
1411
+ target_vibe: Optional[str] = None
1412
+ if agent_id < len(self._initial_vibes):
1413
+ target_vibe = self._initial_vibes[agent_id]
1414
+ # Agents without assigned vibes stay on "default" (noop)
1415
+
1416
+ impl = CogsguardMultiRoleImpl(
1417
+ self._policy_env_info,
1418
+ agent_id,
1419
+ initial_target_vibe=target_vibe,
1420
+ smart_role_coordinator=self._smart_role_coordinator,
1421
+ evolutionary_role_coordinator=self._evolutionary_role_coordinator,
1422
+ use_evolutionary_roles=self._use_evolutionary_roles,
1423
+ )
1424
+ if self._evolutionary_role_coordinator is not None and not self._evolutionary_hooks_configured:
1425
+ from .behavior_hooks import build_cogsguard_behavior_hooks
1426
+
1427
+ self._evolutionary_role_coordinator.behavior_hooks.update(build_cogsguard_behavior_hooks(impl))
1428
+ self._evolutionary_hooks_configured = True
1429
+ self._agent_policies[agent_id] = StatefulAgentPolicy(impl, self._policy_env_info, agent_id=agent_id)
1430
+
1431
+ return self._agent_policies[agent_id]
1432
+
1433
+ def step_batch(self, raw_observations: np.ndarray, raw_actions: np.ndarray) -> None:
1434
+ num_agents = min(raw_observations.shape[0], self._policy_env_info.num_agents)
1435
+ for agent_id in range(num_agents):
1436
+ obs = self._raw_obs_to_agent_obs(agent_id, raw_observations[agent_id])
1437
+ action = self.agent_policy(agent_id).step(obs)
1438
+ action_index = self._action_name_to_index.get(action.name, self._noop_action_value)
1439
+ raw_actions[agent_id] = dtype_actions.type(action_index)
1440
+
1441
+ def _raw_obs_to_agent_obs(self, agent_id: int, raw_obs: np.ndarray) -> AgentObservation:
1442
+ tokens: list[ObservationToken] = []
1443
+ for token in raw_obs:
1444
+ feature_id = int(token[1])
1445
+ if feature_id == 0xFF:
1446
+ break
1447
+ feature = self._feature_by_id.get(feature_id)
1448
+ if feature is None:
1449
+ continue
1450
+ location_packed = int(token[0])
1451
+ value = int(token[2])
1452
+ tokens.append(
1453
+ ObservationToken(
1454
+ feature=feature,
1455
+ value=value,
1456
+ raw_token=(location_packed, feature_id, value),
1457
+ )
1458
+ )
1459
+ return AgentObservation(agent_id=agent_id, tokens=tokens)
1460
+
1461
+
1462
+ class CogsguardMultiRoleImpl(CogsguardAgentPolicyImpl):
1463
+ """Multi-role implementation that delegates to role-specific behavior based on vibe."""
1464
+
1465
+ def __init__(
1466
+ self,
1467
+ policy_env_info: PolicyEnvInterface,
1468
+ agent_id: int,
1469
+ initial_target_vibe: Optional[str] = None,
1470
+ smart_role_coordinator: Optional[SmartRoleCoordinator] = None,
1471
+ evolutionary_role_coordinator: Optional[EvolutionaryRoleCoordinator] = None,
1472
+ use_evolutionary_roles: bool = False,
1473
+ ):
1474
+ # Initialize with MINER as default, but role will be updated based on vibe
1475
+ super().__init__(
1476
+ policy_env_info,
1477
+ agent_id,
1478
+ Role.MINER,
1479
+ smart_role_coordinator=smart_role_coordinator,
1480
+ evolutionary_role_coordinator=evolutionary_role_coordinator,
1481
+ use_evolutionary_roles=use_evolutionary_roles,
1482
+ )
1483
+
1484
+ # Target vibe to switch to at start (if specified)
1485
+ self._initial_target_vibe = initial_target_vibe
1486
+ self._initial_vibe_set = False
1487
+ self._smart_role_enabled = initial_target_vibe == "gear"
1488
+
1489
+ # Lazy-load role implementations
1490
+ self._role_impls: dict[Role, CogsguardAgentPolicyImpl] = {}
1491
+
1492
+ def _execute_phase(self, s: CogsguardAgentState) -> Action:
1493
+ """Execute action for current phase, handling initial vibe assignment.
1494
+
1495
+ Overrides base class to:
1496
+ 1. Handle initial vibe assignment from URI params
1497
+ 2. Skip the hardcoded "agent 0 = scrambler" logic when initial vibe is configured
1498
+ """
1499
+ # If we have a target vibe and haven't switched yet, do it first
1500
+ if self._initial_target_vibe and not self._initial_vibe_set:
1501
+ if not self._has_vibe(self._initial_target_vibe):
1502
+ self._initial_vibe_set = True
1503
+ self._smart_role_enabled = False
1504
+ return self._noop()
1505
+ if s.current_vibe != self._initial_target_vibe:
1506
+ if DEBUG:
1507
+ print(
1508
+ f"[A{s.agent_id}] INITIAL_VIBE: Switching from {s.current_vibe} to {self._initial_target_vibe}"
1509
+ )
1510
+ return change_vibe_action(self._initial_target_vibe, action_names=self._action_names)
1511
+ self._initial_vibe_set = True
1512
+
1513
+ # If initial target vibe was configured, skip the hardcoded agent 0 scrambler logic
1514
+ # by directly handling the vibe-based behavior here
1515
+ if self._initial_target_vibe:
1516
+ return self._execute_vibe_behavior(s)
1517
+
1518
+ # Continue with normal phase execution (includes agent 0 scrambler logic)
1519
+ return super()._execute_phase(s)
1520
+
1521
+ def _execute_vibe_behavior(self, s: CogsguardAgentState) -> Action:
1522
+ """Execute vibe-based behavior without the hardcoded agent 0 scrambler override."""
1523
+ vibe = s.current_vibe
1524
+
1525
+ # Default vibe: do nothing (wait for external vibe change)
1526
+ if vibe == "default":
1527
+ return self._noop()
1528
+
1529
+ # Heart vibe: do nothing
1530
+ if vibe == "heart":
1531
+ return self._noop()
1532
+
1533
+ # Gear vibe: pick a role and change vibe to it
1534
+ if vibe == "gear":
1535
+ selected_role = self._choose_role_vibe(s)
1536
+ if DEBUG:
1537
+ print(f"[A{s.agent_id}] GEAR_VIBE: Picking role vibe: {selected_role}")
1538
+ if not self._has_vibe(selected_role):
1539
+ return self._noop()
1540
+ return change_vibe_action(selected_role, action_names=self._action_names)
1541
+
1542
+ # Role vibes: execute the role behavior
1543
+ if vibe in VIBE_TO_ROLE:
1544
+ action = self._maybe_switch_smart_role(s)
1545
+ if action is not None:
1546
+ return action
1547
+ if s.phase == CogsguardPhase.GET_GEAR:
1548
+ return self._do_get_gear(s)
1549
+ elif s.phase == CogsguardPhase.EXECUTE_ROLE:
1550
+ return self.execute_role(s)
1551
+
1552
+ return self._noop()
1553
+
1554
+ def _maybe_switch_smart_role(self, s: CogsguardAgentState) -> Optional[Action]:
1555
+ if not self._smart_role_enabled or self._smart_role_coordinator is None:
1556
+ return None
1557
+ if s._pending_action_type is not None:
1558
+ return None
1559
+ if s.phase == CogsguardPhase.GET_GEAR:
1560
+ return None
1561
+ if s.step_count < s.role_lock_until_step:
1562
+ return None
1563
+
1564
+ gear_role = None
1565
+ if s.aligner > 0:
1566
+ gear_role = "aligner"
1567
+ elif s.scrambler > 0:
1568
+ gear_role = "scrambler"
1569
+ elif s.miner > 0:
1570
+ gear_role = "miner"
1571
+ elif s.scout > 0:
1572
+ gear_role = "scout"
1573
+ if gear_role and gear_role != s.current_vibe:
1574
+ if not self._has_vibe(gear_role):
1575
+ return None
1576
+ return change_vibe_action(gear_role, action_names=self._action_names)
1577
+
1578
+ selected_role = self._smart_role_coordinator.choose_role(s.agent_id)
1579
+ if selected_role == s.current_vibe:
1580
+ return None
1581
+
1582
+ s.last_role_switch_step = s.step_count
1583
+ s.role_lock_until_step = s.step_count + SMART_ROLE_SWITCH_COOLDOWN
1584
+ if DEBUG:
1585
+ print(f"[A{s.agent_id}] SMART_ROLE: Switching to {selected_role}")
1586
+ if not self._has_vibe(selected_role):
1587
+ return None
1588
+ return change_vibe_action(selected_role, action_names=self._action_names)
1589
+
1590
+ def _get_role_impl(self, role: Role) -> CogsguardAgentPolicyImpl:
1591
+ """Get or create role-specific implementation."""
1592
+ if role not in self._role_impls:
1593
+ from .aligner import AlignerAgentPolicyImpl
1594
+ from .miner import MinerAgentPolicyImpl
1595
+ from .scout import ScoutAgentPolicyImpl
1596
+ from .scrambler import ScramblerAgentPolicyImpl
1597
+
1598
+ impl_class = {
1599
+ Role.MINER: MinerAgentPolicyImpl,
1600
+ Role.SCOUT: ScoutAgentPolicyImpl,
1601
+ Role.ALIGNER: AlignerAgentPolicyImpl,
1602
+ Role.SCRAMBLER: ScramblerAgentPolicyImpl,
1603
+ }[role]
1604
+
1605
+ self._role_impls[role] = impl_class(
1606
+ self._policy_env_info,
1607
+ self._agent_id,
1608
+ role,
1609
+ smart_role_coordinator=self._smart_role_coordinator,
1610
+ )
1611
+
1612
+ return self._role_impls[role]
1613
+
1614
+ def execute_role(self, s: CogsguardAgentState) -> Action:
1615
+ """Delegate to role-specific implementation based on current role (set from vibe)."""
1616
+ role_impl = self._get_role_impl(s.role)
1617
+ return role_impl.execute_role(s)
1618
+
1619
+
1620
+ class CogsguardGeneralistImpl(CogsguardAgentPolicyImpl):
1621
+ """Generalist agent that picks roles based on situational priorities."""
1622
+
1623
+ ROLE = Role.MINER
1624
+ ROLE_SWITCH_COOLDOWN = 120
1625
+ EARLY_SCOUT_STEPS = 80
1626
+ MIN_STRUCTURES_FOR_MIDGAME = 6
1627
+ MIN_ENERGY_BUFFER = 2
1628
+
1629
+ def __init__(
1630
+ self,
1631
+ policy_env_info: PolicyEnvInterface,
1632
+ agent_id: int,
1633
+ smart_role_coordinator: Optional[SmartRoleCoordinator] = None,
1634
+ ):
1635
+ super().__init__(policy_env_info, agent_id, Role.MINER, smart_role_coordinator=smart_role_coordinator)
1636
+ self._role_impls: dict[Role, CogsguardAgentPolicyImpl] = {}
1637
+
1638
+ def _update_phase(self, s: CogsguardAgentState) -> None:
1639
+ desired_role = self._select_role(s)
1640
+ if self._should_switch_role(s, desired_role):
1641
+ if DEBUG:
1642
+ print(f"[A{s.agent_id}] GENERALIST: Switching role {s.role.value} -> {desired_role.value}")
1643
+ s.role = desired_role
1644
+ s.last_role_switch_step = s.step_count
1645
+ s.role_lock_until_step = s.step_count + self.ROLE_SWITCH_COOLDOWN
1646
+
1647
+ if s.has_gear() or s.step_count > 30:
1648
+ s.phase = CogsguardPhase.EXECUTE_ROLE
1649
+ else:
1650
+ s.phase = CogsguardPhase.GET_GEAR
1651
+
1652
+ def _execute_phase(self, s: CogsguardAgentState) -> Action:
1653
+ if self._should_recharge(s):
1654
+ return self._do_recharge(s)
1655
+ if s.phase == CogsguardPhase.GET_GEAR:
1656
+ return self._do_get_gear(s)
1657
+ if s.phase == CogsguardPhase.EXECUTE_ROLE:
1658
+ return self.execute_role(s)
1659
+ return self._noop()
1660
+
1661
+ def execute_role(self, s: CogsguardAgentState) -> Action:
1662
+ role_impl = self._get_role_impl(s.role)
1663
+ return role_impl.execute_role(s)
1664
+
1665
+ def _select_role(self, s: CogsguardAgentState) -> Role:
1666
+ if s._pending_action_type is not None:
1667
+ return s.role
1668
+
1669
+ hub_known = s.stations.get("hub") is not None
1670
+ chest_known = s.stations.get("chest") is not None
1671
+
1672
+ if not hub_known:
1673
+ return Role.SCOUT
1674
+
1675
+ if s._pending_alignment_target is not None:
1676
+ return Role.ALIGNER
1677
+
1678
+ if s.step_count < self.EARLY_SCOUT_STEPS and (
1679
+ not chest_known or len(s.structures) < self.MIN_STRUCTURES_FOR_MIDGAME
1680
+ ):
1681
+ return Role.SCOUT
1682
+
1683
+ junctions = s.get_structures_by_type(StructureType.CHARGER)
1684
+ has_enemy_junctions = any(junction.alignment == "clips" for junction in junctions)
1685
+ has_neutral_junctions = any(junction.alignment in (None, "neutral") for junction in junctions)
1686
+ role_counts = self._role_counts()
1687
+
1688
+ if s.role == Role.SCRAMBLER:
1689
+ if has_neutral_junctions and self._role_is_ready(s, Role.ALIGNER):
1690
+ return Role.ALIGNER
1691
+ if has_enemy_junctions:
1692
+ return Role.SCRAMBLER
1693
+ if s.role == Role.ALIGNER:
1694
+ if (
1695
+ s._pending_alignment_target is None
1696
+ and not has_neutral_junctions
1697
+ and has_enemy_junctions
1698
+ and self._role_is_ready(s, Role.SCRAMBLER)
1699
+ ):
1700
+ return Role.SCRAMBLER
1701
+ if has_enemy_junctions or has_neutral_junctions:
1702
+ return Role.ALIGNER
1703
+
1704
+ target_counts = self._target_role_counts(
1705
+ num_agents=self._policy_env_info.num_agents,
1706
+ has_enemy_junctions=has_enemy_junctions,
1707
+ has_neutral_junctions=has_neutral_junctions,
1708
+ hub_known=hub_known,
1709
+ step_count=s.step_count,
1710
+ )
1711
+ deficit_role = self._pick_deficit_role(s, role_counts, target_counts)
1712
+ if deficit_role is not None:
1713
+ return deficit_role
1714
+
1715
+ if has_enemy_junctions and self._role_is_ready(s, Role.SCRAMBLER):
1716
+ return Role.SCRAMBLER
1717
+ if has_neutral_junctions and self._role_is_ready(s, Role.ALIGNER):
1718
+ return Role.ALIGNER
1719
+
1720
+ if not s.get_usable_extractors():
1721
+ return Role.SCOUT
1722
+
1723
+ if s.total_cargo < s.cargo_capacity - 2:
1724
+ return Role.MINER
1725
+
1726
+ return self._pick_balanced_role(s, has_enemy_junctions, has_neutral_junctions)
1727
+
1728
+ def _should_switch_role(self, s: CogsguardAgentState, desired_role: Role) -> bool:
1729
+ if desired_role == s.role:
1730
+ return False
1731
+ if s._pending_action_type is not None:
1732
+ return False
1733
+ if desired_role == Role.ALIGNER and s._pending_alignment_target is not None:
1734
+ return True
1735
+ if desired_role == Role.ALIGNER and s.role == Role.SCRAMBLER:
1736
+ has_neutral = any(
1737
+ junction.alignment in (None, "neutral") for junction in s.get_structures_by_type(StructureType.CHARGER)
1738
+ )
1739
+ if has_neutral:
1740
+ return True
1741
+ if s.step_count < s.role_lock_until_step:
1742
+ return False
1743
+ return True
1744
+
1745
+ def _role_is_ready(self, s: CogsguardAgentState, role: Role) -> bool:
1746
+ if role in (Role.ALIGNER, Role.SCRAMBLER) and s.stations.get("hub") is None:
1747
+ return False
1748
+ if role in (Role.MINER, Role.SCOUT):
1749
+ return True
1750
+ return True
1751
+
1752
+ def _target_role_counts(
1753
+ self,
1754
+ num_agents: int,
1755
+ has_enemy_junctions: bool,
1756
+ has_neutral_junctions: bool,
1757
+ hub_known: bool,
1758
+ step_count: int,
1759
+ ) -> dict[Role, int]:
1760
+ targets: dict[Role, int] = {}
1761
+ if step_count < self.EARLY_SCOUT_STEPS:
1762
+ targets[Role.SCOUT] = 2 if num_agents >= 8 else 1
1763
+ else:
1764
+ targets[Role.SCOUT] = 1
1765
+ targets[Role.MINER] = max(4, num_agents // 2)
1766
+ if hub_known:
1767
+ targets[Role.SCRAMBLER] = 2 if has_enemy_junctions else 1
1768
+ targets[Role.ALIGNER] = 2 if has_neutral_junctions else 1
1769
+ return targets
1770
+
1771
+ def _pick_deficit_role(
1772
+ self,
1773
+ s: CogsguardAgentState,
1774
+ role_counts: dict[Role, int],
1775
+ target_counts: dict[Role, int],
1776
+ ) -> Role | None:
1777
+ if not target_counts:
1778
+ return None
1779
+ deficits: list[Role] = []
1780
+ ordered_roles = [Role.SCRAMBLER, Role.ALIGNER, Role.SCOUT, Role.MINER]
1781
+ for role in ordered_roles:
1782
+ target = target_counts.get(role, 0)
1783
+ if target <= 0:
1784
+ continue
1785
+ deficit = max(target - role_counts.get(role, 0), 0)
1786
+ deficits.extend([role] * deficit)
1787
+ if not deficits:
1788
+ return None
1789
+ role = deficits[s.agent_id % len(deficits)]
1790
+ if self._role_is_ready(s, role):
1791
+ return role
1792
+ return None
1793
+
1794
+ def _should_recharge(self, s: CogsguardAgentState) -> bool:
1795
+ if s.total_cargo > 0:
1796
+ return False
1797
+ if s.energy >= s.MOVE_ENERGY_COST * self.MIN_ENERGY_BUFFER:
1798
+ return False
1799
+ return s.stations.get("hub") is not None
1800
+
1801
+ def _role_has_gear(self, s: CogsguardAgentState, role: Role) -> bool:
1802
+ return getattr(s, ROLE_TO_GEAR[role], 0) > 0
1803
+
1804
+ def _role_station_known(self, s: CogsguardAgentState, role: Role) -> bool:
1805
+ return s.stations.get(ROLE_TO_STATION[role]) is not None
1806
+
1807
+ def _pick_balanced_role(
1808
+ self,
1809
+ s: CogsguardAgentState,
1810
+ has_enemy_junctions: bool,
1811
+ has_neutral_junctions: bool,
1812
+ ) -> Role:
1813
+ candidates = [Role.MINER, Role.SCOUT]
1814
+ if has_enemy_junctions and self._role_is_ready(s, Role.SCRAMBLER):
1815
+ candidates.append(Role.SCRAMBLER)
1816
+ if has_neutral_junctions and self._role_is_ready(s, Role.ALIGNER):
1817
+ candidates.append(Role.ALIGNER)
1818
+
1819
+ role_counts = self._role_counts()
1820
+ best_role = s.role
1821
+ best_score = float("-inf")
1822
+ for role in candidates:
1823
+ score = 0
1824
+ if role == s.role:
1825
+ score += 2
1826
+ if self._role_has_gear(s, role):
1827
+ score += 3
1828
+ elif self._role_station_known(s, role):
1829
+ score += 1
1830
+ if role_counts:
1831
+ score += 2 - role_counts.get(role, 0)
1832
+ if score > best_score:
1833
+ best_score = score
1834
+ best_role = role
1835
+ return best_role
1836
+
1837
+ def _role_counts(self) -> dict[Role, int]:
1838
+ if self._smart_role_coordinator is None:
1839
+ return {}
1840
+ counts = {role: 0 for role in Role}
1841
+ for snapshot in self._smart_role_coordinator.agent_snapshots.values():
1842
+ counts[snapshot.role] += 1
1843
+ return counts
1844
+
1845
+ def _get_role_impl(self, role: Role) -> CogsguardAgentPolicyImpl:
1846
+ if role not in self._role_impls:
1847
+ from .aligner import AlignerAgentPolicyImpl
1848
+ from .miner import MinerAgentPolicyImpl
1849
+ from .scout import ScoutAgentPolicyImpl
1850
+ from .scrambler import ScramblerAgentPolicyImpl
1851
+
1852
+ impl_class = {
1853
+ Role.MINER: MinerAgentPolicyImpl,
1854
+ Role.SCOUT: ScoutAgentPolicyImpl,
1855
+ Role.ALIGNER: AlignerAgentPolicyImpl,
1856
+ Role.SCRAMBLER: ScramblerAgentPolicyImpl,
1857
+ }[role]
1858
+
1859
+ self._role_impls[role] = impl_class(
1860
+ self._policy_env_info,
1861
+ self._agent_id,
1862
+ role,
1863
+ smart_role_coordinator=self._smart_role_coordinator,
1864
+ )
1865
+ return self._role_impls[role]
1866
+
1867
+
1868
+ class CogsguardWomboImpl(CogsguardGeneralistImpl):
1869
+ """Generalist agent that prioritizes aligning multiple junctions."""
1870
+
1871
+ TARGET_ALIGNED_JUNCTIONS = 2
1872
+ JUNCTION_PUSH_SCOUTS = 2
1873
+ JUNCTION_PUSH_ALIGNERS = 2
1874
+ JUNCTION_PUSH_SCRAMBLERS = 2
1875
+ MIN_MINERS = 4
1876
+
1877
+ def _select_role(self, s: CogsguardAgentState) -> Role:
1878
+ aligned_count = 0
1879
+ if self._smart_role_coordinator is not None:
1880
+ aligned_count = self._smart_role_coordinator.aligned_junction_count()
1881
+ if aligned_count < self.TARGET_ALIGNED_JUNCTIONS:
1882
+ if s._pending_action_type is not None:
1883
+ return s.role
1884
+ if s.stations.get("hub") is None:
1885
+ return Role.SCOUT
1886
+ if s.role in (Role.SCRAMBLER, Role.ALIGNER) and s.has_gear():
1887
+ return s.role
1888
+ if s.role == Role.SCRAMBLER and s._pending_alignment_target is not None:
1889
+ return Role.SCRAMBLER
1890
+ return super()._select_role(s)
1891
+
1892
+ def _should_recharge(self, s: CogsguardAgentState) -> bool:
1893
+ aligned_count = 0
1894
+ if self._smart_role_coordinator is not None:
1895
+ aligned_count = self._smart_role_coordinator.aligned_junction_count()
1896
+ if aligned_count < self.TARGET_ALIGNED_JUNCTIONS and s.role in (Role.SCRAMBLER, Role.ALIGNER):
1897
+ if s.total_cargo > 0:
1898
+ return False
1899
+ if s.energy >= s.MOVE_ENERGY_COST * self.MIN_ENERGY_BUFFER:
1900
+ return False
1901
+ return s.stations.get("hub") is not None
1902
+ return super()._should_recharge(s)
1903
+
1904
+ def _target_role_counts(
1905
+ self,
1906
+ num_agents: int,
1907
+ has_enemy_junctions: bool,
1908
+ has_neutral_junctions: bool,
1909
+ hub_known: bool,
1910
+ step_count: int,
1911
+ ) -> dict[Role, int]:
1912
+ targets = super()._target_role_counts(
1913
+ num_agents=num_agents,
1914
+ has_enemy_junctions=has_enemy_junctions,
1915
+ has_neutral_junctions=has_neutral_junctions,
1916
+ hub_known=hub_known,
1917
+ step_count=step_count,
1918
+ )
1919
+
1920
+ aligned_count = 0
1921
+ if self._smart_role_coordinator is not None:
1922
+ aligned_count = self._smart_role_coordinator.aligned_junction_count()
1923
+
1924
+ if aligned_count < self.TARGET_ALIGNED_JUNCTIONS:
1925
+ targets[Role.SCOUT] = max(targets.get(Role.SCOUT, 0), self.JUNCTION_PUSH_SCOUTS)
1926
+ if hub_known:
1927
+ targets[Role.SCRAMBLER] = max(targets.get(Role.SCRAMBLER, 0), self.JUNCTION_PUSH_SCRAMBLERS)
1928
+ targets[Role.ALIGNER] = max(targets.get(Role.ALIGNER, 0), self.JUNCTION_PUSH_ALIGNERS)
1929
+ targets[Role.MINER] = max(self.MIN_MINERS, num_agents // 2)
1930
+
1931
+ return targets
1932
+
1933
+
1934
+ class CogsguardGeneralistPolicy(CogsguardPolicy):
1935
+ """Generalist policy that adapts roles based on map and resource priorities."""
1936
+
1937
+ def __init__(self, policy_env_info: PolicyEnvInterface, device: str = "cpu", **_ignored: int):
1938
+ super().__init__(policy_env_info, device=device, **_ignored)
1939
+
1940
+ def agent_policy(self, agent_id: int) -> StatefulAgentPolicy[CogsguardAgentState]:
1941
+ if agent_id not in self._agent_policies:
1942
+ impl = CogsguardGeneralistImpl(
1943
+ self._policy_env_info,
1944
+ agent_id,
1945
+ smart_role_coordinator=self._smart_role_coordinator,
1946
+ )
1947
+ self._agent_policies[agent_id] = StatefulAgentPolicy(impl, self._policy_env_info, agent_id=agent_id)
1948
+ return self._agent_policies[agent_id]
1949
+
1950
+
1951
+ class CogsguardWomboPolicy(CogsguardPolicy):
1952
+ """Generalist policy that prioritizes role rigs based on map conditions."""
1953
+
1954
+ short_names = ["wombo"]
1955
+
1956
+ def __init__(self, policy_env_info: PolicyEnvInterface, device: str = "cpu", **_ignored: int):
1957
+ super().__init__(policy_env_info, device=device, **_ignored)
1958
+
1959
+ def agent_policy(self, agent_id: int) -> StatefulAgentPolicy[CogsguardAgentState]:
1960
+ if agent_id not in self._agent_policies:
1961
+ impl = CogsguardWomboImpl(
1962
+ self._policy_env_info,
1963
+ agent_id,
1964
+ smart_role_coordinator=self._smart_role_coordinator,
1965
+ )
1966
+ self._agent_policies[agent_id] = StatefulAgentPolicy(impl, self._policy_env_info, agent_id=agent_id)
1967
+ return self._agent_policies[agent_id]