agent-starter-pack 0.18.2__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_starter_pack/agents/{langgraph_base_react → adk_a2a_base}/.template/templateconfig.yaml +5 -12
- agent_starter_pack/agents/adk_a2a_base/README.md +37 -0
- agent_starter_pack/{frontends/streamlit/frontend/style/app_markdown.py → agents/adk_a2a_base/app/__init__.py} +3 -23
- agent_starter_pack/agents/adk_a2a_base/app/agent.py +70 -0
- agent_starter_pack/agents/adk_a2a_base/notebooks/adk_a2a_app_testing.ipynb +583 -0
- agent_starter_pack/agents/{crewai_coding_crew/notebooks/evaluating_crewai_agent.ipynb → adk_a2a_base/notebooks/evaluating_adk_agent.ipynb} +163 -199
- agent_starter_pack/agents/adk_a2a_base/tests/integration/test_agent.py +58 -0
- agent_starter_pack/agents/adk_base/app/__init__.py +2 -2
- agent_starter_pack/agents/adk_base/app/agent.py +3 -0
- agent_starter_pack/agents/adk_base/notebooks/adk_app_testing.ipynb +13 -28
- agent_starter_pack/agents/adk_live/app/__init__.py +17 -0
- agent_starter_pack/agents/adk_live/app/agent.py +3 -0
- agent_starter_pack/agents/agentic_rag/app/__init__.py +2 -2
- agent_starter_pack/agents/agentic_rag/app/agent.py +3 -0
- agent_starter_pack/agents/agentic_rag/notebooks/adk_app_testing.ipynb +13 -28
- agent_starter_pack/agents/{crewai_coding_crew → langgraph_base}/.template/templateconfig.yaml +12 -9
- agent_starter_pack/agents/langgraph_base/README.md +30 -0
- agent_starter_pack/agents/langgraph_base/app/__init__.py +17 -0
- agent_starter_pack/agents/{langgraph_base_react → langgraph_base}/app/agent.py +4 -4
- agent_starter_pack/agents/{langgraph_base_react → langgraph_base}/tests/integration/test_agent.py +1 -1
- agent_starter_pack/base_template/.gitignore +4 -2
- agent_starter_pack/base_template/Makefile +110 -16
- agent_starter_pack/base_template/README.md +97 -12
- agent_starter_pack/base_template/deployment/terraform/dev/apis.tf +4 -6
- agent_starter_pack/base_template/deployment/terraform/dev/providers.tf +5 -1
- agent_starter_pack/base_template/deployment/terraform/dev/variables.tf +5 -3
- agent_starter_pack/base_template/deployment/terraform/dev/{% if cookiecutter.is_adk %}telemetry.tf{% else %}unused_telemetry.tf{% endif %} +193 -0
- agent_starter_pack/base_template/deployment/terraform/github.tf +16 -9
- agent_starter_pack/base_template/deployment/terraform/locals.tf +7 -7
- agent_starter_pack/base_template/deployment/terraform/providers.tf +5 -1
- agent_starter_pack/base_template/deployment/terraform/sql/completions.sql +138 -0
- agent_starter_pack/base_template/deployment/terraform/storage.tf +0 -9
- agent_starter_pack/base_template/deployment/terraform/variables.tf +15 -19
- agent_starter_pack/base_template/deployment/terraform/{% if cookiecutter.cicd_runner == 'google_cloud_build' %}build_triggers.tf{% else %}unused_build_triggers.tf{% endif %} +20 -22
- agent_starter_pack/base_template/deployment/terraform/{% if cookiecutter.is_adk %}telemetry.tf{% else %}unused_telemetry.tf{% endif %} +206 -0
- agent_starter_pack/base_template/pyproject.toml +5 -17
- agent_starter_pack/base_template/{% if cookiecutter.cicd_runner == 'github_actions' %}.github{% else %}unused_github{% endif %}/workflows/deploy-to-prod.yaml +19 -4
- agent_starter_pack/base_template/{% if cookiecutter.cicd_runner == 'github_actions' %}.github{% else %}unused_github{% endif %}/workflows/staging.yaml +36 -11
- agent_starter_pack/base_template/{% if cookiecutter.cicd_runner == 'google_cloud_build' %}.cloudbuild{% else %}unused_.cloudbuild{% endif %}/deploy-to-prod.yaml +24 -5
- agent_starter_pack/base_template/{% if cookiecutter.cicd_runner == 'google_cloud_build' %}.cloudbuild{% else %}unused_.cloudbuild{% endif %}/staging.yaml +44 -9
- agent_starter_pack/base_template/{{cookiecutter.agent_directory}}/app_utils/telemetry.py +96 -0
- agent_starter_pack/base_template/{{cookiecutter.agent_directory}}/{utils → app_utils}/typing.py +4 -6
- agent_starter_pack/{agents/crewai_coding_crew/app/crew/config/agents.yaml → base_template/{{cookiecutter.agent_directory}}/app_utils/{% if cookiecutter.is_a2a and cookiecutter.agent_name == 'langgraph_base' %}converters{% else %}unused_converters{% endif %}/__init__.py } +9 -23
- agent_starter_pack/base_template/{{cookiecutter.agent_directory}}/app_utils/{% if cookiecutter.is_a2a and cookiecutter.agent_name == 'langgraph_base' %}converters{% else %}unused_converters{% endif %}/part_converter.py +138 -0
- agent_starter_pack/base_template/{{cookiecutter.agent_directory}}/app_utils/{% if cookiecutter.is_a2a and cookiecutter.agent_name == 'langgraph_base' %}executor{% else %}unused_executor{% endif %}/__init__.py +13 -0
- agent_starter_pack/base_template/{{cookiecutter.agent_directory}}/app_utils/{% if cookiecutter.is_a2a and cookiecutter.agent_name == 'langgraph_base' %}executor{% else %}unused_executor{% endif %}/a2a_agent_executor.py +265 -0
- agent_starter_pack/base_template/{{cookiecutter.agent_directory}}/app_utils/{% if cookiecutter.is_a2a and cookiecutter.agent_name == 'langgraph_base' %}executor{% else %}unused_executor{% endif %}/task_result_aggregator.py +152 -0
- agent_starter_pack/cli/commands/create.py +40 -4
- agent_starter_pack/cli/commands/enhance.py +1 -1
- agent_starter_pack/cli/commands/register_gemini_enterprise.py +1070 -0
- agent_starter_pack/cli/main.py +2 -0
- agent_starter_pack/cli/utils/cicd.py +20 -4
- agent_starter_pack/cli/utils/template.py +257 -25
- agent_starter_pack/deployment_targets/agent_engine/tests/integration/test_agent_engine_app.py +113 -16
- agent_starter_pack/deployment_targets/agent_engine/tests/load_test/README.md +2 -2
- agent_starter_pack/deployment_targets/agent_engine/tests/load_test/load_test.py +178 -9
- agent_starter_pack/deployment_targets/agent_engine/tests/{% if cookiecutter.is_a2a %}helpers.py{% else %}unused_helpers.py{% endif %} +138 -0
- agent_starter_pack/deployment_targets/agent_engine/{{cookiecutter.agent_directory}}/agent_engine_app.py +193 -307
- agent_starter_pack/deployment_targets/agent_engine/{{cookiecutter.agent_directory}}/app_utils/deploy.py +414 -0
- agent_starter_pack/deployment_targets/agent_engine/{{cookiecutter.agent_directory}}/{utils → app_utils}/{% if cookiecutter.is_adk_live %}expose_app.py{% else %}unused_expose_app.py{% endif %} +13 -14
- agent_starter_pack/deployment_targets/cloud_run/Dockerfile +4 -1
- agent_starter_pack/deployment_targets/cloud_run/deployment/terraform/dev/service.tf +85 -86
- agent_starter_pack/deployment_targets/cloud_run/deployment/terraform/service.tf +139 -107
- agent_starter_pack/deployment_targets/cloud_run/tests/integration/test_server_e2e.py +228 -12
- agent_starter_pack/deployment_targets/cloud_run/tests/load_test/README.md +4 -4
- agent_starter_pack/deployment_targets/cloud_run/tests/load_test/load_test.py +92 -12
- agent_starter_pack/deployment_targets/cloud_run/{{cookiecutter.agent_directory}}/{server.py → fast_api_app.py} +194 -121
- agent_starter_pack/frontends/adk_live_react/frontend/package-lock.json +18 -18
- agent_starter_pack/frontends/adk_live_react/frontend/src/multimodal-live-types.ts +5 -3
- agent_starter_pack/resources/docs/adk-cheatsheet.md +198 -41
- agent_starter_pack/resources/locks/uv-adk_a2a_base-agent_engine.lock +4966 -0
- agent_starter_pack/resources/locks/uv-adk_a2a_base-cloud_run.lock +5011 -0
- agent_starter_pack/resources/locks/uv-adk_base-agent_engine.lock +1443 -709
- agent_starter_pack/resources/locks/uv-adk_base-cloud_run.lock +1058 -874
- agent_starter_pack/resources/locks/uv-adk_live-agent_engine.lock +1443 -709
- agent_starter_pack/resources/locks/uv-adk_live-cloud_run.lock +1058 -874
- agent_starter_pack/resources/locks/uv-agentic_rag-agent_engine.lock +1568 -749
- agent_starter_pack/resources/locks/uv-agentic_rag-cloud_run.lock +1123 -929
- agent_starter_pack/resources/locks/{uv-langgraph_base_react-agent_engine.lock → uv-langgraph_base-agent_engine.lock} +1714 -1689
- agent_starter_pack/resources/locks/{uv-langgraph_base_react-cloud_run.lock → uv-langgraph_base-cloud_run.lock} +1285 -2374
- agent_starter_pack/utils/watch_and_rebuild.py +1 -1
- {agent_starter_pack-0.18.2.dist-info → agent_starter_pack-0.21.0.dist-info}/METADATA +3 -6
- {agent_starter_pack-0.18.2.dist-info → agent_starter_pack-0.21.0.dist-info}/RECORD +89 -93
- agent_starter_pack-0.21.0.dist-info/entry_points.txt +2 -0
- llm.txt +4 -5
- agent_starter_pack/agents/crewai_coding_crew/README.md +0 -34
- agent_starter_pack/agents/crewai_coding_crew/app/agent.py +0 -47
- agent_starter_pack/agents/crewai_coding_crew/app/crew/config/tasks.yaml +0 -37
- agent_starter_pack/agents/crewai_coding_crew/app/crew/crew.py +0 -71
- agent_starter_pack/agents/crewai_coding_crew/tests/integration/test_agent.py +0 -47
- agent_starter_pack/agents/langgraph_base_react/README.md +0 -9
- agent_starter_pack/agents/langgraph_base_react/notebooks/evaluating_langgraph_agent.ipynb +0 -1574
- agent_starter_pack/base_template/deployment/terraform/dev/log_sinks.tf +0 -69
- agent_starter_pack/base_template/deployment/terraform/log_sinks.tf +0 -79
- agent_starter_pack/base_template/{{cookiecutter.agent_directory}}/utils/tracing.py +0 -155
- agent_starter_pack/cli/utils/register_gemini_enterprise.py +0 -406
- agent_starter_pack/deployment_targets/agent_engine/deployment/terraform/{% if not cookiecutter.is_adk_live %}service.tf{% else %}unused_service.tf{% endif %} +0 -82
- agent_starter_pack/deployment_targets/agent_engine/notebooks/intro_agent_engine.ipynb +0 -1025
- agent_starter_pack/deployment_targets/agent_engine/{{cookiecutter.agent_directory}}/utils/deployment.py +0 -99
- agent_starter_pack/frontends/streamlit/frontend/side_bar.py +0 -214
- agent_starter_pack/frontends/streamlit/frontend/streamlit_app.py +0 -265
- agent_starter_pack/frontends/streamlit/frontend/utils/chat_utils.py +0 -67
- agent_starter_pack/frontends/streamlit/frontend/utils/local_chat_history.py +0 -127
- agent_starter_pack/frontends/streamlit/frontend/utils/message_editing.py +0 -59
- agent_starter_pack/frontends/streamlit/frontend/utils/multimodal_utils.py +0 -217
- agent_starter_pack/frontends/streamlit/frontend/utils/stream_handler.py +0 -310
- agent_starter_pack/frontends/streamlit/frontend/utils/title_summary.py +0 -94
- agent_starter_pack/resources/locks/uv-crewai_coding_crew-agent_engine.lock +0 -6650
- agent_starter_pack/resources/locks/uv-crewai_coding_crew-cloud_run.lock +0 -7825
- agent_starter_pack-0.18.2.dist-info/entry_points.txt +0 -3
- /agent_starter_pack/agents/{crewai_coding_crew → langgraph_base}/notebooks/evaluating_langgraph_agent.ipynb +0 -0
- /agent_starter_pack/base_template/{{cookiecutter.agent_directory}}/{utils → app_utils}/gcs.py +0 -0
- {agent_starter_pack-0.18.2.dist-info → agent_starter_pack-0.21.0.dist-info}/WHEEL +0 -0
- {agent_starter_pack-0.18.2.dist-info → agent_starter_pack-0.21.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -139,6 +139,149 @@ class RemoteAgentUser(WebSocketUser):
|
|
|
139
139
|
# Set the host via command line: locust -f load_test.py --host=https://your-deployed-service.run.app
|
|
140
140
|
host = "http://localhost:8000" # Default for local testing
|
|
141
141
|
{%- else %}
|
|
142
|
+
{%- if cookiecutter.is_a2a %}
|
|
143
|
+
|
|
144
|
+
import json
|
|
145
|
+
import logging
|
|
146
|
+
import os
|
|
147
|
+
import time
|
|
148
|
+
|
|
149
|
+
from locust import HttpUser, between, task
|
|
150
|
+
|
|
151
|
+
# Configure logging
|
|
152
|
+
logging.basicConfig(
|
|
153
|
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
154
|
+
)
|
|
155
|
+
logger = logging.getLogger(__name__)
|
|
156
|
+
|
|
157
|
+
# Initialize Vertex AI and load agent config
|
|
158
|
+
with open("deployment_metadata.json") as f:
|
|
159
|
+
remote_agent_engine_id = json.load(f)["remote_agent_engine_id"]
|
|
160
|
+
|
|
161
|
+
parts = remote_agent_engine_id.split("/")
|
|
162
|
+
project_id = parts[1]
|
|
163
|
+
location = parts[3]
|
|
164
|
+
engine_id = parts[5]
|
|
165
|
+
|
|
166
|
+
# Convert remote agent engine ID to URLs
|
|
167
|
+
base_url = f"https://{location}-aiplatform.googleapis.com"
|
|
168
|
+
a2a_base_path = f"/v1beta1/projects/{project_id}/locations/{location}/reasoningEngines/{engine_id}/a2a/v1"
|
|
169
|
+
|
|
170
|
+
logger.info("Using remote agent engine ID: %s", remote_agent_engine_id)
|
|
171
|
+
logger.info("Using base URL: %s", base_url)
|
|
172
|
+
logger.info("Using API base path: %s", a2a_base_path)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class SendMessageUser(HttpUser):
|
|
176
|
+
"""Simulates a user interacting with the send message API."""
|
|
177
|
+
|
|
178
|
+
wait_time = between(1, 3) # Wait 1-3 seconds between tasks
|
|
179
|
+
host = base_url # Set the base host URL for Locust
|
|
180
|
+
|
|
181
|
+
@task
|
|
182
|
+
def send_message_and_poll(self) -> None:
|
|
183
|
+
"""Simulates a chat interaction: sends a message and polls for completion."""
|
|
184
|
+
headers = {"Content-Type": "application/json"}
|
|
185
|
+
headers["Authorization"] = f"Bearer {os.environ['_AUTH_TOKEN']}"
|
|
186
|
+
|
|
187
|
+
data = {
|
|
188
|
+
"message": {
|
|
189
|
+
"messageId": "msg-id",
|
|
190
|
+
"content": [{"text": "Hello! What's the weather in New York?"}],
|
|
191
|
+
"role": "ROLE_USER",
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
e2e_start_time = time.time()
|
|
196
|
+
with self.client.post(
|
|
197
|
+
f"{a2a_base_path}/message:send",
|
|
198
|
+
headers=headers,
|
|
199
|
+
json=data,
|
|
200
|
+
catch_response=True,
|
|
201
|
+
name="/v1/message:send",
|
|
202
|
+
) as response:
|
|
203
|
+
if response.status_code != 200:
|
|
204
|
+
response.failure(
|
|
205
|
+
f"Send failed with status code: {response.status_code}"
|
|
206
|
+
)
|
|
207
|
+
return
|
|
208
|
+
|
|
209
|
+
response.success()
|
|
210
|
+
response_data = response.json()
|
|
211
|
+
|
|
212
|
+
# Extract task ID
|
|
213
|
+
try:
|
|
214
|
+
task_id = response_data["task"]["id"]
|
|
215
|
+
except (KeyError, TypeError) as e:
|
|
216
|
+
logger.error(f"Failed to extract task ID: {e}")
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
# Poll for task completion
|
|
220
|
+
max_polls = 20 # Maximum number of poll attempts
|
|
221
|
+
poll_interval = 0.5 # Seconds between polls
|
|
222
|
+
poll_count = 0
|
|
223
|
+
|
|
224
|
+
while poll_count < max_polls:
|
|
225
|
+
poll_count += 1
|
|
226
|
+
time.sleep(poll_interval)
|
|
227
|
+
|
|
228
|
+
with self.client.get(
|
|
229
|
+
f"{a2a_base_path}/tasks/{task_id}",
|
|
230
|
+
headers=headers,
|
|
231
|
+
catch_response=True,
|
|
232
|
+
name="/v1/tasks/{id}",
|
|
233
|
+
) as poll_response:
|
|
234
|
+
if poll_response.status_code != 200:
|
|
235
|
+
poll_response.failure(
|
|
236
|
+
f"Poll failed with status code: {poll_response.status_code}"
|
|
237
|
+
)
|
|
238
|
+
return
|
|
239
|
+
|
|
240
|
+
poll_data = poll_response.json()
|
|
241
|
+
|
|
242
|
+
try:
|
|
243
|
+
task_state = poll_data["status"]["state"]
|
|
244
|
+
except (KeyError, TypeError) as e:
|
|
245
|
+
logger.error(f"Failed to extract task state: {e}")
|
|
246
|
+
poll_response.failure(f"Invalid response format: {e}")
|
|
247
|
+
return
|
|
248
|
+
|
|
249
|
+
# Check if task is complete
|
|
250
|
+
if task_state in ["TASK_STATE_COMPLETED"]:
|
|
251
|
+
poll_response.success()
|
|
252
|
+
|
|
253
|
+
# Measure end-to-end time
|
|
254
|
+
e2e_duration = (time.time() - e2e_start_time) * 1000
|
|
255
|
+
|
|
256
|
+
# Fire custom event for end-to-end metrics
|
|
257
|
+
self.environment.events.request.fire(
|
|
258
|
+
request_type="E2E",
|
|
259
|
+
name="message:send_and_complete",
|
|
260
|
+
response_time=e2e_duration,
|
|
261
|
+
response_length=len(json.dumps(poll_data)),
|
|
262
|
+
response=poll_response,
|
|
263
|
+
context={"poll_count": poll_count},
|
|
264
|
+
)
|
|
265
|
+
return
|
|
266
|
+
|
|
267
|
+
elif task_state in ["TASK_STATE_WORKING"]:
|
|
268
|
+
poll_response.success()
|
|
269
|
+
|
|
270
|
+
else:
|
|
271
|
+
poll_response.failure(f"Task failed with state: {task_state}")
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
# Timeout - task didn't complete in time
|
|
275
|
+
self.environment.events.request.fire(
|
|
276
|
+
request_type="TIMEOUT",
|
|
277
|
+
name="message:timeout",
|
|
278
|
+
response_time=(time.time() - e2e_start_time) * 1000,
|
|
279
|
+
response_length=0,
|
|
280
|
+
response=None,
|
|
281
|
+
context={"poll_count": poll_count},
|
|
282
|
+
exception=TimeoutError(f"Task did not complete after {max_polls} polls"),
|
|
283
|
+
)
|
|
284
|
+
{%- else %}
|
|
142
285
|
|
|
143
286
|
import json
|
|
144
287
|
import logging
|
|
@@ -222,6 +365,7 @@ class ChatStreamUser(HttpUser):
|
|
|
222
365
|
) as response:
|
|
223
366
|
if response.status_code == 200:
|
|
224
367
|
events = []
|
|
368
|
+
has_error = False
|
|
225
369
|
for line in response.iter_lines():
|
|
226
370
|
if line:
|
|
227
371
|
line_str = line.decode("utf-8")
|
|
@@ -236,20 +380,45 @@ class ChatStreamUser(HttpUser):
|
|
|
236
380
|
response=response,
|
|
237
381
|
context={},
|
|
238
382
|
)
|
|
383
|
+
|
|
384
|
+
# Check for error responses in the JSON payload
|
|
385
|
+
try:
|
|
386
|
+
event_data = json.loads(line_str)
|
|
387
|
+
if isinstance(event_data, dict) and "code" in event_data:
|
|
388
|
+
# Flag any non-2xx codes as errors
|
|
389
|
+
if event_data["code"] >= 400:
|
|
390
|
+
has_error = True
|
|
391
|
+
error_msg = event_data.get(
|
|
392
|
+
"message", "Unknown error"
|
|
393
|
+
)
|
|
394
|
+
response.failure(f"Error in response: {error_msg}")
|
|
395
|
+
logger.error(
|
|
396
|
+
"Received error response: code=%s, message=%s",
|
|
397
|
+
event_data["code"],
|
|
398
|
+
error_msg,
|
|
399
|
+
)
|
|
400
|
+
except json.JSONDecodeError:
|
|
401
|
+
# If it's not valid JSON, continue processing
|
|
402
|
+
pass
|
|
403
|
+
|
|
239
404
|
end_time = time.time()
|
|
240
405
|
total_time = end_time - start_time
|
|
241
|
-
|
|
242
|
-
|
|
406
|
+
|
|
407
|
+
# Only fire success event if no errors were found
|
|
408
|
+
if not has_error:
|
|
409
|
+
self.environment.events.request.fire(
|
|
410
|
+
request_type="POST",
|
|
243
411
|
{%- if cookiecutter.is_adk %}
|
|
244
|
-
|
|
412
|
+
name="/streamQuery end",
|
|
245
413
|
{%- else %}
|
|
246
|
-
|
|
414
|
+
name="/stream_messages end",
|
|
247
415
|
{%- endif %}
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
416
|
+
response_time=total_time * 1000, # Convert to milliseconds
|
|
417
|
+
response_length=len(events),
|
|
418
|
+
response=response,
|
|
419
|
+
context={},
|
|
420
|
+
)
|
|
253
421
|
else:
|
|
254
422
|
response.failure(f"Unexpected status code: {response.status_code}")
|
|
255
423
|
{%- endif %}
|
|
424
|
+
{%- endif %}
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# mypy: disable-error-code="arg-type"
|
|
16
|
+
|
|
17
|
+
"""Helper functions for testing AgentEngineApp with A2A protocol."""
|
|
18
|
+
|
|
19
|
+
import asyncio
|
|
20
|
+
import json
|
|
21
|
+
from collections.abc import Awaitable, Callable
|
|
22
|
+
from typing import TYPE_CHECKING, Any
|
|
23
|
+
|
|
24
|
+
from starlette.requests import Request
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from {{cookiecutter.agent_directory}}.agent_engine_app import AgentEngineApp
|
|
28
|
+
|
|
29
|
+
# Test constants
|
|
30
|
+
POLL_MAX_ATTEMPTS = 30
|
|
31
|
+
POLL_INTERVAL_SECONDS = 1.0
|
|
32
|
+
TEST_ARTIFACTS_BUCKET = "test-artifacts-bucket"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def receive_wrapper(data: dict[str, Any] | None) -> Callable[[], Awaitable[dict]]:
|
|
36
|
+
"""Creates a mock ASGI receive callable for testing.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
data: Dictionary to encode as JSON request body
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Async callable that returns mock ASGI receive message
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
async def receive() -> dict:
|
|
46
|
+
byte_data = json.dumps(data).encode("utf-8")
|
|
47
|
+
return {"type": "http.request", "body": byte_data, "more_body": False}
|
|
48
|
+
|
|
49
|
+
return receive
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def build_post_request(
|
|
53
|
+
data: dict[str, Any] | None = None, path_params: dict[str, str] | None = None
|
|
54
|
+
) -> Request:
|
|
55
|
+
"""Builds a mock Starlette Request object for a POST request with JSON data.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
data: JSON data to include in request body
|
|
59
|
+
path_params: Path parameters to include in request scope
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Mock Starlette Request object
|
|
63
|
+
"""
|
|
64
|
+
scope: dict[str, Any] = {
|
|
65
|
+
"type": "http",
|
|
66
|
+
"http_version": "1.1",
|
|
67
|
+
"headers": [(b"content-type", b"application/json")],
|
|
68
|
+
"app": None,
|
|
69
|
+
}
|
|
70
|
+
if path_params:
|
|
71
|
+
scope["path_params"] = path_params
|
|
72
|
+
receiver = receive_wrapper(data)
|
|
73
|
+
return Request(scope, receiver)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def build_get_request(path_params: dict[str, str] | None) -> Request:
|
|
77
|
+
"""Builds a mock Starlette Request object for a GET request.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
path_params: Path parameters to include in request scope
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Mock Starlette Request object
|
|
84
|
+
"""
|
|
85
|
+
scope: dict[str, Any] = {
|
|
86
|
+
"type": "http",
|
|
87
|
+
"http_version": "1.1",
|
|
88
|
+
"query_string": b"",
|
|
89
|
+
"app": None,
|
|
90
|
+
}
|
|
91
|
+
if path_params:
|
|
92
|
+
scope["path_params"] = path_params
|
|
93
|
+
|
|
94
|
+
async def receive() -> dict:
|
|
95
|
+
return {"type": "http.disconnect"}
|
|
96
|
+
|
|
97
|
+
return Request(scope, receive)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
async def poll_task_completion(
|
|
101
|
+
agent_app: "AgentEngineApp",
|
|
102
|
+
task_id: str,
|
|
103
|
+
max_attempts: int = POLL_MAX_ATTEMPTS,
|
|
104
|
+
interval: float = POLL_INTERVAL_SECONDS,
|
|
105
|
+
) -> dict[str, Any]:
|
|
106
|
+
"""Poll for task completion and return final response.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
agent_app: The AgentEngineApp instance to poll
|
|
110
|
+
task_id: The task ID to poll for
|
|
111
|
+
max_attempts: Maximum number of polling attempts
|
|
112
|
+
interval: Seconds to wait between polls
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Final task response when completed
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
AssertionError: If task fails or times out
|
|
119
|
+
"""
|
|
120
|
+
for _ in range(max_attempts):
|
|
121
|
+
poll_request = build_get_request({"id": task_id})
|
|
122
|
+
response = await agent_app.on_get_task(
|
|
123
|
+
request=poll_request,
|
|
124
|
+
context=None,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
task_state = response.get("status", {}).get("state", "")
|
|
128
|
+
|
|
129
|
+
if task_state == "TASK_STATE_COMPLETED":
|
|
130
|
+
return response
|
|
131
|
+
elif task_state == "TASK_STATE_FAILED":
|
|
132
|
+
raise AssertionError(f"Task failed: {response}")
|
|
133
|
+
|
|
134
|
+
await asyncio.sleep(interval)
|
|
135
|
+
|
|
136
|
+
raise AssertionError(
|
|
137
|
+
f"Task did not complete within {max_attempts * interval} seconds"
|
|
138
|
+
)
|