synth-ai 0.2.0__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +35 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +1 -1
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.2.0.dist-info/METADATA +0 -36
- synth_ai-0.2.0.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info/licenses}/LICENSE +0 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,280 @@
|
|
1
|
+
import sqlite3
|
2
|
+
import os
|
3
|
+
import logging
|
4
|
+
from datasets import load_dataset, Dataset, Features, Value, Sequence
|
5
|
+
from tqdm import tqdm
|
6
|
+
from datetime import datetime
|
7
|
+
|
8
|
+
# Resolve paths relative to this file so it works regardless of the current working directory
|
9
|
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
10
|
+
# Database will live in "../data/enron_emails.db" relative to project root
|
11
|
+
DEFAULT_DB_PATH = os.path.join(BASE_DIR, "..", "..", "data", "enron_emails.db")
|
12
|
+
|
13
|
+
DEFAULT_REPO_ID = "corbt/enron-emails"
|
14
|
+
|
15
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
16
|
+
|
17
|
+
|
18
|
+
# --- Database Schema ---
|
19
|
+
SQL_CREATE_TABLES = """
|
20
|
+
DROP TABLE IF EXISTS recipients;
|
21
|
+
DROP TABLE IF EXISTS emails_fts;
|
22
|
+
DROP TABLE IF EXISTS emails;
|
23
|
+
|
24
|
+
CREATE TABLE emails (
|
25
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
26
|
+
message_id TEXT UNIQUE,
|
27
|
+
subject TEXT,
|
28
|
+
from_address TEXT,
|
29
|
+
date TEXT, -- Store as ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
|
30
|
+
body TEXT,
|
31
|
+
file_name TEXT
|
32
|
+
);
|
33
|
+
|
34
|
+
CREATE TABLE recipients (
|
35
|
+
email_id INTEGER,
|
36
|
+
recipient_address TEXT,
|
37
|
+
recipient_type TEXT, -- 'to', 'cc', 'bcc'
|
38
|
+
FOREIGN KEY(email_id) REFERENCES emails(id) ON DELETE CASCADE
|
39
|
+
);
|
40
|
+
"""
|
41
|
+
|
42
|
+
SQL_CREATE_INDEXES_TRIGGERS = """
|
43
|
+
CREATE INDEX idx_emails_from ON emails(from_address);
|
44
|
+
CREATE INDEX idx_emails_date ON emails(date);
|
45
|
+
CREATE INDEX idx_emails_message_id ON emails(message_id);
|
46
|
+
CREATE INDEX idx_recipients_address ON recipients(recipient_address);
|
47
|
+
CREATE INDEX idx_recipients_type ON recipients(recipient_type);
|
48
|
+
CREATE INDEX idx_recipients_email_id ON recipients(email_id);
|
49
|
+
CREATE INDEX idx_recipients_address_email ON recipients(recipient_address, email_id);
|
50
|
+
|
51
|
+
CREATE VIRTUAL TABLE emails_fts USING fts5(
|
52
|
+
subject,
|
53
|
+
body,
|
54
|
+
content='emails',
|
55
|
+
content_rowid='id'
|
56
|
+
);
|
57
|
+
|
58
|
+
CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN
|
59
|
+
INSERT INTO emails_fts (rowid, subject, body)
|
60
|
+
VALUES (new.id, new.subject, new.body);
|
61
|
+
END;
|
62
|
+
|
63
|
+
CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN
|
64
|
+
DELETE FROM emails_fts WHERE rowid=old.id;
|
65
|
+
END;
|
66
|
+
|
67
|
+
CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN
|
68
|
+
UPDATE emails_fts SET subject=new.subject, body=new.body WHERE rowid=old.id;
|
69
|
+
END;
|
70
|
+
|
71
|
+
INSERT INTO emails_fts (rowid, subject, body) SELECT id, subject, body FROM emails;
|
72
|
+
"""
|
73
|
+
|
74
|
+
|
75
|
+
# --- Functions ---
|
76
|
+
|
77
|
+
|
78
|
+
def download_dataset(repo_id: str) -> Dataset:
|
79
|
+
"""Downloads the dataset from Hugging Face Hub."""
|
80
|
+
logging.info(f"Attempting to download dataset from Hugging Face Hub: {repo_id}")
|
81
|
+
expected_features = Features(
|
82
|
+
{
|
83
|
+
"message_id": Value("string"),
|
84
|
+
"subject": Value("string"),
|
85
|
+
"from": Value("string"),
|
86
|
+
"to": Sequence(Value("string")),
|
87
|
+
"cc": Sequence(Value("string")),
|
88
|
+
"bcc": Sequence(Value("string")),
|
89
|
+
"date": Value("timestamp[us]"),
|
90
|
+
"body": Value("string"),
|
91
|
+
"file_name": Value("string"),
|
92
|
+
}
|
93
|
+
)
|
94
|
+
dataset_obj = load_dataset(repo_id, features=expected_features, split="train")
|
95
|
+
# Basic type check remains useful
|
96
|
+
if not isinstance(dataset_obj, Dataset):
|
97
|
+
raise TypeError(f"Expected Dataset, got {type(dataset_obj)}")
|
98
|
+
logging.info(f"Successfully loaded dataset '{repo_id}' with {len(dataset_obj)} records.")
|
99
|
+
return dataset_obj
|
100
|
+
|
101
|
+
|
102
|
+
def create_database(db_path: str):
|
103
|
+
"""Creates the SQLite database and tables."""
|
104
|
+
logging.info(f"Creating SQLite database and tables at: {db_path}")
|
105
|
+
conn = sqlite3.connect(db_path)
|
106
|
+
cursor = conn.cursor()
|
107
|
+
cursor.executescript(SQL_CREATE_TABLES)
|
108
|
+
conn.commit()
|
109
|
+
conn.close()
|
110
|
+
logging.info("Database tables created successfully.")
|
111
|
+
|
112
|
+
|
113
|
+
def populate_database(db_path: str, dataset: Dataset):
|
114
|
+
"""Populates the database with data from the Hugging Face dataset."""
|
115
|
+
logging.info(f"Populating database {db_path}...")
|
116
|
+
conn = sqlite3.connect(db_path)
|
117
|
+
cursor = conn.cursor()
|
118
|
+
|
119
|
+
# --- Performance Pragmas ---
|
120
|
+
conn.execute("PRAGMA synchronous = OFF;")
|
121
|
+
conn.execute("PRAGMA journal_mode = MEMORY;")
|
122
|
+
|
123
|
+
record_count = 0
|
124
|
+
skipped_count = 0 # Keep track of skipped emails due to filters
|
125
|
+
duplicate_count = 0 # Keep track of skipped duplicate emails
|
126
|
+
processed_emails = set() # Track (subject, body, from) tuples to dedupe
|
127
|
+
|
128
|
+
conn.execute("BEGIN TRANSACTION;") # Single transaction for bulk insert
|
129
|
+
|
130
|
+
for email_data in tqdm(dataset, desc="Inserting emails"):
|
131
|
+
assert isinstance(email_data, dict)
|
132
|
+
message_id = email_data["message_id"]
|
133
|
+
subject = email_data["subject"]
|
134
|
+
from_address = email_data["from"]
|
135
|
+
date_obj: datetime = email_data["date"]
|
136
|
+
body = email_data["body"]
|
137
|
+
file_name = email_data["file_name"]
|
138
|
+
to_list_raw = email_data["to"]
|
139
|
+
cc_list_raw = email_data["cc"]
|
140
|
+
bcc_list_raw = email_data["bcc"]
|
141
|
+
|
142
|
+
# --- Data Cleaning and Filtering ---
|
143
|
+
date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S")
|
144
|
+
to_list = [str(addr) for addr in to_list_raw if addr]
|
145
|
+
cc_list = [str(addr) for addr in cc_list_raw if addr]
|
146
|
+
bcc_list = [str(addr) for addr in bcc_list_raw if addr]
|
147
|
+
|
148
|
+
# Check body length
|
149
|
+
if len(body) > 5000:
|
150
|
+
logging.debug(f"Skipping email {message_id}: Body length > 5000 characters.")
|
151
|
+
skipped_count += 1
|
152
|
+
continue
|
153
|
+
|
154
|
+
# Check total recipients
|
155
|
+
total_recipients = len(to_list) + len(cc_list) + len(bcc_list)
|
156
|
+
if total_recipients > 30:
|
157
|
+
logging.debug(
|
158
|
+
f"Skipping email {message_id}: Total recipients ({total_recipients}) > 30."
|
159
|
+
)
|
160
|
+
skipped_count += 1
|
161
|
+
continue
|
162
|
+
# --- End Filtering ---
|
163
|
+
|
164
|
+
# --- Deduplication Check ---
|
165
|
+
email_key = (subject, body, from_address)
|
166
|
+
if email_key in processed_emails:
|
167
|
+
logging.debug(
|
168
|
+
f"Skipping duplicate email (Subject: {subject[:50]}..., From: {from_address})"
|
169
|
+
)
|
170
|
+
duplicate_count += 1
|
171
|
+
continue
|
172
|
+
else:
|
173
|
+
processed_emails.add(email_key)
|
174
|
+
# --- End Deduplication ---
|
175
|
+
|
176
|
+
cursor.execute(
|
177
|
+
"""
|
178
|
+
INSERT INTO emails (message_id, subject, from_address, date, body, file_name)
|
179
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
180
|
+
""",
|
181
|
+
(message_id, subject, from_address, date_str, body, file_name),
|
182
|
+
)
|
183
|
+
email_pk_id = cursor.lastrowid
|
184
|
+
|
185
|
+
recipient_data = []
|
186
|
+
for addr in to_list:
|
187
|
+
recipient_data.append((email_pk_id, addr, "to"))
|
188
|
+
for addr in cc_list:
|
189
|
+
recipient_data.append((email_pk_id, addr, "cc"))
|
190
|
+
for addr in bcc_list:
|
191
|
+
recipient_data.append((email_pk_id, addr, "bcc"))
|
192
|
+
|
193
|
+
if recipient_data:
|
194
|
+
cursor.executemany(
|
195
|
+
"""
|
196
|
+
INSERT INTO recipients (email_id, recipient_address, recipient_type)
|
197
|
+
VALUES (?, ?, ?)
|
198
|
+
""",
|
199
|
+
recipient_data,
|
200
|
+
)
|
201
|
+
record_count += 1
|
202
|
+
|
203
|
+
conn.commit()
|
204
|
+
conn.close()
|
205
|
+
logging.info(f"Successfully inserted {record_count} email records.")
|
206
|
+
if skipped_count > 0:
|
207
|
+
logging.info(f"Skipped {skipped_count} email records due to length or recipient limits.")
|
208
|
+
if duplicate_count > 0:
|
209
|
+
logging.info(
|
210
|
+
f"Skipped {duplicate_count} duplicate email records (based on subject, body, from)."
|
211
|
+
)
|
212
|
+
|
213
|
+
|
214
|
+
def create_indexes_and_triggers(db_path: str):
|
215
|
+
"""Creates indexes and triggers on the populated database."""
|
216
|
+
logging.info(f"Creating indexes and triggers for database: {db_path}...")
|
217
|
+
conn = sqlite3.connect(db_path)
|
218
|
+
cursor = conn.cursor()
|
219
|
+
cursor.executescript(SQL_CREATE_INDEXES_TRIGGERS)
|
220
|
+
conn.commit()
|
221
|
+
conn.close()
|
222
|
+
logging.info("Indexes and triggers created successfully.")
|
223
|
+
|
224
|
+
|
225
|
+
def generate_database(overwrite: bool = False):
|
226
|
+
"""
|
227
|
+
Generates the SQLite database from the specified Hugging Face dataset.
|
228
|
+
Simplified version without extensive error handling.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
repo_id: The Hugging Face repository ID for the dataset.
|
232
|
+
db_path: The path where the SQLite database file should be created.
|
233
|
+
overwrite: If True, any existing database file at db_path will be removed.
|
234
|
+
"""
|
235
|
+
logging.info(
|
236
|
+
f"Starting database generation for repo '{DEFAULT_REPO_ID}' at '{DEFAULT_DB_PATH}'"
|
237
|
+
)
|
238
|
+
logging.info(f"Overwrite existing database: {overwrite}")
|
239
|
+
|
240
|
+
db_dir = os.path.dirname(DEFAULT_DB_PATH)
|
241
|
+
if db_dir and not os.path.exists(db_dir):
|
242
|
+
logging.info(f"Creating data directory: {db_dir}")
|
243
|
+
os.makedirs(db_dir)
|
244
|
+
|
245
|
+
if overwrite and os.path.exists(DEFAULT_DB_PATH):
|
246
|
+
logging.info(f"Removing existing database file: {DEFAULT_DB_PATH}")
|
247
|
+
os.remove(DEFAULT_DB_PATH)
|
248
|
+
|
249
|
+
if not os.path.exists(DEFAULT_DB_PATH):
|
250
|
+
dataset = download_dataset(DEFAULT_REPO_ID)
|
251
|
+
create_database(DEFAULT_DB_PATH)
|
252
|
+
populate_database(DEFAULT_DB_PATH, dataset)
|
253
|
+
create_indexes_and_triggers(DEFAULT_DB_PATH)
|
254
|
+
|
255
|
+
# ---- new: add unique index post-creation if not already handled by SQL_CREATE_INDEXES_TRIGGERS ---
|
256
|
+
# This ensures the index exists even if SQL_CREATE_INDEXES_TRIGGERS was modified
|
257
|
+
# or if we want to be absolutely certain this specific index is applied.
|
258
|
+
conn = sqlite3.connect(DEFAULT_DB_PATH)
|
259
|
+
cur = conn.cursor()
|
260
|
+
logging.info("Ensuring UNIQUE index on emails.message_id exists...")
|
261
|
+
cur.executescript(
|
262
|
+
"""
|
263
|
+
CREATE UNIQUE INDEX IF NOT EXISTS idx_emails_message_id
|
264
|
+
ON emails(message_id);
|
265
|
+
"""
|
266
|
+
)
|
267
|
+
conn.commit()
|
268
|
+
conn.close()
|
269
|
+
logging.info("UNIQUE index on emails.message_id verified/created.")
|
270
|
+
# ---- end new section ----
|
271
|
+
else:
|
272
|
+
logging.info(
|
273
|
+
f"Database already exists at {DEFAULT_DB_PATH}. Set overwrite=True to regenerate."
|
274
|
+
)
|
275
|
+
|
276
|
+
logging.info("Database generation process complete.")
|
277
|
+
|
278
|
+
|
279
|
+
if __name__ == "__main__":
|
280
|
+
generate_database(overwrite=True)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
from typing import List, Optional
|
3
|
+
|
4
|
+
|
5
|
+
class SyntheticQuery(BaseModel):
|
6
|
+
id: int
|
7
|
+
question: str
|
8
|
+
answer: str
|
9
|
+
message_ids: List[str] # message_ids (strings) of referenced emails
|
10
|
+
how_realistic: float
|
11
|
+
inbox_address: str
|
12
|
+
query_date: str
|
13
|
+
|
14
|
+
|
15
|
+
class Email(BaseModel):
|
16
|
+
message_id: str
|
17
|
+
date: str # ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
|
18
|
+
subject: Optional[str] = None
|
19
|
+
from_address: Optional[str] = None
|
20
|
+
to_addresses: List[str] = [] # Populated from recipients table
|
21
|
+
cc_addresses: List[str] = [] # Populated from recipients table
|
22
|
+
bcc_addresses: List[str] = [] # Populated from recipients table
|
23
|
+
body: Optional[str] = None
|
24
|
+
file_name: Optional[str] = None
|
@@ -0,0 +1,291 @@
|
|
1
|
+
# engine.py
|
2
|
+
from __future__ import annotations
|
3
|
+
from dataclasses import dataclass, asdict
|
4
|
+
from typing import Any, Dict, Tuple, Optional, List
|
5
|
+
from pydantic import BaseModel
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
from synth_ai.environments.examples.enron.art_helpers.types_enron import Email
|
9
|
+
from synth_ai.environments.examples.enron.art_helpers.email_search_tools import (
|
10
|
+
search_emails as helper_search_emails,
|
11
|
+
read_email as helper_read_email,
|
12
|
+
SearchResult,
|
13
|
+
)
|
14
|
+
|
15
|
+
# SQLite-backed helpers
|
16
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
17
|
+
from synth_ai.environments.examples.enron.taskset import EnronTaskInstance
|
18
|
+
from synth_ai.zyk import LM # Import LM class
|
19
|
+
|
20
|
+
from synth_ai.environments.environment.db.sqlite import SQLiteManager
|
21
|
+
from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent
|
22
|
+
from synth_ai.environments.examples.enron.art_helpers.local_email_db import (
|
23
|
+
DEFAULT_DB_PATH,
|
24
|
+
generate_database,
|
25
|
+
)
|
26
|
+
|
27
|
+
# --------------------------------------------------------------------------- actions
|
28
|
+
ACTION_SEARCH = "search"
|
29
|
+
ACTION_READ = "read"
|
30
|
+
ACTION_ANSWER = "answer"
|
31
|
+
|
32
|
+
|
33
|
+
# --------------------------------------------------------------------------- snapshot
|
34
|
+
@dataclass
|
35
|
+
class EnronEngineSnapshot(StatefulEngineSnapshot):
|
36
|
+
idx: int
|
37
|
+
answered: bool
|
38
|
+
total_reward: float
|
39
|
+
partial_rewards: List[float]
|
40
|
+
|
41
|
+
|
42
|
+
# --------------------------------------------------------------------------- engine
|
43
|
+
class EnronEngine(StatefulEngine):
|
44
|
+
"""
|
45
|
+
Minimal state-machine around the corbt/enron_emails_sample_questions dataset.
|
46
|
+
Action is a tuple(kind, arg):
|
47
|
+
|
48
|
+
(ACTION_SEARCH, query: str) → returns {"search_results": [message_ids]}
|
49
|
+
(ACTION_READ, message_id: str) → returns {"email_body": str}
|
50
|
+
(ACTION_ANSWER, answer: str) → rewards +1 / -1 and terminates
|
51
|
+
"""
|
52
|
+
|
53
|
+
# ----------------------------- init / helpers
|
54
|
+
def __init__(self, task_instance: EnronTaskInstance):
|
55
|
+
# Use the provided TaskInstance snapshot for this episode
|
56
|
+
self.instance = task_instance
|
57
|
+
self.answered = False
|
58
|
+
self.total_reward = 0.0
|
59
|
+
self.idx = 0
|
60
|
+
# List to track each step's reward
|
61
|
+
self.rewards_history: List[float] = []
|
62
|
+
|
63
|
+
db_file_path = Path(DEFAULT_DB_PATH)
|
64
|
+
if not db_file_path.exists():
|
65
|
+
generate_database(overwrite=False) # Ensure DB exists
|
66
|
+
self.sqlite_manager = SQLiteManager(db_path=db_file_path, read_only=True)
|
67
|
+
|
68
|
+
# RewardStack is an attribute of the engine; its calculations update private_state fields
|
69
|
+
self.reward_stack = RewardStack(
|
70
|
+
components=[
|
71
|
+
EnronAnswerCorrectnessComponent(),
|
72
|
+
EnronStepPenaltyComponent(penalty=-0.05),
|
73
|
+
]
|
74
|
+
)
|
75
|
+
# This will hold the specific arguments/details of the current agent action
|
76
|
+
# for the reward components to inspect.
|
77
|
+
self._current_action_details_for_reward: Optional[Dict[str, Any]] = None
|
78
|
+
|
79
|
+
def _sample(self) -> Dict[str, Any]:
|
80
|
+
# Return the snapshot dict from the TaskInstance
|
81
|
+
return self.instance.initial_engine_snapshot
|
82
|
+
|
83
|
+
# ----------------------------- step / reset
|
84
|
+
async def _step_engine(
|
85
|
+
self, tool_output_payload: Optional[Dict[str, Any]]
|
86
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
87
|
+
r = await self._calculate_and_apply_reward()
|
88
|
+
|
89
|
+
# Determine termination: if an answer was attempted, task terminates.
|
90
|
+
# The 'answered' flag is set by answer_question_action.
|
91
|
+
term = self.answered
|
92
|
+
|
93
|
+
s = self._sample()
|
94
|
+
priv = {
|
95
|
+
"reward_last": r,
|
96
|
+
"total_reward": self.total_reward,
|
97
|
+
"terminated": term,
|
98
|
+
"truncated": False,
|
99
|
+
"gold_answer": s["answer"],
|
100
|
+
}
|
101
|
+
|
102
|
+
# Public state combines static elements with dynamic ones from tool_output_payload
|
103
|
+
pub = {
|
104
|
+
"question": s["question"],
|
105
|
+
"tools": [
|
106
|
+
"search_emails",
|
107
|
+
"read_email",
|
108
|
+
"answer_question",
|
109
|
+
"terminate",
|
110
|
+
], # Available tools
|
111
|
+
"already_answered": self.answered,
|
112
|
+
"query_date": s.get("query_date", "<unknown date>"),
|
113
|
+
"inbox_address": s.get("inbox_address", "<unknown_inbox>"),
|
114
|
+
# Default empty values, to be overwritten by tool_output_payload if present
|
115
|
+
"search_results": [],
|
116
|
+
"email": None,
|
117
|
+
**(tool_output_payload if tool_output_payload else {}),
|
118
|
+
}
|
119
|
+
|
120
|
+
return priv, pub
|
121
|
+
|
122
|
+
async def _reset_engine(
|
123
|
+
self, *, seed: Optional[int] = None
|
124
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
125
|
+
"""
|
126
|
+
Advance to the next Q-A pair and emit an initial observation **without**
|
127
|
+
issuing an empty-keyword DB search (which would raise).
|
128
|
+
"""
|
129
|
+
# Reset answered status and total reward for this instance
|
130
|
+
self.answered = False
|
131
|
+
self.total_reward = 0.0
|
132
|
+
self.rewards_history = []
|
133
|
+
self._current_action_details_for_reward = None
|
134
|
+
# self.sqlite_manager.reset() # Enron DB is read-only; reset usually not needed unless switching DB files.
|
135
|
+
|
136
|
+
s = self._sample()
|
137
|
+
priv = {
|
138
|
+
"reward_last": 0.0,
|
139
|
+
"total_reward": 0.0,
|
140
|
+
"terminated": False,
|
141
|
+
"truncated": False,
|
142
|
+
"gold_answer": s["answer"],
|
143
|
+
}
|
144
|
+
pub = {
|
145
|
+
"question": s["question"],
|
146
|
+
"tools": ["search_emails", "read_email", "answer_question", "terminate"],
|
147
|
+
"already_answered": False,
|
148
|
+
"query_date": s.get("query_date", "<unknown date>"),
|
149
|
+
"inbox_address": s.get("inbox_address", "<unknown_inbox>"),
|
150
|
+
"search_results": [],
|
151
|
+
"email": None,
|
152
|
+
}
|
153
|
+
# No index advancement needed when using a single TaskInstance
|
154
|
+
return priv, pub
|
155
|
+
|
156
|
+
# ----------------------------- serialization helpers
|
157
|
+
async def _serialize_engine(self) -> EnronEngineSnapshot:
|
158
|
+
# Include partial rewards history in the snapshot
|
159
|
+
return EnronEngineSnapshot(
|
160
|
+
self.idx,
|
161
|
+
self.answered,
|
162
|
+
self.total_reward,
|
163
|
+
self.rewards_history,
|
164
|
+
)
|
165
|
+
|
166
|
+
@classmethod
|
167
|
+
async def _deserialize_engine(
|
168
|
+
cls, snap: EnronEngineSnapshot, task_instance: EnronTaskInstance
|
169
|
+
) -> "EnronEngine":
|
170
|
+
eng = cls(task_instance)
|
171
|
+
eng.idx = snap.idx
|
172
|
+
eng.answered = snap.answered
|
173
|
+
eng.total_reward = snap.total_reward
|
174
|
+
eng.rewards_history = (
|
175
|
+
snap.partial_rewards
|
176
|
+
) # Ensure this is correctly typed in Pydantic model if not List[float]
|
177
|
+
# Note: SQLiteManager is re-initialized in __init__ based on DEFAULT_DB_PATH.
|
178
|
+
# If the db path could change per instance/snapshot, that would need to be part of the snapshot.
|
179
|
+
return eng
|
180
|
+
|
181
|
+
def close_db(self):
|
182
|
+
self.sqlite_manager.close()
|
183
|
+
|
184
|
+
async def _calculate_and_apply_reward(self) -> float:
|
185
|
+
s = self._sample()
|
186
|
+
reward_context_state = { # State snapshot for reward calculation
|
187
|
+
"question": s["question"],
|
188
|
+
"gold_answer": s["answer"],
|
189
|
+
**(
|
190
|
+
self._current_action_details_for_reward
|
191
|
+
if self._current_action_details_for_reward
|
192
|
+
else {}
|
193
|
+
),
|
194
|
+
}
|
195
|
+
|
196
|
+
# The 'action' param for score can be the conceptual action type or detailed args
|
197
|
+
action_param_for_score = (
|
198
|
+
self._current_action_details_for_reward
|
199
|
+
if self._current_action_details_for_reward
|
200
|
+
else {}
|
201
|
+
)
|
202
|
+
|
203
|
+
reward = await self.reward_stack.step_reward(
|
204
|
+
state=reward_context_state, action=action_param_for_score
|
205
|
+
)
|
206
|
+
|
207
|
+
self.total_reward += reward
|
208
|
+
self.rewards_history.append(reward)
|
209
|
+
self._current_action_details_for_reward = None # Reset after use
|
210
|
+
return reward
|
211
|
+
|
212
|
+
async def search_emails_action(self, search_args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
213
|
+
res: List[SearchResult] = helper_search_emails(self.sqlite_manager, **search_args)
|
214
|
+
self._current_action_details_for_reward = {"type": "search", **search_args}
|
215
|
+
return [asdict(item) for item in res]
|
216
|
+
|
217
|
+
async def read_email_action(self, message_id: str) -> Optional[Dict[str, Any]]:
|
218
|
+
email: Optional[Email] = helper_read_email(self.sqlite_manager, message_id)
|
219
|
+
self._current_action_details_for_reward = {
|
220
|
+
"type": "read",
|
221
|
+
"message_id": message_id,
|
222
|
+
}
|
223
|
+
return email.dict() if email else None
|
224
|
+
|
225
|
+
async def answer_question_action(self, agent_answer: str) -> None:
|
226
|
+
# This method now primarily sets up state for reward calculation.
|
227
|
+
# The actual reward value and termination status are determined by _get_reward_and_update_state.
|
228
|
+
s = self._sample()
|
229
|
+
self._current_action_details_for_reward = {
|
230
|
+
"type": "answer",
|
231
|
+
"is_answer_action": True, # Signal for reward component
|
232
|
+
"question": s["question"],
|
233
|
+
"gold_answer": s["answer"],
|
234
|
+
"agent_answer": agent_answer,
|
235
|
+
}
|
236
|
+
self.answered = True # Mark as answered, termination decided by reward logic
|
237
|
+
|
238
|
+
|
239
|
+
# ----------------------------- LLM Judge for answers
|
240
|
+
async def determine_if_answer_is_correct(
|
241
|
+
question: str, gold_answer: str, agent_answer: str
|
242
|
+
) -> bool:
|
243
|
+
# Instantiate LM for the judge
|
244
|
+
llm = LM(model_name="gpt-4.1-nano", formatting_model_name="gpt-4.1-nano", temperature=0.0)
|
245
|
+
|
246
|
+
system_prompt = (
|
247
|
+
"You will be given a question and two different answers to the question, "
|
248
|
+
"the correct answer and the answer given by an AI. Your job is to determine "
|
249
|
+
"if the answer given by the AI is correct."
|
250
|
+
)
|
251
|
+
user_message_content = (
|
252
|
+
f"Question: {question}\nCorrect answer: {gold_answer}\nAI answer: {agent_answer}"
|
253
|
+
)
|
254
|
+
|
255
|
+
class CorrectnessResponse(BaseModel):
|
256
|
+
correct: bool
|
257
|
+
|
258
|
+
# Use LM.respond_async
|
259
|
+
response = await llm.respond_async(
|
260
|
+
system_message=system_prompt,
|
261
|
+
user_message=user_message_content,
|
262
|
+
response_model=CorrectnessResponse,
|
263
|
+
# Caching is typically handled within the LM class or its underlying setup
|
264
|
+
)
|
265
|
+
return response.structured_output.correct
|
266
|
+
|
267
|
+
|
268
|
+
# --- Placeholder Reward Components (ideally defined elsewhere and imported) ---
|
269
|
+
# (These would typically live in a shared rewards components file or alongside the engine if very specific)
|
270
|
+
class EnronAnswerCorrectnessComponent(RewardComponent):
|
271
|
+
async def score(self, state: Dict[str, Any], action: Any) -> float:
|
272
|
+
if state.get("is_answer_action") and state.get("agent_answer") is not None:
|
273
|
+
# determine_if_answer_is_correct should be part of the engine or accessible
|
274
|
+
# For now, assuming it's available in this scope.
|
275
|
+
correct = await determine_if_answer_is_correct(
|
276
|
+
state["question"], state["gold_answer"], state["agent_answer"]
|
277
|
+
)
|
278
|
+
return 1.0 if correct else -1.0
|
279
|
+
return 0.0
|
280
|
+
|
281
|
+
|
282
|
+
class EnronStepPenaltyComponent(RewardComponent):
|
283
|
+
def __init__(self, penalty: float = -0.01):
|
284
|
+
self.penalty = penalty
|
285
|
+
|
286
|
+
async def score(self, state: Dict[str, Any], action: Any) -> float:
|
287
|
+
# Apply penalty for any action that isn't a final answer, or just every step.
|
288
|
+
# For simplicity, apply if not a "correct" answer action.
|
289
|
+
if not state.get("is_answer_action"):
|
290
|
+
return self.penalty
|
291
|
+
return 0.0
|