synth-ai 0.2.13.dev2__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/README.md +1 -0
- examples/multi_step/SFT_README.md +147 -0
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -11
- examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/convert_traces_to_sft.py +84 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/run_sft_qwen30b.sh +45 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +3 -2
- examples/qwen_coder/configs/coder_lora_4b.toml +2 -1
- examples/qwen_coder/configs/coder_lora_small.toml +2 -1
- examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
- examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
- examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
- examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
- examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
- examples/qwen_vl/QUICKSTART.md +327 -0
- examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
- examples/qwen_vl/README.md +154 -0
- examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
- examples/qwen_vl/RL_VISION_TESTING.md +333 -0
- examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
- examples/qwen_vl/SETUP_COMPLETE.md +275 -0
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +490 -0
- examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
- examples/qwen_vl/__init__.py +2 -0
- examples/qwen_vl/collect_data_via_cli.md +423 -0
- examples/qwen_vl/collect_vision_traces.py +368 -0
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +127 -0
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +60 -0
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +43 -0
- examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +45 -0
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +44 -0
- examples/qwen_vl/configs/filter_qwen2vl_sft.toml +50 -0
- examples/qwen_vl/configs/filter_vision_sft.toml +53 -0
- examples/qwen_vl/configs/filter_vision_test.toml +8 -0
- examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
- examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
- examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
- examples/qwen_vl/run_vision_comparison.sh +62 -0
- examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
- examples/qwen_vl/test_image_validation.py +201 -0
- examples/qwen_vl/test_sft_vision_data.py +110 -0
- examples/rl/README.md +1 -1
- examples/rl/configs/eval_base_qwen.toml +17 -0
- examples/rl/configs/eval_rl_qwen.toml +13 -0
- examples/rl/configs/rl_from_base_qwen.toml +37 -0
- examples/rl/configs/rl_from_base_qwen17.toml +76 -0
- examples/rl/configs/rl_from_ft_qwen.toml +37 -0
- examples/rl/run_eval.py +436 -0
- examples/rl/run_rl_and_save.py +111 -0
- examples/rl/task_app/README.md +22 -0
- examples/rl/task_app/math_single_step.py +990 -0
- examples/rl/task_app/math_task_app.py +111 -0
- examples/sft/README.md +5 -5
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -2
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -3
- examples/sft/evaluate.py +4 -4
- examples/sft/export_dataset.py +7 -4
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/README.md +1 -1
- examples/swe/task_app/grpo_swe_mini.py +1 -1
- examples/swe/task_app/grpo_swe_mini_task_app.py +0 -12
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +13 -13
- examples/swe/task_app/hosted/policy_routes.py +0 -2
- examples/swe/task_app/hosted/rollout.py +2 -8
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +3 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +309 -14
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +75 -4
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +55 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +114 -32
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +127 -27
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +156 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +2 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +2 -0
- examples/task_apps/pokemon_red/task_app.py +199 -6
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +2 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +8 -4
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +258 -23
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +2 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/vlm/README.md +3 -3
- examples/vlm/configs/crafter_vlm_gpt4o.toml +2 -0
- examples/vlm/crafter_openai_vlm_agent.py +3 -5
- examples/vlm/filter_image_rows.py +1 -1
- examples/vlm/run_crafter_vlm_benchmark.py +2 -2
- examples/warming_up_to_rl/_utils.py +92 -0
- examples/warming_up_to_rl/analyze_trace_db.py +1 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +2 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_ft.toml +2 -0
- examples/warming_up_to_rl/export_trace_sft.py +174 -60
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/readme.md +63 -132
- examples/warming_up_to_rl/run_fft_and_save.py +1 -1
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rl_and_save.py +1 -1
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- examples/warming_up_to_rl/task_app/README.md +42 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +696 -0
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +478 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +204 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +618 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1081 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1861 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +62 -0
- synth_ai/__init__.py +44 -30
- synth_ai/_utils/__init__.py +47 -0
- synth_ai/_utils/base_url.py +10 -0
- synth_ai/_utils/http.py +10 -0
- synth_ai/_utils/prompts.py +10 -0
- synth_ai/_utils/task_app_state.py +12 -0
- synth_ai/_utils/user_config.py +10 -0
- synth_ai/api/models/supported.py +145 -7
- synth_ai/api/train/__init__.py +13 -1
- synth_ai/api/train/cli.py +30 -7
- synth_ai/api/train/config_finder.py +18 -11
- synth_ai/api/train/env_resolver.py +13 -10
- synth_ai/cli/__init__.py +66 -49
- synth_ai/cli/_modal_wrapper.py +9 -6
- synth_ai/cli/_typer_patch.py +0 -2
- synth_ai/cli/_validate_task_app.py +22 -4
- synth_ai/cli/legacy_root_backup.py +3 -1
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/task_app_discovery.py +7 -0
- synth_ai/cli/lib/task_app_env.py +518 -0
- synth_ai/cli/recent.py +1 -0
- synth_ai/cli/setup.py +266 -0
- synth_ai/cli/task_app_deploy.py +16 -0
- synth_ai/cli/task_app_list.py +25 -0
- synth_ai/cli/task_app_modal_serve.py +16 -0
- synth_ai/cli/task_app_serve.py +18 -0
- synth_ai/cli/task_apps.py +392 -141
- synth_ai/cli/train.py +18 -0
- synth_ai/cli/tui.py +62 -0
- synth_ai/demos/__init__.py +10 -0
- synth_ai/demos/core/__init__.py +28 -1
- synth_ai/demos/crafter/__init__.py +1 -0
- synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/demos/demo_registry.py +176 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/math/__init__.py +1 -0
- synth_ai/demos/math/_common.py +16 -0
- synth_ai/demos/math/app.py +38 -0
- synth_ai/demos/math/config.toml +76 -0
- synth_ai/demos/math/deploy_modal.py +54 -0
- synth_ai/demos/math/modal_task_app.py +702 -0
- synth_ai/demos/math/task_app_entry.py +51 -0
- synth_ai/environments/environment/core.py +7 -1
- synth_ai/environments/examples/bandit/engine.py +0 -1
- synth_ai/environments/examples/bandit/environment.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/verilog/engine.py +76 -10
- synth_ai/environments/examples/wordle/environment.py +0 -1
- synth_ai/evals/base.py +16 -5
- synth_ai/evals/client.py +1 -1
- synth_ai/inference/client.py +1 -1
- synth_ai/learning/client.py +1 -1
- synth_ai/learning/health.py +1 -1
- synth_ai/learning/jobs.py +1 -1
- synth_ai/learning/rl/client.py +1 -1
- synth_ai/learning/rl/env_keys.py +1 -1
- synth_ai/learning/rl/secrets.py +1 -1
- synth_ai/learning/sft/client.py +1 -1
- synth_ai/learning/sft/data.py +407 -4
- synth_ai/learning/validators.py +4 -1
- synth_ai/task/__init__.py +11 -1
- synth_ai/task/apps/__init__.py +5 -2
- synth_ai/task/config.py +259 -0
- synth_ai/task/contracts.py +15 -2
- synth_ai/task/rubrics/__init__.py +4 -2
- synth_ai/task/rubrics/loaders.py +27 -4
- synth_ai/task/rubrics/scoring.py +3 -0
- synth_ai/task/rubrics.py +219 -0
- synth_ai/task/trace_correlation_helpers.py +328 -0
- synth_ai/task/tracing_utils.py +14 -3
- synth_ai/task/validators.py +145 -2
- synth_ai/tracing_v3/config.py +15 -13
- synth_ai/tracing_v3/constants.py +21 -0
- synth_ai/tracing_v3/db_config.py +3 -1
- synth_ai/tracing_v3/decorators.py +10 -7
- synth_ai/tracing_v3/session_tracer.py +10 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +108 -77
- synth_ai/tracing_v3/utils.py +1 -1
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +911 -0
- synth_ai/utils/__init__.py +101 -0
- synth_ai/utils/base_url.py +94 -0
- synth_ai/utils/cli.py +131 -0
- synth_ai/utils/env.py +287 -0
- synth_ai/utils/http.py +169 -0
- synth_ai/utils/modal.py +308 -0
- synth_ai/utils/process.py +212 -0
- synth_ai/utils/prompts.py +39 -0
- synth_ai/utils/sqld.py +122 -0
- synth_ai/utils/task_app_discovery.py +882 -0
- synth_ai/utils/task_app_env.py +186 -0
- synth_ai/utils/task_app_state.py +318 -0
- synth_ai/utils/user_config.py +137 -0
- synth_ai/v0/config/__init__.py +1 -5
- synth_ai/v0/config/base_url.py +1 -7
- synth_ai/v0/tracing/config.py +1 -1
- synth_ai/v0/tracing/decorators.py +1 -1
- synth_ai/v0/tracing/upload.py +1 -1
- synth_ai/v0/tracing_v1/config.py +1 -1
- synth_ai/v0/tracing_v1/decorators.py +1 -1
- synth_ai/v0/tracing_v1/upload.py +1 -1
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/METADATA +85 -31
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/RECORD +286 -135
- synth_ai/cli/man.py +0 -106
- synth_ai/compound/cais.py +0 -0
- synth_ai/core/experiment.py +0 -13
- synth_ai/core/system.py +0 -15
- synth_ai/demo_registry.py +0 -295
- synth_ai/handshake.py +0 -109
- synth_ai/http.py +0 -26
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -34,6 +34,7 @@ from synth_ai.task.contracts import (
|
|
|
34
34
|
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
35
35
|
from synth_ai.task.rubrics import load_rubric
|
|
36
36
|
from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
37
|
+
from synth_ai.task.validators import normalize_inference_url
|
|
37
38
|
from synth_ai.task.tracing_utils import (
|
|
38
39
|
build_tracer_factory,
|
|
39
40
|
resolve_sft_output_dir,
|
|
@@ -45,7 +46,36 @@ from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
|
45
46
|
logger = logging.getLogger(__name__)
|
|
46
47
|
|
|
47
48
|
_HERE = Path(__file__).resolve()
|
|
48
|
-
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _resolve_repo_root() -> Path:
|
|
52
|
+
"""Find synth-ai repo root, checking env var and parent traversal."""
|
|
53
|
+
candidates: list[Path] = []
|
|
54
|
+
env_root = os.getenv("SYNTH_AI_REPO_ROOT")
|
|
55
|
+
if env_root:
|
|
56
|
+
candidates.append(Path(env_root).expanduser())
|
|
57
|
+
|
|
58
|
+
# Try Modal mount point
|
|
59
|
+
candidates.append(Path("/opt/synth_ai_repo"))
|
|
60
|
+
|
|
61
|
+
# Traverse up from current file
|
|
62
|
+
current = _HERE
|
|
63
|
+
for _ in range(6):
|
|
64
|
+
current = current.parent
|
|
65
|
+
candidates.append(current)
|
|
66
|
+
if (current / "synth_ai").is_dir() and (current / "examples").is_dir():
|
|
67
|
+
return current
|
|
68
|
+
|
|
69
|
+
# Return first existing candidate
|
|
70
|
+
for candidate in candidates:
|
|
71
|
+
if candidate.is_dir() and (candidate / "synth_ai").exists():
|
|
72
|
+
return candidate
|
|
73
|
+
|
|
74
|
+
# Fallback to current parent structure (may not work in Modal)
|
|
75
|
+
return _HERE.parent.parent.parent.parent
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
REPO_ROOT = _resolve_repo_root()
|
|
49
79
|
|
|
50
80
|
DATASET_SPEC = TaskDatasetSpec(
|
|
51
81
|
id="verilog_eval_v2",
|
|
@@ -161,23 +191,6 @@ def _base_task_info(dataset: VerilogDataset) -> TaskInfo:
|
|
|
161
191
|
)
|
|
162
192
|
|
|
163
193
|
|
|
164
|
-
def _normalize_inference_url(url: str | None) -> str:
|
|
165
|
-
candidate = (url or DEFAULT_INFERENCE_URL).strip()
|
|
166
|
-
if not candidate:
|
|
167
|
-
candidate = DEFAULT_INFERENCE_URL
|
|
168
|
-
if candidate.endswith("/v1/chat/completions"):
|
|
169
|
-
return candidate
|
|
170
|
-
if candidate.endswith("/chat/completions"):
|
|
171
|
-
return candidate
|
|
172
|
-
if candidate.endswith("/v1"):
|
|
173
|
-
return f"{candidate.rstrip('/')}/chat/completions"
|
|
174
|
-
if candidate.endswith("/v1/"):
|
|
175
|
-
return f"{candidate.rstrip('/')}/chat/completions"
|
|
176
|
-
if candidate.endswith("/chat"):
|
|
177
|
-
return f"{candidate.rstrip('/')}/completions"
|
|
178
|
-
if candidate.endswith("/chat/"):
|
|
179
|
-
return f"{candidate.rstrip('/')}/completions"
|
|
180
|
-
return f"{candidate.rstrip('/')}/v1/chat/completions"
|
|
181
194
|
|
|
182
195
|
|
|
183
196
|
def _format_file_previews(files: dict[str, str]) -> str:
|
|
@@ -336,7 +349,7 @@ class VerilogLLMAgent:
|
|
|
336
349
|
max_tokens: int,
|
|
337
350
|
) -> None:
|
|
338
351
|
self.instructions = instructions.strip()
|
|
339
|
-
self.inference_url =
|
|
352
|
+
self.inference_url = normalize_inference_url(inference_url, default=DEFAULT_INFERENCE_URL)
|
|
340
353
|
self.model = model or DEFAULT_MODEL
|
|
341
354
|
self.temperature = temperature
|
|
342
355
|
self.max_tokens = max_tokens
|
|
@@ -349,7 +362,16 @@ class VerilogLLMAgent:
|
|
|
349
362
|
if not api_key:
|
|
350
363
|
raise RuntimeError("GROQ_API_KEY is not configured for Verilog inference.")
|
|
351
364
|
self.headers["Authorization"] = f"Bearer {api_key.strip()}"
|
|
352
|
-
|
|
365
|
+
# If target is Synth backend (any deployment), use SYNTH_API_KEY
|
|
366
|
+
elif any(pattern in lowered for pattern in [
|
|
367
|
+
"synth-backend", "synth.run", "agent-learning",
|
|
368
|
+
"localhost:8000", "127.0.0.1:8000"
|
|
369
|
+
]):
|
|
370
|
+
api_key = os.getenv("SYNTH_API_KEY")
|
|
371
|
+
if not api_key:
|
|
372
|
+
raise RuntimeError("SYNTH_API_KEY is not configured for Verilog inference with Synth backend.")
|
|
373
|
+
self.headers["Authorization"] = f"Bearer {api_key.strip()}"
|
|
374
|
+
elif "openai" in lowered or "api.openai.com" in lowered:
|
|
353
375
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
354
376
|
if not api_key:
|
|
355
377
|
raise RuntimeError("OPENAI_API_KEY is not configured for Verilog inference.")
|
|
@@ -574,6 +596,21 @@ async def rollout_executor(
|
|
|
574
596
|
total_reward = 0.0
|
|
575
597
|
final_observation: dict[str, Any] | None = None
|
|
576
598
|
truncated_due_to_limit = False
|
|
599
|
+
|
|
600
|
+
# Log episode start
|
|
601
|
+
problem_id = getattr(instance, "problem_id", "unknown")
|
|
602
|
+
logger.info("=" * 80)
|
|
603
|
+
logger.info(f"[EPISODE START] run_id={request.run_id}")
|
|
604
|
+
logger.info(f" Problem ID: {problem_id}")
|
|
605
|
+
logger.info(f" Policy: {policy_id}")
|
|
606
|
+
logger.info(f" Model: {policy_model}")
|
|
607
|
+
logger.info(f" Max steps: {max_steps}")
|
|
608
|
+
logger.info(f" Temperature: {temperature}")
|
|
609
|
+
logger.info(f" Max tokens: {max_tokens}")
|
|
610
|
+
if instructions:
|
|
611
|
+
instructions_preview = instructions[:150] + "..." if len(instructions) > 150 else instructions
|
|
612
|
+
logger.info(f" Instructions: {instructions_preview}")
|
|
613
|
+
logger.info("=" * 80)
|
|
577
614
|
code_dirty = False
|
|
578
615
|
last_compile_success = False
|
|
579
616
|
simulate_since_last_compile = False
|
|
@@ -648,7 +685,7 @@ async def rollout_executor(
|
|
|
648
685
|
and not code_dirty
|
|
649
686
|
)
|
|
650
687
|
if skip_env_step:
|
|
651
|
-
reward_last =
|
|
688
|
+
reward_last = 0.0 # No reward for blocked operations
|
|
652
689
|
total_reward += reward_last
|
|
653
690
|
current_observation = dict(current_observation)
|
|
654
691
|
current_observation["reward_last"] = reward_last
|
|
@@ -669,6 +706,23 @@ async def rollout_executor(
|
|
|
669
706
|
or current_observation.get("task_completed")
|
|
670
707
|
)
|
|
671
708
|
truncated_flag = bool(current_observation.get("truncated"))
|
|
709
|
+
|
|
710
|
+
# Log what the environment returned
|
|
711
|
+
print(f"\n{'='*80}")
|
|
712
|
+
print(f"[STEP {step_index}] TOOL CALL:")
|
|
713
|
+
print(f" Tool: {env_call.tool}")
|
|
714
|
+
print(f" Args: {env_call.args}")
|
|
715
|
+
print(f"\n[STEP {step_index}] ENVIRONMENT RESPONSE:")
|
|
716
|
+
print(f" Reward: {reward_last:.4f} (cumulative: {total_reward:.4f})")
|
|
717
|
+
print(f" Task completed: {step_observation.get('task_completed')}")
|
|
718
|
+
print(f" Done: {done_flag} | Truncated: {truncated_flag}")
|
|
719
|
+
if 'compile_status' in step_observation and step_observation.get('compile_status'):
|
|
720
|
+
print(f" Compile status:\n{step_observation.get('compile_status')}")
|
|
721
|
+
if 'simulate_status' in step_observation and step_observation.get('simulate_status'):
|
|
722
|
+
print(f" Simulate status:\n{step_observation.get('simulate_status')}")
|
|
723
|
+
if 'files' in step_observation:
|
|
724
|
+
print(f" Files: {list(step_observation.get('files', {}).keys())}")
|
|
725
|
+
print(f"{'='*80}\n")
|
|
672
726
|
|
|
673
727
|
executed_tool_name = str(primary_call["tool"])
|
|
674
728
|
normalized_executed_tool = executed_tool_name.strip().lower()
|
|
@@ -698,10 +752,40 @@ async def rollout_executor(
|
|
|
698
752
|
{"tool_name": call["tool"], "arguments": call["args"]}
|
|
699
753
|
for call in tool_calls
|
|
700
754
|
]
|
|
755
|
+
|
|
756
|
+
# Print tool calls for debugging
|
|
757
|
+
logger.info(f"[STEP {step_index}] Tool calls executed:")
|
|
758
|
+
for call in tool_calls:
|
|
759
|
+
tool_name = call["tool"]
|
|
760
|
+
args = call["args"]
|
|
761
|
+
# Truncate long arguments for readability
|
|
762
|
+
if "code" in args or "content" in args:
|
|
763
|
+
args_preview = {k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v)
|
|
764
|
+
for k, v in args.items()}
|
|
765
|
+
else:
|
|
766
|
+
args_preview = args
|
|
767
|
+
logger.info(f" └─ {tool_name}({args_preview})")
|
|
768
|
+
|
|
769
|
+
# Log reward details for debugging
|
|
770
|
+
logger.info(f"[STEP {step_index}] Reward details:")
|
|
771
|
+
logger.info(f" └─ reward_last: {reward_last:.4f}")
|
|
772
|
+
logger.info(f" └─ total_reward: {total_reward:.4f}")
|
|
773
|
+
logger.info(f" └─ skip_env_step: {skip_env_step}")
|
|
774
|
+
if not skip_env_step:
|
|
775
|
+
logger.info(f" └─ obs.task_completed: {current_observation.get('task_completed', False)}")
|
|
776
|
+
logger.info(f" └─ obs.compile_status: {current_observation.get('compile_status', 'N/A')}")
|
|
777
|
+
logger.info(f" └─ obs.simulate_status: {current_observation.get('simulate_status', 'N/A')}")
|
|
778
|
+
logger.info(f" └─ obs.terminated: {current_observation.get('terminated', False)}")
|
|
779
|
+
else:
|
|
780
|
+
logger.info(f" └─ (blocked operation - no env step)")
|
|
781
|
+
|
|
701
782
|
step_info = {
|
|
702
783
|
"assistant_message": assistant_text,
|
|
703
784
|
"model_response": raw_response,
|
|
704
785
|
"llm_request": request_payload,
|
|
786
|
+
"meta": {
|
|
787
|
+
"inference_url": policy_config.get("inference_url") or resolved_inference, # CRITICAL: Required by RL trainer for trace extraction (must have ?cid=...)
|
|
788
|
+
},
|
|
705
789
|
}
|
|
706
790
|
if override_info:
|
|
707
791
|
step_info["auto_override"] = override_info
|
|
@@ -756,6 +840,9 @@ async def rollout_executor(
|
|
|
756
840
|
"model_response": raw_response,
|
|
757
841
|
"llm_request": request_payload,
|
|
758
842
|
"error": error_text,
|
|
843
|
+
"meta": {
|
|
844
|
+
"inference_url": policy_config.get("inference_url") or resolved_inference, # CRITICAL: Required by RL trainer
|
|
845
|
+
},
|
|
759
846
|
}
|
|
760
847
|
steps.append(
|
|
761
848
|
RolloutStep(
|
|
@@ -797,6 +884,25 @@ async def rollout_executor(
|
|
|
797
884
|
},
|
|
798
885
|
)
|
|
799
886
|
|
|
887
|
+
# Extract inference_url from policy config (REQUIRED for RL trace correlation)
|
|
888
|
+
# The trainer injects this with ?cid=trace_xxxxx parameter for trace linking
|
|
889
|
+
final_inference_url = policy_config.get("inference_url")
|
|
890
|
+
if not isinstance(final_inference_url, str) or not final_inference_url.strip():
|
|
891
|
+
# Fallback to agent's inference_url if not in policy config
|
|
892
|
+
final_inference_url = agent.inference_url
|
|
893
|
+
logger.warning(
|
|
894
|
+
"VERILOG_ROLLOUT: inference_url not found in policy_config, using agent.inference_url run_id=%s url=%s",
|
|
895
|
+
request.run_id,
|
|
896
|
+
final_inference_url,
|
|
897
|
+
)
|
|
898
|
+
else:
|
|
899
|
+
logger.info(
|
|
900
|
+
"VERILOG_ROLLOUT: using inference_url from policy_config run_id=%s url=%s has_cid=%s",
|
|
901
|
+
request.run_id,
|
|
902
|
+
final_inference_url,
|
|
903
|
+
"?cid=" in final_inference_url,
|
|
904
|
+
)
|
|
905
|
+
|
|
800
906
|
trajectory = RolloutTrajectory(
|
|
801
907
|
env_id=str(env_id),
|
|
802
908
|
policy_id=str(policy_id),
|
|
@@ -810,11 +916,11 @@ async def rollout_executor(
|
|
|
810
916
|
"total_reward": final_total_reward,
|
|
811
917
|
"task_completed": bool(final_observation.get("task_completed")),
|
|
812
918
|
"policy_model": policy_model,
|
|
813
|
-
"inference_url":
|
|
919
|
+
"inference_url": final_inference_url,
|
|
814
920
|
},
|
|
815
921
|
},
|
|
816
922
|
length=len(steps),
|
|
817
|
-
inference_url=
|
|
923
|
+
inference_url=final_inference_url, # CRITICAL: Must contain ?cid=... for trace correlation
|
|
818
924
|
decision_samples=None,
|
|
819
925
|
)
|
|
820
926
|
|
|
@@ -836,6 +942,133 @@ async def rollout_executor(
|
|
|
836
942
|
}
|
|
837
943
|
}
|
|
838
944
|
|
|
945
|
+
# Build pipeline_metadata (required for RL training)
|
|
946
|
+
pipeline_metadata = {
|
|
947
|
+
"reward_score": final_total_reward,
|
|
948
|
+
"policy_id": policy_id,
|
|
949
|
+
"inference_url": final_inference_url, # CRITICAL: Must be at top level for RL trainer (expects ?cid=...)
|
|
950
|
+
"inference": {
|
|
951
|
+
"provider": "groq",
|
|
952
|
+
"model": policy_model,
|
|
953
|
+
"url": final_inference_url, # Use final_inference_url (has ?cid=...)
|
|
954
|
+
},
|
|
955
|
+
"env_name": env_id,
|
|
956
|
+
"task_id": getattr(instance, "problem_id", None),
|
|
957
|
+
"task_split": getattr(instance, "split", "val"),
|
|
958
|
+
}
|
|
959
|
+
|
|
960
|
+
# Log episode summary with reward breakdown
|
|
961
|
+
compile_status = final_observation.get("compile_status", "N/A")
|
|
962
|
+
simulate_status = final_observation.get("simulate_status", "N/A")
|
|
963
|
+
task_completed = bool(final_observation.get("task_completed", False))
|
|
964
|
+
|
|
965
|
+
logger.info("=" * 80)
|
|
966
|
+
logger.info(f"[EPISODE COMPLETE] run_id={request.run_id}")
|
|
967
|
+
logger.info(f" Steps taken: {len(steps)}")
|
|
968
|
+
logger.info(f" Total reward: {final_total_reward:.3f}")
|
|
969
|
+
logger.info(f" Task completed: {task_completed}")
|
|
970
|
+
logger.info(f" Compile status: {compile_status}")
|
|
971
|
+
logger.info(f" Simulate status: {simulate_status}")
|
|
972
|
+
logger.info(f" Done/Truncated: {final_done}/{final_truncated}")
|
|
973
|
+
logger.info(f" Problem ID: {getattr(instance, 'problem_id', 'N/A')}")
|
|
974
|
+
|
|
975
|
+
# DEBUG: Log each step's reward for RL debugging
|
|
976
|
+
print(f"\n[REWARD DEBUG] Step-by-step breakdown:")
|
|
977
|
+
for idx, step in enumerate(steps):
|
|
978
|
+
print(f" Step {idx}: reward={step.reward:.4f} tool_calls={[tc.get('tool_name') for tc in step.tool_calls]}")
|
|
979
|
+
print(f"[REWARD DEBUG] Final observation keys: {list(final_observation.keys())}")
|
|
980
|
+
print(f"[REWARD DEBUG] Final obs total_reward: {final_observation.get('total_reward')}")
|
|
981
|
+
print(f"[REWARD DEBUG] Metrics outcome_score: {metrics.outcome_score}")
|
|
982
|
+
print(f"[REWARD DEBUG] Metrics mean_return: {metrics.mean_return}")
|
|
983
|
+
|
|
984
|
+
# Reward breakdown for debugging
|
|
985
|
+
logger.info("\n[REWARD BREAKDOWN]")
|
|
986
|
+
compile_count = sum(1 for s in steps if any(tc.get("tool_name") == "compile" for tc in s.tool_calls))
|
|
987
|
+
simulate_count = sum(1 for s in steps if any(tc.get("tool_name") == "simulate" for tc in s.tool_calls))
|
|
988
|
+
submit_count = sum(1 for s in steps if any(tc.get("tool_name") == "submit" for tc in s.tool_calls))
|
|
989
|
+
write_count = sum(1 for s in steps if any(tc.get("tool_name") == "write_file" for tc in s.tool_calls))
|
|
990
|
+
|
|
991
|
+
logger.info(f" Tool usage: write_file={write_count}, compile={compile_count}, simulate={simulate_count}, submit={submit_count}")
|
|
992
|
+
|
|
993
|
+
# Show per-step rewards
|
|
994
|
+
step_rewards = [s.reward for s in steps]
|
|
995
|
+
nonzero_rewards = [r for r in step_rewards if r != 0.0]
|
|
996
|
+
logger.info(f" Step rewards: {step_rewards}")
|
|
997
|
+
if nonzero_rewards:
|
|
998
|
+
logger.info(f" Non-zero rewards: {nonzero_rewards}")
|
|
999
|
+
else:
|
|
1000
|
+
logger.info(f" ⚠️ ALL REWARDS ZERO! Possible reasons:")
|
|
1001
|
+
logger.info(f" - No successful compiles (compile reward = 0.01)")
|
|
1002
|
+
logger.info(f" - No successful simulations (simulate reward = 0.1)")
|
|
1003
|
+
logger.info(f" - No successful submits (submit reward = 1.0)")
|
|
1004
|
+
logger.info(f" - Check if task_completed={task_completed}")
|
|
1005
|
+
logger.info(f" - Check compile_status='{compile_status}'")
|
|
1006
|
+
logger.info(f" - Check simulate_status='{simulate_status}'")
|
|
1007
|
+
logger.info("=" * 80)
|
|
1008
|
+
|
|
1009
|
+
# Log for debugging RL training
|
|
1010
|
+
logger.info(
|
|
1011
|
+
"VERILOG_ROLLOUT: pipeline_metadata run_id=%s reward=%.3f inference_url=%s",
|
|
1012
|
+
request.run_id,
|
|
1013
|
+
final_total_reward,
|
|
1014
|
+
final_inference_url,
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
# DEBUG: Log what we're returning to the RL trainer
|
|
1018
|
+
print(f"\n[RETURN DEBUG] Trajectory structure being returned:")
|
|
1019
|
+
print(f" trajectory.steps count: {len(steps)}")
|
|
1020
|
+
print(f" trajectory.final.reward: {trajectory.final.get('reward') if trajectory.final else 'None'}")
|
|
1021
|
+
print(f" trajectory.length: {trajectory.length}")
|
|
1022
|
+
print(f" metrics.outcome_score: {metrics.outcome_score}")
|
|
1023
|
+
print(f" metrics.mean_return: {metrics.mean_return}")
|
|
1024
|
+
print(f" metrics.episode_returns: {metrics.episode_returns}")
|
|
1025
|
+
print(f" pipeline_metadata.reward_score: {pipeline_metadata.get('reward_score')}")
|
|
1026
|
+
|
|
1027
|
+
# ASSERTIONS: Validate RL-required fields before returning
|
|
1028
|
+
# These catch structural issues early (before they reach the backend trainer)
|
|
1029
|
+
# Only enforce for RL mode, not EVAL mode
|
|
1030
|
+
is_rl_mode = hasattr(request, 'mode') and str(getattr(request, 'mode', '')).lower() == 'rl'
|
|
1031
|
+
|
|
1032
|
+
assert isinstance(pipeline_metadata, dict), (
|
|
1033
|
+
f"VERILOG_ROLLOUT_VALIDATION: pipeline_metadata must be dict, got {type(pipeline_metadata).__name__}"
|
|
1034
|
+
)
|
|
1035
|
+
assert "inference_url" in pipeline_metadata, (
|
|
1036
|
+
f"VERILOG_ROLLOUT_VALIDATION: pipeline_metadata missing 'inference_url' (REQUIRED for RL training)"
|
|
1037
|
+
)
|
|
1038
|
+
assert isinstance(pipeline_metadata["inference_url"], str), (
|
|
1039
|
+
f"VERILOG_ROLLOUT_VALIDATION: pipeline_metadata['inference_url'] must be string, got {type(pipeline_metadata['inference_url']).__name__}"
|
|
1040
|
+
)
|
|
1041
|
+
# Only require ?cid= for RL mode (not needed for EVAL)
|
|
1042
|
+
if is_rl_mode:
|
|
1043
|
+
assert "?cid=" in pipeline_metadata["inference_url"], (
|
|
1044
|
+
f"VERILOG_ROLLOUT_VALIDATION: pipeline_metadata['inference_url'] must contain '?cid=' for trace correlation in RL mode. "
|
|
1045
|
+
f"Got: {pipeline_metadata['inference_url'][:100]}"
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
# Validate each step has meta.inference_url (backend expects this nested structure)
|
|
1049
|
+
for step_idx, step in enumerate(steps):
|
|
1050
|
+
step_dict = step if isinstance(step, dict) else (step.model_dump() if hasattr(step, "model_dump") else {})
|
|
1051
|
+
step_info = step_dict.get("info", {})
|
|
1052
|
+
assert isinstance(step_info, dict), (
|
|
1053
|
+
f"VERILOG_ROLLOUT_VALIDATION: step[{step_idx}].info must be dict, got {type(step_info).__name__}"
|
|
1054
|
+
)
|
|
1055
|
+
step_meta = step_info.get("meta", {})
|
|
1056
|
+
assert isinstance(step_meta, dict), (
|
|
1057
|
+
f"VERILOG_ROLLOUT_VALIDATION: step[{step_idx}].info.meta must be dict, got {type(step_meta).__name__}"
|
|
1058
|
+
)
|
|
1059
|
+
assert "inference_url" in step_meta, (
|
|
1060
|
+
f"VERILOG_ROLLOUT_VALIDATION: step[{step_idx}].info.meta missing 'inference_url' (REQUIRED for RL training)"
|
|
1061
|
+
)
|
|
1062
|
+
assert isinstance(step_meta["inference_url"], str), (
|
|
1063
|
+
f"VERILOG_ROLLOUT_VALIDATION: step[{step_idx}].info.meta['inference_url'] must be string, got {type(step_meta['inference_url']).__name__}"
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
logger.info(
|
|
1067
|
+
"VERILOG_ROLLOUT_VALIDATION: ✓ All RL-required fields present run_id=%s steps=%d",
|
|
1068
|
+
request.run_id,
|
|
1069
|
+
len(steps),
|
|
1070
|
+
)
|
|
1071
|
+
|
|
839
1072
|
return RolloutResponse(
|
|
840
1073
|
run_id=request.run_id,
|
|
841
1074
|
trajectories=[trajectory],
|
|
@@ -844,6 +1077,7 @@ async def rollout_executor(
|
|
|
844
1077
|
aborted=False,
|
|
845
1078
|
ops_executed=len(steps),
|
|
846
1079
|
trace=trace_payload,
|
|
1080
|
+
pipeline_metadata=pipeline_metadata,
|
|
847
1081
|
)
|
|
848
1082
|
|
|
849
1083
|
|
|
@@ -917,6 +1151,7 @@ register_task_app(
|
|
|
917
1151
|
"python-dotenv>=1.0.1",
|
|
918
1152
|
"datasets>=2.10.0",
|
|
919
1153
|
),
|
|
1154
|
+
apt_packages=("iverilog",), # Icarus Verilog compiler and simulator (provides iverilog and vvp)
|
|
920
1155
|
extra_local_dirs=(
|
|
921
1156
|
(str(REPO_ROOT), "/opt/synth_ai_repo"),
|
|
922
1157
|
(str(REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),
|
examples/vlm/README.md
CHANGED
|
@@ -21,8 +21,8 @@ plumbing with lightweight utilities for dataset curation and training.
|
|
|
21
21
|
3. **Export multimodal SFT rows**
|
|
22
22
|
```
|
|
23
23
|
uv run python examples/warming_up_to_rl/export_trace_sft.py \
|
|
24
|
-
|
|
25
|
-
--output examples/vlm/output/
|
|
24
|
+
--db traces/v3/task_app_traces_<timestamp>.db \
|
|
25
|
+
--output examples/vlm/output/crafter_sft_full.jsonl
|
|
26
26
|
```
|
|
27
27
|
The exporter now emits `metadata.has_image`, `metadata.user_has_image`, and
|
|
28
28
|
`metadata.assistant_has_image` flags per turn.
|
|
@@ -30,7 +30,7 @@ plumbing with lightweight utilities for dataset curation and training.
|
|
|
30
30
|
4. **Filter to image-rich turns**
|
|
31
31
|
```
|
|
32
32
|
uv run python examples/vlm/filter_image_rows.py \
|
|
33
|
-
--input examples/vlm/output/
|
|
33
|
+
--input examples/vlm/output/crafter_sft_full.jsonl \
|
|
34
34
|
--output examples/vlm/output/crafter_vlm_dataset.jsonl
|
|
35
35
|
```
|
|
36
36
|
|
|
@@ -24,6 +24,7 @@ import asyncio
|
|
|
24
24
|
import base64
|
|
25
25
|
import json
|
|
26
26
|
import os
|
|
27
|
+
from contextlib import suppress
|
|
27
28
|
from pathlib import Path
|
|
28
29
|
from typing import Any
|
|
29
30
|
from uuid import uuid4
|
|
@@ -62,7 +63,7 @@ class EpisodeResult:
|
|
|
62
63
|
if unlocked:
|
|
63
64
|
self.achievements.add(str(name))
|
|
64
65
|
reward = obs.get("reward_last_step")
|
|
65
|
-
if isinstance(reward,
|
|
66
|
+
if isinstance(reward, int | float):
|
|
66
67
|
self.total_reward += float(reward)
|
|
67
68
|
|
|
68
69
|
|
|
@@ -107,11 +108,8 @@ def _decode_and_save_image(observation: dict[str, Any], path: Path) -> None:
|
|
|
107
108
|
if not isinstance(base64_data, str) or not base64_data:
|
|
108
109
|
return
|
|
109
110
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
110
|
-
|
|
111
|
+
with suppress(Exception):
|
|
111
112
|
path.write_bytes(base64.b64decode(base64_data))
|
|
112
|
-
except Exception:
|
|
113
|
-
# Best-effort; corrupted frames should not halt rollout
|
|
114
|
-
pass
|
|
115
113
|
|
|
116
114
|
|
|
117
115
|
def _normalise_openai_request(payload: dict[str, Any], model: str, temperature: float) -> dict[str, Any]:
|
|
@@ -8,7 +8,7 @@ output now that each record's metadata includes `has_image`, `user_has_image`, a
|
|
|
8
8
|
|
|
9
9
|
Usage:
|
|
10
10
|
uv run python examples/vlm/filter_image_rows.py \
|
|
11
|
-
--input examples/sft/ft_data/
|
|
11
|
+
--input examples/sft/ft_data/crafter_sft.jsonl \
|
|
12
12
|
--output examples/vlm/output/crafter_vlm_dataset.jsonl
|
|
13
13
|
"""
|
|
14
14
|
|
|
@@ -224,7 +224,7 @@ async def _run_episode(
|
|
|
224
224
|
if unlocked:
|
|
225
225
|
achievements.add(str(name))
|
|
226
226
|
reward = obs.get("reward_last_step")
|
|
227
|
-
if isinstance(reward,
|
|
227
|
+
if isinstance(reward, int | float):
|
|
228
228
|
total_reward += float(reward)
|
|
229
229
|
|
|
230
230
|
_save_observation_frame(env_response, frames_dir / f"step_{step_idx + 1:03d}.png")
|
|
@@ -263,7 +263,7 @@ def _summarise(results: list[EpisodeResult]) -> dict[str, Any]:
|
|
|
263
263
|
"mean_steps": round(mean_steps, 2),
|
|
264
264
|
"mean_achievements": round(mean_achievements, 2),
|
|
265
265
|
"total_tool_calls": sum(r.tool_calls for r in mode_results),
|
|
266
|
-
"achievements":
|
|
266
|
+
"achievements": dict(sorted(achievement_counts.items())),
|
|
267
267
|
}
|
|
268
268
|
return summary
|
|
269
269
|
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable, Sequence
|
|
4
|
+
|
|
5
|
+
from synth_ai.task import (
|
|
6
|
+
RolloutEnvSpec,
|
|
7
|
+
RolloutPolicySpec,
|
|
8
|
+
RolloutRecordConfig,
|
|
9
|
+
RolloutRequest,
|
|
10
|
+
RolloutSafetyConfig,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
DEFAULT_POLICY_NAME = "crafter-react"
|
|
14
|
+
DEFAULT_ENV_NAME = "crafter"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def parse_ops(spec: str | None) -> list[str] | None:
|
|
18
|
+
"""Parse a comma-separated operations string into a list."""
|
|
19
|
+
|
|
20
|
+
if spec is None:
|
|
21
|
+
return None
|
|
22
|
+
ops = [op.strip() for op in spec.split(",") if op.strip()]
|
|
23
|
+
if not ops:
|
|
24
|
+
raise ValueError("Ops must contain at least one entry")
|
|
25
|
+
return ops
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def ops_from_pairs(max_llm_calls: int, *, cap: int | None = None) -> list[str]:
|
|
29
|
+
"""Return alternating agent/env ops for the requested number of LLM calls."""
|
|
30
|
+
|
|
31
|
+
pairs = max(1, int(max_llm_calls or 0))
|
|
32
|
+
if cap is not None:
|
|
33
|
+
pairs = min(pairs, cap)
|
|
34
|
+
ops: list[str] = []
|
|
35
|
+
for _ in range(pairs):
|
|
36
|
+
ops.extend(["agent", "env"])
|
|
37
|
+
return ops
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def build_rollout_request(
|
|
41
|
+
*,
|
|
42
|
+
seed: int,
|
|
43
|
+
run_id: str,
|
|
44
|
+
model: str,
|
|
45
|
+
inference_url: str,
|
|
46
|
+
ops: Sequence[str] | Iterable[str],
|
|
47
|
+
inference_api_key: str | None = None,
|
|
48
|
+
extra_headers: dict[str, str] | None = None,
|
|
49
|
+
trace_format: str = "compact",
|
|
50
|
+
return_trace: bool = False,
|
|
51
|
+
policy_name: str = DEFAULT_POLICY_NAME,
|
|
52
|
+
env_name: str = DEFAULT_ENV_NAME,
|
|
53
|
+
max_policy_tokens: int | None = None,
|
|
54
|
+
record_trajectories: bool = True,
|
|
55
|
+
) -> RolloutRequest:
|
|
56
|
+
"""Construct a RolloutRequest shared across local rollout utilities."""
|
|
57
|
+
|
|
58
|
+
policy_config: dict[str, object] = {
|
|
59
|
+
"model": model,
|
|
60
|
+
"inference_url": inference_url,
|
|
61
|
+
}
|
|
62
|
+
if inference_api_key is not None:
|
|
63
|
+
policy_config["api_key"] = inference_api_key
|
|
64
|
+
if extra_headers:
|
|
65
|
+
policy_config["extra_headers"] = extra_headers
|
|
66
|
+
if max_policy_tokens is not None:
|
|
67
|
+
policy_config["max_completion_tokens"] = max_policy_tokens
|
|
68
|
+
policy_config["max_tokens"] = max_policy_tokens
|
|
69
|
+
|
|
70
|
+
record_cfg = RolloutRecordConfig(
|
|
71
|
+
trajectories=record_trajectories,
|
|
72
|
+
trace_format=trace_format,
|
|
73
|
+
return_trace=return_trace,
|
|
74
|
+
)
|
|
75
|
+
return RolloutRequest(
|
|
76
|
+
run_id=run_id,
|
|
77
|
+
env=RolloutEnvSpec(env_name=env_name, seed=seed, config={}),
|
|
78
|
+
policy=RolloutPolicySpec(policy_name=policy_name, config=policy_config),
|
|
79
|
+
ops=list(ops),
|
|
80
|
+
record=record_cfg,
|
|
81
|
+
on_done="reset",
|
|
82
|
+
safety=RolloutSafetyConfig(),
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
__all__ = [
|
|
87
|
+
"DEFAULT_POLICY_NAME",
|
|
88
|
+
"DEFAULT_ENV_NAME",
|
|
89
|
+
"build_rollout_request",
|
|
90
|
+
"ops_from_pairs",
|
|
91
|
+
"parse_ops",
|
|
92
|
+
]
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
# Eval config for Synth Modal inference Qwen/Qwen3-4B via task app rollout
|
|
2
2
|
|
|
3
|
+
type = "rl"
|
|
4
|
+
|
|
3
5
|
# Required
|
|
4
6
|
task_app_url = "https://synth-laboratories--grpo-crafter-task-app-final-warming--ceb5b2.modal.run"
|
|
5
7
|
model = "Qwen/Qwen3-4B"
|
|
@@ -20,4 +22,3 @@ concurrency = 10
|
|
|
20
22
|
# fetch the vLLM base from the task app /info to use as inference_url.
|
|
21
23
|
# - Ensure the task app mounts the openai-api-key secret if your vLLM gateway
|
|
22
24
|
# requires a bearer token (OPENAI_API_KEY). Otherwise it will call unauthenticated.
|
|
23
|
-
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
# RL training starting from base Qwen/Qwen3-4B (TOML-only model selection)
|
|
2
2
|
|
|
3
|
+
type = "rl"
|
|
4
|
+
|
|
3
5
|
[algorithm]
|
|
4
6
|
type = "online"
|
|
5
7
|
method = "policy_gradient"
|
|
6
8
|
variety = "gspo"
|
|
7
9
|
|
|
8
|
-
|
|
9
10
|
[services]
|
|
10
11
|
task_url = "https://synth-laboratories--grpo-crafter-task-app-final-warming--ceb5b2.modal.run"
|
|
11
12
|
|