synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +23 -17
- examples/dev/qwen3_32b_qlora_4xh100.toml +40 -0
- examples/multi_step/crafter_rl_lora.md +29 -0
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +65 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +19 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +39 -0
- examples/qwen_coder/todos.md +38 -0
- examples/qwen_coder/validate_jsonl.py +60 -0
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +53 -52
- examples/rl/run_rl_and_save.py +29 -12
- examples/rl/task_app/math_single_step.py +180 -41
- examples/rl/task_app/math_task_app.py +14 -6
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/PROPOSAL.md +53 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +12 -10
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +218 -36
- examples/warming_up_to_rl/groq_test.py +15 -8
- examples/warming_up_to_rl/manage_secrets.py +29 -25
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +137 -61
- examples/warming_up_to_rl/run_fft_and_save.py +131 -60
- examples/warming_up_to_rl/run_local_rollout.py +88 -39
- examples/warming_up_to_rl/run_local_rollout_modal.py +114 -28
- examples/warming_up_to_rl/run_local_rollout_parallel.py +81 -20
- examples/warming_up_to_rl/run_local_rollout_traced.py +126 -23
- examples/warming_up_to_rl/run_rl_and_save.py +35 -12
- examples/warming_up_to_rl/run_rollout_remote.py +44 -19
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +319 -57
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +11 -30
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +137 -182
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +150 -57
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +105 -69
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +19 -7
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +45 -42
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +47 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +198 -92
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +361 -263
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +394 -274
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +56 -62
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +6 -15
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth_ai/__init__.py +1 -0
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +157 -26
- synth_ai/api/train/cli.py +213 -57
- synth_ai/api/train/config_finder.py +65 -5
- synth_ai/api/train/env_resolver.py +33 -15
- synth_ai/api/train/pollers.py +13 -4
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +5 -3
- synth_ai/api/train/utils.py +33 -48
- synth_ai/cli/__init__.py +19 -4
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +2 -3
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +21 -6
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +77 -17
- synth_ai/cli/root.py +116 -39
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +1699 -259
- synth_ai/cli/traces.py +7 -4
- synth_ai/cli/turso.py +73 -0
- synth_ai/cli/watch.py +12 -18
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +68 -31
- synth_ai/demos/core/cli.py +516 -194
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +64 -28
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +37 -30
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +183 -82
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -2
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +5 -6
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +10 -9
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/evals/base.py +0 -2
- synth_ai/handshake.py +11 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +43 -11
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +20 -6
- synth_ai/jobs/client.py +103 -78
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +121 -29
- synth_ai/learning/config.py +2 -40
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +4 -56
- synth_ai/learning/health.py +13 -7
- synth_ai/learning/jobs.py +43 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -5
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -10
- synth_ai/{rl → learning/rl}/env_keys.py +45 -16
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -253
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -26
- synth_ai/learning/validators.py +25 -24
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +26 -27
- synth_ai/task/apps/__init__.py +18 -19
- synth_ai/task/auth.py +35 -23
- synth_ai/task/client.py +15 -13
- synth_ai/task/contracts.py +37 -35
- synth_ai/task/datasets.py +9 -6
- synth_ai/task/errors.py +11 -10
- synth_ai/task/health.py +17 -11
- synth_ai/task/json.py +58 -24
- synth_ai/task/proxy.py +15 -14
- synth_ai/task/rubrics.py +22 -15
- synth_ai/task/server.py +43 -17
- synth_ai/task/tracing_utils.py +12 -7
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +5 -7
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +18 -15
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +6 -4
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +63 -16
- synth_ai/tracing_v3/storage/base.py +89 -1
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -8
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +5 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -3
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +3 -5
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/exceptions.py +0 -2
- synth_ai/{lm → v0/lm}/core/main.py +19 -7
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -15
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +6 -4
- synth_ai/{lm → v0/lm}/overrides.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +16 -16
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +12 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +11 -9
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +8 -5
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +4 -6
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +38 -11
- synth_ai/v0/tracing/upload.py +32 -135
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/METADATA +10 -7
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/RECORD +294 -258
- examples/common_old/backend.py +0 -21
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1037
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -239
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -118
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -239
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -37
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -44
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1932
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -207
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -232
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -147
- examples/rl_old/task_app.py +0 -962
- synth_ai/experimental/synth_oss.py +0 -446
- synth_ai/install_sqld.sh +0 -40
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -246
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -774
- synth_ai/zyk/__init__.py +0 -30
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/top_level.txt +0 -0
|
@@ -7,14 +7,12 @@ import argparse
|
|
|
7
7
|
import asyncio
|
|
8
8
|
import json
|
|
9
9
|
import os
|
|
10
|
+
import sys
|
|
10
11
|
from pathlib import Path
|
|
11
12
|
from typing import Any
|
|
12
13
|
|
|
13
|
-
import sys
|
|
14
|
-
|
|
15
14
|
import httpx
|
|
16
15
|
from dotenv import load_dotenv
|
|
17
|
-
|
|
18
16
|
from synth_ai.task import (
|
|
19
17
|
RolloutEnvSpec,
|
|
20
18
|
RolloutPolicySpec,
|
|
@@ -25,7 +23,9 @@ from synth_ai.task import (
|
|
|
25
23
|
)
|
|
26
24
|
|
|
27
25
|
|
|
28
|
-
def build_rollout_request(
|
|
26
|
+
def build_rollout_request(
|
|
27
|
+
seed: int, run_id: str, *, model: str, inference_url: str, ops: list[str], api_key: str
|
|
28
|
+
) -> RolloutRequest:
|
|
29
29
|
policy_config = {
|
|
30
30
|
"model": model,
|
|
31
31
|
"inference_url": inference_url,
|
|
@@ -45,7 +45,11 @@ def build_rollout_request(seed: int, run_id: str, *, model: str, inference_url:
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
def summarise_response(data: Any) -> dict[str, Any]:
|
|
48
|
-
metrics =
|
|
48
|
+
metrics = (
|
|
49
|
+
data.metrics.model_dump()
|
|
50
|
+
if hasattr(data.metrics, "model_dump")
|
|
51
|
+
else data.get("metrics", {})
|
|
52
|
+
)
|
|
49
53
|
return {
|
|
50
54
|
"run_id": getattr(data, "run_id", None) or data.get("run_id"),
|
|
51
55
|
"num_episodes": metrics.get("num_episodes"),
|
|
@@ -57,21 +61,54 @@ def summarise_response(data: Any) -> dict[str, Any]:
|
|
|
57
61
|
|
|
58
62
|
|
|
59
63
|
async def main() -> None:
|
|
64
|
+
# Load .env file from current directory first if it exists
|
|
65
|
+
default_env = Path.cwd() / ".env"
|
|
66
|
+
if default_env.exists():
|
|
67
|
+
load_dotenv(default_env, override=False)
|
|
68
|
+
|
|
60
69
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
61
70
|
parser.add_argument("--base-url", default="http://localhost:8010", help="Task app base URL")
|
|
62
71
|
parser.add_argument("--env-file", type=str, default=None, help="Path to .env file with keys")
|
|
63
72
|
parser.add_argument("--seed", type=int, default=42, help="Env seed to rollout")
|
|
64
73
|
parser.add_argument("--run-id", default="modal-eval", help="Run identifier")
|
|
65
|
-
parser.add_argument(
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
parser.add_argument(
|
|
71
|
-
|
|
72
|
-
|
|
74
|
+
parser.add_argument(
|
|
75
|
+
"--model",
|
|
76
|
+
required=False,
|
|
77
|
+
help="Model identifier for the Crafter policy (e.g., fft:Qwen/Qwen3-4B:job_xxx)",
|
|
78
|
+
)
|
|
79
|
+
parser.add_argument(
|
|
80
|
+
"--inference-url",
|
|
81
|
+
required=False,
|
|
82
|
+
help="Modal backend inference base URL (e.g., http://localhost:8000/api)",
|
|
83
|
+
)
|
|
84
|
+
parser.add_argument(
|
|
85
|
+
"--task-app-key",
|
|
86
|
+
default=None,
|
|
87
|
+
help="Environment API key for the task app (fallback ENVIRONMENT_API_KEY)",
|
|
88
|
+
)
|
|
89
|
+
parser.add_argument(
|
|
90
|
+
"--modal-key",
|
|
91
|
+
default=None,
|
|
92
|
+
help="Synth/Modal API key for inference (fallback SYNTH_API_KEY)",
|
|
93
|
+
)
|
|
94
|
+
parser.add_argument(
|
|
95
|
+
"--max-llm-calls", type=int, default=20, help="Number of policy inference calls"
|
|
96
|
+
)
|
|
97
|
+
parser.add_argument(
|
|
98
|
+
"--ops", default=None, help="Comma-separated rollout ops (advanced override)"
|
|
99
|
+
)
|
|
100
|
+
parser.add_argument(
|
|
101
|
+
"--max-policy-tokens",
|
|
102
|
+
type=int,
|
|
103
|
+
default=None,
|
|
104
|
+
help="Optional per-call token limit forwarded to the policy config",
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"--verbose", action="store_true", help="Print resolved configuration and headers"
|
|
108
|
+
)
|
|
73
109
|
args = parser.parse_args()
|
|
74
110
|
|
|
111
|
+
# Also load from explicit --env-file if provided
|
|
75
112
|
if args.env_file:
|
|
76
113
|
env_path = Path(args.env_file).expanduser()
|
|
77
114
|
if not env_path.exists():
|
|
@@ -79,16 +116,51 @@ async def main() -> None:
|
|
|
79
116
|
else:
|
|
80
117
|
load_dotenv(env_path, override=False)
|
|
81
118
|
|
|
119
|
+
# Prompt for required parameters if not provided
|
|
120
|
+
base_url = args.base_url
|
|
121
|
+
if args.base_url == "http://localhost:8010":
|
|
122
|
+
print("\nTask app configuration:")
|
|
123
|
+
base_url_input = input("Task app base URL [http://localhost:8001]: ").strip()
|
|
124
|
+
base_url = base_url_input if base_url_input else "http://localhost:8001"
|
|
125
|
+
|
|
126
|
+
model = args.model
|
|
127
|
+
if not model:
|
|
128
|
+
print("\nFine-tuned model configuration:")
|
|
129
|
+
print(
|
|
130
|
+
"Note: This should be the model ID returned from training (e.g., fft:Qwen/Qwen3-4B:job_abc123)"
|
|
131
|
+
)
|
|
132
|
+
model_input = input("Fine-tuned model ID: ").strip()
|
|
133
|
+
if not model_input:
|
|
134
|
+
parser.error("Model identifier is required")
|
|
135
|
+
model = model_input
|
|
136
|
+
|
|
137
|
+
inference_url = args.inference_url
|
|
138
|
+
if not inference_url:
|
|
139
|
+
inference_url_input = input("Inference URL [http://localhost:8000/api]: ").strip()
|
|
140
|
+
inference_url = inference_url_input if inference_url_input else "http://localhost:8000/api"
|
|
141
|
+
|
|
142
|
+
# Override args
|
|
143
|
+
args.base_url = base_url
|
|
144
|
+
args.model = model
|
|
145
|
+
args.inference_url = inference_url
|
|
146
|
+
|
|
147
|
+
# Check environment variables first (loaded from .env)
|
|
82
148
|
task_app_key = args.task_app_key or os.getenv("ENVIRONMENT_API_KEY")
|
|
83
149
|
if not task_app_key:
|
|
84
|
-
|
|
150
|
+
print("\n[INFO] ENVIRONMENT_API_KEY not found in environment or .env file")
|
|
151
|
+
task_app_key = input("RL Environment API key: ").strip()
|
|
152
|
+
if not task_app_key:
|
|
153
|
+
parser.error("Missing task app API key")
|
|
85
154
|
|
|
86
155
|
modal_key = args.modal_key or os.getenv("SYNTH_API_KEY")
|
|
87
156
|
if not modal_key:
|
|
88
|
-
|
|
157
|
+
print("[INFO] SYNTH_API_KEY not found in environment or .env file")
|
|
158
|
+
modal_key = input("Synth API key: ").strip()
|
|
159
|
+
if not modal_key:
|
|
160
|
+
parser.error("Missing Synth/Modal API key")
|
|
89
161
|
|
|
90
|
-
if
|
|
91
|
-
os.environ["OPENAI_API_KEY"] =
|
|
162
|
+
if modal_key and "openai.com" not in args.inference_url.lower():
|
|
163
|
+
os.environ["OPENAI_API_KEY"] = modal_key
|
|
92
164
|
|
|
93
165
|
if args.ops:
|
|
94
166
|
ops = [op.strip() for op in args.ops.split(",") if op.strip()]
|
|
@@ -103,6 +175,7 @@ async def main() -> None:
|
|
|
103
175
|
ops.extend(["agent", "env"])
|
|
104
176
|
|
|
105
177
|
if args.verbose:
|
|
178
|
+
|
|
106
179
|
def _mask(val: str | None) -> str:
|
|
107
180
|
if not val:
|
|
108
181
|
return "<unset>"
|
|
@@ -115,11 +188,15 @@ async def main() -> None:
|
|
|
115
188
|
print(f" Modal API key : {_mask(modal_key)}")
|
|
116
189
|
print(f" Ops (count={len(ops)}) : {ops}")
|
|
117
190
|
|
|
118
|
-
inf_url_norm = args.inference_url.rstrip(
|
|
119
|
-
if
|
|
120
|
-
print(
|
|
121
|
-
|
|
122
|
-
|
|
191
|
+
inf_url_norm = args.inference_url.rstrip("/")
|
|
192
|
+
if "/api" not in inf_url_norm:
|
|
193
|
+
print(
|
|
194
|
+
"[WARN] Inference URL is missing /api prefix; proxy endpoints usually live at /api/inference/v1/chat/completions."
|
|
195
|
+
)
|
|
196
|
+
elif not inf_url_norm.lower().endswith("/api"):
|
|
197
|
+
print(
|
|
198
|
+
"[INFO] Using inference base URL; policy will append /v1/chat/completions automatically."
|
|
199
|
+
)
|
|
123
200
|
|
|
124
201
|
async with TaskAppClient(args.base_url, api_key=task_app_key) as client:
|
|
125
202
|
try:
|
|
@@ -139,20 +216,29 @@ async def main() -> None:
|
|
|
139
216
|
if args.verbose:
|
|
140
217
|
print(f"Request headers: {request.policy.config.get('extra_headers', {})}")
|
|
141
218
|
if args.max_policy_tokens is not None:
|
|
142
|
-
request.policy.config.update(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
219
|
+
request.policy.config.update(
|
|
220
|
+
{
|
|
221
|
+
"max_completion_tokens": args.max_policy_tokens,
|
|
222
|
+
"max_tokens": args.max_policy_tokens,
|
|
223
|
+
}
|
|
224
|
+
)
|
|
146
225
|
print("Requesting rollout…")
|
|
147
226
|
response = await client.rollout(request)
|
|
148
227
|
summary = summarise_response(response)
|
|
149
228
|
print(json.dumps(summary, indent=2))
|
|
150
229
|
print(f"Ops executed: {ops}")
|
|
151
230
|
except httpx.HTTPStatusError as exc:
|
|
152
|
-
detail =
|
|
231
|
+
detail = (
|
|
232
|
+
exc.response.json()
|
|
233
|
+
if exc.response.headers.get("content-type", "").startswith("application/json")
|
|
234
|
+
else exc.response.text
|
|
235
|
+
)
|
|
153
236
|
print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
|
|
154
237
|
if exc.response.status_code in (401, 503):
|
|
155
|
-
print(
|
|
238
|
+
print(
|
|
239
|
+
"Hint: ensure ENVIRONMENT_API_KEY and SYNTH_API_KEY are correctly set.",
|
|
240
|
+
file=sys.stderr,
|
|
241
|
+
)
|
|
156
242
|
raise
|
|
157
243
|
|
|
158
244
|
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import argparse
|
|
7
7
|
import asyncio
|
|
8
|
-
import json
|
|
9
8
|
import os
|
|
10
9
|
from collections import Counter
|
|
11
10
|
from pathlib import Path
|
|
@@ -13,15 +12,13 @@ from statistics import mean, median
|
|
|
13
12
|
from typing import Any
|
|
14
13
|
|
|
15
14
|
from dotenv import load_dotenv
|
|
16
|
-
|
|
17
|
-
from synth_ai.task import TaskAppClient
|
|
18
|
-
|
|
19
15
|
from synth_ai.task import (
|
|
20
16
|
RolloutEnvSpec,
|
|
21
17
|
RolloutPolicySpec,
|
|
22
18
|
RolloutRecordConfig,
|
|
23
19
|
RolloutRequest,
|
|
24
20
|
RolloutSafetyConfig,
|
|
21
|
+
TaskAppClient,
|
|
25
22
|
)
|
|
26
23
|
|
|
27
24
|
|
|
@@ -31,12 +28,17 @@ def build_rollout_request(
|
|
|
31
28
|
run_id: str,
|
|
32
29
|
model: str,
|
|
33
30
|
inference_url: str,
|
|
31
|
+
inference_api_key: str,
|
|
34
32
|
ops: list[str],
|
|
35
33
|
extra_headers: dict[str, str] | None = None,
|
|
36
34
|
trace_format: str = "compact",
|
|
37
35
|
return_trace: bool = False,
|
|
38
36
|
) -> RolloutRequest:
|
|
39
|
-
policy_config = {
|
|
37
|
+
policy_config = {
|
|
38
|
+
"model": model,
|
|
39
|
+
"inference_url": inference_url,
|
|
40
|
+
"api_key": inference_api_key,
|
|
41
|
+
}
|
|
40
42
|
if extra_headers:
|
|
41
43
|
policy_config["extra_headers"] = extra_headers
|
|
42
44
|
record_cfg = RolloutRecordConfig(
|
|
@@ -123,7 +125,9 @@ def analyse_rollout_response(response: Any) -> dict[str, Any]:
|
|
|
123
125
|
if isinstance(final_list, list):
|
|
124
126
|
final_achievements = [str(item) for item in final_list]
|
|
125
127
|
|
|
126
|
-
decision_rewards =
|
|
128
|
+
decision_rewards = (
|
|
129
|
+
trace_payload.get("decision_rewards") if isinstance(trace_payload, dict) else []
|
|
130
|
+
)
|
|
127
131
|
trace_all: list[str] = []
|
|
128
132
|
if isinstance(decision_rewards, list):
|
|
129
133
|
for item in decision_rewards:
|
|
@@ -180,7 +184,9 @@ def summarise_runs(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
|
|
|
180
184
|
return stats
|
|
181
185
|
|
|
182
186
|
|
|
183
|
-
def print_summary(
|
|
187
|
+
def print_summary(
|
|
188
|
+
stats: dict[str, Any], *, run_details: list[dict[str, Any]], total_runs: int
|
|
189
|
+
) -> None:
|
|
184
190
|
if not stats:
|
|
185
191
|
print("No successful rollouts to summarise.")
|
|
186
192
|
return
|
|
@@ -234,7 +240,22 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
|
|
|
234
240
|
|
|
235
241
|
api_key = args.api_key or os.getenv("ENVIRONMENT_API_KEY")
|
|
236
242
|
if not api_key:
|
|
237
|
-
|
|
243
|
+
import sys
|
|
244
|
+
|
|
245
|
+
print("Please enter your RL Environment API key:", file=sys.stderr, flush=True)
|
|
246
|
+
api_key = input("> ").strip()
|
|
247
|
+
if not api_key:
|
|
248
|
+
raise RuntimeError("RL Environment API key is required")
|
|
249
|
+
|
|
250
|
+
# Prompt for Groq API key if not set
|
|
251
|
+
groq_api_key = os.getenv("GROQ_API_KEY")
|
|
252
|
+
if not groq_api_key:
|
|
253
|
+
import sys
|
|
254
|
+
|
|
255
|
+
print("Please enter your Groq API key:", file=sys.stderr, flush=True)
|
|
256
|
+
groq_api_key = input("> ").strip()
|
|
257
|
+
if not groq_api_key:
|
|
258
|
+
raise RuntimeError("Groq API key is required")
|
|
238
259
|
|
|
239
260
|
synth_key = os.getenv("SYNTH_API_KEY")
|
|
240
261
|
extra_headers: dict[str, str] | None = None
|
|
@@ -252,29 +273,41 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
|
|
|
252
273
|
|
|
253
274
|
ops = build_ops(args.max_llm_calls, args.ops)
|
|
254
275
|
|
|
276
|
+
print(f"\n🚀 Starting {args.count} rollouts with {args.parallel} parallel workers...")
|
|
277
|
+
print(f"📊 Each rollout: {len(ops)} ops ({args.max_llm_calls} LLM calls)\n")
|
|
278
|
+
|
|
255
279
|
async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
|
|
280
|
+
|
|
256
281
|
async def run_single(index: int) -> dict[str, Any]:
|
|
257
282
|
run_id = f"{args.run_id}-{index:03d}"
|
|
258
283
|
seed = args.seed + index * args.seed_stride
|
|
284
|
+
print(f"\n▶️ [{index + 1}/{args.count}] Starting rollout {run_id} (seed={seed})...")
|
|
285
|
+
|
|
259
286
|
request = build_rollout_request(
|
|
260
287
|
seed=seed,
|
|
261
288
|
run_id=run_id,
|
|
262
289
|
model=args.model,
|
|
263
290
|
inference_url=args.inference_url,
|
|
291
|
+
inference_api_key=groq_api_key,
|
|
264
292
|
ops=ops,
|
|
265
293
|
extra_headers=extra_headers,
|
|
266
294
|
trace_format=args.trace_format,
|
|
267
295
|
return_trace=True,
|
|
268
296
|
)
|
|
269
297
|
if args.max_policy_tokens is not None:
|
|
270
|
-
request.policy.config.update(
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
298
|
+
request.policy.config.update(
|
|
299
|
+
{
|
|
300
|
+
"max_completion_tokens": args.max_policy_tokens,
|
|
301
|
+
"max_tokens": args.max_policy_tokens,
|
|
302
|
+
}
|
|
303
|
+
)
|
|
274
304
|
|
|
275
305
|
try:
|
|
276
306
|
response = await client.rollout(request)
|
|
277
307
|
summary = analyse_rollout_response(response)
|
|
308
|
+
print(
|
|
309
|
+
f"\n✅ [{index + 1}/{args.count}] Completed {run_id} (outcome={summary.get('outcome_score', 'N/A')})"
|
|
310
|
+
)
|
|
278
311
|
return {
|
|
279
312
|
"ok": True,
|
|
280
313
|
"run_id": run_id,
|
|
@@ -283,6 +316,7 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
|
|
|
283
316
|
"summary": summary,
|
|
284
317
|
}
|
|
285
318
|
except Exception as exc: # pragma: no cover - surface errors
|
|
319
|
+
print(f"\n❌ [{index + 1}/{args.count}] Failed {run_id}: {exc}")
|
|
286
320
|
return {
|
|
287
321
|
"ok": False,
|
|
288
322
|
"run_id": run_id,
|
|
@@ -302,6 +336,7 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
|
|
|
302
336
|
successes = [item for item in results if item.get("ok")]
|
|
303
337
|
failures = [item for item in results if not item.get("ok")]
|
|
304
338
|
|
|
339
|
+
print(f"\n{'=' * 100}\n")
|
|
305
340
|
stats = summarise_runs([item["summary"] for item in successes])
|
|
306
341
|
print_summary(stats, run_details=successes, total_runs=args.count)
|
|
307
342
|
|
|
@@ -317,17 +352,43 @@ def parse_args() -> argparse.Namespace:
|
|
|
317
352
|
parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
|
|
318
353
|
parser.add_argument("--api-key", help="Environment API key (or set via --env-file)")
|
|
319
354
|
parser.add_argument("--env-file", help="Path to .env file providing API keys")
|
|
320
|
-
parser.add_argument(
|
|
321
|
-
|
|
355
|
+
parser.add_argument(
|
|
356
|
+
"--model", default="gpt-4o-mini", help="Model identifier for the Crafter policy"
|
|
357
|
+
)
|
|
358
|
+
parser.add_argument(
|
|
359
|
+
"--inference-url",
|
|
360
|
+
default="https://api.openai.com",
|
|
361
|
+
help="Inference base URL for the policy",
|
|
362
|
+
)
|
|
322
363
|
parser.add_argument("--seed", type=int, default=42, help="Base seed for the first rollout")
|
|
323
|
-
parser.add_argument(
|
|
324
|
-
|
|
364
|
+
parser.add_argument(
|
|
365
|
+
"--seed-stride", type=int, default=1, help="Increment applied to the seed for each rollout"
|
|
366
|
+
)
|
|
367
|
+
parser.add_argument(
|
|
368
|
+
"--count", type=int, default=20, help="Number of rollout trajectories to execute"
|
|
369
|
+
)
|
|
325
370
|
parser.add_argument("--parallel", type=int, default=4, help="Maximum concurrent rollouts")
|
|
326
371
|
parser.add_argument("--ops", help="Comma-separated rollout ops (advanced override)")
|
|
327
|
-
parser.add_argument(
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
372
|
+
parser.add_argument(
|
|
373
|
+
"--max-llm-calls",
|
|
374
|
+
type=int,
|
|
375
|
+
default=20,
|
|
376
|
+
help="Number of agent/env pairs per rollout when --ops not provided",
|
|
377
|
+
)
|
|
378
|
+
parser.add_argument(
|
|
379
|
+
"--max-policy-tokens",
|
|
380
|
+
type=int,
|
|
381
|
+
help="Optional per-call token limit forwarded to the policy config",
|
|
382
|
+
)
|
|
383
|
+
parser.add_argument(
|
|
384
|
+
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds) for task app requests"
|
|
385
|
+
)
|
|
386
|
+
parser.add_argument(
|
|
387
|
+
"--trace-format",
|
|
388
|
+
default="compact",
|
|
389
|
+
choices=["compact", "full"],
|
|
390
|
+
help="Trace format requested from the task app",
|
|
391
|
+
)
|
|
331
392
|
parser.add_argument("--run-id", default="batch-demo", help="Run ID prefix for rollouts")
|
|
332
393
|
parser.add_argument("--verbose", action="store_true", help="Print resolved configuration")
|
|
333
394
|
return parser.parse_args()
|
|
@@ -6,13 +6,12 @@ from __future__ import annotations
|
|
|
6
6
|
import argparse
|
|
7
7
|
import asyncio
|
|
8
8
|
import json
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
9
11
|
from pathlib import Path
|
|
10
12
|
from typing import Any
|
|
11
13
|
|
|
12
|
-
import sys
|
|
13
|
-
|
|
14
14
|
import httpx
|
|
15
|
-
|
|
16
15
|
from synth_ai.task import (
|
|
17
16
|
RolloutEnvSpec,
|
|
18
17
|
RolloutPolicySpec,
|
|
@@ -29,6 +28,7 @@ def build_rollout_request(
|
|
|
29
28
|
run_id: str,
|
|
30
29
|
model: str,
|
|
31
30
|
inference_url: str,
|
|
31
|
+
inference_api_key: str,
|
|
32
32
|
ops: list[str],
|
|
33
33
|
return_trace: bool,
|
|
34
34
|
trace_format: str,
|
|
@@ -37,6 +37,7 @@ def build_rollout_request(
|
|
|
37
37
|
policy_config = {
|
|
38
38
|
"model": model,
|
|
39
39
|
"inference_url": inference_url,
|
|
40
|
+
"api_key": inference_api_key,
|
|
40
41
|
}
|
|
41
42
|
if max_policy_tokens is not None:
|
|
42
43
|
policy_config.update(
|
|
@@ -64,7 +65,11 @@ def build_rollout_request(
|
|
|
64
65
|
|
|
65
66
|
|
|
66
67
|
def summarise_rollout(response: Any) -> dict[str, Any]:
|
|
67
|
-
metrics =
|
|
68
|
+
metrics = (
|
|
69
|
+
response.metrics.model_dump()
|
|
70
|
+
if hasattr(response, "metrics")
|
|
71
|
+
else response.get("metrics", {})
|
|
72
|
+
)
|
|
68
73
|
return {
|
|
69
74
|
"run_id": getattr(response, "run_id", None) or response.get("run_id"),
|
|
70
75
|
"num_episodes": metrics.get("num_episodes"),
|
|
@@ -83,17 +88,25 @@ def summarise_trace(trace: Any) -> dict[str, Any]:
|
|
|
83
88
|
|
|
84
89
|
format_hint = "compact" if "events_count" in trace or "lm_calls" in trace else "full"
|
|
85
90
|
events_count = trace.get("events_count")
|
|
86
|
-
if
|
|
91
|
+
if (
|
|
92
|
+
events_count is None
|
|
93
|
+
and "event_history" in trace
|
|
94
|
+
and isinstance(trace["event_history"], list)
|
|
95
|
+
):
|
|
87
96
|
events_count = len(trace["event_history"])
|
|
88
97
|
messages_count = trace.get("messages_count")
|
|
89
|
-
if
|
|
90
|
-
|
|
98
|
+
if (
|
|
99
|
+
messages_count is None
|
|
100
|
+
and "markov_blanket_message_history" in trace
|
|
101
|
+
and isinstance(trace["markov_blanket_message_history"], list)
|
|
91
102
|
):
|
|
92
103
|
messages_count = len(trace["markov_blanket_message_history"])
|
|
93
104
|
|
|
94
105
|
metadata = trace.get("metadata") if isinstance(trace.get("metadata"), dict) else {}
|
|
95
106
|
lm_calls = trace.get("lm_calls") if isinstance(trace.get("lm_calls"), list) else []
|
|
96
|
-
decision_rewards =
|
|
107
|
+
decision_rewards = (
|
|
108
|
+
trace.get("decision_rewards") if isinstance(trace.get("decision_rewards"), list) else []
|
|
109
|
+
)
|
|
97
110
|
|
|
98
111
|
return {
|
|
99
112
|
"session_id": trace.get("session_id"),
|
|
@@ -215,11 +228,13 @@ def print_reward_summary(
|
|
|
215
228
|
if decision_rewards:
|
|
216
229
|
print(" Decision rewards:")
|
|
217
230
|
for entry in decision_rewards:
|
|
218
|
-
turn = entry.get(
|
|
219
|
-
ach_delta = entry.get(
|
|
220
|
-
unique_delta = entry.get(
|
|
221
|
-
achievements = entry.get(
|
|
222
|
-
print(
|
|
231
|
+
turn = entry.get("turn")
|
|
232
|
+
ach_delta = entry.get("ach_delta")
|
|
233
|
+
unique_delta = entry.get("unique_delta")
|
|
234
|
+
achievements = entry.get("achievements") or []
|
|
235
|
+
print(
|
|
236
|
+
f" turn={turn}, ach_delta={ach_delta}, unique_delta={unique_delta}, achievements={achievements}"
|
|
237
|
+
)
|
|
223
238
|
else:
|
|
224
239
|
print(" Decision rewards: none recorded")
|
|
225
240
|
|
|
@@ -242,16 +257,40 @@ def print_reward_summary(
|
|
|
242
257
|
|
|
243
258
|
|
|
244
259
|
async def main() -> None:
|
|
260
|
+
# Load .env file from current directory if it exists
|
|
261
|
+
env_file = Path.cwd() / ".env"
|
|
262
|
+
if env_file.exists():
|
|
263
|
+
from dotenv import load_dotenv
|
|
264
|
+
|
|
265
|
+
load_dotenv(env_file)
|
|
266
|
+
|
|
245
267
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
246
|
-
parser.add_argument("--base-url", default="http://localhost:
|
|
247
|
-
parser.add_argument("--api-key",
|
|
268
|
+
parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
|
|
269
|
+
parser.add_argument("--api-key", help="RL Environment API key (will prompt if not provided)")
|
|
270
|
+
parser.add_argument(
|
|
271
|
+
"--inference-api-key", help="Inference provider API key (will prompt if not provided)"
|
|
272
|
+
)
|
|
248
273
|
parser.add_argument("--seed", type=int, default=42, help="Environment seed")
|
|
249
274
|
parser.add_argument("--run-id", default="local-trace", help="Run identifier")
|
|
250
275
|
parser.add_argument("--model", default="gpt-4o-mini", help="OpenAI-compatible model id")
|
|
251
|
-
parser.add_argument(
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
parser.add_argument(
|
|
276
|
+
parser.add_argument(
|
|
277
|
+
"--inference-url", default="https://api.openai.com", help="Inference base URL (OpenAI/Groq)"
|
|
278
|
+
)
|
|
279
|
+
parser.add_argument(
|
|
280
|
+
"--ops", help="Comma-separated rollout ops (fallback: alternating agent/env)"
|
|
281
|
+
)
|
|
282
|
+
parser.add_argument(
|
|
283
|
+
"--max-llm-calls",
|
|
284
|
+
type=int,
|
|
285
|
+
default=1,
|
|
286
|
+
help="Number of agent/env pairs when --ops not supplied",
|
|
287
|
+
)
|
|
288
|
+
parser.add_argument(
|
|
289
|
+
"--max-policy-tokens",
|
|
290
|
+
type=int,
|
|
291
|
+
default=None,
|
|
292
|
+
help="Optional max token budget forwarded to policy",
|
|
293
|
+
)
|
|
255
294
|
parser.add_argument(
|
|
256
295
|
"--trace-format",
|
|
257
296
|
choices=["compact", "full"],
|
|
@@ -286,10 +325,69 @@ async def main() -> None:
|
|
|
286
325
|
)
|
|
287
326
|
args = parser.parse_args()
|
|
288
327
|
|
|
328
|
+
# Prompt for required parameters if not provided
|
|
329
|
+
base_url = args.base_url
|
|
330
|
+
if args.base_url == "http://localhost:8001":
|
|
331
|
+
print("\nTask app configuration:")
|
|
332
|
+
base_url_input = input("Task app base URL [http://localhost:8001]: ").strip()
|
|
333
|
+
base_url = base_url_input if base_url_input else "http://localhost:8001"
|
|
334
|
+
|
|
335
|
+
api_key = args.api_key or os.getenv("ENVIRONMENT_API_KEY")
|
|
336
|
+
if not api_key:
|
|
337
|
+
api_key = input("RL Environment API key (from ENVIRONMENT_API_KEY): ").strip()
|
|
338
|
+
if not api_key:
|
|
339
|
+
parser.error("RL Environment API key is required")
|
|
340
|
+
|
|
341
|
+
# Use Groq by default
|
|
342
|
+
model = "llama-3.3-70b-versatile"
|
|
343
|
+
inference_url = "https://api.groq.com/openai"
|
|
344
|
+
|
|
345
|
+
print("\nInference configuration (Groq):")
|
|
346
|
+
inference_api_key = args.inference_api_key or os.getenv("GROQ_API_KEY")
|
|
347
|
+
if not inference_api_key:
|
|
348
|
+
inference_api_key = input("Groq API key: ").strip()
|
|
349
|
+
if not inference_api_key:
|
|
350
|
+
parser.error("Groq API key is required")
|
|
351
|
+
|
|
352
|
+
# Save to .env for future use
|
|
353
|
+
env_path = Path.cwd() / ".env"
|
|
354
|
+
try:
|
|
355
|
+
# Read existing .env
|
|
356
|
+
existing_lines = []
|
|
357
|
+
if env_path.exists():
|
|
358
|
+
existing_lines = env_path.read_text().splitlines()
|
|
359
|
+
|
|
360
|
+
# Check if GROQ_API_KEY already exists
|
|
361
|
+
key_exists = any(line.strip().startswith("GROQ_API_KEY=") for line in existing_lines)
|
|
362
|
+
|
|
363
|
+
if not key_exists:
|
|
364
|
+
# Append to .env
|
|
365
|
+
with open(env_path, "a") as f:
|
|
366
|
+
if existing_lines and not existing_lines[-1].strip():
|
|
367
|
+
# File exists and last line is not empty
|
|
368
|
+
pass
|
|
369
|
+
elif existing_lines:
|
|
370
|
+
# Add newline before appending
|
|
371
|
+
f.write("\n")
|
|
372
|
+
f.write(f"GROQ_API_KEY={inference_api_key}\n")
|
|
373
|
+
print(f"[INFO] Saved GROQ_API_KEY to {env_path}")
|
|
374
|
+
except Exception as e:
|
|
375
|
+
print(f"[WARN] Could not save GROQ_API_KEY to .env: {e}")
|
|
376
|
+
|
|
377
|
+
print("\nRollout configuration:")
|
|
378
|
+
max_llm_calls = args.max_llm_calls
|
|
379
|
+
if args.max_llm_calls == 1:
|
|
380
|
+
max_llm_calls_input = input("Max LLM calls [10]: ").strip()
|
|
381
|
+
max_llm_calls = int(max_llm_calls_input) if max_llm_calls_input else 10
|
|
382
|
+
|
|
383
|
+
# Override args with prompted values
|
|
384
|
+
args.base_url = base_url
|
|
385
|
+
args.max_llm_calls = max_llm_calls
|
|
386
|
+
|
|
289
387
|
ops = ensure_ops(args.ops, args.max_llm_calls)
|
|
290
388
|
return_trace = not args.no_trace
|
|
291
389
|
|
|
292
|
-
async with TaskAppClient(args.base_url, api_key=
|
|
390
|
+
async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
|
|
293
391
|
try:
|
|
294
392
|
print(f"Fetching task_info for seed {args.seed}…")
|
|
295
393
|
task_info = await client.task_info(seeds=[args.seed])
|
|
@@ -302,8 +400,9 @@ async def main() -> None:
|
|
|
302
400
|
request = build_rollout_request(
|
|
303
401
|
seed=args.seed,
|
|
304
402
|
run_id=args.run_id,
|
|
305
|
-
model=
|
|
306
|
-
inference_url=
|
|
403
|
+
model=model,
|
|
404
|
+
inference_url=inference_url,
|
|
405
|
+
inference_api_key=inference_api_key,
|
|
307
406
|
ops=ops,
|
|
308
407
|
return_trace=return_trace,
|
|
309
408
|
trace_format=args.trace_format,
|
|
@@ -350,7 +449,11 @@ async def main() -> None:
|
|
|
350
449
|
"Tip: export TASKAPP_TRACING_ENABLED=1 and optionally TASKAPP_SFT_OUTPUT_DIR before running `uvx synth-ai serve …` to persist traces/SFT."
|
|
351
450
|
)
|
|
352
451
|
except httpx.HTTPStatusError as exc:
|
|
353
|
-
detail =
|
|
452
|
+
detail = (
|
|
453
|
+
exc.response.json()
|
|
454
|
+
if exc.response.headers.get("content-type", "").startswith("application/json")
|
|
455
|
+
else exc.response.text
|
|
456
|
+
)
|
|
354
457
|
print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
|
|
355
458
|
if exc.response.status_code in (401, 503):
|
|
356
459
|
print(
|