llama-stack 0.3.5__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/__init__.py +0 -5
- llama_stack/cli/llama.py +3 -3
- llama_stack/cli/stack/_list_deps.py +12 -23
- llama_stack/cli/stack/list_stacks.py +37 -18
- llama_stack/cli/stack/run.py +121 -11
- llama_stack/cli/stack/utils.py +0 -127
- llama_stack/core/access_control/access_control.py +69 -28
- llama_stack/core/access_control/conditions.py +15 -5
- llama_stack/core/admin.py +267 -0
- llama_stack/core/build.py +6 -74
- llama_stack/core/client.py +1 -1
- llama_stack/core/configure.py +6 -6
- llama_stack/core/conversations/conversations.py +28 -25
- llama_stack/core/datatypes.py +271 -79
- llama_stack/core/distribution.py +15 -16
- llama_stack/core/external.py +3 -3
- llama_stack/core/inspect.py +98 -15
- llama_stack/core/library_client.py +73 -61
- llama_stack/core/prompts/prompts.py +12 -11
- llama_stack/core/providers.py +17 -11
- llama_stack/core/resolver.py +65 -56
- llama_stack/core/routers/__init__.py +8 -12
- llama_stack/core/routers/datasets.py +1 -4
- llama_stack/core/routers/eval_scoring.py +7 -4
- llama_stack/core/routers/inference.py +55 -271
- llama_stack/core/routers/safety.py +52 -24
- llama_stack/core/routers/tool_runtime.py +6 -48
- llama_stack/core/routers/vector_io.py +130 -51
- llama_stack/core/routing_tables/benchmarks.py +24 -20
- llama_stack/core/routing_tables/common.py +1 -4
- llama_stack/core/routing_tables/datasets.py +22 -22
- llama_stack/core/routing_tables/models.py +119 -6
- llama_stack/core/routing_tables/scoring_functions.py +7 -7
- llama_stack/core/routing_tables/shields.py +1 -2
- llama_stack/core/routing_tables/toolgroups.py +17 -7
- llama_stack/core/routing_tables/vector_stores.py +51 -16
- llama_stack/core/server/auth.py +5 -3
- llama_stack/core/server/auth_providers.py +36 -20
- llama_stack/core/server/fastapi_router_registry.py +84 -0
- llama_stack/core/server/quota.py +2 -2
- llama_stack/core/server/routes.py +79 -27
- llama_stack/core/server/server.py +102 -87
- llama_stack/core/stack.py +201 -58
- llama_stack/core/storage/datatypes.py +26 -3
- llama_stack/{providers/utils → core/storage}/kvstore/__init__.py +2 -0
- llama_stack/{providers/utils → core/storage}/kvstore/kvstore.py +55 -24
- llama_stack/{providers/utils → core/storage}/kvstore/mongodb/mongodb.py +13 -10
- llama_stack/{providers/utils → core/storage}/kvstore/postgres/postgres.py +28 -17
- llama_stack/{providers/utils → core/storage}/kvstore/redis/redis.py +41 -16
- llama_stack/{providers/utils → core/storage}/kvstore/sqlite/sqlite.py +1 -1
- llama_stack/core/storage/sqlstore/__init__.py +17 -0
- llama_stack/{providers/utils → core/storage}/sqlstore/authorized_sqlstore.py +69 -49
- llama_stack/{providers/utils → core/storage}/sqlstore/sqlalchemy_sqlstore.py +47 -17
- llama_stack/{providers/utils → core/storage}/sqlstore/sqlstore.py +25 -8
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/config.py +8 -2
- llama_stack/core/utils/config_resolution.py +32 -29
- llama_stack/core/utils/context.py +4 -10
- llama_stack/core/utils/exec.py +9 -0
- llama_stack/core/utils/type_inspection.py +45 -0
- llama_stack/distributions/dell/{run.yaml → config.yaml} +3 -2
- llama_stack/distributions/dell/dell.py +2 -2
- llama_stack/distributions/dell/run-with-safety.yaml +3 -2
- llama_stack/distributions/meta-reference-gpu/{run.yaml → config.yaml} +3 -2
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +2 -2
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +3 -2
- llama_stack/distributions/nvidia/{run.yaml → config.yaml} +4 -4
- llama_stack/distributions/nvidia/nvidia.py +1 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -4
- llama_stack/{apis/datasetio → distributions/oci}/__init__.py +1 -1
- llama_stack/distributions/oci/config.yaml +134 -0
- llama_stack/distributions/oci/oci.py +108 -0
- llama_stack/distributions/open-benchmark/{run.yaml → config.yaml} +5 -4
- llama_stack/distributions/open-benchmark/open_benchmark.py +2 -3
- llama_stack/distributions/postgres-demo/{run.yaml → config.yaml} +4 -3
- llama_stack/distributions/starter/{run.yaml → config.yaml} +64 -13
- llama_stack/distributions/starter/run-with-postgres-store.yaml +64 -13
- llama_stack/distributions/starter/starter.py +8 -5
- llama_stack/distributions/starter-gpu/{run.yaml → config.yaml} +64 -13
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +64 -13
- llama_stack/distributions/template.py +13 -69
- llama_stack/distributions/watsonx/{run.yaml → config.yaml} +4 -3
- llama_stack/distributions/watsonx/watsonx.py +1 -1
- llama_stack/log.py +28 -11
- llama_stack/models/llama/checkpoint.py +6 -6
- llama_stack/models/llama/hadamard_utils.py +2 -0
- llama_stack/models/llama/llama3/generation.py +3 -1
- llama_stack/models/llama/llama3/interface.py +2 -5
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +3 -3
- llama_stack/models/llama/llama3/multimodal/image_transform.py +6 -6
- llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +1 -1
- llama_stack/models/llama/llama3/tool_utils.py +2 -1
- llama_stack/models/llama/llama4/prompt_templates/system_prompts.py +1 -1
- llama_stack/providers/inline/agents/meta_reference/__init__.py +3 -3
- llama_stack/providers/inline/agents/meta_reference/agents.py +44 -261
- llama_stack/providers/inline/agents/meta_reference/config.py +6 -1
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +207 -57
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +308 -47
- llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +162 -96
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +23 -8
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +201 -33
- llama_stack/providers/inline/agents/meta_reference/safety.py +8 -13
- llama_stack/providers/inline/batches/reference/__init__.py +2 -4
- llama_stack/providers/inline/batches/reference/batches.py +78 -60
- llama_stack/providers/inline/datasetio/localfs/datasetio.py +2 -5
- llama_stack/providers/inline/eval/meta_reference/eval.py +16 -61
- llama_stack/providers/inline/files/localfs/files.py +37 -28
- llama_stack/providers/inline/inference/meta_reference/config.py +2 -2
- llama_stack/providers/inline/inference/meta_reference/generators.py +50 -60
- llama_stack/providers/inline/inference/meta_reference/inference.py +403 -19
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +7 -26
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +2 -12
- llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +10 -15
- llama_stack/providers/inline/post_training/common/validator.py +1 -5
- llama_stack/providers/inline/post_training/huggingface/post_training.py +8 -8
- llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +18 -10
- llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +12 -9
- llama_stack/providers/inline/post_training/huggingface/utils.py +27 -6
- llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +1 -1
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +1 -1
- llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py +1 -1
- llama_stack/providers/inline/post_training/torchtune/post_training.py +8 -8
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +16 -16
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +13 -9
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +18 -15
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +9 -9
- llama_stack/providers/inline/scoring/basic/scoring.py +6 -13
- llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/ifeval.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +12 -15
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py +2 -2
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +7 -14
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_simpleqa.py +2 -2
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py +1 -2
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +1 -3
- llama_stack/providers/inline/tool_runtime/rag/__init__.py +1 -1
- llama_stack/providers/inline/tool_runtime/rag/config.py +8 -1
- llama_stack/providers/inline/tool_runtime/rag/context_retriever.py +7 -6
- llama_stack/providers/inline/tool_runtime/rag/memory.py +64 -48
- llama_stack/providers/inline/vector_io/chroma/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/chroma/config.py +1 -1
- llama_stack/providers/inline/vector_io/faiss/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/faiss/config.py +1 -1
- llama_stack/providers/inline/vector_io/faiss/faiss.py +43 -28
- llama_stack/providers/inline/vector_io/milvus/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/milvus/config.py +1 -1
- llama_stack/providers/inline/vector_io/qdrant/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/qdrant/config.py +1 -1
- llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +40 -33
- llama_stack/providers/registry/agents.py +7 -3
- llama_stack/providers/registry/batches.py +1 -1
- llama_stack/providers/registry/datasetio.py +1 -1
- llama_stack/providers/registry/eval.py +1 -1
- llama_stack/{apis/datasets/__init__.py → providers/registry/file_processors.py} +5 -1
- llama_stack/providers/registry/files.py +11 -2
- llama_stack/providers/registry/inference.py +22 -3
- llama_stack/providers/registry/post_training.py +1 -1
- llama_stack/providers/registry/safety.py +1 -1
- llama_stack/providers/registry/scoring.py +1 -1
- llama_stack/providers/registry/tool_runtime.py +2 -2
- llama_stack/providers/registry/vector_io.py +7 -7
- llama_stack/providers/remote/datasetio/huggingface/huggingface.py +2 -5
- llama_stack/providers/remote/datasetio/nvidia/datasetio.py +1 -4
- llama_stack/providers/remote/eval/nvidia/eval.py +15 -9
- llama_stack/providers/remote/files/openai/__init__.py +19 -0
- llama_stack/providers/remote/files/openai/config.py +28 -0
- llama_stack/providers/remote/files/openai/files.py +253 -0
- llama_stack/providers/remote/files/s3/files.py +52 -30
- llama_stack/providers/remote/inference/anthropic/anthropic.py +2 -1
- llama_stack/providers/remote/inference/anthropic/config.py +1 -1
- llama_stack/providers/remote/inference/azure/azure.py +1 -3
- llama_stack/providers/remote/inference/azure/config.py +8 -7
- llama_stack/providers/remote/inference/bedrock/__init__.py +1 -1
- llama_stack/providers/remote/inference/bedrock/bedrock.py +82 -105
- llama_stack/providers/remote/inference/bedrock/config.py +24 -3
- llama_stack/providers/remote/inference/cerebras/cerebras.py +5 -5
- llama_stack/providers/remote/inference/cerebras/config.py +12 -5
- llama_stack/providers/remote/inference/databricks/config.py +13 -6
- llama_stack/providers/remote/inference/databricks/databricks.py +16 -6
- llama_stack/providers/remote/inference/fireworks/config.py +5 -5
- llama_stack/providers/remote/inference/fireworks/fireworks.py +1 -1
- llama_stack/providers/remote/inference/gemini/config.py +1 -1
- llama_stack/providers/remote/inference/gemini/gemini.py +13 -14
- llama_stack/providers/remote/inference/groq/config.py +5 -5
- llama_stack/providers/remote/inference/groq/groq.py +1 -1
- llama_stack/providers/remote/inference/llama_openai_compat/config.py +5 -5
- llama_stack/providers/remote/inference/llama_openai_compat/llama.py +8 -6
- llama_stack/providers/remote/inference/nvidia/__init__.py +1 -1
- llama_stack/providers/remote/inference/nvidia/config.py +21 -11
- llama_stack/providers/remote/inference/nvidia/nvidia.py +115 -3
- llama_stack/providers/remote/inference/nvidia/utils.py +1 -1
- llama_stack/providers/remote/inference/oci/__init__.py +17 -0
- llama_stack/providers/remote/inference/oci/auth.py +79 -0
- llama_stack/providers/remote/inference/oci/config.py +75 -0
- llama_stack/providers/remote/inference/oci/oci.py +162 -0
- llama_stack/providers/remote/inference/ollama/config.py +7 -5
- llama_stack/providers/remote/inference/ollama/ollama.py +17 -8
- llama_stack/providers/remote/inference/openai/config.py +4 -4
- llama_stack/providers/remote/inference/openai/openai.py +1 -1
- llama_stack/providers/remote/inference/passthrough/__init__.py +2 -2
- llama_stack/providers/remote/inference/passthrough/config.py +5 -10
- llama_stack/providers/remote/inference/passthrough/passthrough.py +97 -75
- llama_stack/providers/remote/inference/runpod/config.py +12 -5
- llama_stack/providers/remote/inference/runpod/runpod.py +2 -20
- llama_stack/providers/remote/inference/sambanova/config.py +5 -5
- llama_stack/providers/remote/inference/sambanova/sambanova.py +1 -1
- llama_stack/providers/remote/inference/tgi/config.py +7 -6
- llama_stack/providers/remote/inference/tgi/tgi.py +19 -11
- llama_stack/providers/remote/inference/together/config.py +5 -5
- llama_stack/providers/remote/inference/together/together.py +15 -12
- llama_stack/providers/remote/inference/vertexai/config.py +1 -1
- llama_stack/providers/remote/inference/vllm/config.py +5 -5
- llama_stack/providers/remote/inference/vllm/vllm.py +13 -14
- llama_stack/providers/remote/inference/watsonx/config.py +4 -4
- llama_stack/providers/remote/inference/watsonx/watsonx.py +21 -94
- llama_stack/providers/remote/post_training/nvidia/post_training.py +4 -4
- llama_stack/providers/remote/post_training/nvidia/utils.py +1 -1
- llama_stack/providers/remote/safety/bedrock/bedrock.py +6 -6
- llama_stack/providers/remote/safety/bedrock/config.py +1 -1
- llama_stack/providers/remote/safety/nvidia/config.py +1 -1
- llama_stack/providers/remote/safety/nvidia/nvidia.py +11 -5
- llama_stack/providers/remote/safety/sambanova/config.py +1 -1
- llama_stack/providers/remote/safety/sambanova/sambanova.py +6 -6
- llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +11 -6
- llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +12 -7
- llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +8 -2
- llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +57 -15
- llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +11 -6
- llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +11 -6
- llama_stack/providers/remote/vector_io/chroma/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/chroma/chroma.py +125 -20
- llama_stack/providers/remote/vector_io/chroma/config.py +1 -1
- llama_stack/providers/remote/vector_io/milvus/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/milvus/config.py +1 -1
- llama_stack/providers/remote/vector_io/milvus/milvus.py +27 -21
- llama_stack/providers/remote/vector_io/pgvector/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/pgvector/config.py +1 -1
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +26 -18
- llama_stack/providers/remote/vector_io/qdrant/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/qdrant/config.py +1 -1
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +141 -24
- llama_stack/providers/remote/vector_io/weaviate/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/weaviate/config.py +1 -1
- llama_stack/providers/remote/vector_io/weaviate/weaviate.py +26 -21
- llama_stack/providers/utils/common/data_schema_validator.py +1 -5
- llama_stack/providers/utils/files/form_data.py +1 -1
- llama_stack/providers/utils/inference/embedding_mixin.py +1 -1
- llama_stack/providers/utils/inference/inference_store.py +7 -8
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +79 -79
- llama_stack/providers/utils/inference/model_registry.py +1 -3
- llama_stack/providers/utils/inference/openai_compat.py +44 -1171
- llama_stack/providers/utils/inference/openai_mixin.py +68 -42
- llama_stack/providers/utils/inference/prompt_adapter.py +50 -265
- llama_stack/providers/utils/inference/stream_utils.py +23 -0
- llama_stack/providers/utils/memory/__init__.py +2 -0
- llama_stack/providers/utils/memory/file_utils.py +1 -1
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +181 -84
- llama_stack/providers/utils/memory/vector_store.py +39 -38
- llama_stack/providers/utils/pagination.py +1 -1
- llama_stack/providers/utils/responses/responses_store.py +15 -25
- llama_stack/providers/utils/scoring/aggregation_utils.py +1 -2
- llama_stack/providers/utils/scoring/base_scoring_fn.py +1 -2
- llama_stack/providers/utils/tools/mcp.py +93 -11
- llama_stack/telemetry/constants.py +27 -0
- llama_stack/telemetry/helpers.py +43 -0
- llama_stack/testing/api_recorder.py +25 -16
- {llama_stack-0.3.5.dist-info → llama_stack-0.4.0.dist-info}/METADATA +56 -54
- llama_stack-0.4.0.dist-info/RECORD +588 -0
- llama_stack-0.4.0.dist-info/top_level.txt +2 -0
- llama_stack_api/__init__.py +945 -0
- llama_stack_api/admin/__init__.py +45 -0
- llama_stack_api/admin/api.py +72 -0
- llama_stack_api/admin/fastapi_routes.py +117 -0
- llama_stack_api/admin/models.py +113 -0
- llama_stack_api/agents.py +173 -0
- llama_stack_api/batches/__init__.py +40 -0
- llama_stack_api/batches/api.py +53 -0
- llama_stack_api/batches/fastapi_routes.py +113 -0
- llama_stack_api/batches/models.py +78 -0
- llama_stack_api/benchmarks/__init__.py +43 -0
- llama_stack_api/benchmarks/api.py +39 -0
- llama_stack_api/benchmarks/fastapi_routes.py +109 -0
- llama_stack_api/benchmarks/models.py +109 -0
- {llama_stack/apis → llama_stack_api}/common/content_types.py +1 -43
- {llama_stack/apis → llama_stack_api}/common/errors.py +0 -8
- {llama_stack/apis → llama_stack_api}/common/job_types.py +1 -1
- llama_stack_api/common/responses.py +77 -0
- {llama_stack/apis → llama_stack_api}/common/training_types.py +1 -1
- {llama_stack/apis → llama_stack_api}/common/type_system.py +2 -14
- llama_stack_api/connectors.py +146 -0
- {llama_stack/apis/conversations → llama_stack_api}/conversations.py +23 -39
- {llama_stack/apis/datasetio → llama_stack_api}/datasetio.py +4 -8
- llama_stack_api/datasets/__init__.py +61 -0
- llama_stack_api/datasets/api.py +35 -0
- llama_stack_api/datasets/fastapi_routes.py +104 -0
- llama_stack_api/datasets/models.py +152 -0
- {llama_stack/providers → llama_stack_api}/datatypes.py +166 -10
- {llama_stack/apis/eval → llama_stack_api}/eval.py +8 -40
- llama_stack_api/file_processors/__init__.py +27 -0
- llama_stack_api/file_processors/api.py +64 -0
- llama_stack_api/file_processors/fastapi_routes.py +78 -0
- llama_stack_api/file_processors/models.py +42 -0
- llama_stack_api/files/__init__.py +35 -0
- llama_stack_api/files/api.py +51 -0
- llama_stack_api/files/fastapi_routes.py +124 -0
- llama_stack_api/files/models.py +107 -0
- {llama_stack/apis/inference → llama_stack_api}/inference.py +90 -194
- llama_stack_api/inspect_api/__init__.py +37 -0
- llama_stack_api/inspect_api/api.py +25 -0
- llama_stack_api/inspect_api/fastapi_routes.py +76 -0
- llama_stack_api/inspect_api/models.py +28 -0
- {llama_stack/apis/agents → llama_stack_api/internal}/__init__.py +3 -1
- llama_stack/providers/utils/kvstore/api.py → llama_stack_api/internal/kvstore.py +5 -0
- llama_stack_api/internal/sqlstore.py +79 -0
- {llama_stack/apis/models → llama_stack_api}/models.py +11 -9
- {llama_stack/apis/agents → llama_stack_api}/openai_responses.py +184 -27
- {llama_stack/apis/post_training → llama_stack_api}/post_training.py +7 -11
- {llama_stack/apis/prompts → llama_stack_api}/prompts.py +3 -4
- llama_stack_api/providers/__init__.py +33 -0
- llama_stack_api/providers/api.py +16 -0
- llama_stack_api/providers/fastapi_routes.py +57 -0
- llama_stack_api/providers/models.py +24 -0
- {llama_stack/apis/tools → llama_stack_api}/rag_tool.py +2 -52
- {llama_stack/apis → llama_stack_api}/resource.py +1 -1
- llama_stack_api/router_utils.py +160 -0
- {llama_stack/apis/safety → llama_stack_api}/safety.py +6 -9
- {llama_stack → llama_stack_api}/schema_utils.py +94 -4
- {llama_stack/apis/scoring → llama_stack_api}/scoring.py +3 -3
- {llama_stack/apis/scoring_functions → llama_stack_api}/scoring_functions.py +9 -6
- {llama_stack/apis/shields → llama_stack_api}/shields.py +6 -7
- {llama_stack/apis/tools → llama_stack_api}/tools.py +26 -21
- {llama_stack/apis/vector_io → llama_stack_api}/vector_io.py +133 -152
- {llama_stack/apis/vector_stores → llama_stack_api}/vector_stores.py +1 -1
- llama_stack/apis/agents/agents.py +0 -894
- llama_stack/apis/batches/__init__.py +0 -9
- llama_stack/apis/batches/batches.py +0 -100
- llama_stack/apis/benchmarks/__init__.py +0 -7
- llama_stack/apis/benchmarks/benchmarks.py +0 -108
- llama_stack/apis/common/responses.py +0 -36
- llama_stack/apis/conversations/__init__.py +0 -31
- llama_stack/apis/datasets/datasets.py +0 -251
- llama_stack/apis/datatypes.py +0 -160
- llama_stack/apis/eval/__init__.py +0 -7
- llama_stack/apis/files/__init__.py +0 -7
- llama_stack/apis/files/files.py +0 -199
- llama_stack/apis/inference/__init__.py +0 -7
- llama_stack/apis/inference/event_logger.py +0 -43
- llama_stack/apis/inspect/__init__.py +0 -7
- llama_stack/apis/inspect/inspect.py +0 -94
- llama_stack/apis/models/__init__.py +0 -7
- llama_stack/apis/post_training/__init__.py +0 -7
- llama_stack/apis/prompts/__init__.py +0 -9
- llama_stack/apis/providers/__init__.py +0 -7
- llama_stack/apis/providers/providers.py +0 -69
- llama_stack/apis/safety/__init__.py +0 -7
- llama_stack/apis/scoring/__init__.py +0 -7
- llama_stack/apis/scoring_functions/__init__.py +0 -7
- llama_stack/apis/shields/__init__.py +0 -7
- llama_stack/apis/synthetic_data_generation/__init__.py +0 -7
- llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +0 -77
- llama_stack/apis/telemetry/__init__.py +0 -7
- llama_stack/apis/telemetry/telemetry.py +0 -423
- llama_stack/apis/tools/__init__.py +0 -8
- llama_stack/apis/vector_io/__init__.py +0 -7
- llama_stack/apis/vector_stores/__init__.py +0 -7
- llama_stack/core/server/tracing.py +0 -80
- llama_stack/core/ui/app.py +0 -55
- llama_stack/core/ui/modules/__init__.py +0 -5
- llama_stack/core/ui/modules/api.py +0 -32
- llama_stack/core/ui/modules/utils.py +0 -42
- llama_stack/core/ui/page/__init__.py +0 -5
- llama_stack/core/ui/page/distribution/__init__.py +0 -5
- llama_stack/core/ui/page/distribution/datasets.py +0 -18
- llama_stack/core/ui/page/distribution/eval_tasks.py +0 -20
- llama_stack/core/ui/page/distribution/models.py +0 -18
- llama_stack/core/ui/page/distribution/providers.py +0 -27
- llama_stack/core/ui/page/distribution/resources.py +0 -48
- llama_stack/core/ui/page/distribution/scoring_functions.py +0 -18
- llama_stack/core/ui/page/distribution/shields.py +0 -19
- llama_stack/core/ui/page/evaluations/__init__.py +0 -5
- llama_stack/core/ui/page/evaluations/app_eval.py +0 -143
- llama_stack/core/ui/page/evaluations/native_eval.py +0 -253
- llama_stack/core/ui/page/playground/__init__.py +0 -5
- llama_stack/core/ui/page/playground/chat.py +0 -130
- llama_stack/core/ui/page/playground/tools.py +0 -352
- llama_stack/distributions/dell/build.yaml +0 -33
- llama_stack/distributions/meta-reference-gpu/build.yaml +0 -32
- llama_stack/distributions/nvidia/build.yaml +0 -29
- llama_stack/distributions/open-benchmark/build.yaml +0 -36
- llama_stack/distributions/postgres-demo/__init__.py +0 -7
- llama_stack/distributions/postgres-demo/build.yaml +0 -23
- llama_stack/distributions/postgres-demo/postgres_demo.py +0 -125
- llama_stack/distributions/starter/build.yaml +0 -61
- llama_stack/distributions/starter-gpu/build.yaml +0 -61
- llama_stack/distributions/watsonx/build.yaml +0 -33
- llama_stack/providers/inline/agents/meta_reference/agent_instance.py +0 -1024
- llama_stack/providers/inline/agents/meta_reference/persistence.py +0 -228
- llama_stack/providers/inline/telemetry/__init__.py +0 -5
- llama_stack/providers/inline/telemetry/meta_reference/__init__.py +0 -21
- llama_stack/providers/inline/telemetry/meta_reference/config.py +0 -47
- llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +0 -252
- llama_stack/providers/remote/inference/bedrock/models.py +0 -29
- llama_stack/providers/utils/kvstore/sqlite/config.py +0 -20
- llama_stack/providers/utils/sqlstore/__init__.py +0 -5
- llama_stack/providers/utils/sqlstore/api.py +0 -128
- llama_stack/providers/utils/telemetry/__init__.py +0 -5
- llama_stack/providers/utils/telemetry/trace_protocol.py +0 -142
- llama_stack/providers/utils/telemetry/tracing.py +0 -384
- llama_stack/strong_typing/__init__.py +0 -19
- llama_stack/strong_typing/auxiliary.py +0 -228
- llama_stack/strong_typing/classdef.py +0 -440
- llama_stack/strong_typing/core.py +0 -46
- llama_stack/strong_typing/deserializer.py +0 -877
- llama_stack/strong_typing/docstring.py +0 -409
- llama_stack/strong_typing/exception.py +0 -23
- llama_stack/strong_typing/inspection.py +0 -1085
- llama_stack/strong_typing/mapping.py +0 -40
- llama_stack/strong_typing/name.py +0 -182
- llama_stack/strong_typing/schema.py +0 -792
- llama_stack/strong_typing/serialization.py +0 -97
- llama_stack/strong_typing/serializer.py +0 -500
- llama_stack/strong_typing/slots.py +0 -27
- llama_stack/strong_typing/topological.py +0 -89
- llama_stack/ui/node_modules/flatted/python/flatted.py +0 -149
- llama_stack-0.3.5.dist-info/RECORD +0 -625
- llama_stack-0.3.5.dist-info/top_level.txt +0 -1
- /llama_stack/{providers/utils → core/storage}/kvstore/config.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/mongodb/__init__.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/postgres/__init__.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/redis/__init__.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/sqlite/__init__.py +0 -0
- /llama_stack/{apis → providers/inline/file_processor}/__init__.py +0 -0
- /llama_stack/{apis/common → telemetry}/__init__.py +0 -0
- {llama_stack-0.3.5.dist-info → llama_stack-0.4.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.3.5.dist-info → llama_stack-0.4.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.3.5.dist-info → llama_stack-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {llama_stack/core/ui → llama_stack_api/common}/__init__.py +0 -0
- {llama_stack/strong_typing → llama_stack_api}/py.typed +0 -0
- {llama_stack/apis → llama_stack_api}/version.py +0 -0
|
@@ -1,1024 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
import copy
|
|
8
|
-
import json
|
|
9
|
-
import re
|
|
10
|
-
import uuid
|
|
11
|
-
import warnings
|
|
12
|
-
from collections.abc import AsyncGenerator
|
|
13
|
-
from datetime import UTC, datetime
|
|
14
|
-
|
|
15
|
-
import httpx
|
|
16
|
-
|
|
17
|
-
from llama_stack.apis.agents import (
|
|
18
|
-
AgentConfig,
|
|
19
|
-
AgentToolGroup,
|
|
20
|
-
AgentToolGroupWithArgs,
|
|
21
|
-
AgentTurnCreateRequest,
|
|
22
|
-
AgentTurnResponseEvent,
|
|
23
|
-
AgentTurnResponseEventType,
|
|
24
|
-
AgentTurnResponseStepCompletePayload,
|
|
25
|
-
AgentTurnResponseStepProgressPayload,
|
|
26
|
-
AgentTurnResponseStepStartPayload,
|
|
27
|
-
AgentTurnResponseStreamChunk,
|
|
28
|
-
AgentTurnResponseTurnAwaitingInputPayload,
|
|
29
|
-
AgentTurnResponseTurnCompletePayload,
|
|
30
|
-
AgentTurnResumeRequest,
|
|
31
|
-
Attachment,
|
|
32
|
-
Document,
|
|
33
|
-
InferenceStep,
|
|
34
|
-
ShieldCallStep,
|
|
35
|
-
StepType,
|
|
36
|
-
ToolExecutionStep,
|
|
37
|
-
Turn,
|
|
38
|
-
)
|
|
39
|
-
from llama_stack.apis.common.content_types import (
|
|
40
|
-
URL,
|
|
41
|
-
TextContentItem,
|
|
42
|
-
ToolCallDelta,
|
|
43
|
-
ToolCallParseStatus,
|
|
44
|
-
)
|
|
45
|
-
from llama_stack.apis.common.errors import SessionNotFoundError
|
|
46
|
-
from llama_stack.apis.inference import (
|
|
47
|
-
ChatCompletionResponseEventType,
|
|
48
|
-
CompletionMessage,
|
|
49
|
-
Inference,
|
|
50
|
-
Message,
|
|
51
|
-
OpenAIAssistantMessageParam,
|
|
52
|
-
OpenAIChatCompletionRequestWithExtraBody,
|
|
53
|
-
OpenAIDeveloperMessageParam,
|
|
54
|
-
OpenAIMessageParam,
|
|
55
|
-
OpenAISystemMessageParam,
|
|
56
|
-
OpenAIToolMessageParam,
|
|
57
|
-
OpenAIUserMessageParam,
|
|
58
|
-
SamplingParams,
|
|
59
|
-
StopReason,
|
|
60
|
-
SystemMessage,
|
|
61
|
-
ToolDefinition,
|
|
62
|
-
ToolResponse,
|
|
63
|
-
ToolResponseMessage,
|
|
64
|
-
UserMessage,
|
|
65
|
-
)
|
|
66
|
-
from llama_stack.apis.safety import Safety
|
|
67
|
-
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
|
68
|
-
from llama_stack.apis.vector_io import VectorIO
|
|
69
|
-
from llama_stack.core.datatypes import AccessRule
|
|
70
|
-
from llama_stack.log import get_logger
|
|
71
|
-
from llama_stack.models.llama.datatypes import (
|
|
72
|
-
BuiltinTool,
|
|
73
|
-
ToolCall,
|
|
74
|
-
)
|
|
75
|
-
from llama_stack.providers.utils.inference.openai_compat import (
|
|
76
|
-
convert_message_to_openai_dict_new,
|
|
77
|
-
convert_openai_chat_completion_stream,
|
|
78
|
-
convert_tooldef_to_openai_tool,
|
|
79
|
-
)
|
|
80
|
-
from llama_stack.providers.utils.kvstore import KVStore
|
|
81
|
-
from llama_stack.providers.utils.telemetry import tracing
|
|
82
|
-
|
|
83
|
-
from .persistence import AgentPersistence
|
|
84
|
-
from .safety import SafetyException, ShieldRunnerMixin
|
|
85
|
-
|
|
86
|
-
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
|
87
|
-
MEMORY_QUERY_TOOL = "knowledge_search"
|
|
88
|
-
WEB_SEARCH_TOOL = "web_search"
|
|
89
|
-
RAG_TOOL_GROUP = "builtin::rag"
|
|
90
|
-
|
|
91
|
-
logger = get_logger(name=__name__, category="agents::meta_reference")
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
class ChatAgent(ShieldRunnerMixin):
|
|
95
|
-
def __init__(
|
|
96
|
-
self,
|
|
97
|
-
agent_id: str,
|
|
98
|
-
agent_config: AgentConfig,
|
|
99
|
-
inference_api: Inference,
|
|
100
|
-
safety_api: Safety,
|
|
101
|
-
tool_runtime_api: ToolRuntime,
|
|
102
|
-
tool_groups_api: ToolGroups,
|
|
103
|
-
vector_io_api: VectorIO,
|
|
104
|
-
persistence_store: KVStore,
|
|
105
|
-
created_at: str,
|
|
106
|
-
policy: list[AccessRule],
|
|
107
|
-
telemetry_enabled: bool = False,
|
|
108
|
-
):
|
|
109
|
-
self.agent_id = agent_id
|
|
110
|
-
self.agent_config = agent_config
|
|
111
|
-
self.inference_api = inference_api
|
|
112
|
-
self.safety_api = safety_api
|
|
113
|
-
self.vector_io_api = vector_io_api
|
|
114
|
-
self.storage = AgentPersistence(agent_id, persistence_store, policy)
|
|
115
|
-
self.tool_runtime_api = tool_runtime_api
|
|
116
|
-
self.tool_groups_api = tool_groups_api
|
|
117
|
-
self.created_at = created_at
|
|
118
|
-
self.telemetry_enabled = telemetry_enabled
|
|
119
|
-
|
|
120
|
-
ShieldRunnerMixin.__init__(
|
|
121
|
-
self,
|
|
122
|
-
safety_api,
|
|
123
|
-
input_shields=agent_config.input_shields,
|
|
124
|
-
output_shields=agent_config.output_shields,
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
|
128
|
-
messages = []
|
|
129
|
-
|
|
130
|
-
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
|
131
|
-
tool_call_ids = set()
|
|
132
|
-
for step in turn.steps:
|
|
133
|
-
if step.step_type == StepType.tool_execution.value:
|
|
134
|
-
for response in step.tool_responses:
|
|
135
|
-
tool_call_ids.add(response.call_id)
|
|
136
|
-
|
|
137
|
-
for m in turn.input_messages:
|
|
138
|
-
msg = m.model_copy()
|
|
139
|
-
# We do not want to keep adding RAG context to the input messages
|
|
140
|
-
# May be this should be a parameter of the agentic instance
|
|
141
|
-
# that can define its behavior in a custom way
|
|
142
|
-
if isinstance(msg, UserMessage):
|
|
143
|
-
msg.context = None
|
|
144
|
-
if isinstance(msg, ToolResponseMessage):
|
|
145
|
-
if msg.call_id in tool_call_ids:
|
|
146
|
-
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps
|
|
147
|
-
continue
|
|
148
|
-
|
|
149
|
-
messages.append(msg)
|
|
150
|
-
|
|
151
|
-
for step in turn.steps:
|
|
152
|
-
if step.step_type == StepType.inference.value:
|
|
153
|
-
messages.append(step.model_response)
|
|
154
|
-
elif step.step_type == StepType.tool_execution.value:
|
|
155
|
-
for response in step.tool_responses:
|
|
156
|
-
messages.append(
|
|
157
|
-
ToolResponseMessage(
|
|
158
|
-
call_id=response.call_id,
|
|
159
|
-
content=response.content,
|
|
160
|
-
)
|
|
161
|
-
)
|
|
162
|
-
elif step.step_type == StepType.shield_call.value:
|
|
163
|
-
if step.violation:
|
|
164
|
-
# CompletionMessage itself in the ShieldResponse
|
|
165
|
-
messages.append(
|
|
166
|
-
CompletionMessage(
|
|
167
|
-
content=step.violation.user_message,
|
|
168
|
-
stop_reason=StopReason.end_of_turn,
|
|
169
|
-
)
|
|
170
|
-
)
|
|
171
|
-
return messages
|
|
172
|
-
|
|
173
|
-
async def create_session(self, name: str) -> str:
|
|
174
|
-
return await self.storage.create_session(name)
|
|
175
|
-
|
|
176
|
-
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
|
177
|
-
messages = []
|
|
178
|
-
if self.agent_config.instructions != "":
|
|
179
|
-
messages.append(SystemMessage(content=self.agent_config.instructions))
|
|
180
|
-
|
|
181
|
-
for turn in turns:
|
|
182
|
-
messages.extend(self.turn_to_messages(turn))
|
|
183
|
-
return messages
|
|
184
|
-
|
|
185
|
-
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
|
186
|
-
turn_id = str(uuid.uuid4())
|
|
187
|
-
if self.telemetry_enabled:
|
|
188
|
-
span = tracing.get_current_span()
|
|
189
|
-
if span is not None:
|
|
190
|
-
span.set_attribute("session_id", request.session_id)
|
|
191
|
-
span.set_attribute("agent_id", self.agent_id)
|
|
192
|
-
span.set_attribute("request", request.model_dump_json())
|
|
193
|
-
span.set_attribute("turn_id", turn_id)
|
|
194
|
-
if self.agent_config.name:
|
|
195
|
-
span.set_attribute("agent_name", self.agent_config.name)
|
|
196
|
-
|
|
197
|
-
await self._initialize_tools(request.toolgroups)
|
|
198
|
-
async for chunk in self._run_turn(request, turn_id):
|
|
199
|
-
yield chunk
|
|
200
|
-
|
|
201
|
-
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
|
202
|
-
if self.telemetry_enabled:
|
|
203
|
-
span = tracing.get_current_span()
|
|
204
|
-
if span is not None:
|
|
205
|
-
span.set_attribute("agent_id", self.agent_id)
|
|
206
|
-
span.set_attribute("session_id", request.session_id)
|
|
207
|
-
span.set_attribute("request", request.model_dump_json())
|
|
208
|
-
span.set_attribute("turn_id", request.turn_id)
|
|
209
|
-
if self.agent_config.name:
|
|
210
|
-
span.set_attribute("agent_name", self.agent_config.name)
|
|
211
|
-
|
|
212
|
-
await self._initialize_tools()
|
|
213
|
-
async for chunk in self._run_turn(request):
|
|
214
|
-
yield chunk
|
|
215
|
-
|
|
216
|
-
async def _run_turn(
|
|
217
|
-
self,
|
|
218
|
-
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
|
|
219
|
-
turn_id: str | None = None,
|
|
220
|
-
) -> AsyncGenerator:
|
|
221
|
-
assert request.stream is True, "Non-streaming not supported"
|
|
222
|
-
|
|
223
|
-
is_resume = isinstance(request, AgentTurnResumeRequest)
|
|
224
|
-
session_info = await self.storage.get_session_info(request.session_id)
|
|
225
|
-
if session_info is None:
|
|
226
|
-
raise SessionNotFoundError(request.session_id)
|
|
227
|
-
|
|
228
|
-
turns = await self.storage.get_session_turns(request.session_id)
|
|
229
|
-
if is_resume and len(turns) == 0:
|
|
230
|
-
raise ValueError("No turns found for session")
|
|
231
|
-
|
|
232
|
-
steps = []
|
|
233
|
-
messages = await self.get_messages_from_turns(turns)
|
|
234
|
-
if is_resume:
|
|
235
|
-
tool_response_messages = [
|
|
236
|
-
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
|
237
|
-
]
|
|
238
|
-
messages.extend(tool_response_messages)
|
|
239
|
-
last_turn = turns[-1]
|
|
240
|
-
last_turn_messages = self.turn_to_messages(last_turn)
|
|
241
|
-
last_turn_messages = [
|
|
242
|
-
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
|
243
|
-
]
|
|
244
|
-
last_turn_messages.extend(tool_response_messages)
|
|
245
|
-
|
|
246
|
-
# get steps from the turn
|
|
247
|
-
steps = last_turn.steps
|
|
248
|
-
|
|
249
|
-
# mark tool execution step as complete
|
|
250
|
-
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
|
251
|
-
# we'll create a new tool execution step with current time
|
|
252
|
-
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
|
253
|
-
request.session_id, request.turn_id
|
|
254
|
-
)
|
|
255
|
-
now = datetime.now(UTC).isoformat()
|
|
256
|
-
tool_execution_step = ToolExecutionStep(
|
|
257
|
-
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
|
258
|
-
turn_id=request.turn_id,
|
|
259
|
-
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
|
260
|
-
tool_responses=request.tool_responses,
|
|
261
|
-
completed_at=now,
|
|
262
|
-
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
|
263
|
-
)
|
|
264
|
-
steps.append(tool_execution_step)
|
|
265
|
-
yield AgentTurnResponseStreamChunk(
|
|
266
|
-
event=AgentTurnResponseEvent(
|
|
267
|
-
payload=AgentTurnResponseStepCompletePayload(
|
|
268
|
-
step_type=StepType.tool_execution.value,
|
|
269
|
-
step_id=tool_execution_step.step_id,
|
|
270
|
-
step_details=tool_execution_step,
|
|
271
|
-
)
|
|
272
|
-
)
|
|
273
|
-
)
|
|
274
|
-
input_messages = last_turn.input_messages
|
|
275
|
-
|
|
276
|
-
turn_id = request.turn_id
|
|
277
|
-
start_time = last_turn.started_at
|
|
278
|
-
else:
|
|
279
|
-
messages.extend(request.messages)
|
|
280
|
-
start_time = datetime.now(UTC).isoformat()
|
|
281
|
-
input_messages = request.messages
|
|
282
|
-
|
|
283
|
-
output_message = None
|
|
284
|
-
async for chunk in self.run(
|
|
285
|
-
session_id=request.session_id,
|
|
286
|
-
turn_id=turn_id,
|
|
287
|
-
input_messages=messages,
|
|
288
|
-
sampling_params=self.agent_config.sampling_params,
|
|
289
|
-
stream=request.stream,
|
|
290
|
-
documents=request.documents if not is_resume else None,
|
|
291
|
-
):
|
|
292
|
-
if isinstance(chunk, CompletionMessage):
|
|
293
|
-
output_message = chunk
|
|
294
|
-
continue
|
|
295
|
-
|
|
296
|
-
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
|
297
|
-
event = chunk.event
|
|
298
|
-
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
|
299
|
-
steps.append(event.payload.step_details)
|
|
300
|
-
|
|
301
|
-
yield chunk
|
|
302
|
-
|
|
303
|
-
assert output_message is not None
|
|
304
|
-
|
|
305
|
-
turn = Turn(
|
|
306
|
-
turn_id=turn_id,
|
|
307
|
-
session_id=request.session_id,
|
|
308
|
-
input_messages=input_messages,
|
|
309
|
-
output_message=output_message,
|
|
310
|
-
started_at=start_time,
|
|
311
|
-
completed_at=datetime.now(UTC).isoformat(),
|
|
312
|
-
steps=steps,
|
|
313
|
-
)
|
|
314
|
-
await self.storage.add_turn_to_session(request.session_id, turn)
|
|
315
|
-
if output_message.tool_calls:
|
|
316
|
-
chunk = AgentTurnResponseStreamChunk(
|
|
317
|
-
event=AgentTurnResponseEvent(
|
|
318
|
-
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
|
319
|
-
turn=turn,
|
|
320
|
-
)
|
|
321
|
-
)
|
|
322
|
-
)
|
|
323
|
-
else:
|
|
324
|
-
chunk = AgentTurnResponseStreamChunk(
|
|
325
|
-
event=AgentTurnResponseEvent(
|
|
326
|
-
payload=AgentTurnResponseTurnCompletePayload(
|
|
327
|
-
turn=turn,
|
|
328
|
-
)
|
|
329
|
-
)
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
yield chunk
|
|
333
|
-
|
|
334
|
-
async def run(
|
|
335
|
-
self,
|
|
336
|
-
session_id: str,
|
|
337
|
-
turn_id: str,
|
|
338
|
-
input_messages: list[Message],
|
|
339
|
-
sampling_params: SamplingParams,
|
|
340
|
-
stream: bool = False,
|
|
341
|
-
documents: list[Document] | None = None,
|
|
342
|
-
) -> AsyncGenerator:
|
|
343
|
-
# Doing async generators makes downstream code much simpler and everything amenable to
|
|
344
|
-
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
|
345
|
-
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
|
346
|
-
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
|
347
|
-
|
|
348
|
-
if len(self.input_shields) > 0:
|
|
349
|
-
async for res in self.run_multiple_shields_wrapper(
|
|
350
|
-
turn_id, input_messages, self.input_shields, "user-input"
|
|
351
|
-
):
|
|
352
|
-
if isinstance(res, bool):
|
|
353
|
-
return
|
|
354
|
-
else:
|
|
355
|
-
yield res
|
|
356
|
-
|
|
357
|
-
async for res in self._run(
|
|
358
|
-
session_id,
|
|
359
|
-
turn_id,
|
|
360
|
-
input_messages,
|
|
361
|
-
sampling_params,
|
|
362
|
-
stream,
|
|
363
|
-
documents,
|
|
364
|
-
):
|
|
365
|
-
if isinstance(res, bool):
|
|
366
|
-
return
|
|
367
|
-
elif isinstance(res, CompletionMessage):
|
|
368
|
-
final_response = res
|
|
369
|
-
break
|
|
370
|
-
else:
|
|
371
|
-
yield res
|
|
372
|
-
|
|
373
|
-
assert final_response is not None
|
|
374
|
-
# for output shields run on the full input and output combination
|
|
375
|
-
messages = input_messages + [final_response]
|
|
376
|
-
|
|
377
|
-
if len(self.output_shields) > 0:
|
|
378
|
-
async for res in self.run_multiple_shields_wrapper(
|
|
379
|
-
turn_id, messages, self.output_shields, "assistant-output"
|
|
380
|
-
):
|
|
381
|
-
if isinstance(res, bool):
|
|
382
|
-
return
|
|
383
|
-
else:
|
|
384
|
-
yield res
|
|
385
|
-
|
|
386
|
-
yield final_response
|
|
387
|
-
|
|
388
|
-
async def run_multiple_shields_wrapper(
|
|
389
|
-
self,
|
|
390
|
-
turn_id: str,
|
|
391
|
-
messages: list[Message],
|
|
392
|
-
shields: list[str],
|
|
393
|
-
touchpoint: str,
|
|
394
|
-
) -> AsyncGenerator:
|
|
395
|
-
async with tracing.span("run_shields") as span:
|
|
396
|
-
if self.telemetry_enabled and span is not None:
|
|
397
|
-
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
|
398
|
-
if len(shields) == 0:
|
|
399
|
-
span.set_attribute("output", "no shields")
|
|
400
|
-
|
|
401
|
-
if len(shields) == 0:
|
|
402
|
-
return
|
|
403
|
-
|
|
404
|
-
step_id = str(uuid.uuid4())
|
|
405
|
-
shield_call_start_time = datetime.now(UTC).isoformat()
|
|
406
|
-
try:
|
|
407
|
-
yield AgentTurnResponseStreamChunk(
|
|
408
|
-
event=AgentTurnResponseEvent(
|
|
409
|
-
payload=AgentTurnResponseStepStartPayload(
|
|
410
|
-
step_type=StepType.shield_call.value,
|
|
411
|
-
step_id=step_id,
|
|
412
|
-
metadata=dict(touchpoint=touchpoint),
|
|
413
|
-
)
|
|
414
|
-
)
|
|
415
|
-
)
|
|
416
|
-
await self.run_multiple_shields(messages, shields)
|
|
417
|
-
|
|
418
|
-
except SafetyException as e:
|
|
419
|
-
yield AgentTurnResponseStreamChunk(
|
|
420
|
-
event=AgentTurnResponseEvent(
|
|
421
|
-
payload=AgentTurnResponseStepCompletePayload(
|
|
422
|
-
step_type=StepType.shield_call.value,
|
|
423
|
-
step_id=step_id,
|
|
424
|
-
step_details=ShieldCallStep(
|
|
425
|
-
step_id=step_id,
|
|
426
|
-
turn_id=turn_id,
|
|
427
|
-
violation=e.violation,
|
|
428
|
-
started_at=shield_call_start_time,
|
|
429
|
-
completed_at=datetime.now(UTC).isoformat(),
|
|
430
|
-
),
|
|
431
|
-
)
|
|
432
|
-
)
|
|
433
|
-
)
|
|
434
|
-
if self.telemetry_enabled and span is not None:
|
|
435
|
-
span.set_attribute("output", e.violation.model_dump_json())
|
|
436
|
-
|
|
437
|
-
yield CompletionMessage(
|
|
438
|
-
content=str(e),
|
|
439
|
-
stop_reason=StopReason.end_of_turn,
|
|
440
|
-
)
|
|
441
|
-
yield False
|
|
442
|
-
|
|
443
|
-
yield AgentTurnResponseStreamChunk(
|
|
444
|
-
event=AgentTurnResponseEvent(
|
|
445
|
-
payload=AgentTurnResponseStepCompletePayload(
|
|
446
|
-
step_type=StepType.shield_call.value,
|
|
447
|
-
step_id=step_id,
|
|
448
|
-
step_details=ShieldCallStep(
|
|
449
|
-
step_id=step_id,
|
|
450
|
-
turn_id=turn_id,
|
|
451
|
-
violation=None,
|
|
452
|
-
started_at=shield_call_start_time,
|
|
453
|
-
completed_at=datetime.now(UTC).isoformat(),
|
|
454
|
-
),
|
|
455
|
-
)
|
|
456
|
-
)
|
|
457
|
-
)
|
|
458
|
-
if self.telemetry_enabled and span is not None:
|
|
459
|
-
span.set_attribute("output", "no violations")
|
|
460
|
-
|
|
461
|
-
async def _run(
|
|
462
|
-
self,
|
|
463
|
-
session_id: str,
|
|
464
|
-
turn_id: str,
|
|
465
|
-
input_messages: list[Message],
|
|
466
|
-
sampling_params: SamplingParams,
|
|
467
|
-
stream: bool = False,
|
|
468
|
-
documents: list[Document] | None = None,
|
|
469
|
-
) -> AsyncGenerator:
|
|
470
|
-
# if document is passed in a turn, we parse the raw text of the document
|
|
471
|
-
# and sent it as a user message
|
|
472
|
-
if documents:
|
|
473
|
-
contexts = []
|
|
474
|
-
for document in documents:
|
|
475
|
-
raw_document_text = await get_raw_document_text(document)
|
|
476
|
-
contexts.append(raw_document_text)
|
|
477
|
-
|
|
478
|
-
attached_context = "\n".join(contexts)
|
|
479
|
-
if isinstance(input_messages[-1].content, str):
|
|
480
|
-
input_messages[-1].content += attached_context
|
|
481
|
-
elif isinstance(input_messages[-1].content, list):
|
|
482
|
-
input_messages[-1].content.append(TextContentItem(text=attached_context))
|
|
483
|
-
else:
|
|
484
|
-
input_messages[-1].content = [
|
|
485
|
-
input_messages[-1].content,
|
|
486
|
-
TextContentItem(text=attached_context),
|
|
487
|
-
]
|
|
488
|
-
|
|
489
|
-
session_info = await self.storage.get_session_info(session_id)
|
|
490
|
-
# if the session has a memory bank id, let the memory tool use it
|
|
491
|
-
if session_info and session_info.vector_db_id:
|
|
492
|
-
for tool_name in self.tool_name_to_args.keys():
|
|
493
|
-
if tool_name == MEMORY_QUERY_TOOL:
|
|
494
|
-
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
|
|
495
|
-
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
|
|
496
|
-
else:
|
|
497
|
-
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
|
498
|
-
|
|
499
|
-
output_attachments = []
|
|
500
|
-
|
|
501
|
-
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
|
502
|
-
|
|
503
|
-
# Build a map of custom tools to their definitions for faster lookup
|
|
504
|
-
client_tools = {}
|
|
505
|
-
for tool in self.agent_config.client_tools:
|
|
506
|
-
client_tools[tool.name] = tool
|
|
507
|
-
while True:
|
|
508
|
-
step_id = str(uuid.uuid4())
|
|
509
|
-
inference_start_time = datetime.now(UTC).isoformat()
|
|
510
|
-
yield AgentTurnResponseStreamChunk(
|
|
511
|
-
event=AgentTurnResponseEvent(
|
|
512
|
-
payload=AgentTurnResponseStepStartPayload(
|
|
513
|
-
step_type=StepType.inference.value,
|
|
514
|
-
step_id=step_id,
|
|
515
|
-
)
|
|
516
|
-
)
|
|
517
|
-
)
|
|
518
|
-
|
|
519
|
-
tool_calls = []
|
|
520
|
-
content = ""
|
|
521
|
-
stop_reason: StopReason | None = None
|
|
522
|
-
|
|
523
|
-
async with tracing.span("inference") as span:
|
|
524
|
-
if self.telemetry_enabled and span is not None:
|
|
525
|
-
if self.agent_config.name:
|
|
526
|
-
span.set_attribute("agent_name", self.agent_config.name)
|
|
527
|
-
|
|
528
|
-
def _serialize_nested(value):
|
|
529
|
-
"""Recursively serialize nested Pydantic models to dicts."""
|
|
530
|
-
from pydantic import BaseModel
|
|
531
|
-
|
|
532
|
-
if isinstance(value, BaseModel):
|
|
533
|
-
return value.model_dump(mode="json")
|
|
534
|
-
elif isinstance(value, dict):
|
|
535
|
-
return {k: _serialize_nested(v) for k, v in value.items()}
|
|
536
|
-
elif isinstance(value, list):
|
|
537
|
-
return [_serialize_nested(item) for item in value]
|
|
538
|
-
else:
|
|
539
|
-
return value
|
|
540
|
-
|
|
541
|
-
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
|
|
542
|
-
# Serialize any nested Pydantic models to plain dicts
|
|
543
|
-
openai_msg = _serialize_nested(openai_msg)
|
|
544
|
-
|
|
545
|
-
role = openai_msg.get("role")
|
|
546
|
-
if role == "user":
|
|
547
|
-
return OpenAIUserMessageParam(**openai_msg)
|
|
548
|
-
elif role == "system":
|
|
549
|
-
return OpenAISystemMessageParam(**openai_msg)
|
|
550
|
-
elif role == "assistant":
|
|
551
|
-
return OpenAIAssistantMessageParam(**openai_msg)
|
|
552
|
-
elif role == "tool":
|
|
553
|
-
return OpenAIToolMessageParam(**openai_msg)
|
|
554
|
-
elif role == "developer":
|
|
555
|
-
return OpenAIDeveloperMessageParam(**openai_msg)
|
|
556
|
-
else:
|
|
557
|
-
raise ValueError(f"Unknown message role: {role}")
|
|
558
|
-
|
|
559
|
-
# Convert messages to OpenAI format
|
|
560
|
-
openai_messages: list[OpenAIMessageParam] = [
|
|
561
|
-
_add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages
|
|
562
|
-
]
|
|
563
|
-
|
|
564
|
-
# Convert tool definitions to OpenAI format
|
|
565
|
-
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
|
|
566
|
-
|
|
567
|
-
# Extract tool_choice from tool_config for OpenAI compatibility
|
|
568
|
-
# Note: tool_choice can only be provided when tools are also provided
|
|
569
|
-
tool_choice = None
|
|
570
|
-
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
|
|
571
|
-
tc = self.agent_config.tool_config.tool_choice
|
|
572
|
-
tool_choice_str = tc.value if hasattr(tc, "value") else str(tc)
|
|
573
|
-
# Convert tool_choice to OpenAI format
|
|
574
|
-
if tool_choice_str in ("auto", "none", "required"):
|
|
575
|
-
tool_choice = tool_choice_str
|
|
576
|
-
else:
|
|
577
|
-
# It's a specific tool name, wrap it in the proper format
|
|
578
|
-
tool_choice = {"type": "function", "function": {"name": tool_choice_str}}
|
|
579
|
-
|
|
580
|
-
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
|
|
581
|
-
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
|
|
582
|
-
top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None)
|
|
583
|
-
max_tokens = getattr(sampling_params, "max_tokens", None)
|
|
584
|
-
|
|
585
|
-
# Use OpenAI chat completion
|
|
586
|
-
params = OpenAIChatCompletionRequestWithExtraBody(
|
|
587
|
-
model=self.agent_config.model,
|
|
588
|
-
messages=openai_messages,
|
|
589
|
-
tools=openai_tools if openai_tools else None,
|
|
590
|
-
tool_choice=tool_choice,
|
|
591
|
-
response_format=self.agent_config.response_format,
|
|
592
|
-
temperature=temperature,
|
|
593
|
-
top_p=top_p,
|
|
594
|
-
max_tokens=max_tokens,
|
|
595
|
-
stream=True,
|
|
596
|
-
)
|
|
597
|
-
openai_stream = await self.inference_api.openai_chat_completion(params)
|
|
598
|
-
|
|
599
|
-
# Convert OpenAI stream back to Llama Stack format
|
|
600
|
-
response_stream = convert_openai_chat_completion_stream(
|
|
601
|
-
openai_stream, enable_incremental_tool_calls=True
|
|
602
|
-
)
|
|
603
|
-
|
|
604
|
-
async for chunk in response_stream:
|
|
605
|
-
event = chunk.event
|
|
606
|
-
if event.event_type == ChatCompletionResponseEventType.start:
|
|
607
|
-
continue
|
|
608
|
-
elif event.event_type == ChatCompletionResponseEventType.complete:
|
|
609
|
-
stop_reason = event.stop_reason or StopReason.end_of_turn
|
|
610
|
-
continue
|
|
611
|
-
|
|
612
|
-
delta = event.delta
|
|
613
|
-
if delta.type == "tool_call":
|
|
614
|
-
if delta.parse_status == ToolCallParseStatus.succeeded:
|
|
615
|
-
tool_calls.append(delta.tool_call)
|
|
616
|
-
elif delta.parse_status == ToolCallParseStatus.failed:
|
|
617
|
-
# If we cannot parse the tools, set the content to the unparsed raw text
|
|
618
|
-
content = str(delta.tool_call)
|
|
619
|
-
if stream:
|
|
620
|
-
yield AgentTurnResponseStreamChunk(
|
|
621
|
-
event=AgentTurnResponseEvent(
|
|
622
|
-
payload=AgentTurnResponseStepProgressPayload(
|
|
623
|
-
step_type=StepType.inference.value,
|
|
624
|
-
step_id=step_id,
|
|
625
|
-
delta=delta,
|
|
626
|
-
)
|
|
627
|
-
)
|
|
628
|
-
)
|
|
629
|
-
|
|
630
|
-
elif delta.type == "text":
|
|
631
|
-
content += delta.text
|
|
632
|
-
if stream and event.stop_reason is None:
|
|
633
|
-
yield AgentTurnResponseStreamChunk(
|
|
634
|
-
event=AgentTurnResponseEvent(
|
|
635
|
-
payload=AgentTurnResponseStepProgressPayload(
|
|
636
|
-
step_type=StepType.inference.value,
|
|
637
|
-
step_id=step_id,
|
|
638
|
-
delta=delta,
|
|
639
|
-
)
|
|
640
|
-
)
|
|
641
|
-
)
|
|
642
|
-
else:
|
|
643
|
-
raise ValueError(f"Unexpected delta type {type(delta)}")
|
|
644
|
-
|
|
645
|
-
if self.telemetry_enabled and span is not None:
|
|
646
|
-
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
|
647
|
-
span.set_attribute(
|
|
648
|
-
"input",
|
|
649
|
-
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
|
650
|
-
)
|
|
651
|
-
output_attr = json.dumps(
|
|
652
|
-
{
|
|
653
|
-
"content": content,
|
|
654
|
-
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
|
655
|
-
}
|
|
656
|
-
)
|
|
657
|
-
span.set_attribute("output", output_attr)
|
|
658
|
-
|
|
659
|
-
n_iter += 1
|
|
660
|
-
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
|
661
|
-
|
|
662
|
-
stop_reason = stop_reason or StopReason.out_of_tokens
|
|
663
|
-
|
|
664
|
-
# If tool calls are parsed successfully,
|
|
665
|
-
# if content is not made null the tool call str will also be in the content
|
|
666
|
-
# and tokens will have tool call syntax included twice
|
|
667
|
-
if tool_calls:
|
|
668
|
-
content = ""
|
|
669
|
-
|
|
670
|
-
message = CompletionMessage(
|
|
671
|
-
content=content,
|
|
672
|
-
stop_reason=stop_reason,
|
|
673
|
-
tool_calls=tool_calls,
|
|
674
|
-
)
|
|
675
|
-
|
|
676
|
-
yield AgentTurnResponseStreamChunk(
|
|
677
|
-
event=AgentTurnResponseEvent(
|
|
678
|
-
payload=AgentTurnResponseStepCompletePayload(
|
|
679
|
-
step_type=StepType.inference.value,
|
|
680
|
-
step_id=step_id,
|
|
681
|
-
step_details=InferenceStep(
|
|
682
|
-
# somewhere deep, we are re-assigning message or closing over some
|
|
683
|
-
# variable which causes message to mutate later on. fix with a
|
|
684
|
-
# `deepcopy` for now, but this is symptomatic of a deeper issue.
|
|
685
|
-
step_id=step_id,
|
|
686
|
-
turn_id=turn_id,
|
|
687
|
-
model_response=copy.deepcopy(message),
|
|
688
|
-
started_at=inference_start_time,
|
|
689
|
-
completed_at=datetime.now(UTC).isoformat(),
|
|
690
|
-
),
|
|
691
|
-
)
|
|
692
|
-
)
|
|
693
|
-
)
|
|
694
|
-
|
|
695
|
-
if n_iter >= self.agent_config.max_infer_iters:
|
|
696
|
-
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
|
697
|
-
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
|
698
|
-
# Do not continue the tool call loop after this point
|
|
699
|
-
message.stop_reason = StopReason.end_of_turn
|
|
700
|
-
yield message
|
|
701
|
-
break
|
|
702
|
-
|
|
703
|
-
if stop_reason == StopReason.out_of_tokens:
|
|
704
|
-
logger.info("out of token budget, exiting.")
|
|
705
|
-
yield message
|
|
706
|
-
break
|
|
707
|
-
|
|
708
|
-
if len(message.tool_calls) == 0:
|
|
709
|
-
if stop_reason == StopReason.end_of_turn:
|
|
710
|
-
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
|
711
|
-
if len(output_attachments) > 0:
|
|
712
|
-
if isinstance(message.content, list):
|
|
713
|
-
message.content += output_attachments
|
|
714
|
-
else:
|
|
715
|
-
message.content = [message.content] + output_attachments
|
|
716
|
-
yield message
|
|
717
|
-
else:
|
|
718
|
-
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
|
719
|
-
input_messages = input_messages + [message]
|
|
720
|
-
else:
|
|
721
|
-
input_messages = input_messages + [message]
|
|
722
|
-
|
|
723
|
-
# Process tool calls in the message
|
|
724
|
-
client_tool_calls = []
|
|
725
|
-
non_client_tool_calls = []
|
|
726
|
-
|
|
727
|
-
# Separate client and non-client tool calls
|
|
728
|
-
for tool_call in message.tool_calls:
|
|
729
|
-
if tool_call.tool_name in client_tools:
|
|
730
|
-
client_tool_calls.append(tool_call)
|
|
731
|
-
else:
|
|
732
|
-
non_client_tool_calls.append(tool_call)
|
|
733
|
-
|
|
734
|
-
# Process non-client tool calls first
|
|
735
|
-
for tool_call in non_client_tool_calls:
|
|
736
|
-
step_id = str(uuid.uuid4())
|
|
737
|
-
yield AgentTurnResponseStreamChunk(
|
|
738
|
-
event=AgentTurnResponseEvent(
|
|
739
|
-
payload=AgentTurnResponseStepStartPayload(
|
|
740
|
-
step_type=StepType.tool_execution.value,
|
|
741
|
-
step_id=step_id,
|
|
742
|
-
)
|
|
743
|
-
)
|
|
744
|
-
)
|
|
745
|
-
|
|
746
|
-
yield AgentTurnResponseStreamChunk(
|
|
747
|
-
event=AgentTurnResponseEvent(
|
|
748
|
-
payload=AgentTurnResponseStepProgressPayload(
|
|
749
|
-
step_type=StepType.tool_execution.value,
|
|
750
|
-
step_id=step_id,
|
|
751
|
-
delta=ToolCallDelta(
|
|
752
|
-
parse_status=ToolCallParseStatus.in_progress,
|
|
753
|
-
tool_call=tool_call,
|
|
754
|
-
),
|
|
755
|
-
)
|
|
756
|
-
)
|
|
757
|
-
)
|
|
758
|
-
|
|
759
|
-
# Execute the tool call
|
|
760
|
-
async with tracing.span(
|
|
761
|
-
"tool_execution",
|
|
762
|
-
{
|
|
763
|
-
"tool_name": tool_call.tool_name,
|
|
764
|
-
"input": message.model_dump_json(),
|
|
765
|
-
}
|
|
766
|
-
if self.telemetry_enabled
|
|
767
|
-
else {},
|
|
768
|
-
) as span:
|
|
769
|
-
tool_execution_start_time = datetime.now(UTC).isoformat()
|
|
770
|
-
tool_result = await self.execute_tool_call_maybe(
|
|
771
|
-
session_id,
|
|
772
|
-
tool_call,
|
|
773
|
-
)
|
|
774
|
-
if tool_result.content is None:
|
|
775
|
-
raise ValueError(
|
|
776
|
-
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
|
777
|
-
)
|
|
778
|
-
result_message = ToolResponseMessage(
|
|
779
|
-
call_id=tool_call.call_id,
|
|
780
|
-
content=tool_result.content,
|
|
781
|
-
)
|
|
782
|
-
if self.telemetry_enabled and span is not None:
|
|
783
|
-
span.set_attribute("output", result_message.model_dump_json())
|
|
784
|
-
|
|
785
|
-
# Store tool execution step
|
|
786
|
-
tool_execution_step = ToolExecutionStep(
|
|
787
|
-
step_id=step_id,
|
|
788
|
-
turn_id=turn_id,
|
|
789
|
-
tool_calls=[tool_call],
|
|
790
|
-
tool_responses=[
|
|
791
|
-
ToolResponse(
|
|
792
|
-
call_id=tool_call.call_id,
|
|
793
|
-
tool_name=tool_call.tool_name,
|
|
794
|
-
content=tool_result.content,
|
|
795
|
-
metadata=tool_result.metadata,
|
|
796
|
-
)
|
|
797
|
-
],
|
|
798
|
-
started_at=tool_execution_start_time,
|
|
799
|
-
completed_at=datetime.now(UTC).isoformat(),
|
|
800
|
-
)
|
|
801
|
-
|
|
802
|
-
# Yield the step completion event
|
|
803
|
-
yield AgentTurnResponseStreamChunk(
|
|
804
|
-
event=AgentTurnResponseEvent(
|
|
805
|
-
payload=AgentTurnResponseStepCompletePayload(
|
|
806
|
-
step_type=StepType.tool_execution.value,
|
|
807
|
-
step_id=step_id,
|
|
808
|
-
step_details=tool_execution_step,
|
|
809
|
-
)
|
|
810
|
-
)
|
|
811
|
-
)
|
|
812
|
-
|
|
813
|
-
# Add the result message to input_messages for the next iteration
|
|
814
|
-
input_messages.append(result_message)
|
|
815
|
-
|
|
816
|
-
# TODO: add tool-input touchpoint and a "start" event for this step also
|
|
817
|
-
# but that needs a lot more refactoring of Tool code potentially
|
|
818
|
-
if (type(result_message.content) is str) and (
|
|
819
|
-
out_attachment := _interpret_content_as_attachment(result_message.content)
|
|
820
|
-
):
|
|
821
|
-
# NOTE: when we push this message back to the model, the model may ignore the
|
|
822
|
-
# attached file path etc. since the model is trained to only provide a user message
|
|
823
|
-
# with the summary. We keep all generated attachments and then attach them to final message
|
|
824
|
-
output_attachments.append(out_attachment)
|
|
825
|
-
|
|
826
|
-
# If there are client tool calls, yield a message with only those tool calls
|
|
827
|
-
if client_tool_calls:
|
|
828
|
-
await self.storage.set_in_progress_tool_call_step(
|
|
829
|
-
session_id,
|
|
830
|
-
turn_id,
|
|
831
|
-
ToolExecutionStep(
|
|
832
|
-
step_id=step_id,
|
|
833
|
-
turn_id=turn_id,
|
|
834
|
-
tool_calls=client_tool_calls,
|
|
835
|
-
tool_responses=[],
|
|
836
|
-
started_at=datetime.now(UTC).isoformat(),
|
|
837
|
-
),
|
|
838
|
-
)
|
|
839
|
-
|
|
840
|
-
# Create a copy of the message with only client tool calls
|
|
841
|
-
client_message = message.model_copy(deep=True)
|
|
842
|
-
client_message.tool_calls = client_tool_calls
|
|
843
|
-
# NOTE: mark end_of_message to indicate to client that it may
|
|
844
|
-
# call the tool and continue the conversation with the tool's response.
|
|
845
|
-
client_message.stop_reason = StopReason.end_of_message
|
|
846
|
-
|
|
847
|
-
# Yield the message with client tool calls
|
|
848
|
-
yield client_message
|
|
849
|
-
return
|
|
850
|
-
|
|
851
|
-
async def _initialize_tools(
|
|
852
|
-
self,
|
|
853
|
-
toolgroups_for_turn: list[AgentToolGroup] | None = None,
|
|
854
|
-
) -> None:
|
|
855
|
-
toolgroup_to_args = {}
|
|
856
|
-
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
|
857
|
-
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
|
858
|
-
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
|
859
|
-
toolgroup_to_args[tool_group_name] = toolgroup.args
|
|
860
|
-
|
|
861
|
-
# Determine which tools to include
|
|
862
|
-
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
|
863
|
-
agent_config_toolgroups = []
|
|
864
|
-
for toolgroup in tool_groups_to_include:
|
|
865
|
-
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
|
|
866
|
-
if name not in agent_config_toolgroups:
|
|
867
|
-
agent_config_toolgroups.append(name)
|
|
868
|
-
|
|
869
|
-
toolgroup_to_args = toolgroup_to_args or {}
|
|
870
|
-
|
|
871
|
-
tool_name_to_def = {}
|
|
872
|
-
tool_name_to_args = {}
|
|
873
|
-
|
|
874
|
-
for tool_def in self.agent_config.client_tools:
|
|
875
|
-
if tool_name_to_def.get(tool_def.name, None):
|
|
876
|
-
raise ValueError(f"Tool {tool_def.name} already exists")
|
|
877
|
-
|
|
878
|
-
# Use input_schema from ToolDef directly
|
|
879
|
-
tool_name_to_def[tool_def.name] = ToolDefinition(
|
|
880
|
-
tool_name=tool_def.name,
|
|
881
|
-
description=tool_def.description,
|
|
882
|
-
input_schema=tool_def.input_schema,
|
|
883
|
-
)
|
|
884
|
-
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
|
885
|
-
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
|
886
|
-
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
|
887
|
-
if not tools.data:
|
|
888
|
-
available_tool_groups = ", ".join(
|
|
889
|
-
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
|
890
|
-
)
|
|
891
|
-
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
|
892
|
-
if input_tool_name is not None and not any(tool.name == input_tool_name for tool in tools.data):
|
|
893
|
-
raise ValueError(
|
|
894
|
-
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.name for tool in tools.data])}"
|
|
895
|
-
)
|
|
896
|
-
|
|
897
|
-
for tool_def in tools.data:
|
|
898
|
-
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
|
899
|
-
identifier: str | BuiltinTool | None = tool_def.name
|
|
900
|
-
if identifier == "web_search":
|
|
901
|
-
identifier = BuiltinTool.brave_search
|
|
902
|
-
else:
|
|
903
|
-
identifier = BuiltinTool(identifier)
|
|
904
|
-
else:
|
|
905
|
-
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
|
|
906
|
-
if input_tool_name in (None, tool_def.name):
|
|
907
|
-
identifier = tool_def.name
|
|
908
|
-
else:
|
|
909
|
-
identifier = None
|
|
910
|
-
|
|
911
|
-
if tool_name_to_def.get(identifier, None):
|
|
912
|
-
raise ValueError(f"Tool {identifier} already exists")
|
|
913
|
-
if identifier:
|
|
914
|
-
tool_name_to_def[identifier] = ToolDefinition(
|
|
915
|
-
tool_name=identifier,
|
|
916
|
-
description=tool_def.description,
|
|
917
|
-
input_schema=tool_def.input_schema,
|
|
918
|
-
)
|
|
919
|
-
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
|
920
|
-
|
|
921
|
-
self.tool_defs, self.tool_name_to_args = (
|
|
922
|
-
list(tool_name_to_def.values()),
|
|
923
|
-
tool_name_to_args,
|
|
924
|
-
)
|
|
925
|
-
|
|
926
|
-
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, str | None]:
|
|
927
|
-
"""Parse a toolgroup name into its components.
|
|
928
|
-
|
|
929
|
-
Args:
|
|
930
|
-
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
|
|
931
|
-
|
|
932
|
-
Returns:
|
|
933
|
-
A tuple of (tool_type, tool_group, tool_name)
|
|
934
|
-
"""
|
|
935
|
-
split_names = toolgroup_name_with_maybe_tool_name.split("/")
|
|
936
|
-
if len(split_names) == 2:
|
|
937
|
-
# e.g. "builtin::rag"
|
|
938
|
-
tool_group, tool_name = split_names
|
|
939
|
-
else:
|
|
940
|
-
tool_group, tool_name = split_names[0], None
|
|
941
|
-
return tool_group, tool_name
|
|
942
|
-
|
|
943
|
-
async def execute_tool_call_maybe(
|
|
944
|
-
self,
|
|
945
|
-
session_id: str,
|
|
946
|
-
tool_call: ToolCall,
|
|
947
|
-
) -> ToolInvocationResult:
|
|
948
|
-
tool_name = tool_call.tool_name
|
|
949
|
-
registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs]
|
|
950
|
-
if tool_name not in registered_tool_names:
|
|
951
|
-
raise ValueError(
|
|
952
|
-
f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
|
|
953
|
-
)
|
|
954
|
-
if isinstance(tool_name, BuiltinTool):
|
|
955
|
-
if tool_name == BuiltinTool.brave_search:
|
|
956
|
-
tool_name_str = WEB_SEARCH_TOOL
|
|
957
|
-
else:
|
|
958
|
-
tool_name_str = tool_name.value
|
|
959
|
-
else:
|
|
960
|
-
tool_name_str = tool_name
|
|
961
|
-
|
|
962
|
-
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
|
963
|
-
|
|
964
|
-
try:
|
|
965
|
-
args = json.loads(tool_call.arguments)
|
|
966
|
-
except json.JSONDecodeError as e:
|
|
967
|
-
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
|
|
968
|
-
|
|
969
|
-
result = await self.tool_runtime_api.invoke_tool(
|
|
970
|
-
tool_name=tool_name_str,
|
|
971
|
-
kwargs={
|
|
972
|
-
"session_id": session_id,
|
|
973
|
-
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
|
974
|
-
**args,
|
|
975
|
-
**self.tool_name_to_args.get(tool_name_str, {}),
|
|
976
|
-
},
|
|
977
|
-
)
|
|
978
|
-
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
|
979
|
-
return result
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
async def load_data_from_url(url: str) -> str:
|
|
983
|
-
if url.startswith("http"):
|
|
984
|
-
async with httpx.AsyncClient() as client:
|
|
985
|
-
r = await client.get(url)
|
|
986
|
-
resp = r.text
|
|
987
|
-
return resp
|
|
988
|
-
raise ValueError(f"Unexpected URL: {type(url)}")
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
async def get_raw_document_text(document: Document) -> str:
|
|
992
|
-
# Handle deprecated text/yaml mime type with warning
|
|
993
|
-
if document.mime_type == "text/yaml":
|
|
994
|
-
warnings.warn(
|
|
995
|
-
"The 'text/yaml' MIME type is deprecated. Please use 'application/yaml' instead.",
|
|
996
|
-
DeprecationWarning,
|
|
997
|
-
stacklevel=2,
|
|
998
|
-
)
|
|
999
|
-
elif not (document.mime_type.startswith("text/") or document.mime_type in ("application/yaml", "application/json")):
|
|
1000
|
-
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
|
|
1001
|
-
|
|
1002
|
-
if isinstance(document.content, URL):
|
|
1003
|
-
return await load_data_from_url(document.content.uri)
|
|
1004
|
-
elif isinstance(document.content, str):
|
|
1005
|
-
return document.content
|
|
1006
|
-
elif isinstance(document.content, TextContentItem):
|
|
1007
|
-
return document.content.text
|
|
1008
|
-
else:
|
|
1009
|
-
raise ValueError(f"Unexpected document content type: {type(document.content)}")
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
def _interpret_content_as_attachment(
|
|
1013
|
-
content: str,
|
|
1014
|
-
) -> Attachment | None:
|
|
1015
|
-
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
|
1016
|
-
if match:
|
|
1017
|
-
snippet = match.group(1)
|
|
1018
|
-
data = json.loads(snippet)
|
|
1019
|
-
return Attachment(
|
|
1020
|
-
url=URL(uri="file://" + data["filepath"]),
|
|
1021
|
-
mime_type=data["mimetype"],
|
|
1022
|
-
)
|
|
1023
|
-
|
|
1024
|
-
return None
|