cogames-agents 0.0.0.7__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. cogames_agents/__init__.py +0 -0
  2. cogames_agents/evals/__init__.py +5 -0
  3. cogames_agents/evals/planky_evals.py +415 -0
  4. cogames_agents/policy/__init__.py +0 -0
  5. cogames_agents/policy/evolution/__init__.py +0 -0
  6. cogames_agents/policy/evolution/cogsguard/__init__.py +0 -0
  7. cogames_agents/policy/evolution/cogsguard/evolution.py +695 -0
  8. cogames_agents/policy/evolution/cogsguard/evolutionary_coordinator.py +540 -0
  9. cogames_agents/policy/nim_agents/__init__.py +20 -0
  10. cogames_agents/policy/nim_agents/agents.py +98 -0
  11. cogames_agents/policy/nim_agents/bindings/generated/libnim_agents.dylib +0 -0
  12. cogames_agents/policy/nim_agents/bindings/generated/nim_agents.py +215 -0
  13. cogames_agents/policy/nim_agents/cogsguard_agents.nim +555 -0
  14. cogames_agents/policy/nim_agents/cogsguard_align_all_agents.nim +569 -0
  15. cogames_agents/policy/nim_agents/common.nim +1054 -0
  16. cogames_agents/policy/nim_agents/install.sh +1 -0
  17. cogames_agents/policy/nim_agents/ladybug_agent.nim +954 -0
  18. cogames_agents/policy/nim_agents/nim_agents.nim +68 -0
  19. cogames_agents/policy/nim_agents/nim_agents.nims +14 -0
  20. cogames_agents/policy/nim_agents/nimby.lock +3 -0
  21. cogames_agents/policy/nim_agents/racecar_agents.nim +844 -0
  22. cogames_agents/policy/nim_agents/random_agents.nim +68 -0
  23. cogames_agents/policy/nim_agents/test_agents.py +53 -0
  24. cogames_agents/policy/nim_agents/thinky_agents.nim +677 -0
  25. cogames_agents/policy/nim_agents/thinky_eval.py +230 -0
  26. cogames_agents/policy/scripted_agent/README.md +360 -0
  27. cogames_agents/policy/scripted_agent/__init__.py +0 -0
  28. cogames_agents/policy/scripted_agent/baseline_agent.py +1031 -0
  29. cogames_agents/policy/scripted_agent/cogas/__init__.py +5 -0
  30. cogames_agents/policy/scripted_agent/cogas/context.py +68 -0
  31. cogames_agents/policy/scripted_agent/cogas/entity_map.py +152 -0
  32. cogames_agents/policy/scripted_agent/cogas/goal.py +115 -0
  33. cogames_agents/policy/scripted_agent/cogas/goals/__init__.py +27 -0
  34. cogames_agents/policy/scripted_agent/cogas/goals/aligner.py +160 -0
  35. cogames_agents/policy/scripted_agent/cogas/goals/gear.py +197 -0
  36. cogames_agents/policy/scripted_agent/cogas/goals/miner.py +441 -0
  37. cogames_agents/policy/scripted_agent/cogas/goals/scout.py +40 -0
  38. cogames_agents/policy/scripted_agent/cogas/goals/scrambler.py +174 -0
  39. cogames_agents/policy/scripted_agent/cogas/goals/shared.py +160 -0
  40. cogames_agents/policy/scripted_agent/cogas/goals/stem.py +60 -0
  41. cogames_agents/policy/scripted_agent/cogas/goals/survive.py +100 -0
  42. cogames_agents/policy/scripted_agent/cogas/navigator.py +401 -0
  43. cogames_agents/policy/scripted_agent/cogas/obs_parser.py +238 -0
  44. cogames_agents/policy/scripted_agent/cogas/policy.py +525 -0
  45. cogames_agents/policy/scripted_agent/cogas/trace.py +69 -0
  46. cogames_agents/policy/scripted_agent/cogsguard/CLAUDE.md +517 -0
  47. cogames_agents/policy/scripted_agent/cogsguard/README.md +252 -0
  48. cogames_agents/policy/scripted_agent/cogsguard/__init__.py +74 -0
  49. cogames_agents/policy/scripted_agent/cogsguard/aligned_junction_held_investigation.md +152 -0
  50. cogames_agents/policy/scripted_agent/cogsguard/aligner.py +333 -0
  51. cogames_agents/policy/scripted_agent/cogsguard/behavior_hooks.py +44 -0
  52. cogames_agents/policy/scripted_agent/cogsguard/control_agent.py +323 -0
  53. cogames_agents/policy/scripted_agent/cogsguard/debug_agent.py +533 -0
  54. cogames_agents/policy/scripted_agent/cogsguard/miner.py +589 -0
  55. cogames_agents/policy/scripted_agent/cogsguard/options.py +67 -0
  56. cogames_agents/policy/scripted_agent/cogsguard/parity_metrics.py +36 -0
  57. cogames_agents/policy/scripted_agent/cogsguard/policy.py +1967 -0
  58. cogames_agents/policy/scripted_agent/cogsguard/prereq_trace.py +33 -0
  59. cogames_agents/policy/scripted_agent/cogsguard/role_trace.py +50 -0
  60. cogames_agents/policy/scripted_agent/cogsguard/roles.py +31 -0
  61. cogames_agents/policy/scripted_agent/cogsguard/rollout_trace.py +40 -0
  62. cogames_agents/policy/scripted_agent/cogsguard/scout.py +69 -0
  63. cogames_agents/policy/scripted_agent/cogsguard/scrambler.py +350 -0
  64. cogames_agents/policy/scripted_agent/cogsguard/targeted_agent.py +418 -0
  65. cogames_agents/policy/scripted_agent/cogsguard/teacher.py +224 -0
  66. cogames_agents/policy/scripted_agent/cogsguard/types.py +381 -0
  67. cogames_agents/policy/scripted_agent/cogsguard/v2_agent.py +49 -0
  68. cogames_agents/policy/scripted_agent/common/__init__.py +0 -0
  69. cogames_agents/policy/scripted_agent/common/geometry.py +24 -0
  70. cogames_agents/policy/scripted_agent/common/roles.py +34 -0
  71. cogames_agents/policy/scripted_agent/common/tag_utils.py +48 -0
  72. cogames_agents/policy/scripted_agent/demo_policy.py +242 -0
  73. cogames_agents/policy/scripted_agent/pathfinding.py +126 -0
  74. cogames_agents/policy/scripted_agent/pinky/DESIGN.md +317 -0
  75. cogames_agents/policy/scripted_agent/pinky/__init__.py +5 -0
  76. cogames_agents/policy/scripted_agent/pinky/behaviors/__init__.py +17 -0
  77. cogames_agents/policy/scripted_agent/pinky/behaviors/aligner.py +400 -0
  78. cogames_agents/policy/scripted_agent/pinky/behaviors/base.py +119 -0
  79. cogames_agents/policy/scripted_agent/pinky/behaviors/miner.py +632 -0
  80. cogames_agents/policy/scripted_agent/pinky/behaviors/scout.py +138 -0
  81. cogames_agents/policy/scripted_agent/pinky/behaviors/scrambler.py +433 -0
  82. cogames_agents/policy/scripted_agent/pinky/policy.py +570 -0
  83. cogames_agents/policy/scripted_agent/pinky/services/__init__.py +7 -0
  84. cogames_agents/policy/scripted_agent/pinky/services/map_tracker.py +808 -0
  85. cogames_agents/policy/scripted_agent/pinky/services/navigator.py +864 -0
  86. cogames_agents/policy/scripted_agent/pinky/services/safety.py +189 -0
  87. cogames_agents/policy/scripted_agent/pinky/state.py +299 -0
  88. cogames_agents/policy/scripted_agent/pinky/types.py +138 -0
  89. cogames_agents/policy/scripted_agent/planky/CLAUDE.md +124 -0
  90. cogames_agents/policy/scripted_agent/planky/IMPROVEMENTS.md +160 -0
  91. cogames_agents/policy/scripted_agent/planky/NOTES.md +153 -0
  92. cogames_agents/policy/scripted_agent/planky/PLAN.md +254 -0
  93. cogames_agents/policy/scripted_agent/planky/README.md +214 -0
  94. cogames_agents/policy/scripted_agent/planky/STRATEGY.md +100 -0
  95. cogames_agents/policy/scripted_agent/planky/__init__.py +5 -0
  96. cogames_agents/policy/scripted_agent/planky/context.py +68 -0
  97. cogames_agents/policy/scripted_agent/planky/entity_map.py +152 -0
  98. cogames_agents/policy/scripted_agent/planky/goal.py +107 -0
  99. cogames_agents/policy/scripted_agent/planky/goals/__init__.py +27 -0
  100. cogames_agents/policy/scripted_agent/planky/goals/aligner.py +168 -0
  101. cogames_agents/policy/scripted_agent/planky/goals/gear.py +179 -0
  102. cogames_agents/policy/scripted_agent/planky/goals/miner.py +416 -0
  103. cogames_agents/policy/scripted_agent/planky/goals/scout.py +40 -0
  104. cogames_agents/policy/scripted_agent/planky/goals/scrambler.py +174 -0
  105. cogames_agents/policy/scripted_agent/planky/goals/shared.py +160 -0
  106. cogames_agents/policy/scripted_agent/planky/goals/stem.py +49 -0
  107. cogames_agents/policy/scripted_agent/planky/goals/survive.py +96 -0
  108. cogames_agents/policy/scripted_agent/planky/navigator.py +388 -0
  109. cogames_agents/policy/scripted_agent/planky/obs_parser.py +238 -0
  110. cogames_agents/policy/scripted_agent/planky/policy.py +485 -0
  111. cogames_agents/policy/scripted_agent/planky/tests/__init__.py +0 -0
  112. cogames_agents/policy/scripted_agent/planky/tests/conftest.py +66 -0
  113. cogames_agents/policy/scripted_agent/planky/tests/helpers.py +152 -0
  114. cogames_agents/policy/scripted_agent/planky/tests/test_aligner.py +24 -0
  115. cogames_agents/policy/scripted_agent/planky/tests/test_miner.py +30 -0
  116. cogames_agents/policy/scripted_agent/planky/tests/test_scout.py +15 -0
  117. cogames_agents/policy/scripted_agent/planky/tests/test_scrambler.py +29 -0
  118. cogames_agents/policy/scripted_agent/planky/tests/test_stem.py +36 -0
  119. cogames_agents/policy/scripted_agent/planky/trace.py +69 -0
  120. cogames_agents/policy/scripted_agent/types.py +239 -0
  121. cogames_agents/policy/scripted_agent/unclipping_agent.py +461 -0
  122. cogames_agents/policy/scripted_agent/utils.py +381 -0
  123. cogames_agents/policy/scripted_registry.py +80 -0
  124. cogames_agents/py.typed +0 -0
  125. cogames_agents-0.0.0.7.dist-info/METADATA +98 -0
  126. cogames_agents-0.0.0.7.dist-info/RECORD +128 -0
  127. cogames_agents-0.0.0.7.dist-info/WHEEL +6 -0
  128. cogames_agents-0.0.0.7.dist-info/top_level.txt +1 -0
@@ -0,0 +1,418 @@
1
+ """CoGsGuard scripted policy with targeted role assignments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Optional
7
+
8
+ from cogames_agents.policy.scripted_agent.utils import change_vibe_action
9
+ from mettagrid.policy.policy import StatefulAgentPolicy
10
+ from mettagrid.policy.policy_env_interface import PolicyEnvInterface
11
+ from mettagrid.simulator import Action
12
+
13
+ from .aligner import AlignerAgentPolicyImpl
14
+ from .miner import HEALING_AOE_RANGE, MinerAgentPolicyImpl
15
+ from .policy import DEBUG, CogsguardAgentPolicyImpl, CogsguardMultiRoleImpl, CogsguardPolicy
16
+ from .scout import ScoutAgentPolicyImpl
17
+ from .scrambler import ScramblerAgentPolicyImpl
18
+ from .types import CogsguardAgentState, Role, StructureInfo, StructureType
19
+
20
+ PLAN_INTERVAL_STEPS = 25
21
+ PHASE_EXPLORE_END = 80
22
+ PHASE_CONTROL_END = 260
23
+ CHEST_LOW_THRESHOLD = 60
24
+ CONTROL_VIBES = {"scrambler", "aligner"}
25
+ RESOURCE_CYCLE = ["carbon", "oxygen", "germanium", "silicon"]
26
+
27
+
28
+ def _default_role_counts(num_agents: int) -> dict[str, int]:
29
+ if num_agents <= 1:
30
+ return {"miner": 1}
31
+ if num_agents == 2:
32
+ return {"scrambler": 1, "miner": 1}
33
+ if num_agents == 3:
34
+ return {"scrambler": 1, "miner": 1, "scout": 1}
35
+ if num_agents <= 7:
36
+ scramblers = 1
37
+ aligners = 1
38
+ scouts = 1
39
+ else:
40
+ scramblers = max(2, num_agents // 6)
41
+ aligners = max(2, num_agents // 6)
42
+ scouts = 1
43
+ miners = max(1, num_agents - scramblers - scouts - aligners)
44
+ return {
45
+ "scrambler": scramblers,
46
+ "aligner": aligners,
47
+ "miner": miners,
48
+ "scout": scouts,
49
+ }
50
+
51
+
52
+ def _normalize_counts(num_agents: int, counts: dict[str, int]) -> dict[str, int]:
53
+ normalized = {k: v for k, v in counts.items() if isinstance(v, int)}
54
+ total = sum(normalized.values())
55
+ if total < num_agents:
56
+ normalized["miner"] = normalized.get("miner", 0) + (num_agents - total)
57
+ elif total > num_agents:
58
+ overflow = total - num_agents
59
+ miners = normalized.get("miner", 0)
60
+ normalized["miner"] = max(0, miners - overflow)
61
+ return normalized
62
+
63
+
64
+ def _build_role_plan(num_agents: int, counts: dict[str, int]) -> list[str]:
65
+ ordered: list[str] = []
66
+ for role_name in ["scrambler", "aligner", "miner", "scout"]:
67
+ ordered.extend([role_name] * counts.get(role_name, 0))
68
+ if len(ordered) < num_agents:
69
+ ordered.extend(["miner"] * (num_agents - len(ordered)))
70
+ return ordered[:num_agents]
71
+
72
+
73
+ @dataclass
74
+ class TargetedPlannerState:
75
+ num_agents: int
76
+ desired_vibes: list[str] = field(default_factory=list)
77
+ last_plan_step: int = 0
78
+ known_junctions: int = 0
79
+ aligned_junctions: int = 0
80
+ chest_resources: int = 0
81
+ junction_map: dict[tuple[int, int], Optional[str]] = field(default_factory=dict)
82
+ extractor_map: dict[tuple[int, int], Optional[str]] = field(default_factory=dict)
83
+ assigned_junctions: dict[int, tuple[int, int]] = field(default_factory=dict)
84
+ assigned_extractors: dict[int, tuple[int, int]] = field(default_factory=dict)
85
+
86
+ def update_from_agent(self, s: CogsguardAgentState) -> None:
87
+ junctions = s.get_structures_by_type(StructureType.CHARGER)
88
+ aligned = [c for c in junctions if c.alignment == "cogs"]
89
+ self.known_junctions = max(self.known_junctions, len(junctions))
90
+ self.aligned_junctions = max(self.aligned_junctions, len(aligned))
91
+ for junction in junctions:
92
+ self.junction_map[junction.position] = junction.alignment
93
+
94
+ for extractor in s.get_usable_extractors():
95
+ self.extractor_map[extractor.position] = extractor.resource_type
96
+
97
+ chest_resources = 0
98
+ for struct in s.get_structures_by_type(StructureType.CHEST):
99
+ chest_resources = max(chest_resources, struct.inventory_amount)
100
+ if chest_resources > 0:
101
+ self.chest_resources = max(self.chest_resources, chest_resources)
102
+
103
+ def maybe_plan(self, step_count: int) -> None:
104
+ if step_count - self.last_plan_step < PLAN_INTERVAL_STEPS:
105
+ return
106
+ self.last_plan_step = step_count
107
+
108
+ counts = self._choose_counts(step_count)
109
+ self.desired_vibes = _build_role_plan(self.num_agents, counts)
110
+ self._assign_targets()
111
+
112
+ if DEBUG:
113
+ print(
114
+ f"[TARGETED] plan@{step_count}: junctions={self.known_junctions} "
115
+ f"aligned={self.aligned_junctions} chest={self.chest_resources} "
116
+ f"roles={counts}"
117
+ )
118
+
119
+ def _choose_counts(self, step_count: int) -> dict[str, int]:
120
+ if step_count < PHASE_EXPLORE_END or self.known_junctions == 0:
121
+ scouts = 3 if self.num_agents >= 8 else 2 if self.num_agents >= 5 else 1
122
+ return {
123
+ "scrambler": 0,
124
+ "aligner": 0,
125
+ "scout": scouts,
126
+ "miner": max(1, self.num_agents - scouts),
127
+ }
128
+
129
+ if 0 < self.chest_resources < CHEST_LOW_THRESHOLD:
130
+ scramblers = 1
131
+ aligners = 1
132
+ scouts = 1
133
+ return {
134
+ "scrambler": scramblers,
135
+ "aligner": aligners,
136
+ "scout": scouts,
137
+ "miner": max(1, self.num_agents - (scramblers + aligners + scouts)),
138
+ }
139
+
140
+ if step_count < PHASE_CONTROL_END and self.aligned_junctions < max(1, self.known_junctions // 2):
141
+ if self.num_agents >= 8:
142
+ scramblers = 2
143
+ aligners = 3
144
+ elif self.num_agents >= 6:
145
+ scramblers = 1
146
+ aligners = 2
147
+ else:
148
+ scramblers = 1
149
+ aligners = 1
150
+ return {
151
+ "scrambler": scramblers,
152
+ "aligner": aligners,
153
+ "scout": 1,
154
+ "miner": max(1, self.num_agents - (scramblers + aligners + 1)),
155
+ }
156
+
157
+ return {
158
+ "scrambler": 1,
159
+ "aligner": 2 if self.num_agents >= 6 else 1,
160
+ "scout": 1,
161
+ "miner": max(1, self.num_agents - (2 if self.num_agents >= 6 else 1) - 2),
162
+ }
163
+
164
+ def _assign_targets(self) -> None:
165
+ junctions = [pos for pos, alignment in self.junction_map.items() if alignment != "cogs"]
166
+ junctions.sort()
167
+ extractors_by_resource: dict[str, list[tuple[int, int]]] = {res: [] for res in RESOURCE_CYCLE}
168
+ for pos, resource in self.extractor_map.items():
169
+ if resource in extractors_by_resource:
170
+ extractors_by_resource[resource].append(pos)
171
+ for positions in extractors_by_resource.values():
172
+ positions.sort()
173
+ all_extractors = sorted(self.extractor_map.keys())
174
+
175
+ self.assigned_junctions.clear()
176
+ self.assigned_extractors.clear()
177
+ if not self.desired_vibes:
178
+ return
179
+
180
+ junction_index = 0
181
+ extractor_index = 0
182
+ for agent_id, vibe in enumerate(self.desired_vibes):
183
+ if vibe in CONTROL_VIBES and junctions:
184
+ self.assigned_junctions[agent_id] = junctions[junction_index]
185
+ junction_index = (junction_index + 1) % len(junctions)
186
+ elif vibe == "miner" and self.extractor_map:
187
+ preferred = RESOURCE_CYCLE[agent_id % len(RESOURCE_CYCLE)]
188
+ preferred_list = extractors_by_resource.get(preferred, [])
189
+ if preferred_list:
190
+ self.assigned_extractors[agent_id] = preferred_list[extractor_index % len(preferred_list)]
191
+ elif all_extractors:
192
+ self.assigned_extractors[agent_id] = all_extractors[extractor_index % len(all_extractors)]
193
+ extractor_index += 1
194
+
195
+
196
+ class TargetedMultiRoleImpl(CogsguardMultiRoleImpl):
197
+ def __init__(
198
+ self,
199
+ policy_env_info: PolicyEnvInterface,
200
+ agent_id: int,
201
+ initial_target_vibe: Optional[str],
202
+ shared_state: TargetedPlannerState,
203
+ ):
204
+ super().__init__(policy_env_info, agent_id, initial_target_vibe=initial_target_vibe)
205
+ self._shared_state = shared_state
206
+
207
+ def _execute_phase(self, s: CogsguardAgentState) -> Action:
208
+ self._shared_state.update_from_agent(s)
209
+ if s.agent_id == 0:
210
+ self._shared_state.maybe_plan(s.step_count)
211
+
212
+ if self._shared_state.desired_vibes:
213
+ desired = self._shared_state.desired_vibes[s.agent_id]
214
+ if desired != s.current_vibe:
215
+ return change_vibe_action(desired, action_names=self._action_names)
216
+
217
+ return super()._execute_phase(s)
218
+
219
+ def execute_role(self, s: CogsguardAgentState) -> Action:
220
+ target = self._shared_state.assigned_junctions.get(s.agent_id)
221
+ if target and s.current_vibe in CONTROL_VIBES and s.has_gear() and s.heart >= 1:
222
+ struct = s.get_structure_at(target)
223
+ if struct and struct.alignment != "cogs":
224
+ if abs(target[0] - s.row) + abs(target[1] - s.col) > 1:
225
+ return self._move_towards(s, target, reach_adjacent=True)
226
+ return self._use_object_at(s, target)
227
+ return super().execute_role(s)
228
+
229
+ def _get_role_impl(self, role: Role) -> CogsguardAgentPolicyImpl:
230
+ if role not in self._role_impls:
231
+ impl_class = {
232
+ Role.MINER: TargetedMinerAgentPolicyImpl,
233
+ Role.SCOUT: ScoutAgentPolicyImpl,
234
+ Role.ALIGNER: TargetedAlignerAgentPolicyImpl,
235
+ Role.SCRAMBLER: TargetedScramblerAgentPolicyImpl,
236
+ }[role]
237
+ if role == Role.MINER:
238
+ self._role_impls[role] = impl_class(self._policy_env_info, self._agent_id, role, self._shared_state)
239
+ else:
240
+ self._role_impls[role] = impl_class(self._policy_env_info, self._agent_id, role)
241
+ return self._role_impls[role]
242
+
243
+
244
+ class TargetedScramblerAgentPolicyImpl(ScramblerAgentPolicyImpl):
245
+ def _find_best_target(self, s: CogsguardAgentState) -> Optional[tuple[int, int]]:
246
+ junctions = s.get_structures_by_type(StructureType.CHARGER)
247
+ cooldown = 20 if len(junctions) <= 4 else 50
248
+
249
+ enemy_junctions: list[tuple[int, tuple[int, int]]] = []
250
+ neutral_junctions: list[tuple[int, tuple[int, int]]] = []
251
+ any_junctions: list[tuple[int, tuple[int, int]]] = []
252
+
253
+ for junction in junctions:
254
+ pos = junction.position
255
+ dist = abs(pos[0] - s.row) + abs(pos[1] - s.col)
256
+
257
+ last_worked = s.worked_junctions.get(pos, 0)
258
+ if last_worked > 0 and s.step_count - last_worked < cooldown:
259
+ continue
260
+
261
+ if junction.alignment == "cogs":
262
+ continue
263
+
264
+ if junction.alignment == "clips" or junction.clipped:
265
+ enemy_junctions.append((dist, pos))
266
+ elif junction.alignment is None or junction.alignment == "neutral":
267
+ neutral_junctions.append((dist, pos))
268
+ else:
269
+ any_junctions.append((dist, pos))
270
+
271
+ if enemy_junctions:
272
+ enemy_junctions.sort()
273
+ return enemy_junctions[0][1]
274
+ if neutral_junctions:
275
+ neutral_junctions.sort()
276
+ return neutral_junctions[0][1]
277
+ if any_junctions:
278
+ any_junctions.sort()
279
+ return any_junctions[0][1]
280
+
281
+ return super()._find_best_target(s)
282
+
283
+
284
+ class TargetedAlignerAgentPolicyImpl(AlignerAgentPolicyImpl):
285
+ def _find_best_target(self, s: CogsguardAgentState) -> Optional[tuple[int, int]]:
286
+ junctions = s.get_structures_by_type(StructureType.CHARGER)
287
+ cooldown = 20 if len(junctions) <= 4 else 50
288
+
289
+ neutral_junctions: list[tuple[int, tuple[int, int]]] = []
290
+ clips_junctions: list[tuple[int, tuple[int, int]]] = []
291
+ other_junctions: list[tuple[int, tuple[int, int]]] = []
292
+
293
+ for junction in junctions:
294
+ pos = junction.position
295
+ dist = abs(pos[0] - s.row) + abs(pos[1] - s.col)
296
+
297
+ last_worked = s.worked_junctions.get(pos, 0)
298
+ if last_worked > 0 and s.step_count - last_worked < cooldown:
299
+ continue
300
+
301
+ if junction.alignment == "cogs":
302
+ continue
303
+
304
+ if junction.alignment is None or junction.alignment == "neutral":
305
+ neutral_junctions.append((dist, pos))
306
+ elif junction.alignment == "clips" or junction.clipped:
307
+ clips_junctions.append((dist, pos))
308
+ else:
309
+ other_junctions.append((dist, pos))
310
+
311
+ if neutral_junctions:
312
+ neutral_junctions.sort()
313
+ return neutral_junctions[0][1]
314
+ if clips_junctions:
315
+ clips_junctions.sort()
316
+ return clips_junctions[0][1]
317
+ if other_junctions:
318
+ other_junctions.sort()
319
+ return other_junctions[0][1]
320
+
321
+ return super()._find_best_target(s)
322
+
323
+
324
+ class TargetedMinerAgentPolicyImpl(MinerAgentPolicyImpl):
325
+ def __init__(
326
+ self,
327
+ policy_env_info: PolicyEnvInterface,
328
+ agent_id: int,
329
+ role: Role,
330
+ shared_state: TargetedPlannerState,
331
+ ):
332
+ super().__init__(policy_env_info, agent_id, role)
333
+ self._shared_state = shared_state
334
+ self._preferred_resource = RESOURCE_CYCLE[agent_id % len(RESOURCE_CYCLE)]
335
+
336
+ def _get_safe_extractor(
337
+ self,
338
+ s: CogsguardAgentState,
339
+ preferred_resource: str | None = None,
340
+ ) -> Optional[StructureInfo]:
341
+ target = self._shared_state.assigned_extractors.get(s.agent_id)
342
+ if target:
343
+ current = s.get_structure_at(target)
344
+ if current and current.is_usable_extractor():
345
+ max_safe_dist = self._get_max_safe_distance(s)
346
+ dist_to_ext = abs(target[0] - s.row) + abs(target[1] - s.col)
347
+ nearest_depot = self._get_nearest_aligned_depot(s)
348
+ if nearest_depot:
349
+ dist_ext_to_depot = abs(target[0] - nearest_depot[0]) + abs(target[1] - nearest_depot[1])
350
+ round_trip = dist_to_ext + max(0, dist_ext_to_depot - HEALING_AOE_RANGE)
351
+ else:
352
+ round_trip = dist_to_ext * 2
353
+ if round_trip <= max_safe_dist:
354
+ return current
355
+ resource = preferred_resource or self._preferred_resource
356
+ preferred = [ext for ext in s.get_usable_extractors() if ext.resource_type == resource]
357
+ if preferred:
358
+ nearest_depot = self._get_nearest_aligned_depot(s)
359
+ max_safe_dist = self._get_max_safe_distance(s)
360
+ candidates: list[tuple[int, int, int, StructureInfo]] = []
361
+ for ext in preferred:
362
+ dist_to_ext = abs(ext.position[0] - s.row) + abs(ext.position[1] - s.col)
363
+ if nearest_depot:
364
+ dist_ext_to_depot = abs(ext.position[0] - nearest_depot[0]) + abs(
365
+ ext.position[1] - nearest_depot[1]
366
+ )
367
+ round_trip = dist_to_ext + max(0, dist_ext_to_depot - HEALING_AOE_RANGE)
368
+ else:
369
+ round_trip = dist_to_ext * 2
370
+ if round_trip <= max_safe_dist:
371
+ dist_ext_to_depot = (
372
+ abs(ext.position[0] - nearest_depot[0]) + abs(ext.position[1] - nearest_depot[1])
373
+ if nearest_depot
374
+ else 100
375
+ )
376
+ candidates.append((ext.inventory_amount, dist_ext_to_depot, dist_to_ext, ext))
377
+ if candidates:
378
+ candidates.sort(key=lambda x: (-x[0], x[1], x[2]))
379
+ return candidates[0][3]
380
+
381
+ return super()._get_safe_extractor(s, preferred_resource=preferred_resource)
382
+
383
+
384
+ class CogsguardTargetedAgent(CogsguardPolicy):
385
+ """CoGsGuard policy with coordinated role and target assignment."""
386
+
387
+ short_names = ["cogsguard_targeted"]
388
+
389
+ def __init__(
390
+ self,
391
+ policy_env_info: PolicyEnvInterface,
392
+ device: str = "cpu",
393
+ **vibe_counts: Any,
394
+ ):
395
+ has_explicit_counts = any(isinstance(v, int) for v in vibe_counts.values())
396
+ if has_explicit_counts:
397
+ counts = _normalize_counts(policy_env_info.num_agents, vibe_counts)
398
+ else:
399
+ counts = _default_role_counts(policy_env_info.num_agents)
400
+ super().__init__(policy_env_info, device=device, **counts)
401
+ self._shared_state = TargetedPlannerState(policy_env_info.num_agents)
402
+ self._shared_state.desired_vibes = _build_role_plan(policy_env_info.num_agents, counts)
403
+
404
+ def agent_policy(self, agent_id: int) -> StatefulAgentPolicy[CogsguardAgentState]:
405
+ if agent_id not in self._agent_policies:
406
+ target_vibe = None
407
+ if agent_id < len(self._initial_vibes):
408
+ target_vibe = self._initial_vibes[agent_id]
409
+
410
+ impl = TargetedMultiRoleImpl(
411
+ self._policy_env_info,
412
+ agent_id,
413
+ initial_target_vibe=target_vibe,
414
+ shared_state=self._shared_state,
415
+ )
416
+ self._agent_policies[agent_id] = StatefulAgentPolicy(impl, self._policy_env_info, agent_id=agent_id)
417
+
418
+ return self._agent_policies[agent_id]
@@ -0,0 +1,224 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Sequence
4
+
5
+ import numpy as np
6
+
7
+ from cogames_agents.policy.nim_agents.agents import CogsguardAgentsMultiPolicy
8
+ from cogames_agents.policy.scripted_agent.cogsguard.types import Role as CogsguardRole
9
+ from cogames_agents.policy.scripted_agent.common.roles import ROLE_VIBES
10
+ from mettagrid.policy.policy import AgentPolicy, MultiAgentPolicy
11
+ from mettagrid.policy.policy_env_interface import PolicyEnvInterface
12
+ from mettagrid.simulator import Action, AgentObservation
13
+
14
+ DEFAULT_ROLE_VIBES = tuple(ROLE_VIBES)
15
+
16
+
17
+ class CogsguardTeacherPolicy(MultiAgentPolicy):
18
+ """Teacher wrapper that forces an initial vibe, then delegates to the Nim policy."""
19
+
20
+ short_names = ["teacher"]
21
+
22
+ def __init__(
23
+ self,
24
+ policy_env_info: PolicyEnvInterface,
25
+ device: str = "cpu",
26
+ role_vibes: Optional[Sequence[str | CogsguardRole] | str] = None,
27
+ ) -> None:
28
+ super().__init__(policy_env_info, device=device)
29
+ self._delegate = CogsguardAgentsMultiPolicy(policy_env_info)
30
+ self._num_agents = policy_env_info.num_agents
31
+ self._action_names = list(policy_env_info.action_names)
32
+ self._action_name_to_index = {name: idx for idx, name in enumerate(self._action_names)}
33
+ self._delegate_agents = [self._delegate.agent_policy(i) for i in range(self._num_agents)]
34
+
35
+ self._episode_feature_id = self._find_feature_id("episode_completion_pct")
36
+ self._last_action_feature_id = self._find_feature_id("last_action")
37
+
38
+ self._role_action_ids = self._resolve_role_actions(role_vibes)
39
+ self._reset_episode_state()
40
+
41
+ def agent_policy(self, agent_id: int) -> AgentPolicy:
42
+ return _CogsguardTeacherAgentPolicy(self, agent_id)
43
+
44
+ def reset(self) -> None:
45
+ self._delegate.reset()
46
+ self._reset_episode_state()
47
+
48
+ def step_batch(self, raw_observations: np.ndarray, raw_actions: np.ndarray) -> None:
49
+ self._delegate.step_batch(raw_observations, raw_actions)
50
+ if not self._role_action_ids:
51
+ return
52
+ if raw_observations.shape[0] != self._num_agents:
53
+ return
54
+ for agent_id in range(self._num_agents):
55
+ episode_pct = self._extract_episode_pct_raw(raw_observations[agent_id])
56
+ last_action = self._extract_last_action_raw(raw_observations[agent_id])
57
+ forced_action = self._maybe_force_action(agent_id, episode_pct, last_action)
58
+ if forced_action is not None:
59
+ raw_actions[agent_id] = forced_action
60
+
61
+ def _step_single(self, agent_id: int, obs: AgentObservation) -> Action:
62
+ base_action = self._delegate_agents[agent_id].step(obs)
63
+ if not self._role_action_ids:
64
+ return base_action
65
+ episode_pct = self._extract_episode_pct_obs(obs)
66
+ last_action = self._extract_last_action_obs(obs)
67
+ forced_action = self._maybe_force_action(agent_id, episode_pct, last_action)
68
+ if forced_action is None:
69
+ return base_action
70
+ action_name = self._action_names[forced_action]
71
+ return Action(name=action_name)
72
+
73
+ def _extract_episode_pct_raw(self, raw_obs: np.ndarray) -> Optional[int]:
74
+ if self._episode_feature_id is None:
75
+ return None
76
+ for token in raw_obs:
77
+ if token[0] == 255 and token[1] == 255 and token[2] == 255:
78
+ break
79
+ if token[1] == self._episode_feature_id:
80
+ return int(token[2])
81
+ return 0
82
+
83
+ def _extract_episode_pct_obs(self, obs: AgentObservation) -> Optional[int]:
84
+ if self._episode_feature_id is None:
85
+ return None
86
+ for token in obs.tokens:
87
+ if token.feature.name == "episode_completion_pct":
88
+ return token.value
89
+ return 0
90
+
91
+ def _extract_last_action_raw(self, raw_obs: np.ndarray) -> Optional[int]:
92
+ if self._last_action_feature_id is None:
93
+ return None
94
+ for token in raw_obs:
95
+ if token[0] == 255 and token[1] == 255 and token[2] == 255:
96
+ break
97
+ if token[1] == self._last_action_feature_id:
98
+ return int(token[2])
99
+ return 0
100
+
101
+ def _extract_last_action_obs(self, obs: AgentObservation) -> Optional[int]:
102
+ if self._last_action_feature_id is None:
103
+ return None
104
+ for token in obs.tokens:
105
+ if token.feature.name == "last_action":
106
+ return token.value
107
+ return 0
108
+
109
+ def _find_feature_id(self, feature_name: str) -> Optional[int]:
110
+ for feature in self.policy_env_info.obs_features:
111
+ if feature.name == feature_name:
112
+ return feature.id
113
+ return None
114
+
115
+ def _resolve_role_actions(self, role_vibes: Optional[Sequence[str | CogsguardRole] | str]) -> list[int]:
116
+ change_vibe_actions = [name for name in self._action_names if name.startswith("change_vibe_")]
117
+ if len(change_vibe_actions) <= 1:
118
+ return []
119
+
120
+ available_vibes = [name[len("change_vibe_") :] for name in change_vibe_actions]
121
+ if role_vibes is None:
122
+ role_vibes = [vibe for vibe in DEFAULT_ROLE_VIBES if vibe in available_vibes]
123
+ if not role_vibes:
124
+ role_vibes = [vibe for vibe in available_vibes if vibe != "default"]
125
+ if not role_vibes:
126
+ role_vibes = available_vibes
127
+ else:
128
+ if isinstance(role_vibes, str):
129
+ normalized_vibes = [vibe.strip() for vibe in role_vibes.split(",") if vibe.strip()]
130
+ else:
131
+ normalized_vibes = [vibe.value if isinstance(vibe, CogsguardRole) else str(vibe) for vibe in role_vibes]
132
+ role_vibes = [vibe for vibe in normalized_vibes if vibe in available_vibes]
133
+ if not role_vibes:
134
+ role_vibes = available_vibes
135
+
136
+ role_action_ids = []
137
+ for vibe_name in role_vibes:
138
+ action_name = f"change_vibe_{vibe_name}"
139
+ action_id = self._action_name_to_index.get(action_name)
140
+ if action_id is not None:
141
+ role_action_ids.append(action_id)
142
+ return role_action_ids
143
+
144
+ def _reset_episode_state(self) -> None:
145
+ self._episode_index = [0] * self._num_agents
146
+ self._forced_vibe = [False] * self._num_agents
147
+ self._last_episode_pct = [-1] * self._num_agents
148
+ self._step_in_episode = [0] * self._num_agents
149
+ self._last_action_value: list[Optional[int]] = [None] * self._num_agents
150
+
151
+ def _maybe_force_action(
152
+ self,
153
+ agent_id: int,
154
+ episode_pct: Optional[int],
155
+ last_action: Optional[int],
156
+ ) -> Optional[int]:
157
+ self._update_episode_state(agent_id, episode_pct, last_action)
158
+ if self._forced_vibe[agent_id] or self._step_in_episode[agent_id] != 0:
159
+ return None
160
+ self._forced_vibe[agent_id] = True
161
+ role_index = (self._episode_index[agent_id] + agent_id) % len(self._role_action_ids)
162
+ return self._role_action_ids[role_index]
163
+
164
+ def _update_episode_state(
165
+ self,
166
+ agent_id: int,
167
+ episode_pct: Optional[int],
168
+ last_action: Optional[int],
169
+ ) -> None:
170
+ last_pct = self._last_episode_pct[agent_id]
171
+ if episode_pct is None:
172
+ last_action_seen = self._last_action_value[agent_id]
173
+ if (
174
+ last_action is not None
175
+ and last_action_seen is not None
176
+ and last_action == 0
177
+ and last_action_seen != 0
178
+ and self._step_in_episode[agent_id] > 0
179
+ ):
180
+ self._episode_index[agent_id] += 1
181
+ self._step_in_episode[agent_id] = 0
182
+ self._forced_vibe[agent_id] = False
183
+ self._last_episode_pct[agent_id] = 0
184
+ self._last_action_value[agent_id] = last_action
185
+ return
186
+
187
+ if last_pct == -1:
188
+ self._step_in_episode[agent_id] = 0
189
+ else:
190
+ self._step_in_episode[agent_id] += 1
191
+ self._last_episode_pct[agent_id] = 0
192
+ if last_action is not None:
193
+ self._last_action_value[agent_id] = last_action
194
+ return
195
+
196
+ new_episode = False
197
+ if last_pct == -1:
198
+ new_episode = True
199
+ elif episode_pct < last_pct:
200
+ new_episode = True
201
+ elif last_pct > 0 and episode_pct == 0:
202
+ new_episode = True
203
+
204
+ if new_episode:
205
+ if last_pct != -1:
206
+ self._episode_index[agent_id] += 1
207
+ self._step_in_episode[agent_id] = 0
208
+ self._forced_vibe[agent_id] = False
209
+ else:
210
+ self._step_in_episode[agent_id] += 1
211
+
212
+ self._last_episode_pct[agent_id] = episode_pct
213
+ if last_action is not None:
214
+ self._last_action_value[agent_id] = last_action
215
+
216
+
217
+ class _CogsguardTeacherAgentPolicy(AgentPolicy):
218
+ def __init__(self, parent: CogsguardTeacherPolicy, agent_id: int) -> None:
219
+ super().__init__(parent.policy_env_info)
220
+ self._parent = parent
221
+ self._agent_id = agent_id
222
+
223
+ def step(self, obs: AgentObservation) -> Action:
224
+ return self._parent._step_single(self._agent_id, obs)