synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +8 -11
- examples/dev/qwen3_32b_qlora_4xh100.toml +40 -0
- examples/multi_step/crafter_rl_lora.md +29 -0
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +65 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +19 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +39 -0
- examples/qwen_coder/todos.md +38 -0
- examples/qwen_coder/validate_jsonl.py +60 -0
- examples/rl/run_eval.py +36 -37
- examples/rl/run_rl_and_save.py +5 -5
- examples/rl/task_app/math_single_step.py +65 -43
- examples/rl/task_app/math_task_app.py +3 -3
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/PROPOSAL.md +53 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +5 -5
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +78 -21
- examples/warming_up_to_rl/groq_test.py +4 -4
- examples/warming_up_to_rl/manage_secrets.py +13 -18
- examples/warming_up_to_rl/run_eval.py +42 -44
- examples/warming_up_to_rl/run_fft_and_save.py +11 -16
- examples/warming_up_to_rl/run_local_rollout.py +1 -3
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
- examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
- examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
- examples/warming_up_to_rl/run_rl_and_save.py +5 -6
- examples/warming_up_to_rl/run_rollout_remote.py +8 -10
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +128 -21
- synth_ai/api/train/cli.py +80 -64
- synth_ai/api/train/config_finder.py +7 -2
- synth_ai/api/train/env_resolver.py +1 -1
- synth_ai/api/train/pollers.py +2 -1
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +1 -2
- synth_ai/api/train/utils.py +13 -44
- synth_ai/cli/__init__.py +8 -0
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +1 -2
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +2 -1
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +2 -1
- synth_ai/cli/root.py +11 -13
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +529 -179
- synth_ai/cli/traces.py +6 -4
- synth_ai/cli/watch.py +12 -18
- synth_ai/demo_registry.py +1 -1
- synth_ai/demos/core/cli.py +36 -43
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +17 -25
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +2 -5
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +4 -7
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/handshake.py +9 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +18 -10
- synth_ai/inference/client.py +15 -5
- synth_ai/jobs/client.py +78 -83
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +91 -24
- synth_ai/learning/config.py +2 -38
- synth_ai/learning/ft_client.py +4 -59
- synth_ai/learning/health.py +5 -6
- synth_ai/learning/jobs.py +31 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -4
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -8
- synth_ai/{rl → learning/rl}/env_keys.py +39 -15
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -281
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -24
- synth_ai/learning/validators.py +25 -28
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +25 -27
- synth_ai/task/apps/__init__.py +7 -8
- synth_ai/task/auth.py +8 -8
- synth_ai/task/client.py +14 -14
- synth_ai/task/contracts.py +36 -35
- synth_ai/task/datasets.py +6 -5
- synth_ai/task/errors.py +10 -10
- synth_ai/task/health.py +17 -9
- synth_ai/task/json.py +58 -23
- synth_ai/task/proxy.py +13 -9
- synth_ai/task/rubrics.py +16 -15
- synth_ai/task/server.py +12 -12
- synth_ai/task/tracing_utils.py +4 -4
- synth_ai/task/vendors.py +5 -6
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/decorators.py +18 -16
- synth_ai/tracing_v3/hooks.py +5 -5
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/session_tracer.py +40 -14
- synth_ai/tracing_v3/storage/base.py +85 -0
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -7
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -4
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/main.py +6 -6
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
- synth_ai/{lm → v0/lm}/overrides.py +2 -2
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/METADATA +10 -7
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/RECORD +269 -233
- examples/common_old/backend.py +0 -20
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1038
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -152
- examples/rl_old/task_app.py +0 -1131
- synth_ai/experimental/synth_oss.py +0 -445
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -249
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -838
- synth_ai/zyk/__init__.py +0 -30
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/top_level.txt +0 -0
examples/rl/run_eval.py
CHANGED
|
@@ -5,24 +5,24 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import argparse
|
|
7
7
|
import asyncio
|
|
8
|
+
import contextlib
|
|
8
9
|
import json
|
|
9
10
|
import os
|
|
10
|
-
|
|
11
|
-
from typing import Any
|
|
11
|
+
import tomllib
|
|
12
|
+
from typing import Any
|
|
12
13
|
|
|
13
14
|
import httpx
|
|
14
|
-
import tomllib
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class TaskAppClient:
|
|
18
18
|
"""Minimal async client for math single-step task app."""
|
|
19
19
|
|
|
20
|
-
def __init__(self, base_url: str, api_key:
|
|
20
|
+
def __init__(self, base_url: str, api_key: str | None = None) -> None:
|
|
21
21
|
self.base_url = base_url.rstrip("/")
|
|
22
22
|
self.api_key = api_key
|
|
23
|
-
self._client:
|
|
23
|
+
self._client: httpx.AsyncClient | None = None
|
|
24
24
|
|
|
25
|
-
async def __aenter__(self) ->
|
|
25
|
+
async def __aenter__(self) -> TaskAppClient:
|
|
26
26
|
headers = {"X-API-Key": self.api_key} if self.api_key else {}
|
|
27
27
|
self._client = httpx.AsyncClient(
|
|
28
28
|
base_url=self.base_url,
|
|
@@ -49,32 +49,30 @@ class TaskAppClient:
|
|
|
49
49
|
)
|
|
50
50
|
return self._client
|
|
51
51
|
|
|
52
|
-
async def initialize(self, split: str, seed: int | None) ->
|
|
53
|
-
payload:
|
|
52
|
+
async def initialize(self, split: str, seed: int | None) -> dict[str, Any]:
|
|
53
|
+
payload: dict[str, Any] = {"config": {"split": split}}
|
|
54
54
|
if seed is not None:
|
|
55
55
|
payload["seed"] = seed
|
|
56
56
|
resp = await self.client.post("/env/math/initialize", json=payload)
|
|
57
57
|
resp.raise_for_status()
|
|
58
58
|
return resp.json()
|
|
59
59
|
|
|
60
|
-
async def step(self, env_id: str, tool_calls:
|
|
60
|
+
async def step(self, env_id: str, tool_calls: list[dict[str, Any]]) -> dict[str, Any]:
|
|
61
61
|
payload = {"env_id": env_id, "action": {"tool_calls": tool_calls}}
|
|
62
62
|
resp = await self.client.post("/env/math/step", json=payload)
|
|
63
63
|
resp.raise_for_status()
|
|
64
64
|
return resp.json()
|
|
65
65
|
|
|
66
66
|
async def terminate(self, env_id: str) -> None:
|
|
67
|
-
|
|
67
|
+
with contextlib.suppress(Exception):
|
|
68
68
|
await self.client.post("/env/math/terminate", json={"env_id": env_id})
|
|
69
|
-
except Exception:
|
|
70
|
-
pass
|
|
71
69
|
|
|
72
|
-
async def get_info(self) ->
|
|
70
|
+
async def get_info(self) -> dict[str, Any]:
|
|
73
71
|
resp = await self.client.get("/info")
|
|
74
72
|
resp.raise_for_status()
|
|
75
73
|
return resp.json()
|
|
76
74
|
|
|
77
|
-
async def rollout(self, payload:
|
|
75
|
+
async def rollout(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
78
76
|
resp = await self.client.post("/rollout", json=payload)
|
|
79
77
|
resp.raise_for_status()
|
|
80
78
|
return resp.json()
|
|
@@ -82,10 +80,10 @@ class TaskAppClient:
|
|
|
82
80
|
async def post_inference(
|
|
83
81
|
self,
|
|
84
82
|
url: str,
|
|
85
|
-
payload:
|
|
83
|
+
payload: dict[str, Any],
|
|
86
84
|
*,
|
|
87
|
-
headers:
|
|
88
|
-
) ->
|
|
85
|
+
headers: dict[str, str] | None = None,
|
|
86
|
+
) -> dict[str, Any]:
|
|
89
87
|
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as c:
|
|
90
88
|
resp = await c.post(url, json=payload, headers=headers)
|
|
91
89
|
resp.raise_for_status()
|
|
@@ -96,7 +94,7 @@ TOOL_NAME = "math_submit"
|
|
|
96
94
|
DEFAULT_SPLIT = os.getenv("MATH_EVAL_DEFAULT_SPLIT", "validation")
|
|
97
95
|
|
|
98
96
|
|
|
99
|
-
def _math_tool_schema() ->
|
|
97
|
+
def _math_tool_schema() -> list[dict[str, Any]]:
|
|
100
98
|
return [
|
|
101
99
|
{
|
|
102
100
|
"type": "function",
|
|
@@ -123,7 +121,7 @@ def _math_tool_schema() -> List[Dict[str, Any]]:
|
|
|
123
121
|
]
|
|
124
122
|
|
|
125
123
|
|
|
126
|
-
def _build_messages(problem: str) ->
|
|
124
|
+
def _build_messages(problem: str) -> list[dict[str, Any]]:
|
|
127
125
|
return [
|
|
128
126
|
{
|
|
129
127
|
"role": "system",
|
|
@@ -139,18 +137,18 @@ def _build_messages(problem: str) -> List[Dict[str, Any]]:
|
|
|
139
137
|
]
|
|
140
138
|
|
|
141
139
|
|
|
142
|
-
def _parse_tool_calls(data:
|
|
140
|
+
def _parse_tool_calls(data: dict[str, Any]) -> list[dict[str, Any]]:
|
|
143
141
|
choices = data.get("choices") or []
|
|
144
142
|
if not choices:
|
|
145
143
|
return []
|
|
146
144
|
message = choices[0].get("message") or {}
|
|
147
145
|
raw_calls = message.get("tool_calls") or []
|
|
148
|
-
tool_calls:
|
|
146
|
+
tool_calls: list[dict[str, Any]] = []
|
|
149
147
|
for call in raw_calls:
|
|
150
148
|
function = call.get("function") or {}
|
|
151
149
|
name = function.get("name")
|
|
152
150
|
arguments = function.get("arguments")
|
|
153
|
-
parsed_args:
|
|
151
|
+
parsed_args: dict[str, Any]
|
|
154
152
|
if isinstance(arguments, str):
|
|
155
153
|
try:
|
|
156
154
|
parsed_args = json.loads(arguments)
|
|
@@ -164,7 +162,7 @@ def _parse_tool_calls(data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
164
162
|
return tool_calls
|
|
165
163
|
|
|
166
164
|
|
|
167
|
-
def _detect_provider(model: str, hint:
|
|
165
|
+
def _detect_provider(model: str, hint: str | None) -> str:
|
|
168
166
|
if hint:
|
|
169
167
|
return hint.lower()
|
|
170
168
|
lowered = (model or "").lower()
|
|
@@ -193,10 +191,10 @@ async def _choose_actions(
|
|
|
193
191
|
provider: str,
|
|
194
192
|
model: str,
|
|
195
193
|
problem: str,
|
|
196
|
-
policy_cfg:
|
|
197
|
-
) ->
|
|
194
|
+
policy_cfg: dict[str, Any],
|
|
195
|
+
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
|
198
196
|
messages = _build_messages(problem)
|
|
199
|
-
payload:
|
|
197
|
+
payload: dict[str, Any] = {
|
|
200
198
|
"model": model,
|
|
201
199
|
"messages": messages,
|
|
202
200
|
"tools": _math_tool_schema(),
|
|
@@ -239,7 +237,7 @@ async def _choose_actions(
|
|
|
239
237
|
return tool_calls, body
|
|
240
238
|
|
|
241
239
|
|
|
242
|
-
def _tool_to_answer(tool_calls:
|
|
240
|
+
def _tool_to_answer(tool_calls: list[dict[str, Any]]) -> str:
|
|
243
241
|
if not tool_calls:
|
|
244
242
|
return ""
|
|
245
243
|
args = tool_calls[0].get("args") or {}
|
|
@@ -251,11 +249,11 @@ async def eval_episode(
|
|
|
251
249
|
client: TaskAppClient,
|
|
252
250
|
*,
|
|
253
251
|
split: str,
|
|
254
|
-
seed:
|
|
252
|
+
seed: int | None,
|
|
255
253
|
model: str,
|
|
256
254
|
provider: str,
|
|
257
|
-
policy_cfg:
|
|
258
|
-
) ->
|
|
255
|
+
policy_cfg: dict[str, Any],
|
|
256
|
+
) -> dict[str, Any]:
|
|
259
257
|
created = await client.initialize(split, seed)
|
|
260
258
|
env_id = created["env_id"]
|
|
261
259
|
observation = created.get("observation") or {}
|
|
@@ -288,10 +286,10 @@ async def eval_via_rollout(
|
|
|
288
286
|
*,
|
|
289
287
|
run_id: str,
|
|
290
288
|
split: str,
|
|
291
|
-
seed:
|
|
289
|
+
seed: int | None,
|
|
292
290
|
model: str,
|
|
293
|
-
policy_cfg:
|
|
294
|
-
) ->
|
|
291
|
+
policy_cfg: dict[str, Any],
|
|
292
|
+
) -> dict[str, Any]:
|
|
295
293
|
payload = {
|
|
296
294
|
"run_id": run_id,
|
|
297
295
|
"env": {
|
|
@@ -314,6 +312,7 @@ async def eval_via_rollout(
|
|
|
314
312
|
steps = traj.get("steps") or []
|
|
315
313
|
step = steps[0] if steps else {}
|
|
316
314
|
info = step.get("info") or {}
|
|
315
|
+
observation = step.get("obs") or {}
|
|
317
316
|
return {
|
|
318
317
|
"seed": seed,
|
|
319
318
|
"split": split,
|
|
@@ -328,14 +327,14 @@ async def eval_via_rollout(
|
|
|
328
327
|
}
|
|
329
328
|
|
|
330
329
|
|
|
331
|
-
def _load_config(path:
|
|
330
|
+
def _load_config(path: str | None) -> dict[str, Any]:
|
|
332
331
|
if not path:
|
|
333
332
|
return {}
|
|
334
333
|
with open(path, "rb") as fh:
|
|
335
334
|
return tomllib.load(fh)
|
|
336
335
|
|
|
337
336
|
|
|
338
|
-
def _default_policy_cfg(cfg:
|
|
337
|
+
def _default_policy_cfg(cfg: dict[str, Any]) -> dict[str, Any]:
|
|
339
338
|
policy = dict(cfg.get("policy") or {})
|
|
340
339
|
if "inference_url" not in policy:
|
|
341
340
|
env_url = os.getenv("INFERENCE_URL")
|
|
@@ -380,8 +379,8 @@ async def main() -> None:
|
|
|
380
379
|
api_key = os.getenv("ENVIRONMENT_API_KEY")
|
|
381
380
|
|
|
382
381
|
successes = 0
|
|
383
|
-
failures:
|
|
384
|
-
results:
|
|
382
|
+
failures: dict[str, int] = {}
|
|
383
|
+
results: list[dict[str, Any]] = []
|
|
385
384
|
|
|
386
385
|
async with TaskAppClient(task_app_url, api_key=api_key) as client:
|
|
387
386
|
for episode in range(episodes):
|
examples/rl/run_rl_and_save.py
CHANGED
|
@@ -7,14 +7,14 @@ import argparse
|
|
|
7
7
|
import json
|
|
8
8
|
import os
|
|
9
9
|
import sys
|
|
10
|
+
import tomllib
|
|
10
11
|
from pathlib import Path
|
|
11
|
-
from typing import Any
|
|
12
|
+
from typing import Any
|
|
12
13
|
|
|
13
14
|
import requests
|
|
14
|
-
import tomllib
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def _load_toml(path: Path) ->
|
|
17
|
+
def _load_toml(path: Path) -> dict[str, Any]:
|
|
18
18
|
if not path.exists():
|
|
19
19
|
print(f"config not found: {path}", file=sys.stderr)
|
|
20
20
|
sys.exit(2)
|
|
@@ -65,7 +65,7 @@ def main() -> None:
|
|
|
65
65
|
)
|
|
66
66
|
sys.exit(2)
|
|
67
67
|
|
|
68
|
-
payload:
|
|
68
|
+
payload: dict[str, Any] = {
|
|
69
69
|
"job_type": "rl",
|
|
70
70
|
"compute": cfg.get("compute", {}),
|
|
71
71
|
"data": {
|
|
@@ -77,7 +77,7 @@ def main() -> None:
|
|
|
77
77
|
|
|
78
78
|
backend = str(args.backend).rstrip("/")
|
|
79
79
|
url = f"{backend}/rl/jobs"
|
|
80
|
-
api_key = (os.getenv("SYNTH_API_KEY") or os.getenv("
|
|
80
|
+
api_key = (os.getenv("SYNTH_API_KEY") or os.getenv("SYNTH_KEY") or "").strip()
|
|
81
81
|
if not api_key:
|
|
82
82
|
print("Missing SYNTH_API_KEY in env", file=sys.stderr)
|
|
83
83
|
sys.exit(2)
|
|
@@ -1,21 +1,22 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
1
|
"""Task app configuration for a single-step math reasoning environment."""
|
|
4
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
5
|
import contextlib
|
|
6
6
|
import os
|
|
7
7
|
import random
|
|
8
8
|
import re
|
|
9
9
|
import uuid
|
|
10
|
+
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
|
|
10
11
|
from dataclasses import dataclass
|
|
11
12
|
from pathlib import Path
|
|
12
|
-
from typing import Any,
|
|
13
|
+
from typing import Any, cast
|
|
13
14
|
|
|
14
15
|
import httpx
|
|
15
16
|
from datasets import load_dataset
|
|
16
17
|
from fastapi import APIRouter, HTTPException, Request
|
|
17
18
|
from pydantic import BaseModel, Field
|
|
18
|
-
|
|
19
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
19
20
|
from synth_ai.task.contracts import (
|
|
20
21
|
RolloutMetrics,
|
|
21
22
|
RolloutRequest,
|
|
@@ -25,9 +26,9 @@ from synth_ai.task.contracts import (
|
|
|
25
26
|
TaskInfo,
|
|
26
27
|
)
|
|
27
28
|
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
29
|
+
from synth_ai.task.errors import http_exception
|
|
28
30
|
from synth_ai.task.rubrics import Rubric, load_rubric
|
|
29
31
|
from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
30
|
-
from synth_ai.task.errors import http_exception
|
|
31
32
|
from synth_ai.task.tracing_utils import (
|
|
32
33
|
build_tracer_factory,
|
|
33
34
|
resolve_sft_output_dir,
|
|
@@ -35,7 +36,6 @@ from synth_ai.task.tracing_utils import (
|
|
|
35
36
|
tracing_env_enabled,
|
|
36
37
|
)
|
|
37
38
|
from synth_ai.task.vendors import normalize_vendor_keys
|
|
38
|
-
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
39
39
|
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
40
40
|
|
|
41
41
|
REPO_ROOT = Path(__file__).resolve().parents[3]
|
|
@@ -43,7 +43,7 @@ REPO_ROOT = Path(__file__).resolve().parents[3]
|
|
|
43
43
|
_modal_volume_candidate = Path(
|
|
44
44
|
os.getenv("MATH_MODAL_DATASET_DIR", "/modal_volumes/math_dataset")
|
|
45
45
|
).expanduser()
|
|
46
|
-
_modal_volume_root:
|
|
46
|
+
_modal_volume_root: Path | None = None
|
|
47
47
|
try:
|
|
48
48
|
_modal_volume_candidate.mkdir(parents=True, exist_ok=True)
|
|
49
49
|
_modal_volume_root = _modal_volume_candidate
|
|
@@ -105,7 +105,7 @@ MATH_DATASET_SPEC = TaskDatasetSpec(
|
|
|
105
105
|
_BOXED_MARKERS: tuple[str, ...] = ("\\boxed", "boxed")
|
|
106
106
|
|
|
107
107
|
|
|
108
|
-
def _extract_boxed(text: str) ->
|
|
108
|
+
def _extract_boxed(text: str) -> str | None:
|
|
109
109
|
if not text:
|
|
110
110
|
return None
|
|
111
111
|
for marker in _BOXED_MARKERS:
|
|
@@ -174,9 +174,9 @@ class MathDataset:
|
|
|
174
174
|
self.name = name
|
|
175
175
|
self.config = config
|
|
176
176
|
self.splits = [split for split in splits if split]
|
|
177
|
-
self._cache:
|
|
177
|
+
self._cache: dict[str, Any] = {}
|
|
178
178
|
self._local_dir = os.getenv("MATH_DATASET_LOCAL_DIR")
|
|
179
|
-
self._hf_token:
|
|
179
|
+
self._hf_token: str | None = None
|
|
180
180
|
for key in HF_TOKEN_ENV_KEYS:
|
|
181
181
|
value = os.getenv(key)
|
|
182
182
|
if value:
|
|
@@ -186,7 +186,7 @@ class MathDataset:
|
|
|
186
186
|
break
|
|
187
187
|
# No multi-candidate fallback: enforce explicit dataset id
|
|
188
188
|
|
|
189
|
-
def _local_file_for_split(self, split: str) ->
|
|
189
|
+
def _local_file_for_split(self, split: str) -> Path | None:
|
|
190
190
|
specific = os.getenv(f"MATH_DATASET_LOCAL_{split.upper()}_FILE")
|
|
191
191
|
if specific:
|
|
192
192
|
path = Path(specific).expanduser()
|
|
@@ -213,7 +213,7 @@ class MathDataset:
|
|
|
213
213
|
self._cache[split] = dataset["train"]
|
|
214
214
|
else:
|
|
215
215
|
try:
|
|
216
|
-
load_kwargs:
|
|
216
|
+
load_kwargs: dict[str, Any] = {"split": split}
|
|
217
217
|
if self.config:
|
|
218
218
|
load_kwargs["name"] = self.config
|
|
219
219
|
if self._hf_token:
|
|
@@ -227,7 +227,7 @@ class MathDataset:
|
|
|
227
227
|
tmp_path = target.with_name(target.name + ".tmp")
|
|
228
228
|
try:
|
|
229
229
|
local_dir.mkdir(parents=True, exist_ok=True)
|
|
230
|
-
|
|
230
|
+
ds.to_json(str(tmp_path))
|
|
231
231
|
tmp_path.replace(target)
|
|
232
232
|
except Exception:
|
|
233
233
|
with contextlib.suppress(FileNotFoundError):
|
|
@@ -241,7 +241,7 @@ class MathDataset:
|
|
|
241
241
|
raise RuntimeError(" ".join(hints)) from exc
|
|
242
242
|
return self._cache[split]
|
|
243
243
|
|
|
244
|
-
def sample(self, *, split: str, index:
|
|
244
|
+
def sample(self, *, split: str, index: int | None = None) -> dict[str, Any]:
|
|
245
245
|
dataset = self._load_split(split)
|
|
246
246
|
if len(dataset) == 0:
|
|
247
247
|
raise RuntimeError(f"Dataset split '{split}' is empty")
|
|
@@ -326,9 +326,9 @@ class MathEnvironmentManager:
|
|
|
326
326
|
|
|
327
327
|
def __init__(self, dataset: MathDataset) -> None:
|
|
328
328
|
self.dataset = dataset
|
|
329
|
-
self._states:
|
|
329
|
+
self._states: dict[str, MathEnvState] = {}
|
|
330
330
|
|
|
331
|
-
def create(self, *, split: str, index:
|
|
331
|
+
def create(self, *, split: str, index: int | None, seed: int | None) -> MathEnvState:
|
|
332
332
|
if index is None and seed is not None:
|
|
333
333
|
index = seed
|
|
334
334
|
sample = self.dataset.sample(split=split, index=index)
|
|
@@ -354,11 +354,11 @@ class MathEnvironmentManager:
|
|
|
354
354
|
|
|
355
355
|
|
|
356
356
|
class InitializePayload(BaseModel):
|
|
357
|
-
seed:
|
|
358
|
-
config:
|
|
357
|
+
seed: int | None = None
|
|
358
|
+
config: dict[str, Any] = Field(default_factory=dict)
|
|
359
359
|
|
|
360
360
|
|
|
361
|
-
def _observation_from_state(state: MathEnvState) ->
|
|
361
|
+
def _observation_from_state(state: MathEnvState) -> dict[str, Any]:
|
|
362
362
|
return {
|
|
363
363
|
"problem": state.problem,
|
|
364
364
|
"split": state.split,
|
|
@@ -390,12 +390,12 @@ def _score_submission(
|
|
|
390
390
|
math_router = APIRouter()
|
|
391
391
|
|
|
392
392
|
|
|
393
|
-
def _preview_tool_calls(tool_calls: Sequence[Mapping[str, Any]]) -> list[
|
|
393
|
+
def _preview_tool_calls(tool_calls: Sequence[Mapping[str, Any]]) -> list[dict[str, Any]]:
|
|
394
394
|
"""Return a compact, log-friendly preview of tool calls.
|
|
395
395
|
|
|
396
396
|
Truncates long fields to avoid noisy logs and leaking excessive content.
|
|
397
397
|
"""
|
|
398
|
-
preview: list[
|
|
398
|
+
preview: list[dict[str, Any]] = []
|
|
399
399
|
for call in list(tool_calls or [])[:3]:
|
|
400
400
|
args = dict(call.get("args") or {})
|
|
401
401
|
answer = str(args.get("answer") or "")
|
|
@@ -412,7 +412,7 @@ def _preview_tool_calls(tool_calls: Sequence[Mapping[str, Any]]) -> list[Dict[st
|
|
|
412
412
|
|
|
413
413
|
def _event_and_outcome_components(
|
|
414
414
|
tool_calls: Sequence[Mapping[str, Any]], *, correct: bool, reward: float
|
|
415
|
-
) ->
|
|
415
|
+
) -> dict[str, float]:
|
|
416
416
|
"""Approximate component-wise scores for RL-style logs.
|
|
417
417
|
|
|
418
418
|
- env: task-level scalar reward (our single-step outcome)
|
|
@@ -434,7 +434,7 @@ def _event_and_outcome_components(
|
|
|
434
434
|
|
|
435
435
|
|
|
436
436
|
@math_router.post("/env/math/initialize")
|
|
437
|
-
async def initialize_env(request: Request, payload: InitializePayload) ->
|
|
437
|
+
async def initialize_env(request: Request, payload: InitializePayload) -> dict[str, Any]:
|
|
438
438
|
manager: MathEnvironmentManager = request.app.state.math_env_manager
|
|
439
439
|
split = str(payload.config.get("split") or DEFAULT_SPLIT)
|
|
440
440
|
seed = payload.seed
|
|
@@ -450,7 +450,7 @@ async def initialize_env(request: Request, payload: InitializePayload) -> Dict[s
|
|
|
450
450
|
|
|
451
451
|
|
|
452
452
|
@math_router.post("/env/math/step")
|
|
453
|
-
async def step_env(request: Request, payload:
|
|
453
|
+
async def step_env(request: Request, payload: dict[str, Any]) -> dict[str, Any]:
|
|
454
454
|
manager: MathEnvironmentManager = request.app.state.math_env_manager
|
|
455
455
|
env_id = str(payload.get("env_id") or "")
|
|
456
456
|
if not env_id:
|
|
@@ -463,7 +463,7 @@ async def step_env(request: Request, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
463
463
|
action = payload.get("action") or {}
|
|
464
464
|
tool_calls = action.get("tool_calls") or payload.get("tool_calls") or []
|
|
465
465
|
reward, status, correct = _score_submission(state, tool_calls)
|
|
466
|
-
|
|
466
|
+
with contextlib.suppress(Exception):
|
|
467
467
|
print(
|
|
468
468
|
"[MATH_STEP] env_id=",
|
|
469
469
|
state.env_id,
|
|
@@ -483,8 +483,6 @@ async def step_env(request: Request, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
483
483
|
_event_and_outcome_components(tool_calls, correct=correct, reward=reward),
|
|
484
484
|
flush=True,
|
|
485
485
|
)
|
|
486
|
-
except Exception:
|
|
487
|
-
pass
|
|
488
486
|
state.done = True
|
|
489
487
|
|
|
490
488
|
observation = _observation_from_state(state)
|
|
@@ -502,7 +500,7 @@ async def step_env(request: Request, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
502
500
|
|
|
503
501
|
|
|
504
502
|
@math_router.post("/env/math/terminate")
|
|
505
|
-
async def terminate_env(request: Request, payload:
|
|
503
|
+
async def terminate_env(request: Request, payload: dict[str, Any]) -> dict[str, Any]:
|
|
506
504
|
manager: MathEnvironmentManager = request.app.state.math_env_manager
|
|
507
505
|
env_id = str(payload.get("env_id") or "")
|
|
508
506
|
if env_id:
|
|
@@ -525,7 +523,7 @@ def _resolve_inference_url(base_url: str) -> str:
|
|
|
525
523
|
|
|
526
524
|
async def _call_inference(
|
|
527
525
|
policy_config: Mapping[str, Any], observation: Mapping[str, Any]
|
|
528
|
-
) -> tuple[list[
|
|
526
|
+
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
|
529
527
|
inference_url = str(policy_config.get("inference_url") or "").rstrip("/")
|
|
530
528
|
if not inference_url:
|
|
531
529
|
raise RuntimeError("policy.config.inference_url required for rollout")
|
|
@@ -557,7 +555,7 @@ async def _call_inference(
|
|
|
557
555
|
},
|
|
558
556
|
]
|
|
559
557
|
|
|
560
|
-
payload:
|
|
558
|
+
payload: dict[str, Any] = {
|
|
561
559
|
"model": model,
|
|
562
560
|
"messages": messages,
|
|
563
561
|
"tools": [
|
|
@@ -626,7 +624,7 @@ async def _call_inference(
|
|
|
626
624
|
function = call.get("function") or {}
|
|
627
625
|
name = function.get("name")
|
|
628
626
|
arguments = function.get("arguments")
|
|
629
|
-
parsed_args:
|
|
627
|
+
parsed_args: dict[str, Any]
|
|
630
628
|
if isinstance(arguments, str):
|
|
631
629
|
try:
|
|
632
630
|
import json
|
|
@@ -640,7 +638,7 @@ async def _call_inference(
|
|
|
640
638
|
parsed_args = {}
|
|
641
639
|
tool_calls.append({"tool": name, "args": parsed_args})
|
|
642
640
|
# Lightweight provider-side logging
|
|
643
|
-
|
|
641
|
+
with contextlib.suppress(Exception):
|
|
644
642
|
print(
|
|
645
643
|
"[MATH_INFER] model=",
|
|
646
644
|
model,
|
|
@@ -648,8 +646,6 @@ async def _call_inference(
|
|
|
648
646
|
_preview_tool_calls(tool_calls),
|
|
649
647
|
flush=True,
|
|
650
648
|
)
|
|
651
|
-
except Exception:
|
|
652
|
-
pass
|
|
653
649
|
return tool_calls, data
|
|
654
650
|
|
|
655
651
|
|
|
@@ -664,9 +660,9 @@ async def rollout_executor(request: RolloutRequest, fastapi_request: Request) ->
|
|
|
664
660
|
"index": sample["index"],
|
|
665
661
|
}
|
|
666
662
|
|
|
667
|
-
tool_calls: list[
|
|
668
|
-
inference_payload:
|
|
669
|
-
error_info:
|
|
663
|
+
tool_calls: list[dict[str, Any]] = []
|
|
664
|
+
inference_payload: dict[str, Any] | None = None
|
|
665
|
+
error_info: dict[str, Any] = {}
|
|
670
666
|
try:
|
|
671
667
|
tool_calls, inference_payload = await _call_inference(
|
|
672
668
|
request.policy.config or {}, observation
|
|
@@ -691,7 +687,7 @@ async def rollout_executor(request: RolloutRequest, fastapi_request: Request) ->
|
|
|
691
687
|
)
|
|
692
688
|
|
|
693
689
|
# Log a concise summary so we can debug reward=0 issues in production
|
|
694
|
-
|
|
690
|
+
with contextlib.suppress(Exception):
|
|
695
691
|
print(
|
|
696
692
|
"[MATH_ROLLOUT] run=",
|
|
697
693
|
request.run_id,
|
|
@@ -711,8 +707,6 @@ async def rollout_executor(request: RolloutRequest, fastapi_request: Request) ->
|
|
|
711
707
|
_event_and_outcome_components(tool_calls, correct=correct, reward=reward),
|
|
712
708
|
flush=True,
|
|
713
709
|
)
|
|
714
|
-
except Exception:
|
|
715
|
-
pass
|
|
716
710
|
|
|
717
711
|
step = RolloutStep(
|
|
718
712
|
obs=observation,
|
|
@@ -749,6 +743,34 @@ async def rollout_executor(request: RolloutRequest, fastapi_request: Request) ->
|
|
|
749
743
|
details={"status": status, "correct": correct},
|
|
750
744
|
)
|
|
751
745
|
|
|
746
|
+
# Include a minimal trace when requested or tracing is enabled via env
|
|
747
|
+
include_trace = bool(
|
|
748
|
+
(request.record and getattr(request.record, "return_trace", False))
|
|
749
|
+
or os.getenv("TASKAPP_TRACING_ENABLED")
|
|
750
|
+
)
|
|
751
|
+
trace_payload = None
|
|
752
|
+
if include_trace:
|
|
753
|
+
try:
|
|
754
|
+
# Minimal structured trace for assertions
|
|
755
|
+
trace_payload = {
|
|
756
|
+
"session_id": str(uuid.uuid4()),
|
|
757
|
+
"events_count": 1,
|
|
758
|
+
"decision_rewards": [reward],
|
|
759
|
+
"lm_calls": (
|
|
760
|
+
[{"prompt": str(observation.get("problem", "")), "response": str(tool_calls)}]
|
|
761
|
+
if tool_calls
|
|
762
|
+
else []
|
|
763
|
+
),
|
|
764
|
+
"metadata": {
|
|
765
|
+
"env": "math_single_step",
|
|
766
|
+
"split": sample["split"],
|
|
767
|
+
"index": sample["index"],
|
|
768
|
+
"status": status,
|
|
769
|
+
},
|
|
770
|
+
}
|
|
771
|
+
except Exception:
|
|
772
|
+
trace_payload = None
|
|
773
|
+
|
|
752
774
|
return RolloutResponse(
|
|
753
775
|
run_id=request.run_id,
|
|
754
776
|
trajectories=[trajectory],
|
|
@@ -756,7 +778,7 @@ async def rollout_executor(request: RolloutRequest, fastapi_request: Request) ->
|
|
|
756
778
|
metrics=metrics,
|
|
757
779
|
aborted=False,
|
|
758
780
|
ops_executed=2,
|
|
759
|
-
trace=
|
|
781
|
+
trace=trace_payload,
|
|
760
782
|
)
|
|
761
783
|
|
|
762
784
|
|
|
@@ -854,7 +876,7 @@ EVENTS_RUBRIC: Rubric = cast(
|
|
|
854
876
|
)
|
|
855
877
|
|
|
856
878
|
|
|
857
|
-
def describe_taskset(dataset: MathDataset) ->
|
|
879
|
+
def describe_taskset(dataset: MathDataset) -> dict[str, Any]:
|
|
858
880
|
return {
|
|
859
881
|
**MATH_DATASET_SPEC.model_dump(),
|
|
860
882
|
"hf_dataset": DATASET_NAME,
|
|
@@ -895,7 +917,7 @@ def build_config() -> TaskAppConfig:
|
|
|
895
917
|
)
|
|
896
918
|
sft_output_dir = resolve_sft_output_dir()
|
|
897
919
|
|
|
898
|
-
app_state:
|
|
920
|
+
app_state: dict[str, Any] = {
|
|
899
921
|
"math_dataset": dataset,
|
|
900
922
|
"math_env_manager": MathEnvironmentManager(dataset),
|
|
901
923
|
"tracing_enabled": tracing_enabled,
|
|
@@ -8,10 +8,10 @@ from pathlib import Path
|
|
|
8
8
|
from fastapi.exceptions import RequestValidationError
|
|
9
9
|
from fastapi.responses import JSONResponse
|
|
10
10
|
from starlette.requests import Request
|
|
11
|
-
|
|
11
|
+
from synth_ai.task.auth import is_api_key_header_authorized, normalize_environment_api_key
|
|
12
12
|
from synth_ai.task.server import create_task_app, run_task_app
|
|
13
|
+
|
|
13
14
|
from .math_single_step import build_config
|
|
14
|
-
from synth_ai.task.auth import is_api_key_header_authorized, normalize_environment_api_key
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
def fastapi_app():
|
|
@@ -73,7 +73,7 @@ def fastapi_app():
|
|
|
73
73
|
try:
|
|
74
74
|
hdr = request.headers
|
|
75
75
|
snapshot = {
|
|
76
|
-
"path": str(
|
|
76
|
+
"path": str(request.url.path),
|
|
77
77
|
"have_x_api_key": bool(hdr.get("x-api-key")),
|
|
78
78
|
"have_x_api_keys": bool(hdr.get("x-api-keys")),
|
|
79
79
|
"have_authorization": bool(hdr.get("authorization")),
|