google-adk 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/__init__.py +20 -0
- google/adk/agents/__init__.py +32 -0
- google/adk/agents/active_streaming_tool.py +38 -0
- google/adk/agents/base_agent.py +345 -0
- google/adk/agents/callback_context.py +112 -0
- google/adk/agents/invocation_context.py +181 -0
- google/adk/agents/langgraph_agent.py +140 -0
- google/adk/agents/live_request_queue.py +64 -0
- google/adk/agents/llm_agent.py +376 -0
- google/adk/agents/loop_agent.py +62 -0
- google/adk/agents/parallel_agent.py +96 -0
- google/adk/agents/readonly_context.py +46 -0
- google/adk/agents/remote_agent.py +50 -0
- google/adk/agents/run_config.py +87 -0
- google/adk/agents/sequential_agent.py +45 -0
- google/adk/agents/transcription_entry.py +34 -0
- google/adk/artifacts/__init__.py +23 -0
- google/adk/artifacts/base_artifact_service.py +128 -0
- google/adk/artifacts/gcs_artifact_service.py +195 -0
- google/adk/artifacts/in_memory_artifact_service.py +133 -0
- google/adk/auth/__init__.py +22 -0
- google/adk/auth/auth_credential.py +220 -0
- google/adk/auth/auth_handler.py +268 -0
- google/adk/auth/auth_preprocessor.py +116 -0
- google/adk/auth/auth_schemes.py +67 -0
- google/adk/auth/auth_tool.py +55 -0
- google/adk/cli/__init__.py +15 -0
- google/adk/cli/__main__.py +18 -0
- google/adk/cli/agent_graph.py +122 -0
- google/adk/cli/browser/adk_favicon.svg +17 -0
- google/adk/cli/browser/assets/audio-processor.js +51 -0
- google/adk/cli/browser/assets/config/runtime-config.json +3 -0
- google/adk/cli/browser/index.html +33 -0
- google/adk/cli/browser/main-XUU6OGCC.js +75 -0
- google/adk/cli/browser/polyfills-FFHMD2TL.js +18 -0
- google/adk/cli/browser/styles-4VDSPQ37.css +17 -0
- google/adk/cli/cli.py +181 -0
- google/adk/cli/cli_deploy.py +181 -0
- google/adk/cli/cli_eval.py +282 -0
- google/adk/cli/cli_tools_click.py +479 -0
- google/adk/cli/fast_api.py +774 -0
- google/adk/cli/media_streamer/__init__.py +19 -0
- google/adk/cli/media_streamer/index.html +228 -0
- google/adk/cli/utils/__init__.py +49 -0
- google/adk/cli/utils/envs.py +57 -0
- google/adk/cli/utils/evals.py +93 -0
- google/adk/cli/utils/logs.py +72 -0
- google/adk/code_executors/__init__.py +49 -0
- google/adk/code_executors/base_code_executor.py +97 -0
- google/adk/code_executors/code_execution_utils.py +256 -0
- google/adk/code_executors/code_executor_context.py +202 -0
- google/adk/code_executors/container_code_executor.py +196 -0
- google/adk/code_executors/unsafe_local_code_executor.py +71 -0
- google/adk/code_executors/vertex_ai_code_executor.py +234 -0
- google/adk/evaluation/__init__.py +31 -0
- google/adk/evaluation/agent_evaluator.py +329 -0
- google/adk/evaluation/evaluation_constants.py +24 -0
- google/adk/evaluation/evaluation_generator.py +270 -0
- google/adk/evaluation/response_evaluator.py +135 -0
- google/adk/evaluation/trajectory_evaluator.py +184 -0
- google/adk/events/__init__.py +21 -0
- google/adk/events/event.py +130 -0
- google/adk/events/event_actions.py +55 -0
- google/adk/examples/__init__.py +28 -0
- google/adk/examples/base_example_provider.py +35 -0
- google/adk/examples/example.py +27 -0
- google/adk/examples/example_util.py +123 -0
- google/adk/examples/vertex_ai_example_store.py +104 -0
- google/adk/flows/__init__.py +14 -0
- google/adk/flows/llm_flows/__init__.py +20 -0
- google/adk/flows/llm_flows/_base_llm_processor.py +52 -0
- google/adk/flows/llm_flows/_code_execution.py +458 -0
- google/adk/flows/llm_flows/_nl_planning.py +129 -0
- google/adk/flows/llm_flows/agent_transfer.py +132 -0
- google/adk/flows/llm_flows/audio_transcriber.py +109 -0
- google/adk/flows/llm_flows/auto_flow.py +49 -0
- google/adk/flows/llm_flows/base_llm_flow.py +559 -0
- google/adk/flows/llm_flows/basic.py +72 -0
- google/adk/flows/llm_flows/contents.py +370 -0
- google/adk/flows/llm_flows/functions.py +486 -0
- google/adk/flows/llm_flows/identity.py +47 -0
- google/adk/flows/llm_flows/instructions.py +137 -0
- google/adk/flows/llm_flows/single_flow.py +57 -0
- google/adk/memory/__init__.py +35 -0
- google/adk/memory/base_memory_service.py +74 -0
- google/adk/memory/in_memory_memory_service.py +62 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +177 -0
- google/adk/models/__init__.py +31 -0
- google/adk/models/anthropic_llm.py +243 -0
- google/adk/models/base_llm.py +87 -0
- google/adk/models/base_llm_connection.py +76 -0
- google/adk/models/gemini_llm_connection.py +200 -0
- google/adk/models/google_llm.py +331 -0
- google/adk/models/lite_llm.py +673 -0
- google/adk/models/llm_request.py +98 -0
- google/adk/models/llm_response.py +111 -0
- google/adk/models/registry.py +102 -0
- google/adk/planners/__init__.py +23 -0
- google/adk/planners/base_planner.py +66 -0
- google/adk/planners/built_in_planner.py +75 -0
- google/adk/planners/plan_re_act_planner.py +208 -0
- google/adk/runners.py +456 -0
- google/adk/sessions/__init__.py +41 -0
- google/adk/sessions/base_session_service.py +133 -0
- google/adk/sessions/database_session_service.py +522 -0
- google/adk/sessions/in_memory_session_service.py +206 -0
- google/adk/sessions/session.py +54 -0
- google/adk/sessions/state.py +71 -0
- google/adk/sessions/vertex_ai_session_service.py +356 -0
- google/adk/telemetry.py +189 -0
- google/adk/tests/__init__.py +14 -0
- google/adk/tests/integration/.env.example +10 -0
- google/adk/tests/integration/__init__.py +18 -0
- google/adk/tests/integration/conftest.py +119 -0
- google/adk/tests/integration/fixture/__init__.py +14 -0
- google/adk/tests/integration/fixture/agent_with_config/__init__.py +15 -0
- google/adk/tests/integration/fixture/agent_with_config/agent.py +88 -0
- google/adk/tests/integration/fixture/callback_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/callback_agent/agent.py +105 -0
- google/adk/tests/integration/fixture/context_update_test/OWNERS +1 -0
- google/adk/tests/integration/fixture/context_update_test/__init__.py +15 -0
- google/adk/tests/integration/fixture/context_update_test/agent.py +43 -0
- google/adk/tests/integration/fixture/context_update_test/successful_test.session.json +582 -0
- google/adk/tests/integration/fixture/context_variable_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/context_variable_agent/agent.py +115 -0
- google/adk/tests/integration/fixture/customer_support_ma/__init__.py +15 -0
- google/adk/tests/integration/fixture/customer_support_ma/agent.py +172 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/agent.py +338 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/order_query.test.json +69 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/test_config.json +6 -0
- google/adk/tests/integration/fixture/flow_complex_spark/__init__.py +15 -0
- google/adk/tests/integration/fixture/flow_complex_spark/agent.py +182 -0
- google/adk/tests/integration/fixture/flow_complex_spark/sample.debug.log +243 -0
- google/adk/tests/integration/fixture/flow_complex_spark/sample.session.json +190 -0
- google/adk/tests/integration/fixture/hello_world_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/hello_world_agent/agent.py +95 -0
- google/adk/tests/integration/fixture/hello_world_agent/roll_die.test.json +24 -0
- google/adk/tests/integration/fixture/hello_world_agent/test_config.json +6 -0
- google/adk/tests/integration/fixture/home_automation_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/home_automation_agent/agent.py +304 -0
- google/adk/tests/integration/fixture/home_automation_agent/simple_test.test.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/simple_test2.test.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_config.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/dependent_tool_calls.test.json +18 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json +17 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/test_config.json +6 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_multi_turn_conversation.test.json +18 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test.test.json +17 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test2.test.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/test_config.json +5 -0
- google/adk/tests/integration/fixture/tool_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/tool_agent/agent.py +218 -0
- google/adk/tests/integration/fixture/tool_agent/files/Agent_test_plan.pdf +0 -0
- google/adk/tests/integration/fixture/trip_planner_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/trip_planner_agent/agent.py +110 -0
- google/adk/tests/integration/fixture/trip_planner_agent/initial.session.json +13 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_config.json +5 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_files/initial.session.json +13 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_files/test_config.json +5 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_files/trip_inquiry_sub_agent.test.json +7 -0
- google/adk/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json +19 -0
- google/adk/tests/integration/models/__init__.py +14 -0
- google/adk/tests/integration/models/test_google_llm.py +65 -0
- google/adk/tests/integration/test_callback.py +70 -0
- google/adk/tests/integration/test_context_variable.py +67 -0
- google/adk/tests/integration/test_evalute_agent_in_fixture.py +76 -0
- google/adk/tests/integration/test_multi_agent.py +28 -0
- google/adk/tests/integration/test_multi_turn.py +42 -0
- google/adk/tests/integration/test_single_agent.py +23 -0
- google/adk/tests/integration/test_sub_agent.py +26 -0
- google/adk/tests/integration/test_system_instruction.py +177 -0
- google/adk/tests/integration/test_tools.py +287 -0
- google/adk/tests/integration/test_with_test_file.py +34 -0
- google/adk/tests/integration/tools/__init__.py +14 -0
- google/adk/tests/integration/utils/__init__.py +16 -0
- google/adk/tests/integration/utils/asserts.py +75 -0
- google/adk/tests/integration/utils/test_runner.py +97 -0
- google/adk/tests/unittests/__init__.py +14 -0
- google/adk/tests/unittests/agents/__init__.py +14 -0
- google/adk/tests/unittests/agents/test_base_agent.py +407 -0
- google/adk/tests/unittests/agents/test_langgraph_agent.py +191 -0
- google/adk/tests/unittests/agents/test_llm_agent_callbacks.py +138 -0
- google/adk/tests/unittests/agents/test_llm_agent_fields.py +231 -0
- google/adk/tests/unittests/agents/test_loop_agent.py +136 -0
- google/adk/tests/unittests/agents/test_parallel_agent.py +92 -0
- google/adk/tests/unittests/agents/test_sequential_agent.py +114 -0
- google/adk/tests/unittests/artifacts/__init__.py +14 -0
- google/adk/tests/unittests/artifacts/test_artifact_service.py +276 -0
- google/adk/tests/unittests/auth/test_auth_handler.py +575 -0
- google/adk/tests/unittests/conftest.py +73 -0
- google/adk/tests/unittests/fast_api/__init__.py +14 -0
- google/adk/tests/unittests/fast_api/test_fast_api.py +269 -0
- google/adk/tests/unittests/flows/__init__.py +14 -0
- google/adk/tests/unittests/flows/llm_flows/__init__.py +14 -0
- google/adk/tests/unittests/flows/llm_flows/_test_examples.py +142 -0
- google/adk/tests/unittests/flows/llm_flows/test_agent_transfer.py +311 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_long_running.py +244 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_request_euc.py +346 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_sequential.py +93 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_simple.py +258 -0
- google/adk/tests/unittests/flows/llm_flows/test_identity.py +66 -0
- google/adk/tests/unittests/flows/llm_flows/test_instructions.py +164 -0
- google/adk/tests/unittests/flows/llm_flows/test_model_callbacks.py +142 -0
- google/adk/tests/unittests/flows/llm_flows/test_other_configs.py +46 -0
- google/adk/tests/unittests/flows/llm_flows/test_tool_callbacks.py +269 -0
- google/adk/tests/unittests/models/__init__.py +14 -0
- google/adk/tests/unittests/models/test_google_llm.py +224 -0
- google/adk/tests/unittests/models/test_litellm.py +804 -0
- google/adk/tests/unittests/models/test_models.py +60 -0
- google/adk/tests/unittests/sessions/__init__.py +14 -0
- google/adk/tests/unittests/sessions/test_session_service.py +227 -0
- google/adk/tests/unittests/sessions/test_vertex_ai_session_service.py +246 -0
- google/adk/tests/unittests/streaming/__init__.py +14 -0
- google/adk/tests/unittests/streaming/test_streaming.py +50 -0
- google/adk/tests/unittests/tools/__init__.py +14 -0
- google/adk/tests/unittests/tools/apihub_tool/clients/test_apihub_client.py +499 -0
- google/adk/tests/unittests/tools/apihub_tool/test_apihub_toolset.py +204 -0
- google/adk/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py +600 -0
- google/adk/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py +630 -0
- google/adk/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py +345 -0
- google/adk/tests/unittests/tools/google_api_tool/__init__.py +13 -0
- google/adk/tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py +657 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_auto_auth_credential_exchanger.py +145 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_base_auth_credential_exchanger.py +68 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py +153 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py +196 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/test_auth_helper.py +573 -0
- google/adk/tests/unittests/tools/openapi_tool/common/test_common.py +436 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml +1367 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_spec_parser.py +628 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py +139 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py +406 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +966 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py +201 -0
- google/adk/tests/unittests/tools/retrieval/__init__.py +14 -0
- google/adk/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py +147 -0
- google/adk/tests/unittests/tools/test_agent_tool.py +167 -0
- google/adk/tests/unittests/tools/test_base_tool.py +141 -0
- google/adk/tests/unittests/tools/test_build_function_declaration.py +277 -0
- google/adk/tests/unittests/utils.py +304 -0
- google/adk/tools/__init__.py +51 -0
- google/adk/tools/_automatic_function_calling_util.py +346 -0
- google/adk/tools/agent_tool.py +176 -0
- google/adk/tools/apihub_tool/__init__.py +19 -0
- google/adk/tools/apihub_tool/apihub_toolset.py +209 -0
- google/adk/tools/apihub_tool/clients/__init__.py +13 -0
- google/adk/tools/apihub_tool/clients/apihub_client.py +332 -0
- google/adk/tools/apihub_tool/clients/secret_client.py +115 -0
- google/adk/tools/application_integration_tool/__init__.py +19 -0
- google/adk/tools/application_integration_tool/application_integration_toolset.py +230 -0
- google/adk/tools/application_integration_tool/clients/connections_client.py +903 -0
- google/adk/tools/application_integration_tool/clients/integration_client.py +253 -0
- google/adk/tools/base_tool.py +144 -0
- google/adk/tools/built_in_code_execution_tool.py +59 -0
- google/adk/tools/crewai_tool.py +72 -0
- google/adk/tools/example_tool.py +62 -0
- google/adk/tools/exit_loop_tool.py +23 -0
- google/adk/tools/function_parameter_parse_util.py +307 -0
- google/adk/tools/function_tool.py +87 -0
- google/adk/tools/get_user_choice_tool.py +28 -0
- google/adk/tools/google_api_tool/__init__.py +14 -0
- google/adk/tools/google_api_tool/google_api_tool.py +59 -0
- google/adk/tools/google_api_tool/google_api_tool_set.py +107 -0
- google/adk/tools/google_api_tool/google_api_tool_sets.py +55 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +521 -0
- google/adk/tools/google_search_tool.py +68 -0
- google/adk/tools/langchain_tool.py +86 -0
- google/adk/tools/load_artifacts_tool.py +113 -0
- google/adk/tools/load_memory_tool.py +58 -0
- google/adk/tools/load_web_page.py +41 -0
- google/adk/tools/long_running_tool.py +39 -0
- google/adk/tools/mcp_tool/__init__.py +42 -0
- google/adk/tools/mcp_tool/conversion_utils.py +161 -0
- google/adk/tools/mcp_tool/mcp_tool.py +113 -0
- google/adk/tools/mcp_tool/mcp_toolset.py +272 -0
- google/adk/tools/openapi_tool/__init__.py +21 -0
- google/adk/tools/openapi_tool/auth/__init__.py +19 -0
- google/adk/tools/openapi_tool/auth/auth_helpers.py +498 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/__init__.py +25 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/auto_auth_credential_exchanger.py +105 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +55 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py +117 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +97 -0
- google/adk/tools/openapi_tool/common/__init__.py +19 -0
- google/adk/tools/openapi_tool/common/common.py +300 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +32 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py +231 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +144 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +260 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +496 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +268 -0
- google/adk/tools/preload_memory_tool.py +72 -0
- google/adk/tools/retrieval/__init__.py +36 -0
- google/adk/tools/retrieval/base_retrieval_tool.py +37 -0
- google/adk/tools/retrieval/files_retrieval.py +33 -0
- google/adk/tools/retrieval/llama_index_retrieval.py +41 -0
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +107 -0
- google/adk/tools/tool_context.py +90 -0
- google/adk/tools/toolbox_tool.py +46 -0
- google/adk/tools/transfer_to_agent_tool.py +21 -0
- google/adk/tools/vertex_ai_search_tool.py +96 -0
- google/adk/version.py +16 -0
- google_adk-0.0.1.dist-info/LICENSE.txt → google_adk-0.0.2.dist-info/LICENSE +32 -0
- google_adk-0.0.2.dist-info/METADATA +73 -0
- google_adk-0.0.2.dist-info/RECORD +308 -0
- {google_adk-0.0.1.dist-info → google_adk-0.0.2.dist-info}/WHEEL +1 -2
- google_adk-0.0.2.dist-info/entry_points.txt +3 -0
- agent_kit/__init__.py +0 -0
- google_adk-0.0.1.dist-info/METADATA +0 -15
- google_adk-0.0.1.dist-info/RECORD +0 -6
- google_adk-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,57 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Implementation of single flow."""
|
16
|
+
|
17
|
+
import logging
|
18
|
+
|
19
|
+
from ...auth import auth_preprocessor
|
20
|
+
from . import _code_execution
|
21
|
+
from . import _nl_planning
|
22
|
+
from . import basic
|
23
|
+
from . import contents
|
24
|
+
from . import identity
|
25
|
+
from . import instructions
|
26
|
+
from .base_llm_flow import BaseLlmFlow
|
27
|
+
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
|
31
|
+
class SingleFlow(BaseLlmFlow):
|
32
|
+
"""SingleFlow is the LLM flows that handles tools calls.
|
33
|
+
|
34
|
+
A single flow only consider an agent itself and tools.
|
35
|
+
No sub-agents are allowed for single flow.
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(self):
|
39
|
+
super().__init__()
|
40
|
+
self.request_processors += [
|
41
|
+
basic.request_processor,
|
42
|
+
auth_preprocessor.request_processor,
|
43
|
+
instructions.request_processor,
|
44
|
+
identity.request_processor,
|
45
|
+
contents.request_processor,
|
46
|
+
# Some implementations of NL Planning mark planning contents as thoughts
|
47
|
+
# in the post processor. Since these need to be unmarked, NL Planning
|
48
|
+
# should be after contents.
|
49
|
+
_nl_planning.request_processor,
|
50
|
+
# Code execution should be after the contents as it mutates the contents
|
51
|
+
# to optimize data files.
|
52
|
+
_code_execution.request_processor,
|
53
|
+
]
|
54
|
+
self.response_processors += [
|
55
|
+
_nl_planning.response_processor,
|
56
|
+
_code_execution.response_processor,
|
57
|
+
]
|
@@ -0,0 +1,35 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import logging
|
15
|
+
|
16
|
+
from .base_memory_service import BaseMemoryService
|
17
|
+
from .in_memory_memory_service import InMemoryMemoryService
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
__all__ = [
|
22
|
+
'BaseMemoryService',
|
23
|
+
'InMemoryMemoryService',
|
24
|
+
]
|
25
|
+
|
26
|
+
try:
|
27
|
+
from .vertex_ai_rag_memory_service import VertexAiRagMemoryService
|
28
|
+
|
29
|
+
__all__.append('VertexAiRagMemoryService')
|
30
|
+
except ImportError:
|
31
|
+
logger.debug(
|
32
|
+
'The Vertex sdk is not installed. If you want to use the'
|
33
|
+
' VertexAiRagMemoryService please install it. If not, you can ignore this'
|
34
|
+
' warning.'
|
35
|
+
)
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import abc
|
16
|
+
|
17
|
+
from pydantic import BaseModel
|
18
|
+
from pydantic import Field
|
19
|
+
|
20
|
+
from ..events.event import Event
|
21
|
+
from ..sessions.session import Session
|
22
|
+
|
23
|
+
|
24
|
+
class MemoryResult(BaseModel):
|
25
|
+
"""Represents a single memory retrieval result.
|
26
|
+
|
27
|
+
Attributes:
|
28
|
+
session_id: The session id associated with the memory.
|
29
|
+
events: A list of events in the session.
|
30
|
+
"""
|
31
|
+
session_id: str
|
32
|
+
events: list[Event]
|
33
|
+
|
34
|
+
|
35
|
+
class SearchMemoryResponse(BaseModel):
|
36
|
+
"""Represents the response from a memory search.
|
37
|
+
|
38
|
+
Attributes:
|
39
|
+
memories: A list of memory results matching the search query.
|
40
|
+
"""
|
41
|
+
memories: list[MemoryResult] = Field(default_factory=list)
|
42
|
+
|
43
|
+
|
44
|
+
class BaseMemoryService(abc.ABC):
|
45
|
+
"""Base class for memory services.
|
46
|
+
|
47
|
+
The service provides functionalities to ingest sessions into memory so that
|
48
|
+
the memory can be used for user queries.
|
49
|
+
"""
|
50
|
+
|
51
|
+
@abc.abstractmethod
|
52
|
+
def add_session_to_memory(self, session: Session):
|
53
|
+
"""Adds a session to the memory service.
|
54
|
+
|
55
|
+
A session may be added multiple times during its lifetime.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
session: The session to add.
|
59
|
+
"""
|
60
|
+
|
61
|
+
@abc.abstractmethod
|
62
|
+
def search_memory(
|
63
|
+
self, *, app_name: str, user_id: str, query: str
|
64
|
+
) -> SearchMemoryResponse:
|
65
|
+
"""Searches for sessions that match the query.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
app_name: The name of the application.
|
69
|
+
user_id: The id of the user.
|
70
|
+
query: The query to search for.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
A SearchMemoryResponse containing the matching memories.
|
74
|
+
"""
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ..events.event import Event
|
16
|
+
from ..sessions.session import Session
|
17
|
+
from .base_memory_service import BaseMemoryService
|
18
|
+
from .base_memory_service import MemoryResult
|
19
|
+
from .base_memory_service import SearchMemoryResponse
|
20
|
+
|
21
|
+
|
22
|
+
class InMemoryMemoryService(BaseMemoryService):
|
23
|
+
"""An in-memory memory service for prototyping purpose only.
|
24
|
+
|
25
|
+
Uses keyword matching instead of semantic search.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self):
|
29
|
+
self.session_events: dict[str, list[Event]] = {}
|
30
|
+
"""keys are app_name/user_id/session_id"""
|
31
|
+
|
32
|
+
def add_session_to_memory(self, session: Session):
|
33
|
+
key = f'{session.app_name}/{session.user_id}/{session.id}'
|
34
|
+
self.session_events[key] = [
|
35
|
+
event for event in session.events if event.content
|
36
|
+
]
|
37
|
+
|
38
|
+
def search_memory(
|
39
|
+
self, *, app_name: str, user_id: str, query: str
|
40
|
+
) -> SearchMemoryResponse:
|
41
|
+
"""Prototyping purpose only."""
|
42
|
+
keywords = set(query.lower().split())
|
43
|
+
response = SearchMemoryResponse()
|
44
|
+
for key, events in self.session_events.items():
|
45
|
+
if not key.startswith(f'{app_name}/{user_id}/'):
|
46
|
+
continue
|
47
|
+
matched_events = []
|
48
|
+
for event in events:
|
49
|
+
if not event.content or not event.content.parts:
|
50
|
+
continue
|
51
|
+
parts = event.content.parts
|
52
|
+
text = '\n'.join([part.text for part in parts if part.text]).lower()
|
53
|
+
for keyword in keywords:
|
54
|
+
if keyword in text:
|
55
|
+
matched_events.append(event)
|
56
|
+
break
|
57
|
+
if matched_events:
|
58
|
+
session_id = key.split('/')[-1]
|
59
|
+
response.memories.append(
|
60
|
+
MemoryResult(session_id=session_id, events=matched_events)
|
61
|
+
)
|
62
|
+
return response
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from collections import OrderedDict
|
16
|
+
import json
|
17
|
+
import os
|
18
|
+
import tempfile
|
19
|
+
|
20
|
+
from google.genai import types
|
21
|
+
from typing_extensions import override
|
22
|
+
from vertexai.preview import rag
|
23
|
+
|
24
|
+
from ..events.event import Event
|
25
|
+
from ..sessions.session import Session
|
26
|
+
from .base_memory_service import BaseMemoryService
|
27
|
+
from .base_memory_service import MemoryResult
|
28
|
+
from .base_memory_service import SearchMemoryResponse
|
29
|
+
|
30
|
+
|
31
|
+
class VertexAiRagMemoryService(BaseMemoryService):
|
32
|
+
"""A memory service that uses Vertex AI RAG for storage and retrieval."""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
rag_corpus: str = None,
|
37
|
+
similarity_top_k: int = None,
|
38
|
+
vector_distance_threshold: float = 10,
|
39
|
+
):
|
40
|
+
"""Initializes a VertexAiRagMemoryService.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
rag_corpus: The name of the Vertex AI RAG corpus to use. Format:
|
44
|
+
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}``
|
45
|
+
or ``{rag_corpus_id}``
|
46
|
+
similarity_top_k: The number of contexts to retrieve.
|
47
|
+
vector_distance_threshold: Only returns contexts with vector distance
|
48
|
+
smaller than the threshold..
|
49
|
+
"""
|
50
|
+
self.vertex_rag_store = types.VertexRagStore(
|
51
|
+
rag_resources=[rag.RagResource(rag_corpus=rag_corpus)],
|
52
|
+
similarity_top_k=similarity_top_k,
|
53
|
+
vector_distance_threshold=vector_distance_threshold,
|
54
|
+
)
|
55
|
+
|
56
|
+
@override
|
57
|
+
def add_session_to_memory(self, session: Session):
|
58
|
+
with tempfile.NamedTemporaryFile(
|
59
|
+
mode="w", delete=False, suffix=".txt"
|
60
|
+
) as temp_file:
|
61
|
+
|
62
|
+
output_lines = []
|
63
|
+
for event in session.events:
|
64
|
+
if not event.content or not event.content.parts:
|
65
|
+
continue
|
66
|
+
text_parts = [
|
67
|
+
part.text.replace("\n", " ")
|
68
|
+
for part in event.content.parts
|
69
|
+
if part.text
|
70
|
+
]
|
71
|
+
if text_parts:
|
72
|
+
output_lines.append(
|
73
|
+
json.dumps({
|
74
|
+
"author": event.author,
|
75
|
+
"timestamp": event.timestamp,
|
76
|
+
"text": ".".join(text_parts),
|
77
|
+
})
|
78
|
+
)
|
79
|
+
output_string = "\n".join(output_lines)
|
80
|
+
temp_file.write(output_string)
|
81
|
+
temp_file_path = temp_file.name
|
82
|
+
for rag_resource in self.vertex_rag_store.rag_resources:
|
83
|
+
rag.upload_file(
|
84
|
+
corpus_name=rag_resource.rag_corpus,
|
85
|
+
path=temp_file_path,
|
86
|
+
# this is the temp workaround as upload file does not support
|
87
|
+
# adding metadata, thus use display_name to store the session info.
|
88
|
+
display_name=f"{session.app_name}.{session.user_id}.{session.id}",
|
89
|
+
)
|
90
|
+
|
91
|
+
os.remove(temp_file_path)
|
92
|
+
|
93
|
+
@override
|
94
|
+
def search_memory(
|
95
|
+
self, *, app_name: str, user_id: str, query: str
|
96
|
+
) -> SearchMemoryResponse:
|
97
|
+
"""Searches for sessions that match the query using rag.retrieval_query."""
|
98
|
+
response = rag.retrieval_query(
|
99
|
+
text=query,
|
100
|
+
rag_resources=self.vertex_rag_store.rag_resources,
|
101
|
+
rag_corpora=self.vertex_rag_store.rag_corpora,
|
102
|
+
similarity_top_k=self.vertex_rag_store.similarity_top_k,
|
103
|
+
vector_distance_threshold=self.vertex_rag_store.vector_distance_threshold,
|
104
|
+
)
|
105
|
+
|
106
|
+
memory_results = []
|
107
|
+
session_events_map = OrderedDict()
|
108
|
+
for context in response.contexts.contexts:
|
109
|
+
# filter out context that is not related
|
110
|
+
# TODO: Add server side filtering by app_name and user_id.
|
111
|
+
# if not context.source_display_name.startswith(f"{app_name}.{user_id}."):
|
112
|
+
# continue
|
113
|
+
session_id = context.source_display_name.split(".")[-1]
|
114
|
+
events = []
|
115
|
+
if context.text:
|
116
|
+
lines = context.text.split("\n")
|
117
|
+
|
118
|
+
for line in lines:
|
119
|
+
line = line.strip()
|
120
|
+
if not line:
|
121
|
+
continue
|
122
|
+
|
123
|
+
try:
|
124
|
+
# Try to parse as JSON
|
125
|
+
event_data = json.loads(line)
|
126
|
+
|
127
|
+
author = event_data.get("author", "")
|
128
|
+
timestamp = float(event_data.get("timestamp", 0))
|
129
|
+
text = event_data.get("text", "")
|
130
|
+
|
131
|
+
content = types.Content(parts=[types.Part(text=text)])
|
132
|
+
event = Event(author=author, timestamp=timestamp, content=content)
|
133
|
+
events.append(event)
|
134
|
+
except json.JSONDecodeError:
|
135
|
+
# Not valid JSON, skip this line
|
136
|
+
continue
|
137
|
+
|
138
|
+
if session_id in session_events_map:
|
139
|
+
session_events_map[session_id].append(events)
|
140
|
+
else:
|
141
|
+
session_events_map[session_id] = [events]
|
142
|
+
|
143
|
+
# Remove overlap and combine events from the same session.
|
144
|
+
for session_id, event_lists in session_events_map.items():
|
145
|
+
for events in _merge_event_lists(event_lists):
|
146
|
+
sorted_events = sorted(events, key=lambda e: e.timestamp)
|
147
|
+
memory_results.append(
|
148
|
+
MemoryResult(session_id=session_id, events=sorted_events)
|
149
|
+
)
|
150
|
+
return SearchMemoryResponse(memories=memory_results)
|
151
|
+
|
152
|
+
|
153
|
+
def _merge_event_lists(event_lists: list[list[Event]]) -> list[list[Event]]:
|
154
|
+
"""Merge event lists that have overlapping timestamps."""
|
155
|
+
merged = []
|
156
|
+
while event_lists:
|
157
|
+
current = event_lists.pop(0)
|
158
|
+
current_ts = {event.timestamp for event in current}
|
159
|
+
merge_found = True
|
160
|
+
|
161
|
+
# Keep merging until no new overlap is found.
|
162
|
+
while merge_found:
|
163
|
+
merge_found = False
|
164
|
+
remaining = []
|
165
|
+
for other in event_lists:
|
166
|
+
other_ts = {event.timestamp for event in other}
|
167
|
+
# Overlap exists, so we merge and use the merged list to check again
|
168
|
+
if current_ts & other_ts:
|
169
|
+
new_events = [e for e in other if e.timestamp not in current_ts]
|
170
|
+
current.extend(new_events)
|
171
|
+
current_ts.update(e.timestamp for e in new_events)
|
172
|
+
merge_found = True
|
173
|
+
else:
|
174
|
+
remaining.append(other)
|
175
|
+
event_lists = remaining
|
176
|
+
merged.append(current)
|
177
|
+
return merged
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Defines the interface to support a model."""
|
16
|
+
|
17
|
+
from .base_llm import BaseLlm
|
18
|
+
from .google_llm import Gemini
|
19
|
+
from .llm_request import LlmRequest
|
20
|
+
from .llm_response import LlmResponse
|
21
|
+
from .registry import LLMRegistry
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
'BaseLlm',
|
25
|
+
'Gemini',
|
26
|
+
'LLMRegistry',
|
27
|
+
]
|
28
|
+
|
29
|
+
|
30
|
+
for regex in Gemini.supported_models():
|
31
|
+
LLMRegistry.register(Gemini)
|
@@ -0,0 +1,243 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Anthropic integration for Claude models."""
|
16
|
+
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from functools import cached_property
|
20
|
+
import logging
|
21
|
+
import os
|
22
|
+
from typing import AsyncGenerator
|
23
|
+
from typing import Generator
|
24
|
+
from typing import Iterable
|
25
|
+
from typing import Literal
|
26
|
+
from typing import Optional, Union
|
27
|
+
from typing import TYPE_CHECKING
|
28
|
+
|
29
|
+
from anthropic import AnthropicVertex
|
30
|
+
from anthropic import NOT_GIVEN
|
31
|
+
from anthropic import types as anthropic_types
|
32
|
+
from google.genai import types
|
33
|
+
from pydantic import BaseModel
|
34
|
+
from typing_extensions import override
|
35
|
+
|
36
|
+
from .base_llm import BaseLlm
|
37
|
+
from .llm_response import LlmResponse
|
38
|
+
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from .llm_request import LlmRequest
|
41
|
+
|
42
|
+
__all__ = ["Claude"]
|
43
|
+
|
44
|
+
logger = logging.getLogger(__name__)
|
45
|
+
|
46
|
+
MAX_TOKEN = 1024
|
47
|
+
|
48
|
+
|
49
|
+
class ClaudeRequest(BaseModel):
|
50
|
+
system_instruction: str
|
51
|
+
messages: Iterable[anthropic_types.MessageParam]
|
52
|
+
tools: list[anthropic_types.ToolParam]
|
53
|
+
|
54
|
+
|
55
|
+
def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
|
56
|
+
if role in ["model", "assistant"]:
|
57
|
+
return "assistant"
|
58
|
+
return "user"
|
59
|
+
|
60
|
+
|
61
|
+
def to_google_genai_finish_reason(
|
62
|
+
anthropic_stop_reason: Optional[str],
|
63
|
+
) -> types.FinishReason:
|
64
|
+
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
|
65
|
+
return "STOP"
|
66
|
+
if anthropic_stop_reason == "max_tokens":
|
67
|
+
return "MAX_TOKENS"
|
68
|
+
return "FINISH_REASON_UNSPECIFIED"
|
69
|
+
|
70
|
+
|
71
|
+
def part_to_message_block(
|
72
|
+
part: types.Part,
|
73
|
+
) -> Union[
|
74
|
+
anthropic_types.TextBlockParam,
|
75
|
+
anthropic_types.ImageBlockParam,
|
76
|
+
anthropic_types.ToolUseBlockParam,
|
77
|
+
anthropic_types.ToolResultBlockParam,
|
78
|
+
]:
|
79
|
+
if part.text:
|
80
|
+
return anthropic_types.TextBlockParam(text=part.text, type="text")
|
81
|
+
if part.function_call:
|
82
|
+
assert part.function_call.name
|
83
|
+
|
84
|
+
return anthropic_types.ToolUseBlockParam(
|
85
|
+
id=part.function_call.id or "",
|
86
|
+
name=part.function_call.name,
|
87
|
+
input=part.function_call.args,
|
88
|
+
type="tool_use",
|
89
|
+
)
|
90
|
+
if part.function_response:
|
91
|
+
content = ""
|
92
|
+
if (
|
93
|
+
"result" in part.function_response.response
|
94
|
+
and part.function_response.response["result"]
|
95
|
+
):
|
96
|
+
# Transformation is required because the content is a list of dict.
|
97
|
+
# ToolResultBlockParam content doesn't support list of dict. Converting
|
98
|
+
# to str to prevent anthropic.BadRequestError from being thrown.
|
99
|
+
content = str(part.function_response.response["result"])
|
100
|
+
return anthropic_types.ToolResultBlockParam(
|
101
|
+
tool_use_id=part.function_response.id or "",
|
102
|
+
type="tool_result",
|
103
|
+
content=content,
|
104
|
+
is_error=False,
|
105
|
+
)
|
106
|
+
raise NotImplementedError("Not supported yet.")
|
107
|
+
|
108
|
+
|
109
|
+
def content_to_message_param(
|
110
|
+
content: types.Content,
|
111
|
+
) -> anthropic_types.MessageParam:
|
112
|
+
return {
|
113
|
+
"role": to_claude_role(content.role),
|
114
|
+
"content": [part_to_message_block(part) for part in content.parts or []],
|
115
|
+
}
|
116
|
+
|
117
|
+
|
118
|
+
def content_block_to_part(
|
119
|
+
content_block: anthropic_types.ContentBlock,
|
120
|
+
) -> types.Part:
|
121
|
+
if isinstance(content_block, anthropic_types.TextBlock):
|
122
|
+
return types.Part.from_text(text=content_block.text)
|
123
|
+
if isinstance(content_block, anthropic_types.ToolUseBlock):
|
124
|
+
assert isinstance(content_block.input, dict)
|
125
|
+
part = types.Part.from_function_call(
|
126
|
+
name=content_block.name, args=content_block.input
|
127
|
+
)
|
128
|
+
part.function_call.id = content_block.id
|
129
|
+
return part
|
130
|
+
raise NotImplementedError("Not supported yet.")
|
131
|
+
|
132
|
+
|
133
|
+
def message_to_generate_content_response(
|
134
|
+
message: anthropic_types.Message,
|
135
|
+
) -> LlmResponse:
|
136
|
+
|
137
|
+
return LlmResponse(
|
138
|
+
content=types.Content(
|
139
|
+
role="model",
|
140
|
+
parts=[content_block_to_part(cb) for cb in message.content],
|
141
|
+
),
|
142
|
+
# TODO: Deal with these later.
|
143
|
+
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
|
144
|
+
# usage_metadata=types.GenerateContentResponseUsageMetadata(
|
145
|
+
# prompt_token_count=message.usage.input_tokens,
|
146
|
+
# candidates_token_count=message.usage.output_tokens,
|
147
|
+
# total_token_count=(
|
148
|
+
# message.usage.input_tokens + message.usage.output_tokens
|
149
|
+
# ),
|
150
|
+
# ),
|
151
|
+
)
|
152
|
+
|
153
|
+
|
154
|
+
def function_declaration_to_tool_param(
|
155
|
+
function_declaration: types.FunctionDeclaration,
|
156
|
+
) -> anthropic_types.ToolParam:
|
157
|
+
assert function_declaration.name
|
158
|
+
|
159
|
+
properties = {}
|
160
|
+
if (
|
161
|
+
function_declaration.parameters
|
162
|
+
and function_declaration.parameters.properties
|
163
|
+
):
|
164
|
+
for key, value in function_declaration.parameters.properties.items():
|
165
|
+
value_dict = value.model_dump(exclude_none=True)
|
166
|
+
if "type" in value_dict:
|
167
|
+
value_dict["type"] = value_dict["type"].lower()
|
168
|
+
properties[key] = value_dict
|
169
|
+
|
170
|
+
return anthropic_types.ToolParam(
|
171
|
+
name=function_declaration.name,
|
172
|
+
description=function_declaration.description or "",
|
173
|
+
input_schema={
|
174
|
+
"type": "object",
|
175
|
+
"properties": properties,
|
176
|
+
},
|
177
|
+
)
|
178
|
+
|
179
|
+
|
180
|
+
class Claude(BaseLlm):
|
181
|
+
model: str = "claude-3-5-sonnet-v2@20241022"
|
182
|
+
|
183
|
+
@staticmethod
|
184
|
+
@override
|
185
|
+
def supported_models() -> list[str]:
|
186
|
+
return [r"claude-3-.*"]
|
187
|
+
|
188
|
+
@override
|
189
|
+
async def generate_content_async(
|
190
|
+
self, llm_request: LlmRequest, stream: bool = False
|
191
|
+
) -> AsyncGenerator[LlmResponse, None]:
|
192
|
+
messages = [
|
193
|
+
content_to_message_param(content)
|
194
|
+
for content in llm_request.contents or []
|
195
|
+
]
|
196
|
+
tools = NOT_GIVEN
|
197
|
+
if (
|
198
|
+
llm_request.config
|
199
|
+
and llm_request.config.tools
|
200
|
+
and llm_request.config.tools[0].function_declarations
|
201
|
+
):
|
202
|
+
tools = [
|
203
|
+
function_declaration_to_tool_param(tool)
|
204
|
+
for tool in llm_request.config.tools[0].function_declarations
|
205
|
+
]
|
206
|
+
tool_choice = (
|
207
|
+
anthropic_types.ToolChoiceAutoParam(
|
208
|
+
type="auto",
|
209
|
+
# TODO: allow parallel tool use.
|
210
|
+
disable_parallel_tool_use=True,
|
211
|
+
)
|
212
|
+
if llm_request.tools_dict
|
213
|
+
else NOT_GIVEN
|
214
|
+
)
|
215
|
+
message = self._anthropic_client.messages.create(
|
216
|
+
model=llm_request.model,
|
217
|
+
system=llm_request.config.system_instruction,
|
218
|
+
messages=messages,
|
219
|
+
tools=tools,
|
220
|
+
tool_choice=tool_choice,
|
221
|
+
max_tokens=MAX_TOKEN,
|
222
|
+
)
|
223
|
+
logger.info(
|
224
|
+
"Claude response: %s",
|
225
|
+
message.model_dump_json(indent=2, exclude_none=True),
|
226
|
+
)
|
227
|
+
yield message_to_generate_content_response(message)
|
228
|
+
|
229
|
+
@cached_property
|
230
|
+
def _anthropic_client(self) -> AnthropicVertex:
|
231
|
+
if (
|
232
|
+
"GOOGLE_CLOUD_PROJECT" not in os.environ
|
233
|
+
or "GOOGLE_CLOUD_LOCATION" not in os.environ
|
234
|
+
):
|
235
|
+
raise ValueError(
|
236
|
+
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
|
237
|
+
" Anthropic on Vertex."
|
238
|
+
)
|
239
|
+
|
240
|
+
return AnthropicVertex(
|
241
|
+
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
|
242
|
+
region=os.environ["GOOGLE_CLOUD_LOCATION"],
|
243
|
+
)
|