synth-ai 0.2.0__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +35 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +1 -1
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.2.0.dist-info/METADATA +0 -36
- synth_ai-0.2.0.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info/licenses}/LICENSE +0 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,323 @@
|
|
1
|
+
"""TaskSet generation for NetHack environment."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import random
|
6
|
+
from uuid import uuid4
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from typing import Dict, Any, List, Optional, Set
|
9
|
+
|
10
|
+
from synth_ai.environments.tasks.core import (
|
11
|
+
TaskInstance,
|
12
|
+
TaskInstanceMetadata,
|
13
|
+
TaskInstanceSet,
|
14
|
+
Impetus,
|
15
|
+
Intent,
|
16
|
+
SplitInfo,
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class NetHackTaskInstanceMetadata(TaskInstanceMetadata):
|
22
|
+
"""Task-specific metadata for NetHack."""
|
23
|
+
|
24
|
+
character_role: str # "wizard", "knight", etc.
|
25
|
+
starting_level: int # Dungeon level to start on
|
26
|
+
target_depth: int # Goal depth to reach
|
27
|
+
time_limit: int # Maximum turns
|
28
|
+
difficulty: str # "easy", "medium", "hard"
|
29
|
+
special_objectives: List[str] # Additional goals beyond survival
|
30
|
+
seed: Optional[int] = None # Random seed for reproducibility
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class NetHackTaskInstance(TaskInstance):
|
35
|
+
"""NetHack task instance."""
|
36
|
+
|
37
|
+
async def serialize(self) -> dict:
|
38
|
+
"""Convert to serializable format."""
|
39
|
+
return {
|
40
|
+
"id": str(self.id),
|
41
|
+
"impetus": {"instructions": self.impetus.instructions},
|
42
|
+
"intent": {
|
43
|
+
"rubric": self.intent.rubric,
|
44
|
+
"gold_trajectories": None,
|
45
|
+
"gold_state_diff": self.intent.gold_state_diff,
|
46
|
+
},
|
47
|
+
"metadata": {
|
48
|
+
"character_role": self.metadata.character_role,
|
49
|
+
"starting_level": self.metadata.starting_level,
|
50
|
+
"target_depth": self.metadata.target_depth,
|
51
|
+
"time_limit": self.metadata.time_limit,
|
52
|
+
"difficulty": self.metadata.difficulty,
|
53
|
+
"special_objectives": self.metadata.special_objectives,
|
54
|
+
"seed": self.metadata.seed,
|
55
|
+
},
|
56
|
+
"is_reproducible": self.is_reproducible,
|
57
|
+
"initial_engine_snapshot": None,
|
58
|
+
}
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
async def deserialize(cls, data: dict) -> "NetHackTaskInstance":
|
62
|
+
"""Restore from serialized data."""
|
63
|
+
return cls(
|
64
|
+
id=uuid4() if not data.get("id") else data["id"],
|
65
|
+
impetus=Impetus(instructions=data["impetus"]["instructions"]),
|
66
|
+
intent=Intent(
|
67
|
+
rubric=data["intent"]["rubric"],
|
68
|
+
gold_trajectories=None,
|
69
|
+
gold_state_diff=data["intent"]["gold_state_diff"],
|
70
|
+
),
|
71
|
+
metadata=NetHackTaskInstanceMetadata(
|
72
|
+
character_role=data["metadata"]["character_role"],
|
73
|
+
starting_level=data["metadata"]["starting_level"],
|
74
|
+
target_depth=data["metadata"]["target_depth"],
|
75
|
+
time_limit=data["metadata"]["time_limit"],
|
76
|
+
difficulty=data["metadata"]["difficulty"],
|
77
|
+
special_objectives=data["metadata"]["special_objectives"],
|
78
|
+
seed=data["metadata"].get("seed"),
|
79
|
+
),
|
80
|
+
is_reproducible=data.get("is_reproducible", True),
|
81
|
+
initial_engine_snapshot=None,
|
82
|
+
)
|
83
|
+
|
84
|
+
|
85
|
+
# Character role definitions
|
86
|
+
CHARACTER_ROLES = {
|
87
|
+
"tourist": {
|
88
|
+
"description": "A tourist with a camera and Hawaiian shirt",
|
89
|
+
"difficulty_modifier": 0.8, # Easier
|
90
|
+
"starting_items": ["camera", "credit card", "hawaiian shirt"],
|
91
|
+
"strengths": ["gold finding", "luck"],
|
92
|
+
"weaknesses": ["combat", "magic"],
|
93
|
+
},
|
94
|
+
"knight": {
|
95
|
+
"description": "A noble knight in shining armor",
|
96
|
+
"difficulty_modifier": 1.0,
|
97
|
+
"starting_items": ["long sword", "armor", "shield"],
|
98
|
+
"strengths": ["combat", "riding"],
|
99
|
+
"weaknesses": ["magic"],
|
100
|
+
},
|
101
|
+
"wizard": {
|
102
|
+
"description": "A powerful wizard with magical abilities",
|
103
|
+
"difficulty_modifier": 1.2,
|
104
|
+
"starting_items": ["quarterstaff", "spellbook", "cloak"],
|
105
|
+
"strengths": ["magic", "identify"],
|
106
|
+
"weaknesses": ["physical combat", "low hp"],
|
107
|
+
},
|
108
|
+
"barbarian": {
|
109
|
+
"description": "A fierce barbarian warrior",
|
110
|
+
"difficulty_modifier": 0.9,
|
111
|
+
"starting_items": ["battle axe", "leather armor"],
|
112
|
+
"strengths": ["combat", "hp", "strength"],
|
113
|
+
"weaknesses": ["magic", "intelligence"],
|
114
|
+
},
|
115
|
+
"ranger": {
|
116
|
+
"description": "A skilled ranger and tracker",
|
117
|
+
"difficulty_modifier": 1.0,
|
118
|
+
"starting_items": ["bow", "arrows", "cloak"],
|
119
|
+
"strengths": ["ranged combat", "stealth"],
|
120
|
+
"weaknesses": ["melee combat"],
|
121
|
+
},
|
122
|
+
"priest": {
|
123
|
+
"description": "A holy priest with divine powers",
|
124
|
+
"difficulty_modifier": 1.1,
|
125
|
+
"starting_items": ["mace", "robe", "holy water"],
|
126
|
+
"strengths": ["healing", "undead turning"],
|
127
|
+
"weaknesses": ["edged weapons"],
|
128
|
+
},
|
129
|
+
"monk": {
|
130
|
+
"description": "A disciplined monk with martial arts skills",
|
131
|
+
"difficulty_modifier": 1.3,
|
132
|
+
"starting_items": ["robe"],
|
133
|
+
"strengths": ["martial arts", "speed"],
|
134
|
+
"weaknesses": ["armor", "weapons"],
|
135
|
+
},
|
136
|
+
"rogue": {
|
137
|
+
"description": "A stealthy rogue and thief",
|
138
|
+
"difficulty_modifier": 1.1,
|
139
|
+
"starting_items": ["dagger", "leather armor", "lock pick"],
|
140
|
+
"strengths": ["stealth", "backstab", "traps"],
|
141
|
+
"weaknesses": ["direct combat"],
|
142
|
+
},
|
143
|
+
}
|
144
|
+
|
145
|
+
# Special objectives for variety
|
146
|
+
SPECIAL_OBJECTIVES = {
|
147
|
+
"exploration": [
|
148
|
+
"Explore at least 3 different dungeon levels",
|
149
|
+
"Find and enter a shop",
|
150
|
+
"Discover a special room (vault, zoo, etc.)",
|
151
|
+
"Find the entrance to the Gnomish Mines",
|
152
|
+
],
|
153
|
+
"combat": [
|
154
|
+
"Defeat 10 monsters",
|
155
|
+
"Defeat a monster using magic",
|
156
|
+
"Defeat a monster using ranged weapons",
|
157
|
+
"Survive an encounter with a tough monster",
|
158
|
+
],
|
159
|
+
"collection": [
|
160
|
+
"Collect 100 gold pieces",
|
161
|
+
"Find and identify a magical item",
|
162
|
+
"Collect food rations for survival",
|
163
|
+
"Find a valuable gem",
|
164
|
+
],
|
165
|
+
"survival": [
|
166
|
+
"Survive for 500 turns",
|
167
|
+
"Maintain full health for 100 turns",
|
168
|
+
"Never let hunger status reach 'Weak'",
|
169
|
+
"Avoid all traps",
|
170
|
+
],
|
171
|
+
"progression": [
|
172
|
+
"Reach experience level 3",
|
173
|
+
"Improve at least one skill",
|
174
|
+
"Successfully pray to your deity",
|
175
|
+
"Complete a quest or mission",
|
176
|
+
],
|
177
|
+
}
|
178
|
+
|
179
|
+
|
180
|
+
async def create_nethack_taskset() -> TaskInstanceSet:
|
181
|
+
"""Generate diverse NetHack scenarios."""
|
182
|
+
instances = []
|
183
|
+
|
184
|
+
# Configuration for different difficulty levels
|
185
|
+
DIFFICULTY_CONFIGS = {
|
186
|
+
"tutorial": {
|
187
|
+
"roles": ["tourist"],
|
188
|
+
"target_depth_range": (1, 3),
|
189
|
+
"time_limit_range": (500, 1000),
|
190
|
+
"objective_count": 1,
|
191
|
+
"count": 20,
|
192
|
+
},
|
193
|
+
"beginner": {
|
194
|
+
"roles": ["knight", "barbarian"],
|
195
|
+
"target_depth_range": (3, 5),
|
196
|
+
"time_limit_range": (1000, 2000),
|
197
|
+
"objective_count": 2,
|
198
|
+
"count": 30,
|
199
|
+
},
|
200
|
+
"intermediate": {
|
201
|
+
"roles": ["wizard", "ranger", "priest"],
|
202
|
+
"target_depth_range": (5, 10),
|
203
|
+
"time_limit_range": (2000, 5000),
|
204
|
+
"objective_count": 2,
|
205
|
+
"count": 25,
|
206
|
+
},
|
207
|
+
"advanced": {
|
208
|
+
"roles": ["monk", "rogue"],
|
209
|
+
"target_depth_range": (10, 15),
|
210
|
+
"time_limit_range": (5000, 10000),
|
211
|
+
"objective_count": 3,
|
212
|
+
"count": 15,
|
213
|
+
},
|
214
|
+
"expert": {
|
215
|
+
"roles": list(CHARACTER_ROLES.keys()),
|
216
|
+
"target_depth_range": (15, 20),
|
217
|
+
"time_limit_range": (10000, 20000),
|
218
|
+
"objective_count": 4,
|
219
|
+
"count": 10,
|
220
|
+
},
|
221
|
+
}
|
222
|
+
|
223
|
+
# Generate instances for each difficulty
|
224
|
+
for difficulty, config in DIFFICULTY_CONFIGS.items():
|
225
|
+
for i in range(config["count"]):
|
226
|
+
# Random role selection
|
227
|
+
role = random.choice(config["roles"])
|
228
|
+
role_info = CHARACTER_ROLES[role]
|
229
|
+
|
230
|
+
# Random parameters within difficulty range
|
231
|
+
min_depth, max_depth = config["target_depth_range"]
|
232
|
+
target_depth = random.randint(min_depth, max_depth)
|
233
|
+
min_time, max_time = config["time_limit_range"]
|
234
|
+
time_limit = random.randint(min_time, max_time)
|
235
|
+
|
236
|
+
# Select random objectives
|
237
|
+
objectives = []
|
238
|
+
objective_categories = list(SPECIAL_OBJECTIVES.keys())
|
239
|
+
for _ in range(config["objective_count"]):
|
240
|
+
category = random.choice(objective_categories)
|
241
|
+
objective = random.choice(SPECIAL_OBJECTIVES[category])
|
242
|
+
objectives.append(objective)
|
243
|
+
|
244
|
+
# Create instruction text
|
245
|
+
instructions = f"""You are a {role_info["description"]}.
|
246
|
+
|
247
|
+
Your primary goal is to descend to dungeon level {target_depth} within {time_limit} turns.
|
248
|
+
|
249
|
+
Additional objectives:
|
250
|
+
{chr(10).join(f"- {obj}" for obj in objectives)}
|
251
|
+
|
252
|
+
Character strengths: {", ".join(role_info["strengths"])}
|
253
|
+
Character weaknesses: {", ".join(role_info["weaknesses"])}
|
254
|
+
|
255
|
+
Tips:
|
256
|
+
- Use 'inventory' to check your items
|
257
|
+
- Use 'search' to find secret doors
|
258
|
+
- Eat food before you become weak from hunger
|
259
|
+
- Save valuable items for when you need them
|
260
|
+
- Be cautious around unfamiliar monsters
|
261
|
+
|
262
|
+
Remember: In NetHack, careful planning often beats hasty action!"""
|
263
|
+
|
264
|
+
# Create success criteria
|
265
|
+
rubric = {
|
266
|
+
"goal": f"Reach dungeon level {target_depth}",
|
267
|
+
"success_criteria": {
|
268
|
+
"primary": f"Reach dungeon level {target_depth} within {time_limit} turns",
|
269
|
+
"secondary": objectives,
|
270
|
+
},
|
271
|
+
"evaluation_metrics": {
|
272
|
+
"depth_reached": target_depth,
|
273
|
+
"time_limit": time_limit,
|
274
|
+
"objectives_completed": len(objectives),
|
275
|
+
},
|
276
|
+
}
|
277
|
+
|
278
|
+
# Create metadata
|
279
|
+
metadata = NetHackTaskInstanceMetadata(
|
280
|
+
character_role=role,
|
281
|
+
starting_level=1,
|
282
|
+
target_depth=target_depth,
|
283
|
+
time_limit=time_limit,
|
284
|
+
difficulty=difficulty,
|
285
|
+
special_objectives=objectives,
|
286
|
+
seed=random.randint(0, 2**31 - 1),
|
287
|
+
)
|
288
|
+
|
289
|
+
# Create task instance
|
290
|
+
instance = NetHackTaskInstance(
|
291
|
+
id=uuid4(),
|
292
|
+
impetus=Impetus(instructions=instructions),
|
293
|
+
intent=Intent(rubric=rubric, gold_trajectories=None, gold_state_diff={}),
|
294
|
+
metadata=metadata,
|
295
|
+
is_reproducible=True,
|
296
|
+
initial_engine_snapshot=None,
|
297
|
+
)
|
298
|
+
|
299
|
+
instances.append(instance)
|
300
|
+
|
301
|
+
# Define splits (80% train, 10% val, 10% test)
|
302
|
+
random.shuffle(instances)
|
303
|
+
n_instances = len(instances)
|
304
|
+
n_val = n_instances // 10
|
305
|
+
n_test = n_instances // 10
|
306
|
+
|
307
|
+
val_ids = {inst.id for inst in instances[:n_val]}
|
308
|
+
test_ids = {inst.id for inst in instances[n_val : n_val + n_test]}
|
309
|
+
|
310
|
+
split_info = SplitInfo(
|
311
|
+
val_instance_ids=val_ids, test_instance_ids=test_ids, _is_split_defined=True
|
312
|
+
)
|
313
|
+
|
314
|
+
return TaskInstanceSet(
|
315
|
+
name="NetHack TaskSet",
|
316
|
+
description="A comprehensive set of NetHack dungeon exploration tasks with varying difficulty levels, character roles, and objectives",
|
317
|
+
instances=instances,
|
318
|
+
split_info=split_info,
|
319
|
+
)
|
320
|
+
|
321
|
+
|
322
|
+
# Module-level export
|
323
|
+
taskset = create_nethack_taskset
|
@@ -0,0 +1,277 @@
|
|
1
|
+
"""Unit tests for NetHack engine."""
|
2
|
+
|
3
|
+
import pytest
|
4
|
+
import asyncio
|
5
|
+
from uuid import uuid4
|
6
|
+
|
7
|
+
from synth_ai.environments.tasks.core import TaskInstance, Impetus, Intent, TaskInstanceMetadata
|
8
|
+
|
9
|
+
from synth_ai.environments.examples.nethack.engine import (
|
10
|
+
NetHackEngine,
|
11
|
+
NetHackPublicState,
|
12
|
+
NetHackPrivateState,
|
13
|
+
NetHackEngineSnapshot,
|
14
|
+
NetHackSurvivalComponent,
|
15
|
+
NetHackProgressComponent,
|
16
|
+
)
|
17
|
+
from synth_ai.environments.examples.nethack.taskset import (
|
18
|
+
NetHackTaskInstanceMetadata,
|
19
|
+
NetHackTaskInstance,
|
20
|
+
)
|
21
|
+
|
22
|
+
# Since engine requires NLE, all tests require it
|
23
|
+
pytest.importorskip("nle", reason="NLE is required for NetHack engine")
|
24
|
+
|
25
|
+
|
26
|
+
class TestNetHackEngine:
|
27
|
+
"""Test cases for NetHack engine."""
|
28
|
+
|
29
|
+
@pytest.fixture
|
30
|
+
def mock_task_instance(self):
|
31
|
+
"""Create a mock task instance for testing."""
|
32
|
+
metadata = NetHackTaskInstanceMetadata(
|
33
|
+
character_role="tourist",
|
34
|
+
starting_level=1,
|
35
|
+
target_depth=5,
|
36
|
+
time_limit=1000,
|
37
|
+
difficulty="beginner",
|
38
|
+
special_objectives=["Collect 100 gold pieces"],
|
39
|
+
seed=42,
|
40
|
+
)
|
41
|
+
|
42
|
+
return NetHackTaskInstance(
|
43
|
+
id=uuid4(),
|
44
|
+
impetus=Impetus(instructions="Test NetHack game"),
|
45
|
+
intent=Intent(
|
46
|
+
rubric={"goal": "Reach depth 5"},
|
47
|
+
gold_trajectories=None,
|
48
|
+
gold_state_diff={},
|
49
|
+
),
|
50
|
+
metadata=metadata,
|
51
|
+
is_reproducible=True,
|
52
|
+
initial_engine_snapshot=None,
|
53
|
+
)
|
54
|
+
|
55
|
+
@pytest.mark.asyncio
|
56
|
+
async def test_engine_initialization(self, mock_task_instance):
|
57
|
+
"""Test engine initialization."""
|
58
|
+
engine = NetHackEngine(mock_task_instance)
|
59
|
+
|
60
|
+
assert engine.task_instance == mock_task_instance
|
61
|
+
assert engine.character_role == "tourist"
|
62
|
+
assert engine.max_turns == 1000
|
63
|
+
assert engine.public_state is None
|
64
|
+
assert engine.private_state is None
|
65
|
+
|
66
|
+
# Cleanup
|
67
|
+
if hasattr(engine, "nle"):
|
68
|
+
engine.nle.close()
|
69
|
+
|
70
|
+
@pytest.mark.asyncio
|
71
|
+
async def test_engine_reset(self, mock_task_instance):
|
72
|
+
"""Test engine reset functionality."""
|
73
|
+
engine = NetHackEngine(mock_task_instance)
|
74
|
+
|
75
|
+
priv, pub = await engine._reset_engine(seed=42)
|
76
|
+
|
77
|
+
# Check private state
|
78
|
+
assert isinstance(priv, NetHackPrivateState)
|
79
|
+
assert priv.reward_last == 0.0
|
80
|
+
assert priv.total_reward == 0.0
|
81
|
+
assert priv.terminated is False
|
82
|
+
assert priv.score >= 0
|
83
|
+
assert priv.depth_reached >= 1
|
84
|
+
|
85
|
+
# Check public state
|
86
|
+
assert isinstance(pub, NetHackPublicState)
|
87
|
+
assert pub.dungeon_level >= 1
|
88
|
+
assert pub.terminated is False
|
89
|
+
assert pub.turn_count == 0
|
90
|
+
assert pub.max_turns == 1000
|
91
|
+
assert len(pub.message) > 0 # Should have some message
|
92
|
+
assert len(pub.ascii_map) > 0 # Should have a map
|
93
|
+
|
94
|
+
@pytest.mark.asyncio
|
95
|
+
async def test_basic_movement(self, mock_task_instance):
|
96
|
+
"""Test basic movement actions."""
|
97
|
+
engine = NetHackEngine(mock_task_instance)
|
98
|
+
priv0, pub0 = await engine._reset_engine()
|
99
|
+
initial_pos = pub0.position
|
100
|
+
|
101
|
+
# Test wait action (always valid)
|
102
|
+
priv, pub = await engine._step_engine("wait")
|
103
|
+
assert pub.last_action == "wait"
|
104
|
+
assert pub.turn_count == 1
|
105
|
+
assert priv.reward_last >= 0 # Should get at least survival reward
|
106
|
+
|
107
|
+
@pytest.mark.asyncio
|
108
|
+
async def test_invalid_action(self, mock_task_instance):
|
109
|
+
"""Test invalid action handling."""
|
110
|
+
engine = NetHackEngine(mock_task_instance)
|
111
|
+
await engine._reset_engine()
|
112
|
+
|
113
|
+
# Test invalid action
|
114
|
+
with pytest.raises(ValueError, match="Invalid action|Valid actions"):
|
115
|
+
await engine._step_engine("invalid_action_xyz")
|
116
|
+
|
117
|
+
@pytest.mark.asyncio
|
118
|
+
async def test_turn_limit(self, mock_task_instance):
|
119
|
+
"""Test turn limit enforcement."""
|
120
|
+
# Create instance with very short time limit
|
121
|
+
metadata = NetHackTaskInstanceMetadata(
|
122
|
+
character_role="tourist",
|
123
|
+
starting_level=1,
|
124
|
+
target_depth=5,
|
125
|
+
time_limit=3, # Very short
|
126
|
+
difficulty="test",
|
127
|
+
special_objectives=[],
|
128
|
+
seed=42,
|
129
|
+
)
|
130
|
+
|
131
|
+
task = NetHackTaskInstance(
|
132
|
+
id=uuid4(),
|
133
|
+
impetus=Impetus(instructions="Test"),
|
134
|
+
intent=Intent(rubric={"goal": "Test"}, gold_trajectories=None, gold_state_diff={}),
|
135
|
+
metadata=metadata,
|
136
|
+
is_reproducible=True,
|
137
|
+
initial_engine_snapshot=None,
|
138
|
+
)
|
139
|
+
|
140
|
+
engine = NetHackEngine(task)
|
141
|
+
await engine._reset_engine()
|
142
|
+
|
143
|
+
# Take actions until time limit
|
144
|
+
await engine._step_engine("wait")
|
145
|
+
await engine._step_engine("wait")
|
146
|
+
priv, pub = await engine._step_engine("wait")
|
147
|
+
|
148
|
+
# Should be terminated due to time limit
|
149
|
+
assert pub.terminated is True
|
150
|
+
assert priv.terminated is True
|
151
|
+
assert priv.truncated is True
|
152
|
+
assert "Time limit reached" in pub.message
|
153
|
+
|
154
|
+
@pytest.mark.asyncio
|
155
|
+
async def test_reward_calculation(self, mock_task_instance):
|
156
|
+
"""Test reward calculation."""
|
157
|
+
engine = NetHackEngine(mock_task_instance)
|
158
|
+
await engine._reset_engine()
|
159
|
+
|
160
|
+
# Test survival reward
|
161
|
+
priv1, pub1 = await engine._step_engine("wait")
|
162
|
+
# Rewards can be positive (survival) or negative (NLE penalty)
|
163
|
+
assert priv1.reward_last != 0 # Should get some reward
|
164
|
+
|
165
|
+
# Test multiple steps accumulate reward
|
166
|
+
total_before = priv1.total_reward
|
167
|
+
priv2, pub2 = await engine._step_engine("wait")
|
168
|
+
# Total reward should change
|
169
|
+
assert priv2.total_reward != total_before
|
170
|
+
|
171
|
+
@pytest.mark.asyncio
|
172
|
+
async def test_state_diff(self, mock_task_instance):
|
173
|
+
"""Test state diff functionality."""
|
174
|
+
engine = NetHackEngine(mock_task_instance)
|
175
|
+
priv1, pub1 = await engine._reset_engine()
|
176
|
+
|
177
|
+
# Since the engine returns references to its internal state,
|
178
|
+
# we need to capture the values we want to compare
|
179
|
+
initial_turn_count = pub1.turn_count
|
180
|
+
initial_last_action = pub1.last_action
|
181
|
+
initial_reward_last = priv1.reward_last
|
182
|
+
initial_total_reward = priv1.total_reward
|
183
|
+
|
184
|
+
# Take an action
|
185
|
+
priv2, pub2 = await engine._step_engine("wait")
|
186
|
+
|
187
|
+
# Verify states actually changed
|
188
|
+
assert pub2.turn_count == initial_turn_count + 1
|
189
|
+
assert pub2.last_action == "wait"
|
190
|
+
assert priv2.total_reward > initial_total_reward # Should have some reward
|
191
|
+
|
192
|
+
# Since pub1 and pub2 are the same object (engine returns references),
|
193
|
+
# the diff will be empty. This is expected behavior for a stateful engine.
|
194
|
+
# The test was incorrectly assuming the engine returns copies.
|
195
|
+
|
196
|
+
@pytest.mark.asyncio
|
197
|
+
async def test_serialization_roundtrip(self, mock_task_instance):
|
198
|
+
"""Test state serialization and deserialization."""
|
199
|
+
engine = NetHackEngine(mock_task_instance)
|
200
|
+
await engine._reset_engine()
|
201
|
+
|
202
|
+
# Take some actions
|
203
|
+
await engine._step_engine("wait")
|
204
|
+
await engine._step_engine("search")
|
205
|
+
|
206
|
+
# Serialize
|
207
|
+
snapshot = await engine._serialize_engine()
|
208
|
+
assert isinstance(snapshot, NetHackEngineSnapshot)
|
209
|
+
|
210
|
+
# Deserialize
|
211
|
+
restored_engine = await NetHackEngine._deserialize_engine(snapshot)
|
212
|
+
|
213
|
+
# Check restored state matches
|
214
|
+
orig_priv, orig_pub = engine.get_current_states_for_observation()
|
215
|
+
rest_priv, rest_pub = restored_engine.get_current_states_for_observation()
|
216
|
+
|
217
|
+
assert rest_pub.turn_count == orig_pub.turn_count
|
218
|
+
assert rest_pub.position == orig_pub.position
|
219
|
+
assert rest_priv.total_reward == orig_priv.total_reward
|
220
|
+
|
221
|
+
@pytest.mark.asyncio
|
222
|
+
async def test_get_current_states(self, mock_task_instance):
|
223
|
+
"""Test getting current states without advancing."""
|
224
|
+
engine = NetHackEngine(mock_task_instance)
|
225
|
+
|
226
|
+
# Should raise before initialization
|
227
|
+
with pytest.raises(RuntimeError, match="Engine not initialized"):
|
228
|
+
engine.get_current_states_for_observation()
|
229
|
+
|
230
|
+
# Initialize
|
231
|
+
await engine._reset_engine()
|
232
|
+
|
233
|
+
# Get states multiple times
|
234
|
+
priv1, pub1 = engine.get_current_states_for_observation()
|
235
|
+
priv2, pub2 = engine.get_current_states_for_observation()
|
236
|
+
|
237
|
+
# Should be same states
|
238
|
+
assert pub1.turn_count == pub2.turn_count
|
239
|
+
assert priv1.total_reward == priv2.total_reward
|
240
|
+
|
241
|
+
|
242
|
+
class TestRewardComponents:
|
243
|
+
"""Test reward components."""
|
244
|
+
|
245
|
+
@pytest.mark.asyncio
|
246
|
+
async def test_survival_component(self):
|
247
|
+
"""Test survival reward component."""
|
248
|
+
component = NetHackSurvivalComponent()
|
249
|
+
|
250
|
+
# Test alive state
|
251
|
+
state = NetHackPublicState(terminated=False)
|
252
|
+
reward = await component.score(state, "wait")
|
253
|
+
assert reward == 0.01
|
254
|
+
|
255
|
+
# Test dead state
|
256
|
+
state.terminated = True
|
257
|
+
reward = await component.score(state, "wait")
|
258
|
+
assert reward == -1.0
|
259
|
+
|
260
|
+
@pytest.mark.asyncio
|
261
|
+
async def test_progress_component(self):
|
262
|
+
"""Test progress reward component."""
|
263
|
+
component = NetHackProgressComponent()
|
264
|
+
|
265
|
+
# Test no progress
|
266
|
+
state = NetHackPublicState(dungeon_level=1)
|
267
|
+
reward = await component.score(state, "wait")
|
268
|
+
assert reward == 0.0
|
269
|
+
|
270
|
+
# Test depth increase
|
271
|
+
state.dungeon_level = 2
|
272
|
+
reward = await component.score(state, "go_down")
|
273
|
+
assert reward == 1.0
|
274
|
+
|
275
|
+
# Test no further reward for same level
|
276
|
+
reward = await component.score(state, "wait")
|
277
|
+
assert reward == 0.0
|