synth-ai 0.2.4.dev3__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +575 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/tracing_v3/examples/basic_usage.py +188 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +105 -6
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,291 @@
|
|
1
|
+
# engine.py
|
2
|
+
from __future__ import annotations
|
3
|
+
from dataclasses import dataclass, asdict
|
4
|
+
from typing import Any, Dict, Tuple, Optional, List
|
5
|
+
from pydantic import BaseModel
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
from synth_ai.environments.examples.enron.art_helpers.types_enron import Email
|
9
|
+
from synth_ai.environments.examples.enron.art_helpers.email_search_tools import (
|
10
|
+
search_emails as helper_search_emails,
|
11
|
+
read_email as helper_read_email,
|
12
|
+
SearchResult,
|
13
|
+
)
|
14
|
+
|
15
|
+
# SQLite-backed helpers
|
16
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
17
|
+
from synth_ai.environments.examples.enron.taskset import EnronTaskInstance
|
18
|
+
from synth_ai.zyk import LM # Import LM class
|
19
|
+
|
20
|
+
from synth_ai.environments.environment.db.sqlite import SQLiteManager
|
21
|
+
from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent
|
22
|
+
from synth_ai.environments.examples.enron.art_helpers.local_email_db import (
|
23
|
+
DEFAULT_DB_PATH,
|
24
|
+
generate_database,
|
25
|
+
)
|
26
|
+
|
27
|
+
# --------------------------------------------------------------------------- actions
|
28
|
+
ACTION_SEARCH = "search"
|
29
|
+
ACTION_READ = "read"
|
30
|
+
ACTION_ANSWER = "answer"
|
31
|
+
|
32
|
+
|
33
|
+
# --------------------------------------------------------------------------- snapshot
|
34
|
+
@dataclass
|
35
|
+
class EnronEngineSnapshot(StatefulEngineSnapshot):
|
36
|
+
idx: int
|
37
|
+
answered: bool
|
38
|
+
total_reward: float
|
39
|
+
partial_rewards: List[float]
|
40
|
+
|
41
|
+
|
42
|
+
# --------------------------------------------------------------------------- engine
|
43
|
+
class EnronEngine(StatefulEngine):
|
44
|
+
"""
|
45
|
+
Minimal state-machine around the corbt/enron_emails_sample_questions dataset.
|
46
|
+
Action is a tuple(kind, arg):
|
47
|
+
|
48
|
+
(ACTION_SEARCH, query: str) → returns {"search_results": [message_ids]}
|
49
|
+
(ACTION_READ, message_id: str) → returns {"email_body": str}
|
50
|
+
(ACTION_ANSWER, answer: str) → rewards +1 / -1 and terminates
|
51
|
+
"""
|
52
|
+
|
53
|
+
# ----------------------------- init / helpers
|
54
|
+
def __init__(self, task_instance: EnronTaskInstance):
|
55
|
+
# Use the provided TaskInstance snapshot for this episode
|
56
|
+
self.instance = task_instance
|
57
|
+
self.answered = False
|
58
|
+
self.total_reward = 0.0
|
59
|
+
self.idx = 0
|
60
|
+
# List to track each step's reward
|
61
|
+
self.rewards_history: List[float] = []
|
62
|
+
|
63
|
+
db_file_path = Path(DEFAULT_DB_PATH)
|
64
|
+
if not db_file_path.exists():
|
65
|
+
generate_database(overwrite=False) # Ensure DB exists
|
66
|
+
self.sqlite_manager = SQLiteManager(db_path=db_file_path, read_only=True)
|
67
|
+
|
68
|
+
# RewardStack is an attribute of the engine; its calculations update private_state fields
|
69
|
+
self.reward_stack = RewardStack(
|
70
|
+
components=[
|
71
|
+
EnronAnswerCorrectnessComponent(),
|
72
|
+
EnronStepPenaltyComponent(penalty=-0.05),
|
73
|
+
]
|
74
|
+
)
|
75
|
+
# This will hold the specific arguments/details of the current agent action
|
76
|
+
# for the reward components to inspect.
|
77
|
+
self._current_action_details_for_reward: Optional[Dict[str, Any]] = None
|
78
|
+
|
79
|
+
def _sample(self) -> Dict[str, Any]:
|
80
|
+
# Return the snapshot dict from the TaskInstance
|
81
|
+
return self.instance.initial_engine_snapshot
|
82
|
+
|
83
|
+
# ----------------------------- step / reset
|
84
|
+
async def _step_engine(
|
85
|
+
self, tool_output_payload: Optional[Dict[str, Any]]
|
86
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
87
|
+
r = await self._calculate_and_apply_reward()
|
88
|
+
|
89
|
+
# Determine termination: if an answer was attempted, task terminates.
|
90
|
+
# The 'answered' flag is set by answer_question_action.
|
91
|
+
term = self.answered
|
92
|
+
|
93
|
+
s = self._sample()
|
94
|
+
priv = {
|
95
|
+
"reward_last": r,
|
96
|
+
"total_reward": self.total_reward,
|
97
|
+
"terminated": term,
|
98
|
+
"truncated": False,
|
99
|
+
"gold_answer": s["answer"],
|
100
|
+
}
|
101
|
+
|
102
|
+
# Public state combines static elements with dynamic ones from tool_output_payload
|
103
|
+
pub = {
|
104
|
+
"question": s["question"],
|
105
|
+
"tools": [
|
106
|
+
"search_emails",
|
107
|
+
"read_email",
|
108
|
+
"answer_question",
|
109
|
+
"terminate",
|
110
|
+
], # Available tools
|
111
|
+
"already_answered": self.answered,
|
112
|
+
"query_date": s.get("query_date", "<unknown date>"),
|
113
|
+
"inbox_address": s.get("inbox_address", "<unknown_inbox>"),
|
114
|
+
# Default empty values, to be overwritten by tool_output_payload if present
|
115
|
+
"search_results": [],
|
116
|
+
"email": None,
|
117
|
+
**(tool_output_payload if tool_output_payload else {}),
|
118
|
+
}
|
119
|
+
|
120
|
+
return priv, pub
|
121
|
+
|
122
|
+
async def _reset_engine(
|
123
|
+
self, *, seed: Optional[int] = None
|
124
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
125
|
+
"""
|
126
|
+
Advance to the next Q-A pair and emit an initial observation **without**
|
127
|
+
issuing an empty-keyword DB search (which would raise).
|
128
|
+
"""
|
129
|
+
# Reset answered status and total reward for this instance
|
130
|
+
self.answered = False
|
131
|
+
self.total_reward = 0.0
|
132
|
+
self.rewards_history = []
|
133
|
+
self._current_action_details_for_reward = None
|
134
|
+
# self.sqlite_manager.reset() # Enron DB is read-only; reset usually not needed unless switching DB files.
|
135
|
+
|
136
|
+
s = self._sample()
|
137
|
+
priv = {
|
138
|
+
"reward_last": 0.0,
|
139
|
+
"total_reward": 0.0,
|
140
|
+
"terminated": False,
|
141
|
+
"truncated": False,
|
142
|
+
"gold_answer": s["answer"],
|
143
|
+
}
|
144
|
+
pub = {
|
145
|
+
"question": s["question"],
|
146
|
+
"tools": ["search_emails", "read_email", "answer_question", "terminate"],
|
147
|
+
"already_answered": False,
|
148
|
+
"query_date": s.get("query_date", "<unknown date>"),
|
149
|
+
"inbox_address": s.get("inbox_address", "<unknown_inbox>"),
|
150
|
+
"search_results": [],
|
151
|
+
"email": None,
|
152
|
+
}
|
153
|
+
# No index advancement needed when using a single TaskInstance
|
154
|
+
return priv, pub
|
155
|
+
|
156
|
+
# ----------------------------- serialization helpers
|
157
|
+
async def _serialize_engine(self) -> EnronEngineSnapshot:
|
158
|
+
# Include partial rewards history in the snapshot
|
159
|
+
return EnronEngineSnapshot(
|
160
|
+
self.idx,
|
161
|
+
self.answered,
|
162
|
+
self.total_reward,
|
163
|
+
self.rewards_history,
|
164
|
+
)
|
165
|
+
|
166
|
+
@classmethod
|
167
|
+
async def _deserialize_engine(
|
168
|
+
cls, snap: EnronEngineSnapshot, task_instance: EnronTaskInstance
|
169
|
+
) -> "EnronEngine":
|
170
|
+
eng = cls(task_instance)
|
171
|
+
eng.idx = snap.idx
|
172
|
+
eng.answered = snap.answered
|
173
|
+
eng.total_reward = snap.total_reward
|
174
|
+
eng.rewards_history = (
|
175
|
+
snap.partial_rewards
|
176
|
+
) # Ensure this is correctly typed in Pydantic model if not List[float]
|
177
|
+
# Note: SQLiteManager is re-initialized in __init__ based on DEFAULT_DB_PATH.
|
178
|
+
# If the db path could change per instance/snapshot, that would need to be part of the snapshot.
|
179
|
+
return eng
|
180
|
+
|
181
|
+
def close_db(self):
|
182
|
+
self.sqlite_manager.close()
|
183
|
+
|
184
|
+
async def _calculate_and_apply_reward(self) -> float:
|
185
|
+
s = self._sample()
|
186
|
+
reward_context_state = { # State snapshot for reward calculation
|
187
|
+
"question": s["question"],
|
188
|
+
"gold_answer": s["answer"],
|
189
|
+
**(
|
190
|
+
self._current_action_details_for_reward
|
191
|
+
if self._current_action_details_for_reward
|
192
|
+
else {}
|
193
|
+
),
|
194
|
+
}
|
195
|
+
|
196
|
+
# The 'action' param for score can be the conceptual action type or detailed args
|
197
|
+
action_param_for_score = (
|
198
|
+
self._current_action_details_for_reward
|
199
|
+
if self._current_action_details_for_reward
|
200
|
+
else {}
|
201
|
+
)
|
202
|
+
|
203
|
+
reward = await self.reward_stack.step_reward(
|
204
|
+
state=reward_context_state, action=action_param_for_score
|
205
|
+
)
|
206
|
+
|
207
|
+
self.total_reward += reward
|
208
|
+
self.rewards_history.append(reward)
|
209
|
+
self._current_action_details_for_reward = None # Reset after use
|
210
|
+
return reward
|
211
|
+
|
212
|
+
async def search_emails_action(self, search_args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
213
|
+
res: List[SearchResult] = helper_search_emails(self.sqlite_manager, **search_args)
|
214
|
+
self._current_action_details_for_reward = {"type": "search", **search_args}
|
215
|
+
return [asdict(item) for item in res]
|
216
|
+
|
217
|
+
async def read_email_action(self, message_id: str) -> Optional[Dict[str, Any]]:
|
218
|
+
email: Optional[Email] = helper_read_email(self.sqlite_manager, message_id)
|
219
|
+
self._current_action_details_for_reward = {
|
220
|
+
"type": "read",
|
221
|
+
"message_id": message_id,
|
222
|
+
}
|
223
|
+
return email.dict() if email else None
|
224
|
+
|
225
|
+
async def answer_question_action(self, agent_answer: str) -> None:
|
226
|
+
# This method now primarily sets up state for reward calculation.
|
227
|
+
# The actual reward value and termination status are determined by _get_reward_and_update_state.
|
228
|
+
s = self._sample()
|
229
|
+
self._current_action_details_for_reward = {
|
230
|
+
"type": "answer",
|
231
|
+
"is_answer_action": True, # Signal for reward component
|
232
|
+
"question": s["question"],
|
233
|
+
"gold_answer": s["answer"],
|
234
|
+
"agent_answer": agent_answer,
|
235
|
+
}
|
236
|
+
self.answered = True # Mark as answered, termination decided by reward logic
|
237
|
+
|
238
|
+
|
239
|
+
# ----------------------------- LLM Judge for answers
|
240
|
+
async def determine_if_answer_is_correct(
|
241
|
+
question: str, gold_answer: str, agent_answer: str
|
242
|
+
) -> bool:
|
243
|
+
# Instantiate LM for the judge
|
244
|
+
llm = LM(model_name="gpt-4.1-nano", formatting_model_name="gpt-4.1-nano", temperature=0.0)
|
245
|
+
|
246
|
+
system_prompt = (
|
247
|
+
"You will be given a question and two different answers to the question, "
|
248
|
+
"the correct answer and the answer given by an AI. Your job is to determine "
|
249
|
+
"if the answer given by the AI is correct."
|
250
|
+
)
|
251
|
+
user_message_content = (
|
252
|
+
f"Question: {question}\nCorrect answer: {gold_answer}\nAI answer: {agent_answer}"
|
253
|
+
)
|
254
|
+
|
255
|
+
class CorrectnessResponse(BaseModel):
|
256
|
+
correct: bool
|
257
|
+
|
258
|
+
# Use LM.respond_async
|
259
|
+
response = await llm.respond_async(
|
260
|
+
system_message=system_prompt,
|
261
|
+
user_message=user_message_content,
|
262
|
+
response_model=CorrectnessResponse,
|
263
|
+
# Caching is typically handled within the LM class or its underlying setup
|
264
|
+
)
|
265
|
+
return response.structured_output.correct
|
266
|
+
|
267
|
+
|
268
|
+
# --- Placeholder Reward Components (ideally defined elsewhere and imported) ---
|
269
|
+
# (These would typically live in a shared rewards components file or alongside the engine if very specific)
|
270
|
+
class EnronAnswerCorrectnessComponent(RewardComponent):
|
271
|
+
async def score(self, state: Dict[str, Any], action: Any) -> float:
|
272
|
+
if state.get("is_answer_action") and state.get("agent_answer") is not None:
|
273
|
+
# determine_if_answer_is_correct should be part of the engine or accessible
|
274
|
+
# For now, assuming it's available in this scope.
|
275
|
+
correct = await determine_if_answer_is_correct(
|
276
|
+
state["question"], state["gold_answer"], state["agent_answer"]
|
277
|
+
)
|
278
|
+
return 1.0 if correct else -1.0
|
279
|
+
return 0.0
|
280
|
+
|
281
|
+
|
282
|
+
class EnronStepPenaltyComponent(RewardComponent):
|
283
|
+
def __init__(self, penalty: float = -0.01):
|
284
|
+
self.penalty = penalty
|
285
|
+
|
286
|
+
async def score(self, state: Dict[str, Any], action: Any) -> float:
|
287
|
+
# Apply penalty for any action that isn't a final answer, or just every step.
|
288
|
+
# For simplicity, apply if not a "correct" answer action.
|
289
|
+
if not state.get("is_answer_action"):
|
290
|
+
return self.penalty
|
291
|
+
return 0.0
|
@@ -0,0 +1,165 @@
|
|
1
|
+
# environment.py
|
2
|
+
from __future__ import annotations
|
3
|
+
from typing import List, Optional, Dict, Any, Union
|
4
|
+
from pydantic import BaseModel, Field
|
5
|
+
|
6
|
+
from synth_ai.environments.environment.tools import (
|
7
|
+
EnvToolCall,
|
8
|
+
ToolResult,
|
9
|
+
TOOL_REGISTRY,
|
10
|
+
register_tool,
|
11
|
+
)
|
12
|
+
from synth_ai.environments.environment.shared_engine import (
|
13
|
+
GetObservationCallable,
|
14
|
+
InternalObservation,
|
15
|
+
)
|
16
|
+
from synth_ai.environments.stateful.core import StatefulEnvironment
|
17
|
+
from synth_ai.environments.examples.enron.engine import (
|
18
|
+
EnronEngine,
|
19
|
+
ACTION_SEARCH,
|
20
|
+
ACTION_READ,
|
21
|
+
ACTION_ANSWER,
|
22
|
+
)
|
23
|
+
from synth_ai.environments.examples.enron.taskset import EnronTaskInstance
|
24
|
+
|
25
|
+
|
26
|
+
# -------- pydantic schemas (used by agent / LLM function calls)
|
27
|
+
class SearchEmailsArgs(BaseModel):
|
28
|
+
inbox: str = Field(..., description="Email address performing the search (used by tool logic)")
|
29
|
+
keywords: List[str] = Field(..., description="Keywords to AND-search for")
|
30
|
+
from_addr: Optional[str] = None
|
31
|
+
to_addr: Optional[str] = None
|
32
|
+
sent_after: Optional[str] = None
|
33
|
+
sent_before: Optional[str] = None
|
34
|
+
max_results: int = Field(10, le=10)
|
35
|
+
|
36
|
+
|
37
|
+
class ReadEmailArgs(BaseModel):
|
38
|
+
message_id: str
|
39
|
+
|
40
|
+
|
41
|
+
class AnswerQuestionArgs(BaseModel):
|
42
|
+
answer: str
|
43
|
+
|
44
|
+
|
45
|
+
# --------------------------------------------------------------------------- tool wrappers
|
46
|
+
class SearchEmails(EnvToolCall):
|
47
|
+
def __init__(self, **kwargs):
|
48
|
+
self.action = (ACTION_SEARCH, kwargs)
|
49
|
+
|
50
|
+
|
51
|
+
class ReadEmail(EnvToolCall):
|
52
|
+
def __init__(self, message_id: str):
|
53
|
+
self.action = (ACTION_READ, message_id)
|
54
|
+
|
55
|
+
|
56
|
+
class AnswerQuestion(EnvToolCall):
|
57
|
+
def __init__(self, answer: str):
|
58
|
+
self.action = (ACTION_ANSWER, answer)
|
59
|
+
|
60
|
+
|
61
|
+
# -- terminate wrapper (maps to an empty-answer ACTION_ANSWER) --------------
|
62
|
+
class Terminate(EnvToolCall):
|
63
|
+
def __init__(self):
|
64
|
+
self.action = (ACTION_ANSWER, "")
|
65
|
+
|
66
|
+
|
67
|
+
# -------- observation callable (optional for formatted observations)
|
68
|
+
class SynthEnronObservationCallable(GetObservationCallable):
|
69
|
+
async def get_observation(
|
70
|
+
self, pub: Dict[str, Any], priv: Dict[str, Any]
|
71
|
+
) -> InternalObservation:
|
72
|
+
"""Format observation as a human-readable string."""
|
73
|
+
q = pub.get("question")
|
74
|
+
rwd = priv.get("reward_last")
|
75
|
+
return f"Q: {q}\nTools: {pub.get('tools')}\nAnswered: {pub.get('already_answered')}\nSearch Res: {len(pub.get('search_results', []))} items\nEmail Loaded: {pub.get('email') is not None}\nTool Error: {pub.get('tool_error')}\nReward Δ: {rwd}"
|
76
|
+
|
77
|
+
|
78
|
+
# --------------------------------------------------------------------------- environment
|
79
|
+
class EnronEnvironment(StatefulEnvironment):
|
80
|
+
def __init__(
|
81
|
+
self,
|
82
|
+
task_instance: EnronTaskInstance,
|
83
|
+
custom_obs: Optional[GetObservationCallable] = None,
|
84
|
+
):
|
85
|
+
self.engine = EnronEngine(task_instance)
|
86
|
+
self.custom_obs = custom_obs or SynthEnronObservationCallable()
|
87
|
+
self.name = "Enron-QA-Env"
|
88
|
+
|
89
|
+
# Store tool instances on self for reliable access
|
90
|
+
self._tools_instances = {
|
91
|
+
"search_emails": SearchEmailsTool(self.engine),
|
92
|
+
"read_email": ReadEmailTool(self.engine),
|
93
|
+
"answer_question": AnswerQuestionTool(self.engine),
|
94
|
+
"terminate": TerminateTool(self.engine),
|
95
|
+
}
|
96
|
+
for tool_name, tool_instance in self._tools_instances.items():
|
97
|
+
if tool_name not in TOOL_REGISTRY:
|
98
|
+
register_tool(tool_instance)
|
99
|
+
elif TOOL_REGISTRY[tool_name].engine is not self.engine:
|
100
|
+
register_tool(tool_instance)
|
101
|
+
|
102
|
+
async def initialize(self) -> InternalObservation:
|
103
|
+
priv, pub = await self.engine._reset_engine()
|
104
|
+
return await self._obs(priv, pub)
|
105
|
+
|
106
|
+
async def step(
|
107
|
+
self,
|
108
|
+
calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]],
|
109
|
+
) -> InternalObservation:
|
110
|
+
# normalise → always [[EnvToolCall]]
|
111
|
+
if isinstance(calls, EnvToolCall):
|
112
|
+
calls = [[calls]]
|
113
|
+
elif calls and isinstance(calls[0], EnvToolCall):
|
114
|
+
calls = [calls]
|
115
|
+
|
116
|
+
if not isinstance(calls[0][0], EnvToolCall):
|
117
|
+
raise TypeError(f"Processed call is not EnvToolCall: {type(calls[0][0])}")
|
118
|
+
|
119
|
+
tool_name = calls[0][0].tool
|
120
|
+
tool_to_execute = self._tools_instances.get(tool_name)
|
121
|
+
|
122
|
+
if not tool_to_execute:
|
123
|
+
tool_to_execute = TOOL_REGISTRY.get(tool_name)
|
124
|
+
if not tool_to_execute:
|
125
|
+
raise ValueError(f"Tool '{tool_name}' not found.")
|
126
|
+
|
127
|
+
tool_result: ToolResult = await tool_to_execute(calls[0][0])
|
128
|
+
|
129
|
+
public_payload_for_engine = (
|
130
|
+
tool_result.payload if tool_result.ok and tool_result.payload else {}
|
131
|
+
)
|
132
|
+
if not tool_result.ok:
|
133
|
+
public_payload_for_engine["tool_error"] = tool_result.error
|
134
|
+
|
135
|
+
priv, pub = await self.engine._step_engine(public_payload_for_engine)
|
136
|
+
return await self._obs(priv, pub)
|
137
|
+
|
138
|
+
async def terminate(self) -> InternalObservation:
|
139
|
+
self.engine.close_db()
|
140
|
+
priv_state_on_terminate = {
|
141
|
+
"reward_last": 0,
|
142
|
+
"total_reward": self.engine.total_reward,
|
143
|
+
"terminated": True,
|
144
|
+
"truncated": False,
|
145
|
+
"gold_answer": self.engine._sample()["answer"],
|
146
|
+
}
|
147
|
+
pub_state_on_terminate = {
|
148
|
+
"question": self.engine._sample()["question"],
|
149
|
+
"tools": [],
|
150
|
+
"already_answered": self.engine.answered,
|
151
|
+
"status": "terminated_by_env",
|
152
|
+
}
|
153
|
+
return await self._obs(priv_state_on_terminate, pub_state_on_terminate)
|
154
|
+
|
155
|
+
async def checkpoint(self) -> InternalObservation:
|
156
|
+
snapshot = await self.engine._serialize_engine()
|
157
|
+
return {
|
158
|
+
"engine_snapshot": snapshot.model_dump(),
|
159
|
+
"message": "Checkpoint created",
|
160
|
+
}
|
161
|
+
|
162
|
+
async def _obs(self, priv: Dict[str, Any], pub: Dict[str, Any]):
|
163
|
+
if self.custom_obs:
|
164
|
+
return await self.custom_obs.get_observation(pub, priv)
|
165
|
+
return {**pub, **priv}
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# taskset.py
|
2
|
+
from __future__ import annotations
|
3
|
+
import asyncio
|
4
|
+
from uuid import uuid4
|
5
|
+
import os
|
6
|
+
|
7
|
+
from datasets import load_dataset
|
8
|
+
from dataclasses import dataclass, asdict
|
9
|
+
|
10
|
+
from synth_ai.environments.tasks.core import (
|
11
|
+
Task,
|
12
|
+
TaskInstance,
|
13
|
+
TaskInstanceSet,
|
14
|
+
TaskInstanceMetadata,
|
15
|
+
SplitInfo,
|
16
|
+
Impetus,
|
17
|
+
Intent,
|
18
|
+
)
|
19
|
+
|
20
|
+
enron_task = Task(
|
21
|
+
global_premises="Answer factual questions by reading Enron e-mails",
|
22
|
+
global_constraints="",
|
23
|
+
global_objectives="Provide the correct answer; minimise queries",
|
24
|
+
shared_env_params={},
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
# --------------------------------------------------------------------------- metadata
|
29
|
+
@dataclass
|
30
|
+
class EnronTaskInstanceMetadata(TaskInstanceMetadata):
|
31
|
+
split: str
|
32
|
+
email_count: int
|
33
|
+
message_ids: list[str]
|
34
|
+
|
35
|
+
|
36
|
+
@dataclass
|
37
|
+
class EnronTaskInstance(TaskInstance):
|
38
|
+
async def serialize(self):
|
39
|
+
data = asdict(self)
|
40
|
+
if isinstance(data.get("id"), uuid4().__class__):
|
41
|
+
data["id"] = str(data["id"])
|
42
|
+
return data
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
async def deserialize(cls, data: dict) -> "EnronTaskInstance":
|
46
|
+
return cls(**data)
|
47
|
+
|
48
|
+
|
49
|
+
# --------------------------------------------------------------------------- task-set builder
|
50
|
+
# Use a local dataset cache under examples/enron/dataset
|
51
|
+
CACHE_DIR = os.path.join(os.path.dirname(__file__), "dataset")
|
52
|
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
53
|
+
|
54
|
+
|
55
|
+
async def create_enron_taskset() -> TaskInstanceSet:
|
56
|
+
ds_train = load_dataset(
|
57
|
+
"corbt/enron_emails_sample_questions",
|
58
|
+
split="train",
|
59
|
+
cache_dir=CACHE_DIR,
|
60
|
+
)
|
61
|
+
ds_test = load_dataset(
|
62
|
+
"corbt/enron_emails_sample_questions",
|
63
|
+
split="test",
|
64
|
+
cache_dir=CACHE_DIR,
|
65
|
+
)
|
66
|
+
|
67
|
+
def to_instance(row: dict, split: str) -> EnronTaskInstance:
|
68
|
+
impetus = Impetus(instructions=row["question"])
|
69
|
+
intent = Intent(
|
70
|
+
rubric={"goal": "Answer the question using the Enron emails."},
|
71
|
+
gold_trajectories=None,
|
72
|
+
gold_state_diff={"answer": row["answer"]},
|
73
|
+
)
|
74
|
+
metadata = EnronTaskInstanceMetadata(
|
75
|
+
split=split,
|
76
|
+
email_count=len(row["message_ids"]),
|
77
|
+
message_ids=row["message_ids"],
|
78
|
+
)
|
79
|
+
return EnronTaskInstance(
|
80
|
+
id=uuid4(),
|
81
|
+
impetus=impetus,
|
82
|
+
intent=intent,
|
83
|
+
metadata=metadata,
|
84
|
+
is_reproducible=True,
|
85
|
+
initial_engine_snapshot=row,
|
86
|
+
)
|
87
|
+
|
88
|
+
train_instances = [to_instance(r, "train") for r in ds_train]
|
89
|
+
test_instances = [to_instance(r, "test") for r in ds_test]
|
90
|
+
|
91
|
+
split_info = SplitInfo(
|
92
|
+
val_instance_ids=set(),
|
93
|
+
test_instance_ids={inst.id for inst in test_instances},
|
94
|
+
_is_split_defined=True,
|
95
|
+
)
|
96
|
+
|
97
|
+
return TaskInstanceSet(
|
98
|
+
name="Enron-QA",
|
99
|
+
description="QA over Enron email dataset sample.",
|
100
|
+
instances=train_instances + test_instances,
|
101
|
+
split_info=split_info,
|
102
|
+
)
|
103
|
+
|
104
|
+
|
105
|
+
# quick sanity check ----------------------------------------------------------
|
106
|
+
if __name__ == "__main__":
|
107
|
+
|
108
|
+
async def _main():
|
109
|
+
ts = await create_enron_taskset()
|
110
|
+
print(f"{len(ts.instances)} instances built.")
|
111
|
+
|
112
|
+
asyncio.run(_main())
|
@@ -0,0 +1,48 @@
|
|
1
|
+
"""MiniGrid environment example for synth_env.
|
2
|
+
|
3
|
+
This module provides a comprehensive implementation of MiniGrid environments
|
4
|
+
with full state management, tool-based interaction, and task generation.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from synth_ai.environments.examples.minigrid.engine import (
|
8
|
+
MiniGridEngine,
|
9
|
+
MiniGridPublicState,
|
10
|
+
MiniGridPrivateState,
|
11
|
+
MiniGridGoalReachedComponent,
|
12
|
+
MiniGridStepPenaltyComponent,
|
13
|
+
MiniGridObservationCallable,
|
14
|
+
MiniGridCheckpointObservationCallable,
|
15
|
+
)
|
16
|
+
from synth_ai.environments.examples.minigrid.environment import (
|
17
|
+
MiniGridEnvironment,
|
18
|
+
MiniGridInteractTool,
|
19
|
+
MiniGridActionInput,
|
20
|
+
)
|
21
|
+
from synth_ai.environments.examples.minigrid.taskset import (
|
22
|
+
MiniGridTaskInstance,
|
23
|
+
MiniGridTaskInstanceMetadata,
|
24
|
+
DEFAULT_MINIGRID_TASK,
|
25
|
+
create_minigrid_taskset,
|
26
|
+
taskset,
|
27
|
+
)
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
# Engine
|
31
|
+
"MiniGridEngine",
|
32
|
+
"MiniGridPublicState",
|
33
|
+
"MiniGridPrivateState",
|
34
|
+
"MiniGridGoalReachedComponent",
|
35
|
+
"MiniGridStepPenaltyComponent",
|
36
|
+
"MiniGridObservationCallable",
|
37
|
+
"MiniGridCheckpointObservationCallable",
|
38
|
+
# Environment
|
39
|
+
"MiniGridEnvironment",
|
40
|
+
"MiniGridInteractTool",
|
41
|
+
"MiniGridActionInput",
|
42
|
+
# TaskSet
|
43
|
+
"MiniGridTaskInstance",
|
44
|
+
"MiniGridTaskInstanceMetadata",
|
45
|
+
"DEFAULT_MINIGRID_TASK",
|
46
|
+
"create_minigrid_taskset",
|
47
|
+
"taskset",
|
48
|
+
]
|