synth-ai 0.2.13.dev2__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/README.md +1 -0
- examples/multi_step/SFT_README.md +147 -0
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -11
- examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/convert_traces_to_sft.py +84 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/run_sft_qwen30b.sh +45 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +3 -2
- examples/qwen_coder/configs/coder_lora_4b.toml +2 -1
- examples/qwen_coder/configs/coder_lora_small.toml +2 -1
- examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
- examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
- examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
- examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
- examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
- examples/qwen_vl/QUICKSTART.md +327 -0
- examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
- examples/qwen_vl/README.md +154 -0
- examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
- examples/qwen_vl/RL_VISION_TESTING.md +333 -0
- examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
- examples/qwen_vl/SETUP_COMPLETE.md +275 -0
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +490 -0
- examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
- examples/qwen_vl/__init__.py +2 -0
- examples/qwen_vl/collect_data_via_cli.md +423 -0
- examples/qwen_vl/collect_vision_traces.py +368 -0
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +127 -0
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +60 -0
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +43 -0
- examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +45 -0
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +44 -0
- examples/qwen_vl/configs/filter_qwen2vl_sft.toml +50 -0
- examples/qwen_vl/configs/filter_vision_sft.toml +53 -0
- examples/qwen_vl/configs/filter_vision_test.toml +8 -0
- examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
- examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
- examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
- examples/qwen_vl/run_vision_comparison.sh +62 -0
- examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
- examples/qwen_vl/test_image_validation.py +201 -0
- examples/qwen_vl/test_sft_vision_data.py +110 -0
- examples/rl/README.md +1 -1
- examples/rl/configs/eval_base_qwen.toml +17 -0
- examples/rl/configs/eval_rl_qwen.toml +13 -0
- examples/rl/configs/rl_from_base_qwen.toml +37 -0
- examples/rl/configs/rl_from_base_qwen17.toml +76 -0
- examples/rl/configs/rl_from_ft_qwen.toml +37 -0
- examples/rl/run_eval.py +436 -0
- examples/rl/run_rl_and_save.py +111 -0
- examples/rl/task_app/README.md +22 -0
- examples/rl/task_app/math_single_step.py +990 -0
- examples/rl/task_app/math_task_app.py +111 -0
- examples/sft/README.md +5 -5
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -2
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -3
- examples/sft/evaluate.py +4 -4
- examples/sft/export_dataset.py +7 -4
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/README.md +1 -1
- examples/swe/task_app/grpo_swe_mini.py +1 -1
- examples/swe/task_app/grpo_swe_mini_task_app.py +0 -12
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +13 -13
- examples/swe/task_app/hosted/policy_routes.py +0 -2
- examples/swe/task_app/hosted/rollout.py +2 -8
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +3 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +309 -14
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +75 -4
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +55 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +114 -32
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +127 -27
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +156 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +2 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +2 -0
- examples/task_apps/pokemon_red/task_app.py +199 -6
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +2 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +8 -4
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +258 -23
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +2 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/vlm/README.md +3 -3
- examples/vlm/configs/crafter_vlm_gpt4o.toml +2 -0
- examples/vlm/crafter_openai_vlm_agent.py +3 -5
- examples/vlm/filter_image_rows.py +1 -1
- examples/vlm/run_crafter_vlm_benchmark.py +2 -2
- examples/warming_up_to_rl/_utils.py +92 -0
- examples/warming_up_to_rl/analyze_trace_db.py +1 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +2 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_ft.toml +2 -0
- examples/warming_up_to_rl/export_trace_sft.py +174 -60
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/readme.md +63 -132
- examples/warming_up_to_rl/run_fft_and_save.py +1 -1
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rl_and_save.py +1 -1
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- examples/warming_up_to_rl/task_app/README.md +42 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +696 -0
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +478 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +204 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +618 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1081 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1861 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +62 -0
- synth_ai/__init__.py +44 -30
- synth_ai/_utils/__init__.py +47 -0
- synth_ai/_utils/base_url.py +10 -0
- synth_ai/_utils/http.py +10 -0
- synth_ai/_utils/prompts.py +10 -0
- synth_ai/_utils/task_app_state.py +12 -0
- synth_ai/_utils/user_config.py +10 -0
- synth_ai/api/models/supported.py +145 -7
- synth_ai/api/train/__init__.py +13 -1
- synth_ai/api/train/cli.py +30 -7
- synth_ai/api/train/config_finder.py +18 -11
- synth_ai/api/train/env_resolver.py +13 -10
- synth_ai/cli/__init__.py +66 -49
- synth_ai/cli/_modal_wrapper.py +9 -6
- synth_ai/cli/_typer_patch.py +0 -2
- synth_ai/cli/_validate_task_app.py +22 -4
- synth_ai/cli/legacy_root_backup.py +3 -1
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/task_app_discovery.py +7 -0
- synth_ai/cli/lib/task_app_env.py +518 -0
- synth_ai/cli/recent.py +1 -0
- synth_ai/cli/setup.py +266 -0
- synth_ai/cli/task_app_deploy.py +16 -0
- synth_ai/cli/task_app_list.py +25 -0
- synth_ai/cli/task_app_modal_serve.py +16 -0
- synth_ai/cli/task_app_serve.py +18 -0
- synth_ai/cli/task_apps.py +392 -141
- synth_ai/cli/train.py +18 -0
- synth_ai/cli/tui.py +62 -0
- synth_ai/demos/__init__.py +10 -0
- synth_ai/demos/core/__init__.py +28 -1
- synth_ai/demos/crafter/__init__.py +1 -0
- synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/demos/demo_registry.py +176 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/math/__init__.py +1 -0
- synth_ai/demos/math/_common.py +16 -0
- synth_ai/demos/math/app.py +38 -0
- synth_ai/demos/math/config.toml +76 -0
- synth_ai/demos/math/deploy_modal.py +54 -0
- synth_ai/demos/math/modal_task_app.py +702 -0
- synth_ai/demos/math/task_app_entry.py +51 -0
- synth_ai/environments/environment/core.py +7 -1
- synth_ai/environments/examples/bandit/engine.py +0 -1
- synth_ai/environments/examples/bandit/environment.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/verilog/engine.py +76 -10
- synth_ai/environments/examples/wordle/environment.py +0 -1
- synth_ai/evals/base.py +16 -5
- synth_ai/evals/client.py +1 -1
- synth_ai/inference/client.py +1 -1
- synth_ai/learning/client.py +1 -1
- synth_ai/learning/health.py +1 -1
- synth_ai/learning/jobs.py +1 -1
- synth_ai/learning/rl/client.py +1 -1
- synth_ai/learning/rl/env_keys.py +1 -1
- synth_ai/learning/rl/secrets.py +1 -1
- synth_ai/learning/sft/client.py +1 -1
- synth_ai/learning/sft/data.py +407 -4
- synth_ai/learning/validators.py +4 -1
- synth_ai/task/__init__.py +11 -1
- synth_ai/task/apps/__init__.py +5 -2
- synth_ai/task/config.py +259 -0
- synth_ai/task/contracts.py +15 -2
- synth_ai/task/rubrics/__init__.py +4 -2
- synth_ai/task/rubrics/loaders.py +27 -4
- synth_ai/task/rubrics/scoring.py +3 -0
- synth_ai/task/rubrics.py +219 -0
- synth_ai/task/trace_correlation_helpers.py +328 -0
- synth_ai/task/tracing_utils.py +14 -3
- synth_ai/task/validators.py +145 -2
- synth_ai/tracing_v3/config.py +15 -13
- synth_ai/tracing_v3/constants.py +21 -0
- synth_ai/tracing_v3/db_config.py +3 -1
- synth_ai/tracing_v3/decorators.py +10 -7
- synth_ai/tracing_v3/session_tracer.py +10 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +108 -77
- synth_ai/tracing_v3/utils.py +1 -1
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +911 -0
- synth_ai/utils/__init__.py +101 -0
- synth_ai/utils/base_url.py +94 -0
- synth_ai/utils/cli.py +131 -0
- synth_ai/utils/env.py +287 -0
- synth_ai/utils/http.py +169 -0
- synth_ai/utils/modal.py +308 -0
- synth_ai/utils/process.py +212 -0
- synth_ai/utils/prompts.py +39 -0
- synth_ai/utils/sqld.py +122 -0
- synth_ai/utils/task_app_discovery.py +882 -0
- synth_ai/utils/task_app_env.py +186 -0
- synth_ai/utils/task_app_state.py +318 -0
- synth_ai/utils/user_config.py +137 -0
- synth_ai/v0/config/__init__.py +1 -5
- synth_ai/v0/config/base_url.py +1 -7
- synth_ai/v0/tracing/config.py +1 -1
- synth_ai/v0/tracing/decorators.py +1 -1
- synth_ai/v0/tracing/upload.py +1 -1
- synth_ai/v0/tracing_v1/config.py +1 -1
- synth_ai/v0/tracing_v1/decorators.py +1 -1
- synth_ai/v0/tracing_v1/upload.py +1 -1
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/METADATA +85 -31
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/RECORD +286 -135
- synth_ai/cli/man.py +0 -106
- synth_ai/compound/cais.py +0 -0
- synth_ai/core/experiment.py +0 -13
- synth_ai/core/system.py +0 -15
- synth_ai/demo_registry.py +0 -295
- synth_ai/handshake.py +0 -109
- synth_ai/http.py +0 -26
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/top_level.txt +0 -0
synth_ai/cli/task_apps.py
CHANGED
|
@@ -36,19 +36,29 @@ except Exception: # pragma: no cover - fallback
|
|
|
36
36
|
import click
|
|
37
37
|
from click.exceptions import Abort
|
|
38
38
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
39
|
+
# Tracing imports - make conditional for optional dependencies
|
|
40
|
+
try:
|
|
41
|
+
from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
|
|
42
|
+
BaseEvent,
|
|
43
|
+
EnvironmentEvent,
|
|
44
|
+
RuntimeEvent,
|
|
45
|
+
SessionEventMarkovBlanketMessage,
|
|
46
|
+
SessionMessageContent,
|
|
47
|
+
SessionTimeStep,
|
|
48
|
+
SessionTracer,
|
|
49
|
+
TimeRecord,
|
|
50
|
+
)
|
|
51
|
+
from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
|
|
52
|
+
SessionTrace as V3SessionTrace,
|
|
53
|
+
)
|
|
54
|
+
_TRACING_AVAILABLE = True
|
|
55
|
+
except (ImportError, ModuleNotFoundError, TypeError):
|
|
56
|
+
# Tracing system not available (missing optional dependencies)
|
|
57
|
+
BaseEvent = EnvironmentEvent = RuntimeEvent = None # type: ignore
|
|
58
|
+
SessionEventMarkovBlanketMessage = SessionMessageContent = None # type: ignore
|
|
59
|
+
SessionTimeStep = SessionTracer = TimeRecord = None # type: ignore
|
|
60
|
+
V3SessionTrace = None # type: ignore
|
|
61
|
+
_TRACING_AVAILABLE = False
|
|
52
62
|
|
|
53
63
|
# ---------------------------------------------------------------------------
|
|
54
64
|
# Dynamic imports to avoid hard dependencies during type checking.
|
|
@@ -231,6 +241,24 @@ def _event_from_dict(payload: dict[str, Any]) -> BaseEvent:
|
|
|
231
241
|
system_state_after=payload.get("system_state_after"),
|
|
232
242
|
**base_kwargs,
|
|
233
243
|
)
|
|
244
|
+
# Check for LM CAIS event fields
|
|
245
|
+
if any(key in payload for key in ("model_name", "provider", "call_records")):
|
|
246
|
+
from synth_ai.tracing_v3.abstractions import LMCAISEvent
|
|
247
|
+
# Note: call_records are left as dicts - the storage layer will handle serialization
|
|
248
|
+
call_records = payload.get("call_records") or []
|
|
249
|
+
return LMCAISEvent(
|
|
250
|
+
model_name=payload.get("model_name", ""),
|
|
251
|
+
provider=payload.get("provider", ""),
|
|
252
|
+
input_tokens=payload.get("input_tokens"),
|
|
253
|
+
output_tokens=payload.get("output_tokens"),
|
|
254
|
+
total_tokens=payload.get("total_tokens"),
|
|
255
|
+
cost_usd=payload.get("cost_usd"),
|
|
256
|
+
latency_ms=payload.get("latency_ms"),
|
|
257
|
+
span_id=payload.get("span_id"),
|
|
258
|
+
trace_id=payload.get("trace_id"),
|
|
259
|
+
call_records=call_records,
|
|
260
|
+
**base_kwargs,
|
|
261
|
+
)
|
|
234
262
|
return BaseEvent(**base_kwargs)
|
|
235
263
|
|
|
236
264
|
|
|
@@ -320,21 +348,51 @@ async def _store_trace(
|
|
|
320
348
|
trace_namespace: dict[str, Any] | None,
|
|
321
349
|
extra_metadata: dict[str, Any] | None = None,
|
|
322
350
|
):
|
|
351
|
+
import logging
|
|
352
|
+
_logger = logging.getLogger(__name__)
|
|
353
|
+
|
|
354
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Called with tracer={tracer is not None}, trace_namespace={trace_namespace is not None}")
|
|
355
|
+
|
|
323
356
|
if tracer is None or not isinstance(trace_namespace, dict):
|
|
357
|
+
_logger.warning(f"[STORE_TRACE_DEBUG] Early return: tracer={tracer is not None}, trace_namespace type={type(trace_namespace)}")
|
|
324
358
|
return
|
|
359
|
+
|
|
360
|
+
_logger.info(f"[STORE_TRACE_DEBUG] trace_namespace keys: {list(trace_namespace.keys())}")
|
|
361
|
+
|
|
362
|
+
# Handle both formats:
|
|
363
|
+
# - With session_trace key: {"session_trace": {...}}
|
|
364
|
+
# - Without session_trace key (trace itself is the session): {"session_id": ..., "markov_blanket_message_history": ...}
|
|
325
365
|
session_payload = trace_namespace.get("session_trace")
|
|
326
366
|
if not isinstance(session_payload, dict):
|
|
327
|
-
|
|
367
|
+
# If no session_trace key, assume "full" format where trace itself is the session_trace
|
|
368
|
+
if "session_id" in trace_namespace:
|
|
369
|
+
session_payload = trace_namespace
|
|
370
|
+
_logger.info("[STORE_TRACE_DEBUG] Using trace_namespace directly as session_payload (no session_trace key)")
|
|
371
|
+
else:
|
|
372
|
+
_logger.warning(f"[STORE_TRACE_DEBUG] No session_trace found or wrong type: {type(session_payload)}")
|
|
373
|
+
return
|
|
374
|
+
|
|
375
|
+
_logger.info(f"[STORE_TRACE_DEBUG] session_payload keys: {list(session_payload.keys())}")
|
|
376
|
+
msg_count = len(session_payload.get("markov_blanket_message_history", []))
|
|
377
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Found {msg_count} messages in session_payload")
|
|
378
|
+
|
|
328
379
|
trace_obj = _session_trace_from_dict(session_payload)
|
|
329
380
|
if trace_obj is None:
|
|
381
|
+
_logger.warning("[STORE_TRACE_DEBUG] _session_trace_from_dict returned None")
|
|
330
382
|
return
|
|
383
|
+
|
|
384
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Created SessionTrace object with {len(trace_obj.markov_blanket_message_history)} messages")
|
|
385
|
+
|
|
331
386
|
if tracer.db is None:
|
|
332
387
|
await tracer.initialize()
|
|
333
388
|
meta = dict(trace_obj.metadata or {})
|
|
334
389
|
if extra_metadata:
|
|
335
390
|
meta.update(extra_metadata)
|
|
336
391
|
trace_obj.metadata = meta
|
|
392
|
+
|
|
393
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Calling insert_session_trace for session_id={trace_obj.session_id}")
|
|
337
394
|
await tracer.db.insert_session_trace(trace_obj)
|
|
395
|
+
_logger.info("[STORE_TRACE_DEBUG] Successfully inserted trace")
|
|
338
396
|
|
|
339
397
|
def _temporary_sys_path(paths: Sequence[Path]):
|
|
340
398
|
"""Context manager to prepend entries to sys.path temporarily."""
|
|
@@ -881,43 +939,43 @@ def _build_modal_config_from_ast(modal_call: ast.Call) -> ModalDeploymentConfigT
|
|
|
881
939
|
for kw in modal_call.keywords:
|
|
882
940
|
if kw.arg and isinstance(kw.value, ast.Constant):
|
|
883
941
|
kwargs[kw.arg] = kw.value.value
|
|
884
|
-
elif kw.arg == "pip_packages" and isinstance(kw.value,
|
|
942
|
+
elif kw.arg == "pip_packages" and isinstance(kw.value, ast.List | ast.Tuple):
|
|
885
943
|
# Handle pip_packages list/tuple
|
|
886
944
|
packages: list[str] = []
|
|
887
945
|
value_node = kw.value
|
|
888
|
-
if isinstance(value_node,
|
|
946
|
+
if isinstance(value_node, ast.List | ast.Tuple):
|
|
889
947
|
for elt in value_node.elts:
|
|
890
948
|
if isinstance(elt, ast.Constant):
|
|
891
949
|
packages.append(elt.value)
|
|
892
950
|
kwargs[kw.arg] = tuple(packages)
|
|
893
|
-
elif kw.arg == "extra_local_dirs" and isinstance(kw.value,
|
|
951
|
+
elif kw.arg == "extra_local_dirs" and isinstance(kw.value, ast.List | ast.Tuple):
|
|
894
952
|
# Handle extra_local_dirs list/tuple of tuples
|
|
895
953
|
dirs = []
|
|
896
954
|
value_node = kw.value
|
|
897
|
-
if isinstance(value_node,
|
|
955
|
+
if isinstance(value_node, ast.List | ast.Tuple):
|
|
898
956
|
for elt in value_node.elts:
|
|
899
|
-
if isinstance(elt,
|
|
957
|
+
if isinstance(elt, ast.List | ast.Tuple) and len(elt.elts) == 2:
|
|
900
958
|
src = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
|
|
901
959
|
dst = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
|
|
902
960
|
if src and dst:
|
|
903
961
|
dirs.append((src, dst))
|
|
904
962
|
kwargs[kw.arg] = tuple(dirs)
|
|
905
|
-
elif kw.arg == "secret_names" and isinstance(kw.value,
|
|
963
|
+
elif kw.arg == "secret_names" and isinstance(kw.value, ast.List | ast.Tuple):
|
|
906
964
|
# Handle secret_names list/tuple
|
|
907
965
|
secrets = []
|
|
908
966
|
value_node = kw.value
|
|
909
|
-
if isinstance(value_node,
|
|
967
|
+
if isinstance(value_node, ast.List | ast.Tuple):
|
|
910
968
|
for elt in value_node.elts:
|
|
911
969
|
if isinstance(elt, ast.Constant):
|
|
912
970
|
secrets.append(elt.value)
|
|
913
971
|
kwargs[kw.arg] = tuple(secrets)
|
|
914
|
-
elif kw.arg == "volume_mounts" and isinstance(kw.value,
|
|
972
|
+
elif kw.arg == "volume_mounts" and isinstance(kw.value, ast.List | ast.Tuple):
|
|
915
973
|
# Handle volume_mounts list/tuple of tuples
|
|
916
974
|
mounts = []
|
|
917
975
|
value_node = kw.value
|
|
918
|
-
if isinstance(value_node,
|
|
976
|
+
if isinstance(value_node, ast.List | ast.Tuple):
|
|
919
977
|
for elt in value_node.elts:
|
|
920
|
-
if isinstance(elt,
|
|
978
|
+
if isinstance(elt, ast.List | ast.Tuple) and len(elt.elts) == 2:
|
|
921
979
|
name = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
|
|
922
980
|
mount = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
|
|
923
981
|
if name and mount:
|
|
@@ -2213,7 +2271,7 @@ def validate_task_app_cmd(
|
|
|
2213
2271
|
import time
|
|
2214
2272
|
|
|
2215
2273
|
# Import the validate_task_app function defined in this module
|
|
2216
|
-
from
|
|
2274
|
+
from ._validate_task_app import validate_task_app # type: ignore[attr-defined]
|
|
2217
2275
|
|
|
2218
2276
|
proc = None
|
|
2219
2277
|
task_app_url = url
|
|
@@ -3044,6 +3102,11 @@ def _write_modal_entrypoint(
|
|
|
3044
3102
|
if not any(str(p).startswith("synth-ai") for p in pip_packages):
|
|
3045
3103
|
pip_packages.insert(0, synth_pkg)
|
|
3046
3104
|
|
|
3105
|
+
apt_packages = list(modal_cfg.apt_packages)
|
|
3106
|
+
click.echo(f"[DEBUG] modal_cfg.apt_packages type: {type(modal_cfg.apt_packages)}")
|
|
3107
|
+
click.echo(f"[DEBUG] modal_cfg.apt_packages value: {modal_cfg.apt_packages}")
|
|
3108
|
+
click.echo(f"[DEBUG] apt_packages after list(): {apt_packages}")
|
|
3109
|
+
|
|
3047
3110
|
local_dirs = [(str(Path(src)), dst) for src, dst in modal_cfg.extra_local_dirs]
|
|
3048
3111
|
# Also mount the host synth_ai source if available to ensure latest code is used
|
|
3049
3112
|
if host_synth is not None:
|
|
@@ -3090,6 +3153,15 @@ INLINE_SECRET_VALUES = {inline_secret_values!r}
|
|
|
3090
3153
|
|
|
3091
3154
|
image = Image.debian_slim(python_version={modal_cfg.python_version!r})
|
|
3092
3155
|
|
|
3156
|
+
# CRITICAL: Install iverilog for Verilog task app (hardcoded to prevent config issues)
|
|
3157
|
+
if {entry.app_id!r} == "grpo-verilog":
|
|
3158
|
+
image = image.apt_install("iverilog")
|
|
3159
|
+
|
|
3160
|
+
# Install apt packages first (before pip)
|
|
3161
|
+
apt_packages = {apt_packages!r}
|
|
3162
|
+
if apt_packages:
|
|
3163
|
+
image = image.apt_install(*apt_packages)
|
|
3164
|
+
|
|
3093
3165
|
pip_packages = {pip_packages!r}
|
|
3094
3166
|
if pip_packages:
|
|
3095
3167
|
image = image.pip_install(*pip_packages)
|
|
@@ -3251,7 +3323,7 @@ def register(cli: click.Group) -> None:
|
|
|
3251
3323
|
)
|
|
3252
3324
|
@click.option(
|
|
3253
3325
|
"--trace-db",
|
|
3254
|
-
default="traces/v3/
|
|
3326
|
+
default="traces/v3/synth_ai.db",
|
|
3255
3327
|
show_default=True,
|
|
3256
3328
|
help="SQLite/Turso URL for storing rollout traces set to 'none' to disable persistence.",
|
|
3257
3329
|
)
|
|
@@ -3284,8 +3356,13 @@ def eval_command(
|
|
|
3284
3356
|
pointing at a remote `--url`, supply matching `--env-file` values so the CLI can
|
|
3285
3357
|
forward authentication headers to the running service.
|
|
3286
3358
|
"""
|
|
3359
|
+
# Parse and validate TOML config
|
|
3360
|
+
from synth_ai.task.config import EvalConfig
|
|
3361
|
+
|
|
3287
3362
|
cfg: dict[str, Any] = {}
|
|
3363
|
+
eval_cfg: EvalConfig | None = None
|
|
3288
3364
|
config_path: Path | None = None
|
|
3365
|
+
|
|
3289
3366
|
if config:
|
|
3290
3367
|
config_path = Path(config)
|
|
3291
3368
|
else:
|
|
@@ -3307,21 +3384,37 @@ def eval_command(
|
|
|
3307
3384
|
if isinstance(parsed, dict):
|
|
3308
3385
|
section = parsed.get("eval")
|
|
3309
3386
|
cfg = dict(section) if isinstance(section, dict) else dict(parsed)
|
|
3387
|
+
|
|
3388
|
+
# Validate config with dataclass
|
|
3389
|
+
try:
|
|
3390
|
+
eval_cfg = EvalConfig.from_dict(cfg)
|
|
3391
|
+
click.echo(f"✓ Config validated: {len(eval_cfg.seeds)} seeds, model={eval_cfg.model}")
|
|
3392
|
+
except (ValueError, TypeError) as validation_error:
|
|
3393
|
+
raise click.ClickException(f"Invalid eval config: {validation_error}") from validation_error
|
|
3394
|
+
except click.ClickException:
|
|
3395
|
+
raise
|
|
3310
3396
|
except Exception as exc:
|
|
3311
3397
|
raise click.ClickException(f"Failed to parse TOML '{config_path}': {exc}") from exc
|
|
3312
3398
|
|
|
3313
|
-
|
|
3399
|
+
# CLI args override config
|
|
3400
|
+
if eval_cfg:
|
|
3401
|
+
app_id = app_id or eval_cfg.app_id
|
|
3402
|
+
else:
|
|
3403
|
+
app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
|
|
3314
3404
|
|
|
3315
3405
|
metadata_filters: dict[str, str] = {}
|
|
3316
|
-
|
|
3317
|
-
|
|
3318
|
-
|
|
3319
|
-
|
|
3320
|
-
|
|
3321
|
-
|
|
3322
|
-
|
|
3323
|
-
|
|
3324
|
-
|
|
3406
|
+
if eval_cfg:
|
|
3407
|
+
metadata_filters.update(eval_cfg.metadata)
|
|
3408
|
+
else:
|
|
3409
|
+
cfg_metadata = cfg.get("metadata")
|
|
3410
|
+
if isinstance(cfg_metadata, dict):
|
|
3411
|
+
for key, value in cfg_metadata.items():
|
|
3412
|
+
metadata_filters[str(key)] = str(value)
|
|
3413
|
+
elif isinstance(cfg_metadata, list):
|
|
3414
|
+
for item in cfg_metadata:
|
|
3415
|
+
if isinstance(item, str) and "=" in item:
|
|
3416
|
+
key, value = item.split("=", 1)
|
|
3417
|
+
metadata_filters[key.strip()] = value.strip()
|
|
3325
3418
|
|
|
3326
3419
|
for item in metadata or ():
|
|
3327
3420
|
if "=" not in item:
|
|
@@ -3334,11 +3427,14 @@ def eval_command(
|
|
|
3334
3427
|
metadata_filters[key] = value
|
|
3335
3428
|
|
|
3336
3429
|
metadata_sql_query: str | None = None
|
|
3337
|
-
|
|
3338
|
-
|
|
3339
|
-
|
|
3340
|
-
|
|
3341
|
-
|
|
3430
|
+
if eval_cfg and eval_cfg.metadata_sql:
|
|
3431
|
+
metadata_sql_query = eval_cfg.metadata_sql
|
|
3432
|
+
else:
|
|
3433
|
+
cfg_metadata_sql = cfg.get("metadata_sql")
|
|
3434
|
+
if isinstance(cfg_metadata_sql, dict):
|
|
3435
|
+
metadata_sql_query = cfg_metadata_sql.get("query") or cfg_metadata_sql.get("sql")
|
|
3436
|
+
elif isinstance(cfg_metadata_sql, str):
|
|
3437
|
+
metadata_sql_query = cfg_metadata_sql
|
|
3342
3438
|
|
|
3343
3439
|
if metadata_sql:
|
|
3344
3440
|
metadata_sql_query = metadata_sql
|
|
@@ -3780,18 +3876,52 @@ def eval_command(
|
|
|
3780
3876
|
|
|
3781
3877
|
async def _run_seed(seed_val: int) -> None:
|
|
3782
3878
|
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records
|
|
3879
|
+
# Read env_name and policy_name from config if available
|
|
3880
|
+
env_name = cfg.get("env_name") or (cfg.get("env", {}).get("env_name") if isinstance(cfg.get("env"), dict) else None)
|
|
3881
|
+
policy_name = cfg.get("policy_name") or (cfg.get("policy", {}).get("policy_name") if isinstance(cfg.get("policy"), dict) else None)
|
|
3882
|
+
env_config_overrides = cfg.get("env_config", {}) if isinstance(cfg.get("env_config"), dict) else {}
|
|
3883
|
+
policy_config_overrides = cfg.get("policy_config", {}) if isinstance(cfg.get("policy_config"), dict) else {}
|
|
3884
|
+
|
|
3885
|
+
# Debug: print config parsing
|
|
3886
|
+
if seed_val == 0:
|
|
3887
|
+
click.echo(f"[DEBUG] env_name from config: {env_name}")
|
|
3888
|
+
click.echo(f"[DEBUG] policy_name from config: {policy_name}")
|
|
3889
|
+
|
|
3890
|
+
# Generate default ops sequence if not provided
|
|
3891
|
+
max_llm_calls = policy_config_overrides.get("max_llm_calls", 10)
|
|
3892
|
+
ops_list = cfg.get("ops", [])
|
|
3893
|
+
if not ops_list:
|
|
3894
|
+
# Generate default "agent, env" pairs for max_llm_calls
|
|
3895
|
+
ops_list = ["agent", "env"] * int(max_llm_calls)
|
|
3896
|
+
|
|
3783
3897
|
body = {
|
|
3784
3898
|
"run_id": str(uuid.uuid4()),
|
|
3785
|
-
"env": {"config": {"split": split, "index": seed_val}, "seed": seed_val},
|
|
3899
|
+
"env": {"config": {"split": split, "index": seed_val, **env_config_overrides}, "seed": seed_val},
|
|
3786
3900
|
"policy": {
|
|
3787
|
-
"policy_name": selected_model,
|
|
3788
|
-
"config": {"model": selected_model, **policy_overrides},
|
|
3901
|
+
"policy_name": policy_name or selected_model,
|
|
3902
|
+
"config": {"model": selected_model, **policy_overrides, **policy_config_overrides},
|
|
3903
|
+
},
|
|
3904
|
+
"ops": ops_list,
|
|
3905
|
+
"record": {
|
|
3906
|
+
"return_trace": cfg.get("return_trace", True),
|
|
3907
|
+
"trace_format": cfg.get("trace_format", "structured"),
|
|
3789
3908
|
},
|
|
3790
|
-
"
|
|
3909
|
+
"mode": "eval", # RolloutMode.EVAL: use inference URLs as-is, no transformations
|
|
3791
3910
|
}
|
|
3911
|
+
if env_name:
|
|
3912
|
+
body["env"]["env_name"] = env_name
|
|
3913
|
+
|
|
3914
|
+
# Debug: print the body being sent
|
|
3915
|
+
if seed_val == 0:
|
|
3916
|
+
click.echo(f"[DEBUG] rollout body env: {body['env']}")
|
|
3917
|
+
click.echo(f"[DEBUG] rollout body policy: {body['policy']}")
|
|
3918
|
+
click.echo(f"[DEBUG] rollout body mode: {body.get('mode', 'NOT SET')}")
|
|
3792
3919
|
rollout_elapsed: float | None = None
|
|
3793
3920
|
rollout_start = time.perf_counter()
|
|
3794
3921
|
try:
|
|
3922
|
+
import logging
|
|
3923
|
+
_log = logging.getLogger(__name__)
|
|
3924
|
+
_log.info(f"[EVAL_BODY_DEBUG] Sending body with mode={body.get('mode')}")
|
|
3795
3925
|
async with semaphore:
|
|
3796
3926
|
response = await async_client.post("/rollout", json=body)
|
|
3797
3927
|
rollout_elapsed = time.perf_counter() - rollout_start
|
|
@@ -3812,6 +3942,10 @@ def eval_command(
|
|
|
3812
3942
|
data = response.json()
|
|
3813
3943
|
except Exception:
|
|
3814
3944
|
data = None
|
|
3945
|
+
|
|
3946
|
+
# Debug: print validation errors
|
|
3947
|
+
if response.status_code == 422 and data:
|
|
3948
|
+
click.echo(f"[DEBUG] 422 Validation Error: {data}")
|
|
3815
3949
|
|
|
3816
3950
|
metrics: dict[str, Any] | None = None
|
|
3817
3951
|
completion: str | None = None
|
|
@@ -3825,16 +3959,33 @@ def eval_command(
|
|
|
3825
3959
|
session_trace_dict: dict[str, Any] | None = None
|
|
3826
3960
|
|
|
3827
3961
|
if isinstance(data, dict):
|
|
3962
|
+
import logging
|
|
3963
|
+
_logger = logging.getLogger(__name__)
|
|
3964
|
+
_logger.info(f"[EVAL_DEBUG] Response data keys: {list(data.keys())}")
|
|
3965
|
+
if "detail" in data:
|
|
3966
|
+
_logger.error(f"[EVAL_DEBUG] Task app returned error: {data['detail']}")
|
|
3828
3967
|
trace_namespace = data.get("trace")
|
|
3968
|
+
_logger.info(f"[EVAL_DEBUG] trace_namespace type: {type(trace_namespace)}, value: {trace_namespace if not isinstance(trace_namespace, dict) else 'dict with keys: ' + str(list(trace_namespace.keys()) if trace_namespace else 'None')}")
|
|
3829
3969
|
if not isinstance(trace_namespace, dict):
|
|
3830
3970
|
raise RuntimeError(
|
|
3831
|
-
"
|
|
3971
|
+
"The 'synth-ai eval' command requires trace payloads in rollout responses. "
|
|
3972
|
+
"Ensure the rollout request includes 'trace_format': 'structured' and 'return_trace': true, "
|
|
3973
|
+
"and that task app tracing is enabled (TASKAPP_TRACING_ENABLED=1). "
|
|
3974
|
+
"Note: This is specific to the eval command - general rollout endpoints don't require traces."
|
|
3832
3975
|
)
|
|
3976
|
+
# Handle both "compact" and "full" trace formats:
|
|
3977
|
+
# - compact: trace_namespace contains {session_id, metadata, ...}
|
|
3978
|
+
# - full: trace_namespace IS the full session_trace dict
|
|
3833
3979
|
session_trace_dict = trace_namespace.get("session_trace")
|
|
3834
3980
|
if not isinstance(session_trace_dict, dict):
|
|
3835
|
-
|
|
3836
|
-
|
|
3837
|
-
|
|
3981
|
+
# If no session_trace key, assume "full" format where trace itself is the session_trace
|
|
3982
|
+
if "session_id" in trace_namespace:
|
|
3983
|
+
session_trace_dict = trace_namespace
|
|
3984
|
+
else:
|
|
3985
|
+
raise RuntimeError(
|
|
3986
|
+
"The 'synth-ai eval' command requires 'session_trace' in the trace payload or a valid full trace format. "
|
|
3987
|
+
"Ensure the task app is using tracing_v3 and returning structured trace data."
|
|
3988
|
+
)
|
|
3838
3989
|
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else None
|
|
3839
3990
|
if metrics:
|
|
3840
3991
|
mean_return = metrics.get("mean_return") or metrics.get("total_reward")
|
|
@@ -3956,26 +4107,27 @@ def eval_command(
|
|
|
3956
4107
|
for spec in judge_specs:
|
|
3957
4108
|
score_value: float | None = None
|
|
3958
4109
|
judge_elapsed: float | None = None
|
|
3959
|
-
|
|
3960
|
-
|
|
3961
|
-
|
|
3962
|
-
|
|
3963
|
-
|
|
3964
|
-
|
|
3965
|
-
|
|
3966
|
-
|
|
3967
|
-
|
|
3968
|
-
|
|
3969
|
-
|
|
3970
|
-
|
|
3971
|
-
|
|
4110
|
+
# Run judges for all tasks (text-based and trajectory-based)
|
|
4111
|
+
# Text-based tasks have completion, trajectory-based tasks use response
|
|
4112
|
+
judge_payload = {
|
|
4113
|
+
"seed": seed_val,
|
|
4114
|
+
"prompt_index": prompt_index,
|
|
4115
|
+
"prompt": prompt_text,
|
|
4116
|
+
"completion": completion,
|
|
4117
|
+
"metrics": metrics,
|
|
4118
|
+
"response": data,
|
|
4119
|
+
"trace": trace_namespace,
|
|
4120
|
+
}
|
|
4121
|
+
try:
|
|
4122
|
+
judge_start = time.perf_counter()
|
|
4123
|
+
result = spec.fn(judge_payload, **spec.kwargs)
|
|
4124
|
+
judge_elapsed = time.perf_counter() - judge_start
|
|
4125
|
+
if isinstance(result, int | float):
|
|
4126
|
+
score_value = float(result)
|
|
4127
|
+
except Exception as exc:
|
|
4128
|
+
if judge_elapsed is None:
|
|
3972
4129
|
judge_elapsed = time.perf_counter() - judge_start
|
|
3973
|
-
|
|
3974
|
-
score_value = float(result)
|
|
3975
|
-
except Exception as exc:
|
|
3976
|
-
if judge_elapsed is None:
|
|
3977
|
-
judge_elapsed = time.perf_counter() - judge_start
|
|
3978
|
-
click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
|
|
4130
|
+
click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
|
|
3979
4131
|
judges_timings[spec.name] = judge_elapsed
|
|
3980
4132
|
judge_scores[spec.name] = score_value
|
|
3981
4133
|
|
|
@@ -4129,6 +4281,9 @@ def filter_command(config_path: str) -> None:
|
|
|
4129
4281
|
high-quality traces. See `customers/agora_single_file/configs/filter_local.toml`
|
|
4130
4282
|
for a working example.
|
|
4131
4283
|
"""
|
|
4284
|
+
# Parse and validate TOML config
|
|
4285
|
+
from synth_ai.task.config import FilterConfig
|
|
4286
|
+
|
|
4132
4287
|
if _toml is None:
|
|
4133
4288
|
raise click.ClickException("TOML parser not available; install tomli or use Python 3.11+")
|
|
4134
4289
|
|
|
@@ -4141,58 +4296,37 @@ def filter_command(config_path: str) -> None:
|
|
|
4141
4296
|
except Exception as exc:
|
|
4142
4297
|
raise click.ClickException(f"Failed to parse TOML '{cfg_path}': {exc}") from exc
|
|
4143
4298
|
|
|
4144
|
-
|
|
4145
|
-
if not isinstance(
|
|
4299
|
+
filter_cfg_dict = config_data.get("filter") if isinstance(config_data, dict) else None
|
|
4300
|
+
if not isinstance(filter_cfg_dict, dict):
|
|
4146
4301
|
raise click.ClickException("Config must contain a [filter] table")
|
|
4147
4302
|
|
|
4148
|
-
|
|
4149
|
-
if not db_value:
|
|
4150
|
-
raise click.ClickException("filter.db must be provided")
|
|
4151
|
-
if "://" in db_value:
|
|
4152
|
-
db_url = db_value
|
|
4153
|
-
else:
|
|
4154
|
-
db_path = Path(db_value).expanduser()
|
|
4155
|
-
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
4156
|
-
db_url = f"sqlite+aiosqlite:///{db_path}"
|
|
4157
|
-
|
|
4158
|
-
output_value = filter_cfg.get("output")
|
|
4159
|
-
if not output_value:
|
|
4160
|
-
raise click.ClickException("filter.output must be provided")
|
|
4161
|
-
output_path = Path(str(output_value)).expanduser()
|
|
4162
|
-
|
|
4163
|
-
splits = set(filter_cfg.get("splits", []) or [])
|
|
4164
|
-
task_ids = set(filter_cfg.get("task_ids", []) or [])
|
|
4165
|
-
models = set(filter_cfg.get("models", []) or [])
|
|
4166
|
-
min_official = filter_cfg.get("min_official_score")
|
|
4167
|
-
max_official = filter_cfg.get("max_official_score")
|
|
4168
|
-
if min_official is not None:
|
|
4169
|
-
try:
|
|
4170
|
-
min_official = float(min_official)
|
|
4171
|
-
except Exception as err:
|
|
4172
|
-
raise click.ClickException("filter.min_official_score must be numeric") from err
|
|
4173
|
-
if max_official is not None:
|
|
4174
|
-
try:
|
|
4175
|
-
max_official = float(max_official)
|
|
4176
|
-
except Exception as err:
|
|
4177
|
-
raise click.ClickException("filter.max_official_score must be numeric") from err
|
|
4178
|
-
min_judge_scores = filter_cfg.get("min_judge_scores", {}) or {}
|
|
4179
|
-
max_judge_scores = filter_cfg.get("max_judge_scores", {}) or {}
|
|
4303
|
+
# Validate config with dataclass
|
|
4180
4304
|
try:
|
|
4181
|
-
|
|
4182
|
-
|
|
4183
|
-
|
|
4184
|
-
|
|
4185
|
-
|
|
4186
|
-
|
|
4187
|
-
|
|
4188
|
-
|
|
4189
|
-
|
|
4190
|
-
|
|
4191
|
-
|
|
4192
|
-
|
|
4193
|
-
|
|
4194
|
-
|
|
4195
|
-
|
|
4305
|
+
filter_cfg = FilterConfig.from_dict(filter_cfg_dict)
|
|
4306
|
+
click.echo(f"✓ Config validated: db={filter_cfg.db}, output={filter_cfg.output}")
|
|
4307
|
+
if filter_cfg.min_official_score is not None:
|
|
4308
|
+
click.echo(f" → Filtering for official score >= {filter_cfg.min_official_score}")
|
|
4309
|
+
if filter_cfg.limit:
|
|
4310
|
+
click.echo(f" → Limiting to {filter_cfg.limit} examples")
|
|
4311
|
+
except (ValueError, TypeError) as validation_error:
|
|
4312
|
+
raise click.ClickException(f"Invalid filter config: {validation_error}") from validation_error
|
|
4313
|
+
|
|
4314
|
+
# Use validated config
|
|
4315
|
+
db_url = filter_cfg.get_db_url()
|
|
4316
|
+
output_path = filter_cfg.get_output_path()
|
|
4317
|
+
|
|
4318
|
+
# Extract validated fields from dataclass
|
|
4319
|
+
splits = set(filter_cfg.splits)
|
|
4320
|
+
task_ids = set(filter_cfg.task_ids)
|
|
4321
|
+
models = set(filter_cfg.models)
|
|
4322
|
+
min_official = filter_cfg.min_official_score
|
|
4323
|
+
max_official = filter_cfg.max_official_score
|
|
4324
|
+
min_judge_scores = filter_cfg.min_judge_scores
|
|
4325
|
+
max_judge_scores = filter_cfg.max_judge_scores
|
|
4326
|
+
# Note: min_created_at and max_created_at not yet in FilterConfig dataclass
|
|
4327
|
+
min_created = _parse_datetime_for_trace(filter_cfg_dict.get("min_created_at"))
|
|
4328
|
+
max_created = _parse_datetime_for_trace(filter_cfg_dict.get("max_created_at"))
|
|
4329
|
+
limit = filter_cfg.limit
|
|
4196
4330
|
|
|
4197
4331
|
def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
|
|
4198
4332
|
try:
|
|
@@ -4247,8 +4381,21 @@ def filter_command(config_path: str) -> None:
|
|
|
4247
4381
|
if max_created and (created_at_dt is None or created_at_dt > max_created):
|
|
4248
4382
|
continue
|
|
4249
4383
|
|
|
4250
|
-
if
|
|
4251
|
-
|
|
4384
|
+
# Check against outcome_rewards if score filter is set
|
|
4385
|
+
total_reward = None
|
|
4386
|
+
achievements_count = None
|
|
4387
|
+
if min_official is not None or max_official is not None:
|
|
4388
|
+
reward_query = "SELECT total_reward, achievements_count FROM outcome_rewards WHERE session_id = :session_id"
|
|
4389
|
+
reward_rows = await tracer.db.query_traces(reward_query, {"session_id": session_id})
|
|
4390
|
+
reward_records = reward_rows.to_dict("records") if hasattr(reward_rows, "to_dict") else []
|
|
4391
|
+
if reward_records:
|
|
4392
|
+
total_reward = reward_records[0].get("total_reward")
|
|
4393
|
+
achievements_count = reward_records[0].get("achievements_count")
|
|
4394
|
+
if not _score_ok(total_reward, min_official, max_official):
|
|
4395
|
+
continue
|
|
4396
|
+
elif min_official is not None:
|
|
4397
|
+
# No reward found, but score filter requires it
|
|
4398
|
+
continue
|
|
4252
4399
|
|
|
4253
4400
|
judge_scores = metadata.get("judge_scores") or {}
|
|
4254
4401
|
include = True
|
|
@@ -4265,30 +4412,134 @@ def filter_command(config_path: str) -> None:
|
|
|
4265
4412
|
if not include:
|
|
4266
4413
|
continue
|
|
4267
4414
|
|
|
4268
|
-
|
|
4269
|
-
|
|
4270
|
-
|
|
4415
|
+
# Query messages for this session
|
|
4416
|
+
messages_query = """
|
|
4417
|
+
SELECT message_type, content, timestamp
|
|
4418
|
+
FROM messages
|
|
4419
|
+
WHERE session_id = :session_id
|
|
4420
|
+
ORDER BY timestamp ASC, id ASC
|
|
4421
|
+
"""
|
|
4422
|
+
msg_df = await tracer.db.query_traces(messages_query, {"session_id": session_id})
|
|
4423
|
+
message_rows = msg_df.to_dict("records") if hasattr(msg_df, "to_dict") else []
|
|
4424
|
+
|
|
4425
|
+
if not message_rows:
|
|
4426
|
+
# Fallback: check if prompt/completion in metadata (old format)
|
|
4427
|
+
prompt = metadata.get("prompt") or ""
|
|
4428
|
+
completion = metadata.get("completion") or ""
|
|
4429
|
+
if prompt and completion:
|
|
4430
|
+
record = {
|
|
4431
|
+
"messages": [
|
|
4432
|
+
{"role": "user", "content": str(prompt)},
|
|
4433
|
+
{"role": "assistant", "content": str(completion)},
|
|
4434
|
+
],
|
|
4435
|
+
"metadata": {
|
|
4436
|
+
"session_id": session_id,
|
|
4437
|
+
"env_name": metadata.get("env_name"),
|
|
4438
|
+
"policy_name": metadata.get("policy_name"),
|
|
4439
|
+
"seed": metadata.get("seed"),
|
|
4440
|
+
"total_reward": total_reward,
|
|
4441
|
+
"achievements_count": achievements_count,
|
|
4442
|
+
"model": metadata.get("model"),
|
|
4443
|
+
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4444
|
+
},
|
|
4445
|
+
}
|
|
4446
|
+
accepted.append(record)
|
|
4271
4447
|
continue
|
|
4272
4448
|
|
|
4273
|
-
|
|
4274
|
-
|
|
4275
|
-
|
|
4276
|
-
|
|
4277
|
-
|
|
4278
|
-
|
|
4279
|
-
|
|
4280
|
-
|
|
4281
|
-
|
|
4282
|
-
|
|
4283
|
-
|
|
4284
|
-
|
|
4285
|
-
|
|
4286
|
-
|
|
4287
|
-
|
|
4288
|
-
|
|
4289
|
-
|
|
4290
|
-
|
|
4291
|
-
|
|
4449
|
+
# Extract user/assistant pairs from messages
|
|
4450
|
+
for i, msg_row in enumerate(message_rows):
|
|
4451
|
+
msg_type = msg_row.get("message_type")
|
|
4452
|
+
content_raw = msg_row.get("content")
|
|
4453
|
+
|
|
4454
|
+
# Look for user message
|
|
4455
|
+
if msg_type in ("user", "policy_user_prompt"):
|
|
4456
|
+
# Find next policy_system_prompt or assistant
|
|
4457
|
+
assistant_msg = None
|
|
4458
|
+
for j in range(i + 1, len(message_rows)):
|
|
4459
|
+
next_type = message_rows[j].get("message_type")
|
|
4460
|
+
if next_type in ("assistant", "policy_system_prompt"):
|
|
4461
|
+
if next_type == "assistant":
|
|
4462
|
+
assistant_msg = message_rows[j]
|
|
4463
|
+
break
|
|
4464
|
+
|
|
4465
|
+
# Parse content
|
|
4466
|
+
try:
|
|
4467
|
+
user_content = json.loads(content_raw) if isinstance(content_raw, str) else content_raw
|
|
4468
|
+
except Exception:
|
|
4469
|
+
user_content = content_raw
|
|
4470
|
+
|
|
4471
|
+
# If user_content is a message dict with a 'content' key, extract it
|
|
4472
|
+
if isinstance(user_content, dict) and "content" in user_content:
|
|
4473
|
+
user_content = user_content["content"]
|
|
4474
|
+
|
|
4475
|
+
# Extract text from structured content
|
|
4476
|
+
def extract_text(content: Any) -> str:
|
|
4477
|
+
if isinstance(content, str):
|
|
4478
|
+
return content
|
|
4479
|
+
if isinstance(content, dict):
|
|
4480
|
+
# Try payload.content for user prompts
|
|
4481
|
+
if "payload" in content and isinstance(content["payload"], dict):
|
|
4482
|
+
payload = content["payload"]
|
|
4483
|
+
if "content" in payload:
|
|
4484
|
+
return extract_text(payload["content"])
|
|
4485
|
+
# Try common keys
|
|
4486
|
+
for key in ["text", "content", "content_text"]:
|
|
4487
|
+
if key in content:
|
|
4488
|
+
val = content[key]
|
|
4489
|
+
if isinstance(val, str):
|
|
4490
|
+
return val
|
|
4491
|
+
return json.dumps(content)
|
|
4492
|
+
if isinstance(content, list):
|
|
4493
|
+
# Multimodal content - concatenate text parts
|
|
4494
|
+
parts = []
|
|
4495
|
+
for item in content:
|
|
4496
|
+
if isinstance(item, dict) and item.get("type") == "text":
|
|
4497
|
+
parts.append(item.get("text", ""))
|
|
4498
|
+
return " ".join(parts) if parts else str(content)
|
|
4499
|
+
return str(content)
|
|
4500
|
+
|
|
4501
|
+
user_text = extract_text(user_content)
|
|
4502
|
+
|
|
4503
|
+
# For assistant, we might not have it recorded, so use tool calls as completion
|
|
4504
|
+
assistant_text = ""
|
|
4505
|
+
assistant_content = None
|
|
4506
|
+
if assistant_msg:
|
|
4507
|
+
assistant_content_raw = assistant_msg.get("content")
|
|
4508
|
+
try:
|
|
4509
|
+
assistant_content = json.loads(assistant_content_raw) if isinstance(assistant_content_raw, str) else assistant_content_raw
|
|
4510
|
+
except Exception:
|
|
4511
|
+
assistant_content = assistant_content_raw
|
|
4512
|
+
|
|
4513
|
+
# If assistant_content is a message dict with a 'content' key, extract it
|
|
4514
|
+
if isinstance(assistant_content, dict) and "content" in assistant_content:
|
|
4515
|
+
assistant_content = assistant_content["content"]
|
|
4516
|
+
|
|
4517
|
+
assistant_text = extract_text(assistant_content)
|
|
4518
|
+
|
|
4519
|
+
if not user_text:
|
|
4520
|
+
continue
|
|
4521
|
+
|
|
4522
|
+
# Use full multimodal content if it's a list (contains images), otherwise use text
|
|
4523
|
+
user_content_for_message = user_content if isinstance(user_content, list) else user_text
|
|
4524
|
+
assistant_content_for_message = assistant_content if isinstance(assistant_content, list) else (assistant_text if assistant_text else "[no response recorded]")
|
|
4525
|
+
|
|
4526
|
+
record = {
|
|
4527
|
+
"messages": [
|
|
4528
|
+
{"role": "user", "content": user_content_for_message},
|
|
4529
|
+
{"role": "assistant", "content": assistant_content_for_message},
|
|
4530
|
+
],
|
|
4531
|
+
"metadata": {
|
|
4532
|
+
"session_id": session_id,
|
|
4533
|
+
"env_name": metadata.get("env_name"),
|
|
4534
|
+
"policy_name": metadata.get("policy_name"),
|
|
4535
|
+
"seed": metadata.get("seed"),
|
|
4536
|
+
"total_reward": total_reward,
|
|
4537
|
+
"achievements_count": achievements_count,
|
|
4538
|
+
"model": metadata.get("model"),
|
|
4539
|
+
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4540
|
+
},
|
|
4541
|
+
}
|
|
4542
|
+
accepted.append(record)
|
|
4292
4543
|
|
|
4293
4544
|
if not accepted:
|
|
4294
4545
|
raise click.ClickException("No sessions matched the provided filters")
|