synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +8 -11
- examples/dev/qwen3_32b_qlora_4xh100.toml +40 -0
- examples/multi_step/crafter_rl_lora.md +29 -0
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +65 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +19 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +39 -0
- examples/qwen_coder/todos.md +38 -0
- examples/qwen_coder/validate_jsonl.py +60 -0
- examples/rl/run_eval.py +36 -37
- examples/rl/run_rl_and_save.py +5 -5
- examples/rl/task_app/math_single_step.py +65 -43
- examples/rl/task_app/math_task_app.py +3 -3
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/PROPOSAL.md +53 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +5 -5
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +78 -21
- examples/warming_up_to_rl/groq_test.py +4 -4
- examples/warming_up_to_rl/manage_secrets.py +13 -18
- examples/warming_up_to_rl/run_eval.py +42 -44
- examples/warming_up_to_rl/run_fft_and_save.py +11 -16
- examples/warming_up_to_rl/run_local_rollout.py +1 -3
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
- examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
- examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
- examples/warming_up_to_rl/run_rl_and_save.py +5 -6
- examples/warming_up_to_rl/run_rollout_remote.py +8 -10
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +128 -21
- synth_ai/api/train/cli.py +80 -64
- synth_ai/api/train/config_finder.py +7 -2
- synth_ai/api/train/env_resolver.py +1 -1
- synth_ai/api/train/pollers.py +2 -1
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +1 -2
- synth_ai/api/train/utils.py +13 -44
- synth_ai/cli/__init__.py +8 -0
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +1 -2
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +2 -1
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +2 -1
- synth_ai/cli/root.py +11 -13
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +529 -179
- synth_ai/cli/traces.py +6 -4
- synth_ai/cli/watch.py +12 -18
- synth_ai/demo_registry.py +1 -1
- synth_ai/demos/core/cli.py +36 -43
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +17 -25
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +2 -5
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +4 -7
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/handshake.py +9 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +18 -10
- synth_ai/inference/client.py +15 -5
- synth_ai/jobs/client.py +78 -83
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +91 -24
- synth_ai/learning/config.py +2 -38
- synth_ai/learning/ft_client.py +4 -59
- synth_ai/learning/health.py +5 -6
- synth_ai/learning/jobs.py +31 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -4
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -8
- synth_ai/{rl → learning/rl}/env_keys.py +39 -15
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -281
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -24
- synth_ai/learning/validators.py +25 -28
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +25 -27
- synth_ai/task/apps/__init__.py +7 -8
- synth_ai/task/auth.py +8 -8
- synth_ai/task/client.py +14 -14
- synth_ai/task/contracts.py +36 -35
- synth_ai/task/datasets.py +6 -5
- synth_ai/task/errors.py +10 -10
- synth_ai/task/health.py +17 -9
- synth_ai/task/json.py +58 -23
- synth_ai/task/proxy.py +13 -9
- synth_ai/task/rubrics.py +16 -15
- synth_ai/task/server.py +12 -12
- synth_ai/task/tracing_utils.py +4 -4
- synth_ai/task/vendors.py +5 -6
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/decorators.py +18 -16
- synth_ai/tracing_v3/hooks.py +5 -5
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/session_tracer.py +40 -14
- synth_ai/tracing_v3/storage/base.py +85 -0
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -7
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -4
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/main.py +6 -6
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
- synth_ai/{lm → v0/lm}/overrides.py +2 -2
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/METADATA +10 -7
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/RECORD +269 -233
- examples/common_old/backend.py +0 -20
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1038
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -152
- examples/rl_old/task_app.py +0 -1131
- synth_ai/experimental/synth_oss.py +0 -445
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -249
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -838
- synth_ai/zyk/__init__.py +0 -30
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import contextlib
|
|
3
4
|
import logging
|
|
5
|
+
import os
|
|
4
6
|
from datetime import datetime
|
|
5
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
6
8
|
|
|
7
9
|
from fastapi import APIRouter, HTTPException, Request
|
|
8
10
|
from pydantic import BaseModel
|
|
@@ -11,8 +13,6 @@ from .envs.crafter.policy import CrafterPolicy
|
|
|
11
13
|
from .inference.openai_client import create_inference_client
|
|
12
14
|
from .registry import registry
|
|
13
15
|
from .storage.volume import storage
|
|
14
|
-
import os
|
|
15
|
-
from typing import Tuple
|
|
16
16
|
|
|
17
17
|
# Token budgeting (shared logic with inference server)
|
|
18
18
|
try:
|
|
@@ -34,10 +34,10 @@ router = APIRouter()
|
|
|
34
34
|
|
|
35
35
|
class PolicyCreateRequest(BaseModel):
|
|
36
36
|
policy_name: str
|
|
37
|
-
config:
|
|
38
|
-
parent_policy_id:
|
|
37
|
+
config: dict[str, Any] = {}
|
|
38
|
+
parent_policy_id: str | None = None
|
|
39
39
|
rl_run_id: str
|
|
40
|
-
bound_env_id:
|
|
40
|
+
bound_env_id: str | None = None
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
class PolicyCreateResponse(BaseModel):
|
|
@@ -46,15 +46,15 @@ class PolicyCreateResponse(BaseModel):
|
|
|
46
46
|
|
|
47
47
|
class PolicyStepRequest(BaseModel):
|
|
48
48
|
policy_id: str
|
|
49
|
-
observation:
|
|
50
|
-
state:
|
|
51
|
-
metadata:
|
|
49
|
+
observation: dict[str, Any]
|
|
50
|
+
state: dict[str, Any] | None = None
|
|
51
|
+
metadata: dict[str, Any] | None = None
|
|
52
52
|
dry_run: bool = False
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
class PolicyStepResponse(BaseModel):
|
|
56
|
-
tool_calls:
|
|
57
|
-
meta:
|
|
56
|
+
tool_calls: list[dict[str, Any]]
|
|
57
|
+
meta: dict[str, Any]
|
|
58
58
|
|
|
59
59
|
|
|
60
60
|
class PolicySnapshotRequest(BaseModel):
|
|
@@ -91,14 +91,23 @@ async def create_policy(
|
|
|
91
91
|
) -> PolicyCreateResponse:
|
|
92
92
|
"""Create a new policy instance."""
|
|
93
93
|
try:
|
|
94
|
-
task_app = req.app.state
|
|
95
|
-
|
|
96
|
-
# Set defaults from TaskApp if not provided
|
|
97
|
-
config = request.config
|
|
98
|
-
if "inference_url" not in config:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
94
|
+
task_app = getattr(req.app.state, "task_app", None)
|
|
95
|
+
|
|
96
|
+
# Set defaults from TaskApp / environment if not provided
|
|
97
|
+
config = dict(request.config or {})
|
|
98
|
+
if "inference_url" not in config and task_app is not None:
|
|
99
|
+
base_url = getattr(task_app, "vllm_base_url", None)
|
|
100
|
+
if base_url:
|
|
101
|
+
config["inference_url"] = base_url
|
|
102
|
+
if "model" not in config and task_app is not None:
|
|
103
|
+
default_model = getattr(task_app, "default_model", None)
|
|
104
|
+
if default_model:
|
|
105
|
+
config["model"] = default_model
|
|
106
|
+
if "inference_url" not in config or "model" not in config:
|
|
107
|
+
raise HTTPException(
|
|
108
|
+
status_code=422,
|
|
109
|
+
detail="Policy configuration must include 'inference_url' and 'model'.",
|
|
110
|
+
)
|
|
102
111
|
|
|
103
112
|
# Create policy instance based on name
|
|
104
113
|
pname = request.policy_name.lower()
|
|
@@ -110,11 +119,13 @@ async def create_policy(
|
|
|
110
119
|
await policy.initialize(config)
|
|
111
120
|
elif pname in ["wordle-react", "wordle"]:
|
|
112
121
|
try:
|
|
113
|
-
from .envs.wordle.policy import WordlePolicy
|
|
122
|
+
from .envs.wordle.policy import WordlePolicy
|
|
114
123
|
except Exception as e:
|
|
115
|
-
raise HTTPException(
|
|
124
|
+
raise HTTPException(
|
|
125
|
+
status_code=500, detail=f"Wordle policy unavailable: {e}"
|
|
126
|
+
) from e
|
|
116
127
|
|
|
117
|
-
policy =
|
|
128
|
+
policy = WordlePolicy(
|
|
118
129
|
inference_url=config["inference_url"],
|
|
119
130
|
model=config["model"],
|
|
120
131
|
word_length=int(config["word_length"]),
|
|
@@ -123,22 +134,24 @@ async def create_policy(
|
|
|
123
134
|
await policy.initialize(config)
|
|
124
135
|
elif pname in ["sokoban-react", "sokoban"]:
|
|
125
136
|
try:
|
|
126
|
-
from .envs.sokoban.policy import SokobanPolicy
|
|
137
|
+
from .envs.sokoban.policy import SokobanPolicy
|
|
127
138
|
except Exception as e:
|
|
128
|
-
raise HTTPException(
|
|
139
|
+
raise HTTPException(
|
|
140
|
+
status_code=500, detail=f"Sokoban policy unavailable: {e}"
|
|
141
|
+
) from e
|
|
129
142
|
|
|
130
|
-
policy =
|
|
143
|
+
policy = SokobanPolicy(
|
|
131
144
|
inference_url=config["inference_url"],
|
|
132
145
|
model=config["model"],
|
|
133
146
|
)
|
|
134
147
|
await policy.initialize(config)
|
|
135
148
|
elif pname in ["math-react", "math"]:
|
|
136
149
|
try:
|
|
137
|
-
from .envs.math.policy import MathPolicy
|
|
150
|
+
from .envs.math.policy import MathPolicy
|
|
138
151
|
except Exception as e:
|
|
139
|
-
raise HTTPException(status_code=500, detail=f"Math policy unavailable: {e}")
|
|
152
|
+
raise HTTPException(status_code=500, detail=f"Math policy unavailable: {e}") from e
|
|
140
153
|
|
|
141
|
-
policy =
|
|
154
|
+
policy = MathPolicy(
|
|
142
155
|
inference_url=config["inference_url"],
|
|
143
156
|
model=config["model"],
|
|
144
157
|
)
|
|
@@ -160,7 +173,7 @@ async def create_policy(
|
|
|
160
173
|
|
|
161
174
|
except Exception as e:
|
|
162
175
|
logger.error(f"Failed to create policy: {e}")
|
|
163
|
-
raise HTTPException(status_code=500, detail=str(e))
|
|
176
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
164
177
|
|
|
165
178
|
|
|
166
179
|
@router.post("/step", response_model=PolicyStepResponse)
|
|
@@ -178,107 +191,165 @@ async def step_policy(
|
|
|
178
191
|
policy = handle.policy
|
|
179
192
|
tracing_context = getattr(req.state, "rollout_tracing", None)
|
|
180
193
|
|
|
181
|
-
|
|
194
|
+
obs_text = request.observation
|
|
182
195
|
if isinstance(request.observation, dict):
|
|
183
196
|
if isinstance(policy, CrafterPolicy):
|
|
184
197
|
from .envs.crafter.shared import format_observation as format_crafter
|
|
185
198
|
|
|
186
199
|
obs_text = format_crafter(request.observation)
|
|
187
|
-
|
|
200
|
+
else:
|
|
201
|
+
formatted: str | None = None
|
|
202
|
+
|
|
203
|
+
# Wordle formatting
|
|
188
204
|
try:
|
|
189
|
-
from .envs.wordle.policy import WordlePolicy
|
|
205
|
+
from .envs.wordle.policy import WordlePolicy
|
|
190
206
|
except Exception:
|
|
191
|
-
|
|
207
|
+
wordle_policy_cls = None # type: ignore[assignment]
|
|
208
|
+
else:
|
|
209
|
+
wordle_policy_cls = WordlePolicy
|
|
192
210
|
|
|
193
|
-
if
|
|
211
|
+
if formatted is None and wordle_policy_cls is not None and isinstance(
|
|
212
|
+
policy, wordle_policy_cls
|
|
213
|
+
):
|
|
194
214
|
from .envs.wordle.shared import format_observation_wordle
|
|
195
215
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
216
|
+
# ASSERTION: Validate observation structure
|
|
217
|
+
assert request.observation is not None, "request.observation cannot be None"
|
|
218
|
+
assert isinstance(request.observation, dict), (
|
|
219
|
+
f"request.observation must be dict, got {type(request.observation)}"
|
|
220
|
+
)
|
|
201
221
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
f"DEBUG POLICY_ROUTES: Observation
|
|
224
|
-
|
|
222
|
+
required_keys = {
|
|
223
|
+
"text",
|
|
224
|
+
"status",
|
|
225
|
+
"remaining_guesses",
|
|
226
|
+
"guesses",
|
|
227
|
+
"feedback",
|
|
228
|
+
"reward_last",
|
|
229
|
+
"total_reward",
|
|
230
|
+
"terminated",
|
|
231
|
+
}
|
|
232
|
+
missing_keys = required_keys - set(request.observation.keys())
|
|
233
|
+
assert (
|
|
234
|
+
not missing_keys
|
|
235
|
+
), f"Wordle observation missing required keys: {missing_keys}"
|
|
236
|
+
|
|
237
|
+
print("DEBUG POLICY_ROUTES: About to format Wordle observation")
|
|
238
|
+
print(f"DEBUG POLICY_ROUTES: Observation type: {type(request.observation)}")
|
|
239
|
+
print(
|
|
240
|
+
f"DEBUG POLICY_ROUTES: Observation keys: {list(request.observation.keys())}"
|
|
241
|
+
)
|
|
242
|
+
feedback_val = request.observation["feedback"]
|
|
243
|
+
print(f"DEBUG POLICY_ROUTES: Observation feedback: {feedback_val}")
|
|
244
|
+
print(
|
|
245
|
+
f"DEBUG POLICY_ROUTES: Observation guesses: {request.observation['guesses']}"
|
|
246
|
+
)
|
|
247
|
+
print(
|
|
248
|
+
"DEBUG POLICY_ROUTES: Observation text length: "
|
|
249
|
+
f"{len(request.observation['text'])}"
|
|
250
|
+
)
|
|
225
251
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
"
|
|
241
|
-
|
|
252
|
+
guesses = request.observation["guesses"]
|
|
253
|
+
feedback = request.observation["feedback"]
|
|
254
|
+
assert isinstance(guesses, list), f"guesses must be list, got {type(guesses)}"
|
|
255
|
+
assert isinstance(
|
|
256
|
+
feedback, list
|
|
257
|
+
), f"feedback must be list, got {type(feedback)}"
|
|
258
|
+
|
|
259
|
+
formatted = format_observation_wordle(request.observation)
|
|
260
|
+
|
|
261
|
+
assert isinstance(formatted, str), (
|
|
262
|
+
f"obs_text must be string, got {type(formatted)}"
|
|
263
|
+
)
|
|
264
|
+
assert len(formatted) > 0, "obs_text cannot be empty"
|
|
265
|
+
assert "WORDLE" in formatted, "obs_text must contain 'WORDLE' header"
|
|
266
|
+
assert "Respond with a single tool call" in formatted, (
|
|
267
|
+
"obs_text must contain instruction text"
|
|
268
|
+
)
|
|
242
269
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
270
|
+
print(
|
|
271
|
+
f"DEBUG POLICY_ROUTES: Formatted obs_text length: {len(formatted)}"
|
|
272
|
+
)
|
|
273
|
+
print(
|
|
274
|
+
"DEBUG POLICY_ROUTES: Formatted obs_text contains 🟩: "
|
|
275
|
+
f"{'🟩' in formatted}"
|
|
276
|
+
)
|
|
277
|
+
print(
|
|
278
|
+
"DEBUG POLICY_ROUTES: Formatted obs_text contains 🟨: "
|
|
279
|
+
f"{'🟨' in formatted}"
|
|
280
|
+
)
|
|
281
|
+
print(
|
|
282
|
+
"DEBUG POLICY_ROUTES: Formatted obs_text contains ⬛: "
|
|
283
|
+
f"{'⬛' in formatted}"
|
|
284
|
+
)
|
|
285
|
+
print(
|
|
286
|
+
"DEBUG POLICY_ROUTES: Formatted obs_text first 200 chars: "
|
|
287
|
+
f"{formatted[:200]}"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Sokoban formatting
|
|
249
291
|
try:
|
|
250
|
-
from .envs.sokoban.policy import SokobanPolicy
|
|
292
|
+
from .envs.sokoban.policy import SokobanPolicy
|
|
251
293
|
except Exception:
|
|
252
|
-
|
|
294
|
+
sokoban_policy_cls = None # type: ignore[assignment]
|
|
295
|
+
else:
|
|
296
|
+
sokoban_policy_cls = SokobanPolicy
|
|
253
297
|
|
|
254
|
-
if
|
|
298
|
+
if formatted is None and sokoban_policy_cls is not None and isinstance(
|
|
299
|
+
policy, sokoban_policy_cls
|
|
300
|
+
):
|
|
255
301
|
from .envs.sokoban.shared import format_observation_sokoban
|
|
256
302
|
|
|
257
|
-
|
|
258
|
-
|
|
303
|
+
formatted = format_observation_sokoban(request.observation)
|
|
304
|
+
|
|
305
|
+
# Math formatting
|
|
259
306
|
try:
|
|
260
|
-
from .envs.math.policy import MathPolicy
|
|
307
|
+
from .envs.math.policy import MathPolicy
|
|
261
308
|
except Exception:
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
309
|
+
math_policy_cls = None # type: ignore[assignment]
|
|
310
|
+
else:
|
|
311
|
+
math_policy_cls = MathPolicy
|
|
312
|
+
|
|
313
|
+
if formatted is None and math_policy_cls is not None and isinstance(
|
|
314
|
+
policy, math_policy_cls
|
|
315
|
+
):
|
|
265
316
|
try:
|
|
266
|
-
|
|
317
|
+
formatted = str(
|
|
267
318
|
request.observation.get("problem_text") or request.observation
|
|
268
319
|
)
|
|
269
320
|
except Exception:
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
321
|
+
formatted = str(request.observation)
|
|
322
|
+
|
|
323
|
+
if formatted is None:
|
|
324
|
+
formatted = str(request.observation)
|
|
325
|
+
|
|
326
|
+
obs_text = formatted
|
|
327
|
+
|
|
328
|
+
# Merge metadata with raw observation for multimodal policies
|
|
329
|
+
step_metadata: dict[str, Any] = dict(request.metadata or {})
|
|
330
|
+
step_metadata["raw_observation"] = request.observation
|
|
275
331
|
|
|
276
332
|
# Execute policy step to get inference request
|
|
277
333
|
tool_calls, meta = await policy.step(
|
|
278
334
|
observation_text=obs_text,
|
|
279
335
|
state=request.state,
|
|
280
|
-
metadata=
|
|
336
|
+
metadata=step_metadata,
|
|
281
337
|
)
|
|
338
|
+
# Compact tool call summary
|
|
339
|
+
with contextlib.suppress(Exception):
|
|
340
|
+
_summary: list[dict[str, Any]] = []
|
|
341
|
+
_tc = tool_calls or []
|
|
342
|
+
for _item in (_tc if isinstance(_tc, list) else []):
|
|
343
|
+
if isinstance(_item, dict):
|
|
344
|
+
_tool = _item.get("tool")
|
|
345
|
+
_args = _item.get("args")
|
|
346
|
+
_keys = list(_args.keys()) if isinstance(_args, dict) else []
|
|
347
|
+
_summary.append({"tool": _tool, "args_keys": _keys})
|
|
348
|
+
logger.info(
|
|
349
|
+
"POLICY_STEP: tool_calls=%d summary=%s",
|
|
350
|
+
len(_tc),
|
|
351
|
+
_summary,
|
|
352
|
+
)
|
|
282
353
|
|
|
283
354
|
# If not dry run, perform inference
|
|
284
355
|
if not request.dry_run and "inference_request" in meta:
|
|
@@ -286,8 +357,8 @@ async def step_policy(
|
|
|
286
357
|
inf_req = meta["inference_request"]
|
|
287
358
|
msgs = inf_req["messages"]
|
|
288
359
|
model_name = inf_req.get("model") or getattr(policy, "model", None) or ""
|
|
289
|
-
system_messages:
|
|
290
|
-
user_messages:
|
|
360
|
+
system_messages: list[str] = []
|
|
361
|
+
user_messages: list[str] = []
|
|
291
362
|
if msgs and len(msgs) > 0 and msgs[0]["role"] == "system":
|
|
292
363
|
sys_text = msgs[0]["content"]
|
|
293
364
|
policy_name = getattr(policy, "name", "") or type(policy).__name__.lower()
|
|
@@ -314,7 +385,6 @@ async def step_policy(
|
|
|
314
385
|
raise ValueError(
|
|
315
386
|
f"PROMPT MISMATCH: Crafter policy {policy_name} received Wordle system prompt: {sys_text[:200]}..."
|
|
316
387
|
)
|
|
317
|
-
|
|
318
388
|
elif policy_name in ("sokoban-react", "sokoban"):
|
|
319
389
|
if "Sokoban" not in sys_text:
|
|
320
390
|
raise ValueError(
|
|
@@ -353,40 +423,54 @@ async def step_policy(
|
|
|
353
423
|
return "".join(parts)
|
|
354
424
|
return str(content)
|
|
355
425
|
|
|
356
|
-
|
|
357
|
-
|
|
426
|
+
system_prompt_records: list[dict[str, Any]] = []
|
|
427
|
+
user_prompt_records: list[dict[str, Any]] = []
|
|
358
428
|
for message in msgs:
|
|
359
429
|
role = message.get("role")
|
|
360
|
-
|
|
430
|
+
raw_content = message.get("content")
|
|
431
|
+
content = _as_text(raw_content)
|
|
432
|
+
record = {"role": role, "text": content, "content": raw_content}
|
|
361
433
|
if role == "system":
|
|
362
|
-
|
|
434
|
+
system_prompt_records.append(record)
|
|
363
435
|
elif role == "user":
|
|
364
|
-
|
|
436
|
+
user_prompt_records.append(record)
|
|
437
|
+
|
|
438
|
+
logger.info(
|
|
439
|
+
"PROMPTS: system_msgs=%d user_msgs=%d last_user_chars=%d",
|
|
440
|
+
len(system_prompt_records),
|
|
441
|
+
len(user_prompt_records),
|
|
442
|
+
len(user_prompt_records[-1].get("text", "")) if user_prompt_records else 0,
|
|
443
|
+
)
|
|
365
444
|
|
|
366
|
-
if
|
|
445
|
+
if system_prompt_records:
|
|
367
446
|
logger.info("PROMPT_DUMP_SYSTEM_BEGIN")
|
|
368
|
-
for idx,
|
|
447
|
+
for idx, rec in enumerate(system_prompt_records):
|
|
448
|
+
smsg = rec.get("text", "")
|
|
369
449
|
logger.info(f"SYSTEM[{idx}]\n{smsg}")
|
|
370
450
|
logger.info("PROMPT_DUMP_SYSTEM_END")
|
|
371
451
|
|
|
372
|
-
if
|
|
452
|
+
if user_prompt_records:
|
|
373
453
|
logger.info("PROMPT_DUMP_USER_BEGIN")
|
|
374
|
-
for idx,
|
|
454
|
+
for idx, rec in enumerate(user_prompt_records):
|
|
455
|
+
umsg = rec.get("text", "")
|
|
375
456
|
logger.info(f"USER[{idx}]\n{umsg}")
|
|
376
457
|
logger.info("PROMPT_DUMP_USER_END")
|
|
377
458
|
# Print concise preview for visibility in standard logs
|
|
378
|
-
|
|
379
|
-
last_user =
|
|
380
|
-
|
|
459
|
+
with contextlib.suppress(Exception):
|
|
460
|
+
last_user = (
|
|
461
|
+
user_prompt_records[-1].get("text", "")
|
|
462
|
+
if user_prompt_records
|
|
463
|
+
else ""
|
|
464
|
+
)
|
|
381
465
|
print(f"[task:crafter] user prompt: {last_user}", flush=True)
|
|
382
|
-
except Exception:
|
|
383
|
-
pass
|
|
384
466
|
except Exception as e:
|
|
385
467
|
logger.warning(f"PROMPT_DUMP_FAILED: {e}")
|
|
386
468
|
|
|
387
469
|
if tracing_context is not None:
|
|
388
470
|
try:
|
|
389
|
-
await tracing_context.record_policy_prompts(
|
|
471
|
+
await tracing_context.record_policy_prompts(
|
|
472
|
+
system_prompt_records, user_prompt_records
|
|
473
|
+
)
|
|
390
474
|
except Exception as exc:
|
|
391
475
|
logger.debug(f"TRACING_PROMPTS_FAIL: {exc}")
|
|
392
476
|
|
|
@@ -399,10 +483,8 @@ async def step_policy(
|
|
|
399
483
|
)
|
|
400
484
|
|
|
401
485
|
# Ensure meta carries the final target URL for downstream logging/clients
|
|
402
|
-
|
|
486
|
+
with contextlib.suppress(Exception):
|
|
403
487
|
meta["inference_url"] = target_url
|
|
404
|
-
except Exception:
|
|
405
|
-
pass
|
|
406
488
|
|
|
407
489
|
# Select API key based on resolved target URL
|
|
408
490
|
api_key_override = None
|
|
@@ -411,11 +493,14 @@ async def step_policy(
|
|
|
411
493
|
|
|
412
494
|
if isinstance(target_url, str):
|
|
413
495
|
low_url = target_url.lower()
|
|
414
|
-
|
|
496
|
+
# Proxy endpoints should not receive a bearer; the server-side proxy holds the vendor key
|
|
497
|
+
if "/proxy/groq" in low_url or "/proxy/openai" in low_url:
|
|
498
|
+
api_key_override = None
|
|
499
|
+
elif "openai.com" in low_url:
|
|
415
500
|
api_key_override = _os.getenv("OPENAI_API_KEY") or getattr(
|
|
416
501
|
task_app, "openai_api_key", None
|
|
417
502
|
)
|
|
418
|
-
elif "groq.com" in low_url:
|
|
503
|
+
elif "groq.com" in low_url or "/proxy/groq" in low_url:
|
|
419
504
|
api_key_override = _os.getenv("GROQ_API_KEY")
|
|
420
505
|
else:
|
|
421
506
|
api_key_override = (
|
|
@@ -530,16 +615,16 @@ async def step_policy(
|
|
|
530
615
|
except Exception:
|
|
531
616
|
return max(1, int(len(text) / 4))
|
|
532
617
|
|
|
533
|
-
def _count_messages_tokens(messages:
|
|
618
|
+
def _count_messages_tokens(messages: list[dict[str, Any]]) -> int:
|
|
534
619
|
total = 0
|
|
535
620
|
for m in messages:
|
|
536
621
|
total += _count_tokens(_content_to_text(m.get("content")))
|
|
537
622
|
return total
|
|
538
623
|
|
|
539
624
|
def _truncate_messages_to_budget(
|
|
540
|
-
messages:
|
|
625
|
+
messages: list[dict[str, Any]],
|
|
541
626
|
max_tokens: int,
|
|
542
|
-
) ->
|
|
627
|
+
) -> tuple[list[dict[str, Any]], int, int, int]:
|
|
543
628
|
before = _count_messages_tokens(messages)
|
|
544
629
|
if before <= max_tokens:
|
|
545
630
|
return messages, before, before, len(messages)
|
|
@@ -549,7 +634,7 @@ async def step_policy(
|
|
|
549
634
|
if messages and messages[0].get("role") == "system":
|
|
550
635
|
system_msg = messages[0]
|
|
551
636
|
start_idx = 1
|
|
552
|
-
kept_rev:
|
|
637
|
+
kept_rev: list[dict[str, Any]] = []
|
|
553
638
|
total = _count_messages_tokens([system_msg] if system_msg else [])
|
|
554
639
|
# Walk from the end keeping most recent messages
|
|
555
640
|
for m in reversed(messages[start_idx:]):
|
|
@@ -590,7 +675,7 @@ async def step_policy(
|
|
|
590
675
|
)
|
|
591
676
|
if new_msgs is not msgs:
|
|
592
677
|
inf_req["messages"] = new_msgs
|
|
593
|
-
|
|
678
|
+
with contextlib.suppress(Exception):
|
|
594
679
|
logger.info(
|
|
595
680
|
{
|
|
596
681
|
"chat_truncated": True,
|
|
@@ -600,8 +685,6 @@ async def step_policy(
|
|
|
600
685
|
"kept_msgs": int(kept_count),
|
|
601
686
|
}
|
|
602
687
|
)
|
|
603
|
-
except Exception:
|
|
604
|
-
pass
|
|
605
688
|
except Exception as _trunc_e:
|
|
606
689
|
logger.warning(f"CHAT_TRUNCATION_FAILED: {type(_trunc_e).__name__}: {_trunc_e}")
|
|
607
690
|
|
|
@@ -629,64 +712,56 @@ async def step_policy(
|
|
|
629
712
|
# Prompt diagnostics before sending to inference: build chat template locally,
|
|
630
713
|
# count tokens, and log the first 10k tokens if oversized. Also stash a
|
|
631
714
|
# compact preview in meta so the trainer can surface it.
|
|
632
|
-
|
|
715
|
+
with contextlib.suppress(Exception):
|
|
633
716
|
req_for_diag = meta.get("inference_request", {})
|
|
634
717
|
model_for_diag = req_for_diag.get("model") or getattr(policy, "model", None) or ""
|
|
635
718
|
messages_for_diag = req_for_diag.get("messages") or []
|
|
636
719
|
if model_for_diag and messages_for_diag:
|
|
637
|
-
|
|
638
|
-
from transformers import AutoTokenizer
|
|
720
|
+
from transformers import AutoTokenizer
|
|
639
721
|
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
722
|
+
tok = AutoTokenizer.from_pretrained(model_for_diag)
|
|
723
|
+
prompt_preview = tok.apply_chat_template(
|
|
724
|
+
messages_for_diag,
|
|
725
|
+
add_generation_prompt=True,
|
|
726
|
+
tokenize=False,
|
|
727
|
+
)
|
|
728
|
+
ids = tok.encode(prompt_preview, add_special_tokens=False)
|
|
729
|
+
max_len = getattr(tok, "model_max_length", None)
|
|
730
|
+
over_limit = False
|
|
731
|
+
with contextlib.suppress(Exception):
|
|
732
|
+
over_limit = (
|
|
733
|
+
isinstance(max_len, int) and max_len > 0 and len(ids) > int(max_len)
|
|
645
734
|
)
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
preview_text = tok.decode(preview_ids, skip_special_tokens=False)
|
|
658
|
-
try:
|
|
659
|
-
logger.warning(
|
|
660
|
-
{
|
|
661
|
-
"prompt_token_overflow_local": True,
|
|
662
|
-
"model": str(model_for_diag),
|
|
663
|
-
"token_count": int(len(ids)),
|
|
664
|
-
"model_max_length": int(max_len)
|
|
665
|
-
if isinstance(max_len, int)
|
|
666
|
-
else None,
|
|
667
|
-
"preview_tokens_logged": int(len(preview_ids)),
|
|
668
|
-
"prompt_preview_first_10k_tokens": preview_text,
|
|
669
|
-
}
|
|
670
|
-
)
|
|
671
|
-
except Exception:
|
|
672
|
-
pass
|
|
673
|
-
try:
|
|
674
|
-
meta["prompt_debug"] = {
|
|
735
|
+
if over_limit or len(ids) > 10000:
|
|
736
|
+
preview_ids = ids[:10000]
|
|
737
|
+
preview_text = tok.decode(
|
|
738
|
+
preview_ids,
|
|
739
|
+
skip_special_tokens=False,
|
|
740
|
+
)
|
|
741
|
+
with contextlib.suppress(Exception):
|
|
742
|
+
logger.warning(
|
|
743
|
+
{
|
|
744
|
+
"prompt_token_overflow_local": True,
|
|
745
|
+
"model": str(model_for_diag),
|
|
675
746
|
"token_count": int(len(ids)),
|
|
676
747
|
"model_max_length": int(max_len)
|
|
677
748
|
if isinstance(max_len, int)
|
|
678
749
|
else None,
|
|
679
|
-
"
|
|
750
|
+
"preview_tokens_logged": int(len(preview_ids)),
|
|
751
|
+
"prompt_preview_first_10k_tokens": preview_text,
|
|
680
752
|
}
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
753
|
+
)
|
|
754
|
+
with contextlib.suppress(Exception):
|
|
755
|
+
meta["prompt_debug"] = {
|
|
756
|
+
"token_count": int(len(ids)),
|
|
757
|
+
"model_max_length": int(max_len)
|
|
758
|
+
if isinstance(max_len, int)
|
|
759
|
+
else None,
|
|
760
|
+
"preview_first_10k_tokens": preview_text,
|
|
761
|
+
}
|
|
687
762
|
|
|
688
763
|
# Emit the exact prompt/messages and tools before calling the LLM (bounded preview)
|
|
689
|
-
|
|
764
|
+
with contextlib.suppress(Exception):
|
|
690
765
|
req_dump = meta.get("inference_request", {})
|
|
691
766
|
msgs = req_dump.get("messages")
|
|
692
767
|
tools_dump = req_dump.get("tools")
|
|
@@ -706,11 +781,9 @@ async def step_policy(
|
|
|
706
781
|
"tools_preview": tools_compact,
|
|
707
782
|
}
|
|
708
783
|
)
|
|
709
|
-
except Exception:
|
|
710
|
-
pass
|
|
711
784
|
|
|
712
785
|
# Normalize request for non-OpenAI endpoints (strict schemas)
|
|
713
|
-
|
|
786
|
+
with contextlib.suppress(Exception):
|
|
714
787
|
base = str(target_url or "")
|
|
715
788
|
is_openai_dotcom = "openai.com" in base.lower()
|
|
716
789
|
if not is_openai_dotcom:
|
|
@@ -719,7 +792,7 @@ async def step_policy(
|
|
|
719
792
|
# Force structured tool_choice if a bare "required" is present
|
|
720
793
|
if req_body.get("tool_choice") == "required":
|
|
721
794
|
func_name = "interact_many"
|
|
722
|
-
|
|
795
|
+
with contextlib.suppress(Exception):
|
|
723
796
|
tools_arr = req_body.get("tools") or []
|
|
724
797
|
if isinstance(tools_arr, list) and tools_arr:
|
|
725
798
|
f = (
|
|
@@ -730,8 +803,6 @@ async def step_policy(
|
|
|
730
803
|
cand = (f or {}).get("name") if isinstance(f, dict) else None
|
|
731
804
|
if isinstance(cand, str) and cand:
|
|
732
805
|
func_name = cand
|
|
733
|
-
except Exception:
|
|
734
|
-
pass
|
|
735
806
|
req_body["tool_choice"] = {
|
|
736
807
|
"type": "function",
|
|
737
808
|
"function": {"name": func_name},
|
|
@@ -739,7 +810,7 @@ async def step_policy(
|
|
|
739
810
|
req_body["parallel_tool_calls"] = False
|
|
740
811
|
req_body.setdefault("function_call", {"name": func_name})
|
|
741
812
|
# Inject extra_body for thinking controls expected by Modal service
|
|
742
|
-
|
|
813
|
+
with contextlib.suppress(Exception):
|
|
743
814
|
tb = req_body.get("thinking_budget")
|
|
744
815
|
tm = str(req_body.get("thinking_mode") or "").lower()
|
|
745
816
|
enable_thinking = bool(tb) or tm == "think"
|
|
@@ -747,25 +818,52 @@ async def step_policy(
|
|
|
747
818
|
chat_kwargs = dict(extra.get("chat_template_kwargs") or {})
|
|
748
819
|
if enable_thinking:
|
|
749
820
|
chat_kwargs["enable_thinking"] = True
|
|
750
|
-
if isinstance(tb,
|
|
751
|
-
|
|
821
|
+
if isinstance(tb, int | float | str) and str(tb).strip():
|
|
822
|
+
with contextlib.suppress(Exception):
|
|
752
823
|
chat_kwargs["thinking_budget"] = int(tb)
|
|
753
|
-
except Exception:
|
|
754
|
-
pass
|
|
755
824
|
if chat_kwargs:
|
|
756
825
|
extra["chat_template_kwargs"] = chat_kwargs
|
|
757
826
|
# Ensure stop_after_tool_calls honored via extra_body for stricter servers
|
|
758
827
|
extra.setdefault("stop_after_tool_calls", 1)
|
|
759
828
|
if extra:
|
|
760
829
|
req_body["extra_body"] = extra
|
|
761
|
-
except Exception:
|
|
762
|
-
pass
|
|
763
830
|
# Provide a conservative default temperature if missing
|
|
764
831
|
if "temperature" not in req_body:
|
|
765
832
|
req_body["temperature"] = 0.1
|
|
766
833
|
meta["inference_request"] = req_body
|
|
767
|
-
|
|
768
|
-
|
|
834
|
+
|
|
835
|
+
# Strip image parts: Crafter policy currently only uses text prompts.
|
|
836
|
+
# Some providers reject image_url payloads entirely, so always flatten to plain text.
|
|
837
|
+
req_body2 = meta.get("inference_request", {})
|
|
838
|
+
if isinstance(req_body2, dict):
|
|
839
|
+
msgs = req_body2.get("messages")
|
|
840
|
+
if isinstance(msgs, list):
|
|
841
|
+
new_msgs = []
|
|
842
|
+
changed = False
|
|
843
|
+
for m in msgs:
|
|
844
|
+
try:
|
|
845
|
+
if isinstance(m, dict):
|
|
846
|
+
content = m.get("content")
|
|
847
|
+
if isinstance(content, list):
|
|
848
|
+
parts: list[str] = []
|
|
849
|
+
for seg in content:
|
|
850
|
+
if isinstance(seg, dict):
|
|
851
|
+
txt = seg.get("text") or seg.get("content")
|
|
852
|
+
if isinstance(txt, str) and txt:
|
|
853
|
+
parts.append(txt)
|
|
854
|
+
m2 = dict(m)
|
|
855
|
+
m2["content"] = "\n".join(parts)
|
|
856
|
+
new_msgs.append(m2)
|
|
857
|
+
changed = True
|
|
858
|
+
else:
|
|
859
|
+
new_msgs.append(m)
|
|
860
|
+
else:
|
|
861
|
+
new_msgs.append(m)
|
|
862
|
+
except Exception:
|
|
863
|
+
new_msgs.append(m)
|
|
864
|
+
if changed:
|
|
865
|
+
req_body2["messages"] = new_msgs
|
|
866
|
+
meta["inference_request"] = req_body2
|
|
769
867
|
|
|
770
868
|
_t_start = _t.time()
|
|
771
869
|
call_started_at = datetime.utcnow()
|
|
@@ -826,15 +924,13 @@ async def step_policy(
|
|
|
826
924
|
# Replace tool_calls with parsed result
|
|
827
925
|
if isinstance(parsed, list):
|
|
828
926
|
tool_calls = parsed
|
|
829
|
-
|
|
927
|
+
with contextlib.suppress(Exception):
|
|
830
928
|
logger.info(
|
|
831
929
|
"TOOLCALL_PARSE: parsed=%d has_tools=%s example=%r",
|
|
832
930
|
len(tool_calls) if isinstance(tool_calls, list) else -1,
|
|
833
931
|
bool(getattr(policy, "use_tools", True)),
|
|
834
932
|
(tool_calls[0] if isinstance(tool_calls, list) and tool_calls else None),
|
|
835
933
|
)
|
|
836
|
-
except Exception:
|
|
837
|
-
pass
|
|
838
934
|
except Exception as _pe:
|
|
839
935
|
logger.warning(f"Failed to parse tool calls: {str(_pe)}")
|
|
840
936
|
# Attach raw response + usage for observability
|
|
@@ -864,7 +960,7 @@ async def step_policy(
|
|
|
864
960
|
|
|
865
961
|
except Exception as e:
|
|
866
962
|
logger.error(f"Failed to step policy {request.policy_id}: {e}")
|
|
867
|
-
raise HTTPException(status_code=500, detail=str(e))
|
|
963
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
868
964
|
|
|
869
965
|
|
|
870
966
|
@router.post("/snapshot", response_model=PolicySnapshotResponse)
|
|
@@ -902,7 +998,7 @@ async def snapshot_policy(request: PolicySnapshotRequest) -> PolicySnapshotRespo
|
|
|
902
998
|
|
|
903
999
|
except Exception as e:
|
|
904
1000
|
logger.error(f"Failed to snapshot policy {request.policy_id}: {e}")
|
|
905
|
-
raise HTTPException(status_code=500, detail=str(e))
|
|
1001
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
906
1002
|
|
|
907
1003
|
|
|
908
1004
|
@router.post("/restore", response_model=PolicyRestoreResponse)
|
|
@@ -933,16 +1029,20 @@ async def restore_policy(request: PolicyRestoreRequest) -> PolicyRestoreResponse
|
|
|
933
1029
|
policy = await CrafterPolicy.deserialize(state_dict)
|
|
934
1030
|
elif low in ["wordle-react", "wordle"]:
|
|
935
1031
|
try:
|
|
936
|
-
from .envs.wordle.policy import WordlePolicy
|
|
1032
|
+
from .envs.wordle.policy import WordlePolicy
|
|
937
1033
|
except Exception as e:
|
|
938
|
-
raise HTTPException(
|
|
939
|
-
|
|
1034
|
+
raise HTTPException(
|
|
1035
|
+
status_code=500, detail=f"Wordle policy unavailable: {e}"
|
|
1036
|
+
) from e
|
|
1037
|
+
policy = await WordlePolicy.deserialize(state_dict)
|
|
940
1038
|
elif low in ["sokoban-react", "sokoban"]:
|
|
941
1039
|
try:
|
|
942
|
-
from .envs.sokoban.policy import SokobanPolicy
|
|
1040
|
+
from .envs.sokoban.policy import SokobanPolicy
|
|
943
1041
|
except Exception as e:
|
|
944
|
-
raise HTTPException(
|
|
945
|
-
|
|
1042
|
+
raise HTTPException(
|
|
1043
|
+
status_code=500, detail=f"Sokoban policy unavailable: {e}"
|
|
1044
|
+
) from e
|
|
1045
|
+
policy = await SokobanPolicy.deserialize(state_dict)
|
|
946
1046
|
else:
|
|
947
1047
|
raise HTTPException(
|
|
948
1048
|
status_code=422,
|
|
@@ -959,7 +1059,7 @@ async def restore_policy(request: PolicyRestoreRequest) -> PolicyRestoreResponse
|
|
|
959
1059
|
|
|
960
1060
|
except Exception as e:
|
|
961
1061
|
logger.error(f"Failed to restore policy from snapshot {request.snapshot_id}: {e}")
|
|
962
|
-
raise HTTPException(status_code=500, detail=str(e))
|
|
1062
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
963
1063
|
|
|
964
1064
|
|
|
965
1065
|
@router.post("/terminate", response_model=PolicyTerminateResponse)
|
|
@@ -980,4 +1080,4 @@ async def terminate_policy(request: PolicyTerminateRequest) -> PolicyTerminateRe
|
|
|
980
1080
|
|
|
981
1081
|
except Exception as e:
|
|
982
1082
|
logger.error(f"Failed to terminate policy {request.policy_id}: {e}")
|
|
983
|
-
raise HTTPException(status_code=500, detail=str(e))
|
|
1083
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|