google-adk 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/__init__.py +20 -0
- google/adk/agents/__init__.py +32 -0
- google/adk/agents/active_streaming_tool.py +38 -0
- google/adk/agents/base_agent.py +345 -0
- google/adk/agents/callback_context.py +112 -0
- google/adk/agents/invocation_context.py +181 -0
- google/adk/agents/langgraph_agent.py +140 -0
- google/adk/agents/live_request_queue.py +64 -0
- google/adk/agents/llm_agent.py +376 -0
- google/adk/agents/loop_agent.py +62 -0
- google/adk/agents/parallel_agent.py +96 -0
- google/adk/agents/readonly_context.py +46 -0
- google/adk/agents/remote_agent.py +50 -0
- google/adk/agents/run_config.py +87 -0
- google/adk/agents/sequential_agent.py +45 -0
- google/adk/agents/transcription_entry.py +34 -0
- google/adk/artifacts/__init__.py +23 -0
- google/adk/artifacts/base_artifact_service.py +128 -0
- google/adk/artifacts/gcs_artifact_service.py +195 -0
- google/adk/artifacts/in_memory_artifact_service.py +133 -0
- google/adk/auth/__init__.py +22 -0
- google/adk/auth/auth_credential.py +220 -0
- google/adk/auth/auth_handler.py +268 -0
- google/adk/auth/auth_preprocessor.py +116 -0
- google/adk/auth/auth_schemes.py +67 -0
- google/adk/auth/auth_tool.py +55 -0
- google/adk/cli/__init__.py +15 -0
- google/adk/cli/__main__.py +18 -0
- google/adk/cli/agent_graph.py +122 -0
- google/adk/cli/browser/adk_favicon.svg +17 -0
- google/adk/cli/browser/assets/audio-processor.js +51 -0
- google/adk/cli/browser/assets/config/runtime-config.json +3 -0
- google/adk/cli/browser/index.html +33 -0
- google/adk/cli/browser/main-XUU6OGCC.js +75 -0
- google/adk/cli/browser/polyfills-FFHMD2TL.js +18 -0
- google/adk/cli/browser/styles-4VDSPQ37.css +17 -0
- google/adk/cli/cli.py +181 -0
- google/adk/cli/cli_deploy.py +181 -0
- google/adk/cli/cli_eval.py +282 -0
- google/adk/cli/cli_tools_click.py +479 -0
- google/adk/cli/fast_api.py +774 -0
- google/adk/cli/media_streamer/__init__.py +19 -0
- google/adk/cli/media_streamer/index.html +228 -0
- google/adk/cli/utils/__init__.py +49 -0
- google/adk/cli/utils/envs.py +57 -0
- google/adk/cli/utils/evals.py +93 -0
- google/adk/cli/utils/logs.py +72 -0
- google/adk/code_executors/__init__.py +49 -0
- google/adk/code_executors/base_code_executor.py +97 -0
- google/adk/code_executors/code_execution_utils.py +256 -0
- google/adk/code_executors/code_executor_context.py +202 -0
- google/adk/code_executors/container_code_executor.py +196 -0
- google/adk/code_executors/unsafe_local_code_executor.py +71 -0
- google/adk/code_executors/vertex_ai_code_executor.py +234 -0
- google/adk/evaluation/__init__.py +31 -0
- google/adk/evaluation/agent_evaluator.py +329 -0
- google/adk/evaluation/evaluation_constants.py +24 -0
- google/adk/evaluation/evaluation_generator.py +270 -0
- google/adk/evaluation/response_evaluator.py +135 -0
- google/adk/evaluation/trajectory_evaluator.py +184 -0
- google/adk/events/__init__.py +21 -0
- google/adk/events/event.py +130 -0
- google/adk/events/event_actions.py +55 -0
- google/adk/examples/__init__.py +28 -0
- google/adk/examples/base_example_provider.py +35 -0
- google/adk/examples/example.py +27 -0
- google/adk/examples/example_util.py +123 -0
- google/adk/examples/vertex_ai_example_store.py +104 -0
- google/adk/flows/__init__.py +14 -0
- google/adk/flows/llm_flows/__init__.py +20 -0
- google/adk/flows/llm_flows/_base_llm_processor.py +52 -0
- google/adk/flows/llm_flows/_code_execution.py +458 -0
- google/adk/flows/llm_flows/_nl_planning.py +129 -0
- google/adk/flows/llm_flows/agent_transfer.py +132 -0
- google/adk/flows/llm_flows/audio_transcriber.py +109 -0
- google/adk/flows/llm_flows/auto_flow.py +49 -0
- google/adk/flows/llm_flows/base_llm_flow.py +559 -0
- google/adk/flows/llm_flows/basic.py +72 -0
- google/adk/flows/llm_flows/contents.py +370 -0
- google/adk/flows/llm_flows/functions.py +486 -0
- google/adk/flows/llm_flows/identity.py +47 -0
- google/adk/flows/llm_flows/instructions.py +137 -0
- google/adk/flows/llm_flows/single_flow.py +57 -0
- google/adk/memory/__init__.py +35 -0
- google/adk/memory/base_memory_service.py +74 -0
- google/adk/memory/in_memory_memory_service.py +62 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +177 -0
- google/adk/models/__init__.py +31 -0
- google/adk/models/anthropic_llm.py +243 -0
- google/adk/models/base_llm.py +87 -0
- google/adk/models/base_llm_connection.py +76 -0
- google/adk/models/gemini_llm_connection.py +200 -0
- google/adk/models/google_llm.py +331 -0
- google/adk/models/lite_llm.py +673 -0
- google/adk/models/llm_request.py +98 -0
- google/adk/models/llm_response.py +111 -0
- google/adk/models/registry.py +102 -0
- google/adk/planners/__init__.py +23 -0
- google/adk/planners/base_planner.py +66 -0
- google/adk/planners/built_in_planner.py +75 -0
- google/adk/planners/plan_re_act_planner.py +208 -0
- google/adk/runners.py +456 -0
- google/adk/sessions/__init__.py +41 -0
- google/adk/sessions/base_session_service.py +133 -0
- google/adk/sessions/database_session_service.py +522 -0
- google/adk/sessions/in_memory_session_service.py +206 -0
- google/adk/sessions/session.py +54 -0
- google/adk/sessions/state.py +71 -0
- google/adk/sessions/vertex_ai_session_service.py +356 -0
- google/adk/telemetry.py +189 -0
- google/adk/tests/__init__.py +14 -0
- google/adk/tests/integration/.env.example +10 -0
- google/adk/tests/integration/__init__.py +18 -0
- google/adk/tests/integration/conftest.py +119 -0
- google/adk/tests/integration/fixture/__init__.py +14 -0
- google/adk/tests/integration/fixture/agent_with_config/__init__.py +15 -0
- google/adk/tests/integration/fixture/agent_with_config/agent.py +88 -0
- google/adk/tests/integration/fixture/callback_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/callback_agent/agent.py +105 -0
- google/adk/tests/integration/fixture/context_update_test/OWNERS +1 -0
- google/adk/tests/integration/fixture/context_update_test/__init__.py +15 -0
- google/adk/tests/integration/fixture/context_update_test/agent.py +43 -0
- google/adk/tests/integration/fixture/context_update_test/successful_test.session.json +582 -0
- google/adk/tests/integration/fixture/context_variable_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/context_variable_agent/agent.py +115 -0
- google/adk/tests/integration/fixture/customer_support_ma/__init__.py +15 -0
- google/adk/tests/integration/fixture/customer_support_ma/agent.py +172 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/agent.py +338 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/order_query.test.json +69 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/test_config.json +6 -0
- google/adk/tests/integration/fixture/flow_complex_spark/__init__.py +15 -0
- google/adk/tests/integration/fixture/flow_complex_spark/agent.py +182 -0
- google/adk/tests/integration/fixture/flow_complex_spark/sample.debug.log +243 -0
- google/adk/tests/integration/fixture/flow_complex_spark/sample.session.json +190 -0
- google/adk/tests/integration/fixture/hello_world_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/hello_world_agent/agent.py +95 -0
- google/adk/tests/integration/fixture/hello_world_agent/roll_die.test.json +24 -0
- google/adk/tests/integration/fixture/hello_world_agent/test_config.json +6 -0
- google/adk/tests/integration/fixture/home_automation_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/home_automation_agent/agent.py +304 -0
- google/adk/tests/integration/fixture/home_automation_agent/simple_test.test.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/simple_test2.test.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_config.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/dependent_tool_calls.test.json +18 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json +17 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/test_config.json +6 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_multi_turn_conversation.test.json +18 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test.test.json +17 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test2.test.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/test_config.json +5 -0
- google/adk/tests/integration/fixture/tool_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/tool_agent/agent.py +218 -0
- google/adk/tests/integration/fixture/tool_agent/files/Agent_test_plan.pdf +0 -0
- google/adk/tests/integration/fixture/trip_planner_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/trip_planner_agent/agent.py +110 -0
- google/adk/tests/integration/fixture/trip_planner_agent/initial.session.json +13 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_config.json +5 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_files/initial.session.json +13 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_files/test_config.json +5 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_files/trip_inquiry_sub_agent.test.json +7 -0
- google/adk/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json +19 -0
- google/adk/tests/integration/models/__init__.py +14 -0
- google/adk/tests/integration/models/test_google_llm.py +65 -0
- google/adk/tests/integration/test_callback.py +70 -0
- google/adk/tests/integration/test_context_variable.py +67 -0
- google/adk/tests/integration/test_evalute_agent_in_fixture.py +76 -0
- google/adk/tests/integration/test_multi_agent.py +28 -0
- google/adk/tests/integration/test_multi_turn.py +42 -0
- google/adk/tests/integration/test_single_agent.py +23 -0
- google/adk/tests/integration/test_sub_agent.py +26 -0
- google/adk/tests/integration/test_system_instruction.py +177 -0
- google/adk/tests/integration/test_tools.py +287 -0
- google/adk/tests/integration/test_with_test_file.py +34 -0
- google/adk/tests/integration/tools/__init__.py +14 -0
- google/adk/tests/integration/utils/__init__.py +16 -0
- google/adk/tests/integration/utils/asserts.py +75 -0
- google/adk/tests/integration/utils/test_runner.py +97 -0
- google/adk/tests/unittests/__init__.py +14 -0
- google/adk/tests/unittests/agents/__init__.py +14 -0
- google/adk/tests/unittests/agents/test_base_agent.py +407 -0
- google/adk/tests/unittests/agents/test_langgraph_agent.py +191 -0
- google/adk/tests/unittests/agents/test_llm_agent_callbacks.py +138 -0
- google/adk/tests/unittests/agents/test_llm_agent_fields.py +231 -0
- google/adk/tests/unittests/agents/test_loop_agent.py +136 -0
- google/adk/tests/unittests/agents/test_parallel_agent.py +92 -0
- google/adk/tests/unittests/agents/test_sequential_agent.py +114 -0
- google/adk/tests/unittests/artifacts/__init__.py +14 -0
- google/adk/tests/unittests/artifacts/test_artifact_service.py +276 -0
- google/adk/tests/unittests/auth/test_auth_handler.py +575 -0
- google/adk/tests/unittests/conftest.py +73 -0
- google/adk/tests/unittests/fast_api/__init__.py +14 -0
- google/adk/tests/unittests/fast_api/test_fast_api.py +269 -0
- google/adk/tests/unittests/flows/__init__.py +14 -0
- google/adk/tests/unittests/flows/llm_flows/__init__.py +14 -0
- google/adk/tests/unittests/flows/llm_flows/_test_examples.py +142 -0
- google/adk/tests/unittests/flows/llm_flows/test_agent_transfer.py +311 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_long_running.py +244 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_request_euc.py +346 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_sequential.py +93 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_simple.py +258 -0
- google/adk/tests/unittests/flows/llm_flows/test_identity.py +66 -0
- google/adk/tests/unittests/flows/llm_flows/test_instructions.py +164 -0
- google/adk/tests/unittests/flows/llm_flows/test_model_callbacks.py +142 -0
- google/adk/tests/unittests/flows/llm_flows/test_other_configs.py +46 -0
- google/adk/tests/unittests/flows/llm_flows/test_tool_callbacks.py +269 -0
- google/adk/tests/unittests/models/__init__.py +14 -0
- google/adk/tests/unittests/models/test_google_llm.py +224 -0
- google/adk/tests/unittests/models/test_litellm.py +804 -0
- google/adk/tests/unittests/models/test_models.py +60 -0
- google/adk/tests/unittests/sessions/__init__.py +14 -0
- google/adk/tests/unittests/sessions/test_session_service.py +227 -0
- google/adk/tests/unittests/sessions/test_vertex_ai_session_service.py +246 -0
- google/adk/tests/unittests/streaming/__init__.py +14 -0
- google/adk/tests/unittests/streaming/test_streaming.py +50 -0
- google/adk/tests/unittests/tools/__init__.py +14 -0
- google/adk/tests/unittests/tools/apihub_tool/clients/test_apihub_client.py +499 -0
- google/adk/tests/unittests/tools/apihub_tool/test_apihub_toolset.py +204 -0
- google/adk/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py +600 -0
- google/adk/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py +630 -0
- google/adk/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py +345 -0
- google/adk/tests/unittests/tools/google_api_tool/__init__.py +13 -0
- google/adk/tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py +657 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_auto_auth_credential_exchanger.py +145 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_base_auth_credential_exchanger.py +68 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py +153 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py +196 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/test_auth_helper.py +573 -0
- google/adk/tests/unittests/tools/openapi_tool/common/test_common.py +436 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml +1367 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_spec_parser.py +628 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py +139 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py +406 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +966 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py +201 -0
- google/adk/tests/unittests/tools/retrieval/__init__.py +14 -0
- google/adk/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py +147 -0
- google/adk/tests/unittests/tools/test_agent_tool.py +167 -0
- google/adk/tests/unittests/tools/test_base_tool.py +141 -0
- google/adk/tests/unittests/tools/test_build_function_declaration.py +277 -0
- google/adk/tests/unittests/utils.py +304 -0
- google/adk/tools/__init__.py +51 -0
- google/adk/tools/_automatic_function_calling_util.py +346 -0
- google/adk/tools/agent_tool.py +176 -0
- google/adk/tools/apihub_tool/__init__.py +19 -0
- google/adk/tools/apihub_tool/apihub_toolset.py +209 -0
- google/adk/tools/apihub_tool/clients/__init__.py +13 -0
- google/adk/tools/apihub_tool/clients/apihub_client.py +332 -0
- google/adk/tools/apihub_tool/clients/secret_client.py +115 -0
- google/adk/tools/application_integration_tool/__init__.py +19 -0
- google/adk/tools/application_integration_tool/application_integration_toolset.py +230 -0
- google/adk/tools/application_integration_tool/clients/connections_client.py +903 -0
- google/adk/tools/application_integration_tool/clients/integration_client.py +253 -0
- google/adk/tools/base_tool.py +144 -0
- google/adk/tools/built_in_code_execution_tool.py +59 -0
- google/adk/tools/crewai_tool.py +72 -0
- google/adk/tools/example_tool.py +62 -0
- google/adk/tools/exit_loop_tool.py +23 -0
- google/adk/tools/function_parameter_parse_util.py +307 -0
- google/adk/tools/function_tool.py +87 -0
- google/adk/tools/get_user_choice_tool.py +28 -0
- google/adk/tools/google_api_tool/__init__.py +14 -0
- google/adk/tools/google_api_tool/google_api_tool.py +59 -0
- google/adk/tools/google_api_tool/google_api_tool_set.py +107 -0
- google/adk/tools/google_api_tool/google_api_tool_sets.py +55 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +521 -0
- google/adk/tools/google_search_tool.py +68 -0
- google/adk/tools/langchain_tool.py +86 -0
- google/adk/tools/load_artifacts_tool.py +113 -0
- google/adk/tools/load_memory_tool.py +58 -0
- google/adk/tools/load_web_page.py +41 -0
- google/adk/tools/long_running_tool.py +39 -0
- google/adk/tools/mcp_tool/__init__.py +42 -0
- google/adk/tools/mcp_tool/conversion_utils.py +161 -0
- google/adk/tools/mcp_tool/mcp_tool.py +113 -0
- google/adk/tools/mcp_tool/mcp_toolset.py +272 -0
- google/adk/tools/openapi_tool/__init__.py +21 -0
- google/adk/tools/openapi_tool/auth/__init__.py +19 -0
- google/adk/tools/openapi_tool/auth/auth_helpers.py +498 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/__init__.py +25 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/auto_auth_credential_exchanger.py +105 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +55 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py +117 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +97 -0
- google/adk/tools/openapi_tool/common/__init__.py +19 -0
- google/adk/tools/openapi_tool/common/common.py +300 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +32 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py +231 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +144 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +260 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +496 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +268 -0
- google/adk/tools/preload_memory_tool.py +72 -0
- google/adk/tools/retrieval/__init__.py +36 -0
- google/adk/tools/retrieval/base_retrieval_tool.py +37 -0
- google/adk/tools/retrieval/files_retrieval.py +33 -0
- google/adk/tools/retrieval/llama_index_retrieval.py +41 -0
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +107 -0
- google/adk/tools/tool_context.py +90 -0
- google/adk/tools/toolbox_tool.py +46 -0
- google/adk/tools/transfer_to_agent_tool.py +21 -0
- google/adk/tools/vertex_ai_search_tool.py +96 -0
- google/adk/version.py +16 -0
- google_adk-0.0.1.dist-info/LICENSE.txt → google_adk-0.0.2.dist-info/LICENSE +32 -0
- google_adk-0.0.2.dist-info/METADATA +73 -0
- google_adk-0.0.2.dist-info/RECORD +308 -0
- {google_adk-0.0.1.dist-info → google_adk-0.0.2.dist-info}/WHEEL +1 -2
- google_adk-0.0.2.dist-info/entry_points.txt +3 -0
- agent_kit/__init__.py +0 -0
- google_adk-0.0.1.dist-info/METADATA +0 -15
- google_adk-0.0.1.dist-info/RECORD +0 -6
- google_adk-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,804 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from unittest.mock import AsyncMock
|
17
|
+
from unittest.mock import Mock
|
18
|
+
from google.adk.models.lite_llm import _content_to_message_param
|
19
|
+
from google.adk.models.lite_llm import _function_declaration_to_tool_param
|
20
|
+
from google.adk.models.lite_llm import _get_content
|
21
|
+
from google.adk.models.lite_llm import _message_to_generate_content_response
|
22
|
+
from google.adk.models.lite_llm import _model_response_to_chunk
|
23
|
+
from google.adk.models.lite_llm import _to_litellm_role
|
24
|
+
from google.adk.models.lite_llm import FunctionChunk
|
25
|
+
from google.adk.models.lite_llm import LiteLlm
|
26
|
+
from google.adk.models.lite_llm import LiteLLMClient
|
27
|
+
from google.adk.models.lite_llm import TextChunk
|
28
|
+
from google.adk.models.llm_request import LlmRequest
|
29
|
+
from google.genai import types
|
30
|
+
from litellm import ChatCompletionAssistantMessage
|
31
|
+
from litellm import ChatCompletionMessageToolCall
|
32
|
+
from litellm import Function
|
33
|
+
from litellm.types.utils import ChatCompletionDeltaToolCall
|
34
|
+
from litellm.types.utils import Choices
|
35
|
+
from litellm.types.utils import Delta
|
36
|
+
from litellm.types.utils import ModelResponse
|
37
|
+
from litellm.types.utils import StreamingChoices
|
38
|
+
import pytest
|
39
|
+
|
40
|
+
LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest(
|
41
|
+
contents=[
|
42
|
+
types.Content(
|
43
|
+
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
44
|
+
)
|
45
|
+
],
|
46
|
+
config=types.GenerateContentConfig(
|
47
|
+
tools=[
|
48
|
+
types.Tool(
|
49
|
+
function_declarations=[
|
50
|
+
types.FunctionDeclaration(
|
51
|
+
name="test_function",
|
52
|
+
description="Test function description",
|
53
|
+
parameters=types.Schema(
|
54
|
+
type=types.Type.OBJECT,
|
55
|
+
properties={
|
56
|
+
"test_arg": types.Schema(
|
57
|
+
type=types.Type.STRING
|
58
|
+
),
|
59
|
+
"array_arg": types.Schema(
|
60
|
+
type=types.Type.ARRAY,
|
61
|
+
items={
|
62
|
+
"type": types.Type.STRING,
|
63
|
+
},
|
64
|
+
),
|
65
|
+
"nested_arg": types.Schema(
|
66
|
+
type=types.Type.OBJECT,
|
67
|
+
properties={
|
68
|
+
"nested_key1": types.Schema(
|
69
|
+
type=types.Type.STRING
|
70
|
+
),
|
71
|
+
"nested_key2": types.Schema(
|
72
|
+
type=types.Type.STRING
|
73
|
+
),
|
74
|
+
},
|
75
|
+
),
|
76
|
+
},
|
77
|
+
),
|
78
|
+
)
|
79
|
+
]
|
80
|
+
)
|
81
|
+
],
|
82
|
+
),
|
83
|
+
)
|
84
|
+
|
85
|
+
|
86
|
+
STREAMING_MODEL_RESPONSE = [
|
87
|
+
ModelResponse(
|
88
|
+
choices=[
|
89
|
+
StreamingChoices(
|
90
|
+
finish_reason=None,
|
91
|
+
delta=Delta(
|
92
|
+
role="assistant",
|
93
|
+
content="zero, ",
|
94
|
+
),
|
95
|
+
)
|
96
|
+
]
|
97
|
+
),
|
98
|
+
ModelResponse(
|
99
|
+
choices=[
|
100
|
+
StreamingChoices(
|
101
|
+
finish_reason=None,
|
102
|
+
delta=Delta(
|
103
|
+
role="assistant",
|
104
|
+
content="one, ",
|
105
|
+
),
|
106
|
+
)
|
107
|
+
]
|
108
|
+
),
|
109
|
+
ModelResponse(
|
110
|
+
choices=[
|
111
|
+
StreamingChoices(
|
112
|
+
finish_reason=None,
|
113
|
+
delta=Delta(
|
114
|
+
role="assistant",
|
115
|
+
content="two:",
|
116
|
+
),
|
117
|
+
)
|
118
|
+
]
|
119
|
+
),
|
120
|
+
ModelResponse(
|
121
|
+
choices=[
|
122
|
+
StreamingChoices(
|
123
|
+
finish_reason=None,
|
124
|
+
delta=Delta(
|
125
|
+
role="assistant",
|
126
|
+
tool_calls=[
|
127
|
+
ChatCompletionDeltaToolCall(
|
128
|
+
type="function",
|
129
|
+
id="test_tool_call_id",
|
130
|
+
function=Function(
|
131
|
+
name="test_function",
|
132
|
+
arguments='{"test_arg": "test_',
|
133
|
+
),
|
134
|
+
index=0,
|
135
|
+
)
|
136
|
+
],
|
137
|
+
),
|
138
|
+
)
|
139
|
+
]
|
140
|
+
),
|
141
|
+
ModelResponse(
|
142
|
+
choices=[
|
143
|
+
StreamingChoices(
|
144
|
+
finish_reason=None,
|
145
|
+
delta=Delta(
|
146
|
+
role="assistant",
|
147
|
+
tool_calls=[
|
148
|
+
ChatCompletionDeltaToolCall(
|
149
|
+
type="function",
|
150
|
+
id=None,
|
151
|
+
function=Function(
|
152
|
+
name=None,
|
153
|
+
arguments='value"}',
|
154
|
+
),
|
155
|
+
index=0,
|
156
|
+
)
|
157
|
+
],
|
158
|
+
),
|
159
|
+
)
|
160
|
+
]
|
161
|
+
),
|
162
|
+
ModelResponse(
|
163
|
+
choices=[
|
164
|
+
StreamingChoices(
|
165
|
+
finish_reason="tool_use",
|
166
|
+
)
|
167
|
+
]
|
168
|
+
),
|
169
|
+
]
|
170
|
+
|
171
|
+
@pytest.fixture
|
172
|
+
def mock_response():
|
173
|
+
return ModelResponse(
|
174
|
+
choices=[
|
175
|
+
Choices(
|
176
|
+
message=ChatCompletionAssistantMessage(
|
177
|
+
role="assistant",
|
178
|
+
content="Test response",
|
179
|
+
tool_calls=[
|
180
|
+
ChatCompletionMessageToolCall(
|
181
|
+
type="function",
|
182
|
+
id="test_tool_call_id",
|
183
|
+
function=Function(
|
184
|
+
name="test_function",
|
185
|
+
arguments='{"test_arg": "test_value"}',
|
186
|
+
),
|
187
|
+
)
|
188
|
+
],
|
189
|
+
)
|
190
|
+
)
|
191
|
+
]
|
192
|
+
)
|
193
|
+
|
194
|
+
|
195
|
+
@pytest.fixture
|
196
|
+
def mock_acompletion(mock_response):
|
197
|
+
return AsyncMock(return_value=mock_response)
|
198
|
+
|
199
|
+
|
200
|
+
@pytest.fixture
|
201
|
+
def mock_completion(mock_response):
|
202
|
+
return Mock(return_value=mock_response)
|
203
|
+
|
204
|
+
|
205
|
+
@pytest.fixture
|
206
|
+
def mock_client(mock_acompletion, mock_completion):
|
207
|
+
return MockLLMClient(mock_acompletion, mock_completion)
|
208
|
+
|
209
|
+
|
210
|
+
@pytest.fixture
|
211
|
+
def lite_llm_instance(mock_client):
|
212
|
+
return LiteLlm(model="test_model", llm_client=mock_client)
|
213
|
+
|
214
|
+
|
215
|
+
class MockLLMClient(LiteLLMClient):
|
216
|
+
|
217
|
+
def __init__(self, acompletion_mock, completion_mock):
|
218
|
+
self.acompletion_mock = acompletion_mock
|
219
|
+
self.completion_mock = completion_mock
|
220
|
+
|
221
|
+
async def acompletion(self, model, messages, tools, **kwargs):
|
222
|
+
return await self.acompletion_mock(
|
223
|
+
model=model, messages=messages, tools=tools, **kwargs
|
224
|
+
)
|
225
|
+
|
226
|
+
def completion(self, model, messages, tools, stream, **kwargs):
|
227
|
+
return self.completion_mock(
|
228
|
+
model=model, messages=messages, tools=tools, stream=stream, **kwargs
|
229
|
+
)
|
230
|
+
|
231
|
+
|
232
|
+
@pytest.mark.asyncio
|
233
|
+
async def test_generate_content_async(mock_acompletion, lite_llm_instance):
|
234
|
+
|
235
|
+
async for response in lite_llm_instance.generate_content_async(
|
236
|
+
LLM_REQUEST_WITH_FUNCTION_DECLARATION
|
237
|
+
):
|
238
|
+
assert response.content.role == "model"
|
239
|
+
assert response.content.parts[0].text == "Test response"
|
240
|
+
assert response.content.parts[1].function_call.name == "test_function"
|
241
|
+
assert response.content.parts[1].function_call.args == {
|
242
|
+
"test_arg": "test_value"
|
243
|
+
}
|
244
|
+
assert response.content.parts[1].function_call.id == "test_tool_call_id"
|
245
|
+
|
246
|
+
mock_acompletion.assert_called_once()
|
247
|
+
|
248
|
+
_, kwargs = mock_acompletion.call_args
|
249
|
+
assert kwargs["model"] == "test_model"
|
250
|
+
assert kwargs["messages"][0]["role"] == "user"
|
251
|
+
assert kwargs["messages"][0]["content"] == "Test prompt"
|
252
|
+
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
253
|
+
assert (
|
254
|
+
kwargs["tools"][0]["function"]["description"]
|
255
|
+
== "Test function description"
|
256
|
+
)
|
257
|
+
assert (
|
258
|
+
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
|
259
|
+
"type"
|
260
|
+
]
|
261
|
+
== "string"
|
262
|
+
)
|
263
|
+
|
264
|
+
|
265
|
+
function_declaration_test_cases = [
|
266
|
+
(
|
267
|
+
"simple_function",
|
268
|
+
types.FunctionDeclaration(
|
269
|
+
name="test_function",
|
270
|
+
description="Test function description",
|
271
|
+
parameters=types.Schema(
|
272
|
+
type=types.Type.OBJECT,
|
273
|
+
properties={
|
274
|
+
"test_arg": types.Schema(type=types.Type.STRING),
|
275
|
+
"array_arg": types.Schema(
|
276
|
+
type=types.Type.ARRAY,
|
277
|
+
items=types.Schema(
|
278
|
+
type=types.Type.STRING,
|
279
|
+
),
|
280
|
+
),
|
281
|
+
"nested_arg": types.Schema(
|
282
|
+
type=types.Type.OBJECT,
|
283
|
+
properties={
|
284
|
+
"nested_key1": types.Schema(type=types.Type.STRING),
|
285
|
+
"nested_key2": types.Schema(type=types.Type.STRING),
|
286
|
+
},
|
287
|
+
),
|
288
|
+
},
|
289
|
+
),
|
290
|
+
),
|
291
|
+
{
|
292
|
+
"type": "function",
|
293
|
+
"function": {
|
294
|
+
"name": "test_function",
|
295
|
+
"description": "Test function description",
|
296
|
+
"parameters": {
|
297
|
+
"type": "object",
|
298
|
+
"properties": {
|
299
|
+
"test_arg": {"type": "string"},
|
300
|
+
"array_arg": {
|
301
|
+
"items": {"type": "string"},
|
302
|
+
"type": "array",
|
303
|
+
},
|
304
|
+
"nested_arg": {
|
305
|
+
"properties": {
|
306
|
+
"nested_key1": {"type": "string"},
|
307
|
+
"nested_key2": {"type": "string"},
|
308
|
+
},
|
309
|
+
"type": "object",
|
310
|
+
},
|
311
|
+
},
|
312
|
+
},
|
313
|
+
},
|
314
|
+
},
|
315
|
+
),
|
316
|
+
(
|
317
|
+
"no_description",
|
318
|
+
types.FunctionDeclaration(
|
319
|
+
name="test_function_no_description",
|
320
|
+
parameters=types.Schema(
|
321
|
+
type=types.Type.OBJECT,
|
322
|
+
properties={
|
323
|
+
"test_arg": types.Schema(type=types.Type.STRING),
|
324
|
+
},
|
325
|
+
),
|
326
|
+
),
|
327
|
+
{
|
328
|
+
"type": "function",
|
329
|
+
"function": {
|
330
|
+
"name": "test_function_no_description",
|
331
|
+
"description": "",
|
332
|
+
"parameters": {
|
333
|
+
"type": "object",
|
334
|
+
"properties": {
|
335
|
+
"test_arg": {"type": "string"},
|
336
|
+
},
|
337
|
+
},
|
338
|
+
},
|
339
|
+
},
|
340
|
+
),
|
341
|
+
(
|
342
|
+
"empty_parameters",
|
343
|
+
types.FunctionDeclaration(
|
344
|
+
name="test_function_empty_params",
|
345
|
+
parameters=types.Schema(type=types.Type.OBJECT, properties={}),
|
346
|
+
),
|
347
|
+
{
|
348
|
+
"type": "function",
|
349
|
+
"function": {
|
350
|
+
"name": "test_function_empty_params",
|
351
|
+
"description": "",
|
352
|
+
"parameters": {
|
353
|
+
"type": "object",
|
354
|
+
"properties": {},
|
355
|
+
},
|
356
|
+
},
|
357
|
+
},
|
358
|
+
),
|
359
|
+
(
|
360
|
+
"nested_array",
|
361
|
+
types.FunctionDeclaration(
|
362
|
+
name="test_function_nested_array",
|
363
|
+
parameters=types.Schema(
|
364
|
+
type=types.Type.OBJECT,
|
365
|
+
properties={
|
366
|
+
"array_arg": types.Schema(
|
367
|
+
type=types.Type.ARRAY,
|
368
|
+
items=types.Schema(
|
369
|
+
type=types.Type.OBJECT,
|
370
|
+
properties={
|
371
|
+
"nested_key": types.Schema(
|
372
|
+
type=types.Type.STRING
|
373
|
+
)
|
374
|
+
},
|
375
|
+
),
|
376
|
+
),
|
377
|
+
},
|
378
|
+
),
|
379
|
+
),
|
380
|
+
{
|
381
|
+
"type": "function",
|
382
|
+
"function": {
|
383
|
+
"name": "test_function_nested_array",
|
384
|
+
"description": "",
|
385
|
+
"parameters": {
|
386
|
+
"type": "object",
|
387
|
+
"properties": {
|
388
|
+
"array_arg": {
|
389
|
+
"items": {
|
390
|
+
"properties": {
|
391
|
+
"nested_key": {"type": "string"}
|
392
|
+
},
|
393
|
+
"type": "object",
|
394
|
+
},
|
395
|
+
"type": "array",
|
396
|
+
},
|
397
|
+
},
|
398
|
+
},
|
399
|
+
},
|
400
|
+
},
|
401
|
+
),
|
402
|
+
]
|
403
|
+
|
404
|
+
|
405
|
+
@pytest.mark.parametrize(
|
406
|
+
"_, function_declaration, expected_output",
|
407
|
+
function_declaration_test_cases,
|
408
|
+
ids=[case[0] for case in function_declaration_test_cases],
|
409
|
+
)
|
410
|
+
def test_function_declaration_to_tool_param(
|
411
|
+
_, function_declaration, expected_output
|
412
|
+
):
|
413
|
+
assert (
|
414
|
+
_function_declaration_to_tool_param(function_declaration)
|
415
|
+
== expected_output
|
416
|
+
)
|
417
|
+
|
418
|
+
|
419
|
+
@pytest.mark.asyncio
|
420
|
+
async def test_generate_content_async_with_system_instruction(
|
421
|
+
lite_llm_instance, mock_acompletion
|
422
|
+
):
|
423
|
+
mock_response_with_system_instruction = ModelResponse(
|
424
|
+
choices=[
|
425
|
+
Choices(
|
426
|
+
message=ChatCompletionAssistantMessage(
|
427
|
+
role="assistant",
|
428
|
+
content="Test response",
|
429
|
+
)
|
430
|
+
)
|
431
|
+
]
|
432
|
+
)
|
433
|
+
mock_acompletion.return_value = mock_response_with_system_instruction
|
434
|
+
|
435
|
+
llm_request = LlmRequest(
|
436
|
+
contents=[
|
437
|
+
types.Content(
|
438
|
+
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
439
|
+
)
|
440
|
+
],
|
441
|
+
config=types.GenerateContentConfig(
|
442
|
+
system_instruction="Test system instruction"
|
443
|
+
),
|
444
|
+
)
|
445
|
+
|
446
|
+
async for response in lite_llm_instance.generate_content_async(llm_request):
|
447
|
+
assert response.content.role == "model"
|
448
|
+
assert response.content.parts[0].text == "Test response"
|
449
|
+
|
450
|
+
mock_acompletion.assert_called_once()
|
451
|
+
|
452
|
+
_, kwargs = mock_acompletion.call_args
|
453
|
+
assert kwargs["model"] == "test_model"
|
454
|
+
assert kwargs["messages"][0]["role"] == "developer"
|
455
|
+
assert kwargs["messages"][0]["content"] == "Test system instruction"
|
456
|
+
assert kwargs["messages"][1]["role"] == "user"
|
457
|
+
assert kwargs["messages"][1]["content"] == "Test prompt"
|
458
|
+
|
459
|
+
|
460
|
+
@pytest.mark.asyncio
|
461
|
+
async def test_generate_content_async_with_tool_response(
|
462
|
+
lite_llm_instance, mock_acompletion
|
463
|
+
):
|
464
|
+
mock_response_with_tool_response = ModelResponse(
|
465
|
+
choices=[
|
466
|
+
Choices(
|
467
|
+
message=ChatCompletionAssistantMessage(
|
468
|
+
role="tool",
|
469
|
+
content='{"result": "test_result"}',
|
470
|
+
tool_call_id="test_tool_call_id",
|
471
|
+
)
|
472
|
+
)
|
473
|
+
]
|
474
|
+
)
|
475
|
+
mock_acompletion.return_value = mock_response_with_tool_response
|
476
|
+
|
477
|
+
llm_request = LlmRequest(
|
478
|
+
contents=[
|
479
|
+
types.Content(
|
480
|
+
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
481
|
+
),
|
482
|
+
types.Content(
|
483
|
+
role="tool",
|
484
|
+
parts=[
|
485
|
+
types.Part.from_function_response(
|
486
|
+
name="test_function",
|
487
|
+
response={"result": "test_result"},
|
488
|
+
)
|
489
|
+
],
|
490
|
+
),
|
491
|
+
],
|
492
|
+
config=types.GenerateContentConfig(
|
493
|
+
system_instruction="test instruction",
|
494
|
+
),
|
495
|
+
)
|
496
|
+
async for response in lite_llm_instance.generate_content_async(llm_request):
|
497
|
+
assert response.content.role == "model"
|
498
|
+
assert response.content.parts[0].text == '{"result": "test_result"}'
|
499
|
+
|
500
|
+
mock_acompletion.assert_called_once()
|
501
|
+
|
502
|
+
_, kwargs = mock_acompletion.call_args
|
503
|
+
assert kwargs["model"] == "test_model"
|
504
|
+
|
505
|
+
assert kwargs["messages"][2]["role"] == "tool"
|
506
|
+
assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
|
507
|
+
|
508
|
+
|
509
|
+
def test_content_to_message_param_user_message():
|
510
|
+
content = types.Content(
|
511
|
+
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
512
|
+
)
|
513
|
+
message = _content_to_message_param(content)
|
514
|
+
assert message["role"] == "user"
|
515
|
+
assert message["content"] == "Test prompt"
|
516
|
+
|
517
|
+
|
518
|
+
def test_content_to_message_param_assistant_message():
|
519
|
+
content = types.Content(
|
520
|
+
role="assistant", parts=[types.Part.from_text(text="Test response")]
|
521
|
+
)
|
522
|
+
message = _content_to_message_param(content)
|
523
|
+
assert message["role"] == "assistant"
|
524
|
+
assert message["content"] == "Test response"
|
525
|
+
|
526
|
+
|
527
|
+
def test_content_to_message_param_function_call():
|
528
|
+
content = types.Content(
|
529
|
+
role="assistant",
|
530
|
+
parts=[
|
531
|
+
types.Part.from_function_call(
|
532
|
+
name="test_function", args={"test_arg": "test_value"}
|
533
|
+
)
|
534
|
+
],
|
535
|
+
)
|
536
|
+
content.parts[0].function_call.id = "test_tool_call_id"
|
537
|
+
message = _content_to_message_param(content)
|
538
|
+
assert message["role"] == "assistant"
|
539
|
+
assert message["content"] == []
|
540
|
+
assert message["tool_calls"][0].type == "function"
|
541
|
+
assert message["tool_calls"][0].id == "test_tool_call_id"
|
542
|
+
assert message["tool_calls"][0].function.name == "test_function"
|
543
|
+
assert (
|
544
|
+
message["tool_calls"][0].function.arguments
|
545
|
+
== '{"test_arg": "test_value"}'
|
546
|
+
)
|
547
|
+
|
548
|
+
|
549
|
+
def test_message_to_generate_content_response_text():
|
550
|
+
message = ChatCompletionAssistantMessage(
|
551
|
+
role="assistant",
|
552
|
+
content="Test response",
|
553
|
+
)
|
554
|
+
response = _message_to_generate_content_response(message)
|
555
|
+
assert response.content.role == "model"
|
556
|
+
assert response.content.parts[0].text == "Test response"
|
557
|
+
|
558
|
+
|
559
|
+
def test_message_to_generate_content_response_tool_call():
|
560
|
+
message = ChatCompletionAssistantMessage(
|
561
|
+
role="assistant",
|
562
|
+
content=None,
|
563
|
+
tool_calls=[
|
564
|
+
ChatCompletionMessageToolCall(
|
565
|
+
type="function",
|
566
|
+
id="test_tool_call_id",
|
567
|
+
function=Function(
|
568
|
+
name="test_function",
|
569
|
+
arguments='{"test_arg": "test_value"}',
|
570
|
+
),
|
571
|
+
)
|
572
|
+
],
|
573
|
+
)
|
574
|
+
|
575
|
+
response = _message_to_generate_content_response(message)
|
576
|
+
assert response.content.role == "model"
|
577
|
+
assert response.content.parts[0].function_call.name == "test_function"
|
578
|
+
assert response.content.parts[0].function_call.args == {
|
579
|
+
"test_arg": "test_value"
|
580
|
+
}
|
581
|
+
assert response.content.parts[0].function_call.id == "test_tool_call_id"
|
582
|
+
|
583
|
+
|
584
|
+
def test_get_content_text():
|
585
|
+
parts = [types.Part.from_text(text="Test text")]
|
586
|
+
content = _get_content(parts)
|
587
|
+
assert content == "Test text"
|
588
|
+
|
589
|
+
|
590
|
+
def test_get_content_image():
|
591
|
+
parts = [
|
592
|
+
types.Part.from_bytes(data=b"test_image_data", mime_type="image/png")
|
593
|
+
]
|
594
|
+
content = _get_content(parts)
|
595
|
+
assert content[0]["type"] == "image_url"
|
596
|
+
assert content[0]["image_url"] == ""
|
597
|
+
|
598
|
+
|
599
|
+
def test_get_content_video():
|
600
|
+
parts = [
|
601
|
+
types.Part.from_bytes(data=b"test_video_data", mime_type="video/mp4")
|
602
|
+
]
|
603
|
+
content = _get_content(parts)
|
604
|
+
assert content[0]["type"] == "video_url"
|
605
|
+
assert content[0]["video_url"] == "data:video/mp4;base64,dGVzdF92aWRlb19kYXRh"
|
606
|
+
|
607
|
+
|
608
|
+
def test_to_litellm_role():
|
609
|
+
assert _to_litellm_role("model") == "assistant"
|
610
|
+
assert _to_litellm_role("assistant") == "assistant"
|
611
|
+
assert _to_litellm_role("user") == "user"
|
612
|
+
assert _to_litellm_role(None) == "user"
|
613
|
+
|
614
|
+
|
615
|
+
@pytest.mark.parametrize(
|
616
|
+
"response, expected_chunk, expected_finished",
|
617
|
+
[
|
618
|
+
(
|
619
|
+
ModelResponse(
|
620
|
+
choices=[
|
621
|
+
{
|
622
|
+
"message": {
|
623
|
+
"content": "this is a test",
|
624
|
+
}
|
625
|
+
}
|
626
|
+
]
|
627
|
+
),
|
628
|
+
TextChunk(text="this is a test"),
|
629
|
+
"stop",
|
630
|
+
),
|
631
|
+
(
|
632
|
+
ModelResponse(
|
633
|
+
choices=[
|
634
|
+
StreamingChoices(
|
635
|
+
finish_reason=None,
|
636
|
+
delta=Delta(
|
637
|
+
role="assistant",
|
638
|
+
tool_calls=[
|
639
|
+
ChatCompletionDeltaToolCall(
|
640
|
+
type="function",
|
641
|
+
id="1",
|
642
|
+
function=Function(
|
643
|
+
name="test_function",
|
644
|
+
arguments='{"key": "va',
|
645
|
+
),
|
646
|
+
index=0,
|
647
|
+
)
|
648
|
+
],
|
649
|
+
),
|
650
|
+
)
|
651
|
+
]
|
652
|
+
),
|
653
|
+
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
|
654
|
+
None,
|
655
|
+
),
|
656
|
+
(
|
657
|
+
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
|
658
|
+
None,
|
659
|
+
"tool_calls",
|
660
|
+
),
|
661
|
+
(ModelResponse(choices=[{}]), None, "stop"),
|
662
|
+
],
|
663
|
+
)
|
664
|
+
def test_model_response_to_chunk(response, expected_chunk, expected_finished):
|
665
|
+
result = list(_model_response_to_chunk(response))
|
666
|
+
assert len(result) == 1
|
667
|
+
chunk, finished = result[0]
|
668
|
+
if expected_chunk:
|
669
|
+
assert isinstance(chunk, type(expected_chunk))
|
670
|
+
assert chunk == expected_chunk
|
671
|
+
else:
|
672
|
+
assert chunk is None
|
673
|
+
assert finished == expected_finished
|
674
|
+
|
675
|
+
|
676
|
+
@pytest.mark.asyncio
|
677
|
+
async def test_acompletion_additional_args(mock_acompletion, mock_client):
|
678
|
+
lite_llm_instance = LiteLlm(
|
679
|
+
# valid args
|
680
|
+
model="test_model",
|
681
|
+
llm_client=mock_client,
|
682
|
+
api_key="test_key",
|
683
|
+
api_base="some://url",
|
684
|
+
api_version="2024-09-12",
|
685
|
+
# invalid args (ignored)
|
686
|
+
stream=True,
|
687
|
+
messages=[{"role": "invalid", "content": "invalid"}],
|
688
|
+
tools=[{
|
689
|
+
"type": "function",
|
690
|
+
"function": {
|
691
|
+
"name": "invalid",
|
692
|
+
},
|
693
|
+
}],
|
694
|
+
)
|
695
|
+
|
696
|
+
async for response in lite_llm_instance.generate_content_async(
|
697
|
+
LLM_REQUEST_WITH_FUNCTION_DECLARATION
|
698
|
+
):
|
699
|
+
assert response.content.role == "model"
|
700
|
+
assert response.content.parts[0].text == "Test response"
|
701
|
+
assert response.content.parts[1].function_call.name == "test_function"
|
702
|
+
assert response.content.parts[1].function_call.args == {
|
703
|
+
"test_arg": "test_value"
|
704
|
+
}
|
705
|
+
assert response.content.parts[1].function_call.id == "test_tool_call_id"
|
706
|
+
|
707
|
+
mock_acompletion.assert_called_once()
|
708
|
+
|
709
|
+
_, kwargs = mock_acompletion.call_args
|
710
|
+
|
711
|
+
assert kwargs["model"] == "test_model"
|
712
|
+
assert kwargs["messages"][0]["role"] == "user"
|
713
|
+
assert kwargs["messages"][0]["content"] == "Test prompt"
|
714
|
+
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
715
|
+
assert "stream" not in kwargs
|
716
|
+
assert "llm_client" not in kwargs
|
717
|
+
assert kwargs["api_base"] == "some://url"
|
718
|
+
|
719
|
+
|
720
|
+
@pytest.mark.asyncio
|
721
|
+
async def test_completion_additional_args(mock_completion, mock_client):
|
722
|
+
lite_llm_instance = LiteLlm(
|
723
|
+
# valid args
|
724
|
+
model="test_model",
|
725
|
+
llm_client=mock_client,
|
726
|
+
api_key="test_key",
|
727
|
+
api_base="some://url",
|
728
|
+
api_version="2024-09-12",
|
729
|
+
# invalid args (ignored)
|
730
|
+
stream=False,
|
731
|
+
messages=[{"role": "invalid", "content": "invalid"}],
|
732
|
+
tools=[{
|
733
|
+
"type": "function",
|
734
|
+
"function": {
|
735
|
+
"name": "invalid",
|
736
|
+
},
|
737
|
+
}],
|
738
|
+
)
|
739
|
+
|
740
|
+
mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE)
|
741
|
+
|
742
|
+
responses = [
|
743
|
+
response
|
744
|
+
async for response in lite_llm_instance.generate_content_async(
|
745
|
+
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
|
746
|
+
)
|
747
|
+
]
|
748
|
+
assert len(responses) == 4
|
749
|
+
mock_completion.assert_called_once()
|
750
|
+
|
751
|
+
_, kwargs = mock_completion.call_args
|
752
|
+
|
753
|
+
assert kwargs["model"] == "test_model"
|
754
|
+
assert kwargs["messages"][0]["role"] == "user"
|
755
|
+
assert kwargs["messages"][0]["content"] == "Test prompt"
|
756
|
+
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
757
|
+
assert kwargs["stream"]
|
758
|
+
assert "llm_client" not in kwargs
|
759
|
+
assert kwargs["api_base"] == "some://url"
|
760
|
+
|
761
|
+
|
762
|
+
@pytest.mark.asyncio
|
763
|
+
async def test_generate_content_async_stream(
|
764
|
+
mock_completion, lite_llm_instance
|
765
|
+
):
|
766
|
+
|
767
|
+
mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE)
|
768
|
+
|
769
|
+
responses = [
|
770
|
+
response
|
771
|
+
async for response in lite_llm_instance.generate_content_async(
|
772
|
+
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
|
773
|
+
)
|
774
|
+
]
|
775
|
+
assert len(responses) == 4
|
776
|
+
assert responses[0].content.role == "model"
|
777
|
+
assert responses[0].content.parts[0].text == "zero, "
|
778
|
+
assert responses[1].content.role == "model"
|
779
|
+
assert responses[1].content.parts[0].text == "one, "
|
780
|
+
assert responses[2].content.role == "model"
|
781
|
+
assert responses[2].content.parts[0].text == "two:"
|
782
|
+
assert responses[3].content.role == "model"
|
783
|
+
assert responses[3].content.parts[0].function_call.name == "test_function"
|
784
|
+
assert responses[3].content.parts[0].function_call.args == {
|
785
|
+
"test_arg": "test_value"
|
786
|
+
}
|
787
|
+
assert responses[3].content.parts[0].function_call.id == "test_tool_call_id"
|
788
|
+
mock_completion.assert_called_once()
|
789
|
+
|
790
|
+
_, kwargs = mock_completion.call_args
|
791
|
+
assert kwargs["model"] == "test_model"
|
792
|
+
assert kwargs["messages"][0]["role"] == "user"
|
793
|
+
assert kwargs["messages"][0]["content"] == "Test prompt"
|
794
|
+
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
795
|
+
assert (
|
796
|
+
kwargs["tools"][0]["function"]["description"]
|
797
|
+
== "Test function description"
|
798
|
+
)
|
799
|
+
assert (
|
800
|
+
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
|
801
|
+
"type"
|
802
|
+
]
|
803
|
+
== "string"
|
804
|
+
)
|