synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +575 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +104 -6
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,738 @@
|
|
1
|
+
"""NetHack engine implementation with state management and NLE integration."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import base64
|
7
|
+
from dataclasses import dataclass, field
|
8
|
+
from typing import Dict, Any, Optional, Tuple, List, TYPE_CHECKING, cast
|
9
|
+
import numpy as np
|
10
|
+
import logging
|
11
|
+
|
12
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
13
|
+
from synth_ai.environments.reproducibility.core import IReproducibleEngine
|
14
|
+
from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent
|
15
|
+
from synth_ai.environments.environment.shared_engine import (
|
16
|
+
GetObservationCallable,
|
17
|
+
InternalObservation,
|
18
|
+
)
|
19
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
# NLE imports are required
|
24
|
+
try:
|
25
|
+
from .helpers.nle_wrapper import NLEWrapper
|
26
|
+
from .helpers.action_mapping import convert_action_to_nle
|
27
|
+
from .achievements import NetHackAchievements, calculate_balrog_reward
|
28
|
+
except ImportError as e:
|
29
|
+
raise ImportError(
|
30
|
+
"NLE (NetHack Learning Environment) is required but not installed. "
|
31
|
+
"Please install it with: pip install nle"
|
32
|
+
) from e
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from .taskset import NetHackTaskInstanceMetadata
|
36
|
+
|
37
|
+
|
38
|
+
@dataclass
|
39
|
+
class NetHackPublicState:
|
40
|
+
"""State visible to the agent."""
|
41
|
+
|
42
|
+
# Game state
|
43
|
+
dungeon_level: int = 1
|
44
|
+
character_stats: Dict[str, Any] = field(default_factory=dict)
|
45
|
+
inventory: List[Dict[str, Any]] = field(default_factory=list)
|
46
|
+
position: Tuple[int, int] = (0, 0)
|
47
|
+
|
48
|
+
# Observation data
|
49
|
+
ascii_map: str = ""
|
50
|
+
message: str = ""
|
51
|
+
cursor_position: Tuple[int, int] = (0, 0)
|
52
|
+
|
53
|
+
# Meta information
|
54
|
+
turn_count: int = 0
|
55
|
+
max_turns: int = 10000
|
56
|
+
last_action: str = ""
|
57
|
+
terminated: bool = False
|
58
|
+
|
59
|
+
# Game context
|
60
|
+
in_menu: bool = False
|
61
|
+
menu_items: List[str] = field(default_factory=list)
|
62
|
+
|
63
|
+
# Achievements tracking
|
64
|
+
achievements: NetHackAchievements = field(default_factory=NetHackAchievements)
|
65
|
+
achievements_unlocked: Dict[str, bool] = field(default_factory=dict)
|
66
|
+
|
67
|
+
def diff(self, prev_state: "NetHackPublicState") -> Dict[str, Any]:
|
68
|
+
"""Track changes between states."""
|
69
|
+
differences = {}
|
70
|
+
|
71
|
+
if self.dungeon_level != prev_state.dungeon_level:
|
72
|
+
differences["dungeon_level"] = (
|
73
|
+
prev_state.dungeon_level,
|
74
|
+
self.dungeon_level,
|
75
|
+
)
|
76
|
+
if self.position != prev_state.position:
|
77
|
+
differences["position"] = (prev_state.position, self.position)
|
78
|
+
if self.message != prev_state.message:
|
79
|
+
differences["message"] = (prev_state.message, self.message)
|
80
|
+
if self.turn_count != prev_state.turn_count:
|
81
|
+
differences["turn_count"] = (prev_state.turn_count, self.turn_count)
|
82
|
+
if self.terminated != prev_state.terminated:
|
83
|
+
differences["terminated"] = (prev_state.terminated, self.terminated)
|
84
|
+
if self.last_action != prev_state.last_action:
|
85
|
+
differences["last_action"] = (prev_state.last_action, self.last_action)
|
86
|
+
|
87
|
+
return differences
|
88
|
+
|
89
|
+
@property
|
90
|
+
def map_text(self) -> str:
|
91
|
+
"""Formatted ASCII dungeon map."""
|
92
|
+
return self.ascii_map
|
93
|
+
|
94
|
+
|
95
|
+
@dataclass
|
96
|
+
class NetHackPrivateState:
|
97
|
+
"""Internal state (rewards, termination flags)."""
|
98
|
+
|
99
|
+
reward_last: float = 0.0
|
100
|
+
total_reward: float = 0.0
|
101
|
+
terminated: bool = False
|
102
|
+
truncated: bool = False
|
103
|
+
|
104
|
+
# Progress tracking
|
105
|
+
score: int = 0
|
106
|
+
depth_reached: int = 1
|
107
|
+
experience_level: int = 1
|
108
|
+
monsters_killed: int = 0
|
109
|
+
items_collected: int = 0
|
110
|
+
|
111
|
+
# Balrog reward tracking
|
112
|
+
balrog_reward_last: float = 0.0
|
113
|
+
balrog_total_reward: float = 0.0
|
114
|
+
|
115
|
+
def diff(self, prev_state: "NetHackPrivateState") -> Dict[str, Any]:
|
116
|
+
"""Track reward/progress changes."""
|
117
|
+
differences = {}
|
118
|
+
|
119
|
+
if self.reward_last != prev_state.reward_last:
|
120
|
+
differences["reward_last"] = (prev_state.reward_last, self.reward_last)
|
121
|
+
if self.total_reward != prev_state.total_reward:
|
122
|
+
differences["total_reward"] = (prev_state.total_reward, self.total_reward)
|
123
|
+
if self.score != prev_state.score:
|
124
|
+
differences["score"] = (prev_state.score, self.score)
|
125
|
+
if self.depth_reached != prev_state.depth_reached:
|
126
|
+
differences["depth_reached"] = (
|
127
|
+
prev_state.depth_reached,
|
128
|
+
self.depth_reached,
|
129
|
+
)
|
130
|
+
|
131
|
+
return differences
|
132
|
+
|
133
|
+
|
134
|
+
@dataclass
|
135
|
+
class NetHackEngineSnapshot(StatefulEngineSnapshot):
|
136
|
+
"""Serialization container for NetHack engine state."""
|
137
|
+
|
138
|
+
task_instance_dict: Dict[str, Any]
|
139
|
+
engine_snapshot: Dict[str, Any]
|
140
|
+
nle_state: Optional[Dict[str, Any]] = None # NLE-specific state if available
|
141
|
+
|
142
|
+
|
143
|
+
class NetHackSurvivalComponent(RewardComponent):
|
144
|
+
"""Reward component for staying alive."""
|
145
|
+
|
146
|
+
async def score(self, state: NetHackPublicState, action: str) -> float:
|
147
|
+
if state.terminated:
|
148
|
+
return -1.0 # Penalty for death
|
149
|
+
return 0.01 # Small reward for each turn survived
|
150
|
+
|
151
|
+
|
152
|
+
class NetHackProgressComponent(RewardComponent):
|
153
|
+
"""Reward component for exploration and depth."""
|
154
|
+
|
155
|
+
def __init__(self):
|
156
|
+
self.last_depth = 1
|
157
|
+
|
158
|
+
async def score(self, state: NetHackPublicState, action: str) -> float:
|
159
|
+
reward = 0.0
|
160
|
+
|
161
|
+
# Reward for reaching new dungeon levels
|
162
|
+
if state.dungeon_level > self.last_depth:
|
163
|
+
reward += 1.0 * (state.dungeon_level - self.last_depth)
|
164
|
+
self.last_depth = state.dungeon_level
|
165
|
+
|
166
|
+
return reward
|
167
|
+
|
168
|
+
|
169
|
+
class NetHackScoreComponent(RewardComponent):
|
170
|
+
"""Reward component based on game score."""
|
171
|
+
|
172
|
+
def __init__(self):
|
173
|
+
self.last_score = 0
|
174
|
+
|
175
|
+
async def score(self, state: NetHackPublicState, action: str) -> float:
|
176
|
+
# Get score from character stats - require it exists
|
177
|
+
current_score = state.character_stats["score"]
|
178
|
+
|
179
|
+
# Calculate score delta
|
180
|
+
score_delta = current_score - self.last_score
|
181
|
+
self.last_score = current_score
|
182
|
+
|
183
|
+
# Scale the score reward (NLE scores can be large)
|
184
|
+
return score_delta / 100.0 if score_delta > 0 else 0.0
|
185
|
+
|
186
|
+
|
187
|
+
class NetHackAchievementComponent(RewardComponent):
|
188
|
+
"""Reward component for unlocking achievements."""
|
189
|
+
|
190
|
+
def __init__(self):
|
191
|
+
self.last_unlocked = set()
|
192
|
+
|
193
|
+
async def score(self, state: NetHackPublicState, action: str) -> float:
|
194
|
+
reward = 0.0
|
195
|
+
|
196
|
+
# Count newly unlocked achievements
|
197
|
+
current_unlocked = set(k for k, v in state.achievements_unlocked.items() if v)
|
198
|
+
new_achievements = current_unlocked - self.last_unlocked
|
199
|
+
|
200
|
+
# Give rewards for different achievement types
|
201
|
+
for achievement in new_achievements:
|
202
|
+
if "first_" in achievement:
|
203
|
+
reward += 1.0 # First-time achievements
|
204
|
+
elif "reached_dlvl_" in achievement:
|
205
|
+
reward += 2.0 # Depth achievements
|
206
|
+
elif "killed_" in achievement and "monsters" in achievement:
|
207
|
+
reward += 0.5 # Kill milestones
|
208
|
+
elif "collected_" in achievement and "gold" in achievement:
|
209
|
+
reward += 0.5 # Gold milestones
|
210
|
+
elif "reached_level_" in achievement:
|
211
|
+
reward += 1.5 # Experience level milestones
|
212
|
+
elif "minetown" in achievement or "castle" in achievement:
|
213
|
+
reward += 5.0 # Special locations
|
214
|
+
elif "quest" in achievement:
|
215
|
+
reward += 10.0 # Quest achievements
|
216
|
+
else:
|
217
|
+
reward += 0.5 # Default reward
|
218
|
+
|
219
|
+
self.last_unlocked = current_unlocked
|
220
|
+
return reward
|
221
|
+
|
222
|
+
|
223
|
+
class NetHackEngine(StatefulEngine, IReproducibleEngine):
|
224
|
+
"""NetHack game engine with NLE backend."""
|
225
|
+
|
226
|
+
def __init__(self, task_instance: TaskInstance):
|
227
|
+
self.task_instance = task_instance
|
228
|
+
|
229
|
+
# Require proper metadata
|
230
|
+
from .taskset import NetHackTaskInstanceMetadata
|
231
|
+
|
232
|
+
if not isinstance(task_instance.metadata, NetHackTaskInstanceMetadata):
|
233
|
+
raise TypeError(
|
234
|
+
f"Expected NetHackTaskInstanceMetadata, got {type(task_instance.metadata).__name__}"
|
235
|
+
)
|
236
|
+
|
237
|
+
metadata = cast(NetHackTaskInstanceMetadata, task_instance.metadata)
|
238
|
+
self.character_role = metadata.character_role
|
239
|
+
self.max_turns = metadata.time_limit
|
240
|
+
|
241
|
+
# Initialize NLE wrapper
|
242
|
+
self.nle = NLEWrapper(character_role=self.character_role)
|
243
|
+
|
244
|
+
# Initialize reward components with proper tracking - NO SURVIVAL NOISE
|
245
|
+
self.progress_component = NetHackProgressComponent()
|
246
|
+
self.score_component = NetHackScoreComponent()
|
247
|
+
self.achievement_component = NetHackAchievementComponent()
|
248
|
+
|
249
|
+
self.reward_stack = RewardStack(
|
250
|
+
[
|
251
|
+
self.progress_component, # Depth progress
|
252
|
+
self.score_component, # Game score changes
|
253
|
+
self.achievement_component, # Achievement unlocks
|
254
|
+
]
|
255
|
+
)
|
256
|
+
|
257
|
+
# State tracking
|
258
|
+
self.public_state: Optional[NetHackPublicState] = None
|
259
|
+
self.private_state: Optional[NetHackPrivateState] = None
|
260
|
+
|
261
|
+
# NLE observation processing
|
262
|
+
self.last_nle_obs = None
|
263
|
+
|
264
|
+
async def _reset_engine(
|
265
|
+
self, *, seed: int | None = None
|
266
|
+
) -> Tuple[NetHackPrivateState, NetHackPublicState]:
|
267
|
+
"""Reset to initial state using NLE."""
|
268
|
+
# Reset NLE environment with seed
|
269
|
+
obs = await asyncio.to_thread(self.nle.reset, seed)
|
270
|
+
self.last_nle_obs = obs
|
271
|
+
|
272
|
+
# Log what we actually got from NLE
|
273
|
+
logger.info(f"NLE reset returned observation keys: {list(obs.keys())}")
|
274
|
+
if "player_stats" in obs:
|
275
|
+
logger.info(f"Player stats keys: {list(obs['player_stats'].keys())}")
|
276
|
+
|
277
|
+
# Initialize private state - require all fields
|
278
|
+
player_stats = obs["player_stats"] # Will KeyError if missing
|
279
|
+
self.private_state = NetHackPrivateState(
|
280
|
+
reward_last=0.0,
|
281
|
+
total_reward=0.0,
|
282
|
+
terminated=False,
|
283
|
+
truncated=False,
|
284
|
+
score=player_stats["score"],
|
285
|
+
depth_reached=player_stats["depth"],
|
286
|
+
experience_level=player_stats["experience_level"],
|
287
|
+
monsters_killed=0,
|
288
|
+
items_collected=0,
|
289
|
+
balrog_reward_last=0.0,
|
290
|
+
balrog_total_reward=0.0,
|
291
|
+
)
|
292
|
+
|
293
|
+
# Initialize public state from NLE observation - no fallbacks
|
294
|
+
self.public_state = NetHackPublicState(
|
295
|
+
dungeon_level=player_stats["depth"],
|
296
|
+
character_stats={
|
297
|
+
"hp": player_stats["hp"],
|
298
|
+
"max_hp": player_stats["max_hp"],
|
299
|
+
"strength": player_stats["strength"],
|
300
|
+
"dexterity": player_stats["dexterity"],
|
301
|
+
"constitution": player_stats["constitution"],
|
302
|
+
"intelligence": player_stats["intelligence"],
|
303
|
+
"wisdom": player_stats["wisdom"],
|
304
|
+
"charisma": player_stats["charisma"],
|
305
|
+
"gold": player_stats["gold"],
|
306
|
+
"experience": player_stats["experience_points"],
|
307
|
+
"level": player_stats["experience_level"],
|
308
|
+
"ac": player_stats["ac"],
|
309
|
+
},
|
310
|
+
inventory=self._process_inventory(obs["inventory"]) if "inventory" in obs else [],
|
311
|
+
position=(player_stats["y"], player_stats["x"]),
|
312
|
+
ascii_map=obs["ascii_map"],
|
313
|
+
message=obs["message"],
|
314
|
+
cursor_position=obs.get(
|
315
|
+
"cursor", (player_stats["y"], player_stats["x"])
|
316
|
+
), # Cursor might not be in processed obs
|
317
|
+
turn_count=0,
|
318
|
+
max_turns=self.max_turns,
|
319
|
+
last_action="",
|
320
|
+
terminated=False,
|
321
|
+
in_menu=obs.get("in_menu", False), # Menu detection is heuristic-based
|
322
|
+
menu_items=obs.get("menu_text", []), # Menu text only present when in menu
|
323
|
+
achievements=NetHackAchievements(),
|
324
|
+
achievements_unlocked={},
|
325
|
+
)
|
326
|
+
|
327
|
+
# Reset reward components
|
328
|
+
self.progress_component.last_depth = self.public_state.dungeon_level
|
329
|
+
self.score_component.last_score = self.private_state.score
|
330
|
+
|
331
|
+
return self.private_state, self.public_state
|
332
|
+
|
333
|
+
def _process_inventory(self, inventory_items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
334
|
+
"""Process NLE inventory format to our format."""
|
335
|
+
processed_items = []
|
336
|
+
for item in inventory_items:
|
337
|
+
processed_items.append(
|
338
|
+
{
|
339
|
+
"name": item["description"],
|
340
|
+
"count": 1, # NLE doesn't always provide count
|
341
|
+
"letter": item["letter"],
|
342
|
+
}
|
343
|
+
)
|
344
|
+
return processed_items
|
345
|
+
|
346
|
+
async def _step_engine(self, action: str) -> Tuple[NetHackPrivateState, NetHackPublicState]:
|
347
|
+
"""Execute one step/action using NLE."""
|
348
|
+
# print(f"===== NetHack Engine _step_engine called with action: {action} =====")
|
349
|
+
if self.public_state is None or self.private_state is None:
|
350
|
+
raise RuntimeError("Engine not initialized. Call _reset_engine first.")
|
351
|
+
|
352
|
+
# Validate action
|
353
|
+
if action not in self.nle.action_map and action not in ["terminate"]:
|
354
|
+
# Try to handle menu selections and special cases
|
355
|
+
if len(action) == 1 and (action.isalpha() or action.isdigit()):
|
356
|
+
# Single character actions are likely menu selections
|
357
|
+
pass
|
358
|
+
else:
|
359
|
+
raise ValueError(
|
360
|
+
f"Invalid action: {action}. Valid actions: {list(self.nle.action_map.keys())}"
|
361
|
+
)
|
362
|
+
|
363
|
+
# Update turn count
|
364
|
+
self.public_state.turn_count += 1
|
365
|
+
self.public_state.last_action = action
|
366
|
+
|
367
|
+
# Define non-turn-consuming actions
|
368
|
+
non_turn_actions = [
|
369
|
+
"look",
|
370
|
+
"farlook",
|
371
|
+
"whatis",
|
372
|
+
"identify",
|
373
|
+
"discoveries",
|
374
|
+
"conduct",
|
375
|
+
"attributes",
|
376
|
+
"help",
|
377
|
+
"version",
|
378
|
+
"history",
|
379
|
+
]
|
380
|
+
|
381
|
+
# Warn about non-advancing actions
|
382
|
+
if action in non_turn_actions:
|
383
|
+
logger.warning(f"Action '{action}' is a free action that doesn't advance game time!")
|
384
|
+
# If we're repeatedly using non-advancing actions, force a wait
|
385
|
+
if hasattr(self, "_consecutive_free_actions"):
|
386
|
+
self._consecutive_free_actions += 1
|
387
|
+
if self._consecutive_free_actions >= 3:
|
388
|
+
logger.warning(
|
389
|
+
f"Too many consecutive free actions ({self._consecutive_free_actions}), forcing 'wait'"
|
390
|
+
)
|
391
|
+
action = "wait"
|
392
|
+
self._consecutive_free_actions = 0
|
393
|
+
else:
|
394
|
+
self._consecutive_free_actions = 1
|
395
|
+
else:
|
396
|
+
self._consecutive_free_actions = 0
|
397
|
+
|
398
|
+
# Check for manual termination
|
399
|
+
if action == "terminate":
|
400
|
+
self.public_state.terminated = True
|
401
|
+
self.private_state.terminated = True
|
402
|
+
self.public_state.message = "Game terminated by agent."
|
403
|
+
return self.private_state, self.public_state
|
404
|
+
|
405
|
+
# Check for timeout
|
406
|
+
if self.public_state.turn_count >= self.public_state.max_turns:
|
407
|
+
self.public_state.terminated = True
|
408
|
+
self.private_state.terminated = True
|
409
|
+
self.private_state.truncated = True
|
410
|
+
self.public_state.message = "Time limit reached. Game over!"
|
411
|
+
return self.private_state, self.public_state
|
412
|
+
|
413
|
+
# Execute action in NLE
|
414
|
+
try:
|
415
|
+
# Save previous observation BEFORE stepping
|
416
|
+
prev_obs = self.last_nle_obs
|
417
|
+
|
418
|
+
obs, reward, done, info = await asyncio.to_thread(self.nle.step, action)
|
419
|
+
logger.debug(f"NLE step returned - reward: {reward}, done: {done}, info: {info}")
|
420
|
+
except Exception as e:
|
421
|
+
logger.error(f"NLE step failed for action '{action}': {e}")
|
422
|
+
raise
|
423
|
+
|
424
|
+
# Log observation structure on first few steps for debugging
|
425
|
+
if self.public_state.turn_count < 3:
|
426
|
+
logger.info(f"Turn {self.public_state.turn_count} observation keys: {list(obs.keys())}")
|
427
|
+
|
428
|
+
# Update state from NLE observation - no defensive coding
|
429
|
+
player_stats = obs["player_stats"] # Will KeyError if missing
|
430
|
+
|
431
|
+
# Track previous values for reward calculation
|
432
|
+
prev_score = self.private_state.score
|
433
|
+
prev_depth = self.private_state.depth_reached
|
434
|
+
|
435
|
+
# Update private state
|
436
|
+
self.private_state.score = player_stats["score"]
|
437
|
+
self.private_state.depth_reached = max(
|
438
|
+
self.private_state.depth_reached, player_stats["depth"]
|
439
|
+
)
|
440
|
+
self.private_state.experience_level = player_stats["experience_level"]
|
441
|
+
|
442
|
+
# Update public state
|
443
|
+
self.public_state.dungeon_level = player_stats["depth"]
|
444
|
+
self.public_state.position = (player_stats["y"], player_stats["x"])
|
445
|
+
self.public_state.ascii_map = obs["ascii_map"]
|
446
|
+
self.public_state.message = obs["message"]
|
447
|
+
self.public_state.cursor_position = obs.get(
|
448
|
+
"cursor", (player_stats["y"], player_stats["x"])
|
449
|
+
)
|
450
|
+
self.public_state.in_menu = obs.get("in_menu", False)
|
451
|
+
self.public_state.menu_items = obs.get("menu_text", [])
|
452
|
+
|
453
|
+
# Update character stats - require all fields
|
454
|
+
self.public_state.character_stats = {
|
455
|
+
"hp": player_stats["hp"],
|
456
|
+
"max_hp": player_stats["max_hp"],
|
457
|
+
"strength": player_stats["strength"],
|
458
|
+
"dexterity": player_stats["dexterity"],
|
459
|
+
"constitution": player_stats["constitution"],
|
460
|
+
"intelligence": player_stats["intelligence"],
|
461
|
+
"wisdom": player_stats["wisdom"],
|
462
|
+
"charisma": player_stats["charisma"],
|
463
|
+
"gold": player_stats["gold"],
|
464
|
+
"experience": player_stats["experience_points"],
|
465
|
+
"level": player_stats["experience_level"],
|
466
|
+
"ac": player_stats["ac"],
|
467
|
+
"score": player_stats["score"],
|
468
|
+
}
|
469
|
+
|
470
|
+
# Update inventory
|
471
|
+
self.public_state.inventory = (
|
472
|
+
self._process_inventory(obs["inventory"]) if "inventory" in obs else []
|
473
|
+
)
|
474
|
+
|
475
|
+
# Handle termination from NLE
|
476
|
+
if done:
|
477
|
+
self.public_state.terminated = True
|
478
|
+
self.private_state.terminated = True
|
479
|
+
# Log info to understand structure
|
480
|
+
logger.info(f"Game ended - info: {info}")
|
481
|
+
if "end_status" in info and info["end_status"] == 0: # 0 means death
|
482
|
+
self.public_state.message = info.get(
|
483
|
+
"death_reason", "You died!"
|
484
|
+
) # death_reason might not always exist
|
485
|
+
else:
|
486
|
+
self.public_state.message = "Game ended."
|
487
|
+
|
488
|
+
# Update achievements before calculating rewards
|
489
|
+
newly_unlocked = self.public_state.achievements.update_from_observation(obs, prev_obs)
|
490
|
+
self.public_state.achievements_unlocked.update(
|
491
|
+
self.public_state.achievements.get_unlocked_achievements()
|
492
|
+
)
|
493
|
+
|
494
|
+
# Log newly unlocked achievements
|
495
|
+
if newly_unlocked:
|
496
|
+
logger.info(f"Achievements unlocked: {list(newly_unlocked.keys())}")
|
497
|
+
|
498
|
+
# Calculate rewards
|
499
|
+
# Base reward from NLE
|
500
|
+
nle_reward = reward
|
501
|
+
|
502
|
+
# Additional reward shaping
|
503
|
+
step_reward = await self.reward_stack.step_reward(self.public_state, action)
|
504
|
+
|
505
|
+
self.private_state.reward_last = nle_reward + step_reward
|
506
|
+
self.private_state.total_reward += self.private_state.reward_last
|
507
|
+
|
508
|
+
# Calculate Balrog-style reward
|
509
|
+
self.private_state.balrog_reward_last = calculate_balrog_reward(obs, prev_obs)
|
510
|
+
self.private_state.balrog_total_reward += self.private_state.balrog_reward_last
|
511
|
+
|
512
|
+
# Log balrog reward changes with context
|
513
|
+
if self.private_state.balrog_reward_last > 0:
|
514
|
+
print(
|
515
|
+
f"🏆 BALROG REWARD: +{self.private_state.balrog_reward_last:.3f} (total: {self.private_state.balrog_total_reward:.3f})"
|
516
|
+
)
|
517
|
+
balrog_score = self.public_state.achievements.balrog_progress.percent
|
518
|
+
print(
|
519
|
+
f" Balrog score: {balrog_score}% (dungeon: {self.public_state.achievements.balrog_progress.dungeon_progression}, exp: {self.public_state.achievements.balrog_progress.experience_progression})"
|
520
|
+
)
|
521
|
+
|
522
|
+
# NOW update last_nle_obs for next step
|
523
|
+
self.last_nle_obs = obs
|
524
|
+
|
525
|
+
return self.private_state, self.public_state
|
526
|
+
|
527
|
+
def __del__(self):
|
528
|
+
"""Cleanup NLE environment on deletion."""
|
529
|
+
if hasattr(self, "nle"):
|
530
|
+
self.nle.close()
|
531
|
+
|
532
|
+
async def _serialize_engine(self) -> NetHackEngineSnapshot:
|
533
|
+
"""Serialize current state."""
|
534
|
+
if self.public_state is None or self.private_state is None:
|
535
|
+
raise RuntimeError("Cannot serialize uninitialized engine")
|
536
|
+
|
537
|
+
# Get NLE state
|
538
|
+
nle_state = None
|
539
|
+
try:
|
540
|
+
nle_state_bytes = await asyncio.to_thread(self.nle.get_state)
|
541
|
+
# Convert bytes to base64 string for JSON serialization
|
542
|
+
nle_state = base64.b64encode(nle_state_bytes).decode("ascii")
|
543
|
+
except Exception as e:
|
544
|
+
logger.warning(f"Failed to serialize NLE state: {e}")
|
545
|
+
|
546
|
+
task_dict = await self.task_instance.serialize()
|
547
|
+
logger.debug(f"Serialized task instance: {task_dict}")
|
548
|
+
|
549
|
+
return NetHackEngineSnapshot(
|
550
|
+
task_instance_dict=task_dict,
|
551
|
+
engine_snapshot={
|
552
|
+
"public_state": {
|
553
|
+
"dungeon_level": self.public_state.dungeon_level,
|
554
|
+
"character_stats": self.public_state.character_stats,
|
555
|
+
"inventory": self.public_state.inventory,
|
556
|
+
"position": self.public_state.position,
|
557
|
+
"ascii_map": self.public_state.ascii_map,
|
558
|
+
"message": self.public_state.message,
|
559
|
+
"cursor_position": self.public_state.cursor_position,
|
560
|
+
"turn_count": self.public_state.turn_count,
|
561
|
+
"max_turns": self.public_state.max_turns,
|
562
|
+
"last_action": self.public_state.last_action,
|
563
|
+
"terminated": self.public_state.terminated,
|
564
|
+
"in_menu": self.public_state.in_menu,
|
565
|
+
"menu_items": self.public_state.menu_items,
|
566
|
+
},
|
567
|
+
"private_state": {
|
568
|
+
"reward_last": self.private_state.reward_last,
|
569
|
+
"total_reward": self.private_state.total_reward,
|
570
|
+
"terminated": self.private_state.terminated,
|
571
|
+
"truncated": self.private_state.truncated,
|
572
|
+
"score": self.private_state.score,
|
573
|
+
"depth_reached": self.private_state.depth_reached,
|
574
|
+
"experience_level": self.private_state.experience_level,
|
575
|
+
"monsters_killed": self.private_state.monsters_killed,
|
576
|
+
"items_collected": self.private_state.items_collected,
|
577
|
+
},
|
578
|
+
"character_role": self.character_role,
|
579
|
+
"progress_last_depth": self.progress_component.last_depth,
|
580
|
+
"score_last_score": self.score_component.last_score,
|
581
|
+
},
|
582
|
+
nle_state=nle_state,
|
583
|
+
)
|
584
|
+
|
585
|
+
@classmethod
|
586
|
+
async def _deserialize_engine(cls, snapshot: NetHackEngineSnapshot) -> "NetHackEngine":
|
587
|
+
"""Restore from serialized state."""
|
588
|
+
from .taskset import NetHackTaskInstance
|
589
|
+
|
590
|
+
task_instance = await NetHackTaskInstance.deserialize(snapshot.task_instance_dict)
|
591
|
+
if task_instance is None:
|
592
|
+
raise ValueError("Failed to deserialize task instance")
|
593
|
+
engine = cls(task_instance)
|
594
|
+
|
595
|
+
# Restore state
|
596
|
+
engine_data = snapshot.engine_snapshot
|
597
|
+
pub_data = engine_data["public_state"]
|
598
|
+
priv_data = engine_data["private_state"]
|
599
|
+
|
600
|
+
engine.public_state = NetHackPublicState(
|
601
|
+
dungeon_level=pub_data["dungeon_level"],
|
602
|
+
character_stats=pub_data["character_stats"],
|
603
|
+
inventory=pub_data["inventory"],
|
604
|
+
position=(pub_data["position"][0], pub_data["position"][1]),
|
605
|
+
ascii_map=pub_data["ascii_map"],
|
606
|
+
message=pub_data["message"],
|
607
|
+
cursor_position=(
|
608
|
+
pub_data["cursor_position"][0],
|
609
|
+
pub_data["cursor_position"][1],
|
610
|
+
),
|
611
|
+
turn_count=pub_data["turn_count"],
|
612
|
+
max_turns=pub_data["max_turns"],
|
613
|
+
last_action=pub_data["last_action"],
|
614
|
+
terminated=pub_data["terminated"],
|
615
|
+
in_menu=pub_data["in_menu"],
|
616
|
+
menu_items=pub_data["menu_items"],
|
617
|
+
)
|
618
|
+
|
619
|
+
engine.private_state = NetHackPrivateState(
|
620
|
+
reward_last=priv_data["reward_last"],
|
621
|
+
total_reward=priv_data["total_reward"],
|
622
|
+
terminated=priv_data["terminated"],
|
623
|
+
truncated=priv_data["truncated"],
|
624
|
+
score=priv_data["score"],
|
625
|
+
depth_reached=priv_data["depth_reached"],
|
626
|
+
experience_level=priv_data["experience_level"],
|
627
|
+
monsters_killed=priv_data["monsters_killed"],
|
628
|
+
items_collected=priv_data["items_collected"],
|
629
|
+
)
|
630
|
+
|
631
|
+
engine.character_role = engine_data["character_role"]
|
632
|
+
|
633
|
+
# Restore reward component states
|
634
|
+
engine.progress_component.last_depth = engine_data["progress_last_depth"]
|
635
|
+
engine.score_component.last_score = engine_data["score_last_score"]
|
636
|
+
|
637
|
+
# Restore NLE state if available
|
638
|
+
if snapshot.nle_state:
|
639
|
+
try:
|
640
|
+
nle_state_bytes = base64.b64decode(snapshot.nle_state)
|
641
|
+
await asyncio.to_thread(engine.nle.set_state, nle_state_bytes)
|
642
|
+
except Exception as e:
|
643
|
+
logger.warning(f"Failed to restore NLE state: {e}")
|
644
|
+
# If we can't restore NLE state, reset it
|
645
|
+
await asyncio.to_thread(engine.nle.reset)
|
646
|
+
|
647
|
+
return engine
|
648
|
+
|
649
|
+
def get_current_states_for_observation(
|
650
|
+
self,
|
651
|
+
) -> Tuple[NetHackPrivateState, NetHackPublicState]:
|
652
|
+
"""Get current states without advancing."""
|
653
|
+
if self.public_state is None or self.private_state is None:
|
654
|
+
raise RuntimeError("Engine not initialized")
|
655
|
+
return self.private_state, self.public_state
|
656
|
+
|
657
|
+
|
658
|
+
class NetHackObservationCallable(GetObservationCallable):
|
659
|
+
"""Standard observation callable for NetHack."""
|
660
|
+
|
661
|
+
async def get_observation(
|
662
|
+
self, pub: NetHackPublicState, priv: NetHackPrivateState
|
663
|
+
) -> InternalObservation:
|
664
|
+
observation = {
|
665
|
+
"ascii_map": pub.ascii_map,
|
666
|
+
"message": pub.message,
|
667
|
+
"character_stats": pub.character_stats,
|
668
|
+
"inventory_summary": self._format_inventory(pub.inventory),
|
669
|
+
"dungeon_level": pub.dungeon_level,
|
670
|
+
"position": pub.position,
|
671
|
+
"turn_count": pub.turn_count,
|
672
|
+
"last_action": pub.last_action,
|
673
|
+
"reward_last": priv.reward_last,
|
674
|
+
"total_reward": priv.total_reward,
|
675
|
+
"balrog_reward_last": priv.balrog_reward_last,
|
676
|
+
"balrog_total_reward": priv.balrog_total_reward,
|
677
|
+
"score": priv.score,
|
678
|
+
"experience_level": priv.experience_level,
|
679
|
+
"terminated": priv.terminated,
|
680
|
+
"in_menu": pub.in_menu,
|
681
|
+
"menu_items": pub.menu_items if pub.in_menu else [],
|
682
|
+
"achievements_unlocked": pub.achievements_unlocked,
|
683
|
+
"achievements_summary": self._format_achievements(pub.achievements_unlocked),
|
684
|
+
}
|
685
|
+
return observation # type: ignore[return-value]
|
686
|
+
|
687
|
+
def _format_inventory(self, inventory: List[Dict[str, Any]]) -> str:
|
688
|
+
"""Format inventory for display."""
|
689
|
+
if not inventory:
|
690
|
+
return "Your inventory is empty."
|
691
|
+
|
692
|
+
items = []
|
693
|
+
for item in inventory:
|
694
|
+
items.append(f"- {item['name']} (count: {item.get('count', 1)})")
|
695
|
+
return "\n".join(items)
|
696
|
+
|
697
|
+
def _format_achievements(self, achievements: Dict[str, bool]) -> str:
|
698
|
+
"""Format achievements for display."""
|
699
|
+
unlocked = [name for name, status in achievements.items() if status]
|
700
|
+
if not unlocked:
|
701
|
+
return "None unlocked yet"
|
702
|
+
if len(unlocked) <= 5:
|
703
|
+
return ", ".join(unlocked)
|
704
|
+
else:
|
705
|
+
return f"{', '.join(unlocked[:5])} and {len(unlocked) - 5} more"
|
706
|
+
|
707
|
+
|
708
|
+
class NetHackCheckpointObservationCallable(GetObservationCallable):
|
709
|
+
"""Checkpoint observation callable for NetHack."""
|
710
|
+
|
711
|
+
async def get_observation(
|
712
|
+
self, pub: NetHackPublicState, priv: NetHackPrivateState
|
713
|
+
) -> InternalObservation:
|
714
|
+
observation = {
|
715
|
+
"final_score": priv.score,
|
716
|
+
"max_depth": priv.depth_reached,
|
717
|
+
"experience_level": priv.experience_level,
|
718
|
+
"monsters_killed": priv.monsters_killed,
|
719
|
+
"items_collected": priv.items_collected,
|
720
|
+
"turn_count_final": pub.turn_count,
|
721
|
+
"total_reward": priv.total_reward,
|
722
|
+
"balrog_total_reward": priv.balrog_total_reward,
|
723
|
+
"terminated": priv.terminated,
|
724
|
+
"truncated": priv.truncated,
|
725
|
+
"character_role": pub.character_stats.get("role", "unknown"),
|
726
|
+
"achievements_unlocked": list(pub.achievements_unlocked.keys()),
|
727
|
+
"achievements_count": len([v for v in pub.achievements_unlocked.values() if v]),
|
728
|
+
"achievement_stats": {
|
729
|
+
"depth_reached": pub.achievements.depth_reached,
|
730
|
+
"monsters_killed": pub.achievements.monsters_killed,
|
731
|
+
"gold_collected": pub.achievements.gold_collected,
|
732
|
+
"items_collected": pub.achievements.items_picked_up,
|
733
|
+
"max_level": pub.achievements.max_level_reached,
|
734
|
+
"turns_survived": pub.achievements.turns_survived,
|
735
|
+
"balrog_score": pub.achievements.balrog_progress.percent,
|
736
|
+
},
|
737
|
+
}
|
738
|
+
return observation # type: ignore[return-value]
|