agent-starter-pack 0.18.2__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agent-starter-pack might be problematic. Click here for more details.
- agent_starter_pack/agents/adk_a2a_base/.template/templateconfig.yaml +22 -0
- agent_starter_pack/agents/adk_a2a_base/README.md +22 -0
- agent_starter_pack/agents/adk_a2a_base/app/__init__.py +17 -0
- agent_starter_pack/agents/adk_a2a_base/app/agent.py +70 -0
- agent_starter_pack/agents/adk_a2a_base/notebooks/adk_a2a_app_testing.ipynb +600 -0
- agent_starter_pack/agents/adk_a2a_base/notebooks/evaluating_adk_agent.ipynb +1535 -0
- agent_starter_pack/agents/adk_a2a_base/tests/integration/test_agent.py +58 -0
- agent_starter_pack/base_template/.gitignore +1 -1
- agent_starter_pack/base_template/Makefile +11 -6
- agent_starter_pack/base_template/README.md +1 -1
- agent_starter_pack/base_template/{% if cookiecutter.cicd_runner == 'github_actions' %}.github{% else %}unused_github{% endif %}/workflows/deploy-to-prod.yaml +10 -2
- agent_starter_pack/base_template/{% if cookiecutter.cicd_runner == 'github_actions' %}.github{% else %}unused_github{% endif %}/workflows/staging.yaml +26 -5
- agent_starter_pack/base_template/{% if cookiecutter.cicd_runner == 'google_cloud_build' %}.cloudbuild{% else %}unused_.cloudbuild{% endif %}/deploy-to-prod.yaml +18 -3
- agent_starter_pack/base_template/{% if cookiecutter.cicd_runner == 'google_cloud_build' %}.cloudbuild{% else %}unused_.cloudbuild{% endif %}/staging.yaml +34 -3
- agent_starter_pack/cli/utils/cicd.py +20 -4
- agent_starter_pack/cli/utils/register_gemini_enterprise.py +79 -84
- agent_starter_pack/cli/utils/template.py +2 -0
- agent_starter_pack/deployment_targets/agent_engine/tests/integration/test_agent_engine_app.py +104 -2
- agent_starter_pack/deployment_targets/agent_engine/tests/load_test/load_test.py +144 -0
- agent_starter_pack/deployment_targets/agent_engine/tests/{% if cookiecutter.is_adk_a2a %}helpers.py{% else %}unused_helpers.py{% endif %} +138 -0
- agent_starter_pack/deployment_targets/agent_engine/{{cookiecutter.agent_directory}}/agent_engine_app.py +88 -4
- agent_starter_pack/deployment_targets/agent_engine/{{cookiecutter.agent_directory}}/utils/deployment.py +4 -0
- agent_starter_pack/deployment_targets/cloud_run/Dockerfile +3 -0
- agent_starter_pack/deployment_targets/cloud_run/deployment/terraform/dev/service.tf +7 -0
- agent_starter_pack/deployment_targets/cloud_run/deployment/terraform/service.tf +16 -2
- agent_starter_pack/deployment_targets/cloud_run/tests/integration/test_server_e2e.py +218 -1
- agent_starter_pack/deployment_targets/cloud_run/tests/load_test/README.md +2 -2
- agent_starter_pack/deployment_targets/cloud_run/tests/load_test/load_test.py +51 -4
- agent_starter_pack/deployment_targets/cloud_run/{{cookiecutter.agent_directory}}/server.py +66 -0
- agent_starter_pack/resources/locks/uv-adk_a2a_base-agent_engine.lock +4224 -0
- agent_starter_pack/resources/locks/uv-adk_a2a_base-cloud_run.lock +4819 -0
- agent_starter_pack/resources/locks/uv-adk_base-agent_engine.lock +230 -236
- agent_starter_pack/resources/locks/uv-adk_base-cloud_run.lock +290 -296
- agent_starter_pack/resources/locks/uv-adk_live-agent_engine.lock +230 -236
- agent_starter_pack/resources/locks/uv-adk_live-cloud_run.lock +290 -296
- agent_starter_pack/resources/locks/uv-agentic_rag-agent_engine.lock +234 -239
- agent_starter_pack/resources/locks/uv-agentic_rag-cloud_run.lock +294 -299
- agent_starter_pack/resources/locks/uv-crewai_coding_crew-agent_engine.lock +221 -228
- agent_starter_pack/resources/locks/uv-crewai_coding_crew-cloud_run.lock +279 -286
- agent_starter_pack/resources/locks/uv-langgraph_base_react-agent_engine.lock +226 -233
- agent_starter_pack/resources/locks/uv-langgraph_base_react-cloud_run.lock +298 -305
- {agent_starter_pack-0.18.2.dist-info → agent_starter_pack-0.19.0.dist-info}/METADATA +2 -1
- {agent_starter_pack-0.18.2.dist-info → agent_starter_pack-0.19.0.dist-info}/RECORD +46 -36
- {agent_starter_pack-0.18.2.dist-info → agent_starter_pack-0.19.0.dist-info}/WHEEL +0 -0
- {agent_starter_pack-0.18.2.dist-info → agent_starter_pack-0.19.0.dist-info}/entry_points.txt +0 -0
- {agent_starter_pack-0.18.2.dist-info → agent_starter_pack-0.19.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -226,6 +226,20 @@ from typing import Any
|
|
|
226
226
|
|
|
227
227
|
import pytest
|
|
228
228
|
import requests
|
|
229
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
230
|
+
from a2a.types import (
|
|
231
|
+
JSONRPCErrorResponse,
|
|
232
|
+
Message,
|
|
233
|
+
MessageSendParams,
|
|
234
|
+
Part,
|
|
235
|
+
Role,
|
|
236
|
+
SendMessageRequest,
|
|
237
|
+
SendMessageResponse,
|
|
238
|
+
SendStreamingMessageRequest,
|
|
239
|
+
SendStreamingMessageResponse,
|
|
240
|
+
TextPart,
|
|
241
|
+
)
|
|
242
|
+
{%- endif %}
|
|
229
243
|
from requests.exceptions import RequestException
|
|
230
244
|
|
|
231
245
|
# Configure logging
|
|
@@ -233,7 +247,10 @@ logging.basicConfig(level=logging.INFO)
|
|
|
233
247
|
logger = logging.getLogger(__name__)
|
|
234
248
|
|
|
235
249
|
BASE_URL = "http://127.0.0.1:8000/"
|
|
236
|
-
{%- if cookiecutter.
|
|
250
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
251
|
+
A2A_RPC_URL = BASE_URL + "a2a/{{cookiecutter.agent_directory}}/"
|
|
252
|
+
AGENT_CARD_URL = A2A_RPC_URL + ".well-known/agent-card.json"
|
|
253
|
+
{%- elif cookiecutter.is_adk %}
|
|
237
254
|
STREAM_URL = BASE_URL + "run_sse"
|
|
238
255
|
{%- else %}
|
|
239
256
|
STREAM_URL = BASE_URL + "stream_messages"
|
|
@@ -292,7 +309,11 @@ def wait_for_server(timeout: int = 90, interval: int = 1) -> bool:
|
|
|
292
309
|
start_time = time.time()
|
|
293
310
|
while time.time() - start_time < timeout:
|
|
294
311
|
try:
|
|
312
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
313
|
+
response = requests.get(AGENT_CARD_URL, timeout=10)
|
|
314
|
+
{%- else %}
|
|
295
315
|
response = requests.get("http://127.0.0.1:8000/docs", timeout=10)
|
|
316
|
+
{%- endif %}
|
|
296
317
|
if response.status_code == 200:
|
|
297
318
|
logger.info("Server is ready")
|
|
298
319
|
return True
|
|
@@ -323,6 +344,82 @@ def server_fixture(request: Any) -> Iterator[subprocess.Popen[str]]:
|
|
|
323
344
|
|
|
324
345
|
|
|
325
346
|
def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None:
|
|
347
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
348
|
+
"""Test the chat stream functionality using A2A JSON-RPC protocol."""
|
|
349
|
+
logger.info("Starting chat stream test")
|
|
350
|
+
|
|
351
|
+
message = Message(
|
|
352
|
+
message_id=f"msg-user-{uuid.uuid4()}",
|
|
353
|
+
role=Role.user,
|
|
354
|
+
parts=[Part(root=TextPart(text="What's the weather in San Francisco?"))],
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
request = SendStreamingMessageRequest(
|
|
358
|
+
id="test-req-001",
|
|
359
|
+
params=MessageSendParams(message=message),
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
# Send the request
|
|
363
|
+
response = requests.post(
|
|
364
|
+
A2A_RPC_URL,
|
|
365
|
+
headers=HEADERS,
|
|
366
|
+
json=request.model_dump(mode="json", exclude_none=True),
|
|
367
|
+
stream=True,
|
|
368
|
+
timeout=60,
|
|
369
|
+
)
|
|
370
|
+
assert response.status_code == 200
|
|
371
|
+
|
|
372
|
+
# Parse streaming JSON-RPC responses
|
|
373
|
+
responses: list[SendStreamingMessageResponse] = []
|
|
374
|
+
|
|
375
|
+
for line in response.iter_lines():
|
|
376
|
+
if line:
|
|
377
|
+
line_str = line.decode("utf-8")
|
|
378
|
+
if line_str.startswith("data: "):
|
|
379
|
+
event_json = line_str[6:]
|
|
380
|
+
json_data = json.loads(event_json)
|
|
381
|
+
streaming_response = SendStreamingMessageResponse.model_validate(
|
|
382
|
+
json_data
|
|
383
|
+
)
|
|
384
|
+
responses.append(streaming_response)
|
|
385
|
+
|
|
386
|
+
assert responses, "No responses received from stream"
|
|
387
|
+
|
|
388
|
+
# Check for final status update
|
|
389
|
+
final_responses = [
|
|
390
|
+
r.root
|
|
391
|
+
for r in responses
|
|
392
|
+
if hasattr(r.root, "result")
|
|
393
|
+
and hasattr(r.root.result, "final")
|
|
394
|
+
and r.root.result.final is True
|
|
395
|
+
]
|
|
396
|
+
assert final_responses, "No final response received"
|
|
397
|
+
|
|
398
|
+
final_response = final_responses[-1]
|
|
399
|
+
assert final_response.result.kind == "status-update"
|
|
400
|
+
assert hasattr(final_response.result, "status")
|
|
401
|
+
assert final_response.result.status.state == "completed"
|
|
402
|
+
|
|
403
|
+
# Check for artifact content
|
|
404
|
+
artifact_responses = [
|
|
405
|
+
r.root
|
|
406
|
+
for r in responses
|
|
407
|
+
if hasattr(r.root, "result") and r.root.result.kind == "artifact-update"
|
|
408
|
+
]
|
|
409
|
+
assert artifact_responses, "No artifact content received in stream"
|
|
410
|
+
|
|
411
|
+
# Verify text content is in the artifact
|
|
412
|
+
artifact_response = artifact_responses[-1]
|
|
413
|
+
assert hasattr(artifact_response.result, "artifact")
|
|
414
|
+
artifact = artifact_response.result.artifact
|
|
415
|
+
assert artifact.parts, "Artifact has no parts"
|
|
416
|
+
|
|
417
|
+
has_text = any(
|
|
418
|
+
part.root.kind == "text" and hasattr(part.root, "text") and part.root.text
|
|
419
|
+
for part in artifact.parts
|
|
420
|
+
)
|
|
421
|
+
assert has_text, "No text content found in artifact"
|
|
422
|
+
{%- else %}
|
|
326
423
|
"""Test the chat stream functionality."""
|
|
327
424
|
logger.info("Starting chat stream test")
|
|
328
425
|
{% if cookiecutter.is_adk %}
|
|
@@ -417,6 +514,94 @@ def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None:
|
|
|
417
514
|
break
|
|
418
515
|
assert has_content, "At least one message should have content"
|
|
419
516
|
{%- endif %}
|
|
517
|
+
{%- endif %}
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def test_chat_non_streaming(server_fixture: subprocess.Popen[str]) -> None:
|
|
524
|
+
"""Test the non-streaming chat functionality using A2A JSON-RPC protocol."""
|
|
525
|
+
logger.info("Starting non-streaming chat test")
|
|
526
|
+
|
|
527
|
+
message = Message(
|
|
528
|
+
message_id=f"msg-user-{uuid.uuid4()}",
|
|
529
|
+
role=Role.user,
|
|
530
|
+
parts=[Part(root=TextPart(text="What's the weather in San Francisco?"))],
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
request = SendMessageRequest(
|
|
534
|
+
id="test-req-002",
|
|
535
|
+
params=MessageSendParams(message=message),
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
response = requests.post(
|
|
539
|
+
A2A_RPC_URL,
|
|
540
|
+
headers=HEADERS,
|
|
541
|
+
json=request.model_dump(mode="json", exclude_none=True),
|
|
542
|
+
timeout=60,
|
|
543
|
+
)
|
|
544
|
+
assert response.status_code == 200
|
|
545
|
+
|
|
546
|
+
# Parse the single JSON-RPC response
|
|
547
|
+
response_data = response.json()
|
|
548
|
+
message_response = SendMessageResponse.model_validate(response_data)
|
|
549
|
+
logger.info(f"Received response: {message_response}")
|
|
550
|
+
|
|
551
|
+
# For non-streaming, the result is a Task object
|
|
552
|
+
json_rpc_resp = message_response.root
|
|
553
|
+
assert hasattr(json_rpc_resp, "result")
|
|
554
|
+
task = json_rpc_resp.result
|
|
555
|
+
assert task.kind == "task"
|
|
556
|
+
assert hasattr(task, "status")
|
|
557
|
+
assert task.status.state == "completed"
|
|
558
|
+
|
|
559
|
+
# Check that we got artifacts (the final agent output)
|
|
560
|
+
assert hasattr(task, "artifacts")
|
|
561
|
+
assert task.artifacts, "No artifacts in task"
|
|
562
|
+
|
|
563
|
+
# Verify we got text content in the artifact
|
|
564
|
+
artifact = task.artifacts[0]
|
|
565
|
+
assert artifact.parts, "Artifact has no parts"
|
|
566
|
+
|
|
567
|
+
has_text = any(
|
|
568
|
+
part.root.kind == "text" and hasattr(part.root, "text") and part.root.text
|
|
569
|
+
for part in artifact.parts
|
|
570
|
+
)
|
|
571
|
+
assert has_text, "No text content found in artifact"
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def test_chat_stream_error_handling(server_fixture: subprocess.Popen[str]) -> None:
|
|
575
|
+
"""Test the chat stream error handling with invalid A2A request."""
|
|
576
|
+
logger.info("Starting chat stream error handling test")
|
|
577
|
+
|
|
578
|
+
invalid_data = {
|
|
579
|
+
"jsonrpc": "2.0",
|
|
580
|
+
"id": "test-error-001",
|
|
581
|
+
"method": "message/send",
|
|
582
|
+
"params": {
|
|
583
|
+
"message": {
|
|
584
|
+
"role": "user",
|
|
585
|
+
# Missing required 'parts' field
|
|
586
|
+
"messageId": f"msg-user-{uuid.uuid4()}",
|
|
587
|
+
}
|
|
588
|
+
},
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
response = requests.post(
|
|
592
|
+
A2A_RPC_URL, headers=HEADERS, json=invalid_data, timeout=10
|
|
593
|
+
)
|
|
594
|
+
assert response.status_code == 200
|
|
595
|
+
|
|
596
|
+
response_data = response.json()
|
|
597
|
+
error_response = JSONRPCErrorResponse.model_validate(response_data)
|
|
598
|
+
assert "error" in response_data, "Expected JSON-RPC error in response"
|
|
599
|
+
|
|
600
|
+
# Assert error for invalid parameters
|
|
601
|
+
assert error_response.error.code == -32602
|
|
602
|
+
|
|
603
|
+
logger.info("Error handling test completed successfully")
|
|
604
|
+
{%- else %}
|
|
420
605
|
|
|
421
606
|
|
|
422
607
|
def test_chat_stream_error_handling(server_fixture: subprocess.Popen[str]) -> None:
|
|
@@ -433,6 +618,7 @@ def test_chat_stream_error_handling(server_fixture: subprocess.Popen[str]) -> No
|
|
|
433
618
|
f"Expected status code 422, got {response.status_code}"
|
|
434
619
|
)
|
|
435
620
|
logger.info("Error handling test completed successfully")
|
|
621
|
+
{%- endif %}
|
|
436
622
|
|
|
437
623
|
|
|
438
624
|
def test_collect_feedback(server_fixture: subprocess.Popen[str]) -> None:
|
|
@@ -455,6 +641,37 @@ def test_collect_feedback(server_fixture: subprocess.Popen[str]) -> None:
|
|
|
455
641
|
FEEDBACK_URL, json=feedback_data, headers=HEADERS, timeout=10
|
|
456
642
|
)
|
|
457
643
|
assert response.status_code == 200
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def test_a2a_agent_json_generation(server_fixture: subprocess.Popen[str]) -> None:
|
|
650
|
+
"""
|
|
651
|
+
Test that the agent.json file is automatically generated and served correctly
|
|
652
|
+
via the well-known URI.
|
|
653
|
+
"""
|
|
654
|
+
# Verify the A2A endpoint serves the agent card
|
|
655
|
+
response = requests.get(AGENT_CARD_URL, timeout=10)
|
|
656
|
+
assert response.status_code == 200, f"A2A endpoint returned {response.status_code}"
|
|
657
|
+
|
|
658
|
+
# Validate required fields in served agent card
|
|
659
|
+
served_agent_card = response.json()
|
|
660
|
+
required_fields = [
|
|
661
|
+
"name",
|
|
662
|
+
"description",
|
|
663
|
+
"skills",
|
|
664
|
+
"capabilities",
|
|
665
|
+
"url",
|
|
666
|
+
"version",
|
|
667
|
+
]
|
|
668
|
+
for field in required_fields:
|
|
669
|
+
assert field in served_agent_card, (
|
|
670
|
+
f"Missing required field in served agent card: {field}"
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
{%- endif %}
|
|
458
675
|
{%- if cookiecutter.session_type == "agent_engine" %}
|
|
459
676
|
|
|
460
677
|
|
|
@@ -99,7 +99,7 @@ uv run uvicorn {{cookiecutter.agent_directory}}.server:app --host 0.0.0.0 --port
|
|
|
99
99
|
Using another terminal tab, This is suggested to avoid conflicts with the existing application python environment.
|
|
100
100
|
|
|
101
101
|
```bash
|
|
102
|
-
python3 -m venv .locust_env && source .locust_env/bin/activate && pip install locust==2.31.1
|
|
102
|
+
python3 -m venv .locust_env && source .locust_env/bin/activate && pip install locust==2.31.1{%- if cookiecutter.is_adk_a2a %} a2a-sdk~=0.3.9{%- endif %}
|
|
103
103
|
```
|
|
104
104
|
|
|
105
105
|
**3. Execute the Load Test:**
|
|
@@ -150,7 +150,7 @@ export _ID_TOKEN=$(gcloud auth print-identity-token -q)
|
|
|
150
150
|
**3. Execute the Load Test:**
|
|
151
151
|
Create virtual environment with Locust:
|
|
152
152
|
```bash
|
|
153
|
-
python3 -m venv .locust_env && source .locust_env/bin/activate && pip install locust==2.31.1
|
|
153
|
+
python3 -m venv .locust_env && source .locust_env/bin/activate && pip install locust==2.31.1{%- if cookiecutter.is_adk_a2a %} a2a-sdk~=0.3.9{%- endif %}
|
|
154
154
|
```
|
|
155
155
|
|
|
156
156
|
Execute load tests. The following command executes the same load test parameters as the local test but targets your remote Cloud Run instance.
|
|
@@ -140,7 +140,19 @@ class RemoteAgentUser(WebSocketUser):
|
|
|
140
140
|
|
|
141
141
|
import os
|
|
142
142
|
import time
|
|
143
|
-
{%- if cookiecutter.
|
|
143
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
144
|
+
import uuid
|
|
145
|
+
|
|
146
|
+
from a2a.types import (
|
|
147
|
+
Message,
|
|
148
|
+
MessageSendParams,
|
|
149
|
+
Part,
|
|
150
|
+
Role,
|
|
151
|
+
SendStreamingMessageRequest,
|
|
152
|
+
TextPart,
|
|
153
|
+
)
|
|
154
|
+
from locust import HttpUser, between, task
|
|
155
|
+
{%- elif cookiecutter.is_adk %}
|
|
144
156
|
import uuid
|
|
145
157
|
|
|
146
158
|
import requests
|
|
@@ -149,11 +161,17 @@ from locust import HttpUser, between, task
|
|
|
149
161
|
|
|
150
162
|
from locust import HttpUser, between, task
|
|
151
163
|
{%- endif %}
|
|
152
|
-
{
|
|
164
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
165
|
+
|
|
166
|
+
ENDPOINT = "/a2a/{{cookiecutter.agent_directory}}/"
|
|
167
|
+
{%- elif cookiecutter.is_adk %}
|
|
168
|
+
|
|
153
169
|
ENDPOINT = "/run_sse"
|
|
154
|
-
{
|
|
170
|
+
{%- else %}
|
|
171
|
+
|
|
155
172
|
ENDPOINT = "/stream_messages"
|
|
156
|
-
{
|
|
173
|
+
{%- endif %}
|
|
174
|
+
|
|
157
175
|
|
|
158
176
|
class ChatStreamUser(HttpUser):
|
|
159
177
|
"""Simulates a user interacting with the chat stream API."""
|
|
@@ -162,6 +180,34 @@ class ChatStreamUser(HttpUser):
|
|
|
162
180
|
|
|
163
181
|
@task
|
|
164
182
|
def chat_stream(self) -> None:
|
|
183
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
184
|
+
"""Simulates a chat stream interaction using A2A protocol."""
|
|
185
|
+
headers = {"Content-Type": "application/json"}
|
|
186
|
+
if os.environ.get("_ID_TOKEN"):
|
|
187
|
+
headers["Authorization"] = f"Bearer {os.environ['_ID_TOKEN']}"
|
|
188
|
+
|
|
189
|
+
message = Message(
|
|
190
|
+
message_id=f"msg-user-{uuid.uuid4()}",
|
|
191
|
+
role=Role.user,
|
|
192
|
+
parts=[Part(root=TextPart(text="Hello! What's the weather in New York?"))],
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
request = SendStreamingMessageRequest(
|
|
196
|
+
id=f"req-{uuid.uuid4()}",
|
|
197
|
+
params=MessageSendParams(message=message),
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
start_time = time.time()
|
|
201
|
+
|
|
202
|
+
with self.client.post(
|
|
203
|
+
ENDPOINT,
|
|
204
|
+
name=f"{ENDPOINT} message",
|
|
205
|
+
headers=headers,
|
|
206
|
+
json=request.model_dump(mode="json", exclude_none=True),
|
|
207
|
+
catch_response=True,
|
|
208
|
+
stream=True,
|
|
209
|
+
) as response:
|
|
210
|
+
{%- else %}
|
|
165
211
|
"""Simulates a chat stream interaction."""
|
|
166
212
|
headers = {"Content-Type": "application/json"}
|
|
167
213
|
if os.environ.get("_ID_TOKEN"):
|
|
@@ -218,6 +264,7 @@ class ChatStreamUser(HttpUser):
|
|
|
218
264
|
stream=True,
|
|
219
265
|
params={"alt": "sse"},
|
|
220
266
|
) as response:
|
|
267
|
+
{%- endif %}
|
|
221
268
|
if response.status_code == 200:
|
|
222
269
|
events = []
|
|
223
270
|
for line in response.iter_lines():
|
|
@@ -291,10 +291,28 @@ async def serve_frontend_spa(full_path: str) -> FileResponse:
|
|
|
291
291
|
)
|
|
292
292
|
{% elif cookiecutter.is_adk %}
|
|
293
293
|
import os
|
|
294
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
295
|
+
from collections.abc import AsyncIterator
|
|
296
|
+
from contextlib import asynccontextmanager
|
|
297
|
+
{%- endif %}
|
|
294
298
|
|
|
295
299
|
import google.auth
|
|
300
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
301
|
+
from a2a.server.apps import A2AFastAPIApplication
|
|
302
|
+
from a2a.server.request_handlers import DefaultRequestHandler
|
|
303
|
+
from a2a.server.tasks import InMemoryTaskStore
|
|
304
|
+
from a2a.types import AgentCapabilities, AgentCard
|
|
305
|
+
{%- endif %}
|
|
296
306
|
from fastapi import FastAPI
|
|
307
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
308
|
+
from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor
|
|
309
|
+
from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder
|
|
310
|
+
from google.adk.artifacts.gcs_artifact_service import GcsArtifactService
|
|
311
|
+
from google.adk.runners import Runner
|
|
312
|
+
from google.adk.sessions import InMemorySessionService
|
|
313
|
+
{%- else %}
|
|
297
314
|
from google.adk.cli.fast_api import get_fast_api_app
|
|
315
|
+
{%- endif %}
|
|
298
316
|
from google.cloud import logging as google_cloud_logging
|
|
299
317
|
from opentelemetry import trace
|
|
300
318
|
from opentelemetry.sdk.trace import TracerProvider, export
|
|
@@ -302,6 +320,9 @@ from opentelemetry.sdk.trace import TracerProvider, export
|
|
|
302
320
|
from vertexai import agent_engines
|
|
303
321
|
{%- endif %}
|
|
304
322
|
|
|
323
|
+
{% if cookiecutter.is_adk_a2a -%}
|
|
324
|
+
from {{cookiecutter.agent_directory}}.agent import app as adk_app
|
|
325
|
+
{% endif -%}
|
|
305
326
|
from {{cookiecutter.agent_directory}}.utils.gcs import create_bucket_if_not_exists
|
|
306
327
|
from {{cookiecutter.agent_directory}}.utils.tracing import CloudTraceLoggingSpanExporter
|
|
307
328
|
from {{cookiecutter.agent_directory}}.utils.typing import Feedback
|
|
@@ -309,9 +330,11 @@ from {{cookiecutter.agent_directory}}.utils.typing import Feedback
|
|
|
309
330
|
_, project_id = google.auth.default()
|
|
310
331
|
logging_client = google_cloud_logging.Client()
|
|
311
332
|
logger = logging_client.logger(__name__)
|
|
333
|
+
{%- if not cookiecutter.is_adk_a2a %}
|
|
312
334
|
allow_origins = (
|
|
313
335
|
os.getenv("ALLOW_ORIGINS", "").split(",") if os.getenv("ALLOW_ORIGINS") else None
|
|
314
336
|
)
|
|
337
|
+
{%- endif %}
|
|
315
338
|
|
|
316
339
|
bucket_name = f"gs://{project_id}-{{cookiecutter.project_name}}-logs"
|
|
317
340
|
create_bucket_if_not_exists(
|
|
@@ -323,6 +346,48 @@ processor = export.BatchSpanProcessor(CloudTraceLoggingSpanExporter())
|
|
|
323
346
|
provider.add_span_processor(processor)
|
|
324
347
|
trace.set_tracer_provider(provider)
|
|
325
348
|
|
|
349
|
+
{%- if cookiecutter.is_adk_a2a %}
|
|
350
|
+
|
|
351
|
+
runner = Runner(
|
|
352
|
+
app=adk_app,
|
|
353
|
+
artifact_service=GcsArtifactService(bucket_name=bucket_name),
|
|
354
|
+
session_service=InMemorySessionService(),
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
request_handler = DefaultRequestHandler(
|
|
358
|
+
agent_executor=A2aAgentExecutor(runner=runner), task_store=InMemoryTaskStore()
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
async def build_dynamic_agent_card() -> AgentCard:
|
|
363
|
+
"""Builds the Agent Card dynamically from the root_agent."""
|
|
364
|
+
agent_card_builder = AgentCardBuilder(
|
|
365
|
+
agent=adk_app.root_agent,
|
|
366
|
+
capabilities=AgentCapabilities(streaming=True),
|
|
367
|
+
rpc_url=f"{os.getenv('APP_URL', 'http://0.0.0.0:8000')}/a2a/{adk_app.name}/",
|
|
368
|
+
agent_version=os.getenv("AGENT_VERSION", "0.1.0"),
|
|
369
|
+
)
|
|
370
|
+
agent_card = await agent_card_builder.build()
|
|
371
|
+
return agent_card
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@asynccontextmanager
|
|
375
|
+
async def lifespan(app_instance: FastAPI) -> AsyncIterator[None]:
|
|
376
|
+
agent_card = await build_dynamic_agent_card()
|
|
377
|
+
a2a_app = A2AFastAPIApplication(
|
|
378
|
+
agent_card=agent_card, http_handler=request_handler
|
|
379
|
+
).build()
|
|
380
|
+
app_instance.mount(f"/a2a/{adk_app.name}", a2a_app)
|
|
381
|
+
yield
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
app = FastAPI(
|
|
385
|
+
title="{{cookiecutter.project_name}}",
|
|
386
|
+
description="API for interacting with the Agent {{cookiecutter.project_name}}",
|
|
387
|
+
lifespan=lifespan,
|
|
388
|
+
)
|
|
389
|
+
{%- else %}
|
|
390
|
+
|
|
326
391
|
AGENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
327
392
|
|
|
328
393
|
{%- if cookiecutter.session_type == "alloydb" %}
|
|
@@ -366,6 +431,7 @@ app: FastAPI = get_fast_api_app(
|
|
|
366
431
|
)
|
|
367
432
|
app.title = "{{cookiecutter.project_name}}"
|
|
368
433
|
app.description = "API for interacting with the Agent {{cookiecutter.project_name}}"
|
|
434
|
+
{%- endif %}
|
|
369
435
|
{% else %}
|
|
370
436
|
import logging
|
|
371
437
|
import os
|