synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +8 -11
- examples/dev/qwen3_32b_qlora_4xh100.toml +40 -0
- examples/multi_step/crafter_rl_lora.md +29 -0
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +65 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +19 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +39 -0
- examples/qwen_coder/todos.md +38 -0
- examples/qwen_coder/validate_jsonl.py +60 -0
- examples/rl/run_eval.py +36 -37
- examples/rl/run_rl_and_save.py +5 -5
- examples/rl/task_app/math_single_step.py +65 -43
- examples/rl/task_app/math_task_app.py +3 -3
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/PROPOSAL.md +53 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +5 -5
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +78 -21
- examples/warming_up_to_rl/groq_test.py +4 -4
- examples/warming_up_to_rl/manage_secrets.py +13 -18
- examples/warming_up_to_rl/run_eval.py +42 -44
- examples/warming_up_to_rl/run_fft_and_save.py +11 -16
- examples/warming_up_to_rl/run_local_rollout.py +1 -3
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
- examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
- examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
- examples/warming_up_to_rl/run_rl_and_save.py +5 -6
- examples/warming_up_to_rl/run_rollout_remote.py +8 -10
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +128 -21
- synth_ai/api/train/cli.py +80 -64
- synth_ai/api/train/config_finder.py +7 -2
- synth_ai/api/train/env_resolver.py +1 -1
- synth_ai/api/train/pollers.py +2 -1
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +1 -2
- synth_ai/api/train/utils.py +13 -44
- synth_ai/cli/__init__.py +8 -0
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +1 -2
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +2 -1
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +2 -1
- synth_ai/cli/root.py +11 -13
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +529 -179
- synth_ai/cli/traces.py +6 -4
- synth_ai/cli/watch.py +12 -18
- synth_ai/demo_registry.py +1 -1
- synth_ai/demos/core/cli.py +36 -43
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +17 -25
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +2 -5
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +4 -7
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/handshake.py +9 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +18 -10
- synth_ai/inference/client.py +15 -5
- synth_ai/jobs/client.py +78 -83
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +91 -24
- synth_ai/learning/config.py +2 -38
- synth_ai/learning/ft_client.py +4 -59
- synth_ai/learning/health.py +5 -6
- synth_ai/learning/jobs.py +31 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -4
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -8
- synth_ai/{rl → learning/rl}/env_keys.py +39 -15
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -281
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -24
- synth_ai/learning/validators.py +25 -28
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +25 -27
- synth_ai/task/apps/__init__.py +7 -8
- synth_ai/task/auth.py +8 -8
- synth_ai/task/client.py +14 -14
- synth_ai/task/contracts.py +36 -35
- synth_ai/task/datasets.py +6 -5
- synth_ai/task/errors.py +10 -10
- synth_ai/task/health.py +17 -9
- synth_ai/task/json.py +58 -23
- synth_ai/task/proxy.py +13 -9
- synth_ai/task/rubrics.py +16 -15
- synth_ai/task/server.py +12 -12
- synth_ai/task/tracing_utils.py +4 -4
- synth_ai/task/vendors.py +5 -6
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/decorators.py +18 -16
- synth_ai/tracing_v3/hooks.py +5 -5
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/session_tracer.py +40 -14
- synth_ai/tracing_v3/storage/base.py +85 -0
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -7
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -4
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/main.py +6 -6
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
- synth_ai/{lm → v0/lm}/overrides.py +2 -2
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/METADATA +10 -7
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/RECORD +269 -233
- examples/common_old/backend.py +0 -20
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1038
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -152
- examples/rl_old/task_app.py +0 -1131
- synth_ai/experimental/synth_oss.py +0 -445
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -249
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -838
- synth_ai/zyk/__init__.py +0 -30
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,618 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextlib
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OpenAIClient:
|
|
14
|
+
"""Async HTTP client for OpenAI-compatible inference servers (vLLM)."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
base_url: str,
|
|
19
|
+
api_key: str | None = None,
|
|
20
|
+
timeout_s: float = 120.0,
|
|
21
|
+
) -> None:
|
|
22
|
+
self.base_url = base_url.rstrip("/")
|
|
23
|
+
self.api_key = api_key
|
|
24
|
+
self.timeout_s = timeout_s
|
|
25
|
+
self.headers = {}
|
|
26
|
+
|
|
27
|
+
if api_key:
|
|
28
|
+
self.headers["Authorization"] = f"Bearer {api_key}"
|
|
29
|
+
|
|
30
|
+
def _fix_model_parameters(
|
|
31
|
+
self, request: dict[str, Any], target_url: str | None = None
|
|
32
|
+
) -> dict[str, Any]:
|
|
33
|
+
"""
|
|
34
|
+
Fix parameter compatibility for newer OpenAI models.
|
|
35
|
+
|
|
36
|
+
Newer models like gpt-5-nano use 'max_completion_tokens' instead of 'max_tokens'.
|
|
37
|
+
"""
|
|
38
|
+
if not request:
|
|
39
|
+
return request
|
|
40
|
+
|
|
41
|
+
# Make a copy to avoid modifying the original
|
|
42
|
+
fixed_request = request.copy()
|
|
43
|
+
|
|
44
|
+
# Determine if target is OpenAI-compatible (OpenAI, Azure OpenAI, Groq);
|
|
45
|
+
# strip fields those endpoints don't accept
|
|
46
|
+
is_openai = False
|
|
47
|
+
try:
|
|
48
|
+
if isinstance(target_url, str):
|
|
49
|
+
low = target_url.lower()
|
|
50
|
+
is_openai = (
|
|
51
|
+
("openai.com" in low)
|
|
52
|
+
or ("azure" in low and ".openai." in low)
|
|
53
|
+
or ("groq.com" in low)
|
|
54
|
+
or ("/openai" in low)
|
|
55
|
+
)
|
|
56
|
+
except Exception:
|
|
57
|
+
is_openai = False
|
|
58
|
+
|
|
59
|
+
model = fixed_request.get("model", "")
|
|
60
|
+
|
|
61
|
+
if is_openai:
|
|
62
|
+
# Remove fields OpenAI/Groq don't accept
|
|
63
|
+
for k in (
|
|
64
|
+
"stop_after_tool_calls",
|
|
65
|
+
"thinking_mode",
|
|
66
|
+
"thinking_budget",
|
|
67
|
+
"reasoning",
|
|
68
|
+
"extra_body",
|
|
69
|
+
"parallel_tool_calls",
|
|
70
|
+
"function_call",
|
|
71
|
+
):
|
|
72
|
+
if k in fixed_request:
|
|
73
|
+
fixed_request.pop(k, None)
|
|
74
|
+
|
|
75
|
+
# GPT-5 family specifics
|
|
76
|
+
if "gpt-5" in model or "gpt-4.1" in model:
|
|
77
|
+
# Convert max_tokens to max_completion_tokens for newer models
|
|
78
|
+
if "max_tokens" in fixed_request:
|
|
79
|
+
if "max_completion_tokens" not in fixed_request:
|
|
80
|
+
fixed_request["max_completion_tokens"] = fixed_request.pop("max_tokens")
|
|
81
|
+
logger.info(
|
|
82
|
+
f"Converted max_tokens to max_completion_tokens for model {model}"
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
fixed_request.pop("max_tokens")
|
|
86
|
+
logger.info(f"Removed conflicting max_tokens parameter for model {model}")
|
|
87
|
+
# Some OpenAI endpoints ignore/deny sampling fields for reasoning models
|
|
88
|
+
for k in ("temperature", "top_p"):
|
|
89
|
+
if k in fixed_request:
|
|
90
|
+
fixed_request.pop(k, None)
|
|
91
|
+
# If tools are present, force single tool choice to our function
|
|
92
|
+
try:
|
|
93
|
+
tools = fixed_request.get("tools")
|
|
94
|
+
if isinstance(tools, list) and tools:
|
|
95
|
+
# Choose the first provided function name from tools schema (e.g., run_command)
|
|
96
|
+
func_name = None
|
|
97
|
+
for t in tools:
|
|
98
|
+
try:
|
|
99
|
+
cand = None
|
|
100
|
+
if isinstance(t, dict):
|
|
101
|
+
f = t.get("function")
|
|
102
|
+
if isinstance(f, dict):
|
|
103
|
+
cand = f.get("name")
|
|
104
|
+
if isinstance(cand, str) and cand:
|
|
105
|
+
func_name = cand
|
|
106
|
+
break
|
|
107
|
+
except Exception:
|
|
108
|
+
continue
|
|
109
|
+
if not func_name:
|
|
110
|
+
func_name = "run_command"
|
|
111
|
+
fixed_request["tool_choice"] = {
|
|
112
|
+
"type": "function",
|
|
113
|
+
"function": {"name": func_name},
|
|
114
|
+
}
|
|
115
|
+
fixed_request["parallel_tool_calls"] = False
|
|
116
|
+
except Exception:
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
return fixed_request
|
|
120
|
+
|
|
121
|
+
async def generate(
|
|
122
|
+
self,
|
|
123
|
+
request: dict[str, Any],
|
|
124
|
+
base_url: str | None = None,
|
|
125
|
+
timeout_s: float | None = None,
|
|
126
|
+
extra_headers: dict[str, str] | None = None,
|
|
127
|
+
) -> dict[str, Any]:
|
|
128
|
+
"""
|
|
129
|
+
Send a chat completion request to the inference server.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
request: OpenAI-compatible chat completion request
|
|
133
|
+
base_url: Override base URL for this request
|
|
134
|
+
timeout_s: Override timeout for this request
|
|
135
|
+
extra_headers: Additional headers to include (e.g., X-Policy-Name)
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
OpenAI-compatible chat completion response
|
|
139
|
+
"""
|
|
140
|
+
url = (base_url or self.base_url).rstrip("/") + "/v1/chat/completions"
|
|
141
|
+
timeout = timeout_s or self.timeout_s
|
|
142
|
+
|
|
143
|
+
# Merge headers
|
|
144
|
+
headers = self.headers.copy()
|
|
145
|
+
if extra_headers:
|
|
146
|
+
headers.update(extra_headers)
|
|
147
|
+
|
|
148
|
+
# Fix parameter compatibility for newer models
|
|
149
|
+
processed_request = self._fix_model_parameters(request, target_url=url)
|
|
150
|
+
|
|
151
|
+
# Log request (redact messages in production)
|
|
152
|
+
logger.info(f"Inference POST target: {url}")
|
|
153
|
+
if extra_headers:
|
|
154
|
+
logger.info(f"Extra headers: {extra_headers}")
|
|
155
|
+
with contextlib.suppress(Exception):
|
|
156
|
+
keys_preview = sorted(processed_request.keys())
|
|
157
|
+
logger.info(f"Request keys: {keys_preview}")
|
|
158
|
+
|
|
159
|
+
# Final hard-guard for OpenAI: ensure unsupported field is not present
|
|
160
|
+
try:
|
|
161
|
+
if "openai" in url.lower() and "stop_after_tool_calls" in processed_request:
|
|
162
|
+
processed_request.pop("stop_after_tool_calls", None)
|
|
163
|
+
logger.info("Removed stop_after_tool_calls for OpenAI request")
|
|
164
|
+
# Groq-specific requirement: when using JSON mode, one of the messages must contain the word 'json'
|
|
165
|
+
low_url = url.lower()
|
|
166
|
+
if ("groq.com" in low_url or "/openai" in low_url) and isinstance(
|
|
167
|
+
processed_request, dict
|
|
168
|
+
):
|
|
169
|
+
rf = processed_request.get("response_format")
|
|
170
|
+
rf_type = None
|
|
171
|
+
if isinstance(rf, dict):
|
|
172
|
+
rf_type = str(rf.get("type") or "").lower()
|
|
173
|
+
if rf_type in {"json_object", "json_schema"}:
|
|
174
|
+
msgs = processed_request.get("messages")
|
|
175
|
+
has_json_word = False
|
|
176
|
+
if isinstance(msgs, list):
|
|
177
|
+
for m in msgs:
|
|
178
|
+
try:
|
|
179
|
+
content = m.get("content") if isinstance(m, dict) else None
|
|
180
|
+
text = None
|
|
181
|
+
if isinstance(content, str):
|
|
182
|
+
text = content
|
|
183
|
+
elif isinstance(content, list):
|
|
184
|
+
# Join any text segments
|
|
185
|
+
parts = []
|
|
186
|
+
for seg in content:
|
|
187
|
+
if isinstance(seg, dict) and isinstance(
|
|
188
|
+
seg.get("text"), str
|
|
189
|
+
):
|
|
190
|
+
parts.append(seg["text"])
|
|
191
|
+
text = "\n".join(parts)
|
|
192
|
+
if isinstance(text, str) and ("json" in text.lower()):
|
|
193
|
+
has_json_word = True
|
|
194
|
+
break
|
|
195
|
+
except Exception:
|
|
196
|
+
continue
|
|
197
|
+
if not has_json_word:
|
|
198
|
+
try:
|
|
199
|
+
instruction = (
|
|
200
|
+
"Respond in strict JSON only. Output a single valid JSON object."
|
|
201
|
+
)
|
|
202
|
+
if not isinstance(msgs, list):
|
|
203
|
+
msgs = []
|
|
204
|
+
# Prepend a system message to satisfy Groq requirement without changing user intent
|
|
205
|
+
prepend = {"role": "system", "content": instruction}
|
|
206
|
+
processed_request["messages"] = [prepend] + list(msgs)
|
|
207
|
+
logger.info(
|
|
208
|
+
"Injected JSON-mode system instruction for Groq response_format compliance"
|
|
209
|
+
)
|
|
210
|
+
except Exception:
|
|
211
|
+
pass
|
|
212
|
+
except Exception:
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
216
|
+
try:
|
|
217
|
+
response = await client.post(
|
|
218
|
+
url,
|
|
219
|
+
json=processed_request,
|
|
220
|
+
headers=headers,
|
|
221
|
+
)
|
|
222
|
+
response.raise_for_status()
|
|
223
|
+
|
|
224
|
+
# Rich response diagnostics
|
|
225
|
+
content_type = response.headers.get("content-type")
|
|
226
|
+
body_text = response.text
|
|
227
|
+
logger.info(
|
|
228
|
+
f"Inference response status=200, content-type={content_type}, bytes={len(body_text)}"
|
|
229
|
+
)
|
|
230
|
+
if body_text:
|
|
231
|
+
preview_len = min(800, len(body_text))
|
|
232
|
+
logger.info(
|
|
233
|
+
f"Inference response preview ({preview_len} bytes): {body_text[:preview_len]}"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
result = response.json()
|
|
237
|
+
logger.info(f"Inference response parsed_type={type(result).__name__}")
|
|
238
|
+
return result
|
|
239
|
+
|
|
240
|
+
except httpx.TimeoutException:
|
|
241
|
+
logger.error(f"Request to {url} timed out after {timeout}s")
|
|
242
|
+
raise
|
|
243
|
+
except httpx.HTTPStatusError as e:
|
|
244
|
+
status = e.response.status_code if e.response is not None else None
|
|
245
|
+
text = e.response.text if e.response is not None else str(e)
|
|
246
|
+
# Log full body for debugging remote failures
|
|
247
|
+
try:
|
|
248
|
+
logger.error(
|
|
249
|
+
{
|
|
250
|
+
"openai_http_error": True,
|
|
251
|
+
"status": status,
|
|
252
|
+
"url": url,
|
|
253
|
+
"body": text,
|
|
254
|
+
}
|
|
255
|
+
)
|
|
256
|
+
except Exception:
|
|
257
|
+
logger.error(f"HTTP error from {url}: {status} - {text}")
|
|
258
|
+
# For 4xx/5xx, print full sanitized request to aid debugging (especially Groq 400s)
|
|
259
|
+
try:
|
|
260
|
+
redacted_headers = dict(headers)
|
|
261
|
+
if "Authorization" in redacted_headers:
|
|
262
|
+
redacted_headers["Authorization"] = "***REDACTED***"
|
|
263
|
+
logger.error(
|
|
264
|
+
{
|
|
265
|
+
"request_debug": True,
|
|
266
|
+
"status": status,
|
|
267
|
+
"target": url,
|
|
268
|
+
"headers": redacted_headers,
|
|
269
|
+
"payload": processed_request,
|
|
270
|
+
}
|
|
271
|
+
)
|
|
272
|
+
except Exception:
|
|
273
|
+
pass
|
|
274
|
+
# Special case: token budget exceeded (OpenAI-compatible error schema)
|
|
275
|
+
try:
|
|
276
|
+
if status == 400 and e.response is not None:
|
|
277
|
+
data = e.response.json()
|
|
278
|
+
detail = data.get("detail") if isinstance(data, dict) else None
|
|
279
|
+
err_code = (detail or {}).get("error") if isinstance(detail, dict) else None
|
|
280
|
+
if err_code == "token_budget_exceeded":
|
|
281
|
+
info = (detail or {}).get("details") or {}
|
|
282
|
+
messages_tokens = int(info.get("messages_tokens") or 0)
|
|
283
|
+
model_limit = int(info.get("model_limit") or 0)
|
|
284
|
+
safety = 64
|
|
285
|
+
# Compute a conservative new max_tokens
|
|
286
|
+
new_max = max(16, model_limit - messages_tokens - safety)
|
|
287
|
+
try:
|
|
288
|
+
# Update request and retry once immediately with smaller budget
|
|
289
|
+
if isinstance(processed_request, dict):
|
|
290
|
+
processed_request = dict(processed_request)
|
|
291
|
+
if "max_completion_tokens" in processed_request:
|
|
292
|
+
processed_request["max_completion_tokens"] = new_max
|
|
293
|
+
processed_request.pop("max_tokens", None)
|
|
294
|
+
else:
|
|
295
|
+
processed_request["max_tokens"] = new_max
|
|
296
|
+
# Remove optional fields that some servers reject
|
|
297
|
+
for k in ("thinking_mode", "thinking_budget", "reasoning"):
|
|
298
|
+
processed_request.pop(k, None)
|
|
299
|
+
# Force structured tool choice
|
|
300
|
+
if processed_request.get("tool_choice") == "required":
|
|
301
|
+
func_name = "run_command"
|
|
302
|
+
try:
|
|
303
|
+
tools_arr = processed_request.get("tools") or []
|
|
304
|
+
if isinstance(tools_arr, list) and tools_arr:
|
|
305
|
+
f = (
|
|
306
|
+
tools_arr[0].get("function")
|
|
307
|
+
if isinstance(tools_arr[0], dict)
|
|
308
|
+
else None
|
|
309
|
+
)
|
|
310
|
+
cand = (
|
|
311
|
+
(f or {}).get("name")
|
|
312
|
+
if isinstance(f, dict)
|
|
313
|
+
else None
|
|
314
|
+
)
|
|
315
|
+
if isinstance(cand, str) and cand:
|
|
316
|
+
func_name = cand
|
|
317
|
+
except Exception:
|
|
318
|
+
pass
|
|
319
|
+
processed_request["tool_choice"] = {
|
|
320
|
+
"type": "function",
|
|
321
|
+
"function": {"name": func_name},
|
|
322
|
+
}
|
|
323
|
+
processed_request["parallel_tool_calls"] = False
|
|
324
|
+
logger.warning(
|
|
325
|
+
{
|
|
326
|
+
"token_budget_recovery": True,
|
|
327
|
+
"messages_tokens": messages_tokens,
|
|
328
|
+
"model_limit": model_limit,
|
|
329
|
+
"retry_max_tokens": new_max,
|
|
330
|
+
}
|
|
331
|
+
)
|
|
332
|
+
# Retry once with reduced budget
|
|
333
|
+
async with httpx.AsyncClient(timeout=timeout) as client2:
|
|
334
|
+
r2 = await client2.post(
|
|
335
|
+
url, json=processed_request, headers=headers
|
|
336
|
+
)
|
|
337
|
+
r2.raise_for_status()
|
|
338
|
+
return r2.json()
|
|
339
|
+
except Exception:
|
|
340
|
+
pass
|
|
341
|
+
except Exception:
|
|
342
|
+
pass
|
|
343
|
+
# Gracefully degrade on 422 so rollouts can still produce a trajectory
|
|
344
|
+
if status == 422:
|
|
345
|
+
try:
|
|
346
|
+
# Best-effort parse of error for diagnostics
|
|
347
|
+
err = None
|
|
348
|
+
try:
|
|
349
|
+
err = e.response.json()
|
|
350
|
+
except Exception:
|
|
351
|
+
err = {"error": "unprocessable", "detail": (text or "")[:200]}
|
|
352
|
+
logger.warning(
|
|
353
|
+
{
|
|
354
|
+
"inference_422_recovered": True,
|
|
355
|
+
"detail": err,
|
|
356
|
+
}
|
|
357
|
+
)
|
|
358
|
+
except Exception:
|
|
359
|
+
pass
|
|
360
|
+
# Return a minimal OpenAI-compatible response with no tool_calls/content
|
|
361
|
+
import time as _t
|
|
362
|
+
|
|
363
|
+
return {
|
|
364
|
+
"id": f"cmpl-{int(_t.time())}",
|
|
365
|
+
"object": "chat.completion",
|
|
366
|
+
"created": int(_t.time()),
|
|
367
|
+
"model": processed_request.get("model") or "unknown",
|
|
368
|
+
"choices": [
|
|
369
|
+
{
|
|
370
|
+
"index": 0,
|
|
371
|
+
"message": {"role": "assistant", "content": "", "tool_calls": []},
|
|
372
|
+
"finish_reason": "stop",
|
|
373
|
+
}
|
|
374
|
+
],
|
|
375
|
+
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
376
|
+
}
|
|
377
|
+
raise
|
|
378
|
+
except Exception as e:
|
|
379
|
+
logger.error(f"Unexpected error calling {url}: {e}")
|
|
380
|
+
raise
|
|
381
|
+
|
|
382
|
+
async def check_health(
|
|
383
|
+
self,
|
|
384
|
+
base_url: str | None = None,
|
|
385
|
+
timeout_s: float | None = None,
|
|
386
|
+
) -> dict[str, Any]:
|
|
387
|
+
"""
|
|
388
|
+
Check if the inference service is healthy.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
base_url: Override base URL for this request
|
|
392
|
+
timeout_s: Override timeout for this request
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
Health status dict with 'status' field
|
|
396
|
+
"""
|
|
397
|
+
url = (base_url or self.base_url).rstrip("/") + "/health"
|
|
398
|
+
timeout = timeout_s or 10.0
|
|
399
|
+
|
|
400
|
+
try:
|
|
401
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
402
|
+
response = await client.get(url, headers=self.headers)
|
|
403
|
+
response.raise_for_status()
|
|
404
|
+
return response.json()
|
|
405
|
+
except httpx.HTTPStatusError as e:
|
|
406
|
+
if e.response.status_code == 400:
|
|
407
|
+
# Service is overloaded but still responding
|
|
408
|
+
try:
|
|
409
|
+
data = e.response.json()
|
|
410
|
+
if data.get("status") == "overloaded":
|
|
411
|
+
return {"status": "overloaded", "retry_after": data.get("retry_after", 1)}
|
|
412
|
+
except Exception:
|
|
413
|
+
pass
|
|
414
|
+
return {"status": "unhealthy", "error": str(e)}
|
|
415
|
+
except Exception as e:
|
|
416
|
+
return {"status": "unhealthy", "error": str(e)}
|
|
417
|
+
|
|
418
|
+
async def generate_with_retries(
|
|
419
|
+
self,
|
|
420
|
+
request: dict[str, Any],
|
|
421
|
+
base_url: str | None = None,
|
|
422
|
+
timeout_s: float | None = None,
|
|
423
|
+
max_retries: int = 4,
|
|
424
|
+
backoff_factor: float = 2.0,
|
|
425
|
+
extra_headers: dict[str, str] | None = None,
|
|
426
|
+
) -> dict[str, Any]:
|
|
427
|
+
"""
|
|
428
|
+
Generate with exponential backoff retries for transient errors.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
request: OpenAI-compatible chat completion request
|
|
432
|
+
base_url: Override base URL
|
|
433
|
+
timeout_s: Override timeout
|
|
434
|
+
max_retries: Maximum number of retry attempts
|
|
435
|
+
backoff_factor: Exponential backoff multiplier
|
|
436
|
+
extra_headers: Additional headers to include (e.g., X-Policy-Name)
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
OpenAI-compatible chat completion response
|
|
440
|
+
"""
|
|
441
|
+
last_error = None
|
|
442
|
+
wait_time = 1.0
|
|
443
|
+
|
|
444
|
+
for attempt in range(max_retries + 1):
|
|
445
|
+
try:
|
|
446
|
+
# Apply parameter fixes to the request
|
|
447
|
+
processed_request = self._fix_model_parameters(
|
|
448
|
+
request,
|
|
449
|
+
target_url=(base_url or self.base_url).rstrip("/") + "/v1/chat/completions",
|
|
450
|
+
)
|
|
451
|
+
return await self.generate(
|
|
452
|
+
request=processed_request,
|
|
453
|
+
base_url=base_url,
|
|
454
|
+
timeout_s=timeout_s,
|
|
455
|
+
extra_headers=extra_headers,
|
|
456
|
+
)
|
|
457
|
+
except httpx.HTTPStatusError as e:
|
|
458
|
+
# Retry on 400 (overloaded), 429 (rate limit), 500 (internal error), 503 (service unavailable)
|
|
459
|
+
if e.response.status_code not in [400, 429, 500, 503]:
|
|
460
|
+
raise
|
|
461
|
+
last_error = e
|
|
462
|
+
if e.response.status_code == 400:
|
|
463
|
+
# Check if this is an overload error by looking at response content
|
|
464
|
+
try:
|
|
465
|
+
response_data = e.response.json()
|
|
466
|
+
if response_data.get("status") == "overloaded":
|
|
467
|
+
retry_after = response_data.get("retry_after", 1)
|
|
468
|
+
# Use the suggested retry_after time instead of exponential backoff for overload
|
|
469
|
+
wait_time = max(wait_time, float(retry_after))
|
|
470
|
+
logger.warning(
|
|
471
|
+
f"Inference service overloaded (400). {response_data} Retrying after {wait_time}s..."
|
|
472
|
+
)
|
|
473
|
+
else:
|
|
474
|
+
# This is a different type of 400 error, don't retry
|
|
475
|
+
try:
|
|
476
|
+
redacted_headers = {}
|
|
477
|
+
try:
|
|
478
|
+
redacted_headers = dict(self.headers)
|
|
479
|
+
if "Authorization" in redacted_headers:
|
|
480
|
+
redacted_headers["Authorization"] = "***REDACTED***"
|
|
481
|
+
except Exception:
|
|
482
|
+
redacted_headers = {}
|
|
483
|
+
logger.error(
|
|
484
|
+
{
|
|
485
|
+
"non_overload_400": True,
|
|
486
|
+
"target": (base_url or self.base_url),
|
|
487
|
+
"payload": processed_request,
|
|
488
|
+
"headers": redacted_headers,
|
|
489
|
+
"body": e.response.text if e.response is not None else None,
|
|
490
|
+
}
|
|
491
|
+
)
|
|
492
|
+
except Exception:
|
|
493
|
+
pass
|
|
494
|
+
raise RuntimeError(
|
|
495
|
+
f"Inference 400 response: {e.response.text if e.response is not None else 'Bad Request'}"
|
|
496
|
+
) from e
|
|
497
|
+
except Exception:
|
|
498
|
+
# If we can't parse the response, don't retry 400 errors
|
|
499
|
+
with contextlib.suppress(Exception):
|
|
500
|
+
logger.error(
|
|
501
|
+
{
|
|
502
|
+
"non_overload_400_unparsed": True,
|
|
503
|
+
"target": (base_url or self.base_url),
|
|
504
|
+
"payload": processed_request,
|
|
505
|
+
}
|
|
506
|
+
)
|
|
507
|
+
raise RuntimeError(
|
|
508
|
+
f"Inference 400 response (unparsed): {e.response.text if e.response is not None else 'Bad Request'}"
|
|
509
|
+
) from e
|
|
510
|
+
elif e.response.status_code == 503:
|
|
511
|
+
# Avoid referencing undefined response_data
|
|
512
|
+
try:
|
|
513
|
+
preview = (e.response.text or "")[:200]
|
|
514
|
+
except Exception:
|
|
515
|
+
preview = ""
|
|
516
|
+
logger.warning(
|
|
517
|
+
f"Flash returned 503; container may be cold starting. Retrying... body={preview}"
|
|
518
|
+
)
|
|
519
|
+
elif e.response.status_code == 500:
|
|
520
|
+
try:
|
|
521
|
+
preview = (e.response.text or "")[:200]
|
|
522
|
+
except Exception:
|
|
523
|
+
preview = ""
|
|
524
|
+
logger.warning(
|
|
525
|
+
f"Flash returned 500; inference service error. Retrying... body={preview}"
|
|
526
|
+
)
|
|
527
|
+
except httpx.TimeoutException as e:
|
|
528
|
+
last_error = e
|
|
529
|
+
|
|
530
|
+
if attempt < max_retries:
|
|
531
|
+
logger.warning(
|
|
532
|
+
f"Inference request failed (attempt {attempt + 1}/{max_retries + 1}), "
|
|
533
|
+
f"retrying in {wait_time}s..."
|
|
534
|
+
)
|
|
535
|
+
await asyncio.sleep(wait_time)
|
|
536
|
+
wait_time *= backoff_factor
|
|
537
|
+
|
|
538
|
+
raise last_error
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def create_inference_client(
|
|
542
|
+
task_app: Any,
|
|
543
|
+
api_key: str | None = None,
|
|
544
|
+
) -> OpenAIClient:
|
|
545
|
+
"""
|
|
546
|
+
Create an inference client using TaskApp configuration.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
task_app: TaskApp instance with vllm_base_url
|
|
550
|
+
api_key: Optional API key for authentication
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Configured OpenAIClient instance
|
|
554
|
+
"""
|
|
555
|
+
# Fallback to environment if caller didn't provide an API key
|
|
556
|
+
if api_key is None:
|
|
557
|
+
try:
|
|
558
|
+
import os as _os # local import to avoid module-level side effects
|
|
559
|
+
|
|
560
|
+
api_key = _os.getenv("OPENAI_API_KEY") or getattr(task_app, "openai_api_key", None)
|
|
561
|
+
except Exception:
|
|
562
|
+
api_key = None
|
|
563
|
+
|
|
564
|
+
import json as _json
|
|
565
|
+
import os as _os
|
|
566
|
+
import time as _time
|
|
567
|
+
|
|
568
|
+
if _os.getenv("SYNTH_FAKE_INFERENCE", "").strip():
|
|
569
|
+
|
|
570
|
+
class _DummyClient:
|
|
571
|
+
async def generate_with_retries(
|
|
572
|
+
self,
|
|
573
|
+
request: dict[str, Any],
|
|
574
|
+
base_url: str | None = None,
|
|
575
|
+
max_retries: int = 0,
|
|
576
|
+
backoff_factor: float = 1.0,
|
|
577
|
+
extra_headers: dict[str, str] | None = None,
|
|
578
|
+
) -> dict[str, Any]:
|
|
579
|
+
tool_call = {
|
|
580
|
+
"id": "call_dummy",
|
|
581
|
+
"type": "function",
|
|
582
|
+
"function": {
|
|
583
|
+
"name": "interact_many",
|
|
584
|
+
"arguments": _json.dumps({"actions": ["move_right"]}),
|
|
585
|
+
},
|
|
586
|
+
}
|
|
587
|
+
return {
|
|
588
|
+
"id": f"cmpl-{int(_time.time())}",
|
|
589
|
+
"object": "chat.completion",
|
|
590
|
+
"created": int(_time.time()),
|
|
591
|
+
"model": request.get("model") or "dummy-model",
|
|
592
|
+
"choices": [
|
|
593
|
+
{
|
|
594
|
+
"index": 0,
|
|
595
|
+
"message": {
|
|
596
|
+
"role": "assistant",
|
|
597
|
+
"content": "",
|
|
598
|
+
"tool_calls": [tool_call],
|
|
599
|
+
},
|
|
600
|
+
"finish_reason": "tool_calls",
|
|
601
|
+
}
|
|
602
|
+
],
|
|
603
|
+
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
async def check_health(
|
|
607
|
+
self,
|
|
608
|
+
base_url: str | None = None,
|
|
609
|
+
timeout_s: float | None = None,
|
|
610
|
+
) -> dict[str, Any]:
|
|
611
|
+
return {"status": "ok", "dummy": True}
|
|
612
|
+
|
|
613
|
+
return _DummyClient()
|
|
614
|
+
|
|
615
|
+
return OpenAIClient(
|
|
616
|
+
base_url=task_app.vllm_base_url,
|
|
617
|
+
api_key=api_key,
|
|
618
|
+
)
|