synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. synth_ai/environments/examples/__init__.py +1 -0
  2. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  3. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  4. synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
  5. synth_ai/environments/examples/crafter_classic/engine.py +579 -0
  6. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  7. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  8. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  9. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
  10. synth_ai/environments/examples/crafter_classic/environment.py +364 -0
  11. synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
  12. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
  13. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
  14. synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
  15. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
  16. synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
  17. synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
  18. synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
  19. synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
  20. synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
  21. synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
  22. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
  23. synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
  24. synth_ai/environments/examples/crafter_custom/environment.py +312 -0
  25. synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
  26. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  27. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  28. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  29. synth_ai/environments/examples/enron/engine.py +291 -0
  30. synth_ai/environments/examples/enron/environment.py +165 -0
  31. synth_ai/environments/examples/enron/taskset.py +112 -0
  32. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  33. synth_ai/environments/examples/minigrid/engine.py +589 -0
  34. synth_ai/environments/examples/minigrid/environment.py +274 -0
  35. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  36. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  37. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  38. synth_ai/environments/examples/nethack/__init__.py +7 -0
  39. synth_ai/environments/examples/nethack/achievements.py +337 -0
  40. synth_ai/environments/examples/nethack/engine.py +738 -0
  41. synth_ai/environments/examples/nethack/environment.py +255 -0
  42. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  43. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  44. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  45. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  46. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  47. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  48. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  49. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  50. synth_ai/environments/examples/nethack/taskset.py +323 -0
  51. synth_ai/environments/examples/red/__init__.py +7 -0
  52. synth_ai/environments/examples/red/config_logging.py +110 -0
  53. synth_ai/environments/examples/red/engine.py +693 -0
  54. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  55. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  56. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  57. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  58. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  59. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  60. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  61. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  62. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  63. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  64. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  69. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  70. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  71. synth_ai/environments/examples/red/environment.py +235 -0
  72. synth_ai/environments/examples/red/taskset.py +77 -0
  73. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  74. synth_ai/environments/examples/sokoban/engine.py +675 -0
  75. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  76. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  77. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  78. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  79. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  80. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  81. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  82. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  87. synth_ai/environments/examples/sokoban/environment.py +228 -0
  88. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  89. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  90. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  91. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  92. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  93. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  94. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  95. synth_ai/environments/examples/verilog/__init__.py +10 -0
  96. synth_ai/environments/examples/verilog/engine.py +328 -0
  97. synth_ai/environments/examples/verilog/environment.py +349 -0
  98. synth_ai/environments/examples/verilog/taskset.py +418 -0
  99. synth_ai/environments/examples/wordle/__init__.py +29 -0
  100. synth_ai/environments/examples/wordle/engine.py +391 -0
  101. synth_ai/environments/examples/wordle/environment.py +154 -0
  102. synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
  103. synth_ai/environments/examples/wordle/taskset.py +222 -0
  104. synth_ai/environments/service/app.py +8 -0
  105. synth_ai/environments/service/core_routes.py +38 -0
  106. synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
  107. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
  108. synth_ai/learning/prompts/mipro.py +273 -1
  109. synth_ai/learning/prompts/random_search.py +247 -0
  110. synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
  111. synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
  112. synth_ai/lm/injection.py +81 -0
  113. synth_ai/lm/overrides.py +204 -0
  114. synth_ai/lm/provider_support/anthropic.py +39 -12
  115. synth_ai/lm/provider_support/openai.py +31 -4
  116. synth_ai/lm/vendors/core/anthropic_api.py +16 -0
  117. synth_ai/lm/vendors/openai_standard.py +35 -5
  118. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
  119. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
  120. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
  121. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
  122. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
  123. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,291 @@
1
+ # engine.py
2
+ from __future__ import annotations
3
+ from dataclasses import dataclass, asdict
4
+ from typing import Any, Dict, Tuple, Optional, List
5
+ from pydantic import BaseModel
6
+ from pathlib import Path
7
+
8
+ from synth_ai.environments.examples.enron.art_helpers.types_enron import Email
9
+ from synth_ai.environments.examples.enron.art_helpers.email_search_tools import (
10
+ search_emails as helper_search_emails,
11
+ read_email as helper_read_email,
12
+ SearchResult,
13
+ )
14
+
15
+ # SQLite-backed helpers
16
+ from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
17
+ from synth_ai.environments.examples.enron.taskset import EnronTaskInstance
18
+ from synth_ai.zyk import LM # Import LM class
19
+
20
+ from synth_ai.environments.environment.db.sqlite import SQLiteManager
21
+ from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent
22
+ from synth_ai.environments.examples.enron.art_helpers.local_email_db import (
23
+ DEFAULT_DB_PATH,
24
+ generate_database,
25
+ )
26
+
27
+ # --------------------------------------------------------------------------- actions
28
+ ACTION_SEARCH = "search"
29
+ ACTION_READ = "read"
30
+ ACTION_ANSWER = "answer"
31
+
32
+
33
+ # --------------------------------------------------------------------------- snapshot
34
+ @dataclass
35
+ class EnronEngineSnapshot(StatefulEngineSnapshot):
36
+ idx: int
37
+ answered: bool
38
+ total_reward: float
39
+ partial_rewards: List[float]
40
+
41
+
42
+ # --------------------------------------------------------------------------- engine
43
+ class EnronEngine(StatefulEngine):
44
+ """
45
+ Minimal state-machine around the corbt/enron_emails_sample_questions dataset.
46
+ Action is a tuple(kind, arg):
47
+
48
+ (ACTION_SEARCH, query: str) → returns {"search_results": [message_ids]}
49
+ (ACTION_READ, message_id: str) → returns {"email_body": str}
50
+ (ACTION_ANSWER, answer: str) → rewards +1 / -1 and terminates
51
+ """
52
+
53
+ # ----------------------------- init / helpers
54
+ def __init__(self, task_instance: EnronTaskInstance):
55
+ # Use the provided TaskInstance snapshot for this episode
56
+ self.instance = task_instance
57
+ self.answered = False
58
+ self.total_reward = 0.0
59
+ self.idx = 0
60
+ # List to track each step's reward
61
+ self.rewards_history: List[float] = []
62
+
63
+ db_file_path = Path(DEFAULT_DB_PATH)
64
+ if not db_file_path.exists():
65
+ generate_database(overwrite=False) # Ensure DB exists
66
+ self.sqlite_manager = SQLiteManager(db_path=db_file_path, read_only=True)
67
+
68
+ # RewardStack is an attribute of the engine; its calculations update private_state fields
69
+ self.reward_stack = RewardStack(
70
+ components=[
71
+ EnronAnswerCorrectnessComponent(),
72
+ EnronStepPenaltyComponent(penalty=-0.05),
73
+ ]
74
+ )
75
+ # This will hold the specific arguments/details of the current agent action
76
+ # for the reward components to inspect.
77
+ self._current_action_details_for_reward: Optional[Dict[str, Any]] = None
78
+
79
+ def _sample(self) -> Dict[str, Any]:
80
+ # Return the snapshot dict from the TaskInstance
81
+ return self.instance.initial_engine_snapshot
82
+
83
+ # ----------------------------- step / reset
84
+ async def _step_engine(
85
+ self, tool_output_payload: Optional[Dict[str, Any]]
86
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
87
+ r = await self._calculate_and_apply_reward()
88
+
89
+ # Determine termination: if an answer was attempted, task terminates.
90
+ # The 'answered' flag is set by answer_question_action.
91
+ term = self.answered
92
+
93
+ s = self._sample()
94
+ priv = {
95
+ "reward_last": r,
96
+ "total_reward": self.total_reward,
97
+ "terminated": term,
98
+ "truncated": False,
99
+ "gold_answer": s["answer"],
100
+ }
101
+
102
+ # Public state combines static elements with dynamic ones from tool_output_payload
103
+ pub = {
104
+ "question": s["question"],
105
+ "tools": [
106
+ "search_emails",
107
+ "read_email",
108
+ "answer_question",
109
+ "terminate",
110
+ ], # Available tools
111
+ "already_answered": self.answered,
112
+ "query_date": s.get("query_date", "<unknown date>"),
113
+ "inbox_address": s.get("inbox_address", "<unknown_inbox>"),
114
+ # Default empty values, to be overwritten by tool_output_payload if present
115
+ "search_results": [],
116
+ "email": None,
117
+ **(tool_output_payload if tool_output_payload else {}),
118
+ }
119
+
120
+ return priv, pub
121
+
122
+ async def _reset_engine(
123
+ self, *, seed: Optional[int] = None
124
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
125
+ """
126
+ Advance to the next Q-A pair and emit an initial observation **without**
127
+ issuing an empty-keyword DB search (which would raise).
128
+ """
129
+ # Reset answered status and total reward for this instance
130
+ self.answered = False
131
+ self.total_reward = 0.0
132
+ self.rewards_history = []
133
+ self._current_action_details_for_reward = None
134
+ # self.sqlite_manager.reset() # Enron DB is read-only; reset usually not needed unless switching DB files.
135
+
136
+ s = self._sample()
137
+ priv = {
138
+ "reward_last": 0.0,
139
+ "total_reward": 0.0,
140
+ "terminated": False,
141
+ "truncated": False,
142
+ "gold_answer": s["answer"],
143
+ }
144
+ pub = {
145
+ "question": s["question"],
146
+ "tools": ["search_emails", "read_email", "answer_question", "terminate"],
147
+ "already_answered": False,
148
+ "query_date": s.get("query_date", "<unknown date>"),
149
+ "inbox_address": s.get("inbox_address", "<unknown_inbox>"),
150
+ "search_results": [],
151
+ "email": None,
152
+ }
153
+ # No index advancement needed when using a single TaskInstance
154
+ return priv, pub
155
+
156
+ # ----------------------------- serialization helpers
157
+ async def _serialize_engine(self) -> EnronEngineSnapshot:
158
+ # Include partial rewards history in the snapshot
159
+ return EnronEngineSnapshot(
160
+ self.idx,
161
+ self.answered,
162
+ self.total_reward,
163
+ self.rewards_history,
164
+ )
165
+
166
+ @classmethod
167
+ async def _deserialize_engine(
168
+ cls, snap: EnronEngineSnapshot, task_instance: EnronTaskInstance
169
+ ) -> "EnronEngine":
170
+ eng = cls(task_instance)
171
+ eng.idx = snap.idx
172
+ eng.answered = snap.answered
173
+ eng.total_reward = snap.total_reward
174
+ eng.rewards_history = (
175
+ snap.partial_rewards
176
+ ) # Ensure this is correctly typed in Pydantic model if not List[float]
177
+ # Note: SQLiteManager is re-initialized in __init__ based on DEFAULT_DB_PATH.
178
+ # If the db path could change per instance/snapshot, that would need to be part of the snapshot.
179
+ return eng
180
+
181
+ def close_db(self):
182
+ self.sqlite_manager.close()
183
+
184
+ async def _calculate_and_apply_reward(self) -> float:
185
+ s = self._sample()
186
+ reward_context_state = { # State snapshot for reward calculation
187
+ "question": s["question"],
188
+ "gold_answer": s["answer"],
189
+ **(
190
+ self._current_action_details_for_reward
191
+ if self._current_action_details_for_reward
192
+ else {}
193
+ ),
194
+ }
195
+
196
+ # The 'action' param for score can be the conceptual action type or detailed args
197
+ action_param_for_score = (
198
+ self._current_action_details_for_reward
199
+ if self._current_action_details_for_reward
200
+ else {}
201
+ )
202
+
203
+ reward = await self.reward_stack.step_reward(
204
+ state=reward_context_state, action=action_param_for_score
205
+ )
206
+
207
+ self.total_reward += reward
208
+ self.rewards_history.append(reward)
209
+ self._current_action_details_for_reward = None # Reset after use
210
+ return reward
211
+
212
+ async def search_emails_action(self, search_args: Dict[str, Any]) -> List[Dict[str, Any]]:
213
+ res: List[SearchResult] = helper_search_emails(self.sqlite_manager, **search_args)
214
+ self._current_action_details_for_reward = {"type": "search", **search_args}
215
+ return [asdict(item) for item in res]
216
+
217
+ async def read_email_action(self, message_id: str) -> Optional[Dict[str, Any]]:
218
+ email: Optional[Email] = helper_read_email(self.sqlite_manager, message_id)
219
+ self._current_action_details_for_reward = {
220
+ "type": "read",
221
+ "message_id": message_id,
222
+ }
223
+ return email.dict() if email else None
224
+
225
+ async def answer_question_action(self, agent_answer: str) -> None:
226
+ # This method now primarily sets up state for reward calculation.
227
+ # The actual reward value and termination status are determined by _get_reward_and_update_state.
228
+ s = self._sample()
229
+ self._current_action_details_for_reward = {
230
+ "type": "answer",
231
+ "is_answer_action": True, # Signal for reward component
232
+ "question": s["question"],
233
+ "gold_answer": s["answer"],
234
+ "agent_answer": agent_answer,
235
+ }
236
+ self.answered = True # Mark as answered, termination decided by reward logic
237
+
238
+
239
+ # ----------------------------- LLM Judge for answers
240
+ async def determine_if_answer_is_correct(
241
+ question: str, gold_answer: str, agent_answer: str
242
+ ) -> bool:
243
+ # Instantiate LM for the judge
244
+ llm = LM(model_name="gpt-4.1-nano", formatting_model_name="gpt-4.1-nano", temperature=0.0)
245
+
246
+ system_prompt = (
247
+ "You will be given a question and two different answers to the question, "
248
+ "the correct answer and the answer given by an AI. Your job is to determine "
249
+ "if the answer given by the AI is correct."
250
+ )
251
+ user_message_content = (
252
+ f"Question: {question}\nCorrect answer: {gold_answer}\nAI answer: {agent_answer}"
253
+ )
254
+
255
+ class CorrectnessResponse(BaseModel):
256
+ correct: bool
257
+
258
+ # Use LM.respond_async
259
+ response = await llm.respond_async(
260
+ system_message=system_prompt,
261
+ user_message=user_message_content,
262
+ response_model=CorrectnessResponse,
263
+ # Caching is typically handled within the LM class or its underlying setup
264
+ )
265
+ return response.structured_output.correct
266
+
267
+
268
+ # --- Placeholder Reward Components (ideally defined elsewhere and imported) ---
269
+ # (These would typically live in a shared rewards components file or alongside the engine if very specific)
270
+ class EnronAnswerCorrectnessComponent(RewardComponent):
271
+ async def score(self, state: Dict[str, Any], action: Any) -> float:
272
+ if state.get("is_answer_action") and state.get("agent_answer") is not None:
273
+ # determine_if_answer_is_correct should be part of the engine or accessible
274
+ # For now, assuming it's available in this scope.
275
+ correct = await determine_if_answer_is_correct(
276
+ state["question"], state["gold_answer"], state["agent_answer"]
277
+ )
278
+ return 1.0 if correct else -1.0
279
+ return 0.0
280
+
281
+
282
+ class EnronStepPenaltyComponent(RewardComponent):
283
+ def __init__(self, penalty: float = -0.01):
284
+ self.penalty = penalty
285
+
286
+ async def score(self, state: Dict[str, Any], action: Any) -> float:
287
+ # Apply penalty for any action that isn't a final answer, or just every step.
288
+ # For simplicity, apply if not a "correct" answer action.
289
+ if not state.get("is_answer_action"):
290
+ return self.penalty
291
+ return 0.0
@@ -0,0 +1,165 @@
1
+ # environment.py
2
+ from __future__ import annotations
3
+ from typing import List, Optional, Dict, Any, Union
4
+ from pydantic import BaseModel, Field
5
+
6
+ from synth_ai.environments.environment.tools import (
7
+ EnvToolCall,
8
+ ToolResult,
9
+ TOOL_REGISTRY,
10
+ register_tool,
11
+ )
12
+ from synth_ai.environments.environment.shared_engine import (
13
+ GetObservationCallable,
14
+ InternalObservation,
15
+ )
16
+ from synth_ai.environments.stateful.core import StatefulEnvironment
17
+ from synth_ai.environments.examples.enron.engine import (
18
+ EnronEngine,
19
+ ACTION_SEARCH,
20
+ ACTION_READ,
21
+ ACTION_ANSWER,
22
+ )
23
+ from synth_ai.environments.examples.enron.taskset import EnronTaskInstance
24
+
25
+
26
+ # -------- pydantic schemas (used by agent / LLM function calls)
27
+ class SearchEmailsArgs(BaseModel):
28
+ inbox: str = Field(..., description="Email address performing the search (used by tool logic)")
29
+ keywords: List[str] = Field(..., description="Keywords to AND-search for")
30
+ from_addr: Optional[str] = None
31
+ to_addr: Optional[str] = None
32
+ sent_after: Optional[str] = None
33
+ sent_before: Optional[str] = None
34
+ max_results: int = Field(10, le=10)
35
+
36
+
37
+ class ReadEmailArgs(BaseModel):
38
+ message_id: str
39
+
40
+
41
+ class AnswerQuestionArgs(BaseModel):
42
+ answer: str
43
+
44
+
45
+ # --------------------------------------------------------------------------- tool wrappers
46
+ class SearchEmails(EnvToolCall):
47
+ def __init__(self, **kwargs):
48
+ self.action = (ACTION_SEARCH, kwargs)
49
+
50
+
51
+ class ReadEmail(EnvToolCall):
52
+ def __init__(self, message_id: str):
53
+ self.action = (ACTION_READ, message_id)
54
+
55
+
56
+ class AnswerQuestion(EnvToolCall):
57
+ def __init__(self, answer: str):
58
+ self.action = (ACTION_ANSWER, answer)
59
+
60
+
61
+ # -- terminate wrapper (maps to an empty-answer ACTION_ANSWER) --------------
62
+ class Terminate(EnvToolCall):
63
+ def __init__(self):
64
+ self.action = (ACTION_ANSWER, "")
65
+
66
+
67
+ # -------- observation callable (optional for formatted observations)
68
+ class SynthEnronObservationCallable(GetObservationCallable):
69
+ async def get_observation(
70
+ self, pub: Dict[str, Any], priv: Dict[str, Any]
71
+ ) -> InternalObservation:
72
+ """Format observation as a human-readable string."""
73
+ q = pub.get("question")
74
+ rwd = priv.get("reward_last")
75
+ return f"Q: {q}\nTools: {pub.get('tools')}\nAnswered: {pub.get('already_answered')}\nSearch Res: {len(pub.get('search_results', []))} items\nEmail Loaded: {pub.get('email') is not None}\nTool Error: {pub.get('tool_error')}\nReward Δ: {rwd}"
76
+
77
+
78
+ # --------------------------------------------------------------------------- environment
79
+ class EnronEnvironment(StatefulEnvironment):
80
+ def __init__(
81
+ self,
82
+ task_instance: EnronTaskInstance,
83
+ custom_obs: Optional[GetObservationCallable] = None,
84
+ ):
85
+ self.engine = EnronEngine(task_instance)
86
+ self.custom_obs = custom_obs or SynthEnronObservationCallable()
87
+ self.name = "Enron-QA-Env"
88
+
89
+ # Store tool instances on self for reliable access
90
+ self._tools_instances = {
91
+ "search_emails": SearchEmailsTool(self.engine),
92
+ "read_email": ReadEmailTool(self.engine),
93
+ "answer_question": AnswerQuestionTool(self.engine),
94
+ "terminate": TerminateTool(self.engine),
95
+ }
96
+ for tool_name, tool_instance in self._tools_instances.items():
97
+ if tool_name not in TOOL_REGISTRY:
98
+ register_tool(tool_instance)
99
+ elif TOOL_REGISTRY[tool_name].engine is not self.engine:
100
+ register_tool(tool_instance)
101
+
102
+ async def initialize(self) -> InternalObservation:
103
+ priv, pub = await self.engine._reset_engine()
104
+ return await self._obs(priv, pub)
105
+
106
+ async def step(
107
+ self,
108
+ calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]],
109
+ ) -> InternalObservation:
110
+ # normalise → always [[EnvToolCall]]
111
+ if isinstance(calls, EnvToolCall):
112
+ calls = [[calls]]
113
+ elif calls and isinstance(calls[0], EnvToolCall):
114
+ calls = [calls]
115
+
116
+ if not isinstance(calls[0][0], EnvToolCall):
117
+ raise TypeError(f"Processed call is not EnvToolCall: {type(calls[0][0])}")
118
+
119
+ tool_name = calls[0][0].tool
120
+ tool_to_execute = self._tools_instances.get(tool_name)
121
+
122
+ if not tool_to_execute:
123
+ tool_to_execute = TOOL_REGISTRY.get(tool_name)
124
+ if not tool_to_execute:
125
+ raise ValueError(f"Tool '{tool_name}' not found.")
126
+
127
+ tool_result: ToolResult = await tool_to_execute(calls[0][0])
128
+
129
+ public_payload_for_engine = (
130
+ tool_result.payload if tool_result.ok and tool_result.payload else {}
131
+ )
132
+ if not tool_result.ok:
133
+ public_payload_for_engine["tool_error"] = tool_result.error
134
+
135
+ priv, pub = await self.engine._step_engine(public_payload_for_engine)
136
+ return await self._obs(priv, pub)
137
+
138
+ async def terminate(self) -> InternalObservation:
139
+ self.engine.close_db()
140
+ priv_state_on_terminate = {
141
+ "reward_last": 0,
142
+ "total_reward": self.engine.total_reward,
143
+ "terminated": True,
144
+ "truncated": False,
145
+ "gold_answer": self.engine._sample()["answer"],
146
+ }
147
+ pub_state_on_terminate = {
148
+ "question": self.engine._sample()["question"],
149
+ "tools": [],
150
+ "already_answered": self.engine.answered,
151
+ "status": "terminated_by_env",
152
+ }
153
+ return await self._obs(priv_state_on_terminate, pub_state_on_terminate)
154
+
155
+ async def checkpoint(self) -> InternalObservation:
156
+ snapshot = await self.engine._serialize_engine()
157
+ return {
158
+ "engine_snapshot": snapshot.model_dump(),
159
+ "message": "Checkpoint created",
160
+ }
161
+
162
+ async def _obs(self, priv: Dict[str, Any], pub: Dict[str, Any]):
163
+ if self.custom_obs:
164
+ return await self.custom_obs.get_observation(pub, priv)
165
+ return {**pub, **priv}
@@ -0,0 +1,112 @@
1
+ # taskset.py
2
+ from __future__ import annotations
3
+ import asyncio
4
+ from uuid import uuid4
5
+ import os
6
+
7
+ from datasets import load_dataset
8
+ from dataclasses import dataclass, asdict
9
+
10
+ from synth_ai.environments.tasks.core import (
11
+ Task,
12
+ TaskInstance,
13
+ TaskInstanceSet,
14
+ TaskInstanceMetadata,
15
+ SplitInfo,
16
+ Impetus,
17
+ Intent,
18
+ )
19
+
20
+ enron_task = Task(
21
+ global_premises="Answer factual questions by reading Enron e-mails",
22
+ global_constraints="",
23
+ global_objectives="Provide the correct answer; minimise queries",
24
+ shared_env_params={},
25
+ )
26
+
27
+
28
+ # --------------------------------------------------------------------------- metadata
29
+ @dataclass
30
+ class EnronTaskInstanceMetadata(TaskInstanceMetadata):
31
+ split: str
32
+ email_count: int
33
+ message_ids: list[str]
34
+
35
+
36
+ @dataclass
37
+ class EnronTaskInstance(TaskInstance):
38
+ async def serialize(self):
39
+ data = asdict(self)
40
+ if isinstance(data.get("id"), uuid4().__class__):
41
+ data["id"] = str(data["id"])
42
+ return data
43
+
44
+ @classmethod
45
+ async def deserialize(cls, data: dict) -> "EnronTaskInstance":
46
+ return cls(**data)
47
+
48
+
49
+ # --------------------------------------------------------------------------- task-set builder
50
+ # Use a local dataset cache under examples/enron/dataset
51
+ CACHE_DIR = os.path.join(os.path.dirname(__file__), "dataset")
52
+ os.makedirs(CACHE_DIR, exist_ok=True)
53
+
54
+
55
+ async def create_enron_taskset() -> TaskInstanceSet:
56
+ ds_train = load_dataset(
57
+ "corbt/enron_emails_sample_questions",
58
+ split="train",
59
+ cache_dir=CACHE_DIR,
60
+ )
61
+ ds_test = load_dataset(
62
+ "corbt/enron_emails_sample_questions",
63
+ split="test",
64
+ cache_dir=CACHE_DIR,
65
+ )
66
+
67
+ def to_instance(row: dict, split: str) -> EnronTaskInstance:
68
+ impetus = Impetus(instructions=row["question"])
69
+ intent = Intent(
70
+ rubric={"goal": "Answer the question using the Enron emails."},
71
+ gold_trajectories=None,
72
+ gold_state_diff={"answer": row["answer"]},
73
+ )
74
+ metadata = EnronTaskInstanceMetadata(
75
+ split=split,
76
+ email_count=len(row["message_ids"]),
77
+ message_ids=row["message_ids"],
78
+ )
79
+ return EnronTaskInstance(
80
+ id=uuid4(),
81
+ impetus=impetus,
82
+ intent=intent,
83
+ metadata=metadata,
84
+ is_reproducible=True,
85
+ initial_engine_snapshot=row,
86
+ )
87
+
88
+ train_instances = [to_instance(r, "train") for r in ds_train]
89
+ test_instances = [to_instance(r, "test") for r in ds_test]
90
+
91
+ split_info = SplitInfo(
92
+ val_instance_ids=set(),
93
+ test_instance_ids={inst.id for inst in test_instances},
94
+ _is_split_defined=True,
95
+ )
96
+
97
+ return TaskInstanceSet(
98
+ name="Enron-QA",
99
+ description="QA over Enron email dataset sample.",
100
+ instances=train_instances + test_instances,
101
+ split_info=split_info,
102
+ )
103
+
104
+
105
+ # quick sanity check ----------------------------------------------------------
106
+ if __name__ == "__main__":
107
+
108
+ async def _main():
109
+ ts = await create_enron_taskset()
110
+ print(f"{len(ts.instances)} instances built.")
111
+
112
+ asyncio.run(_main())
@@ -0,0 +1,48 @@
1
+ """MiniGrid environment example for synth_env.
2
+
3
+ This module provides a comprehensive implementation of MiniGrid environments
4
+ with full state management, tool-based interaction, and task generation.
5
+ """
6
+
7
+ from synth_ai.environments.examples.minigrid.engine import (
8
+ MiniGridEngine,
9
+ MiniGridPublicState,
10
+ MiniGridPrivateState,
11
+ MiniGridGoalReachedComponent,
12
+ MiniGridStepPenaltyComponent,
13
+ MiniGridObservationCallable,
14
+ MiniGridCheckpointObservationCallable,
15
+ )
16
+ from synth_ai.environments.examples.minigrid.environment import (
17
+ MiniGridEnvironment,
18
+ MiniGridInteractTool,
19
+ MiniGridActionInput,
20
+ )
21
+ from synth_ai.environments.examples.minigrid.taskset import (
22
+ MiniGridTaskInstance,
23
+ MiniGridTaskInstanceMetadata,
24
+ DEFAULT_MINIGRID_TASK,
25
+ create_minigrid_taskset,
26
+ taskset,
27
+ )
28
+
29
+ __all__ = [
30
+ # Engine
31
+ "MiniGridEngine",
32
+ "MiniGridPublicState",
33
+ "MiniGridPrivateState",
34
+ "MiniGridGoalReachedComponent",
35
+ "MiniGridStepPenaltyComponent",
36
+ "MiniGridObservationCallable",
37
+ "MiniGridCheckpointObservationCallable",
38
+ # Environment
39
+ "MiniGridEnvironment",
40
+ "MiniGridInteractTool",
41
+ "MiniGridActionInput",
42
+ # TaskSet
43
+ "MiniGridTaskInstance",
44
+ "MiniGridTaskInstanceMetadata",
45
+ "DEFAULT_MINIGRID_TASK",
46
+ "create_minigrid_taskset",
47
+ "taskset",
48
+ ]