synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -1
- examples/swe/task_app/grpo_swe_mini.py +55 -26
- examples/swe/task_app/hosted/rollout.py +40 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/task_app/__init__.py +2 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +18 -13
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +25 -3
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +10 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
- examples/task_apps/pokemon_red/task_app.py +606 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +2 -2
- synth_ai/api/train/builders.py +25 -11
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +10 -10
- synth_ai/api/train/configs/rl.py +5 -4
- synth_ai/api/train/configs/sft.py +4 -3
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +7 -51
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +14 -7
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/rl_demo.py +8 -7
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/task_apps.py +1707 -186
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +30 -4
- synth_ai/evals/client.py +58 -61
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +16 -16
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +14 -5
- synth_ai/task/contracts.py +124 -38
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +53 -0
- synth_ai/task/rubrics/loaders.py +133 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +113 -0
- synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
- synth_ai/task/server.py +8 -7
- synth_ai/task/validators.py +269 -6
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +5 -5
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/native_manager.py +3 -3
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +214 -101
- examples/agora_ex/README_MoE.md +0 -224
- examples/agora_ex/__init__.py +0 -7
- examples/agora_ex/agora_ex.py +0 -65
- examples/agora_ex/agora_ex_task_app.py +0 -590
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
- examples/agora_ex/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/system_prompt_CURRENT.md +0 -63
- examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
- examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
- synth_ai/rubrics/__init__.py +0 -22
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,906 @@
|
|
|
1
|
+
"""Task App configuration for the GRPO Enron email QA example."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import contextlib
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import time
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Iterable, Sequence
|
|
14
|
+
from uuid import UUID, uuid4
|
|
15
|
+
|
|
16
|
+
from datasets import load_dataset
|
|
17
|
+
import httpx
|
|
18
|
+
|
|
19
|
+
from fastapi import HTTPException
|
|
20
|
+
|
|
21
|
+
from synth_ai.environments.examples.enron.environment import EnronEnvironment
|
|
22
|
+
from synth_ai.environments.examples.enron.taskset import (
|
|
23
|
+
EnronTaskInstance,
|
|
24
|
+
EnronTaskInstanceMetadata,
|
|
25
|
+
)
|
|
26
|
+
from synth_ai.environments.tasks.core import (
|
|
27
|
+
Impetus,
|
|
28
|
+
Intent,
|
|
29
|
+
SplitInfo,
|
|
30
|
+
TaskInstanceSet,
|
|
31
|
+
)
|
|
32
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
33
|
+
from synth_ai.task.contracts import (
|
|
34
|
+
RolloutMetrics,
|
|
35
|
+
RolloutRequest,
|
|
36
|
+
RolloutResponse,
|
|
37
|
+
RolloutStep,
|
|
38
|
+
RolloutTrajectory,
|
|
39
|
+
TaskInfo,
|
|
40
|
+
)
|
|
41
|
+
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
42
|
+
from synth_ai.task.rubrics import load_rubric
|
|
43
|
+
from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
44
|
+
from synth_ai.task.tracing_utils import (
|
|
45
|
+
build_tracer_factory,
|
|
46
|
+
resolve_sft_output_dir,
|
|
47
|
+
resolve_tracing_db_url,
|
|
48
|
+
tracing_env_enabled,
|
|
49
|
+
)
|
|
50
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
51
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
|
52
|
+
|
|
53
|
+
logger = logging.getLogger(__name__)
|
|
54
|
+
|
|
55
|
+
_HERE = Path(__file__).resolve()
|
|
56
|
+
REPO_ROOT = _HERE.parents[4]
|
|
57
|
+
|
|
58
|
+
DATASET_SPEC = TaskDatasetSpec(
|
|
59
|
+
id="enron_email_qa",
|
|
60
|
+
name="Enron Email QA",
|
|
61
|
+
version="1.0.0",
|
|
62
|
+
splits=["train", "test"],
|
|
63
|
+
default_split="train",
|
|
64
|
+
description="Question answering over a sample of Enron emails.",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
HF_DATASET_ID = "corbt/enron_emails_sample_questions"
|
|
68
|
+
HF_CACHE_DIR = os.path.join(
|
|
69
|
+
os.getenv("ENRON_DATASET_CACHE_DIR", str(REPO_ROOT / ".cache" / "hf-datasets"))
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
TOOLS = ["search_emails", "read_email", "answer_question", "terminate"]
|
|
73
|
+
GROQ_CHAT_URL = "https://api.groq.com/openai/v1/chat/completions"
|
|
74
|
+
DEFAULT_GROQ_MODEL = "qwen/qwen3-32b"
|
|
75
|
+
ENRON_SYSTEM_PROMPT = (
|
|
76
|
+
"You are an Enron investigations analyst. Answer the user's question by reading emails. "
|
|
77
|
+
"You can call tools to search the corpus, read specific messages, and submit a final answer. "
|
|
78
|
+
"Use the tools deliberately, gather evidence before answering, and when confident call "
|
|
79
|
+
"answer_question with your final answer. If you cannot find the answer after thorough search, "
|
|
80
|
+
"answer_question with your best attempt noting uncertainty."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _simplify(obj: Any) -> Any:
|
|
85
|
+
if isinstance(obj, (str, int, float, bool)) or obj is None:
|
|
86
|
+
return obj
|
|
87
|
+
if isinstance(obj, dict):
|
|
88
|
+
return {str(k): _simplify(v) for k, v in obj.items()}
|
|
89
|
+
if isinstance(obj, (list, tuple, set)):
|
|
90
|
+
return [_simplify(v) for v in obj]
|
|
91
|
+
return str(obj)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _render_search_results(results: list[dict[str, Any]]) -> str:
|
|
95
|
+
if not results:
|
|
96
|
+
return "No search results."
|
|
97
|
+
lines = []
|
|
98
|
+
for item in results[:5]:
|
|
99
|
+
message_id = item.get("message_id") or item.get("id") or "<unknown>"
|
|
100
|
+
snippet = (item.get("snippet") or item.get("snip") or "").strip()
|
|
101
|
+
lines.append(f"- {message_id}: {snippet[:280]}")
|
|
102
|
+
return "\n".join(lines)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _render_email(email: dict[str, Any] | None) -> str:
|
|
106
|
+
if not email:
|
|
107
|
+
return "No email loaded."
|
|
108
|
+
subject = email.get("subject", "<no subject>")
|
|
109
|
+
from_addr = email.get("from_address") or email.get("from_addr") or "<unknown>"
|
|
110
|
+
date = email.get("date", "<unknown date>")
|
|
111
|
+
snippet = (email.get("body") or "")[:600]
|
|
112
|
+
return f"Subject: {subject}\nFrom: {from_addr}\nDate: {date}\nBody Preview:\n{snippet}"
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _render_observation(obs: dict[str, Any]) -> str:
|
|
116
|
+
lines = [
|
|
117
|
+
f"Question: {obs.get('question', '')}",
|
|
118
|
+
f"Already answered: {bool(obs.get('already_answered'))}",
|
|
119
|
+
f"Available tools: {', '.join(obs.get('tools') or [])}",
|
|
120
|
+
f"Inbox address: {obs.get('inbox_address', '<unknown>')}",
|
|
121
|
+
f"Reward Δ: {obs.get('reward_last', 0)} Total Reward: {obs.get('total_reward', 0)}",
|
|
122
|
+
]
|
|
123
|
+
tool_error = obs.get("tool_error")
|
|
124
|
+
if tool_error:
|
|
125
|
+
lines.append(f"Last tool error: {tool_error}")
|
|
126
|
+
search_results = obs.get("search_results") or []
|
|
127
|
+
if search_results:
|
|
128
|
+
lines.append("Search Results:")
|
|
129
|
+
lines.append(_render_search_results(search_results))
|
|
130
|
+
email = obs.get("email")
|
|
131
|
+
if email:
|
|
132
|
+
lines.append("Email Content:")
|
|
133
|
+
lines.append(_render_email(email))
|
|
134
|
+
gold = obs.get("gold_answer")
|
|
135
|
+
if gold and obs.get("terminated"):
|
|
136
|
+
lines.append(f"Gold Answer: {gold}")
|
|
137
|
+
return "\n".join(lines)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _conversation_message(role: str, content: Any, **metadata: Any) -> dict[str, Any]:
|
|
141
|
+
if isinstance(content, (dict, list)):
|
|
142
|
+
rendered = json.dumps(_simplify(content), ensure_ascii=False)
|
|
143
|
+
else:
|
|
144
|
+
rendered = str(content)
|
|
145
|
+
message: dict[str, Any] = {"role": role, "content": rendered}
|
|
146
|
+
message.update({k: v for k, v in metadata.items() if v is not None})
|
|
147
|
+
return message
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _build_trace_payload_enron(
|
|
151
|
+
run_id: str,
|
|
152
|
+
request: RolloutRequest,
|
|
153
|
+
steps: list[RolloutStep],
|
|
154
|
+
metrics: RolloutMetrics,
|
|
155
|
+
*,
|
|
156
|
+
provider: str,
|
|
157
|
+
model: str,
|
|
158
|
+
conversation: list[dict[str, Any]],
|
|
159
|
+
metadata: dict[str, Any] | None = None,
|
|
160
|
+
) -> dict[str, Any]:
|
|
161
|
+
created_at = datetime.now(timezone.utc)
|
|
162
|
+
event_time = time.time()
|
|
163
|
+
session_steps: list[dict[str, Any]] = []
|
|
164
|
+
event_history: list[dict[str, Any]] = []
|
|
165
|
+
markov_history: list[dict[str, Any]] = []
|
|
166
|
+
for msg in conversation:
|
|
167
|
+
event_time += 0.005
|
|
168
|
+
markov_history.append(
|
|
169
|
+
{
|
|
170
|
+
"content": {"text": msg.get("content", "")},
|
|
171
|
+
"message_type": msg.get("role", "system"),
|
|
172
|
+
"time_record": {"event_time": event_time},
|
|
173
|
+
"metadata": _simplify({k: v for k, v in msg.items() if k not in {"role", "content"}}),
|
|
174
|
+
}
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
session_trace = {
|
|
178
|
+
"session_id": run_id,
|
|
179
|
+
"created_at": created_at.isoformat(),
|
|
180
|
+
"metadata": {
|
|
181
|
+
"task": "enron_email_qa",
|
|
182
|
+
"provider": provider,
|
|
183
|
+
"model": model,
|
|
184
|
+
"policy": _simplify(request.policy.model_dump() if request.policy else {}),
|
|
185
|
+
"env": _simplify(request.env.model_dump() if request.env else {}),
|
|
186
|
+
**(_simplify(metadata or {})),
|
|
187
|
+
},
|
|
188
|
+
"session_time_steps": session_steps,
|
|
189
|
+
"event_history": event_history,
|
|
190
|
+
"markov_blanket_message_history": markov_history,
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
return {
|
|
194
|
+
"version": 3,
|
|
195
|
+
"session_trace": session_trace,
|
|
196
|
+
"run_id": run_id,
|
|
197
|
+
"policy_id": request.policy.policy_id or request.policy.policy_name,
|
|
198
|
+
"reward": metrics.mean_return,
|
|
199
|
+
"episode_returns": metrics.episode_returns,
|
|
200
|
+
"mean_return": metrics.mean_return,
|
|
201
|
+
"num_steps": metrics.num_steps,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
async def _call_groq_chat(
|
|
206
|
+
client: httpx.AsyncClient,
|
|
207
|
+
api_key: str,
|
|
208
|
+
payload: dict[str, Any],
|
|
209
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
210
|
+
response = await client.post(
|
|
211
|
+
GROQ_CHAT_URL,
|
|
212
|
+
json=payload,
|
|
213
|
+
headers={"Authorization": f"Bearer {api_key}"},
|
|
214
|
+
)
|
|
215
|
+
if response.status_code >= 400:
|
|
216
|
+
try:
|
|
217
|
+
body = response.json()
|
|
218
|
+
except Exception:
|
|
219
|
+
body = {"raw": response.text}
|
|
220
|
+
detail = {
|
|
221
|
+
"status": response.status_code,
|
|
222
|
+
"body": body,
|
|
223
|
+
"headers": dict(response.headers),
|
|
224
|
+
}
|
|
225
|
+
raise HTTPException(status_code=response.status_code, detail=detail)
|
|
226
|
+
data = response.json()
|
|
227
|
+
return data, {
|
|
228
|
+
"status": response.status_code,
|
|
229
|
+
"headers": dict(response.headers),
|
|
230
|
+
"body": data,
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _load_taskset_blocking() -> TaskInstanceSet:
|
|
235
|
+
"""Build the Enron taskset synchronously."""
|
|
236
|
+
|
|
237
|
+
cache_dir = Path(HF_CACHE_DIR)
|
|
238
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
239
|
+
|
|
240
|
+
ds_train = load_dataset(HF_DATASET_ID, split="train", cache_dir=cache_dir)
|
|
241
|
+
ds_test = load_dataset(HF_DATASET_ID, split="test", cache_dir=cache_dir)
|
|
242
|
+
|
|
243
|
+
def _metadata_from_row(row: dict[str, Any], split: str) -> EnronTaskInstance:
|
|
244
|
+
question = str(row.get("question") or "").strip()
|
|
245
|
+
answer = str(row.get("answer") or "").strip()
|
|
246
|
+
message_ids = row.get("message_ids") or []
|
|
247
|
+
if not isinstance(message_ids, list):
|
|
248
|
+
message_ids = list(message_ids)
|
|
249
|
+
impetus = Impetus(instructions=question)
|
|
250
|
+
intent = Intent(
|
|
251
|
+
rubric={"goal": "Answer the question using the Enron emails."},
|
|
252
|
+
gold_trajectories=None,
|
|
253
|
+
gold_state_diff={"answer": answer},
|
|
254
|
+
)
|
|
255
|
+
metadata = EnronTaskInstanceMetadata(
|
|
256
|
+
split=split,
|
|
257
|
+
email_count=len(message_ids),
|
|
258
|
+
message_ids=message_ids,
|
|
259
|
+
)
|
|
260
|
+
return EnronTaskInstance(
|
|
261
|
+
id=uuid4(),
|
|
262
|
+
impetus=impetus,
|
|
263
|
+
intent=intent,
|
|
264
|
+
metadata=metadata,
|
|
265
|
+
is_reproducible=True,
|
|
266
|
+
initial_engine_snapshot=row,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
train_instances = [_metadata_from_row(r, "train") for r in ds_train]
|
|
270
|
+
test_instances = [_metadata_from_row(r, "test") for r in ds_test]
|
|
271
|
+
|
|
272
|
+
split_info = SplitInfo(
|
|
273
|
+
val_instance_ids=set(),
|
|
274
|
+
test_instance_ids={inst.id for inst in test_instances},
|
|
275
|
+
_is_split_defined=True,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return TaskInstanceSet(
|
|
279
|
+
name="Enron-QA",
|
|
280
|
+
description="QA over Enron email dataset sample.",
|
|
281
|
+
instances=train_instances + test_instances,
|
|
282
|
+
split_info=split_info,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _safe_uuid(value: Any) -> UUID:
|
|
287
|
+
if isinstance(value, UUID):
|
|
288
|
+
return value
|
|
289
|
+
try:
|
|
290
|
+
return UUID(str(value))
|
|
291
|
+
except Exception:
|
|
292
|
+
return UUID(int=0)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@dataclass
|
|
296
|
+
class EnronDataset:
|
|
297
|
+
spec: TaskDatasetSpec
|
|
298
|
+
|
|
299
|
+
def __post_init__(self) -> None:
|
|
300
|
+
self._taskset = _load_taskset_blocking()
|
|
301
|
+
self.instances: list[EnronTaskInstance] = list(self._taskset.instances)
|
|
302
|
+
self.instance_ids = [str(_safe_uuid(inst.id)) for inst in self.instances]
|
|
303
|
+
self.default_seed = 0
|
|
304
|
+
self.seed_min = 0
|
|
305
|
+
self.seed_max = max(len(self.instances) - 1, 0)
|
|
306
|
+
|
|
307
|
+
def describe(self) -> dict[str, Any]:
|
|
308
|
+
return {
|
|
309
|
+
**self.spec.model_dump(),
|
|
310
|
+
"instance_count": len(self.instances),
|
|
311
|
+
"instance_ids": self.instance_ids[:50],
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
def instance_by_seed(self, seed: int | None) -> EnronTaskInstance:
|
|
315
|
+
if not self.instances:
|
|
316
|
+
raise ValueError("Enron dataset is empty.")
|
|
317
|
+
if seed is None:
|
|
318
|
+
index = 0
|
|
319
|
+
else:
|
|
320
|
+
index = int(seed) % len(self.instances)
|
|
321
|
+
return self.instances[index]
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def build_dataset() -> tuple[TaskDatasetRegistry, EnronDataset]:
|
|
325
|
+
registry = TaskDatasetRegistry()
|
|
326
|
+
dataset = EnronDataset(DATASET_SPEC)
|
|
327
|
+
registry.register(DATASET_SPEC, lambda _spec: dataset, cache=True)
|
|
328
|
+
return registry, dataset
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _base_task_info(dataset: EnronDataset) -> TaskInfo:
|
|
332
|
+
return TaskInfo(
|
|
333
|
+
task={"id": "enron_email_qa", "name": "Enron Email QA", "version": "1.0.0"},
|
|
334
|
+
environment="enron",
|
|
335
|
+
action_space={
|
|
336
|
+
"type": "tool_calls",
|
|
337
|
+
"tools": TOOLS,
|
|
338
|
+
"description": "Tool-assisted QA workflow over an email corpus.",
|
|
339
|
+
},
|
|
340
|
+
observation={
|
|
341
|
+
"summary": "Text observations describing the question, tool status, and last reward.",
|
|
342
|
+
"format": "text",
|
|
343
|
+
},
|
|
344
|
+
dataset={**dataset.describe(), "default_seed": dataset.default_seed},
|
|
345
|
+
rubric={
|
|
346
|
+
"version": "1",
|
|
347
|
+
"criteria_count": 1,
|
|
348
|
+
"source": "inline",
|
|
349
|
+
"aggregation": "weighted_sum",
|
|
350
|
+
},
|
|
351
|
+
inference={
|
|
352
|
+
"supports_proxy": False,
|
|
353
|
+
"endpoints": {},
|
|
354
|
+
"tool": {"name": "enron_tools", "parallel_tool_calls": False},
|
|
355
|
+
},
|
|
356
|
+
limits={"max_ops": 0, "max_time_s": 900},
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
OUTCOME_RUBRIC = load_rubric(
|
|
361
|
+
{
|
|
362
|
+
"version": "1",
|
|
363
|
+
"goal_text": "Provide the correct answer to the question using the Enron emails.",
|
|
364
|
+
"aggregation": "weighted_sum",
|
|
365
|
+
"criteria": [
|
|
366
|
+
{
|
|
367
|
+
"id": "accuracy",
|
|
368
|
+
"description": "Final answer matches the gold answer.",
|
|
369
|
+
"weight": 1.0,
|
|
370
|
+
}
|
|
371
|
+
],
|
|
372
|
+
}
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
EVENTS_RUBRIC = load_rubric(
|
|
376
|
+
{
|
|
377
|
+
"version": "1",
|
|
378
|
+
"goal_text": "Encourage efficient use of tools when exploring the corpus.",
|
|
379
|
+
"aggregation": "weighted_sum",
|
|
380
|
+
"criteria": [
|
|
381
|
+
{
|
|
382
|
+
"id": "tool_use",
|
|
383
|
+
"description": "Use search, read, and answer tools deliberately.",
|
|
384
|
+
"weight": 1.0,
|
|
385
|
+
}
|
|
386
|
+
],
|
|
387
|
+
}
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def describe_taskset(dataset: EnronDataset) -> dict[str, Any]:
|
|
392
|
+
return dataset.describe()
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def provide_task_instances(
|
|
396
|
+
dataset: EnronDataset, base_info: TaskInfo, seeds: Sequence[int]
|
|
397
|
+
) -> Iterable[TaskInfo]:
|
|
398
|
+
infos: list[TaskInfo] = []
|
|
399
|
+
base_observation = getattr(base_info, "observation", None)
|
|
400
|
+
if hasattr(base_observation, "model_dump"):
|
|
401
|
+
observation_template = base_observation.model_dump()
|
|
402
|
+
elif isinstance(base_observation, dict):
|
|
403
|
+
observation_template = dict(base_observation)
|
|
404
|
+
else:
|
|
405
|
+
observation_template = {}
|
|
406
|
+
|
|
407
|
+
for seed in seeds:
|
|
408
|
+
instance = dataset.instance_by_seed(seed)
|
|
409
|
+
metadata = instance.metadata
|
|
410
|
+
meta_dict = {
|
|
411
|
+
"split": getattr(metadata, "split", None),
|
|
412
|
+
"email_count": getattr(metadata, "email_count", None),
|
|
413
|
+
"message_ids": getattr(metadata, "message_ids", None),
|
|
414
|
+
}
|
|
415
|
+
infos.append(
|
|
416
|
+
TaskInfo(
|
|
417
|
+
task=base_info.task,
|
|
418
|
+
environment=base_info.environment,
|
|
419
|
+
action_space=base_info.action_space,
|
|
420
|
+
observation={
|
|
421
|
+
**observation_template,
|
|
422
|
+
"question": instance.impetus.instructions,
|
|
423
|
+
},
|
|
424
|
+
dataset={
|
|
425
|
+
**base_info.dataset.model_dump(),
|
|
426
|
+
"instance_id": str(_safe_uuid(instance.id)),
|
|
427
|
+
"metadata": meta_dict,
|
|
428
|
+
},
|
|
429
|
+
rubric=base_info.rubric,
|
|
430
|
+
inference=base_info.inference,
|
|
431
|
+
limits=base_info.limits,
|
|
432
|
+
)
|
|
433
|
+
)
|
|
434
|
+
return infos
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def _ensure_dataset_from_state(fastapi_request, fallback: EnronDataset) -> EnronDataset:
|
|
438
|
+
if fastapi_request is None:
|
|
439
|
+
return fallback
|
|
440
|
+
dataset = getattr(getattr(fastapi_request, "app", None), "state", None)
|
|
441
|
+
candidate = getattr(dataset, "dataset", None)
|
|
442
|
+
return candidate or fallback
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _normalise_observation(value: Any) -> dict[str, Any]:
|
|
446
|
+
if isinstance(value, dict):
|
|
447
|
+
return value
|
|
448
|
+
if hasattr(value, "observation"):
|
|
449
|
+
obs = getattr(value, "observation")
|
|
450
|
+
if isinstance(obs, dict):
|
|
451
|
+
return obs
|
|
452
|
+
return {"text": str(obs)}
|
|
453
|
+
return {"text": str(value)}
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutResponse:
|
|
457
|
+
policy_cfg = dict(request.policy.config or {})
|
|
458
|
+
provider = str(policy_cfg.get("provider") or "").strip().lower()
|
|
459
|
+
if provider == "groq":
|
|
460
|
+
return await _rollout_with_groq(request, fastapi_request, policy_cfg)
|
|
461
|
+
|
|
462
|
+
# Fallback: return initial observation but include minimal trace payload
|
|
463
|
+
dataset = _ensure_dataset_from_state(fastapi_request, RUNTIME_DATASET)
|
|
464
|
+
env_seed = getattr(request.env, "seed", None) if request and request.env else None
|
|
465
|
+
instance = dataset.instance_by_seed(env_seed)
|
|
466
|
+
env = EnronEnvironment(task_instance=instance)
|
|
467
|
+
env.custom_obs = None
|
|
468
|
+
try:
|
|
469
|
+
initial_observation = await env.initialize()
|
|
470
|
+
finally:
|
|
471
|
+
with contextlib.suppress(Exception):
|
|
472
|
+
await env.terminate()
|
|
473
|
+
|
|
474
|
+
obs_dict = _normalise_observation(initial_observation)
|
|
475
|
+
step = RolloutStep(
|
|
476
|
+
obs=obs_dict,
|
|
477
|
+
tool_calls=[],
|
|
478
|
+
reward=0.0,
|
|
479
|
+
done=True,
|
|
480
|
+
truncated=None,
|
|
481
|
+
info={"note": "No rollout executed; provider unset."},
|
|
482
|
+
)
|
|
483
|
+
# No inference_url for noop policy
|
|
484
|
+
trajectory = RolloutTrajectory(
|
|
485
|
+
env_id=request.env.env_id or "enron",
|
|
486
|
+
policy_id=request.policy.policy_id or request.policy.policy_name or "noop-policy",
|
|
487
|
+
steps=[step],
|
|
488
|
+
final={"observation": obs_dict},
|
|
489
|
+
length=1,
|
|
490
|
+
inference_url=None, # NEW: No inference for noop policy
|
|
491
|
+
decision_samples=None,
|
|
492
|
+
)
|
|
493
|
+
metrics = RolloutMetrics(
|
|
494
|
+
episode_returns=[0.0],
|
|
495
|
+
mean_return=0.0,
|
|
496
|
+
num_steps=1,
|
|
497
|
+
num_episodes=1,
|
|
498
|
+
outcome_score=None,
|
|
499
|
+
events_score=None,
|
|
500
|
+
details={"note": "Provider not configured; returning initial state."},
|
|
501
|
+
)
|
|
502
|
+
trace_payload = _build_trace_payload_enron(
|
|
503
|
+
request.run_id,
|
|
504
|
+
request,
|
|
505
|
+
[step],
|
|
506
|
+
metrics,
|
|
507
|
+
provider="local",
|
|
508
|
+
model=policy_cfg.get("model") or "noop",
|
|
509
|
+
conversation=[
|
|
510
|
+
_conversation_message("system", ENRON_SYSTEM_PROMPT),
|
|
511
|
+
_conversation_message("user", _render_observation(obs_dict)),
|
|
512
|
+
],
|
|
513
|
+
metadata={"mode": "noop"},
|
|
514
|
+
)
|
|
515
|
+
return RolloutResponse(
|
|
516
|
+
run_id=request.run_id,
|
|
517
|
+
trajectories=[trajectory],
|
|
518
|
+
branches={},
|
|
519
|
+
metrics=metrics,
|
|
520
|
+
aborted=False,
|
|
521
|
+
ops_executed=0,
|
|
522
|
+
trace=trace_payload,
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def _prepare_tool_call(
|
|
527
|
+
tool_name: str,
|
|
528
|
+
raw_args: dict[str, Any],
|
|
529
|
+
current_obs: dict[str, Any],
|
|
530
|
+
) -> EnvToolCall:
|
|
531
|
+
if tool_name == "search_emails":
|
|
532
|
+
keywords = raw_args.get("keywords")
|
|
533
|
+
if isinstance(keywords, str):
|
|
534
|
+
keywords = [k.strip() for k in keywords.split(",") if k.strip()]
|
|
535
|
+
if not isinstance(keywords, list) or not keywords:
|
|
536
|
+
raise ValueError("search_emails requires a non-empty list of keywords.")
|
|
537
|
+
inbox = raw_args.get("inbox") or current_obs.get("inbox_address") or "investigator@enron.com"
|
|
538
|
+
args = {
|
|
539
|
+
"inbox": str(inbox),
|
|
540
|
+
"keywords": [str(k) for k in keywords],
|
|
541
|
+
"from_addr": raw_args.get("from_addr"),
|
|
542
|
+
"to_addr": raw_args.get("to_addr"),
|
|
543
|
+
"sent_after": raw_args.get("sent_after"),
|
|
544
|
+
"sent_before": raw_args.get("sent_before"),
|
|
545
|
+
"max_results": int(raw_args.get("max_results") or 5),
|
|
546
|
+
}
|
|
547
|
+
return EnvToolCall(tool="search_emails", args=args)
|
|
548
|
+
|
|
549
|
+
if tool_name == "read_email":
|
|
550
|
+
message_id = raw_args.get("message_id")
|
|
551
|
+
if not message_id:
|
|
552
|
+
raise ValueError("read_email requires 'message_id'.")
|
|
553
|
+
return EnvToolCall(tool="read_email", args={"message_id": str(message_id)})
|
|
554
|
+
|
|
555
|
+
if tool_name == "answer_question":
|
|
556
|
+
answer = raw_args.get("answer")
|
|
557
|
+
if not isinstance(answer, str) or not answer.strip():
|
|
558
|
+
raise ValueError("answer_question requires a non-empty 'answer'.")
|
|
559
|
+
return EnvToolCall(tool="answer_question", args={"answer": answer.strip()})
|
|
560
|
+
|
|
561
|
+
if tool_name == "terminate":
|
|
562
|
+
return EnvToolCall(tool="terminate", args={})
|
|
563
|
+
|
|
564
|
+
raise ValueError(f"Unsupported tool '{tool_name}'")
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
async def _rollout_with_groq(
|
|
568
|
+
request: RolloutRequest,
|
|
569
|
+
fastapi_request,
|
|
570
|
+
config: dict[str, Any],
|
|
571
|
+
) -> RolloutResponse:
|
|
572
|
+
api_key = os.getenv("GROQ_API_KEY")
|
|
573
|
+
if not api_key:
|
|
574
|
+
raise HTTPException(
|
|
575
|
+
status_code=503,
|
|
576
|
+
detail="GROQ_API_KEY environment variable is required for Groq rollouts.",
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
dataset = _ensure_dataset_from_state(fastapi_request, RUNTIME_DATASET)
|
|
580
|
+
env_seed = getattr(request.env, "seed", None) if request and request.env else None
|
|
581
|
+
instance = dataset.instance_by_seed(env_seed)
|
|
582
|
+
env = EnronEnvironment(task_instance=instance)
|
|
583
|
+
env.custom_obs = None
|
|
584
|
+
|
|
585
|
+
metadata_extra = {
|
|
586
|
+
"split": getattr(instance.metadata, "split", None),
|
|
587
|
+
"email_count": getattr(instance.metadata, "email_count", None),
|
|
588
|
+
"message_ids": list(getattr(instance.metadata, "message_ids", []))[:10],
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
model = config.get("model") or DEFAULT_GROQ_MODEL
|
|
592
|
+
temperature = float(config.get("temperature", 0.2) or 0.2)
|
|
593
|
+
top_p = float(config.get("top_p", 0.8) or 0.8)
|
|
594
|
+
max_tokens = int(config.get("max_tokens", 768) or 768)
|
|
595
|
+
max_turns = int(config.get("max_turns", config.get("max_steps", 12)) or 12)
|
|
596
|
+
|
|
597
|
+
tool_schemas = [
|
|
598
|
+
{
|
|
599
|
+
"type": "function",
|
|
600
|
+
"function": {
|
|
601
|
+
"name": "search_emails",
|
|
602
|
+
"description": "Search the Enron corpus for emails matching keywords.",
|
|
603
|
+
"parameters": {
|
|
604
|
+
"type": "object",
|
|
605
|
+
"properties": {
|
|
606
|
+
"inbox": {"type": "string", "description": "Email address performing the search."},
|
|
607
|
+
"keywords": {
|
|
608
|
+
"type": "array",
|
|
609
|
+
"items": {"type": "string"},
|
|
610
|
+
"minItems": 1,
|
|
611
|
+
"description": "Keywords to include in the search.",
|
|
612
|
+
},
|
|
613
|
+
"from_addr": {"type": "string"},
|
|
614
|
+
"to_addr": {"type": "string"},
|
|
615
|
+
"sent_after": {"type": "string", "description": "YYYY-MM-DD"},
|
|
616
|
+
"sent_before": {"type": "string", "description": "YYYY-MM-DD"},
|
|
617
|
+
"max_results": {"type": "integer", "minimum": 1, "maximum": 10},
|
|
618
|
+
},
|
|
619
|
+
"required": ["keywords"],
|
|
620
|
+
"additionalProperties": False,
|
|
621
|
+
},
|
|
622
|
+
},
|
|
623
|
+
},
|
|
624
|
+
{
|
|
625
|
+
"type": "function",
|
|
626
|
+
"function": {
|
|
627
|
+
"name": "read_email",
|
|
628
|
+
"description": "Read the full contents of an email by message_id.",
|
|
629
|
+
"parameters": {
|
|
630
|
+
"type": "object",
|
|
631
|
+
"properties": {"message_id": {"type": "string"}},
|
|
632
|
+
"required": ["message_id"],
|
|
633
|
+
"additionalProperties": False,
|
|
634
|
+
},
|
|
635
|
+
},
|
|
636
|
+
},
|
|
637
|
+
{
|
|
638
|
+
"type": "function",
|
|
639
|
+
"function": {
|
|
640
|
+
"name": "answer_question",
|
|
641
|
+
"description": "Submit the final answer to the investigation question.",
|
|
642
|
+
"parameters": {
|
|
643
|
+
"type": "object",
|
|
644
|
+
"properties": {"answer": {"type": "string"}},
|
|
645
|
+
"required": ["answer"],
|
|
646
|
+
"additionalProperties": False,
|
|
647
|
+
},
|
|
648
|
+
},
|
|
649
|
+
},
|
|
650
|
+
{
|
|
651
|
+
"type": "function",
|
|
652
|
+
"function": {
|
|
653
|
+
"name": "terminate",
|
|
654
|
+
"description": "Terminate the investigation without answering.",
|
|
655
|
+
"parameters": {"type": "object", "properties": {}, "additionalProperties": False},
|
|
656
|
+
},
|
|
657
|
+
},
|
|
658
|
+
]
|
|
659
|
+
|
|
660
|
+
steps: list[RolloutStep] = []
|
|
661
|
+
conversation: list[dict[str, Any]] = []
|
|
662
|
+
executed = 0
|
|
663
|
+
try:
|
|
664
|
+
observation = await env.initialize()
|
|
665
|
+
obs_dict = _normalise_observation(observation)
|
|
666
|
+
conversation.append(_conversation_message("system", ENRON_SYSTEM_PROMPT))
|
|
667
|
+
conversation.append(_conversation_message("user", _render_observation(obs_dict)))
|
|
668
|
+
|
|
669
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
|
|
670
|
+
for turn in range(max_turns):
|
|
671
|
+
payload = {
|
|
672
|
+
"model": model,
|
|
673
|
+
"messages": conversation,
|
|
674
|
+
"temperature": temperature,
|
|
675
|
+
"top_p": top_p,
|
|
676
|
+
"max_tokens": max_tokens,
|
|
677
|
+
"tools": tool_schemas,
|
|
678
|
+
"tool_choice": "auto",
|
|
679
|
+
}
|
|
680
|
+
vendor_attempts: list[dict[str, Any]] = []
|
|
681
|
+
response, response_meta = await _call_groq_chat(client, api_key, payload)
|
|
682
|
+
vendor_attempts.append({"request": payload, "response": response_meta})
|
|
683
|
+
|
|
684
|
+
choices = response.get("choices") or []
|
|
685
|
+
if not choices:
|
|
686
|
+
break
|
|
687
|
+
message = choices[0].get("message") or {}
|
|
688
|
+
tool_calls = message.get("tool_calls") or []
|
|
689
|
+
assistant_msg_meta = {"tool_calls": _simplify(tool_calls)} if tool_calls else {}
|
|
690
|
+
conversation.append(
|
|
691
|
+
_conversation_message("assistant", message.get("content") or "", **assistant_msg_meta)
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
tool_call_records: list[dict[str, Any]] = []
|
|
695
|
+
step_reward = 0.0
|
|
696
|
+
done = False
|
|
697
|
+
truncated = False
|
|
698
|
+
|
|
699
|
+
if not tool_calls:
|
|
700
|
+
final_answer = (message.get("content") or "").strip()
|
|
701
|
+
if final_answer:
|
|
702
|
+
env_call = EnvToolCall(tool="answer_question", args={"answer": final_answer})
|
|
703
|
+
observation = await env.step(env_call)
|
|
704
|
+
executed += 1
|
|
705
|
+
obs_dict = _normalise_observation(observation)
|
|
706
|
+
step_reward += float(obs_dict.get("reward_last") or 0.0)
|
|
707
|
+
done = bool(obs_dict.get("terminated"))
|
|
708
|
+
truncated = bool(obs_dict.get("truncated"))
|
|
709
|
+
tool_call_records.append({"tool": "answer_question", "args": env_call.args})
|
|
710
|
+
conversation.append(
|
|
711
|
+
_conversation_message(
|
|
712
|
+
"tool",
|
|
713
|
+
{"result": "answer_submitted", "observation": obs_dict},
|
|
714
|
+
name="answer_question",
|
|
715
|
+
)
|
|
716
|
+
)
|
|
717
|
+
else:
|
|
718
|
+
break
|
|
719
|
+
else:
|
|
720
|
+
for call in tool_calls:
|
|
721
|
+
func = call.get("function") or {}
|
|
722
|
+
name = func.get("name")
|
|
723
|
+
raw_args = func.get("arguments")
|
|
724
|
+
if isinstance(raw_args, str):
|
|
725
|
+
try:
|
|
726
|
+
parsed_args = json.loads(raw_args)
|
|
727
|
+
except json.JSONDecodeError:
|
|
728
|
+
parsed_args = {}
|
|
729
|
+
elif isinstance(raw_args, dict):
|
|
730
|
+
parsed_args = raw_args
|
|
731
|
+
else:
|
|
732
|
+
parsed_args = {}
|
|
733
|
+
|
|
734
|
+
env_call = _prepare_tool_call(name, parsed_args, obs_dict)
|
|
735
|
+
observation = await env.step(env_call)
|
|
736
|
+
executed += 1
|
|
737
|
+
obs_dict = _normalise_observation(observation)
|
|
738
|
+
reward_delta = float(obs_dict.get("reward_last") or 0.0)
|
|
739
|
+
step_reward += reward_delta
|
|
740
|
+
done = bool(obs_dict.get("terminated"))
|
|
741
|
+
truncated = bool(obs_dict.get("truncated"))
|
|
742
|
+
tool_call_records.append({"tool": env_call.tool, "args": env_call.args})
|
|
743
|
+
conversation.append(
|
|
744
|
+
_conversation_message(
|
|
745
|
+
"tool",
|
|
746
|
+
{
|
|
747
|
+
"tool": env_call.tool,
|
|
748
|
+
"args": env_call.args,
|
|
749
|
+
"reward_delta": reward_delta,
|
|
750
|
+
"observation": obs_dict,
|
|
751
|
+
},
|
|
752
|
+
name=env_call.tool,
|
|
753
|
+
tool_call_id=call.get("id"),
|
|
754
|
+
)
|
|
755
|
+
)
|
|
756
|
+
if done or truncated:
|
|
757
|
+
break
|
|
758
|
+
|
|
759
|
+
conversation.append(_conversation_message("user", _render_observation(obs_dict)))
|
|
760
|
+
|
|
761
|
+
step = RolloutStep(
|
|
762
|
+
obs=obs_dict,
|
|
763
|
+
tool_calls=tool_call_records,
|
|
764
|
+
reward=step_reward,
|
|
765
|
+
done=done,
|
|
766
|
+
truncated=truncated if truncated else None,
|
|
767
|
+
info={
|
|
768
|
+
"provider": "groq",
|
|
769
|
+
"model": model,
|
|
770
|
+
"vendor_attempts": vendor_attempts,
|
|
771
|
+
"turn": turn,
|
|
772
|
+
},
|
|
773
|
+
)
|
|
774
|
+
steps.append(step)
|
|
775
|
+
|
|
776
|
+
if done or truncated:
|
|
777
|
+
break
|
|
778
|
+
finally:
|
|
779
|
+
with contextlib.suppress(Exception):
|
|
780
|
+
await env.terminate()
|
|
781
|
+
|
|
782
|
+
if steps:
|
|
783
|
+
final_obs = steps[-1].obs
|
|
784
|
+
total_reward = float(final_obs.get("total_reward") or 0.0)
|
|
785
|
+
else:
|
|
786
|
+
total_reward = 0.0
|
|
787
|
+
|
|
788
|
+
metrics = RolloutMetrics(
|
|
789
|
+
episode_returns=[total_reward],
|
|
790
|
+
mean_return=total_reward if steps else 0.0,
|
|
791
|
+
num_steps=len(steps),
|
|
792
|
+
num_episodes=1,
|
|
793
|
+
outcome_score=None,
|
|
794
|
+
events_score=None,
|
|
795
|
+
details={"provider": "groq", "model": model},
|
|
796
|
+
)
|
|
797
|
+
inference_url_groq = "https://api.groq.com/openai/v1/chat/completions"
|
|
798
|
+
|
|
799
|
+
trajectory = RolloutTrajectory(
|
|
800
|
+
env_id=request.env.env_id or "enron",
|
|
801
|
+
policy_id=request.policy.policy_id or request.policy.policy_name or "enron-groq",
|
|
802
|
+
steps=steps,
|
|
803
|
+
final={"observation": steps[-1].obs if steps else {}},
|
|
804
|
+
length=len(steps),
|
|
805
|
+
inference_url=inference_url_groq, # NEW: Required for trace correlation
|
|
806
|
+
decision_samples=None,
|
|
807
|
+
)
|
|
808
|
+
trace_payload = _build_trace_payload_enron(
|
|
809
|
+
request.run_id,
|
|
810
|
+
request,
|
|
811
|
+
steps,
|
|
812
|
+
metrics,
|
|
813
|
+
provider="groq",
|
|
814
|
+
model=model,
|
|
815
|
+
conversation=conversation,
|
|
816
|
+
metadata=metadata_extra,
|
|
817
|
+
)
|
|
818
|
+
return RolloutResponse(
|
|
819
|
+
run_id=request.run_id,
|
|
820
|
+
trajectories=[trajectory],
|
|
821
|
+
branches={},
|
|
822
|
+
metrics=metrics,
|
|
823
|
+
aborted=False,
|
|
824
|
+
ops_executed=executed,
|
|
825
|
+
trace=trace_payload,
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
RUNTIME_DATASET: EnronDataset
|
|
830
|
+
registry, RUNTIME_DATASET = build_dataset()
|
|
831
|
+
BASE_INFO = _base_task_info(RUNTIME_DATASET)
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
def build_config() -> TaskAppConfig:
|
|
835
|
+
tracing_enabled = tracing_env_enabled()
|
|
836
|
+
tracing_db_url = resolve_tracing_db_url()
|
|
837
|
+
tracer_factory = build_tracer_factory(
|
|
838
|
+
SessionTracer, enabled=tracing_enabled, db_url=tracing_db_url
|
|
839
|
+
)
|
|
840
|
+
sft_output_dir = resolve_sft_output_dir()
|
|
841
|
+
|
|
842
|
+
app_state: dict[str, Any] = {
|
|
843
|
+
"dataset": RUNTIME_DATASET,
|
|
844
|
+
"allowed_environments": ["enron"],
|
|
845
|
+
"tracing_enabled": tracing_enabled,
|
|
846
|
+
}
|
|
847
|
+
if tracer_factory is not None:
|
|
848
|
+
app_state["session_tracer_factory"] = tracer_factory
|
|
849
|
+
if sft_output_dir:
|
|
850
|
+
app_state["sft_output_dir"] = sft_output_dir
|
|
851
|
+
|
|
852
|
+
if tracing_enabled:
|
|
853
|
+
logger.info("[enron:tracing] enabled (db=%s)", tracing_db_url or "default")
|
|
854
|
+
else:
|
|
855
|
+
logger.info("[enron:tracing] disabled")
|
|
856
|
+
if sft_output_dir:
|
|
857
|
+
logger.info("[enron:sft] writing JSONL to %s", sft_output_dir)
|
|
858
|
+
|
|
859
|
+
config = TaskAppConfig(
|
|
860
|
+
app_id="grpo-enron",
|
|
861
|
+
name="GRPO Enron Email QA Task App",
|
|
862
|
+
description="Tool-assisted QA environment over Enron emails with GRPO-compatible endpoints.",
|
|
863
|
+
base_task_info=BASE_INFO,
|
|
864
|
+
describe_taskset=lambda: describe_taskset(RUNTIME_DATASET),
|
|
865
|
+
provide_task_instances=lambda seeds: provide_task_instances(RUNTIME_DATASET, BASE_INFO, seeds),
|
|
866
|
+
rollout=rollout_executor,
|
|
867
|
+
dataset_registry=registry,
|
|
868
|
+
rubrics=RubricBundle(outcome=OUTCOME_RUBRIC, events=EVENTS_RUBRIC),
|
|
869
|
+
proxy=ProxyConfig(enable_openai=False, enable_groq=False),
|
|
870
|
+
routers=(),
|
|
871
|
+
app_state=app_state,
|
|
872
|
+
cors_origins=["*"],
|
|
873
|
+
)
|
|
874
|
+
return config
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
register_task_app(
|
|
878
|
+
entry=TaskAppEntry(
|
|
879
|
+
app_id="grpo-enron",
|
|
880
|
+
description="Enron email QA task app with rollout metadata endpoints.",
|
|
881
|
+
config_factory=build_config,
|
|
882
|
+
aliases=("enron", "enron-task"),
|
|
883
|
+
env_files=(str(REPO_ROOT / "backend" / ".env.dev"),),
|
|
884
|
+
modal=ModalDeploymentConfig(
|
|
885
|
+
app_name="grpo-enron-task-app",
|
|
886
|
+
python_version="3.11",
|
|
887
|
+
pip_packages=(
|
|
888
|
+
"fastapi>=0.100.0",
|
|
889
|
+
"uvicorn>=0.23.0",
|
|
890
|
+
"pydantic>=2.0.0",
|
|
891
|
+
"httpx>=0.24.0",
|
|
892
|
+
"python-dotenv>=1.0.1",
|
|
893
|
+
"datasets>=2.10.0",
|
|
894
|
+
),
|
|
895
|
+
extra_local_dirs=(
|
|
896
|
+
(str(REPO_ROOT), "/opt/synth_ai_repo"),
|
|
897
|
+
(str(REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),
|
|
898
|
+
(str(_HERE.parent), "/opt/synth_ai_repo/examples/task_apps/enron/task_app"),
|
|
899
|
+
),
|
|
900
|
+
secret_names=("groq-api-key", "openai-api-key"),
|
|
901
|
+
memory=8192,
|
|
902
|
+
cpu=2.0,
|
|
903
|
+
max_containers=4,
|
|
904
|
+
),
|
|
905
|
+
)
|
|
906
|
+
)
|