synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (226) hide show
  1. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -1
  2. examples/swe/task_app/grpo_swe_mini.py +55 -26
  3. examples/swe/task_app/hosted/rollout.py +40 -0
  4. examples/swe/task_app/hosted/test_service.py +5 -6
  5. examples/task_apps/TESTING.md +275 -0
  6. examples/task_apps/__init__.py +0 -0
  7. examples/task_apps/crafter/__init__.py +0 -0
  8. examples/task_apps/crafter/task_app/__init__.py +2 -0
  9. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +18 -13
  10. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
  11. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
  12. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +25 -3
  13. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +10 -0
  14. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
  15. examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
  16. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
  17. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
  18. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
  19. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
  20. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
  21. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
  22. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
  23. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
  24. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
  25. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
  26. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
  27. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
  28. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
  29. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
  30. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
  31. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
  32. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
  33. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
  34. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
  35. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
  36. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
  37. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
  38. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
  39. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
  40. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
  41. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
  42. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
  43. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
  44. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
  45. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
  46. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
  47. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
  48. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
  49. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
  50. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
  51. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
  52. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
  53. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
  54. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
  55. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
  56. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
  57. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
  58. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
  59. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
  60. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
  61. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
  62. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
  63. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
  64. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
  65. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
  66. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
  67. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
  68. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
  69. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
  70. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
  71. examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
  72. examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
  73. examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
  74. examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
  75. examples/task_apps/enron/__init__.py +1 -0
  76. examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
  77. examples/task_apps/enron/task_app/README.md +14 -0
  78. examples/task_apps/enron/task_app/__init__.py +1 -0
  79. examples/task_apps/enron/task_app/grpo_enron.py +906 -0
  80. examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
  81. examples/task_apps/enron/tests/__init__.py +2 -0
  82. examples/task_apps/enron/tests/conftest.py +115 -0
  83. examples/task_apps/enron/tests/integration/__init__.py +2 -0
  84. examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
  85. examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
  86. examples/task_apps/enron/tests/unit/__init__.py +2 -0
  87. examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
  88. examples/task_apps/math/__init__.py +0 -0
  89. examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
  90. examples/task_apps/pokemon_battle/__init__.py +2 -0
  91. examples/task_apps/pokemon_battle/modal_app.py +104 -0
  92. examples/task_apps/pokemon_battle/task_app/README.md +68 -0
  93. examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
  94. examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
  95. examples/task_apps/pokemon_red/README.md +357 -0
  96. examples/task_apps/pokemon_red/__init__.py +3 -0
  97. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
  98. examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
  99. examples/task_apps/pokemon_red/task_app.py +606 -0
  100. examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
  101. examples/task_apps/sokoban/README.md +307 -0
  102. examples/task_apps/sokoban/__init__.py +3 -0
  103. examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
  104. examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
  105. examples/task_apps/sokoban/task_app.py +1058 -0
  106. examples/task_apps/sokoban/tests/__init__.py +2 -0
  107. examples/task_apps/sokoban/tests/conftest.py +113 -0
  108. examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
  109. examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
  110. examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
  111. examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
  112. examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
  113. examples/task_apps/verilog/__init__.py +1 -0
  114. examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
  115. examples/task_apps/verilog/task_app/README.md +12 -0
  116. examples/task_apps/verilog/task_app/__init__.py +1 -0
  117. examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
  118. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
  119. examples/task_apps/verilog/tests/__init__.py +2 -0
  120. examples/task_apps/verilog/tests/conftest.py +115 -0
  121. examples/task_apps/verilog/tests/integration/__init__.py +2 -0
  122. examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
  123. examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
  124. examples/task_apps/verilog/tests/unit/__init__.py +2 -0
  125. examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
  126. examples/vlm/crafter_openai_vlm_agent.py +4 -4
  127. examples/vlm/run_crafter_vlm_benchmark.py +4 -4
  128. examples/workflows/__init__.py +0 -0
  129. examples/workflows/math_rl/__init__.py +0 -0
  130. examples/workflows/math_rl/download_dataset.py +80 -0
  131. synth_ai/__init__.py +2 -2
  132. synth_ai/api/train/builders.py +25 -11
  133. synth_ai/api/train/cli.py +12 -6
  134. synth_ai/api/train/configs/__init__.py +10 -10
  135. synth_ai/api/train/configs/rl.py +5 -4
  136. synth_ai/api/train/configs/sft.py +4 -3
  137. synth_ai/api/train/env_resolver.py +5 -2
  138. synth_ai/api/train/supported_algos.py +10 -5
  139. synth_ai/api/train/utils.py +7 -4
  140. synth_ai/cli/__init__.py +7 -51
  141. synth_ai/cli/_storage.py +4 -3
  142. synth_ai/cli/_validate_task_app.py +11 -0
  143. synth_ai/cli/balance.py +4 -3
  144. synth_ai/cli/calc.py +2 -2
  145. synth_ai/cli/demo.py +14 -7
  146. synth_ai/cli/legacy_root_backup.py +1 -1
  147. synth_ai/cli/rl_demo.py +8 -7
  148. synth_ai/cli/root.py +0 -97
  149. synth_ai/cli/task_apps.py +1707 -186
  150. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
  151. synth_ai/environments/examples/enron/engine.py +7 -2
  152. synth_ai/environments/examples/enron/environment.py +68 -0
  153. synth_ai/environments/examples/red/engine.py +27 -0
  154. synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
  155. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
  156. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
  157. synth_ai/environments/examples/red/environment.py +60 -0
  158. synth_ai/environments/examples/sokoban/taskset.py +116 -0
  159. synth_ai/environments/examples/verilog/engine.py +30 -4
  160. synth_ai/evals/client.py +58 -61
  161. synth_ai/jobs/client.py +16 -4
  162. synth_ai/judge_schemas.py +16 -16
  163. synth_ai/py.typed +0 -0
  164. synth_ai/task/__init__.py +14 -5
  165. synth_ai/task/contracts.py +124 -38
  166. synth_ai/task/proxy.py +48 -56
  167. synth_ai/task/rubrics/__init__.py +53 -0
  168. synth_ai/task/rubrics/loaders.py +133 -0
  169. synth_ai/task/rubrics/models.py +57 -0
  170. synth_ai/task/rubrics/scoring.py +113 -0
  171. synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
  172. synth_ai/task/server.py +8 -7
  173. synth_ai/task/validators.py +269 -6
  174. synth_ai/tracing_v3/decorators.py +7 -3
  175. synth_ai/tracing_v3/replica_sync.py +4 -4
  176. synth_ai/tracing_v3/serialization.py +5 -5
  177. synth_ai/tracing_v3/trace_utils.py +317 -0
  178. synth_ai/tracing_v3/turso/native_manager.py +3 -3
  179. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
  180. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +214 -101
  181. examples/agora_ex/README_MoE.md +0 -224
  182. examples/agora_ex/__init__.py +0 -7
  183. examples/agora_ex/agora_ex.py +0 -65
  184. examples/agora_ex/agora_ex_task_app.py +0 -590
  185. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
  186. examples/agora_ex/reward_fn_grpo-human.py +0 -129
  187. examples/agora_ex/system_prompt_CURRENT.md +0 -63
  188. examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
  189. examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
  190. examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
  191. synth_ai/rubrics/__init__.py +0 -22
  192. synth_ai/task/rubrics.py +0 -219
  193. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
  194. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
  195. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
  196. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
  197. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
  198. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
  199. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
  200. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
  201. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
  202. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
  203. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
  204. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
  205. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
  206. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
  207. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +0 -0
  208. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
  209. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
  210. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
  211. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
  212. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
  213. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
  214. /examples/{rl/task_app → task_apps/math}/README.md +0 -0
  215. /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
  216. /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
  217. /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
  218. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
  219. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
  220. /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
  221. /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
  222. /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
  223. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
  224. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -0
  225. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
  226. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,906 @@
1
+ """Task App configuration for the GRPO Enron email QA example."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import contextlib
6
+ import json
7
+ import logging
8
+ import os
9
+ import time
10
+ from dataclasses import dataclass
11
+ from datetime import datetime, timezone
12
+ from pathlib import Path
13
+ from typing import Any, Iterable, Sequence
14
+ from uuid import UUID, uuid4
15
+
16
+ from datasets import load_dataset
17
+ import httpx
18
+
19
+ from fastapi import HTTPException
20
+
21
+ from synth_ai.environments.examples.enron.environment import EnronEnvironment
22
+ from synth_ai.environments.examples.enron.taskset import (
23
+ EnronTaskInstance,
24
+ EnronTaskInstanceMetadata,
25
+ )
26
+ from synth_ai.environments.tasks.core import (
27
+ Impetus,
28
+ Intent,
29
+ SplitInfo,
30
+ TaskInstanceSet,
31
+ )
32
+ from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
33
+ from synth_ai.task.contracts import (
34
+ RolloutMetrics,
35
+ RolloutRequest,
36
+ RolloutResponse,
37
+ RolloutStep,
38
+ RolloutTrajectory,
39
+ TaskInfo,
40
+ )
41
+ from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
42
+ from synth_ai.task.rubrics import load_rubric
43
+ from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
44
+ from synth_ai.task.tracing_utils import (
45
+ build_tracer_factory,
46
+ resolve_sft_output_dir,
47
+ resolve_tracing_db_url,
48
+ tracing_env_enabled,
49
+ )
50
+ from synth_ai.tracing_v3.session_tracer import SessionTracer
51
+ from synth_ai.environments.environment.tools import EnvToolCall
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+ _HERE = Path(__file__).resolve()
56
+ REPO_ROOT = _HERE.parents[4]
57
+
58
+ DATASET_SPEC = TaskDatasetSpec(
59
+ id="enron_email_qa",
60
+ name="Enron Email QA",
61
+ version="1.0.0",
62
+ splits=["train", "test"],
63
+ default_split="train",
64
+ description="Question answering over a sample of Enron emails.",
65
+ )
66
+
67
+ HF_DATASET_ID = "corbt/enron_emails_sample_questions"
68
+ HF_CACHE_DIR = os.path.join(
69
+ os.getenv("ENRON_DATASET_CACHE_DIR", str(REPO_ROOT / ".cache" / "hf-datasets"))
70
+ )
71
+
72
+ TOOLS = ["search_emails", "read_email", "answer_question", "terminate"]
73
+ GROQ_CHAT_URL = "https://api.groq.com/openai/v1/chat/completions"
74
+ DEFAULT_GROQ_MODEL = "qwen/qwen3-32b"
75
+ ENRON_SYSTEM_PROMPT = (
76
+ "You are an Enron investigations analyst. Answer the user's question by reading emails. "
77
+ "You can call tools to search the corpus, read specific messages, and submit a final answer. "
78
+ "Use the tools deliberately, gather evidence before answering, and when confident call "
79
+ "answer_question with your final answer. If you cannot find the answer after thorough search, "
80
+ "answer_question with your best attempt noting uncertainty."
81
+ )
82
+
83
+
84
+ def _simplify(obj: Any) -> Any:
85
+ if isinstance(obj, (str, int, float, bool)) or obj is None:
86
+ return obj
87
+ if isinstance(obj, dict):
88
+ return {str(k): _simplify(v) for k, v in obj.items()}
89
+ if isinstance(obj, (list, tuple, set)):
90
+ return [_simplify(v) for v in obj]
91
+ return str(obj)
92
+
93
+
94
+ def _render_search_results(results: list[dict[str, Any]]) -> str:
95
+ if not results:
96
+ return "No search results."
97
+ lines = []
98
+ for item in results[:5]:
99
+ message_id = item.get("message_id") or item.get("id") or "<unknown>"
100
+ snippet = (item.get("snippet") or item.get("snip") or "").strip()
101
+ lines.append(f"- {message_id}: {snippet[:280]}")
102
+ return "\n".join(lines)
103
+
104
+
105
+ def _render_email(email: dict[str, Any] | None) -> str:
106
+ if not email:
107
+ return "No email loaded."
108
+ subject = email.get("subject", "<no subject>")
109
+ from_addr = email.get("from_address") or email.get("from_addr") or "<unknown>"
110
+ date = email.get("date", "<unknown date>")
111
+ snippet = (email.get("body") or "")[:600]
112
+ return f"Subject: {subject}\nFrom: {from_addr}\nDate: {date}\nBody Preview:\n{snippet}"
113
+
114
+
115
+ def _render_observation(obs: dict[str, Any]) -> str:
116
+ lines = [
117
+ f"Question: {obs.get('question', '')}",
118
+ f"Already answered: {bool(obs.get('already_answered'))}",
119
+ f"Available tools: {', '.join(obs.get('tools') or [])}",
120
+ f"Inbox address: {obs.get('inbox_address', '<unknown>')}",
121
+ f"Reward Δ: {obs.get('reward_last', 0)} Total Reward: {obs.get('total_reward', 0)}",
122
+ ]
123
+ tool_error = obs.get("tool_error")
124
+ if tool_error:
125
+ lines.append(f"Last tool error: {tool_error}")
126
+ search_results = obs.get("search_results") or []
127
+ if search_results:
128
+ lines.append("Search Results:")
129
+ lines.append(_render_search_results(search_results))
130
+ email = obs.get("email")
131
+ if email:
132
+ lines.append("Email Content:")
133
+ lines.append(_render_email(email))
134
+ gold = obs.get("gold_answer")
135
+ if gold and obs.get("terminated"):
136
+ lines.append(f"Gold Answer: {gold}")
137
+ return "\n".join(lines)
138
+
139
+
140
+ def _conversation_message(role: str, content: Any, **metadata: Any) -> dict[str, Any]:
141
+ if isinstance(content, (dict, list)):
142
+ rendered = json.dumps(_simplify(content), ensure_ascii=False)
143
+ else:
144
+ rendered = str(content)
145
+ message: dict[str, Any] = {"role": role, "content": rendered}
146
+ message.update({k: v for k, v in metadata.items() if v is not None})
147
+ return message
148
+
149
+
150
+ def _build_trace_payload_enron(
151
+ run_id: str,
152
+ request: RolloutRequest,
153
+ steps: list[RolloutStep],
154
+ metrics: RolloutMetrics,
155
+ *,
156
+ provider: str,
157
+ model: str,
158
+ conversation: list[dict[str, Any]],
159
+ metadata: dict[str, Any] | None = None,
160
+ ) -> dict[str, Any]:
161
+ created_at = datetime.now(timezone.utc)
162
+ event_time = time.time()
163
+ session_steps: list[dict[str, Any]] = []
164
+ event_history: list[dict[str, Any]] = []
165
+ markov_history: list[dict[str, Any]] = []
166
+ for msg in conversation:
167
+ event_time += 0.005
168
+ markov_history.append(
169
+ {
170
+ "content": {"text": msg.get("content", "")},
171
+ "message_type": msg.get("role", "system"),
172
+ "time_record": {"event_time": event_time},
173
+ "metadata": _simplify({k: v for k, v in msg.items() if k not in {"role", "content"}}),
174
+ }
175
+ )
176
+
177
+ session_trace = {
178
+ "session_id": run_id,
179
+ "created_at": created_at.isoformat(),
180
+ "metadata": {
181
+ "task": "enron_email_qa",
182
+ "provider": provider,
183
+ "model": model,
184
+ "policy": _simplify(request.policy.model_dump() if request.policy else {}),
185
+ "env": _simplify(request.env.model_dump() if request.env else {}),
186
+ **(_simplify(metadata or {})),
187
+ },
188
+ "session_time_steps": session_steps,
189
+ "event_history": event_history,
190
+ "markov_blanket_message_history": markov_history,
191
+ }
192
+
193
+ return {
194
+ "version": 3,
195
+ "session_trace": session_trace,
196
+ "run_id": run_id,
197
+ "policy_id": request.policy.policy_id or request.policy.policy_name,
198
+ "reward": metrics.mean_return,
199
+ "episode_returns": metrics.episode_returns,
200
+ "mean_return": metrics.mean_return,
201
+ "num_steps": metrics.num_steps,
202
+ }
203
+
204
+
205
+ async def _call_groq_chat(
206
+ client: httpx.AsyncClient,
207
+ api_key: str,
208
+ payload: dict[str, Any],
209
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
210
+ response = await client.post(
211
+ GROQ_CHAT_URL,
212
+ json=payload,
213
+ headers={"Authorization": f"Bearer {api_key}"},
214
+ )
215
+ if response.status_code >= 400:
216
+ try:
217
+ body = response.json()
218
+ except Exception:
219
+ body = {"raw": response.text}
220
+ detail = {
221
+ "status": response.status_code,
222
+ "body": body,
223
+ "headers": dict(response.headers),
224
+ }
225
+ raise HTTPException(status_code=response.status_code, detail=detail)
226
+ data = response.json()
227
+ return data, {
228
+ "status": response.status_code,
229
+ "headers": dict(response.headers),
230
+ "body": data,
231
+ }
232
+
233
+
234
+ def _load_taskset_blocking() -> TaskInstanceSet:
235
+ """Build the Enron taskset synchronously."""
236
+
237
+ cache_dir = Path(HF_CACHE_DIR)
238
+ cache_dir.mkdir(parents=True, exist_ok=True)
239
+
240
+ ds_train = load_dataset(HF_DATASET_ID, split="train", cache_dir=cache_dir)
241
+ ds_test = load_dataset(HF_DATASET_ID, split="test", cache_dir=cache_dir)
242
+
243
+ def _metadata_from_row(row: dict[str, Any], split: str) -> EnronTaskInstance:
244
+ question = str(row.get("question") or "").strip()
245
+ answer = str(row.get("answer") or "").strip()
246
+ message_ids = row.get("message_ids") or []
247
+ if not isinstance(message_ids, list):
248
+ message_ids = list(message_ids)
249
+ impetus = Impetus(instructions=question)
250
+ intent = Intent(
251
+ rubric={"goal": "Answer the question using the Enron emails."},
252
+ gold_trajectories=None,
253
+ gold_state_diff={"answer": answer},
254
+ )
255
+ metadata = EnronTaskInstanceMetadata(
256
+ split=split,
257
+ email_count=len(message_ids),
258
+ message_ids=message_ids,
259
+ )
260
+ return EnronTaskInstance(
261
+ id=uuid4(),
262
+ impetus=impetus,
263
+ intent=intent,
264
+ metadata=metadata,
265
+ is_reproducible=True,
266
+ initial_engine_snapshot=row,
267
+ )
268
+
269
+ train_instances = [_metadata_from_row(r, "train") for r in ds_train]
270
+ test_instances = [_metadata_from_row(r, "test") for r in ds_test]
271
+
272
+ split_info = SplitInfo(
273
+ val_instance_ids=set(),
274
+ test_instance_ids={inst.id for inst in test_instances},
275
+ _is_split_defined=True,
276
+ )
277
+
278
+ return TaskInstanceSet(
279
+ name="Enron-QA",
280
+ description="QA over Enron email dataset sample.",
281
+ instances=train_instances + test_instances,
282
+ split_info=split_info,
283
+ )
284
+
285
+
286
+ def _safe_uuid(value: Any) -> UUID:
287
+ if isinstance(value, UUID):
288
+ return value
289
+ try:
290
+ return UUID(str(value))
291
+ except Exception:
292
+ return UUID(int=0)
293
+
294
+
295
+ @dataclass
296
+ class EnronDataset:
297
+ spec: TaskDatasetSpec
298
+
299
+ def __post_init__(self) -> None:
300
+ self._taskset = _load_taskset_blocking()
301
+ self.instances: list[EnronTaskInstance] = list(self._taskset.instances)
302
+ self.instance_ids = [str(_safe_uuid(inst.id)) for inst in self.instances]
303
+ self.default_seed = 0
304
+ self.seed_min = 0
305
+ self.seed_max = max(len(self.instances) - 1, 0)
306
+
307
+ def describe(self) -> dict[str, Any]:
308
+ return {
309
+ **self.spec.model_dump(),
310
+ "instance_count": len(self.instances),
311
+ "instance_ids": self.instance_ids[:50],
312
+ }
313
+
314
+ def instance_by_seed(self, seed: int | None) -> EnronTaskInstance:
315
+ if not self.instances:
316
+ raise ValueError("Enron dataset is empty.")
317
+ if seed is None:
318
+ index = 0
319
+ else:
320
+ index = int(seed) % len(self.instances)
321
+ return self.instances[index]
322
+
323
+
324
+ def build_dataset() -> tuple[TaskDatasetRegistry, EnronDataset]:
325
+ registry = TaskDatasetRegistry()
326
+ dataset = EnronDataset(DATASET_SPEC)
327
+ registry.register(DATASET_SPEC, lambda _spec: dataset, cache=True)
328
+ return registry, dataset
329
+
330
+
331
+ def _base_task_info(dataset: EnronDataset) -> TaskInfo:
332
+ return TaskInfo(
333
+ task={"id": "enron_email_qa", "name": "Enron Email QA", "version": "1.0.0"},
334
+ environment="enron",
335
+ action_space={
336
+ "type": "tool_calls",
337
+ "tools": TOOLS,
338
+ "description": "Tool-assisted QA workflow over an email corpus.",
339
+ },
340
+ observation={
341
+ "summary": "Text observations describing the question, tool status, and last reward.",
342
+ "format": "text",
343
+ },
344
+ dataset={**dataset.describe(), "default_seed": dataset.default_seed},
345
+ rubric={
346
+ "version": "1",
347
+ "criteria_count": 1,
348
+ "source": "inline",
349
+ "aggregation": "weighted_sum",
350
+ },
351
+ inference={
352
+ "supports_proxy": False,
353
+ "endpoints": {},
354
+ "tool": {"name": "enron_tools", "parallel_tool_calls": False},
355
+ },
356
+ limits={"max_ops": 0, "max_time_s": 900},
357
+ )
358
+
359
+
360
+ OUTCOME_RUBRIC = load_rubric(
361
+ {
362
+ "version": "1",
363
+ "goal_text": "Provide the correct answer to the question using the Enron emails.",
364
+ "aggregation": "weighted_sum",
365
+ "criteria": [
366
+ {
367
+ "id": "accuracy",
368
+ "description": "Final answer matches the gold answer.",
369
+ "weight": 1.0,
370
+ }
371
+ ],
372
+ }
373
+ )
374
+
375
+ EVENTS_RUBRIC = load_rubric(
376
+ {
377
+ "version": "1",
378
+ "goal_text": "Encourage efficient use of tools when exploring the corpus.",
379
+ "aggregation": "weighted_sum",
380
+ "criteria": [
381
+ {
382
+ "id": "tool_use",
383
+ "description": "Use search, read, and answer tools deliberately.",
384
+ "weight": 1.0,
385
+ }
386
+ ],
387
+ }
388
+ )
389
+
390
+
391
+ def describe_taskset(dataset: EnronDataset) -> dict[str, Any]:
392
+ return dataset.describe()
393
+
394
+
395
+ def provide_task_instances(
396
+ dataset: EnronDataset, base_info: TaskInfo, seeds: Sequence[int]
397
+ ) -> Iterable[TaskInfo]:
398
+ infos: list[TaskInfo] = []
399
+ base_observation = getattr(base_info, "observation", None)
400
+ if hasattr(base_observation, "model_dump"):
401
+ observation_template = base_observation.model_dump()
402
+ elif isinstance(base_observation, dict):
403
+ observation_template = dict(base_observation)
404
+ else:
405
+ observation_template = {}
406
+
407
+ for seed in seeds:
408
+ instance = dataset.instance_by_seed(seed)
409
+ metadata = instance.metadata
410
+ meta_dict = {
411
+ "split": getattr(metadata, "split", None),
412
+ "email_count": getattr(metadata, "email_count", None),
413
+ "message_ids": getattr(metadata, "message_ids", None),
414
+ }
415
+ infos.append(
416
+ TaskInfo(
417
+ task=base_info.task,
418
+ environment=base_info.environment,
419
+ action_space=base_info.action_space,
420
+ observation={
421
+ **observation_template,
422
+ "question": instance.impetus.instructions,
423
+ },
424
+ dataset={
425
+ **base_info.dataset.model_dump(),
426
+ "instance_id": str(_safe_uuid(instance.id)),
427
+ "metadata": meta_dict,
428
+ },
429
+ rubric=base_info.rubric,
430
+ inference=base_info.inference,
431
+ limits=base_info.limits,
432
+ )
433
+ )
434
+ return infos
435
+
436
+
437
+ def _ensure_dataset_from_state(fastapi_request, fallback: EnronDataset) -> EnronDataset:
438
+ if fastapi_request is None:
439
+ return fallback
440
+ dataset = getattr(getattr(fastapi_request, "app", None), "state", None)
441
+ candidate = getattr(dataset, "dataset", None)
442
+ return candidate or fallback
443
+
444
+
445
+ def _normalise_observation(value: Any) -> dict[str, Any]:
446
+ if isinstance(value, dict):
447
+ return value
448
+ if hasattr(value, "observation"):
449
+ obs = getattr(value, "observation")
450
+ if isinstance(obs, dict):
451
+ return obs
452
+ return {"text": str(obs)}
453
+ return {"text": str(value)}
454
+
455
+
456
+ async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutResponse:
457
+ policy_cfg = dict(request.policy.config or {})
458
+ provider = str(policy_cfg.get("provider") or "").strip().lower()
459
+ if provider == "groq":
460
+ return await _rollout_with_groq(request, fastapi_request, policy_cfg)
461
+
462
+ # Fallback: return initial observation but include minimal trace payload
463
+ dataset = _ensure_dataset_from_state(fastapi_request, RUNTIME_DATASET)
464
+ env_seed = getattr(request.env, "seed", None) if request and request.env else None
465
+ instance = dataset.instance_by_seed(env_seed)
466
+ env = EnronEnvironment(task_instance=instance)
467
+ env.custom_obs = None
468
+ try:
469
+ initial_observation = await env.initialize()
470
+ finally:
471
+ with contextlib.suppress(Exception):
472
+ await env.terminate()
473
+
474
+ obs_dict = _normalise_observation(initial_observation)
475
+ step = RolloutStep(
476
+ obs=obs_dict,
477
+ tool_calls=[],
478
+ reward=0.0,
479
+ done=True,
480
+ truncated=None,
481
+ info={"note": "No rollout executed; provider unset."},
482
+ )
483
+ # No inference_url for noop policy
484
+ trajectory = RolloutTrajectory(
485
+ env_id=request.env.env_id or "enron",
486
+ policy_id=request.policy.policy_id or request.policy.policy_name or "noop-policy",
487
+ steps=[step],
488
+ final={"observation": obs_dict},
489
+ length=1,
490
+ inference_url=None, # NEW: No inference for noop policy
491
+ decision_samples=None,
492
+ )
493
+ metrics = RolloutMetrics(
494
+ episode_returns=[0.0],
495
+ mean_return=0.0,
496
+ num_steps=1,
497
+ num_episodes=1,
498
+ outcome_score=None,
499
+ events_score=None,
500
+ details={"note": "Provider not configured; returning initial state."},
501
+ )
502
+ trace_payload = _build_trace_payload_enron(
503
+ request.run_id,
504
+ request,
505
+ [step],
506
+ metrics,
507
+ provider="local",
508
+ model=policy_cfg.get("model") or "noop",
509
+ conversation=[
510
+ _conversation_message("system", ENRON_SYSTEM_PROMPT),
511
+ _conversation_message("user", _render_observation(obs_dict)),
512
+ ],
513
+ metadata={"mode": "noop"},
514
+ )
515
+ return RolloutResponse(
516
+ run_id=request.run_id,
517
+ trajectories=[trajectory],
518
+ branches={},
519
+ metrics=metrics,
520
+ aborted=False,
521
+ ops_executed=0,
522
+ trace=trace_payload,
523
+ )
524
+
525
+
526
+ def _prepare_tool_call(
527
+ tool_name: str,
528
+ raw_args: dict[str, Any],
529
+ current_obs: dict[str, Any],
530
+ ) -> EnvToolCall:
531
+ if tool_name == "search_emails":
532
+ keywords = raw_args.get("keywords")
533
+ if isinstance(keywords, str):
534
+ keywords = [k.strip() for k in keywords.split(",") if k.strip()]
535
+ if not isinstance(keywords, list) or not keywords:
536
+ raise ValueError("search_emails requires a non-empty list of keywords.")
537
+ inbox = raw_args.get("inbox") or current_obs.get("inbox_address") or "investigator@enron.com"
538
+ args = {
539
+ "inbox": str(inbox),
540
+ "keywords": [str(k) for k in keywords],
541
+ "from_addr": raw_args.get("from_addr"),
542
+ "to_addr": raw_args.get("to_addr"),
543
+ "sent_after": raw_args.get("sent_after"),
544
+ "sent_before": raw_args.get("sent_before"),
545
+ "max_results": int(raw_args.get("max_results") or 5),
546
+ }
547
+ return EnvToolCall(tool="search_emails", args=args)
548
+
549
+ if tool_name == "read_email":
550
+ message_id = raw_args.get("message_id")
551
+ if not message_id:
552
+ raise ValueError("read_email requires 'message_id'.")
553
+ return EnvToolCall(tool="read_email", args={"message_id": str(message_id)})
554
+
555
+ if tool_name == "answer_question":
556
+ answer = raw_args.get("answer")
557
+ if not isinstance(answer, str) or not answer.strip():
558
+ raise ValueError("answer_question requires a non-empty 'answer'.")
559
+ return EnvToolCall(tool="answer_question", args={"answer": answer.strip()})
560
+
561
+ if tool_name == "terminate":
562
+ return EnvToolCall(tool="terminate", args={})
563
+
564
+ raise ValueError(f"Unsupported tool '{tool_name}'")
565
+
566
+
567
+ async def _rollout_with_groq(
568
+ request: RolloutRequest,
569
+ fastapi_request,
570
+ config: dict[str, Any],
571
+ ) -> RolloutResponse:
572
+ api_key = os.getenv("GROQ_API_KEY")
573
+ if not api_key:
574
+ raise HTTPException(
575
+ status_code=503,
576
+ detail="GROQ_API_KEY environment variable is required for Groq rollouts.",
577
+ )
578
+
579
+ dataset = _ensure_dataset_from_state(fastapi_request, RUNTIME_DATASET)
580
+ env_seed = getattr(request.env, "seed", None) if request and request.env else None
581
+ instance = dataset.instance_by_seed(env_seed)
582
+ env = EnronEnvironment(task_instance=instance)
583
+ env.custom_obs = None
584
+
585
+ metadata_extra = {
586
+ "split": getattr(instance.metadata, "split", None),
587
+ "email_count": getattr(instance.metadata, "email_count", None),
588
+ "message_ids": list(getattr(instance.metadata, "message_ids", []))[:10],
589
+ }
590
+
591
+ model = config.get("model") or DEFAULT_GROQ_MODEL
592
+ temperature = float(config.get("temperature", 0.2) or 0.2)
593
+ top_p = float(config.get("top_p", 0.8) or 0.8)
594
+ max_tokens = int(config.get("max_tokens", 768) or 768)
595
+ max_turns = int(config.get("max_turns", config.get("max_steps", 12)) or 12)
596
+
597
+ tool_schemas = [
598
+ {
599
+ "type": "function",
600
+ "function": {
601
+ "name": "search_emails",
602
+ "description": "Search the Enron corpus for emails matching keywords.",
603
+ "parameters": {
604
+ "type": "object",
605
+ "properties": {
606
+ "inbox": {"type": "string", "description": "Email address performing the search."},
607
+ "keywords": {
608
+ "type": "array",
609
+ "items": {"type": "string"},
610
+ "minItems": 1,
611
+ "description": "Keywords to include in the search.",
612
+ },
613
+ "from_addr": {"type": "string"},
614
+ "to_addr": {"type": "string"},
615
+ "sent_after": {"type": "string", "description": "YYYY-MM-DD"},
616
+ "sent_before": {"type": "string", "description": "YYYY-MM-DD"},
617
+ "max_results": {"type": "integer", "minimum": 1, "maximum": 10},
618
+ },
619
+ "required": ["keywords"],
620
+ "additionalProperties": False,
621
+ },
622
+ },
623
+ },
624
+ {
625
+ "type": "function",
626
+ "function": {
627
+ "name": "read_email",
628
+ "description": "Read the full contents of an email by message_id.",
629
+ "parameters": {
630
+ "type": "object",
631
+ "properties": {"message_id": {"type": "string"}},
632
+ "required": ["message_id"],
633
+ "additionalProperties": False,
634
+ },
635
+ },
636
+ },
637
+ {
638
+ "type": "function",
639
+ "function": {
640
+ "name": "answer_question",
641
+ "description": "Submit the final answer to the investigation question.",
642
+ "parameters": {
643
+ "type": "object",
644
+ "properties": {"answer": {"type": "string"}},
645
+ "required": ["answer"],
646
+ "additionalProperties": False,
647
+ },
648
+ },
649
+ },
650
+ {
651
+ "type": "function",
652
+ "function": {
653
+ "name": "terminate",
654
+ "description": "Terminate the investigation without answering.",
655
+ "parameters": {"type": "object", "properties": {}, "additionalProperties": False},
656
+ },
657
+ },
658
+ ]
659
+
660
+ steps: list[RolloutStep] = []
661
+ conversation: list[dict[str, Any]] = []
662
+ executed = 0
663
+ try:
664
+ observation = await env.initialize()
665
+ obs_dict = _normalise_observation(observation)
666
+ conversation.append(_conversation_message("system", ENRON_SYSTEM_PROMPT))
667
+ conversation.append(_conversation_message("user", _render_observation(obs_dict)))
668
+
669
+ async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
670
+ for turn in range(max_turns):
671
+ payload = {
672
+ "model": model,
673
+ "messages": conversation,
674
+ "temperature": temperature,
675
+ "top_p": top_p,
676
+ "max_tokens": max_tokens,
677
+ "tools": tool_schemas,
678
+ "tool_choice": "auto",
679
+ }
680
+ vendor_attempts: list[dict[str, Any]] = []
681
+ response, response_meta = await _call_groq_chat(client, api_key, payload)
682
+ vendor_attempts.append({"request": payload, "response": response_meta})
683
+
684
+ choices = response.get("choices") or []
685
+ if not choices:
686
+ break
687
+ message = choices[0].get("message") or {}
688
+ tool_calls = message.get("tool_calls") or []
689
+ assistant_msg_meta = {"tool_calls": _simplify(tool_calls)} if tool_calls else {}
690
+ conversation.append(
691
+ _conversation_message("assistant", message.get("content") or "", **assistant_msg_meta)
692
+ )
693
+
694
+ tool_call_records: list[dict[str, Any]] = []
695
+ step_reward = 0.0
696
+ done = False
697
+ truncated = False
698
+
699
+ if not tool_calls:
700
+ final_answer = (message.get("content") or "").strip()
701
+ if final_answer:
702
+ env_call = EnvToolCall(tool="answer_question", args={"answer": final_answer})
703
+ observation = await env.step(env_call)
704
+ executed += 1
705
+ obs_dict = _normalise_observation(observation)
706
+ step_reward += float(obs_dict.get("reward_last") or 0.0)
707
+ done = bool(obs_dict.get("terminated"))
708
+ truncated = bool(obs_dict.get("truncated"))
709
+ tool_call_records.append({"tool": "answer_question", "args": env_call.args})
710
+ conversation.append(
711
+ _conversation_message(
712
+ "tool",
713
+ {"result": "answer_submitted", "observation": obs_dict},
714
+ name="answer_question",
715
+ )
716
+ )
717
+ else:
718
+ break
719
+ else:
720
+ for call in tool_calls:
721
+ func = call.get("function") or {}
722
+ name = func.get("name")
723
+ raw_args = func.get("arguments")
724
+ if isinstance(raw_args, str):
725
+ try:
726
+ parsed_args = json.loads(raw_args)
727
+ except json.JSONDecodeError:
728
+ parsed_args = {}
729
+ elif isinstance(raw_args, dict):
730
+ parsed_args = raw_args
731
+ else:
732
+ parsed_args = {}
733
+
734
+ env_call = _prepare_tool_call(name, parsed_args, obs_dict)
735
+ observation = await env.step(env_call)
736
+ executed += 1
737
+ obs_dict = _normalise_observation(observation)
738
+ reward_delta = float(obs_dict.get("reward_last") or 0.0)
739
+ step_reward += reward_delta
740
+ done = bool(obs_dict.get("terminated"))
741
+ truncated = bool(obs_dict.get("truncated"))
742
+ tool_call_records.append({"tool": env_call.tool, "args": env_call.args})
743
+ conversation.append(
744
+ _conversation_message(
745
+ "tool",
746
+ {
747
+ "tool": env_call.tool,
748
+ "args": env_call.args,
749
+ "reward_delta": reward_delta,
750
+ "observation": obs_dict,
751
+ },
752
+ name=env_call.tool,
753
+ tool_call_id=call.get("id"),
754
+ )
755
+ )
756
+ if done or truncated:
757
+ break
758
+
759
+ conversation.append(_conversation_message("user", _render_observation(obs_dict)))
760
+
761
+ step = RolloutStep(
762
+ obs=obs_dict,
763
+ tool_calls=tool_call_records,
764
+ reward=step_reward,
765
+ done=done,
766
+ truncated=truncated if truncated else None,
767
+ info={
768
+ "provider": "groq",
769
+ "model": model,
770
+ "vendor_attempts": vendor_attempts,
771
+ "turn": turn,
772
+ },
773
+ )
774
+ steps.append(step)
775
+
776
+ if done or truncated:
777
+ break
778
+ finally:
779
+ with contextlib.suppress(Exception):
780
+ await env.terminate()
781
+
782
+ if steps:
783
+ final_obs = steps[-1].obs
784
+ total_reward = float(final_obs.get("total_reward") or 0.0)
785
+ else:
786
+ total_reward = 0.0
787
+
788
+ metrics = RolloutMetrics(
789
+ episode_returns=[total_reward],
790
+ mean_return=total_reward if steps else 0.0,
791
+ num_steps=len(steps),
792
+ num_episodes=1,
793
+ outcome_score=None,
794
+ events_score=None,
795
+ details={"provider": "groq", "model": model},
796
+ )
797
+ inference_url_groq = "https://api.groq.com/openai/v1/chat/completions"
798
+
799
+ trajectory = RolloutTrajectory(
800
+ env_id=request.env.env_id or "enron",
801
+ policy_id=request.policy.policy_id or request.policy.policy_name or "enron-groq",
802
+ steps=steps,
803
+ final={"observation": steps[-1].obs if steps else {}},
804
+ length=len(steps),
805
+ inference_url=inference_url_groq, # NEW: Required for trace correlation
806
+ decision_samples=None,
807
+ )
808
+ trace_payload = _build_trace_payload_enron(
809
+ request.run_id,
810
+ request,
811
+ steps,
812
+ metrics,
813
+ provider="groq",
814
+ model=model,
815
+ conversation=conversation,
816
+ metadata=metadata_extra,
817
+ )
818
+ return RolloutResponse(
819
+ run_id=request.run_id,
820
+ trajectories=[trajectory],
821
+ branches={},
822
+ metrics=metrics,
823
+ aborted=False,
824
+ ops_executed=executed,
825
+ trace=trace_payload,
826
+ )
827
+
828
+
829
+ RUNTIME_DATASET: EnronDataset
830
+ registry, RUNTIME_DATASET = build_dataset()
831
+ BASE_INFO = _base_task_info(RUNTIME_DATASET)
832
+
833
+
834
+ def build_config() -> TaskAppConfig:
835
+ tracing_enabled = tracing_env_enabled()
836
+ tracing_db_url = resolve_tracing_db_url()
837
+ tracer_factory = build_tracer_factory(
838
+ SessionTracer, enabled=tracing_enabled, db_url=tracing_db_url
839
+ )
840
+ sft_output_dir = resolve_sft_output_dir()
841
+
842
+ app_state: dict[str, Any] = {
843
+ "dataset": RUNTIME_DATASET,
844
+ "allowed_environments": ["enron"],
845
+ "tracing_enabled": tracing_enabled,
846
+ }
847
+ if tracer_factory is not None:
848
+ app_state["session_tracer_factory"] = tracer_factory
849
+ if sft_output_dir:
850
+ app_state["sft_output_dir"] = sft_output_dir
851
+
852
+ if tracing_enabled:
853
+ logger.info("[enron:tracing] enabled (db=%s)", tracing_db_url or "default")
854
+ else:
855
+ logger.info("[enron:tracing] disabled")
856
+ if sft_output_dir:
857
+ logger.info("[enron:sft] writing JSONL to %s", sft_output_dir)
858
+
859
+ config = TaskAppConfig(
860
+ app_id="grpo-enron",
861
+ name="GRPO Enron Email QA Task App",
862
+ description="Tool-assisted QA environment over Enron emails with GRPO-compatible endpoints.",
863
+ base_task_info=BASE_INFO,
864
+ describe_taskset=lambda: describe_taskset(RUNTIME_DATASET),
865
+ provide_task_instances=lambda seeds: provide_task_instances(RUNTIME_DATASET, BASE_INFO, seeds),
866
+ rollout=rollout_executor,
867
+ dataset_registry=registry,
868
+ rubrics=RubricBundle(outcome=OUTCOME_RUBRIC, events=EVENTS_RUBRIC),
869
+ proxy=ProxyConfig(enable_openai=False, enable_groq=False),
870
+ routers=(),
871
+ app_state=app_state,
872
+ cors_origins=["*"],
873
+ )
874
+ return config
875
+
876
+
877
+ register_task_app(
878
+ entry=TaskAppEntry(
879
+ app_id="grpo-enron",
880
+ description="Enron email QA task app with rollout metadata endpoints.",
881
+ config_factory=build_config,
882
+ aliases=("enron", "enron-task"),
883
+ env_files=(str(REPO_ROOT / "backend" / ".env.dev"),),
884
+ modal=ModalDeploymentConfig(
885
+ app_name="grpo-enron-task-app",
886
+ python_version="3.11",
887
+ pip_packages=(
888
+ "fastapi>=0.100.0",
889
+ "uvicorn>=0.23.0",
890
+ "pydantic>=2.0.0",
891
+ "httpx>=0.24.0",
892
+ "python-dotenv>=1.0.1",
893
+ "datasets>=2.10.0",
894
+ ),
895
+ extra_local_dirs=(
896
+ (str(REPO_ROOT), "/opt/synth_ai_repo"),
897
+ (str(REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),
898
+ (str(_HERE.parent), "/opt/synth_ai_repo/examples/task_apps/enron/task_app"),
899
+ ),
900
+ secret_names=("groq-api-key", "openai-api-key"),
901
+ memory=8192,
902
+ cpu=2.0,
903
+ max_containers=4,
904
+ ),
905
+ )
906
+ )