synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.9.dev8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +8 -11
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +64 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +18 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +38 -0
- examples/qwen_coder/validate_jsonl.py +59 -0
- examples/rl/run_eval.py +36 -37
- examples/rl/run_rl_and_save.py +5 -5
- examples/rl/task_app/math_single_step.py +65 -43
- examples/rl/task_app/math_task_app.py +3 -3
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +5 -5
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +78 -21
- examples/warming_up_to_rl/groq_test.py +4 -4
- examples/warming_up_to_rl/manage_secrets.py +13 -18
- examples/warming_up_to_rl/run_eval.py +42 -44
- examples/warming_up_to_rl/run_fft_and_save.py +11 -16
- examples/warming_up_to_rl/run_local_rollout.py +1 -3
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
- examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
- examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
- examples/warming_up_to_rl/run_rl_and_save.py +5 -6
- examples/warming_up_to_rl/run_rollout_remote.py +8 -10
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth/__init__.py +14 -0
- synth_ai/__init__.py +26 -4
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +128 -21
- synth_ai/api/train/cli.py +80 -64
- synth_ai/api/train/config_finder.py +7 -2
- synth_ai/api/train/env_resolver.py +1 -1
- synth_ai/api/train/pollers.py +2 -1
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +1 -2
- synth_ai/api/train/utils.py +13 -44
- synth_ai/cli/__init__.py +8 -0
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +1 -2
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +2 -1
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +2 -1
- synth_ai/cli/root.py +11 -13
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +529 -179
- synth_ai/cli/traces.py +6 -4
- synth_ai/cli/watch.py +12 -18
- synth_ai/demo_registry.py +1 -1
- synth_ai/demos/core/cli.py +36 -43
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +17 -25
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +2 -5
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +4 -7
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/handshake.py +9 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +18 -10
- synth_ai/inference/client.py +15 -5
- synth_ai/jobs/client.py +78 -83
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +91 -24
- synth_ai/learning/config.py +2 -38
- synth_ai/learning/ft_client.py +4 -59
- synth_ai/learning/health.py +5 -6
- synth_ai/learning/jobs.py +31 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -4
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -8
- synth_ai/{rl → learning/rl}/env_keys.py +39 -15
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -281
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -24
- synth_ai/learning/validators.py +25 -28
- synth_ai/lm/__init__.py +21 -47
- synth_ai/main.py +4 -0
- synth_ai/task/__init__.py +25 -27
- synth_ai/task/apps/__init__.py +7 -8
- synth_ai/task/auth.py +8 -8
- synth_ai/task/client.py +14 -14
- synth_ai/task/contracts.py +36 -35
- synth_ai/task/datasets.py +6 -5
- synth_ai/task/errors.py +10 -10
- synth_ai/task/health.py +17 -9
- synth_ai/task/json.py +58 -23
- synth_ai/task/proxy.py +13 -9
- synth_ai/task/rubrics.py +16 -15
- synth_ai/task/server.py +12 -12
- synth_ai/task/tracing_utils.py +4 -4
- synth_ai/task/vendors.py +5 -6
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/decorators.py +18 -16
- synth_ai/tracing_v3/hooks.py +5 -5
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/session_tracer.py +40 -14
- synth_ai/tracing_v3/storage/base.py +85 -0
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -7
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -4
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/main.py +6 -6
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
- synth_ai/{lm → v0/lm}/overrides.py +2 -2
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- synth_ai-0.2.9.dev8.dist-info/METADATA +191 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/RECORD +268 -238
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/top_level.txt +1 -0
- examples/common_old/backend.py +0 -20
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1038
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -152
- examples/rl_old/task_app.py +0 -1131
- examples/warming_up_to_rl/old/event_rewards.md +0 -234
- examples/warming_up_to_rl/old/notes.md +0 -73
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/experimental/synth_oss.py +0 -445
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -249
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -838
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev7.dist-info/METADATA +0 -131
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1869 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import time as _time
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from fastapi import APIRouter, HTTPException, Request, status
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
from synth_ai.lm.vendors.base import BaseLMResponse
|
|
14
|
+
from synth_ai.task.tracing_utils import unique_sft_path
|
|
15
|
+
from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
|
|
16
|
+
from synth_ai.tracing_v3.llm_call_record_helpers import create_llm_call_record_from_response
|
|
17
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
18
|
+
|
|
19
|
+
from .registry import registry
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# --- Seeding utilities (robust, optional deps) ---
|
|
25
|
+
def _set_global_seed(seed_value: int) -> dict[str, Any]:
|
|
26
|
+
"""Set global RNG seeds across common libraries; return details for logging/restoration.
|
|
27
|
+
|
|
28
|
+
Returns a dict containing which libraries were seeded and prior states if obtainable.
|
|
29
|
+
"""
|
|
30
|
+
seeded: dict[str, Any] = {"seed": int(seed_value), "libs": []}
|
|
31
|
+
with contextlib.suppress(Exception):
|
|
32
|
+
import random as _random # type: ignore
|
|
33
|
+
|
|
34
|
+
_random.seed(seed_value)
|
|
35
|
+
seeded["libs"].append("random")
|
|
36
|
+
with contextlib.suppress(Exception):
|
|
37
|
+
import numpy as _np # type: ignore
|
|
38
|
+
|
|
39
|
+
_np.random.seed(seed_value)
|
|
40
|
+
seeded["libs"].append("numpy")
|
|
41
|
+
with contextlib.suppress(Exception):
|
|
42
|
+
import torch as _torch # type: ignore
|
|
43
|
+
|
|
44
|
+
if hasattr(_torch, "manual_seed"):
|
|
45
|
+
_torch.manual_seed(seed_value)
|
|
46
|
+
seeded["libs"].append("torch")
|
|
47
|
+
# Make CUDA deterministic if present (best-effort)
|
|
48
|
+
with contextlib.suppress(Exception):
|
|
49
|
+
if getattr(_torch, "cuda", None) and _torch.cuda.is_available():
|
|
50
|
+
_torch.cuda.manual_seed_all(seed_value)
|
|
51
|
+
seeded.setdefault("cuda", True)
|
|
52
|
+
# CUDNN deterministic flags (optional)
|
|
53
|
+
with contextlib.suppress(Exception):
|
|
54
|
+
if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
|
|
55
|
+
_torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
|
|
56
|
+
_torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
|
|
57
|
+
return seeded
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _clear_seed_side_effects() -> None:
|
|
61
|
+
"""Best-effort cleanup to avoid global deterministic side-effects between requests."""
|
|
62
|
+
# We cannot truly restore prior RNG states without capturing them; we just avoid
|
|
63
|
+
# leaving aggressive deterministic flags enabled where it matters.
|
|
64
|
+
with contextlib.suppress(Exception):
|
|
65
|
+
import torch as _torch # type: ignore
|
|
66
|
+
|
|
67
|
+
with contextlib.suppress(Exception):
|
|
68
|
+
if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
|
|
69
|
+
# Re-enable cudnn.benchmark default True only if it was True; safest is False -> leave as is.
|
|
70
|
+
# We'll keep deterministic False to avoid global impact; benchmark left False for stability.
|
|
71
|
+
_torch.backends.cudnn.deterministic = False # type: ignore[attr-defined]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
router = APIRouter()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class RolloutEnvSpec(BaseModel):
|
|
78
|
+
env_id: str | None = None
|
|
79
|
+
env_name: str | None = None
|
|
80
|
+
config: dict[str, Any] = {}
|
|
81
|
+
seed: int | None = None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class RolloutPolicySpec(BaseModel):
|
|
85
|
+
policy_id: str | None = None
|
|
86
|
+
policy_name: str | None = None
|
|
87
|
+
config: dict[str, Any] = {}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class RolloutBranchConfig(BaseModel):
|
|
91
|
+
branch_every_n_steps: int = 0
|
|
92
|
+
branch_on_condition: str | None = None
|
|
93
|
+
max_branches: int = 0
|
|
94
|
+
branch_policy: bool = False
|
|
95
|
+
branch_env: bool = False
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class RolloutRecordConfig(BaseModel):
|
|
99
|
+
trajectories: bool = True
|
|
100
|
+
logprobs: bool = False
|
|
101
|
+
value: bool = False
|
|
102
|
+
return_trace: bool = False
|
|
103
|
+
trace_format: str = "compact"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class RolloutSafetyConfig(BaseModel):
|
|
107
|
+
max_ops: int = 100000
|
|
108
|
+
max_time_s: float = 3600.0
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class RolloutRequest(BaseModel):
|
|
112
|
+
run_id: str
|
|
113
|
+
env: RolloutEnvSpec
|
|
114
|
+
policy: RolloutPolicySpec
|
|
115
|
+
ops: list[str] # ["agent", "env", ...]
|
|
116
|
+
record: RolloutRecordConfig = RolloutRecordConfig()
|
|
117
|
+
on_done: str = "reset" # "reset" | "terminate"
|
|
118
|
+
branch: RolloutBranchConfig | None = None
|
|
119
|
+
safety: RolloutSafetyConfig = RolloutSafetyConfig()
|
|
120
|
+
# Optional run/session context
|
|
121
|
+
training_session_id: str | None = None
|
|
122
|
+
synth_base_url: str | None = None
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class RolloutStep(BaseModel):
|
|
126
|
+
obs: dict[str, Any]
|
|
127
|
+
tool_calls: list[dict[str, Any]]
|
|
128
|
+
reward: float | None = None
|
|
129
|
+
done: bool = False
|
|
130
|
+
truncated: bool | None = None
|
|
131
|
+
logprob: float | None = None
|
|
132
|
+
value: float | None = None
|
|
133
|
+
info: dict[str, Any] | None = None
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class RolloutTrajectory(BaseModel):
|
|
137
|
+
env_id: str
|
|
138
|
+
policy_id: str
|
|
139
|
+
steps: list[RolloutStep]
|
|
140
|
+
final: dict[str, Any] | None = None
|
|
141
|
+
length: int
|
|
142
|
+
decision_samples: list[dict[str, Any]] | None = None
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def compute_stepwise_reward(
|
|
146
|
+
prev_achievements: dict[str, bool],
|
|
147
|
+
new_achievements: dict[str, bool],
|
|
148
|
+
decision_index: int,
|
|
149
|
+
actions_summary: list[dict[str, Any]],
|
|
150
|
+
indicator_lambda: float,
|
|
151
|
+
) -> tuple[dict[str, Any], dict[str, Any], dict[str, float]]:
|
|
152
|
+
"""Compute stepwise reward metadata given achievement states before/after a decision."""
|
|
153
|
+
|
|
154
|
+
prev_map = prev_achievements or {}
|
|
155
|
+
next_map = new_achievements or {}
|
|
156
|
+
|
|
157
|
+
unlocked = [name for name, value in next_map.items() if value and not prev_map.get(name, False)]
|
|
158
|
+
indicator = 1 if unlocked else 0
|
|
159
|
+
reward_value = float(indicator_lambda) * indicator
|
|
160
|
+
|
|
161
|
+
stepwise_info = {
|
|
162
|
+
"decision_index": decision_index,
|
|
163
|
+
"indicator": indicator,
|
|
164
|
+
"new_achievements": unlocked,
|
|
165
|
+
"reward": reward_value,
|
|
166
|
+
}
|
|
167
|
+
decision_sample = {
|
|
168
|
+
"decision_index": decision_index,
|
|
169
|
+
"indicator": indicator,
|
|
170
|
+
"r_i": reward_value,
|
|
171
|
+
"actions": actions_summary,
|
|
172
|
+
}
|
|
173
|
+
stats = {
|
|
174
|
+
"indicator": float(indicator),
|
|
175
|
+
"reward": reward_value,
|
|
176
|
+
"new_achievements_count": float(len(unlocked)),
|
|
177
|
+
}
|
|
178
|
+
return stepwise_info, decision_sample, stats
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class RolloutMetrics(BaseModel):
|
|
182
|
+
episode_returns: list[float]
|
|
183
|
+
mean_return: float
|
|
184
|
+
num_steps: int
|
|
185
|
+
num_episodes: int = 0
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class RolloutResponse(BaseModel):
|
|
189
|
+
run_id: str
|
|
190
|
+
trajectories: list[RolloutTrajectory]
|
|
191
|
+
branches: dict[str, list[str]] = {}
|
|
192
|
+
metrics: RolloutMetrics
|
|
193
|
+
aborted: bool = False
|
|
194
|
+
ops_executed: int = 0
|
|
195
|
+
trace: dict[str, Any] | None = None
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class RolloutTracingContext:
|
|
199
|
+
"""Helper managing tracing_v3 recording and optional SFT dumps for a rollout."""
|
|
200
|
+
|
|
201
|
+
def __init__(
|
|
202
|
+
self,
|
|
203
|
+
tracer: SessionTracer | None,
|
|
204
|
+
request: RolloutRequest,
|
|
205
|
+
fastapi_request: Request,
|
|
206
|
+
) -> None:
|
|
207
|
+
self.tracer = tracer
|
|
208
|
+
self.enabled = tracer is not None
|
|
209
|
+
self.request = request
|
|
210
|
+
self.fastapi_request = fastapi_request
|
|
211
|
+
self.run_id = request.run_id
|
|
212
|
+
self.current_step_id: str | None = None
|
|
213
|
+
self.current_turn: int | None = None
|
|
214
|
+
self.lm_calls_summary: list[dict[str, Any]] = []
|
|
215
|
+
self.decision_rewards: list[dict[str, Any]] = []
|
|
216
|
+
self.sft_records: list[dict[str, Any]] = []
|
|
217
|
+
self.latest_system_messages: list[str] = []
|
|
218
|
+
self.latest_user_messages: list[str] = []
|
|
219
|
+
self.latest_system_prompt_content: list[Any] = []
|
|
220
|
+
self.latest_user_prompt_content: list[Any] = []
|
|
221
|
+
self.trace_format = (
|
|
222
|
+
getattr(request.record, "trace_format", "compact") or "compact"
|
|
223
|
+
).lower()
|
|
224
|
+
self.return_trace = bool(getattr(request.record, "return_trace", False))
|
|
225
|
+
self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
|
|
226
|
+
self.session_trace = None
|
|
227
|
+
self.metadata_updates: dict[str, Any] = {}
|
|
228
|
+
self.policy_name = request.policy.policy_name or ""
|
|
229
|
+
self.env_name = request.env.env_name or ""
|
|
230
|
+
self.metadata_base: dict[str, Any] = {
|
|
231
|
+
"run_id": self.run_id,
|
|
232
|
+
"policy_name": self.policy_name,
|
|
233
|
+
"policy_id": request.policy.policy_id,
|
|
234
|
+
"env_name": self.env_name,
|
|
235
|
+
"env_id": request.env.env_id,
|
|
236
|
+
"seed": request.env.seed,
|
|
237
|
+
"training_session_id": request.training_session_id,
|
|
238
|
+
"synth_base_url": request.synth_base_url,
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
# Expose context for downstream calls inside this request lifecycle
|
|
242
|
+
fastapi_request.state.rollout_tracing = self
|
|
243
|
+
fastapi_request.state.rollout_run_id = self.run_id
|
|
244
|
+
|
|
245
|
+
async def start_session(self) -> None:
|
|
246
|
+
if not self.enabled or self.tracer is None:
|
|
247
|
+
return
|
|
248
|
+
try:
|
|
249
|
+
await self.tracer.initialize()
|
|
250
|
+
except Exception as exc:
|
|
251
|
+
logger.debug("TRACING_INIT_FAIL: %s", exc)
|
|
252
|
+
try:
|
|
253
|
+
await self.tracer.start_session(
|
|
254
|
+
session_id=self.run_id, metadata=dict(self.metadata_base)
|
|
255
|
+
)
|
|
256
|
+
except Exception as exc:
|
|
257
|
+
logger.warning("TRACING_START_FAIL: %s", exc)
|
|
258
|
+
self.enabled = False
|
|
259
|
+
self.tracer = None
|
|
260
|
+
|
|
261
|
+
async def start_decision(self, turn_number: int) -> None:
|
|
262
|
+
self.current_turn = turn_number
|
|
263
|
+
self.current_step_id = f"decision_{turn_number}"
|
|
264
|
+
if not self.enabled or self.tracer is None:
|
|
265
|
+
return
|
|
266
|
+
try:
|
|
267
|
+
await self.tracer.start_timestep(step_id=self.current_step_id, turn_number=turn_number)
|
|
268
|
+
except Exception as exc:
|
|
269
|
+
logger.debug("TRACING_STEP_START_FAIL: %s", exc)
|
|
270
|
+
|
|
271
|
+
async def end_decision(self) -> None:
|
|
272
|
+
if not self.enabled or self.tracer is None:
|
|
273
|
+
return
|
|
274
|
+
try:
|
|
275
|
+
await self.tracer.end_timestep(step_id=self.current_step_id)
|
|
276
|
+
except Exception as exc:
|
|
277
|
+
logger.debug("TRACING_STEP_END_FAIL: %s", exc)
|
|
278
|
+
finally:
|
|
279
|
+
self.current_step_id = None
|
|
280
|
+
|
|
281
|
+
def _message_metadata(self) -> dict[str, Any]:
|
|
282
|
+
return {
|
|
283
|
+
"turn": self.current_turn,
|
|
284
|
+
"step_id": self.current_step_id,
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
async def record_policy_prompts(
|
|
288
|
+
self,
|
|
289
|
+
system_messages: list[Any],
|
|
290
|
+
user_messages: list[Any],
|
|
291
|
+
) -> None:
|
|
292
|
+
self.latest_system_messages = [self._prompt_text(entry) for entry in system_messages]
|
|
293
|
+
self.latest_user_messages = [self._prompt_text(entry) for entry in user_messages]
|
|
294
|
+
self.latest_system_prompt_content = [
|
|
295
|
+
self._prompt_content(entry, role="system") for entry in system_messages
|
|
296
|
+
]
|
|
297
|
+
self.latest_user_prompt_content = [
|
|
298
|
+
self._prompt_content(entry, role="user") for entry in user_messages
|
|
299
|
+
]
|
|
300
|
+
if not self.enabled or self.tracer is None:
|
|
301
|
+
return
|
|
302
|
+
for entry in system_messages:
|
|
303
|
+
try:
|
|
304
|
+
await self.tracer.record_message(
|
|
305
|
+
content=self._prompt_payload(entry, role="system"),
|
|
306
|
+
message_type="policy_system_prompt",
|
|
307
|
+
metadata=self._message_metadata(),
|
|
308
|
+
)
|
|
309
|
+
except Exception as exc:
|
|
310
|
+
logger.debug("TRACING_SYSTEM_MSG_FAIL: %s", exc)
|
|
311
|
+
for entry in user_messages:
|
|
312
|
+
try:
|
|
313
|
+
await self.tracer.record_message(
|
|
314
|
+
content=self._prompt_payload(entry, role="user"),
|
|
315
|
+
message_type="policy_user_prompt",
|
|
316
|
+
metadata=self._message_metadata(),
|
|
317
|
+
)
|
|
318
|
+
except Exception as exc:
|
|
319
|
+
logger.debug("TRACING_USER_MSG_FAIL: %s", exc)
|
|
320
|
+
|
|
321
|
+
def _content_to_text(self, content: Any) -> str:
|
|
322
|
+
if isinstance(content, str):
|
|
323
|
+
return content
|
|
324
|
+
if isinstance(content, list):
|
|
325
|
+
parts: list[str] = []
|
|
326
|
+
for seg in content:
|
|
327
|
+
if isinstance(seg, dict):
|
|
328
|
+
text_val = seg.get("text") or seg.get("content")
|
|
329
|
+
if isinstance(text_val, str):
|
|
330
|
+
parts.append(text_val)
|
|
331
|
+
return "".join(parts)
|
|
332
|
+
if content is None:
|
|
333
|
+
return ""
|
|
334
|
+
return str(content)
|
|
335
|
+
|
|
336
|
+
def _prompt_text(self, entry: Any) -> str:
|
|
337
|
+
if isinstance(entry, dict):
|
|
338
|
+
text = entry.get("text")
|
|
339
|
+
if isinstance(text, str):
|
|
340
|
+
return text
|
|
341
|
+
content = entry.get("content")
|
|
342
|
+
return self._content_to_text(content)
|
|
343
|
+
return self._content_to_text(entry)
|
|
344
|
+
|
|
345
|
+
def _prompt_payload(self, entry: Any, *, role: str) -> dict[str, Any]:
|
|
346
|
+
if isinstance(entry, dict):
|
|
347
|
+
payload = dict(entry)
|
|
348
|
+
payload.setdefault("role", role)
|
|
349
|
+
return payload
|
|
350
|
+
return {
|
|
351
|
+
"role": role,
|
|
352
|
+
"text": self._prompt_text(entry),
|
|
353
|
+
"content": entry,
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
def _prompt_content(self, entry: Any, *, role: str) -> Any:
|
|
357
|
+
payload = self._prompt_payload(entry, role=role)
|
|
358
|
+
return payload.get("content", payload.get("text"))
|
|
359
|
+
|
|
360
|
+
def _content_has_image(self, content: Any) -> bool:
|
|
361
|
+
if isinstance(content, list):
|
|
362
|
+
return any(
|
|
363
|
+
isinstance(seg, dict)
|
|
364
|
+
and seg.get("type") in {"image", "image_url"}
|
|
365
|
+
for seg in content
|
|
366
|
+
)
|
|
367
|
+
if isinstance(content, dict):
|
|
368
|
+
if content.get("type") in {"image", "image_url"}:
|
|
369
|
+
return True
|
|
370
|
+
inner = content.get("content")
|
|
371
|
+
if isinstance(inner, list):
|
|
372
|
+
return any(
|
|
373
|
+
isinstance(seg, dict)
|
|
374
|
+
and seg.get("type") in {"image", "image_url"}
|
|
375
|
+
for seg in inner
|
|
376
|
+
)
|
|
377
|
+
return False
|
|
378
|
+
|
|
379
|
+
def _safe_json(self, payload: Any, limit: int = 4000) -> str:
|
|
380
|
+
try:
|
|
381
|
+
text = json.dumps(payload, ensure_ascii=False)
|
|
382
|
+
except Exception:
|
|
383
|
+
text = str(payload)
|
|
384
|
+
if len(text) > limit:
|
|
385
|
+
return text[:limit] + "…"
|
|
386
|
+
return text
|
|
387
|
+
|
|
388
|
+
async def record_tool_invocation(self, tool_calls: list[dict[str, Any]] | None) -> None:
|
|
389
|
+
if tool_calls is None:
|
|
390
|
+
return
|
|
391
|
+
if self.enabled and self.tracer is not None:
|
|
392
|
+
try:
|
|
393
|
+
await self.tracer.record_message(
|
|
394
|
+
content=self._safe_json(tool_calls),
|
|
395
|
+
message_type="policy_tool_call",
|
|
396
|
+
metadata=self._message_metadata(),
|
|
397
|
+
)
|
|
398
|
+
except Exception as exc:
|
|
399
|
+
logger.debug("TRACING_TOOL_MSG_FAIL: %s", exc)
|
|
400
|
+
|
|
401
|
+
async def _record_event(self, event: Any) -> int | None:
|
|
402
|
+
if not self.enabled or self.tracer is None:
|
|
403
|
+
return None
|
|
404
|
+
try:
|
|
405
|
+
return await self.tracer.record_event(event)
|
|
406
|
+
except Exception as exc:
|
|
407
|
+
logger.debug("TRACING_EVENT_FAIL: %s", exc)
|
|
408
|
+
return None
|
|
409
|
+
|
|
410
|
+
async def record_llm_call(
|
|
411
|
+
self,
|
|
412
|
+
*,
|
|
413
|
+
inference_request: dict[str, Any],
|
|
414
|
+
inference_response: dict[str, Any],
|
|
415
|
+
tool_calls: list[dict[str, Any]] | None,
|
|
416
|
+
provider: str,
|
|
417
|
+
model_name: str,
|
|
418
|
+
started_at: datetime,
|
|
419
|
+
completed_at: datetime,
|
|
420
|
+
latency_ms: int | None,
|
|
421
|
+
) -> None:
|
|
422
|
+
usage = inference_response.get("usage") or {}
|
|
423
|
+
input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens")
|
|
424
|
+
output_tokens = usage.get("output_tokens") or usage.get("completion_tokens")
|
|
425
|
+
total_tokens = usage.get("total_tokens")
|
|
426
|
+
cost_usd = usage.get("cost_usd") or usage.get("cost") or usage.get("total_cost")
|
|
427
|
+
|
|
428
|
+
assistant_message = None
|
|
429
|
+
choices = inference_response.get("choices") or []
|
|
430
|
+
if choices:
|
|
431
|
+
assistant_message = choices[0].get("message") or {}
|
|
432
|
+
assistant_content = (
|
|
433
|
+
assistant_message.get("content") if isinstance(assistant_message, dict) else None
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
raw_response = self._content_to_text(assistant_content)
|
|
437
|
+
if not raw_response:
|
|
438
|
+
raw_response = self._safe_json(inference_response, limit=2000)
|
|
439
|
+
|
|
440
|
+
base_response = BaseLMResponse(
|
|
441
|
+
raw_response=raw_response,
|
|
442
|
+
tool_calls=assistant_message.get("tool_calls")
|
|
443
|
+
if isinstance(assistant_message, dict)
|
|
444
|
+
else None,
|
|
445
|
+
usage=usage or None,
|
|
446
|
+
api_type="chat_completions",
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
request_messages = inference_request.get("messages") or []
|
|
450
|
+
try:
|
|
451
|
+
temperature = float(inference_request.get("temperature"))
|
|
452
|
+
except Exception:
|
|
453
|
+
temperature = 0.0
|
|
454
|
+
|
|
455
|
+
call_record = create_llm_call_record_from_response(
|
|
456
|
+
response=base_response,
|
|
457
|
+
model_name=model_name,
|
|
458
|
+
provider=provider,
|
|
459
|
+
messages=request_messages,
|
|
460
|
+
temperature=temperature,
|
|
461
|
+
request_params=inference_request,
|
|
462
|
+
tools=inference_request.get("tools"),
|
|
463
|
+
started_at=started_at,
|
|
464
|
+
completed_at=completed_at,
|
|
465
|
+
latency_ms=latency_ms,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
event_metadata = {
|
|
469
|
+
"policy_id": self.request.policy.policy_id,
|
|
470
|
+
"turn": self.current_turn,
|
|
471
|
+
"run_id": self.run_id,
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
event = LMCAISEvent(
|
|
475
|
+
system_instance_id=f"policy:{self.policy_name or 'unknown'}",
|
|
476
|
+
time_record=TimeRecord(event_time=completed_at.timestamp()),
|
|
477
|
+
model_name=model_name,
|
|
478
|
+
provider=provider,
|
|
479
|
+
input_tokens=input_tokens,
|
|
480
|
+
output_tokens=output_tokens,
|
|
481
|
+
total_tokens=total_tokens,
|
|
482
|
+
cost_usd=cost_usd,
|
|
483
|
+
latency_ms=latency_ms,
|
|
484
|
+
call_records=[call_record],
|
|
485
|
+
metadata=event_metadata,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
await self._record_event(event)
|
|
489
|
+
|
|
490
|
+
self.lm_calls_summary.append(
|
|
491
|
+
{
|
|
492
|
+
"turn": self.current_turn,
|
|
493
|
+
"model": model_name,
|
|
494
|
+
"provider": provider,
|
|
495
|
+
"total_tokens": total_tokens,
|
|
496
|
+
"input_tokens": input_tokens,
|
|
497
|
+
"output_tokens": output_tokens,
|
|
498
|
+
"latency_ms": latency_ms,
|
|
499
|
+
"tool_calls": len(tool_calls or []),
|
|
500
|
+
}
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
if self.sft_output_dir is not None:
|
|
504
|
+
assistant_structured = assistant_content if assistant_content is not None else ""
|
|
505
|
+
assistant_text = self._content_to_text(assistant_content)
|
|
506
|
+
dialogue_structured: list[dict[str, Any]] = []
|
|
507
|
+
for content in self.latest_system_prompt_content:
|
|
508
|
+
if content is None:
|
|
509
|
+
continue
|
|
510
|
+
dialogue_structured.append({"role": "system", "content": content})
|
|
511
|
+
for content in self.latest_user_prompt_content:
|
|
512
|
+
if content is None:
|
|
513
|
+
continue
|
|
514
|
+
dialogue_structured.append({"role": "user", "content": content})
|
|
515
|
+
dialogue_text = (
|
|
516
|
+
[{"role": "system", "content": s} for s in self.latest_system_messages]
|
|
517
|
+
+ [{"role": "user", "content": u} for u in self.latest_user_messages]
|
|
518
|
+
)
|
|
519
|
+
user_has_image = any(
|
|
520
|
+
self._content_has_image(content) for content in self.latest_user_prompt_content
|
|
521
|
+
)
|
|
522
|
+
assistant_has_image = self._content_has_image(assistant_structured)
|
|
523
|
+
record = {
|
|
524
|
+
"run_id": self.run_id,
|
|
525
|
+
"turn": self.current_turn,
|
|
526
|
+
"model": model_name,
|
|
527
|
+
"provider": provider,
|
|
528
|
+
"dialogue": dialogue_structured,
|
|
529
|
+
"dialogue_text": dialogue_text,
|
|
530
|
+
"assistant": {
|
|
531
|
+
"content": assistant_structured,
|
|
532
|
+
"content_text": assistant_text,
|
|
533
|
+
"tool_calls": assistant_message.get("tool_calls")
|
|
534
|
+
if isinstance(assistant_message, dict)
|
|
535
|
+
else [],
|
|
536
|
+
"has_image": assistant_has_image,
|
|
537
|
+
},
|
|
538
|
+
"metadata": {
|
|
539
|
+
"user_has_image": user_has_image,
|
|
540
|
+
"assistant_has_image": assistant_has_image,
|
|
541
|
+
"has_image": user_has_image or assistant_has_image,
|
|
542
|
+
},
|
|
543
|
+
"timestamp": datetime.utcnow().isoformat(),
|
|
544
|
+
}
|
|
545
|
+
self.sft_records.append(record)
|
|
546
|
+
|
|
547
|
+
async def record_environment_event(
|
|
548
|
+
self,
|
|
549
|
+
*,
|
|
550
|
+
env_handle: Any,
|
|
551
|
+
prev_obs: dict[str, Any] | None,
|
|
552
|
+
env_response: Any,
|
|
553
|
+
next_obs: dict[str, Any] | None,
|
|
554
|
+
metadata: dict[str, Any] | None = None,
|
|
555
|
+
) -> int | None:
|
|
556
|
+
if not self.enabled or self.tracer is None:
|
|
557
|
+
return None
|
|
558
|
+
|
|
559
|
+
try:
|
|
560
|
+
prev_summary = (
|
|
561
|
+
_summarize_observation_for_storage(env_handle, prev_obs or {})
|
|
562
|
+
if prev_obs is not None
|
|
563
|
+
else None
|
|
564
|
+
)
|
|
565
|
+
except Exception:
|
|
566
|
+
prev_summary = None
|
|
567
|
+
try:
|
|
568
|
+
next_summary = (
|
|
569
|
+
_summarize_observation_for_storage(env_handle, next_obs or {})
|
|
570
|
+
if next_obs is not None
|
|
571
|
+
else None
|
|
572
|
+
)
|
|
573
|
+
except Exception:
|
|
574
|
+
next_summary = None
|
|
575
|
+
|
|
576
|
+
reward_val = getattr(env_response, "reward", None)
|
|
577
|
+
try:
|
|
578
|
+
reward_float = float(reward_val) if reward_val is not None else 0.0
|
|
579
|
+
except Exception:
|
|
580
|
+
reward_float = 0.0
|
|
581
|
+
|
|
582
|
+
event = EnvironmentEvent(
|
|
583
|
+
system_instance_id=f"environment:{self.env_name or 'unknown'}",
|
|
584
|
+
time_record=TimeRecord(event_time=datetime.utcnow().timestamp()),
|
|
585
|
+
reward=reward_float,
|
|
586
|
+
terminated=bool(getattr(env_response, "done", False)),
|
|
587
|
+
truncated=bool(getattr(env_response, "truncated", False)),
|
|
588
|
+
system_state_before=prev_summary,
|
|
589
|
+
system_state_after=next_summary,
|
|
590
|
+
metadata={
|
|
591
|
+
"turn": self.current_turn,
|
|
592
|
+
"run_id": self.run_id,
|
|
593
|
+
**(metadata or {}),
|
|
594
|
+
},
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
return await self._record_event(event)
|
|
598
|
+
|
|
599
|
+
async def record_decision_reward(
|
|
600
|
+
self,
|
|
601
|
+
*,
|
|
602
|
+
event_id: int | None,
|
|
603
|
+
decision_meta: dict[str, Any] | None,
|
|
604
|
+
) -> None:
|
|
605
|
+
decision_meta = decision_meta or {}
|
|
606
|
+
ach_delta = int(decision_meta.get("ach_delta", 0))
|
|
607
|
+
unique_delta = int(decision_meta.get("unique_delta", 0))
|
|
608
|
+
all_ach = list(decision_meta.get("all") or [])
|
|
609
|
+
unique_ach = list(decision_meta.get("unique") or [])
|
|
610
|
+
|
|
611
|
+
self.decision_rewards.append(
|
|
612
|
+
{
|
|
613
|
+
"turn": self.current_turn,
|
|
614
|
+
"ach_delta": ach_delta,
|
|
615
|
+
"unique_delta": unique_delta,
|
|
616
|
+
"achievements": all_ach,
|
|
617
|
+
"unique_achievements": unique_ach,
|
|
618
|
+
}
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
if not self.enabled or self.tracer is None or event_id is None:
|
|
622
|
+
return
|
|
623
|
+
try:
|
|
624
|
+
await self.tracer.record_event_reward(
|
|
625
|
+
event_id=event_id,
|
|
626
|
+
turn_number=self.current_turn,
|
|
627
|
+
reward_value=float(ach_delta),
|
|
628
|
+
reward_type="achievement_delta",
|
|
629
|
+
annotation={"achievements": all_ach},
|
|
630
|
+
source="environment",
|
|
631
|
+
)
|
|
632
|
+
if unique_delta:
|
|
633
|
+
await self.tracer.record_event_reward(
|
|
634
|
+
event_id=event_id,
|
|
635
|
+
turn_number=self.current_turn,
|
|
636
|
+
reward_value=float(unique_delta),
|
|
637
|
+
reward_type="unique_achievement_delta",
|
|
638
|
+
annotation={"achievements": unique_ach},
|
|
639
|
+
source="environment",
|
|
640
|
+
)
|
|
641
|
+
except Exception as exc:
|
|
642
|
+
logger.debug("TRACING_REWARD_FAIL: %s", exc)
|
|
643
|
+
|
|
644
|
+
def update_metadata(self, **kwargs: Any) -> None:
|
|
645
|
+
self.metadata_updates.update({k: v for k, v in kwargs.items() if v is not None})
|
|
646
|
+
|
|
647
|
+
async def finalize(
|
|
648
|
+
self,
|
|
649
|
+
*,
|
|
650
|
+
total_reward: float,
|
|
651
|
+
achievement_state: dict[str, bool] | None,
|
|
652
|
+
total_steps: int,
|
|
653
|
+
) -> Any:
|
|
654
|
+
final_achievements = [key for key, val in (achievement_state or {}).items() if val]
|
|
655
|
+
self.metadata_updates.setdefault("final_achievements", final_achievements)
|
|
656
|
+
if self.enabled and self.tracer is not None:
|
|
657
|
+
try:
|
|
658
|
+
await self.tracer.record_outcome_reward(
|
|
659
|
+
total_reward=int(total_reward),
|
|
660
|
+
achievements_count=len(final_achievements),
|
|
661
|
+
total_steps=int(total_steps),
|
|
662
|
+
reward_metadata=dict(self.metadata_updates),
|
|
663
|
+
)
|
|
664
|
+
except Exception as exc:
|
|
665
|
+
logger.debug("TRACING_OUTCOME_FAIL: %s", exc)
|
|
666
|
+
try:
|
|
667
|
+
self.session_trace = await self.tracer.end_session()
|
|
668
|
+
if self.session_trace is not None:
|
|
669
|
+
self.session_trace.metadata.update(self.metadata_updates)
|
|
670
|
+
except Exception as exc:
|
|
671
|
+
logger.debug("TRACING_END_SESSION_FAIL: %s", exc)
|
|
672
|
+
self.session_trace = None
|
|
673
|
+
with contextlib.suppress(Exception):
|
|
674
|
+
await self.tracer.close()
|
|
675
|
+
|
|
676
|
+
if self.sft_records and self.sft_output_dir:
|
|
677
|
+
self.write_sft_records()
|
|
678
|
+
|
|
679
|
+
# Clear context from request state to avoid leaks
|
|
680
|
+
self.fastapi_request.state.rollout_tracing = None
|
|
681
|
+
|
|
682
|
+
return self.session_trace
|
|
683
|
+
|
|
684
|
+
def write_sft_records(self) -> None:
|
|
685
|
+
if not self.sft_output_dir or not self.sft_records:
|
|
686
|
+
return
|
|
687
|
+
try:
|
|
688
|
+
path = unique_sft_path(self.sft_output_dir, run_id=self.run_id)
|
|
689
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
690
|
+
with path.open("w", encoding="utf-8") as fh:
|
|
691
|
+
for record in self.sft_records:
|
|
692
|
+
json.dump(record, fh, ensure_ascii=False)
|
|
693
|
+
fh.write("\n")
|
|
694
|
+
logger.info(f"SFT_WRITTEN: {path}")
|
|
695
|
+
except Exception as exc:
|
|
696
|
+
logger.warning(f"SFT_WRITE_FAIL: {exc}")
|
|
697
|
+
finally:
|
|
698
|
+
self.sft_records.clear()
|
|
699
|
+
|
|
700
|
+
def build_trace_payload(self, session_trace: Any) -> dict[str, Any] | None:
|
|
701
|
+
if not self.return_trace or session_trace is None:
|
|
702
|
+
return None
|
|
703
|
+
if self.trace_format == "full":
|
|
704
|
+
payload = session_trace.to_dict()
|
|
705
|
+
payload.setdefault("metadata", {}).update(self.metadata_updates)
|
|
706
|
+
return payload
|
|
707
|
+
metadata = dict(session_trace.metadata)
|
|
708
|
+
metadata.update(self.metadata_updates)
|
|
709
|
+
return {
|
|
710
|
+
"session_id": session_trace.session_id,
|
|
711
|
+
"created_at": session_trace.created_at.isoformat(),
|
|
712
|
+
"metadata": metadata,
|
|
713
|
+
"events_count": len(session_trace.event_history),
|
|
714
|
+
"messages_count": len(session_trace.markov_blanket_message_history),
|
|
715
|
+
"lm_calls": self.lm_calls_summary,
|
|
716
|
+
"decision_rewards": self.decision_rewards,
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def _summarize_observation_for_storage(
|
|
721
|
+
env_handle: Any, observation: dict[str, Any]
|
|
722
|
+
) -> dict[str, Any]:
|
|
723
|
+
"""Return a compact dict for trajectory storage instead of the raw observation.
|
|
724
|
+
|
|
725
|
+
- For Crafter, use the same summary used for the policy user prompt
|
|
726
|
+
- For others, keep a minimal subset or plain text preview
|
|
727
|
+
"""
|
|
728
|
+
# Try Crafter-specific formatter
|
|
729
|
+
crafter_wrapper = None
|
|
730
|
+
with contextlib.suppress(Exception):
|
|
731
|
+
from .envs.crafter.environment import (
|
|
732
|
+
CrafterEnvironmentWrapper as _CrafterWrapper, # type: ignore
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
crafter_wrapper = _CrafterWrapper # type: ignore[assignment]
|
|
736
|
+
|
|
737
|
+
if crafter_wrapper is not None and isinstance(
|
|
738
|
+
getattr(env_handle, "env", None), crafter_wrapper
|
|
739
|
+
):
|
|
740
|
+
with contextlib.suppress(Exception):
|
|
741
|
+
from .envs.crafter.shared import format_observation as _fmt # type: ignore
|
|
742
|
+
|
|
743
|
+
text = _fmt(observation or {})
|
|
744
|
+
return {"text": text}
|
|
745
|
+
|
|
746
|
+
# Generic fallback: extract a few small fields if present; avoid huge arrays
|
|
747
|
+
with contextlib.suppress(Exception):
|
|
748
|
+
inv = observation.get("inventory") if isinstance(observation, dict) else None
|
|
749
|
+
ach = observation.get("achievements_status") if isinstance(observation, dict) else None
|
|
750
|
+
pos = observation.get("player_position") if isinstance(observation, dict) else None
|
|
751
|
+
health = None
|
|
752
|
+
if isinstance(inv, dict):
|
|
753
|
+
health = inv.get("health")
|
|
754
|
+
summary = {
|
|
755
|
+
"position": pos,
|
|
756
|
+
"health": health,
|
|
757
|
+
"inventory_keys": sorted(k for k, v in (inv or {}).items() if v)[:10]
|
|
758
|
+
if isinstance(inv, dict)
|
|
759
|
+
else None,
|
|
760
|
+
"achievements_unlocked": sorted(k for k, v in (ach or {}).items() if v)[:10]
|
|
761
|
+
if isinstance(ach, dict)
|
|
762
|
+
else None,
|
|
763
|
+
}
|
|
764
|
+
return {"text": json.dumps(summary, ensure_ascii=False)}
|
|
765
|
+
|
|
766
|
+
# Last resort: plain string preview
|
|
767
|
+
try:
|
|
768
|
+
return {"text": str(observation)[:10000]}
|
|
769
|
+
except Exception:
|
|
770
|
+
return {"text": ""}
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
class RunAbortRequest(BaseModel):
|
|
774
|
+
run_id: str
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
class RunAbortResponse(BaseModel):
|
|
778
|
+
ok: bool
|
|
779
|
+
run_id: str
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
class RunStatusResponse(BaseModel):
|
|
783
|
+
run_id: str
|
|
784
|
+
status: str
|
|
785
|
+
started_at: datetime
|
|
786
|
+
finished_at: datetime | None = None
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
@router.post("/rollout", response_model=RolloutResponse)
|
|
790
|
+
async def execute_rollout(
|
|
791
|
+
request: RolloutRequest,
|
|
792
|
+
req: Request,
|
|
793
|
+
) -> RolloutResponse:
|
|
794
|
+
"""Execute a rollout with coordinated environment and policy steps."""
|
|
795
|
+
# Emit rollout identifier early for correlation
|
|
796
|
+
with contextlib.suppress(Exception):
|
|
797
|
+
_rid = getattr(request, "run_id", None)
|
|
798
|
+
_pol = getattr(request.policy, "policy_name", None) or getattr(request.policy, "policy_id", None)
|
|
799
|
+
_env = getattr(request.env, "env_name", None) or getattr(request.env, "env_id", None)
|
|
800
|
+
logger.info("ROLLOUT_BEGIN: run_id=%s policy=%s env=%s", _rid, _pol, _env)
|
|
801
|
+
print(f"[rollout] begin run_id={_rid} policy={_pol} env={_env}", flush=True)
|
|
802
|
+
# Enforce per-episode step cap via env-specific parameters; default to 20 if omitted
|
|
803
|
+
try:
|
|
804
|
+
_env_params = {}
|
|
805
|
+
if isinstance(request.env, RolloutEnvSpec) and isinstance(request.env.config, dict):
|
|
806
|
+
_env_params = dict(request.env.config.get("env_params") or {})
|
|
807
|
+
max_steps_per_episode = int(_env_params.get("max_steps_per_episode") or 20)
|
|
808
|
+
assert max_steps_per_episode > 0, "max_steps_per_episode must be a positive integer"
|
|
809
|
+
except Exception as _mse:
|
|
810
|
+
raise HTTPException(
|
|
811
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
812
|
+
detail={
|
|
813
|
+
"error": "invalid_env_params",
|
|
814
|
+
"message": f"Invalid or missing env_params.max_steps_per_episode: {_mse}",
|
|
815
|
+
},
|
|
816
|
+
) from _mse
|
|
817
|
+
# Truncate incoming ops to the enforced cap (each step is [agent, env])
|
|
818
|
+
ops_seq: list[str] = list(request.ops or [])
|
|
819
|
+
allowed_ops = max(0, int(max_steps_per_episode) * 2)
|
|
820
|
+
if len(ops_seq) > allowed_ops:
|
|
821
|
+
with contextlib.suppress(Exception):
|
|
822
|
+
logger.info(
|
|
823
|
+
"ROLL_OUT: truncating ops to cap: requested_ops=%s allowed_ops=%s",
|
|
824
|
+
str(len(ops_seq)),
|
|
825
|
+
str(allowed_ops),
|
|
826
|
+
)
|
|
827
|
+
ops_seq = ops_seq[:allowed_ops]
|
|
828
|
+
# Simple API key auth for inbound rollout
|
|
829
|
+
header_key = req.headers.get("x-api-key")
|
|
830
|
+
env_key = os.getenv("ENVIRONMENT_API_KEY")
|
|
831
|
+
dev_key = os.getenv("DEV_ENVIRONMENT_API_KEY")
|
|
832
|
+
# Accept either ENVIRONMENT_API_KEY or DEV_ENVIRONMENT_API_KEY
|
|
833
|
+
expected_keys = [k for k in (env_key, dev_key) if k]
|
|
834
|
+
if not expected_keys:
|
|
835
|
+
missing = []
|
|
836
|
+
if not env_key:
|
|
837
|
+
missing.append("ENVIRONMENT_API_KEY")
|
|
838
|
+
if not dev_key:
|
|
839
|
+
missing.append("DEV_ENVIRONMENT_API_KEY")
|
|
840
|
+
msg = f"Auth not configured: missing {', '.join(missing)} in task service environment"
|
|
841
|
+
logger.error(msg)
|
|
842
|
+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=msg)
|
|
843
|
+
if not header_key:
|
|
844
|
+
raise HTTPException(
|
|
845
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
846
|
+
detail="Invalid or missing API key: X-API-Key header not provided",
|
|
847
|
+
)
|
|
848
|
+
if header_key not in expected_keys:
|
|
849
|
+
# Do not leak secrets; include short prefix for diagnostics
|
|
850
|
+
exp_src = env_key if env_key else (dev_key or "")
|
|
851
|
+
exp_prefix = (exp_src[:7] + "…") if len(exp_src) >= 7 else "set"
|
|
852
|
+
got_prefix = (header_key[:7] + "…") if len(header_key) >= 7 else "set"
|
|
853
|
+
raise HTTPException(
|
|
854
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
855
|
+
detail=f"Invalid API key: header does not match expected (got={got_prefix}, expected_prefix={exp_prefix})",
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
# Log contextual fields for traceability
|
|
859
|
+
if request.training_session_id:
|
|
860
|
+
logger.info(f"ROLL_OUT: training_session_id={request.training_session_id}")
|
|
861
|
+
if request.synth_base_url:
|
|
862
|
+
logger.info(f"ROLL_OUT: synth_base_url={request.synth_base_url}")
|
|
863
|
+
|
|
864
|
+
# Log masked OpenAI API key presence for diagnostics
|
|
865
|
+
with contextlib.suppress(Exception):
|
|
866
|
+
_oa = os.getenv("OPENAI_API_KEY")
|
|
867
|
+
if _oa:
|
|
868
|
+
_pref = (_oa[:6] + "…") if len(_oa) >= 6 else "set"
|
|
869
|
+
logger.info(f"ROLL_OUT: OPENAI_API_KEY present (prefix={_pref})")
|
|
870
|
+
else:
|
|
871
|
+
logger.warning("ROLL_OUT: OPENAI_API_KEY missing")
|
|
872
|
+
|
|
873
|
+
# Make synth_base_url available for outbound calls in this app
|
|
874
|
+
with contextlib.suppress(Exception):
|
|
875
|
+
task_app = req.app.state.task_app
|
|
876
|
+
if request.synth_base_url:
|
|
877
|
+
task_app.synth_base_url = request.synth_base_url
|
|
878
|
+
|
|
879
|
+
tracer_factory = getattr(req.app.state, "session_tracer_factory", None)
|
|
880
|
+
tracer_instance: SessionTracer | None = None
|
|
881
|
+
if callable(tracer_factory):
|
|
882
|
+
try:
|
|
883
|
+
inst = tracer_factory()
|
|
884
|
+
tracer_instance = inst if isinstance(inst, SessionTracer) else None
|
|
885
|
+
except Exception as exc:
|
|
886
|
+
logger.debug(f"TRACER_FACTORY_FAIL: {exc}")
|
|
887
|
+
tracing_context = RolloutTracingContext(tracer_instance, request, req)
|
|
888
|
+
await tracing_context.start_session()
|
|
889
|
+
# Print whether tracing is active for this rollout
|
|
890
|
+
try:
|
|
891
|
+
print(
|
|
892
|
+
f"[rollout] tracing enabled={bool(tracing_context.enabled)} run_id={request.run_id}",
|
|
893
|
+
flush=True,
|
|
894
|
+
)
|
|
895
|
+
except Exception:
|
|
896
|
+
pass
|
|
897
|
+
|
|
898
|
+
# Register run
|
|
899
|
+
registry.register_run(request.run_id)
|
|
900
|
+
|
|
901
|
+
# Track resources created during this rollout so we can guarantee cleanup
|
|
902
|
+
created_env_id: str | None = None
|
|
903
|
+
created_policy_id: str | None = None
|
|
904
|
+
env_seed_used: int | None = None
|
|
905
|
+
trajectory_steps: list[RolloutStep] = []
|
|
906
|
+
decision_samples: list[dict[str, Any]] = []
|
|
907
|
+
pending_tool_calls: Any = None
|
|
908
|
+
current_obs: Any = {}
|
|
909
|
+
total_reward: float = 0.0
|
|
910
|
+
ops_executed = 0
|
|
911
|
+
last_agent_response_ts: float | None = None
|
|
912
|
+
last_policy_meta: dict[str, Any] | None = None
|
|
913
|
+
last_env_step_ms: float | None = None
|
|
914
|
+
last_env_step_completed_ts: float | None = None
|
|
915
|
+
decision_open = False
|
|
916
|
+
finalized = False
|
|
917
|
+
prev_achievements: dict[str, bool] = {}
|
|
918
|
+
session_trace = None
|
|
919
|
+
step_rewards_active = False
|
|
920
|
+
|
|
921
|
+
try:
|
|
922
|
+
# Initialize deterministic seed early for the entire rollout
|
|
923
|
+
seed_value: int | None = None
|
|
924
|
+
try:
|
|
925
|
+
if request.env and request.env.seed is not None:
|
|
926
|
+
seed_value = int(request.env.seed)
|
|
927
|
+
else:
|
|
928
|
+
# Derive a stable seed from run_id
|
|
929
|
+
import hashlib as _hashlib # local import to avoid global deps
|
|
930
|
+
|
|
931
|
+
_digest = _hashlib.sha256(request.run_id.encode("utf-8")).hexdigest()
|
|
932
|
+
# Use lower 32 bits to fit common RNG ranges
|
|
933
|
+
seed_value = int(_digest[:8], 16)
|
|
934
|
+
except Exception:
|
|
935
|
+
# Fallback to time-based seed if anything goes wrong
|
|
936
|
+
try:
|
|
937
|
+
seed_value = int((_time.time_ns() // 1_000_000) % (2**31 - 1))
|
|
938
|
+
except Exception:
|
|
939
|
+
seed_value = 42
|
|
940
|
+
|
|
941
|
+
_seed_info = _set_global_seed(int(seed_value))
|
|
942
|
+
with contextlib.suppress(Exception):
|
|
943
|
+
logger.info(
|
|
944
|
+
"ROLL_OUT: RNG seeded seed=%s libs=%s",
|
|
945
|
+
str(_seed_info.get("seed")),
|
|
946
|
+
",".join(_seed_info.get("libs", [])),
|
|
947
|
+
)
|
|
948
|
+
# Resolve or create environment
|
|
949
|
+
if request.env.env_id:
|
|
950
|
+
env_handle = registry.get_env(request.env.env_id)
|
|
951
|
+
if not env_handle:
|
|
952
|
+
raise HTTPException(
|
|
953
|
+
status_code=404,
|
|
954
|
+
detail=f"Environment {request.env.env_id} not found",
|
|
955
|
+
)
|
|
956
|
+
env_id = request.env.env_id
|
|
957
|
+
else:
|
|
958
|
+
# Create new environment
|
|
959
|
+
from .environment_routes import EnvCreateRequest, create_environment
|
|
960
|
+
|
|
961
|
+
if not request.env.env_name:
|
|
962
|
+
raise ValueError("FATAL: env_name is required - NO FALLBACKS!")
|
|
963
|
+
|
|
964
|
+
# Propagate training_session_id via env config for downstream usage
|
|
965
|
+
_env_config = dict(request.env.config or {})
|
|
966
|
+
if request.training_session_id is not None:
|
|
967
|
+
_env_config.setdefault("training_session_id", request.training_session_id)
|
|
968
|
+
env_response = await create_environment(
|
|
969
|
+
EnvCreateRequest(
|
|
970
|
+
env_name=request.env.env_name,
|
|
971
|
+
config=_env_config,
|
|
972
|
+
seed=request.env.seed,
|
|
973
|
+
rl_run_id=request.run_id,
|
|
974
|
+
)
|
|
975
|
+
)
|
|
976
|
+
env_id = env_response.env_id
|
|
977
|
+
env_handle = registry.get_env(env_id)
|
|
978
|
+
created_env_id = env_id
|
|
979
|
+
|
|
980
|
+
tracing_context.update_metadata(env_id=env_id)
|
|
981
|
+
|
|
982
|
+
# Resolve or create policy
|
|
983
|
+
if request.policy.policy_id:
|
|
984
|
+
policy_handle = registry.get_policy(request.policy.policy_id)
|
|
985
|
+
if not policy_handle:
|
|
986
|
+
raise HTTPException(
|
|
987
|
+
status_code=404,
|
|
988
|
+
detail=f"Policy {request.policy.policy_id} not found",
|
|
989
|
+
)
|
|
990
|
+
policy_id = request.policy.policy_id
|
|
991
|
+
else:
|
|
992
|
+
# Create new policy
|
|
993
|
+
from .policy_routes import PolicyCreateRequest, create_policy
|
|
994
|
+
|
|
995
|
+
if not request.policy.policy_name:
|
|
996
|
+
raise ValueError("FATAL: policy_name is required - NO FALLBACKS!")
|
|
997
|
+
|
|
998
|
+
# Propagate training_session_id and synth_base_url via policy config
|
|
999
|
+
_policy_config = dict(request.policy.config or {})
|
|
1000
|
+
if request.training_session_id is not None:
|
|
1001
|
+
_policy_config.setdefault("training_session_id", request.training_session_id)
|
|
1002
|
+
if request.synth_base_url is not None:
|
|
1003
|
+
_policy_config.setdefault("synth_base_url", request.synth_base_url)
|
|
1004
|
+
policy_response = await create_policy(
|
|
1005
|
+
PolicyCreateRequest(
|
|
1006
|
+
policy_name=request.policy.policy_name,
|
|
1007
|
+
config=_policy_config,
|
|
1008
|
+
rl_run_id=request.run_id,
|
|
1009
|
+
bound_env_id=env_id,
|
|
1010
|
+
),
|
|
1011
|
+
req,
|
|
1012
|
+
)
|
|
1013
|
+
policy_id = policy_response.policy_id
|
|
1014
|
+
policy_handle = registry.get_policy(policy_id)
|
|
1015
|
+
created_policy_id = policy_id
|
|
1016
|
+
|
|
1017
|
+
tracing_context.update_metadata(policy_id=policy_id)
|
|
1018
|
+
|
|
1019
|
+
# Bind policy to environment if not already bound
|
|
1020
|
+
if policy_handle and not policy_handle.bound_env_id:
|
|
1021
|
+
policy_handle.bound_env_id = env_id
|
|
1022
|
+
|
|
1023
|
+
# Record seed bound to environment for end-of-rollout verification/logging
|
|
1024
|
+
try:
|
|
1025
|
+
env_seed_used = int(getattr(env_handle, "seed", 0) or 0)
|
|
1026
|
+
except Exception:
|
|
1027
|
+
env_seed_used = None
|
|
1028
|
+
tracing_context.update_metadata(env_seed=env_seed_used)
|
|
1029
|
+
# Initialize trajectory
|
|
1030
|
+
trajectory_steps = []
|
|
1031
|
+
pending_tool_calls = None
|
|
1032
|
+
current_obs = env_handle.last_observation
|
|
1033
|
+
total_reward = 0.0
|
|
1034
|
+
ops_executed = 0
|
|
1035
|
+
last_agent_response_ts = None
|
|
1036
|
+
last_policy_meta = None
|
|
1037
|
+
last_env_step_ms = None
|
|
1038
|
+
last_env_step_completed_ts = None
|
|
1039
|
+
|
|
1040
|
+
# Stepwise reward configuration (Crafter shaping; gate on explicit enable)
|
|
1041
|
+
step_rewards_cfg_raw: dict[str, Any] = {}
|
|
1042
|
+
try:
|
|
1043
|
+
if isinstance(request.policy.config, dict):
|
|
1044
|
+
step_rewards_cfg_raw = dict(request.policy.config.get("step_rewards") or {})
|
|
1045
|
+
except Exception:
|
|
1046
|
+
step_rewards_cfg_raw = {}
|
|
1047
|
+
if not step_rewards_cfg_raw:
|
|
1048
|
+
try:
|
|
1049
|
+
if isinstance(request.env.config, dict):
|
|
1050
|
+
step_rewards_cfg_raw = dict(request.env.config.get("step_rewards") or {})
|
|
1051
|
+
except Exception:
|
|
1052
|
+
step_rewards_cfg_raw = {}
|
|
1053
|
+
|
|
1054
|
+
step_rewards_enabled = bool(step_rewards_cfg_raw.get("enabled", False))
|
|
1055
|
+
step_rewards_mode = str(step_rewards_cfg_raw.get("mode") or "off").lower()
|
|
1056
|
+
try:
|
|
1057
|
+
step_rewards_indicator_lambda = float(
|
|
1058
|
+
step_rewards_cfg_raw.get("indicator_lambda") or 0.0
|
|
1059
|
+
)
|
|
1060
|
+
except Exception:
|
|
1061
|
+
step_rewards_indicator_lambda = 0.0
|
|
1062
|
+
try:
|
|
1063
|
+
step_rewards_beta = float(step_rewards_cfg_raw.get("step_beta") or 0.0)
|
|
1064
|
+
except Exception:
|
|
1065
|
+
step_rewards_beta = 0.0
|
|
1066
|
+
step_rewards_active = step_rewards_enabled and step_rewards_mode == "decision_stepwise"
|
|
1067
|
+
|
|
1068
|
+
def _extract_achievements(obs: Any) -> dict[str, bool]:
|
|
1069
|
+
if not isinstance(obs, dict):
|
|
1070
|
+
return {}
|
|
1071
|
+
ach = obs.get("achievements_status")
|
|
1072
|
+
if isinstance(ach, dict):
|
|
1073
|
+
return {str(k): bool(v) for k, v in ach.items()}
|
|
1074
|
+
return {}
|
|
1075
|
+
|
|
1076
|
+
def _summarize_tool_calls(tool_calls: Any) -> list[dict[str, Any]]:
|
|
1077
|
+
if not tool_calls:
|
|
1078
|
+
return []
|
|
1079
|
+
try:
|
|
1080
|
+
items = (
|
|
1081
|
+
tool_calls
|
|
1082
|
+
if isinstance(tool_calls, list)
|
|
1083
|
+
else list(tool_calls) # tolerates tuples or pydantic lists
|
|
1084
|
+
)
|
|
1085
|
+
except Exception:
|
|
1086
|
+
return []
|
|
1087
|
+
summary: list[dict[str, Any]] = []
|
|
1088
|
+
for tc in items:
|
|
1089
|
+
tool_name = None
|
|
1090
|
+
args: Any = {}
|
|
1091
|
+
if isinstance(tc, dict):
|
|
1092
|
+
tool_name = tc.get("tool") or tc.get("tool_name") or tc.get("name")
|
|
1093
|
+
raw_args = tc.get("arguments") or tc.get("args") or {}
|
|
1094
|
+
else:
|
|
1095
|
+
tool_name = getattr(tc, "tool", None) or getattr(tc, "tool_name", None)
|
|
1096
|
+
raw_args = getattr(tc, "arguments", None) or getattr(tc, "args", None) or {}
|
|
1097
|
+
args = raw_args
|
|
1098
|
+
if isinstance(raw_args, str):
|
|
1099
|
+
try:
|
|
1100
|
+
args = json.loads(raw_args)
|
|
1101
|
+
except Exception:
|
|
1102
|
+
args = raw_args
|
|
1103
|
+
summary.append({"tool": tool_name, "args": args})
|
|
1104
|
+
return summary
|
|
1105
|
+
|
|
1106
|
+
decision_samples: list[dict[str, Any]] = []
|
|
1107
|
+
decision_index = 0
|
|
1108
|
+
decision_open = False
|
|
1109
|
+
session_trace = None
|
|
1110
|
+
finalized = False
|
|
1111
|
+
prev_achievements = _extract_achievements(current_obs)
|
|
1112
|
+
# Track episode-level achievements that have been seen as true at any point so far
|
|
1113
|
+
episode_seen_achievements: set[str] = {
|
|
1114
|
+
k for k, v in (prev_achievements or {}).items() if bool(v)
|
|
1115
|
+
}
|
|
1116
|
+
stepwise_indicator_sum = 0.0
|
|
1117
|
+
stepwise_reward_sum = 0.0
|
|
1118
|
+
stepwise_new_achievements_total = 0
|
|
1119
|
+
final_achievement_count = sum(1 for v in prev_achievements.values() if v)
|
|
1120
|
+
|
|
1121
|
+
# Execute ops sequence (capped by env_params.max_steps_per_episode)
|
|
1122
|
+
for op_idx, op in enumerate(ops_seq):
|
|
1123
|
+
# Check for abort
|
|
1124
|
+
if registry.is_run_aborted(request.run_id):
|
|
1125
|
+
logger.info(f"Run {request.run_id} aborted at op {op_idx}")
|
|
1126
|
+
break
|
|
1127
|
+
|
|
1128
|
+
# Check safety limits
|
|
1129
|
+
if ops_executed >= request.safety.max_ops:
|
|
1130
|
+
logger.warning(f"Reached max_ops limit ({request.safety.max_ops})")
|
|
1131
|
+
break
|
|
1132
|
+
|
|
1133
|
+
if op == "agent":
|
|
1134
|
+
# Policy step
|
|
1135
|
+
from .policy_routes import PolicyStepRequest, step_policy
|
|
1136
|
+
|
|
1137
|
+
if not decision_open:
|
|
1138
|
+
await tracing_context.start_decision(decision_index)
|
|
1139
|
+
decision_open = True
|
|
1140
|
+
|
|
1141
|
+
agent_request_start = _time.perf_counter()
|
|
1142
|
+
if last_agent_response_ts is not None and last_policy_meta is not None:
|
|
1143
|
+
with contextlib.suppress(Exception):
|
|
1144
|
+
timing_prev = last_policy_meta.setdefault("timing", {})
|
|
1145
|
+
decision_ms = max(
|
|
1146
|
+
0.0,
|
|
1147
|
+
(agent_request_start - float(last_agent_response_ts)) * 1000.0,
|
|
1148
|
+
)
|
|
1149
|
+
# Update timing on prior policy meta (kept by previous env step)
|
|
1150
|
+
timing_prev["decision_ms"] = decision_ms
|
|
1151
|
+
if last_env_step_ms is not None:
|
|
1152
|
+
timing_prev["env_step_ms"] = float(last_env_step_ms)
|
|
1153
|
+
timing_prev["overhead_ms"] = max(
|
|
1154
|
+
0.0, decision_ms - float(last_env_step_ms)
|
|
1155
|
+
)
|
|
1156
|
+
else:
|
|
1157
|
+
timing_prev.setdefault("overhead_ms", 0.0)
|
|
1158
|
+
timing_prev["decision_ready_s"] = agent_request_start
|
|
1159
|
+
# Also backfill the last appended trajectory step so the trainer
|
|
1160
|
+
# can always see decision_ms without relying on shared dict refs.
|
|
1161
|
+
if trajectory_steps:
|
|
1162
|
+
with contextlib.suppress(Exception):
|
|
1163
|
+
_last = trajectory_steps[-1]
|
|
1164
|
+
_info = dict(_last.info or {})
|
|
1165
|
+
_meta = dict(_info.get("meta") or {})
|
|
1166
|
+
_timing = dict(_meta.get("timing") or {})
|
|
1167
|
+
_timing["decision_ms"] = decision_ms
|
|
1168
|
+
if last_env_step_ms is not None:
|
|
1169
|
+
_timing.setdefault("env_step_ms", float(last_env_step_ms))
|
|
1170
|
+
_timing.setdefault(
|
|
1171
|
+
"overhead_ms",
|
|
1172
|
+
max(0.0, decision_ms - float(last_env_step_ms)),
|
|
1173
|
+
)
|
|
1174
|
+
else:
|
|
1175
|
+
_timing.setdefault("overhead_ms", 0.0)
|
|
1176
|
+
_meta["timing"] = _timing
|
|
1177
|
+
_info["meta"] = _meta
|
|
1178
|
+
_last.info = _info
|
|
1179
|
+
last_env_step_ms = None
|
|
1180
|
+
last_env_step_completed_ts = None
|
|
1181
|
+
|
|
1182
|
+
# Build metadata for policy (carry previous tool_calls and env result)
|
|
1183
|
+
metadata = {}
|
|
1184
|
+
if pending_tool_calls:
|
|
1185
|
+
metadata["prev_tool_calls"] = pending_tool_calls
|
|
1186
|
+
if len(trajectory_steps) > 0:
|
|
1187
|
+
last_step = trajectory_steps[-1]
|
|
1188
|
+
# Prefer the last executed tool calls to seed history
|
|
1189
|
+
if last_step.tool_calls:
|
|
1190
|
+
metadata["prev_tool_calls"] = last_step.tool_calls
|
|
1191
|
+
# Provide a compact env result snapshot
|
|
1192
|
+
metadata["prev_env_result"] = {
|
|
1193
|
+
"observation": last_step.obs,
|
|
1194
|
+
"reward": last_step.reward,
|
|
1195
|
+
"done": last_step.done,
|
|
1196
|
+
"truncated": last_step.truncated,
|
|
1197
|
+
"info": last_step.info,
|
|
1198
|
+
}
|
|
1199
|
+
|
|
1200
|
+
# Log compact metadata summary to confirm history threading
|
|
1201
|
+
with contextlib.suppress(Exception):
|
|
1202
|
+
_prev_calls = metadata.get("prev_tool_calls")
|
|
1203
|
+
_count = len(_prev_calls) if isinstance(_prev_calls, list) else 0
|
|
1204
|
+
_first_guess = None
|
|
1205
|
+
if _count > 0 and isinstance(_prev_calls[0], dict):
|
|
1206
|
+
_args = _prev_calls[0].get("arguments", None)
|
|
1207
|
+
if isinstance(_args, str):
|
|
1208
|
+
import json as _json
|
|
1209
|
+
with contextlib.suppress(Exception):
|
|
1210
|
+
_args = _json.loads(_args)
|
|
1211
|
+
if not isinstance(_args, dict):
|
|
1212
|
+
_args = {}
|
|
1213
|
+
_first_guess = _args.get("guess") or _args.get("word")
|
|
1214
|
+
logger.info(
|
|
1215
|
+
"POLICY_METADATA: prev_tool_calls=%d first_guess=%r has_prev_env_result=%s",
|
|
1216
|
+
_count,
|
|
1217
|
+
_first_guess,
|
|
1218
|
+
str("prev_env_result" in metadata),
|
|
1219
|
+
)
|
|
1220
|
+
|
|
1221
|
+
try:
|
|
1222
|
+
policy_response = await step_policy(
|
|
1223
|
+
PolicyStepRequest(
|
|
1224
|
+
policy_id=policy_id,
|
|
1225
|
+
observation=current_obs,
|
|
1226
|
+
metadata=metadata,
|
|
1227
|
+
),
|
|
1228
|
+
req,
|
|
1229
|
+
)
|
|
1230
|
+
except Exception as _pe:
|
|
1231
|
+
# Do not 500 the rollout; finalize with partial trajectory
|
|
1232
|
+
with contextlib.suppress(Exception):
|
|
1233
|
+
logger.warning(
|
|
1234
|
+
"POLICY_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
|
|
1235
|
+
request.run_id,
|
|
1236
|
+
str(op_idx),
|
|
1237
|
+
str(_pe),
|
|
1238
|
+
)
|
|
1239
|
+
|
|
1240
|
+
# Build partial trajectory and return HTTP 200
|
|
1241
|
+
trajectory = RolloutTrajectory(
|
|
1242
|
+
env_id=env_id,
|
|
1243
|
+
policy_id=policy_id,
|
|
1244
|
+
steps=trajectory_steps,
|
|
1245
|
+
final={
|
|
1246
|
+
"observation": current_obs,
|
|
1247
|
+
"rollout_status": "partial_policy_error",
|
|
1248
|
+
"error": str(_pe),
|
|
1249
|
+
"at_op": op,
|
|
1250
|
+
},
|
|
1251
|
+
length=len(trajectory_steps),
|
|
1252
|
+
decision_samples=decision_samples if step_rewards_active else None,
|
|
1253
|
+
)
|
|
1254
|
+
metrics = RolloutMetrics(
|
|
1255
|
+
episode_returns=[total_reward],
|
|
1256
|
+
mean_return=total_reward,
|
|
1257
|
+
num_steps=len(trajectory_steps),
|
|
1258
|
+
num_episodes=1,
|
|
1259
|
+
)
|
|
1260
|
+
aborted = registry.is_run_aborted(request.run_id)
|
|
1261
|
+
if not aborted:
|
|
1262
|
+
registry.complete_run(request.run_id)
|
|
1263
|
+
if decision_open:
|
|
1264
|
+
await tracing_context.end_decision()
|
|
1265
|
+
decision_open = False
|
|
1266
|
+
if not finalized:
|
|
1267
|
+
session_trace = await tracing_context.finalize(
|
|
1268
|
+
total_reward=total_reward,
|
|
1269
|
+
achievement_state=prev_achievements,
|
|
1270
|
+
total_steps=len(trajectory_steps),
|
|
1271
|
+
)
|
|
1272
|
+
finalized = True
|
|
1273
|
+
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1274
|
+
return RolloutResponse(
|
|
1275
|
+
run_id=request.run_id,
|
|
1276
|
+
trajectories=[trajectory],
|
|
1277
|
+
branches={},
|
|
1278
|
+
metrics=metrics,
|
|
1279
|
+
aborted=aborted,
|
|
1280
|
+
ops_executed=ops_executed,
|
|
1281
|
+
trace=trace_payload,
|
|
1282
|
+
)
|
|
1283
|
+
|
|
1284
|
+
agent_response_ts = _time.perf_counter()
|
|
1285
|
+
if isinstance(policy_response.meta, dict):
|
|
1286
|
+
with contextlib.suppress(Exception):
|
|
1287
|
+
timing_cur = policy_response.meta.setdefault("timing", {})
|
|
1288
|
+
timing_cur["agent_request_start_s"] = agent_request_start
|
|
1289
|
+
timing_cur["agent_response_s"] = agent_response_ts
|
|
1290
|
+
if "inference_ms" in policy_response.meta:
|
|
1291
|
+
with contextlib.suppress(Exception):
|
|
1292
|
+
timing_cur.setdefault(
|
|
1293
|
+
"inference_ms",
|
|
1294
|
+
float(policy_response.meta["inference_ms"]),
|
|
1295
|
+
)
|
|
1296
|
+
timing_cur.setdefault(
|
|
1297
|
+
"inference_s",
|
|
1298
|
+
float(policy_response.meta["inference_ms"]) / 1000.0,
|
|
1299
|
+
)
|
|
1300
|
+
last_policy_meta = policy_response.meta
|
|
1301
|
+
else:
|
|
1302
|
+
last_policy_meta = None
|
|
1303
|
+
last_agent_response_ts = agent_response_ts
|
|
1304
|
+
|
|
1305
|
+
# Diagnostic: summarize policy step target and tool calls
|
|
1306
|
+
try:
|
|
1307
|
+
model_name = None
|
|
1308
|
+
target_url = None
|
|
1309
|
+
if isinstance(policy_response.meta, dict):
|
|
1310
|
+
req_body = policy_response.meta.get("inference_request") or {}
|
|
1311
|
+
model_name = req_body.get("model")
|
|
1312
|
+
target_url = policy_response.meta.get("inference_url")
|
|
1313
|
+
_tc = policy_response.tool_calls or []
|
|
1314
|
+
print(
|
|
1315
|
+
{
|
|
1316
|
+
"rollout.policy_step": True,
|
|
1317
|
+
"run_id": request.run_id,
|
|
1318
|
+
"model": model_name,
|
|
1319
|
+
"inference_url": target_url,
|
|
1320
|
+
"tool_calls_count": len(_tc) if isinstance(_tc, list) else 0,
|
|
1321
|
+
},
|
|
1322
|
+
flush=True,
|
|
1323
|
+
)
|
|
1324
|
+
except Exception:
|
|
1325
|
+
pass
|
|
1326
|
+
|
|
1327
|
+
pending_tool_calls = policy_response.tool_calls
|
|
1328
|
+
# Log summarized agent tool calls
|
|
1329
|
+
with contextlib.suppress(Exception):
|
|
1330
|
+
_tc = pending_tool_calls or []
|
|
1331
|
+
_summary = []
|
|
1332
|
+
for _item in (_tc if isinstance(_tc, list) else []):
|
|
1333
|
+
try:
|
|
1334
|
+
if isinstance(_item, dict):
|
|
1335
|
+
_tool = _item.get("tool")
|
|
1336
|
+
_args = _item.get("args")
|
|
1337
|
+
_keys = list(_args.keys()) if isinstance(_args, dict) else []
|
|
1338
|
+
_summary.append({"tool": _tool, "args_keys": _keys})
|
|
1339
|
+
except Exception:
|
|
1340
|
+
continue
|
|
1341
|
+
_rid = getattr(request, "run_id", None)
|
|
1342
|
+
logger.info("AGENT_TOOL_CALLS: run_id=%s count=%d summary=%s", _rid, len(_tc), _summary)
|
|
1343
|
+
print(f"[rollout] agent tool_calls run_id={_rid} count={len(_tc)} summary={_summary}", flush=True)
|
|
1344
|
+
await tracing_context.record_tool_invocation(pending_tool_calls)
|
|
1345
|
+
ops_executed += 1
|
|
1346
|
+
|
|
1347
|
+
elif op == "env":
|
|
1348
|
+
if not pending_tool_calls:
|
|
1349
|
+
# Treat absence of tool calls as a soft terminal condition; yield partial trajectory
|
|
1350
|
+
with contextlib.suppress(Exception):
|
|
1351
|
+
logger.warning(
|
|
1352
|
+
"NO_TOOL_CALLS: terminating episode early run_id=%s op_idx=%s",
|
|
1353
|
+
request.run_id,
|
|
1354
|
+
str(op_idx),
|
|
1355
|
+
)
|
|
1356
|
+
print(
|
|
1357
|
+
f"[rollout] no tool_calls; terminating early run_id={request.run_id} op_idx={op_idx}",
|
|
1358
|
+
flush=True,
|
|
1359
|
+
)
|
|
1360
|
+
term_step = RolloutStep(
|
|
1361
|
+
obs=current_obs,
|
|
1362
|
+
tool_calls=[],
|
|
1363
|
+
reward=None,
|
|
1364
|
+
done=True,
|
|
1365
|
+
truncated=False,
|
|
1366
|
+
info={
|
|
1367
|
+
"terminated": True,
|
|
1368
|
+
"reason": "no_tool_calls",
|
|
1369
|
+
},
|
|
1370
|
+
)
|
|
1371
|
+
trajectory_steps.append(term_step)
|
|
1372
|
+
trajectory = RolloutTrajectory(
|
|
1373
|
+
env_id=env_id,
|
|
1374
|
+
policy_id=policy_id,
|
|
1375
|
+
steps=trajectory_steps,
|
|
1376
|
+
final={
|
|
1377
|
+
"observation": current_obs,
|
|
1378
|
+
"rollout_status": "partial_no_tool_calls",
|
|
1379
|
+
"at_op": op,
|
|
1380
|
+
},
|
|
1381
|
+
length=len(trajectory_steps),
|
|
1382
|
+
decision_samples=decision_samples if step_rewards_active else None,
|
|
1383
|
+
)
|
|
1384
|
+
metrics = RolloutMetrics(
|
|
1385
|
+
episode_returns=[total_reward],
|
|
1386
|
+
mean_return=total_reward,
|
|
1387
|
+
num_steps=len(trajectory_steps),
|
|
1388
|
+
num_episodes=1,
|
|
1389
|
+
)
|
|
1390
|
+
aborted = registry.is_run_aborted(request.run_id)
|
|
1391
|
+
if not aborted:
|
|
1392
|
+
registry.complete_run(request.run_id)
|
|
1393
|
+
if decision_open:
|
|
1394
|
+
await tracing_context.end_decision()
|
|
1395
|
+
decision_open = False
|
|
1396
|
+
if not finalized:
|
|
1397
|
+
session_trace = await tracing_context.finalize(
|
|
1398
|
+
total_reward=total_reward,
|
|
1399
|
+
achievement_state=prev_achievements,
|
|
1400
|
+
total_steps=len(trajectory_steps),
|
|
1401
|
+
)
|
|
1402
|
+
finalized = True
|
|
1403
|
+
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1404
|
+
return RolloutResponse(
|
|
1405
|
+
run_id=request.run_id,
|
|
1406
|
+
trajectories=[trajectory],
|
|
1407
|
+
branches={},
|
|
1408
|
+
metrics=metrics,
|
|
1409
|
+
aborted=aborted,
|
|
1410
|
+
ops_executed=ops_executed,
|
|
1411
|
+
trace=trace_payload,
|
|
1412
|
+
)
|
|
1413
|
+
|
|
1414
|
+
# Environment step
|
|
1415
|
+
from .environment_routes import EnvStepRequest, step_environment
|
|
1416
|
+
|
|
1417
|
+
env_step_error: Exception | None = None
|
|
1418
|
+
env_response = None
|
|
1419
|
+
env_step_start = _time.perf_counter()
|
|
1420
|
+
try:
|
|
1421
|
+
env_response = await step_environment(
|
|
1422
|
+
EnvStepRequest(
|
|
1423
|
+
env_id=env_id,
|
|
1424
|
+
tool_calls=pending_tool_calls,
|
|
1425
|
+
)
|
|
1426
|
+
)
|
|
1427
|
+
except Exception as _ee:
|
|
1428
|
+
env_step_error = _ee
|
|
1429
|
+
env_step_end = _time.perf_counter()
|
|
1430
|
+
env_step_duration_ms = (env_step_end - env_step_start) * 1000.0
|
|
1431
|
+
last_env_step_ms = env_step_duration_ms
|
|
1432
|
+
last_env_step_completed_ts = env_step_end
|
|
1433
|
+
if last_policy_meta is not None:
|
|
1434
|
+
with contextlib.suppress(Exception):
|
|
1435
|
+
timing_env = last_policy_meta.setdefault("timing", {})
|
|
1436
|
+
timing_env["env_step_ms"] = env_step_duration_ms
|
|
1437
|
+
timing_env["env_step_end_s"] = env_step_end
|
|
1438
|
+
|
|
1439
|
+
if env_step_error is not None:
|
|
1440
|
+
# Invalid action or environment rejection — terminate episode early with partial trajectory
|
|
1441
|
+
with contextlib.suppress(Exception):
|
|
1442
|
+
logger.warning(
|
|
1443
|
+
"ENV_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
|
|
1444
|
+
request.run_id,
|
|
1445
|
+
str(op_idx),
|
|
1446
|
+
str(env_step_error),
|
|
1447
|
+
)
|
|
1448
|
+
|
|
1449
|
+
term_step = RolloutStep(
|
|
1450
|
+
obs=current_obs,
|
|
1451
|
+
tool_calls=pending_tool_calls,
|
|
1452
|
+
reward=None,
|
|
1453
|
+
done=True,
|
|
1454
|
+
truncated=False,
|
|
1455
|
+
info={
|
|
1456
|
+
"terminated": True,
|
|
1457
|
+
"reason": "invalid_action",
|
|
1458
|
+
"error": str(env_step_error),
|
|
1459
|
+
},
|
|
1460
|
+
)
|
|
1461
|
+
trajectory_steps.append(term_step)
|
|
1462
|
+
# Build partial response
|
|
1463
|
+
trajectory = RolloutTrajectory(
|
|
1464
|
+
env_id=env_id,
|
|
1465
|
+
policy_id=policy_id,
|
|
1466
|
+
steps=trajectory_steps,
|
|
1467
|
+
final={
|
|
1468
|
+
"observation": current_obs,
|
|
1469
|
+
"rollout_status": "partial_invalid_action",
|
|
1470
|
+
"error": str(env_step_error),
|
|
1471
|
+
"at_op": op,
|
|
1472
|
+
},
|
|
1473
|
+
length=len(trajectory_steps),
|
|
1474
|
+
decision_samples=decision_samples if step_rewards_active else None,
|
|
1475
|
+
)
|
|
1476
|
+
metrics = RolloutMetrics(
|
|
1477
|
+
episode_returns=[total_reward],
|
|
1478
|
+
mean_return=total_reward,
|
|
1479
|
+
num_steps=len(trajectory_steps),
|
|
1480
|
+
num_episodes=1,
|
|
1481
|
+
)
|
|
1482
|
+
aborted = registry.is_run_aborted(request.run_id)
|
|
1483
|
+
if not aborted:
|
|
1484
|
+
registry.complete_run(request.run_id)
|
|
1485
|
+
if (
|
|
1486
|
+
last_policy_meta is not None
|
|
1487
|
+
and last_agent_response_ts is not None
|
|
1488
|
+
and "decision_ms" not in last_policy_meta.get("timing", {})
|
|
1489
|
+
):
|
|
1490
|
+
with contextlib.suppress(Exception):
|
|
1491
|
+
timing_last = last_policy_meta.setdefault("timing", {})
|
|
1492
|
+
decision_ms = max(
|
|
1493
|
+
0.0,
|
|
1494
|
+
(env_step_end - float(last_agent_response_ts)) * 1000.0,
|
|
1495
|
+
)
|
|
1496
|
+
timing_last["decision_ms"] = decision_ms
|
|
1497
|
+
timing_last.setdefault(
|
|
1498
|
+
"overhead_ms", max(0.0, decision_ms - env_step_duration_ms)
|
|
1499
|
+
)
|
|
1500
|
+
if decision_open:
|
|
1501
|
+
await tracing_context.end_decision()
|
|
1502
|
+
decision_open = False
|
|
1503
|
+
if not finalized:
|
|
1504
|
+
session_trace = await tracing_context.finalize(
|
|
1505
|
+
total_reward=total_reward,
|
|
1506
|
+
achievement_state=prev_achievements,
|
|
1507
|
+
total_steps=len(trajectory_steps),
|
|
1508
|
+
)
|
|
1509
|
+
finalized = True
|
|
1510
|
+
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1511
|
+
return RolloutResponse(
|
|
1512
|
+
run_id=request.run_id,
|
|
1513
|
+
trajectories=[trajectory],
|
|
1514
|
+
branches={},
|
|
1515
|
+
metrics=metrics,
|
|
1516
|
+
aborted=aborted,
|
|
1517
|
+
ops_executed=ops_executed,
|
|
1518
|
+
trace=trace_payload,
|
|
1519
|
+
)
|
|
1520
|
+
|
|
1521
|
+
# Reaching here means env step succeeded
|
|
1522
|
+
assert env_response is not None
|
|
1523
|
+
|
|
1524
|
+
# Record step, including policy meta if present for timing/tokens observability
|
|
1525
|
+
_info = env_response.info if isinstance(env_response.info, dict) else {}
|
|
1526
|
+
# Attach policy meta from the immediately preceding agent step
|
|
1527
|
+
with contextlib.suppress(Exception):
|
|
1528
|
+
prev_meta = {}
|
|
1529
|
+
if "policy_response" in locals() and isinstance(policy_response.meta, dict): # type: ignore[name-defined]
|
|
1530
|
+
prev_meta = policy_response.meta
|
|
1531
|
+
if prev_meta:
|
|
1532
|
+
_info = dict(_info)
|
|
1533
|
+
_info["meta"] = prev_meta
|
|
1534
|
+
|
|
1535
|
+
event_metadata = {
|
|
1536
|
+
"op_index": op_idx,
|
|
1537
|
+
}
|
|
1538
|
+
event_id = await tracing_context.record_environment_event(
|
|
1539
|
+
env_handle=env_handle,
|
|
1540
|
+
prev_obs=current_obs,
|
|
1541
|
+
env_response=env_response,
|
|
1542
|
+
next_obs=getattr(env_response, "observation", None),
|
|
1543
|
+
metadata=event_metadata,
|
|
1544
|
+
)
|
|
1545
|
+
|
|
1546
|
+
decision_index += 1
|
|
1547
|
+
next_obs = env_response.observation
|
|
1548
|
+
new_achievement_state = _extract_achievements(next_obs)
|
|
1549
|
+
final_achievement_count = sum(
|
|
1550
|
+
1 for _, unlocked in new_achievement_state.items() if unlocked
|
|
1551
|
+
)
|
|
1552
|
+
indicator_val = 0
|
|
1553
|
+
reward_stepwise = 0.0
|
|
1554
|
+
decision_rewards_meta: dict[str, Any] | None = None
|
|
1555
|
+
if step_rewards_active:
|
|
1556
|
+
decision_actions = _summarize_tool_calls(pending_tool_calls)
|
|
1557
|
+
stepwise_info, decision_record, stats = compute_stepwise_reward(
|
|
1558
|
+
prev_achievements or {},
|
|
1559
|
+
new_achievement_state,
|
|
1560
|
+
decision_index,
|
|
1561
|
+
decision_actions,
|
|
1562
|
+
step_rewards_indicator_lambda,
|
|
1563
|
+
)
|
|
1564
|
+
indicator_val = int(stats.get("indicator", 0.0))
|
|
1565
|
+
reward_stepwise = float(stats.get("reward", 0.0))
|
|
1566
|
+
stepwise_indicator_sum += float(stats.get("indicator", 0.0))
|
|
1567
|
+
stepwise_reward_sum += reward_stepwise
|
|
1568
|
+
stepwise_new_achievements_total += int(stats.get("new_achievements_count", 0.0))
|
|
1569
|
+
_info = {} if not isinstance(_info, dict) else dict(_info)
|
|
1570
|
+
_info["stepwise"] = stepwise_info
|
|
1571
|
+
# Compute decision-level rewards (absolute vs unique) and attach to metadata
|
|
1572
|
+
with contextlib.suppress(Exception):
|
|
1573
|
+
turned_true = set(stepwise_info.get("new_achievements") or [])
|
|
1574
|
+
seen_before = set(episode_seen_achievements)
|
|
1575
|
+
new_unique = sorted(turned_true - seen_before)
|
|
1576
|
+
ach_delta = int(len(turned_true))
|
|
1577
|
+
unique_delta = int(len(new_unique))
|
|
1578
|
+
# Prepare stable lists for logging/metadata
|
|
1579
|
+
all_list = sorted(turned_true)
|
|
1580
|
+
# Ensure nested meta exists
|
|
1581
|
+
meta_block = (
|
|
1582
|
+
_info.get("meta") if isinstance(_info.get("meta"), dict) else {}
|
|
1583
|
+
)
|
|
1584
|
+
decision_rewards = {
|
|
1585
|
+
"turn": int(decision_index),
|
|
1586
|
+
"ach_delta": ach_delta,
|
|
1587
|
+
"unique_delta": unique_delta,
|
|
1588
|
+
"all": all_list,
|
|
1589
|
+
"unique": new_unique,
|
|
1590
|
+
}
|
|
1591
|
+
decision_rewards_meta = decision_rewards
|
|
1592
|
+
meta_block["decision_rewards"] = decision_rewards
|
|
1593
|
+
_info["meta"] = meta_block
|
|
1594
|
+
# Update episode-level seen set after attributing uniqueness to this decision
|
|
1595
|
+
episode_seen_achievements.update(turned_true)
|
|
1596
|
+
decision_samples.append(decision_record)
|
|
1597
|
+
prev_achievements = new_achievement_state
|
|
1598
|
+
|
|
1599
|
+
await tracing_context.record_decision_reward(
|
|
1600
|
+
event_id=event_id,
|
|
1601
|
+
decision_meta=decision_rewards_meta,
|
|
1602
|
+
)
|
|
1603
|
+
|
|
1604
|
+
step = RolloutStep(
|
|
1605
|
+
obs=_summarize_observation_for_storage(env_handle, current_obs),
|
|
1606
|
+
tool_calls=pending_tool_calls,
|
|
1607
|
+
reward=env_response.reward,
|
|
1608
|
+
done=env_response.done,
|
|
1609
|
+
truncated=env_response.truncated,
|
|
1610
|
+
info=_info,
|
|
1611
|
+
)
|
|
1612
|
+
# Log summarized env application of tool calls and immediate reward/done
|
|
1613
|
+
with contextlib.suppress(Exception):
|
|
1614
|
+
_tc = pending_tool_calls or []
|
|
1615
|
+
_summary = []
|
|
1616
|
+
for _item in (_tc if isinstance(_tc, list) else []):
|
|
1617
|
+
try:
|
|
1618
|
+
if isinstance(_item, dict):
|
|
1619
|
+
_tool = _item.get("tool")
|
|
1620
|
+
_args = _item.get("args")
|
|
1621
|
+
_keys = list(_args.keys()) if isinstance(_args, dict) else []
|
|
1622
|
+
_summary.append({"tool": _tool, "args_keys": _keys})
|
|
1623
|
+
except Exception:
|
|
1624
|
+
continue
|
|
1625
|
+
_rid = getattr(request, "run_id", None)
|
|
1626
|
+
logger.info(
|
|
1627
|
+
"ENV_APPLY: run_id=%s tool_calls=%d reward=%s done=%s summary=%s",
|
|
1628
|
+
_rid,
|
|
1629
|
+
len(_tc),
|
|
1630
|
+
str(env_response.reward),
|
|
1631
|
+
str(env_response.done),
|
|
1632
|
+
_summary,
|
|
1633
|
+
)
|
|
1634
|
+
print(
|
|
1635
|
+
f"[rollout] env apply run_id={_rid} tool_calls={len(_tc)} reward={env_response.reward} done={env_response.done} summary={_summary}",
|
|
1636
|
+
flush=True,
|
|
1637
|
+
)
|
|
1638
|
+
trajectory_steps.append(step)
|
|
1639
|
+
|
|
1640
|
+
if env_response.reward is not None:
|
|
1641
|
+
total_reward += env_response.reward
|
|
1642
|
+
|
|
1643
|
+
# Update state
|
|
1644
|
+
current_obs = next_obs
|
|
1645
|
+
pending_tool_calls = None
|
|
1646
|
+
ops_executed += 1
|
|
1647
|
+
|
|
1648
|
+
# Handle episode end
|
|
1649
|
+
if env_response.done:
|
|
1650
|
+
if request.on_done == "reset":
|
|
1651
|
+
# Reset environment
|
|
1652
|
+
from .environment_routes import (
|
|
1653
|
+
EnvResetRequest,
|
|
1654
|
+
reset_environment,
|
|
1655
|
+
)
|
|
1656
|
+
|
|
1657
|
+
reset_response = await reset_environment(EnvResetRequest(env_id=env_id))
|
|
1658
|
+
current_obs = reset_response.observation
|
|
1659
|
+
elif request.on_done == "terminate":
|
|
1660
|
+
break
|
|
1661
|
+
|
|
1662
|
+
if decision_open:
|
|
1663
|
+
await tracing_context.end_decision()
|
|
1664
|
+
decision_open = False
|
|
1665
|
+
|
|
1666
|
+
else:
|
|
1667
|
+
logger.warning(f"Unknown op: {op}")
|
|
1668
|
+
|
|
1669
|
+
if (
|
|
1670
|
+
last_policy_meta is not None
|
|
1671
|
+
and last_agent_response_ts is not None
|
|
1672
|
+
and "timing" in last_policy_meta
|
|
1673
|
+
and isinstance(last_policy_meta["timing"], dict)
|
|
1674
|
+
and "decision_ms" not in last_policy_meta["timing"]
|
|
1675
|
+
):
|
|
1676
|
+
with contextlib.suppress(Exception):
|
|
1677
|
+
final_now = last_env_step_completed_ts or _time.perf_counter()
|
|
1678
|
+
final_decision_ms = max(0.0, (final_now - float(last_agent_response_ts)) * 1000.0)
|
|
1679
|
+
timing_final = last_policy_meta.setdefault("timing", {})
|
|
1680
|
+
timing_final["decision_ms"] = final_decision_ms
|
|
1681
|
+
if last_env_step_ms is not None:
|
|
1682
|
+
timing_final.setdefault("env_step_ms", float(last_env_step_ms))
|
|
1683
|
+
timing_final.setdefault(
|
|
1684
|
+
"overhead_ms",
|
|
1685
|
+
max(0.0, final_decision_ms - float(last_env_step_ms)),
|
|
1686
|
+
)
|
|
1687
|
+
else:
|
|
1688
|
+
timing_final.setdefault("overhead_ms", 0.0)
|
|
1689
|
+
|
|
1690
|
+
# Build trajectory
|
|
1691
|
+
trajectory = RolloutTrajectory(
|
|
1692
|
+
env_id=env_id,
|
|
1693
|
+
policy_id=policy_id,
|
|
1694
|
+
steps=trajectory_steps,
|
|
1695
|
+
final={"observation": _summarize_observation_for_storage(env_handle, current_obs)},
|
|
1696
|
+
length=len(trajectory_steps),
|
|
1697
|
+
decision_samples=decision_samples if step_rewards_active else None,
|
|
1698
|
+
)
|
|
1699
|
+
|
|
1700
|
+
# Build metrics
|
|
1701
|
+
metrics = RolloutMetrics(
|
|
1702
|
+
episode_returns=[total_reward],
|
|
1703
|
+
mean_return=total_reward,
|
|
1704
|
+
num_steps=len(trajectory_steps),
|
|
1705
|
+
num_episodes=1,
|
|
1706
|
+
)
|
|
1707
|
+
|
|
1708
|
+
# Environment-specific: Log summary if available
|
|
1709
|
+
try:
|
|
1710
|
+
# Check if this is a Wordle environment and use Wordle helpers (lazy import)
|
|
1711
|
+
wordle_wrapper_cls = None
|
|
1712
|
+
try:
|
|
1713
|
+
from .envs.wordle.environment import WordleEnvironmentWrapper
|
|
1714
|
+
from .envs.wordle.helpers import (
|
|
1715
|
+
get_wordle_rollout_summary,
|
|
1716
|
+
log_wordle_rollout_summary,
|
|
1717
|
+
)
|
|
1718
|
+
|
|
1719
|
+
wordle_wrapper_cls = WordleEnvironmentWrapper
|
|
1720
|
+
except Exception:
|
|
1721
|
+
wordle_wrapper_cls = None # type: ignore[assignment]
|
|
1722
|
+
get_wordle_rollout_summary = None # type: ignore
|
|
1723
|
+
log_wordle_rollout_summary = None # type: ignore
|
|
1724
|
+
|
|
1725
|
+
is_wordle = wordle_wrapper_cls is not None and isinstance(
|
|
1726
|
+
env_handle.env,
|
|
1727
|
+
wordle_wrapper_cls, # type: ignore[arg-type]
|
|
1728
|
+
)
|
|
1729
|
+
if is_wordle:
|
|
1730
|
+
# Convert trajectory steps to expected format
|
|
1731
|
+
formatted_steps = []
|
|
1732
|
+
for step in trajectory_steps:
|
|
1733
|
+
formatted_steps.append({"tool_calls": step.tool_calls or []})
|
|
1734
|
+
|
|
1735
|
+
if (
|
|
1736
|
+
get_wordle_rollout_summary is not None
|
|
1737
|
+
and log_wordle_rollout_summary is not None
|
|
1738
|
+
):
|
|
1739
|
+
summary = get_wordle_rollout_summary(formatted_steps, current_obs, env_handle)
|
|
1740
|
+
log_wordle_rollout_summary(request.run_id, summary)
|
|
1741
|
+
except ImportError:
|
|
1742
|
+
# Wordle helpers not available, skip Wordle-specific logging
|
|
1743
|
+
pass
|
|
1744
|
+
except Exception as e:
|
|
1745
|
+
logger.warning(f"Failed to generate environment-specific summary: {e}")
|
|
1746
|
+
|
|
1747
|
+
# Mark run as completed
|
|
1748
|
+
aborted = registry.is_run_aborted(request.run_id)
|
|
1749
|
+
if not aborted:
|
|
1750
|
+
registry.complete_run(request.run_id)
|
|
1751
|
+
if decision_open:
|
|
1752
|
+
await tracing_context.end_decision()
|
|
1753
|
+
decision_open = False
|
|
1754
|
+
if not finalized:
|
|
1755
|
+
session_trace = await tracing_context.finalize(
|
|
1756
|
+
total_reward=total_reward,
|
|
1757
|
+
achievement_state=prev_achievements,
|
|
1758
|
+
total_steps=len(trajectory_steps),
|
|
1759
|
+
)
|
|
1760
|
+
finalized = True
|
|
1761
|
+
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1762
|
+
|
|
1763
|
+
return RolloutResponse(
|
|
1764
|
+
run_id=request.run_id,
|
|
1765
|
+
trajectories=[trajectory],
|
|
1766
|
+
branches={},
|
|
1767
|
+
metrics=metrics,
|
|
1768
|
+
aborted=aborted,
|
|
1769
|
+
ops_executed=ops_executed,
|
|
1770
|
+
trace=trace_payload,
|
|
1771
|
+
)
|
|
1772
|
+
|
|
1773
|
+
except Exception as e:
|
|
1774
|
+
logger.error(f"Rollout failed for run {request.run_id}: {e}")
|
|
1775
|
+
registry.abort_run(request.run_id)
|
|
1776
|
+
if decision_open:
|
|
1777
|
+
with contextlib.suppress(Exception):
|
|
1778
|
+
await tracing_context.end_decision()
|
|
1779
|
+
decision_open = False
|
|
1780
|
+
if not finalized:
|
|
1781
|
+
session_trace = None
|
|
1782
|
+
with contextlib.suppress(Exception):
|
|
1783
|
+
session_trace = await tracing_context.finalize(
|
|
1784
|
+
total_reward=total_reward,
|
|
1785
|
+
achievement_state=prev_achievements,
|
|
1786
|
+
total_steps=len(trajectory_steps),
|
|
1787
|
+
)
|
|
1788
|
+
finalized = True
|
|
1789
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
1790
|
+
finally:
|
|
1791
|
+
# Ensure any environment created for this rollout is terminated (no reuse across rollouts)
|
|
1792
|
+
try:
|
|
1793
|
+
if created_env_id:
|
|
1794
|
+
from .environment_routes import EnvTerminateRequest, terminate_environment
|
|
1795
|
+
|
|
1796
|
+
await terminate_environment(EnvTerminateRequest(env_id=created_env_id))
|
|
1797
|
+
logger.info(
|
|
1798
|
+
"ROLL_OUT: terminated environment env_id=%s seed=%s",
|
|
1799
|
+
str(created_env_id),
|
|
1800
|
+
str(env_seed_used) if env_seed_used is not None else "unknown",
|
|
1801
|
+
)
|
|
1802
|
+
# Verify removal from registry
|
|
1803
|
+
with contextlib.suppress(Exception):
|
|
1804
|
+
_post = registry.get_env(created_env_id)
|
|
1805
|
+
logger.info(
|
|
1806
|
+
"ROLL_OUT: env_killed=%s (post_lookup=%s)",
|
|
1807
|
+
str(_post is None),
|
|
1808
|
+
str(_post),
|
|
1809
|
+
)
|
|
1810
|
+
except Exception as _te:
|
|
1811
|
+
logger.warning(f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}")
|
|
1812
|
+
|
|
1813
|
+
# Best-effort policy cleanup if we created one (avoid reuse across rollouts)
|
|
1814
|
+
with contextlib.suppress(Exception):
|
|
1815
|
+
if created_policy_id:
|
|
1816
|
+
from .policy_routes import PolicyTerminateRequest, terminate_policy
|
|
1817
|
+
|
|
1818
|
+
await terminate_policy(PolicyTerminateRequest(policy_id=created_policy_id))
|
|
1819
|
+
logger.info("ROLL_OUT: terminated policy policy_id=%s", str(created_policy_id))
|
|
1820
|
+
|
|
1821
|
+
if not finalized:
|
|
1822
|
+
session_trace = None
|
|
1823
|
+
with contextlib.suppress(Exception):
|
|
1824
|
+
session_trace = await tracing_context.finalize(
|
|
1825
|
+
total_reward=total_reward,
|
|
1826
|
+
achievement_state=prev_achievements,
|
|
1827
|
+
total_steps=len(trajectory_steps),
|
|
1828
|
+
)
|
|
1829
|
+
finalized = True
|
|
1830
|
+
|
|
1831
|
+
with contextlib.suppress(Exception):
|
|
1832
|
+
_clear_seed_side_effects()
|
|
1833
|
+
logger.info("ROLL_OUT: RNG seed terminated/cleared before conclusion")
|
|
1834
|
+
|
|
1835
|
+
|
|
1836
|
+
@router.post("/run/abort", response_model=RunAbortResponse)
|
|
1837
|
+
async def abort_run(request: RunAbortRequest) -> RunAbortResponse:
|
|
1838
|
+
"""Abort a running rollout."""
|
|
1839
|
+
success = registry.abort_run(request.run_id)
|
|
1840
|
+
|
|
1841
|
+
if not success:
|
|
1842
|
+
raise HTTPException(
|
|
1843
|
+
status_code=404,
|
|
1844
|
+
detail=f"Run {request.run_id} not found",
|
|
1845
|
+
)
|
|
1846
|
+
|
|
1847
|
+
return RunAbortResponse(
|
|
1848
|
+
ok=True,
|
|
1849
|
+
run_id=request.run_id,
|
|
1850
|
+
)
|
|
1851
|
+
|
|
1852
|
+
|
|
1853
|
+
@router.get("/run/status/{run_id}", response_model=RunStatusResponse)
|
|
1854
|
+
async def get_run_status(run_id: str) -> RunStatusResponse:
|
|
1855
|
+
"""Get the status of a run."""
|
|
1856
|
+
run_handle = registry.get_run(run_id)
|
|
1857
|
+
|
|
1858
|
+
if not run_handle:
|
|
1859
|
+
raise HTTPException(
|
|
1860
|
+
status_code=404,
|
|
1861
|
+
detail=f"Run {run_id} not found",
|
|
1862
|
+
)
|
|
1863
|
+
|
|
1864
|
+
return RunStatusResponse(
|
|
1865
|
+
run_id=run_id,
|
|
1866
|
+
status=run_handle.status,
|
|
1867
|
+
started_at=run_handle.started_at,
|
|
1868
|
+
finished_at=run_handle.finished_at,
|
|
1869
|
+
)
|