cogames-agents 0.0.0.7__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cogames_agents/__init__.py +0 -0
- cogames_agents/evals/__init__.py +5 -0
- cogames_agents/evals/planky_evals.py +415 -0
- cogames_agents/policy/__init__.py +0 -0
- cogames_agents/policy/evolution/__init__.py +0 -0
- cogames_agents/policy/evolution/cogsguard/__init__.py +0 -0
- cogames_agents/policy/evolution/cogsguard/evolution.py +695 -0
- cogames_agents/policy/evolution/cogsguard/evolutionary_coordinator.py +540 -0
- cogames_agents/policy/nim_agents/__init__.py +20 -0
- cogames_agents/policy/nim_agents/agents.py +98 -0
- cogames_agents/policy/nim_agents/bindings/generated/libnim_agents.dylib +0 -0
- cogames_agents/policy/nim_agents/bindings/generated/nim_agents.py +215 -0
- cogames_agents/policy/nim_agents/cogsguard_agents.nim +555 -0
- cogames_agents/policy/nim_agents/cogsguard_align_all_agents.nim +569 -0
- cogames_agents/policy/nim_agents/common.nim +1054 -0
- cogames_agents/policy/nim_agents/install.sh +1 -0
- cogames_agents/policy/nim_agents/ladybug_agent.nim +954 -0
- cogames_agents/policy/nim_agents/nim_agents.nim +68 -0
- cogames_agents/policy/nim_agents/nim_agents.nims +14 -0
- cogames_agents/policy/nim_agents/nimby.lock +3 -0
- cogames_agents/policy/nim_agents/racecar_agents.nim +844 -0
- cogames_agents/policy/nim_agents/random_agents.nim +68 -0
- cogames_agents/policy/nim_agents/test_agents.py +53 -0
- cogames_agents/policy/nim_agents/thinky_agents.nim +677 -0
- cogames_agents/policy/nim_agents/thinky_eval.py +230 -0
- cogames_agents/policy/scripted_agent/README.md +360 -0
- cogames_agents/policy/scripted_agent/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/baseline_agent.py +1031 -0
- cogames_agents/policy/scripted_agent/cogas/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/cogas/context.py +68 -0
- cogames_agents/policy/scripted_agent/cogas/entity_map.py +152 -0
- cogames_agents/policy/scripted_agent/cogas/goal.py +115 -0
- cogames_agents/policy/scripted_agent/cogas/goals/__init__.py +27 -0
- cogames_agents/policy/scripted_agent/cogas/goals/aligner.py +160 -0
- cogames_agents/policy/scripted_agent/cogas/goals/gear.py +197 -0
- cogames_agents/policy/scripted_agent/cogas/goals/miner.py +441 -0
- cogames_agents/policy/scripted_agent/cogas/goals/scout.py +40 -0
- cogames_agents/policy/scripted_agent/cogas/goals/scrambler.py +174 -0
- cogames_agents/policy/scripted_agent/cogas/goals/shared.py +160 -0
- cogames_agents/policy/scripted_agent/cogas/goals/stem.py +60 -0
- cogames_agents/policy/scripted_agent/cogas/goals/survive.py +100 -0
- cogames_agents/policy/scripted_agent/cogas/navigator.py +401 -0
- cogames_agents/policy/scripted_agent/cogas/obs_parser.py +238 -0
- cogames_agents/policy/scripted_agent/cogas/policy.py +525 -0
- cogames_agents/policy/scripted_agent/cogas/trace.py +69 -0
- cogames_agents/policy/scripted_agent/cogsguard/CLAUDE.md +517 -0
- cogames_agents/policy/scripted_agent/cogsguard/README.md +252 -0
- cogames_agents/policy/scripted_agent/cogsguard/__init__.py +74 -0
- cogames_agents/policy/scripted_agent/cogsguard/aligned_junction_held_investigation.md +152 -0
- cogames_agents/policy/scripted_agent/cogsguard/aligner.py +333 -0
- cogames_agents/policy/scripted_agent/cogsguard/behavior_hooks.py +44 -0
- cogames_agents/policy/scripted_agent/cogsguard/control_agent.py +323 -0
- cogames_agents/policy/scripted_agent/cogsguard/debug_agent.py +533 -0
- cogames_agents/policy/scripted_agent/cogsguard/miner.py +589 -0
- cogames_agents/policy/scripted_agent/cogsguard/options.py +67 -0
- cogames_agents/policy/scripted_agent/cogsguard/parity_metrics.py +36 -0
- cogames_agents/policy/scripted_agent/cogsguard/policy.py +1967 -0
- cogames_agents/policy/scripted_agent/cogsguard/prereq_trace.py +33 -0
- cogames_agents/policy/scripted_agent/cogsguard/role_trace.py +50 -0
- cogames_agents/policy/scripted_agent/cogsguard/roles.py +31 -0
- cogames_agents/policy/scripted_agent/cogsguard/rollout_trace.py +40 -0
- cogames_agents/policy/scripted_agent/cogsguard/scout.py +69 -0
- cogames_agents/policy/scripted_agent/cogsguard/scrambler.py +350 -0
- cogames_agents/policy/scripted_agent/cogsguard/targeted_agent.py +418 -0
- cogames_agents/policy/scripted_agent/cogsguard/teacher.py +224 -0
- cogames_agents/policy/scripted_agent/cogsguard/types.py +381 -0
- cogames_agents/policy/scripted_agent/cogsguard/v2_agent.py +49 -0
- cogames_agents/policy/scripted_agent/common/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/common/geometry.py +24 -0
- cogames_agents/policy/scripted_agent/common/roles.py +34 -0
- cogames_agents/policy/scripted_agent/common/tag_utils.py +48 -0
- cogames_agents/policy/scripted_agent/demo_policy.py +242 -0
- cogames_agents/policy/scripted_agent/pathfinding.py +126 -0
- cogames_agents/policy/scripted_agent/pinky/DESIGN.md +317 -0
- cogames_agents/policy/scripted_agent/pinky/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/__init__.py +17 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/aligner.py +400 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/base.py +119 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/miner.py +632 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/scout.py +138 -0
- cogames_agents/policy/scripted_agent/pinky/behaviors/scrambler.py +433 -0
- cogames_agents/policy/scripted_agent/pinky/policy.py +570 -0
- cogames_agents/policy/scripted_agent/pinky/services/__init__.py +7 -0
- cogames_agents/policy/scripted_agent/pinky/services/map_tracker.py +808 -0
- cogames_agents/policy/scripted_agent/pinky/services/navigator.py +864 -0
- cogames_agents/policy/scripted_agent/pinky/services/safety.py +189 -0
- cogames_agents/policy/scripted_agent/pinky/state.py +299 -0
- cogames_agents/policy/scripted_agent/pinky/types.py +138 -0
- cogames_agents/policy/scripted_agent/planky/CLAUDE.md +124 -0
- cogames_agents/policy/scripted_agent/planky/IMPROVEMENTS.md +160 -0
- cogames_agents/policy/scripted_agent/planky/NOTES.md +153 -0
- cogames_agents/policy/scripted_agent/planky/PLAN.md +254 -0
- cogames_agents/policy/scripted_agent/planky/README.md +214 -0
- cogames_agents/policy/scripted_agent/planky/STRATEGY.md +100 -0
- cogames_agents/policy/scripted_agent/planky/__init__.py +5 -0
- cogames_agents/policy/scripted_agent/planky/context.py +68 -0
- cogames_agents/policy/scripted_agent/planky/entity_map.py +152 -0
- cogames_agents/policy/scripted_agent/planky/goal.py +107 -0
- cogames_agents/policy/scripted_agent/planky/goals/__init__.py +27 -0
- cogames_agents/policy/scripted_agent/planky/goals/aligner.py +168 -0
- cogames_agents/policy/scripted_agent/planky/goals/gear.py +179 -0
- cogames_agents/policy/scripted_agent/planky/goals/miner.py +416 -0
- cogames_agents/policy/scripted_agent/planky/goals/scout.py +40 -0
- cogames_agents/policy/scripted_agent/planky/goals/scrambler.py +174 -0
- cogames_agents/policy/scripted_agent/planky/goals/shared.py +160 -0
- cogames_agents/policy/scripted_agent/planky/goals/stem.py +49 -0
- cogames_agents/policy/scripted_agent/planky/goals/survive.py +96 -0
- cogames_agents/policy/scripted_agent/planky/navigator.py +388 -0
- cogames_agents/policy/scripted_agent/planky/obs_parser.py +238 -0
- cogames_agents/policy/scripted_agent/planky/policy.py +485 -0
- cogames_agents/policy/scripted_agent/planky/tests/__init__.py +0 -0
- cogames_agents/policy/scripted_agent/planky/tests/conftest.py +66 -0
- cogames_agents/policy/scripted_agent/planky/tests/helpers.py +152 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_aligner.py +24 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_miner.py +30 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_scout.py +15 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_scrambler.py +29 -0
- cogames_agents/policy/scripted_agent/planky/tests/test_stem.py +36 -0
- cogames_agents/policy/scripted_agent/planky/trace.py +69 -0
- cogames_agents/policy/scripted_agent/types.py +239 -0
- cogames_agents/policy/scripted_agent/unclipping_agent.py +461 -0
- cogames_agents/policy/scripted_agent/utils.py +381 -0
- cogames_agents/policy/scripted_registry.py +80 -0
- cogames_agents/py.typed +0 -0
- cogames_agents-0.0.0.7.dist-info/METADATA +98 -0
- cogames_agents-0.0.0.7.dist-info/RECORD +128 -0
- cogames_agents-0.0.0.7.dist-info/WHEEL +6 -0
- cogames_agents-0.0.0.7.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,808 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MapTracker service for Pinky policy.
|
|
3
|
+
|
|
4
|
+
Processes observations and maintains map knowledge.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from collections import Counter
|
|
10
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
11
|
+
|
|
12
|
+
from cogames_agents.policy.scripted_agent.common.tag_utils import derive_alignment, select_primary_tag
|
|
13
|
+
from cogames_agents.policy.scripted_agent.pinky.types import (
|
|
14
|
+
DEBUG,
|
|
15
|
+
ROLE_TO_STATION,
|
|
16
|
+
CellType,
|
|
17
|
+
StructureInfo,
|
|
18
|
+
StructureType,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from cogames_agents.policy.scripted_agent.pinky.state import AgentState
|
|
23
|
+
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
|
|
24
|
+
from mettagrid.simulator.interface import AgentObservation
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MapTracker:
|
|
28
|
+
"""Processes observations and maintains map knowledge."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, policy_env_info: PolicyEnvInterface):
|
|
31
|
+
self._obs_hr = policy_env_info.obs_height // 2
|
|
32
|
+
self._obs_wr = policy_env_info.obs_width // 2
|
|
33
|
+
self._tag_names = policy_env_info.tag_id_to_name
|
|
34
|
+
if DEBUG:
|
|
35
|
+
print(f"[MAP] Available tags: {list(self._tag_names.values())}")
|
|
36
|
+
self._spatial_feature_names = {"tag", "cooldown_remaining", "clipped", "remaining_uses", "collective"}
|
|
37
|
+
self._agent_feature_key_by_name = {"agent:group": "agent_group", "agent:frozen": "agent_frozen"}
|
|
38
|
+
|
|
39
|
+
# Build collective ID to name mapping from tags
|
|
40
|
+
# Tags like "cogs" and "clips" indicate collective names
|
|
41
|
+
# Collective IDs in observations correspond to these names (alphabetically sorted)
|
|
42
|
+
self._collective_names = ["clips", "cogs"] # Alphabetically sorted - matches mettagrid convention
|
|
43
|
+
self._cogs_collective_id: Optional[int] = None
|
|
44
|
+
self._clips_collective_id: Optional[int] = None
|
|
45
|
+
for i, name in enumerate(self._collective_names):
|
|
46
|
+
if name == "cogs":
|
|
47
|
+
self._cogs_collective_id = i
|
|
48
|
+
elif name == "clips":
|
|
49
|
+
self._clips_collective_id = i
|
|
50
|
+
if DEBUG:
|
|
51
|
+
print(f"[MAP] Collective IDs: cogs={self._cogs_collective_id}, clips={self._clips_collective_id}")
|
|
52
|
+
|
|
53
|
+
# Derive vibe names from action names (change_vibe_<vibe_name>)
|
|
54
|
+
# The order of vibes in action names matches the vibe IDs in observations
|
|
55
|
+
self._vibe_names: list[str] = []
|
|
56
|
+
for action_name in policy_env_info.action_names:
|
|
57
|
+
if action_name.startswith("change_vibe_"):
|
|
58
|
+
vibe_name = action_name[len("change_vibe_") :]
|
|
59
|
+
self._vibe_names.append(vibe_name)
|
|
60
|
+
|
|
61
|
+
def update(self, state: AgentState, obs: AgentObservation) -> None:
|
|
62
|
+
"""Parse observation and update map knowledge."""
|
|
63
|
+
# Clear current-step agent occupancy
|
|
64
|
+
state.map.agent_occupancy.clear()
|
|
65
|
+
|
|
66
|
+
# Read inventory from observation
|
|
67
|
+
self._read_inventory(state, obs)
|
|
68
|
+
|
|
69
|
+
# Compute position from object matching (more reliable than action tracking)
|
|
70
|
+
self._compute_position_from_observation(state, obs)
|
|
71
|
+
|
|
72
|
+
# Parse spatial features from observation (now using corrected position)
|
|
73
|
+
position_features = self._parse_observation(state, obs)
|
|
74
|
+
|
|
75
|
+
# Mark observed cells as explored and FREE
|
|
76
|
+
self._mark_explored(state)
|
|
77
|
+
|
|
78
|
+
# Track which positions in observation window have agents
|
|
79
|
+
observed_agent_positions: set[tuple[int, int]] = set()
|
|
80
|
+
|
|
81
|
+
# Process discovered objects
|
|
82
|
+
for pos, features in position_features.items():
|
|
83
|
+
if "tags" not in features:
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
obj_name = self._get_object_name(features)
|
|
87
|
+
self._process_object(state, pos, obj_name, features)
|
|
88
|
+
|
|
89
|
+
# Track agent positions
|
|
90
|
+
if obj_name.lower() == "agent":
|
|
91
|
+
observed_agent_positions.add(pos)
|
|
92
|
+
|
|
93
|
+
# Update recent agents tracking
|
|
94
|
+
self._update_recent_agents(state, observed_agent_positions)
|
|
95
|
+
|
|
96
|
+
def _read_inventory(self, state: AgentState, obs: AgentObservation) -> None:
|
|
97
|
+
"""Read inventory, vibe, and collective stats from observation center cell."""
|
|
98
|
+
inv: dict[str, int] = {}
|
|
99
|
+
collective_inv: dict[str, int] = {}
|
|
100
|
+
vibe_id = 0
|
|
101
|
+
|
|
102
|
+
center_r, center_c = self._obs_hr, self._obs_wr
|
|
103
|
+
for tok in obs.tokens:
|
|
104
|
+
if tok.row() == center_r and tok.col() == center_c:
|
|
105
|
+
feature_name = tok.feature.name
|
|
106
|
+
if feature_name.startswith("inv:"):
|
|
107
|
+
resource_name = feature_name[4:]
|
|
108
|
+
inv[resource_name] = tok.value
|
|
109
|
+
elif feature_name == "vibe":
|
|
110
|
+
vibe_id = tok.value
|
|
111
|
+
elif feature_name.startswith("stat:collective:collective.") and feature_name.endswith(".amount"):
|
|
112
|
+
# Parse collective resource amount from "stat:collective:collective.{resource}.amount"
|
|
113
|
+
# Extract resource name between "collective." and ".amount"
|
|
114
|
+
prefix = "stat:collective:collective."
|
|
115
|
+
suffix = ".amount"
|
|
116
|
+
resource_name = feature_name[len(prefix) : -len(suffix)]
|
|
117
|
+
collective_inv[resource_name] = tok.value
|
|
118
|
+
|
|
119
|
+
# Update inventory
|
|
120
|
+
state.energy = inv.get("energy", 0)
|
|
121
|
+
state.hp = inv.get("hp", 100)
|
|
122
|
+
state.carbon = inv.get("carbon", 0)
|
|
123
|
+
state.oxygen = inv.get("oxygen", 0)
|
|
124
|
+
state.germanium = inv.get("germanium", 0)
|
|
125
|
+
state.silicon = inv.get("silicon", 0)
|
|
126
|
+
state.heart = inv.get("heart", 0)
|
|
127
|
+
state.influence = inv.get("influence", 0)
|
|
128
|
+
|
|
129
|
+
# Update gear
|
|
130
|
+
state.miner_gear = inv.get("miner", 0) > 0
|
|
131
|
+
state.scout_gear = inv.get("scout", 0) > 0
|
|
132
|
+
state.aligner_gear = inv.get("aligner", 0) > 0
|
|
133
|
+
state.scrambler_gear = inv.get("scrambler", 0) > 0
|
|
134
|
+
|
|
135
|
+
# Update collective inventory
|
|
136
|
+
state.collective_carbon = collective_inv.get("carbon", 0)
|
|
137
|
+
state.collective_oxygen = collective_inv.get("oxygen", 0)
|
|
138
|
+
state.collective_germanium = collective_inv.get("germanium", 0)
|
|
139
|
+
state.collective_silicon = collective_inv.get("silicon", 0)
|
|
140
|
+
|
|
141
|
+
# Update vibe
|
|
142
|
+
state.vibe = self._get_vibe_name(vibe_id)
|
|
143
|
+
|
|
144
|
+
def _compute_position_from_observation(self, state: AgentState, obs: AgentObservation) -> None:
|
|
145
|
+
"""Compute agent position by matching known objects in observation.
|
|
146
|
+
|
|
147
|
+
More reliable than action-based tracking because it:
|
|
148
|
+
- Doesn't matter if moves succeed or fail
|
|
149
|
+
- Self-corrects any position drift
|
|
150
|
+
- Based on actual observation, not assumptions
|
|
151
|
+
|
|
152
|
+
Strategy: Find objects in observation, match to known world positions,
|
|
153
|
+
derive agent position from the offset.
|
|
154
|
+
"""
|
|
155
|
+
# Collect observed objects with their observation-relative positions
|
|
156
|
+
# Format: {obs_pos: obj_name}
|
|
157
|
+
observed_objects: dict[tuple[int, int], str] = {}
|
|
158
|
+
|
|
159
|
+
for tok in obs.tokens:
|
|
160
|
+
if tok.feature.name != "tag":
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
obs_r, obs_c = tok.row(), tok.col()
|
|
164
|
+
# Skip center cell (that's the agent itself)
|
|
165
|
+
if obs_r == self._obs_hr and obs_c == self._obs_wr:
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
tag_name = self._tag_names.get(tok.value, "")
|
|
169
|
+
obj_name = tag_name.removeprefix("type:")
|
|
170
|
+
|
|
171
|
+
# Only track unique identifiable objects (stations, junctions, hubs)
|
|
172
|
+
if obj_name in self.MATCHABLE_OBJECTS:
|
|
173
|
+
obs_pos = (obs_r, obs_c)
|
|
174
|
+
# Keep the most specific name if multiple tags
|
|
175
|
+
if obs_pos not in observed_objects or obj_name in self.PRIORITY_OBJECTS:
|
|
176
|
+
observed_objects[obs_pos] = obj_name
|
|
177
|
+
|
|
178
|
+
if not observed_objects:
|
|
179
|
+
return # No matchable objects, keep current position
|
|
180
|
+
|
|
181
|
+
# Try to match observed objects against known world positions
|
|
182
|
+
position_votes: list[tuple[int, int]] = []
|
|
183
|
+
|
|
184
|
+
for obs_pos, obj_name in observed_objects.items():
|
|
185
|
+
obs_r, obs_c = obs_pos
|
|
186
|
+
|
|
187
|
+
# Check known stations
|
|
188
|
+
for station_name, world_pos in state.map.stations.items():
|
|
189
|
+
# Match by name similarity
|
|
190
|
+
if self._objects_match(obj_name, station_name):
|
|
191
|
+
# Derive agent position: world_pos - obs_offset
|
|
192
|
+
derived_row = world_pos[0] - (obs_r - self._obs_hr)
|
|
193
|
+
derived_col = world_pos[1] - (obs_c - self._obs_wr)
|
|
194
|
+
position_votes.append((derived_row, derived_col))
|
|
195
|
+
|
|
196
|
+
# Check known structures (junctions, extractors, etc.)
|
|
197
|
+
for world_pos, struct in state.map.structures.items():
|
|
198
|
+
if self._objects_match(obj_name, struct.name):
|
|
199
|
+
derived_row = world_pos[0] - (obs_r - self._obs_hr)
|
|
200
|
+
derived_col = world_pos[1] - (obs_c - self._obs_wr)
|
|
201
|
+
position_votes.append((derived_row, derived_col))
|
|
202
|
+
|
|
203
|
+
if not position_votes:
|
|
204
|
+
return # No matches found, keep current position
|
|
205
|
+
|
|
206
|
+
# Use majority vote (most common derived position)
|
|
207
|
+
# This handles cases where multiple objects are visible
|
|
208
|
+
position_counts = Counter(position_votes)
|
|
209
|
+
most_common_pos, count = position_counts.most_common(1)[0]
|
|
210
|
+
|
|
211
|
+
# Only update if we have confidence (multiple matches or different from current)
|
|
212
|
+
if count >= 1:
|
|
213
|
+
old_pos = (state.row, state.col)
|
|
214
|
+
if most_common_pos != old_pos:
|
|
215
|
+
if DEBUG:
|
|
216
|
+
print(
|
|
217
|
+
f"[A{state.agent_id}] MAP: Position corrected via object matching: "
|
|
218
|
+
f"{old_pos} -> {most_common_pos} (votes={count})"
|
|
219
|
+
)
|
|
220
|
+
state.row, state.col = most_common_pos
|
|
221
|
+
|
|
222
|
+
# Objects that can be used for position matching (unique, static)
|
|
223
|
+
MATCHABLE_OBJECTS = frozenset(
|
|
224
|
+
{
|
|
225
|
+
"miner_station",
|
|
226
|
+
"scout_station",
|
|
227
|
+
"aligner_station",
|
|
228
|
+
"scrambler_station",
|
|
229
|
+
"nexus",
|
|
230
|
+
"hub",
|
|
231
|
+
"junction",
|
|
232
|
+
"chest",
|
|
233
|
+
"carbon_extractor",
|
|
234
|
+
"oxygen_extractor",
|
|
235
|
+
"germanium_extractor",
|
|
236
|
+
"silicon_extractor",
|
|
237
|
+
}
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def _objects_match(self, obs_name: str, known_name: str) -> bool:
|
|
241
|
+
"""Check if observed object name matches a known object name."""
|
|
242
|
+
obs_lower = obs_name.lower()
|
|
243
|
+
known_lower = known_name.lower()
|
|
244
|
+
# Direct match
|
|
245
|
+
if obs_lower == known_lower:
|
|
246
|
+
return True
|
|
247
|
+
# Substring match (e.g., "junction" matches "cogs_junction")
|
|
248
|
+
if obs_lower in known_lower or known_lower in obs_lower:
|
|
249
|
+
return True
|
|
250
|
+
# Type match (e.g., "junction" matches "junction" which is a junction type)
|
|
251
|
+
junction_names = {"junction", "supply_depot"}
|
|
252
|
+
if obs_lower in junction_names and known_lower in junction_names:
|
|
253
|
+
return True
|
|
254
|
+
return False
|
|
255
|
+
|
|
256
|
+
def _get_vibe_name(self, vibe_id: int) -> str:
|
|
257
|
+
"""Convert vibe ID to name."""
|
|
258
|
+
# Use dynamically derived vibe names from action names
|
|
259
|
+
if 0 <= vibe_id < len(self._vibe_names):
|
|
260
|
+
return self._vibe_names[vibe_id]
|
|
261
|
+
return "default"
|
|
262
|
+
|
|
263
|
+
def _parse_observation(
|
|
264
|
+
self, state: AgentState, obs: AgentObservation
|
|
265
|
+
) -> dict[tuple[int, int], dict[str, Union[int, list[int], dict[str, int]]]]:
|
|
266
|
+
"""Parse observation tokens into position-keyed features."""
|
|
267
|
+
position_features: dict[tuple[int, int], dict[str, Union[int, list[int], dict[str, int]]]] = {}
|
|
268
|
+
|
|
269
|
+
for tok in obs.tokens:
|
|
270
|
+
# Use row()/col() methods - location tuple is (col, row) format
|
|
271
|
+
obs_r, obs_c = tok.row(), tok.col()
|
|
272
|
+
feature_name = tok.feature.name
|
|
273
|
+
value = tok.value
|
|
274
|
+
|
|
275
|
+
# Skip center cell (inventory/global)
|
|
276
|
+
if obs_r == self._obs_hr and obs_c == self._obs_wr:
|
|
277
|
+
continue
|
|
278
|
+
|
|
279
|
+
# Convert to world coords
|
|
280
|
+
r = obs_r - self._obs_hr + state.row
|
|
281
|
+
c = obs_c - self._obs_wr + state.col
|
|
282
|
+
|
|
283
|
+
if not (0 <= r < state.map.grid_size and 0 <= c < state.map.grid_size):
|
|
284
|
+
continue
|
|
285
|
+
|
|
286
|
+
pos = (r, c)
|
|
287
|
+
if pos not in position_features:
|
|
288
|
+
position_features[pos] = {}
|
|
289
|
+
|
|
290
|
+
# Handle spatial features
|
|
291
|
+
if feature_name in self._spatial_feature_names:
|
|
292
|
+
if feature_name == "tag":
|
|
293
|
+
tags = position_features[pos].setdefault("tags", [])
|
|
294
|
+
if isinstance(tags, list):
|
|
295
|
+
tags.append(value)
|
|
296
|
+
else:
|
|
297
|
+
position_features[pos][feature_name] = value
|
|
298
|
+
|
|
299
|
+
# Handle inventory features (for extractors)
|
|
300
|
+
# Multi-token encoding: inv:{resource} = base, inv:{resource}:p1 = power1, etc.
|
|
301
|
+
# We accumulate: base + p1*256 + p2*256^2 + ... (token_value_base=256 is default)
|
|
302
|
+
elif feature_name.startswith("inv:"):
|
|
303
|
+
suffix = feature_name[4:] # e.g., "carbon" or "carbon:p1"
|
|
304
|
+
# Parse resource name and power
|
|
305
|
+
if ":p" in suffix:
|
|
306
|
+
resource, power_str = suffix.rsplit(":p", 1)
|
|
307
|
+
power = int(power_str)
|
|
308
|
+
else:
|
|
309
|
+
resource = suffix
|
|
310
|
+
power = 0
|
|
311
|
+
|
|
312
|
+
inventory = position_features[pos].setdefault("inventory", {})
|
|
313
|
+
if isinstance(inventory, dict):
|
|
314
|
+
# Accumulate multi-token value (token_value_base=256 is mettagrid default)
|
|
315
|
+
# inv:X = amount % 256, inv:X:p1 = (amount // 256) % 256, etc.
|
|
316
|
+
token_base = 256
|
|
317
|
+
current = inventory.get(resource, 0)
|
|
318
|
+
inventory[resource] = current + value * (token_base**power)
|
|
319
|
+
|
|
320
|
+
return position_features
|
|
321
|
+
|
|
322
|
+
def _mark_explored(self, state: AgentState) -> None:
|
|
323
|
+
"""Mark observed cells as explored and FREE.
|
|
324
|
+
|
|
325
|
+
All cells in the current observation window are marked FREE initially.
|
|
326
|
+
Then _process_object() will mark specific cells as OBSTACLE if objects are present.
|
|
327
|
+
This correctly handles dynamic changes (objects that moved since last observation).
|
|
328
|
+
|
|
329
|
+
Cells outside the observation window retain their previous state (FREE, OBSTACLE, or UNKNOWN),
|
|
330
|
+
which is the "internal map" knowledge built up over time.
|
|
331
|
+
"""
|
|
332
|
+
for obs_r in range(2 * self._obs_hr + 1):
|
|
333
|
+
for obs_c in range(2 * self._obs_wr + 1):
|
|
334
|
+
r = obs_r - self._obs_hr + state.row
|
|
335
|
+
c = obs_c - self._obs_wr + state.col
|
|
336
|
+
if 0 <= r < state.map.grid_size and 0 <= c < state.map.grid_size:
|
|
337
|
+
# Mark all observed cells as FREE (objects will be re-marked as OBSTACLE)
|
|
338
|
+
state.map.occupancy[r][c] = CellType.FREE.value
|
|
339
|
+
state.map.explored[r][c] = True
|
|
340
|
+
|
|
341
|
+
# Object types that should be preferred over collective tags
|
|
342
|
+
PRIORITY_OBJECTS = frozenset(
|
|
343
|
+
{
|
|
344
|
+
"miner_station",
|
|
345
|
+
"scout_station",
|
|
346
|
+
"aligner_station",
|
|
347
|
+
"scrambler_station",
|
|
348
|
+
"carbon_extractor",
|
|
349
|
+
"oxygen_extractor",
|
|
350
|
+
"germanium_extractor",
|
|
351
|
+
"silicon_extractor",
|
|
352
|
+
"junction",
|
|
353
|
+
"hub",
|
|
354
|
+
"chest",
|
|
355
|
+
"wall",
|
|
356
|
+
"agent",
|
|
357
|
+
}
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
def _get_object_name(self, features: dict[str, Union[int, list[int], dict[str, int]]]) -> str:
|
|
361
|
+
"""Get object name from features - prioritize type: tags over collective: tags.
|
|
362
|
+
|
|
363
|
+
Priority order (matching common.tag_utils.select_primary_tag):
|
|
364
|
+
1. type:* tags (strip prefix and return the type name)
|
|
365
|
+
2. Non-collective tags that are PRIORITY_OBJECTS
|
|
366
|
+
3. Any non-collective tag
|
|
367
|
+
4. First tag (fallback)
|
|
368
|
+
"""
|
|
369
|
+
tags_value = features.get("tags", [])
|
|
370
|
+
if not isinstance(tags_value, list) or not tags_value:
|
|
371
|
+
return "unknown"
|
|
372
|
+
|
|
373
|
+
resolved_tags = [self._tag_names.get(tag_id, "") for tag_id in tags_value]
|
|
374
|
+
return select_primary_tag(resolved_tags, priority_objects=set(self.PRIORITY_OBJECTS))
|
|
375
|
+
|
|
376
|
+
def _process_object(
|
|
377
|
+
self,
|
|
378
|
+
state: AgentState,
|
|
379
|
+
pos: tuple[int, int],
|
|
380
|
+
obj_name: str,
|
|
381
|
+
features: dict[str, Union[int, list[int], dict[str, int]]],
|
|
382
|
+
) -> None:
|
|
383
|
+
"""Process a discovered object and update state."""
|
|
384
|
+
obj_lower = obj_name.lower()
|
|
385
|
+
|
|
386
|
+
# Extract common features
|
|
387
|
+
cooldown_val = features.get("cooldown_remaining", 0)
|
|
388
|
+
cooldown = cooldown_val if isinstance(cooldown_val, int) else 0
|
|
389
|
+
clipped_val = features.get("clipped", 0)
|
|
390
|
+
clipped = clipped_val if isinstance(clipped_val, int) else 0
|
|
391
|
+
remaining_val = features.get("remaining_uses", 999)
|
|
392
|
+
remaining = remaining_val if isinstance(remaining_val, int) else 999
|
|
393
|
+
inventory = features.get("inventory")
|
|
394
|
+
# Inventory handling:
|
|
395
|
+
# - If inv: tokens present (non-empty dict), this is a chest-based object with that inventory
|
|
396
|
+
# - If no inv: tokens, we need to check if this object was previously known to have inventory
|
|
397
|
+
# - For new objects without inv: tokens, assume protocol-based (inv_amount = -1)
|
|
398
|
+
has_inv_tokens = isinstance(inventory, dict) and bool(inventory)
|
|
399
|
+
if has_inv_tokens:
|
|
400
|
+
inv_amount = sum(inventory.values())
|
|
401
|
+
else:
|
|
402
|
+
inv_amount = -1 # Will be updated in _update_structure if we knew it had inventory before
|
|
403
|
+
|
|
404
|
+
# Extract collective ID for alignment detection
|
|
405
|
+
collective_val = features.get("collective")
|
|
406
|
+
collective_id: Optional[int] = collective_val if isinstance(collective_val, int) else None
|
|
407
|
+
|
|
408
|
+
# Check if it's a wall
|
|
409
|
+
if self._is_wall(obj_lower):
|
|
410
|
+
state.map.occupancy[pos[0]][pos[1]] = CellType.OBSTACLE.value
|
|
411
|
+
return
|
|
412
|
+
|
|
413
|
+
# Check if it's another agent
|
|
414
|
+
if obj_lower == "agent":
|
|
415
|
+
state.map.agent_occupancy.add(pos)
|
|
416
|
+
# Don't return - we still want to track the agent in recent_agents
|
|
417
|
+
# but no other processing needed
|
|
418
|
+
return
|
|
419
|
+
|
|
420
|
+
# Check for gear stations
|
|
421
|
+
for _role, station_name in ROLE_TO_STATION.items():
|
|
422
|
+
if station_name in obj_lower or self._is_station(obj_lower, station_name):
|
|
423
|
+
state.map.occupancy[pos[0]][pos[1]] = CellType.OBSTACLE.value
|
|
424
|
+
struct_type = self._get_station_type(station_name)
|
|
425
|
+
self._update_structure(state, pos, obj_name, struct_type, None, cooldown, remaining, inv_amount)
|
|
426
|
+
if station_name not in state.map.stations:
|
|
427
|
+
state.map.stations[station_name] = pos
|
|
428
|
+
if DEBUG:
|
|
429
|
+
print(f"[A{state.agent_id}] MAP: Found {station_name} at {pos}")
|
|
430
|
+
return
|
|
431
|
+
|
|
432
|
+
# Check for junction (junction/supply_depot/junction)
|
|
433
|
+
if "junction" in obj_lower or "supply_depot" in obj_lower or obj_lower == "junction":
|
|
434
|
+
state.map.occupancy[pos[0]][pos[1]] = CellType.OBSTACLE.value
|
|
435
|
+
alignment = self._derive_alignment(obj_lower, clipped, collective_id)
|
|
436
|
+
self._update_structure(
|
|
437
|
+
state, pos, obj_name, StructureType.JUNCTION, alignment, cooldown, remaining, inv_amount
|
|
438
|
+
)
|
|
439
|
+
if "junction" not in state.map.stations:
|
|
440
|
+
state.map.stations["junction"] = pos
|
|
441
|
+
if DEBUG and pos not in state.map.structures:
|
|
442
|
+
print(
|
|
443
|
+
f"[A{state.agent_id}] MAP: Found junction at {pos} "
|
|
444
|
+
f"(alignment={alignment}, collective_id={collective_id})"
|
|
445
|
+
)
|
|
446
|
+
return
|
|
447
|
+
|
|
448
|
+
# Check for hub/nexus
|
|
449
|
+
if "hub" in obj_lower or "nexus" in obj_lower:
|
|
450
|
+
state.map.occupancy[pos[0]][pos[1]] = CellType.OBSTACLE.value
|
|
451
|
+
self._update_structure(state, pos, obj_name, StructureType.HUB, "cogs", cooldown, remaining, inv_amount)
|
|
452
|
+
if "hub" not in state.map.stations:
|
|
453
|
+
state.map.stations["hub"] = pos
|
|
454
|
+
if DEBUG:
|
|
455
|
+
print(f"[A{state.agent_id}] MAP: Found hub at {pos}")
|
|
456
|
+
return
|
|
457
|
+
|
|
458
|
+
# Check for chest (hearts)
|
|
459
|
+
resources = ["carbon", "oxygen", "germanium", "silicon"]
|
|
460
|
+
is_resource_chest = any(f"{res}_" in obj_lower for res in resources)
|
|
461
|
+
if "chest" in obj_lower and not is_resource_chest:
|
|
462
|
+
state.map.occupancy[pos[0]][pos[1]] = CellType.OBSTACLE.value
|
|
463
|
+
self._update_structure(state, pos, obj_name, StructureType.CHEST, None, cooldown, remaining, inv_amount)
|
|
464
|
+
if "chest" not in state.map.stations:
|
|
465
|
+
state.map.stations["chest"] = pos
|
|
466
|
+
if DEBUG:
|
|
467
|
+
print(f"[A{state.agent_id}] MAP: Found chest at {pos}")
|
|
468
|
+
return
|
|
469
|
+
|
|
470
|
+
# Check for extractors
|
|
471
|
+
for resource in resources:
|
|
472
|
+
if f"{resource}_extractor" in obj_lower or f"{resource}_chest" in obj_lower:
|
|
473
|
+
state.map.occupancy[pos[0]][pos[1]] = CellType.OBSTACLE.value
|
|
474
|
+
# Get specific resource amount from inventory if available
|
|
475
|
+
res_amount = inventory.get(resource, inv_amount) if isinstance(inventory, dict) else inv_amount
|
|
476
|
+
self._update_extractor(state, pos, obj_name, cooldown, remaining, res_amount, resource, has_inv_tokens)
|
|
477
|
+
if DEBUG and state.step < 10:
|
|
478
|
+
print(
|
|
479
|
+
f"[A{state.agent_id}] MAP: Found {resource} extractor at {pos} "
|
|
480
|
+
f"(remaining={remaining}, inv={res_amount}, has_inv={has_inv_tokens})"
|
|
481
|
+
)
|
|
482
|
+
return
|
|
483
|
+
|
|
484
|
+
def _update_extractor(
|
|
485
|
+
self,
|
|
486
|
+
state: AgentState,
|
|
487
|
+
pos: tuple[int, int],
|
|
488
|
+
obj_name: str,
|
|
489
|
+
cooldown: int,
|
|
490
|
+
remaining: int,
|
|
491
|
+
inv_amount: int,
|
|
492
|
+
resource_type: str,
|
|
493
|
+
has_inv_tokens: bool,
|
|
494
|
+
) -> None:
|
|
495
|
+
"""Update or create an extractor structure with proper inventory tracking.
|
|
496
|
+
|
|
497
|
+
Handles the distinction between:
|
|
498
|
+
- Chest-based extractors (have inventory, emit inv: tokens when non-empty)
|
|
499
|
+
- Protocol-based extractors (no inventory, use remaining_uses for depletion)
|
|
500
|
+
"""
|
|
501
|
+
if pos in state.map.structures:
|
|
502
|
+
struct = state.map.structures[pos]
|
|
503
|
+
struct.last_seen_step = state.step
|
|
504
|
+
struct.cooldown_remaining = cooldown
|
|
505
|
+
struct.remaining_uses = remaining
|
|
506
|
+
|
|
507
|
+
if has_inv_tokens:
|
|
508
|
+
# We saw inventory tokens - this is a chest-based extractor with resources
|
|
509
|
+
struct.has_inventory = True
|
|
510
|
+
struct.inventory_amount = inv_amount
|
|
511
|
+
elif struct.has_inventory:
|
|
512
|
+
# We've seen inventory before but not now - extractor is depleted
|
|
513
|
+
struct.inventory_amount = 0
|
|
514
|
+
# else: Never seen inventory, keep as protocol-based (inventory_amount stays -1)
|
|
515
|
+
else:
|
|
516
|
+
# New extractor
|
|
517
|
+
state.map.structures[pos] = StructureInfo(
|
|
518
|
+
position=pos,
|
|
519
|
+
structure_type=StructureType.EXTRACTOR,
|
|
520
|
+
name=obj_name,
|
|
521
|
+
last_seen_step=state.step,
|
|
522
|
+
alignment=None,
|
|
523
|
+
resource_type=resource_type,
|
|
524
|
+
cooldown_remaining=cooldown,
|
|
525
|
+
remaining_uses=remaining,
|
|
526
|
+
inventory_amount=inv_amount if has_inv_tokens else -1,
|
|
527
|
+
has_inventory=has_inv_tokens,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
def _update_structure(
|
|
531
|
+
self,
|
|
532
|
+
state: AgentState,
|
|
533
|
+
pos: tuple[int, int],
|
|
534
|
+
obj_name: str,
|
|
535
|
+
struct_type: StructureType,
|
|
536
|
+
alignment: Optional[str],
|
|
537
|
+
cooldown: int,
|
|
538
|
+
remaining: int,
|
|
539
|
+
inv_amount: int,
|
|
540
|
+
resource_type: Optional[str] = None,
|
|
541
|
+
) -> None:
|
|
542
|
+
"""Update or create a non-extractor structure in the map."""
|
|
543
|
+
if pos in state.map.structures:
|
|
544
|
+
struct = state.map.structures[pos]
|
|
545
|
+
# Check if alignment changed
|
|
546
|
+
if DEBUG and struct.alignment != alignment:
|
|
547
|
+
print(f"[A{state.agent_id}] MAP: Junction {pos} alignment changed: {struct.alignment} -> {alignment}")
|
|
548
|
+
struct.last_seen_step = state.step
|
|
549
|
+
struct.cooldown_remaining = cooldown
|
|
550
|
+
struct.remaining_uses = remaining
|
|
551
|
+
struct.inventory_amount = inv_amount
|
|
552
|
+
# Always update alignment - None means neutral (not cogs or clips)
|
|
553
|
+
struct.alignment = alignment
|
|
554
|
+
else:
|
|
555
|
+
state.map.structures[pos] = StructureInfo(
|
|
556
|
+
position=pos,
|
|
557
|
+
structure_type=struct_type,
|
|
558
|
+
name=obj_name,
|
|
559
|
+
last_seen_step=state.step,
|
|
560
|
+
alignment=alignment,
|
|
561
|
+
resource_type=resource_type,
|
|
562
|
+
cooldown_remaining=cooldown,
|
|
563
|
+
remaining_uses=remaining,
|
|
564
|
+
inventory_amount=inv_amount,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
def _derive_alignment(self, obj_name: str, clipped: int, collective_id: Optional[int] = None) -> Optional[str]:
|
|
568
|
+
return derive_alignment(
|
|
569
|
+
obj_name,
|
|
570
|
+
clipped,
|
|
571
|
+
collective_id,
|
|
572
|
+
cogs_collective_id=self._cogs_collective_id,
|
|
573
|
+
clips_collective_id=self._clips_collective_id,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
def _update_recent_agents(self, state: AgentState, observed_agent_positions: set[tuple[int, int]]) -> None:
|
|
577
|
+
"""Update tracking of recently-seen agents.
|
|
578
|
+
|
|
579
|
+
- Add/update agents we see this step
|
|
580
|
+
- Remove agents from positions in our observation window where we no longer see them
|
|
581
|
+
- Keep agents at positions outside our current observation (they might still be there)
|
|
582
|
+
"""
|
|
583
|
+
from cogames_agents.policy.scripted_agent.pinky.state import AgentSighting
|
|
584
|
+
|
|
585
|
+
# Calculate current observation window bounds
|
|
586
|
+
min_r = state.row - self._obs_hr
|
|
587
|
+
max_r = state.row + self._obs_hr
|
|
588
|
+
min_c = state.col - self._obs_wr
|
|
589
|
+
max_c = state.col + self._obs_wr
|
|
590
|
+
|
|
591
|
+
# Update/add agents we see
|
|
592
|
+
for pos in observed_agent_positions:
|
|
593
|
+
state.map.recent_agents[pos] = AgentSighting(position=pos, last_seen_step=state.step)
|
|
594
|
+
|
|
595
|
+
# Remove agents from positions in our observation window that we don't see anymore
|
|
596
|
+
# (they moved or we were wrong about their position)
|
|
597
|
+
positions_to_remove: list[tuple[int, int]] = []
|
|
598
|
+
for pos in state.map.recent_agents:
|
|
599
|
+
# Check if position is in current observation window
|
|
600
|
+
if min_r <= pos[0] <= max_r and min_c <= pos[1] <= max_c:
|
|
601
|
+
# We can see this position - if no agent there, remove from tracking
|
|
602
|
+
if pos not in observed_agent_positions:
|
|
603
|
+
positions_to_remove.append(pos)
|
|
604
|
+
|
|
605
|
+
for pos in positions_to_remove:
|
|
606
|
+
del state.map.recent_agents[pos]
|
|
607
|
+
|
|
608
|
+
# Also remove very stale agents (not seen for many steps and outside observation)
|
|
609
|
+
# This prevents memory buildup from agents that moved far away
|
|
610
|
+
stale_threshold = 50 # Remove if not seen for 50 steps
|
|
611
|
+
stale_positions: list[tuple[int, int]] = []
|
|
612
|
+
for pos, sighting in state.map.recent_agents.items():
|
|
613
|
+
if state.step - sighting.last_seen_step > stale_threshold:
|
|
614
|
+
stale_positions.append(pos)
|
|
615
|
+
|
|
616
|
+
for pos in stale_positions:
|
|
617
|
+
del state.map.recent_agents[pos]
|
|
618
|
+
|
|
619
|
+
def _is_wall(self, obj_name: str) -> bool:
|
|
620
|
+
"""Check if object is a wall."""
|
|
621
|
+
return "wall" in obj_name or "#" in obj_name
|
|
622
|
+
|
|
623
|
+
def _is_station(self, obj_name: str, station: str) -> bool:
|
|
624
|
+
"""Check if object is a specific station type."""
|
|
625
|
+
return station in obj_name
|
|
626
|
+
|
|
627
|
+
def _get_station_type(self, station_name: str) -> StructureType:
|
|
628
|
+
"""Convert station name to StructureType."""
|
|
629
|
+
mapping = {
|
|
630
|
+
"miner_station": StructureType.MINER_STATION,
|
|
631
|
+
"scout_station": StructureType.SCOUT_STATION,
|
|
632
|
+
"aligner_station": StructureType.ALIGNER_STATION,
|
|
633
|
+
"scrambler_station": StructureType.SCRAMBLER_STATION,
|
|
634
|
+
}
|
|
635
|
+
return mapping.get(station_name, StructureType.UNKNOWN)
|
|
636
|
+
|
|
637
|
+
def get_direction_to_nearest(
|
|
638
|
+
self,
|
|
639
|
+
state: AgentState,
|
|
640
|
+
obs: AgentObservation,
|
|
641
|
+
target_types: frozenset[str],
|
|
642
|
+
exclude_positions: Optional[set[tuple[int, int]]] = None,
|
|
643
|
+
) -> Optional[tuple[str, tuple[int, int]]]:
|
|
644
|
+
"""Get the direction to move toward nearest target in current observation.
|
|
645
|
+
|
|
646
|
+
Uses A* pathfinding within the observation window.
|
|
647
|
+
Args:
|
|
648
|
+
state: Agent state
|
|
649
|
+
obs: Current observation
|
|
650
|
+
target_types: Set of object type names to target
|
|
651
|
+
exclude_positions: Optional set of world positions to exclude from targets
|
|
652
|
+
Returns: (direction, world_pos) tuple or None if no path found.
|
|
653
|
+
direction: "north", "south", "east", "west"
|
|
654
|
+
world_pos: (row, col) in world coordinates
|
|
655
|
+
"""
|
|
656
|
+
import heapq
|
|
657
|
+
|
|
658
|
+
center_c, center_r = self._obs_wr, self._obs_hr # (5, 5) typically
|
|
659
|
+
|
|
660
|
+
# Find cells with any objects (blocked) and target cells
|
|
661
|
+
blocked_cells: set[tuple[int, int]] = set()
|
|
662
|
+
target_cells: list[tuple[int, int]] = []
|
|
663
|
+
|
|
664
|
+
for tok in obs.tokens:
|
|
665
|
+
if tok.feature.name != "tag":
|
|
666
|
+
continue
|
|
667
|
+
|
|
668
|
+
dr = tok.row() - center_r
|
|
669
|
+
dc = tok.col() - center_c
|
|
670
|
+
|
|
671
|
+
tag_name = self._tag_names.get(tok.value, "")
|
|
672
|
+
match_name = tag_name.removeprefix("type:")
|
|
673
|
+
|
|
674
|
+
# Check if this is a target (use substring matching for flexibility)
|
|
675
|
+
is_target = match_name in target_types or any(t in match_name for t in target_types)
|
|
676
|
+
if is_target:
|
|
677
|
+
# Check if this position should be excluded
|
|
678
|
+
world_pos = (state.row + dr, state.col + dc)
|
|
679
|
+
if exclude_positions and world_pos in exclude_positions:
|
|
680
|
+
if DEBUG and state.agent_id == 0 and state.step % 50 == 0:
|
|
681
|
+
print(f"[A{state.agent_id}] DIR: Excluding target '{match_name}' at {world_pos}")
|
|
682
|
+
else:
|
|
683
|
+
target_cells.append((dr, dc))
|
|
684
|
+
if DEBUG and state.agent_id == 0 and state.step % 50 == 0:
|
|
685
|
+
print(f"[A{state.agent_id}] DIR: Found target '{match_name}' at ({dr},{dc})")
|
|
686
|
+
|
|
687
|
+
# All objects block movement (except self at center)
|
|
688
|
+
if dr != 0 or dc != 0:
|
|
689
|
+
blocked_cells.add((dr, dc))
|
|
690
|
+
|
|
691
|
+
if not target_cells:
|
|
692
|
+
return None
|
|
693
|
+
|
|
694
|
+
# Find nearest target (for A* goal)
|
|
695
|
+
target_cells.sort(key=lambda t: abs(t[0]) + abs(t[1]))
|
|
696
|
+
goal = target_cells[0]
|
|
697
|
+
|
|
698
|
+
# Convert goal to world coordinates for return value
|
|
699
|
+
goal_world = (state.row + goal[0], state.col + goal[1])
|
|
700
|
+
|
|
701
|
+
# A* pathfinding from (0, 0) to goal
|
|
702
|
+
# We want to reach adjacent to goal (since goal cell is blocked)
|
|
703
|
+
def heuristic(pos: tuple[int, int]) -> int:
|
|
704
|
+
return abs(pos[0] - goal[0]) + abs(pos[1] - goal[1])
|
|
705
|
+
|
|
706
|
+
def is_adjacent_to_goal(pos: tuple[int, int]) -> bool:
|
|
707
|
+
return abs(pos[0] - goal[0]) + abs(pos[1] - goal[1]) == 1
|
|
708
|
+
|
|
709
|
+
start = (0, 0)
|
|
710
|
+
|
|
711
|
+
# Special case: already adjacent to goal - return direction to walk into target
|
|
712
|
+
if is_adjacent_to_goal(start):
|
|
713
|
+
dr, dc = goal
|
|
714
|
+
if dr == 1:
|
|
715
|
+
return ("south", goal_world)
|
|
716
|
+
elif dr == -1:
|
|
717
|
+
return ("north", goal_world)
|
|
718
|
+
elif dc == 1:
|
|
719
|
+
return ("east", goal_world)
|
|
720
|
+
else:
|
|
721
|
+
return ("west", goal_world)
|
|
722
|
+
|
|
723
|
+
# Priority queue: (f_score, g_score, position)
|
|
724
|
+
open_set: list[tuple[int, int, tuple[int, int]]] = [(heuristic(start), 0, start)]
|
|
725
|
+
came_from: dict[tuple[int, int], tuple[int, int]] = {}
|
|
726
|
+
g_score: dict[tuple[int, int], int] = {start: 0}
|
|
727
|
+
|
|
728
|
+
directions = [(1, 0), (-1, 0), (0, 1), (0, -1)] # south, north, east, west
|
|
729
|
+
obs_bounds = (self._obs_hr, self._obs_wr) # max distance from center
|
|
730
|
+
|
|
731
|
+
while open_set:
|
|
732
|
+
_, current_g, current = heapq.heappop(open_set)
|
|
733
|
+
|
|
734
|
+
# Check if we reached adjacent to goal
|
|
735
|
+
if is_adjacent_to_goal(current):
|
|
736
|
+
# Reconstruct path and return first direction
|
|
737
|
+
path = []
|
|
738
|
+
pos = current
|
|
739
|
+
while pos in came_from:
|
|
740
|
+
path.append(pos)
|
|
741
|
+
pos = came_from[pos]
|
|
742
|
+
path.reverse()
|
|
743
|
+
|
|
744
|
+
if path:
|
|
745
|
+
first_step = path[0]
|
|
746
|
+
dr, dc = first_step
|
|
747
|
+
if dr == 1:
|
|
748
|
+
direction = "south"
|
|
749
|
+
elif dr == -1:
|
|
750
|
+
direction = "north"
|
|
751
|
+
elif dc == 1:
|
|
752
|
+
direction = "east"
|
|
753
|
+
else:
|
|
754
|
+
direction = "west"
|
|
755
|
+
|
|
756
|
+
if DEBUG and state.agent_id == 0:
|
|
757
|
+
print(f"[A{state.agent_id}] DIR: A* path to {goal}, first step → {direction}")
|
|
758
|
+
return (direction, goal_world)
|
|
759
|
+
return None
|
|
760
|
+
|
|
761
|
+
for dr, dc in directions:
|
|
762
|
+
neighbor = (current[0] + dr, current[1] + dc)
|
|
763
|
+
|
|
764
|
+
# Check bounds (stay within observation window)
|
|
765
|
+
if abs(neighbor[0]) > obs_bounds[0] or abs(neighbor[1]) > obs_bounds[1]:
|
|
766
|
+
continue
|
|
767
|
+
|
|
768
|
+
# Check if blocked (but goal cell is allowed as destination check happens above)
|
|
769
|
+
if neighbor in blocked_cells and neighbor != goal:
|
|
770
|
+
continue
|
|
771
|
+
|
|
772
|
+
tentative_g = current_g + 1
|
|
773
|
+
if neighbor not in g_score or tentative_g < g_score[neighbor]:
|
|
774
|
+
g_score[neighbor] = tentative_g
|
|
775
|
+
f_score = tentative_g + heuristic(neighbor)
|
|
776
|
+
heapq.heappush(open_set, (f_score, tentative_g, neighbor))
|
|
777
|
+
came_from[neighbor] = current
|
|
778
|
+
|
|
779
|
+
# No path found
|
|
780
|
+
if DEBUG and state.agent_id == 0:
|
|
781
|
+
print(f"[A{state.agent_id}] DIR: A* no path to {goal}")
|
|
782
|
+
return None
|
|
783
|
+
|
|
784
|
+
# === Query methods ===
|
|
785
|
+
|
|
786
|
+
def find_nearest(
|
|
787
|
+
self, state: AgentState, structure_type: StructureType, exclude: Optional[tuple[int, int]] = None
|
|
788
|
+
) -> Optional[StructureInfo]:
|
|
789
|
+
"""Find nearest structure of given type."""
|
|
790
|
+
best: Optional[StructureInfo] = None
|
|
791
|
+
best_dist = float("inf")
|
|
792
|
+
|
|
793
|
+
for struct in state.map.structures.values():
|
|
794
|
+
if struct.structure_type != structure_type:
|
|
795
|
+
continue
|
|
796
|
+
if exclude and struct.position == exclude:
|
|
797
|
+
continue
|
|
798
|
+
|
|
799
|
+
dist = abs(struct.position[0] - state.row) + abs(struct.position[1] - state.col)
|
|
800
|
+
if dist < best_dist:
|
|
801
|
+
best = struct
|
|
802
|
+
best_dist = dist
|
|
803
|
+
|
|
804
|
+
return best
|
|
805
|
+
|
|
806
|
+
def distance_to(self, state: AgentState, pos: tuple[int, int]) -> int:
|
|
807
|
+
"""Manhattan distance from agent to position."""
|
|
808
|
+
return abs(pos[0] - state.row) + abs(pos[1] - state.col)
|