synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (226) hide show
  1. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -1
  2. examples/swe/task_app/grpo_swe_mini.py +55 -26
  3. examples/swe/task_app/hosted/rollout.py +40 -0
  4. examples/swe/task_app/hosted/test_service.py +5 -6
  5. examples/task_apps/TESTING.md +275 -0
  6. examples/task_apps/__init__.py +0 -0
  7. examples/task_apps/crafter/__init__.py +0 -0
  8. examples/task_apps/crafter/task_app/__init__.py +2 -0
  9. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +18 -13
  10. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
  11. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
  12. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +25 -3
  13. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +10 -0
  14. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
  15. examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
  16. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
  17. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
  18. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
  19. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
  20. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
  21. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
  22. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
  23. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
  24. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
  25. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
  26. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
  27. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
  28. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
  29. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
  30. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
  31. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
  32. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
  33. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
  34. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
  35. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
  36. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
  37. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
  38. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
  39. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
  40. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
  41. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
  42. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
  43. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
  44. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
  45. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
  46. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
  47. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
  48. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
  49. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
  50. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
  51. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
  52. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
  53. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
  54. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
  55. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
  56. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
  57. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
  58. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
  59. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
  60. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
  61. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
  62. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
  63. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
  64. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
  65. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
  66. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
  67. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
  68. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
  69. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
  70. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
  71. examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
  72. examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
  73. examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
  74. examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
  75. examples/task_apps/enron/__init__.py +1 -0
  76. examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
  77. examples/task_apps/enron/task_app/README.md +14 -0
  78. examples/task_apps/enron/task_app/__init__.py +1 -0
  79. examples/task_apps/enron/task_app/grpo_enron.py +906 -0
  80. examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
  81. examples/task_apps/enron/tests/__init__.py +2 -0
  82. examples/task_apps/enron/tests/conftest.py +115 -0
  83. examples/task_apps/enron/tests/integration/__init__.py +2 -0
  84. examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
  85. examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
  86. examples/task_apps/enron/tests/unit/__init__.py +2 -0
  87. examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
  88. examples/task_apps/math/__init__.py +0 -0
  89. examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
  90. examples/task_apps/pokemon_battle/__init__.py +2 -0
  91. examples/task_apps/pokemon_battle/modal_app.py +104 -0
  92. examples/task_apps/pokemon_battle/task_app/README.md +68 -0
  93. examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
  94. examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
  95. examples/task_apps/pokemon_red/README.md +357 -0
  96. examples/task_apps/pokemon_red/__init__.py +3 -0
  97. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
  98. examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
  99. examples/task_apps/pokemon_red/task_app.py +606 -0
  100. examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
  101. examples/task_apps/sokoban/README.md +307 -0
  102. examples/task_apps/sokoban/__init__.py +3 -0
  103. examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
  104. examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
  105. examples/task_apps/sokoban/task_app.py +1058 -0
  106. examples/task_apps/sokoban/tests/__init__.py +2 -0
  107. examples/task_apps/sokoban/tests/conftest.py +113 -0
  108. examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
  109. examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
  110. examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
  111. examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
  112. examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
  113. examples/task_apps/verilog/__init__.py +1 -0
  114. examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
  115. examples/task_apps/verilog/task_app/README.md +12 -0
  116. examples/task_apps/verilog/task_app/__init__.py +1 -0
  117. examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
  118. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
  119. examples/task_apps/verilog/tests/__init__.py +2 -0
  120. examples/task_apps/verilog/tests/conftest.py +115 -0
  121. examples/task_apps/verilog/tests/integration/__init__.py +2 -0
  122. examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
  123. examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
  124. examples/task_apps/verilog/tests/unit/__init__.py +2 -0
  125. examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
  126. examples/vlm/crafter_openai_vlm_agent.py +4 -4
  127. examples/vlm/run_crafter_vlm_benchmark.py +4 -4
  128. examples/workflows/__init__.py +0 -0
  129. examples/workflows/math_rl/__init__.py +0 -0
  130. examples/workflows/math_rl/download_dataset.py +80 -0
  131. synth_ai/__init__.py +2 -2
  132. synth_ai/api/train/builders.py +25 -11
  133. synth_ai/api/train/cli.py +12 -6
  134. synth_ai/api/train/configs/__init__.py +10 -10
  135. synth_ai/api/train/configs/rl.py +5 -4
  136. synth_ai/api/train/configs/sft.py +4 -3
  137. synth_ai/api/train/env_resolver.py +5 -2
  138. synth_ai/api/train/supported_algos.py +10 -5
  139. synth_ai/api/train/utils.py +7 -4
  140. synth_ai/cli/__init__.py +7 -51
  141. synth_ai/cli/_storage.py +4 -3
  142. synth_ai/cli/_validate_task_app.py +11 -0
  143. synth_ai/cli/balance.py +4 -3
  144. synth_ai/cli/calc.py +2 -2
  145. synth_ai/cli/demo.py +14 -7
  146. synth_ai/cli/legacy_root_backup.py +1 -1
  147. synth_ai/cli/rl_demo.py +8 -7
  148. synth_ai/cli/root.py +0 -97
  149. synth_ai/cli/task_apps.py +1707 -186
  150. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
  151. synth_ai/environments/examples/enron/engine.py +7 -2
  152. synth_ai/environments/examples/enron/environment.py +68 -0
  153. synth_ai/environments/examples/red/engine.py +27 -0
  154. synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
  155. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
  156. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
  157. synth_ai/environments/examples/red/environment.py +60 -0
  158. synth_ai/environments/examples/sokoban/taskset.py +116 -0
  159. synth_ai/environments/examples/verilog/engine.py +30 -4
  160. synth_ai/evals/client.py +58 -61
  161. synth_ai/jobs/client.py +16 -4
  162. synth_ai/judge_schemas.py +16 -16
  163. synth_ai/py.typed +0 -0
  164. synth_ai/task/__init__.py +14 -5
  165. synth_ai/task/contracts.py +124 -38
  166. synth_ai/task/proxy.py +48 -56
  167. synth_ai/task/rubrics/__init__.py +53 -0
  168. synth_ai/task/rubrics/loaders.py +133 -0
  169. synth_ai/task/rubrics/models.py +57 -0
  170. synth_ai/task/rubrics/scoring.py +113 -0
  171. synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
  172. synth_ai/task/server.py +8 -7
  173. synth_ai/task/validators.py +269 -6
  174. synth_ai/tracing_v3/decorators.py +7 -3
  175. synth_ai/tracing_v3/replica_sync.py +4 -4
  176. synth_ai/tracing_v3/serialization.py +5 -5
  177. synth_ai/tracing_v3/trace_utils.py +317 -0
  178. synth_ai/tracing_v3/turso/native_manager.py +3 -3
  179. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
  180. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +214 -101
  181. examples/agora_ex/README_MoE.md +0 -224
  182. examples/agora_ex/__init__.py +0 -7
  183. examples/agora_ex/agora_ex.py +0 -65
  184. examples/agora_ex/agora_ex_task_app.py +0 -590
  185. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
  186. examples/agora_ex/reward_fn_grpo-human.py +0 -129
  187. examples/agora_ex/system_prompt_CURRENT.md +0 -63
  188. examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
  189. examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
  190. examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
  191. synth_ai/rubrics/__init__.py +0 -22
  192. synth_ai/task/rubrics.py +0 -219
  193. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
  194. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
  195. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
  196. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
  197. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
  198. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
  199. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
  200. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
  201. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
  202. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
  203. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
  204. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
  205. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
  206. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
  207. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +0 -0
  208. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
  209. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
  210. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
  211. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
  212. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
  213. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
  214. /examples/{rl/task_app → task_apps/math}/README.md +0 -0
  215. /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
  216. /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
  217. /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
  218. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
  219. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
  220. /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
  221. /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
  222. /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
  223. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
  224. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -0
  225. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
  226. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
synth_ai/cli/task_apps.py CHANGED
@@ -1,23 +1,30 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import argparse
3
4
  import ast
4
5
  import asyncio
5
6
  import contextlib
7
+ import functools
6
8
  import hashlib
7
9
  import importlib
8
10
  import importlib.util
9
11
  import inspect
10
12
  import json
11
13
  import os
14
+ import shlex
12
15
  import shutil
13
16
  import signal
17
+ import sqlite3
14
18
  import subprocess
15
19
  import sys
16
20
  import tempfile
17
21
  import textwrap
22
+ import time
18
23
  import types
24
+ import uuid
19
25
  from collections.abc import Callable, Iterable, Iterator, Sequence
20
26
  from dataclasses import dataclass
27
+ from datetime import UTC, datetime
21
28
  from pathlib import Path
22
29
  from typing import Any, cast
23
30
 
@@ -25,24 +32,39 @@ try: # Python 3.11+
25
32
  import tomllib as _toml
26
33
  except Exception: # pragma: no cover - fallback
27
34
  _toml = None # type: ignore
28
- import uuid
29
35
 
30
36
  import click
31
37
  from click.exceptions import Abort
32
38
 
39
+ from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
40
+ BaseEvent,
41
+ EnvironmentEvent,
42
+ RuntimeEvent,
43
+ SessionEventMarkovBlanketMessage,
44
+ SessionMessageContent,
45
+ SessionTimeStep,
46
+ SessionTracer,
47
+ TimeRecord,
48
+ )
49
+ from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
50
+ SessionTrace as V3SessionTrace,
51
+ )
52
+
33
53
  # ---------------------------------------------------------------------------
34
54
  # Dynamic imports to avoid hard dependencies during type checking.
35
55
  # ---------------------------------------------------------------------------
36
56
  ModalDeploymentConfigType = TaskAppConfigType = TaskAppEntryType = Any
37
57
 
38
58
  try: # Resolve base URL defaults lazily
39
- _config_module = importlib.import_module("synth_ai.config.base_url")
59
+ _config_module = cast(
60
+ Any, importlib.import_module("synth_ai.config.base_url")
61
+ )
40
62
  PROD_BASE_URL_DEFAULT = cast(str, _config_module.PROD_BASE_URL_DEFAULT)
41
63
  except Exception: # pragma: no cover - fallback
42
64
  PROD_BASE_URL_DEFAULT = "https://agent-learning.onrender.com"
43
65
 
44
66
  try:
45
- _task_apps_module = importlib.import_module("synth_ai.task.apps")
67
+ _task_apps_module = cast(Any, importlib.import_module("synth_ai.task.apps"))
46
68
  ModalDeploymentConfig = cast(
47
69
  type[ModalDeploymentConfigType], _task_apps_module.ModalDeploymentConfig
48
70
  )
@@ -53,9 +75,9 @@ except Exception as exc: # pragma: no cover - critical dependency
53
75
  raise RuntimeError("Unable to load task app registry") from exc
54
76
 
55
77
  try:
56
- _task_server_module = importlib.import_module("synth_ai.task.server")
57
- create_task_app = _task_server_module.create_task_app
58
- run_task_app = _task_server_module.run_task_app
78
+ _task_server_module = cast(Any, importlib.import_module("synth_ai.task.server"))
79
+ create_task_app = cast(Callable[..., Any], _task_server_module.create_task_app)
80
+ run_task_app = cast(Callable[..., Any], _task_server_module.run_task_app)
59
81
  except Exception as exc: # pragma: no cover - critical dependency
60
82
  raise RuntimeError("Unable to load task app server utilities") from exc
61
83
 
@@ -64,10 +86,12 @@ def _load_demo_directory() -> Path | None:
64
86
  """Return the demo task apps directory if available."""
65
87
 
66
88
  try:
67
- module = importlib.import_module("synth_ai.demos.demo_task_apps.core")
68
- loader = module.load_demo_dir
89
+ module = cast(
90
+ Any, importlib.import_module("synth_ai.demos.demo_task_apps.core")
91
+ )
92
+ loader = cast(Callable[[], str | Path | None], module.load_demo_dir)
69
93
  demo_dir = loader()
70
- if isinstance(demo_dir, (str, Path)):
94
+ if isinstance(demo_dir, str | Path):
71
95
  demo_path = Path(demo_dir)
72
96
  if demo_path.exists():
73
97
  return demo_path.resolve()
@@ -105,6 +129,25 @@ DEFAULT_SEARCH_RELATIVE = (
105
129
  )
106
130
 
107
131
 
132
+ def _pearson(xs: Sequence[float], ys: Sequence[float]) -> float | None:
133
+ if len(xs) != len(ys) or len(xs) < 2:
134
+ return None
135
+ mean_x = sum(xs) / len(xs)
136
+ mean_y = sum(ys) / len(ys)
137
+ num = 0.0
138
+ denom_x = 0.0
139
+ denom_y = 0.0
140
+ for x, y in zip(xs, ys, strict=False):
141
+ dx = x - mean_x
142
+ dy = y - mean_y
143
+ num += dx * dy
144
+ denom_x += dx * dx
145
+ denom_y += dy * dy
146
+ if denom_x <= 0 or denom_y <= 0:
147
+ return None
148
+ return num / (denom_x ** 0.5 * denom_y ** 0.5)
149
+
150
+
108
151
  @dataclass
109
152
  class AppChoice:
110
153
  app_id: str
@@ -128,6 +171,171 @@ class AppChoice:
128
171
  return entry
129
172
 
130
173
 
174
+ @dataclass
175
+ class JudgeSpec:
176
+ name: str
177
+ fn: Callable[..., Any]
178
+ kwargs: dict[str, Any]
179
+
180
+
181
+ def _parse_datetime_for_trace(value: Any) -> datetime | None:
182
+ if isinstance(value, datetime):
183
+ return value if value.tzinfo else value.replace(tzinfo=UTC)
184
+ if isinstance(value, str):
185
+ value = value.replace("Z", "+00:00")
186
+ try:
187
+ dt = datetime.fromisoformat(value)
188
+ except ValueError:
189
+ try:
190
+ dt = datetime.fromtimestamp(float(value), tz=UTC)
191
+ except Exception:
192
+ return None
193
+ return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
194
+ if isinstance(value, int | float):
195
+ return datetime.fromtimestamp(float(value), tz=UTC)
196
+ return None
197
+
198
+
199
+ def _time_record_from_dict(payload: dict[str, Any] | None) -> TimeRecord:
200
+ payload = payload or {}
201
+ event_time = payload.get("event_time")
202
+ if not isinstance(event_time, int | float):
203
+ try:
204
+ event_time = float(event_time)
205
+ except Exception:
206
+ event_time = float(time.time())
207
+ message_time = payload.get("message_time")
208
+ if message_time is not None:
209
+ try:
210
+ message_time = int(message_time)
211
+ except Exception:
212
+ message_time = None
213
+ return TimeRecord(event_time=event_time, message_time=message_time)
214
+
215
+
216
+ def _event_from_dict(payload: dict[str, Any]) -> BaseEvent:
217
+ base_kwargs = {
218
+ "system_instance_id": payload.get("system_instance_id", ""),
219
+ "time_record": _time_record_from_dict(payload.get("time_record")),
220
+ "metadata": payload.get("metadata") or {},
221
+ "event_metadata": payload.get("event_metadata"),
222
+ }
223
+ if "actions" in payload:
224
+ return RuntimeEvent(actions=payload.get("actions") or [], **base_kwargs)
225
+ if any(key in payload for key in ("reward", "terminated", "truncated")):
226
+ return EnvironmentEvent(
227
+ reward=float(payload.get("reward", 0.0) or 0.0),
228
+ terminated=bool(payload.get("terminated", False)),
229
+ truncated=bool(payload.get("truncated", False)),
230
+ system_state_before=payload.get("system_state_before"),
231
+ system_state_after=payload.get("system_state_after"),
232
+ **base_kwargs,
233
+ )
234
+ return BaseEvent(**base_kwargs)
235
+
236
+
237
+ def _markov_message_from_dict(payload: dict[str, Any]) -> SessionEventMarkovBlanketMessage:
238
+ content_payload = payload.get("content") or {}
239
+ content = SessionMessageContent(
240
+ text=content_payload.get("text"),
241
+ json_payload=content_payload.get("json_payload"),
242
+ )
243
+ raw_type = (payload.get("message_type") or "").lower()
244
+ if raw_type == "observation":
245
+ normalized_type = "system"
246
+ elif raw_type == "action":
247
+ normalized_type = "assistant"
248
+ elif raw_type in {"user", "assistant", "system", "tool_use", "tool_result"}:
249
+ normalized_type = raw_type
250
+ else:
251
+ normalized_type = "system"
252
+
253
+ return SessionEventMarkovBlanketMessage(
254
+ content=content,
255
+ message_type=normalized_type,
256
+ time_record=_time_record_from_dict(payload.get("time_record")),
257
+ metadata=payload.get("metadata") or {},
258
+ )
259
+
260
+
261
+ def _step_from_dict(payload: dict[str, Any]) -> SessionTimeStep:
262
+ events = [
263
+ _event_from_dict(event)
264
+ for event in payload.get("events", [])
265
+ if isinstance(event, dict)
266
+ ]
267
+ messages = [
268
+ _markov_message_from_dict(msg)
269
+ for msg in payload.get("markov_blanket_messages", [])
270
+ if isinstance(msg, dict)
271
+ ]
272
+ timestamp = _parse_datetime_for_trace(payload.get("timestamp")) or datetime.now(UTC)
273
+ completed_at = _parse_datetime_for_trace(payload.get("completed_at"))
274
+ return SessionTimeStep(
275
+ step_id=payload.get("step_id", ""),
276
+ step_index=int(payload.get("step_index", 0) or 0),
277
+ timestamp=timestamp,
278
+ turn_number=payload.get("turn_number"),
279
+ events=events,
280
+ markov_blanket_messages=messages,
281
+ step_metadata=payload.get("step_metadata") or {},
282
+ completed_at=completed_at,
283
+ )
284
+
285
+
286
+ def _session_trace_from_dict(payload: dict[str, Any]) -> V3SessionTrace | None:
287
+ if not isinstance(payload, dict):
288
+ return None
289
+ steps = [
290
+ _step_from_dict(step)
291
+ for step in payload.get("session_time_steps", [])
292
+ if isinstance(step, dict)
293
+ ]
294
+ events = [
295
+ _event_from_dict(event)
296
+ for event in payload.get("event_history", [])
297
+ if isinstance(event, dict)
298
+ ]
299
+ markov_history = [
300
+ _markov_message_from_dict(msg)
301
+ for msg in payload.get("markov_blanket_message_history", [])
302
+ if isinstance(msg, dict)
303
+ ]
304
+ created_at = _parse_datetime_for_trace(payload.get("created_at")) or datetime.now(UTC)
305
+ metadata = payload.get("metadata") or {}
306
+ session_metadata = payload.get("session_metadata")
307
+ return V3SessionTrace(
308
+ session_id=payload.get("session_id", ""),
309
+ created_at=created_at,
310
+ session_time_steps=steps,
311
+ event_history=events,
312
+ markov_blanket_message_history=markov_history,
313
+ metadata=metadata,
314
+ session_metadata=session_metadata,
315
+ )
316
+
317
+
318
+ async def _store_trace(
319
+ tracer: SessionTracer | None,
320
+ trace_namespace: dict[str, Any] | None,
321
+ extra_metadata: dict[str, Any] | None = None,
322
+ ):
323
+ if tracer is None or not isinstance(trace_namespace, dict):
324
+ return
325
+ session_payload = trace_namespace.get("session_trace")
326
+ if not isinstance(session_payload, dict):
327
+ return
328
+ trace_obj = _session_trace_from_dict(session_payload)
329
+ if trace_obj is None:
330
+ return
331
+ if tracer.db is None:
332
+ await tracer.initialize()
333
+ meta = dict(trace_obj.metadata or {})
334
+ if extra_metadata:
335
+ meta.update(extra_metadata)
336
+ trace_obj.metadata = meta
337
+ await tracer.db.insert_session_trace(trace_obj)
338
+
131
339
  def _temporary_sys_path(paths: Sequence[Path]):
132
340
  """Context manager to prepend entries to sys.path temporarily."""
133
341
 
@@ -676,36 +884,44 @@ def _build_modal_config_from_ast(modal_call: ast.Call) -> ModalDeploymentConfigT
676
884
  elif kw.arg == "pip_packages" and isinstance(kw.value, (ast.List, ast.Tuple)):
677
885
  # Handle pip_packages list/tuple
678
886
  packages: list[str] = []
679
- for elt in kw.value.elts:
680
- if isinstance(elt, ast.Constant):
681
- packages.append(elt.value)
887
+ value_node = kw.value
888
+ if isinstance(value_node, (ast.List, ast.Tuple)):
889
+ for elt in value_node.elts:
890
+ if isinstance(elt, ast.Constant):
891
+ packages.append(elt.value)
682
892
  kwargs[kw.arg] = tuple(packages)
683
893
  elif kw.arg == "extra_local_dirs" and isinstance(kw.value, (ast.List, ast.Tuple)):
684
894
  # Handle extra_local_dirs list/tuple of tuples
685
895
  dirs = []
686
- for elt in kw.value.elts:
687
- if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
688
- src = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
689
- dst = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
690
- if src and dst:
691
- dirs.append((src, dst))
896
+ value_node = kw.value
897
+ if isinstance(value_node, (ast.List, ast.Tuple)):
898
+ for elt in value_node.elts:
899
+ if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
900
+ src = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
901
+ dst = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
902
+ if src and dst:
903
+ dirs.append((src, dst))
692
904
  kwargs[kw.arg] = tuple(dirs)
693
905
  elif kw.arg == "secret_names" and isinstance(kw.value, (ast.List, ast.Tuple)):
694
906
  # Handle secret_names list/tuple
695
907
  secrets = []
696
- for elt in kw.value.elts:
697
- if isinstance(elt, ast.Constant):
698
- secrets.append(elt.value)
908
+ value_node = kw.value
909
+ if isinstance(value_node, (ast.List, ast.Tuple)):
910
+ for elt in value_node.elts:
911
+ if isinstance(elt, ast.Constant):
912
+ secrets.append(elt.value)
699
913
  kwargs[kw.arg] = tuple(secrets)
700
914
  elif kw.arg == "volume_mounts" and isinstance(kw.value, (ast.List, ast.Tuple)):
701
915
  # Handle volume_mounts list/tuple of tuples
702
916
  mounts = []
703
- for elt in kw.value.elts:
704
- if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
705
- name = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
706
- mount = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
707
- if name and mount:
708
- mounts.append((name, mount))
917
+ value_node = kw.value
918
+ if isinstance(value_node, (ast.List, ast.Tuple)):
919
+ for elt in value_node.elts:
920
+ if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
921
+ name = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
922
+ mount = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
923
+ if name and mount:
924
+ mounts.append((name, mount))
709
925
  kwargs[kw.arg] = tuple(mounts)
710
926
 
711
927
  return ModalDeploymentConfig(**kwargs)
@@ -832,6 +1048,71 @@ def _import_task_app_module(
832
1048
  return module
833
1049
 
834
1050
 
1051
+ @contextlib.contextmanager
1052
+ def _safe_import_context() -> Iterator[None]:
1053
+ """Guard module imports against argparse/uvicorn side effects."""
1054
+
1055
+ original_argv = sys.argv[:]
1056
+ sys.argv = [original_argv[0]] if original_argv else ["python"]
1057
+
1058
+ parser_cls = argparse.ArgumentParser
1059
+ old_parse_args = parser_cls.parse_args
1060
+
1061
+ def _parse_noargs(self, args=None, namespace=None): # type: ignore[override]
1062
+ if args is None:
1063
+ args = []
1064
+ if namespace is None:
1065
+ namespace = argparse.Namespace()
1066
+ try:
1067
+ return old_parse_args(self, args, namespace)
1068
+ except SystemExit:
1069
+ return namespace
1070
+
1071
+ parser_cls.parse_args = _parse_noargs # type: ignore[assignment]
1072
+
1073
+ uvicorn_run = None
1074
+ run_task_app_orig = None
1075
+ try:
1076
+ import uvicorn # type: ignore
1077
+
1078
+ uvicorn_run = uvicorn.run
1079
+ uvicorn.run = lambda *args, **kwargs: None # type: ignore[assignment]
1080
+ except Exception:
1081
+ uvicorn_run = None
1082
+
1083
+ try:
1084
+ _task_server_patch = cast(
1085
+ Any, importlib.import_module("synth_ai.task.server")
1086
+ )
1087
+ run_task_app_orig = cast(Callable[..., Any], _task_server_patch.run_task_app)
1088
+ _task_server_patch.run_task_app = ( # type: ignore[assignment]
1089
+ lambda *args, **kwargs: None
1090
+ )
1091
+ except Exception:
1092
+ run_task_app_orig = None
1093
+
1094
+ try:
1095
+ yield
1096
+ finally:
1097
+ sys.argv = original_argv
1098
+ parser_cls.parse_args = old_parse_args # type: ignore[assignment]
1099
+ if uvicorn_run is not None:
1100
+ try:
1101
+ import uvicorn # type: ignore
1102
+
1103
+ uvicorn.run = uvicorn_run # type: ignore[assignment]
1104
+ except Exception:
1105
+ pass
1106
+ if run_task_app_orig is not None:
1107
+ try:
1108
+ _task_server_patch = cast(
1109
+ Any, importlib.import_module("synth_ai.task.server")
1110
+ )
1111
+ _task_server_patch.run_task_app = run_task_app_orig # type: ignore[assignment]
1112
+ except Exception:
1113
+ pass
1114
+
1115
+
835
1116
  def _load_entry_from_path(
836
1117
  path: Path, app_id: str, module_search_roots: Sequence[Path] | None = None
837
1118
  ) -> TaskAppEntryType:
@@ -859,13 +1140,14 @@ def _load_entry_from_path(
859
1140
 
860
1141
  for module_name, namespace_root in _possible_module_names(resolved, search_roots):
861
1142
  try:
862
- module = _import_task_app_module(
863
- resolved,
864
- module_name,
865
- namespace_root=namespace_root,
866
- sys_path_roots=search_roots,
867
- ensure_namespace=True,
868
- )
1143
+ with _safe_import_context():
1144
+ module = _import_task_app_module(
1145
+ resolved,
1146
+ module_name,
1147
+ namespace_root=namespace_root,
1148
+ sys_path_roots=search_roots,
1149
+ ensure_namespace=True,
1150
+ )
869
1151
  break
870
1152
  except Exception as exc: # pragma: no cover - best-effort fallbacks
871
1153
  last_error = exc
@@ -874,13 +1156,14 @@ def _load_entry_from_path(
874
1156
  if module is None:
875
1157
  hashed_name = f"_synth_task_app_{hashlib.md5(str(resolved).encode(), usedforsecurity=False).hexdigest()}"
876
1158
  try:
877
- module = _import_task_app_module(
878
- resolved,
879
- hashed_name,
880
- namespace_root=None,
881
- sys_path_roots=search_roots,
882
- ensure_namespace=False,
883
- )
1159
+ with _safe_import_context():
1160
+ module = _import_task_app_module(
1161
+ resolved,
1162
+ hashed_name,
1163
+ namespace_root=None,
1164
+ sys_path_roots=search_roots,
1165
+ ensure_namespace=False,
1166
+ )
884
1167
  except Exception as exc: # pragma: no cover - propagate meaningful error
885
1168
  detail = last_error or exc
886
1169
  raise click.ClickException(f"Failed to import {resolved}: {detail}") from detail
@@ -928,7 +1211,10 @@ def _load_entry_from_path(
928
1211
  if has_required:
929
1212
  continue
930
1213
  try:
931
- result = attr()
1214
+ with _safe_import_context():
1215
+ result = attr()
1216
+ except SystemExit:
1217
+ continue
932
1218
  except Exception:
933
1219
  continue
934
1220
  if isinstance(result, TaskAppConfig) and result.app_id == app_id:
@@ -1024,21 +1310,173 @@ def _resolve_env_paths_for_script(script_path: Path, explicit: Sequence[str]) ->
1024
1310
  return [env_candidates[choice - 1]]
1025
1311
 
1026
1312
 
1313
+ def _path_is_within(child: Path, parent: Path) -> bool:
1314
+ try:
1315
+ child.resolve().relative_to(parent.resolve())
1316
+ return True
1317
+ except Exception:
1318
+ return False
1319
+
1320
+
1321
+ @functools.lru_cache(maxsize=16)
1322
+ def _is_modal_shim(path_str: str) -> bool:
1323
+ """Return True if the candidate CLI path refers to the synth-ai shim."""
1324
+
1325
+ path = Path(path_str)
1326
+ try:
1327
+ resolved = path.resolve(strict=True)
1328
+ except Exception:
1329
+ resolved = path
1330
+
1331
+ if not resolved.exists() or resolved.is_dir():
1332
+ return False
1333
+
1334
+ snippet = ""
1335
+ try:
1336
+ snippet = resolved.read_bytes()[:4096].decode("utf-8", errors="ignore")
1337
+ except Exception:
1338
+ snippet = ""
1339
+
1340
+ shim_markers = (
1341
+ "synth_ai.cli._modal_wrapper",
1342
+ "from modal.__main__ import main",
1343
+ "import modal.__main__",
1344
+ "run_module('modal.__main__'",
1345
+ )
1346
+ if snippet and any(marker in snippet for marker in shim_markers):
1347
+ return True
1348
+
1349
+ try:
1350
+ size = resolved.stat().st_size
1351
+ except Exception:
1352
+ size = None
1353
+
1354
+ if (
1355
+ size is not None
1356
+ and size < 2048
1357
+ and "python" in (snippet.splitlines() or [""])[0]
1358
+ and (
1359
+ "modal.__main__" in snippet
1360
+ or "modal.__main__" in snippet.replace(" ", "")
1361
+ )
1362
+ ):
1363
+ return True
1364
+
1365
+ virtual_env = os.environ.get("VIRTUAL_ENV")
1366
+ if virtual_env and _path_is_within(resolved, Path(virtual_env)):
1367
+ return True
1368
+
1369
+ if _path_is_within(resolved, REPO_ROOT):
1370
+ return True
1371
+
1372
+ uv_tools_dir = Path.home() / ".local" / "share" / "uv" / "tools"
1373
+ return uv_tools_dir.exists() and _path_is_within(resolved, uv_tools_dir)
1374
+
1375
+
1376
+ def _find_modal_executable(modal_cli: str) -> tuple[str | None, str | None]:
1377
+ """Return the first non-shim executable and the first shim discovered on PATH."""
1378
+
1379
+ if not modal_cli:
1380
+ modal_cli = "modal"
1381
+
1382
+ candidate_path = Path(modal_cli).expanduser()
1383
+ if candidate_path.is_absolute() or len(candidate_path.parts) > 1:
1384
+ resolved_candidate = candidate_path
1385
+ if not resolved_candidate.is_absolute():
1386
+ resolved_candidate = (Path.cwd() / resolved_candidate).resolve()
1387
+ else:
1388
+ resolved_candidate = resolved_candidate.resolve()
1389
+ if not resolved_candidate.exists():
1390
+ raise click.ClickException(f"--modal-cli path does not exist: {resolved_candidate}")
1391
+ if not os.access(resolved_candidate, os.X_OK):
1392
+ raise click.ClickException(f"--modal-cli is not executable: {resolved_candidate}")
1393
+ return str(resolved_candidate), None
1394
+
1395
+ path_env = os.environ.get("PATH", "")
1396
+ if not path_env:
1397
+ return None, None
1398
+
1399
+ seen_dirs: set[str] = set()
1400
+ seen_candidates: set[str] = set()
1401
+ shim_path: str | None = None
1402
+
1403
+ for raw_entry in path_env.split(os.pathsep):
1404
+ if not raw_entry:
1405
+ continue
1406
+ try:
1407
+ resolved_entry = str(Path(raw_entry).resolve())
1408
+ except Exception:
1409
+ resolved_entry = os.path.normpath(raw_entry)
1410
+ if resolved_entry in seen_dirs:
1411
+ continue
1412
+ seen_dirs.add(resolved_entry)
1413
+
1414
+ candidate = shutil.which(modal_cli, path=raw_entry)
1415
+ if candidate is None:
1416
+ continue
1417
+ if candidate in seen_candidates:
1418
+ continue
1419
+ seen_candidates.add(candidate)
1420
+
1421
+ if _is_modal_shim(candidate):
1422
+ if shim_path is None:
1423
+ shim_path = candidate
1424
+ continue
1425
+ return candidate, shim_path
1426
+
1427
+ return None, shim_path
1428
+
1429
+
1027
1430
  def _modal_command_prefix(modal_cli: str) -> list[str]:
1028
1431
  """Resolve a command prefix for invoking the Modal CLI within the active environment."""
1029
- if modal_cli == "modal" and importlib.util.find_spec("modal") is not None:
1432
+
1433
+ force_wrapper_env = os.environ.get("SYNTH_FORCE_MODAL_WRAPPER", "").strip().lower()
1434
+ if force_wrapper_env in {"1", "true", "yes"}:
1435
+ click.secho(
1436
+ "[modal-prefix] SYNTH_FORCE_MODAL_WRAPPER=1 -> using in-process wrapper",
1437
+ fg="yellow",
1438
+ )
1030
1439
  return [sys.executable, "-m", "synth_ai.cli._modal_wrapper"]
1031
1440
 
1032
- modal_path = shutil.which(modal_cli)
1033
- if modal_path is not None:
1034
- return [modal_path]
1441
+ lookup = modal_cli or "modal"
1442
+ spec = importlib.util.find_spec("modal") if lookup == "modal" else None
1443
+
1444
+ preferred, shim_candidate = _find_modal_executable(lookup)
1445
+ if preferred is not None:
1446
+ detail = f"[modal-prefix] modal_cli={lookup} selected={preferred}"
1447
+ if lookup == "modal":
1448
+ detail += f" spec={'yes' if spec else 'no'}"
1449
+ click.secho(detail, fg="cyan")
1450
+ return [preferred]
1451
+
1452
+ if lookup != "modal":
1453
+ raise click.ClickException(f"Modal CLI not found (looked for '{lookup}')")
1454
+
1455
+ if spec is not None:
1456
+ warning = "[modal-prefix] Using synth-ai modal shim; pass --modal-cli /path/to/modal to override."
1457
+ if shim_candidate is not None:
1458
+ warning = (
1459
+ f"[modal-prefix] Using synth-ai modal shim at {shim_candidate}; "
1460
+ "pass --modal-cli /path/to/modal to override."
1461
+ )
1462
+ click.secho(warning, fg="yellow")
1463
+ click.secho(
1464
+ "[modal-prefix] modal_cli=modal selected=module-wrapper spec=yes",
1465
+ fg="yellow",
1466
+ )
1467
+ return [sys.executable, "-m", "synth_ai.cli._modal_wrapper"]
1035
1468
 
1036
- if modal_cli == "modal":
1469
+ if shim_candidate is not None:
1037
1470
  raise click.ClickException(
1038
- "Modal CLI not found. Install the 'modal' package in this environment or pass "
1039
- "--modal-cli with an explicit path."
1471
+ "Modal CLI resolution found the synth-ai shim but the 'modal' package "
1472
+ "is not importable in this environment. Install the official Modal CLI "
1473
+ "or pass --modal-cli with its path."
1040
1474
  )
1041
- raise click.ClickException(f"Modal CLI not found (looked for '{modal_cli}')")
1475
+
1476
+ raise click.ClickException(
1477
+ "Modal CLI not found. Install the 'modal' package in this environment or pass "
1478
+ "--modal-cli with an explicit path."
1479
+ )
1042
1480
 
1043
1481
 
1044
1482
  def _build_modal_app_wrapper(original_script: Path) -> tuple[Path, Path]:
@@ -1173,8 +1611,15 @@ def _run_modal_script(
1173
1611
  if modal_name and command == "deploy":
1174
1612
  cmd.extend(["--name", modal_name])
1175
1613
  if dry_run:
1176
- click.echo("Dry run: " + " ".join(cmd))
1614
+ click.echo(
1615
+ "Dry run: " + " ".join(shlex.quote(component) for component in cmd),
1616
+ err=False,
1617
+ )
1177
1618
  return
1619
+ click.secho(
1620
+ "[modal-exec] " + " ".join(shlex.quote(component) for component in cmd),
1621
+ fg="cyan",
1622
+ )
1178
1623
  try:
1179
1624
  # Stream output live for better diagnostics
1180
1625
  proc = subprocess.Popen(
@@ -1429,7 +1874,6 @@ def _run_modal_with_entry(
1429
1874
  inline_secret_values=inline_secret_values,
1430
1875
  )
1431
1876
  cmd = [*_modal_command_prefix(modal_cli), command, str(script_path)]
1432
-
1433
1877
  if modal_name and command == "deploy":
1434
1878
  cmd.extend(["--name", modal_name])
1435
1879
 
@@ -1444,9 +1888,13 @@ def _run_modal_with_entry(
1444
1888
  proc_env["PYTHONPATH"] = os.pathsep.join(list(dict.fromkeys(pythonpath_entries)))
1445
1889
 
1446
1890
  if dry_run:
1447
- click.echo("Dry run: " + " ".join(cmd))
1891
+ click.echo("Dry run: " + " ".join(shlex.quote(component) for component in cmd))
1448
1892
  script_path.unlink(missing_ok=True)
1449
1893
  return
1894
+ click.secho(
1895
+ "[modal-exec] " + " ".join(shlex.quote(component) for component in cmd),
1896
+ fg="cyan",
1897
+ )
1450
1898
 
1451
1899
  try:
1452
1900
  # Stream output live for better diagnostics
@@ -1531,6 +1979,10 @@ def _parse_env_file(path: Path) -> dict[str, str]:
1531
1979
 
1532
1980
 
1533
1981
  def _interactive_fill_env(env_path: Path) -> Path | None:
1982
+ if not sys.stdin.isatty():
1983
+ raise click.ClickException(
1984
+ "ENVIRONMENT_API_KEY missing. Provide --env-file or run `synth-ai setup` in an interactive shell to create one."
1985
+ )
1534
1986
  existing = _parse_env_file(env_path) if env_path.exists() else {}
1535
1987
 
1536
1988
  def _prompt(label: str, *, default: str = "", required: bool) -> str | None:
@@ -1570,6 +2022,10 @@ def _ensure_env_values(env_paths: list[Path], fallback_dir: Path) -> None:
1570
2022
  if (os.environ.get("ENVIRONMENT_API_KEY") or "").strip():
1571
2023
  return
1572
2024
  target = env_paths[0] if env_paths else (fallback_dir / ".env").resolve()
2025
+ click.echo(
2026
+ "⚠️ ENVIRONMENT_API_KEY not set. Run `uvx synth-ai setup`, "
2027
+ "or pass --env-file pointing at a .env with ENVIRONMENT_API_KEY."
2028
+ )
1573
2029
  result = _interactive_fill_env(target)
1574
2030
  if result is None:
1575
2031
  raise click.ClickException("ENVIRONMENT_API_KEY required to continue")
@@ -1593,7 +2049,7 @@ def _deploy_entry(
1593
2049
  f"Task app '{entry.app_id}' does not define Modal deployment settings"
1594
2050
  )
1595
2051
 
1596
- env_paths = _determine_env_files(entry, env_file)
2052
+ env_paths = _determine_env_files(entry, env_file, original_path=original_path)
1597
2053
  click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
1598
2054
  _run_modal_with_entry(
1599
2055
  entry,
@@ -1620,7 +2076,7 @@ def _modal_serve_entry(
1620
2076
  f"Task app '{entry.app_id}' does not define Modal deployment settings"
1621
2077
  )
1622
2078
 
1623
- env_paths = _determine_env_files(entry, env_file)
2079
+ env_paths = _determine_env_files(entry, env_file, original_path=original_path)
1624
2080
  click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
1625
2081
  _run_modal_with_entry(
1626
2082
  entry,
@@ -1651,6 +2107,255 @@ def list_apps() -> None:
1651
2107
  click.echo(f"- {entry.app_id}{aliases}: {entry.description}")
1652
2108
 
1653
2109
 
2110
+ @task_app_group.command("validate")
2111
+ @click.argument("app_id", type=str, required=True)
2112
+ @click.option(
2113
+ "--url",
2114
+ type=str,
2115
+ default=None,
2116
+ help="Task app URL to validate (if not provided, starts a local server)",
2117
+ )
2118
+ @click.option(
2119
+ "--port",
2120
+ type=int,
2121
+ default=8765,
2122
+ help="Port to use for temporary server (default: 8765)",
2123
+ )
2124
+ @click.option(
2125
+ "--api-key",
2126
+ type=str,
2127
+ default=None,
2128
+ envvar="ENVIRONMENT_API_KEY",
2129
+ help="API key for authentication (default: $ENVIRONMENT_API_KEY)",
2130
+ )
2131
+ @click.option(
2132
+ "--min-instances",
2133
+ type=int,
2134
+ default=10,
2135
+ help="Minimum number of task instances required (default: 10)",
2136
+ )
2137
+ @click.option(
2138
+ "--verbose",
2139
+ "-v",
2140
+ is_flag=True,
2141
+ help="Show detailed information about the task app",
2142
+ )
2143
+ @click.option(
2144
+ "--json",
2145
+ "output_json",
2146
+ is_flag=True,
2147
+ help="Output results as JSON",
2148
+ )
2149
+ def validate_task_app_cmd(
2150
+ app_id: str,
2151
+ url: str | None,
2152
+ port: int,
2153
+ api_key: str | None,
2154
+ min_instances: int,
2155
+ verbose: bool,
2156
+ output_json: bool,
2157
+ ) -> None:
2158
+ """Validate a task app deployment readiness.
2159
+
2160
+ This command verifies that a task app is properly configured and ready to run
2161
+ by checking all required HTTP endpoints, authentication, and task availability.
2162
+
2163
+ By default, it starts a temporary local server for validation. You can also
2164
+ validate a remote deployment by passing --url.
2165
+
2166
+ \b
2167
+ What gets validated:
2168
+ • Root endpoint (/) responds correctly
2169
+ • Health endpoint (/health) is accessible with proper authentication
2170
+ • Info endpoint (/info) returns valid task metadata
2171
+ • Task info endpoint (/task_info) provides task instances
2172
+ • Rollout endpoint (/rollout) is registered
2173
+ • At least N task instances are available (default: 10)
2174
+
2175
+ \b
2176
+ Examples:
2177
+
2178
+ \b
2179
+ Validate grpo-crafter (starts local server automatically):
2180
+ $ synth-ai task-app validate grpo-crafter
2181
+
2182
+ \b
2183
+ Validate sokoban with verbose output:
2184
+ $ synth-ai task-app validate sokoban --verbose
2185
+
2186
+ \b
2187
+ Validate with custom port:
2188
+ $ synth-ai task-app validate sokoban --port 9000
2189
+
2190
+ \b
2191
+ Validate a remote deployment:
2192
+ $ synth-ai task-app validate grpo-crafter --url https://my-crafter.modal.run
2193
+
2194
+ \b
2195
+ Require at least 20 task instances:
2196
+ $ synth-ai task-app validate grpo-crafter --min-instances 20
2197
+
2198
+ \b
2199
+ Get JSON output for automation:
2200
+ $ synth-ai task-app validate sokoban --json
2201
+
2202
+ \b
2203
+ Common use cases:
2204
+ • Pre-deployment verification: Check task app works before deploying to Modal
2205
+ • CI/CD integration: Use --json flag for automated validation in pipelines
2206
+ • Debug failing deployments: Use --verbose to see detailed endpoint responses
2207
+ • Test API key configuration: Verify authentication is set up correctly
2208
+ """
2209
+ import asyncio
2210
+ import socket
2211
+ import subprocess
2212
+ import tempfile
2213
+ import time
2214
+
2215
+ # Import the validate_task_app function defined in this module
2216
+ from synth_ai.cli._validate_task_app import validate_task_app # type: ignore[attr-defined]
2217
+
2218
+ proc = None
2219
+ task_app_url = url
2220
+
2221
+ try:
2222
+ # If no URL provided, start a temporary server
2223
+ if not task_app_url:
2224
+ # Find an available port
2225
+ def is_port_available(port: int) -> bool:
2226
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
2227
+ try:
2228
+ s.bind(("", port))
2229
+ return True
2230
+ except OSError:
2231
+ return False
2232
+
2233
+ while not is_port_available(port):
2234
+ port += 1
2235
+
2236
+ task_app_url = f"http://localhost:{port}"
2237
+
2238
+ if not output_json:
2239
+ click.echo(f"Starting temporary {app_id} server on port {port}...")
2240
+
2241
+ # Start the server in background
2242
+ env = os.environ.copy()
2243
+ if api_key:
2244
+ env["ENVIRONMENT_API_KEY"] = api_key
2245
+
2246
+ # Create a temporary trace DB and trace dir to avoid prompts
2247
+ import tempfile
2248
+ temp_dir = tempfile.mkdtemp()
2249
+ temp_trace_db = os.path.join(temp_dir, "validate_trace.db")
2250
+ temp_trace_dir = os.path.join(temp_dir, "traces")
2251
+ os.makedirs(temp_trace_dir, exist_ok=True)
2252
+
2253
+ proc = subprocess.Popen(
2254
+ [
2255
+ "uv",
2256
+ "run",
2257
+ "synth-ai",
2258
+ "task-app",
2259
+ "serve",
2260
+ app_id,
2261
+ "--port",
2262
+ str(port),
2263
+ "--no-reload",
2264
+ "--trace",
2265
+ temp_trace_dir,
2266
+ "--trace-db",
2267
+ temp_trace_db,
2268
+ ],
2269
+ env=env,
2270
+ stdin=subprocess.PIPE, # Add stdin to handle any prompts
2271
+ stdout=subprocess.DEVNULL if output_json else subprocess.PIPE,
2272
+ stderr=subprocess.DEVNULL if output_json else subprocess.PIPE,
2273
+ text=True,
2274
+ )
2275
+
2276
+ # Write empty input to stdin to skip any prompts
2277
+ if proc.stdin:
2278
+ try:
2279
+ proc.stdin.write("\n")
2280
+ proc.stdin.flush()
2281
+ proc.stdin.close()
2282
+ except Exception:
2283
+ pass
2284
+
2285
+ # Wait for server to be ready
2286
+ if not output_json:
2287
+ click.echo("Waiting for server to start...")
2288
+
2289
+ import httpx
2290
+ for _attempt in range(60): # 30 seconds timeout
2291
+ try:
2292
+ async def check_health():
2293
+ async with httpx.AsyncClient(timeout=2.0) as client:
2294
+ resp = await client.get(f"{task_app_url}/")
2295
+ return resp.status_code == 200
2296
+
2297
+ if asyncio.run(check_health()):
2298
+ break
2299
+ except Exception:
2300
+ pass
2301
+
2302
+ # Check if process died
2303
+ if proc.poll() is not None:
2304
+ stderr_output = ""
2305
+ if proc.stderr and not output_json:
2306
+ stderr_output = proc.stderr.read()
2307
+ click.echo(click.style("✗ Server process exited unexpectedly", fg="red"), err=True)
2308
+ if stderr_output and not output_json:
2309
+ click.echo(f"Error output:\n{stderr_output}", err=True)
2310
+ sys.exit(1)
2311
+
2312
+ time.sleep(0.5)
2313
+ else:
2314
+ click.echo(click.style("✗ Server failed to start within 30 seconds", fg="red"), err=True)
2315
+ sys.exit(1)
2316
+
2317
+ if not output_json:
2318
+ click.echo(click.style("✓ Server started", fg="green"))
2319
+ click.echo()
2320
+
2321
+ # Ensure URL doesn't have trailing slash
2322
+ task_app_url = task_app_url.rstrip("/")
2323
+
2324
+ async def _run() -> tuple[bool, dict[str, Any]]:
2325
+ return await validate_task_app(
2326
+ url=task_app_url,
2327
+ api_key=api_key,
2328
+ min_instances=min_instances,
2329
+ verbose=verbose,
2330
+ )
2331
+
2332
+ success, results = asyncio.run(_run())
2333
+
2334
+ if output_json:
2335
+ import json as _json
2336
+ click.echo(_json.dumps(results, indent=2))
2337
+
2338
+ sys.exit(0 if success else 1)
2339
+
2340
+ finally:
2341
+ # Cleanup: stop the temporary server
2342
+ if proc is not None:
2343
+ if not output_json:
2344
+ click.echo("\nStopping temporary server...")
2345
+ try:
2346
+ proc.terminate()
2347
+ proc.wait(timeout=5)
2348
+ except Exception:
2349
+ proc.kill()
2350
+
2351
+ # Cleanup temp trace DB
2352
+ if not url and 'temp_dir' in locals():
2353
+ import contextlib
2354
+ import shutil
2355
+ with contextlib.suppress(Exception):
2356
+ shutil.rmtree(temp_dir, ignore_errors=True)
2357
+
2358
+
1654
2359
  def _load_env_files_into_process(paths: Sequence[str]) -> None:
1655
2360
  for p in paths:
1656
2361
  try:
@@ -1907,7 +2612,9 @@ def serve_task_group(
1907
2612
  )
1908
2613
 
1909
2614
 
1910
- def _determine_env_files(entry: TaskAppEntryType, user_env_files: Sequence[str]) -> list[Path]:
2615
+ def _determine_env_files(
2616
+ entry: TaskAppEntryType, user_env_files: Sequence[str], *, original_path: Path | None = None
2617
+ ) -> list[Path]:
1911
2618
  resolved: list[Path] = []
1912
2619
  for candidate in user_env_files:
1913
2620
  p = Path(candidate).expanduser()
@@ -1917,30 +2624,46 @@ def _determine_env_files(entry: TaskAppEntryType, user_env_files: Sequence[str])
1917
2624
  if resolved:
1918
2625
  return resolved
1919
2626
 
1920
- # Always prompt for env file selection instead of auto-loading defaults
1921
- # Look for env files in current working directory first, then repo root
1922
- cwd = Path.cwd()
1923
- env_candidates = []
1924
-
1925
- # Add CWD env files first (prioritized)
1926
- cwd_env_files = sorted(cwd.glob("**/*.env"))
1927
- env_candidates.extend(cwd_env_files)
2627
+ declared: list[Path] = []
2628
+ for candidate in getattr(entry, "env_files", ()) or ():
2629
+ try:
2630
+ p = Path(candidate).expanduser()
2631
+ except Exception:
2632
+ continue
2633
+ if p.exists() and p.is_file():
2634
+ declared.append(p)
2635
+ if declared:
2636
+ return declared
1928
2637
 
1929
- # Add repo root env files
1930
- repo_env_files = sorted(REPO_ROOT.glob("**/*.env"))
1931
- # Avoid duplicates
1932
- for repo_file in repo_env_files:
1933
- if repo_file not in env_candidates:
1934
- env_candidates.append(repo_file)
2638
+ def _append_candidate(collection: list[Path], candidate: Path) -> None:
2639
+ if candidate.exists() and candidate.is_file() and candidate not in collection:
2640
+ collection.append(candidate)
1935
2641
 
1936
- if not env_candidates:
1937
- raise click.ClickException("No env file found. Pass --env-file explicitly.")
2642
+ auto_candidates: list[Path] = []
1938
2643
 
1939
- click.echo("Select env file to load:")
1940
- for idx, path in enumerate(env_candidates, start=1):
1941
- click.echo(f" {idx}) {path.resolve()}")
1942
- choice = click.prompt("Enter choice", type=click.IntRange(1, len(env_candidates)), default=1)
1943
- return [env_candidates[choice - 1]]
2644
+ search_dirs: list[Path] = []
2645
+ if original_path is not None:
2646
+ search_dirs.append(original_path.parent.resolve())
2647
+ for parent in original_path.parent.resolve().parents:
2648
+ search_dirs.append(parent)
2649
+ cwd = Path.cwd().resolve()
2650
+ if cwd not in search_dirs:
2651
+ search_dirs.append(cwd)
2652
+ repo_root = REPO_ROOT.resolve()
2653
+ if repo_root not in search_dirs:
2654
+ search_dirs.append(repo_root)
2655
+
2656
+ for directory in search_dirs:
2657
+ _append_candidate(auto_candidates, directory / ".env")
2658
+ for candidate in sorted(directory.glob("*.env")):
2659
+ _append_candidate(auto_candidates, candidate)
2660
+
2661
+ if auto_candidates:
2662
+ return [auto_candidates[0]]
2663
+
2664
+ raise click.ClickException(
2665
+ "No .env file discovered automatically. Pass --env-file /path/to/.env or generate one with `uvx synth-ai setup`."
2666
+ )
1944
2667
 
1945
2668
 
1946
2669
  def _ensure_port_free(port: int, host: str, *, force: bool) -> None:
@@ -2242,7 +2965,14 @@ def deploy_app(
2242
2965
  def modal_serve_app(
2243
2966
  app_id: str | None, modal_cli: str, modal_name: str | None, env_file: Sequence[str]
2244
2967
  ) -> None:
2245
- choice = _select_app_choice(app_id, purpose="modal-serve")
2968
+ click.echo(f"[modal-serve] requested app_id={app_id or '(auto)'} modal_cli={modal_cli}")
2969
+ try:
2970
+ choice = _select_app_choice(app_id, purpose="modal-serve")
2971
+ except SystemExit as exc: # bubble up with context (legacy argparse would trigger this)
2972
+ raise click.ClickException(
2973
+ f"Legacy CLI intercepted modal-serve (exit {exc.code}). "
2974
+ "Make sure you're running the Click CLI (synth_ai.cli:cli)."
2975
+ ) from exc
2246
2976
 
2247
2977
  if choice.modal_script:
2248
2978
  env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
@@ -2251,6 +2981,7 @@ def modal_serve_app(
2251
2981
  return
2252
2982
 
2253
2983
  entry = choice.ensure_entry()
2984
+ click.echo(f"[modal-serve] serving entry {entry.app_id} from {choice.path}")
2254
2985
  _modal_serve_entry(entry, modal_name, modal_cli, env_file, original_path=choice.path)
2255
2986
 
2256
2987
 
@@ -2480,22 +3211,60 @@ def register(cli: click.Group) -> None:
2480
3211
  cli.add_command(serve_command)
2481
3212
  cli.add_command(task_app_group)
2482
3213
  cli.add_command(eval_command)
3214
+ cli.add_command(filter_command)
2483
3215
 
2484
3216
 
2485
- @click.command("eval")
3217
+ @click.command(
3218
+ "eval",
3219
+ help="Run one-off rollouts against a task app and print judge/eval summaries.",
3220
+ )
2486
3221
  @click.argument("app_id", type=str, required=False)
2487
- @click.option("--config", type=click.Path(), default=None, help="Path to eval TOML (short schema)")
3222
+ @click.option(
3223
+ "--config",
3224
+ type=click.Path(),
3225
+ default=None,
3226
+ help="Path to eval TOML (short schema). Auto-discovers the first matching file when omitted.",
3227
+ )
2488
3228
  @click.option(
2489
3229
  "--url",
2490
3230
  "task_app_url",
2491
3231
  type=str,
2492
3232
  default=None,
2493
- help="Base URL of a running task app (skip in-process server)",
3233
+ help="Base URL of a running task app instead of spawning locally (requires --env-file for secrets).",
3234
+ )
3235
+ @click.option(
3236
+ "--seeds",
3237
+ default="0,1,2,3,4",
3238
+ help="Comma-separated seeds/indices to evaluate. Use negative numbers to wrap around the dataset.",
2494
3239
  )
2495
- @click.option("--seeds", default="0,1,2,3,4", help="Comma-separated seeds/indices to evaluate")
2496
3240
  @click.option("--split", default="train", show_default=True, help="Dataset split to use")
2497
- @click.option("--model", default=None, help="Model identifier (prompted if omitted)")
2498
- @click.option("--env-file", multiple=True, type=click.Path(), help="Env file(s) for keys")
3241
+ @click.option(
3242
+ "--model",
3243
+ default=None,
3244
+ help="Model identifier. When omitted the CLI will prompt based on task metadata.",
3245
+ )
3246
+ @click.option(
3247
+ "--env-file",
3248
+ multiple=True,
3249
+ type=click.Path(),
3250
+ help="Env file(s) to load (API keys, etc.). Required when using --url or remote judges.",
3251
+ )
3252
+ @click.option(
3253
+ "--trace-db",
3254
+ default="traces/v3/eval_traces.db",
3255
+ show_default=True,
3256
+ help="SQLite/Turso URL for storing rollout traces set to 'none' to disable persistence.",
3257
+ )
3258
+ @click.option(
3259
+ "--metadata",
3260
+ multiple=True,
3261
+ help="Filter tasks by key=value metadata (e.g., --metadata difficulty=easy)",
3262
+ )
3263
+ @click.option(
3264
+ "--metadata-sql",
3265
+ default=None,
3266
+ help="SQLite query that returns seeds to evaluate (e.g., SELECT seed FROM tasks WHERE difficulty='easy' LIMIT 5)",
3267
+ )
2499
3268
  def eval_command(
2500
3269
  app_id: str | None,
2501
3270
  config: str | None,
@@ -2504,8 +3273,17 @@ def eval_command(
2504
3273
  split: str,
2505
3274
  model: str | None,
2506
3275
  env_file: Sequence[str],
3276
+ trace_db: str,
3277
+ metadata: Sequence[str],
3278
+ metadata_sql: str | None,
2507
3279
  ) -> None:
2508
- """Run local rollouts against a task app using in-process ASGI and summarize results."""
3280
+ """Run rollouts against a task app and report judge statistics.
3281
+
3282
+ By default the command spins up the selected task app in-process, executes the
3283
+ requested seeds, and prints aggregate scores (official and custom judges). When
3284
+ pointing at a remote `--url`, supply matching `--env-file` values so the CLI can
3285
+ forward authentication headers to the running service.
3286
+ """
2509
3287
  cfg: dict[str, Any] = {}
2510
3288
  config_path: Path | None = None
2511
3289
  if config:
@@ -2534,6 +3312,50 @@ def eval_command(
2534
3312
 
2535
3313
  app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
2536
3314
 
3315
+ metadata_filters: dict[str, str] = {}
3316
+ cfg_metadata = cfg.get("metadata")
3317
+ if isinstance(cfg_metadata, dict):
3318
+ for key, value in cfg_metadata.items():
3319
+ metadata_filters[str(key)] = str(value)
3320
+ elif isinstance(cfg_metadata, list):
3321
+ for item in cfg_metadata:
3322
+ if isinstance(item, str) and "=" in item:
3323
+ key, value = item.split("=", 1)
3324
+ metadata_filters[key.strip()] = value.strip()
3325
+
3326
+ for item in metadata or ():
3327
+ if "=" not in item:
3328
+ raise click.ClickException(f"Metadata filters must be key=value (got: {item})")
3329
+ key, value = item.split("=", 1)
3330
+ key = key.strip()
3331
+ value = value.strip()
3332
+ if not key or not value:
3333
+ raise click.ClickException(f"Invalid metadata filter: {item}")
3334
+ metadata_filters[key] = value
3335
+
3336
+ metadata_sql_query: str | None = None
3337
+ cfg_metadata_sql = cfg.get("metadata_sql")
3338
+ if isinstance(cfg_metadata_sql, dict):
3339
+ metadata_sql_query = cfg_metadata_sql.get("query") or cfg_metadata_sql.get("sql")
3340
+ elif isinstance(cfg_metadata_sql, str):
3341
+ metadata_sql_query = cfg_metadata_sql
3342
+
3343
+ if metadata_sql:
3344
+ metadata_sql_query = metadata_sql
3345
+ if metadata_sql_query is not None:
3346
+ metadata_sql_query = str(metadata_sql_query)
3347
+
3348
+ trace_db_url: str | None = None
3349
+ trace_db = (trace_db or "").strip()
3350
+ if trace_db and trace_db.lower() not in {"none", "off", "disable"}:
3351
+ if "://" in trace_db:
3352
+ trace_db_url = trace_db
3353
+ else:
3354
+ trace_path = Path(trace_db).expanduser()
3355
+ trace_path.parent.mkdir(parents=True, exist_ok=True)
3356
+ trace_db_url = f"sqlite+aiosqlite:///{trace_path}"
3357
+ trace_tracer: SessionTracer | None = SessionTracer(db_url=trace_db_url, auto_save=True) if trace_db_url else None
3358
+
2537
3359
  # Determine selection params (CLI takes precedence; TOML only fills unset model/seeds/env)
2538
3360
  if cfg.get("model") and not model:
2539
3361
  model = str(cfg["model"]) # type: ignore[index]
@@ -2553,14 +3375,16 @@ def eval_command(
2553
3375
  elif isinstance(ef, list):
2554
3376
  env_file = tuple(str(x) for x in ef) # type: ignore[assignment]
2555
3377
 
3378
+ choice_for_env: AppChoice | None = None
2556
3379
  entry: TaskAppEntryType | None = None
2557
3380
  if task_app_url is None:
2558
- choice = _select_app_choice(app_id, purpose="eval")
2559
- entry = choice.ensure_entry()
3381
+ choice_for_env = _select_app_choice(app_id, purpose="eval")
3382
+ entry = choice_for_env.ensure_entry()
2560
3383
 
2561
3384
  env_paths: list[Path] = []
2562
3385
  if entry is not None:
2563
- env_paths = _determine_env_files(entry, env_file)
3386
+ original_env_path = choice_for_env.path if choice_for_env is not None else None
3387
+ env_paths = _determine_env_files(entry, env_file, original_path=original_env_path)
2564
3388
  else:
2565
3389
  if not env_file:
2566
3390
  raise click.ClickException("--env-file is required when using --url")
@@ -2583,12 +3407,30 @@ def eval_command(
2583
3407
  app = create_task_app(config)
2584
3408
 
2585
3409
  # Determine supported models
3410
+ inference_meta: dict[str, Any] = {}
2586
3411
  supported: list[str] = []
3412
+ seen_models: set[str] = set()
3413
+
3414
+ def _add_supported_model(candidate: Any) -> None:
3415
+ if not candidate:
3416
+ return
3417
+ text = str(candidate).strip()
3418
+ if not text or text in seen_models:
3419
+ return
3420
+ supported.append(text)
3421
+ seen_models.add(text)
3422
+
2587
3423
  if task_app_url is None:
2588
3424
  try:
2589
- supported = list((config.base_task_info.inference or {}).get("models") or []) # type: ignore[union-attr]
3425
+ if hasattr(config, "base_task_info") and config.base_task_info:
3426
+ inf_obj = getattr(config.base_task_info, "inference", None)
3427
+ if inf_obj is not None:
3428
+ if hasattr(inf_obj, "model_dump"):
3429
+ inference_meta = dict(inf_obj.model_dump(exclude_none=True)) # type: ignore[attr-defined]
3430
+ elif isinstance(inf_obj, dict):
3431
+ inference_meta = dict(inf_obj)
2590
3432
  except Exception:
2591
- supported = []
3433
+ inference_meta = {}
2592
3434
  else:
2593
3435
  try:
2594
3436
  import httpx as _hx
@@ -2601,38 +3443,38 @@ def eval_command(
2601
3443
  info = c.get("/info").json()
2602
3444
  inf = info.get("inference") if isinstance(info, dict) else None
2603
3445
  if isinstance(inf, dict):
2604
- m = inf.get("models")
2605
- if isinstance(m, list):
2606
- supported = [str(x) for x in m]
2607
- if not supported:
2608
- providers = inf.get("providers")
2609
- if isinstance(providers, list):
2610
- if "openai" in providers:
2611
- supported.append("gpt-5")
2612
- if "groq" in providers:
2613
- supported.append("groq:llama-3.1-70b-versatile")
2614
- supported.append("synth:qwen-0.6b")
3446
+ inference_meta = dict(inf)
2615
3447
  except Exception:
2616
- supported = []
2617
- if not supported:
2618
- # Only fall back to local config-derived providers when running in-process
2619
- if task_app_url is None:
2620
- try:
2621
- providers = list((config.base_task_info.inference or {}).get("providers") or []) # type: ignore[union-attr]
2622
- except Exception:
2623
- providers = []
2624
- if "openai" in providers:
2625
- supported.append("gpt-5")
2626
- if "groq" in providers:
2627
- supported.append("groq:llama-3.1-70b-versatile")
2628
- # Always include a local synth model option for smoke tests
2629
- supported.append("synth:qwen-0.6b")
3448
+ inference_meta = {}
3449
+
3450
+ default_model = inference_meta.get("model")
3451
+ if isinstance(default_model, str):
3452
+ _add_supported_model(default_model)
3453
+
3454
+ models_field = inference_meta.get("models")
3455
+ if isinstance(models_field, list):
3456
+ for candidate in models_field:
3457
+ _add_supported_model(candidate)
3458
+
3459
+ supported_models = inference_meta.get("supported_models")
3460
+ if isinstance(supported_models, list):
3461
+ for candidate in supported_models:
3462
+ _add_supported_model(candidate)
3463
+
3464
+ providers = inference_meta.get("providers")
3465
+ if isinstance(providers, list):
3466
+ if "openai" in providers:
3467
+ _add_supported_model("gpt-5")
3468
+ if "groq" in providers:
3469
+ _add_supported_model("groq:llama-3.1-70b-versatile")
3470
+
3471
+ _add_supported_model("synth:qwen-0.6b")
2630
3472
 
2631
3473
  selected_model = model
2632
3474
  if not selected_model:
2633
3475
  if not supported:
2634
3476
  raise click.ClickException(
2635
- "No supported models; supply --model or add base_task_info.inference.models"
3477
+ "No supported models; supply --model or add base_task_info.inference.model"
2636
3478
  )
2637
3479
  click.echo("Select model to evaluate:")
2638
3480
  for idx, m in enumerate(supported, start=1):
@@ -2652,70 +3494,347 @@ def eval_command(
2652
3494
  if api_key:
2653
3495
  headers["X-API-Key"] = api_key
2654
3496
 
3497
+ # Precompute optional policy overrides from TOML
3498
+ policy_overrides: dict[str, Any] = {}
3499
+ try:
3500
+ # Accept [eval.policy] table or top-level keys for convenience
3501
+ if isinstance(cfg.get("policy"), dict):
3502
+ policy_overrides.update(dict(cfg["policy"]))
3503
+ # Back-compat: allow temperature/max_tokens at top level
3504
+ for k in (
3505
+ "temperature",
3506
+ "max_tokens",
3507
+ "reasoning_effort",
3508
+ "system_hint",
3509
+ "tool_choice",
3510
+ "inference_url",
3511
+ ):
3512
+ if k in cfg and k not in policy_overrides:
3513
+ policy_overrides[k] = cfg.get(k)
3514
+ except Exception:
3515
+ policy_overrides = {}
3516
+
3517
+ raw_concurrency = cfg.get("concurrency")
3518
+ try:
3519
+ concurrency_limit = int(raw_concurrency) if raw_concurrency is not None else 1
3520
+ except Exception:
3521
+ concurrency_limit = 1
3522
+ if concurrency_limit <= 0:
3523
+ concurrency_limit = 1
3524
+ concurrency_limit = min(concurrency_limit, max(1, len(seed_values)))
3525
+
3526
+ judge_specs: list[JudgeSpec] = []
3527
+
3528
+ def _register_judge(name_hint: str | None, judge_cfg: dict[str, Any]) -> None:
3529
+ if not judge_cfg:
3530
+ return
3531
+ judge_module = judge_cfg.get("module")
3532
+ judge_path = judge_cfg.get("path")
3533
+ judge_callable_name = judge_cfg.get("callable") or judge_cfg.get("function")
3534
+ if judge_module and judge_path:
3535
+ raise click.ClickException("Judge config cannot set both 'module' and 'path'")
3536
+ if not judge_module and not judge_path:
3537
+ raise click.ClickException("Judge config requires 'module' or 'path'")
3538
+ try:
3539
+ if judge_module:
3540
+ module = importlib.import_module(str(judge_module))
3541
+ else:
3542
+ path = Path(str(judge_path)).expanduser()
3543
+ if not path.exists():
3544
+ raise click.ClickException(f"Judge module path not found: {path}")
3545
+ spec = importlib.util.spec_from_file_location(
3546
+ f"_eval_judge_{path.stem}", path
3547
+ )
3548
+ if not spec or not spec.loader:
3549
+ raise click.ClickException(f"Failed to load judge module from {path}")
3550
+ module = importlib.util.module_from_spec(spec)
3551
+ sys.modules[spec.name] = module
3552
+ spec.loader.exec_module(module)
3553
+ except click.ClickException:
3554
+ raise
3555
+ except Exception as exc:
3556
+ raise click.ClickException(f"Unable to load judge module: {exc}") from exc
3557
+
3558
+ if judge_callable_name:
3559
+ try:
3560
+ judge_fn = getattr(module, str(judge_callable_name))
3561
+ except AttributeError as exc:
3562
+ raise click.ClickException(
3563
+ f"Judge callable '{judge_callable_name}' not found in module"
3564
+ ) from exc
3565
+ else:
3566
+ if hasattr(module, "judge"):
3567
+ judge_fn = module.judge
3568
+ else:
3569
+ raise click.ClickException("Judge module must expose 'judge' callable")
3570
+
3571
+ if not callable(judge_fn):
3572
+ raise click.ClickException("Judge callable is not callable")
3573
+
3574
+ judge_kwargs = {
3575
+ k: v
3576
+ for k, v in judge_cfg.items()
3577
+ if k not in {"module", "path", "callable", "function", "name"}
3578
+ }
3579
+ display_name = str(
3580
+ judge_cfg.get("name")
3581
+ or name_hint
3582
+ or f"judge{len(judge_specs) + 1}"
3583
+ )
3584
+ judge_specs.append(JudgeSpec(display_name, judge_fn, judge_kwargs))
3585
+
3586
+ raw_judge_cfg = cfg.get("judge")
3587
+ if isinstance(raw_judge_cfg, dict) and raw_judge_cfg:
3588
+ direct_keys = {"module", "path", "callable", "function", "name"}
3589
+ has_direct_keys = any(key in raw_judge_cfg for key in direct_keys)
3590
+ nested_candidates = [
3591
+ (key, value)
3592
+ for key, value in raw_judge_cfg.items()
3593
+ if isinstance(value, dict)
3594
+ ]
3595
+ if has_direct_keys and not nested_candidates:
3596
+ _register_judge(None, raw_judge_cfg)
3597
+ else:
3598
+ for sub_name, sub_cfg in nested_candidates:
3599
+ _register_judge(sub_name, sub_cfg)
3600
+
3601
+ raw_judges_list = cfg.get("judges")
3602
+ if isinstance(raw_judges_list, list):
3603
+ for _index, entry in enumerate(raw_judges_list, start=1):
3604
+ if isinstance(entry, dict):
3605
+ _register_judge(entry.get("name") or f"judge{len(judge_specs) + 1}", entry)
3606
+
3607
+ records: list[dict[str, Any]] = []
3608
+
2655
3609
  successes = 0
2656
3610
  failures = 0
2657
3611
  # Aggregate outcome stats across successful seeds
2658
3612
  outcome_sum: float = 0.0
2659
3613
  outcome_count: int = 0
2660
3614
  outcome_correct: int = 0
2661
- if task_app_url is None:
2662
- transport = httpx.ASGITransport(app=app) # type: ignore[name-defined]
2663
- # Newer httpx types consider ASGITransport under httpx._transports; cast to satisfy type checker
2664
- client = httpx.Client(
2665
- transport=cast(Any, transport),
2666
- base_url="http://eval.local",
2667
- timeout=60.0,
2668
- headers=headers,
2669
- )
2670
- else:
2671
- client = httpx.Client(base_url=task_app_url, timeout=60.0, headers=headers)
2672
- try:
2673
- with contextlib.suppress(Exception):
2674
- client.get("/task_info")
2675
- # Precompute optional policy overrides from TOML
2676
- policy_overrides: dict[str, Any] = {}
2677
- try:
2678
- # Accept [eval.policy] table or top-level keys for convenience
2679
- if isinstance(cfg.get("policy"), dict):
2680
- policy_overrides.update(dict(cfg["policy"]))
2681
- # Back-compat: allow temperature/max_tokens at top level
2682
- for k in (
2683
- "temperature",
2684
- "max_tokens",
2685
- "reasoning_effort",
2686
- "system_hint",
2687
- "tool_choice",
2688
- ):
2689
- if k in cfg and k not in policy_overrides:
2690
- policy_overrides[k] = cfg.get(k)
2691
- except Exception:
2692
- policy_overrides = {}
2693
-
2694
- for seed_val in seed_values:
2695
- body = {
2696
- "run_id": str(uuid.uuid4()),
2697
- "env": {"config": {"split": split, "index": seed_val}, "seed": seed_val},
2698
- "policy": {
2699
- "policy_name": selected_model,
2700
- "config": {"model": selected_model, **policy_overrides},
2701
- },
2702
- "ops": [],
3615
+
3616
+ def _build_task_rows(taskset: Any) -> dict[int, dict[str, Any]]:
3617
+ rows: dict[int, dict[str, Any]] = {}
3618
+ if not isinstance(taskset, dict):
3619
+ return rows
3620
+
3621
+ scenario_ids = taskset.get("scenario_ids") or []
3622
+ loop_ids = taskset.get("loop_ids") or []
3623
+ thread_ids = taskset.get("thread_ids") or []
3624
+ difficulty_map = taskset.get("difficulty_map") or {}
3625
+
3626
+ max_len = max(len(scenario_ids), len(loop_ids), len(thread_ids))
3627
+ for seed in range(max_len):
3628
+ scenario_id = scenario_ids[seed] if seed < len(scenario_ids) else None
3629
+ loop_id = loop_ids[seed] if seed < len(loop_ids) else None
3630
+ thread_id = thread_ids[seed] if seed < len(thread_ids) else None
3631
+ difficulty = None
3632
+ if isinstance(difficulty_map, dict):
3633
+ if scenario_id and scenario_id in difficulty_map:
3634
+ difficulty = difficulty_map.get(scenario_id)
3635
+ elif str(seed) in difficulty_map:
3636
+ difficulty = difficulty_map.get(str(seed))
3637
+
3638
+ rows[seed] = {
3639
+ "seed": seed,
3640
+ "scenario_id": scenario_id,
3641
+ "loop_id": loop_id,
3642
+ "thread_id": thread_id,
3643
+ "difficulty": difficulty,
2703
3644
  }
3645
+ return rows
3646
+
3647
+ def _apply_metadata_filters(
3648
+ rows: dict[int, dict[str, Any]], seeds_list: list[int], filters: dict[str, str]
3649
+ ) -> list[int]:
3650
+ if not filters:
3651
+ return seeds_list
3652
+ filtered: list[int] = []
3653
+ for seed in seeds_list:
3654
+ row = rows.get(seed)
3655
+ if not row:
3656
+ continue
3657
+ include = True
3658
+ for key, expected in filters.items():
3659
+ actual = row.get(key)
3660
+ if actual is None:
3661
+ include = False
3662
+ break
3663
+ if str(actual).lower() != expected.lower():
3664
+ include = False
3665
+ break
3666
+ if include:
3667
+ filtered.append(seed)
3668
+ return filtered
3669
+
3670
+ def _apply_metadata_sql(
3671
+ rows: dict[int, dict[str, Any]], seeds_list: list[int], query: str
3672
+ ) -> list[int]:
3673
+ """Return seeds that satisfy an arbitrary SQL query.
3674
+
3675
+ The query is executed against an in-memory SQLite table named `tasks`
3676
+ with columns (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT).
3677
+ Any rows whose `seed` value (or first column if `seed` is absent) appear in the result set are retained.
3678
+ """
3679
+ if not query:
3680
+ return seeds_list
3681
+ conn = sqlite3.connect(":memory:")
3682
+ try:
3683
+ cur = conn.cursor()
3684
+ cur.execute(
3685
+ "CREATE TABLE tasks (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT)"
3686
+ )
3687
+ insert_stmt = (
3688
+ "INSERT INTO tasks (seed, scenario_id, loop_id, thread_id, difficulty) VALUES (?,?,?,?,?)"
3689
+ )
3690
+ for seed in seeds_list:
3691
+ row = rows.get(seed, {})
3692
+ cur.execute(
3693
+ insert_stmt,
3694
+ [
3695
+ seed,
3696
+ row.get("scenario_id"),
3697
+ row.get("loop_id"),
3698
+ row.get("thread_id"),
3699
+ row.get("difficulty"),
3700
+ ],
3701
+ )
3702
+
3703
+ result = cur.execute(query)
3704
+ fetched = result.fetchall()
3705
+ if not fetched:
3706
+ return []
3707
+ description = result.description or []
3708
+ col_names = [col[0] for col in description]
3709
+ seeds_out: list[int] = []
3710
+ for entry in fetched:
3711
+ value = entry[col_names.index("seed")] if "seed" in col_names else entry[0]
3712
+ try:
3713
+ seeds_out.append(int(value))
3714
+ except Exception as exc:
3715
+ raise click.ClickException(
3716
+ "metadata SQL query must return seed integers"
3717
+ ) from exc
3718
+ seeds_set = set(seeds_out)
3719
+ return [seed for seed in seeds_list if seed in seeds_set]
3720
+ except sqlite3.Error as exc:
3721
+ raise click.ClickException(f"Failed to execute metadata SQL query: {exc}") from exc
3722
+ finally:
3723
+ conn.close()
3724
+
3725
+ async def _run_eval() -> None:
3726
+ nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records, seed_values
3727
+
3728
+ if trace_tracer is not None and trace_tracer.db is None:
3729
+ await trace_tracer.initialize()
3730
+
3731
+ if task_app_url is None:
3732
+ transport = httpx.ASGITransport(app=app) # type: ignore[name-defined]
3733
+ async_client = httpx.AsyncClient(
3734
+ transport=cast(Any, transport),
3735
+ base_url="http://eval.local",
3736
+ timeout=300.0,
3737
+ follow_redirects=True,
3738
+ headers=headers,
3739
+ )
3740
+ else:
3741
+ async_client = httpx.AsyncClient(
3742
+ base_url=task_app_url,
3743
+ timeout=300.0,
3744
+ follow_redirects=True,
3745
+ headers=headers,
3746
+ )
3747
+
3748
+ try:
3749
+ taskset_payload: dict[str, Any] | None = None
2704
3750
  try:
2705
- resp = client.post("/rollout", json=body)
2706
- ok = 200 <= resp.status_code < 300
3751
+ task_info_response = await async_client.get("/task_info")
3752
+ except Exception:
3753
+ task_info_response = None
3754
+ if task_info_response is not None and task_info_response.status_code == 200:
3755
+ with contextlib.suppress(Exception):
3756
+ payload_json = task_info_response.json()
3757
+ if isinstance(payload_json, dict) and "taskset" in payload_json:
3758
+ taskset_payload = payload_json.get("taskset")
3759
+ if not isinstance(taskset_payload, dict):
3760
+ taskset_payload = None
3761
+ elif isinstance(payload_json, dict):
3762
+ taskset_payload = payload_json
3763
+
3764
+ available_seeds = list(seed_values)
3765
+ if metadata_sql_query or metadata_filters:
3766
+ if not taskset_payload:
3767
+ raise click.ClickException(
3768
+ "Task metadata filters require the task app to expose /task_info metadata"
3769
+ )
3770
+ rows = _build_task_rows(taskset_payload)
3771
+ if metadata_sql_query:
3772
+ available_seeds = _apply_metadata_sql(rows, available_seeds, metadata_sql_query)
3773
+ if metadata_filters:
3774
+ available_seeds = _apply_metadata_filters(rows, available_seeds, metadata_filters)
3775
+ if not available_seeds:
3776
+ raise click.ClickException("No seeds match the provided metadata filters")
3777
+ seed_values = available_seeds
3778
+
3779
+ semaphore = asyncio.Semaphore(concurrency_limit)
3780
+
3781
+ async def _run_seed(seed_val: int) -> None:
3782
+ nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records
3783
+ body = {
3784
+ "run_id": str(uuid.uuid4()),
3785
+ "env": {"config": {"split": split, "index": seed_val}, "seed": seed_val},
3786
+ "policy": {
3787
+ "policy_name": selected_model,
3788
+ "config": {"model": selected_model, **policy_overrides},
3789
+ },
3790
+ "ops": [],
3791
+ }
3792
+ rollout_elapsed: float | None = None
3793
+ rollout_start = time.perf_counter()
3794
+ try:
3795
+ async with semaphore:
3796
+ response = await async_client.post("/rollout", json=body)
3797
+ rollout_elapsed = time.perf_counter() - rollout_start
3798
+ except Exception as exc:
3799
+ failures += 1
3800
+ click.echo(f"seed={seed_val} error={exc}")
3801
+ return
3802
+
3803
+ ok = 200 <= response.status_code < 300
2707
3804
  if ok:
2708
3805
  successes += 1
2709
3806
  else:
2710
3807
  failures += 1
2711
3808
 
2712
- # Print summary with any available metrics/tool calls
2713
- summary = [f"seed={seed_val}", f"status={resp.status_code}"]
3809
+ summary = [f"seed={seed_val}", f"status={response.status_code}"]
3810
+ data: Any
2714
3811
  try:
2715
- data = resp.json()
3812
+ data = response.json()
2716
3813
  except Exception:
2717
3814
  data = None
3815
+
3816
+ metrics: dict[str, Any] | None = None
3817
+ completion: str | None = None
3818
+ prompt_index: int | None = None
3819
+ prompt_text: str | None = None
3820
+ task_id: str | None = None
3821
+ task_split: str | None = None
3822
+ task_rubric_id: str | None = None
3823
+
3824
+ trace_namespace: dict[str, Any] | None = None
3825
+ session_trace_dict: dict[str, Any] | None = None
3826
+
2718
3827
  if isinstance(data, dict):
3828
+ trace_namespace = data.get("trace")
3829
+ if not isinstance(trace_namespace, dict):
3830
+ raise RuntimeError(
3831
+ "rollout response missing trace payload; task app must return tracing_v3 data"
3832
+ )
3833
+ session_trace_dict = trace_namespace.get("session_trace")
3834
+ if not isinstance(session_trace_dict, dict):
3835
+ raise RuntimeError(
3836
+ "rollout response trace missing 'session_trace'; ensure the task app is serving the tracing_v3 build"
3837
+ )
2719
3838
  metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else None
2720
3839
  if metrics:
2721
3840
  mean_return = metrics.get("mean_return") or metrics.get("total_reward")
@@ -2724,7 +3843,6 @@ def eval_command(
2724
3843
  summary.append(f"mean_return={mean_return}")
2725
3844
  if outcome is not None:
2726
3845
  summary.append(f"outcome={outcome}")
2727
- # Aggregate outcome stats
2728
3846
  try:
2729
3847
  val = float(outcome)
2730
3848
  outcome_sum += val
@@ -2733,7 +3851,6 @@ def eval_command(
2733
3851
  outcome_correct += 1
2734
3852
  except Exception:
2735
3853
  pass
2736
- # Try to infer tool call count from first trajectory step
2737
3854
  trajs = (
2738
3855
  data.get("trajectories")
2739
3856
  if isinstance(data.get("trajectories"), list)
@@ -2747,38 +3864,163 @@ def eval_command(
2747
3864
  tool_calls = step0.get("tool_calls") or step0.get("tools") or []
2748
3865
  if isinstance(tool_calls, list):
2749
3866
  summary.append(f"tool_calls={len(tool_calls)}")
3867
+ obs = step0.get("obs") if isinstance(step0, dict) else None
3868
+ if isinstance(obs, dict):
3869
+ idx_val = obs.get("prompt_index")
3870
+ if isinstance(idx_val, int):
3871
+ prompt_index = idx_val
3872
+ prompt_raw = obs.get("prompt")
3873
+ if isinstance(prompt_raw, str):
3874
+ prompt_text = prompt_raw
3875
+ if task_id is None:
3876
+ candidate_id = obs.get("task_id")
3877
+ if isinstance(candidate_id, str) and candidate_id:
3878
+ task_id = candidate_id
3879
+ if task_split is None:
3880
+ candidate_split = obs.get("task_split")
3881
+ if isinstance(candidate_split, str) and candidate_split:
3882
+ task_split = candidate_split
3883
+ if task_rubric_id is None:
3884
+ candidate_rid = obs.get("task_rubric_id")
3885
+ if isinstance(candidate_rid, str) and candidate_rid:
3886
+ task_rubric_id = candidate_rid
3887
+ final = first.get("final") if isinstance(first, dict) else None
3888
+ if isinstance(final, dict):
3889
+ final_obs = final.get("observation")
3890
+ if isinstance(final_obs, dict):
3891
+ comp_val = final_obs.get("completion")
3892
+ if isinstance(comp_val, str):
3893
+ completion = comp_val
3894
+ if task_id is None:
3895
+ candidate_id = final_obs.get("task_id")
3896
+ if isinstance(candidate_id, str) and candidate_id:
3897
+ task_id = candidate_id
3898
+ if task_split is None:
3899
+ candidate_split = final_obs.get("task_split")
3900
+ if isinstance(candidate_split, str) and candidate_split:
3901
+ task_split = candidate_split
3902
+ if task_rubric_id is None:
3903
+ candidate_rid = final_obs.get("task_rubric_id")
3904
+ if isinstance(candidate_rid, str) and candidate_rid:
3905
+ task_rubric_id = candidate_rid
3906
+ final_info = final.get("info")
3907
+ if isinstance(final_info, dict):
3908
+ if task_id is None:
3909
+ candidate_id = final_info.get("task_id")
3910
+ if isinstance(candidate_id, str) and candidate_id:
3911
+ task_id = candidate_id
3912
+ if task_split is None:
3913
+ candidate_split = final_info.get("task_split")
3914
+ if isinstance(candidate_split, str) and candidate_split:
3915
+ task_split = candidate_split
3916
+ if task_rubric_id is None:
3917
+ candidate_rid = final_info.get("task_rubric_id")
3918
+ if isinstance(candidate_rid, str) and candidate_rid:
3919
+ task_rubric_id = candidate_rid
3920
+ if task_id:
3921
+ summary.append(f"task_id={task_id}")
2750
3922
  click.echo(" ".join(summary))
2751
- # Print the full response JSON (trace, trajectories, metrics)
2752
3923
  with contextlib.suppress(Exception):
2753
3924
  click.echo(json.dumps(data, indent=2))
2754
3925
  else:
2755
3926
  click.echo(" ".join(summary))
2756
- except Exception as exc:
2757
- failures += 1
2758
- click.echo(f"seed={seed_val} error={exc}")
2759
3927
 
2760
- finally:
2761
- try:
2762
- client.close()
2763
- except AttributeError:
2764
- transport_obj = getattr(client, "_transport", None)
2765
- if transport_obj and hasattr(transport_obj, "aclose"):
2766
- try:
2767
- asyncio.run(transport_obj.aclose())
2768
- except RuntimeError:
2769
- # Fallback when already inside a running loop (rare for CLI).
2770
- new_loop = asyncio.new_event_loop()
3928
+ official_score = None
3929
+ if isinstance(metrics, dict):
3930
+ for key in ("mean_return", "total_reward", "outcome_score"):
3931
+ val = metrics.get(key)
3932
+ if isinstance(val, int | float):
3933
+ official_score = float(val)
3934
+ break
3935
+ if official_score is None and isinstance(data, dict):
2771
3936
  try:
2772
- new_loop.run_until_complete(transport_obj.aclose())
2773
- finally:
2774
- new_loop.close()
2775
- except Exception:
2776
- pass
3937
+ reward_val = data["trajectories"][0]["steps"][0].get("reward")
3938
+ if isinstance(reward_val, int | float):
3939
+ official_score = float(reward_val)
3940
+ except Exception:
3941
+ pass
3942
+
3943
+ if official_score is not None:
3944
+ if official_score < 0.0:
3945
+ official_score = 0.0
3946
+ elif official_score > 1.0:
3947
+ official_score = min(1.0, official_score)
3948
+
3949
+ judge_scores: dict[str, float | None] = {}
3950
+ judges_timings: dict[str, float | None] = {}
3951
+ timings: dict[str, Any] = {
3952
+ "rollout_s": rollout_elapsed,
3953
+ "judges": judges_timings,
3954
+ }
3955
+ if judge_specs:
3956
+ for spec in judge_specs:
3957
+ score_value: float | None = None
3958
+ judge_elapsed: float | None = None
3959
+ if completion is not None:
3960
+ judge_payload = {
3961
+ "seed": seed_val,
3962
+ "prompt_index": prompt_index,
3963
+ "prompt": prompt_text,
3964
+ "completion": completion,
3965
+ "metrics": metrics,
3966
+ "response": data,
3967
+ "trace": trace_namespace,
3968
+ }
3969
+ try:
3970
+ judge_start = time.perf_counter()
3971
+ result = spec.fn(judge_payload, **spec.kwargs)
3972
+ judge_elapsed = time.perf_counter() - judge_start
3973
+ if isinstance(result, int | float):
3974
+ score_value = float(result)
3975
+ except Exception as exc:
3976
+ if judge_elapsed is None:
3977
+ judge_elapsed = time.perf_counter() - judge_start
3978
+ click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
3979
+ judges_timings[spec.name] = judge_elapsed
3980
+ judge_scores[spec.name] = score_value
3981
+
3982
+ if trace_tracer is not None and trace_namespace:
3983
+ storage_metadata = {
3984
+ "eval_seed": seed_val,
3985
+ "prompt_index": prompt_index,
3986
+ "task_id": task_id,
3987
+ "task_split": task_split,
3988
+ "task_rubric_id": task_rubric_id,
3989
+ "official_score": official_score,
3990
+ "judge_scores": judge_scores,
3991
+ "model": selected_model,
3992
+ "prompt": prompt_text,
3993
+ "completion": completion,
3994
+ }
3995
+ await _store_trace(trace_tracer, trace_namespace, storage_metadata)
3996
+
3997
+ records.append(
3998
+ {
3999
+ "seed": seed_val,
4000
+ "prompt_index": prompt_index,
4001
+ "task_id": task_id,
4002
+ "task_split": task_split,
4003
+ "task_rubric_id": task_rubric_id,
4004
+ "official_score": official_score,
4005
+ "judge_scores": judge_scores,
4006
+ "timings": timings,
4007
+ }
4008
+ )
4009
+
4010
+ await asyncio.gather(*[_run_seed(seed_val) for seed_val in seed_values])
4011
+ finally:
4012
+ await async_client.aclose()
4013
+
4014
+ try:
4015
+ asyncio.run(_run_eval())
4016
+ finally:
4017
+ if trace_tracer is not None and trace_tracer.db is not None:
4018
+ asyncio.run(trace_tracer.db.close())
2777
4019
 
2778
4020
  click.echo(
2779
4021
  f"Eval complete: {successes} ok, {failures} failed; model={selected_model}, split={split}"
2780
4022
  )
2781
- # Print outcome summary if any successes
4023
+
2782
4024
  if outcome_count > 0:
2783
4025
  mean_outcome = outcome_sum / float(outcome_count)
2784
4026
  frac_right = outcome_correct / float(outcome_count)
@@ -2786,6 +4028,285 @@ def eval_command(
2786
4028
  f"Outcome summary: correct={outcome_correct}/{outcome_count} ({frac_right:.2%}), mean_outcome={mean_outcome:.3f}"
2787
4029
  )
2788
4030
 
4031
+ if records:
4032
+ judge_specs = judge_specs or [] # ensure iterable
4033
+ official_scores = [
4034
+ r["official_score"] for r in records if r["official_score"] is not None
4035
+ ]
4036
+ if official_scores:
4037
+ click.echo(f" Official mean: {sum(official_scores) / len(official_scores):.3f}")
4038
+ else:
4039
+ click.echo(" Official mean: n/a")
4040
+
4041
+ for spec in judge_specs:
4042
+ spec_scores = [
4043
+ record["judge_scores"].get(spec.name)
4044
+ for record in records
4045
+ if record["judge_scores"].get(spec.name) is not None
4046
+ ]
4047
+ if spec_scores:
4048
+ mean_spec = sum(spec_scores) / len(spec_scores)
4049
+ click.echo(f" [{spec.name}] mean: {mean_spec:.3f}")
4050
+ else:
4051
+ click.echo(f" [{spec.name}] mean: n/a")
4052
+
4053
+ paired = [
4054
+ (
4055
+ record["official_score"],
4056
+ record["judge_scores"].get(spec.name),
4057
+ )
4058
+ for record in records
4059
+ if record["official_score"] is not None
4060
+ and record["judge_scores"].get(spec.name) is not None
4061
+ ]
4062
+ if len(paired) >= 2:
4063
+ corr = _pearson(
4064
+ [p[0] for p in paired if p[0] is not None],
4065
+ [p[1] for p in paired if p[1] is not None],
4066
+ )
4067
+ if corr is not None:
4068
+ click.echo(f" Pearson r: {corr:.3f}")
4069
+ else:
4070
+ click.echo(" Pearson r: undefined (zero variance)")
4071
+ else:
4072
+ click.echo(" Pearson r: n/a (need ≥2 paired scores)")
4073
+
4074
+ header = ["Seed", "Prompt", "Official"]
4075
+ header.extend(spec.name for spec in judge_specs)
4076
+ rows: list[list[str]] = []
4077
+ for record in sorted(records, key=lambda r: (r["seed"], r.get("prompt_index") or -1)):
4078
+ seed_val = str(record["seed"])
4079
+ prompt_idx = (
4080
+ str(record["prompt_index"])
4081
+ if record["prompt_index"] is not None
4082
+ else "-"
4083
+ )
4084
+ official_val = (
4085
+ f"{record['official_score']:.3f}"
4086
+ if record["official_score"] is not None
4087
+ else "-"
4088
+ )
4089
+ row = [seed_val, prompt_idx, official_val]
4090
+ for spec in judge_specs:
4091
+ score_val = record["judge_scores"].get(spec.name)
4092
+ row.append(f"{score_val:.3f}" if isinstance(score_val, int | float) else "-")
4093
+ rows.append(row)
4094
+
4095
+ widths = [len(col) for col in header]
4096
+ for row in rows:
4097
+ for idx, cell in enumerate(row):
4098
+ widths[idx] = max(widths[idx], len(cell))
4099
+
4100
+ click.echo("")
4101
+ click.echo(" ".join(h.ljust(widths[idx]) for idx, h in enumerate(header)))
4102
+ click.echo(" ".join("-" * widths[idx] for idx in range(len(header))))
4103
+ for row in rows:
4104
+ click.echo(" ".join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)))
4105
+
4106
+
4107
+
4108
+ @click.command(
4109
+ "filter",
4110
+ help="Export filtered tracing sessions to SFT-ready JSONL based on a TOML config.",
4111
+ )
4112
+ @click.option(
4113
+ "--config",
4114
+ "config_path",
4115
+ type=click.Path(),
4116
+ required=True,
4117
+ help="Path to TOML config describing the input trace DB, score thresholds, and output JSONL.",
4118
+ )
4119
+ def filter_command(config_path: str) -> None:
4120
+ """Render tracing sessions that match filter rules into SFT JSONL.
4121
+
4122
+ The TOML file should contain a `[filter]` table with at least:
4123
+
4124
+ db = \"path/to/traces.db\" # sqlite path or URL (sqlite+aiosqlite://...)
4125
+ output = \"ft_data/out.jsonl\" # destination JSONL
4126
+
4127
+ Optional keys such as `splits`, `task_ids`, `models`, `min_official_score`, or
4128
+ `min_judge_scores.my_judge = 0.7` allow you to narrow the dataset down to
4129
+ high-quality traces. See `customers/agora_single_file/configs/filter_local.toml`
4130
+ for a working example.
4131
+ """
4132
+ if _toml is None:
4133
+ raise click.ClickException("TOML parser not available; install tomli or use Python 3.11+")
4134
+
4135
+ cfg_path = Path(config_path)
4136
+ if not cfg_path.exists():
4137
+ raise click.ClickException(f"Filter config not found: {cfg_path}")
4138
+
4139
+ try:
4140
+ config_data = _toml.loads(cfg_path.read_text(encoding="utf-8"))
4141
+ except Exception as exc:
4142
+ raise click.ClickException(f"Failed to parse TOML '{cfg_path}': {exc}") from exc
4143
+
4144
+ filter_cfg = config_data.get("filter") if isinstance(config_data, dict) else None
4145
+ if not isinstance(filter_cfg, dict):
4146
+ raise click.ClickException("Config must contain a [filter] table")
4147
+
4148
+ db_value = str(filter_cfg.get("db", "traces/v3/eval_traces.db")).strip()
4149
+ if not db_value:
4150
+ raise click.ClickException("filter.db must be provided")
4151
+ if "://" in db_value:
4152
+ db_url = db_value
4153
+ else:
4154
+ db_path = Path(db_value).expanduser()
4155
+ db_path.parent.mkdir(parents=True, exist_ok=True)
4156
+ db_url = f"sqlite+aiosqlite:///{db_path}"
4157
+
4158
+ output_value = filter_cfg.get("output")
4159
+ if not output_value:
4160
+ raise click.ClickException("filter.output must be provided")
4161
+ output_path = Path(str(output_value)).expanduser()
4162
+
4163
+ splits = set(filter_cfg.get("splits", []) or [])
4164
+ task_ids = set(filter_cfg.get("task_ids", []) or [])
4165
+ models = set(filter_cfg.get("models", []) or [])
4166
+ min_official = filter_cfg.get("min_official_score")
4167
+ max_official = filter_cfg.get("max_official_score")
4168
+ if min_official is not None:
4169
+ try:
4170
+ min_official = float(min_official)
4171
+ except Exception as err:
4172
+ raise click.ClickException("filter.min_official_score must be numeric") from err
4173
+ if max_official is not None:
4174
+ try:
4175
+ max_official = float(max_official)
4176
+ except Exception as err:
4177
+ raise click.ClickException("filter.max_official_score must be numeric") from err
4178
+ min_judge_scores = filter_cfg.get("min_judge_scores", {}) or {}
4179
+ max_judge_scores = filter_cfg.get("max_judge_scores", {}) or {}
4180
+ try:
4181
+ min_judge_scores = {k: float(v) for k, v in min_judge_scores.items()}
4182
+ except Exception as err:
4183
+ raise click.ClickException("filter.min_judge_scores values must be numeric") from err
4184
+ try:
4185
+ max_judge_scores = {k: float(v) for k, v in max_judge_scores.items()}
4186
+ except Exception as err:
4187
+ raise click.ClickException("filter.max_judge_scores values must be numeric") from err
4188
+ min_created = _parse_datetime_for_trace(filter_cfg.get("min_created_at"))
4189
+ max_created = _parse_datetime_for_trace(filter_cfg.get("max_created_at"))
4190
+ limit = filter_cfg.get("limit")
4191
+ if limit is not None:
4192
+ try:
4193
+ limit = int(limit)
4194
+ except Exception as err:
4195
+ raise click.ClickException("filter.limit must be an integer") from err
4196
+
4197
+ def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
4198
+ try:
4199
+ if value is None:
4200
+ return min_val is None
4201
+ value = float(value)
4202
+ except Exception:
4203
+ return False
4204
+ if min_val is not None and value < float(min_val):
4205
+ return False
4206
+ return not (max_val is not None and value > float(max_val))
4207
+
4208
+ async def _run_filter() -> None:
4209
+ tracer = SessionTracer(db_url=db_url, auto_save=False)
4210
+ await tracer.initialize()
4211
+
4212
+ df = await tracer.db.query_traces(
4213
+ "SELECT session_id, created_at, metadata FROM session_traces ORDER BY created_at"
4214
+ )
4215
+ if getattr(df, "empty", True):
4216
+ raise click.ClickException("No traces found in database")
4217
+
4218
+ sessions = df.to_dict("records")
4219
+ accepted: list[dict[str, Any]] = []
4220
+
4221
+ for row in sessions:
4222
+ metadata_raw = row.get("metadata")
4223
+ if isinstance(metadata_raw, str):
4224
+ try:
4225
+ metadata = json.loads(metadata_raw)
4226
+ except Exception:
4227
+ metadata = {}
4228
+ elif isinstance(metadata_raw, dict):
4229
+ metadata = dict(metadata_raw)
4230
+ else:
4231
+ metadata = {}
4232
+
4233
+ created_at_raw = row.get("created_at")
4234
+ created_at_dt = _parse_datetime_for_trace(created_at_raw)
4235
+
4236
+ session_id = row.get("session_id")
4237
+
4238
+ if splits and metadata.get("task_split") not in splits:
4239
+ continue
4240
+ if task_ids and metadata.get("task_id") not in task_ids:
4241
+ continue
4242
+ if models and metadata.get("model") not in models:
4243
+ continue
4244
+
4245
+ if min_created and (created_at_dt is None or created_at_dt < min_created):
4246
+ continue
4247
+ if max_created and (created_at_dt is None or created_at_dt > max_created):
4248
+ continue
4249
+
4250
+ if not _score_ok(metadata.get("official_score"), min_official, max_official):
4251
+ continue
4252
+
4253
+ judge_scores = metadata.get("judge_scores") or {}
4254
+ include = True
4255
+ for judge_name, threshold in (min_judge_scores or {}).items():
4256
+ if not _score_ok(judge_scores.get(judge_name), threshold, None):
4257
+ include = False
4258
+ break
4259
+ if not include:
4260
+ continue
4261
+ for judge_name, threshold in (max_judge_scores or {}).items():
4262
+ if not _score_ok(judge_scores.get(judge_name), None, threshold):
4263
+ include = False
4264
+ break
4265
+ if not include:
4266
+ continue
4267
+
4268
+ prompt = metadata.get("prompt") or ""
4269
+ completion = metadata.get("completion") or ""
4270
+ if not prompt or not completion:
4271
+ continue
4272
+
4273
+ record = {
4274
+ "messages": [
4275
+ {"role": "user", "content": str(prompt)},
4276
+ {"role": "assistant", "content": str(completion)},
4277
+ ],
4278
+ "metadata": {
4279
+ "session_id": session_id,
4280
+ "task_id": metadata.get("task_id"),
4281
+ "task_split": metadata.get("task_split"),
4282
+ "task_rubric_id": metadata.get("task_rubric_id"),
4283
+ "official_score": metadata.get("official_score"),
4284
+ "judge_scores": judge_scores,
4285
+ "model": metadata.get("model"),
4286
+ "created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
4287
+ "prompt": prompt,
4288
+ "completion": completion,
4289
+ },
4290
+ }
4291
+ accepted.append(record)
4292
+
4293
+ if not accepted:
4294
+ raise click.ClickException("No sessions matched the provided filters")
4295
+
4296
+ if limit is not None and limit > 0:
4297
+ accepted = accepted[:limit]
4298
+
4299
+ output_path.parent.mkdir(parents=True, exist_ok=True)
4300
+ with output_path.open("w", encoding="utf-8") as handle:
4301
+ for item in accepted:
4302
+ handle.write(json.dumps(item, ensure_ascii=False))
4303
+ handle.write("\n")
4304
+
4305
+ click.echo(f"Wrote {len(accepted)} examples -> {output_path}")
4306
+ await tracer.db.close()
4307
+
4308
+ asyncio.run(_run_filter())
4309
+
2789
4310
 
2790
4311
  def register_eval(cli: click.Group) -> None:
2791
4312
  cli.add_command(eval_command)