cogames-agents 0.0.0.7__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cogames_agents/__init__.py +0 -0
- cogames_agents/evals/__init__.py +5 -0
- cogames_agents/evals/planky_evals.py +415 -0
- cogames_agents/policy/__init__.py +0 -0
- cogames_agents/policy/evolution/__init__.py +0 -0
- cogames_agents/policy/evolution/cogsguard/__init__.py +0 -0
- cogames_agents/policy/evolution/cogsguard/evolution.py +695 -0
- cogames_agents/policy/evolution/cogsguard/evolutionary_coordinator.py +540 -0
- cogames_agents/policy/nim_agents/__init__.py +20 -0
- cogames_agents/policy/nim_agents/agents.py +98 -0
- cogames_agents/policy/nim_agents/bindings/generated/libnim_agents.dylib +0 -0
- cogames_agents/policy/nim_agents/bindings/generated/nim_agents.py +215 -0
- cogames_agents/policy/nim_agents/cogsguard_agents.nim +555 -0
- cogames_agents/policy/nim_agents/cogsguard_align_all_agents.nim +569 -0
- cogames_agents/policy/nim_agents/common.nim +1054 -0
- cogames_agents/policy/nim_agents/install.sh +1 -0
- cogames_agents/policy/nim_agents/ladybug_agent.nim +954 -0
- cogames_agents/policy/nim_agents/nim_agents.nim +68 -0
- cogames_agents/policy/nim_agents/nim_agents.nims +14 -0
- cogames_agents/policy/nim_agents/nimby.lock +3 -0
- cogames_agents/policy/nim_agents/racecar_agents.nim +844 -0
- cogames_agents/policy/nim_agents/random_agents.nim +68 -0
- cogames_agents/policy/nim_agents/test_agents.py +53 -0
- cogames_agents/policy/nim_agents/thinky_agents.nim +677 -0
- cogames_agents/policy/nim_agents/thinky_eval.py +230 -0
- cogames_agents/policy/scripted_agent/README.md +360 -0
- cogames_agents/policy/scripted_agent/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/baseline_agent.py +1031 -0
- cogames_agents/policy/scripted_agent/cogas/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/cogas/context.py +68 -0
- cogames_agents/policy/scripted_agent/cogas/entity_map.py +152 -0
- cogames_agents/policy/scripted_agent/cogas/goal.py +115 -0
- cogames_agents/policy/scripted_agent/cogas/goals/__init__.py +27 -0
- cogames_agents/policy/scripted_agent/cogas/goals/aligner.py +160 -0
- cogames_agents/policy/scripted_agent/cogas/goals/gear.py +197 -0
- cogames_agents/policy/scripted_agent/cogas/goals/miner.py +441 -0
- cogames_agents/policy/scripted_agent/cogas/goals/scout.py +40 -0
- cogames_agents/policy/scripted_agent/cogas/goals/scrambler.py +174 -0
- cogames_agents/policy/scripted_agent/cogas/goals/shared.py +160 -0
- cogames_agents/policy/scripted_agent/cogas/goals/stem.py +60 -0
- cogames_agents/policy/scripted_agent/cogas/goals/survive.py +100 -0
- cogames_agents/policy/scripted_agent/cogas/navigator.py +401 -0
- cogames_agents/policy/scripted_agent/cogas/obs_parser.py +238 -0
- cogames_agents/policy/scripted_agent/cogas/policy.py +525 -0
- cogames_agents/policy/scripted_agent/cogas/trace.py +69 -0
- cogames_agents/policy/scripted_agent/cogsguard/CLAUDE.md +517 -0
- cogames_agents/policy/scripted_agent/cogsguard/README.md +252 -0
- cogames_agents/policy/scripted_agent/cogsguard/__init__.py +74 -0
- cogames_agents/policy/scripted_agent/cogsguard/aligned_junction_held_investigation.md +152 -0
- cogames_agents/policy/scripted_agent/cogsguard/aligner.py +333 -0
- cogames_agents/policy/scripted_agent/cogsguard/behavior_hooks.py +44 -0
- cogames_agents/policy/scripted_agent/cogsguard/control_agent.py +323 -0
- cogames_agents/policy/scripted_agent/cogsguard/debug_agent.py +533 -0
- cogames_agents/policy/scripted_agent/cogsguard/miner.py +589 -0
- cogames_agents/policy/scripted_agent/cogsguard/options.py +67 -0
- cogames_agents/policy/scripted_agent/cogsguard/parity_metrics.py +36 -0
- cogames_agents/policy/scripted_agent/cogsguard/policy.py +1967 -0
- cogames_agents/policy/scripted_agent/cogsguard/prereq_trace.py +33 -0
- cogames_agents/policy/scripted_agent/cogsguard/role_trace.py +50 -0
- cogames_agents/policy/scripted_agent/cogsguard/roles.py +31 -0
- cogames_agents/policy/scripted_agent/cogsguard/rollout_trace.py +40 -0
- cogames_agents/policy/scripted_agent/cogsguard/scout.py +69 -0
- cogames_agents/policy/scripted_agent/cogsguard/scrambler.py +350 -0
- cogames_agents/policy/scripted_agent/cogsguard/targeted_agent.py +418 -0
- cogames_agents/policy/scripted_agent/cogsguard/teacher.py +224 -0
- cogames_agents/policy/scripted_agent/cogsguard/types.py +381 -0
- cogames_agents/policy/scripted_agent/cogsguard/v2_agent.py +49 -0
- cogames_agents/policy/scripted_agent/common/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/common/geometry.py +24 -0
- cogames_agents/policy/scripted_agent/common/roles.py +34 -0
- cogames_agents/policy/scripted_agent/common/tag_utils.py +48 -0
- cogames_agents/policy/scripted_agent/demo_policy.py +242 -0
- cogames_agents/policy/scripted_agent/pathfinding.py +126 -0
- cogames_agents/policy/scripted_agent/pinky/DESIGN.md +317 -0
- cogames_agents/policy/scripted_agent/pinky/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/__init__.py +17 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/aligner.py +400 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/base.py +119 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/miner.py +632 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/scout.py +138 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/scrambler.py +433 -0
- cogames_agents/policy/scripted_agent/pinky/policy.py +570 -0
- cogames_agents/policy/scripted_agent/pinky/services/__init__.py +7 -0
- cogames_agents/policy/scripted_agent/pinky/services/map_tracker.py +808 -0
- cogames_agents/policy/scripted_agent/pinky/services/navigator.py +864 -0
- cogames_agents/policy/scripted_agent/pinky/services/safety.py +189 -0
- cogames_agents/policy/scripted_agent/pinky/state.py +299 -0
- cogames_agents/policy/scripted_agent/pinky/types.py +138 -0
- cogames_agents/policy/scripted_agent/planky/CLAUDE.md +124 -0
- cogames_agents/policy/scripted_agent/planky/IMPROVEMENTS.md +160 -0
- cogames_agents/policy/scripted_agent/planky/NOTES.md +153 -0
- cogames_agents/policy/scripted_agent/planky/PLAN.md +254 -0
- cogames_agents/policy/scripted_agent/planky/README.md +214 -0
- cogames_agents/policy/scripted_agent/planky/STRATEGY.md +100 -0
- cogames_agents/policy/scripted_agent/planky/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/planky/context.py +68 -0
- cogames_agents/policy/scripted_agent/planky/entity_map.py +152 -0
- cogames_agents/policy/scripted_agent/planky/goal.py +107 -0
- cogames_agents/policy/scripted_agent/planky/goals/__init__.py +27 -0
- cogames_agents/policy/scripted_agent/planky/goals/aligner.py +168 -0
- cogames_agents/policy/scripted_agent/planky/goals/gear.py +179 -0
- cogames_agents/policy/scripted_agent/planky/goals/miner.py +416 -0
- cogames_agents/policy/scripted_agent/planky/goals/scout.py +40 -0
- cogames_agents/policy/scripted_agent/planky/goals/scrambler.py +174 -0
- cogames_agents/policy/scripted_agent/planky/goals/shared.py +160 -0
- cogames_agents/policy/scripted_agent/planky/goals/stem.py +49 -0
- cogames_agents/policy/scripted_agent/planky/goals/survive.py +96 -0
- cogames_agents/policy/scripted_agent/planky/navigator.py +388 -0
- cogames_agents/policy/scripted_agent/planky/obs_parser.py +238 -0
- cogames_agents/policy/scripted_agent/planky/policy.py +485 -0
- cogames_agents/policy/scripted_agent/planky/tests/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/planky/tests/conftest.py +66 -0
- cogames_agents/policy/scripted_agent/planky/tests/helpers.py +152 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_aligner.py +24 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_miner.py +30 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_scout.py +15 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_scrambler.py +29 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_stem.py +36 -0
- cogames_agents/policy/scripted_agent/planky/trace.py +69 -0
- cogames_agents/policy/scripted_agent/types.py +239 -0
- cogames_agents/policy/scripted_agent/unclipping_agent.py +461 -0
- cogames_agents/policy/scripted_agent/utils.py +381 -0
- cogames_agents/policy/scripted_registry.py +80 -0
- cogames_agents/py.typed +0 -0
- cogames_agents-0.0.0.7.dist-info/METADATA +98 -0
- cogames_agents-0.0.0.7.dist-info/RECORD +128 -0
- cogames_agents-0.0.0.7.dist-info/WHEEL +6 -0
- cogames_agents-0.0.0.7.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
"""A* navigator for Cogas policy."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import heapq
|
|
6
|
+
import random
|
|
7
|
+
from typing import TYPE_CHECKING, Optional
|
|
8
|
+
|
|
9
|
+
from mettagrid.simulator import Action
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .entity_map import EntityMap
|
|
13
|
+
|
|
14
|
+
MOVE_DELTAS: dict[str, tuple[int, int]] = {
|
|
15
|
+
"north": (-1, 0),
|
|
16
|
+
"south": (1, 0),
|
|
17
|
+
"east": (0, 1),
|
|
18
|
+
"west": (0, -1),
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
DIRECTIONS = ["north", "south", "east", "west"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Navigator:
|
|
25
|
+
"""A* pathfinding over the entity map."""
|
|
26
|
+
|
|
27
|
+
def __init__(self) -> None:
|
|
28
|
+
self._cached_path: Optional[list[tuple[int, int]]] = None
|
|
29
|
+
self._cached_target: Optional[tuple[int, int]] = None
|
|
30
|
+
self._cached_reach_adjacent: bool = False
|
|
31
|
+
self._position_history: list[tuple[int, int]] = []
|
|
32
|
+
|
|
33
|
+
def get_action(
|
|
34
|
+
self,
|
|
35
|
+
current: tuple[int, int],
|
|
36
|
+
target: tuple[int, int],
|
|
37
|
+
map: EntityMap,
|
|
38
|
+
reach_adjacent: bool = False,
|
|
39
|
+
) -> Action:
|
|
40
|
+
"""Navigate from current to target using A*.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
current: Current position
|
|
44
|
+
target: Target position
|
|
45
|
+
map: Entity map for pathfinding
|
|
46
|
+
reach_adjacent: If True, stop adjacent to target
|
|
47
|
+
"""
|
|
48
|
+
# Track position history for stuck detection
|
|
49
|
+
self._position_history.append(current)
|
|
50
|
+
if len(self._position_history) > 30:
|
|
51
|
+
self._position_history.pop(0)
|
|
52
|
+
|
|
53
|
+
# Stuck detection
|
|
54
|
+
if self._is_stuck():
|
|
55
|
+
action = self._break_stuck(current, map)
|
|
56
|
+
if action:
|
|
57
|
+
return action
|
|
58
|
+
|
|
59
|
+
if current == target and not reach_adjacent:
|
|
60
|
+
# Already at target - bump in a random direction to stay active
|
|
61
|
+
return self._random_move(current, map)
|
|
62
|
+
|
|
63
|
+
# Check if adjacent to target (for reach_adjacent mode)
|
|
64
|
+
if reach_adjacent and _manhattan(current, target) == 1:
|
|
65
|
+
# Already adjacent - bump toward target instead of nooping
|
|
66
|
+
return _move_action(current, target)
|
|
67
|
+
|
|
68
|
+
# Get or compute path
|
|
69
|
+
path = self._get_path(current, target, map, reach_adjacent)
|
|
70
|
+
|
|
71
|
+
if not path:
|
|
72
|
+
# No path found — try exploring toward target
|
|
73
|
+
return self._move_toward_greedy(current, target, map)
|
|
74
|
+
|
|
75
|
+
next_pos = path[0]
|
|
76
|
+
|
|
77
|
+
# Check if next position is blocked by agent
|
|
78
|
+
if map.has_agent(next_pos):
|
|
79
|
+
sidestep = self._find_sidestep(current, next_pos, target, map)
|
|
80
|
+
if sidestep:
|
|
81
|
+
self._cached_path = None
|
|
82
|
+
return _move_action(current, sidestep)
|
|
83
|
+
# Don't wait (noop) - try random move to break congestion
|
|
84
|
+
self._cached_path = None
|
|
85
|
+
return self._random_move(current, map)
|
|
86
|
+
|
|
87
|
+
# Advance path
|
|
88
|
+
self._cached_path = path[1:] if len(path) > 1 else None
|
|
89
|
+
return _move_action(current, next_pos)
|
|
90
|
+
|
|
91
|
+
def explore(
|
|
92
|
+
self,
|
|
93
|
+
current: tuple[int, int],
|
|
94
|
+
map: EntityMap,
|
|
95
|
+
direction_bias: Optional[str] = None,
|
|
96
|
+
) -> Action:
|
|
97
|
+
"""Navigate toward unexplored frontier cells."""
|
|
98
|
+
self._position_history.append(current)
|
|
99
|
+
if len(self._position_history) > 30:
|
|
100
|
+
self._position_history.pop(0)
|
|
101
|
+
|
|
102
|
+
if self._is_stuck():
|
|
103
|
+
action = self._break_stuck(current, map)
|
|
104
|
+
if action:
|
|
105
|
+
return action
|
|
106
|
+
|
|
107
|
+
frontier = self._find_frontier(current, map, direction_bias)
|
|
108
|
+
if frontier:
|
|
109
|
+
return self.get_action(current, frontier, map)
|
|
110
|
+
|
|
111
|
+
# No frontier — random walk
|
|
112
|
+
return self._random_move(current, map)
|
|
113
|
+
|
|
114
|
+
def _get_path(
|
|
115
|
+
self,
|
|
116
|
+
start: tuple[int, int],
|
|
117
|
+
target: tuple[int, int],
|
|
118
|
+
map: EntityMap,
|
|
119
|
+
reach_adjacent: bool,
|
|
120
|
+
) -> Optional[list[tuple[int, int]]]:
|
|
121
|
+
"""Get cached path or compute new one."""
|
|
122
|
+
if self._cached_path and self._cached_target == target and self._cached_reach_adjacent == reach_adjacent:
|
|
123
|
+
# Verify path is still valid
|
|
124
|
+
for pos in self._cached_path:
|
|
125
|
+
if map.has_agent(pos):
|
|
126
|
+
break
|
|
127
|
+
else:
|
|
128
|
+
return self._cached_path
|
|
129
|
+
|
|
130
|
+
# Compute new path
|
|
131
|
+
goal_cells = self._compute_goals(target, map, reach_adjacent)
|
|
132
|
+
if not goal_cells:
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
# Try known terrain first
|
|
136
|
+
path = self._astar(start, goal_cells, map, allow_unknown=False)
|
|
137
|
+
if not path:
|
|
138
|
+
# Allow unknown cells
|
|
139
|
+
path = self._astar(start, goal_cells, map, allow_unknown=True)
|
|
140
|
+
|
|
141
|
+
self._cached_path = path.copy() if path else None
|
|
142
|
+
self._cached_target = target
|
|
143
|
+
self._cached_reach_adjacent = reach_adjacent
|
|
144
|
+
return path
|
|
145
|
+
|
|
146
|
+
def _compute_goals(
|
|
147
|
+
self,
|
|
148
|
+
target: tuple[int, int],
|
|
149
|
+
map: EntityMap,
|
|
150
|
+
reach_adjacent: bool,
|
|
151
|
+
) -> list[tuple[int, int]]:
|
|
152
|
+
if not reach_adjacent:
|
|
153
|
+
return [target]
|
|
154
|
+
goals = []
|
|
155
|
+
for dr, dc in MOVE_DELTAS.values():
|
|
156
|
+
nr, nc = target[0] + dr, target[1] + dc
|
|
157
|
+
pos = (nr, nc)
|
|
158
|
+
if self._is_traversable(pos, map, allow_unknown=True):
|
|
159
|
+
goals.append(pos)
|
|
160
|
+
return goals
|
|
161
|
+
|
|
162
|
+
def _astar(
|
|
163
|
+
self,
|
|
164
|
+
start: tuple[int, int],
|
|
165
|
+
goals: list[tuple[int, int]],
|
|
166
|
+
map: EntityMap,
|
|
167
|
+
allow_unknown: bool,
|
|
168
|
+
) -> list[tuple[int, int]]:
|
|
169
|
+
"""A* pathfinding with iteration limit to prevent hanging."""
|
|
170
|
+
goal_set = set(goals)
|
|
171
|
+
if not goals:
|
|
172
|
+
return []
|
|
173
|
+
|
|
174
|
+
def h(pos: tuple[int, int]) -> int:
|
|
175
|
+
return min(_manhattan(pos, g) for g in goals)
|
|
176
|
+
|
|
177
|
+
tie = 0
|
|
178
|
+
iterations = 0
|
|
179
|
+
max_iterations = 5000 # Prevent infinite search on large unknown maps
|
|
180
|
+
|
|
181
|
+
open_set: list[tuple[int, int, tuple[int, int]]] = [(h(start), tie, start)]
|
|
182
|
+
came_from: dict[tuple[int, int], Optional[tuple[int, int]]] = {start: None}
|
|
183
|
+
g_score: dict[tuple[int, int], int] = {start: 0}
|
|
184
|
+
|
|
185
|
+
while open_set and iterations < max_iterations:
|
|
186
|
+
iterations += 1
|
|
187
|
+
_, _, current = heapq.heappop(open_set)
|
|
188
|
+
|
|
189
|
+
if current in goal_set:
|
|
190
|
+
return self._reconstruct(came_from, current)
|
|
191
|
+
|
|
192
|
+
current_g = g_score.get(current, float("inf"))
|
|
193
|
+
if isinstance(current_g, float):
|
|
194
|
+
continue
|
|
195
|
+
|
|
196
|
+
for dr, dc in MOVE_DELTAS.values():
|
|
197
|
+
neighbor = (current[0] + dr, current[1] + dc)
|
|
198
|
+
is_goal = neighbor in goal_set
|
|
199
|
+
if not is_goal and not self._is_traversable(neighbor, map, allow_unknown):
|
|
200
|
+
continue
|
|
201
|
+
|
|
202
|
+
tentative_g = current_g + 1
|
|
203
|
+
if tentative_g < g_score.get(neighbor, float("inf")):
|
|
204
|
+
came_from[neighbor] = current
|
|
205
|
+
g_score[neighbor] = tentative_g
|
|
206
|
+
f = tentative_g + h(neighbor)
|
|
207
|
+
tie += 1
|
|
208
|
+
heapq.heappush(open_set, (f, tie, neighbor))
|
|
209
|
+
|
|
210
|
+
return []
|
|
211
|
+
|
|
212
|
+
def _reconstruct(
|
|
213
|
+
self,
|
|
214
|
+
came_from: dict[tuple[int, int], Optional[tuple[int, int]]],
|
|
215
|
+
current: tuple[int, int],
|
|
216
|
+
) -> list[tuple[int, int]]:
|
|
217
|
+
path = []
|
|
218
|
+
while came_from[current] is not None:
|
|
219
|
+
path.append(current)
|
|
220
|
+
prev = came_from[current]
|
|
221
|
+
assert prev is not None
|
|
222
|
+
current = prev
|
|
223
|
+
path.reverse()
|
|
224
|
+
return path
|
|
225
|
+
|
|
226
|
+
def _is_traversable(
|
|
227
|
+
self,
|
|
228
|
+
pos: tuple[int, int],
|
|
229
|
+
map: EntityMap,
|
|
230
|
+
allow_unknown: bool = False,
|
|
231
|
+
) -> bool:
|
|
232
|
+
"""Check if a cell can be walked through."""
|
|
233
|
+
if map.is_wall(pos) or map.is_structure(pos):
|
|
234
|
+
return False
|
|
235
|
+
if map.has_agent(pos):
|
|
236
|
+
return False
|
|
237
|
+
if pos in map.explored:
|
|
238
|
+
return pos not in map.entities or map.entities[pos].type == "agent"
|
|
239
|
+
# Unknown cell
|
|
240
|
+
return allow_unknown
|
|
241
|
+
|
|
242
|
+
def _find_frontier(
|
|
243
|
+
self,
|
|
244
|
+
from_pos: tuple[int, int],
|
|
245
|
+
map: EntityMap,
|
|
246
|
+
direction_bias: Optional[str] = None,
|
|
247
|
+
) -> Optional[tuple[int, int]]:
|
|
248
|
+
"""BFS to find nearest unexplored cell adjacent to explored free cell."""
|
|
249
|
+
from collections import deque
|
|
250
|
+
|
|
251
|
+
if direction_bias == "north":
|
|
252
|
+
deltas = [(-1, 0), (0, -1), (0, 1), (1, 0)]
|
|
253
|
+
elif direction_bias == "south":
|
|
254
|
+
deltas = [(1, 0), (0, -1), (0, 1), (-1, 0)]
|
|
255
|
+
elif direction_bias == "east":
|
|
256
|
+
deltas = [(0, 1), (-1, 0), (1, 0), (0, -1)]
|
|
257
|
+
elif direction_bias == "west":
|
|
258
|
+
deltas = [(0, -1), (-1, 0), (1, 0), (0, 1)]
|
|
259
|
+
else:
|
|
260
|
+
deltas = [(-1, 0), (1, 0), (0, -1), (0, 1)]
|
|
261
|
+
|
|
262
|
+
visited: set[tuple[int, int]] = {from_pos}
|
|
263
|
+
queue: deque[tuple[int, int, int]] = deque([(from_pos[0], from_pos[1], 0)])
|
|
264
|
+
|
|
265
|
+
while queue:
|
|
266
|
+
r, c, dist = queue.popleft()
|
|
267
|
+
if dist > 50:
|
|
268
|
+
continue
|
|
269
|
+
|
|
270
|
+
for dr, dc in deltas:
|
|
271
|
+
nr, nc = r + dr, c + dc
|
|
272
|
+
pos = (nr, nc)
|
|
273
|
+
if pos in visited:
|
|
274
|
+
continue
|
|
275
|
+
visited.add(pos)
|
|
276
|
+
|
|
277
|
+
if pos not in map.explored:
|
|
278
|
+
# Check if any neighbor is explored and free
|
|
279
|
+
for dr2, dc2 in deltas:
|
|
280
|
+
adj = (nr + dr2, nc + dc2)
|
|
281
|
+
if adj in map.explored and map.is_free(adj):
|
|
282
|
+
return pos
|
|
283
|
+
continue
|
|
284
|
+
|
|
285
|
+
if map.is_free(pos):
|
|
286
|
+
queue.append((nr, nc, dist + 1))
|
|
287
|
+
|
|
288
|
+
return None
|
|
289
|
+
|
|
290
|
+
def _find_sidestep(
|
|
291
|
+
self,
|
|
292
|
+
current: tuple[int, int],
|
|
293
|
+
blocked: tuple[int, int],
|
|
294
|
+
target: tuple[int, int],
|
|
295
|
+
map: EntityMap,
|
|
296
|
+
) -> Optional[tuple[int, int]]:
|
|
297
|
+
"""Find sidestep around blocking agent."""
|
|
298
|
+
current_dist = _manhattan(current, target)
|
|
299
|
+
candidates = []
|
|
300
|
+
for d in DIRECTIONS:
|
|
301
|
+
dr, dc = MOVE_DELTAS[d]
|
|
302
|
+
pos = (current[0] + dr, current[1] + dc)
|
|
303
|
+
if pos == blocked:
|
|
304
|
+
continue
|
|
305
|
+
if not self._is_traversable(pos, map, allow_unknown=True):
|
|
306
|
+
continue
|
|
307
|
+
new_dist = _manhattan(pos, target)
|
|
308
|
+
score = new_dist - current_dist
|
|
309
|
+
candidates.append((score, pos))
|
|
310
|
+
|
|
311
|
+
if not candidates:
|
|
312
|
+
return None
|
|
313
|
+
candidates.sort()
|
|
314
|
+
if candidates[0][0] <= 2:
|
|
315
|
+
return candidates[0][1]
|
|
316
|
+
return None
|
|
317
|
+
|
|
318
|
+
def _is_stuck(self) -> bool:
|
|
319
|
+
history = self._position_history
|
|
320
|
+
if len(history) < 6:
|
|
321
|
+
return False
|
|
322
|
+
recent = history[-6:]
|
|
323
|
+
if len(set(recent)) <= 2:
|
|
324
|
+
return True
|
|
325
|
+
if len(history) >= 20:
|
|
326
|
+
current = history[-1]
|
|
327
|
+
earlier = history[:-10]
|
|
328
|
+
if earlier.count(current) >= 2:
|
|
329
|
+
return True
|
|
330
|
+
return False
|
|
331
|
+
|
|
332
|
+
def _break_stuck(self, current: tuple[int, int], map: EntityMap) -> Optional[Action]:
|
|
333
|
+
self._cached_path = None
|
|
334
|
+
self._cached_target = None
|
|
335
|
+
self._position_history.clear()
|
|
336
|
+
return self._random_move(current, map)
|
|
337
|
+
|
|
338
|
+
def _random_move(self, current: tuple[int, int], map: EntityMap) -> Action:
|
|
339
|
+
dirs = list(DIRECTIONS)
|
|
340
|
+
random.shuffle(dirs)
|
|
341
|
+
# Try explored free cells first (excluding agent positions)
|
|
342
|
+
for d in dirs:
|
|
343
|
+
dr, dc = MOVE_DELTAS[d]
|
|
344
|
+
pos = (current[0] + dr, current[1] + dc)
|
|
345
|
+
if pos in map.explored and not map.is_wall(pos) and not map.is_structure(pos) and not map.has_agent(pos):
|
|
346
|
+
return Action(name=f"move_{d}")
|
|
347
|
+
# Try explored cells even with agents (will fail but better than noop)
|
|
348
|
+
for d in dirs:
|
|
349
|
+
dr, dc = MOVE_DELTAS[d]
|
|
350
|
+
pos = (current[0] + dr, current[1] + dc)
|
|
351
|
+
if pos in map.explored and not map.is_wall(pos) and not map.is_structure(pos):
|
|
352
|
+
return Action(name=f"move_{d}")
|
|
353
|
+
# Try unknown cells
|
|
354
|
+
for d in dirs:
|
|
355
|
+
dr, dc = MOVE_DELTAS[d]
|
|
356
|
+
pos = (current[0] + dr, current[1] + dc)
|
|
357
|
+
if not map.is_wall(pos):
|
|
358
|
+
return Action(name=f"move_{d}")
|
|
359
|
+
# Absolute last resort: try any direction (will likely fail but attempt something)
|
|
360
|
+
return Action(name=f"move_{dirs[0]}")
|
|
361
|
+
|
|
362
|
+
def _move_toward_greedy(self, current: tuple[int, int], target: tuple[int, int], map: EntityMap) -> Action:
|
|
363
|
+
"""Move greedily toward target without pathfinding."""
|
|
364
|
+
dr = target[0] - current[0]
|
|
365
|
+
dc = target[1] - current[1]
|
|
366
|
+
|
|
367
|
+
# Try primary direction
|
|
368
|
+
if abs(dr) >= abs(dc):
|
|
369
|
+
primary = "south" if dr > 0 else "north"
|
|
370
|
+
secondary = "east" if dc > 0 else "west"
|
|
371
|
+
else:
|
|
372
|
+
primary = "east" if dc > 0 else "west"
|
|
373
|
+
secondary = "south" if dr > 0 else "north"
|
|
374
|
+
|
|
375
|
+
for d in [primary, secondary]:
|
|
376
|
+
ddr, ddc = MOVE_DELTAS[d]
|
|
377
|
+
pos = (current[0] + ddr, current[1] + ddc)
|
|
378
|
+
if not map.is_wall(pos) and not map.is_structure(pos) and not map.has_agent(pos):
|
|
379
|
+
return Action(name=f"move_{d}")
|
|
380
|
+
|
|
381
|
+
return self._random_move(current, map)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _manhattan(a: tuple[int, int], b: tuple[int, int]) -> int:
|
|
385
|
+
return abs(a[0] - b[0]) + abs(a[1] - b[1])
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def _move_action(current: tuple[int, int], target: tuple[int, int]) -> Action:
|
|
389
|
+
"""Return move action from current to adjacent target."""
|
|
390
|
+
dr = target[0] - current[0]
|
|
391
|
+
dc = target[1] - current[1]
|
|
392
|
+
if dr == -1 and dc == 0:
|
|
393
|
+
return Action(name="move_north")
|
|
394
|
+
if dr == 1 and dc == 0:
|
|
395
|
+
return Action(name="move_south")
|
|
396
|
+
if dr == 0 and dc == 1:
|
|
397
|
+
return Action(name="move_east")
|
|
398
|
+
if dr == 0 and dc == -1:
|
|
399
|
+
return Action(name="move_west")
|
|
400
|
+
# Already at target - pick a random direction instead of nooping
|
|
401
|
+
return Action(name=f"move_{random.choice(['north', 'south', 'east', 'west'])}")
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
"""Observation parser for Cogas policy.
|
|
2
|
+
|
|
3
|
+
Converts raw observation tokens into StateSnapshot and visible entities.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from .context import StateSnapshot
|
|
11
|
+
from .entity_map import Entity
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
|
|
15
|
+
from mettagrid.simulator.interface import AgentObservation
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ObsParser:
|
|
19
|
+
"""Parses observation tokens into state snapshot and visible entities."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, policy_env_info: PolicyEnvInterface) -> None:
|
|
22
|
+
self._obs_hr = policy_env_info.obs_height // 2
|
|
23
|
+
self._obs_wr = policy_env_info.obs_width // 2
|
|
24
|
+
self._tag_names = policy_env_info.tag_id_to_name
|
|
25
|
+
|
|
26
|
+
# Derive vibe names from action names
|
|
27
|
+
self._vibe_names: list[str] = []
|
|
28
|
+
for action_name in policy_env_info.action_names:
|
|
29
|
+
if action_name.startswith("change_vibe_"):
|
|
30
|
+
self._vibe_names.append(action_name[len("change_vibe_") :])
|
|
31
|
+
|
|
32
|
+
# Collective name mapping
|
|
33
|
+
self._collective_names = ["clips", "cogs"] # Alphabetical
|
|
34
|
+
self._cogs_collective_id = 1 # "cogs" is index 1 alphabetically
|
|
35
|
+
self._clips_collective_id = 0 # "clips" is index 0
|
|
36
|
+
|
|
37
|
+
def parse(
|
|
38
|
+
self,
|
|
39
|
+
obs: AgentObservation,
|
|
40
|
+
step: int,
|
|
41
|
+
spawn_pos: tuple[int, int],
|
|
42
|
+
) -> tuple[StateSnapshot, dict[tuple[int, int], Entity]]:
|
|
43
|
+
"""Parse observation into state snapshot and visible entities.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
obs: Raw observation
|
|
47
|
+
step: Current tick
|
|
48
|
+
spawn_pos: Agent's spawn position for offset calculation
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
(state_snapshot, visible_entities_dict)
|
|
52
|
+
"""
|
|
53
|
+
state = StateSnapshot()
|
|
54
|
+
|
|
55
|
+
# Read center cell for inventory/vibe and local position
|
|
56
|
+
inv: dict[str, int] = {}
|
|
57
|
+
vibe_id = 0
|
|
58
|
+
# Local position tokens: lp:east/west for col offset, lp:north/south for row offset
|
|
59
|
+
lp_col_offset = 0 # east is positive, west is negative
|
|
60
|
+
lp_row_offset = 0 # south is positive, north is negative
|
|
61
|
+
has_position = False
|
|
62
|
+
|
|
63
|
+
center_r, center_c = self._obs_hr, self._obs_wr
|
|
64
|
+
|
|
65
|
+
for tok in obs.tokens:
|
|
66
|
+
if tok.row() == center_r and tok.col() == center_c:
|
|
67
|
+
feature_name = tok.feature.name
|
|
68
|
+
if feature_name.startswith("inv:"):
|
|
69
|
+
resource_name = feature_name[4:]
|
|
70
|
+
# Handle multi-token encoding
|
|
71
|
+
if ":p" in resource_name:
|
|
72
|
+
base_name, power_str = resource_name.rsplit(":p", 1)
|
|
73
|
+
power = int(power_str)
|
|
74
|
+
current = inv.get(base_name, 0)
|
|
75
|
+
inv[base_name] = current + tok.value * (256**power)
|
|
76
|
+
else:
|
|
77
|
+
current = inv.get(resource_name, 0)
|
|
78
|
+
inv[resource_name] = current + tok.value
|
|
79
|
+
elif feature_name == "vibe":
|
|
80
|
+
vibe_id = tok.value
|
|
81
|
+
# Local position tokens from local_position observation feature
|
|
82
|
+
elif feature_name == "lp:east":
|
|
83
|
+
lp_col_offset = tok.value
|
|
84
|
+
has_position = True
|
|
85
|
+
elif feature_name == "lp:west":
|
|
86
|
+
lp_col_offset = -tok.value
|
|
87
|
+
has_position = True
|
|
88
|
+
elif feature_name == "lp:south":
|
|
89
|
+
lp_row_offset = tok.value
|
|
90
|
+
has_position = True
|
|
91
|
+
elif feature_name == "lp:north":
|
|
92
|
+
lp_row_offset = -tok.value
|
|
93
|
+
has_position = True
|
|
94
|
+
|
|
95
|
+
# Build state - lp: tokens give offset from spawn
|
|
96
|
+
if has_position:
|
|
97
|
+
state.position = (spawn_pos[0] + lp_row_offset, spawn_pos[1] + lp_col_offset)
|
|
98
|
+
else:
|
|
99
|
+
state.position = spawn_pos
|
|
100
|
+
|
|
101
|
+
state.hp = inv.get("hp", 100)
|
|
102
|
+
state.energy = inv.get("energy", 100)
|
|
103
|
+
state.carbon = inv.get("carbon", 0)
|
|
104
|
+
state.oxygen = inv.get("oxygen", 0)
|
|
105
|
+
state.germanium = inv.get("germanium", 0)
|
|
106
|
+
state.silicon = inv.get("silicon", 0)
|
|
107
|
+
state.heart = inv.get("heart", 0)
|
|
108
|
+
state.influence = inv.get("influence", 0)
|
|
109
|
+
state.miner_gear = inv.get("miner", 0) > 0
|
|
110
|
+
state.scout_gear = inv.get("scout", 0) > 0
|
|
111
|
+
state.aligner_gear = inv.get("aligner", 0) > 0
|
|
112
|
+
state.scrambler_gear = inv.get("scrambler", 0) > 0
|
|
113
|
+
state.vibe = self._get_vibe_name(vibe_id)
|
|
114
|
+
|
|
115
|
+
# Read collective inventory from the inv dict.
|
|
116
|
+
# Collective tokens appear as "inv:collective:<resource>" features on the center cell,
|
|
117
|
+
# parsed above into keys like "collective:carbon", "collective:oxygen", etc.
|
|
118
|
+
state.collective_carbon = inv.get("collective:carbon", 0)
|
|
119
|
+
state.collective_oxygen = inv.get("collective:oxygen", 0)
|
|
120
|
+
state.collective_germanium = inv.get("collective:germanium", 0)
|
|
121
|
+
state.collective_silicon = inv.get("collective:silicon", 0)
|
|
122
|
+
state.collective_heart = inv.get("collective:heart", 0)
|
|
123
|
+
state.collective_influence = inv.get("collective:influence", 0)
|
|
124
|
+
|
|
125
|
+
# Parse visible entities
|
|
126
|
+
visible_entities: dict[tuple[int, int], Entity] = {}
|
|
127
|
+
position_features: dict[tuple[int, int], dict] = {}
|
|
128
|
+
|
|
129
|
+
for tok in obs.tokens:
|
|
130
|
+
obs_r, obs_c = tok.row(), tok.col()
|
|
131
|
+
# Skip center cell
|
|
132
|
+
if obs_r == center_r and obs_c == center_c:
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
world_r = obs_r - self._obs_hr + state.position[0]
|
|
136
|
+
world_c = obs_c - self._obs_wr + state.position[1]
|
|
137
|
+
world_pos = (world_r, world_c)
|
|
138
|
+
|
|
139
|
+
if world_pos not in position_features:
|
|
140
|
+
position_features[world_pos] = {"tags": [], "props": {}}
|
|
141
|
+
|
|
142
|
+
feature_name = tok.feature.name
|
|
143
|
+
if feature_name == "tag":
|
|
144
|
+
position_features[world_pos]["tags"].append(tok.value)
|
|
145
|
+
elif feature_name in ("cooldown_remaining", "clipped", "remaining_uses", "collective"):
|
|
146
|
+
position_features[world_pos]["props"][feature_name] = tok.value
|
|
147
|
+
elif feature_name.startswith("inv:"):
|
|
148
|
+
inv_dict = position_features[world_pos].setdefault("inventory", {})
|
|
149
|
+
suffix = feature_name[4:]
|
|
150
|
+
if ":p" in suffix:
|
|
151
|
+
base_name, power_str = suffix.rsplit(":p", 1)
|
|
152
|
+
power = int(power_str)
|
|
153
|
+
current = inv_dict.get(base_name, 0)
|
|
154
|
+
inv_dict[base_name] = current + tok.value * (256**power)
|
|
155
|
+
else:
|
|
156
|
+
current = inv_dict.get(suffix, 0)
|
|
157
|
+
inv_dict[suffix] = current + tok.value
|
|
158
|
+
|
|
159
|
+
# Convert to entities
|
|
160
|
+
for world_pos, features in position_features.items():
|
|
161
|
+
tags = features.get("tags", [])
|
|
162
|
+
if not tags:
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
obj_name = self._resolve_object_name(tags)
|
|
166
|
+
if obj_name == "unknown":
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
props = dict(features.get("props", {}))
|
|
170
|
+
inv_data = features.get("inventory")
|
|
171
|
+
|
|
172
|
+
# Alignment from collective ID
|
|
173
|
+
collective_id = props.pop("collective", None)
|
|
174
|
+
if collective_id is not None:
|
|
175
|
+
props["collective_id"] = collective_id
|
|
176
|
+
alignment = self._derive_alignment(obj_name, props.get("clipped", 0), collective_id)
|
|
177
|
+
if alignment:
|
|
178
|
+
props["alignment"] = alignment
|
|
179
|
+
|
|
180
|
+
# Remaining uses
|
|
181
|
+
if "remaining_uses" not in props:
|
|
182
|
+
props["remaining_uses"] = 999
|
|
183
|
+
|
|
184
|
+
# Inventory amount for extractors
|
|
185
|
+
if inv_data:
|
|
186
|
+
props["inventory_amount"] = sum(inv_data.values())
|
|
187
|
+
props["has_inventory"] = True
|
|
188
|
+
else:
|
|
189
|
+
props.setdefault("inventory_amount", -1)
|
|
190
|
+
|
|
191
|
+
visible_entities[world_pos] = Entity(
|
|
192
|
+
type=obj_name,
|
|
193
|
+
properties=props,
|
|
194
|
+
last_seen=step,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return state, visible_entities
|
|
198
|
+
|
|
199
|
+
def _resolve_object_name(self, tag_ids: list[int]) -> str:
|
|
200
|
+
"""Resolve tag IDs to an object name."""
|
|
201
|
+
resolved = [self._tag_names.get(tid, "") for tid in tag_ids]
|
|
202
|
+
|
|
203
|
+
# Priority: type:* tags
|
|
204
|
+
for tag in resolved:
|
|
205
|
+
if tag.startswith("type:"):
|
|
206
|
+
return tag[5:]
|
|
207
|
+
|
|
208
|
+
# Non-collective tags
|
|
209
|
+
for tag in resolved:
|
|
210
|
+
if tag and not tag.startswith("collective:"):
|
|
211
|
+
return tag
|
|
212
|
+
|
|
213
|
+
return "unknown"
|
|
214
|
+
|
|
215
|
+
def _get_vibe_name(self, vibe_id: int) -> str:
|
|
216
|
+
if 0 <= vibe_id < len(self._vibe_names):
|
|
217
|
+
return self._vibe_names[vibe_id]
|
|
218
|
+
return "default"
|
|
219
|
+
|
|
220
|
+
def _derive_alignment(self, obj_name: str, clipped: int, collective_id: int | None) -> str | None:
|
|
221
|
+
if collective_id is not None:
|
|
222
|
+
if collective_id == self._cogs_collective_id:
|
|
223
|
+
return "cogs"
|
|
224
|
+
elif collective_id == self._clips_collective_id:
|
|
225
|
+
return "clips"
|
|
226
|
+
if "cogs" in obj_name:
|
|
227
|
+
return "cogs"
|
|
228
|
+
if "clips" in obj_name or clipped > 0:
|
|
229
|
+
return "clips"
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def obs_half_height(self) -> int:
|
|
234
|
+
return self._obs_hr
|
|
235
|
+
|
|
236
|
+
@property
|
|
237
|
+
def obs_half_width(self) -> int:
|
|
238
|
+
return self._obs_wr
|