synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.9.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +8 -11
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +64 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +18 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +38 -0
- examples/qwen_coder/validate_jsonl.py +59 -0
- examples/rl/run_eval.py +36 -37
- examples/rl/run_rl_and_save.py +5 -5
- examples/rl/task_app/math_single_step.py +65 -43
- examples/rl/task_app/math_task_app.py +3 -3
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +5 -5
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +78 -21
- examples/warming_up_to_rl/groq_test.py +4 -4
- examples/warming_up_to_rl/manage_secrets.py +13 -18
- examples/warming_up_to_rl/run_eval.py +42 -44
- examples/warming_up_to_rl/run_fft_and_save.py +11 -16
- examples/warming_up_to_rl/run_local_rollout.py +1 -3
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
- examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
- examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
- examples/warming_up_to_rl/run_rl_and_save.py +5 -6
- examples/warming_up_to_rl/run_rollout_remote.py +8 -10
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth/__init__.py +14 -0
- synth_ai/__init__.py +26 -4
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +128 -21
- synth_ai/api/train/cli.py +80 -64
- synth_ai/api/train/config_finder.py +7 -2
- synth_ai/api/train/env_resolver.py +1 -1
- synth_ai/api/train/pollers.py +2 -1
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +1 -2
- synth_ai/api/train/utils.py +13 -44
- synth_ai/cli/__init__.py +8 -0
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +1 -2
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +2 -1
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +2 -1
- synth_ai/cli/root.py +11 -13
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +529 -179
- synth_ai/cli/traces.py +6 -4
- synth_ai/cli/watch.py +12 -18
- synth_ai/demo_registry.py +1 -1
- synth_ai/demos/core/cli.py +36 -43
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +17 -25
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +2 -5
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +4 -7
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/handshake.py +9 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +18 -10
- synth_ai/inference/client.py +15 -5
- synth_ai/jobs/client.py +78 -83
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +91 -24
- synth_ai/learning/config.py +2 -38
- synth_ai/learning/ft_client.py +4 -59
- synth_ai/learning/health.py +5 -6
- synth_ai/learning/jobs.py +31 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -4
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -8
- synth_ai/{rl → learning/rl}/env_keys.py +39 -15
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -281
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -24
- synth_ai/learning/validators.py +25 -28
- synth_ai/lm/__init__.py +21 -47
- synth_ai/main.py +6 -0
- synth_ai/task/__init__.py +25 -27
- synth_ai/task/apps/__init__.py +7 -8
- synth_ai/task/auth.py +8 -8
- synth_ai/task/client.py +14 -14
- synth_ai/task/contracts.py +36 -35
- synth_ai/task/datasets.py +6 -5
- synth_ai/task/errors.py +10 -10
- synth_ai/task/health.py +17 -9
- synth_ai/task/json.py +58 -23
- synth_ai/task/proxy.py +13 -9
- synth_ai/task/rubrics.py +16 -15
- synth_ai/task/server.py +12 -12
- synth_ai/task/tracing_utils.py +4 -4
- synth_ai/task/vendors.py +5 -6
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/decorators.py +18 -16
- synth_ai/tracing_v3/hooks.py +5 -5
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/session_tracer.py +40 -14
- synth_ai/tracing_v3/storage/base.py +85 -0
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -7
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -4
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/main.py +6 -6
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
- synth_ai/{lm → v0/lm}/overrides.py +2 -2
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- synth_ai-0.2.9.dev9.dist-info/METADATA +191 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/RECORD +268 -238
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/top_level.txt +1 -0
- examples/common_old/backend.py +0 -20
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1038
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -152
- examples/rl_old/task_app.py +0 -1131
- examples/warming_up_to_rl/old/event_rewards.md +0 -234
- examples/warming_up_to_rl/old/notes.md +0 -73
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/experimental/synth_oss.py +0 -445
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -249
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -838
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev7.dist-info/METADATA +0 -131
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,20 +1,20 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import contextlib
|
|
4
4
|
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import time as _time
|
|
5
8
|
from datetime import datetime
|
|
6
|
-
from
|
|
7
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
9
|
+
from typing import Any
|
|
8
10
|
|
|
9
11
|
from fastapi import APIRouter, HTTPException, Request, status
|
|
10
|
-
import os
|
|
11
|
-
import time as _time
|
|
12
12
|
from pydantic import BaseModel
|
|
13
13
|
from synth_ai.lm.vendors.base import BaseLMResponse
|
|
14
|
-
from synth_ai.
|
|
14
|
+
from synth_ai.task.tracing_utils import unique_sft_path
|
|
15
15
|
from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
|
|
16
16
|
from synth_ai.tracing_v3.llm_call_record_helpers import create_llm_call_record_from_response
|
|
17
|
-
from synth_ai.
|
|
17
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
18
18
|
|
|
19
19
|
from .registry import registry
|
|
20
20
|
|
|
@@ -22,48 +22,38 @@ logger = logging.getLogger(__name__)
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
# --- Seeding utilities (robust, optional deps) ---
|
|
25
|
-
def _set_global_seed(seed_value: int) ->
|
|
25
|
+
def _set_global_seed(seed_value: int) -> dict[str, Any]:
|
|
26
26
|
"""Set global RNG seeds across common libraries; return details for logging/restoration.
|
|
27
27
|
|
|
28
28
|
Returns a dict containing which libraries were seeded and prior states if obtainable.
|
|
29
29
|
"""
|
|
30
|
-
seeded:
|
|
31
|
-
|
|
30
|
+
seeded: dict[str, Any] = {"seed": int(seed_value), "libs": []}
|
|
31
|
+
with contextlib.suppress(Exception):
|
|
32
32
|
import random as _random # type: ignore
|
|
33
33
|
|
|
34
34
|
_random.seed(seed_value)
|
|
35
35
|
seeded["libs"].append("random")
|
|
36
|
-
|
|
37
|
-
pass
|
|
38
|
-
try:
|
|
36
|
+
with contextlib.suppress(Exception):
|
|
39
37
|
import numpy as _np # type: ignore
|
|
40
38
|
|
|
41
39
|
_np.random.seed(seed_value)
|
|
42
40
|
seeded["libs"].append("numpy")
|
|
43
|
-
|
|
44
|
-
pass
|
|
45
|
-
try:
|
|
41
|
+
with contextlib.suppress(Exception):
|
|
46
42
|
import torch as _torch # type: ignore
|
|
47
43
|
|
|
48
44
|
if hasattr(_torch, "manual_seed"):
|
|
49
45
|
_torch.manual_seed(seed_value)
|
|
50
46
|
seeded["libs"].append("torch")
|
|
51
47
|
# Make CUDA deterministic if present (best-effort)
|
|
52
|
-
|
|
48
|
+
with contextlib.suppress(Exception):
|
|
53
49
|
if getattr(_torch, "cuda", None) and _torch.cuda.is_available():
|
|
54
50
|
_torch.cuda.manual_seed_all(seed_value)
|
|
55
51
|
seeded.setdefault("cuda", True)
|
|
56
|
-
except Exception:
|
|
57
|
-
pass
|
|
58
52
|
# CUDNN deterministic flags (optional)
|
|
59
|
-
|
|
53
|
+
with contextlib.suppress(Exception):
|
|
60
54
|
if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
|
|
61
55
|
_torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
|
|
62
56
|
_torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
|
|
63
|
-
except Exception:
|
|
64
|
-
pass
|
|
65
|
-
except Exception:
|
|
66
|
-
pass
|
|
67
57
|
return seeded
|
|
68
58
|
|
|
69
59
|
|
|
@@ -71,39 +61,35 @@ def _clear_seed_side_effects() -> None:
|
|
|
71
61
|
"""Best-effort cleanup to avoid global deterministic side-effects between requests."""
|
|
72
62
|
# We cannot truly restore prior RNG states without capturing them; we just avoid
|
|
73
63
|
# leaving aggressive deterministic flags enabled where it matters.
|
|
74
|
-
|
|
64
|
+
with contextlib.suppress(Exception):
|
|
75
65
|
import torch as _torch # type: ignore
|
|
76
66
|
|
|
77
|
-
|
|
67
|
+
with contextlib.suppress(Exception):
|
|
78
68
|
if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
|
|
79
69
|
# Re-enable cudnn.benchmark default True only if it was True; safest is False -> leave as is.
|
|
80
70
|
# We'll keep deterministic False to avoid global impact; benchmark left False for stability.
|
|
81
71
|
_torch.backends.cudnn.deterministic = False # type: ignore[attr-defined]
|
|
82
|
-
except Exception:
|
|
83
|
-
pass
|
|
84
|
-
except Exception:
|
|
85
|
-
pass
|
|
86
72
|
|
|
87
73
|
|
|
88
74
|
router = APIRouter()
|
|
89
75
|
|
|
90
76
|
|
|
91
77
|
class RolloutEnvSpec(BaseModel):
|
|
92
|
-
env_id:
|
|
93
|
-
env_name:
|
|
94
|
-
config:
|
|
95
|
-
seed:
|
|
78
|
+
env_id: str | None = None
|
|
79
|
+
env_name: str | None = None
|
|
80
|
+
config: dict[str, Any] = {}
|
|
81
|
+
seed: int | None = None
|
|
96
82
|
|
|
97
83
|
|
|
98
84
|
class RolloutPolicySpec(BaseModel):
|
|
99
|
-
policy_id:
|
|
100
|
-
policy_name:
|
|
101
|
-
config:
|
|
85
|
+
policy_id: str | None = None
|
|
86
|
+
policy_name: str | None = None
|
|
87
|
+
config: dict[str, Any] = {}
|
|
102
88
|
|
|
103
89
|
|
|
104
90
|
class RolloutBranchConfig(BaseModel):
|
|
105
91
|
branch_every_n_steps: int = 0
|
|
106
|
-
branch_on_condition:
|
|
92
|
+
branch_on_condition: str | None = None
|
|
107
93
|
max_branches: int = 0
|
|
108
94
|
branch_policy: bool = False
|
|
109
95
|
branch_env: bool = False
|
|
@@ -126,43 +112,43 @@ class RolloutRequest(BaseModel):
|
|
|
126
112
|
run_id: str
|
|
127
113
|
env: RolloutEnvSpec
|
|
128
114
|
policy: RolloutPolicySpec
|
|
129
|
-
ops:
|
|
115
|
+
ops: list[str] # ["agent", "env", ...]
|
|
130
116
|
record: RolloutRecordConfig = RolloutRecordConfig()
|
|
131
117
|
on_done: str = "reset" # "reset" | "terminate"
|
|
132
|
-
branch:
|
|
118
|
+
branch: RolloutBranchConfig | None = None
|
|
133
119
|
safety: RolloutSafetyConfig = RolloutSafetyConfig()
|
|
134
120
|
# Optional run/session context
|
|
135
|
-
training_session_id:
|
|
136
|
-
synth_base_url:
|
|
121
|
+
training_session_id: str | None = None
|
|
122
|
+
synth_base_url: str | None = None
|
|
137
123
|
|
|
138
124
|
|
|
139
125
|
class RolloutStep(BaseModel):
|
|
140
|
-
obs:
|
|
141
|
-
tool_calls:
|
|
142
|
-
reward:
|
|
126
|
+
obs: dict[str, Any]
|
|
127
|
+
tool_calls: list[dict[str, Any]]
|
|
128
|
+
reward: float | None = None
|
|
143
129
|
done: bool = False
|
|
144
|
-
truncated:
|
|
145
|
-
logprob:
|
|
146
|
-
value:
|
|
147
|
-
info:
|
|
130
|
+
truncated: bool | None = None
|
|
131
|
+
logprob: float | None = None
|
|
132
|
+
value: float | None = None
|
|
133
|
+
info: dict[str, Any] | None = None
|
|
148
134
|
|
|
149
135
|
|
|
150
136
|
class RolloutTrajectory(BaseModel):
|
|
151
137
|
env_id: str
|
|
152
138
|
policy_id: str
|
|
153
|
-
steps:
|
|
154
|
-
final:
|
|
139
|
+
steps: list[RolloutStep]
|
|
140
|
+
final: dict[str, Any] | None = None
|
|
155
141
|
length: int
|
|
156
|
-
decision_samples:
|
|
142
|
+
decision_samples: list[dict[str, Any]] | None = None
|
|
157
143
|
|
|
158
144
|
|
|
159
145
|
def compute_stepwise_reward(
|
|
160
|
-
prev_achievements:
|
|
161
|
-
new_achievements:
|
|
146
|
+
prev_achievements: dict[str, bool],
|
|
147
|
+
new_achievements: dict[str, bool],
|
|
162
148
|
decision_index: int,
|
|
163
|
-
actions_summary:
|
|
149
|
+
actions_summary: list[dict[str, Any]],
|
|
164
150
|
indicator_lambda: float,
|
|
165
|
-
) ->
|
|
151
|
+
) -> tuple[dict[str, Any], dict[str, Any], dict[str, float]]:
|
|
166
152
|
"""Compute stepwise reward metadata given achievement states before/after a decision."""
|
|
167
153
|
|
|
168
154
|
prev_map = prev_achievements or {}
|
|
@@ -193,7 +179,7 @@ def compute_stepwise_reward(
|
|
|
193
179
|
|
|
194
180
|
|
|
195
181
|
class RolloutMetrics(BaseModel):
|
|
196
|
-
episode_returns:
|
|
182
|
+
episode_returns: list[float]
|
|
197
183
|
mean_return: float
|
|
198
184
|
num_steps: int
|
|
199
185
|
num_episodes: int = 0
|
|
@@ -201,12 +187,12 @@ class RolloutMetrics(BaseModel):
|
|
|
201
187
|
|
|
202
188
|
class RolloutResponse(BaseModel):
|
|
203
189
|
run_id: str
|
|
204
|
-
trajectories:
|
|
205
|
-
branches:
|
|
190
|
+
trajectories: list[RolloutTrajectory]
|
|
191
|
+
branches: dict[str, list[str]] = {}
|
|
206
192
|
metrics: RolloutMetrics
|
|
207
193
|
aborted: bool = False
|
|
208
194
|
ops_executed: int = 0
|
|
209
|
-
trace:
|
|
195
|
+
trace: dict[str, Any] | None = None
|
|
210
196
|
|
|
211
197
|
|
|
212
198
|
class RolloutTracingContext:
|
|
@@ -230,6 +216,8 @@ class RolloutTracingContext:
|
|
|
230
216
|
self.sft_records: list[dict[str, Any]] = []
|
|
231
217
|
self.latest_system_messages: list[str] = []
|
|
232
218
|
self.latest_user_messages: list[str] = []
|
|
219
|
+
self.latest_system_prompt_content: list[Any] = []
|
|
220
|
+
self.latest_user_prompt_content: list[Any] = []
|
|
233
221
|
self.trace_format = (
|
|
234
222
|
getattr(request.record, "trace_format", "compact") or "compact"
|
|
235
223
|
).lower()
|
|
@@ -298,26 +286,32 @@ class RolloutTracingContext:
|
|
|
298
286
|
|
|
299
287
|
async def record_policy_prompts(
|
|
300
288
|
self,
|
|
301
|
-
system_messages: list[
|
|
302
|
-
user_messages: list[
|
|
289
|
+
system_messages: list[Any],
|
|
290
|
+
user_messages: list[Any],
|
|
303
291
|
) -> None:
|
|
304
|
-
self.latest_system_messages =
|
|
305
|
-
self.latest_user_messages =
|
|
292
|
+
self.latest_system_messages = [self._prompt_text(entry) for entry in system_messages]
|
|
293
|
+
self.latest_user_messages = [self._prompt_text(entry) for entry in user_messages]
|
|
294
|
+
self.latest_system_prompt_content = [
|
|
295
|
+
self._prompt_content(entry, role="system") for entry in system_messages
|
|
296
|
+
]
|
|
297
|
+
self.latest_user_prompt_content = [
|
|
298
|
+
self._prompt_content(entry, role="user") for entry in user_messages
|
|
299
|
+
]
|
|
306
300
|
if not self.enabled or self.tracer is None:
|
|
307
301
|
return
|
|
308
|
-
for
|
|
302
|
+
for entry in system_messages:
|
|
309
303
|
try:
|
|
310
304
|
await self.tracer.record_message(
|
|
311
|
-
content=
|
|
305
|
+
content=self._prompt_payload(entry, role="system"),
|
|
312
306
|
message_type="policy_system_prompt",
|
|
313
307
|
metadata=self._message_metadata(),
|
|
314
308
|
)
|
|
315
309
|
except Exception as exc:
|
|
316
310
|
logger.debug("TRACING_SYSTEM_MSG_FAIL: %s", exc)
|
|
317
|
-
for
|
|
311
|
+
for entry in user_messages:
|
|
318
312
|
try:
|
|
319
313
|
await self.tracer.record_message(
|
|
320
|
-
content=
|
|
314
|
+
content=self._prompt_payload(entry, role="user"),
|
|
321
315
|
message_type="policy_user_prompt",
|
|
322
316
|
metadata=self._message_metadata(),
|
|
323
317
|
)
|
|
@@ -339,6 +333,49 @@ class RolloutTracingContext:
|
|
|
339
333
|
return ""
|
|
340
334
|
return str(content)
|
|
341
335
|
|
|
336
|
+
def _prompt_text(self, entry: Any) -> str:
|
|
337
|
+
if isinstance(entry, dict):
|
|
338
|
+
text = entry.get("text")
|
|
339
|
+
if isinstance(text, str):
|
|
340
|
+
return text
|
|
341
|
+
content = entry.get("content")
|
|
342
|
+
return self._content_to_text(content)
|
|
343
|
+
return self._content_to_text(entry)
|
|
344
|
+
|
|
345
|
+
def _prompt_payload(self, entry: Any, *, role: str) -> dict[str, Any]:
|
|
346
|
+
if isinstance(entry, dict):
|
|
347
|
+
payload = dict(entry)
|
|
348
|
+
payload.setdefault("role", role)
|
|
349
|
+
return payload
|
|
350
|
+
return {
|
|
351
|
+
"role": role,
|
|
352
|
+
"text": self._prompt_text(entry),
|
|
353
|
+
"content": entry,
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
def _prompt_content(self, entry: Any, *, role: str) -> Any:
|
|
357
|
+
payload = self._prompt_payload(entry, role=role)
|
|
358
|
+
return payload.get("content", payload.get("text"))
|
|
359
|
+
|
|
360
|
+
def _content_has_image(self, content: Any) -> bool:
|
|
361
|
+
if isinstance(content, list):
|
|
362
|
+
return any(
|
|
363
|
+
isinstance(seg, dict)
|
|
364
|
+
and seg.get("type") in {"image", "image_url"}
|
|
365
|
+
for seg in content
|
|
366
|
+
)
|
|
367
|
+
if isinstance(content, dict):
|
|
368
|
+
if content.get("type") in {"image", "image_url"}:
|
|
369
|
+
return True
|
|
370
|
+
inner = content.get("content")
|
|
371
|
+
if isinstance(inner, list):
|
|
372
|
+
return any(
|
|
373
|
+
isinstance(seg, dict)
|
|
374
|
+
and seg.get("type") in {"image", "image_url"}
|
|
375
|
+
for seg in inner
|
|
376
|
+
)
|
|
377
|
+
return False
|
|
378
|
+
|
|
342
379
|
def _safe_json(self, payload: Any, limit: int = 4000) -> str:
|
|
343
380
|
try:
|
|
344
381
|
text = json.dumps(payload, ensure_ascii=False)
|
|
@@ -464,21 +501,44 @@ class RolloutTracingContext:
|
|
|
464
501
|
)
|
|
465
502
|
|
|
466
503
|
if self.sft_output_dir is not None:
|
|
504
|
+
assistant_structured = assistant_content if assistant_content is not None else ""
|
|
467
505
|
assistant_text = self._content_to_text(assistant_content)
|
|
506
|
+
dialogue_structured: list[dict[str, Any]] = []
|
|
507
|
+
for content in self.latest_system_prompt_content:
|
|
508
|
+
if content is None:
|
|
509
|
+
continue
|
|
510
|
+
dialogue_structured.append({"role": "system", "content": content})
|
|
511
|
+
for content in self.latest_user_prompt_content:
|
|
512
|
+
if content is None:
|
|
513
|
+
continue
|
|
514
|
+
dialogue_structured.append({"role": "user", "content": content})
|
|
515
|
+
dialogue_text = (
|
|
516
|
+
[{"role": "system", "content": s} for s in self.latest_system_messages]
|
|
517
|
+
+ [{"role": "user", "content": u} for u in self.latest_user_messages]
|
|
518
|
+
)
|
|
519
|
+
user_has_image = any(
|
|
520
|
+
self._content_has_image(content) for content in self.latest_user_prompt_content
|
|
521
|
+
)
|
|
522
|
+
assistant_has_image = self._content_has_image(assistant_structured)
|
|
468
523
|
record = {
|
|
469
524
|
"run_id": self.run_id,
|
|
470
525
|
"turn": self.current_turn,
|
|
471
526
|
"model": model_name,
|
|
472
527
|
"provider": provider,
|
|
473
|
-
"dialogue":
|
|
474
|
-
|
|
475
|
-
+ [{"role": "user", "content": u} for u in self.latest_user_messages]
|
|
476
|
-
),
|
|
528
|
+
"dialogue": dialogue_structured,
|
|
529
|
+
"dialogue_text": dialogue_text,
|
|
477
530
|
"assistant": {
|
|
478
|
-
"content":
|
|
531
|
+
"content": assistant_structured,
|
|
532
|
+
"content_text": assistant_text,
|
|
479
533
|
"tool_calls": assistant_message.get("tool_calls")
|
|
480
534
|
if isinstance(assistant_message, dict)
|
|
481
535
|
else [],
|
|
536
|
+
"has_image": assistant_has_image,
|
|
537
|
+
},
|
|
538
|
+
"metadata": {
|
|
539
|
+
"user_has_image": user_has_image,
|
|
540
|
+
"assistant_has_image": assistant_has_image,
|
|
541
|
+
"has_image": user_has_image or assistant_has_image,
|
|
482
542
|
},
|
|
483
543
|
"timestamp": datetime.utcnow().isoformat(),
|
|
484
544
|
}
|
|
@@ -488,10 +548,10 @@ class RolloutTracingContext:
|
|
|
488
548
|
self,
|
|
489
549
|
*,
|
|
490
550
|
env_handle: Any,
|
|
491
|
-
prev_obs:
|
|
551
|
+
prev_obs: dict[str, Any] | None,
|
|
492
552
|
env_response: Any,
|
|
493
|
-
next_obs:
|
|
494
|
-
metadata:
|
|
553
|
+
next_obs: dict[str, Any] | None,
|
|
554
|
+
metadata: dict[str, Any] | None = None,
|
|
495
555
|
) -> int | None:
|
|
496
556
|
if not self.enabled or self.tracer is None:
|
|
497
557
|
return None
|
|
@@ -540,7 +600,7 @@ class RolloutTracingContext:
|
|
|
540
600
|
self,
|
|
541
601
|
*,
|
|
542
602
|
event_id: int | None,
|
|
543
|
-
decision_meta:
|
|
603
|
+
decision_meta: dict[str, Any] | None,
|
|
544
604
|
) -> None:
|
|
545
605
|
decision_meta = decision_meta or {}
|
|
546
606
|
ach_delta = int(decision_meta.get("ach_delta", 0))
|
|
@@ -588,7 +648,7 @@ class RolloutTracingContext:
|
|
|
588
648
|
self,
|
|
589
649
|
*,
|
|
590
650
|
total_reward: float,
|
|
591
|
-
achievement_state:
|
|
651
|
+
achievement_state: dict[str, bool] | None,
|
|
592
652
|
total_steps: int,
|
|
593
653
|
) -> Any:
|
|
594
654
|
final_achievements = [key for key, val in (achievement_state or {}).items() if val]
|
|
@@ -610,10 +670,8 @@ class RolloutTracingContext:
|
|
|
610
670
|
except Exception as exc:
|
|
611
671
|
logger.debug("TRACING_END_SESSION_FAIL: %s", exc)
|
|
612
672
|
self.session_trace = None
|
|
613
|
-
|
|
673
|
+
with contextlib.suppress(Exception):
|
|
614
674
|
await self.tracer.close()
|
|
615
|
-
except Exception:
|
|
616
|
-
pass
|
|
617
675
|
|
|
618
676
|
if self.sft_records and self.sft_output_dir:
|
|
619
677
|
self.write_sft_records()
|
|
@@ -639,7 +697,7 @@ class RolloutTracingContext:
|
|
|
639
697
|
finally:
|
|
640
698
|
self.sft_records.clear()
|
|
641
699
|
|
|
642
|
-
def build_trace_payload(self, session_trace: Any) ->
|
|
700
|
+
def build_trace_payload(self, session_trace: Any) -> dict[str, Any] | None:
|
|
643
701
|
if not self.return_trace or session_trace is None:
|
|
644
702
|
return None
|
|
645
703
|
if self.trace_format == "full":
|
|
@@ -660,32 +718,33 @@ class RolloutTracingContext:
|
|
|
660
718
|
|
|
661
719
|
|
|
662
720
|
def _summarize_observation_for_storage(
|
|
663
|
-
env_handle: Any, observation:
|
|
664
|
-
) ->
|
|
721
|
+
env_handle: Any, observation: dict[str, Any]
|
|
722
|
+
) -> dict[str, Any]:
|
|
665
723
|
"""Return a compact dict for trajectory storage instead of the raw observation.
|
|
666
724
|
|
|
667
725
|
- For Crafter, use the same summary used for the policy user prompt
|
|
668
726
|
- For others, keep a minimal subset or plain text preview
|
|
669
727
|
"""
|
|
670
728
|
# Try Crafter-specific formatter
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
729
|
+
crafter_wrapper = None
|
|
730
|
+
with contextlib.suppress(Exception):
|
|
731
|
+
from .envs.crafter.environment import (
|
|
732
|
+
CrafterEnvironmentWrapper as _CrafterWrapper, # type: ignore
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
crafter_wrapper = _CrafterWrapper # type: ignore[assignment]
|
|
675
736
|
|
|
676
|
-
if
|
|
677
|
-
getattr(env_handle, "env", None),
|
|
737
|
+
if crafter_wrapper is not None and isinstance(
|
|
738
|
+
getattr(env_handle, "env", None), crafter_wrapper
|
|
678
739
|
):
|
|
679
|
-
|
|
740
|
+
with contextlib.suppress(Exception):
|
|
680
741
|
from .envs.crafter.shared import format_observation as _fmt # type: ignore
|
|
681
742
|
|
|
682
743
|
text = _fmt(observation or {})
|
|
683
744
|
return {"text": text}
|
|
684
|
-
except Exception:
|
|
685
|
-
pass
|
|
686
745
|
|
|
687
746
|
# Generic fallback: extract a few small fields if present; avoid huge arrays
|
|
688
|
-
|
|
747
|
+
with contextlib.suppress(Exception):
|
|
689
748
|
inv = observation.get("inventory") if isinstance(observation, dict) else None
|
|
690
749
|
ach = observation.get("achievements_status") if isinstance(observation, dict) else None
|
|
691
750
|
pos = observation.get("player_position") if isinstance(observation, dict) else None
|
|
@@ -695,16 +754,14 @@ def _summarize_observation_for_storage(
|
|
|
695
754
|
summary = {
|
|
696
755
|
"position": pos,
|
|
697
756
|
"health": health,
|
|
698
|
-
"inventory_keys": sorted(
|
|
757
|
+
"inventory_keys": sorted(k for k, v in (inv or {}).items() if v)[:10]
|
|
699
758
|
if isinstance(inv, dict)
|
|
700
759
|
else None,
|
|
701
|
-
"achievements_unlocked": sorted(
|
|
760
|
+
"achievements_unlocked": sorted(k for k, v in (ach or {}).items() if v)[:10]
|
|
702
761
|
if isinstance(ach, dict)
|
|
703
762
|
else None,
|
|
704
763
|
}
|
|
705
764
|
return {"text": json.dumps(summary, ensure_ascii=False)}
|
|
706
|
-
except Exception:
|
|
707
|
-
pass
|
|
708
765
|
|
|
709
766
|
# Last resort: plain string preview
|
|
710
767
|
try:
|
|
@@ -726,7 +783,7 @@ class RunStatusResponse(BaseModel):
|
|
|
726
783
|
run_id: str
|
|
727
784
|
status: str
|
|
728
785
|
started_at: datetime
|
|
729
|
-
finished_at:
|
|
786
|
+
finished_at: datetime | None = None
|
|
730
787
|
|
|
731
788
|
|
|
732
789
|
@router.post("/rollout", response_model=RolloutResponse)
|
|
@@ -735,6 +792,13 @@ async def execute_rollout(
|
|
|
735
792
|
req: Request,
|
|
736
793
|
) -> RolloutResponse:
|
|
737
794
|
"""Execute a rollout with coordinated environment and policy steps."""
|
|
795
|
+
# Emit rollout identifier early for correlation
|
|
796
|
+
with contextlib.suppress(Exception):
|
|
797
|
+
_rid = getattr(request, "run_id", None)
|
|
798
|
+
_pol = getattr(request.policy, "policy_name", None) or getattr(request.policy, "policy_id", None)
|
|
799
|
+
_env = getattr(request.env, "env_name", None) or getattr(request.env, "env_id", None)
|
|
800
|
+
logger.info("ROLLOUT_BEGIN: run_id=%s policy=%s env=%s", _rid, _pol, _env)
|
|
801
|
+
print(f"[rollout] begin run_id={_rid} policy={_pol} env={_env}", flush=True)
|
|
738
802
|
# Enforce per-episode step cap via env-specific parameters; default to 20 if omitted
|
|
739
803
|
try:
|
|
740
804
|
_env_params = {}
|
|
@@ -749,32 +813,30 @@ async def execute_rollout(
|
|
|
749
813
|
"error": "invalid_env_params",
|
|
750
814
|
"message": f"Invalid or missing env_params.max_steps_per_episode: {_mse}",
|
|
751
815
|
},
|
|
752
|
-
)
|
|
816
|
+
) from _mse
|
|
753
817
|
# Truncate incoming ops to the enforced cap (each step is [agent, env])
|
|
754
|
-
ops_seq:
|
|
818
|
+
ops_seq: list[str] = list(request.ops or [])
|
|
755
819
|
allowed_ops = max(0, int(max_steps_per_episode) * 2)
|
|
756
820
|
if len(ops_seq) > allowed_ops:
|
|
757
|
-
|
|
821
|
+
with contextlib.suppress(Exception):
|
|
758
822
|
logger.info(
|
|
759
823
|
"ROLL_OUT: truncating ops to cap: requested_ops=%s allowed_ops=%s",
|
|
760
824
|
str(len(ops_seq)),
|
|
761
825
|
str(allowed_ops),
|
|
762
826
|
)
|
|
763
|
-
except Exception:
|
|
764
|
-
pass
|
|
765
827
|
ops_seq = ops_seq[:allowed_ops]
|
|
766
828
|
# Simple API key auth for inbound rollout
|
|
767
829
|
header_key = req.headers.get("x-api-key")
|
|
768
830
|
env_key = os.getenv("ENVIRONMENT_API_KEY")
|
|
769
|
-
dev_key = os.getenv("
|
|
770
|
-
# Accept either ENVIRONMENT_API_KEY or
|
|
831
|
+
dev_key = os.getenv("DEV_ENVIRONMENT_API_KEY")
|
|
832
|
+
# Accept either ENVIRONMENT_API_KEY or DEV_ENVIRONMENT_API_KEY
|
|
771
833
|
expected_keys = [k for k in (env_key, dev_key) if k]
|
|
772
834
|
if not expected_keys:
|
|
773
835
|
missing = []
|
|
774
836
|
if not env_key:
|
|
775
837
|
missing.append("ENVIRONMENT_API_KEY")
|
|
776
838
|
if not dev_key:
|
|
777
|
-
missing.append("
|
|
839
|
+
missing.append("DEV_ENVIRONMENT_API_KEY")
|
|
778
840
|
msg = f"Auth not configured: missing {', '.join(missing)} in task service environment"
|
|
779
841
|
logger.error(msg)
|
|
780
842
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=msg)
|
|
@@ -800,33 +862,38 @@ async def execute_rollout(
|
|
|
800
862
|
logger.info(f"ROLL_OUT: synth_base_url={request.synth_base_url}")
|
|
801
863
|
|
|
802
864
|
# Log masked OpenAI API key presence for diagnostics
|
|
803
|
-
|
|
865
|
+
with contextlib.suppress(Exception):
|
|
804
866
|
_oa = os.getenv("OPENAI_API_KEY")
|
|
805
867
|
if _oa:
|
|
806
868
|
_pref = (_oa[:6] + "…") if len(_oa) >= 6 else "set"
|
|
807
869
|
logger.info(f"ROLL_OUT: OPENAI_API_KEY present (prefix={_pref})")
|
|
808
870
|
else:
|
|
809
871
|
logger.warning("ROLL_OUT: OPENAI_API_KEY missing")
|
|
810
|
-
except Exception:
|
|
811
|
-
pass
|
|
812
872
|
|
|
813
873
|
# Make synth_base_url available for outbound calls in this app
|
|
814
|
-
|
|
874
|
+
with contextlib.suppress(Exception):
|
|
815
875
|
task_app = req.app.state.task_app
|
|
816
876
|
if request.synth_base_url:
|
|
817
|
-
|
|
818
|
-
except Exception:
|
|
819
|
-
pass
|
|
877
|
+
task_app.synth_base_url = request.synth_base_url
|
|
820
878
|
|
|
821
879
|
tracer_factory = getattr(req.app.state, "session_tracer_factory", None)
|
|
822
|
-
tracer_instance = None
|
|
880
|
+
tracer_instance: SessionTracer | None = None
|
|
823
881
|
if callable(tracer_factory):
|
|
824
882
|
try:
|
|
825
|
-
|
|
883
|
+
inst = tracer_factory()
|
|
884
|
+
tracer_instance = inst if isinstance(inst, SessionTracer) else None
|
|
826
885
|
except Exception as exc:
|
|
827
886
|
logger.debug(f"TRACER_FACTORY_FAIL: {exc}")
|
|
828
887
|
tracing_context = RolloutTracingContext(tracer_instance, request, req)
|
|
829
888
|
await tracing_context.start_session()
|
|
889
|
+
# Print whether tracing is active for this rollout
|
|
890
|
+
try:
|
|
891
|
+
print(
|
|
892
|
+
f"[rollout] tracing enabled={bool(tracing_context.enabled)} run_id={request.run_id}",
|
|
893
|
+
flush=True,
|
|
894
|
+
)
|
|
895
|
+
except Exception:
|
|
896
|
+
pass
|
|
830
897
|
|
|
831
898
|
# Register run
|
|
832
899
|
registry.register_run(request.run_id)
|
|
@@ -835,10 +902,25 @@ async def execute_rollout(
|
|
|
835
902
|
created_env_id: str | None = None
|
|
836
903
|
created_policy_id: str | None = None
|
|
837
904
|
env_seed_used: int | None = None
|
|
905
|
+
trajectory_steps: list[RolloutStep] = []
|
|
906
|
+
decision_samples: list[dict[str, Any]] = []
|
|
907
|
+
pending_tool_calls: Any = None
|
|
908
|
+
current_obs: Any = {}
|
|
909
|
+
total_reward: float = 0.0
|
|
910
|
+
ops_executed = 0
|
|
911
|
+
last_agent_response_ts: float | None = None
|
|
912
|
+
last_policy_meta: dict[str, Any] | None = None
|
|
913
|
+
last_env_step_ms: float | None = None
|
|
914
|
+
last_env_step_completed_ts: float | None = None
|
|
915
|
+
decision_open = False
|
|
916
|
+
finalized = False
|
|
917
|
+
prev_achievements: dict[str, bool] = {}
|
|
918
|
+
session_trace = None
|
|
919
|
+
step_rewards_active = False
|
|
838
920
|
|
|
839
921
|
try:
|
|
840
922
|
# Initialize deterministic seed early for the entire rollout
|
|
841
|
-
seed_value:
|
|
923
|
+
seed_value: int | None = None
|
|
842
924
|
try:
|
|
843
925
|
if request.env and request.env.seed is not None:
|
|
844
926
|
seed_value = int(request.env.seed)
|
|
@@ -857,14 +939,12 @@ async def execute_rollout(
|
|
|
857
939
|
seed_value = 42
|
|
858
940
|
|
|
859
941
|
_seed_info = _set_global_seed(int(seed_value))
|
|
860
|
-
|
|
942
|
+
with contextlib.suppress(Exception):
|
|
861
943
|
logger.info(
|
|
862
944
|
"ROLL_OUT: RNG seeded seed=%s libs=%s",
|
|
863
945
|
str(_seed_info.get("seed")),
|
|
864
946
|
",".join(_seed_info.get("libs", [])),
|
|
865
947
|
)
|
|
866
|
-
except Exception:
|
|
867
|
-
pass
|
|
868
948
|
# Resolve or create environment
|
|
869
949
|
if request.env.env_id:
|
|
870
950
|
env_handle = registry.get_env(request.env.env_id)
|
|
@@ -876,7 +956,7 @@ async def execute_rollout(
|
|
|
876
956
|
env_id = request.env.env_id
|
|
877
957
|
else:
|
|
878
958
|
# Create new environment
|
|
879
|
-
from .environment_routes import
|
|
959
|
+
from .environment_routes import EnvCreateRequest, create_environment
|
|
880
960
|
|
|
881
961
|
if not request.env.env_name:
|
|
882
962
|
raise ValueError("FATAL: env_name is required - NO FALLBACKS!")
|
|
@@ -910,7 +990,7 @@ async def execute_rollout(
|
|
|
910
990
|
policy_id = request.policy.policy_id
|
|
911
991
|
else:
|
|
912
992
|
# Create new policy
|
|
913
|
-
from .policy_routes import
|
|
993
|
+
from .policy_routes import PolicyCreateRequest, create_policy
|
|
914
994
|
|
|
915
995
|
if not request.policy.policy_name:
|
|
916
996
|
raise ValueError("FATAL: policy_name is required - NO FALLBACKS!")
|
|
@@ -946,20 +1026,19 @@ async def execute_rollout(
|
|
|
946
1026
|
except Exception:
|
|
947
1027
|
env_seed_used = None
|
|
948
1028
|
tracing_context.update_metadata(env_seed=env_seed_used)
|
|
949
|
-
|
|
950
1029
|
# Initialize trajectory
|
|
951
1030
|
trajectory_steps = []
|
|
952
1031
|
pending_tool_calls = None
|
|
953
1032
|
current_obs = env_handle.last_observation
|
|
954
1033
|
total_reward = 0.0
|
|
955
1034
|
ops_executed = 0
|
|
956
|
-
last_agent_response_ts
|
|
957
|
-
last_policy_meta
|
|
958
|
-
last_env_step_ms
|
|
959
|
-
last_env_step_completed_ts
|
|
1035
|
+
last_agent_response_ts = None
|
|
1036
|
+
last_policy_meta = None
|
|
1037
|
+
last_env_step_ms = None
|
|
1038
|
+
last_env_step_completed_ts = None
|
|
960
1039
|
|
|
961
1040
|
# Stepwise reward configuration (Crafter shaping; gate on explicit enable)
|
|
962
|
-
step_rewards_cfg_raw:
|
|
1041
|
+
step_rewards_cfg_raw: dict[str, Any] = {}
|
|
963
1042
|
try:
|
|
964
1043
|
if isinstance(request.policy.config, dict):
|
|
965
1044
|
step_rewards_cfg_raw = dict(request.policy.config.get("step_rewards") or {})
|
|
@@ -986,7 +1065,7 @@ async def execute_rollout(
|
|
|
986
1065
|
step_rewards_beta = 0.0
|
|
987
1066
|
step_rewards_active = step_rewards_enabled and step_rewards_mode == "decision_stepwise"
|
|
988
1067
|
|
|
989
|
-
def _extract_achievements(obs: Any) ->
|
|
1068
|
+
def _extract_achievements(obs: Any) -> dict[str, bool]:
|
|
990
1069
|
if not isinstance(obs, dict):
|
|
991
1070
|
return {}
|
|
992
1071
|
ach = obs.get("achievements_status")
|
|
@@ -994,7 +1073,7 @@ async def execute_rollout(
|
|
|
994
1073
|
return {str(k): bool(v) for k, v in ach.items()}
|
|
995
1074
|
return {}
|
|
996
1075
|
|
|
997
|
-
def _summarize_tool_calls(tool_calls: Any) ->
|
|
1076
|
+
def _summarize_tool_calls(tool_calls: Any) -> list[dict[str, Any]]:
|
|
998
1077
|
if not tool_calls:
|
|
999
1078
|
return []
|
|
1000
1079
|
try:
|
|
@@ -1005,7 +1084,7 @@ async def execute_rollout(
|
|
|
1005
1084
|
)
|
|
1006
1085
|
except Exception:
|
|
1007
1086
|
return []
|
|
1008
|
-
summary:
|
|
1087
|
+
summary: list[dict[str, Any]] = []
|
|
1009
1088
|
for tc in items:
|
|
1010
1089
|
tool_name = None
|
|
1011
1090
|
args: Any = {}
|
|
@@ -1024,16 +1103,16 @@ async def execute_rollout(
|
|
|
1024
1103
|
summary.append({"tool": tool_name, "args": args})
|
|
1025
1104
|
return summary
|
|
1026
1105
|
|
|
1027
|
-
decision_samples:
|
|
1106
|
+
decision_samples: list[dict[str, Any]] = []
|
|
1028
1107
|
decision_index = 0
|
|
1029
1108
|
decision_open = False
|
|
1030
1109
|
session_trace = None
|
|
1031
1110
|
finalized = False
|
|
1032
1111
|
prev_achievements = _extract_achievements(current_obs)
|
|
1033
1112
|
# Track episode-level achievements that have been seen as true at any point so far
|
|
1034
|
-
episode_seen_achievements: set[str] =
|
|
1035
|
-
|
|
1036
|
-
|
|
1113
|
+
episode_seen_achievements: set[str] = {
|
|
1114
|
+
k for k, v in (prev_achievements or {}).items() if bool(v)
|
|
1115
|
+
}
|
|
1037
1116
|
stepwise_indicator_sum = 0.0
|
|
1038
1117
|
stepwise_reward_sum = 0.0
|
|
1039
1118
|
stepwise_new_achievements_total = 0
|
|
@@ -1053,7 +1132,7 @@ async def execute_rollout(
|
|
|
1053
1132
|
|
|
1054
1133
|
if op == "agent":
|
|
1055
1134
|
# Policy step
|
|
1056
|
-
from .policy_routes import
|
|
1135
|
+
from .policy_routes import PolicyStepRequest, step_policy
|
|
1057
1136
|
|
|
1058
1137
|
if not decision_open:
|
|
1059
1138
|
await tracing_context.start_decision(decision_index)
|
|
@@ -1061,7 +1140,7 @@ async def execute_rollout(
|
|
|
1061
1140
|
|
|
1062
1141
|
agent_request_start = _time.perf_counter()
|
|
1063
1142
|
if last_agent_response_ts is not None and last_policy_meta is not None:
|
|
1064
|
-
|
|
1143
|
+
with contextlib.suppress(Exception):
|
|
1065
1144
|
timing_prev = last_policy_meta.setdefault("timing", {})
|
|
1066
1145
|
decision_ms = max(
|
|
1067
1146
|
0.0,
|
|
@@ -1080,7 +1159,7 @@ async def execute_rollout(
|
|
|
1080
1159
|
# Also backfill the last appended trajectory step so the trainer
|
|
1081
1160
|
# can always see decision_ms without relying on shared dict refs.
|
|
1082
1161
|
if trajectory_steps:
|
|
1083
|
-
|
|
1162
|
+
with contextlib.suppress(Exception):
|
|
1084
1163
|
_last = trajectory_steps[-1]
|
|
1085
1164
|
_info = dict(_last.info or {})
|
|
1086
1165
|
_meta = dict(_info.get("meta") or {})
|
|
@@ -1097,10 +1176,6 @@ async def execute_rollout(
|
|
|
1097
1176
|
_meta["timing"] = _timing
|
|
1098
1177
|
_info["meta"] = _meta
|
|
1099
1178
|
_last.info = _info
|
|
1100
|
-
except Exception:
|
|
1101
|
-
pass
|
|
1102
|
-
except Exception:
|
|
1103
|
-
pass
|
|
1104
1179
|
last_env_step_ms = None
|
|
1105
1180
|
last_env_step_completed_ts = None
|
|
1106
1181
|
|
|
@@ -1123,37 +1198,25 @@ async def execute_rollout(
|
|
|
1123
1198
|
}
|
|
1124
1199
|
|
|
1125
1200
|
# Log compact metadata summary to confirm history threading
|
|
1126
|
-
|
|
1127
|
-
_prev_calls = (
|
|
1128
|
-
metadata["prev_tool_calls"]
|
|
1129
|
-
if isinstance(metadata, dict) and "prev_tool_calls" in metadata
|
|
1130
|
-
else None
|
|
1131
|
-
)
|
|
1201
|
+
with contextlib.suppress(Exception):
|
|
1202
|
+
_prev_calls = metadata.get("prev_tool_calls")
|
|
1132
1203
|
_count = len(_prev_calls) if isinstance(_prev_calls, list) else 0
|
|
1133
1204
|
_first_guess = None
|
|
1134
1205
|
if _count > 0 and isinstance(_prev_calls[0], dict):
|
|
1135
|
-
_args = (
|
|
1136
|
-
_prev_calls[0]["arguments"] if "arguments" in _prev_calls[0] else None
|
|
1137
|
-
)
|
|
1206
|
+
_args = _prev_calls[0].get("arguments", None)
|
|
1138
1207
|
if isinstance(_args, str):
|
|
1139
1208
|
import json as _json
|
|
1140
|
-
|
|
1141
|
-
try:
|
|
1209
|
+
with contextlib.suppress(Exception):
|
|
1142
1210
|
_args = _json.loads(_args)
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
_first_guess = (_args["guess"] if "guess" in _args else None) or (
|
|
1147
|
-
_args["word"] if "word" in _args else None
|
|
1148
|
-
)
|
|
1211
|
+
if not isinstance(_args, dict):
|
|
1212
|
+
_args = {}
|
|
1213
|
+
_first_guess = _args.get("guess") or _args.get("word")
|
|
1149
1214
|
logger.info(
|
|
1150
1215
|
"POLICY_METADATA: prev_tool_calls=%d first_guess=%r has_prev_env_result=%s",
|
|
1151
1216
|
_count,
|
|
1152
1217
|
_first_guess,
|
|
1153
1218
|
str("prev_env_result" in metadata),
|
|
1154
1219
|
)
|
|
1155
|
-
except Exception:
|
|
1156
|
-
pass
|
|
1157
1220
|
|
|
1158
1221
|
try:
|
|
1159
1222
|
policy_response = await step_policy(
|
|
@@ -1166,15 +1229,13 @@ async def execute_rollout(
|
|
|
1166
1229
|
)
|
|
1167
1230
|
except Exception as _pe:
|
|
1168
1231
|
# Do not 500 the rollout; finalize with partial trajectory
|
|
1169
|
-
|
|
1232
|
+
with contextlib.suppress(Exception):
|
|
1170
1233
|
logger.warning(
|
|
1171
1234
|
"POLICY_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
|
|
1172
1235
|
request.run_id,
|
|
1173
1236
|
str(op_idx),
|
|
1174
1237
|
str(_pe),
|
|
1175
1238
|
)
|
|
1176
|
-
except Exception:
|
|
1177
|
-
pass
|
|
1178
1239
|
|
|
1179
1240
|
# Build partial trajectory and return HTTP 200
|
|
1180
1241
|
trajectory = RolloutTrajectory(
|
|
@@ -1222,12 +1283,12 @@ async def execute_rollout(
|
|
|
1222
1283
|
|
|
1223
1284
|
agent_response_ts = _time.perf_counter()
|
|
1224
1285
|
if isinstance(policy_response.meta, dict):
|
|
1225
|
-
|
|
1286
|
+
with contextlib.suppress(Exception):
|
|
1226
1287
|
timing_cur = policy_response.meta.setdefault("timing", {})
|
|
1227
1288
|
timing_cur["agent_request_start_s"] = agent_request_start
|
|
1228
1289
|
timing_cur["agent_response_s"] = agent_response_ts
|
|
1229
1290
|
if "inference_ms" in policy_response.meta:
|
|
1230
|
-
|
|
1291
|
+
with contextlib.suppress(Exception):
|
|
1231
1292
|
timing_cur.setdefault(
|
|
1232
1293
|
"inference_ms",
|
|
1233
1294
|
float(policy_response.meta["inference_ms"]),
|
|
@@ -1236,30 +1297,66 @@ async def execute_rollout(
|
|
|
1236
1297
|
"inference_s",
|
|
1237
1298
|
float(policy_response.meta["inference_ms"]) / 1000.0,
|
|
1238
1299
|
)
|
|
1239
|
-
except Exception:
|
|
1240
|
-
pass
|
|
1241
|
-
except Exception:
|
|
1242
|
-
pass
|
|
1243
1300
|
last_policy_meta = policy_response.meta
|
|
1244
1301
|
else:
|
|
1245
1302
|
last_policy_meta = None
|
|
1246
1303
|
last_agent_response_ts = agent_response_ts
|
|
1247
1304
|
|
|
1305
|
+
# Diagnostic: summarize policy step target and tool calls
|
|
1306
|
+
try:
|
|
1307
|
+
model_name = None
|
|
1308
|
+
target_url = None
|
|
1309
|
+
if isinstance(policy_response.meta, dict):
|
|
1310
|
+
req_body = policy_response.meta.get("inference_request") or {}
|
|
1311
|
+
model_name = req_body.get("model")
|
|
1312
|
+
target_url = policy_response.meta.get("inference_url")
|
|
1313
|
+
_tc = policy_response.tool_calls or []
|
|
1314
|
+
print(
|
|
1315
|
+
{
|
|
1316
|
+
"rollout.policy_step": True,
|
|
1317
|
+
"run_id": request.run_id,
|
|
1318
|
+
"model": model_name,
|
|
1319
|
+
"inference_url": target_url,
|
|
1320
|
+
"tool_calls_count": len(_tc) if isinstance(_tc, list) else 0,
|
|
1321
|
+
},
|
|
1322
|
+
flush=True,
|
|
1323
|
+
)
|
|
1324
|
+
except Exception:
|
|
1325
|
+
pass
|
|
1326
|
+
|
|
1248
1327
|
pending_tool_calls = policy_response.tool_calls
|
|
1328
|
+
# Log summarized agent tool calls
|
|
1329
|
+
with contextlib.suppress(Exception):
|
|
1330
|
+
_tc = pending_tool_calls or []
|
|
1331
|
+
_summary = []
|
|
1332
|
+
for _item in (_tc if isinstance(_tc, list) else []):
|
|
1333
|
+
try:
|
|
1334
|
+
if isinstance(_item, dict):
|
|
1335
|
+
_tool = _item.get("tool")
|
|
1336
|
+
_args = _item.get("args")
|
|
1337
|
+
_keys = list(_args.keys()) if isinstance(_args, dict) else []
|
|
1338
|
+
_summary.append({"tool": _tool, "args_keys": _keys})
|
|
1339
|
+
except Exception:
|
|
1340
|
+
continue
|
|
1341
|
+
_rid = getattr(request, "run_id", None)
|
|
1342
|
+
logger.info("AGENT_TOOL_CALLS: run_id=%s count=%d summary=%s", _rid, len(_tc), _summary)
|
|
1343
|
+
print(f"[rollout] agent tool_calls run_id={_rid} count={len(_tc)} summary={_summary}", flush=True)
|
|
1249
1344
|
await tracing_context.record_tool_invocation(pending_tool_calls)
|
|
1250
1345
|
ops_executed += 1
|
|
1251
1346
|
|
|
1252
1347
|
elif op == "env":
|
|
1253
1348
|
if not pending_tool_calls:
|
|
1254
1349
|
# Treat absence of tool calls as a soft terminal condition; yield partial trajectory
|
|
1255
|
-
|
|
1350
|
+
with contextlib.suppress(Exception):
|
|
1256
1351
|
logger.warning(
|
|
1257
1352
|
"NO_TOOL_CALLS: terminating episode early run_id=%s op_idx=%s",
|
|
1258
1353
|
request.run_id,
|
|
1259
1354
|
str(op_idx),
|
|
1260
1355
|
)
|
|
1261
|
-
|
|
1262
|
-
|
|
1356
|
+
print(
|
|
1357
|
+
f"[rollout] no tool_calls; terminating early run_id={request.run_id} op_idx={op_idx}",
|
|
1358
|
+
flush=True,
|
|
1359
|
+
)
|
|
1263
1360
|
term_step = RolloutStep(
|
|
1264
1361
|
obs=current_obs,
|
|
1265
1362
|
tool_calls=[],
|
|
@@ -1315,7 +1412,7 @@ async def execute_rollout(
|
|
|
1315
1412
|
)
|
|
1316
1413
|
|
|
1317
1414
|
# Environment step
|
|
1318
|
-
from .environment_routes import
|
|
1415
|
+
from .environment_routes import EnvStepRequest, step_environment
|
|
1319
1416
|
|
|
1320
1417
|
env_step_error: Exception | None = None
|
|
1321
1418
|
env_response = None
|
|
@@ -1334,24 +1431,20 @@ async def execute_rollout(
|
|
|
1334
1431
|
last_env_step_ms = env_step_duration_ms
|
|
1335
1432
|
last_env_step_completed_ts = env_step_end
|
|
1336
1433
|
if last_policy_meta is not None:
|
|
1337
|
-
|
|
1434
|
+
with contextlib.suppress(Exception):
|
|
1338
1435
|
timing_env = last_policy_meta.setdefault("timing", {})
|
|
1339
1436
|
timing_env["env_step_ms"] = env_step_duration_ms
|
|
1340
1437
|
timing_env["env_step_end_s"] = env_step_end
|
|
1341
|
-
except Exception:
|
|
1342
|
-
pass
|
|
1343
1438
|
|
|
1344
1439
|
if env_step_error is not None:
|
|
1345
1440
|
# Invalid action or environment rejection — terminate episode early with partial trajectory
|
|
1346
|
-
|
|
1441
|
+
with contextlib.suppress(Exception):
|
|
1347
1442
|
logger.warning(
|
|
1348
1443
|
"ENV_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
|
|
1349
1444
|
request.run_id,
|
|
1350
1445
|
str(op_idx),
|
|
1351
1446
|
str(env_step_error),
|
|
1352
1447
|
)
|
|
1353
|
-
except Exception:
|
|
1354
|
-
pass
|
|
1355
1448
|
|
|
1356
1449
|
term_step = RolloutStep(
|
|
1357
1450
|
obs=current_obs,
|
|
@@ -1394,7 +1487,7 @@ async def execute_rollout(
|
|
|
1394
1487
|
and last_agent_response_ts is not None
|
|
1395
1488
|
and "decision_ms" not in last_policy_meta.get("timing", {})
|
|
1396
1489
|
):
|
|
1397
|
-
|
|
1490
|
+
with contextlib.suppress(Exception):
|
|
1398
1491
|
timing_last = last_policy_meta.setdefault("timing", {})
|
|
1399
1492
|
decision_ms = max(
|
|
1400
1493
|
0.0,
|
|
@@ -1404,8 +1497,6 @@ async def execute_rollout(
|
|
|
1404
1497
|
timing_last.setdefault(
|
|
1405
1498
|
"overhead_ms", max(0.0, decision_ms - env_step_duration_ms)
|
|
1406
1499
|
)
|
|
1407
|
-
except Exception:
|
|
1408
|
-
pass
|
|
1409
1500
|
if decision_open:
|
|
1410
1501
|
await tracing_context.end_decision()
|
|
1411
1502
|
decision_open = False
|
|
@@ -1433,15 +1524,13 @@ async def execute_rollout(
|
|
|
1433
1524
|
# Record step, including policy meta if present for timing/tokens observability
|
|
1434
1525
|
_info = env_response.info if isinstance(env_response.info, dict) else {}
|
|
1435
1526
|
# Attach policy meta from the immediately preceding agent step
|
|
1436
|
-
|
|
1527
|
+
with contextlib.suppress(Exception):
|
|
1437
1528
|
prev_meta = {}
|
|
1438
1529
|
if "policy_response" in locals() and isinstance(policy_response.meta, dict): # type: ignore[name-defined]
|
|
1439
1530
|
prev_meta = policy_response.meta
|
|
1440
1531
|
if prev_meta:
|
|
1441
1532
|
_info = dict(_info)
|
|
1442
1533
|
_info["meta"] = prev_meta
|
|
1443
|
-
except Exception:
|
|
1444
|
-
pass
|
|
1445
1534
|
|
|
1446
1535
|
event_metadata = {
|
|
1447
1536
|
"op_index": op_idx,
|
|
@@ -1462,7 +1551,7 @@ async def execute_rollout(
|
|
|
1462
1551
|
)
|
|
1463
1552
|
indicator_val = 0
|
|
1464
1553
|
reward_stepwise = 0.0
|
|
1465
|
-
decision_rewards_meta:
|
|
1554
|
+
decision_rewards_meta: dict[str, Any] | None = None
|
|
1466
1555
|
if step_rewards_active:
|
|
1467
1556
|
decision_actions = _summarize_tool_calls(pending_tool_calls)
|
|
1468
1557
|
stepwise_info, decision_record, stats = compute_stepwise_reward(
|
|
@@ -1477,20 +1566,17 @@ async def execute_rollout(
|
|
|
1477
1566
|
stepwise_indicator_sum += float(stats.get("indicator", 0.0))
|
|
1478
1567
|
stepwise_reward_sum += reward_stepwise
|
|
1479
1568
|
stepwise_new_achievements_total += int(stats.get("new_achievements_count", 0.0))
|
|
1480
|
-
if not isinstance(_info, dict)
|
|
1481
|
-
_info = {}
|
|
1482
|
-
else:
|
|
1483
|
-
_info = dict(_info)
|
|
1569
|
+
_info = {} if not isinstance(_info, dict) else dict(_info)
|
|
1484
1570
|
_info["stepwise"] = stepwise_info
|
|
1485
1571
|
# Compute decision-level rewards (absolute vs unique) and attach to metadata
|
|
1486
|
-
|
|
1572
|
+
with contextlib.suppress(Exception):
|
|
1487
1573
|
turned_true = set(stepwise_info.get("new_achievements") or [])
|
|
1488
1574
|
seen_before = set(episode_seen_achievements)
|
|
1489
|
-
new_unique = sorted(
|
|
1575
|
+
new_unique = sorted(turned_true - seen_before)
|
|
1490
1576
|
ach_delta = int(len(turned_true))
|
|
1491
1577
|
unique_delta = int(len(new_unique))
|
|
1492
1578
|
# Prepare stable lists for logging/metadata
|
|
1493
|
-
all_list = sorted(
|
|
1579
|
+
all_list = sorted(turned_true)
|
|
1494
1580
|
# Ensure nested meta exists
|
|
1495
1581
|
meta_block = (
|
|
1496
1582
|
_info.get("meta") if isinstance(_info.get("meta"), dict) else {}
|
|
@@ -1507,9 +1593,6 @@ async def execute_rollout(
|
|
|
1507
1593
|
_info["meta"] = meta_block
|
|
1508
1594
|
# Update episode-level seen set after attributing uniqueness to this decision
|
|
1509
1595
|
episode_seen_achievements.update(turned_true)
|
|
1510
|
-
except Exception:
|
|
1511
|
-
# Best-effort; do not block rollout on metadata computation
|
|
1512
|
-
pass
|
|
1513
1596
|
decision_samples.append(decision_record)
|
|
1514
1597
|
prev_achievements = new_achievement_state
|
|
1515
1598
|
|
|
@@ -1526,6 +1609,32 @@ async def execute_rollout(
|
|
|
1526
1609
|
truncated=env_response.truncated,
|
|
1527
1610
|
info=_info,
|
|
1528
1611
|
)
|
|
1612
|
+
# Log summarized env application of tool calls and immediate reward/done
|
|
1613
|
+
with contextlib.suppress(Exception):
|
|
1614
|
+
_tc = pending_tool_calls or []
|
|
1615
|
+
_summary = []
|
|
1616
|
+
for _item in (_tc if isinstance(_tc, list) else []):
|
|
1617
|
+
try:
|
|
1618
|
+
if isinstance(_item, dict):
|
|
1619
|
+
_tool = _item.get("tool")
|
|
1620
|
+
_args = _item.get("args")
|
|
1621
|
+
_keys = list(_args.keys()) if isinstance(_args, dict) else []
|
|
1622
|
+
_summary.append({"tool": _tool, "args_keys": _keys})
|
|
1623
|
+
except Exception:
|
|
1624
|
+
continue
|
|
1625
|
+
_rid = getattr(request, "run_id", None)
|
|
1626
|
+
logger.info(
|
|
1627
|
+
"ENV_APPLY: run_id=%s tool_calls=%d reward=%s done=%s summary=%s",
|
|
1628
|
+
_rid,
|
|
1629
|
+
len(_tc),
|
|
1630
|
+
str(env_response.reward),
|
|
1631
|
+
str(env_response.done),
|
|
1632
|
+
_summary,
|
|
1633
|
+
)
|
|
1634
|
+
print(
|
|
1635
|
+
f"[rollout] env apply run_id={_rid} tool_calls={len(_tc)} reward={env_response.reward} done={env_response.done} summary={_summary}",
|
|
1636
|
+
flush=True,
|
|
1637
|
+
)
|
|
1529
1638
|
trajectory_steps.append(step)
|
|
1530
1639
|
|
|
1531
1640
|
if env_response.reward is not None:
|
|
@@ -1541,8 +1650,8 @@ async def execute_rollout(
|
|
|
1541
1650
|
if request.on_done == "reset":
|
|
1542
1651
|
# Reset environment
|
|
1543
1652
|
from .environment_routes import (
|
|
1544
|
-
reset_environment,
|
|
1545
1653
|
EnvResetRequest,
|
|
1654
|
+
reset_environment,
|
|
1546
1655
|
)
|
|
1547
1656
|
|
|
1548
1657
|
reset_response = await reset_environment(EnvResetRequest(env_id=env_id))
|
|
@@ -1564,7 +1673,7 @@ async def execute_rollout(
|
|
|
1564
1673
|
and isinstance(last_policy_meta["timing"], dict)
|
|
1565
1674
|
and "decision_ms" not in last_policy_meta["timing"]
|
|
1566
1675
|
):
|
|
1567
|
-
|
|
1676
|
+
with contextlib.suppress(Exception):
|
|
1568
1677
|
final_now = last_env_step_completed_ts or _time.perf_counter()
|
|
1569
1678
|
final_decision_ms = max(0.0, (final_now - float(last_agent_response_ts)) * 1000.0)
|
|
1570
1679
|
timing_final = last_policy_meta.setdefault("timing", {})
|
|
@@ -1577,8 +1686,6 @@ async def execute_rollout(
|
|
|
1577
1686
|
)
|
|
1578
1687
|
else:
|
|
1579
1688
|
timing_final.setdefault("overhead_ms", 0.0)
|
|
1580
|
-
except Exception:
|
|
1581
|
-
pass
|
|
1582
1689
|
|
|
1583
1690
|
# Build trajectory
|
|
1584
1691
|
trajectory = RolloutTrajectory(
|
|
@@ -1601,18 +1708,24 @@ async def execute_rollout(
|
|
|
1601
1708
|
# Environment-specific: Log summary if available
|
|
1602
1709
|
try:
|
|
1603
1710
|
# Check if this is a Wordle environment and use Wordle helpers (lazy import)
|
|
1711
|
+
wordle_wrapper_cls = None
|
|
1604
1712
|
try:
|
|
1605
|
-
from .envs.wordle.environment import WordleEnvironmentWrapper
|
|
1713
|
+
from .envs.wordle.environment import WordleEnvironmentWrapper
|
|
1606
1714
|
from .envs.wordle.helpers import (
|
|
1607
1715
|
get_wordle_rollout_summary,
|
|
1608
1716
|
log_wordle_rollout_summary,
|
|
1609
1717
|
)
|
|
1718
|
+
|
|
1719
|
+
wordle_wrapper_cls = WordleEnvironmentWrapper
|
|
1610
1720
|
except Exception:
|
|
1611
|
-
|
|
1721
|
+
wordle_wrapper_cls = None # type: ignore[assignment]
|
|
1612
1722
|
get_wordle_rollout_summary = None # type: ignore
|
|
1613
1723
|
log_wordle_rollout_summary = None # type: ignore
|
|
1614
1724
|
|
|
1615
|
-
is_wordle =
|
|
1725
|
+
is_wordle = wordle_wrapper_cls is not None and isinstance(
|
|
1726
|
+
env_handle.env,
|
|
1727
|
+
wordle_wrapper_cls, # type: ignore[arg-type]
|
|
1728
|
+
)
|
|
1616
1729
|
if is_wordle:
|
|
1617
1730
|
# Convert trajectory steps to expected format
|
|
1618
1731
|
formatted_steps = []
|
|
@@ -1661,27 +1774,24 @@ async def execute_rollout(
|
|
|
1661
1774
|
logger.error(f"Rollout failed for run {request.run_id}: {e}")
|
|
1662
1775
|
registry.abort_run(request.run_id)
|
|
1663
1776
|
if decision_open:
|
|
1664
|
-
|
|
1777
|
+
with contextlib.suppress(Exception):
|
|
1665
1778
|
await tracing_context.end_decision()
|
|
1666
|
-
except Exception:
|
|
1667
|
-
pass
|
|
1668
1779
|
decision_open = False
|
|
1669
1780
|
if not finalized:
|
|
1670
|
-
|
|
1781
|
+
session_trace = None
|
|
1782
|
+
with contextlib.suppress(Exception):
|
|
1671
1783
|
session_trace = await tracing_context.finalize(
|
|
1672
1784
|
total_reward=total_reward,
|
|
1673
1785
|
achievement_state=prev_achievements,
|
|
1674
1786
|
total_steps=len(trajectory_steps),
|
|
1675
1787
|
)
|
|
1676
|
-
except Exception:
|
|
1677
|
-
session_trace = None
|
|
1678
1788
|
finalized = True
|
|
1679
|
-
raise HTTPException(status_code=500, detail=str(e))
|
|
1789
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
1680
1790
|
finally:
|
|
1681
1791
|
# Ensure any environment created for this rollout is terminated (no reuse across rollouts)
|
|
1682
1792
|
try:
|
|
1683
1793
|
if created_env_id:
|
|
1684
|
-
from .environment_routes import
|
|
1794
|
+
from .environment_routes import EnvTerminateRequest, terminate_environment
|
|
1685
1795
|
|
|
1686
1796
|
await terminate_environment(EnvTerminateRequest(env_id=created_env_id))
|
|
1687
1797
|
logger.info(
|
|
@@ -1690,44 +1800,37 @@ async def execute_rollout(
|
|
|
1690
1800
|
str(env_seed_used) if env_seed_used is not None else "unknown",
|
|
1691
1801
|
)
|
|
1692
1802
|
# Verify removal from registry
|
|
1693
|
-
|
|
1803
|
+
with contextlib.suppress(Exception):
|
|
1694
1804
|
_post = registry.get_env(created_env_id)
|
|
1695
1805
|
logger.info(
|
|
1696
1806
|
"ROLL_OUT: env_killed=%s (post_lookup=%s)",
|
|
1697
1807
|
str(_post is None),
|
|
1698
1808
|
str(_post),
|
|
1699
1809
|
)
|
|
1700
|
-
except Exception:
|
|
1701
|
-
pass
|
|
1702
1810
|
except Exception as _te:
|
|
1703
1811
|
logger.warning(f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}")
|
|
1704
1812
|
|
|
1705
1813
|
# Best-effort policy cleanup if we created one (avoid reuse across rollouts)
|
|
1706
|
-
|
|
1814
|
+
with contextlib.suppress(Exception):
|
|
1707
1815
|
if created_policy_id:
|
|
1708
|
-
from .policy_routes import
|
|
1816
|
+
from .policy_routes import PolicyTerminateRequest, terminate_policy
|
|
1709
1817
|
|
|
1710
1818
|
await terminate_policy(PolicyTerminateRequest(policy_id=created_policy_id))
|
|
1711
1819
|
logger.info("ROLL_OUT: terminated policy policy_id=%s", str(created_policy_id))
|
|
1712
|
-
except Exception:
|
|
1713
|
-
pass
|
|
1714
1820
|
|
|
1715
1821
|
if not finalized:
|
|
1716
|
-
|
|
1822
|
+
session_trace = None
|
|
1823
|
+
with contextlib.suppress(Exception):
|
|
1717
1824
|
session_trace = await tracing_context.finalize(
|
|
1718
1825
|
total_reward=total_reward,
|
|
1719
1826
|
achievement_state=prev_achievements,
|
|
1720
1827
|
total_steps=len(trajectory_steps),
|
|
1721
1828
|
)
|
|
1722
|
-
except Exception:
|
|
1723
|
-
session_trace = None
|
|
1724
1829
|
finalized = True
|
|
1725
1830
|
|
|
1726
|
-
|
|
1831
|
+
with contextlib.suppress(Exception):
|
|
1727
1832
|
_clear_seed_side_effects()
|
|
1728
1833
|
logger.info("ROLL_OUT: RNG seed terminated/cleared before conclusion")
|
|
1729
|
-
except Exception:
|
|
1730
|
-
pass
|
|
1731
1834
|
|
|
1732
1835
|
|
|
1733
1836
|
@router.post("/run/abort", response_model=RunAbortResponse)
|