synth-ai 0.1.9__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +37 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.1.9.dist-info/METADATA +0 -37
- synth_ai-0.1.9.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,111 @@
|
|
1
|
+
"""
|
2
|
+
Logging configuration for Craftax environment.
|
3
|
+
Suppresses obnoxious JAX debug messages and sets appropriate log levels.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import warnings
|
9
|
+
|
10
|
+
|
11
|
+
def configure_logging():
|
12
|
+
"""Configure logging to suppress noisy debug messages."""
|
13
|
+
|
14
|
+
# Suppress JAX debug logging by setting appropriate log levels
|
15
|
+
jax_loggers = [
|
16
|
+
"jax._src.cache_key",
|
17
|
+
"jax._src.compilation_cache",
|
18
|
+
"jax._src.compiler",
|
19
|
+
"jax._src.dispatch",
|
20
|
+
"jax",
|
21
|
+
"jaxlib",
|
22
|
+
]
|
23
|
+
|
24
|
+
for logger_name in jax_loggers:
|
25
|
+
logger = logging.getLogger(logger_name)
|
26
|
+
logger.setLevel(logging.WARNING)
|
27
|
+
logger.propagate = False
|
28
|
+
|
29
|
+
# Set JAX platform to CPU to avoid GPU-related logging
|
30
|
+
os.environ.setdefault("JAX_PLATFORMS", "cpu")
|
31
|
+
|
32
|
+
# Suppress JAX warnings and compilation messages
|
33
|
+
os.environ.setdefault("JAX_ENABLE_X64", "False")
|
34
|
+
os.environ.setdefault("JAX_LOG_COMPILES", "0")
|
35
|
+
os.environ.setdefault("JAX_COMPILATION_CACHE_DIR", "/tmp/jax_cache")
|
36
|
+
|
37
|
+
# Configure root logger to INFO level (but don't override if already configured)
|
38
|
+
if not logging.getLogger().handlers:
|
39
|
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
|
40
|
+
|
41
|
+
# Suppress other noisy libraries
|
42
|
+
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
43
|
+
logging.getLogger("PIL").setLevel(logging.WARNING)
|
44
|
+
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
45
|
+
|
46
|
+
# Filter out specific warnings
|
47
|
+
warnings.filterwarnings("ignore", category=UserWarning, module="jax")
|
48
|
+
warnings.filterwarnings("ignore", category=FutureWarning, module="jax")
|
49
|
+
|
50
|
+
|
51
|
+
def safe_compare(left, right, operation="<"):
|
52
|
+
"""
|
53
|
+
Safely compare two values, handling string vs int comparison errors.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
left: Left operand
|
57
|
+
right: Right operand
|
58
|
+
operation: Comparison operation ('>', '<', '>=', '<=', '==', '!=')
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
bool: Result of comparison, or False if types are incompatible
|
62
|
+
"""
|
63
|
+
try:
|
64
|
+
# If both are strings, try to convert to numbers
|
65
|
+
if isinstance(left, str) and isinstance(right, str):
|
66
|
+
try:
|
67
|
+
left = float(left)
|
68
|
+
right = float(right)
|
69
|
+
except ValueError:
|
70
|
+
# If conversion fails, compare as strings
|
71
|
+
pass
|
72
|
+
# If one is string and one is number, try to convert string to number
|
73
|
+
elif isinstance(left, str) and isinstance(right, (int, float)):
|
74
|
+
try:
|
75
|
+
left = type(right)(left)
|
76
|
+
except ValueError:
|
77
|
+
logging.warning(f"Cannot compare string '{left}' with number {right}")
|
78
|
+
return False
|
79
|
+
elif isinstance(left, (int, float)) and isinstance(right, str):
|
80
|
+
try:
|
81
|
+
right = type(left)(right)
|
82
|
+
except ValueError:
|
83
|
+
logging.warning(f"Cannot compare number {left} with string '{right}'")
|
84
|
+
return False
|
85
|
+
|
86
|
+
# Perform the comparison
|
87
|
+
if operation == "<":
|
88
|
+
return left < right
|
89
|
+
elif operation == ">":
|
90
|
+
return left > right
|
91
|
+
elif operation == "<=":
|
92
|
+
return left <= right
|
93
|
+
elif operation == ">=":
|
94
|
+
return left >= right
|
95
|
+
elif operation == "==":
|
96
|
+
return left == right
|
97
|
+
elif operation == "!=":
|
98
|
+
return left != right
|
99
|
+
else:
|
100
|
+
raise ValueError(f"Unsupported operation: {operation}")
|
101
|
+
|
102
|
+
except TypeError as e:
|
103
|
+
logging.error(f"Type error in comparison: {left} {operation} {right} - {e}")
|
104
|
+
return False
|
105
|
+
except Exception as e:
|
106
|
+
logging.error(f"Unexpected error in comparison: {left} {operation} {right} - {e}")
|
107
|
+
return False
|
108
|
+
|
109
|
+
|
110
|
+
# Configure logging when module is imported
|
111
|
+
configure_logging()
|
@@ -0,0 +1,502 @@
|
|
1
|
+
"""CrafterEngine — Stateful, reproducible wrapper around danijar/crafter.Env.
|
2
|
+
This file follows the same structure as the SokobanEngine shown earlier.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
# Import logging configuration first to suppress JAX debug messages
|
8
|
+
from .config_logging import safe_compare
|
9
|
+
|
10
|
+
|
11
|
+
import logging
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
14
|
+
|
15
|
+
import numpy as np
|
16
|
+
import crafter # type: ignore
|
17
|
+
import copy
|
18
|
+
import dataclasses
|
19
|
+
|
20
|
+
from synth_ai.environments.environment.shared_engine import (
|
21
|
+
GetObservationCallable,
|
22
|
+
InternalObservation,
|
23
|
+
)
|
24
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
25
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
26
|
+
from synth_ai.environments.reproducibility.core import IReproducibleEngine
|
27
|
+
from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent # Added
|
28
|
+
|
29
|
+
# Local helper imports (must exist relative to this file)
|
30
|
+
from .engine_helpers.action_map import CRAFTER_ACTION_MAP # action‑name → int
|
31
|
+
from .engine_helpers.serialization import (
|
32
|
+
serialize_world_object,
|
33
|
+
)
|
34
|
+
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
logging.basicConfig(level=logging.INFO)
|
37
|
+
|
38
|
+
# -----------------------------------------------------------------------------
|
39
|
+
# Dataclasses for snapshot & (public, private) runtime state
|
40
|
+
# -----------------------------------------------------------------------------
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class CrafterEngineSnapshot(StatefulEngineSnapshot):
|
45
|
+
env_raw_state: Any # from crafter.Env.save()
|
46
|
+
total_reward_snapshot: float
|
47
|
+
crafter_seed: Optional[int] = None
|
48
|
+
# Store previous states needed for reward calculation to resume correctly
|
49
|
+
previous_public_state_snapshot: Optional[Dict] = None
|
50
|
+
previous_private_state_snapshot: Optional[Dict] = None
|
51
|
+
# Add _previous_public_state_for_reward and _previous_private_state_for_reward if needed for perfect resume
|
52
|
+
# For RewardStack, its configuration is fixed at init. If it had internal state, that would need saving.
|
53
|
+
|
54
|
+
|
55
|
+
@dataclass
|
56
|
+
class CrafterPublicState:
|
57
|
+
inventory: Dict[str, int]
|
58
|
+
achievements_status: Dict[str, bool]
|
59
|
+
player_position: Tuple[int, int]
|
60
|
+
player_direction: Union[int, Tuple[int, int]]
|
61
|
+
semantic_map: Optional[np.ndarray]
|
62
|
+
world_material_map: np.ndarray
|
63
|
+
observation_image: np.ndarray
|
64
|
+
num_steps_taken: int
|
65
|
+
max_steps_episode: int
|
66
|
+
error_info: Optional[str] = None
|
67
|
+
|
68
|
+
def diff(self, prev_state: "CrafterPublicState") -> Dict[str, Any]:
|
69
|
+
changes = {}
|
70
|
+
for field in self.__dataclass_fields__: # type: ignore[attr-defined]
|
71
|
+
new_v, old_v = getattr(self, field), getattr(prev_state, field)
|
72
|
+
if isinstance(new_v, np.ndarray):
|
73
|
+
if not np.array_equal(new_v, old_v):
|
74
|
+
changes[field] = True
|
75
|
+
elif new_v != old_v:
|
76
|
+
changes[field] = (old_v, new_v)
|
77
|
+
return changes
|
78
|
+
|
79
|
+
|
80
|
+
@dataclass
|
81
|
+
class CrafterPrivateState:
|
82
|
+
reward_last_step: float
|
83
|
+
total_reward_episode: float
|
84
|
+
achievements_current_values: Dict[str, int]
|
85
|
+
terminated: bool
|
86
|
+
truncated: bool
|
87
|
+
player_internal_stats: Dict[str, Any]
|
88
|
+
world_rng_state_snapshot: Any
|
89
|
+
|
90
|
+
def diff(self, prev_state: "CrafterPrivateState") -> Dict[str, Any]:
|
91
|
+
changes = {}
|
92
|
+
for field in self.__dataclass_fields__: # type: ignore[attr-defined]
|
93
|
+
new_v, old_v = getattr(self, field), getattr(prev_state, field)
|
94
|
+
if new_v != old_v:
|
95
|
+
changes[field] = (old_v, new_v)
|
96
|
+
return changes
|
97
|
+
|
98
|
+
|
99
|
+
# -----------------------------------------------------------------------------
|
100
|
+
# Observation helpers
|
101
|
+
# -----------------------------------------------------------------------------
|
102
|
+
|
103
|
+
|
104
|
+
class CrafterObservationCallable(GetObservationCallable):
|
105
|
+
def __init__(self) -> None:
|
106
|
+
pass
|
107
|
+
|
108
|
+
async def get_observation(
|
109
|
+
self, pub: CrafterPublicState, priv: CrafterPrivateState
|
110
|
+
) -> InternalObservation: # type: ignore[override]
|
111
|
+
observation: Dict[str, Any] = {
|
112
|
+
"inventory": pub.inventory,
|
113
|
+
"achievements": pub.achievements_status,
|
114
|
+
"player_pos": pub.player_position,
|
115
|
+
"steps": pub.num_steps_taken,
|
116
|
+
"reward_last": priv.reward_last_step,
|
117
|
+
"total_reward": priv.total_reward_episode,
|
118
|
+
"terminated": priv.terminated,
|
119
|
+
"truncated": priv.truncated,
|
120
|
+
}
|
121
|
+
return observation # type: ignore[return-value]
|
122
|
+
|
123
|
+
|
124
|
+
# -----------------------------------------------------------------------------
|
125
|
+
# CrafterEngine implementation
|
126
|
+
# -----------------------------------------------------------------------------
|
127
|
+
|
128
|
+
|
129
|
+
class CrafterEngine(StatefulEngine, IReproducibleEngine):
|
130
|
+
"""StatefulEngine wrapper around `crafter.Env` supporting full snapshotting."""
|
131
|
+
|
132
|
+
task_instance: TaskInstance
|
133
|
+
env: crafter.Env
|
134
|
+
|
135
|
+
# ────────────────────────────────────────────────────────────────────────
|
136
|
+
# Construction helpers
|
137
|
+
# ────────────────────────────────────────────────────────────────────────
|
138
|
+
|
139
|
+
def __init__(self, task_instance: TaskInstance):
|
140
|
+
self.task_instance = task_instance
|
141
|
+
self._total_reward: float = 0.0
|
142
|
+
self._current_action_for_reward: Optional[int] = None
|
143
|
+
self._previous_public_state_for_reward: Optional[CrafterPublicState] = None
|
144
|
+
self._previous_private_state_for_reward: Optional[CrafterPrivateState] = (
|
145
|
+
None # For stat changes
|
146
|
+
)
|
147
|
+
|
148
|
+
# Initialize achievements tracking
|
149
|
+
self.achievements_unlocked: set = set()
|
150
|
+
|
151
|
+
cfg = getattr(task_instance, "config", {}) or {}
|
152
|
+
area: Tuple[int, int] = tuple(cfg.get("area", (64, 64))) # type: ignore[arg-type]
|
153
|
+
length: int = int(cfg.get("length", 10000))
|
154
|
+
|
155
|
+
# Get seed from metadata if available, otherwise fall back to config
|
156
|
+
seed: Optional[int] = cfg.get("seed")
|
157
|
+
if hasattr(task_instance, "metadata") and hasattr(task_instance.metadata, "seed"):
|
158
|
+
seed = task_instance.metadata.seed
|
159
|
+
|
160
|
+
self.env = crafter.Env(area=area, length=length, seed=seed)
|
161
|
+
# store original seed for reproducibility
|
162
|
+
self.env._seed = seed
|
163
|
+
|
164
|
+
self.reward_stack = RewardStack(
|
165
|
+
components=[
|
166
|
+
CrafterAchievementComponent(),
|
167
|
+
CrafterPlayerStatComponent(),
|
168
|
+
CrafterStepPenaltyComponent(penalty=-0.001),
|
169
|
+
]
|
170
|
+
)
|
171
|
+
|
172
|
+
# ────────────────────────────────────────────────────────────────────────
|
173
|
+
# Utility: action validation / mapping
|
174
|
+
# ────────────────────────────────────────────────────────────────────────
|
175
|
+
|
176
|
+
def _validate_action_engine(self, action: Union[int, str]) -> int: # type: ignore[override]
|
177
|
+
if isinstance(action, str):
|
178
|
+
action = CRAFTER_ACTION_MAP.get(action, 0)
|
179
|
+
if not isinstance(action, int):
|
180
|
+
return 0
|
181
|
+
return int(np.clip(action, 0, len(crafter.constants.actions) - 1)) # type: ignore
|
182
|
+
|
183
|
+
# ────────────────────────────────────────────────────────────────────────
|
184
|
+
# Core StatefulEngine API
|
185
|
+
# ────────────────────────────────────────────────────────────────────────
|
186
|
+
|
187
|
+
async def _reset_engine(
|
188
|
+
self, *, seed: Optional[int] | None = None
|
189
|
+
) -> Tuple[CrafterPrivateState, CrafterPublicState]:
|
190
|
+
if seed is not None:
|
191
|
+
# Re‑instantiate env with new seed to match crafter's internal reseeding convention
|
192
|
+
self.env = crafter.Env(area=self.env._area, length=self.env._length, seed=seed)
|
193
|
+
obs_img = self.env.reset()
|
194
|
+
self._total_reward = 0.0
|
195
|
+
pub = self._build_public_state(obs_img)
|
196
|
+
priv = self._build_private_state(reward=0.0, terminated=False, truncated=False)
|
197
|
+
return priv, pub
|
198
|
+
|
199
|
+
async def _step_engine(self, action: int) -> Tuple[CrafterPrivateState, CrafterPublicState]:
|
200
|
+
try:
|
201
|
+
# Validate action is in valid range
|
202
|
+
if action < 0 or action >= self.env.action_space.n:
|
203
|
+
raise ValueError(
|
204
|
+
f"Invalid action {action}, must be in range [0, {self.env.action_space.n})"
|
205
|
+
)
|
206
|
+
|
207
|
+
current_pub_state = self._build_public_state(self.env.render())
|
208
|
+
|
209
|
+
# Step the environment
|
210
|
+
obs, reward, done, info = self.env.step(action)
|
211
|
+
|
212
|
+
# Update internal state
|
213
|
+
self.obs = obs
|
214
|
+
self.done = done
|
215
|
+
self.info = info
|
216
|
+
self.last_reward = reward
|
217
|
+
|
218
|
+
# Step count is tracked by the crafter environment itself in self.env._step
|
219
|
+
|
220
|
+
# Process achievements - check what was unlocked this step
|
221
|
+
new_achievements = set()
|
222
|
+
if "achievements" in info:
|
223
|
+
for achievement, status in info["achievements"].items():
|
224
|
+
if status and achievement not in self.achievements_unlocked:
|
225
|
+
new_achievements.add(achievement)
|
226
|
+
self.achievements_unlocked.add(achievement)
|
227
|
+
|
228
|
+
# Calculate reward
|
229
|
+
reward_from_stack = 0
|
230
|
+
try:
|
231
|
+
if hasattr(self, "_reward_stack") and self._reward_stack:
|
232
|
+
reward_from_stack = sum(self._reward_stack)
|
233
|
+
self._reward_stack.clear()
|
234
|
+
except Exception as e:
|
235
|
+
reward_from_stack = 0
|
236
|
+
|
237
|
+
# Create private state
|
238
|
+
# Current episode reward
|
239
|
+
final_reward = self._total_reward + reward + reward_from_stack
|
240
|
+
self._total_reward = final_reward
|
241
|
+
|
242
|
+
# Determine proper termination reason based on game state
|
243
|
+
player = self.env._player # type: ignore[attr-defined]
|
244
|
+
current_step = self.env._step # type: ignore[attr-defined]
|
245
|
+
max_steps = self.env._length # type: ignore[attr-defined]
|
246
|
+
|
247
|
+
# Check if player died (health <= 0)
|
248
|
+
player_died = player.health <= 0
|
249
|
+
|
250
|
+
# Check if max steps reached
|
251
|
+
max_steps_reached = current_step >= max_steps
|
252
|
+
|
253
|
+
# Set termination flags properly:
|
254
|
+
# - terminated=True only if player actually died
|
255
|
+
# - truncated=True only if episode ended due to step limit
|
256
|
+
if done:
|
257
|
+
if player_died:
|
258
|
+
terminated = True
|
259
|
+
truncated = False
|
260
|
+
elif max_steps_reached:
|
261
|
+
terminated = False
|
262
|
+
truncated = True
|
263
|
+
else:
|
264
|
+
# Fallback: if done=True but unclear reason, assume timeout
|
265
|
+
terminated = False
|
266
|
+
truncated = True
|
267
|
+
else:
|
268
|
+
terminated = False
|
269
|
+
truncated = False
|
270
|
+
|
271
|
+
final_priv_state = self._build_private_state(final_reward, terminated, truncated)
|
272
|
+
|
273
|
+
self._previous_public_state_for_reward = current_pub_state
|
274
|
+
self._previous_private_state_for_reward = final_priv_state
|
275
|
+
|
276
|
+
return final_priv_state, current_pub_state
|
277
|
+
|
278
|
+
except Exception as e:
|
279
|
+
# Create error state
|
280
|
+
error_pub_state = self._get_public_state_from_env()
|
281
|
+
error_pub_state.error_info = f"Step engine error: {e}"
|
282
|
+
error_priv_state = self._get_private_state_from_env(
|
283
|
+
reward=-1.0, terminated=True, truncated=False
|
284
|
+
)
|
285
|
+
return error_priv_state, error_pub_state
|
286
|
+
|
287
|
+
# ------------------------------------------------------------------
|
288
|
+
# Rendering (simple text summary)
|
289
|
+
# ------------------------------------------------------------------
|
290
|
+
|
291
|
+
async def _render(
|
292
|
+
self,
|
293
|
+
private_state: CrafterPrivateState,
|
294
|
+
public_state: CrafterPublicState,
|
295
|
+
get_observation: Optional[GetObservationCallable] = None,
|
296
|
+
) -> str: # type: ignore[override]
|
297
|
+
obs_cb = get_observation or CrafterObservationCallable()
|
298
|
+
obs = await obs_cb.get_observation(public_state, private_state)
|
299
|
+
if isinstance(obs, str):
|
300
|
+
return obs
|
301
|
+
if isinstance(obs, dict):
|
302
|
+
header = f"steps: {public_state.num_steps_taken}/{public_state.max_steps_episode} | "
|
303
|
+
header += f"last_r: {private_state.reward_last_step:.2f} | total_r: {private_state.total_reward_episode:.2f}"
|
304
|
+
inv = ", ".join(f"{k}:{v}" for k, v in public_state.inventory.items() if v)
|
305
|
+
ach = ", ".join(k for k, v in public_state.achievements_status.items() if v)
|
306
|
+
return f"{header}\ninv: {inv}\nach: {ach}"
|
307
|
+
return str(obs)
|
308
|
+
|
309
|
+
# ------------------------------------------------------------------
|
310
|
+
# Snapshotting for exact reproducibility
|
311
|
+
# ------------------------------------------------------------------
|
312
|
+
|
313
|
+
async def _serialize_engine(self) -> CrafterEngineSnapshot:
|
314
|
+
world = self.env._world # type: ignore[attr-defined]
|
315
|
+
objects_state = [None if o is None else serialize_world_object(o) for o in world._objects]
|
316
|
+
# capture total reward and original seed
|
317
|
+
total_reward = self._total_reward
|
318
|
+
snap = CrafterEngineSnapshot(
|
319
|
+
env_raw_state=self.env.save(),
|
320
|
+
total_reward_snapshot=total_reward,
|
321
|
+
crafter_seed=self.env._seed,
|
322
|
+
previous_public_state_snapshot=dataclasses.asdict(
|
323
|
+
self._previous_public_state_for_reward
|
324
|
+
)
|
325
|
+
if self._previous_public_state_for_reward
|
326
|
+
else None,
|
327
|
+
previous_private_state_snapshot=dataclasses.asdict(
|
328
|
+
self._previous_private_state_for_reward
|
329
|
+
)
|
330
|
+
if self._previous_private_state_for_reward
|
331
|
+
else None,
|
332
|
+
)
|
333
|
+
return snap
|
334
|
+
|
335
|
+
@classmethod
|
336
|
+
async def _deserialize_engine(
|
337
|
+
cls, snapshot: CrafterEngineSnapshot, task_instance: TaskInstance
|
338
|
+
) -> "CrafterEngine":
|
339
|
+
engine = cls(task_instance)
|
340
|
+
engine.env.load(snapshot.env_raw_state)
|
341
|
+
engine._total_reward = snapshot.total_reward_snapshot
|
342
|
+
engine.env._seed = snapshot.crafter_seed
|
343
|
+
_ = engine.env.reset() # create initial world structure
|
344
|
+
# Re-establish previous states for reward system continuity if first step after load
|
345
|
+
engine._previous_public_state_for_reward = engine._build_public_state(engine.env.render())
|
346
|
+
# Safe comparisons to avoid string vs int errors
|
347
|
+
health_dead = safe_compare(0, engine.env._player.health, ">=")
|
348
|
+
step_exceeded = safe_compare(engine.env._length, engine.env._step, "<=")
|
349
|
+
engine._previous_private_state_for_reward = engine._build_private_state(
|
350
|
+
0.0, health_dead, step_exceeded
|
351
|
+
)
|
352
|
+
return engine
|
353
|
+
|
354
|
+
# ------------------------------------------------------------------
|
355
|
+
# Internal helpers
|
356
|
+
# ------------------------------------------------------------------
|
357
|
+
|
358
|
+
def _build_public_state(
|
359
|
+
self, obs_img: np.ndarray, info: Optional[Dict[str, Any]] | None = None
|
360
|
+
) -> CrafterPublicState:
|
361
|
+
try:
|
362
|
+
if info is None:
|
363
|
+
player = self.env._player # type: ignore[attr-defined]
|
364
|
+
# Safe achievement status check
|
365
|
+
achievements_status = {}
|
366
|
+
for k, v in player.achievements.items():
|
367
|
+
achievements_status[k] = safe_compare(0, v, "<")
|
368
|
+
inventory = player.inventory.copy()
|
369
|
+
semantic = getattr(self.env, "_sem_view", lambda: None)()
|
370
|
+
else:
|
371
|
+
inventory = info.get("inventory", {})
|
372
|
+
# Safe achievement status check from info
|
373
|
+
achievements_status = {}
|
374
|
+
achievements_info = info.get("achievements", {})
|
375
|
+
for k, v in achievements_info.items():
|
376
|
+
achievements_status[k] = safe_compare(0, v, "<")
|
377
|
+
semantic = info.get("semantic")
|
378
|
+
|
379
|
+
player = self.env._player # type: ignore[attr-defined]
|
380
|
+
return CrafterPublicState(
|
381
|
+
inventory=inventory,
|
382
|
+
achievements_status=achievements_status,
|
383
|
+
player_position=tuple(player.pos), # type: ignore[attr-defined]
|
384
|
+
player_direction=player.facing, # type: ignore[attr-defined]
|
385
|
+
semantic_map=semantic,
|
386
|
+
world_material_map=self.env._world._mat_map.copy(), # type: ignore[attr-defined]
|
387
|
+
observation_image=obs_img,
|
388
|
+
num_steps_taken=self.env._step, # type: ignore[attr-defined]
|
389
|
+
max_steps_episode=self.env._length, # type: ignore[attr-defined]
|
390
|
+
error_info=info.get("error_info") if info else None,
|
391
|
+
)
|
392
|
+
except Exception as e:
|
393
|
+
logging.error(f"Error building public state: {e}")
|
394
|
+
# Return minimal safe state
|
395
|
+
return CrafterPublicState(
|
396
|
+
inventory={},
|
397
|
+
achievements_status={},
|
398
|
+
player_position=(0, 0),
|
399
|
+
player_direction=0,
|
400
|
+
semantic_map=None,
|
401
|
+
world_material_map=np.zeros((1, 1), dtype=np.uint8),
|
402
|
+
observation_image=obs_img
|
403
|
+
if obs_img is not None
|
404
|
+
else np.zeros((64, 64, 3), dtype=np.uint8),
|
405
|
+
num_steps_taken=0,
|
406
|
+
max_steps_episode=10000,
|
407
|
+
error_info=f"State building error: {e}",
|
408
|
+
)
|
409
|
+
|
410
|
+
def _build_private_state(
|
411
|
+
self, reward: float, terminated: bool, truncated: bool
|
412
|
+
) -> CrafterPrivateState:
|
413
|
+
player = self.env._player # type: ignore[attr-defined]
|
414
|
+
stats = {
|
415
|
+
"health": player.health,
|
416
|
+
"food": player.inventory.get("food"),
|
417
|
+
"drink": player.inventory.get("drink"),
|
418
|
+
"energy": player.inventory.get("energy"),
|
419
|
+
"_hunger": getattr(player, "_hunger", 0),
|
420
|
+
"_thirst": getattr(player, "_thirst", 0),
|
421
|
+
}
|
422
|
+
return CrafterPrivateState(
|
423
|
+
reward_last_step=reward,
|
424
|
+
total_reward_episode=self._total_reward,
|
425
|
+
achievements_current_values=player.achievements.copy(),
|
426
|
+
terminated=terminated,
|
427
|
+
truncated=truncated,
|
428
|
+
player_internal_stats=stats,
|
429
|
+
world_rng_state_snapshot=self.env._world.random.get_state(), # type: ignore[attr-defined]
|
430
|
+
)
|
431
|
+
|
432
|
+
def _get_public_state_from_env(self) -> CrafterPublicState:
|
433
|
+
"""Helper method to get current public state from synth_ai.environments.environment"""
|
434
|
+
try:
|
435
|
+
obs_img = self.env.render()
|
436
|
+
return self._build_public_state(obs_img)
|
437
|
+
except Exception as e:
|
438
|
+
logging.error(f"Error getting public state from env: {e}")
|
439
|
+
# Return default state
|
440
|
+
return CrafterPublicState(
|
441
|
+
inventory={},
|
442
|
+
achievements_status={},
|
443
|
+
player_position=(0, 0),
|
444
|
+
player_direction=0,
|
445
|
+
semantic_map=None,
|
446
|
+
world_material_map=np.zeros((1, 1), dtype=np.uint8),
|
447
|
+
observation_image=np.zeros((64, 64, 3), dtype=np.uint8),
|
448
|
+
num_steps_taken=0,
|
449
|
+
max_steps_episode=10000,
|
450
|
+
error_info=f"State extraction error: {e}",
|
451
|
+
)
|
452
|
+
|
453
|
+
def _get_private_state_from_env(
|
454
|
+
self, reward: float, terminated: bool, truncated: bool
|
455
|
+
) -> CrafterPrivateState:
|
456
|
+
"""Helper method to get current private state from synth_ai.environments.environment"""
|
457
|
+
try:
|
458
|
+
return self._build_private_state(reward, terminated, truncated)
|
459
|
+
except Exception as e:
|
460
|
+
logging.error(f"Error getting private state from env: {e}")
|
461
|
+
# Return default state
|
462
|
+
return CrafterPrivateState(
|
463
|
+
reward_last_step=reward,
|
464
|
+
total_reward_episode=0.0,
|
465
|
+
achievements_current_values={},
|
466
|
+
terminated=terminated,
|
467
|
+
truncated=truncated,
|
468
|
+
player_internal_stats={},
|
469
|
+
world_rng_state_snapshot=None,
|
470
|
+
)
|
471
|
+
|
472
|
+
|
473
|
+
# --- Reward Components ---
|
474
|
+
class CrafterAchievementComponent(RewardComponent):
|
475
|
+
async def score(self, state: CrafterPublicState, action: Dict[str, Any]) -> float:
|
476
|
+
prev_achievements = action.get("previous_public_state_achievements", {})
|
477
|
+
current_achievements = state.achievements_status
|
478
|
+
new_achievements = sum(
|
479
|
+
1
|
480
|
+
for ach, status in current_achievements.items()
|
481
|
+
if status and not prev_achievements.get(ach)
|
482
|
+
)
|
483
|
+
return float(new_achievements) * 0.1
|
484
|
+
|
485
|
+
|
486
|
+
class CrafterPlayerStatComponent(RewardComponent):
|
487
|
+
async def score(self, state: CrafterPrivateState, action: Dict[str, Any]) -> float:
|
488
|
+
current_health = state.player_internal_stats.get("health", 0)
|
489
|
+
prev_health = action.get("previous_private_state_stats", {}).get("health", current_health)
|
490
|
+
if current_health < prev_health:
|
491
|
+
return -0.05 # Lost health penalty
|
492
|
+
return 0.0
|
493
|
+
|
494
|
+
|
495
|
+
class CrafterStepPenaltyComponent(RewardComponent):
|
496
|
+
def __init__(self, penalty: float = -0.001):
|
497
|
+
super().__init__()
|
498
|
+
self.penalty = penalty
|
499
|
+
self.weight = 1.0
|
500
|
+
|
501
|
+
async def score(self, state: Any, action: Any) -> float:
|
502
|
+
return self.penalty
|
@@ -0,0 +1,63 @@
|
|
1
|
+
"""
|
2
|
+
Apply once (import this module anywhere before CrafterEngine is used).
|
3
|
+
It replaces Env._balance_object so that every per-chunk object list is
|
4
|
+
sorted by (x, y, class-name) before any random choice is made – removing
|
5
|
+
the hash-based set-iteration nondeterminism that caused the drift.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import collections
|
9
|
+
import crafter
|
10
|
+
|
11
|
+
print("[PATCH] Attempting to apply Crafter deterministic patch...")
|
12
|
+
|
13
|
+
# -----------------------------------------------------------------------------
|
14
|
+
# 1. Make per–chunk object order stable
|
15
|
+
# -----------------------------------------------------------------------------
|
16
|
+
if not hasattr(crafter.Env, "_orig_balance_object"):
|
17
|
+
print("[PATCH] Patching crafter.Env._balance_object...")
|
18
|
+
crafter.Env._orig_balance_object = crafter.Env._balance_object
|
19
|
+
|
20
|
+
def _balance_object_det(self, chunk, objs, *args, **kwargs):
|
21
|
+
# cls, material, span_dist, despan_dist, spawn_prob, despawn_prob, ctor, target_fn
|
22
|
+
# were part of the original signature, but *args, **kwargs is more robust.
|
23
|
+
objs = sorted(objs, key=lambda o: (o.pos[0], o.pos[1], o.__class__.__name__))
|
24
|
+
return crafter.Env._orig_balance_object(self, chunk, objs, *args, **kwargs)
|
25
|
+
|
26
|
+
crafter.Env._balance_object = _balance_object_det
|
27
|
+
print("[PATCH] crafter.Env._balance_object patched.")
|
28
|
+
else:
|
29
|
+
print("[PATCH] crafter.Env._balance_object already patched or _orig_balance_object exists.")
|
30
|
+
|
31
|
+
# -----------------------------------------------------------------------------
|
32
|
+
# 2. Make *chunk* iteration order stable
|
33
|
+
# -----------------------------------------------------------------------------
|
34
|
+
if not hasattr(crafter.engine.World, "_orig_chunks_prop"):
|
35
|
+
crafter.engine.World._orig_chunks_prop = crafter.engine.World.chunks
|
36
|
+
|
37
|
+
def _chunks_sorted(self):
|
38
|
+
# OrderedDict keeps the sorted key order during iteration
|
39
|
+
return collections.OrderedDict(sorted(self._chunks.items()))
|
40
|
+
|
41
|
+
crafter.engine.World.chunks = property(_chunks_sorted)
|
42
|
+
|
43
|
+
# -----------------------------------------------------------------------------
|
44
|
+
# 3. NEW: keep per-frame object update order deterministic
|
45
|
+
# -----------------------------------------------------------------------------
|
46
|
+
if not hasattr(crafter.engine.World, "_orig_objects_prop"):
|
47
|
+
crafter.engine.World._orig_objects_prop = crafter.engine.World.objects # save original
|
48
|
+
|
49
|
+
@property
|
50
|
+
def _objects_sorted(self):
|
51
|
+
objs = [o for o in self._objects if o] # Filter out None (removed) objects
|
52
|
+
# stable order: x, y, class-name, creation-index
|
53
|
+
return sorted(
|
54
|
+
objs,
|
55
|
+
key=lambda o: (
|
56
|
+
o.pos[0],
|
57
|
+
o.pos[1],
|
58
|
+
o.__class__.__name__,
|
59
|
+
getattr(o, "_id", 0),
|
60
|
+
),
|
61
|
+
)
|
62
|
+
|
63
|
+
crafter.engine.World.objects = _objects_sorted
|