synth-ai 0.2.0__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +35 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +1 -1
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.2.0.dist-info/METADATA +0 -36
- synth_ai-0.2.0.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info/licenses}/LICENSE +0 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,535 @@
|
|
1
|
+
# react_agent.py ── minimal ReAct agent for the new tools (LLM wiring identical to Sokoban pattern)
|
2
|
+
# Combined with tests_eval_enron.py
|
3
|
+
import asyncio
|
4
|
+
import json
|
5
|
+
import uuid
|
6
|
+
from collections import deque
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any, Deque, Dict, List
|
9
|
+
import textwrap
|
10
|
+
import os
|
11
|
+
|
12
|
+
import pytest
|
13
|
+
from pydantic import BaseModel
|
14
|
+
|
15
|
+
from synth_ai.zyk import LM
|
16
|
+
from synth_sdk.tracing.abstractions import Dataset, RewardSignal, TrainingQuestion
|
17
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
18
|
+
|
19
|
+
from synth_ai.environments.examples.enron.engine import ACTION_ANSWER
|
20
|
+
from synth_ai.environments.examples.enron.environment import (
|
21
|
+
AnswerQuestion,
|
22
|
+
AnswerQuestionArgs,
|
23
|
+
EnronEnvironment,
|
24
|
+
ReadEmail,
|
25
|
+
ReadEmailArgs,
|
26
|
+
SearchEmails,
|
27
|
+
SearchEmailsArgs,
|
28
|
+
Terminate,
|
29
|
+
)
|
30
|
+
from synth_ai.environments.examples.enron.taskset import create_enron_taskset
|
31
|
+
from synth_ai.environments.examples.enron.art_helpers import local_email_db, email_search_tools
|
32
|
+
|
33
|
+
# ensure SQLite email database exists in dataset directory
|
34
|
+
# align database path with HF dataset cache folder
|
35
|
+
DATASET_DIR = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "dataset"))
|
36
|
+
os.makedirs(DATASET_DIR, exist_ok=True)
|
37
|
+
DB_PATH = os.path.join(DATASET_DIR, "enron_emails.db")
|
38
|
+
local_email_db.DEFAULT_DB_PATH = DB_PATH
|
39
|
+
email_search_tools.DEFAULT_DB_PATH = DB_PATH
|
40
|
+
if not os.path.exists(DB_PATH):
|
41
|
+
local_email_db.generate_database(overwrite=False)
|
42
|
+
|
43
|
+
|
44
|
+
# ---- schemas for function-calling LLM
|
45
|
+
class TerminateArgs(BaseModel):
|
46
|
+
reason: str
|
47
|
+
|
48
|
+
|
49
|
+
# ---- ReAct Agent
|
50
|
+
class ReActEnronAgent:
|
51
|
+
def __init__(self, llm: LM, max_steps: int = 8, tool_window: int = 12):
|
52
|
+
self.llm, self.max_steps = llm, max_steps
|
53
|
+
self.history: Deque[Dict[str, Any]] = deque(maxlen=20)
|
54
|
+
self.system_name, self.system_instance_id = "enron-react", str(uuid.uuid4())
|
55
|
+
self.tool_window = tool_window
|
56
|
+
self.tool_history: Deque[Dict[str, Any]] = deque(maxlen=self.tool_window)
|
57
|
+
|
58
|
+
self.tools = [
|
59
|
+
{
|
60
|
+
"type": "function",
|
61
|
+
"function": {
|
62
|
+
"name": "search_emails",
|
63
|
+
"description": (
|
64
|
+
"Full-text search over the inbox. "
|
65
|
+
"`keywords` **must** be a list of individual words "
|
66
|
+
'— e.g. ["Jeff","Skilling","Enron","stock"]. '
|
67
|
+
"Do NOT wrap whole sentences or use quotes."
|
68
|
+
),
|
69
|
+
"parameters": SearchEmailsArgs.model_json_schema(),
|
70
|
+
},
|
71
|
+
},
|
72
|
+
{
|
73
|
+
"type": "function",
|
74
|
+
"function": {
|
75
|
+
"name": "read_email",
|
76
|
+
"description": "Read a single email by message-id",
|
77
|
+
"parameters": ReadEmailArgs.model_json_schema(),
|
78
|
+
},
|
79
|
+
},
|
80
|
+
{
|
81
|
+
"type": "function",
|
82
|
+
"function": {
|
83
|
+
"name": "answer_question",
|
84
|
+
"description": "Final answer to the user's question",
|
85
|
+
"parameters": AnswerQuestionArgs.model_json_schema(),
|
86
|
+
},
|
87
|
+
},
|
88
|
+
{
|
89
|
+
"type": "function",
|
90
|
+
"function": {
|
91
|
+
"name": "terminate",
|
92
|
+
"description": "Stop the episode",
|
93
|
+
"parameters": TerminateArgs.model_json_schema(),
|
94
|
+
},
|
95
|
+
},
|
96
|
+
]
|
97
|
+
|
98
|
+
async def act(self, observation: Dict[str, Any]) -> EnvToolCall:
|
99
|
+
# --- build prompt -------------------------------------------------
|
100
|
+
# ① never leak evaluation labels to the LLM
|
101
|
+
obs_filtered = {k: v for k, v in observation.items() if k != "gold_answer"}
|
102
|
+
self.history.append({"obs": obs_filtered})
|
103
|
+
|
104
|
+
# ─── dynamic context pulled from the latest env observation ───
|
105
|
+
user_email = observation.get("inbox_address", "<unknown>")
|
106
|
+
user_query = observation.get("question", "")
|
107
|
+
today_string = observation.get("query_date", "<unknown date>")
|
108
|
+
|
109
|
+
# ----- expose the *current* search hits (max 5) to the LLM -----
|
110
|
+
sr = observation.get("search_results", [])
|
111
|
+
# ② keep hit list in-sync; clear it on 0-hit searches
|
112
|
+
self.last_hits = sr
|
113
|
+
hits = getattr(self, "last_hits", [])
|
114
|
+
hits_block = (
|
115
|
+
"\n".join(
|
116
|
+
f"{i + 1}. {h.get('message_id', 'N/A')} : "
|
117
|
+
f"{(h.get('snippet', '') or '')[:120].replace(chr(10), ' ')}…"
|
118
|
+
for i, h in enumerate(hits[:10])
|
119
|
+
)
|
120
|
+
if hits
|
121
|
+
else "No search results yet."
|
122
|
+
)
|
123
|
+
|
124
|
+
# ----- expose a short excerpt of the last-opened email ----------
|
125
|
+
em = observation.get("email")
|
126
|
+
if em and isinstance(em, dict) and em.get("body"):
|
127
|
+
email_excerpt = em["body"][:10000].replace("\n", " ") + "…"
|
128
|
+
else:
|
129
|
+
email_excerpt = "No email opened yet."
|
130
|
+
|
131
|
+
# system prompt: role, tool rules *and* context -------------------
|
132
|
+
history_block = self._format_tool_history()
|
133
|
+
system_message = textwrap.dedent(f'''
|
134
|
+
You are an email-search agent.
|
135
|
+
|
136
|
+
• When calling **search_emails** pass *individual* words in `keywords`.
|
137
|
+
Example → `search_emails(keywords=["Jeff","Skilling","sell","Enron","stock"])`
|
138
|
+
(never a whole sentence or use quotes).
|
139
|
+
|
140
|
+
• If a search returns 0 results, try different terms or read a promising
|
141
|
+
message-id.
|
142
|
+
|
143
|
+
You may take up to {self.max_steps} turns; finish with
|
144
|
+
`answer_question` once confident.
|
145
|
+
|
146
|
+
If an email already contains the answer, IMMEDIATELY finish with
|
147
|
+
`answer_question(answer="…")`.
|
148
|
+
• When calling `answer_question`, return only the exact answer sentence verbatim as it appears in the source; do not add any extra explanation or text.
|
149
|
+
|
150
|
+
Recent tool history:
|
151
|
+
{history_block}
|
152
|
+
|
153
|
+
Context
|
154
|
+
────────
|
155
|
+
• Inbox you can query: **{user_email}**
|
156
|
+
• Today's date: **{today_string}**
|
157
|
+
|
158
|
+
Original user question:
|
159
|
+
"""{user_query}"""
|
160
|
+
|
161
|
+
Latest search hits:
|
162
|
+
{hits_block}
|
163
|
+
|
164
|
+
Latest email excerpt:
|
165
|
+
{email_excerpt}
|
166
|
+
''').strip()
|
167
|
+
|
168
|
+
user_message = json.dumps({"history": list(self.history)})
|
169
|
+
|
170
|
+
resp = await self.llm.respond_async(
|
171
|
+
system_message=system_message,
|
172
|
+
user_message=user_message,
|
173
|
+
tools=self.tools,
|
174
|
+
)
|
175
|
+
if not resp.tool_calls:
|
176
|
+
self.history.append({"tool": "no_op", "args": "LLM returned no tool calls."})
|
177
|
+
return AnswerQuestion("")
|
178
|
+
|
179
|
+
primary_action_to_execute = None
|
180
|
+
|
181
|
+
for i, tc in enumerate(resp.tool_calls):
|
182
|
+
if isinstance(tc, dict):
|
183
|
+
# Response from a model that returns dicts (e.g. some OSS models)
|
184
|
+
fc = tc.get("function", {})
|
185
|
+
name, args_json_str = fc.get("name"), fc.get("arguments")
|
186
|
+
elif hasattr(tc, "function"):
|
187
|
+
# Response from OpenAI, Anthropic (object with .function attribute)
|
188
|
+
name = tc.function.name
|
189
|
+
args_json_str = tc.function.arguments
|
190
|
+
else:
|
191
|
+
self.history.append(
|
192
|
+
{
|
193
|
+
"tool": "unknown_format",
|
194
|
+
"raw_tool_call": str(tc),
|
195
|
+
"error": "Unknown tool call format",
|
196
|
+
}
|
197
|
+
)
|
198
|
+
if i == 0:
|
199
|
+
primary_action_to_execute = AnswerQuestion(
|
200
|
+
""
|
201
|
+
) # Fallback if first call is bad format
|
202
|
+
continue
|
203
|
+
|
204
|
+
if not name or args_json_str is None:
|
205
|
+
print(
|
206
|
+
f"Tool call {i}: Missing name or arguments. Name: '{name}', Args: '{args_json_str}'. Skipping."
|
207
|
+
)
|
208
|
+
self.history.append(
|
209
|
+
{
|
210
|
+
"tool": name or "unknown_name",
|
211
|
+
"args_str": args_json_str,
|
212
|
+
"error": "Missing name or arguments",
|
213
|
+
}
|
214
|
+
)
|
215
|
+
if i == 0:
|
216
|
+
primary_action_to_execute = AnswerQuestion(
|
217
|
+
""
|
218
|
+
) # Fallback if first call is malformed
|
219
|
+
continue
|
220
|
+
|
221
|
+
try:
|
222
|
+
args = json.loads(args_json_str)
|
223
|
+
except json.JSONDecodeError as e:
|
224
|
+
print(f"Tool call {i} ({name}): JSON decode error: {e}. Args: '{args_json_str}'")
|
225
|
+
self.history.append(
|
226
|
+
{
|
227
|
+
"tool": name,
|
228
|
+
"args_str": args_json_str,
|
229
|
+
"error": "JSONDecodeError",
|
230
|
+
"detail": str(e),
|
231
|
+
}
|
232
|
+
)
|
233
|
+
if i == 0:
|
234
|
+
primary_action_to_execute = AnswerQuestion("")
|
235
|
+
continue
|
236
|
+
|
237
|
+
current_tool_env_call = None
|
238
|
+
history_entry_for_this_tool = {"tool": name, "args": args}
|
239
|
+
|
240
|
+
if name == "search_emails":
|
241
|
+
try:
|
242
|
+
parsed = SearchEmailsArgs(**args)
|
243
|
+
if parsed.max_results is None or parsed.max_results < 10:
|
244
|
+
parsed.max_results = 10
|
245
|
+
history_entry_for_this_tool["args"] = parsed.model_dump()
|
246
|
+
current_tool_env_call = SearchEmails(**parsed.model_dump())
|
247
|
+
except Exception as e:
|
248
|
+
print(f"Tool call {i} ({name}): Args parsing error: {e}")
|
249
|
+
history_entry_for_this_tool["error"] = (
|
250
|
+
f"SearchEmailsArgs parsing error: {str(e)}"
|
251
|
+
)
|
252
|
+
if i == 0:
|
253
|
+
primary_action_to_execute = AnswerQuestion("")
|
254
|
+
|
255
|
+
elif name == "read_email":
|
256
|
+
msg_id = args.get("message_id")
|
257
|
+
if msg_id and not msg_id.startswith("<") and not msg_id.endswith(">"):
|
258
|
+
msg_id = f"<{msg_id}>"
|
259
|
+
|
260
|
+
if msg_id is None:
|
261
|
+
print(f"Tool call {i} ({name}): message_id is missing.")
|
262
|
+
history_entry_for_this_tool["error"] = "message_id missing"
|
263
|
+
if i == 0:
|
264
|
+
primary_action_to_execute = AnswerQuestion("")
|
265
|
+
else:
|
266
|
+
history_entry_for_this_tool["args"] = {"message_id": msg_id}
|
267
|
+
current_tool_env_call = ReadEmail(message_id=msg_id)
|
268
|
+
|
269
|
+
elif name == "answer_question":
|
270
|
+
try:
|
271
|
+
parsed = AnswerQuestionArgs(**args)
|
272
|
+
history_entry_for_this_tool["args"] = parsed.model_dump()
|
273
|
+
current_tool_env_call = AnswerQuestion(parsed.answer)
|
274
|
+
except Exception as e:
|
275
|
+
print(f"Tool call {i} ({name}): Args parsing error: {e}")
|
276
|
+
history_entry_for_this_tool["error"] = (
|
277
|
+
f"AnswerQuestionArgs parsing error: {str(e)}"
|
278
|
+
)
|
279
|
+
if i == 0:
|
280
|
+
primary_action_to_execute = AnswerQuestion("")
|
281
|
+
|
282
|
+
elif name == "terminate":
|
283
|
+
try:
|
284
|
+
parsed = TerminateArgs(**args)
|
285
|
+
history_entry_for_this_tool["args"] = parsed.model_dump()
|
286
|
+
current_tool_env_call = Terminate()
|
287
|
+
except Exception as e:
|
288
|
+
print(
|
289
|
+
f"Tool call {i} ({name}): Args parsing error (TerminateArgs): {e}. Proceeding with Terminate()."
|
290
|
+
)
|
291
|
+
history_entry_for_this_tool["args"] = (
|
292
|
+
args # Log raw args if TerminateArgs parsing fails
|
293
|
+
)
|
294
|
+
history_entry_for_this_tool["error"] = (
|
295
|
+
f"TerminateArgs parsing error: {str(e)}, but Terminate() called"
|
296
|
+
)
|
297
|
+
current_tool_env_call = Terminate()
|
298
|
+
|
299
|
+
else:
|
300
|
+
print(f"Tool call {i}: Unknown tool name '{name}'")
|
301
|
+
history_entry_for_this_tool["error"] = "Unknown tool name"
|
302
|
+
if i == 0:
|
303
|
+
primary_action_to_execute = AnswerQuestion("")
|
304
|
+
|
305
|
+
self.history.append(history_entry_for_this_tool)
|
306
|
+
|
307
|
+
if i == 0 and primary_action_to_execute is None:
|
308
|
+
primary_action_to_execute = current_tool_env_call
|
309
|
+
|
310
|
+
if primary_action_to_execute is not None:
|
311
|
+
return primary_action_to_execute
|
312
|
+
else:
|
313
|
+
# Fallback if primary_action_to_execute is still None after the loop
|
314
|
+
# (e.g., first tool had an error but didn't set a fallback, or all tools had issues)
|
315
|
+
print(
|
316
|
+
"Fallback: No valid primary action determined from tool calls after processing all."
|
317
|
+
)
|
318
|
+
self.history.append(
|
319
|
+
{
|
320
|
+
"tool": "no_op",
|
321
|
+
"args": "No valid primary action derived from LLM tools after loop.",
|
322
|
+
}
|
323
|
+
)
|
324
|
+
return AnswerQuestion("")
|
325
|
+
|
326
|
+
def _format_tool_history(self) -> str:
|
327
|
+
lines = []
|
328
|
+
if not self.tool_history:
|
329
|
+
return "No calls yet."
|
330
|
+
for h in list(self.tool_history):
|
331
|
+
args_str = ""
|
332
|
+
action_name = h.get("name", "")
|
333
|
+
action_args = h.get("args", "")
|
334
|
+
|
335
|
+
if action_name == "search_emails" and isinstance(action_args, dict):
|
336
|
+
args_str = str(action_args.get("keywords", []))
|
337
|
+
elif isinstance(action_args, (str, int, float, bool)):
|
338
|
+
args_str = str(action_args)
|
339
|
+
elif (
|
340
|
+
isinstance(action_args, dict) and "keywords" in action_args
|
341
|
+
): # Cater for SearchEmails direct args
|
342
|
+
args_str = str(action_args.get("keywords", []))
|
343
|
+
else:
|
344
|
+
args_str = str(action_args) # Fallback
|
345
|
+
|
346
|
+
detail = h.get("result_detail", "")
|
347
|
+
lines.append(
|
348
|
+
f"{h.get('turn', 0)}. {action_name}({args_str}) → {h.get('result', '')}; {detail}"
|
349
|
+
)
|
350
|
+
return "\n".join(lines)
|
351
|
+
|
352
|
+
|
353
|
+
# ------------------------ helpers ----------------------------------------- #
|
354
|
+
async def run_episode(env: EnronEnvironment, agent: ReActEnronAgent) -> bool:
|
355
|
+
obs = await env.initialize()
|
356
|
+
for _ in range(agent.max_steps):
|
357
|
+
call = await agent.act(obs)
|
358
|
+
if isinstance(call, AnswerQuestion) and call.action[1]: # answered
|
359
|
+
obs = await env.step(call)
|
360
|
+
# Minimal logging for AnswerQuestion, as per user prompt (no extra detail)
|
361
|
+
tool_entry_answer = {
|
362
|
+
"turn": len(agent.tool_history) + 1,
|
363
|
+
"name": call.action[0],
|
364
|
+
"args": call.action[1],
|
365
|
+
"result": "Question answered",
|
366
|
+
}
|
367
|
+
agent.tool_history.append(tool_entry_answer)
|
368
|
+
break
|
369
|
+
obs = await env.step(call)
|
370
|
+
|
371
|
+
tool_entry = {
|
372
|
+
"turn": len(agent.tool_history) + 1,
|
373
|
+
"name": call.action[0],
|
374
|
+
"args": call.action[1],
|
375
|
+
}
|
376
|
+
|
377
|
+
if isinstance(call, SearchEmails):
|
378
|
+
sr = obs.get("search_results", [])
|
379
|
+
tool_entry["result"] = f"{len(sr)} hits"
|
380
|
+
result_details_list = []
|
381
|
+
if sr: # If there are search results
|
382
|
+
for idx, result_item in enumerate(sr):
|
383
|
+
message_id_val = result_item.get("message_id", "N/A")
|
384
|
+
snippet_from_db = result_item.get(
|
385
|
+
"snippet", ""
|
386
|
+
) # Raw snippet from the search result
|
387
|
+
|
388
|
+
# --- Logging for defaults or empty values ---
|
389
|
+
if message_id_val == "N/A":
|
390
|
+
print(
|
391
|
+
f"WARNING: SearchEmails - Result {idx + 1} - Message ID is 'N/A'. Search result item: {result_item}"
|
392
|
+
)
|
393
|
+
if not snippet_from_db:
|
394
|
+
print(
|
395
|
+
f"WARNING: SearchEmails - Result {idx + 1} - Snippet from DB is empty for Message ID '{message_id_val}'. Search result item: {result_item}"
|
396
|
+
)
|
397
|
+
|
398
|
+
snippet_for_display = snippet_from_db.replace("\n", " ")[:80]
|
399
|
+
|
400
|
+
if not snippet_for_display.strip() and snippet_from_db.strip():
|
401
|
+
print(
|
402
|
+
f"WARNING: SearchEmails - Result {idx + 1} - Processed snippet for display ('{snippet_for_display}') is effectively empty for Message ID '{message_id_val}', original DB snippet ('{snippet_from_db[:40]}...'). Item: {result_item}"
|
403
|
+
)
|
404
|
+
|
405
|
+
result_details_list.append(
|
406
|
+
f" {idx + 1}. {message_id_val} : {snippet_for_display}..."
|
407
|
+
)
|
408
|
+
tool_entry["result_detail"] = "\n".join(result_details_list)
|
409
|
+
else:
|
410
|
+
tool_entry["result_detail"] = " (No specific details for 0 hits)"
|
411
|
+
elif isinstance(call, ReadEmail):
|
412
|
+
email_data = obs.get("email") # This is a dict or None
|
413
|
+
email_txt = ""
|
414
|
+
if email_data and isinstance(email_data, dict):
|
415
|
+
email_txt = email_data.get("body", "")[:120]
|
416
|
+
tool_entry["result"] = "email_read"
|
417
|
+
tool_entry["result_detail"] = (
|
418
|
+
email_txt + "..." if email_txt else "Email not found or empty."
|
419
|
+
)
|
420
|
+
elif isinstance(call, Terminate):
|
421
|
+
tool_entry["result"] = "Session terminated"
|
422
|
+
# No result_detail needed for Terminate as per user prompt
|
423
|
+
|
424
|
+
agent.tool_history.append(tool_entry)
|
425
|
+
|
426
|
+
if obs["terminated"]:
|
427
|
+
break
|
428
|
+
return obs["terminated"] and obs["reward_last"] > 0
|
429
|
+
|
430
|
+
|
431
|
+
# ------------------------ unit-style sanity -------------------------------- #
|
432
|
+
@pytest.mark.asyncio
|
433
|
+
async def test_react_agent_enron(tmp_path: Path):
|
434
|
+
taskset = await create_enron_taskset()
|
435
|
+
inst = taskset.instances[0] # pick first QA pair
|
436
|
+
env = EnronEnvironment(inst)
|
437
|
+
llm = LM(model_name="gpt-4.1", formatting_model_name="gpt-4.1", temperature=0.0)
|
438
|
+
agent = ReActEnronAgent(llm)
|
439
|
+
solved = await run_episode(env, agent)
|
440
|
+
# Retrieve and print final total_reward from the engine snapshot
|
441
|
+
snapshot = await env.checkpoint()
|
442
|
+
print(f"Total Reward: {snapshot.total_reward}")
|
443
|
+
print(f"Partial Rewards: {snapshot.partial_rewards}")
|
444
|
+
|
445
|
+
ds = Dataset(
|
446
|
+
questions=[TrainingQuestion(id="enron_ep", intent="answer", criteria="correct")],
|
447
|
+
reward_signals=[
|
448
|
+
RewardSignal(
|
449
|
+
question_id="enron_ep",
|
450
|
+
run_id=agent.system_instance_id,
|
451
|
+
system_instance_id=agent.system_instance_id,
|
452
|
+
reward=1 if solved else 0,
|
453
|
+
error_message="",
|
454
|
+
metadata={"history": list(agent.history)},
|
455
|
+
)
|
456
|
+
],
|
457
|
+
)
|
458
|
+
# upload(ds) # optional
|
459
|
+
assert isinstance(solved, bool)
|
460
|
+
|
461
|
+
|
462
|
+
# ------------------------ quick eval over 10 test instances ---------------- #
|
463
|
+
async def eval_react_enron(n: int = 2) -> None:
|
464
|
+
ts = await create_enron_taskset()
|
465
|
+
test_insts = [i for i in ts.instances if i.metadata.split == "test"][:n]
|
466
|
+
|
467
|
+
rows: List[Dict[str, Any]] = []
|
468
|
+
|
469
|
+
async def _run(instance): # wrapper to build env/agent per instance
|
470
|
+
env = EnronEnvironment(instance)
|
471
|
+
llm = LM(
|
472
|
+
model_name="gpt-4.1-nano",
|
473
|
+
formatting_model_name="gpt-4.1-nano",
|
474
|
+
temperature=0.0,
|
475
|
+
)
|
476
|
+
agent = ReActEnronAgent(llm)
|
477
|
+
solved = await run_episode(env, agent)
|
478
|
+
# Retrieve and print final total_reward for this instance
|
479
|
+
snapshot = await env.checkpoint()
|
480
|
+
# print(f" Total Reward: {snapshot.total_reward}")
|
481
|
+
# print(f" Partial Rewards: {snapshot.partial_rewards}")
|
482
|
+
|
483
|
+
agent_answer = "Agent did not attempt to answer."
|
484
|
+
# Search agent.tool_history for the last 'answer_question' call
|
485
|
+
for tool_call in reversed(agent.tool_history):
|
486
|
+
if tool_call.get("name") == ACTION_ANSWER:
|
487
|
+
# For AnswerQuestion, 'args' in tool_history holds the answer string
|
488
|
+
agent_answer = tool_call.get(
|
489
|
+
"args",
|
490
|
+
"Agent called answer_question, but answer was not logged correctly.",
|
491
|
+
)
|
492
|
+
break
|
493
|
+
|
494
|
+
gold_answer = instance.intent.gold_state_diff["answer"]
|
495
|
+
question = instance.impetus.instructions
|
496
|
+
|
497
|
+
# print(f"\nQuestion: {question}")
|
498
|
+
# print(f" Gold Answer : {gold_answer}")
|
499
|
+
# print(f" Agent Answer : {agent_answer}")
|
500
|
+
# print(f" Solved : {solved}")
|
501
|
+
# print("-" * 40)
|
502
|
+
# collect summary row
|
503
|
+
rows.append(
|
504
|
+
{
|
505
|
+
"gold": gold_answer,
|
506
|
+
"agent": agent_answer,
|
507
|
+
"score": snapshot.total_reward,
|
508
|
+
"partials": snapshot.partial_rewards,
|
509
|
+
}
|
510
|
+
)
|
511
|
+
return solved
|
512
|
+
|
513
|
+
solved_results = await asyncio.gather(*(_run(i) for i in test_insts))
|
514
|
+
print(
|
515
|
+
f"Overall Solved: {sum(solved_results)}/{len(test_insts)} ({sum(solved_results) / len(test_insts):.0%})"
|
516
|
+
)
|
517
|
+
# Print summary table
|
518
|
+
print("\nSummary Table:")
|
519
|
+
print(f"{'Gold Answer':<40} | {'Agent Answer':<40} | {'Score':<5} | Partial Rewards")
|
520
|
+
print("-" * 100)
|
521
|
+
for r in rows:
|
522
|
+
gold = r["gold"][:40]
|
523
|
+
agent = r["agent"][:40]
|
524
|
+
score = r["score"]
|
525
|
+
partials = ",".join(str(x) for x in r["partials"])
|
526
|
+
print(f"{gold:<40} | {agent:<40} | {score:<5.1f} | {partials}")
|
527
|
+
|
528
|
+
|
529
|
+
if __name__ == "__main__":
|
530
|
+
experiment_params = {"model": "gpt-4.1-mini", "n_questions": 5}
|
531
|
+
asyncio.run(eval_react_enron(n=experiment_params["n_questions"]))
|
532
|
+
|
533
|
+
# gpt-4.1 Overall Solved: 6/15 (40%)
|
534
|
+
# gpt-4.1-mini Overall Solved: 8/15 (53%)
|
535
|
+
# gpt-4.1-nano Overall Solved: 3/15 (20%)
|
@@ -0,0 +1,156 @@
|
|
1
|
+
import sqlite3
|
2
|
+
import logging
|
3
|
+
import textwrap
|
4
|
+
from typing import List, Optional
|
5
|
+
from dataclasses import dataclass
|
6
|
+
|
7
|
+
from synth_ai.environments.environment.db.sqlite import SQLiteManager
|
8
|
+
from synth_ai.environments.examples.enron.art_helpers.types_enron import Email
|
9
|
+
|
10
|
+
# Configure logger for this module
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
if not logger.handlers: # avoid duplicate handlers in pytest -x
|
13
|
+
h = logging.StreamHandler()
|
14
|
+
h.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
|
15
|
+
logger.addHandler(h)
|
16
|
+
logger.setLevel(logging.DEBUG) # DEBUG so we see the raw SQL
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class SearchResult:
|
21
|
+
message_id: str
|
22
|
+
snippet: str
|
23
|
+
score: float
|
24
|
+
|
25
|
+
|
26
|
+
def search_emails(
|
27
|
+
sqlite_manager: SQLiteManager,
|
28
|
+
inbox: str,
|
29
|
+
keywords: List[str],
|
30
|
+
from_addr: Optional[str] = None,
|
31
|
+
to_addr: Optional[str] = None,
|
32
|
+
sent_after: Optional[str] = None,
|
33
|
+
sent_before: Optional[str] = None,
|
34
|
+
max_results: int = 10,
|
35
|
+
) -> List[SearchResult]:
|
36
|
+
"""
|
37
|
+
Searches the email database based on keywords, inbox, sender, recipient, and date range.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
sqlite_manager: The SQLiteManager instance for database operations.
|
41
|
+
inbox: The email address of the user performing the search.
|
42
|
+
Results include emails sent from or to (inc. cc/bcc) this address.
|
43
|
+
keywords: A list of keywords that must all appear in the subject or body.
|
44
|
+
from_addr: Optional email address to filter emails sent *from*.
|
45
|
+
to_addr: Optional email address to filter emails sent *to* (inc. cc/bcc).
|
46
|
+
sent_after: Optional date string 'YYYY-MM-DD'. Filters for emails sent on or after this date.
|
47
|
+
sent_before: Optional date string 'YYYY-MM-DD'. Filters for emails sent before this date.
|
48
|
+
max_results: The maximum number of results to return. Cannot exceed 10.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
A list of SearchResult objects, each containing 'message_id' and 'snippet'.
|
52
|
+
Returns an empty list if no results are found or an error occurs.
|
53
|
+
"""
|
54
|
+
|
55
|
+
if not keywords:
|
56
|
+
raise ValueError("No keywords provided for search.")
|
57
|
+
if max_results > 10:
|
58
|
+
# The user snippet implies max_results isn't part of the simplified SQL here.
|
59
|
+
# Keeping the check, but the new SQL query below does not use all filters.
|
60
|
+
# This might need reconciliation if all filters are intended to be used with the new SQL.
|
61
|
+
logger.warning(
|
62
|
+
"max_results > 10, but the provided SQL snippet for logging might not respect all filters."
|
63
|
+
)
|
64
|
+
|
65
|
+
safe_keywords = [k.replace("'", "''") for k in keywords]
|
66
|
+
fts_match_query = " ".join(f'"{k}"' for k in safe_keywords)
|
67
|
+
|
68
|
+
sql_query = textwrap.dedent("""
|
69
|
+
SELECT DISTINCT
|
70
|
+
e.message_id,
|
71
|
+
snippet(emails_fts, -1, '⟪', '⟫', ' … ', 15) AS snip
|
72
|
+
FROM emails e
|
73
|
+
JOIN emails_fts ON e.id = emails_fts.rowid
|
74
|
+
WHERE emails_fts MATCH ?
|
75
|
+
LIMIT ?
|
76
|
+
""").strip()
|
77
|
+
|
78
|
+
params = (fts_match_query, max_results)
|
79
|
+
|
80
|
+
try:
|
81
|
+
with sqlite_manager.connection() as db_conn:
|
82
|
+
rows = db_conn.execute(sql_query, params).fetchall()
|
83
|
+
return [SearchResult(message_id=row[0], snippet=row[1], score=0.0) for row in rows]
|
84
|
+
except sqlite3.Error as e:
|
85
|
+
logger.error(f"Database error during search: {e}\nSQL: {sql_query}\nParams: {params}")
|
86
|
+
return []
|
87
|
+
|
88
|
+
|
89
|
+
def read_email(sqlite_manager: SQLiteManager, message_id: str) -> Optional[Email]:
|
90
|
+
"""
|
91
|
+
Retrieves a single email by its message_id from the database.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
sqlite_manager: The SQLiteManager instance for database operations.
|
95
|
+
message_id: The unique identifier of the email to retrieve.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
An Email object containing the details of the found email,
|
99
|
+
or None if the email is not found or an error occurs.
|
100
|
+
"""
|
101
|
+
|
102
|
+
email_sql = """
|
103
|
+
SELECT id, message_id, date, subject, from_address, body, file_name
|
104
|
+
FROM emails
|
105
|
+
WHERE message_id = ?;
|
106
|
+
"""
|
107
|
+
|
108
|
+
recipients_sql = """
|
109
|
+
SELECT recipient_address, recipient_type
|
110
|
+
FROM recipients
|
111
|
+
WHERE email_id = ?;
|
112
|
+
"""
|
113
|
+
|
114
|
+
try:
|
115
|
+
with sqlite_manager.connection() as db_conn:
|
116
|
+
cursor = db_conn.cursor()
|
117
|
+
cursor.execute(email_sql, (message_id,))
|
118
|
+
email_row = cursor.fetchone()
|
119
|
+
|
120
|
+
if not email_row:
|
121
|
+
logging.warning(f"Email with message_id '{message_id}' not found.")
|
122
|
+
return None
|
123
|
+
|
124
|
+
email_id, msg_id, date, subject, from_addr, body, file_name = email_row
|
125
|
+
# Fetch recipients for this email primary key
|
126
|
+
cursor.execute(recipients_sql, (email_id,))
|
127
|
+
recipient_rows = cursor.fetchall()
|
128
|
+
except sqlite3.Error as e:
|
129
|
+
logger.error(f"Database error reading email {message_id}: {e}")
|
130
|
+
return None
|
131
|
+
|
132
|
+
to_addresses: List[str] = []
|
133
|
+
cc_addresses: List[str] = []
|
134
|
+
bcc_addresses: List[str] = []
|
135
|
+
|
136
|
+
for addr, type_val in recipient_rows:
|
137
|
+
type_lower = type_val.lower()
|
138
|
+
if type_lower == "to":
|
139
|
+
to_addresses.append(addr)
|
140
|
+
elif type_lower == "cc":
|
141
|
+
cc_addresses.append(addr)
|
142
|
+
elif type_lower == "bcc":
|
143
|
+
bcc_addresses.append(addr)
|
144
|
+
|
145
|
+
email_obj = Email(
|
146
|
+
message_id=msg_id,
|
147
|
+
date=date,
|
148
|
+
subject=subject,
|
149
|
+
from_address=from_addr,
|
150
|
+
to_addresses=to_addresses,
|
151
|
+
cc_addresses=cc_addresses,
|
152
|
+
bcc_addresses=bcc_addresses,
|
153
|
+
body=body,
|
154
|
+
file_name=file_name,
|
155
|
+
)
|
156
|
+
return email_obj
|