agent-starter-pack 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agent-starter-pack might be problematic. Click here for more details.

Files changed (72) hide show
  1. {agent_starter_pack-0.2.2.dist-info → agent_starter_pack-0.3.0.dist-info}/METADATA +14 -16
  2. {agent_starter_pack-0.2.2.dist-info → agent_starter_pack-0.3.0.dist-info}/RECORD +69 -54
  3. agents/adk_base/README.md +14 -0
  4. agents/adk_base/app/agent.py +66 -0
  5. agents/adk_base/notebooks/adk_app_testing.ipynb +305 -0
  6. agents/adk_base/template/.templateconfig.yaml +21 -0
  7. agents/adk_base/tests/integration/test_agent.py +58 -0
  8. agents/agentic_rag/README.md +1 -0
  9. agents/agentic_rag/app/agent.py +44 -89
  10. agents/agentic_rag/app/templates.py +0 -25
  11. agents/agentic_rag/notebooks/adk_app_testing.ipynb +305 -0
  12. agents/agentic_rag/template/.templateconfig.yaml +3 -1
  13. agents/agentic_rag/tests/integration/test_agent.py +34 -27
  14. agents/langgraph_base_react/README.md +1 -1
  15. agents/langgraph_base_react/template/.templateconfig.yaml +1 -1
  16. src/base_template/Makefile +15 -4
  17. src/base_template/README.md +8 -2
  18. src/base_template/app/__init__.py +3 -0
  19. src/base_template/app/utils/tracing.py +11 -1
  20. src/base_template/app/utils/typing.py +54 -4
  21. src/base_template/deployment/README.md +4 -1
  22. src/base_template/deployment/cd/deploy-to-prod.yaml +3 -3
  23. src/base_template/deployment/cd/staging.yaml +4 -4
  24. src/base_template/deployment/ci/pr_checks.yaml +1 -1
  25. src/base_template/deployment/terraform/build_triggers.tf +3 -0
  26. src/base_template/deployment/terraform/dev/variables.tf +4 -0
  27. src/base_template/deployment/terraform/dev/vars/env.tfvars +0 -3
  28. src/base_template/deployment/terraform/variables.tf +4 -0
  29. src/base_template/deployment/terraform/vars/env.tfvars +0 -4
  30. src/base_template/pyproject.toml +5 -3
  31. src/{deployment_targets/agent_engine → base_template}/tests/unit/test_dummy.py +2 -1
  32. src/cli/commands/create.py +45 -11
  33. src/cli/commands/setup_cicd.py +25 -6
  34. src/cli/utils/gcp.py +1 -1
  35. src/cli/utils/template.py +27 -25
  36. src/data_ingestion/README.md +37 -50
  37. src/data_ingestion/data_ingestion_pipeline/components/ingest_data.py +2 -1
  38. src/deployment_targets/agent_engine/app/agent_engine_app.py +68 -22
  39. src/deployment_targets/agent_engine/app/utils/gcs.py +1 -1
  40. src/deployment_targets/agent_engine/tests/integration/test_agent_engine_app.py +63 -0
  41. src/deployment_targets/agent_engine/tests/load_test/load_test.py +9 -2
  42. src/deployment_targets/cloud_run/Dockerfile +1 -1
  43. src/deployment_targets/cloud_run/app/server.py +41 -15
  44. src/deployment_targets/cloud_run/tests/integration/test_server_e2e.py +60 -3
  45. src/deployment_targets/cloud_run/tests/load_test/README.md +1 -1
  46. src/deployment_targets/cloud_run/tests/load_test/load_test.py +57 -24
  47. src/frontends/live_api_react/frontend/package-lock.json +3 -3
  48. src/frontends/streamlit/frontend/utils/stream_handler.py +3 -3
  49. src/frontends/streamlit_adk/frontend/side_bar.py +214 -0
  50. src/frontends/streamlit_adk/frontend/streamlit_app.py +314 -0
  51. src/frontends/streamlit_adk/frontend/style/app_markdown.py +37 -0
  52. src/frontends/streamlit_adk/frontend/utils/chat_utils.py +84 -0
  53. src/frontends/streamlit_adk/frontend/utils/local_chat_history.py +110 -0
  54. src/frontends/streamlit_adk/frontend/utils/message_editing.py +61 -0
  55. src/frontends/streamlit_adk/frontend/utils/multimodal_utils.py +223 -0
  56. src/frontends/streamlit_adk/frontend/utils/stream_handler.py +311 -0
  57. src/frontends/streamlit_adk/frontend/utils/title_summary.py +129 -0
  58. src/resources/locks/uv-adk_base-agent_engine.lock +5335 -0
  59. src/resources/locks/uv-adk_base-cloud_run.lock +5927 -0
  60. src/resources/locks/uv-agentic_rag-agent_engine.lock +939 -732
  61. src/resources/locks/uv-agentic_rag-cloud_run.lock +1087 -907
  62. src/resources/locks/uv-crewai_coding_crew-agent_engine.lock +778 -671
  63. src/resources/locks/uv-crewai_coding_crew-cloud_run.lock +852 -753
  64. src/resources/locks/uv-langgraph_base_react-agent_engine.lock +665 -591
  65. src/resources/locks/uv-langgraph_base_react-cloud_run.lock +842 -743
  66. src/resources/locks/uv-live_api-cloud_run.lock +830 -731
  67. agents/agentic_rag/notebooks/evaluating_langgraph_agent.ipynb +0 -1561
  68. src/base_template/tests/unit/test_utils/test_tracing_exporter.py +0 -140
  69. src/deployment_targets/cloud_run/tests/unit/test_server.py +0 -124
  70. {agent_starter_pack-0.2.2.dist-info → agent_starter_pack-0.3.0.dist-info}/WHEEL +0 -0
  71. {agent_starter_pack-0.2.2.dist-info → agent_starter_pack-0.3.0.dist-info}/entry_points.txt +0 -0
  72. {agent_starter_pack-0.2.2.dist-info → agent_starter_pack-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,305 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# ADK Application Testing\n",
8
+ "\n",
9
+ "This notebook demonstrates how to test an ADK (Agent Development Kit) application.\n",
10
+ "It covers both local and remote testing, both with Agent Engine and Cloud Run.\n",
11
+ "\n",
12
+ "<img src=\"https://github.com/GoogleCloudPlatform/agent-starter-pack/blob/main/docs/images/adk_logo.png?raw=true\" width=\"400\">\n"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "metadata": {},
18
+ "source": [
19
+ "### Import libraries"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 1,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import json\n",
29
+ "\n",
30
+ "import requests\n",
31
+ "import vertexai.agent_engines"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {},
37
+ "source": [
38
+ "## If you are using Agent Engine\n",
39
+ "See more documentation at [Agent Engine Overview](https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/overview)"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "markdown",
44
+ "metadata": {},
45
+ "source": [
46
+ "### Local Testing\n",
47
+ "\n",
48
+ "You can import directly the AgentEngineApp class within your environment. "
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 2,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "from app.agent import root_agent\n",
58
+ "from app.agent_engine_app import AgentEngineApp\n",
59
+ "\n",
60
+ "agent_engine = AgentEngineApp(agent=root_agent)"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "for event in agent_engine.stream_query(message=\"hi!\", user_id=\"test\"):\n",
70
+ " print(event)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "metadata": {},
76
+ "source": [
77
+ "### Remote Testing"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 4,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "# Replace with your Agent Engine ID\n",
87
+ "AGENT_ENGINE_ID = \"projects/PROJECT_ID/locations/us-central1/reasoningEngines/ENGINE_ID\""
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 5,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "remote_agent_engine = vertexai.agent_engines.get(AGENT_ENGINE_ID)"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "for event in remote_agent_engine.stream_query(message=\"hi!\", user_id=\"test\"):\n",
106
+ " print(event)"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "metadata": {},
112
+ "source": [
113
+ "## If you are using Cloud Run"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "metadata": {},
119
+ "source": [
120
+ "### Local Testing\n",
121
+ "\n",
122
+ "> You can run the application locally via the `make backend` command."
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "metadata": {},
128
+ "source": [
129
+ "#### Create a session\n",
130
+ " Create a new session with user preferences and state information\n"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "user_id = \"test_user_123\"\n",
140
+ "session_id = \"test_session_456\"\n",
141
+ "session_data = {\"state\": {\"preferred_language\": \"English\", \"visit_count\": 1}}\n",
142
+ "\n",
143
+ "session_url = f\"http://127.0.0.1:8000/apps/app/users/{user_id}/sessions/{session_id}\"\n",
144
+ "headers = {\"Content-Type\": \"application/json\"}\n",
145
+ "\n",
146
+ "session_response = requests.post(session_url, headers=headers, json=session_data)\n",
147
+ "print(f\"Session creation status code: {session_response.status_code}\")"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "markdown",
152
+ "metadata": {},
153
+ "source": [
154
+ "#### Send a message\n",
155
+ "Send a message to the backend service and receive a streaming response\n"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "message_data = {\n",
165
+ " \"app_name\": \"app\",\n",
166
+ " \"user_id\": user_id,\n",
167
+ " \"session_id\": session_id,\n",
168
+ " \"new_message\": {\"role\": \"user\", \"parts\": [{\"text\": \"Hello! Weather in New york?\"}]},\n",
169
+ " \"streaming\": True,\n",
170
+ "}\n",
171
+ "\n",
172
+ "message_url = \"http://127.0.0.1:8000/run_sse\"\n",
173
+ "message_response = requests.post(\n",
174
+ " message_url, headers=headers, json=message_data, stream=True\n",
175
+ ")\n",
176
+ "\n",
177
+ "print(f\"Message send status code: {message_response.status_code}\")\n",
178
+ "\n",
179
+ "# Print streamed response\n",
180
+ "for line in message_response.iter_lines():\n",
181
+ " if line:\n",
182
+ " line_str = line.decode(\"utf-8\")\n",
183
+ " if line_str.startswith(\"data: \"):\n",
184
+ " event_json = line_str[6:]\n",
185
+ " event = json.loads(event_json)\n",
186
+ " print(f\"Received event: {event}\")"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "metadata": {},
192
+ "source": [
193
+ "#### Remote Testing\n",
194
+ "\n",
195
+ "For more information about authenticating HTTPS requests to Cloud Run services, see:\n",
196
+ "[Cloud Run Authentication Documentation](https://cloud.google.com/run/docs/triggering/https-request)\n",
197
+ "\n",
198
+ "Remote testing involves using a deployed service URL instead of localhost.\n",
199
+ "\n",
200
+ "Authentication is handled using GCP identity tokens instead of local credentials."
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 20,
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": [
209
+ "ID_TOKEN = get_ipython().getoutput(\"gcloud auth print-identity-token -q\")[0]"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "SERVICE_URL = \"YOUR_SERVICE_URL_HERE\" # Replace with your Cloud Run service URL"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "markdown",
223
+ "metadata": {},
224
+ "source": [
225
+ "You'll need to first create a Session"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "metadata": {},
232
+ "outputs": [],
233
+ "source": [
234
+ "user_id = \"test_user_123\"\n",
235
+ "session_id = \"test_session_456\"\n",
236
+ "session_data = {\"state\": {\"preferred_language\": \"English\", \"visit_count\": 1}}\n",
237
+ "\n",
238
+ "session_url = f\"{SERVICE_URL}/apps/app/users/{user_id}/sessions/{session_id}\"\n",
239
+ "headers = {\"Content-Type\": \"application/json\", \"Authorization\": f\"Bearer {ID_TOKEN}\"}\n",
240
+ "\n",
241
+ "session_response = requests.post(session_url, headers=headers, json=session_data)\n",
242
+ "print(f\"Session creation status code: {session_response.status_code}\")"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "markdown",
247
+ "metadata": {},
248
+ "source": [
249
+ "Then you will be able to send a message"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": null,
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "message_data = {\n",
259
+ " \"app_name\": \"app\",\n",
260
+ " \"user_id\": user_id,\n",
261
+ " \"session_id\": session_id,\n",
262
+ " \"new_message\": {\"role\": \"user\", \"parts\": [{\"text\": \"Hello! Weather in New york?\"}]},\n",
263
+ " \"streaming\": True,\n",
264
+ "}\n",
265
+ "\n",
266
+ "message_url = f\"{SERVICE_URL}/run_sse\"\n",
267
+ "message_response = requests.post(\n",
268
+ " message_url, headers=headers, json=message_data, stream=True\n",
269
+ ")\n",
270
+ "\n",
271
+ "print(f\"Message send status code: {message_response.status_code}\")\n",
272
+ "\n",
273
+ "# Print streamed response\n",
274
+ "for line in message_response.iter_lines():\n",
275
+ " if line:\n",
276
+ " line_str = line.decode(\"utf-8\")\n",
277
+ " if line_str.startswith(\"data: \"):\n",
278
+ " event_json = line_str[6:]\n",
279
+ " event = json.loads(event_json)\n",
280
+ " print(f\"Received event: {event}\")"
281
+ ]
282
+ }
283
+ ],
284
+ "metadata": {
285
+ "kernelspec": {
286
+ "display_name": ".venv",
287
+ "language": "python",
288
+ "name": "python3"
289
+ },
290
+ "language_info": {
291
+ "codemirror_mode": {
292
+ "name": "ipython",
293
+ "version": 3
294
+ },
295
+ "file_extension": ".py",
296
+ "mimetype": "text/x-python",
297
+ "name": "python",
298
+ "nbconvert_exporter": "python",
299
+ "pygments_lexer": "ipython3",
300
+ "version": "3.12.8"
301
+ }
302
+ },
303
+ "nbformat": 4,
304
+ "nbformat_minor": 2
305
+ }
@@ -0,0 +1,21 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ description: "An agent implementing a base ReAct agent using Google's Agent Development Kit"
16
+ settings:
17
+ requires_data_ingestion: false
18
+ deployment_targets: ["agent_engine", "cloud_run"]
19
+ extra_dependencies: ["google-adk~=0.1.0"]
20
+ tags: ["adk"]
21
+ frontend_type: "streamlit_adk"
@@ -0,0 +1,58 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # mypy: disable-error-code="union-attr"
16
+ from google.adk.agents.run_config import RunConfig, StreamingMode
17
+ from google.adk.runners import Runner
18
+ from google.adk.sessions import InMemorySessionService
19
+ from google.genai import types
20
+
21
+ from app.agent import root_agent
22
+
23
+
24
+ def test_agent_stream() -> None:
25
+ """
26
+ Integration test for the agent stream functionality.
27
+ Tests that the agent returns valid streaming responses.
28
+ """
29
+
30
+ session_service = InMemorySessionService()
31
+
32
+ session = session_service.create_session(user_id="test_user", app_name="test")
33
+ runner = Runner(agent=root_agent, session_service=session_service, app_name="test")
34
+
35
+ message = types.Content(
36
+ role="user", parts=[types.Part.from_text(text="Why is the sky blue?")]
37
+ )
38
+
39
+ events = list(
40
+ runner.run(
41
+ new_message=message,
42
+ user_id="test_user",
43
+ session_id=session.id,
44
+ run_config=RunConfig(streaming_mode=StreamingMode.SSE),
45
+ )
46
+ )
47
+ assert len(events) > 0, "Expected at least one message"
48
+
49
+ has_text_content = False
50
+ for event in events:
51
+ if (
52
+ event.content
53
+ and event.content.parts
54
+ and any(part.text for part in event.content.parts)
55
+ ):
56
+ has_text_content = True
57
+ break
58
+ assert has_text_content, "Expected at least one message with text content"
@@ -14,6 +14,7 @@ The agent implements the following architecture:
14
14
 
15
15
  ### Key Features
16
16
 
17
+ - **Built on Agent Development Kit (ADK):** ADK is a flexible, modular framework for developing and deploying AI agents. It integrates with the Google ecosystem and Gemini models, supporting various LLMs and open-source AI tools, enabling both simple and complex agent architectures.
17
18
  - **Flexible Datastore Options:** Choose between Vertex AI Search or Vertex AI Vector Search for efficient data storage and retrieval based on your specific needs.
18
19
  - **Automated Data Ingestion Pipeline:** Automates the process of ingesting data from input sources.
19
20
  - **Custom Embeddings:** Generates embeddings using Vertex AI Embeddings and incorporates them into your data for enhanced semantic search.
@@ -17,25 +17,22 @@ import os
17
17
 
18
18
  import google
19
19
  import vertexai
20
- from langchain_core.documents import Document
21
- from langchain_core.messages import BaseMessage
22
- from langchain_core.runnables import RunnableConfig
23
- from langchain_core.tools import tool
24
- from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
25
- from langgraph.graph import END, MessagesState, StateGraph
26
- from langgraph.prebuilt import ToolNode
20
+ from google.adk.agents import Agent
21
+ from langchain_google_vertexai import VertexAIEmbeddings
27
22
 
28
23
  from app.retrievers import get_compressor, get_retriever
29
- from app.templates import format_docs, inspect_conversation_template, rag_template
24
+ from app.templates import format_docs
30
25
 
31
26
  EMBEDDING_MODEL = "text-embedding-005"
32
27
  LOCATION = "us-central1"
33
28
  LLM = "gemini-2.0-flash-001"
34
29
 
35
- # Initialize Google Cloud and Vertex AI
36
30
  credentials, project_id = google.auth.default()
37
- vertexai.init(project=project_id, location=LOCATION)
31
+ os.environ.setdefault("GOOGLE_CLOUD_PROJECT", project_id)
32
+ os.environ.setdefault("GOOGLE_CLOUD_LOCATION", LOCATION)
33
+ os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "True")
38
34
 
35
+ vertexai.init(project=project_id, location=LOCATION)
39
36
  embedding = VertexAIEmbeddings(
40
37
  project=project_id, location=LOCATION, model_name=EMBEDDING_MODEL
41
38
  )
@@ -45,7 +42,7 @@ EMBEDDING_COLUMN = "embedding"
45
42
  TOP_K = 5
46
43
 
47
44
  data_store_region = os.getenv("DATA_STORE_REGION", "us")
48
- data_store_id = os.getenv("DATA_STORE_ID", "sample-datastore")
45
+ data_store_id = os.getenv("DATA_STORE_ID", "{{cookiecutter.project_name}}-datastore")
49
46
 
50
47
  retriever = get_retriever(
51
48
  project_id=project_id,
@@ -56,9 +53,15 @@ retriever = get_retriever(
56
53
  max_documents=10,
57
54
  )
58
55
  {% elif cookiecutter.datastore_type == "vertex_ai_vector_search" %}
59
- vector_search_index = os.getenv("VECTOR_SEARCH_INDEX")
60
- vector_search_index_endpoint = os.getenv("VECTOR_SEARCH_INDEX_ENDPOINT")
61
- vector_search_bucket = os.getenv("VECTOR_SEARCH_BUCKET")
56
+ vector_search_index = os.getenv(
57
+ "VECTOR_SEARCH_INDEX", "{{cookiecutter.project_name}}-vector-search"
58
+ )
59
+ vector_search_index_endpoint = os.getenv(
60
+ "VECTOR_SEARCH_INDEX_ENDPOINT", "{{cookiecutter.project_name}}-vector-search-endpoint"
61
+ )
62
+ vector_search_bucket = os.getenv(
63
+ "VECTOR_SEARCH_BUCKET", f"{project_id}-{{cookiecutter.project_name}}-vs"
64
+ )
62
65
 
63
66
  retriever = get_retriever(
64
67
  project_id=project_id,
@@ -74,8 +77,7 @@ compressor = get_compressor(
74
77
  )
75
78
 
76
79
 
77
- @tool(response_format="content_and_artifact")
78
- def retrieve_docs(query: str) -> tuple[str, list[Document]]:
80
+ def retrieve_docs(query: str) -> str:
79
81
  """
80
82
  Useful for retrieving relevant documents based on a query.
81
83
  Use this when you need additional information to answer a question.
@@ -84,78 +86,31 @@ def retrieve_docs(query: str) -> tuple[str, list[Document]]:
84
86
  query (str): The user's question or search query.
85
87
 
86
88
  Returns:
87
- List[Document]: A list of the top-ranked Document objects, limited to TOP_K (5) results.
88
- """
89
- # Use the retriever to fetch relevant documents based on the query
90
- retrieved_docs = retriever.invoke(query)
91
- # Re-rank docs with Vertex AI Rank for better relevance
92
- ranked_docs = compressor.compress_documents(documents=retrieved_docs, query=query)
93
- # Format ranked documents into a consistent structure for LLM consumption
94
- formatted_docs = format_docs.format(docs=ranked_docs)
95
- return (formatted_docs, ranked_docs)
96
-
97
-
98
- @tool
99
- def should_continue() -> None:
100
- """
101
- Use this tool if you determine that you have enough context to respond to the questions of the user.
89
+ str: Formatted string containing relevant document content retrieved and ranked based on the query.
102
90
  """
103
- return None
104
-
105
-
106
- tools = [retrieve_docs, should_continue]
107
-
108
- llm = ChatVertexAI(model=LLM, temperature=0, max_tokens=1024, streaming=True)
109
-
110
- # Set up conversation inspector
111
- inspect_conversation = inspect_conversation_template | llm.bind_tools(
112
- tools, tool_choice="any"
91
+ try:
92
+ # Use the retriever to fetch relevant documents based on the query
93
+ retrieved_docs = retriever.invoke(query)
94
+ # Re-rank docs with Vertex AI Rank for better relevance
95
+ ranked_docs = compressor.compress_documents(
96
+ documents=retrieved_docs, query=query
97
+ )
98
+ # Format ranked documents into a consistent structure for LLM consumption
99
+ formatted_docs = format_docs.format(docs=ranked_docs)
100
+ except Exception as e:
101
+ return f"Calling retrieval tool with query:\n\n{query}\n\nraised the following error:\n\n{type(e)}: {e}"
102
+
103
+ return formatted_docs
104
+
105
+
106
+ instruction = """You are an AI assistant for question-answering tasks.
107
+ Answer to the best of your ability using the context provided.
108
+ Leverage the Tools you are provided to answer questions.
109
+ If you already know the answer to a question, you can respond directly without using the tools."""
110
+
111
+ root_agent = Agent(
112
+ name="root_agent",
113
+ model="gemini-2.0-flash",
114
+ instruction=instruction,
115
+ tools=[retrieve_docs],
113
116
  )
114
-
115
- # Set up response chain
116
- response_chain = rag_template | llm
117
-
118
-
119
- def inspect_conversation_node(
120
- state: MessagesState, config: RunnableConfig
121
- ) -> dict[str, BaseMessage]:
122
- """Inspects the conversation state and returns the next message using the conversation inspector."""
123
- response = inspect_conversation.invoke(state, config)
124
- return {"messages": response}
125
-
126
-
127
- def generate_node(
128
- state: MessagesState, config: RunnableConfig
129
- ) -> dict[str, BaseMessage]:
130
- """Generates a response using the RAG template and returns it as a message."""
131
- response = response_chain.invoke(state, config)
132
- return {"messages": response}
133
-
134
-
135
- # Flow:
136
- # 1. Start with agent node that inspects conversation using inspect_conversation_node
137
- # 2. Agent node connects to tools node which can either:
138
- # - Retrieve relevant docs using retrieve_docs tool
139
- # - End tool usage with should_continue tool
140
- # 3. Tools node connects to generate node which produces final response
141
- # 4. Generate node connects to END to complete the workflow
142
-
143
- workflow = StateGraph(MessagesState)
144
- workflow.add_node("agent", inspect_conversation_node)
145
- workflow.add_node("generate", generate_node)
146
- workflow.set_entry_point("agent")
147
-
148
- workflow.add_node(
149
- "tools",
150
- ToolNode(
151
- tools=tools,
152
- # With False, tool errors won't be caught by LangGraph
153
- handle_tool_errors=True,
154
- ),
155
- )
156
- workflow.add_edge("agent", "tools")
157
- workflow.add_edge("tools", "generate")
158
-
159
- workflow.add_edge("generate", END)
160
-
161
- agent = workflow.compile()
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from langchain_core.prompts import (
16
- ChatPromptTemplate,
17
- MessagesPlaceholder,
18
16
  PromptTemplate,
19
17
  )
20
18
 
@@ -28,26 +26,3 @@ format_docs = PromptTemplate.from_template(
28
26
  """,
29
27
  template_format="jinja2",
30
28
  )
31
-
32
- inspect_conversation_template = ChatPromptTemplate.from_messages(
33
- [
34
- (
35
- "system",
36
- """You are an AI assistant tasked with analyzing the conversation"""
37
- """ and determining the best course of action.""",
38
- ),
39
- MessagesPlaceholder(variable_name="messages"),
40
- ]
41
- )
42
-
43
- rag_template = ChatPromptTemplate.from_messages(
44
- [
45
- (
46
- "system",
47
- """You are an AI assistant for question-answering tasks."""
48
- """ Answer to the best of your ability using the context provided."""
49
- """ Leverage the Tools you are provided to answer questions.""",
50
- ),
51
- MessagesPlaceholder(variable_name="messages"),
52
- ]
53
- )