synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.9.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +23 -17
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +64 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +18 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +38 -0
- examples/qwen_coder/validate_jsonl.py +59 -0
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +53 -52
- examples/rl/run_rl_and_save.py +29 -12
- examples/rl/task_app/math_single_step.py +180 -41
- examples/rl/task_app/math_task_app.py +14 -6
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +12 -10
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +218 -36
- examples/warming_up_to_rl/groq_test.py +15 -8
- examples/warming_up_to_rl/manage_secrets.py +29 -25
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +137 -61
- examples/warming_up_to_rl/run_fft_and_save.py +131 -60
- examples/warming_up_to_rl/run_local_rollout.py +88 -39
- examples/warming_up_to_rl/run_local_rollout_modal.py +114 -28
- examples/warming_up_to_rl/run_local_rollout_parallel.py +81 -20
- examples/warming_up_to_rl/run_local_rollout_traced.py +126 -23
- examples/warming_up_to_rl/run_rl_and_save.py +35 -12
- examples/warming_up_to_rl/run_rollout_remote.py +44 -19
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +319 -57
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +11 -30
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +137 -182
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +150 -57
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +105 -69
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +19 -7
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +45 -42
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +47 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +198 -92
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +361 -263
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +394 -274
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +56 -62
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +6 -15
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth/__init__.py +14 -0
- synth_ai/__init__.py +20 -4
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +157 -26
- synth_ai/api/train/cli.py +213 -57
- synth_ai/api/train/config_finder.py +65 -5
- synth_ai/api/train/env_resolver.py +33 -15
- synth_ai/api/train/pollers.py +13 -4
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +5 -3
- synth_ai/api/train/utils.py +33 -48
- synth_ai/cli/__init__.py +19 -4
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +2 -3
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +21 -6
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +77 -17
- synth_ai/cli/root.py +116 -39
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +1699 -259
- synth_ai/cli/traces.py +7 -4
- synth_ai/cli/turso.py +73 -0
- synth_ai/cli/watch.py +12 -18
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +68 -31
- synth_ai/demos/core/cli.py +516 -194
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +64 -28
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +37 -30
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +183 -82
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -2
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +5 -6
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +10 -9
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/evals/base.py +0 -2
- synth_ai/handshake.py +11 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +43 -11
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +20 -6
- synth_ai/jobs/client.py +103 -78
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +121 -29
- synth_ai/learning/config.py +2 -40
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +4 -56
- synth_ai/learning/health.py +13 -7
- synth_ai/learning/jobs.py +43 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -5
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -10
- synth_ai/{rl → learning/rl}/env_keys.py +45 -16
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -253
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -26
- synth_ai/learning/validators.py +25 -24
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +26 -27
- synth_ai/task/apps/__init__.py +18 -19
- synth_ai/task/auth.py +35 -23
- synth_ai/task/client.py +15 -13
- synth_ai/task/contracts.py +37 -35
- synth_ai/task/datasets.py +9 -6
- synth_ai/task/errors.py +11 -10
- synth_ai/task/health.py +17 -11
- synth_ai/task/json.py +58 -24
- synth_ai/task/proxy.py +15 -14
- synth_ai/task/rubrics.py +22 -15
- synth_ai/task/server.py +43 -17
- synth_ai/task/tracing_utils.py +12 -7
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +5 -7
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +18 -15
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +6 -4
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +63 -16
- synth_ai/tracing_v3/storage/base.py +89 -1
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -8
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +5 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -3
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +3 -5
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/exceptions.py +0 -2
- synth_ai/{lm → v0/lm}/core/main.py +19 -7
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -15
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +6 -4
- synth_ai/{lm → v0/lm}/overrides.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +16 -16
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +12 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +11 -9
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +8 -5
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +4 -6
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +38 -11
- synth_ai/v0/tracing/upload.py +32 -135
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- synth_ai-0.2.9.dev6.dist-info/METADATA +191 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/RECORD +291 -262
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/top_level.txt +1 -0
- examples/common_old/backend.py +0 -21
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1037
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -239
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -118
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -239
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -37
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -44
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1932
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -207
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -232
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -147
- examples/rl_old/task_app.py +0 -962
- examples/warming_up_to_rl/old/event_rewards.md +0 -234
- examples/warming_up_to_rl/old/notes.md +0 -73
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/experimental/synth_oss.py +0 -446
- synth_ai/install_sqld.sh +0 -40
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -246
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -774
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev5.dist-info/METADATA +0 -131
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,102 +1,95 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import contextlib
|
|
4
4
|
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import time as _time
|
|
5
8
|
from datetime import datetime
|
|
6
|
-
from
|
|
7
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
9
|
+
from typing import Any
|
|
8
10
|
|
|
9
11
|
from fastapi import APIRouter, HTTPException, Request, status
|
|
10
|
-
import os
|
|
11
|
-
import time as _time
|
|
12
12
|
from pydantic import BaseModel
|
|
13
13
|
from synth_ai.lm.vendors.base import BaseLMResponse
|
|
14
|
-
from synth_ai.
|
|
14
|
+
from synth_ai.task.tracing_utils import unique_sft_path
|
|
15
15
|
from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
|
|
16
16
|
from synth_ai.tracing_v3.llm_call_record_helpers import create_llm_call_record_from_response
|
|
17
|
-
from synth_ai.
|
|
17
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
18
18
|
|
|
19
19
|
from .registry import registry
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
22
22
|
|
|
23
|
+
|
|
23
24
|
# --- Seeding utilities (robust, optional deps) ---
|
|
24
|
-
def _set_global_seed(seed_value: int) ->
|
|
25
|
+
def _set_global_seed(seed_value: int) -> dict[str, Any]:
|
|
25
26
|
"""Set global RNG seeds across common libraries; return details for logging/restoration.
|
|
26
27
|
|
|
27
28
|
Returns a dict containing which libraries were seeded and prior states if obtainable.
|
|
28
29
|
"""
|
|
29
|
-
seeded:
|
|
30
|
-
|
|
30
|
+
seeded: dict[str, Any] = {"seed": int(seed_value), "libs": []}
|
|
31
|
+
with contextlib.suppress(Exception):
|
|
31
32
|
import random as _random # type: ignore
|
|
33
|
+
|
|
32
34
|
_random.seed(seed_value)
|
|
33
35
|
seeded["libs"].append("random")
|
|
34
|
-
|
|
35
|
-
pass
|
|
36
|
-
try:
|
|
36
|
+
with contextlib.suppress(Exception):
|
|
37
37
|
import numpy as _np # type: ignore
|
|
38
|
+
|
|
38
39
|
_np.random.seed(seed_value)
|
|
39
40
|
seeded["libs"].append("numpy")
|
|
40
|
-
|
|
41
|
-
pass
|
|
42
|
-
try:
|
|
41
|
+
with contextlib.suppress(Exception):
|
|
43
42
|
import torch as _torch # type: ignore
|
|
43
|
+
|
|
44
44
|
if hasattr(_torch, "manual_seed"):
|
|
45
45
|
_torch.manual_seed(seed_value)
|
|
46
46
|
seeded["libs"].append("torch")
|
|
47
47
|
# Make CUDA deterministic if present (best-effort)
|
|
48
|
-
|
|
48
|
+
with contextlib.suppress(Exception):
|
|
49
49
|
if getattr(_torch, "cuda", None) and _torch.cuda.is_available():
|
|
50
50
|
_torch.cuda.manual_seed_all(seed_value)
|
|
51
51
|
seeded.setdefault("cuda", True)
|
|
52
|
-
except Exception:
|
|
53
|
-
pass
|
|
54
52
|
# CUDNN deterministic flags (optional)
|
|
55
|
-
|
|
53
|
+
with contextlib.suppress(Exception):
|
|
56
54
|
if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
|
|
57
55
|
_torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
|
|
58
56
|
_torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
|
|
59
|
-
except Exception:
|
|
60
|
-
pass
|
|
61
|
-
except Exception:
|
|
62
|
-
pass
|
|
63
57
|
return seeded
|
|
64
58
|
|
|
59
|
+
|
|
65
60
|
def _clear_seed_side_effects() -> None:
|
|
66
61
|
"""Best-effort cleanup to avoid global deterministic side-effects between requests."""
|
|
67
62
|
# We cannot truly restore prior RNG states without capturing them; we just avoid
|
|
68
63
|
# leaving aggressive deterministic flags enabled where it matters.
|
|
69
|
-
|
|
64
|
+
with contextlib.suppress(Exception):
|
|
70
65
|
import torch as _torch # type: ignore
|
|
71
|
-
|
|
66
|
+
|
|
67
|
+
with contextlib.suppress(Exception):
|
|
72
68
|
if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
|
|
73
69
|
# Re-enable cudnn.benchmark default True only if it was True; safest is False -> leave as is.
|
|
74
70
|
# We'll keep deterministic False to avoid global impact; benchmark left False for stability.
|
|
75
71
|
_torch.backends.cudnn.deterministic = False # type: ignore[attr-defined]
|
|
76
|
-
|
|
77
|
-
pass
|
|
78
|
-
except Exception:
|
|
79
|
-
pass
|
|
72
|
+
|
|
80
73
|
|
|
81
74
|
router = APIRouter()
|
|
82
75
|
|
|
83
76
|
|
|
84
77
|
class RolloutEnvSpec(BaseModel):
|
|
85
|
-
env_id:
|
|
86
|
-
env_name:
|
|
87
|
-
config:
|
|
88
|
-
seed:
|
|
78
|
+
env_id: str | None = None
|
|
79
|
+
env_name: str | None = None
|
|
80
|
+
config: dict[str, Any] = {}
|
|
81
|
+
seed: int | None = None
|
|
89
82
|
|
|
90
83
|
|
|
91
84
|
class RolloutPolicySpec(BaseModel):
|
|
92
|
-
policy_id:
|
|
93
|
-
policy_name:
|
|
94
|
-
config:
|
|
85
|
+
policy_id: str | None = None
|
|
86
|
+
policy_name: str | None = None
|
|
87
|
+
config: dict[str, Any] = {}
|
|
95
88
|
|
|
96
89
|
|
|
97
90
|
class RolloutBranchConfig(BaseModel):
|
|
98
91
|
branch_every_n_steps: int = 0
|
|
99
|
-
branch_on_condition:
|
|
92
|
+
branch_on_condition: str | None = None
|
|
100
93
|
max_branches: int = 0
|
|
101
94
|
branch_policy: bool = False
|
|
102
95
|
branch_env: bool = False
|
|
@@ -119,53 +112,49 @@ class RolloutRequest(BaseModel):
|
|
|
119
112
|
run_id: str
|
|
120
113
|
env: RolloutEnvSpec
|
|
121
114
|
policy: RolloutPolicySpec
|
|
122
|
-
ops:
|
|
115
|
+
ops: list[str] # ["agent", "env", ...]
|
|
123
116
|
record: RolloutRecordConfig = RolloutRecordConfig()
|
|
124
117
|
on_done: str = "reset" # "reset" | "terminate"
|
|
125
|
-
branch:
|
|
118
|
+
branch: RolloutBranchConfig | None = None
|
|
126
119
|
safety: RolloutSafetyConfig = RolloutSafetyConfig()
|
|
127
120
|
# Optional run/session context
|
|
128
|
-
training_session_id:
|
|
129
|
-
synth_base_url:
|
|
121
|
+
training_session_id: str | None = None
|
|
122
|
+
synth_base_url: str | None = None
|
|
130
123
|
|
|
131
124
|
|
|
132
125
|
class RolloutStep(BaseModel):
|
|
133
|
-
obs:
|
|
134
|
-
tool_calls:
|
|
135
|
-
reward:
|
|
126
|
+
obs: dict[str, Any]
|
|
127
|
+
tool_calls: list[dict[str, Any]]
|
|
128
|
+
reward: float | None = None
|
|
136
129
|
done: bool = False
|
|
137
|
-
truncated:
|
|
138
|
-
logprob:
|
|
139
|
-
value:
|
|
140
|
-
info:
|
|
130
|
+
truncated: bool | None = None
|
|
131
|
+
logprob: float | None = None
|
|
132
|
+
value: float | None = None
|
|
133
|
+
info: dict[str, Any] | None = None
|
|
141
134
|
|
|
142
135
|
|
|
143
136
|
class RolloutTrajectory(BaseModel):
|
|
144
137
|
env_id: str
|
|
145
138
|
policy_id: str
|
|
146
|
-
steps:
|
|
147
|
-
final:
|
|
139
|
+
steps: list[RolloutStep]
|
|
140
|
+
final: dict[str, Any] | None = None
|
|
148
141
|
length: int
|
|
149
|
-
decision_samples:
|
|
142
|
+
decision_samples: list[dict[str, Any]] | None = None
|
|
150
143
|
|
|
151
144
|
|
|
152
145
|
def compute_stepwise_reward(
|
|
153
|
-
prev_achievements:
|
|
154
|
-
new_achievements:
|
|
146
|
+
prev_achievements: dict[str, bool],
|
|
147
|
+
new_achievements: dict[str, bool],
|
|
155
148
|
decision_index: int,
|
|
156
|
-
actions_summary:
|
|
149
|
+
actions_summary: list[dict[str, Any]],
|
|
157
150
|
indicator_lambda: float,
|
|
158
|
-
) ->
|
|
151
|
+
) -> tuple[dict[str, Any], dict[str, Any], dict[str, float]]:
|
|
159
152
|
"""Compute stepwise reward metadata given achievement states before/after a decision."""
|
|
160
153
|
|
|
161
154
|
prev_map = prev_achievements or {}
|
|
162
155
|
next_map = new_achievements or {}
|
|
163
156
|
|
|
164
|
-
unlocked = [
|
|
165
|
-
name
|
|
166
|
-
for name, value in next_map.items()
|
|
167
|
-
if value and not prev_map.get(name, False)
|
|
168
|
-
]
|
|
157
|
+
unlocked = [name for name, value in next_map.items() if value and not prev_map.get(name, False)]
|
|
169
158
|
indicator = 1 if unlocked else 0
|
|
170
159
|
reward_value = float(indicator_lambda) * indicator
|
|
171
160
|
|
|
@@ -190,7 +179,7 @@ def compute_stepwise_reward(
|
|
|
190
179
|
|
|
191
180
|
|
|
192
181
|
class RolloutMetrics(BaseModel):
|
|
193
|
-
episode_returns:
|
|
182
|
+
episode_returns: list[float]
|
|
194
183
|
mean_return: float
|
|
195
184
|
num_steps: int
|
|
196
185
|
num_episodes: int = 0
|
|
@@ -198,12 +187,12 @@ class RolloutMetrics(BaseModel):
|
|
|
198
187
|
|
|
199
188
|
class RolloutResponse(BaseModel):
|
|
200
189
|
run_id: str
|
|
201
|
-
trajectories:
|
|
202
|
-
branches:
|
|
190
|
+
trajectories: list[RolloutTrajectory]
|
|
191
|
+
branches: dict[str, list[str]] = {}
|
|
203
192
|
metrics: RolloutMetrics
|
|
204
193
|
aborted: bool = False
|
|
205
194
|
ops_executed: int = 0
|
|
206
|
-
trace:
|
|
195
|
+
trace: dict[str, Any] | None = None
|
|
207
196
|
|
|
208
197
|
|
|
209
198
|
class RolloutTracingContext:
|
|
@@ -227,7 +216,11 @@ class RolloutTracingContext:
|
|
|
227
216
|
self.sft_records: list[dict[str, Any]] = []
|
|
228
217
|
self.latest_system_messages: list[str] = []
|
|
229
218
|
self.latest_user_messages: list[str] = []
|
|
230
|
-
self.
|
|
219
|
+
self.latest_system_prompt_content: list[Any] = []
|
|
220
|
+
self.latest_user_prompt_content: list[Any] = []
|
|
221
|
+
self.trace_format = (
|
|
222
|
+
getattr(request.record, "trace_format", "compact") or "compact"
|
|
223
|
+
).lower()
|
|
231
224
|
self.return_trace = bool(getattr(request.record, "return_trace", False))
|
|
232
225
|
self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
|
|
233
226
|
self.session_trace = None
|
|
@@ -257,7 +250,9 @@ class RolloutTracingContext:
|
|
|
257
250
|
except Exception as exc:
|
|
258
251
|
logger.debug("TRACING_INIT_FAIL: %s", exc)
|
|
259
252
|
try:
|
|
260
|
-
await self.tracer.start_session(
|
|
253
|
+
await self.tracer.start_session(
|
|
254
|
+
session_id=self.run_id, metadata=dict(self.metadata_base)
|
|
255
|
+
)
|
|
261
256
|
except Exception as exc:
|
|
262
257
|
logger.warning("TRACING_START_FAIL: %s", exc)
|
|
263
258
|
self.enabled = False
|
|
@@ -291,26 +286,32 @@ class RolloutTracingContext:
|
|
|
291
286
|
|
|
292
287
|
async def record_policy_prompts(
|
|
293
288
|
self,
|
|
294
|
-
system_messages: list[
|
|
295
|
-
user_messages: list[
|
|
289
|
+
system_messages: list[Any],
|
|
290
|
+
user_messages: list[Any],
|
|
296
291
|
) -> None:
|
|
297
|
-
self.latest_system_messages =
|
|
298
|
-
self.latest_user_messages =
|
|
292
|
+
self.latest_system_messages = [self._prompt_text(entry) for entry in system_messages]
|
|
293
|
+
self.latest_user_messages = [self._prompt_text(entry) for entry in user_messages]
|
|
294
|
+
self.latest_system_prompt_content = [
|
|
295
|
+
self._prompt_content(entry, role="system") for entry in system_messages
|
|
296
|
+
]
|
|
297
|
+
self.latest_user_prompt_content = [
|
|
298
|
+
self._prompt_content(entry, role="user") for entry in user_messages
|
|
299
|
+
]
|
|
299
300
|
if not self.enabled or self.tracer is None:
|
|
300
301
|
return
|
|
301
|
-
for
|
|
302
|
+
for entry in system_messages:
|
|
302
303
|
try:
|
|
303
304
|
await self.tracer.record_message(
|
|
304
|
-
content=
|
|
305
|
+
content=self._prompt_payload(entry, role="system"),
|
|
305
306
|
message_type="policy_system_prompt",
|
|
306
307
|
metadata=self._message_metadata(),
|
|
307
308
|
)
|
|
308
309
|
except Exception as exc:
|
|
309
310
|
logger.debug("TRACING_SYSTEM_MSG_FAIL: %s", exc)
|
|
310
|
-
for
|
|
311
|
+
for entry in user_messages:
|
|
311
312
|
try:
|
|
312
313
|
await self.tracer.record_message(
|
|
313
|
-
content=
|
|
314
|
+
content=self._prompt_payload(entry, role="user"),
|
|
314
315
|
message_type="policy_user_prompt",
|
|
315
316
|
metadata=self._message_metadata(),
|
|
316
317
|
)
|
|
@@ -332,6 +333,49 @@ class RolloutTracingContext:
|
|
|
332
333
|
return ""
|
|
333
334
|
return str(content)
|
|
334
335
|
|
|
336
|
+
def _prompt_text(self, entry: Any) -> str:
|
|
337
|
+
if isinstance(entry, dict):
|
|
338
|
+
text = entry.get("text")
|
|
339
|
+
if isinstance(text, str):
|
|
340
|
+
return text
|
|
341
|
+
content = entry.get("content")
|
|
342
|
+
return self._content_to_text(content)
|
|
343
|
+
return self._content_to_text(entry)
|
|
344
|
+
|
|
345
|
+
def _prompt_payload(self, entry: Any, *, role: str) -> dict[str, Any]:
|
|
346
|
+
if isinstance(entry, dict):
|
|
347
|
+
payload = dict(entry)
|
|
348
|
+
payload.setdefault("role", role)
|
|
349
|
+
return payload
|
|
350
|
+
return {
|
|
351
|
+
"role": role,
|
|
352
|
+
"text": self._prompt_text(entry),
|
|
353
|
+
"content": entry,
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
def _prompt_content(self, entry: Any, *, role: str) -> Any:
|
|
357
|
+
payload = self._prompt_payload(entry, role=role)
|
|
358
|
+
return payload.get("content", payload.get("text"))
|
|
359
|
+
|
|
360
|
+
def _content_has_image(self, content: Any) -> bool:
|
|
361
|
+
if isinstance(content, list):
|
|
362
|
+
return any(
|
|
363
|
+
isinstance(seg, dict)
|
|
364
|
+
and seg.get("type") in {"image", "image_url"}
|
|
365
|
+
for seg in content
|
|
366
|
+
)
|
|
367
|
+
if isinstance(content, dict):
|
|
368
|
+
if content.get("type") in {"image", "image_url"}:
|
|
369
|
+
return True
|
|
370
|
+
inner = content.get("content")
|
|
371
|
+
if isinstance(inner, list):
|
|
372
|
+
return any(
|
|
373
|
+
isinstance(seg, dict)
|
|
374
|
+
and seg.get("type") in {"image", "image_url"}
|
|
375
|
+
for seg in inner
|
|
376
|
+
)
|
|
377
|
+
return False
|
|
378
|
+
|
|
335
379
|
def _safe_json(self, payload: Any, limit: int = 4000) -> str:
|
|
336
380
|
try:
|
|
337
381
|
text = json.dumps(payload, ensure_ascii=False)
|
|
@@ -379,17 +423,15 @@ class RolloutTracingContext:
|
|
|
379
423
|
input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens")
|
|
380
424
|
output_tokens = usage.get("output_tokens") or usage.get("completion_tokens")
|
|
381
425
|
total_tokens = usage.get("total_tokens")
|
|
382
|
-
cost_usd = (
|
|
383
|
-
usage.get("cost_usd")
|
|
384
|
-
or usage.get("cost")
|
|
385
|
-
or usage.get("total_cost")
|
|
386
|
-
)
|
|
426
|
+
cost_usd = usage.get("cost_usd") or usage.get("cost") or usage.get("total_cost")
|
|
387
427
|
|
|
388
428
|
assistant_message = None
|
|
389
429
|
choices = inference_response.get("choices") or []
|
|
390
430
|
if choices:
|
|
391
431
|
assistant_message = choices[0].get("message") or {}
|
|
392
|
-
assistant_content =
|
|
432
|
+
assistant_content = (
|
|
433
|
+
assistant_message.get("content") if isinstance(assistant_message, dict) else None
|
|
434
|
+
)
|
|
393
435
|
|
|
394
436
|
raw_response = self._content_to_text(assistant_content)
|
|
395
437
|
if not raw_response:
|
|
@@ -397,7 +439,9 @@ class RolloutTracingContext:
|
|
|
397
439
|
|
|
398
440
|
base_response = BaseLMResponse(
|
|
399
441
|
raw_response=raw_response,
|
|
400
|
-
tool_calls=assistant_message.get("tool_calls")
|
|
442
|
+
tool_calls=assistant_message.get("tool_calls")
|
|
443
|
+
if isinstance(assistant_message, dict)
|
|
444
|
+
else None,
|
|
401
445
|
usage=usage or None,
|
|
402
446
|
api_type="chat_completions",
|
|
403
447
|
)
|
|
@@ -457,19 +501,44 @@ class RolloutTracingContext:
|
|
|
457
501
|
)
|
|
458
502
|
|
|
459
503
|
if self.sft_output_dir is not None:
|
|
504
|
+
assistant_structured = assistant_content if assistant_content is not None else ""
|
|
460
505
|
assistant_text = self._content_to_text(assistant_content)
|
|
506
|
+
dialogue_structured: list[dict[str, Any]] = []
|
|
507
|
+
for content in self.latest_system_prompt_content:
|
|
508
|
+
if content is None:
|
|
509
|
+
continue
|
|
510
|
+
dialogue_structured.append({"role": "system", "content": content})
|
|
511
|
+
for content in self.latest_user_prompt_content:
|
|
512
|
+
if content is None:
|
|
513
|
+
continue
|
|
514
|
+
dialogue_structured.append({"role": "user", "content": content})
|
|
515
|
+
dialogue_text = (
|
|
516
|
+
[{"role": "system", "content": s} for s in self.latest_system_messages]
|
|
517
|
+
+ [{"role": "user", "content": u} for u in self.latest_user_messages]
|
|
518
|
+
)
|
|
519
|
+
user_has_image = any(
|
|
520
|
+
self._content_has_image(content) for content in self.latest_user_prompt_content
|
|
521
|
+
)
|
|
522
|
+
assistant_has_image = self._content_has_image(assistant_structured)
|
|
461
523
|
record = {
|
|
462
524
|
"run_id": self.run_id,
|
|
463
525
|
"turn": self.current_turn,
|
|
464
526
|
"model": model_name,
|
|
465
527
|
"provider": provider,
|
|
466
|
-
"dialogue":
|
|
467
|
-
|
|
468
|
-
+ [{"role": "user", "content": u} for u in self.latest_user_messages]
|
|
469
|
-
),
|
|
528
|
+
"dialogue": dialogue_structured,
|
|
529
|
+
"dialogue_text": dialogue_text,
|
|
470
530
|
"assistant": {
|
|
471
|
-
"content":
|
|
472
|
-
"
|
|
531
|
+
"content": assistant_structured,
|
|
532
|
+
"content_text": assistant_text,
|
|
533
|
+
"tool_calls": assistant_message.get("tool_calls")
|
|
534
|
+
if isinstance(assistant_message, dict)
|
|
535
|
+
else [],
|
|
536
|
+
"has_image": assistant_has_image,
|
|
537
|
+
},
|
|
538
|
+
"metadata": {
|
|
539
|
+
"user_has_image": user_has_image,
|
|
540
|
+
"assistant_has_image": assistant_has_image,
|
|
541
|
+
"has_image": user_has_image or assistant_has_image,
|
|
473
542
|
},
|
|
474
543
|
"timestamp": datetime.utcnow().isoformat(),
|
|
475
544
|
}
|
|
@@ -479,20 +548,28 @@ class RolloutTracingContext:
|
|
|
479
548
|
self,
|
|
480
549
|
*,
|
|
481
550
|
env_handle: Any,
|
|
482
|
-
prev_obs:
|
|
551
|
+
prev_obs: dict[str, Any] | None,
|
|
483
552
|
env_response: Any,
|
|
484
|
-
next_obs:
|
|
485
|
-
metadata:
|
|
553
|
+
next_obs: dict[str, Any] | None,
|
|
554
|
+
metadata: dict[str, Any] | None = None,
|
|
486
555
|
) -> int | None:
|
|
487
556
|
if not self.enabled or self.tracer is None:
|
|
488
557
|
return None
|
|
489
558
|
|
|
490
559
|
try:
|
|
491
|
-
prev_summary =
|
|
560
|
+
prev_summary = (
|
|
561
|
+
_summarize_observation_for_storage(env_handle, prev_obs or {})
|
|
562
|
+
if prev_obs is not None
|
|
563
|
+
else None
|
|
564
|
+
)
|
|
492
565
|
except Exception:
|
|
493
566
|
prev_summary = None
|
|
494
567
|
try:
|
|
495
|
-
next_summary =
|
|
568
|
+
next_summary = (
|
|
569
|
+
_summarize_observation_for_storage(env_handle, next_obs or {})
|
|
570
|
+
if next_obs is not None
|
|
571
|
+
else None
|
|
572
|
+
)
|
|
496
573
|
except Exception:
|
|
497
574
|
next_summary = None
|
|
498
575
|
|
|
@@ -523,7 +600,7 @@ class RolloutTracingContext:
|
|
|
523
600
|
self,
|
|
524
601
|
*,
|
|
525
602
|
event_id: int | None,
|
|
526
|
-
decision_meta:
|
|
603
|
+
decision_meta: dict[str, Any] | None,
|
|
527
604
|
) -> None:
|
|
528
605
|
decision_meta = decision_meta or {}
|
|
529
606
|
ach_delta = int(decision_meta.get("ach_delta", 0))
|
|
@@ -571,7 +648,7 @@ class RolloutTracingContext:
|
|
|
571
648
|
self,
|
|
572
649
|
*,
|
|
573
650
|
total_reward: float,
|
|
574
|
-
achievement_state:
|
|
651
|
+
achievement_state: dict[str, bool] | None,
|
|
575
652
|
total_steps: int,
|
|
576
653
|
) -> Any:
|
|
577
654
|
final_achievements = [key for key, val in (achievement_state or {}).items() if val]
|
|
@@ -593,10 +670,8 @@ class RolloutTracingContext:
|
|
|
593
670
|
except Exception as exc:
|
|
594
671
|
logger.debug("TRACING_END_SESSION_FAIL: %s", exc)
|
|
595
672
|
self.session_trace = None
|
|
596
|
-
|
|
673
|
+
with contextlib.suppress(Exception):
|
|
597
674
|
await self.tracer.close()
|
|
598
|
-
except Exception:
|
|
599
|
-
pass
|
|
600
675
|
|
|
601
676
|
if self.sft_records and self.sft_output_dir:
|
|
602
677
|
self.write_sft_records()
|
|
@@ -622,7 +697,7 @@ class RolloutTracingContext:
|
|
|
622
697
|
finally:
|
|
623
698
|
self.sft_records.clear()
|
|
624
699
|
|
|
625
|
-
def build_trace_payload(self, session_trace: Any) ->
|
|
700
|
+
def build_trace_payload(self, session_trace: Any) -> dict[str, Any] | None:
|
|
626
701
|
if not self.return_trace or session_trace is None:
|
|
627
702
|
return None
|
|
628
703
|
if self.trace_format == "full":
|
|
@@ -640,28 +715,36 @@ class RolloutTracingContext:
|
|
|
640
715
|
"lm_calls": self.lm_calls_summary,
|
|
641
716
|
"decision_rewards": self.decision_rewards,
|
|
642
717
|
}
|
|
643
|
-
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def _summarize_observation_for_storage(
|
|
721
|
+
env_handle: Any, observation: dict[str, Any]
|
|
722
|
+
) -> dict[str, Any]:
|
|
644
723
|
"""Return a compact dict for trajectory storage instead of the raw observation.
|
|
645
724
|
|
|
646
725
|
- For Crafter, use the same summary used for the policy user prompt
|
|
647
726
|
- For others, keep a minimal subset or plain text preview
|
|
648
727
|
"""
|
|
649
728
|
# Try Crafter-specific formatter
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
729
|
+
crafter_wrapper = None
|
|
730
|
+
with contextlib.suppress(Exception):
|
|
731
|
+
from .envs.crafter.environment import (
|
|
732
|
+
CrafterEnvironmentWrapper as _CrafterWrapper, # type: ignore
|
|
733
|
+
)
|
|
654
734
|
|
|
655
|
-
|
|
656
|
-
|
|
735
|
+
crafter_wrapper = _CrafterWrapper # type: ignore[assignment]
|
|
736
|
+
|
|
737
|
+
if crafter_wrapper is not None and isinstance(
|
|
738
|
+
getattr(env_handle, "env", None), crafter_wrapper
|
|
739
|
+
):
|
|
740
|
+
with contextlib.suppress(Exception):
|
|
657
741
|
from .envs.crafter.shared import format_observation as _fmt # type: ignore
|
|
742
|
+
|
|
658
743
|
text = _fmt(observation or {})
|
|
659
744
|
return {"text": text}
|
|
660
|
-
except Exception:
|
|
661
|
-
pass
|
|
662
745
|
|
|
663
746
|
# Generic fallback: extract a few small fields if present; avoid huge arrays
|
|
664
|
-
|
|
747
|
+
with contextlib.suppress(Exception):
|
|
665
748
|
inv = observation.get("inventory") if isinstance(observation, dict) else None
|
|
666
749
|
ach = observation.get("achievements_status") if isinstance(observation, dict) else None
|
|
667
750
|
pos = observation.get("player_position") if isinstance(observation, dict) else None
|
|
@@ -671,12 +754,14 @@ def _summarize_observation_for_storage(env_handle: Any, observation: Dict[str, A
|
|
|
671
754
|
summary = {
|
|
672
755
|
"position": pos,
|
|
673
756
|
"health": health,
|
|
674
|
-
"inventory_keys": sorted(
|
|
675
|
-
|
|
757
|
+
"inventory_keys": sorted(k for k, v in (inv or {}).items() if v)[:10]
|
|
758
|
+
if isinstance(inv, dict)
|
|
759
|
+
else None,
|
|
760
|
+
"achievements_unlocked": sorted(k for k, v in (ach or {}).items() if v)[:10]
|
|
761
|
+
if isinstance(ach, dict)
|
|
762
|
+
else None,
|
|
676
763
|
}
|
|
677
764
|
return {"text": json.dumps(summary, ensure_ascii=False)}
|
|
678
|
-
except Exception:
|
|
679
|
-
pass
|
|
680
765
|
|
|
681
766
|
# Last resort: plain string preview
|
|
682
767
|
try:
|
|
@@ -685,7 +770,6 @@ def _summarize_observation_for_storage(env_handle: Any, observation: Dict[str, A
|
|
|
685
770
|
return {"text": ""}
|
|
686
771
|
|
|
687
772
|
|
|
688
|
-
|
|
689
773
|
class RunAbortRequest(BaseModel):
|
|
690
774
|
run_id: str
|
|
691
775
|
|
|
@@ -699,7 +783,7 @@ class RunStatusResponse(BaseModel):
|
|
|
699
783
|
run_id: str
|
|
700
784
|
status: str
|
|
701
785
|
started_at: datetime
|
|
702
|
-
finished_at:
|
|
786
|
+
finished_at: datetime | None = None
|
|
703
787
|
|
|
704
788
|
|
|
705
789
|
@router.post("/rollout", response_model=RolloutResponse)
|
|
@@ -708,6 +792,13 @@ async def execute_rollout(
|
|
|
708
792
|
req: Request,
|
|
709
793
|
) -> RolloutResponse:
|
|
710
794
|
"""Execute a rollout with coordinated environment and policy steps."""
|
|
795
|
+
# Emit rollout identifier early for correlation
|
|
796
|
+
with contextlib.suppress(Exception):
|
|
797
|
+
_rid = getattr(request, "run_id", None)
|
|
798
|
+
_pol = getattr(request.policy, "policy_name", None) or getattr(request.policy, "policy_id", None)
|
|
799
|
+
_env = getattr(request.env, "env_name", None) or getattr(request.env, "env_id", None)
|
|
800
|
+
logger.info("ROLLOUT_BEGIN: run_id=%s policy=%s env=%s", _rid, _pol, _env)
|
|
801
|
+
print(f"[rollout] begin run_id={_rid} policy={_pol} env={_env}", flush=True)
|
|
711
802
|
# Enforce per-episode step cap via env-specific parameters; default to 20 if omitted
|
|
712
803
|
try:
|
|
713
804
|
_env_params = {}
|
|
@@ -722,32 +813,30 @@ async def execute_rollout(
|
|
|
722
813
|
"error": "invalid_env_params",
|
|
723
814
|
"message": f"Invalid or missing env_params.max_steps_per_episode: {_mse}",
|
|
724
815
|
},
|
|
725
|
-
)
|
|
816
|
+
) from _mse
|
|
726
817
|
# Truncate incoming ops to the enforced cap (each step is [agent, env])
|
|
727
|
-
ops_seq:
|
|
818
|
+
ops_seq: list[str] = list(request.ops or [])
|
|
728
819
|
allowed_ops = max(0, int(max_steps_per_episode) * 2)
|
|
729
820
|
if len(ops_seq) > allowed_ops:
|
|
730
|
-
|
|
821
|
+
with contextlib.suppress(Exception):
|
|
731
822
|
logger.info(
|
|
732
823
|
"ROLL_OUT: truncating ops to cap: requested_ops=%s allowed_ops=%s",
|
|
733
824
|
str(len(ops_seq)),
|
|
734
825
|
str(allowed_ops),
|
|
735
826
|
)
|
|
736
|
-
except Exception:
|
|
737
|
-
pass
|
|
738
827
|
ops_seq = ops_seq[:allowed_ops]
|
|
739
828
|
# Simple API key auth for inbound rollout
|
|
740
829
|
header_key = req.headers.get("x-api-key")
|
|
741
830
|
env_key = os.getenv("ENVIRONMENT_API_KEY")
|
|
742
|
-
dev_key = os.getenv("
|
|
743
|
-
# Accept either ENVIRONMENT_API_KEY or
|
|
831
|
+
dev_key = os.getenv("DEV_ENVIRONMENT_API_KEY")
|
|
832
|
+
# Accept either ENVIRONMENT_API_KEY or DEV_ENVIRONMENT_API_KEY
|
|
744
833
|
expected_keys = [k for k in (env_key, dev_key) if k]
|
|
745
834
|
if not expected_keys:
|
|
746
835
|
missing = []
|
|
747
836
|
if not env_key:
|
|
748
837
|
missing.append("ENVIRONMENT_API_KEY")
|
|
749
838
|
if not dev_key:
|
|
750
|
-
missing.append("
|
|
839
|
+
missing.append("DEV_ENVIRONMENT_API_KEY")
|
|
751
840
|
msg = f"Auth not configured: missing {', '.join(missing)} in task service environment"
|
|
752
841
|
logger.error(msg)
|
|
753
842
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=msg)
|
|
@@ -773,33 +862,38 @@ async def execute_rollout(
|
|
|
773
862
|
logger.info(f"ROLL_OUT: synth_base_url={request.synth_base_url}")
|
|
774
863
|
|
|
775
864
|
# Log masked OpenAI API key presence for diagnostics
|
|
776
|
-
|
|
865
|
+
with contextlib.suppress(Exception):
|
|
777
866
|
_oa = os.getenv("OPENAI_API_KEY")
|
|
778
867
|
if _oa:
|
|
779
868
|
_pref = (_oa[:6] + "…") if len(_oa) >= 6 else "set"
|
|
780
869
|
logger.info(f"ROLL_OUT: OPENAI_API_KEY present (prefix={_pref})")
|
|
781
870
|
else:
|
|
782
871
|
logger.warning("ROLL_OUT: OPENAI_API_KEY missing")
|
|
783
|
-
except Exception:
|
|
784
|
-
pass
|
|
785
872
|
|
|
786
873
|
# Make synth_base_url available for outbound calls in this app
|
|
787
|
-
|
|
874
|
+
with contextlib.suppress(Exception):
|
|
788
875
|
task_app = req.app.state.task_app
|
|
789
876
|
if request.synth_base_url:
|
|
790
|
-
|
|
791
|
-
except Exception:
|
|
792
|
-
pass
|
|
877
|
+
task_app.synth_base_url = request.synth_base_url
|
|
793
878
|
|
|
794
879
|
tracer_factory = getattr(req.app.state, "session_tracer_factory", None)
|
|
795
|
-
tracer_instance = None
|
|
880
|
+
tracer_instance: SessionTracer | None = None
|
|
796
881
|
if callable(tracer_factory):
|
|
797
882
|
try:
|
|
798
|
-
|
|
883
|
+
inst = tracer_factory()
|
|
884
|
+
tracer_instance = inst if isinstance(inst, SessionTracer) else None
|
|
799
885
|
except Exception as exc:
|
|
800
886
|
logger.debug(f"TRACER_FACTORY_FAIL: {exc}")
|
|
801
887
|
tracing_context = RolloutTracingContext(tracer_instance, request, req)
|
|
802
888
|
await tracing_context.start_session()
|
|
889
|
+
# Print whether tracing is active for this rollout
|
|
890
|
+
try:
|
|
891
|
+
print(
|
|
892
|
+
f"[rollout] tracing enabled={bool(tracing_context.enabled)} run_id={request.run_id}",
|
|
893
|
+
flush=True,
|
|
894
|
+
)
|
|
895
|
+
except Exception:
|
|
896
|
+
pass
|
|
803
897
|
|
|
804
898
|
# Register run
|
|
805
899
|
registry.register_run(request.run_id)
|
|
@@ -808,10 +902,25 @@ async def execute_rollout(
|
|
|
808
902
|
created_env_id: str | None = None
|
|
809
903
|
created_policy_id: str | None = None
|
|
810
904
|
env_seed_used: int | None = None
|
|
905
|
+
trajectory_steps: list[RolloutStep] = []
|
|
906
|
+
decision_samples: list[dict[str, Any]] = []
|
|
907
|
+
pending_tool_calls: Any = None
|
|
908
|
+
current_obs: Any = {}
|
|
909
|
+
total_reward: float = 0.0
|
|
910
|
+
ops_executed = 0
|
|
911
|
+
last_agent_response_ts: float | None = None
|
|
912
|
+
last_policy_meta: dict[str, Any] | None = None
|
|
913
|
+
last_env_step_ms: float | None = None
|
|
914
|
+
last_env_step_completed_ts: float | None = None
|
|
915
|
+
decision_open = False
|
|
916
|
+
finalized = False
|
|
917
|
+
prev_achievements: dict[str, bool] = {}
|
|
918
|
+
session_trace = None
|
|
919
|
+
step_rewards_active = False
|
|
811
920
|
|
|
812
921
|
try:
|
|
813
922
|
# Initialize deterministic seed early for the entire rollout
|
|
814
|
-
seed_value:
|
|
923
|
+
seed_value: int | None = None
|
|
815
924
|
try:
|
|
816
925
|
if request.env and request.env.seed is not None:
|
|
817
926
|
seed_value = int(request.env.seed)
|
|
@@ -830,14 +939,12 @@ async def execute_rollout(
|
|
|
830
939
|
seed_value = 42
|
|
831
940
|
|
|
832
941
|
_seed_info = _set_global_seed(int(seed_value))
|
|
833
|
-
|
|
942
|
+
with contextlib.suppress(Exception):
|
|
834
943
|
logger.info(
|
|
835
944
|
"ROLL_OUT: RNG seeded seed=%s libs=%s",
|
|
836
945
|
str(_seed_info.get("seed")),
|
|
837
946
|
",".join(_seed_info.get("libs", [])),
|
|
838
947
|
)
|
|
839
|
-
except Exception:
|
|
840
|
-
pass
|
|
841
948
|
# Resolve or create environment
|
|
842
949
|
if request.env.env_id:
|
|
843
950
|
env_handle = registry.get_env(request.env.env_id)
|
|
@@ -849,7 +956,7 @@ async def execute_rollout(
|
|
|
849
956
|
env_id = request.env.env_id
|
|
850
957
|
else:
|
|
851
958
|
# Create new environment
|
|
852
|
-
from .environment_routes import
|
|
959
|
+
from .environment_routes import EnvCreateRequest, create_environment
|
|
853
960
|
|
|
854
961
|
if not request.env.env_name:
|
|
855
962
|
raise ValueError("FATAL: env_name is required - NO FALLBACKS!")
|
|
@@ -857,9 +964,7 @@ async def execute_rollout(
|
|
|
857
964
|
# Propagate training_session_id via env config for downstream usage
|
|
858
965
|
_env_config = dict(request.env.config or {})
|
|
859
966
|
if request.training_session_id is not None:
|
|
860
|
-
_env_config.setdefault(
|
|
861
|
-
"training_session_id", request.training_session_id
|
|
862
|
-
)
|
|
967
|
+
_env_config.setdefault("training_session_id", request.training_session_id)
|
|
863
968
|
env_response = await create_environment(
|
|
864
969
|
EnvCreateRequest(
|
|
865
970
|
env_name=request.env.env_name,
|
|
@@ -885,7 +990,7 @@ async def execute_rollout(
|
|
|
885
990
|
policy_id = request.policy.policy_id
|
|
886
991
|
else:
|
|
887
992
|
# Create new policy
|
|
888
|
-
from .policy_routes import
|
|
993
|
+
from .policy_routes import PolicyCreateRequest, create_policy
|
|
889
994
|
|
|
890
995
|
if not request.policy.policy_name:
|
|
891
996
|
raise ValueError("FATAL: policy_name is required - NO FALLBACKS!")
|
|
@@ -893,9 +998,7 @@ async def execute_rollout(
|
|
|
893
998
|
# Propagate training_session_id and synth_base_url via policy config
|
|
894
999
|
_policy_config = dict(request.policy.config or {})
|
|
895
1000
|
if request.training_session_id is not None:
|
|
896
|
-
_policy_config.setdefault(
|
|
897
|
-
"training_session_id", request.training_session_id
|
|
898
|
-
)
|
|
1001
|
+
_policy_config.setdefault("training_session_id", request.training_session_id)
|
|
899
1002
|
if request.synth_base_url is not None:
|
|
900
1003
|
_policy_config.setdefault("synth_base_url", request.synth_base_url)
|
|
901
1004
|
policy_response = await create_policy(
|
|
@@ -923,20 +1026,19 @@ async def execute_rollout(
|
|
|
923
1026
|
except Exception:
|
|
924
1027
|
env_seed_used = None
|
|
925
1028
|
tracing_context.update_metadata(env_seed=env_seed_used)
|
|
926
|
-
|
|
927
1029
|
# Initialize trajectory
|
|
928
1030
|
trajectory_steps = []
|
|
929
1031
|
pending_tool_calls = None
|
|
930
1032
|
current_obs = env_handle.last_observation
|
|
931
1033
|
total_reward = 0.0
|
|
932
1034
|
ops_executed = 0
|
|
933
|
-
last_agent_response_ts
|
|
934
|
-
last_policy_meta
|
|
935
|
-
last_env_step_ms
|
|
936
|
-
last_env_step_completed_ts
|
|
1035
|
+
last_agent_response_ts = None
|
|
1036
|
+
last_policy_meta = None
|
|
1037
|
+
last_env_step_ms = None
|
|
1038
|
+
last_env_step_completed_ts = None
|
|
937
1039
|
|
|
938
1040
|
# Stepwise reward configuration (Crafter shaping; gate on explicit enable)
|
|
939
|
-
step_rewards_cfg_raw:
|
|
1041
|
+
step_rewards_cfg_raw: dict[str, Any] = {}
|
|
940
1042
|
try:
|
|
941
1043
|
if isinstance(request.policy.config, dict):
|
|
942
1044
|
step_rewards_cfg_raw = dict(request.policy.config.get("step_rewards") or {})
|
|
@@ -963,7 +1065,7 @@ async def execute_rollout(
|
|
|
963
1065
|
step_rewards_beta = 0.0
|
|
964
1066
|
step_rewards_active = step_rewards_enabled and step_rewards_mode == "decision_stepwise"
|
|
965
1067
|
|
|
966
|
-
def _extract_achievements(obs: Any) ->
|
|
1068
|
+
def _extract_achievements(obs: Any) -> dict[str, bool]:
|
|
967
1069
|
if not isinstance(obs, dict):
|
|
968
1070
|
return {}
|
|
969
1071
|
ach = obs.get("achievements_status")
|
|
@@ -971,7 +1073,7 @@ async def execute_rollout(
|
|
|
971
1073
|
return {str(k): bool(v) for k, v in ach.items()}
|
|
972
1074
|
return {}
|
|
973
1075
|
|
|
974
|
-
def _summarize_tool_calls(tool_calls: Any) ->
|
|
1076
|
+
def _summarize_tool_calls(tool_calls: Any) -> list[dict[str, Any]]:
|
|
975
1077
|
if not tool_calls:
|
|
976
1078
|
return []
|
|
977
1079
|
try:
|
|
@@ -982,7 +1084,7 @@ async def execute_rollout(
|
|
|
982
1084
|
)
|
|
983
1085
|
except Exception:
|
|
984
1086
|
return []
|
|
985
|
-
summary:
|
|
1087
|
+
summary: list[dict[str, Any]] = []
|
|
986
1088
|
for tc in items:
|
|
987
1089
|
tool_name = None
|
|
988
1090
|
args: Any = {}
|
|
@@ -1001,16 +1103,16 @@ async def execute_rollout(
|
|
|
1001
1103
|
summary.append({"tool": tool_name, "args": args})
|
|
1002
1104
|
return summary
|
|
1003
1105
|
|
|
1004
|
-
decision_samples:
|
|
1106
|
+
decision_samples: list[dict[str, Any]] = []
|
|
1005
1107
|
decision_index = 0
|
|
1006
1108
|
decision_open = False
|
|
1007
1109
|
session_trace = None
|
|
1008
1110
|
finalized = False
|
|
1009
1111
|
prev_achievements = _extract_achievements(current_obs)
|
|
1010
1112
|
# Track episode-level achievements that have been seen as true at any point so far
|
|
1011
|
-
episode_seen_achievements: set[str] =
|
|
1012
|
-
|
|
1013
|
-
|
|
1113
|
+
episode_seen_achievements: set[str] = {
|
|
1114
|
+
k for k, v in (prev_achievements or {}).items() if bool(v)
|
|
1115
|
+
}
|
|
1014
1116
|
stepwise_indicator_sum = 0.0
|
|
1015
1117
|
stepwise_reward_sum = 0.0
|
|
1016
1118
|
stepwise_new_achievements_total = 0
|
|
@@ -1030,7 +1132,7 @@ async def execute_rollout(
|
|
|
1030
1132
|
|
|
1031
1133
|
if op == "agent":
|
|
1032
1134
|
# Policy step
|
|
1033
|
-
from .policy_routes import
|
|
1135
|
+
from .policy_routes import PolicyStepRequest, step_policy
|
|
1034
1136
|
|
|
1035
1137
|
if not decision_open:
|
|
1036
1138
|
await tracing_context.start_decision(decision_index)
|
|
@@ -1038,7 +1140,7 @@ async def execute_rollout(
|
|
|
1038
1140
|
|
|
1039
1141
|
agent_request_start = _time.perf_counter()
|
|
1040
1142
|
if last_agent_response_ts is not None and last_policy_meta is not None:
|
|
1041
|
-
|
|
1143
|
+
with contextlib.suppress(Exception):
|
|
1042
1144
|
timing_prev = last_policy_meta.setdefault("timing", {})
|
|
1043
1145
|
decision_ms = max(
|
|
1044
1146
|
0.0,
|
|
@@ -1057,7 +1159,7 @@ async def execute_rollout(
|
|
|
1057
1159
|
# Also backfill the last appended trajectory step so the trainer
|
|
1058
1160
|
# can always see decision_ms without relying on shared dict refs.
|
|
1059
1161
|
if trajectory_steps:
|
|
1060
|
-
|
|
1162
|
+
with contextlib.suppress(Exception):
|
|
1061
1163
|
_last = trajectory_steps[-1]
|
|
1062
1164
|
_info = dict(_last.info or {})
|
|
1063
1165
|
_meta = dict(_info.get("meta") or {})
|
|
@@ -1065,16 +1167,15 @@ async def execute_rollout(
|
|
|
1065
1167
|
_timing["decision_ms"] = decision_ms
|
|
1066
1168
|
if last_env_step_ms is not None:
|
|
1067
1169
|
_timing.setdefault("env_step_ms", float(last_env_step_ms))
|
|
1068
|
-
_timing.setdefault(
|
|
1170
|
+
_timing.setdefault(
|
|
1171
|
+
"overhead_ms",
|
|
1172
|
+
max(0.0, decision_ms - float(last_env_step_ms)),
|
|
1173
|
+
)
|
|
1069
1174
|
else:
|
|
1070
1175
|
_timing.setdefault("overhead_ms", 0.0)
|
|
1071
1176
|
_meta["timing"] = _timing
|
|
1072
1177
|
_info["meta"] = _meta
|
|
1073
1178
|
_last.info = _info
|
|
1074
|
-
except Exception:
|
|
1075
|
-
pass
|
|
1076
|
-
except Exception:
|
|
1077
|
-
pass
|
|
1078
1179
|
last_env_step_ms = None
|
|
1079
1180
|
last_env_step_completed_ts = None
|
|
1080
1181
|
|
|
@@ -1097,39 +1198,25 @@ async def execute_rollout(
|
|
|
1097
1198
|
}
|
|
1098
1199
|
|
|
1099
1200
|
# Log compact metadata summary to confirm history threading
|
|
1100
|
-
|
|
1101
|
-
_prev_calls = (
|
|
1102
|
-
metadata["prev_tool_calls"]
|
|
1103
|
-
if isinstance(metadata, dict) and "prev_tool_calls" in metadata
|
|
1104
|
-
else None
|
|
1105
|
-
)
|
|
1201
|
+
with contextlib.suppress(Exception):
|
|
1202
|
+
_prev_calls = metadata.get("prev_tool_calls")
|
|
1106
1203
|
_count = len(_prev_calls) if isinstance(_prev_calls, list) else 0
|
|
1107
1204
|
_first_guess = None
|
|
1108
1205
|
if _count > 0 and isinstance(_prev_calls[0], dict):
|
|
1109
|
-
_args = (
|
|
1110
|
-
_prev_calls[0]["arguments"]
|
|
1111
|
-
if "arguments" in _prev_calls[0]
|
|
1112
|
-
else None
|
|
1113
|
-
)
|
|
1206
|
+
_args = _prev_calls[0].get("arguments", None)
|
|
1114
1207
|
if isinstance(_args, str):
|
|
1115
1208
|
import json as _json
|
|
1116
|
-
|
|
1117
|
-
try:
|
|
1209
|
+
with contextlib.suppress(Exception):
|
|
1118
1210
|
_args = _json.loads(_args)
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
_first_guess = (
|
|
1123
|
-
_args["guess"] if "guess" in _args else None
|
|
1124
|
-
) or (_args["word"] if "word" in _args else None)
|
|
1211
|
+
if not isinstance(_args, dict):
|
|
1212
|
+
_args = {}
|
|
1213
|
+
_first_guess = _args.get("guess") or _args.get("word")
|
|
1125
1214
|
logger.info(
|
|
1126
1215
|
"POLICY_METADATA: prev_tool_calls=%d first_guess=%r has_prev_env_result=%s",
|
|
1127
1216
|
_count,
|
|
1128
1217
|
_first_guess,
|
|
1129
1218
|
str("prev_env_result" in metadata),
|
|
1130
1219
|
)
|
|
1131
|
-
except Exception:
|
|
1132
|
-
pass
|
|
1133
1220
|
|
|
1134
1221
|
try:
|
|
1135
1222
|
policy_response = await step_policy(
|
|
@@ -1142,15 +1229,13 @@ async def execute_rollout(
|
|
|
1142
1229
|
)
|
|
1143
1230
|
except Exception as _pe:
|
|
1144
1231
|
# Do not 500 the rollout; finalize with partial trajectory
|
|
1145
|
-
|
|
1232
|
+
with contextlib.suppress(Exception):
|
|
1146
1233
|
logger.warning(
|
|
1147
1234
|
"POLICY_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
|
|
1148
1235
|
request.run_id,
|
|
1149
1236
|
str(op_idx),
|
|
1150
1237
|
str(_pe),
|
|
1151
1238
|
)
|
|
1152
|
-
except Exception:
|
|
1153
|
-
pass
|
|
1154
1239
|
|
|
1155
1240
|
# Build partial trajectory and return HTTP 200
|
|
1156
1241
|
trajectory = RolloutTrajectory(
|
|
@@ -1198,12 +1283,12 @@ async def execute_rollout(
|
|
|
1198
1283
|
|
|
1199
1284
|
agent_response_ts = _time.perf_counter()
|
|
1200
1285
|
if isinstance(policy_response.meta, dict):
|
|
1201
|
-
|
|
1286
|
+
with contextlib.suppress(Exception):
|
|
1202
1287
|
timing_cur = policy_response.meta.setdefault("timing", {})
|
|
1203
1288
|
timing_cur["agent_request_start_s"] = agent_request_start
|
|
1204
1289
|
timing_cur["agent_response_s"] = agent_response_ts
|
|
1205
1290
|
if "inference_ms" in policy_response.meta:
|
|
1206
|
-
|
|
1291
|
+
with contextlib.suppress(Exception):
|
|
1207
1292
|
timing_cur.setdefault(
|
|
1208
1293
|
"inference_ms",
|
|
1209
1294
|
float(policy_response.meta["inference_ms"]),
|
|
@@ -1212,30 +1297,66 @@ async def execute_rollout(
|
|
|
1212
1297
|
"inference_s",
|
|
1213
1298
|
float(policy_response.meta["inference_ms"]) / 1000.0,
|
|
1214
1299
|
)
|
|
1215
|
-
except Exception:
|
|
1216
|
-
pass
|
|
1217
|
-
except Exception:
|
|
1218
|
-
pass
|
|
1219
1300
|
last_policy_meta = policy_response.meta
|
|
1220
1301
|
else:
|
|
1221
1302
|
last_policy_meta = None
|
|
1222
1303
|
last_agent_response_ts = agent_response_ts
|
|
1223
1304
|
|
|
1305
|
+
# Diagnostic: summarize policy step target and tool calls
|
|
1306
|
+
try:
|
|
1307
|
+
model_name = None
|
|
1308
|
+
target_url = None
|
|
1309
|
+
if isinstance(policy_response.meta, dict):
|
|
1310
|
+
req_body = policy_response.meta.get("inference_request") or {}
|
|
1311
|
+
model_name = req_body.get("model")
|
|
1312
|
+
target_url = policy_response.meta.get("inference_url")
|
|
1313
|
+
_tc = policy_response.tool_calls or []
|
|
1314
|
+
print(
|
|
1315
|
+
{
|
|
1316
|
+
"rollout.policy_step": True,
|
|
1317
|
+
"run_id": request.run_id,
|
|
1318
|
+
"model": model_name,
|
|
1319
|
+
"inference_url": target_url,
|
|
1320
|
+
"tool_calls_count": len(_tc) if isinstance(_tc, list) else 0,
|
|
1321
|
+
},
|
|
1322
|
+
flush=True,
|
|
1323
|
+
)
|
|
1324
|
+
except Exception:
|
|
1325
|
+
pass
|
|
1326
|
+
|
|
1224
1327
|
pending_tool_calls = policy_response.tool_calls
|
|
1328
|
+
# Log summarized agent tool calls
|
|
1329
|
+
with contextlib.suppress(Exception):
|
|
1330
|
+
_tc = pending_tool_calls or []
|
|
1331
|
+
_summary = []
|
|
1332
|
+
for _item in (_tc if isinstance(_tc, list) else []):
|
|
1333
|
+
try:
|
|
1334
|
+
if isinstance(_item, dict):
|
|
1335
|
+
_tool = _item.get("tool")
|
|
1336
|
+
_args = _item.get("args")
|
|
1337
|
+
_keys = list(_args.keys()) if isinstance(_args, dict) else []
|
|
1338
|
+
_summary.append({"tool": _tool, "args_keys": _keys})
|
|
1339
|
+
except Exception:
|
|
1340
|
+
continue
|
|
1341
|
+
_rid = getattr(request, "run_id", None)
|
|
1342
|
+
logger.info("AGENT_TOOL_CALLS: run_id=%s count=%d summary=%s", _rid, len(_tc), _summary)
|
|
1343
|
+
print(f"[rollout] agent tool_calls run_id={_rid} count={len(_tc)} summary={_summary}", flush=True)
|
|
1225
1344
|
await tracing_context.record_tool_invocation(pending_tool_calls)
|
|
1226
1345
|
ops_executed += 1
|
|
1227
1346
|
|
|
1228
1347
|
elif op == "env":
|
|
1229
1348
|
if not pending_tool_calls:
|
|
1230
1349
|
# Treat absence of tool calls as a soft terminal condition; yield partial trajectory
|
|
1231
|
-
|
|
1350
|
+
with contextlib.suppress(Exception):
|
|
1232
1351
|
logger.warning(
|
|
1233
1352
|
"NO_TOOL_CALLS: terminating episode early run_id=%s op_idx=%s",
|
|
1234
1353
|
request.run_id,
|
|
1235
1354
|
str(op_idx),
|
|
1236
1355
|
)
|
|
1237
|
-
|
|
1238
|
-
|
|
1356
|
+
print(
|
|
1357
|
+
f"[rollout] no tool_calls; terminating early run_id={request.run_id} op_idx={op_idx}",
|
|
1358
|
+
flush=True,
|
|
1359
|
+
)
|
|
1239
1360
|
term_step = RolloutStep(
|
|
1240
1361
|
obs=current_obs,
|
|
1241
1362
|
tool_calls=[],
|
|
@@ -1291,7 +1412,7 @@ async def execute_rollout(
|
|
|
1291
1412
|
)
|
|
1292
1413
|
|
|
1293
1414
|
# Environment step
|
|
1294
|
-
from .environment_routes import
|
|
1415
|
+
from .environment_routes import EnvStepRequest, step_environment
|
|
1295
1416
|
|
|
1296
1417
|
env_step_error: Exception | None = None
|
|
1297
1418
|
env_response = None
|
|
@@ -1310,24 +1431,20 @@ async def execute_rollout(
|
|
|
1310
1431
|
last_env_step_ms = env_step_duration_ms
|
|
1311
1432
|
last_env_step_completed_ts = env_step_end
|
|
1312
1433
|
if last_policy_meta is not None:
|
|
1313
|
-
|
|
1434
|
+
with contextlib.suppress(Exception):
|
|
1314
1435
|
timing_env = last_policy_meta.setdefault("timing", {})
|
|
1315
1436
|
timing_env["env_step_ms"] = env_step_duration_ms
|
|
1316
1437
|
timing_env["env_step_end_s"] = env_step_end
|
|
1317
|
-
except Exception:
|
|
1318
|
-
pass
|
|
1319
1438
|
|
|
1320
1439
|
if env_step_error is not None:
|
|
1321
1440
|
# Invalid action or environment rejection — terminate episode early with partial trajectory
|
|
1322
|
-
|
|
1441
|
+
with contextlib.suppress(Exception):
|
|
1323
1442
|
logger.warning(
|
|
1324
1443
|
"ENV_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
|
|
1325
1444
|
request.run_id,
|
|
1326
1445
|
str(op_idx),
|
|
1327
1446
|
str(env_step_error),
|
|
1328
1447
|
)
|
|
1329
|
-
except Exception:
|
|
1330
|
-
pass
|
|
1331
1448
|
|
|
1332
1449
|
term_step = RolloutStep(
|
|
1333
1450
|
obs=current_obs,
|
|
@@ -1370,16 +1487,16 @@ async def execute_rollout(
|
|
|
1370
1487
|
and last_agent_response_ts is not None
|
|
1371
1488
|
and "decision_ms" not in last_policy_meta.get("timing", {})
|
|
1372
1489
|
):
|
|
1373
|
-
|
|
1490
|
+
with contextlib.suppress(Exception):
|
|
1374
1491
|
timing_last = last_policy_meta.setdefault("timing", {})
|
|
1375
1492
|
decision_ms = max(
|
|
1376
1493
|
0.0,
|
|
1377
1494
|
(env_step_end - float(last_agent_response_ts)) * 1000.0,
|
|
1378
1495
|
)
|
|
1379
1496
|
timing_last["decision_ms"] = decision_ms
|
|
1380
|
-
timing_last.setdefault(
|
|
1381
|
-
|
|
1382
|
-
|
|
1497
|
+
timing_last.setdefault(
|
|
1498
|
+
"overhead_ms", max(0.0, decision_ms - env_step_duration_ms)
|
|
1499
|
+
)
|
|
1383
1500
|
if decision_open:
|
|
1384
1501
|
await tracing_context.end_decision()
|
|
1385
1502
|
decision_open = False
|
|
@@ -1407,17 +1524,13 @@ async def execute_rollout(
|
|
|
1407
1524
|
# Record step, including policy meta if present for timing/tokens observability
|
|
1408
1525
|
_info = env_response.info if isinstance(env_response.info, dict) else {}
|
|
1409
1526
|
# Attach policy meta from the immediately preceding agent step
|
|
1410
|
-
|
|
1527
|
+
with contextlib.suppress(Exception):
|
|
1411
1528
|
prev_meta = {}
|
|
1412
|
-
if "policy_response" in locals() and isinstance(
|
|
1413
|
-
policy_response.meta, dict
|
|
1414
|
-
): # type: ignore[name-defined]
|
|
1529
|
+
if "policy_response" in locals() and isinstance(policy_response.meta, dict): # type: ignore[name-defined]
|
|
1415
1530
|
prev_meta = policy_response.meta
|
|
1416
1531
|
if prev_meta:
|
|
1417
1532
|
_info = dict(_info)
|
|
1418
1533
|
_info["meta"] = prev_meta
|
|
1419
|
-
except Exception:
|
|
1420
|
-
pass
|
|
1421
1534
|
|
|
1422
1535
|
event_metadata = {
|
|
1423
1536
|
"op_index": op_idx,
|
|
@@ -1438,7 +1551,7 @@ async def execute_rollout(
|
|
|
1438
1551
|
)
|
|
1439
1552
|
indicator_val = 0
|
|
1440
1553
|
reward_stepwise = 0.0
|
|
1441
|
-
decision_rewards_meta:
|
|
1554
|
+
decision_rewards_meta: dict[str, Any] | None = None
|
|
1442
1555
|
if step_rewards_active:
|
|
1443
1556
|
decision_actions = _summarize_tool_calls(pending_tool_calls)
|
|
1444
1557
|
stepwise_info, decision_record, stats = compute_stepwise_reward(
|
|
@@ -1452,25 +1565,22 @@ async def execute_rollout(
|
|
|
1452
1565
|
reward_stepwise = float(stats.get("reward", 0.0))
|
|
1453
1566
|
stepwise_indicator_sum += float(stats.get("indicator", 0.0))
|
|
1454
1567
|
stepwise_reward_sum += reward_stepwise
|
|
1455
|
-
stepwise_new_achievements_total += int(
|
|
1456
|
-
|
|
1457
|
-
)
|
|
1458
|
-
if not isinstance(_info, dict):
|
|
1459
|
-
_info = {}
|
|
1460
|
-
else:
|
|
1461
|
-
_info = dict(_info)
|
|
1568
|
+
stepwise_new_achievements_total += int(stats.get("new_achievements_count", 0.0))
|
|
1569
|
+
_info = {} if not isinstance(_info, dict) else dict(_info)
|
|
1462
1570
|
_info["stepwise"] = stepwise_info
|
|
1463
1571
|
# Compute decision-level rewards (absolute vs unique) and attach to metadata
|
|
1464
|
-
|
|
1572
|
+
with contextlib.suppress(Exception):
|
|
1465
1573
|
turned_true = set(stepwise_info.get("new_achievements") or [])
|
|
1466
1574
|
seen_before = set(episode_seen_achievements)
|
|
1467
|
-
new_unique = sorted(
|
|
1575
|
+
new_unique = sorted(turned_true - seen_before)
|
|
1468
1576
|
ach_delta = int(len(turned_true))
|
|
1469
1577
|
unique_delta = int(len(new_unique))
|
|
1470
1578
|
# Prepare stable lists for logging/metadata
|
|
1471
|
-
all_list = sorted(
|
|
1579
|
+
all_list = sorted(turned_true)
|
|
1472
1580
|
# Ensure nested meta exists
|
|
1473
|
-
meta_block =
|
|
1581
|
+
meta_block = (
|
|
1582
|
+
_info.get("meta") if isinstance(_info.get("meta"), dict) else {}
|
|
1583
|
+
)
|
|
1474
1584
|
decision_rewards = {
|
|
1475
1585
|
"turn": int(decision_index),
|
|
1476
1586
|
"ach_delta": ach_delta,
|
|
@@ -1483,9 +1593,6 @@ async def execute_rollout(
|
|
|
1483
1593
|
_info["meta"] = meta_block
|
|
1484
1594
|
# Update episode-level seen set after attributing uniqueness to this decision
|
|
1485
1595
|
episode_seen_achievements.update(turned_true)
|
|
1486
|
-
except Exception:
|
|
1487
|
-
# Best-effort; do not block rollout on metadata computation
|
|
1488
|
-
pass
|
|
1489
1596
|
decision_samples.append(decision_record)
|
|
1490
1597
|
prev_achievements = new_achievement_state
|
|
1491
1598
|
|
|
@@ -1502,6 +1609,32 @@ async def execute_rollout(
|
|
|
1502
1609
|
truncated=env_response.truncated,
|
|
1503
1610
|
info=_info,
|
|
1504
1611
|
)
|
|
1612
|
+
# Log summarized env application of tool calls and immediate reward/done
|
|
1613
|
+
with contextlib.suppress(Exception):
|
|
1614
|
+
_tc = pending_tool_calls or []
|
|
1615
|
+
_summary = []
|
|
1616
|
+
for _item in (_tc if isinstance(_tc, list) else []):
|
|
1617
|
+
try:
|
|
1618
|
+
if isinstance(_item, dict):
|
|
1619
|
+
_tool = _item.get("tool")
|
|
1620
|
+
_args = _item.get("args")
|
|
1621
|
+
_keys = list(_args.keys()) if isinstance(_args, dict) else []
|
|
1622
|
+
_summary.append({"tool": _tool, "args_keys": _keys})
|
|
1623
|
+
except Exception:
|
|
1624
|
+
continue
|
|
1625
|
+
_rid = getattr(request, "run_id", None)
|
|
1626
|
+
logger.info(
|
|
1627
|
+
"ENV_APPLY: run_id=%s tool_calls=%d reward=%s done=%s summary=%s",
|
|
1628
|
+
_rid,
|
|
1629
|
+
len(_tc),
|
|
1630
|
+
str(env_response.reward),
|
|
1631
|
+
str(env_response.done),
|
|
1632
|
+
_summary,
|
|
1633
|
+
)
|
|
1634
|
+
print(
|
|
1635
|
+
f"[rollout] env apply run_id={_rid} tool_calls={len(_tc)} reward={env_response.reward} done={env_response.done} summary={_summary}",
|
|
1636
|
+
flush=True,
|
|
1637
|
+
)
|
|
1505
1638
|
trajectory_steps.append(step)
|
|
1506
1639
|
|
|
1507
1640
|
if env_response.reward is not None:
|
|
@@ -1517,13 +1650,11 @@ async def execute_rollout(
|
|
|
1517
1650
|
if request.on_done == "reset":
|
|
1518
1651
|
# Reset environment
|
|
1519
1652
|
from .environment_routes import (
|
|
1520
|
-
reset_environment,
|
|
1521
1653
|
EnvResetRequest,
|
|
1654
|
+
reset_environment,
|
|
1522
1655
|
)
|
|
1523
1656
|
|
|
1524
|
-
reset_response = await reset_environment(
|
|
1525
|
-
EnvResetRequest(env_id=env_id)
|
|
1526
|
-
)
|
|
1657
|
+
reset_response = await reset_environment(EnvResetRequest(env_id=env_id))
|
|
1527
1658
|
current_obs = reset_response.observation
|
|
1528
1659
|
elif request.on_done == "terminate":
|
|
1529
1660
|
break
|
|
@@ -1542,25 +1673,19 @@ async def execute_rollout(
|
|
|
1542
1673
|
and isinstance(last_policy_meta["timing"], dict)
|
|
1543
1674
|
and "decision_ms" not in last_policy_meta["timing"]
|
|
1544
1675
|
):
|
|
1545
|
-
|
|
1676
|
+
with contextlib.suppress(Exception):
|
|
1546
1677
|
final_now = last_env_step_completed_ts or _time.perf_counter()
|
|
1547
|
-
final_decision_ms = max(
|
|
1548
|
-
0.0, (final_now - float(last_agent_response_ts)) * 1000.0
|
|
1549
|
-
)
|
|
1678
|
+
final_decision_ms = max(0.0, (final_now - float(last_agent_response_ts)) * 1000.0)
|
|
1550
1679
|
timing_final = last_policy_meta.setdefault("timing", {})
|
|
1551
1680
|
timing_final["decision_ms"] = final_decision_ms
|
|
1552
1681
|
if last_env_step_ms is not None:
|
|
1553
|
-
timing_final.setdefault(
|
|
1554
|
-
"env_step_ms", float(last_env_step_ms)
|
|
1555
|
-
)
|
|
1682
|
+
timing_final.setdefault("env_step_ms", float(last_env_step_ms))
|
|
1556
1683
|
timing_final.setdefault(
|
|
1557
1684
|
"overhead_ms",
|
|
1558
1685
|
max(0.0, final_decision_ms - float(last_env_step_ms)),
|
|
1559
1686
|
)
|
|
1560
1687
|
else:
|
|
1561
1688
|
timing_final.setdefault("overhead_ms", 0.0)
|
|
1562
|
-
except Exception:
|
|
1563
|
-
pass
|
|
1564
1689
|
|
|
1565
1690
|
# Build trajectory
|
|
1566
1691
|
trajectory = RolloutTrajectory(
|
|
@@ -1583,28 +1708,35 @@ async def execute_rollout(
|
|
|
1583
1708
|
# Environment-specific: Log summary if available
|
|
1584
1709
|
try:
|
|
1585
1710
|
# Check if this is a Wordle environment and use Wordle helpers (lazy import)
|
|
1711
|
+
wordle_wrapper_cls = None
|
|
1586
1712
|
try:
|
|
1587
|
-
from .envs.wordle.environment import WordleEnvironmentWrapper
|
|
1713
|
+
from .envs.wordle.environment import WordleEnvironmentWrapper
|
|
1588
1714
|
from .envs.wordle.helpers import (
|
|
1589
1715
|
get_wordle_rollout_summary,
|
|
1590
1716
|
log_wordle_rollout_summary,
|
|
1591
1717
|
)
|
|
1718
|
+
|
|
1719
|
+
wordle_wrapper_cls = WordleEnvironmentWrapper
|
|
1592
1720
|
except Exception:
|
|
1593
|
-
|
|
1721
|
+
wordle_wrapper_cls = None # type: ignore[assignment]
|
|
1594
1722
|
get_wordle_rollout_summary = None # type: ignore
|
|
1595
1723
|
log_wordle_rollout_summary = None # type: ignore
|
|
1596
1724
|
|
|
1597
|
-
is_wordle =
|
|
1725
|
+
is_wordle = wordle_wrapper_cls is not None and isinstance(
|
|
1726
|
+
env_handle.env,
|
|
1727
|
+
wordle_wrapper_cls, # type: ignore[arg-type]
|
|
1728
|
+
)
|
|
1598
1729
|
if is_wordle:
|
|
1599
1730
|
# Convert trajectory steps to expected format
|
|
1600
1731
|
formatted_steps = []
|
|
1601
1732
|
for step in trajectory_steps:
|
|
1602
1733
|
formatted_steps.append({"tool_calls": step.tool_calls or []})
|
|
1603
1734
|
|
|
1604
|
-
if
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1735
|
+
if (
|
|
1736
|
+
get_wordle_rollout_summary is not None
|
|
1737
|
+
and log_wordle_rollout_summary is not None
|
|
1738
|
+
):
|
|
1739
|
+
summary = get_wordle_rollout_summary(formatted_steps, current_obs, env_handle)
|
|
1608
1740
|
log_wordle_rollout_summary(request.run_id, summary)
|
|
1609
1741
|
except ImportError:
|
|
1610
1742
|
# Wordle helpers not available, skip Wordle-specific logging
|
|
@@ -1642,27 +1774,24 @@ async def execute_rollout(
|
|
|
1642
1774
|
logger.error(f"Rollout failed for run {request.run_id}: {e}")
|
|
1643
1775
|
registry.abort_run(request.run_id)
|
|
1644
1776
|
if decision_open:
|
|
1645
|
-
|
|
1777
|
+
with contextlib.suppress(Exception):
|
|
1646
1778
|
await tracing_context.end_decision()
|
|
1647
|
-
except Exception:
|
|
1648
|
-
pass
|
|
1649
1779
|
decision_open = False
|
|
1650
1780
|
if not finalized:
|
|
1651
|
-
|
|
1781
|
+
session_trace = None
|
|
1782
|
+
with contextlib.suppress(Exception):
|
|
1652
1783
|
session_trace = await tracing_context.finalize(
|
|
1653
1784
|
total_reward=total_reward,
|
|
1654
1785
|
achievement_state=prev_achievements,
|
|
1655
1786
|
total_steps=len(trajectory_steps),
|
|
1656
1787
|
)
|
|
1657
|
-
except Exception:
|
|
1658
|
-
session_trace = None
|
|
1659
1788
|
finalized = True
|
|
1660
|
-
raise HTTPException(status_code=500, detail=str(e))
|
|
1789
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
1661
1790
|
finally:
|
|
1662
1791
|
# Ensure any environment created for this rollout is terminated (no reuse across rollouts)
|
|
1663
1792
|
try:
|
|
1664
1793
|
if created_env_id:
|
|
1665
|
-
from .environment_routes import
|
|
1794
|
+
from .environment_routes import EnvTerminateRequest, terminate_environment
|
|
1666
1795
|
|
|
1667
1796
|
await terminate_environment(EnvTerminateRequest(env_id=created_env_id))
|
|
1668
1797
|
logger.info(
|
|
@@ -1671,46 +1800,37 @@ async def execute_rollout(
|
|
|
1671
1800
|
str(env_seed_used) if env_seed_used is not None else "unknown",
|
|
1672
1801
|
)
|
|
1673
1802
|
# Verify removal from registry
|
|
1674
|
-
|
|
1803
|
+
with contextlib.suppress(Exception):
|
|
1675
1804
|
_post = registry.get_env(created_env_id)
|
|
1676
1805
|
logger.info(
|
|
1677
1806
|
"ROLL_OUT: env_killed=%s (post_lookup=%s)",
|
|
1678
1807
|
str(_post is None),
|
|
1679
1808
|
str(_post),
|
|
1680
1809
|
)
|
|
1681
|
-
except Exception:
|
|
1682
|
-
pass
|
|
1683
1810
|
except Exception as _te:
|
|
1684
|
-
logger.warning(
|
|
1685
|
-
f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}"
|
|
1686
|
-
)
|
|
1811
|
+
logger.warning(f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}")
|
|
1687
1812
|
|
|
1688
1813
|
# Best-effort policy cleanup if we created one (avoid reuse across rollouts)
|
|
1689
|
-
|
|
1814
|
+
with contextlib.suppress(Exception):
|
|
1690
1815
|
if created_policy_id:
|
|
1691
|
-
from .policy_routes import
|
|
1816
|
+
from .policy_routes import PolicyTerminateRequest, terminate_policy
|
|
1692
1817
|
|
|
1693
1818
|
await terminate_policy(PolicyTerminateRequest(policy_id=created_policy_id))
|
|
1694
1819
|
logger.info("ROLL_OUT: terminated policy policy_id=%s", str(created_policy_id))
|
|
1695
|
-
except Exception:
|
|
1696
|
-
pass
|
|
1697
1820
|
|
|
1698
1821
|
if not finalized:
|
|
1699
|
-
|
|
1822
|
+
session_trace = None
|
|
1823
|
+
with contextlib.suppress(Exception):
|
|
1700
1824
|
session_trace = await tracing_context.finalize(
|
|
1701
1825
|
total_reward=total_reward,
|
|
1702
1826
|
achievement_state=prev_achievements,
|
|
1703
1827
|
total_steps=len(trajectory_steps),
|
|
1704
1828
|
)
|
|
1705
|
-
except Exception:
|
|
1706
|
-
session_trace = None
|
|
1707
1829
|
finalized = True
|
|
1708
1830
|
|
|
1709
|
-
|
|
1831
|
+
with contextlib.suppress(Exception):
|
|
1710
1832
|
_clear_seed_side_effects()
|
|
1711
1833
|
logger.info("ROLL_OUT: RNG seed terminated/cleared before conclusion")
|
|
1712
|
-
except Exception:
|
|
1713
|
-
pass
|
|
1714
1834
|
|
|
1715
1835
|
|
|
1716
1836
|
@router.post("/run/abort", response_model=RunAbortResponse)
|