llama-stack 0.3.5__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/__init__.py +0 -5
- llama_stack/cli/llama.py +3 -3
- llama_stack/cli/stack/_list_deps.py +12 -23
- llama_stack/cli/stack/list_stacks.py +37 -18
- llama_stack/cli/stack/run.py +121 -11
- llama_stack/cli/stack/utils.py +0 -127
- llama_stack/core/access_control/access_control.py +69 -28
- llama_stack/core/access_control/conditions.py +15 -5
- llama_stack/core/admin.py +267 -0
- llama_stack/core/build.py +6 -74
- llama_stack/core/client.py +1 -1
- llama_stack/core/configure.py +6 -6
- llama_stack/core/conversations/conversations.py +28 -25
- llama_stack/core/datatypes.py +271 -79
- llama_stack/core/distribution.py +15 -16
- llama_stack/core/external.py +3 -3
- llama_stack/core/inspect.py +98 -15
- llama_stack/core/library_client.py +73 -61
- llama_stack/core/prompts/prompts.py +12 -11
- llama_stack/core/providers.py +17 -11
- llama_stack/core/resolver.py +65 -56
- llama_stack/core/routers/__init__.py +8 -12
- llama_stack/core/routers/datasets.py +1 -4
- llama_stack/core/routers/eval_scoring.py +7 -4
- llama_stack/core/routers/inference.py +55 -271
- llama_stack/core/routers/safety.py +52 -24
- llama_stack/core/routers/tool_runtime.py +6 -48
- llama_stack/core/routers/vector_io.py +130 -51
- llama_stack/core/routing_tables/benchmarks.py +24 -20
- llama_stack/core/routing_tables/common.py +1 -4
- llama_stack/core/routing_tables/datasets.py +22 -22
- llama_stack/core/routing_tables/models.py +119 -6
- llama_stack/core/routing_tables/scoring_functions.py +7 -7
- llama_stack/core/routing_tables/shields.py +1 -2
- llama_stack/core/routing_tables/toolgroups.py +17 -7
- llama_stack/core/routing_tables/vector_stores.py +51 -16
- llama_stack/core/server/auth.py +5 -3
- llama_stack/core/server/auth_providers.py +36 -20
- llama_stack/core/server/fastapi_router_registry.py +84 -0
- llama_stack/core/server/quota.py +2 -2
- llama_stack/core/server/routes.py +79 -27
- llama_stack/core/server/server.py +102 -87
- llama_stack/core/stack.py +235 -62
- llama_stack/core/storage/datatypes.py +26 -3
- llama_stack/{providers/utils → core/storage}/kvstore/__init__.py +2 -0
- llama_stack/{providers/utils → core/storage}/kvstore/kvstore.py +55 -24
- llama_stack/{providers/utils → core/storage}/kvstore/mongodb/mongodb.py +13 -10
- llama_stack/{providers/utils → core/storage}/kvstore/postgres/postgres.py +28 -17
- llama_stack/{providers/utils → core/storage}/kvstore/redis/redis.py +41 -16
- llama_stack/{providers/utils → core/storage}/kvstore/sqlite/sqlite.py +1 -1
- llama_stack/core/storage/sqlstore/__init__.py +17 -0
- llama_stack/{providers/utils → core/storage}/sqlstore/authorized_sqlstore.py +69 -49
- llama_stack/{providers/utils → core/storage}/sqlstore/sqlalchemy_sqlstore.py +47 -17
- llama_stack/{providers/utils → core/storage}/sqlstore/sqlstore.py +25 -8
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/config.py +8 -2
- llama_stack/core/utils/config_resolution.py +32 -29
- llama_stack/core/utils/context.py +4 -10
- llama_stack/core/utils/exec.py +9 -0
- llama_stack/core/utils/type_inspection.py +45 -0
- llama_stack/distributions/dell/{run.yaml → config.yaml} +3 -2
- llama_stack/distributions/dell/dell.py +2 -2
- llama_stack/distributions/dell/run-with-safety.yaml +3 -2
- llama_stack/distributions/meta-reference-gpu/{run.yaml → config.yaml} +3 -2
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +2 -2
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +3 -2
- llama_stack/distributions/nvidia/{run.yaml → config.yaml} +4 -4
- llama_stack/distributions/nvidia/nvidia.py +1 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -4
- llama_stack/{apis/datasetio → distributions/oci}/__init__.py +1 -1
- llama_stack/distributions/oci/config.yaml +134 -0
- llama_stack/distributions/oci/oci.py +108 -0
- llama_stack/distributions/open-benchmark/{run.yaml → config.yaml} +5 -4
- llama_stack/distributions/open-benchmark/open_benchmark.py +2 -3
- llama_stack/distributions/postgres-demo/{run.yaml → config.yaml} +4 -3
- llama_stack/distributions/starter/{run.yaml → config.yaml} +64 -13
- llama_stack/distributions/starter/run-with-postgres-store.yaml +64 -13
- llama_stack/distributions/starter/starter.py +8 -5
- llama_stack/distributions/starter-gpu/{run.yaml → config.yaml} +64 -13
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +64 -13
- llama_stack/distributions/template.py +13 -69
- llama_stack/distributions/watsonx/{run.yaml → config.yaml} +4 -3
- llama_stack/distributions/watsonx/watsonx.py +1 -1
- llama_stack/log.py +28 -11
- llama_stack/models/llama/checkpoint.py +6 -6
- llama_stack/models/llama/hadamard_utils.py +2 -0
- llama_stack/models/llama/llama3/generation.py +3 -1
- llama_stack/models/llama/llama3/interface.py +2 -5
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +3 -3
- llama_stack/models/llama/llama3/multimodal/image_transform.py +6 -6
- llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +1 -1
- llama_stack/models/llama/llama3/tool_utils.py +2 -1
- llama_stack/models/llama/llama4/prompt_templates/system_prompts.py +1 -1
- llama_stack/providers/inline/agents/meta_reference/__init__.py +3 -3
- llama_stack/providers/inline/agents/meta_reference/agents.py +44 -261
- llama_stack/providers/inline/agents/meta_reference/config.py +6 -1
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +207 -57
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +308 -47
- llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +162 -96
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +23 -8
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +201 -33
- llama_stack/providers/inline/agents/meta_reference/safety.py +8 -13
- llama_stack/providers/inline/batches/reference/__init__.py +2 -4
- llama_stack/providers/inline/batches/reference/batches.py +78 -60
- llama_stack/providers/inline/datasetio/localfs/datasetio.py +2 -5
- llama_stack/providers/inline/eval/meta_reference/eval.py +16 -61
- llama_stack/providers/inline/files/localfs/files.py +37 -28
- llama_stack/providers/inline/inference/meta_reference/config.py +2 -2
- llama_stack/providers/inline/inference/meta_reference/generators.py +50 -60
- llama_stack/providers/inline/inference/meta_reference/inference.py +403 -19
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +7 -26
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +2 -12
- llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +10 -15
- llama_stack/providers/inline/post_training/common/validator.py +1 -5
- llama_stack/providers/inline/post_training/huggingface/post_training.py +8 -8
- llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +18 -10
- llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +12 -9
- llama_stack/providers/inline/post_training/huggingface/utils.py +27 -6
- llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +1 -1
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +1 -1
- llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py +1 -1
- llama_stack/providers/inline/post_training/torchtune/post_training.py +8 -8
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +16 -16
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +13 -9
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +18 -15
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +9 -9
- llama_stack/providers/inline/scoring/basic/scoring.py +6 -13
- llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/ifeval.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +12 -15
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py +2 -2
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +7 -14
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_simpleqa.py +2 -2
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py +1 -2
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +1 -3
- llama_stack/providers/inline/tool_runtime/rag/__init__.py +1 -1
- llama_stack/providers/inline/tool_runtime/rag/config.py +8 -1
- llama_stack/providers/inline/tool_runtime/rag/context_retriever.py +7 -6
- llama_stack/providers/inline/tool_runtime/rag/memory.py +64 -48
- llama_stack/providers/inline/vector_io/chroma/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/chroma/config.py +1 -1
- llama_stack/providers/inline/vector_io/faiss/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/faiss/config.py +1 -1
- llama_stack/providers/inline/vector_io/faiss/faiss.py +46 -28
- llama_stack/providers/inline/vector_io/milvus/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/milvus/config.py +1 -1
- llama_stack/providers/inline/vector_io/qdrant/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/qdrant/config.py +1 -1
- llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +44 -33
- llama_stack/providers/registry/agents.py +8 -3
- llama_stack/providers/registry/batches.py +1 -1
- llama_stack/providers/registry/datasetio.py +1 -1
- llama_stack/providers/registry/eval.py +1 -1
- llama_stack/{apis/datasets/__init__.py → providers/registry/file_processors.py} +5 -1
- llama_stack/providers/registry/files.py +11 -2
- llama_stack/providers/registry/inference.py +22 -3
- llama_stack/providers/registry/post_training.py +1 -1
- llama_stack/providers/registry/safety.py +1 -1
- llama_stack/providers/registry/scoring.py +1 -1
- llama_stack/providers/registry/tool_runtime.py +2 -2
- llama_stack/providers/registry/vector_io.py +7 -7
- llama_stack/providers/remote/datasetio/huggingface/huggingface.py +2 -5
- llama_stack/providers/remote/datasetio/nvidia/datasetio.py +1 -4
- llama_stack/providers/remote/eval/nvidia/eval.py +15 -9
- llama_stack/providers/remote/files/openai/__init__.py +19 -0
- llama_stack/providers/remote/files/openai/config.py +28 -0
- llama_stack/providers/remote/files/openai/files.py +253 -0
- llama_stack/providers/remote/files/s3/files.py +52 -30
- llama_stack/providers/remote/inference/anthropic/anthropic.py +2 -1
- llama_stack/providers/remote/inference/anthropic/config.py +1 -1
- llama_stack/providers/remote/inference/azure/azure.py +1 -3
- llama_stack/providers/remote/inference/azure/config.py +8 -7
- llama_stack/providers/remote/inference/bedrock/__init__.py +1 -1
- llama_stack/providers/remote/inference/bedrock/bedrock.py +82 -105
- llama_stack/providers/remote/inference/bedrock/config.py +24 -3
- llama_stack/providers/remote/inference/cerebras/cerebras.py +5 -5
- llama_stack/providers/remote/inference/cerebras/config.py +12 -5
- llama_stack/providers/remote/inference/databricks/config.py +13 -6
- llama_stack/providers/remote/inference/databricks/databricks.py +16 -6
- llama_stack/providers/remote/inference/fireworks/config.py +5 -5
- llama_stack/providers/remote/inference/fireworks/fireworks.py +1 -1
- llama_stack/providers/remote/inference/gemini/config.py +1 -1
- llama_stack/providers/remote/inference/gemini/gemini.py +13 -14
- llama_stack/providers/remote/inference/groq/config.py +5 -5
- llama_stack/providers/remote/inference/groq/groq.py +1 -1
- llama_stack/providers/remote/inference/llama_openai_compat/config.py +5 -5
- llama_stack/providers/remote/inference/llama_openai_compat/llama.py +8 -6
- llama_stack/providers/remote/inference/nvidia/__init__.py +1 -1
- llama_stack/providers/remote/inference/nvidia/config.py +21 -11
- llama_stack/providers/remote/inference/nvidia/nvidia.py +115 -3
- llama_stack/providers/remote/inference/nvidia/utils.py +1 -1
- llama_stack/providers/remote/inference/oci/__init__.py +17 -0
- llama_stack/providers/remote/inference/oci/auth.py +79 -0
- llama_stack/providers/remote/inference/oci/config.py +75 -0
- llama_stack/providers/remote/inference/oci/oci.py +162 -0
- llama_stack/providers/remote/inference/ollama/config.py +7 -5
- llama_stack/providers/remote/inference/ollama/ollama.py +17 -8
- llama_stack/providers/remote/inference/openai/config.py +4 -4
- llama_stack/providers/remote/inference/openai/openai.py +1 -1
- llama_stack/providers/remote/inference/passthrough/__init__.py +2 -2
- llama_stack/providers/remote/inference/passthrough/config.py +5 -10
- llama_stack/providers/remote/inference/passthrough/passthrough.py +97 -75
- llama_stack/providers/remote/inference/runpod/config.py +12 -5
- llama_stack/providers/remote/inference/runpod/runpod.py +2 -20
- llama_stack/providers/remote/inference/sambanova/config.py +5 -5
- llama_stack/providers/remote/inference/sambanova/sambanova.py +1 -1
- llama_stack/providers/remote/inference/tgi/config.py +7 -6
- llama_stack/providers/remote/inference/tgi/tgi.py +19 -11
- llama_stack/providers/remote/inference/together/config.py +5 -5
- llama_stack/providers/remote/inference/together/together.py +15 -12
- llama_stack/providers/remote/inference/vertexai/config.py +1 -1
- llama_stack/providers/remote/inference/vllm/config.py +5 -5
- llama_stack/providers/remote/inference/vllm/vllm.py +13 -14
- llama_stack/providers/remote/inference/watsonx/config.py +4 -4
- llama_stack/providers/remote/inference/watsonx/watsonx.py +21 -94
- llama_stack/providers/remote/post_training/nvidia/post_training.py +4 -4
- llama_stack/providers/remote/post_training/nvidia/utils.py +1 -1
- llama_stack/providers/remote/safety/bedrock/bedrock.py +6 -6
- llama_stack/providers/remote/safety/bedrock/config.py +1 -1
- llama_stack/providers/remote/safety/nvidia/config.py +1 -1
- llama_stack/providers/remote/safety/nvidia/nvidia.py +11 -5
- llama_stack/providers/remote/safety/sambanova/config.py +1 -1
- llama_stack/providers/remote/safety/sambanova/sambanova.py +6 -6
- llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +11 -6
- llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +12 -7
- llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +8 -2
- llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +57 -15
- llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +11 -6
- llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +11 -6
- llama_stack/providers/remote/vector_io/chroma/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/chroma/chroma.py +131 -23
- llama_stack/providers/remote/vector_io/chroma/config.py +1 -1
- llama_stack/providers/remote/vector_io/milvus/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/milvus/config.py +1 -1
- llama_stack/providers/remote/vector_io/milvus/milvus.py +37 -28
- llama_stack/providers/remote/vector_io/pgvector/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/pgvector/config.py +1 -1
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +37 -25
- llama_stack/providers/remote/vector_io/qdrant/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/qdrant/config.py +1 -1
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +147 -30
- llama_stack/providers/remote/vector_io/weaviate/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/weaviate/config.py +1 -1
- llama_stack/providers/remote/vector_io/weaviate/weaviate.py +31 -26
- llama_stack/providers/utils/common/data_schema_validator.py +1 -5
- llama_stack/providers/utils/files/form_data.py +1 -1
- llama_stack/providers/utils/inference/embedding_mixin.py +1 -1
- llama_stack/providers/utils/inference/inference_store.py +7 -8
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +79 -79
- llama_stack/providers/utils/inference/model_registry.py +1 -3
- llama_stack/providers/utils/inference/openai_compat.py +44 -1171
- llama_stack/providers/utils/inference/openai_mixin.py +68 -42
- llama_stack/providers/utils/inference/prompt_adapter.py +50 -265
- llama_stack/providers/utils/inference/stream_utils.py +23 -0
- llama_stack/providers/utils/memory/__init__.py +2 -0
- llama_stack/providers/utils/memory/file_utils.py +1 -1
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +181 -84
- llama_stack/providers/utils/memory/vector_store.py +39 -38
- llama_stack/providers/utils/pagination.py +1 -1
- llama_stack/providers/utils/responses/responses_store.py +15 -25
- llama_stack/providers/utils/scoring/aggregation_utils.py +1 -2
- llama_stack/providers/utils/scoring/base_scoring_fn.py +1 -2
- llama_stack/providers/utils/tools/mcp.py +93 -11
- llama_stack/providers/utils/vector_io/__init__.py +16 -0
- llama_stack/providers/utils/vector_io/vector_utils.py +36 -0
- llama_stack/telemetry/constants.py +27 -0
- llama_stack/telemetry/helpers.py +43 -0
- llama_stack/testing/api_recorder.py +25 -16
- {llama_stack-0.3.5.dist-info → llama_stack-0.4.1.dist-info}/METADATA +57 -55
- llama_stack-0.4.1.dist-info/RECORD +588 -0
- llama_stack-0.4.1.dist-info/top_level.txt +2 -0
- llama_stack_api/__init__.py +945 -0
- llama_stack_api/admin/__init__.py +45 -0
- llama_stack_api/admin/api.py +72 -0
- llama_stack_api/admin/fastapi_routes.py +117 -0
- llama_stack_api/admin/models.py +113 -0
- llama_stack_api/agents.py +173 -0
- llama_stack_api/batches/__init__.py +40 -0
- llama_stack_api/batches/api.py +53 -0
- llama_stack_api/batches/fastapi_routes.py +113 -0
- llama_stack_api/batches/models.py +78 -0
- llama_stack_api/benchmarks/__init__.py +43 -0
- llama_stack_api/benchmarks/api.py +39 -0
- llama_stack_api/benchmarks/fastapi_routes.py +109 -0
- llama_stack_api/benchmarks/models.py +109 -0
- {llama_stack/apis → llama_stack_api}/common/content_types.py +1 -43
- {llama_stack/apis → llama_stack_api}/common/errors.py +0 -8
- {llama_stack/apis → llama_stack_api}/common/job_types.py +1 -1
- llama_stack_api/common/responses.py +77 -0
- {llama_stack/apis → llama_stack_api}/common/training_types.py +1 -1
- {llama_stack/apis → llama_stack_api}/common/type_system.py +2 -14
- llama_stack_api/connectors.py +146 -0
- {llama_stack/apis/conversations → llama_stack_api}/conversations.py +23 -39
- {llama_stack/apis/datasetio → llama_stack_api}/datasetio.py +4 -8
- llama_stack_api/datasets/__init__.py +61 -0
- llama_stack_api/datasets/api.py +35 -0
- llama_stack_api/datasets/fastapi_routes.py +104 -0
- llama_stack_api/datasets/models.py +152 -0
- {llama_stack/providers → llama_stack_api}/datatypes.py +166 -10
- {llama_stack/apis/eval → llama_stack_api}/eval.py +8 -40
- llama_stack_api/file_processors/__init__.py +27 -0
- llama_stack_api/file_processors/api.py +64 -0
- llama_stack_api/file_processors/fastapi_routes.py +78 -0
- llama_stack_api/file_processors/models.py +42 -0
- llama_stack_api/files/__init__.py +35 -0
- llama_stack_api/files/api.py +51 -0
- llama_stack_api/files/fastapi_routes.py +124 -0
- llama_stack_api/files/models.py +107 -0
- {llama_stack/apis/inference → llama_stack_api}/inference.py +90 -194
- llama_stack_api/inspect_api/__init__.py +37 -0
- llama_stack_api/inspect_api/api.py +25 -0
- llama_stack_api/inspect_api/fastapi_routes.py +76 -0
- llama_stack_api/inspect_api/models.py +28 -0
- {llama_stack/apis/agents → llama_stack_api/internal}/__init__.py +3 -1
- llama_stack/providers/utils/kvstore/api.py → llama_stack_api/internal/kvstore.py +5 -0
- llama_stack_api/internal/sqlstore.py +79 -0
- {llama_stack/apis/models → llama_stack_api}/models.py +11 -9
- {llama_stack/apis/agents → llama_stack_api}/openai_responses.py +184 -27
- {llama_stack/apis/post_training → llama_stack_api}/post_training.py +7 -11
- {llama_stack/apis/prompts → llama_stack_api}/prompts.py +3 -4
- llama_stack_api/providers/__init__.py +33 -0
- llama_stack_api/providers/api.py +16 -0
- llama_stack_api/providers/fastapi_routes.py +57 -0
- llama_stack_api/providers/models.py +24 -0
- {llama_stack/apis/tools → llama_stack_api}/rag_tool.py +2 -52
- {llama_stack/apis → llama_stack_api}/resource.py +1 -1
- llama_stack_api/router_utils.py +160 -0
- {llama_stack/apis/safety → llama_stack_api}/safety.py +6 -9
- {llama_stack → llama_stack_api}/schema_utils.py +94 -4
- {llama_stack/apis/scoring → llama_stack_api}/scoring.py +3 -3
- {llama_stack/apis/scoring_functions → llama_stack_api}/scoring_functions.py +9 -6
- {llama_stack/apis/shields → llama_stack_api}/shields.py +6 -7
- {llama_stack/apis/tools → llama_stack_api}/tools.py +26 -21
- {llama_stack/apis/vector_io → llama_stack_api}/vector_io.py +133 -152
- {llama_stack/apis/vector_stores → llama_stack_api}/vector_stores.py +1 -1
- llama_stack/apis/agents/agents.py +0 -894
- llama_stack/apis/batches/__init__.py +0 -9
- llama_stack/apis/batches/batches.py +0 -100
- llama_stack/apis/benchmarks/__init__.py +0 -7
- llama_stack/apis/benchmarks/benchmarks.py +0 -108
- llama_stack/apis/common/responses.py +0 -36
- llama_stack/apis/conversations/__init__.py +0 -31
- llama_stack/apis/datasets/datasets.py +0 -251
- llama_stack/apis/datatypes.py +0 -160
- llama_stack/apis/eval/__init__.py +0 -7
- llama_stack/apis/files/__init__.py +0 -7
- llama_stack/apis/files/files.py +0 -199
- llama_stack/apis/inference/__init__.py +0 -7
- llama_stack/apis/inference/event_logger.py +0 -43
- llama_stack/apis/inspect/__init__.py +0 -7
- llama_stack/apis/inspect/inspect.py +0 -94
- llama_stack/apis/models/__init__.py +0 -7
- llama_stack/apis/post_training/__init__.py +0 -7
- llama_stack/apis/prompts/__init__.py +0 -9
- llama_stack/apis/providers/__init__.py +0 -7
- llama_stack/apis/providers/providers.py +0 -69
- llama_stack/apis/safety/__init__.py +0 -7
- llama_stack/apis/scoring/__init__.py +0 -7
- llama_stack/apis/scoring_functions/__init__.py +0 -7
- llama_stack/apis/shields/__init__.py +0 -7
- llama_stack/apis/synthetic_data_generation/__init__.py +0 -7
- llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +0 -77
- llama_stack/apis/telemetry/__init__.py +0 -7
- llama_stack/apis/telemetry/telemetry.py +0 -423
- llama_stack/apis/tools/__init__.py +0 -8
- llama_stack/apis/vector_io/__init__.py +0 -7
- llama_stack/apis/vector_stores/__init__.py +0 -7
- llama_stack/core/server/tracing.py +0 -80
- llama_stack/core/ui/app.py +0 -55
- llama_stack/core/ui/modules/__init__.py +0 -5
- llama_stack/core/ui/modules/api.py +0 -32
- llama_stack/core/ui/modules/utils.py +0 -42
- llama_stack/core/ui/page/__init__.py +0 -5
- llama_stack/core/ui/page/distribution/__init__.py +0 -5
- llama_stack/core/ui/page/distribution/datasets.py +0 -18
- llama_stack/core/ui/page/distribution/eval_tasks.py +0 -20
- llama_stack/core/ui/page/distribution/models.py +0 -18
- llama_stack/core/ui/page/distribution/providers.py +0 -27
- llama_stack/core/ui/page/distribution/resources.py +0 -48
- llama_stack/core/ui/page/distribution/scoring_functions.py +0 -18
- llama_stack/core/ui/page/distribution/shields.py +0 -19
- llama_stack/core/ui/page/evaluations/__init__.py +0 -5
- llama_stack/core/ui/page/evaluations/app_eval.py +0 -143
- llama_stack/core/ui/page/evaluations/native_eval.py +0 -253
- llama_stack/core/ui/page/playground/__init__.py +0 -5
- llama_stack/core/ui/page/playground/chat.py +0 -130
- llama_stack/core/ui/page/playground/tools.py +0 -352
- llama_stack/distributions/dell/build.yaml +0 -33
- llama_stack/distributions/meta-reference-gpu/build.yaml +0 -32
- llama_stack/distributions/nvidia/build.yaml +0 -29
- llama_stack/distributions/open-benchmark/build.yaml +0 -36
- llama_stack/distributions/postgres-demo/__init__.py +0 -7
- llama_stack/distributions/postgres-demo/build.yaml +0 -23
- llama_stack/distributions/postgres-demo/postgres_demo.py +0 -125
- llama_stack/distributions/starter/build.yaml +0 -61
- llama_stack/distributions/starter-gpu/build.yaml +0 -61
- llama_stack/distributions/watsonx/build.yaml +0 -33
- llama_stack/providers/inline/agents/meta_reference/agent_instance.py +0 -1024
- llama_stack/providers/inline/agents/meta_reference/persistence.py +0 -228
- llama_stack/providers/inline/telemetry/__init__.py +0 -5
- llama_stack/providers/inline/telemetry/meta_reference/__init__.py +0 -21
- llama_stack/providers/inline/telemetry/meta_reference/config.py +0 -47
- llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +0 -252
- llama_stack/providers/remote/inference/bedrock/models.py +0 -29
- llama_stack/providers/utils/kvstore/sqlite/config.py +0 -20
- llama_stack/providers/utils/sqlstore/__init__.py +0 -5
- llama_stack/providers/utils/sqlstore/api.py +0 -128
- llama_stack/providers/utils/telemetry/__init__.py +0 -5
- llama_stack/providers/utils/telemetry/trace_protocol.py +0 -142
- llama_stack/providers/utils/telemetry/tracing.py +0 -384
- llama_stack/strong_typing/__init__.py +0 -19
- llama_stack/strong_typing/auxiliary.py +0 -228
- llama_stack/strong_typing/classdef.py +0 -440
- llama_stack/strong_typing/core.py +0 -46
- llama_stack/strong_typing/deserializer.py +0 -877
- llama_stack/strong_typing/docstring.py +0 -409
- llama_stack/strong_typing/exception.py +0 -23
- llama_stack/strong_typing/inspection.py +0 -1085
- llama_stack/strong_typing/mapping.py +0 -40
- llama_stack/strong_typing/name.py +0 -182
- llama_stack/strong_typing/schema.py +0 -792
- llama_stack/strong_typing/serialization.py +0 -97
- llama_stack/strong_typing/serializer.py +0 -500
- llama_stack/strong_typing/slots.py +0 -27
- llama_stack/strong_typing/topological.py +0 -89
- llama_stack/ui/node_modules/flatted/python/flatted.py +0 -149
- llama_stack-0.3.5.dist-info/RECORD +0 -625
- llama_stack-0.3.5.dist-info/top_level.txt +0 -1
- /llama_stack/{providers/utils → core/storage}/kvstore/config.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/mongodb/__init__.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/postgres/__init__.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/redis/__init__.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/sqlite/__init__.py +0 -0
- /llama_stack/{apis → providers/inline/file_processor}/__init__.py +0 -0
- /llama_stack/{apis/common → telemetry}/__init__.py +0 -0
- {llama_stack-0.3.5.dist-info → llama_stack-0.4.1.dist-info}/WHEEL +0 -0
- {llama_stack-0.3.5.dist-info → llama_stack-0.4.1.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.3.5.dist-info → llama_stack-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {llama_stack/core/ui → llama_stack_api/common}/__init__.py +0 -0
- {llama_stack/strong_typing → llama_stack_api}/py.typed +0 -0
- {llama_stack/apis → llama_stack_api}/version.py +0 -0
|
@@ -8,15 +8,42 @@ import uuid
|
|
|
8
8
|
from collections.abc import AsyncIterator
|
|
9
9
|
from typing import Any
|
|
10
10
|
|
|
11
|
-
from
|
|
11
|
+
from openai.types.chat import ChatCompletionToolParam
|
|
12
|
+
from opentelemetry import trace
|
|
13
|
+
|
|
14
|
+
from llama_stack.log import get_logger
|
|
15
|
+
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
|
16
|
+
from llama_stack_api import (
|
|
12
17
|
AllowedToolsFilter,
|
|
13
18
|
ApprovalFilter,
|
|
19
|
+
Inference,
|
|
14
20
|
MCPListToolsTool,
|
|
21
|
+
ModelNotFoundError,
|
|
22
|
+
OpenAIAssistantMessageParam,
|
|
23
|
+
OpenAIChatCompletion,
|
|
24
|
+
OpenAIChatCompletionChunk,
|
|
25
|
+
OpenAIChatCompletionRequestWithExtraBody,
|
|
26
|
+
OpenAIChatCompletionToolCall,
|
|
27
|
+
OpenAIChatCompletionToolChoice,
|
|
28
|
+
OpenAIChatCompletionToolChoiceAllowedTools,
|
|
29
|
+
OpenAIChatCompletionToolChoiceCustomTool,
|
|
30
|
+
OpenAIChatCompletionToolChoiceFunctionTool,
|
|
31
|
+
OpenAIChoice,
|
|
32
|
+
OpenAIChoiceLogprobs,
|
|
33
|
+
OpenAIMessageParam,
|
|
15
34
|
OpenAIResponseContentPartOutputText,
|
|
16
35
|
OpenAIResponseContentPartReasoningText,
|
|
17
36
|
OpenAIResponseContentPartRefusal,
|
|
18
37
|
OpenAIResponseError,
|
|
19
38
|
OpenAIResponseInputTool,
|
|
39
|
+
OpenAIResponseInputToolChoice,
|
|
40
|
+
OpenAIResponseInputToolChoiceAllowedTools,
|
|
41
|
+
OpenAIResponseInputToolChoiceCustomTool,
|
|
42
|
+
OpenAIResponseInputToolChoiceFileSearch,
|
|
43
|
+
OpenAIResponseInputToolChoiceFunctionTool,
|
|
44
|
+
OpenAIResponseInputToolChoiceMCPTool,
|
|
45
|
+
OpenAIResponseInputToolChoiceMode,
|
|
46
|
+
OpenAIResponseInputToolChoiceWebSearch,
|
|
20
47
|
OpenAIResponseInputToolMCP,
|
|
21
48
|
OpenAIResponseMCPApprovalRequest,
|
|
22
49
|
OpenAIResponseMessage,
|
|
@@ -49,34 +76,27 @@ from llama_stack.apis.agents.openai_responses import (
|
|
|
49
76
|
OpenAIResponseOutputMessageMCPCall,
|
|
50
77
|
OpenAIResponseOutputMessageMCPListTools,
|
|
51
78
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
|
79
|
+
OpenAIResponsePrompt,
|
|
52
80
|
OpenAIResponseText,
|
|
53
81
|
OpenAIResponseUsage,
|
|
54
82
|
OpenAIResponseUsageInputTokensDetails,
|
|
55
83
|
OpenAIResponseUsageOutputTokensDetails,
|
|
84
|
+
OpenAIToolMessageParam,
|
|
85
|
+
ResponseItemInclude,
|
|
86
|
+
Safety,
|
|
56
87
|
WebSearchToolTypes,
|
|
57
88
|
)
|
|
58
|
-
from llama_stack.apis.inference import (
|
|
59
|
-
Inference,
|
|
60
|
-
OpenAIAssistantMessageParam,
|
|
61
|
-
OpenAIChatCompletion,
|
|
62
|
-
OpenAIChatCompletionChunk,
|
|
63
|
-
OpenAIChatCompletionRequestWithExtraBody,
|
|
64
|
-
OpenAIChatCompletionToolCall,
|
|
65
|
-
OpenAIChoice,
|
|
66
|
-
OpenAIMessageParam,
|
|
67
|
-
)
|
|
68
|
-
from llama_stack.log import get_logger
|
|
69
|
-
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
|
70
|
-
from llama_stack.providers.utils.telemetry import tracing
|
|
71
89
|
|
|
72
90
|
from .types import ChatCompletionContext, ChatCompletionResult
|
|
73
91
|
from .utils import (
|
|
74
92
|
convert_chat_choice_to_response_message,
|
|
93
|
+
convert_mcp_tool_choice,
|
|
75
94
|
is_function_tool_call,
|
|
76
95
|
run_guardrails,
|
|
77
96
|
)
|
|
78
97
|
|
|
79
98
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
|
99
|
+
tracer = trace.get_tracer(__name__)
|
|
80
100
|
|
|
81
101
|
|
|
82
102
|
def convert_tooldef_to_chat_tool(tool_def):
|
|
@@ -110,9 +130,14 @@ class StreamingResponseOrchestrator:
|
|
|
110
130
|
text: OpenAIResponseText,
|
|
111
131
|
max_infer_iters: int,
|
|
112
132
|
tool_executor, # Will be the tool execution logic from the main class
|
|
113
|
-
instructions: str,
|
|
114
|
-
safety_api,
|
|
133
|
+
instructions: str | None,
|
|
134
|
+
safety_api: Safety | None,
|
|
115
135
|
guardrail_ids: list[str] | None = None,
|
|
136
|
+
prompt: OpenAIResponsePrompt | None = None,
|
|
137
|
+
parallel_tool_calls: bool | None = None,
|
|
138
|
+
max_tool_calls: int | None = None,
|
|
139
|
+
metadata: dict[str, str] | None = None,
|
|
140
|
+
include: list[ResponseItemInclude] | None = None,
|
|
116
141
|
):
|
|
117
142
|
self.inference_api = inference_api
|
|
118
143
|
self.ctx = ctx
|
|
@@ -123,9 +148,27 @@ class StreamingResponseOrchestrator:
|
|
|
123
148
|
self.tool_executor = tool_executor
|
|
124
149
|
self.safety_api = safety_api
|
|
125
150
|
self.guardrail_ids = guardrail_ids or []
|
|
151
|
+
self.prompt = prompt
|
|
152
|
+
# System message that is inserted into the model's context
|
|
153
|
+
self.instructions = instructions
|
|
154
|
+
# Whether to allow more than one function tool call generated per turn.
|
|
155
|
+
self.parallel_tool_calls = parallel_tool_calls
|
|
156
|
+
# Max number of total calls to built-in tools that can be processed in a response
|
|
157
|
+
self.max_tool_calls = max_tool_calls
|
|
158
|
+
self.metadata = metadata
|
|
159
|
+
self.include = include
|
|
126
160
|
self.sequence_number = 0
|
|
127
161
|
# Store MCP tool mapping that gets built during tool processing
|
|
128
|
-
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] =
|
|
162
|
+
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
|
163
|
+
ctx.tool_context.previous_tools if ctx.tool_context else {}
|
|
164
|
+
)
|
|
165
|
+
# Reverse mapping: server_label -> list of tool names for efficient lookup
|
|
166
|
+
self.server_label_to_tools: dict[str, list[str]] = {}
|
|
167
|
+
# Build initial reverse mapping from previous_tools
|
|
168
|
+
for tool_name, mcp_server in self.mcp_tool_to_server.items():
|
|
169
|
+
if mcp_server.server_label not in self.server_label_to_tools:
|
|
170
|
+
self.server_label_to_tools[mcp_server.server_label] = []
|
|
171
|
+
self.server_label_to_tools[mcp_server.server_label].append(tool_name)
|
|
129
172
|
# Track final messages after all tool executions
|
|
130
173
|
self.final_messages: list[OpenAIMessageParam] = []
|
|
131
174
|
# mapping for annotations
|
|
@@ -134,8 +177,8 @@ class StreamingResponseOrchestrator:
|
|
|
134
177
|
self.accumulated_usage: OpenAIResponseUsage | None = None
|
|
135
178
|
# Track if we've sent a refusal response
|
|
136
179
|
self.violation_detected = False
|
|
137
|
-
#
|
|
138
|
-
self.
|
|
180
|
+
# Track total calls made to built-in tools
|
|
181
|
+
self.accumulated_builtin_tool_calls = 0
|
|
139
182
|
|
|
140
183
|
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
|
141
184
|
"""Create a refusal response to replace streaming content."""
|
|
@@ -148,6 +191,7 @@ class StreamingResponseOrchestrator:
|
|
|
148
191
|
model=self.ctx.model,
|
|
149
192
|
status="completed",
|
|
150
193
|
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
|
194
|
+
metadata=self.metadata,
|
|
151
195
|
)
|
|
152
196
|
|
|
153
197
|
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
|
@@ -177,9 +221,14 @@ class StreamingResponseOrchestrator:
|
|
|
177
221
|
output=self._clone_outputs(outputs),
|
|
178
222
|
text=self.text,
|
|
179
223
|
tools=self.ctx.available_tools(),
|
|
224
|
+
tool_choice=self.ctx.tool_choice,
|
|
180
225
|
error=error,
|
|
181
226
|
usage=self.accumulated_usage,
|
|
182
227
|
instructions=self.instructions,
|
|
228
|
+
prompt=self.prompt,
|
|
229
|
+
parallel_tool_calls=self.parallel_tool_calls,
|
|
230
|
+
max_tool_calls=self.max_tool_calls,
|
|
231
|
+
metadata=self.metadata,
|
|
183
232
|
)
|
|
184
233
|
|
|
185
234
|
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
|
@@ -208,6 +257,34 @@ class StreamingResponseOrchestrator:
|
|
|
208
257
|
async for stream_event in self._process_tools(output_messages):
|
|
209
258
|
yield stream_event
|
|
210
259
|
|
|
260
|
+
chat_tool_choice = None
|
|
261
|
+
# Track allowed tools for filtering (persists across iterations)
|
|
262
|
+
allowed_tool_names: set[str] | None = None
|
|
263
|
+
if self.ctx.tool_choice and len(self.ctx.chat_tools) > 0:
|
|
264
|
+
processed_tool_choice = await _process_tool_choice(
|
|
265
|
+
self.ctx.chat_tools,
|
|
266
|
+
self.ctx.tool_choice,
|
|
267
|
+
self.server_label_to_tools,
|
|
268
|
+
)
|
|
269
|
+
# chat_tool_choice can be str, dict-like object, or None
|
|
270
|
+
if isinstance(processed_tool_choice, str | type(None)):
|
|
271
|
+
chat_tool_choice = processed_tool_choice
|
|
272
|
+
elif isinstance(processed_tool_choice, OpenAIChatCompletionToolChoiceAllowedTools):
|
|
273
|
+
# For allowed_tools: filter the tools list instead of using tool_choice
|
|
274
|
+
# This maintains the constraint across all iterations while letting model
|
|
275
|
+
# decide freely whether to call a tool or respond
|
|
276
|
+
allowed_tool_names = {
|
|
277
|
+
tool["function"]["name"]
|
|
278
|
+
for tool in processed_tool_choice.allowed_tools.tools
|
|
279
|
+
if tool.get("type") == "function" and "function" in tool
|
|
280
|
+
}
|
|
281
|
+
# Use the mode (e.g., "required") for first iteration, then "auto"
|
|
282
|
+
chat_tool_choice = (
|
|
283
|
+
processed_tool_choice.allowed_tools.mode if processed_tool_choice.allowed_tools.mode else "auto"
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
chat_tool_choice = processed_tool_choice.model_dump()
|
|
287
|
+
|
|
211
288
|
n_iter = 0
|
|
212
289
|
messages = self.ctx.messages.copy()
|
|
213
290
|
final_status = "completed"
|
|
@@ -217,19 +294,36 @@ class StreamingResponseOrchestrator:
|
|
|
217
294
|
while True:
|
|
218
295
|
# Text is the default response format for chat completion so don't need to pass it
|
|
219
296
|
# (some providers don't support non-empty response_format when tools are present)
|
|
220
|
-
response_format =
|
|
221
|
-
|
|
297
|
+
response_format = (
|
|
298
|
+
None if getattr(self.ctx.response_format, "type", None) == "text" else self.ctx.response_format
|
|
299
|
+
)
|
|
300
|
+
# Filter tools to only allowed ones if tool_choice specified an allowed list
|
|
301
|
+
effective_tools = self.ctx.chat_tools
|
|
302
|
+
if allowed_tool_names is not None:
|
|
303
|
+
effective_tools = [
|
|
304
|
+
tool
|
|
305
|
+
for tool in self.ctx.chat_tools
|
|
306
|
+
if tool.get("function", {}).get("name") in allowed_tool_names
|
|
307
|
+
]
|
|
308
|
+
logger.debug(f"calling openai_chat_completion with tools: {effective_tools}")
|
|
309
|
+
|
|
310
|
+
logprobs = (
|
|
311
|
+
True if self.include and ResponseItemInclude.message_output_text_logprobs in self.include else None
|
|
312
|
+
)
|
|
222
313
|
|
|
223
314
|
params = OpenAIChatCompletionRequestWithExtraBody(
|
|
224
315
|
model=self.ctx.model,
|
|
225
316
|
messages=messages,
|
|
226
|
-
|
|
317
|
+
# Pydantic models are dict-compatible but mypy treats them as distinct types
|
|
318
|
+
tools=effective_tools, # type: ignore[arg-type]
|
|
319
|
+
tool_choice=chat_tool_choice,
|
|
227
320
|
stream=True,
|
|
228
321
|
temperature=self.ctx.temperature,
|
|
229
322
|
response_format=response_format,
|
|
230
323
|
stream_options={
|
|
231
324
|
"include_usage": True,
|
|
232
325
|
},
|
|
326
|
+
logprobs=logprobs,
|
|
233
327
|
)
|
|
234
328
|
completion_result = await self.inference_api.openai_chat_completion(params)
|
|
235
329
|
|
|
@@ -266,7 +360,12 @@ class StreamingResponseOrchestrator:
|
|
|
266
360
|
|
|
267
361
|
# Handle choices with no tool calls
|
|
268
362
|
for choice in current_response.choices:
|
|
269
|
-
|
|
363
|
+
has_tool_calls = (
|
|
364
|
+
isinstance(choice.message, OpenAIAssistantMessageParam)
|
|
365
|
+
and choice.message.tool_calls
|
|
366
|
+
and self.ctx.response_tools
|
|
367
|
+
)
|
|
368
|
+
if not has_tool_calls:
|
|
270
369
|
output_messages.append(
|
|
271
370
|
await convert_chat_choice_to_response_message(
|
|
272
371
|
choice,
|
|
@@ -295,6 +394,14 @@ class StreamingResponseOrchestrator:
|
|
|
295
394
|
break
|
|
296
395
|
|
|
297
396
|
n_iter += 1
|
|
397
|
+
# After first iteration, reset tool_choice to "auto" to let model decide freely
|
|
398
|
+
# based on tool results (prevents infinite loops when forcing specific tools)
|
|
399
|
+
# Note: When allowed_tool_names is set, tools are already filtered so model
|
|
400
|
+
# can only call allowed tools - we just need to let it decide whether to call
|
|
401
|
+
# a tool or respond (hence "auto" mode)
|
|
402
|
+
if n_iter == 1 and chat_tool_choice and chat_tool_choice != "auto":
|
|
403
|
+
chat_tool_choice = "auto"
|
|
404
|
+
|
|
298
405
|
if n_iter >= self.max_infer_iters:
|
|
299
406
|
logger.info(
|
|
300
407
|
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
|
|
@@ -305,6 +412,8 @@ class StreamingResponseOrchestrator:
|
|
|
305
412
|
if last_completion_result and last_completion_result.finish_reason == "length":
|
|
306
413
|
final_status = "incomplete"
|
|
307
414
|
|
|
415
|
+
except ModelNotFoundError:
|
|
416
|
+
raise
|
|
308
417
|
except Exception as exc: # noqa: BLE001
|
|
309
418
|
self.final_messages = messages.copy()
|
|
310
419
|
self.sequence_number += 1
|
|
@@ -544,6 +653,7 @@ class StreamingResponseOrchestrator:
|
|
|
544
653
|
chunk_created = 0
|
|
545
654
|
chunk_model = ""
|
|
546
655
|
chunk_finish_reason = ""
|
|
656
|
+
chat_response_logprobs = []
|
|
547
657
|
|
|
548
658
|
# Create a placeholder message item for delta events
|
|
549
659
|
message_item_id = f"msg_{uuid.uuid4()}"
|
|
@@ -573,6 +683,12 @@ class StreamingResponseOrchestrator:
|
|
|
573
683
|
chunk_events: list[OpenAIResponseObjectStream] = []
|
|
574
684
|
|
|
575
685
|
for chunk_choice in chunk.choices:
|
|
686
|
+
# Collect logprobs if present
|
|
687
|
+
chunk_logprobs = None
|
|
688
|
+
if chunk_choice.logprobs and chunk_choice.logprobs.content:
|
|
689
|
+
chunk_logprobs = chunk_choice.logprobs.content
|
|
690
|
+
chat_response_logprobs.extend(chunk_logprobs)
|
|
691
|
+
|
|
576
692
|
# Emit incremental text content as delta events
|
|
577
693
|
if chunk_choice.delta.content:
|
|
578
694
|
# Emit output_item.added for the message on first content
|
|
@@ -612,6 +728,7 @@ class StreamingResponseOrchestrator:
|
|
|
612
728
|
content_index=content_index,
|
|
613
729
|
delta=chunk_choice.delta.content,
|
|
614
730
|
item_id=message_item_id,
|
|
731
|
+
logprobs=chunk_logprobs,
|
|
615
732
|
output_index=message_output_index,
|
|
616
733
|
sequence_number=self.sequence_number,
|
|
617
734
|
)
|
|
@@ -716,7 +833,10 @@ class StreamingResponseOrchestrator:
|
|
|
716
833
|
)
|
|
717
834
|
|
|
718
835
|
# Accumulate arguments for final response (only for subsequent chunks)
|
|
719
|
-
if not is_new_tool_call:
|
|
836
|
+
if not is_new_tool_call and response_tool_call is not None:
|
|
837
|
+
# Both should have functions since we're inside the tool_call.function check above
|
|
838
|
+
assert response_tool_call.function is not None
|
|
839
|
+
assert tool_call.function is not None
|
|
720
840
|
response_tool_call.function.arguments = (
|
|
721
841
|
response_tool_call.function.arguments or ""
|
|
722
842
|
) + tool_call.function.arguments
|
|
@@ -741,10 +861,13 @@ class StreamingResponseOrchestrator:
|
|
|
741
861
|
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
|
742
862
|
tool_call = chat_response_tool_calls[tool_call_index]
|
|
743
863
|
# Ensure that arguments, if sent back to the inference provider, are not None
|
|
744
|
-
|
|
864
|
+
if tool_call.function:
|
|
865
|
+
tool_call.function.arguments = tool_call.function.arguments or "{}"
|
|
745
866
|
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
|
746
|
-
final_arguments = tool_call.function.arguments
|
|
747
|
-
|
|
867
|
+
final_arguments: str = tool_call.function.arguments or "{}" if tool_call.function else "{}"
|
|
868
|
+
func = chat_response_tool_calls[tool_call_index].function
|
|
869
|
+
|
|
870
|
+
tool_call_name = func.name if func else ""
|
|
748
871
|
|
|
749
872
|
# Check if this is an MCP tool call
|
|
750
873
|
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
|
|
@@ -809,6 +932,7 @@ class StreamingResponseOrchestrator:
|
|
|
809
932
|
OpenAIResponseOutputMessageContentOutputText(
|
|
810
933
|
text=final_text,
|
|
811
934
|
annotations=[],
|
|
935
|
+
logprobs=chat_response_logprobs if chat_response_logprobs else None,
|
|
812
936
|
)
|
|
813
937
|
)
|
|
814
938
|
|
|
@@ -836,6 +960,7 @@ class StreamingResponseOrchestrator:
|
|
|
836
960
|
message_item_id=message_item_id,
|
|
837
961
|
tool_call_item_ids=tool_call_item_ids,
|
|
838
962
|
content_part_emitted=content_part_emitted,
|
|
963
|
+
logprobs=OpenAIChoiceLogprobs(content=chat_response_logprobs) if chat_response_logprobs else None,
|
|
839
964
|
)
|
|
840
965
|
|
|
841
966
|
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
|
|
@@ -857,6 +982,7 @@ class StreamingResponseOrchestrator:
|
|
|
857
982
|
message=assistant_message,
|
|
858
983
|
finish_reason=result.finish_reason,
|
|
859
984
|
index=0,
|
|
985
|
+
logprobs=result.logprobs,
|
|
860
986
|
)
|
|
861
987
|
],
|
|
862
988
|
created=result.created,
|
|
@@ -874,6 +1000,17 @@ class StreamingResponseOrchestrator:
|
|
|
874
1000
|
"""Coordinate execution of both function and non-function tool calls."""
|
|
875
1001
|
# Execute non-function tool calls
|
|
876
1002
|
for tool_call in non_function_tool_calls:
|
|
1003
|
+
# if total calls made to built-in and mcp tools exceed max_tool_calls
|
|
1004
|
+
# then create a tool response message indicating the call was skipped
|
|
1005
|
+
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
|
|
1006
|
+
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
|
|
1007
|
+
skipped_call_message = OpenAIToolMessageParam(
|
|
1008
|
+
content=f"Tool call skipped: maximum tool calls limit ({self.max_tool_calls}) reached.",
|
|
1009
|
+
tool_call_id=tool_call.id,
|
|
1010
|
+
)
|
|
1011
|
+
next_turn_messages.append(skipped_call_message)
|
|
1012
|
+
continue
|
|
1013
|
+
|
|
877
1014
|
# Find the item_id for this tool call
|
|
878
1015
|
matching_item_id = None
|
|
879
1016
|
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
|
@@ -888,12 +1025,11 @@ class StreamingResponseOrchestrator:
|
|
|
888
1025
|
|
|
889
1026
|
self.sequence_number += 1
|
|
890
1027
|
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
|
|
891
|
-
item = OpenAIResponseOutputMessageMCPCall(
|
|
1028
|
+
item: OpenAIResponseOutput = OpenAIResponseOutputMessageMCPCall(
|
|
892
1029
|
arguments="",
|
|
893
1030
|
name=tool_call.function.name,
|
|
894
1031
|
id=matching_item_id,
|
|
895
1032
|
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
|
|
896
|
-
status="in_progress",
|
|
897
1033
|
)
|
|
898
1034
|
elif tool_call.function.name == "web_search":
|
|
899
1035
|
item = OpenAIResponseOutputMessageWebSearchToolCall(
|
|
@@ -955,6 +1091,9 @@ class StreamingResponseOrchestrator:
|
|
|
955
1091
|
if tool_response_message:
|
|
956
1092
|
next_turn_messages.append(tool_response_message)
|
|
957
1093
|
|
|
1094
|
+
# Track number of calls made to built-in and mcp tools
|
|
1095
|
+
self.accumulated_builtin_tool_calls += 1
|
|
1096
|
+
|
|
958
1097
|
# Execute function tool calls (client-side)
|
|
959
1098
|
for tool_call in function_tool_calls:
|
|
960
1099
|
# Find the item_id for this tool call from our tracking dictionary
|
|
@@ -992,9 +1131,9 @@ class StreamingResponseOrchestrator:
|
|
|
992
1131
|
"""Process all tools and emit appropriate streaming events."""
|
|
993
1132
|
from openai.types.chat import ChatCompletionToolParam
|
|
994
1133
|
|
|
995
|
-
from llama_stack.apis.tools import ToolDef
|
|
996
1134
|
from llama_stack.models.llama.datatypes import ToolDefinition
|
|
997
1135
|
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
|
1136
|
+
from llama_stack_api import ToolDef
|
|
998
1137
|
|
|
999
1138
|
def make_openai_tool(tool_name: str, tool: ToolDef) -> ChatCompletionToolParam:
|
|
1000
1139
|
tool_def = ToolDefinition(
|
|
@@ -1002,7 +1141,7 @@ class StreamingResponseOrchestrator:
|
|
|
1002
1141
|
description=tool.description,
|
|
1003
1142
|
input_schema=tool.input_schema,
|
|
1004
1143
|
)
|
|
1005
|
-
return convert_tooldef_to_openai_tool(tool_def)
|
|
1144
|
+
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
|
|
1006
1145
|
|
|
1007
1146
|
# Initialize chat_tools if not already set
|
|
1008
1147
|
if self.ctx.chat_tools is None:
|
|
@@ -1010,7 +1149,7 @@ class StreamingResponseOrchestrator:
|
|
|
1010
1149
|
|
|
1011
1150
|
for input_tool in tools:
|
|
1012
1151
|
if input_tool.type == "function":
|
|
1013
|
-
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
|
1152
|
+
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition
|
|
1014
1153
|
elif input_tool.type in WebSearchToolTypes:
|
|
1015
1154
|
tool_name = "web_search"
|
|
1016
1155
|
# Need to access tool_groups_api from tool_executor
|
|
@@ -1049,8 +1188,8 @@ class StreamingResponseOrchestrator:
|
|
|
1049
1188
|
if isinstance(mcp_tool.allowed_tools, list):
|
|
1050
1189
|
always_allowed = mcp_tool.allowed_tools
|
|
1051
1190
|
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
|
1052
|
-
|
|
1053
|
-
|
|
1191
|
+
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
|
|
1192
|
+
always_allowed = mcp_tool.allowed_tools.tool_names
|
|
1054
1193
|
|
|
1055
1194
|
# Call list_mcp_tools
|
|
1056
1195
|
tool_defs = None
|
|
@@ -1060,10 +1199,14 @@ class StreamingResponseOrchestrator:
|
|
|
1060
1199
|
"server_url": mcp_tool.server_url,
|
|
1061
1200
|
"mcp_list_tools_id": list_id,
|
|
1062
1201
|
}
|
|
1063
|
-
|
|
1202
|
+
|
|
1203
|
+
# TODO: follow semantic conventions for Open Telemetry tool spans
|
|
1204
|
+
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span
|
|
1205
|
+
with tracer.start_as_current_span("list_mcp_tools", attributes=attributes):
|
|
1064
1206
|
tool_defs = await list_mcp_tools(
|
|
1065
1207
|
endpoint=mcp_tool.server_url,
|
|
1066
|
-
headers=mcp_tool.headers
|
|
1208
|
+
headers=mcp_tool.headers,
|
|
1209
|
+
authorization=mcp_tool.authorization,
|
|
1067
1210
|
)
|
|
1068
1211
|
|
|
1069
1212
|
# Create the MCP list tools message
|
|
@@ -1082,13 +1225,18 @@ class StreamingResponseOrchestrator:
|
|
|
1082
1225
|
openai_tool = convert_tooldef_to_chat_tool(t)
|
|
1083
1226
|
if self.ctx.chat_tools is None:
|
|
1084
1227
|
self.ctx.chat_tools = []
|
|
1085
|
-
self.ctx.chat_tools.append(openai_tool)
|
|
1228
|
+
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
|
1086
1229
|
|
|
1087
1230
|
# Add to MCP tool mapping
|
|
1088
1231
|
if t.name in self.mcp_tool_to_server:
|
|
1089
1232
|
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
|
|
1090
1233
|
self.mcp_tool_to_server[t.name] = mcp_tool
|
|
1091
1234
|
|
|
1235
|
+
# Add to reverse mapping for efficient server_label lookup
|
|
1236
|
+
if mcp_tool.server_label not in self.server_label_to_tools:
|
|
1237
|
+
self.server_label_to_tools[mcp_tool.server_label] = []
|
|
1238
|
+
self.server_label_to_tools[mcp_tool.server_label].append(t.name)
|
|
1239
|
+
|
|
1092
1240
|
# Add to MCP list message
|
|
1093
1241
|
mcp_list_message.tools.append(
|
|
1094
1242
|
MCPListToolsTool(
|
|
@@ -1114,13 +1262,17 @@ class StreamingResponseOrchestrator:
|
|
|
1114
1262
|
self, output_messages: list[OpenAIResponseOutput]
|
|
1115
1263
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
|
1116
1264
|
# Handle all mcp tool lists from previous response that are still valid:
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1265
|
+
# tool_context can be None when no tools are provided in the response request
|
|
1266
|
+
if self.ctx.tool_context:
|
|
1267
|
+
for tool in self.ctx.tool_context.previous_tool_listings:
|
|
1268
|
+
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
|
1269
|
+
yield evt
|
|
1270
|
+
# Process all remaining tools (including MCP tools) and emit streaming events
|
|
1271
|
+
if self.ctx.tool_context.tools_to_process:
|
|
1272
|
+
async for stream_event in self._process_new_tools(
|
|
1273
|
+
self.ctx.tool_context.tools_to_process, output_messages
|
|
1274
|
+
):
|
|
1275
|
+
yield stream_event
|
|
1124
1276
|
|
|
1125
1277
|
def _approval_required(self, tool_name: str) -> bool:
|
|
1126
1278
|
if tool_name not in self.mcp_tool_to_server:
|
|
@@ -1131,9 +1283,9 @@ class StreamingResponseOrchestrator:
|
|
|
1131
1283
|
if mcp_server.require_approval == "never":
|
|
1132
1284
|
return False
|
|
1133
1285
|
if isinstance(mcp_server, ApprovalFilter):
|
|
1134
|
-
if tool_name in mcp_server.always:
|
|
1286
|
+
if mcp_server.always and tool_name in mcp_server.always:
|
|
1135
1287
|
return True
|
|
1136
|
-
if tool_name in mcp_server.never:
|
|
1288
|
+
if mcp_server.never and tool_name in mcp_server.never:
|
|
1137
1289
|
return False
|
|
1138
1290
|
return True
|
|
1139
1291
|
|
|
@@ -1214,7 +1366,7 @@ class StreamingResponseOrchestrator:
|
|
|
1214
1366
|
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
|
1215
1367
|
if self.ctx.chat_tools is None:
|
|
1216
1368
|
self.ctx.chat_tools = []
|
|
1217
|
-
self.ctx.chat_tools.append(openai_tool)
|
|
1369
|
+
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
|
1218
1370
|
|
|
1219
1371
|
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
|
1220
1372
|
id=f"mcp_list_{uuid.uuid4()}",
|
|
@@ -1224,3 +1376,112 @@ class StreamingResponseOrchestrator:
|
|
|
1224
1376
|
|
|
1225
1377
|
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
|
|
1226
1378
|
yield stream_event
|
|
1379
|
+
|
|
1380
|
+
|
|
1381
|
+
async def _process_tool_choice(
|
|
1382
|
+
chat_tools: list[ChatCompletionToolParam],
|
|
1383
|
+
tool_choice: OpenAIResponseInputToolChoice,
|
|
1384
|
+
server_label_to_tools: dict[str, list[str]],
|
|
1385
|
+
) -> str | OpenAIChatCompletionToolChoice | None:
|
|
1386
|
+
"""Process and validate the OpenAI Responses tool choice and return the appropriate chat completion tool choice object.
|
|
1387
|
+
|
|
1388
|
+
:param chat_tools: The list of chat tools to enforce tool choice against.
|
|
1389
|
+
:param tool_choice: The OpenAI Responses tool choice to process.
|
|
1390
|
+
:param server_label_to_tools: A dictionary mapping server labels to the list of tools available on that server.
|
|
1391
|
+
:return: The appropriate chat completion tool choice object.
|
|
1392
|
+
"""
|
|
1393
|
+
|
|
1394
|
+
# retrieve all function tool names from the chat tools
|
|
1395
|
+
# Note: chat_tools contains dicts, not objects
|
|
1396
|
+
chat_tool_names = [tool["function"]["name"] for tool in chat_tools if tool["type"] == "function"]
|
|
1397
|
+
|
|
1398
|
+
if isinstance(tool_choice, OpenAIResponseInputToolChoiceMode):
|
|
1399
|
+
if tool_choice.value == "required":
|
|
1400
|
+
if len(chat_tool_names) == 0:
|
|
1401
|
+
return None
|
|
1402
|
+
|
|
1403
|
+
# add all function tools to the allowed tools list and set mode to required
|
|
1404
|
+
return OpenAIChatCompletionToolChoiceAllowedTools(
|
|
1405
|
+
tools=[{"type": "function", "function": {"name": tool}} for tool in chat_tool_names],
|
|
1406
|
+
mode="required",
|
|
1407
|
+
)
|
|
1408
|
+
# return other modes as is
|
|
1409
|
+
return tool_choice.value
|
|
1410
|
+
|
|
1411
|
+
elif isinstance(tool_choice, OpenAIResponseInputToolChoiceAllowedTools):
|
|
1412
|
+
# ensure that specified tool choices are available in the chat tools, if not, remove them from the list
|
|
1413
|
+
final_tools = []
|
|
1414
|
+
for tool in tool_choice.tools:
|
|
1415
|
+
match tool.get("type"):
|
|
1416
|
+
case "function":
|
|
1417
|
+
final_tools.append({"type": "function", "function": {"name": tool.get("name")}})
|
|
1418
|
+
case "custom":
|
|
1419
|
+
final_tools.append({"type": "custom", "custom": {"name": tool.get("name")}})
|
|
1420
|
+
case "mcp":
|
|
1421
|
+
mcp_tools = convert_mcp_tool_choice(
|
|
1422
|
+
chat_tool_names, tool.get("server_label"), server_label_to_tools, None
|
|
1423
|
+
)
|
|
1424
|
+
# convert_mcp_tool_choice can return a dict, list, or None
|
|
1425
|
+
if isinstance(mcp_tools, list):
|
|
1426
|
+
final_tools.extend(mcp_tools)
|
|
1427
|
+
elif isinstance(mcp_tools, dict):
|
|
1428
|
+
final_tools.append(mcp_tools)
|
|
1429
|
+
# Skip if None or empty
|
|
1430
|
+
case "file_search":
|
|
1431
|
+
final_tools.append({"type": "function", "function": {"name": "file_search"}})
|
|
1432
|
+
case _ if tool["type"] in WebSearchToolTypes:
|
|
1433
|
+
final_tools.append({"type": "function", "function": {"name": "web_search"}})
|
|
1434
|
+
case _:
|
|
1435
|
+
logger.warning(f"Unsupported tool type: {tool['type']}, skipping tool choice enforcement for it")
|
|
1436
|
+
continue
|
|
1437
|
+
|
|
1438
|
+
return OpenAIChatCompletionToolChoiceAllowedTools(
|
|
1439
|
+
tools=final_tools,
|
|
1440
|
+
mode=tool_choice.mode,
|
|
1441
|
+
)
|
|
1442
|
+
|
|
1443
|
+
else:
|
|
1444
|
+
# Handle specific tool choice by type
|
|
1445
|
+
# Each case validates the tool exists in chat_tools before returning
|
|
1446
|
+
tool_name = getattr(tool_choice, "name", None)
|
|
1447
|
+
match tool_choice:
|
|
1448
|
+
case OpenAIResponseInputToolChoiceCustomTool():
|
|
1449
|
+
if tool_name and tool_name not in chat_tool_names:
|
|
1450
|
+
logger.warning(f"Tool {tool_name} not found in chat tools")
|
|
1451
|
+
return None
|
|
1452
|
+
return OpenAIChatCompletionToolChoiceCustomTool(name=tool_name)
|
|
1453
|
+
|
|
1454
|
+
case OpenAIResponseInputToolChoiceFunctionTool():
|
|
1455
|
+
if tool_name and tool_name not in chat_tool_names:
|
|
1456
|
+
logger.warning(f"Tool {tool_name} not found in chat tools")
|
|
1457
|
+
return None
|
|
1458
|
+
return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_name)
|
|
1459
|
+
|
|
1460
|
+
case OpenAIResponseInputToolChoiceFileSearch():
|
|
1461
|
+
if "file_search" not in chat_tool_names:
|
|
1462
|
+
logger.warning("Tool file_search not found in chat tools")
|
|
1463
|
+
return None
|
|
1464
|
+
return OpenAIChatCompletionToolChoiceFunctionTool(name="file_search")
|
|
1465
|
+
|
|
1466
|
+
case OpenAIResponseInputToolChoiceWebSearch():
|
|
1467
|
+
if "web_search" not in chat_tool_names:
|
|
1468
|
+
logger.warning("Tool web_search not found in chat tools")
|
|
1469
|
+
return None
|
|
1470
|
+
return OpenAIChatCompletionToolChoiceFunctionTool(name="web_search")
|
|
1471
|
+
|
|
1472
|
+
case OpenAIResponseInputToolChoiceMCPTool():
|
|
1473
|
+
tool_choice = convert_mcp_tool_choice(
|
|
1474
|
+
chat_tool_names,
|
|
1475
|
+
tool_choice.server_label,
|
|
1476
|
+
server_label_to_tools,
|
|
1477
|
+
tool_name,
|
|
1478
|
+
)
|
|
1479
|
+
if isinstance(tool_choice, dict):
|
|
1480
|
+
# for single tool choice, return as function tool choice
|
|
1481
|
+
return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_choice["function"]["name"])
|
|
1482
|
+
elif isinstance(tool_choice, list):
|
|
1483
|
+
# for multiple tool choices, return as allowed tools
|
|
1484
|
+
return OpenAIChatCompletionToolChoiceAllowedTools(
|
|
1485
|
+
tools=tool_choice,
|
|
1486
|
+
mode="required",
|
|
1487
|
+
)
|