google-adk 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/__init__.py +20 -0
- google/adk/agents/__init__.py +32 -0
- google/adk/agents/active_streaming_tool.py +38 -0
- google/adk/agents/base_agent.py +345 -0
- google/adk/agents/callback_context.py +112 -0
- google/adk/agents/invocation_context.py +181 -0
- google/adk/agents/langgraph_agent.py +140 -0
- google/adk/agents/live_request_queue.py +64 -0
- google/adk/agents/llm_agent.py +376 -0
- google/adk/agents/loop_agent.py +62 -0
- google/adk/agents/parallel_agent.py +96 -0
- google/adk/agents/readonly_context.py +46 -0
- google/adk/agents/remote_agent.py +50 -0
- google/adk/agents/run_config.py +87 -0
- google/adk/agents/sequential_agent.py +45 -0
- google/adk/agents/transcription_entry.py +34 -0
- google/adk/artifacts/__init__.py +23 -0
- google/adk/artifacts/base_artifact_service.py +128 -0
- google/adk/artifacts/gcs_artifact_service.py +195 -0
- google/adk/artifacts/in_memory_artifact_service.py +133 -0
- google/adk/auth/__init__.py +22 -0
- google/adk/auth/auth_credential.py +220 -0
- google/adk/auth/auth_handler.py +268 -0
- google/adk/auth/auth_preprocessor.py +116 -0
- google/adk/auth/auth_schemes.py +67 -0
- google/adk/auth/auth_tool.py +55 -0
- google/adk/cli/__init__.py +15 -0
- google/adk/cli/__main__.py +18 -0
- google/adk/cli/agent_graph.py +122 -0
- google/adk/cli/browser/adk_favicon.svg +17 -0
- google/adk/cli/browser/assets/audio-processor.js +51 -0
- google/adk/cli/browser/assets/config/runtime-config.json +3 -0
- google/adk/cli/browser/index.html +33 -0
- google/adk/cli/browser/main-XUU6OGCC.js +75 -0
- google/adk/cli/browser/polyfills-FFHMD2TL.js +18 -0
- google/adk/cli/browser/styles-4VDSPQ37.css +17 -0
- google/adk/cli/cli.py +181 -0
- google/adk/cli/cli_deploy.py +181 -0
- google/adk/cli/cli_eval.py +282 -0
- google/adk/cli/cli_tools_click.py +479 -0
- google/adk/cli/fast_api.py +774 -0
- google/adk/cli/media_streamer/__init__.py +19 -0
- google/adk/cli/media_streamer/index.html +228 -0
- google/adk/cli/utils/__init__.py +49 -0
- google/adk/cli/utils/envs.py +57 -0
- google/adk/cli/utils/evals.py +93 -0
- google/adk/cli/utils/logs.py +72 -0
- google/adk/code_executors/__init__.py +49 -0
- google/adk/code_executors/base_code_executor.py +97 -0
- google/adk/code_executors/code_execution_utils.py +256 -0
- google/adk/code_executors/code_executor_context.py +202 -0
- google/adk/code_executors/container_code_executor.py +196 -0
- google/adk/code_executors/unsafe_local_code_executor.py +71 -0
- google/adk/code_executors/vertex_ai_code_executor.py +234 -0
- google/adk/evaluation/__init__.py +31 -0
- google/adk/evaluation/agent_evaluator.py +329 -0
- google/adk/evaluation/evaluation_constants.py +24 -0
- google/adk/evaluation/evaluation_generator.py +270 -0
- google/adk/evaluation/response_evaluator.py +135 -0
- google/adk/evaluation/trajectory_evaluator.py +184 -0
- google/adk/events/__init__.py +21 -0
- google/adk/events/event.py +130 -0
- google/adk/events/event_actions.py +55 -0
- google/adk/examples/__init__.py +28 -0
- google/adk/examples/base_example_provider.py +35 -0
- google/adk/examples/example.py +27 -0
- google/adk/examples/example_util.py +123 -0
- google/adk/examples/vertex_ai_example_store.py +104 -0
- google/adk/flows/__init__.py +14 -0
- google/adk/flows/llm_flows/__init__.py +20 -0
- google/adk/flows/llm_flows/_base_llm_processor.py +52 -0
- google/adk/flows/llm_flows/_code_execution.py +458 -0
- google/adk/flows/llm_flows/_nl_planning.py +129 -0
- google/adk/flows/llm_flows/agent_transfer.py +132 -0
- google/adk/flows/llm_flows/audio_transcriber.py +109 -0
- google/adk/flows/llm_flows/auto_flow.py +49 -0
- google/adk/flows/llm_flows/base_llm_flow.py +559 -0
- google/adk/flows/llm_flows/basic.py +72 -0
- google/adk/flows/llm_flows/contents.py +370 -0
- google/adk/flows/llm_flows/functions.py +486 -0
- google/adk/flows/llm_flows/identity.py +47 -0
- google/adk/flows/llm_flows/instructions.py +137 -0
- google/adk/flows/llm_flows/single_flow.py +57 -0
- google/adk/memory/__init__.py +35 -0
- google/adk/memory/base_memory_service.py +74 -0
- google/adk/memory/in_memory_memory_service.py +62 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +177 -0
- google/adk/models/__init__.py +31 -0
- google/adk/models/anthropic_llm.py +243 -0
- google/adk/models/base_llm.py +87 -0
- google/adk/models/base_llm_connection.py +76 -0
- google/adk/models/gemini_llm_connection.py +200 -0
- google/adk/models/google_llm.py +331 -0
- google/adk/models/lite_llm.py +673 -0
- google/adk/models/llm_request.py +98 -0
- google/adk/models/llm_response.py +111 -0
- google/adk/models/registry.py +102 -0
- google/adk/planners/__init__.py +23 -0
- google/adk/planners/base_planner.py +66 -0
- google/adk/planners/built_in_planner.py +75 -0
- google/adk/planners/plan_re_act_planner.py +208 -0
- google/adk/runners.py +456 -0
- google/adk/sessions/__init__.py +41 -0
- google/adk/sessions/base_session_service.py +133 -0
- google/adk/sessions/database_session_service.py +522 -0
- google/adk/sessions/in_memory_session_service.py +206 -0
- google/adk/sessions/session.py +54 -0
- google/adk/sessions/state.py +71 -0
- google/adk/sessions/vertex_ai_session_service.py +356 -0
- google/adk/telemetry.py +189 -0
- google/adk/tests/__init__.py +14 -0
- google/adk/tests/integration/.env.example +10 -0
- google/adk/tests/integration/__init__.py +18 -0
- google/adk/tests/integration/conftest.py +119 -0
- google/adk/tests/integration/fixture/__init__.py +14 -0
- google/adk/tests/integration/fixture/agent_with_config/__init__.py +15 -0
- google/adk/tests/integration/fixture/agent_with_config/agent.py +88 -0
- google/adk/tests/integration/fixture/callback_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/callback_agent/agent.py +105 -0
- google/adk/tests/integration/fixture/context_update_test/OWNERS +1 -0
- google/adk/tests/integration/fixture/context_update_test/__init__.py +15 -0
- google/adk/tests/integration/fixture/context_update_test/agent.py +43 -0
- google/adk/tests/integration/fixture/context_update_test/successful_test.session.json +582 -0
- google/adk/tests/integration/fixture/context_variable_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/context_variable_agent/agent.py +115 -0
- google/adk/tests/integration/fixture/customer_support_ma/__init__.py +15 -0
- google/adk/tests/integration/fixture/customer_support_ma/agent.py +172 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/agent.py +338 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/order_query.test.json +69 -0
- google/adk/tests/integration/fixture/ecommerce_customer_service_agent/test_config.json +6 -0
- google/adk/tests/integration/fixture/flow_complex_spark/__init__.py +15 -0
- google/adk/tests/integration/fixture/flow_complex_spark/agent.py +182 -0
- google/adk/tests/integration/fixture/flow_complex_spark/sample.debug.log +243 -0
- google/adk/tests/integration/fixture/flow_complex_spark/sample.session.json +190 -0
- google/adk/tests/integration/fixture/hello_world_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/hello_world_agent/agent.py +95 -0
- google/adk/tests/integration/fixture/hello_world_agent/roll_die.test.json +24 -0
- google/adk/tests/integration/fixture/hello_world_agent/test_config.json +6 -0
- google/adk/tests/integration/fixture/home_automation_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/home_automation_agent/agent.py +304 -0
- google/adk/tests/integration/fixture/home_automation_agent/simple_test.test.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/simple_test2.test.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_config.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/dependent_tool_calls.test.json +18 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json +17 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/test_config.json +6 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_multi_turn_conversation.test.json +18 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test.test.json +17 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test2.test.json +5 -0
- google/adk/tests/integration/fixture/home_automation_agent/test_files/test_config.json +5 -0
- google/adk/tests/integration/fixture/tool_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/tool_agent/agent.py +218 -0
- google/adk/tests/integration/fixture/tool_agent/files/Agent_test_plan.pdf +0 -0
- google/adk/tests/integration/fixture/trip_planner_agent/__init__.py +15 -0
- google/adk/tests/integration/fixture/trip_planner_agent/agent.py +110 -0
- google/adk/tests/integration/fixture/trip_planner_agent/initial.session.json +13 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_config.json +5 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_files/initial.session.json +13 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_files/test_config.json +5 -0
- google/adk/tests/integration/fixture/trip_planner_agent/test_files/trip_inquiry_sub_agent.test.json +7 -0
- google/adk/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json +19 -0
- google/adk/tests/integration/models/__init__.py +14 -0
- google/adk/tests/integration/models/test_google_llm.py +65 -0
- google/adk/tests/integration/test_callback.py +70 -0
- google/adk/tests/integration/test_context_variable.py +67 -0
- google/adk/tests/integration/test_evalute_agent_in_fixture.py +76 -0
- google/adk/tests/integration/test_multi_agent.py +28 -0
- google/adk/tests/integration/test_multi_turn.py +42 -0
- google/adk/tests/integration/test_single_agent.py +23 -0
- google/adk/tests/integration/test_sub_agent.py +26 -0
- google/adk/tests/integration/test_system_instruction.py +177 -0
- google/adk/tests/integration/test_tools.py +287 -0
- google/adk/tests/integration/test_with_test_file.py +34 -0
- google/adk/tests/integration/tools/__init__.py +14 -0
- google/adk/tests/integration/utils/__init__.py +16 -0
- google/adk/tests/integration/utils/asserts.py +75 -0
- google/adk/tests/integration/utils/test_runner.py +97 -0
- google/adk/tests/unittests/__init__.py +14 -0
- google/adk/tests/unittests/agents/__init__.py +14 -0
- google/adk/tests/unittests/agents/test_base_agent.py +407 -0
- google/adk/tests/unittests/agents/test_langgraph_agent.py +191 -0
- google/adk/tests/unittests/agents/test_llm_agent_callbacks.py +138 -0
- google/adk/tests/unittests/agents/test_llm_agent_fields.py +231 -0
- google/adk/tests/unittests/agents/test_loop_agent.py +136 -0
- google/adk/tests/unittests/agents/test_parallel_agent.py +92 -0
- google/adk/tests/unittests/agents/test_sequential_agent.py +114 -0
- google/adk/tests/unittests/artifacts/__init__.py +14 -0
- google/adk/tests/unittests/artifacts/test_artifact_service.py +276 -0
- google/adk/tests/unittests/auth/test_auth_handler.py +575 -0
- google/adk/tests/unittests/conftest.py +73 -0
- google/adk/tests/unittests/fast_api/__init__.py +14 -0
- google/adk/tests/unittests/fast_api/test_fast_api.py +269 -0
- google/adk/tests/unittests/flows/__init__.py +14 -0
- google/adk/tests/unittests/flows/llm_flows/__init__.py +14 -0
- google/adk/tests/unittests/flows/llm_flows/_test_examples.py +142 -0
- google/adk/tests/unittests/flows/llm_flows/test_agent_transfer.py +311 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_long_running.py +244 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_request_euc.py +346 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_sequential.py +93 -0
- google/adk/tests/unittests/flows/llm_flows/test_functions_simple.py +258 -0
- google/adk/tests/unittests/flows/llm_flows/test_identity.py +66 -0
- google/adk/tests/unittests/flows/llm_flows/test_instructions.py +164 -0
- google/adk/tests/unittests/flows/llm_flows/test_model_callbacks.py +142 -0
- google/adk/tests/unittests/flows/llm_flows/test_other_configs.py +46 -0
- google/adk/tests/unittests/flows/llm_flows/test_tool_callbacks.py +269 -0
- google/adk/tests/unittests/models/__init__.py +14 -0
- google/adk/tests/unittests/models/test_google_llm.py +224 -0
- google/adk/tests/unittests/models/test_litellm.py +804 -0
- google/adk/tests/unittests/models/test_models.py +60 -0
- google/adk/tests/unittests/sessions/__init__.py +14 -0
- google/adk/tests/unittests/sessions/test_session_service.py +227 -0
- google/adk/tests/unittests/sessions/test_vertex_ai_session_service.py +246 -0
- google/adk/tests/unittests/streaming/__init__.py +14 -0
- google/adk/tests/unittests/streaming/test_streaming.py +50 -0
- google/adk/tests/unittests/tools/__init__.py +14 -0
- google/adk/tests/unittests/tools/apihub_tool/clients/test_apihub_client.py +499 -0
- google/adk/tests/unittests/tools/apihub_tool/test_apihub_toolset.py +204 -0
- google/adk/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py +600 -0
- google/adk/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py +630 -0
- google/adk/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py +345 -0
- google/adk/tests/unittests/tools/google_api_tool/__init__.py +13 -0
- google/adk/tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py +657 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_auto_auth_credential_exchanger.py +145 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_base_auth_credential_exchanger.py +68 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py +153 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py +196 -0
- google/adk/tests/unittests/tools/openapi_tool/auth/test_auth_helper.py +573 -0
- google/adk/tests/unittests/tools/openapi_tool/common/test_common.py +436 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml +1367 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_spec_parser.py +628 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py +139 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py +406 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +966 -0
- google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py +201 -0
- google/adk/tests/unittests/tools/retrieval/__init__.py +14 -0
- google/adk/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py +147 -0
- google/adk/tests/unittests/tools/test_agent_tool.py +167 -0
- google/adk/tests/unittests/tools/test_base_tool.py +141 -0
- google/adk/tests/unittests/tools/test_build_function_declaration.py +277 -0
- google/adk/tests/unittests/utils.py +304 -0
- google/adk/tools/__init__.py +51 -0
- google/adk/tools/_automatic_function_calling_util.py +346 -0
- google/adk/tools/agent_tool.py +176 -0
- google/adk/tools/apihub_tool/__init__.py +19 -0
- google/adk/tools/apihub_tool/apihub_toolset.py +209 -0
- google/adk/tools/apihub_tool/clients/__init__.py +13 -0
- google/adk/tools/apihub_tool/clients/apihub_client.py +332 -0
- google/adk/tools/apihub_tool/clients/secret_client.py +115 -0
- google/adk/tools/application_integration_tool/__init__.py +19 -0
- google/adk/tools/application_integration_tool/application_integration_toolset.py +230 -0
- google/adk/tools/application_integration_tool/clients/connections_client.py +903 -0
- google/adk/tools/application_integration_tool/clients/integration_client.py +253 -0
- google/adk/tools/base_tool.py +144 -0
- google/adk/tools/built_in_code_execution_tool.py +59 -0
- google/adk/tools/crewai_tool.py +72 -0
- google/adk/tools/example_tool.py +62 -0
- google/adk/tools/exit_loop_tool.py +23 -0
- google/adk/tools/function_parameter_parse_util.py +307 -0
- google/adk/tools/function_tool.py +87 -0
- google/adk/tools/get_user_choice_tool.py +28 -0
- google/adk/tools/google_api_tool/__init__.py +14 -0
- google/adk/tools/google_api_tool/google_api_tool.py +59 -0
- google/adk/tools/google_api_tool/google_api_tool_set.py +107 -0
- google/adk/tools/google_api_tool/google_api_tool_sets.py +55 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +521 -0
- google/adk/tools/google_search_tool.py +68 -0
- google/adk/tools/langchain_tool.py +86 -0
- google/adk/tools/load_artifacts_tool.py +113 -0
- google/adk/tools/load_memory_tool.py +58 -0
- google/adk/tools/load_web_page.py +41 -0
- google/adk/tools/long_running_tool.py +39 -0
- google/adk/tools/mcp_tool/__init__.py +42 -0
- google/adk/tools/mcp_tool/conversion_utils.py +161 -0
- google/adk/tools/mcp_tool/mcp_tool.py +113 -0
- google/adk/tools/mcp_tool/mcp_toolset.py +272 -0
- google/adk/tools/openapi_tool/__init__.py +21 -0
- google/adk/tools/openapi_tool/auth/__init__.py +19 -0
- google/adk/tools/openapi_tool/auth/auth_helpers.py +498 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/__init__.py +25 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/auto_auth_credential_exchanger.py +105 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +55 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py +117 -0
- google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +97 -0
- google/adk/tools/openapi_tool/common/__init__.py +19 -0
- google/adk/tools/openapi_tool/common/common.py +300 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +32 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py +231 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +144 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +260 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +496 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +268 -0
- google/adk/tools/preload_memory_tool.py +72 -0
- google/adk/tools/retrieval/__init__.py +36 -0
- google/adk/tools/retrieval/base_retrieval_tool.py +37 -0
- google/adk/tools/retrieval/files_retrieval.py +33 -0
- google/adk/tools/retrieval/llama_index_retrieval.py +41 -0
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +107 -0
- google/adk/tools/tool_context.py +90 -0
- google/adk/tools/toolbox_tool.py +46 -0
- google/adk/tools/transfer_to_agent_tool.py +21 -0
- google/adk/tools/vertex_ai_search_tool.py +96 -0
- google/adk/version.py +16 -0
- google_adk-0.0.1.dist-info/LICENSE.txt → google_adk-0.0.2.dist-info/LICENSE +32 -0
- google_adk-0.0.2.dist-info/METADATA +73 -0
- google_adk-0.0.2.dist-info/RECORD +308 -0
- {google_adk-0.0.1.dist-info → google_adk-0.0.2.dist-info}/WHEEL +1 -2
- google_adk-0.0.2.dist-info/entry_points.txt +3 -0
- agent_kit/__init__.py +0 -0
- google_adk-0.0.1.dist-info/METADATA +0 -15
- google_adk-0.0.1.dist-info/RECORD +0 -6
- google_adk-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,522 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import copy
|
16
|
+
from datetime import datetime
|
17
|
+
import json
|
18
|
+
import logging
|
19
|
+
from typing import Any
|
20
|
+
from typing import Optional
|
21
|
+
import uuid
|
22
|
+
|
23
|
+
from sqlalchemy import delete
|
24
|
+
from sqlalchemy import Dialect
|
25
|
+
from sqlalchemy import ForeignKeyConstraint
|
26
|
+
from sqlalchemy import func
|
27
|
+
from sqlalchemy import select
|
28
|
+
from sqlalchemy import Text
|
29
|
+
from sqlalchemy.dialects import postgresql
|
30
|
+
from sqlalchemy.engine import create_engine
|
31
|
+
from sqlalchemy.engine import Engine
|
32
|
+
from sqlalchemy.ext.mutable import MutableDict
|
33
|
+
from sqlalchemy.inspection import inspect
|
34
|
+
from sqlalchemy.orm import DeclarativeBase
|
35
|
+
from sqlalchemy.orm import Mapped
|
36
|
+
from sqlalchemy.orm import mapped_column
|
37
|
+
from sqlalchemy.orm import relationship
|
38
|
+
from sqlalchemy.orm import Session as DatabaseSessionFactory
|
39
|
+
from sqlalchemy.orm import sessionmaker
|
40
|
+
from sqlalchemy.schema import MetaData
|
41
|
+
from sqlalchemy.types import DateTime
|
42
|
+
from sqlalchemy.types import PickleType
|
43
|
+
from sqlalchemy.types import String
|
44
|
+
from sqlalchemy.types import TypeDecorator
|
45
|
+
from typing_extensions import override
|
46
|
+
from tzlocal import get_localzone
|
47
|
+
|
48
|
+
from ..events.event import Event
|
49
|
+
from .base_session_service import BaseSessionService
|
50
|
+
from .base_session_service import GetSessionConfig
|
51
|
+
from .base_session_service import ListEventsResponse
|
52
|
+
from .base_session_service import ListSessionsResponse
|
53
|
+
from .session import Session
|
54
|
+
from .state import State
|
55
|
+
|
56
|
+
logger = logging.getLogger(__name__)
|
57
|
+
|
58
|
+
|
59
|
+
class DynamicJSON(TypeDecorator):
|
60
|
+
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
|
61
|
+
|
62
|
+
serialization for other databases.
|
63
|
+
"""
|
64
|
+
|
65
|
+
impl = Text # Default implementation is TEXT
|
66
|
+
|
67
|
+
def load_dialect_impl(self, dialect: Dialect):
|
68
|
+
if dialect.name == "postgresql":
|
69
|
+
return dialect.type_descriptor(postgresql.JSONB)
|
70
|
+
else:
|
71
|
+
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
72
|
+
|
73
|
+
def process_bind_param(self, value, dialect: Dialect):
|
74
|
+
if value is not None:
|
75
|
+
if dialect.name == "postgresql":
|
76
|
+
return value # JSONB handles dict directly
|
77
|
+
else:
|
78
|
+
return json.dumps(value) # Serialize to JSON string for TEXT
|
79
|
+
return value
|
80
|
+
|
81
|
+
def process_result_value(self, value, dialect: Dialect):
|
82
|
+
if value is not None:
|
83
|
+
if dialect.name == "postgresql":
|
84
|
+
return value # JSONB returns dict directly
|
85
|
+
else:
|
86
|
+
return json.loads(value) # Deserialize from JSON string for TEXT
|
87
|
+
return value
|
88
|
+
|
89
|
+
|
90
|
+
class Base(DeclarativeBase):
|
91
|
+
"""Base class for database tables."""
|
92
|
+
pass
|
93
|
+
|
94
|
+
|
95
|
+
class StorageSession(Base):
|
96
|
+
"""Represents a session stored in the database."""
|
97
|
+
__tablename__ = "sessions"
|
98
|
+
|
99
|
+
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
100
|
+
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
101
|
+
id: Mapped[str] = mapped_column(
|
102
|
+
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
103
|
+
)
|
104
|
+
|
105
|
+
state: Mapped[dict] = mapped_column(
|
106
|
+
MutableDict.as_mutable(DynamicJSON), default={}
|
107
|
+
)
|
108
|
+
|
109
|
+
create_time: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
110
|
+
update_time: Mapped[DateTime] = mapped_column(
|
111
|
+
DateTime(), default=func.now(), onupdate=func.now()
|
112
|
+
)
|
113
|
+
|
114
|
+
storage_events: Mapped[list["StorageEvent"]] = relationship(
|
115
|
+
"StorageEvent",
|
116
|
+
back_populates="storage_session",
|
117
|
+
)
|
118
|
+
|
119
|
+
def __repr__(self):
|
120
|
+
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
|
121
|
+
|
122
|
+
|
123
|
+
class StorageEvent(Base):
|
124
|
+
"""Represents an event stored in the database."""
|
125
|
+
__tablename__ = "events"
|
126
|
+
|
127
|
+
id: Mapped[str] = mapped_column(String, primary_key=True)
|
128
|
+
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
129
|
+
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
130
|
+
session_id: Mapped[str] = mapped_column(String, primary_key=True)
|
131
|
+
|
132
|
+
invocation_id: Mapped[str] = mapped_column(String)
|
133
|
+
author: Mapped[str] = mapped_column(String)
|
134
|
+
branch: Mapped[str] = mapped_column(String, nullable=True)
|
135
|
+
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
136
|
+
content: Mapped[dict] = mapped_column(DynamicJSON)
|
137
|
+
actions: Mapped[dict] = mapped_column(PickleType)
|
138
|
+
|
139
|
+
storage_session: Mapped[StorageSession] = relationship(
|
140
|
+
"StorageSession",
|
141
|
+
back_populates="storage_events",
|
142
|
+
)
|
143
|
+
|
144
|
+
__table_args__ = (
|
145
|
+
ForeignKeyConstraint(
|
146
|
+
["app_name", "user_id", "session_id"],
|
147
|
+
["sessions.app_name", "sessions.user_id", "sessions.id"],
|
148
|
+
ondelete="CASCADE",
|
149
|
+
),
|
150
|
+
)
|
151
|
+
|
152
|
+
|
153
|
+
class StorageAppState(Base):
|
154
|
+
"""Represents an app state stored in the database."""
|
155
|
+
__tablename__ = "app_states"
|
156
|
+
|
157
|
+
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
158
|
+
state: Mapped[dict] = mapped_column(
|
159
|
+
MutableDict.as_mutable(DynamicJSON), default={}
|
160
|
+
)
|
161
|
+
update_time: Mapped[DateTime] = mapped_column(
|
162
|
+
DateTime(), default=func.now(), onupdate=func.now()
|
163
|
+
)
|
164
|
+
|
165
|
+
|
166
|
+
class StorageUserState(Base):
|
167
|
+
"""Represents a user state stored in the database."""
|
168
|
+
__tablename__ = "user_states"
|
169
|
+
|
170
|
+
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
171
|
+
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
172
|
+
state: Mapped[dict] = mapped_column(
|
173
|
+
MutableDict.as_mutable(DynamicJSON), default={}
|
174
|
+
)
|
175
|
+
update_time: Mapped[DateTime] = mapped_column(
|
176
|
+
DateTime(), default=func.now(), onupdate=func.now()
|
177
|
+
)
|
178
|
+
|
179
|
+
|
180
|
+
class DatabaseSessionService(BaseSessionService):
|
181
|
+
"""A session service that uses a database for storage."""
|
182
|
+
|
183
|
+
def __init__(self, db_url: str):
|
184
|
+
"""
|
185
|
+
Args:
|
186
|
+
db_url: The database URL to connect to.
|
187
|
+
"""
|
188
|
+
# 1. Create DB engine for db connection
|
189
|
+
# 2. Create all tables based on schema
|
190
|
+
# 3. Initialize all properies
|
191
|
+
|
192
|
+
supported_dialects = ["postgresql", "mysql", "sqlite"]
|
193
|
+
dialect = db_url.split("://")[0]
|
194
|
+
|
195
|
+
if dialect in supported_dialects:
|
196
|
+
db_engine = create_engine(db_url)
|
197
|
+
else:
|
198
|
+
raise ValueError(f"Unsupported database URL: {db_url}")
|
199
|
+
|
200
|
+
# Get the local timezone
|
201
|
+
local_timezone = get_localzone()
|
202
|
+
logger.info(f"Local timezone: {local_timezone}")
|
203
|
+
|
204
|
+
self.db_engine: Engine = db_engine
|
205
|
+
self.metadata: MetaData = MetaData()
|
206
|
+
self.inspector = inspect(self.db_engine)
|
207
|
+
|
208
|
+
# DB session factory method
|
209
|
+
self.DatabaseSessionFactory: sessionmaker[DatabaseSessionFactory] = (
|
210
|
+
sessionmaker(bind=self.db_engine)
|
211
|
+
)
|
212
|
+
|
213
|
+
# Uncomment to recreate DB every time
|
214
|
+
# Base.metadata.drop_all(self.db_engine)
|
215
|
+
Base.metadata.create_all(self.db_engine)
|
216
|
+
|
217
|
+
@override
|
218
|
+
def create_session(
|
219
|
+
self,
|
220
|
+
*,
|
221
|
+
app_name: str,
|
222
|
+
user_id: str,
|
223
|
+
state: Optional[dict[str, Any]] = None,
|
224
|
+
session_id: Optional[str] = None,
|
225
|
+
) -> Session:
|
226
|
+
# 1. Populate states.
|
227
|
+
# 2. Build storage session object
|
228
|
+
# 3. Add the object to the table
|
229
|
+
# 4. Build the session object with generated id
|
230
|
+
# 5. Return the session
|
231
|
+
|
232
|
+
with self.DatabaseSessionFactory() as sessionFactory:
|
233
|
+
|
234
|
+
# Fetch app and user states from storage
|
235
|
+
storage_app_state = sessionFactory.get(StorageAppState, (app_name))
|
236
|
+
storage_user_state = sessionFactory.get(
|
237
|
+
StorageUserState, (app_name, user_id)
|
238
|
+
)
|
239
|
+
|
240
|
+
app_state = storage_app_state.state if storage_app_state else {}
|
241
|
+
user_state = storage_user_state.state if storage_user_state else {}
|
242
|
+
|
243
|
+
# Create state tables if not exist
|
244
|
+
if not storage_app_state:
|
245
|
+
storage_app_state = StorageAppState(app_name=app_name, state={})
|
246
|
+
sessionFactory.add(storage_app_state)
|
247
|
+
if not storage_user_state:
|
248
|
+
storage_user_state = StorageUserState(
|
249
|
+
app_name=app_name, user_id=user_id, state={}
|
250
|
+
)
|
251
|
+
sessionFactory.add(storage_user_state)
|
252
|
+
|
253
|
+
# Extract state deltas
|
254
|
+
app_state_delta, user_state_delta, session_state = _extract_state_delta(
|
255
|
+
state
|
256
|
+
)
|
257
|
+
|
258
|
+
# Apply state delta
|
259
|
+
app_state.update(app_state_delta)
|
260
|
+
user_state.update(user_state_delta)
|
261
|
+
|
262
|
+
# Store app and user state
|
263
|
+
if app_state_delta:
|
264
|
+
storage_app_state.state = app_state
|
265
|
+
if user_state_delta:
|
266
|
+
storage_user_state.state = user_state
|
267
|
+
|
268
|
+
# Store the session
|
269
|
+
storage_session = StorageSession(
|
270
|
+
app_name=app_name,
|
271
|
+
user_id=user_id,
|
272
|
+
id=session_id,
|
273
|
+
state=session_state,
|
274
|
+
)
|
275
|
+
sessionFactory.add(storage_session)
|
276
|
+
sessionFactory.commit()
|
277
|
+
|
278
|
+
sessionFactory.refresh(storage_session)
|
279
|
+
|
280
|
+
# Merge states for response
|
281
|
+
merged_state = _merge_state(app_state, user_state, session_state)
|
282
|
+
session = Session(
|
283
|
+
app_name=str(storage_session.app_name),
|
284
|
+
user_id=str(storage_session.user_id),
|
285
|
+
id=str(storage_session.id),
|
286
|
+
state=merged_state,
|
287
|
+
last_update_time=storage_session.update_time.timestamp(),
|
288
|
+
)
|
289
|
+
return session
|
290
|
+
return None
|
291
|
+
|
292
|
+
@override
|
293
|
+
def get_session(
|
294
|
+
self,
|
295
|
+
*,
|
296
|
+
app_name: str,
|
297
|
+
user_id: str,
|
298
|
+
session_id: str,
|
299
|
+
config: Optional[GetSessionConfig] = None,
|
300
|
+
) -> Optional[Session]:
|
301
|
+
# 1. Get the storage session entry from session table
|
302
|
+
# 2. Get all the events based on session id and filtering config
|
303
|
+
# 3. Convert and return the session
|
304
|
+
session: Session = None
|
305
|
+
with self.DatabaseSessionFactory() as sessionFactory:
|
306
|
+
storage_session = sessionFactory.get(
|
307
|
+
StorageSession, (app_name, user_id, session_id)
|
308
|
+
)
|
309
|
+
if storage_session is None:
|
310
|
+
return None
|
311
|
+
|
312
|
+
storage_events = (
|
313
|
+
sessionFactory.query(StorageEvent)
|
314
|
+
.filter(StorageEvent.session_id == storage_session.id)
|
315
|
+
.filter(
|
316
|
+
StorageEvent.timestamp < config.after_timestamp
|
317
|
+
if config
|
318
|
+
else True
|
319
|
+
)
|
320
|
+
.limit(config.num_recent_events if config else None)
|
321
|
+
.all()
|
322
|
+
)
|
323
|
+
|
324
|
+
# Fetch states from storage
|
325
|
+
storage_app_state = sessionFactory.get(StorageAppState, (app_name))
|
326
|
+
storage_user_state = sessionFactory.get(
|
327
|
+
StorageUserState, (app_name, user_id)
|
328
|
+
)
|
329
|
+
|
330
|
+
app_state = storage_app_state.state if storage_app_state else {}
|
331
|
+
user_state = storage_user_state.state if storage_user_state else {}
|
332
|
+
session_state = storage_session.state
|
333
|
+
|
334
|
+
# Merge states
|
335
|
+
merged_state = _merge_state(app_state, user_state, session_state)
|
336
|
+
|
337
|
+
# Convert storage session to session
|
338
|
+
session = Session(
|
339
|
+
app_name=app_name,
|
340
|
+
user_id=user_id,
|
341
|
+
id=session_id,
|
342
|
+
state=merged_state,
|
343
|
+
last_update_time=storage_session.update_time.timestamp(),
|
344
|
+
)
|
345
|
+
session.events = [
|
346
|
+
Event(
|
347
|
+
id=e.id,
|
348
|
+
author=e.author,
|
349
|
+
branch=e.branch,
|
350
|
+
invocation_id=e.invocation_id,
|
351
|
+
content=e.content,
|
352
|
+
actions=e.actions,
|
353
|
+
timestamp=e.timestamp.timestamp(),
|
354
|
+
)
|
355
|
+
for e in storage_events
|
356
|
+
]
|
357
|
+
|
358
|
+
return session
|
359
|
+
|
360
|
+
@override
|
361
|
+
def list_sessions(
|
362
|
+
self, *, app_name: str, user_id: str
|
363
|
+
) -> ListSessionsResponse:
|
364
|
+
with self.DatabaseSessionFactory() as sessionFactory:
|
365
|
+
results = (
|
366
|
+
sessionFactory.query(StorageSession)
|
367
|
+
.filter(StorageSession.app_name == app_name)
|
368
|
+
.filter(StorageSession.user_id == user_id)
|
369
|
+
.all()
|
370
|
+
)
|
371
|
+
sessions = []
|
372
|
+
for storage_session in results:
|
373
|
+
session = Session(
|
374
|
+
app_name=app_name,
|
375
|
+
user_id=user_id,
|
376
|
+
id=storage_session.id,
|
377
|
+
state={},
|
378
|
+
last_update_time=storage_session.update_time.timestamp(),
|
379
|
+
)
|
380
|
+
sessions.append(session)
|
381
|
+
return ListSessionsResponse(sessions=sessions)
|
382
|
+
raise ValueError("Failed to retrieve sessions.")
|
383
|
+
|
384
|
+
@override
|
385
|
+
def delete_session(
|
386
|
+
self, app_name: str, user_id: str, session_id: str
|
387
|
+
) -> None:
|
388
|
+
with self.DatabaseSessionFactory() as sessionFactory:
|
389
|
+
stmt = delete(StorageSession).where(
|
390
|
+
StorageSession.app_name == app_name,
|
391
|
+
StorageSession.user_id == user_id,
|
392
|
+
StorageSession.id == session_id,
|
393
|
+
)
|
394
|
+
sessionFactory.execute(stmt)
|
395
|
+
sessionFactory.commit()
|
396
|
+
|
397
|
+
@override
|
398
|
+
def append_event(self, session: Session, event: Event) -> Event:
|
399
|
+
logger.info(f"Append event: {event} to session {session.id}")
|
400
|
+
|
401
|
+
if event.partial and not event.content:
|
402
|
+
return event
|
403
|
+
|
404
|
+
# 1. Check if timestamp is stale
|
405
|
+
# 2. Update session attributes based on event config
|
406
|
+
# 3. Store event to table
|
407
|
+
with self.DatabaseSessionFactory() as sessionFactory:
|
408
|
+
storage_session = sessionFactory.get(
|
409
|
+
StorageSession, (session.app_name, session.user_id, session.id)
|
410
|
+
)
|
411
|
+
|
412
|
+
if storage_session.update_time.timestamp() > session.last_update_time:
|
413
|
+
raise ValueError(
|
414
|
+
f"Session last_update_time {session.last_update_time} is later than"
|
415
|
+
f" the upate_time in storage {storage_session.update_time}"
|
416
|
+
)
|
417
|
+
|
418
|
+
# Fetch states from storage
|
419
|
+
storage_app_state = sessionFactory.get(
|
420
|
+
StorageAppState, (session.app_name)
|
421
|
+
)
|
422
|
+
storage_user_state = sessionFactory.get(
|
423
|
+
StorageUserState, (session.app_name, session.user_id)
|
424
|
+
)
|
425
|
+
|
426
|
+
app_state = storage_app_state.state if storage_app_state else {}
|
427
|
+
user_state = storage_user_state.state if storage_user_state else {}
|
428
|
+
session_state = storage_session.state
|
429
|
+
|
430
|
+
# Extract state delta
|
431
|
+
app_state_delta = {}
|
432
|
+
user_state_delta = {}
|
433
|
+
session_state_delta = {}
|
434
|
+
if event.actions:
|
435
|
+
if event.actions.state_delta:
|
436
|
+
app_state_delta, user_state_delta, session_state_delta = (
|
437
|
+
_extract_state_delta(event.actions.state_delta)
|
438
|
+
)
|
439
|
+
|
440
|
+
# Merge state
|
441
|
+
app_state.update(app_state_delta)
|
442
|
+
user_state.update(user_state_delta)
|
443
|
+
session_state.update(session_state_delta)
|
444
|
+
|
445
|
+
# Update storage
|
446
|
+
storage_app_state.state = app_state
|
447
|
+
storage_user_state.state = user_state
|
448
|
+
storage_session.state = session_state
|
449
|
+
|
450
|
+
encoded_content = event.content.model_dump(exclude_none=True)
|
451
|
+
storage_event = StorageEvent(
|
452
|
+
id=event.id,
|
453
|
+
invocation_id=event.invocation_id,
|
454
|
+
author=event.author,
|
455
|
+
branch=event.branch,
|
456
|
+
content=encoded_content,
|
457
|
+
actions=event.actions,
|
458
|
+
session_id=session.id,
|
459
|
+
app_name=session.app_name,
|
460
|
+
user_id=session.user_id,
|
461
|
+
timestamp=datetime.fromtimestamp(event.timestamp),
|
462
|
+
)
|
463
|
+
|
464
|
+
sessionFactory.add(storage_event)
|
465
|
+
|
466
|
+
sessionFactory.commit()
|
467
|
+
sessionFactory.refresh(storage_session)
|
468
|
+
|
469
|
+
# Update timestamp with commit time
|
470
|
+
session.last_update_time = storage_session.update_time.timestamp()
|
471
|
+
|
472
|
+
# Also update the in-memory session
|
473
|
+
super().append_event(session=session, event=event)
|
474
|
+
return event
|
475
|
+
|
476
|
+
@override
|
477
|
+
def list_events(
|
478
|
+
self,
|
479
|
+
*,
|
480
|
+
app_name: str,
|
481
|
+
user_id: str,
|
482
|
+
session_id: str,
|
483
|
+
) -> ListEventsResponse:
|
484
|
+
pass
|
485
|
+
|
486
|
+
|
487
|
+
def convert_event(event: StorageEvent) -> Event:
|
488
|
+
"""Converts a storage event to an event."""
|
489
|
+
return Event(
|
490
|
+
id=event.id,
|
491
|
+
author=event.author,
|
492
|
+
branch=event.branch,
|
493
|
+
invocation_id=event.invocation_id,
|
494
|
+
content=event.content,
|
495
|
+
actions=event.actions,
|
496
|
+
timestamp=event.timestamp.timestamp(),
|
497
|
+
)
|
498
|
+
|
499
|
+
|
500
|
+
def _extract_state_delta(state: dict):
|
501
|
+
app_state_delta = {}
|
502
|
+
user_state_delta = {}
|
503
|
+
session_state_delta = {}
|
504
|
+
if state:
|
505
|
+
for key in state.keys():
|
506
|
+
if key.startswith(State.APP_PREFIX):
|
507
|
+
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
|
508
|
+
elif key.startswith(State.USER_PREFIX):
|
509
|
+
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
|
510
|
+
elif not key.startswith(State.TEMP_PREFIX):
|
511
|
+
session_state_delta[key] = state[key]
|
512
|
+
return app_state_delta, user_state_delta, session_state_delta
|
513
|
+
|
514
|
+
|
515
|
+
def _merge_state(app_state, user_state, session_state):
|
516
|
+
# Merge states for response
|
517
|
+
merged_state = copy.deepcopy(session_state)
|
518
|
+
for key in app_state.keys():
|
519
|
+
merged_state[State.APP_PREFIX + key] = app_state[key]
|
520
|
+
for key in user_state.keys():
|
521
|
+
merged_state[State.USER_PREFIX + key] = user_state[key]
|
522
|
+
return merged_state
|