synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +579 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/wordle/__init__.py +29 -0
- synth_ai/environments/examples/wordle/engine.py +391 -0
- synth_ai/environments/examples/wordle/environment.py +154 -0
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
- synth_ai/environments/examples/wordle/taskset.py +222 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/environments/service/core_routes.py +38 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
- synth_ai/learning/prompts/mipro.py +273 -1
- synth_ai/learning/prompts/random_search.py +247 -0
- synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
- synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
- synth_ai/lm/injection.py +81 -0
- synth_ai/lm/overrides.py +204 -0
- synth_ai/lm/provider_support/anthropic.py +39 -12
- synth_ai/lm/provider_support/openai.py +31 -4
- synth_ai/lm/vendors/core/anthropic_api.py +16 -0
- synth_ai/lm/vendors/openai_standard.py +35 -5
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,214 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import random
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from typing import List
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from uuid import uuid4
|
9
|
+
from synth_ai.environments.tasks.core import (
|
10
|
+
TaskInstance,
|
11
|
+
TaskInstanceMetadata,
|
12
|
+
TaskInstanceSet,
|
13
|
+
Impetus,
|
14
|
+
Intent,
|
15
|
+
SplitInfo,
|
16
|
+
)
|
17
|
+
|
18
|
+
from .engine import COORD_TO_IDX, WIN_PATTERNS, PLAYER_MARKS
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class TicTacToeTaskInstanceMetadata(TaskInstanceMetadata):
|
23
|
+
starting_player: str # "X" or "O"
|
24
|
+
opening_moves: List[str] # Pre-made moves to create position
|
25
|
+
optimal_outcome: str # "win", "draw", "loss" for starting player
|
26
|
+
position_complexity: int # Number of pre-moves made
|
27
|
+
shortest_win_length: int # Min moves to force win/draw
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class TicTacToeTaskInstance(TaskInstance):
|
32
|
+
async def serialize(self) -> dict:
|
33
|
+
return {
|
34
|
+
"id": str(self.id),
|
35
|
+
"impetus": {"instructions": self.impetus.instructions},
|
36
|
+
"intent": {
|
37
|
+
"rubric": self.intent.rubric,
|
38
|
+
"gold_trajectories": self.intent.gold_trajectories,
|
39
|
+
"gold_state_diff": self.intent.gold_state_diff,
|
40
|
+
},
|
41
|
+
"metadata": {
|
42
|
+
"starting_player": self.metadata.starting_player,
|
43
|
+
"opening_moves": self.metadata.opening_moves,
|
44
|
+
"optimal_outcome": self.metadata.optimal_outcome,
|
45
|
+
"position_complexity": self.metadata.position_complexity,
|
46
|
+
"shortest_win_length": self.metadata.shortest_win_length,
|
47
|
+
},
|
48
|
+
"is_reproducible": self.is_reproducible,
|
49
|
+
"initial_engine_snapshot": self.initial_engine_snapshot,
|
50
|
+
}
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
async def deserialize(cls, data: dict) -> "TicTacToeTaskInstance":
|
54
|
+
from uuid import UUID
|
55
|
+
|
56
|
+
metadata = TicTacToeTaskInstanceMetadata(
|
57
|
+
starting_player=data["metadata"]["starting_player"],
|
58
|
+
opening_moves=data["metadata"]["opening_moves"],
|
59
|
+
optimal_outcome=data["metadata"]["optimal_outcome"],
|
60
|
+
position_complexity=data["metadata"]["position_complexity"],
|
61
|
+
shortest_win_length=data["metadata"]["shortest_win_length"],
|
62
|
+
)
|
63
|
+
|
64
|
+
return cls(
|
65
|
+
id=UUID(data["id"]),
|
66
|
+
impetus=Impetus(instructions=data["impetus"]["instructions"]),
|
67
|
+
intent=Intent(
|
68
|
+
rubric=data["intent"]["rubric"],
|
69
|
+
gold_trajectories=data["intent"]["gold_trajectories"],
|
70
|
+
gold_state_diff=data["intent"]["gold_state_diff"],
|
71
|
+
),
|
72
|
+
metadata=metadata,
|
73
|
+
is_reproducible=data["is_reproducible"],
|
74
|
+
initial_engine_snapshot=data["initial_engine_snapshot"],
|
75
|
+
)
|
76
|
+
|
77
|
+
|
78
|
+
def _evaluate_position(board: np.ndarray, player: int) -> str:
|
79
|
+
"""Simple evaluation of position outcome with perfect play"""
|
80
|
+
# Check for immediate win
|
81
|
+
for pattern in WIN_PATTERNS:
|
82
|
+
values = [board[i] for i in pattern]
|
83
|
+
if values.count(player) == 3:
|
84
|
+
return "win"
|
85
|
+
if values.count(3 - player) == 3:
|
86
|
+
return "loss"
|
87
|
+
|
88
|
+
# Check if board is full
|
89
|
+
if np.all(board != 0):
|
90
|
+
return "draw"
|
91
|
+
|
92
|
+
# For simplicity, assume draw for non-terminal positions
|
93
|
+
# In a real implementation, this would use minimax
|
94
|
+
return "draw"
|
95
|
+
|
96
|
+
|
97
|
+
def _count_shortest_win(board: np.ndarray, player: int) -> int:
|
98
|
+
"""Count minimum moves to force a win/draw"""
|
99
|
+
# Simplified: return remaining empty cells
|
100
|
+
empty_cells = sum(1 for i in range(9) if board[i] == 0)
|
101
|
+
return max(1, empty_cells // 2)
|
102
|
+
|
103
|
+
|
104
|
+
async def create_tictactoe_taskset() -> TaskInstanceSet:
|
105
|
+
"""Generate diverse TicTacToe starting positions"""
|
106
|
+
instances = []
|
107
|
+
|
108
|
+
# Configuration for different position types
|
109
|
+
POSITION_CONFIGS = {
|
110
|
+
"opening": {"pre_moves": 0, "count": 10}, # Fresh games
|
111
|
+
"early": {"pre_moves": 1, "count": 15}, # After 1 move
|
112
|
+
"mid": {"pre_moves": 2, "count": 15}, # After 2 moves
|
113
|
+
"complex": {"pre_moves": 3, "count": 10}, # After 3 moves
|
114
|
+
}
|
115
|
+
|
116
|
+
all_coords = list(COORD_TO_IDX.keys())
|
117
|
+
|
118
|
+
for config_name, config in POSITION_CONFIGS.items():
|
119
|
+
for i in range(config["count"]):
|
120
|
+
# Generate random opening moves
|
121
|
+
opening_moves = []
|
122
|
+
board = np.zeros(9, dtype=int)
|
123
|
+
current_player = "X"
|
124
|
+
|
125
|
+
# Make pre-moves
|
126
|
+
available_coords = all_coords.copy()
|
127
|
+
for move_idx in range(config["pre_moves"]):
|
128
|
+
if not available_coords:
|
129
|
+
break
|
130
|
+
|
131
|
+
# Random move
|
132
|
+
move = random.choice(available_coords)
|
133
|
+
opening_moves.append(move)
|
134
|
+
available_coords.remove(move)
|
135
|
+
|
136
|
+
# Update board
|
137
|
+
board[COORD_TO_IDX[move]] = PLAYER_MARKS[current_player]
|
138
|
+
current_player = "O" if current_player == "X" else "X"
|
139
|
+
|
140
|
+
# Evaluate position
|
141
|
+
starting_player = current_player
|
142
|
+
optimal_outcome = _evaluate_position(board, PLAYER_MARKS[starting_player])
|
143
|
+
shortest_win = _count_shortest_win(board, PLAYER_MARKS[starting_player])
|
144
|
+
|
145
|
+
# Create metadata
|
146
|
+
metadata = TicTacToeTaskInstanceMetadata(
|
147
|
+
starting_player=starting_player,
|
148
|
+
opening_moves=opening_moves,
|
149
|
+
optimal_outcome=optimal_outcome,
|
150
|
+
position_complexity=config["pre_moves"],
|
151
|
+
shortest_win_length=shortest_win,
|
152
|
+
)
|
153
|
+
|
154
|
+
# Create instance
|
155
|
+
impetus = Impetus(
|
156
|
+
instructions=(
|
157
|
+
f"You are playing TicTacToe as {starting_player}. "
|
158
|
+
+ "The game is played on a 3x3 grid with cells labeled A1-A3, B1-B3, C1-C3. "
|
159
|
+
+ (
|
160
|
+
f"The game has already had {len(opening_moves)} moves."
|
161
|
+
if opening_moves
|
162
|
+
else "This is a fresh game."
|
163
|
+
)
|
164
|
+
+ f" You must place your mark ({starting_player}) in an empty cell. "
|
165
|
+
+ "Win by getting three of your marks in a row (horizontally, vertically, or diagonally)."
|
166
|
+
)
|
167
|
+
)
|
168
|
+
|
169
|
+
intent = Intent(
|
170
|
+
rubric={"goal": f"Win the game as {starting_player}, or at least force a draw"},
|
171
|
+
gold_trajectories=None,
|
172
|
+
gold_state_diff={"optimal_outcome": optimal_outcome},
|
173
|
+
)
|
174
|
+
|
175
|
+
instance = TicTacToeTaskInstance(
|
176
|
+
id=uuid4(),
|
177
|
+
impetus=impetus,
|
178
|
+
intent=intent,
|
179
|
+
metadata=metadata,
|
180
|
+
is_reproducible=True,
|
181
|
+
initial_engine_snapshot=None,
|
182
|
+
)
|
183
|
+
|
184
|
+
instances.append(instance)
|
185
|
+
|
186
|
+
# Shuffle instances
|
187
|
+
random.shuffle(instances)
|
188
|
+
|
189
|
+
# Define splits based on complexity
|
190
|
+
val_ids = {inst.id for inst in instances if inst.metadata.position_complexity == 1}
|
191
|
+
test_ids = {inst.id for inst in instances if inst.metadata.position_complexity >= 2}
|
192
|
+
|
193
|
+
# If not enough instances for splits, use simple division
|
194
|
+
if len(val_ids) == 0 or len(test_ids) == 0:
|
195
|
+
total = len(instances)
|
196
|
+
val_end = int(total * 0.15)
|
197
|
+
test_end = int(total * 0.30)
|
198
|
+
val_ids = {instances[i].id for i in range(val_end)}
|
199
|
+
test_ids = {instances[i].id for i in range(val_end, test_end)}
|
200
|
+
|
201
|
+
split_info = SplitInfo(
|
202
|
+
val_instance_ids=val_ids, test_instance_ids=test_ids, _is_split_defined=True
|
203
|
+
)
|
204
|
+
|
205
|
+
return TaskInstanceSet(
|
206
|
+
name="TicTacToe Procedural TaskSet",
|
207
|
+
description="Procedurally generated TicTacToe tasks with varying starting positions.",
|
208
|
+
instances=instances,
|
209
|
+
split_info=split_info,
|
210
|
+
)
|
211
|
+
|
212
|
+
|
213
|
+
# Make taskset available as module attribute
|
214
|
+
taskset = create_tictactoe_taskset
|
@@ -0,0 +1,10 @@
|
|
1
|
+
from .engine import VerilogEngine
|
2
|
+
from .environment import VerilogEnvironment
|
3
|
+
from .taskset import VerilogTaskInstance, create_verilog_taskset
|
4
|
+
|
5
|
+
__all__ = [
|
6
|
+
"VerilogEngine",
|
7
|
+
"VerilogEnvironment",
|
8
|
+
"VerilogTaskInstance",
|
9
|
+
"create_verilog_taskset",
|
10
|
+
]
|
@@ -0,0 +1,328 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import shutil
|
3
|
+
import subprocess
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Dict, Any, Tuple, Optional
|
6
|
+
from dataclasses import dataclass
|
7
|
+
|
8
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
9
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
10
|
+
from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class VerilogEngineSnapshot(StatefulEngineSnapshot):
|
15
|
+
task_instance_dict: Dict
|
16
|
+
engine_snapshot: Dict
|
17
|
+
|
18
|
+
def model_dump(self) -> Dict:
|
19
|
+
"""Convert dataclass to dictionary for compatibility with Pydantic models."""
|
20
|
+
return {
|
21
|
+
"task_instance_dict": self.task_instance_dict,
|
22
|
+
"engine_snapshot": self.engine_snapshot,
|
23
|
+
}
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class VerilogPublicState:
|
28
|
+
files: Dict[str, str]
|
29
|
+
build_dir: str
|
30
|
+
task_completed: bool = False
|
31
|
+
last_compile_output: Optional[str] = None
|
32
|
+
last_simulate_output: Optional[str] = None
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class VerilogPrivateState:
|
37
|
+
reward_last: float
|
38
|
+
total_reward: float
|
39
|
+
terminated: bool
|
40
|
+
truncated: bool
|
41
|
+
|
42
|
+
|
43
|
+
class VerilogCompileSuccessComponent(RewardComponent):
|
44
|
+
async def score(self, state: VerilogPublicState, action: Any) -> float:
|
45
|
+
if hasattr(action, "get") and action.get("type") == "compile":
|
46
|
+
# Check if compilation was successful (returncode 0)
|
47
|
+
if action.get("returncode") == 0:
|
48
|
+
return 0.1
|
49
|
+
return 0.0
|
50
|
+
|
51
|
+
|
52
|
+
class VerilogSimulationPassComponent(RewardComponent):
|
53
|
+
async def score(self, state: VerilogPublicState, action: Any) -> float:
|
54
|
+
if hasattr(action, "get") and action.get("type") == "simulate":
|
55
|
+
# Check if simulation passed
|
56
|
+
if action.get("passed", False):
|
57
|
+
return 1.0
|
58
|
+
return 0.0
|
59
|
+
|
60
|
+
|
61
|
+
class VerilogStepPenaltyComponent(RewardComponent):
|
62
|
+
def __init__(self, penalty: float = -0.01):
|
63
|
+
self.penalty = penalty
|
64
|
+
|
65
|
+
async def score(self, state: Any, action: Any) -> float:
|
66
|
+
return self.penalty
|
67
|
+
|
68
|
+
|
69
|
+
class VerilogEngine(StatefulEngine):
|
70
|
+
"""
|
71
|
+
Stateful Verilog evaluation engine with persistent artifact snapshots.
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(self, task_instance: TaskInstance):
|
75
|
+
self.task_instance = task_instance
|
76
|
+
self._total_reward = 0.0
|
77
|
+
self._current_action_for_reward: Optional[Dict[str, Any]] = None
|
78
|
+
|
79
|
+
self.reward_stack = RewardStack(
|
80
|
+
components=[
|
81
|
+
VerilogCompileSuccessComponent(),
|
82
|
+
VerilogSimulationPassComponent(),
|
83
|
+
VerilogStepPenaltyComponent(penalty=-0.01),
|
84
|
+
]
|
85
|
+
)
|
86
|
+
|
87
|
+
# Initialize paths - will be set properly in _reset_engine
|
88
|
+
self.snapshot_dir: Optional[Path] = None
|
89
|
+
self.build_dir: Optional[Path] = None
|
90
|
+
|
91
|
+
# Track last compile/simulate outputs
|
92
|
+
self._last_compile_output: Optional[str] = None
|
93
|
+
self._last_simulate_output: Optional[str] = None
|
94
|
+
|
95
|
+
async def _reset_engine(
|
96
|
+
self, *, seed: Optional[int] = None
|
97
|
+
) -> Tuple[VerilogPrivateState, VerilogPublicState]:
|
98
|
+
"""Initialize the Verilog environment with task files."""
|
99
|
+
self._total_reward = 0.0
|
100
|
+
self._current_action_for_reward = None
|
101
|
+
self._last_compile_output = None
|
102
|
+
self._last_simulate_output = None
|
103
|
+
|
104
|
+
# Initialize snapshot from task instance
|
105
|
+
self._init_snapshot()
|
106
|
+
|
107
|
+
priv = VerilogPrivateState(
|
108
|
+
reward_last=0.0, total_reward=0.0, terminated=False, truncated=False
|
109
|
+
)
|
110
|
+
|
111
|
+
pub = VerilogPublicState(
|
112
|
+
files=self._get_file_contents(),
|
113
|
+
build_dir=str(self.build_dir),
|
114
|
+
task_completed=False,
|
115
|
+
)
|
116
|
+
|
117
|
+
return priv, pub
|
118
|
+
|
119
|
+
async def _step_engine(
|
120
|
+
self, action_result: Dict[str, Any]
|
121
|
+
) -> Tuple[VerilogPrivateState, VerilogPublicState]:
|
122
|
+
"""Process an action result and update engine state."""
|
123
|
+
self._current_action_for_reward = action_result
|
124
|
+
|
125
|
+
# Update last outputs if this is a compile or simulate action
|
126
|
+
if action_result.get("type") == "compile":
|
127
|
+
stdout = action_result.get("stdout", "")
|
128
|
+
stderr = action_result.get("stderr", "")
|
129
|
+
# Combine stdout and stderr for compile output, stderr has the error info
|
130
|
+
self._last_compile_output = stderr if stderr else stdout
|
131
|
+
elif action_result.get("type") == "simulate":
|
132
|
+
self._last_simulate_output = action_result.get("stdout")
|
133
|
+
|
134
|
+
# Calculate reward using RewardStack
|
135
|
+
current_pub_state = VerilogPublicState(
|
136
|
+
files=self._get_file_contents(),
|
137
|
+
build_dir=str(self.build_dir),
|
138
|
+
task_completed=action_result.get("passed", False),
|
139
|
+
)
|
140
|
+
|
141
|
+
reward_from_stack = await self.reward_stack.step_reward(
|
142
|
+
state=current_pub_state, action=self._current_action_for_reward
|
143
|
+
)
|
144
|
+
self._current_action_for_reward = None
|
145
|
+
|
146
|
+
self._total_reward += reward_from_stack
|
147
|
+
|
148
|
+
# Check termination conditions
|
149
|
+
terminated = action_result.get("passed", False) or action_result.get("submitted", False)
|
150
|
+
|
151
|
+
priv = VerilogPrivateState(
|
152
|
+
reward_last=reward_from_stack,
|
153
|
+
total_reward=self._total_reward,
|
154
|
+
terminated=terminated,
|
155
|
+
truncated=False,
|
156
|
+
)
|
157
|
+
|
158
|
+
pub = VerilogPublicState(
|
159
|
+
files=self._get_file_contents(),
|
160
|
+
build_dir=str(self.build_dir),
|
161
|
+
task_completed=action_result.get("passed", False),
|
162
|
+
last_compile_output=self._last_compile_output,
|
163
|
+
last_simulate_output=self._last_simulate_output,
|
164
|
+
)
|
165
|
+
|
166
|
+
return priv, pub
|
167
|
+
|
168
|
+
def _init_snapshot(self) -> None:
|
169
|
+
"""Initialize snapshot directory from task instance data."""
|
170
|
+
if not hasattr(self.task_instance, "snapshot_dir"):
|
171
|
+
raise ValueError("Task instance must have a snapshot_dir attribute")
|
172
|
+
|
173
|
+
self.snapshot_dir = Path(self.task_instance.snapshot_dir)
|
174
|
+
|
175
|
+
if self.snapshot_dir.exists() and any(self.snapshot_dir.iterdir()):
|
176
|
+
# Already initialized
|
177
|
+
self.build_dir = self.snapshot_dir / "build"
|
178
|
+
self.build_dir.mkdir(exist_ok=True)
|
179
|
+
return
|
180
|
+
|
181
|
+
# Copy pristine files from task data
|
182
|
+
pristine_dir = getattr(self.task_instance, "pristine_dir", None)
|
183
|
+
if pristine_dir and Path(pristine_dir).exists():
|
184
|
+
shutil.copytree(pristine_dir, self.snapshot_dir, dirs_exist_ok=True)
|
185
|
+
else:
|
186
|
+
# Create basic structure if no pristine dir
|
187
|
+
self.snapshot_dir.mkdir(parents=True, exist_ok=True)
|
188
|
+
|
189
|
+
self.build_dir = self.snapshot_dir / "build"
|
190
|
+
self.build_dir.mkdir(exist_ok=True)
|
191
|
+
|
192
|
+
def _get_file_contents(self) -> Dict[str, str]:
|
193
|
+
"""Get contents of all Verilog files in the snapshot directory."""
|
194
|
+
if not self.snapshot_dir:
|
195
|
+
return {}
|
196
|
+
|
197
|
+
files = {}
|
198
|
+
for p in self.snapshot_dir.rglob("*.v"):
|
199
|
+
try:
|
200
|
+
relative_path = p.relative_to(self.snapshot_dir)
|
201
|
+
files[str(relative_path)] = p.read_text()
|
202
|
+
except Exception:
|
203
|
+
continue
|
204
|
+
return files
|
205
|
+
|
206
|
+
async def write_file(self, path: str, content: str) -> Dict[str, Any]:
|
207
|
+
"""Write content to a file in the snapshot directory."""
|
208
|
+
if not self.snapshot_dir:
|
209
|
+
return {"ok": False, "error": "Snapshot directory not initialized"}
|
210
|
+
|
211
|
+
file_path = self.snapshot_dir / path
|
212
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
213
|
+
file_path.write_text(content)
|
214
|
+
return {"ok": True, "type": "write_file"}
|
215
|
+
|
216
|
+
async def compile(
|
217
|
+
self, sources: Optional[list] = None, testbench: Optional[str] = None
|
218
|
+
) -> Dict[str, Any]:
|
219
|
+
"""Compile Verilog sources with iverilog."""
|
220
|
+
if not self.snapshot_dir or not self.build_dir:
|
221
|
+
return {"ok": False, "error": "Directories not initialized"}
|
222
|
+
|
223
|
+
# Default to all .v files if no sources specified
|
224
|
+
if sources is None:
|
225
|
+
sources = [str(p.relative_to(self.snapshot_dir)) for p in self.snapshot_dir.glob("*.v")]
|
226
|
+
|
227
|
+
src_paths = [self.snapshot_dir / src for src in sources]
|
228
|
+
|
229
|
+
# Add testbench if specified
|
230
|
+
if testbench:
|
231
|
+
tb_path = self.snapshot_dir / testbench
|
232
|
+
if tb_path.exists():
|
233
|
+
src_paths.append(tb_path)
|
234
|
+
|
235
|
+
binary = self.build_dir / "a.out"
|
236
|
+
cmd = ["iverilog", "-g2012", "-o", str(binary)] + [str(p) for p in src_paths]
|
237
|
+
|
238
|
+
try:
|
239
|
+
proc = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
240
|
+
return {
|
241
|
+
"ok": proc.returncode == 0,
|
242
|
+
"type": "compile",
|
243
|
+
"stdout": proc.stdout,
|
244
|
+
"stderr": proc.stderr,
|
245
|
+
"returncode": proc.returncode,
|
246
|
+
"binary": str(binary) if proc.returncode == 0 else None,
|
247
|
+
}
|
248
|
+
except subprocess.TimeoutExpired:
|
249
|
+
return {"ok": False, "error": "Compilation timeout", "type": "compile"}
|
250
|
+
except Exception as e:
|
251
|
+
return {"ok": False, "error": str(e), "type": "compile"}
|
252
|
+
|
253
|
+
async def simulate(self, binary: Optional[str] = None) -> Dict[str, Any]:
|
254
|
+
"""Run vvp on compiled binary."""
|
255
|
+
if not self.build_dir:
|
256
|
+
return {"ok": False, "error": "Build directory not initialized"}
|
257
|
+
|
258
|
+
bin_path = binary if binary else str(self.build_dir / "a.out")
|
259
|
+
|
260
|
+
try:
|
261
|
+
proc = subprocess.run(["vvp", bin_path], capture_output=True, text=True, timeout=30)
|
262
|
+
|
263
|
+
# Check for various success indicators
|
264
|
+
stdout = proc.stdout
|
265
|
+
passed = (
|
266
|
+
"ALL_TESTS_PASSED" in stdout
|
267
|
+
or ("Mismatches: 0 " in stdout and "samples" in stdout)
|
268
|
+
or ("no mismatches" in stdout.lower() and "errors" not in stdout.lower())
|
269
|
+
)
|
270
|
+
|
271
|
+
return {
|
272
|
+
"ok": True,
|
273
|
+
"type": "simulate",
|
274
|
+
"stdout": proc.stdout,
|
275
|
+
"stderr": proc.stderr,
|
276
|
+
"returncode": proc.returncode,
|
277
|
+
"passed": passed,
|
278
|
+
}
|
279
|
+
except subprocess.TimeoutExpired:
|
280
|
+
return {"ok": False, "error": "Simulation timeout", "type": "simulate"}
|
281
|
+
except Exception as e:
|
282
|
+
return {"ok": False, "error": str(e), "type": "simulate"}
|
283
|
+
|
284
|
+
async def submit(self) -> Dict[str, Any]:
|
285
|
+
"""Submit solution for grading."""
|
286
|
+
# For now, simple check based on last simulation
|
287
|
+
# In a full implementation, this would call the task's verify method
|
288
|
+
return {
|
289
|
+
"ok": True,
|
290
|
+
"type": "submit",
|
291
|
+
"passed": True, # Placeholder
|
292
|
+
"detail": "Submission processed",
|
293
|
+
"submitted": True,
|
294
|
+
}
|
295
|
+
|
296
|
+
async def _serialize_engine(self) -> VerilogEngineSnapshot:
|
297
|
+
"""Serialize engine state to a snapshot."""
|
298
|
+
engine_data = {
|
299
|
+
"total_reward": self._total_reward,
|
300
|
+
"snapshot_dir": str(self.snapshot_dir) if self.snapshot_dir else None,
|
301
|
+
"build_dir": str(self.build_dir) if self.build_dir else None,
|
302
|
+
}
|
303
|
+
|
304
|
+
task_instance_dict = await self.task_instance.serialize()
|
305
|
+
|
306
|
+
return VerilogEngineSnapshot(
|
307
|
+
task_instance_dict=task_instance_dict, engine_snapshot=engine_data
|
308
|
+
)
|
309
|
+
|
310
|
+
@classmethod
|
311
|
+
async def _deserialize_engine(cls, snapshot: VerilogEngineSnapshot) -> "VerilogEngine":
|
312
|
+
"""Deserialize engine from snapshot."""
|
313
|
+
# This would need proper task instance deserialization
|
314
|
+
# For now, create a minimal implementation
|
315
|
+
from synth_ai.environments.examples.verilog.taskset import VerilogTaskInstance
|
316
|
+
|
317
|
+
task_instance = await VerilogTaskInstance.deserialize(snapshot.task_instance_dict)
|
318
|
+
engine = cls(task_instance)
|
319
|
+
|
320
|
+
engine_data = snapshot.engine_snapshot
|
321
|
+
engine._total_reward = engine_data.get("total_reward", 0.0)
|
322
|
+
|
323
|
+
if engine_data.get("snapshot_dir"):
|
324
|
+
engine.snapshot_dir = Path(engine_data["snapshot_dir"])
|
325
|
+
if engine_data.get("build_dir"):
|
326
|
+
engine.build_dir = Path(engine_data["build_dir"])
|
327
|
+
|
328
|
+
return engine
|