synth-ai 0.2.16__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +2 -2
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +27 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +43 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +91 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/verilog_rl_lora.toml +80 -123
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
- examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
- examples/qwen_coder/configs/coder_lora_small.toml +1 -3
- examples/qwen_vl/README.md +10 -12
- examples/qwen_vl/SETUP_COMPLETE.md +7 -8
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
- examples/qwen_vl/collect_data_via_cli.md +76 -84
- examples/qwen_vl/collect_vision_traces.py +4 -4
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
- examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
- examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
- examples/qwen_vl/run_vision_comparison.sh +6 -7
- examples/rl/README.md +5 -5
- examples/rl/configs/rl_from_base_qwen.toml +26 -1
- examples/rl/configs/rl_from_base_qwen17.toml +6 -2
- examples/rl/task_app/README.md +1 -2
- examples/rl/task_app/math_single_step.py +2 -2
- examples/run_crafter_demo.sh +2 -2
- examples/sft/README.md +1 -1
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
- examples/swe/task_app/README.md +32 -2
- examples/swe/task_app/grpo_swe_mini.py +4 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
- examples/swe/task_app/hosted/inference/openai_client.py +4 -38
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/README.md +1 -1
- examples/task_apps/crafter/task_app/grpo_crafter.py +90 -5
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +372 -107
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +81 -12
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +82 -11
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/math/README.md +1 -2
- examples/task_apps/pokemon_red/README.md +3 -4
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
- examples/task_apps/pokemon_red/task_app.py +288 -39
- examples/task_apps/sokoban/README.md +2 -3
- examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
- examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +3 -2
- examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +1 -1
- examples/warming_up_to_rl/task_app/grpo_crafter.py +185 -5
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +156 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +37 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +6 -0
- synth_ai/api/train/builders.py +99 -4
- synth_ai/api/train/cli.py +516 -26
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +23 -2
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +61 -7
- synth_ai/api/train/configs/sft.py +6 -2
- synth_ai/api/train/configs/shared.py +59 -2
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/auth/credentials.py +119 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +94 -18
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +18 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1112 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +424 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +177 -0
- synth_ai/cli/commands/help/core.py +72 -0
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/__init__.py +64 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +200 -0
- synth_ai/cli/commands/train/judge_validation.py +305 -0
- synth_ai/cli/commands/train/validation.py +386 -0
- synth_ai/cli/demo.py +30 -158
- synth_ai/cli/deploy/__init__.py +43 -0
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +20 -265
- synth_ai/cli/status.py +7 -126
- synth_ai/cli/task_app_deploy.py +1 -10
- synth_ai/cli/task_app_modal_serve.py +4 -9
- synth_ai/cli/task_app_serve.py +4 -11
- synth_ai/cli/task_apps.py +51 -1480
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +1 -14
- synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/environments/examples/red/engine.py +33 -12
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/environment.py +26 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/http.py +12 -0
- synth_ai/judge_schemas.py +10 -10
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/learning/rl/client.py +3 -1
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +518 -0
- synth_ai/streaming/streamer.py +320 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +45 -9
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +40 -33
- synth_ai/utils/http.py +4 -1
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +285 -3
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/METADATA +109 -6
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/RECORD +291 -142
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
- synth_ai/cli/tui.py +0 -62
- synth_ai/tui/__init__.py +0 -5
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -911
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
synth_ai/cli/task_apps.py
CHANGED
|
@@ -9,19 +9,16 @@ import hashlib
|
|
|
9
9
|
import importlib
|
|
10
10
|
import importlib.util
|
|
11
11
|
import inspect
|
|
12
|
-
import json
|
|
13
12
|
import os
|
|
14
13
|
import shlex
|
|
15
14
|
import shutil
|
|
16
15
|
import signal
|
|
17
|
-
import sqlite3
|
|
18
16
|
import subprocess
|
|
19
17
|
import sys
|
|
20
18
|
import tempfile
|
|
21
19
|
import textwrap
|
|
22
20
|
import time
|
|
23
21
|
import types
|
|
24
|
-
import uuid
|
|
25
22
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
26
23
|
from dataclasses import dataclass
|
|
27
24
|
from datetime import UTC, datetime
|
|
@@ -35,6 +32,8 @@ except Exception: # pragma: no cover - fallback
|
|
|
35
32
|
|
|
36
33
|
import click
|
|
37
34
|
from click.exceptions import Abort
|
|
35
|
+
from synth_ai.cli.commands.eval import core as eval_core
|
|
36
|
+
from synth_ai.cli.commands.filter import core as filter_core
|
|
38
37
|
|
|
39
38
|
# Tracing imports - make conditional for optional dependencies
|
|
40
39
|
try:
|
|
@@ -269,20 +268,25 @@ def _markov_message_from_dict(payload: dict[str, Any]) -> SessionEventMarkovBlan
|
|
|
269
268
|
json_payload=content_payload.get("json_payload"),
|
|
270
269
|
)
|
|
271
270
|
raw_type = (payload.get("message_type") or "").lower()
|
|
272
|
-
|
|
271
|
+
original_type = payload.get("message_type") or raw_type
|
|
272
|
+
|
|
273
|
+
if raw_type in ("observation", "policy_system_prompt"):
|
|
273
274
|
normalized_type = "system"
|
|
274
|
-
elif raw_type
|
|
275
|
+
elif raw_type in ("action", "policy_tool_call"):
|
|
275
276
|
normalized_type = "assistant"
|
|
276
277
|
elif raw_type in {"user", "assistant", "system", "tool_use", "tool_result"}:
|
|
277
278
|
normalized_type = raw_type
|
|
278
279
|
else:
|
|
279
280
|
normalized_type = "system"
|
|
280
281
|
|
|
282
|
+
metadata = dict(payload.get("metadata") or {})
|
|
283
|
+
metadata["original_message_type"] = original_type
|
|
284
|
+
|
|
281
285
|
return SessionEventMarkovBlanketMessage(
|
|
282
286
|
content=content,
|
|
283
287
|
message_type=normalized_type,
|
|
284
288
|
time_record=_time_record_from_dict(payload.get("time_record")),
|
|
285
|
-
metadata=
|
|
289
|
+
metadata=metadata,
|
|
286
290
|
)
|
|
287
291
|
|
|
288
292
|
|
|
@@ -506,49 +510,6 @@ def _candidate_search_roots() -> list[Path]:
|
|
|
506
510
|
return ordered
|
|
507
511
|
|
|
508
512
|
|
|
509
|
-
def _eval_config_sort_key(path: Path) -> tuple[int, int, int, str]:
|
|
510
|
-
name = path.name.lower()
|
|
511
|
-
parent_names = {p.name.lower() for p in path.parents}
|
|
512
|
-
in_configs = 0 if "configs" in parent_names else 1
|
|
513
|
-
in_examples = 0 if "examples" in parent_names else 1
|
|
514
|
-
starts_eval = 0 if name.startswith("eval") else 1
|
|
515
|
-
return (in_configs, in_examples, starts_eval, str(path))
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
def _discover_eval_config_paths() -> list[Path]:
|
|
519
|
-
"""Find candidate eval TOML files near the current working directory."""
|
|
520
|
-
|
|
521
|
-
candidates: list[Path] = []
|
|
522
|
-
seen: set[Path] = set()
|
|
523
|
-
search_roots = _candidate_search_roots()
|
|
524
|
-
for root in search_roots:
|
|
525
|
-
if not root.exists() or not root.is_dir():
|
|
526
|
-
continue
|
|
527
|
-
try:
|
|
528
|
-
root = root.resolve()
|
|
529
|
-
except Exception:
|
|
530
|
-
continue
|
|
531
|
-
for path in root.rglob("*.toml"):
|
|
532
|
-
if not path.is_file():
|
|
533
|
-
continue
|
|
534
|
-
if _should_ignore_path(path):
|
|
535
|
-
continue
|
|
536
|
-
name_lower = path.name.lower()
|
|
537
|
-
if "eval" not in name_lower and "evaluation" not in name_lower:
|
|
538
|
-
continue
|
|
539
|
-
try:
|
|
540
|
-
resolved = path.resolve()
|
|
541
|
-
except Exception:
|
|
542
|
-
continue
|
|
543
|
-
if resolved in seen:
|
|
544
|
-
continue
|
|
545
|
-
seen.add(resolved)
|
|
546
|
-
candidates.append(resolved)
|
|
547
|
-
|
|
548
|
-
candidates.sort(key=_eval_config_sort_key)
|
|
549
|
-
return candidates
|
|
550
|
-
|
|
551
|
-
|
|
552
513
|
class _TaskAppConfigVisitor(ast.NodeVisitor):
|
|
553
514
|
def __init__(self) -> None:
|
|
554
515
|
self.matches: list[tuple[str, int]] = []
|
|
@@ -2264,7 +2225,6 @@ def validate_task_app_cmd(
|
|
|
2264
2225
|
• Debug failing deployments: Use --verbose to see detailed endpoint responses
|
|
2265
2226
|
• Test API key configuration: Verify authentication is set up correctly
|
|
2266
2227
|
"""
|
|
2267
|
-
import asyncio
|
|
2268
2228
|
import socket
|
|
2269
2229
|
import subprocess
|
|
2270
2230
|
import tempfile
|
|
@@ -2471,49 +2431,7 @@ def serve_command(
|
|
|
2471
2431
|
trace_dir: str | None,
|
|
2472
2432
|
trace_db: str | None,
|
|
2473
2433
|
) -> None:
|
|
2474
|
-
|
|
2475
|
-
if demo_dir_path:
|
|
2476
|
-
if not demo_dir_path.is_dir():
|
|
2477
|
-
raise click.ClickException(
|
|
2478
|
-
f"Demo directory not found: {demo_dir_path}\nRun 'synth-ai setup' to create a demo."
|
|
2479
|
-
)
|
|
2480
|
-
os.chdir(demo_dir_path)
|
|
2481
|
-
click.echo(f"Using demo directory: {demo_dir_path}\n")
|
|
2482
|
-
os.environ["SYNTH_DEMO_DIR"] = str(demo_dir_path.resolve())
|
|
2483
|
-
|
|
2484
|
-
# Prompt for port if not provided
|
|
2485
|
-
if port is None:
|
|
2486
|
-
port = click.prompt("Port to serve on", type=int, default=8001)
|
|
2487
|
-
|
|
2488
|
-
# Prompt for trace directory if not provided
|
|
2489
|
-
if trace_dir is None:
|
|
2490
|
-
click.echo(
|
|
2491
|
-
"\nTracing captures rollout data (actions, rewards, model outputs) to a local SQLite DB."
|
|
2492
|
-
)
|
|
2493
|
-
click.echo("This data can be exported to JSONL for supervised fine-tuning (SFT).")
|
|
2494
|
-
enable_tracing = click.confirm("Enable tracing?", default=True)
|
|
2495
|
-
if enable_tracing:
|
|
2496
|
-
demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
|
|
2497
|
-
default_trace_dir = str((demo_base / "traces/v3").resolve())
|
|
2498
|
-
trace_dir = click.prompt(
|
|
2499
|
-
"Trace directory", type=str, default=default_trace_dir, show_default=True
|
|
2500
|
-
)
|
|
2501
|
-
else:
|
|
2502
|
-
trace_dir = None
|
|
2503
|
-
|
|
2504
|
-
# Prompt for trace DB if not provided and tracing is enabled
|
|
2505
|
-
if trace_dir and trace_db is None:
|
|
2506
|
-
demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
|
|
2507
|
-
default_trace_db = str((demo_base / "traces/v3/synth_ai.db").resolve())
|
|
2508
|
-
trace_db = click.prompt(
|
|
2509
|
-
"Trace DB path", type=str, default=default_trace_db, show_default=True
|
|
2510
|
-
)
|
|
2511
|
-
|
|
2512
|
-
choice = _select_app_choice(app_id, purpose="serve")
|
|
2513
|
-
entry = choice.ensure_entry()
|
|
2514
|
-
_serve_entry(
|
|
2515
|
-
entry, host, port, env_file, reload_flag, force, trace_dir=trace_dir, trace_db=trace_db
|
|
2516
|
-
)
|
|
2434
|
+
return None
|
|
2517
2435
|
|
|
2518
2436
|
|
|
2519
2437
|
@task_app_group.command("info")
|
|
@@ -2625,51 +2543,53 @@ def serve_task_group(
|
|
|
2625
2543
|
trace_dir: str | None,
|
|
2626
2544
|
trace_db: str | None,
|
|
2627
2545
|
) -> None:
|
|
2628
|
-
|
|
2629
|
-
|
|
2630
|
-
|
|
2631
|
-
|
|
2632
|
-
|
|
2633
|
-
|
|
2634
|
-
os.chdir(demo_dir_path)
|
|
2635
|
-
click.echo(f"Using demo directory: {demo_dir_path}\n")
|
|
2636
|
-
os.environ["SYNTH_DEMO_DIR"] = str(demo_dir_path.resolve())
|
|
2637
|
-
|
|
2638
|
-
# Prompt for port if not provided
|
|
2546
|
+
"""Serve a TaskAppConfig-based task app using uvicorn."""
|
|
2547
|
+
import contextlib
|
|
2548
|
+
|
|
2549
|
+
if not host:
|
|
2550
|
+
host = "0.0.0.0"
|
|
2551
|
+
|
|
2639
2552
|
if port is None:
|
|
2640
|
-
port =
|
|
2641
|
-
|
|
2642
|
-
#
|
|
2643
|
-
|
|
2644
|
-
|
|
2645
|
-
|
|
2646
|
-
|
|
2647
|
-
|
|
2648
|
-
enable_tracing = click.confirm("Enable tracing?", default=True)
|
|
2649
|
-
if enable_tracing:
|
|
2650
|
-
demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
|
|
2651
|
-
default_trace_dir = str((demo_base / "traces/v3").resolve())
|
|
2652
|
-
trace_dir = click.prompt(
|
|
2653
|
-
"Trace directory", type=str, default=default_trace_dir, show_default=True
|
|
2654
|
-
)
|
|
2655
|
-
else:
|
|
2656
|
-
trace_dir = None
|
|
2553
|
+
port = 8001
|
|
2554
|
+
|
|
2555
|
+
# Auto-enable tracing by default
|
|
2556
|
+
try:
|
|
2557
|
+
auto_trace = os.getenv("SYNTH_AUTO_TRACE", "1")
|
|
2558
|
+
auto_trace_enabled = auto_trace not in {"0", "false", "False", ""}
|
|
2559
|
+
except Exception:
|
|
2560
|
+
auto_trace_enabled = True
|
|
2657
2561
|
|
|
2658
|
-
|
|
2659
|
-
if trace_dir and trace_db is None:
|
|
2562
|
+
if auto_trace_enabled:
|
|
2660
2563
|
demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
|
|
2661
|
-
|
|
2662
|
-
|
|
2663
|
-
|
|
2664
|
-
|
|
2665
|
-
|
|
2564
|
+
if trace_dir is None:
|
|
2565
|
+
default_trace_dir = (demo_base / "traces" / "v3").resolve()
|
|
2566
|
+
with contextlib.suppress(Exception):
|
|
2567
|
+
default_trace_dir.mkdir(parents=True, exist_ok=True)
|
|
2568
|
+
trace_dir = str(default_trace_dir)
|
|
2569
|
+
click.echo(f"[trace] Using trace directory: {trace_dir}")
|
|
2570
|
+
if trace_dir and trace_db is None:
|
|
2571
|
+
default_trace_db = (Path(trace_dir) / "synth_ai.db").resolve()
|
|
2572
|
+
with contextlib.suppress(Exception):
|
|
2573
|
+
default_trace_db.parent.mkdir(parents=True, exist_ok=True)
|
|
2574
|
+
trace_db = str(default_trace_db)
|
|
2575
|
+
click.echo(f"[trace] Using trace DB: {trace_db}")
|
|
2576
|
+
|
|
2577
|
+
# Select and serve the app
|
|
2666
2578
|
choice = _select_app_choice(app_id, purpose="serve")
|
|
2667
2579
|
entry = choice.ensure_entry()
|
|
2668
2580
|
_serve_entry(
|
|
2669
|
-
entry,
|
|
2581
|
+
entry,
|
|
2582
|
+
host,
|
|
2583
|
+
port,
|
|
2584
|
+
env_file,
|
|
2585
|
+
reload_flag,
|
|
2586
|
+
force,
|
|
2587
|
+
trace_dir=trace_dir,
|
|
2588
|
+
trace_db=trace_db,
|
|
2670
2589
|
)
|
|
2671
2590
|
|
|
2672
2591
|
|
|
2592
|
+
|
|
2673
2593
|
def _determine_env_files(
|
|
2674
2594
|
entry: TaskAppEntryType, user_env_files: Sequence[str], *, original_path: Path | None = None
|
|
2675
2595
|
) -> list[Path]:
|
|
@@ -2962,87 +2882,6 @@ def _serve_entry(
|
|
|
2962
2882
|
)
|
|
2963
2883
|
|
|
2964
2884
|
|
|
2965
|
-
@task_app_group.command("deploy")
|
|
2966
|
-
@click.argument("app_id", type=str, required=False)
|
|
2967
|
-
@click.option("--name", "modal_name", default=None, help="Override Modal app name")
|
|
2968
|
-
@click.option("--dry-run", is_flag=True, help="Print modal deploy command without executing")
|
|
2969
|
-
@click.option("--modal-cli", default="modal", help="Path to modal CLI executable")
|
|
2970
|
-
@click.option(
|
|
2971
|
-
"--env-file",
|
|
2972
|
-
multiple=True,
|
|
2973
|
-
type=click.Path(),
|
|
2974
|
-
help="Env file to load into the container (can be repeated)",
|
|
2975
|
-
)
|
|
2976
|
-
def deploy_app(
|
|
2977
|
-
app_id: str | None,
|
|
2978
|
-
modal_name: str | None,
|
|
2979
|
-
dry_run: bool,
|
|
2980
|
-
modal_cli: str,
|
|
2981
|
-
env_file: Sequence[str],
|
|
2982
|
-
) -> None:
|
|
2983
|
-
"""Deploy a task app to Modal."""
|
|
2984
|
-
|
|
2985
|
-
demo_dir_path = _load_demo_directory()
|
|
2986
|
-
if demo_dir_path:
|
|
2987
|
-
if not demo_dir_path.is_dir():
|
|
2988
|
-
raise click.ClickException(
|
|
2989
|
-
f"Demo directory not found: {demo_dir_path}\nRun 'synth-ai demo' to create a demo."
|
|
2990
|
-
)
|
|
2991
|
-
os.chdir(demo_dir_path)
|
|
2992
|
-
click.echo(f"Using demo directory: {demo_dir_path}\n")
|
|
2993
|
-
|
|
2994
|
-
choice = _select_app_choice(app_id, purpose="deploy")
|
|
2995
|
-
|
|
2996
|
-
if choice.modal_script:
|
|
2997
|
-
env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
|
|
2998
|
-
click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
|
|
2999
|
-
_run_modal_script(
|
|
3000
|
-
choice.modal_script,
|
|
3001
|
-
modal_cli,
|
|
3002
|
-
"deploy",
|
|
3003
|
-
env_paths,
|
|
3004
|
-
modal_name=modal_name,
|
|
3005
|
-
dry_run=dry_run,
|
|
3006
|
-
)
|
|
3007
|
-
return
|
|
3008
|
-
|
|
3009
|
-
entry = choice.ensure_entry()
|
|
3010
|
-
_deploy_entry(entry, modal_name, dry_run, modal_cli, env_file, original_path=choice.path)
|
|
3011
|
-
|
|
3012
|
-
|
|
3013
|
-
@task_app_group.command("modal-serve")
|
|
3014
|
-
@click.argument("app_id", type=str, required=False)
|
|
3015
|
-
@click.option("--modal-cli", default="modal", help="Path to modal CLI executable")
|
|
3016
|
-
@click.option("--name", "modal_name", default=None, help="Override Modal app name (optional)")
|
|
3017
|
-
@click.option(
|
|
3018
|
-
"--env-file",
|
|
3019
|
-
multiple=True,
|
|
3020
|
-
type=click.Path(),
|
|
3021
|
-
help="Env file to load into the container (can be repeated)",
|
|
3022
|
-
)
|
|
3023
|
-
def modal_serve_app(
|
|
3024
|
-
app_id: str | None, modal_cli: str, modal_name: str | None, env_file: Sequence[str]
|
|
3025
|
-
) -> None:
|
|
3026
|
-
click.echo(f"[modal-serve] requested app_id={app_id or '(auto)'} modal_cli={modal_cli}")
|
|
3027
|
-
try:
|
|
3028
|
-
choice = _select_app_choice(app_id, purpose="modal-serve")
|
|
3029
|
-
except SystemExit as exc: # bubble up with context (legacy argparse would trigger this)
|
|
3030
|
-
raise click.ClickException(
|
|
3031
|
-
f"Legacy CLI intercepted modal-serve (exit {exc.code}). "
|
|
3032
|
-
"Make sure you're running the Click CLI (synth_ai.cli:cli)."
|
|
3033
|
-
) from exc
|
|
3034
|
-
|
|
3035
|
-
if choice.modal_script:
|
|
3036
|
-
env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
|
|
3037
|
-
click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
|
|
3038
|
-
_run_modal_script(choice.modal_script, modal_cli, "serve", env_paths, modal_name=modal_name)
|
|
3039
|
-
return
|
|
3040
|
-
|
|
3041
|
-
entry = choice.ensure_entry()
|
|
3042
|
-
click.echo(f"[modal-serve] serving entry {entry.app_id} from {choice.path}")
|
|
3043
|
-
_modal_serve_entry(entry, modal_name, modal_cli, env_file, original_path=choice.path)
|
|
3044
|
-
|
|
3045
|
-
|
|
3046
2885
|
def _write_modal_entrypoint(
|
|
3047
2886
|
entry: TaskAppEntryType,
|
|
3048
2887
|
modal_cfg: ModalDeploymentConfigType,
|
|
@@ -3286,1277 +3125,9 @@ def register(cli: click.Group) -> None:
|
|
|
3286
3125
|
cli.add_command(filter_command)
|
|
3287
3126
|
|
|
3288
3127
|
|
|
3289
|
-
|
|
3290
|
-
"eval",
|
|
3291
|
-
help="Run one-off rollouts against a task app and print judge/eval summaries.",
|
|
3292
|
-
)
|
|
3293
|
-
@click.argument("app_id", type=str, required=False)
|
|
3294
|
-
@click.option(
|
|
3295
|
-
"--config",
|
|
3296
|
-
type=click.Path(),
|
|
3297
|
-
default=None,
|
|
3298
|
-
help="Path to eval TOML (short schema). Auto-discovers the first matching file when omitted.",
|
|
3299
|
-
)
|
|
3300
|
-
@click.option(
|
|
3301
|
-
"--url",
|
|
3302
|
-
"task_app_url",
|
|
3303
|
-
type=str,
|
|
3304
|
-
default=None,
|
|
3305
|
-
help="Base URL of a running task app instead of spawning locally (requires --env-file for secrets).",
|
|
3306
|
-
)
|
|
3307
|
-
@click.option(
|
|
3308
|
-
"--seeds",
|
|
3309
|
-
default="0,1,2,3,4",
|
|
3310
|
-
help="Comma-separated seeds/indices to evaluate. Use negative numbers to wrap around the dataset.",
|
|
3311
|
-
)
|
|
3312
|
-
@click.option("--split", default="train", show_default=True, help="Dataset split to use")
|
|
3313
|
-
@click.option(
|
|
3314
|
-
"--model",
|
|
3315
|
-
default=None,
|
|
3316
|
-
help="Model identifier. When omitted the CLI will prompt based on task metadata.",
|
|
3317
|
-
)
|
|
3318
|
-
@click.option(
|
|
3319
|
-
"--env-file",
|
|
3320
|
-
multiple=True,
|
|
3321
|
-
type=click.Path(),
|
|
3322
|
-
help="Env file(s) to load (API keys, etc.). Required when using --url or remote judges.",
|
|
3323
|
-
)
|
|
3324
|
-
@click.option(
|
|
3325
|
-
"--trace-db",
|
|
3326
|
-
default="traces/v3/synth_ai.db",
|
|
3327
|
-
show_default=True,
|
|
3328
|
-
help="SQLite/Turso URL for storing rollout traces set to 'none' to disable persistence.",
|
|
3329
|
-
)
|
|
3330
|
-
@click.option(
|
|
3331
|
-
"--metadata",
|
|
3332
|
-
multiple=True,
|
|
3333
|
-
help="Filter tasks by key=value metadata (e.g., --metadata difficulty=easy)",
|
|
3334
|
-
)
|
|
3335
|
-
@click.option(
|
|
3336
|
-
"--metadata-sql",
|
|
3337
|
-
default=None,
|
|
3338
|
-
help="SQLite query that returns seeds to evaluate (e.g., SELECT seed FROM tasks WHERE difficulty='easy' LIMIT 5)",
|
|
3339
|
-
)
|
|
3340
|
-
def eval_command(
|
|
3341
|
-
app_id: str | None,
|
|
3342
|
-
config: str | None,
|
|
3343
|
-
task_app_url: str | None,
|
|
3344
|
-
seeds: str,
|
|
3345
|
-
split: str,
|
|
3346
|
-
model: str | None,
|
|
3347
|
-
env_file: Sequence[str],
|
|
3348
|
-
trace_db: str,
|
|
3349
|
-
metadata: Sequence[str],
|
|
3350
|
-
metadata_sql: str | None,
|
|
3351
|
-
) -> None:
|
|
3352
|
-
"""Run rollouts against a task app and report judge statistics.
|
|
3353
|
-
|
|
3354
|
-
By default the command spins up the selected task app in-process, executes the
|
|
3355
|
-
requested seeds, and prints aggregate scores (official and custom judges). When
|
|
3356
|
-
pointing at a remote `--url`, supply matching `--env-file` values so the CLI can
|
|
3357
|
-
forward authentication headers to the running service.
|
|
3358
|
-
"""
|
|
3359
|
-
# Parse and validate TOML config
|
|
3360
|
-
from synth_ai.task.config import EvalConfig
|
|
3361
|
-
|
|
3362
|
-
cfg: dict[str, Any] = {}
|
|
3363
|
-
eval_cfg: EvalConfig | None = None
|
|
3364
|
-
config_path: Path | None = None
|
|
3365
|
-
|
|
3366
|
-
if config:
|
|
3367
|
-
config_path = Path(config)
|
|
3368
|
-
else:
|
|
3369
|
-
auto_configs = _discover_eval_config_paths()
|
|
3370
|
-
if auto_configs:
|
|
3371
|
-
config_path = auto_configs[0]
|
|
3372
|
-
click.echo(f"Using eval config: {config_path}")
|
|
3373
|
-
|
|
3374
|
-
if config_path:
|
|
3375
|
-
if _toml is None:
|
|
3376
|
-
raise click.ClickException(
|
|
3377
|
-
"TOML parser not available; use Python 3.11+ or install tomli"
|
|
3378
|
-
)
|
|
3379
|
-
if not config_path.exists():
|
|
3380
|
-
raise click.ClickException(f"Eval config not found: {config_path}")
|
|
3381
|
-
try:
|
|
3382
|
-
data = config_path.read_bytes()
|
|
3383
|
-
parsed = _toml.loads(data.decode("utf-8"))
|
|
3384
|
-
if isinstance(parsed, dict):
|
|
3385
|
-
section = parsed.get("eval")
|
|
3386
|
-
cfg = dict(section) if isinstance(section, dict) else dict(parsed)
|
|
3387
|
-
|
|
3388
|
-
# Validate config with dataclass
|
|
3389
|
-
try:
|
|
3390
|
-
eval_cfg = EvalConfig.from_dict(cfg)
|
|
3391
|
-
click.echo(f"✓ Config validated: {len(eval_cfg.seeds)} seeds, model={eval_cfg.model}")
|
|
3392
|
-
except (ValueError, TypeError) as validation_error:
|
|
3393
|
-
raise click.ClickException(f"Invalid eval config: {validation_error}") from validation_error
|
|
3394
|
-
except click.ClickException:
|
|
3395
|
-
raise
|
|
3396
|
-
except Exception as exc:
|
|
3397
|
-
raise click.ClickException(f"Failed to parse TOML '{config_path}': {exc}") from exc
|
|
3398
|
-
|
|
3399
|
-
# CLI args override config
|
|
3400
|
-
if eval_cfg:
|
|
3401
|
-
app_id = app_id or eval_cfg.app_id
|
|
3402
|
-
else:
|
|
3403
|
-
app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
|
|
3404
|
-
|
|
3405
|
-
metadata_filters: dict[str, str] = {}
|
|
3406
|
-
if eval_cfg:
|
|
3407
|
-
metadata_filters.update(eval_cfg.metadata)
|
|
3408
|
-
else:
|
|
3409
|
-
cfg_metadata = cfg.get("metadata")
|
|
3410
|
-
if isinstance(cfg_metadata, dict):
|
|
3411
|
-
for key, value in cfg_metadata.items():
|
|
3412
|
-
metadata_filters[str(key)] = str(value)
|
|
3413
|
-
elif isinstance(cfg_metadata, list):
|
|
3414
|
-
for item in cfg_metadata:
|
|
3415
|
-
if isinstance(item, str) and "=" in item:
|
|
3416
|
-
key, value = item.split("=", 1)
|
|
3417
|
-
metadata_filters[key.strip()] = value.strip()
|
|
3418
|
-
|
|
3419
|
-
for item in metadata or ():
|
|
3420
|
-
if "=" not in item:
|
|
3421
|
-
raise click.ClickException(f"Metadata filters must be key=value (got: {item})")
|
|
3422
|
-
key, value = item.split("=", 1)
|
|
3423
|
-
key = key.strip()
|
|
3424
|
-
value = value.strip()
|
|
3425
|
-
if not key or not value:
|
|
3426
|
-
raise click.ClickException(f"Invalid metadata filter: {item}")
|
|
3427
|
-
metadata_filters[key] = value
|
|
3428
|
-
|
|
3429
|
-
metadata_sql_query: str | None = None
|
|
3430
|
-
if eval_cfg and eval_cfg.metadata_sql:
|
|
3431
|
-
metadata_sql_query = eval_cfg.metadata_sql
|
|
3432
|
-
else:
|
|
3433
|
-
cfg_metadata_sql = cfg.get("metadata_sql")
|
|
3434
|
-
if isinstance(cfg_metadata_sql, dict):
|
|
3435
|
-
metadata_sql_query = cfg_metadata_sql.get("query") or cfg_metadata_sql.get("sql")
|
|
3436
|
-
elif isinstance(cfg_metadata_sql, str):
|
|
3437
|
-
metadata_sql_query = cfg_metadata_sql
|
|
3438
|
-
|
|
3439
|
-
if metadata_sql:
|
|
3440
|
-
metadata_sql_query = metadata_sql
|
|
3441
|
-
if metadata_sql_query is not None:
|
|
3442
|
-
metadata_sql_query = str(metadata_sql_query)
|
|
3443
|
-
|
|
3444
|
-
trace_db_url: str | None = None
|
|
3445
|
-
trace_db = (trace_db or "").strip()
|
|
3446
|
-
if trace_db and trace_db.lower() not in {"none", "off", "disable"}:
|
|
3447
|
-
if "://" in trace_db:
|
|
3448
|
-
trace_db_url = trace_db
|
|
3449
|
-
else:
|
|
3450
|
-
trace_path = Path(trace_db).expanduser()
|
|
3451
|
-
trace_path.parent.mkdir(parents=True, exist_ok=True)
|
|
3452
|
-
trace_db_url = f"sqlite+aiosqlite:///{trace_path}"
|
|
3453
|
-
trace_tracer: SessionTracer | None = SessionTracer(db_url=trace_db_url, auto_save=True) if trace_db_url else None
|
|
3454
|
-
|
|
3455
|
-
# Determine selection params (CLI takes precedence; TOML only fills unset model/seeds/env)
|
|
3456
|
-
if cfg.get("model") and not model:
|
|
3457
|
-
model = str(cfg["model"]) # type: ignore[index]
|
|
3458
|
-
if cfg.get("seeds") and seeds == "0,1,2,3,4":
|
|
3459
|
-
val = cfg["seeds"]
|
|
3460
|
-
if isinstance(val, list):
|
|
3461
|
-
with contextlib.suppress(Exception):
|
|
3462
|
-
seeds = ",".join(str(int(x)) for x in val)
|
|
3463
|
-
elif isinstance(val, str):
|
|
3464
|
-
seeds = val
|
|
3465
|
-
elif isinstance(val, int):
|
|
3466
|
-
seeds = str(val)
|
|
3467
|
-
if cfg.get("env_file") and not env_file:
|
|
3468
|
-
ef = cfg["env_file"]
|
|
3469
|
-
if isinstance(ef, str):
|
|
3470
|
-
env_file = (ef,) # type: ignore[assignment]
|
|
3471
|
-
elif isinstance(ef, list):
|
|
3472
|
-
env_file = tuple(str(x) for x in ef) # type: ignore[assignment]
|
|
3473
|
-
|
|
3474
|
-
choice_for_env: AppChoice | None = None
|
|
3475
|
-
entry: TaskAppEntryType | None = None
|
|
3476
|
-
if task_app_url is None:
|
|
3477
|
-
choice_for_env = _select_app_choice(app_id, purpose="eval")
|
|
3478
|
-
entry = choice_for_env.ensure_entry()
|
|
3479
|
-
|
|
3480
|
-
env_paths: list[Path] = []
|
|
3481
|
-
if entry is not None:
|
|
3482
|
-
original_env_path = choice_for_env.path if choice_for_env is not None else None
|
|
3483
|
-
env_paths = _determine_env_files(entry, env_file, original_path=original_env_path)
|
|
3484
|
-
else:
|
|
3485
|
-
if not env_file:
|
|
3486
|
-
raise click.ClickException("--env-file is required when using --url")
|
|
3487
|
-
for candidate in env_file:
|
|
3488
|
-
p = Path(candidate).expanduser()
|
|
3489
|
-
if not p.exists():
|
|
3490
|
-
raise click.ClickException(f"Env file not found: {p}")
|
|
3491
|
-
env_paths.append(p)
|
|
3492
|
-
|
|
3493
|
-
click.echo("Using env file(s): " + ", ".join(str(p) for p in env_paths))
|
|
3494
|
-
_load_env_files_into_process([str(Path(p)) for p in env_paths])
|
|
3495
|
-
|
|
3496
|
-
if task_app_url is None:
|
|
3497
|
-
config = entry.config_factory() # type: ignore[union-attr]
|
|
3498
|
-
# Help the type checker; runtime check also enforced in server.run_task_app
|
|
3499
|
-
if not isinstance(config, TaskAppConfig):
|
|
3500
|
-
raise click.ClickException(
|
|
3501
|
-
"Invalid task app: config_factory did not return TaskAppConfig"
|
|
3502
|
-
)
|
|
3503
|
-
app = create_task_app(config)
|
|
3504
|
-
|
|
3505
|
-
# Determine supported models
|
|
3506
|
-
inference_meta: dict[str, Any] = {}
|
|
3507
|
-
supported: list[str] = []
|
|
3508
|
-
seen_models: set[str] = set()
|
|
3509
|
-
|
|
3510
|
-
def _add_supported_model(candidate: Any) -> None:
|
|
3511
|
-
if not candidate:
|
|
3512
|
-
return
|
|
3513
|
-
text = str(candidate).strip()
|
|
3514
|
-
if not text or text in seen_models:
|
|
3515
|
-
return
|
|
3516
|
-
supported.append(text)
|
|
3517
|
-
seen_models.add(text)
|
|
3518
|
-
|
|
3519
|
-
if task_app_url is None:
|
|
3520
|
-
try:
|
|
3521
|
-
if hasattr(config, "base_task_info") and config.base_task_info:
|
|
3522
|
-
inf_obj = getattr(config.base_task_info, "inference", None)
|
|
3523
|
-
if inf_obj is not None:
|
|
3524
|
-
if hasattr(inf_obj, "model_dump"):
|
|
3525
|
-
inference_meta = dict(inf_obj.model_dump(exclude_none=True)) # type: ignore[attr-defined]
|
|
3526
|
-
elif isinstance(inf_obj, dict):
|
|
3527
|
-
inference_meta = dict(inf_obj)
|
|
3528
|
-
except Exception:
|
|
3529
|
-
inference_meta = {}
|
|
3530
|
-
else:
|
|
3531
|
-
try:
|
|
3532
|
-
import httpx as _hx
|
|
3533
|
-
|
|
3534
|
-
headers = {}
|
|
3535
|
-
api_key = (os.environ.get("ENVIRONMENT_API_KEY") or "").strip()
|
|
3536
|
-
if api_key:
|
|
3537
|
-
headers["X-API-Key"] = api_key
|
|
3538
|
-
with _hx.Client(base_url=task_app_url, headers=headers, timeout=15.0) as c:
|
|
3539
|
-
info = c.get("/info").json()
|
|
3540
|
-
inf = info.get("inference") if isinstance(info, dict) else None
|
|
3541
|
-
if isinstance(inf, dict):
|
|
3542
|
-
inference_meta = dict(inf)
|
|
3543
|
-
except Exception:
|
|
3544
|
-
inference_meta = {}
|
|
3545
|
-
|
|
3546
|
-
default_model = inference_meta.get("model")
|
|
3547
|
-
if isinstance(default_model, str):
|
|
3548
|
-
_add_supported_model(default_model)
|
|
3549
|
-
|
|
3550
|
-
models_field = inference_meta.get("models")
|
|
3551
|
-
if isinstance(models_field, list):
|
|
3552
|
-
for candidate in models_field:
|
|
3553
|
-
_add_supported_model(candidate)
|
|
3554
|
-
|
|
3555
|
-
supported_models = inference_meta.get("supported_models")
|
|
3556
|
-
if isinstance(supported_models, list):
|
|
3557
|
-
for candidate in supported_models:
|
|
3558
|
-
_add_supported_model(candidate)
|
|
3559
|
-
|
|
3560
|
-
providers = inference_meta.get("providers")
|
|
3561
|
-
if isinstance(providers, list):
|
|
3562
|
-
if "openai" in providers:
|
|
3563
|
-
_add_supported_model("gpt-5")
|
|
3564
|
-
if "groq" in providers:
|
|
3565
|
-
_add_supported_model("groq:llama-3.1-70b-versatile")
|
|
3566
|
-
|
|
3567
|
-
_add_supported_model("synth:qwen-0.6b")
|
|
3568
|
-
|
|
3569
|
-
selected_model = model
|
|
3570
|
-
if not selected_model:
|
|
3571
|
-
if not supported:
|
|
3572
|
-
raise click.ClickException(
|
|
3573
|
-
"No supported models; supply --model or add base_task_info.inference.model"
|
|
3574
|
-
)
|
|
3575
|
-
click.echo("Select model to evaluate:")
|
|
3576
|
-
for idx, m in enumerate(supported, start=1):
|
|
3577
|
-
click.echo(f" {idx}) {m}")
|
|
3578
|
-
choice_idx = click.prompt("Enter choice", type=click.IntRange(1, len(supported)))
|
|
3579
|
-
selected_model = supported[choice_idx - 1]
|
|
3580
|
-
|
|
3581
|
-
try:
|
|
3582
|
-
seed_values = [int(s.strip()) for s in seeds.split(",") if s.strip()]
|
|
3583
|
-
except Exception as exc:
|
|
3584
|
-
raise click.ClickException("Invalid --seeds; expected comma-separated integers") from exc
|
|
3585
|
-
|
|
3586
|
-
import httpx
|
|
3587
|
-
|
|
3588
|
-
headers = {}
|
|
3589
|
-
api_key = (os.environ.get("ENVIRONMENT_API_KEY") or "").strip()
|
|
3590
|
-
if api_key:
|
|
3591
|
-
headers["X-API-Key"] = api_key
|
|
3592
|
-
|
|
3593
|
-
# Precompute optional policy overrides from TOML
|
|
3594
|
-
policy_overrides: dict[str, Any] = {}
|
|
3595
|
-
try:
|
|
3596
|
-
# Accept [eval.policy] table or top-level keys for convenience
|
|
3597
|
-
if isinstance(cfg.get("policy"), dict):
|
|
3598
|
-
policy_overrides.update(dict(cfg["policy"]))
|
|
3599
|
-
# Back-compat: allow temperature/max_tokens at top level
|
|
3600
|
-
for k in (
|
|
3601
|
-
"temperature",
|
|
3602
|
-
"max_tokens",
|
|
3603
|
-
"reasoning_effort",
|
|
3604
|
-
"system_hint",
|
|
3605
|
-
"tool_choice",
|
|
3606
|
-
"inference_url",
|
|
3607
|
-
):
|
|
3608
|
-
if k in cfg and k not in policy_overrides:
|
|
3609
|
-
policy_overrides[k] = cfg.get(k)
|
|
3610
|
-
except Exception:
|
|
3611
|
-
policy_overrides = {}
|
|
3612
|
-
|
|
3613
|
-
raw_concurrency = cfg.get("concurrency")
|
|
3614
|
-
try:
|
|
3615
|
-
concurrency_limit = int(raw_concurrency) if raw_concurrency is not None else 1
|
|
3616
|
-
except Exception:
|
|
3617
|
-
concurrency_limit = 1
|
|
3618
|
-
if concurrency_limit <= 0:
|
|
3619
|
-
concurrency_limit = 1
|
|
3620
|
-
concurrency_limit = min(concurrency_limit, max(1, len(seed_values)))
|
|
3621
|
-
|
|
3622
|
-
judge_specs: list[JudgeSpec] = []
|
|
3623
|
-
|
|
3624
|
-
def _register_judge(name_hint: str | None, judge_cfg: dict[str, Any]) -> None:
|
|
3625
|
-
if not judge_cfg:
|
|
3626
|
-
return
|
|
3627
|
-
judge_module = judge_cfg.get("module")
|
|
3628
|
-
judge_path = judge_cfg.get("path")
|
|
3629
|
-
judge_callable_name = judge_cfg.get("callable") or judge_cfg.get("function")
|
|
3630
|
-
if judge_module and judge_path:
|
|
3631
|
-
raise click.ClickException("Judge config cannot set both 'module' and 'path'")
|
|
3632
|
-
if not judge_module and not judge_path:
|
|
3633
|
-
raise click.ClickException("Judge config requires 'module' or 'path'")
|
|
3634
|
-
try:
|
|
3635
|
-
if judge_module:
|
|
3636
|
-
module = importlib.import_module(str(judge_module))
|
|
3637
|
-
else:
|
|
3638
|
-
path = Path(str(judge_path)).expanduser()
|
|
3639
|
-
if not path.exists():
|
|
3640
|
-
raise click.ClickException(f"Judge module path not found: {path}")
|
|
3641
|
-
spec = importlib.util.spec_from_file_location(
|
|
3642
|
-
f"_eval_judge_{path.stem}", path
|
|
3643
|
-
)
|
|
3644
|
-
if not spec or not spec.loader:
|
|
3645
|
-
raise click.ClickException(f"Failed to load judge module from {path}")
|
|
3646
|
-
module = importlib.util.module_from_spec(spec)
|
|
3647
|
-
sys.modules[spec.name] = module
|
|
3648
|
-
spec.loader.exec_module(module)
|
|
3649
|
-
except click.ClickException:
|
|
3650
|
-
raise
|
|
3651
|
-
except Exception as exc:
|
|
3652
|
-
raise click.ClickException(f"Unable to load judge module: {exc}") from exc
|
|
3653
|
-
|
|
3654
|
-
if judge_callable_name:
|
|
3655
|
-
try:
|
|
3656
|
-
judge_fn = getattr(module, str(judge_callable_name))
|
|
3657
|
-
except AttributeError as exc:
|
|
3658
|
-
raise click.ClickException(
|
|
3659
|
-
f"Judge callable '{judge_callable_name}' not found in module"
|
|
3660
|
-
) from exc
|
|
3661
|
-
else:
|
|
3662
|
-
if hasattr(module, "judge"):
|
|
3663
|
-
judge_fn = module.judge
|
|
3664
|
-
else:
|
|
3665
|
-
raise click.ClickException("Judge module must expose 'judge' callable")
|
|
3666
|
-
|
|
3667
|
-
if not callable(judge_fn):
|
|
3668
|
-
raise click.ClickException("Judge callable is not callable")
|
|
3669
|
-
|
|
3670
|
-
judge_kwargs = {
|
|
3671
|
-
k: v
|
|
3672
|
-
for k, v in judge_cfg.items()
|
|
3673
|
-
if k not in {"module", "path", "callable", "function", "name"}
|
|
3674
|
-
}
|
|
3675
|
-
display_name = str(
|
|
3676
|
-
judge_cfg.get("name")
|
|
3677
|
-
or name_hint
|
|
3678
|
-
or f"judge{len(judge_specs) + 1}"
|
|
3679
|
-
)
|
|
3680
|
-
judge_specs.append(JudgeSpec(display_name, judge_fn, judge_kwargs))
|
|
3681
|
-
|
|
3682
|
-
raw_judge_cfg = cfg.get("judge")
|
|
3683
|
-
if isinstance(raw_judge_cfg, dict) and raw_judge_cfg:
|
|
3684
|
-
direct_keys = {"module", "path", "callable", "function", "name"}
|
|
3685
|
-
has_direct_keys = any(key in raw_judge_cfg for key in direct_keys)
|
|
3686
|
-
nested_candidates = [
|
|
3687
|
-
(key, value)
|
|
3688
|
-
for key, value in raw_judge_cfg.items()
|
|
3689
|
-
if isinstance(value, dict)
|
|
3690
|
-
]
|
|
3691
|
-
if has_direct_keys and not nested_candidates:
|
|
3692
|
-
_register_judge(None, raw_judge_cfg)
|
|
3693
|
-
else:
|
|
3694
|
-
for sub_name, sub_cfg in nested_candidates:
|
|
3695
|
-
_register_judge(sub_name, sub_cfg)
|
|
3696
|
-
|
|
3697
|
-
raw_judges_list = cfg.get("judges")
|
|
3698
|
-
if isinstance(raw_judges_list, list):
|
|
3699
|
-
for _index, entry in enumerate(raw_judges_list, start=1):
|
|
3700
|
-
if isinstance(entry, dict):
|
|
3701
|
-
_register_judge(entry.get("name") or f"judge{len(judge_specs) + 1}", entry)
|
|
3702
|
-
|
|
3703
|
-
records: list[dict[str, Any]] = []
|
|
3704
|
-
|
|
3705
|
-
successes = 0
|
|
3706
|
-
failures = 0
|
|
3707
|
-
# Aggregate outcome stats across successful seeds
|
|
3708
|
-
outcome_sum: float = 0.0
|
|
3709
|
-
outcome_count: int = 0
|
|
3710
|
-
outcome_correct: int = 0
|
|
3711
|
-
|
|
3712
|
-
def _build_task_rows(taskset: Any) -> dict[int, dict[str, Any]]:
|
|
3713
|
-
rows: dict[int, dict[str, Any]] = {}
|
|
3714
|
-
if not isinstance(taskset, dict):
|
|
3715
|
-
return rows
|
|
3716
|
-
|
|
3717
|
-
scenario_ids = taskset.get("scenario_ids") or []
|
|
3718
|
-
loop_ids = taskset.get("loop_ids") or []
|
|
3719
|
-
thread_ids = taskset.get("thread_ids") or []
|
|
3720
|
-
difficulty_map = taskset.get("difficulty_map") or {}
|
|
3721
|
-
|
|
3722
|
-
max_len = max(len(scenario_ids), len(loop_ids), len(thread_ids))
|
|
3723
|
-
for seed in range(max_len):
|
|
3724
|
-
scenario_id = scenario_ids[seed] if seed < len(scenario_ids) else None
|
|
3725
|
-
loop_id = loop_ids[seed] if seed < len(loop_ids) else None
|
|
3726
|
-
thread_id = thread_ids[seed] if seed < len(thread_ids) else None
|
|
3727
|
-
difficulty = None
|
|
3728
|
-
if isinstance(difficulty_map, dict):
|
|
3729
|
-
if scenario_id and scenario_id in difficulty_map:
|
|
3730
|
-
difficulty = difficulty_map.get(scenario_id)
|
|
3731
|
-
elif str(seed) in difficulty_map:
|
|
3732
|
-
difficulty = difficulty_map.get(str(seed))
|
|
3733
|
-
|
|
3734
|
-
rows[seed] = {
|
|
3735
|
-
"seed": seed,
|
|
3736
|
-
"scenario_id": scenario_id,
|
|
3737
|
-
"loop_id": loop_id,
|
|
3738
|
-
"thread_id": thread_id,
|
|
3739
|
-
"difficulty": difficulty,
|
|
3740
|
-
}
|
|
3741
|
-
return rows
|
|
3742
|
-
|
|
3743
|
-
def _apply_metadata_filters(
|
|
3744
|
-
rows: dict[int, dict[str, Any]], seeds_list: list[int], filters: dict[str, str]
|
|
3745
|
-
) -> list[int]:
|
|
3746
|
-
if not filters:
|
|
3747
|
-
return seeds_list
|
|
3748
|
-
filtered: list[int] = []
|
|
3749
|
-
for seed in seeds_list:
|
|
3750
|
-
row = rows.get(seed)
|
|
3751
|
-
if not row:
|
|
3752
|
-
continue
|
|
3753
|
-
include = True
|
|
3754
|
-
for key, expected in filters.items():
|
|
3755
|
-
actual = row.get(key)
|
|
3756
|
-
if actual is None:
|
|
3757
|
-
include = False
|
|
3758
|
-
break
|
|
3759
|
-
if str(actual).lower() != expected.lower():
|
|
3760
|
-
include = False
|
|
3761
|
-
break
|
|
3762
|
-
if include:
|
|
3763
|
-
filtered.append(seed)
|
|
3764
|
-
return filtered
|
|
3765
|
-
|
|
3766
|
-
def _apply_metadata_sql(
|
|
3767
|
-
rows: dict[int, dict[str, Any]], seeds_list: list[int], query: str
|
|
3768
|
-
) -> list[int]:
|
|
3769
|
-
"""Return seeds that satisfy an arbitrary SQL query.
|
|
3770
|
-
|
|
3771
|
-
The query is executed against an in-memory SQLite table named `tasks`
|
|
3772
|
-
with columns (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT).
|
|
3773
|
-
Any rows whose `seed` value (or first column if `seed` is absent) appear in the result set are retained.
|
|
3774
|
-
"""
|
|
3775
|
-
if not query:
|
|
3776
|
-
return seeds_list
|
|
3777
|
-
conn = sqlite3.connect(":memory:")
|
|
3778
|
-
try:
|
|
3779
|
-
cur = conn.cursor()
|
|
3780
|
-
cur.execute(
|
|
3781
|
-
"CREATE TABLE tasks (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT)"
|
|
3782
|
-
)
|
|
3783
|
-
insert_stmt = (
|
|
3784
|
-
"INSERT INTO tasks (seed, scenario_id, loop_id, thread_id, difficulty) VALUES (?,?,?,?,?)"
|
|
3785
|
-
)
|
|
3786
|
-
for seed in seeds_list:
|
|
3787
|
-
row = rows.get(seed, {})
|
|
3788
|
-
cur.execute(
|
|
3789
|
-
insert_stmt,
|
|
3790
|
-
[
|
|
3791
|
-
seed,
|
|
3792
|
-
row.get("scenario_id"),
|
|
3793
|
-
row.get("loop_id"),
|
|
3794
|
-
row.get("thread_id"),
|
|
3795
|
-
row.get("difficulty"),
|
|
3796
|
-
],
|
|
3797
|
-
)
|
|
3798
|
-
|
|
3799
|
-
result = cur.execute(query)
|
|
3800
|
-
fetched = result.fetchall()
|
|
3801
|
-
if not fetched:
|
|
3802
|
-
return []
|
|
3803
|
-
description = result.description or []
|
|
3804
|
-
col_names = [col[0] for col in description]
|
|
3805
|
-
seeds_out: list[int] = []
|
|
3806
|
-
for entry in fetched:
|
|
3807
|
-
value = entry[col_names.index("seed")] if "seed" in col_names else entry[0]
|
|
3808
|
-
try:
|
|
3809
|
-
seeds_out.append(int(value))
|
|
3810
|
-
except Exception as exc:
|
|
3811
|
-
raise click.ClickException(
|
|
3812
|
-
"metadata SQL query must return seed integers"
|
|
3813
|
-
) from exc
|
|
3814
|
-
seeds_set = set(seeds_out)
|
|
3815
|
-
return [seed for seed in seeds_list if seed in seeds_set]
|
|
3816
|
-
except sqlite3.Error as exc:
|
|
3817
|
-
raise click.ClickException(f"Failed to execute metadata SQL query: {exc}") from exc
|
|
3818
|
-
finally:
|
|
3819
|
-
conn.close()
|
|
3820
|
-
|
|
3821
|
-
async def _run_eval() -> None:
|
|
3822
|
-
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records, seed_values
|
|
3823
|
-
|
|
3824
|
-
if trace_tracer is not None and trace_tracer.db is None:
|
|
3825
|
-
await trace_tracer.initialize()
|
|
3826
|
-
|
|
3827
|
-
if task_app_url is None:
|
|
3828
|
-
transport = httpx.ASGITransport(app=app) # type: ignore[name-defined]
|
|
3829
|
-
async_client = httpx.AsyncClient(
|
|
3830
|
-
transport=cast(Any, transport),
|
|
3831
|
-
base_url="http://eval.local",
|
|
3832
|
-
timeout=300.0,
|
|
3833
|
-
follow_redirects=True,
|
|
3834
|
-
headers=headers,
|
|
3835
|
-
)
|
|
3836
|
-
else:
|
|
3837
|
-
async_client = httpx.AsyncClient(
|
|
3838
|
-
base_url=task_app_url,
|
|
3839
|
-
timeout=300.0,
|
|
3840
|
-
follow_redirects=True,
|
|
3841
|
-
headers=headers,
|
|
3842
|
-
)
|
|
3843
|
-
|
|
3844
|
-
try:
|
|
3845
|
-
taskset_payload: dict[str, Any] | None = None
|
|
3846
|
-
try:
|
|
3847
|
-
task_info_response = await async_client.get("/task_info")
|
|
3848
|
-
except Exception:
|
|
3849
|
-
task_info_response = None
|
|
3850
|
-
if task_info_response is not None and task_info_response.status_code == 200:
|
|
3851
|
-
with contextlib.suppress(Exception):
|
|
3852
|
-
payload_json = task_info_response.json()
|
|
3853
|
-
if isinstance(payload_json, dict) and "taskset" in payload_json:
|
|
3854
|
-
taskset_payload = payload_json.get("taskset")
|
|
3855
|
-
if not isinstance(taskset_payload, dict):
|
|
3856
|
-
taskset_payload = None
|
|
3857
|
-
elif isinstance(payload_json, dict):
|
|
3858
|
-
taskset_payload = payload_json
|
|
3859
|
-
|
|
3860
|
-
available_seeds = list(seed_values)
|
|
3861
|
-
if metadata_sql_query or metadata_filters:
|
|
3862
|
-
if not taskset_payload:
|
|
3863
|
-
raise click.ClickException(
|
|
3864
|
-
"Task metadata filters require the task app to expose /task_info metadata"
|
|
3865
|
-
)
|
|
3866
|
-
rows = _build_task_rows(taskset_payload)
|
|
3867
|
-
if metadata_sql_query:
|
|
3868
|
-
available_seeds = _apply_metadata_sql(rows, available_seeds, metadata_sql_query)
|
|
3869
|
-
if metadata_filters:
|
|
3870
|
-
available_seeds = _apply_metadata_filters(rows, available_seeds, metadata_filters)
|
|
3871
|
-
if not available_seeds:
|
|
3872
|
-
raise click.ClickException("No seeds match the provided metadata filters")
|
|
3873
|
-
seed_values = available_seeds
|
|
3874
|
-
|
|
3875
|
-
semaphore = asyncio.Semaphore(concurrency_limit)
|
|
3876
|
-
|
|
3877
|
-
async def _run_seed(seed_val: int) -> None:
|
|
3878
|
-
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records
|
|
3879
|
-
# Read env_name and policy_name from config if available
|
|
3880
|
-
env_name = cfg.get("env_name") or (cfg.get("env", {}).get("env_name") if isinstance(cfg.get("env"), dict) else None)
|
|
3881
|
-
policy_name = cfg.get("policy_name") or (cfg.get("policy", {}).get("policy_name") if isinstance(cfg.get("policy"), dict) else None)
|
|
3882
|
-
env_config_overrides = cfg.get("env_config", {}) if isinstance(cfg.get("env_config"), dict) else {}
|
|
3883
|
-
policy_config_overrides = cfg.get("policy_config", {}) if isinstance(cfg.get("policy_config"), dict) else {}
|
|
3884
|
-
|
|
3885
|
-
# Debug: print config parsing
|
|
3886
|
-
if seed_val == 0:
|
|
3887
|
-
click.echo(f"[DEBUG] env_name from config: {env_name}")
|
|
3888
|
-
click.echo(f"[DEBUG] policy_name from config: {policy_name}")
|
|
3889
|
-
|
|
3890
|
-
# Generate default ops sequence if not provided
|
|
3891
|
-
max_llm_calls = policy_config_overrides.get("max_llm_calls", 10)
|
|
3892
|
-
ops_list = cfg.get("ops", [])
|
|
3893
|
-
if not ops_list:
|
|
3894
|
-
# Generate default "agent, env" pairs for max_llm_calls
|
|
3895
|
-
ops_list = ["agent", "env"] * int(max_llm_calls)
|
|
3896
|
-
|
|
3897
|
-
body = {
|
|
3898
|
-
"run_id": str(uuid.uuid4()),
|
|
3899
|
-
"env": {"config": {"split": split, "index": seed_val, **env_config_overrides}, "seed": seed_val},
|
|
3900
|
-
"policy": {
|
|
3901
|
-
"policy_name": policy_name or selected_model,
|
|
3902
|
-
"config": {"model": selected_model, **policy_overrides, **policy_config_overrides},
|
|
3903
|
-
},
|
|
3904
|
-
"ops": ops_list,
|
|
3905
|
-
"record": {
|
|
3906
|
-
"return_trace": cfg.get("return_trace", True),
|
|
3907
|
-
"trace_format": cfg.get("trace_format", "structured"),
|
|
3908
|
-
},
|
|
3909
|
-
"mode": "eval", # RolloutMode.EVAL: use inference URLs as-is, no transformations
|
|
3910
|
-
}
|
|
3911
|
-
if env_name:
|
|
3912
|
-
body["env"]["env_name"] = env_name
|
|
3913
|
-
|
|
3914
|
-
# Debug: print the body being sent
|
|
3915
|
-
if seed_val == 0:
|
|
3916
|
-
click.echo(f"[DEBUG] rollout body env: {body['env']}")
|
|
3917
|
-
click.echo(f"[DEBUG] rollout body policy: {body['policy']}")
|
|
3918
|
-
click.echo(f"[DEBUG] rollout body mode: {body.get('mode', 'NOT SET')}")
|
|
3919
|
-
rollout_elapsed: float | None = None
|
|
3920
|
-
rollout_start = time.perf_counter()
|
|
3921
|
-
try:
|
|
3922
|
-
import logging
|
|
3923
|
-
_log = logging.getLogger(__name__)
|
|
3924
|
-
_log.info(f"[EVAL_BODY_DEBUG] Sending body with mode={body.get('mode')}")
|
|
3925
|
-
async with semaphore:
|
|
3926
|
-
response = await async_client.post("/rollout", json=body)
|
|
3927
|
-
rollout_elapsed = time.perf_counter() - rollout_start
|
|
3928
|
-
except Exception as exc:
|
|
3929
|
-
failures += 1
|
|
3930
|
-
click.echo(f"seed={seed_val} error={exc}")
|
|
3931
|
-
return
|
|
3932
|
-
|
|
3933
|
-
ok = 200 <= response.status_code < 300
|
|
3934
|
-
if ok:
|
|
3935
|
-
successes += 1
|
|
3936
|
-
else:
|
|
3937
|
-
failures += 1
|
|
3938
|
-
|
|
3939
|
-
summary = [f"seed={seed_val}", f"status={response.status_code}"]
|
|
3940
|
-
data: Any
|
|
3941
|
-
try:
|
|
3942
|
-
data = response.json()
|
|
3943
|
-
except Exception:
|
|
3944
|
-
data = None
|
|
3945
|
-
|
|
3946
|
-
# Debug: print validation errors
|
|
3947
|
-
if response.status_code == 422 and data:
|
|
3948
|
-
click.echo(f"[DEBUG] 422 Validation Error: {data}")
|
|
3949
|
-
|
|
3950
|
-
metrics: dict[str, Any] | None = None
|
|
3951
|
-
completion: str | None = None
|
|
3952
|
-
prompt_index: int | None = None
|
|
3953
|
-
prompt_text: str | None = None
|
|
3954
|
-
task_id: str | None = None
|
|
3955
|
-
task_split: str | None = None
|
|
3956
|
-
task_rubric_id: str | None = None
|
|
3957
|
-
|
|
3958
|
-
trace_namespace: dict[str, Any] | None = None
|
|
3959
|
-
session_trace_dict: dict[str, Any] | None = None
|
|
3960
|
-
|
|
3961
|
-
if isinstance(data, dict):
|
|
3962
|
-
import logging
|
|
3963
|
-
_logger = logging.getLogger(__name__)
|
|
3964
|
-
_logger.info(f"[EVAL_DEBUG] Response data keys: {list(data.keys())}")
|
|
3965
|
-
if "detail" in data:
|
|
3966
|
-
_logger.error(f"[EVAL_DEBUG] Task app returned error: {data['detail']}")
|
|
3967
|
-
trace_namespace = data.get("trace")
|
|
3968
|
-
_logger.info(f"[EVAL_DEBUG] trace_namespace type: {type(trace_namespace)}, value: {trace_namespace if not isinstance(trace_namespace, dict) else 'dict with keys: ' + str(list(trace_namespace.keys()) if trace_namespace else 'None')}")
|
|
3969
|
-
if not isinstance(trace_namespace, dict):
|
|
3970
|
-
raise RuntimeError(
|
|
3971
|
-
"The 'synth-ai eval' command requires trace payloads in rollout responses. "
|
|
3972
|
-
"Ensure the rollout request includes 'trace_format': 'structured' and 'return_trace': true, "
|
|
3973
|
-
"and that task app tracing is enabled (TASKAPP_TRACING_ENABLED=1). "
|
|
3974
|
-
"Note: This is specific to the eval command - general rollout endpoints don't require traces."
|
|
3975
|
-
)
|
|
3976
|
-
# Handle both "compact" and "full" trace formats:
|
|
3977
|
-
# - compact: trace_namespace contains {session_id, metadata, ...}
|
|
3978
|
-
# - full: trace_namespace IS the full session_trace dict
|
|
3979
|
-
session_trace_dict = trace_namespace.get("session_trace")
|
|
3980
|
-
if not isinstance(session_trace_dict, dict):
|
|
3981
|
-
# If no session_trace key, assume "full" format where trace itself is the session_trace
|
|
3982
|
-
if "session_id" in trace_namespace:
|
|
3983
|
-
session_trace_dict = trace_namespace
|
|
3984
|
-
else:
|
|
3985
|
-
raise RuntimeError(
|
|
3986
|
-
"The 'synth-ai eval' command requires 'session_trace' in the trace payload or a valid full trace format. "
|
|
3987
|
-
"Ensure the task app is using tracing_v3 and returning structured trace data."
|
|
3988
|
-
)
|
|
3989
|
-
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else None
|
|
3990
|
-
if metrics:
|
|
3991
|
-
mean_return = metrics.get("mean_return") or metrics.get("total_reward")
|
|
3992
|
-
outcome = metrics.get("outcome_score")
|
|
3993
|
-
if mean_return is not None:
|
|
3994
|
-
summary.append(f"mean_return={mean_return}")
|
|
3995
|
-
if outcome is not None:
|
|
3996
|
-
summary.append(f"outcome={outcome}")
|
|
3997
|
-
try:
|
|
3998
|
-
val = float(outcome)
|
|
3999
|
-
outcome_sum += val
|
|
4000
|
-
outcome_count += 1
|
|
4001
|
-
if val >= 0.5:
|
|
4002
|
-
outcome_correct += 1
|
|
4003
|
-
except Exception:
|
|
4004
|
-
pass
|
|
4005
|
-
trajs = (
|
|
4006
|
-
data.get("trajectories")
|
|
4007
|
-
if isinstance(data.get("trajectories"), list)
|
|
4008
|
-
else None
|
|
4009
|
-
)
|
|
4010
|
-
if trajs:
|
|
4011
|
-
first = trajs[0] if trajs else None
|
|
4012
|
-
steps = first.get("steps") if isinstance(first, dict) else None
|
|
4013
|
-
if isinstance(steps, list) and steps:
|
|
4014
|
-
step0 = steps[0]
|
|
4015
|
-
tool_calls = step0.get("tool_calls") or step0.get("tools") or []
|
|
4016
|
-
if isinstance(tool_calls, list):
|
|
4017
|
-
summary.append(f"tool_calls={len(tool_calls)}")
|
|
4018
|
-
obs = step0.get("obs") if isinstance(step0, dict) else None
|
|
4019
|
-
if isinstance(obs, dict):
|
|
4020
|
-
idx_val = obs.get("prompt_index")
|
|
4021
|
-
if isinstance(idx_val, int):
|
|
4022
|
-
prompt_index = idx_val
|
|
4023
|
-
prompt_raw = obs.get("prompt")
|
|
4024
|
-
if isinstance(prompt_raw, str):
|
|
4025
|
-
prompt_text = prompt_raw
|
|
4026
|
-
if task_id is None:
|
|
4027
|
-
candidate_id = obs.get("task_id")
|
|
4028
|
-
if isinstance(candidate_id, str) and candidate_id:
|
|
4029
|
-
task_id = candidate_id
|
|
4030
|
-
if task_split is None:
|
|
4031
|
-
candidate_split = obs.get("task_split")
|
|
4032
|
-
if isinstance(candidate_split, str) and candidate_split:
|
|
4033
|
-
task_split = candidate_split
|
|
4034
|
-
if task_rubric_id is None:
|
|
4035
|
-
candidate_rid = obs.get("task_rubric_id")
|
|
4036
|
-
if isinstance(candidate_rid, str) and candidate_rid:
|
|
4037
|
-
task_rubric_id = candidate_rid
|
|
4038
|
-
final = first.get("final") if isinstance(first, dict) else None
|
|
4039
|
-
if isinstance(final, dict):
|
|
4040
|
-
final_obs = final.get("observation")
|
|
4041
|
-
if isinstance(final_obs, dict):
|
|
4042
|
-
comp_val = final_obs.get("completion")
|
|
4043
|
-
if isinstance(comp_val, str):
|
|
4044
|
-
completion = comp_val
|
|
4045
|
-
if task_id is None:
|
|
4046
|
-
candidate_id = final_obs.get("task_id")
|
|
4047
|
-
if isinstance(candidate_id, str) and candidate_id:
|
|
4048
|
-
task_id = candidate_id
|
|
4049
|
-
if task_split is None:
|
|
4050
|
-
candidate_split = final_obs.get("task_split")
|
|
4051
|
-
if isinstance(candidate_split, str) and candidate_split:
|
|
4052
|
-
task_split = candidate_split
|
|
4053
|
-
if task_rubric_id is None:
|
|
4054
|
-
candidate_rid = final_obs.get("task_rubric_id")
|
|
4055
|
-
if isinstance(candidate_rid, str) and candidate_rid:
|
|
4056
|
-
task_rubric_id = candidate_rid
|
|
4057
|
-
final_info = final.get("info")
|
|
4058
|
-
if isinstance(final_info, dict):
|
|
4059
|
-
if task_id is None:
|
|
4060
|
-
candidate_id = final_info.get("task_id")
|
|
4061
|
-
if isinstance(candidate_id, str) and candidate_id:
|
|
4062
|
-
task_id = candidate_id
|
|
4063
|
-
if task_split is None:
|
|
4064
|
-
candidate_split = final_info.get("task_split")
|
|
4065
|
-
if isinstance(candidate_split, str) and candidate_split:
|
|
4066
|
-
task_split = candidate_split
|
|
4067
|
-
if task_rubric_id is None:
|
|
4068
|
-
candidate_rid = final_info.get("task_rubric_id")
|
|
4069
|
-
if isinstance(candidate_rid, str) and candidate_rid:
|
|
4070
|
-
task_rubric_id = candidate_rid
|
|
4071
|
-
if task_id:
|
|
4072
|
-
summary.append(f"task_id={task_id}")
|
|
4073
|
-
click.echo(" ".join(summary))
|
|
4074
|
-
with contextlib.suppress(Exception):
|
|
4075
|
-
click.echo(json.dumps(data, indent=2))
|
|
4076
|
-
else:
|
|
4077
|
-
click.echo(" ".join(summary))
|
|
4078
|
-
|
|
4079
|
-
official_score = None
|
|
4080
|
-
if isinstance(metrics, dict):
|
|
4081
|
-
for key in ("mean_return", "total_reward", "outcome_score"):
|
|
4082
|
-
val = metrics.get(key)
|
|
4083
|
-
if isinstance(val, int | float):
|
|
4084
|
-
official_score = float(val)
|
|
4085
|
-
break
|
|
4086
|
-
if official_score is None and isinstance(data, dict):
|
|
4087
|
-
try:
|
|
4088
|
-
reward_val = data["trajectories"][0]["steps"][0].get("reward")
|
|
4089
|
-
if isinstance(reward_val, int | float):
|
|
4090
|
-
official_score = float(reward_val)
|
|
4091
|
-
except Exception:
|
|
4092
|
-
pass
|
|
4093
|
-
|
|
4094
|
-
if official_score is not None:
|
|
4095
|
-
if official_score < 0.0:
|
|
4096
|
-
official_score = 0.0
|
|
4097
|
-
elif official_score > 1.0:
|
|
4098
|
-
official_score = min(1.0, official_score)
|
|
4099
|
-
|
|
4100
|
-
judge_scores: dict[str, float | None] = {}
|
|
4101
|
-
judges_timings: dict[str, float | None] = {}
|
|
4102
|
-
timings: dict[str, Any] = {
|
|
4103
|
-
"rollout_s": rollout_elapsed,
|
|
4104
|
-
"judges": judges_timings,
|
|
4105
|
-
}
|
|
4106
|
-
if judge_specs:
|
|
4107
|
-
for spec in judge_specs:
|
|
4108
|
-
score_value: float | None = None
|
|
4109
|
-
judge_elapsed: float | None = None
|
|
4110
|
-
# Run judges for all tasks (text-based and trajectory-based)
|
|
4111
|
-
# Text-based tasks have completion, trajectory-based tasks use response
|
|
4112
|
-
judge_payload = {
|
|
4113
|
-
"seed": seed_val,
|
|
4114
|
-
"prompt_index": prompt_index,
|
|
4115
|
-
"prompt": prompt_text,
|
|
4116
|
-
"completion": completion,
|
|
4117
|
-
"metrics": metrics,
|
|
4118
|
-
"response": data,
|
|
4119
|
-
"trace": trace_namespace,
|
|
4120
|
-
}
|
|
4121
|
-
try:
|
|
4122
|
-
judge_start = time.perf_counter()
|
|
4123
|
-
result = spec.fn(judge_payload, **spec.kwargs)
|
|
4124
|
-
judge_elapsed = time.perf_counter() - judge_start
|
|
4125
|
-
if isinstance(result, int | float):
|
|
4126
|
-
score_value = float(result)
|
|
4127
|
-
except Exception as exc:
|
|
4128
|
-
if judge_elapsed is None:
|
|
4129
|
-
judge_elapsed = time.perf_counter() - judge_start
|
|
4130
|
-
click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
|
|
4131
|
-
judges_timings[spec.name] = judge_elapsed
|
|
4132
|
-
judge_scores[spec.name] = score_value
|
|
4133
|
-
|
|
4134
|
-
if trace_tracer is not None and trace_namespace:
|
|
4135
|
-
storage_metadata = {
|
|
4136
|
-
"eval_seed": seed_val,
|
|
4137
|
-
"prompt_index": prompt_index,
|
|
4138
|
-
"task_id": task_id,
|
|
4139
|
-
"task_split": task_split,
|
|
4140
|
-
"task_rubric_id": task_rubric_id,
|
|
4141
|
-
"official_score": official_score,
|
|
4142
|
-
"judge_scores": judge_scores,
|
|
4143
|
-
"model": selected_model,
|
|
4144
|
-
"prompt": prompt_text,
|
|
4145
|
-
"completion": completion,
|
|
4146
|
-
}
|
|
4147
|
-
await _store_trace(trace_tracer, trace_namespace, storage_metadata)
|
|
4148
|
-
|
|
4149
|
-
records.append(
|
|
4150
|
-
{
|
|
4151
|
-
"seed": seed_val,
|
|
4152
|
-
"prompt_index": prompt_index,
|
|
4153
|
-
"task_id": task_id,
|
|
4154
|
-
"task_split": task_split,
|
|
4155
|
-
"task_rubric_id": task_rubric_id,
|
|
4156
|
-
"official_score": official_score,
|
|
4157
|
-
"judge_scores": judge_scores,
|
|
4158
|
-
"timings": timings,
|
|
4159
|
-
}
|
|
4160
|
-
)
|
|
4161
|
-
|
|
4162
|
-
await asyncio.gather(*[_run_seed(seed_val) for seed_val in seed_values])
|
|
4163
|
-
finally:
|
|
4164
|
-
await async_client.aclose()
|
|
4165
|
-
|
|
4166
|
-
try:
|
|
4167
|
-
asyncio.run(_run_eval())
|
|
4168
|
-
finally:
|
|
4169
|
-
if trace_tracer is not None and trace_tracer.db is not None:
|
|
4170
|
-
asyncio.run(trace_tracer.db.close())
|
|
4171
|
-
|
|
4172
|
-
click.echo(
|
|
4173
|
-
f"Eval complete: {successes} ok, {failures} failed; model={selected_model}, split={split}"
|
|
4174
|
-
)
|
|
4175
|
-
|
|
4176
|
-
if outcome_count > 0:
|
|
4177
|
-
mean_outcome = outcome_sum / float(outcome_count)
|
|
4178
|
-
frac_right = outcome_correct / float(outcome_count)
|
|
4179
|
-
click.echo(
|
|
4180
|
-
f"Outcome summary: correct={outcome_correct}/{outcome_count} ({frac_right:.2%}), mean_outcome={mean_outcome:.3f}"
|
|
4181
|
-
)
|
|
4182
|
-
|
|
4183
|
-
if records:
|
|
4184
|
-
judge_specs = judge_specs or [] # ensure iterable
|
|
4185
|
-
official_scores = [
|
|
4186
|
-
r["official_score"] for r in records if r["official_score"] is not None
|
|
4187
|
-
]
|
|
4188
|
-
if official_scores:
|
|
4189
|
-
click.echo(f" Official mean: {sum(official_scores) / len(official_scores):.3f}")
|
|
4190
|
-
else:
|
|
4191
|
-
click.echo(" Official mean: n/a")
|
|
4192
|
-
|
|
4193
|
-
for spec in judge_specs:
|
|
4194
|
-
spec_scores = [
|
|
4195
|
-
record["judge_scores"].get(spec.name)
|
|
4196
|
-
for record in records
|
|
4197
|
-
if record["judge_scores"].get(spec.name) is not None
|
|
4198
|
-
]
|
|
4199
|
-
if spec_scores:
|
|
4200
|
-
mean_spec = sum(spec_scores) / len(spec_scores)
|
|
4201
|
-
click.echo(f" [{spec.name}] mean: {mean_spec:.3f}")
|
|
4202
|
-
else:
|
|
4203
|
-
click.echo(f" [{spec.name}] mean: n/a")
|
|
4204
|
-
|
|
4205
|
-
paired = [
|
|
4206
|
-
(
|
|
4207
|
-
record["official_score"],
|
|
4208
|
-
record["judge_scores"].get(spec.name),
|
|
4209
|
-
)
|
|
4210
|
-
for record in records
|
|
4211
|
-
if record["official_score"] is not None
|
|
4212
|
-
and record["judge_scores"].get(spec.name) is not None
|
|
4213
|
-
]
|
|
4214
|
-
if len(paired) >= 2:
|
|
4215
|
-
corr = _pearson(
|
|
4216
|
-
[p[0] for p in paired if p[0] is not None],
|
|
4217
|
-
[p[1] for p in paired if p[1] is not None],
|
|
4218
|
-
)
|
|
4219
|
-
if corr is not None:
|
|
4220
|
-
click.echo(f" Pearson r: {corr:.3f}")
|
|
4221
|
-
else:
|
|
4222
|
-
click.echo(" Pearson r: undefined (zero variance)")
|
|
4223
|
-
else:
|
|
4224
|
-
click.echo(" Pearson r: n/a (need ≥2 paired scores)")
|
|
4225
|
-
|
|
4226
|
-
header = ["Seed", "Prompt", "Official"]
|
|
4227
|
-
header.extend(spec.name for spec in judge_specs)
|
|
4228
|
-
rows: list[list[str]] = []
|
|
4229
|
-
for record in sorted(records, key=lambda r: (r["seed"], r.get("prompt_index") or -1)):
|
|
4230
|
-
seed_val = str(record["seed"])
|
|
4231
|
-
prompt_idx = (
|
|
4232
|
-
str(record["prompt_index"])
|
|
4233
|
-
if record["prompt_index"] is not None
|
|
4234
|
-
else "-"
|
|
4235
|
-
)
|
|
4236
|
-
official_val = (
|
|
4237
|
-
f"{record['official_score']:.3f}"
|
|
4238
|
-
if record["official_score"] is not None
|
|
4239
|
-
else "-"
|
|
4240
|
-
)
|
|
4241
|
-
row = [seed_val, prompt_idx, official_val]
|
|
4242
|
-
for spec in judge_specs:
|
|
4243
|
-
score_val = record["judge_scores"].get(spec.name)
|
|
4244
|
-
row.append(f"{score_val:.3f}" if isinstance(score_val, int | float) else "-")
|
|
4245
|
-
rows.append(row)
|
|
4246
|
-
|
|
4247
|
-
widths = [len(col) for col in header]
|
|
4248
|
-
for row in rows:
|
|
4249
|
-
for idx, cell in enumerate(row):
|
|
4250
|
-
widths[idx] = max(widths[idx], len(cell))
|
|
4251
|
-
|
|
4252
|
-
click.echo("")
|
|
4253
|
-
click.echo(" ".join(h.ljust(widths[idx]) for idx, h in enumerate(header)))
|
|
4254
|
-
click.echo(" ".join("-" * widths[idx] for idx in range(len(header))))
|
|
4255
|
-
for row in rows:
|
|
4256
|
-
click.echo(" ".join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)))
|
|
4257
|
-
|
|
4258
|
-
|
|
4259
|
-
|
|
4260
|
-
@click.command(
|
|
4261
|
-
"filter",
|
|
4262
|
-
help="Export filtered tracing sessions to SFT-ready JSONL based on a TOML config.",
|
|
4263
|
-
)
|
|
4264
|
-
@click.option(
|
|
4265
|
-
"--config",
|
|
4266
|
-
"config_path",
|
|
4267
|
-
type=click.Path(),
|
|
4268
|
-
required=True,
|
|
4269
|
-
help="Path to TOML config describing the input trace DB, score thresholds, and output JSONL.",
|
|
4270
|
-
)
|
|
4271
|
-
def filter_command(config_path: str) -> None:
|
|
4272
|
-
"""Render tracing sessions that match filter rules into SFT JSONL.
|
|
4273
|
-
|
|
4274
|
-
The TOML file should contain a `[filter]` table with at least:
|
|
4275
|
-
|
|
4276
|
-
db = \"path/to/traces.db\" # sqlite path or URL (sqlite+aiosqlite://...)
|
|
4277
|
-
output = \"ft_data/out.jsonl\" # destination JSONL
|
|
4278
|
-
|
|
4279
|
-
Optional keys such as `splits`, `task_ids`, `models`, `min_official_score`, or
|
|
4280
|
-
`min_judge_scores.my_judge = 0.7` allow you to narrow the dataset down to
|
|
4281
|
-
high-quality traces. See `customers/agora_single_file/configs/filter_local.toml`
|
|
4282
|
-
for a working example.
|
|
4283
|
-
"""
|
|
4284
|
-
# Parse and validate TOML config
|
|
4285
|
-
from synth_ai.task.config import FilterConfig
|
|
4286
|
-
|
|
4287
|
-
if _toml is None:
|
|
4288
|
-
raise click.ClickException("TOML parser not available; install tomli or use Python 3.11+")
|
|
4289
|
-
|
|
4290
|
-
cfg_path = Path(config_path)
|
|
4291
|
-
if not cfg_path.exists():
|
|
4292
|
-
raise click.ClickException(f"Filter config not found: {cfg_path}")
|
|
4293
|
-
|
|
4294
|
-
try:
|
|
4295
|
-
config_data = _toml.loads(cfg_path.read_text(encoding="utf-8"))
|
|
4296
|
-
except Exception as exc:
|
|
4297
|
-
raise click.ClickException(f"Failed to parse TOML '{cfg_path}': {exc}") from exc
|
|
4298
|
-
|
|
4299
|
-
filter_cfg_dict = config_data.get("filter") if isinstance(config_data, dict) else None
|
|
4300
|
-
if not isinstance(filter_cfg_dict, dict):
|
|
4301
|
-
raise click.ClickException("Config must contain a [filter] table")
|
|
4302
|
-
|
|
4303
|
-
# Validate config with dataclass
|
|
4304
|
-
try:
|
|
4305
|
-
filter_cfg = FilterConfig.from_dict(filter_cfg_dict)
|
|
4306
|
-
click.echo(f"✓ Config validated: db={filter_cfg.db}, output={filter_cfg.output}")
|
|
4307
|
-
if filter_cfg.min_official_score is not None:
|
|
4308
|
-
click.echo(f" → Filtering for official score >= {filter_cfg.min_official_score}")
|
|
4309
|
-
if filter_cfg.limit:
|
|
4310
|
-
click.echo(f" → Limiting to {filter_cfg.limit} examples")
|
|
4311
|
-
except (ValueError, TypeError) as validation_error:
|
|
4312
|
-
raise click.ClickException(f"Invalid filter config: {validation_error}") from validation_error
|
|
4313
|
-
|
|
4314
|
-
# Use validated config
|
|
4315
|
-
db_url = filter_cfg.get_db_url()
|
|
4316
|
-
output_path = filter_cfg.get_output_path()
|
|
4317
|
-
|
|
4318
|
-
# Extract validated fields from dataclass
|
|
4319
|
-
splits = set(filter_cfg.splits)
|
|
4320
|
-
task_ids = set(filter_cfg.task_ids)
|
|
4321
|
-
models = set(filter_cfg.models)
|
|
4322
|
-
min_official = filter_cfg.min_official_score
|
|
4323
|
-
max_official = filter_cfg.max_official_score
|
|
4324
|
-
min_judge_scores = filter_cfg.min_judge_scores
|
|
4325
|
-
max_judge_scores = filter_cfg.max_judge_scores
|
|
4326
|
-
# Note: min_created_at and max_created_at not yet in FilterConfig dataclass
|
|
4327
|
-
min_created = _parse_datetime_for_trace(filter_cfg_dict.get("min_created_at"))
|
|
4328
|
-
max_created = _parse_datetime_for_trace(filter_cfg_dict.get("max_created_at"))
|
|
4329
|
-
limit = filter_cfg.limit
|
|
4330
|
-
|
|
4331
|
-
def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
|
|
4332
|
-
try:
|
|
4333
|
-
if value is None:
|
|
4334
|
-
return min_val is None
|
|
4335
|
-
value = float(value)
|
|
4336
|
-
except Exception:
|
|
4337
|
-
return False
|
|
4338
|
-
if min_val is not None and value < float(min_val):
|
|
4339
|
-
return False
|
|
4340
|
-
return not (max_val is not None and value > float(max_val))
|
|
4341
|
-
|
|
4342
|
-
async def _run_filter() -> None:
|
|
4343
|
-
tracer = SessionTracer(db_url=db_url, auto_save=False)
|
|
4344
|
-
await tracer.initialize()
|
|
4345
|
-
|
|
4346
|
-
df = await tracer.db.query_traces(
|
|
4347
|
-
"SELECT session_id, created_at, metadata FROM session_traces ORDER BY created_at"
|
|
4348
|
-
)
|
|
4349
|
-
if getattr(df, "empty", True):
|
|
4350
|
-
raise click.ClickException("No traces found in database")
|
|
4351
|
-
|
|
4352
|
-
sessions = df.to_dict("records")
|
|
4353
|
-
accepted: list[dict[str, Any]] = []
|
|
4354
|
-
|
|
4355
|
-
for row in sessions:
|
|
4356
|
-
metadata_raw = row.get("metadata")
|
|
4357
|
-
if isinstance(metadata_raw, str):
|
|
4358
|
-
try:
|
|
4359
|
-
metadata = json.loads(metadata_raw)
|
|
4360
|
-
except Exception:
|
|
4361
|
-
metadata = {}
|
|
4362
|
-
elif isinstance(metadata_raw, dict):
|
|
4363
|
-
metadata = dict(metadata_raw)
|
|
4364
|
-
else:
|
|
4365
|
-
metadata = {}
|
|
4366
|
-
|
|
4367
|
-
created_at_raw = row.get("created_at")
|
|
4368
|
-
created_at_dt = _parse_datetime_for_trace(created_at_raw)
|
|
4369
|
-
|
|
4370
|
-
session_id = row.get("session_id")
|
|
4371
|
-
|
|
4372
|
-
if splits and metadata.get("task_split") not in splits:
|
|
4373
|
-
continue
|
|
4374
|
-
if task_ids and metadata.get("task_id") not in task_ids:
|
|
4375
|
-
continue
|
|
4376
|
-
if models and metadata.get("model") not in models:
|
|
4377
|
-
continue
|
|
4378
|
-
|
|
4379
|
-
if min_created and (created_at_dt is None or created_at_dt < min_created):
|
|
4380
|
-
continue
|
|
4381
|
-
if max_created and (created_at_dt is None or created_at_dt > max_created):
|
|
4382
|
-
continue
|
|
4383
|
-
|
|
4384
|
-
# Check against outcome_rewards if score filter is set
|
|
4385
|
-
total_reward = None
|
|
4386
|
-
achievements_count = None
|
|
4387
|
-
if min_official is not None or max_official is not None:
|
|
4388
|
-
reward_query = "SELECT total_reward, achievements_count FROM outcome_rewards WHERE session_id = :session_id"
|
|
4389
|
-
reward_rows = await tracer.db.query_traces(reward_query, {"session_id": session_id})
|
|
4390
|
-
reward_records = reward_rows.to_dict("records") if hasattr(reward_rows, "to_dict") else []
|
|
4391
|
-
if reward_records:
|
|
4392
|
-
total_reward = reward_records[0].get("total_reward")
|
|
4393
|
-
achievements_count = reward_records[0].get("achievements_count")
|
|
4394
|
-
if not _score_ok(total_reward, min_official, max_official):
|
|
4395
|
-
continue
|
|
4396
|
-
elif min_official is not None:
|
|
4397
|
-
# No reward found, but score filter requires it
|
|
4398
|
-
continue
|
|
4399
|
-
|
|
4400
|
-
judge_scores = metadata.get("judge_scores") or {}
|
|
4401
|
-
include = True
|
|
4402
|
-
for judge_name, threshold in (min_judge_scores or {}).items():
|
|
4403
|
-
if not _score_ok(judge_scores.get(judge_name), threshold, None):
|
|
4404
|
-
include = False
|
|
4405
|
-
break
|
|
4406
|
-
if not include:
|
|
4407
|
-
continue
|
|
4408
|
-
for judge_name, threshold in (max_judge_scores or {}).items():
|
|
4409
|
-
if not _score_ok(judge_scores.get(judge_name), None, threshold):
|
|
4410
|
-
include = False
|
|
4411
|
-
break
|
|
4412
|
-
if not include:
|
|
4413
|
-
continue
|
|
4414
|
-
|
|
4415
|
-
# Query messages for this session
|
|
4416
|
-
messages_query = """
|
|
4417
|
-
SELECT message_type, content, timestamp
|
|
4418
|
-
FROM messages
|
|
4419
|
-
WHERE session_id = :session_id
|
|
4420
|
-
ORDER BY timestamp ASC, id ASC
|
|
4421
|
-
"""
|
|
4422
|
-
msg_df = await tracer.db.query_traces(messages_query, {"session_id": session_id})
|
|
4423
|
-
message_rows = msg_df.to_dict("records") if hasattr(msg_df, "to_dict") else []
|
|
4424
|
-
|
|
4425
|
-
if not message_rows:
|
|
4426
|
-
# Fallback: check if prompt/completion in metadata (old format)
|
|
4427
|
-
prompt = metadata.get("prompt") or ""
|
|
4428
|
-
completion = metadata.get("completion") or ""
|
|
4429
|
-
if prompt and completion:
|
|
4430
|
-
record = {
|
|
4431
|
-
"messages": [
|
|
4432
|
-
{"role": "user", "content": str(prompt)},
|
|
4433
|
-
{"role": "assistant", "content": str(completion)},
|
|
4434
|
-
],
|
|
4435
|
-
"metadata": {
|
|
4436
|
-
"session_id": session_id,
|
|
4437
|
-
"env_name": metadata.get("env_name"),
|
|
4438
|
-
"policy_name": metadata.get("policy_name"),
|
|
4439
|
-
"seed": metadata.get("seed"),
|
|
4440
|
-
"total_reward": total_reward,
|
|
4441
|
-
"achievements_count": achievements_count,
|
|
4442
|
-
"model": metadata.get("model"),
|
|
4443
|
-
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4444
|
-
},
|
|
4445
|
-
}
|
|
4446
|
-
accepted.append(record)
|
|
4447
|
-
continue
|
|
4448
|
-
|
|
4449
|
-
# Extract user/assistant pairs from messages
|
|
4450
|
-
for i, msg_row in enumerate(message_rows):
|
|
4451
|
-
msg_type = msg_row.get("message_type")
|
|
4452
|
-
content_raw = msg_row.get("content")
|
|
4453
|
-
|
|
4454
|
-
# Look for user message
|
|
4455
|
-
if msg_type in ("user", "policy_user_prompt"):
|
|
4456
|
-
# Find next policy_system_prompt or assistant
|
|
4457
|
-
assistant_msg = None
|
|
4458
|
-
for j in range(i + 1, len(message_rows)):
|
|
4459
|
-
next_type = message_rows[j].get("message_type")
|
|
4460
|
-
if next_type in ("assistant", "policy_system_prompt"):
|
|
4461
|
-
if next_type == "assistant":
|
|
4462
|
-
assistant_msg = message_rows[j]
|
|
4463
|
-
break
|
|
4464
|
-
|
|
4465
|
-
# Parse content
|
|
4466
|
-
try:
|
|
4467
|
-
user_content = json.loads(content_raw) if isinstance(content_raw, str) else content_raw
|
|
4468
|
-
except Exception:
|
|
4469
|
-
user_content = content_raw
|
|
4470
|
-
|
|
4471
|
-
# If user_content is a message dict with a 'content' key, extract it
|
|
4472
|
-
if isinstance(user_content, dict) and "content" in user_content:
|
|
4473
|
-
user_content = user_content["content"]
|
|
4474
|
-
|
|
4475
|
-
# Extract text from structured content
|
|
4476
|
-
def extract_text(content: Any) -> str:
|
|
4477
|
-
if isinstance(content, str):
|
|
4478
|
-
return content
|
|
4479
|
-
if isinstance(content, dict):
|
|
4480
|
-
# Try payload.content for user prompts
|
|
4481
|
-
if "payload" in content and isinstance(content["payload"], dict):
|
|
4482
|
-
payload = content["payload"]
|
|
4483
|
-
if "content" in payload:
|
|
4484
|
-
return extract_text(payload["content"])
|
|
4485
|
-
# Try common keys
|
|
4486
|
-
for key in ["text", "content", "content_text"]:
|
|
4487
|
-
if key in content:
|
|
4488
|
-
val = content[key]
|
|
4489
|
-
if isinstance(val, str):
|
|
4490
|
-
return val
|
|
4491
|
-
return json.dumps(content)
|
|
4492
|
-
if isinstance(content, list):
|
|
4493
|
-
# Multimodal content - concatenate text parts
|
|
4494
|
-
parts = []
|
|
4495
|
-
for item in content:
|
|
4496
|
-
if isinstance(item, dict) and item.get("type") == "text":
|
|
4497
|
-
parts.append(item.get("text", ""))
|
|
4498
|
-
return " ".join(parts) if parts else str(content)
|
|
4499
|
-
return str(content)
|
|
4500
|
-
|
|
4501
|
-
user_text = extract_text(user_content)
|
|
4502
|
-
|
|
4503
|
-
# For assistant, we might not have it recorded, so use tool calls as completion
|
|
4504
|
-
assistant_text = ""
|
|
4505
|
-
assistant_content = None
|
|
4506
|
-
if assistant_msg:
|
|
4507
|
-
assistant_content_raw = assistant_msg.get("content")
|
|
4508
|
-
try:
|
|
4509
|
-
assistant_content = json.loads(assistant_content_raw) if isinstance(assistant_content_raw, str) else assistant_content_raw
|
|
4510
|
-
except Exception:
|
|
4511
|
-
assistant_content = assistant_content_raw
|
|
4512
|
-
|
|
4513
|
-
# If assistant_content is a message dict with a 'content' key, extract it
|
|
4514
|
-
if isinstance(assistant_content, dict) and "content" in assistant_content:
|
|
4515
|
-
assistant_content = assistant_content["content"]
|
|
4516
|
-
|
|
4517
|
-
assistant_text = extract_text(assistant_content)
|
|
4518
|
-
|
|
4519
|
-
if not user_text:
|
|
4520
|
-
continue
|
|
4521
|
-
|
|
4522
|
-
# Use full multimodal content if it's a list (contains images), otherwise use text
|
|
4523
|
-
user_content_for_message = user_content if isinstance(user_content, list) else user_text
|
|
4524
|
-
assistant_content_for_message = assistant_content if isinstance(assistant_content, list) else (assistant_text if assistant_text else "[no response recorded]")
|
|
4525
|
-
|
|
4526
|
-
record = {
|
|
4527
|
-
"messages": [
|
|
4528
|
-
{"role": "user", "content": user_content_for_message},
|
|
4529
|
-
{"role": "assistant", "content": assistant_content_for_message},
|
|
4530
|
-
],
|
|
4531
|
-
"metadata": {
|
|
4532
|
-
"session_id": session_id,
|
|
4533
|
-
"env_name": metadata.get("env_name"),
|
|
4534
|
-
"policy_name": metadata.get("policy_name"),
|
|
4535
|
-
"seed": metadata.get("seed"),
|
|
4536
|
-
"total_reward": total_reward,
|
|
4537
|
-
"achievements_count": achievements_count,
|
|
4538
|
-
"model": metadata.get("model"),
|
|
4539
|
-
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4540
|
-
},
|
|
4541
|
-
}
|
|
4542
|
-
accepted.append(record)
|
|
4543
|
-
|
|
4544
|
-
if not accepted:
|
|
4545
|
-
raise click.ClickException("No sessions matched the provided filters")
|
|
4546
|
-
|
|
4547
|
-
if limit is not None and limit > 0:
|
|
4548
|
-
accepted = accepted[:limit]
|
|
4549
|
-
|
|
4550
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
4551
|
-
with output_path.open("w", encoding="utf-8") as handle:
|
|
4552
|
-
for item in accepted:
|
|
4553
|
-
handle.write(json.dumps(item, ensure_ascii=False))
|
|
4554
|
-
handle.write("\n")
|
|
4555
|
-
|
|
4556
|
-
click.echo(f"Wrote {len(accepted)} examples -> {output_path}")
|
|
4557
|
-
await tracer.db.close()
|
|
3128
|
+
eval_command = eval_core.command
|
|
4558
3129
|
|
|
4559
|
-
|
|
3130
|
+
filter_command = filter_core.command
|
|
4560
3131
|
|
|
4561
3132
|
|
|
4562
3133
|
def register_eval(cli: click.Group) -> None:
|