synth-ai 0.1.9__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +37 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.1.9.dist-info/METADATA +0 -37
- synth_ai-0.1.9.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,328 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import shutil
|
3
|
+
import subprocess
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Dict, Any, Tuple, Optional
|
6
|
+
from dataclasses import dataclass
|
7
|
+
|
8
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
9
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
10
|
+
from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class VerilogEngineSnapshot(StatefulEngineSnapshot):
|
15
|
+
task_instance_dict: Dict
|
16
|
+
engine_snapshot: Dict
|
17
|
+
|
18
|
+
def model_dump(self) -> Dict:
|
19
|
+
"""Convert dataclass to dictionary for compatibility with Pydantic models."""
|
20
|
+
return {
|
21
|
+
"task_instance_dict": self.task_instance_dict,
|
22
|
+
"engine_snapshot": self.engine_snapshot,
|
23
|
+
}
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class VerilogPublicState:
|
28
|
+
files: Dict[str, str]
|
29
|
+
build_dir: str
|
30
|
+
task_completed: bool = False
|
31
|
+
last_compile_output: Optional[str] = None
|
32
|
+
last_simulate_output: Optional[str] = None
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class VerilogPrivateState:
|
37
|
+
reward_last: float
|
38
|
+
total_reward: float
|
39
|
+
terminated: bool
|
40
|
+
truncated: bool
|
41
|
+
|
42
|
+
|
43
|
+
class VerilogCompileSuccessComponent(RewardComponent):
|
44
|
+
async def score(self, state: VerilogPublicState, action: Any) -> float:
|
45
|
+
if hasattr(action, "get") and action.get("type") == "compile":
|
46
|
+
# Check if compilation was successful (returncode 0)
|
47
|
+
if action.get("returncode") == 0:
|
48
|
+
return 0.1
|
49
|
+
return 0.0
|
50
|
+
|
51
|
+
|
52
|
+
class VerilogSimulationPassComponent(RewardComponent):
|
53
|
+
async def score(self, state: VerilogPublicState, action: Any) -> float:
|
54
|
+
if hasattr(action, "get") and action.get("type") == "simulate":
|
55
|
+
# Check if simulation passed
|
56
|
+
if action.get("passed", False):
|
57
|
+
return 1.0
|
58
|
+
return 0.0
|
59
|
+
|
60
|
+
|
61
|
+
class VerilogStepPenaltyComponent(RewardComponent):
|
62
|
+
def __init__(self, penalty: float = -0.01):
|
63
|
+
self.penalty = penalty
|
64
|
+
|
65
|
+
async def score(self, state: Any, action: Any) -> float:
|
66
|
+
return self.penalty
|
67
|
+
|
68
|
+
|
69
|
+
class VerilogEngine(StatefulEngine):
|
70
|
+
"""
|
71
|
+
Stateful Verilog evaluation engine with persistent artifact snapshots.
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(self, task_instance: TaskInstance):
|
75
|
+
self.task_instance = task_instance
|
76
|
+
self._total_reward = 0.0
|
77
|
+
self._current_action_for_reward: Optional[Dict[str, Any]] = None
|
78
|
+
|
79
|
+
self.reward_stack = RewardStack(
|
80
|
+
components=[
|
81
|
+
VerilogCompileSuccessComponent(),
|
82
|
+
VerilogSimulationPassComponent(),
|
83
|
+
VerilogStepPenaltyComponent(penalty=-0.01),
|
84
|
+
]
|
85
|
+
)
|
86
|
+
|
87
|
+
# Initialize paths - will be set properly in _reset_engine
|
88
|
+
self.snapshot_dir: Optional[Path] = None
|
89
|
+
self.build_dir: Optional[Path] = None
|
90
|
+
|
91
|
+
# Track last compile/simulate outputs
|
92
|
+
self._last_compile_output: Optional[str] = None
|
93
|
+
self._last_simulate_output: Optional[str] = None
|
94
|
+
|
95
|
+
async def _reset_engine(
|
96
|
+
self, *, seed: Optional[int] = None
|
97
|
+
) -> Tuple[VerilogPrivateState, VerilogPublicState]:
|
98
|
+
"""Initialize the Verilog environment with task files."""
|
99
|
+
self._total_reward = 0.0
|
100
|
+
self._current_action_for_reward = None
|
101
|
+
self._last_compile_output = None
|
102
|
+
self._last_simulate_output = None
|
103
|
+
|
104
|
+
# Initialize snapshot from task instance
|
105
|
+
self._init_snapshot()
|
106
|
+
|
107
|
+
priv = VerilogPrivateState(
|
108
|
+
reward_last=0.0, total_reward=0.0, terminated=False, truncated=False
|
109
|
+
)
|
110
|
+
|
111
|
+
pub = VerilogPublicState(
|
112
|
+
files=self._get_file_contents(),
|
113
|
+
build_dir=str(self.build_dir),
|
114
|
+
task_completed=False,
|
115
|
+
)
|
116
|
+
|
117
|
+
return priv, pub
|
118
|
+
|
119
|
+
async def _step_engine(
|
120
|
+
self, action_result: Dict[str, Any]
|
121
|
+
) -> Tuple[VerilogPrivateState, VerilogPublicState]:
|
122
|
+
"""Process an action result and update engine state."""
|
123
|
+
self._current_action_for_reward = action_result
|
124
|
+
|
125
|
+
# Update last outputs if this is a compile or simulate action
|
126
|
+
if action_result.get("type") == "compile":
|
127
|
+
stdout = action_result.get("stdout", "")
|
128
|
+
stderr = action_result.get("stderr", "")
|
129
|
+
# Combine stdout and stderr for compile output, stderr has the error info
|
130
|
+
self._last_compile_output = stderr if stderr else stdout
|
131
|
+
elif action_result.get("type") == "simulate":
|
132
|
+
self._last_simulate_output = action_result.get("stdout")
|
133
|
+
|
134
|
+
# Calculate reward using RewardStack
|
135
|
+
current_pub_state = VerilogPublicState(
|
136
|
+
files=self._get_file_contents(),
|
137
|
+
build_dir=str(self.build_dir),
|
138
|
+
task_completed=action_result.get("passed", False),
|
139
|
+
)
|
140
|
+
|
141
|
+
reward_from_stack = await self.reward_stack.step_reward(
|
142
|
+
state=current_pub_state, action=self._current_action_for_reward
|
143
|
+
)
|
144
|
+
self._current_action_for_reward = None
|
145
|
+
|
146
|
+
self._total_reward += reward_from_stack
|
147
|
+
|
148
|
+
# Check termination conditions
|
149
|
+
terminated = action_result.get("passed", False) or action_result.get("submitted", False)
|
150
|
+
|
151
|
+
priv = VerilogPrivateState(
|
152
|
+
reward_last=reward_from_stack,
|
153
|
+
total_reward=self._total_reward,
|
154
|
+
terminated=terminated,
|
155
|
+
truncated=False,
|
156
|
+
)
|
157
|
+
|
158
|
+
pub = VerilogPublicState(
|
159
|
+
files=self._get_file_contents(),
|
160
|
+
build_dir=str(self.build_dir),
|
161
|
+
task_completed=action_result.get("passed", False),
|
162
|
+
last_compile_output=self._last_compile_output,
|
163
|
+
last_simulate_output=self._last_simulate_output,
|
164
|
+
)
|
165
|
+
|
166
|
+
return priv, pub
|
167
|
+
|
168
|
+
def _init_snapshot(self) -> None:
|
169
|
+
"""Initialize snapshot directory from task instance data."""
|
170
|
+
if not hasattr(self.task_instance, "snapshot_dir"):
|
171
|
+
raise ValueError("Task instance must have a snapshot_dir attribute")
|
172
|
+
|
173
|
+
self.snapshot_dir = Path(self.task_instance.snapshot_dir)
|
174
|
+
|
175
|
+
if self.snapshot_dir.exists() and any(self.snapshot_dir.iterdir()):
|
176
|
+
# Already initialized
|
177
|
+
self.build_dir = self.snapshot_dir / "build"
|
178
|
+
self.build_dir.mkdir(exist_ok=True)
|
179
|
+
return
|
180
|
+
|
181
|
+
# Copy pristine files from task data
|
182
|
+
pristine_dir = getattr(self.task_instance, "pristine_dir", None)
|
183
|
+
if pristine_dir and Path(pristine_dir).exists():
|
184
|
+
shutil.copytree(pristine_dir, self.snapshot_dir, dirs_exist_ok=True)
|
185
|
+
else:
|
186
|
+
# Create basic structure if no pristine dir
|
187
|
+
self.snapshot_dir.mkdir(parents=True, exist_ok=True)
|
188
|
+
|
189
|
+
self.build_dir = self.snapshot_dir / "build"
|
190
|
+
self.build_dir.mkdir(exist_ok=True)
|
191
|
+
|
192
|
+
def _get_file_contents(self) -> Dict[str, str]:
|
193
|
+
"""Get contents of all Verilog files in the snapshot directory."""
|
194
|
+
if not self.snapshot_dir:
|
195
|
+
return {}
|
196
|
+
|
197
|
+
files = {}
|
198
|
+
for p in self.snapshot_dir.rglob("*.v"):
|
199
|
+
try:
|
200
|
+
relative_path = p.relative_to(self.snapshot_dir)
|
201
|
+
files[str(relative_path)] = p.read_text()
|
202
|
+
except Exception:
|
203
|
+
continue
|
204
|
+
return files
|
205
|
+
|
206
|
+
async def write_file(self, path: str, content: str) -> Dict[str, Any]:
|
207
|
+
"""Write content to a file in the snapshot directory."""
|
208
|
+
if not self.snapshot_dir:
|
209
|
+
return {"ok": False, "error": "Snapshot directory not initialized"}
|
210
|
+
|
211
|
+
file_path = self.snapshot_dir / path
|
212
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
213
|
+
file_path.write_text(content)
|
214
|
+
return {"ok": True, "type": "write_file"}
|
215
|
+
|
216
|
+
async def compile(
|
217
|
+
self, sources: Optional[list] = None, testbench: Optional[str] = None
|
218
|
+
) -> Dict[str, Any]:
|
219
|
+
"""Compile Verilog sources with iverilog."""
|
220
|
+
if not self.snapshot_dir or not self.build_dir:
|
221
|
+
return {"ok": False, "error": "Directories not initialized"}
|
222
|
+
|
223
|
+
# Default to all .v files if no sources specified
|
224
|
+
if sources is None:
|
225
|
+
sources = [str(p.relative_to(self.snapshot_dir)) for p in self.snapshot_dir.glob("*.v")]
|
226
|
+
|
227
|
+
src_paths = [self.snapshot_dir / src for src in sources]
|
228
|
+
|
229
|
+
# Add testbench if specified
|
230
|
+
if testbench:
|
231
|
+
tb_path = self.snapshot_dir / testbench
|
232
|
+
if tb_path.exists():
|
233
|
+
src_paths.append(tb_path)
|
234
|
+
|
235
|
+
binary = self.build_dir / "a.out"
|
236
|
+
cmd = ["iverilog", "-g2012", "-o", str(binary)] + [str(p) for p in src_paths]
|
237
|
+
|
238
|
+
try:
|
239
|
+
proc = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
240
|
+
return {
|
241
|
+
"ok": proc.returncode == 0,
|
242
|
+
"type": "compile",
|
243
|
+
"stdout": proc.stdout,
|
244
|
+
"stderr": proc.stderr,
|
245
|
+
"returncode": proc.returncode,
|
246
|
+
"binary": str(binary) if proc.returncode == 0 else None,
|
247
|
+
}
|
248
|
+
except subprocess.TimeoutExpired:
|
249
|
+
return {"ok": False, "error": "Compilation timeout", "type": "compile"}
|
250
|
+
except Exception as e:
|
251
|
+
return {"ok": False, "error": str(e), "type": "compile"}
|
252
|
+
|
253
|
+
async def simulate(self, binary: Optional[str] = None) -> Dict[str, Any]:
|
254
|
+
"""Run vvp on compiled binary."""
|
255
|
+
if not self.build_dir:
|
256
|
+
return {"ok": False, "error": "Build directory not initialized"}
|
257
|
+
|
258
|
+
bin_path = binary if binary else str(self.build_dir / "a.out")
|
259
|
+
|
260
|
+
try:
|
261
|
+
proc = subprocess.run(["vvp", bin_path], capture_output=True, text=True, timeout=30)
|
262
|
+
|
263
|
+
# Check for various success indicators
|
264
|
+
stdout = proc.stdout
|
265
|
+
passed = (
|
266
|
+
"ALL_TESTS_PASSED" in stdout
|
267
|
+
or ("Mismatches: 0 " in stdout and "samples" in stdout)
|
268
|
+
or ("no mismatches" in stdout.lower() and "errors" not in stdout.lower())
|
269
|
+
)
|
270
|
+
|
271
|
+
return {
|
272
|
+
"ok": True,
|
273
|
+
"type": "simulate",
|
274
|
+
"stdout": proc.stdout,
|
275
|
+
"stderr": proc.stderr,
|
276
|
+
"returncode": proc.returncode,
|
277
|
+
"passed": passed,
|
278
|
+
}
|
279
|
+
except subprocess.TimeoutExpired:
|
280
|
+
return {"ok": False, "error": "Simulation timeout", "type": "simulate"}
|
281
|
+
except Exception as e:
|
282
|
+
return {"ok": False, "error": str(e), "type": "simulate"}
|
283
|
+
|
284
|
+
async def submit(self) -> Dict[str, Any]:
|
285
|
+
"""Submit solution for grading."""
|
286
|
+
# For now, simple check based on last simulation
|
287
|
+
# In a full implementation, this would call the task's verify method
|
288
|
+
return {
|
289
|
+
"ok": True,
|
290
|
+
"type": "submit",
|
291
|
+
"passed": True, # Placeholder
|
292
|
+
"detail": "Submission processed",
|
293
|
+
"submitted": True,
|
294
|
+
}
|
295
|
+
|
296
|
+
async def _serialize_engine(self) -> VerilogEngineSnapshot:
|
297
|
+
"""Serialize engine state to a snapshot."""
|
298
|
+
engine_data = {
|
299
|
+
"total_reward": self._total_reward,
|
300
|
+
"snapshot_dir": str(self.snapshot_dir) if self.snapshot_dir else None,
|
301
|
+
"build_dir": str(self.build_dir) if self.build_dir else None,
|
302
|
+
}
|
303
|
+
|
304
|
+
task_instance_dict = await self.task_instance.serialize()
|
305
|
+
|
306
|
+
return VerilogEngineSnapshot(
|
307
|
+
task_instance_dict=task_instance_dict, engine_snapshot=engine_data
|
308
|
+
)
|
309
|
+
|
310
|
+
@classmethod
|
311
|
+
async def _deserialize_engine(cls, snapshot: VerilogEngineSnapshot) -> "VerilogEngine":
|
312
|
+
"""Deserialize engine from snapshot."""
|
313
|
+
# This would need proper task instance deserialization
|
314
|
+
# For now, create a minimal implementation
|
315
|
+
from synth_ai.environments.examples.verilog.taskset import VerilogTaskInstance
|
316
|
+
|
317
|
+
task_instance = await VerilogTaskInstance.deserialize(snapshot.task_instance_dict)
|
318
|
+
engine = cls(task_instance)
|
319
|
+
|
320
|
+
engine_data = snapshot.engine_snapshot
|
321
|
+
engine._total_reward = engine_data.get("total_reward", 0.0)
|
322
|
+
|
323
|
+
if engine_data.get("snapshot_dir"):
|
324
|
+
engine.snapshot_dir = Path(engine_data["snapshot_dir"])
|
325
|
+
if engine_data.get("build_dir"):
|
326
|
+
engine.build_dir = Path(engine_data["build_dir"])
|
327
|
+
|
328
|
+
return engine
|
@@ -0,0 +1,349 @@
|
|
1
|
+
from typing import List, Optional, Any, Dict, Union
|
2
|
+
from pydantic import BaseModel
|
3
|
+
|
4
|
+
from synth_ai.environments.examples.verilog.engine import (
|
5
|
+
VerilogEngine,
|
6
|
+
VerilogPrivateState,
|
7
|
+
VerilogPublicState,
|
8
|
+
VerilogEngineSnapshot,
|
9
|
+
)
|
10
|
+
from synth_ai.environments.environment.shared_engine import (
|
11
|
+
GetObservationCallable,
|
12
|
+
InternalObservation,
|
13
|
+
)
|
14
|
+
from synth_ai.environments.stateful.core import StatefulEnvironment
|
15
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
16
|
+
from synth_ai.environments.environment.tools import (
|
17
|
+
AbstractTool,
|
18
|
+
EnvToolCall,
|
19
|
+
ToolResult,
|
20
|
+
TOOL_REGISTRY,
|
21
|
+
register_tool,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
# Tool Input Schemas
|
26
|
+
class WriteFileInput(BaseModel):
|
27
|
+
path: str
|
28
|
+
content: str
|
29
|
+
|
30
|
+
|
31
|
+
class CompileInput(BaseModel):
|
32
|
+
sources: Optional[List[str]] = None
|
33
|
+
testbench: Optional[str] = None
|
34
|
+
|
35
|
+
|
36
|
+
class SimulateInput(BaseModel):
|
37
|
+
binary: Optional[str] = None
|
38
|
+
|
39
|
+
|
40
|
+
class SubmitInput(BaseModel):
|
41
|
+
pass # No arguments needed for submit
|
42
|
+
|
43
|
+
|
44
|
+
# Tool Implementations
|
45
|
+
class VerilogWriteFileTool(AbstractTool):
|
46
|
+
name = "write_file"
|
47
|
+
description = "Write content to a Verilog file"
|
48
|
+
call_schema = WriteFileInput
|
49
|
+
result_schema = ToolResult
|
50
|
+
|
51
|
+
def __init__(self, engine: VerilogEngine):
|
52
|
+
self.engine = engine
|
53
|
+
|
54
|
+
async def __call__(self, call: EnvToolCall) -> ToolResult:
|
55
|
+
try:
|
56
|
+
validated_args = self.call_schema(**call.args)
|
57
|
+
result = await self.engine.write_file(validated_args.path, validated_args.content)
|
58
|
+
return ToolResult(ok=result["ok"], payload=result)
|
59
|
+
except Exception as e:
|
60
|
+
return ToolResult(ok=False, error=str(e))
|
61
|
+
|
62
|
+
|
63
|
+
class VerilogCompileTool(AbstractTool):
|
64
|
+
name = "compile"
|
65
|
+
description = "Compile Verilog sources with iverilog"
|
66
|
+
call_schema = CompileInput
|
67
|
+
result_schema = ToolResult
|
68
|
+
|
69
|
+
def __init__(self, engine: VerilogEngine):
|
70
|
+
self.engine = engine
|
71
|
+
|
72
|
+
async def __call__(self, call: EnvToolCall) -> ToolResult:
|
73
|
+
try:
|
74
|
+
validated_args = self.call_schema(**call.args)
|
75
|
+
result = await self.engine.compile(validated_args.sources, validated_args.testbench)
|
76
|
+
return ToolResult(ok=result["ok"], payload=result)
|
77
|
+
except Exception as e:
|
78
|
+
return ToolResult(ok=False, error=str(e))
|
79
|
+
|
80
|
+
|
81
|
+
class VerilogSimulateTool(AbstractTool):
|
82
|
+
name = "simulate"
|
83
|
+
description = "Run vvp on compiled binary"
|
84
|
+
call_schema = SimulateInput
|
85
|
+
result_schema = ToolResult
|
86
|
+
|
87
|
+
def __init__(self, engine: VerilogEngine):
|
88
|
+
self.engine = engine
|
89
|
+
|
90
|
+
async def __call__(self, call: EnvToolCall) -> ToolResult:
|
91
|
+
try:
|
92
|
+
validated_args = self.call_schema(**call.args)
|
93
|
+
result = await self.engine.simulate(validated_args.binary)
|
94
|
+
return ToolResult(ok=result["ok"], payload=result)
|
95
|
+
except Exception as e:
|
96
|
+
return ToolResult(ok=False, error=str(e))
|
97
|
+
|
98
|
+
|
99
|
+
class VerilogSubmitTool(AbstractTool):
|
100
|
+
name = "submit"
|
101
|
+
description = "Submit solution for grading"
|
102
|
+
call_schema = SubmitInput
|
103
|
+
result_schema = ToolResult
|
104
|
+
|
105
|
+
def __init__(self, engine: VerilogEngine):
|
106
|
+
self.engine = engine
|
107
|
+
|
108
|
+
async def __call__(self, call: EnvToolCall) -> ToolResult:
|
109
|
+
try:
|
110
|
+
result = await self.engine.submit()
|
111
|
+
return ToolResult(ok=result["ok"], payload=result)
|
112
|
+
except Exception as e:
|
113
|
+
return ToolResult(ok=False, error=str(e))
|
114
|
+
|
115
|
+
|
116
|
+
class VerilogObservationCallable(GetObservationCallable):
|
117
|
+
async def get_observation(
|
118
|
+
self, pub: VerilogPublicState, priv: VerilogPrivateState
|
119
|
+
) -> InternalObservation:
|
120
|
+
files_summary = f"{len(pub.files)} Verilog files available"
|
121
|
+
if pub.files:
|
122
|
+
files_summary += f": {', '.join(pub.files.keys())}"
|
123
|
+
|
124
|
+
compile_status = ""
|
125
|
+
if pub.last_compile_output is not None:
|
126
|
+
# Check for common error indicators in compile output
|
127
|
+
output_lower = pub.last_compile_output.lower()
|
128
|
+
is_success = not any(
|
129
|
+
indicator in output_lower for indicator in ["error", "failed", "syntax"]
|
130
|
+
)
|
131
|
+
if is_success:
|
132
|
+
compile_status = "Last compile: Success"
|
133
|
+
else:
|
134
|
+
# Include the actual error message to help the agent debug
|
135
|
+
compile_status = f"Last compile: Failed\n{pub.last_compile_output}"
|
136
|
+
|
137
|
+
simulate_status = ""
|
138
|
+
if pub.last_simulate_output:
|
139
|
+
# Use same success detection logic as in engine
|
140
|
+
stdout = pub.last_simulate_output
|
141
|
+
passed = (
|
142
|
+
"ALL_TESTS_PASSED" in stdout
|
143
|
+
or ("Mismatches: 0 " in stdout and "samples" in stdout)
|
144
|
+
or ("no mismatches" in stdout.lower() and "errors" not in stdout.lower())
|
145
|
+
)
|
146
|
+
simulate_status = f"Last simulation: {'Passed' if passed else 'Failed'}"
|
147
|
+
|
148
|
+
observation: Dict[str, Any] = {
|
149
|
+
"files": pub.files,
|
150
|
+
"build_dir": pub.build_dir,
|
151
|
+
"files_summary": files_summary,
|
152
|
+
"task_completed": pub.task_completed,
|
153
|
+
"reward_last": priv.reward_last,
|
154
|
+
"total_reward": priv.total_reward,
|
155
|
+
"terminated": priv.terminated,
|
156
|
+
"compile_status": compile_status,
|
157
|
+
"simulate_status": simulate_status,
|
158
|
+
}
|
159
|
+
return observation # type: ignore[return-value]
|
160
|
+
|
161
|
+
|
162
|
+
class VerilogEnvironment(StatefulEnvironment):
|
163
|
+
def __init__(
|
164
|
+
self,
|
165
|
+
task_instance: TaskInstance,
|
166
|
+
custom_obs: Optional[GetObservationCallable] = None,
|
167
|
+
):
|
168
|
+
self.name = "VerilogEval"
|
169
|
+
self.task_instance = task_instance
|
170
|
+
self.custom_observation_callable = custom_obs or VerilogObservationCallable()
|
171
|
+
self.engine: VerilogEngine = VerilogEngine(task_instance)
|
172
|
+
|
173
|
+
# Initialize tools
|
174
|
+
self._tools_instances = {
|
175
|
+
"write_file": VerilogWriteFileTool(self.engine),
|
176
|
+
"compile": VerilogCompileTool(self.engine),
|
177
|
+
"simulate": VerilogSimulateTool(self.engine),
|
178
|
+
"submit": VerilogSubmitTool(self.engine),
|
179
|
+
}
|
180
|
+
|
181
|
+
# Register tools
|
182
|
+
for tool_name, tool_instance in self._tools_instances.items():
|
183
|
+
if tool_name not in TOOL_REGISTRY:
|
184
|
+
register_tool(tool_instance)
|
185
|
+
|
186
|
+
async def initialize(self) -> InternalObservation:
|
187
|
+
priv, pub = await self.engine._reset_engine()
|
188
|
+
return await self._to_observation(priv, pub)
|
189
|
+
|
190
|
+
async def terminate(self) -> InternalObservation:
|
191
|
+
# Get current state and mark as terminated
|
192
|
+
try:
|
193
|
+
# Try to get current state from engine
|
194
|
+
current_files = self.engine._get_file_contents()
|
195
|
+
build_dir = str(self.engine.build_dir) if self.engine.build_dir else ""
|
196
|
+
|
197
|
+
priv = VerilogPrivateState(
|
198
|
+
reward_last=0.0,
|
199
|
+
total_reward=self.engine._total_reward,
|
200
|
+
terminated=True,
|
201
|
+
truncated=False,
|
202
|
+
)
|
203
|
+
|
204
|
+
pub = VerilogPublicState(files=current_files, build_dir=build_dir, task_completed=False)
|
205
|
+
except Exception:
|
206
|
+
# Fallback if engine state is not accessible
|
207
|
+
priv = VerilogPrivateState(
|
208
|
+
reward_last=0.0, total_reward=0.0, terminated=True, truncated=False
|
209
|
+
)
|
210
|
+
|
211
|
+
pub = VerilogPublicState(files={}, build_dir="", task_completed=False)
|
212
|
+
|
213
|
+
obs = await self._to_observation(priv, pub)
|
214
|
+
if isinstance(obs, dict):
|
215
|
+
obs["terminated"] = True
|
216
|
+
obs["message"] = "Environment terminated."
|
217
|
+
return obs
|
218
|
+
|
219
|
+
def validate_tool_calls(
|
220
|
+
self,
|
221
|
+
tool_calls: Union[
|
222
|
+
EnvToolCall,
|
223
|
+
List[Dict[str, Any]],
|
224
|
+
List[List[Dict[str, Any]]],
|
225
|
+
Dict[str, Any],
|
226
|
+
],
|
227
|
+
) -> EnvToolCall:
|
228
|
+
"""Normalize and validate tool calls to a single EnvToolCall."""
|
229
|
+
raw_call_data: Dict[str, Any]
|
230
|
+
|
231
|
+
if isinstance(tool_calls, list):
|
232
|
+
if not tool_calls:
|
233
|
+
raise ValueError("Received empty list of tool calls.")
|
234
|
+
first_item = tool_calls[0]
|
235
|
+
if isinstance(first_item, list):
|
236
|
+
if not first_item:
|
237
|
+
raise ValueError("Received empty inner list of tool calls.")
|
238
|
+
raw_call_data = first_item[0]
|
239
|
+
elif isinstance(first_item, dict):
|
240
|
+
raw_call_data = first_item
|
241
|
+
elif isinstance(first_item, EnvToolCall):
|
242
|
+
return first_item
|
243
|
+
else:
|
244
|
+
raise TypeError(f"Unexpected type in tool_calls list: {type(first_item)}")
|
245
|
+
elif isinstance(tool_calls, dict):
|
246
|
+
raw_call_data = tool_calls
|
247
|
+
elif isinstance(tool_calls, EnvToolCall):
|
248
|
+
return tool_calls
|
249
|
+
else:
|
250
|
+
raise TypeError(f"Unexpected type for tool_calls: {type(tool_calls)}")
|
251
|
+
|
252
|
+
if not isinstance(raw_call_data, dict):
|
253
|
+
raise TypeError(f"Processed call data is not a dict: {type(raw_call_data)}")
|
254
|
+
|
255
|
+
# Convert dict to EnvToolCall instance
|
256
|
+
tool_name = raw_call_data.get("tool")
|
257
|
+
tool_args = raw_call_data.get("args", {})
|
258
|
+
|
259
|
+
valid_tools = {"write_file", "compile", "simulate", "submit"}
|
260
|
+
if tool_name not in valid_tools:
|
261
|
+
raise ValueError(f"Unknown tool: {tool_name}. Expected one of: {valid_tools}")
|
262
|
+
|
263
|
+
return EnvToolCall(tool=tool_name, args=tool_args)
|
264
|
+
|
265
|
+
async def step(
|
266
|
+
self,
|
267
|
+
tool_calls: Union[
|
268
|
+
EnvToolCall,
|
269
|
+
List[Dict[str, Any]],
|
270
|
+
List[List[Dict[str, Any]]],
|
271
|
+
Dict[str, Any],
|
272
|
+
],
|
273
|
+
) -> InternalObservation:
|
274
|
+
agent_call = self.validate_tool_calls(tool_calls)
|
275
|
+
|
276
|
+
# Get the appropriate tool
|
277
|
+
tool_instance = self._tools_instances.get(agent_call.tool)
|
278
|
+
if not tool_instance:
|
279
|
+
tool_instance = TOOL_REGISTRY.get(agent_call.tool)
|
280
|
+
if not tool_instance:
|
281
|
+
raise ValueError(f"Tool '{agent_call.tool}' not found.")
|
282
|
+
|
283
|
+
# Execute the tool
|
284
|
+
tool_result: ToolResult = await tool_instance(agent_call)
|
285
|
+
|
286
|
+
# Update engine state with tool result
|
287
|
+
if tool_result.payload:
|
288
|
+
action_result = tool_result.payload
|
289
|
+
elif not tool_result.ok:
|
290
|
+
action_result = {
|
291
|
+
"ok": False,
|
292
|
+
"error": tool_result.error,
|
293
|
+
"type": agent_call.tool,
|
294
|
+
}
|
295
|
+
else:
|
296
|
+
action_result = {}
|
297
|
+
|
298
|
+
priv_state, pub_state = await self.engine._step_engine(action_result)
|
299
|
+
|
300
|
+
return await self._to_observation(priv_state, pub_state)
|
301
|
+
|
302
|
+
async def checkpoint(self) -> InternalObservation:
|
303
|
+
engine_snapshot: VerilogEngineSnapshot = await self.engine._serialize_engine()
|
304
|
+
|
305
|
+
# Get current state for observation
|
306
|
+
try:
|
307
|
+
current_files = self.engine._get_file_contents()
|
308
|
+
build_dir = str(self.engine.build_dir) if self.engine.build_dir else ""
|
309
|
+
|
310
|
+
priv = VerilogPrivateState(
|
311
|
+
reward_last=0.0,
|
312
|
+
total_reward=self.engine._total_reward,
|
313
|
+
terminated=False,
|
314
|
+
truncated=False,
|
315
|
+
)
|
316
|
+
|
317
|
+
pub = VerilogPublicState(files=current_files, build_dir=build_dir, task_completed=False)
|
318
|
+
|
319
|
+
obs_data = await self._to_observation(priv, pub)
|
320
|
+
except Exception:
|
321
|
+
obs_data = {"message": "Checkpoint created"}
|
322
|
+
|
323
|
+
if isinstance(obs_data, dict):
|
324
|
+
obs_data["engine_snapshot_data"] = engine_snapshot.model_dump()
|
325
|
+
|
326
|
+
return obs_data
|
327
|
+
|
328
|
+
async def _to_observation(
|
329
|
+
self,
|
330
|
+
priv: VerilogPrivateState,
|
331
|
+
pub: VerilogPublicState,
|
332
|
+
extra_obs: Optional[Dict[str, Any]] = None,
|
333
|
+
) -> InternalObservation:
|
334
|
+
observation = await self.custom_observation_callable.get_observation(pub, priv)
|
335
|
+
if extra_obs and isinstance(observation, dict):
|
336
|
+
observation.update(extra_obs)
|
337
|
+
return observation
|
338
|
+
|
339
|
+
async def _serialize_engine(self) -> VerilogEngineSnapshot:
|
340
|
+
return await self.engine._serialize_engine()
|
341
|
+
|
342
|
+
@classmethod
|
343
|
+
async def _deserialize_engine(
|
344
|
+
cls, snapshot: VerilogEngineSnapshot, task_instance: TaskInstance
|
345
|
+
) -> "VerilogEnvironment":
|
346
|
+
eng = await VerilogEngine._deserialize_engine(snapshot)
|
347
|
+
env = cls(task_instance)
|
348
|
+
env.engine = eng
|
349
|
+
return env
|