kagent-adk 0.7.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kagent/adk/__init__.py +8 -0
- kagent/adk/_a2a.py +178 -0
- kagent/adk/_agent_executor.py +335 -0
- kagent/adk/_lifespan.py +36 -0
- kagent/adk/_session_service.py +178 -0
- kagent/adk/_token.py +80 -0
- kagent/adk/artifacts/__init__.py +13 -0
- kagent/adk/artifacts/artifacts_toolset.py +56 -0
- kagent/adk/artifacts/return_artifacts_tool.py +160 -0
- kagent/adk/artifacts/session_path.py +106 -0
- kagent/adk/artifacts/stage_artifacts_tool.py +170 -0
- kagent/adk/cli.py +249 -0
- kagent/adk/converters/__init__.py +0 -0
- kagent/adk/converters/error_mappings.py +60 -0
- kagent/adk/converters/event_converter.py +322 -0
- kagent/adk/converters/part_converter.py +206 -0
- kagent/adk/converters/request_converter.py +35 -0
- kagent/adk/models/__init__.py +3 -0
- kagent/adk/models/_openai.py +564 -0
- kagent/adk/models/_ssl.py +245 -0
- kagent/adk/sandbox_code_executer.py +77 -0
- kagent/adk/skill_fetcher.py +103 -0
- kagent/adk/tools/README.md +217 -0
- kagent/adk/tools/__init__.py +15 -0
- kagent/adk/tools/bash_tool.py +74 -0
- kagent/adk/tools/file_tools.py +192 -0
- kagent/adk/tools/skill_tool.py +104 -0
- kagent/adk/tools/skills_plugin.py +49 -0
- kagent/adk/tools/skills_toolset.py +68 -0
- kagent/adk/types.py +268 -0
- kagent_adk-0.7.11.dist-info/METADATA +35 -0
- kagent_adk-0.7.11.dist-info/RECORD +34 -0
- kagent_adk-0.7.11.dist-info/WHEEL +4 -0
- kagent_adk-0.7.11.dist-info/entry_points.txt +2 -0
kagent/adk/__init__.py
ADDED
kagent/adk/_a2a.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
#! /usr/bin/env python3
|
|
2
|
+
import faulthandler
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Callable, List, Optional
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
from a2a.server.apps import A2AFastAPIApplication
|
|
9
|
+
from a2a.server.request_handlers import DefaultRequestHandler
|
|
10
|
+
from a2a.server.tasks import InMemoryTaskStore
|
|
11
|
+
from a2a.types import AgentCard
|
|
12
|
+
from agentsts.adk import ADKSTSIntegration, ADKTokenPropagationPlugin
|
|
13
|
+
from fastapi import FastAPI, Request
|
|
14
|
+
from fastapi.responses import PlainTextResponse
|
|
15
|
+
from google.adk.agents import BaseAgent
|
|
16
|
+
from google.adk.apps import App
|
|
17
|
+
from google.adk.artifacts import InMemoryArtifactService
|
|
18
|
+
from google.adk.plugins import BasePlugin
|
|
19
|
+
from google.adk.runners import Runner
|
|
20
|
+
from google.adk.sessions import InMemorySessionService
|
|
21
|
+
from google.genai import types
|
|
22
|
+
|
|
23
|
+
from kagent.core.a2a import (
|
|
24
|
+
KAgentRequestContextBuilder,
|
|
25
|
+
KAgentTaskStore,
|
|
26
|
+
get_a2a_max_content_length,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from ._agent_executor import A2aAgentExecutor, A2aAgentExecutorConfig
|
|
30
|
+
from ._lifespan import LifespanManager
|
|
31
|
+
from ._session_service import KAgentSessionService
|
|
32
|
+
from ._token import KAgentTokenService
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def health_check(request: Request) -> PlainTextResponse:
|
|
38
|
+
return PlainTextResponse("OK")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def thread_dump(request: Request) -> PlainTextResponse:
|
|
42
|
+
import io
|
|
43
|
+
|
|
44
|
+
buf = io.StringIO()
|
|
45
|
+
faulthandler.dump_traceback(file=buf)
|
|
46
|
+
buf.seek(0)
|
|
47
|
+
return PlainTextResponse(buf.read())
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
kagent_url_override = os.getenv("KAGENT_URL")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class KAgentApp:
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
root_agent_factory: Callable[[], BaseAgent],
|
|
57
|
+
agent_card: AgentCard,
|
|
58
|
+
kagent_url: str,
|
|
59
|
+
app_name: str,
|
|
60
|
+
lifespan: Optional[Callable[[Any], Any]] = None,
|
|
61
|
+
plugins: List[BasePlugin] = None,
|
|
62
|
+
stream: bool = False,
|
|
63
|
+
):
|
|
64
|
+
"""Initialize the KAgent application.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
root_agent_factory: Root agent factory function that returns a new agent instance
|
|
68
|
+
agent_card: Agent card configuration for A2A protocol
|
|
69
|
+
kagent_url: URL of the KAgent backend server
|
|
70
|
+
app_name: Application name for identification
|
|
71
|
+
lifespan: Optional lifespan function
|
|
72
|
+
plugins: Optional list of plugins
|
|
73
|
+
stream: Whether to stream the response
|
|
74
|
+
"""
|
|
75
|
+
self.root_agent_factory = root_agent_factory
|
|
76
|
+
self.kagent_url = kagent_url
|
|
77
|
+
self.app_name = app_name
|
|
78
|
+
self.agent_card = agent_card
|
|
79
|
+
self._lifespan = lifespan
|
|
80
|
+
self.plugins = plugins if plugins is not None else []
|
|
81
|
+
self.stream = stream
|
|
82
|
+
|
|
83
|
+
def build(self, local=False) -> FastAPI:
|
|
84
|
+
session_service = InMemorySessionService()
|
|
85
|
+
token_service = None
|
|
86
|
+
if not local:
|
|
87
|
+
token_service = KAgentTokenService(self.app_name)
|
|
88
|
+
http_client = httpx.AsyncClient(
|
|
89
|
+
# TODO: add user and agent headers
|
|
90
|
+
base_url=kagent_url_override or self.kagent_url,
|
|
91
|
+
event_hooks=token_service.event_hooks(),
|
|
92
|
+
)
|
|
93
|
+
session_service = KAgentSessionService(http_client)
|
|
94
|
+
|
|
95
|
+
def create_runner() -> Runner:
|
|
96
|
+
root_agent = self.root_agent_factory()
|
|
97
|
+
adk_app = App(name=self.app_name, root_agent=root_agent, plugins=self.plugins)
|
|
98
|
+
|
|
99
|
+
return Runner(
|
|
100
|
+
app=adk_app,
|
|
101
|
+
session_service=session_service,
|
|
102
|
+
artifact_service=InMemoryArtifactService(),
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
agent_executor = A2aAgentExecutor(
|
|
106
|
+
runner=create_runner,
|
|
107
|
+
config=A2aAgentExecutorConfig(stream=self.stream),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
task_store = InMemoryTaskStore()
|
|
111
|
+
if not local:
|
|
112
|
+
task_store = KAgentTaskStore(http_client)
|
|
113
|
+
|
|
114
|
+
request_context_builder = KAgentRequestContextBuilder(task_store=task_store)
|
|
115
|
+
request_handler = DefaultRequestHandler(
|
|
116
|
+
agent_executor=agent_executor,
|
|
117
|
+
task_store=task_store,
|
|
118
|
+
request_context_builder=request_context_builder,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
max_content_length = get_a2a_max_content_length()
|
|
122
|
+
a2a_app = A2AFastAPIApplication(
|
|
123
|
+
agent_card=self.agent_card,
|
|
124
|
+
http_handler=request_handler,
|
|
125
|
+
max_content_length=max_content_length,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
faulthandler.enable()
|
|
129
|
+
|
|
130
|
+
lifespan_manager = LifespanManager()
|
|
131
|
+
lifespan_manager.add(self._lifespan)
|
|
132
|
+
if not local:
|
|
133
|
+
lifespan_manager.add(token_service.lifespan())
|
|
134
|
+
|
|
135
|
+
app = FastAPI(lifespan=lifespan_manager)
|
|
136
|
+
|
|
137
|
+
# Health check/readiness probe
|
|
138
|
+
app.add_route("/health", methods=["GET"], route=health_check)
|
|
139
|
+
app.add_route("/thread_dump", methods=["GET"], route=thread_dump)
|
|
140
|
+
a2a_app.add_routes_to_app(app)
|
|
141
|
+
|
|
142
|
+
return app
|
|
143
|
+
|
|
144
|
+
async def test(self, task: str):
|
|
145
|
+
session_service = InMemorySessionService()
|
|
146
|
+
SESSION_ID = "12345"
|
|
147
|
+
USER_ID = "admin"
|
|
148
|
+
await session_service.create_session(
|
|
149
|
+
app_name=self.app_name,
|
|
150
|
+
session_id=SESSION_ID,
|
|
151
|
+
user_id=USER_ID,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
root_agent = self.root_agent_factory()
|
|
155
|
+
runner = Runner(
|
|
156
|
+
agent=root_agent,
|
|
157
|
+
app_name=self.app_name,
|
|
158
|
+
session_service=session_service,
|
|
159
|
+
artifact_service=InMemoryArtifactService(),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
logger.info(f"\n>>> User Query: {task}")
|
|
163
|
+
|
|
164
|
+
# Prepare the user's message in ADK format
|
|
165
|
+
content = types.Content(role="user", parts=[types.Part(text=task)])
|
|
166
|
+
# Key Concept: run_async executes the agent logic and yields Events.
|
|
167
|
+
# We iterate through events to find the final answer.
|
|
168
|
+
async for event in runner.run_async(
|
|
169
|
+
user_id=USER_ID,
|
|
170
|
+
session_id=SESSION_ID,
|
|
171
|
+
new_message=content,
|
|
172
|
+
):
|
|
173
|
+
# You can uncomment the line below to see *all* events during execution
|
|
174
|
+
# print(f" [Event] Author: {event.author}, Type: {type(event).__name__}, Final: {event.is_final_response()}, Content: {event.content}")
|
|
175
|
+
|
|
176
|
+
# Key Concept: is_final_response() marks the concluding message for the turn.
|
|
177
|
+
jsn = event.model_dump_json()
|
|
178
|
+
logger.info(f" [Event] {jsn}")
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import logging
|
|
5
|
+
import uuid
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from typing import Any, Awaitable, Callable, Optional
|
|
8
|
+
|
|
9
|
+
from a2a.server.agent_execution import AgentExecutor
|
|
10
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
11
|
+
from a2a.server.events.event_queue import EventQueue
|
|
12
|
+
from a2a.types import (
|
|
13
|
+
Artifact,
|
|
14
|
+
Message,
|
|
15
|
+
Part,
|
|
16
|
+
Role,
|
|
17
|
+
TaskArtifactUpdateEvent,
|
|
18
|
+
TaskState,
|
|
19
|
+
TaskStatus,
|
|
20
|
+
TaskStatusUpdateEvent,
|
|
21
|
+
TextPart,
|
|
22
|
+
)
|
|
23
|
+
from google.adk.events import Event, EventActions
|
|
24
|
+
from google.adk.runners import Runner
|
|
25
|
+
from google.adk.utils.context_utils import Aclosing
|
|
26
|
+
from opentelemetry import trace
|
|
27
|
+
from pydantic import BaseModel
|
|
28
|
+
from typing_extensions import override
|
|
29
|
+
|
|
30
|
+
from kagent.core.a2a import TaskResultAggregator, get_kagent_metadata_key
|
|
31
|
+
from kagent.core.tracing._span_processor import (
|
|
32
|
+
clear_kagent_span_attributes,
|
|
33
|
+
set_kagent_span_attributes,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
from .converters.event_converter import convert_event_to_a2a_events
|
|
37
|
+
from .converters.request_converter import convert_a2a_request_to_adk_run_args
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger("kagent_adk." + __name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class A2aAgentExecutorConfig(BaseModel):
|
|
43
|
+
"""Configuration for the A2aAgentExecutor."""
|
|
44
|
+
|
|
45
|
+
stream: bool = False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# This class is a copy of the A2aAgentExecutor class in the ADK sdk,
|
|
49
|
+
# with the following changes:
|
|
50
|
+
# - The runner is ALWAYS a callable that returns a Runner instance
|
|
51
|
+
# - The runner is cleaned up at the end of the execution
|
|
52
|
+
class A2aAgentExecutor(AgentExecutor):
|
|
53
|
+
"""An AgentExecutor that runs an ADK Agent against an A2A request and
|
|
54
|
+
publishes updates to an event queue.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
*,
|
|
60
|
+
runner: Callable[..., Runner | Awaitable[Runner]],
|
|
61
|
+
config: Optional[A2aAgentExecutorConfig] = None,
|
|
62
|
+
):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self._runner = runner
|
|
65
|
+
self._config = config
|
|
66
|
+
|
|
67
|
+
async def _resolve_runner(self) -> Runner:
|
|
68
|
+
"""Resolve the runner, handling cases where it's a callable that returns a Runner."""
|
|
69
|
+
if callable(self._runner):
|
|
70
|
+
# Call the function to get the runner
|
|
71
|
+
result = self._runner()
|
|
72
|
+
|
|
73
|
+
# Handle async callables
|
|
74
|
+
if inspect.iscoroutine(result):
|
|
75
|
+
resolved_runner = await result
|
|
76
|
+
else:
|
|
77
|
+
resolved_runner = result
|
|
78
|
+
|
|
79
|
+
# Ensure we got a Runner instance
|
|
80
|
+
if not isinstance(resolved_runner, Runner):
|
|
81
|
+
raise TypeError(f"Callable must return a Runner instance, got {type(resolved_runner)}")
|
|
82
|
+
|
|
83
|
+
return resolved_runner
|
|
84
|
+
|
|
85
|
+
raise TypeError(
|
|
86
|
+
f"Runner must be a Runner instance or a callable that returns a Runner, got {type(self._runner)}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@override
|
|
90
|
+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
|
|
91
|
+
"""Cancel the execution."""
|
|
92
|
+
# TODO: Implement proper cancellation logic if needed
|
|
93
|
+
raise NotImplementedError("Cancellation is not supported")
|
|
94
|
+
|
|
95
|
+
@override
|
|
96
|
+
async def execute(
|
|
97
|
+
self,
|
|
98
|
+
context: RequestContext,
|
|
99
|
+
event_queue: EventQueue,
|
|
100
|
+
):
|
|
101
|
+
"""Executes an A2A request and publishes updates to the event queue
|
|
102
|
+
specified. It runs as following:
|
|
103
|
+
* Takes the input from the A2A request
|
|
104
|
+
* Convert the input to ADK input content, and runs the ADK agent
|
|
105
|
+
* Collects output events of the underlying ADK Agent
|
|
106
|
+
* Converts the ADK output events into A2A task updates
|
|
107
|
+
* Publishes the updates back to A2A server via event queue
|
|
108
|
+
"""
|
|
109
|
+
if not context.message:
|
|
110
|
+
raise ValueError("A2A request must have a message")
|
|
111
|
+
|
|
112
|
+
# Convert the a2a request to ADK run args
|
|
113
|
+
stream = self._config.stream if self._config is not None else False
|
|
114
|
+
run_args = convert_a2a_request_to_adk_run_args(context, stream=stream)
|
|
115
|
+
|
|
116
|
+
# Prepare span attributes.
|
|
117
|
+
span_attributes = {}
|
|
118
|
+
if run_args.get("user_id"):
|
|
119
|
+
span_attributes["kagent.user_id"] = run_args["user_id"]
|
|
120
|
+
if context.task_id:
|
|
121
|
+
span_attributes["gen_ai.task.id"] = context.task_id
|
|
122
|
+
if run_args.get("session_id"):
|
|
123
|
+
span_attributes["gen_ai.conversation.id"] = run_args["session_id"]
|
|
124
|
+
|
|
125
|
+
# Set kagent span attributes for all spans in context.
|
|
126
|
+
context_token = set_kagent_span_attributes(span_attributes)
|
|
127
|
+
try:
|
|
128
|
+
# for new task, create a task submitted event
|
|
129
|
+
if not context.current_task:
|
|
130
|
+
await event_queue.enqueue_event(
|
|
131
|
+
TaskStatusUpdateEvent(
|
|
132
|
+
task_id=context.task_id,
|
|
133
|
+
status=TaskStatus(
|
|
134
|
+
state=TaskState.submitted,
|
|
135
|
+
message=context.message,
|
|
136
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
137
|
+
),
|
|
138
|
+
context_id=context.context_id,
|
|
139
|
+
final=False,
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Handle the request and publish updates to the event queue
|
|
144
|
+
runner = await self._resolve_runner()
|
|
145
|
+
try:
|
|
146
|
+
await self._handle_request(context, event_queue, runner, run_args)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.error("Error handling A2A request: %s", e, exc_info=True)
|
|
149
|
+
|
|
150
|
+
# Check if this is a LiteLLM JSON parsing error (common with Ollama models that don't support function calling)
|
|
151
|
+
error_message = str(e)
|
|
152
|
+
if (
|
|
153
|
+
"JSONDecodeError" in error_message
|
|
154
|
+
or "Unterminated string" in error_message
|
|
155
|
+
or "APIConnectionError" in error_message
|
|
156
|
+
):
|
|
157
|
+
# Check if it's related to function calling
|
|
158
|
+
if "function_call" in error_message.lower() or "json.loads" in error_message:
|
|
159
|
+
error_message = (
|
|
160
|
+
"The model does not support function calling properly. "
|
|
161
|
+
"This error typically occurs when using Ollama models with tools. "
|
|
162
|
+
"Please either:\n"
|
|
163
|
+
"1. Remove tools from the agent configuration, or\n"
|
|
164
|
+
"2. Use a model that supports function calling (e.g., OpenAI, Anthropic, or Gemini models)."
|
|
165
|
+
)
|
|
166
|
+
# Publish failure event
|
|
167
|
+
try:
|
|
168
|
+
await event_queue.enqueue_event(
|
|
169
|
+
TaskStatusUpdateEvent(
|
|
170
|
+
task_id=context.task_id,
|
|
171
|
+
status=TaskStatus(
|
|
172
|
+
state=TaskState.failed,
|
|
173
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
174
|
+
message=Message(
|
|
175
|
+
message_id=str(uuid.uuid4()),
|
|
176
|
+
role=Role.agent,
|
|
177
|
+
parts=[Part(TextPart(text=error_message))],
|
|
178
|
+
),
|
|
179
|
+
),
|
|
180
|
+
context_id=context.context_id,
|
|
181
|
+
final=True,
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
except Exception as enqueue_error:
|
|
185
|
+
logger.error("Failed to publish failure event: %s", enqueue_error, exc_info=True)
|
|
186
|
+
finally:
|
|
187
|
+
clear_kagent_span_attributes(context_token)
|
|
188
|
+
# close the runner which cleans up the mcptoolsets
|
|
189
|
+
# since the runner is created for each a2a request
|
|
190
|
+
# and the mcptoolsets are not shared between requests
|
|
191
|
+
# this is necessary to gracefully handle mcp toolset connections
|
|
192
|
+
await runner.close()
|
|
193
|
+
|
|
194
|
+
async def _handle_request(
|
|
195
|
+
self,
|
|
196
|
+
context: RequestContext,
|
|
197
|
+
event_queue: EventQueue,
|
|
198
|
+
runner: Runner,
|
|
199
|
+
run_args: dict[str, Any],
|
|
200
|
+
):
|
|
201
|
+
# ensure the session exists
|
|
202
|
+
session = await self._prepare_session(context, run_args, runner)
|
|
203
|
+
|
|
204
|
+
# set request headers to session state
|
|
205
|
+
headers = context.call_context.state.get("headers", {})
|
|
206
|
+
state_changes = {
|
|
207
|
+
"headers": headers,
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
actions_with_update = EventActions(state_delta=state_changes)
|
|
211
|
+
system_event = Event(
|
|
212
|
+
invocation_id="header_update",
|
|
213
|
+
author="system",
|
|
214
|
+
actions=actions_with_update,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
await runner.session_service.append_event(session, system_event)
|
|
218
|
+
|
|
219
|
+
# create invocation context
|
|
220
|
+
invocation_context = runner._new_invocation_context(
|
|
221
|
+
session=session,
|
|
222
|
+
new_message=run_args["new_message"],
|
|
223
|
+
run_config=run_args["run_config"],
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# publish the task working event
|
|
227
|
+
await event_queue.enqueue_event(
|
|
228
|
+
TaskStatusUpdateEvent(
|
|
229
|
+
task_id=context.task_id,
|
|
230
|
+
status=TaskStatus(
|
|
231
|
+
state=TaskState.working,
|
|
232
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
233
|
+
),
|
|
234
|
+
context_id=context.context_id,
|
|
235
|
+
final=False,
|
|
236
|
+
metadata={
|
|
237
|
+
get_kagent_metadata_key("app_name"): runner.app_name,
|
|
238
|
+
get_kagent_metadata_key("user_id"): run_args["user_id"],
|
|
239
|
+
get_kagent_metadata_key("session_id"): run_args["session_id"],
|
|
240
|
+
},
|
|
241
|
+
)
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
task_result_aggregator = TaskResultAggregator()
|
|
245
|
+
async with Aclosing(runner.run_async(**run_args)) as agen:
|
|
246
|
+
async for adk_event in agen:
|
|
247
|
+
for a2a_event in convert_event_to_a2a_events(
|
|
248
|
+
adk_event, invocation_context, context.task_id, context.context_id
|
|
249
|
+
):
|
|
250
|
+
# Only aggregate non-partial events to avoid duplicates from streaming chunks
|
|
251
|
+
# Partial events are sent to frontend for display but not accumulated
|
|
252
|
+
if not adk_event.partial:
|
|
253
|
+
task_result_aggregator.process_event(a2a_event)
|
|
254
|
+
await event_queue.enqueue_event(a2a_event)
|
|
255
|
+
|
|
256
|
+
# publish the task result event - this is final
|
|
257
|
+
if (
|
|
258
|
+
task_result_aggregator.task_state == TaskState.working
|
|
259
|
+
and task_result_aggregator.task_status_message is not None
|
|
260
|
+
and task_result_aggregator.task_status_message.parts
|
|
261
|
+
):
|
|
262
|
+
# if task is still working properly, publish the artifact update event as
|
|
263
|
+
# the final result according to a2a protocol.
|
|
264
|
+
await event_queue.enqueue_event(
|
|
265
|
+
TaskArtifactUpdateEvent(
|
|
266
|
+
task_id=context.task_id,
|
|
267
|
+
last_chunk=True,
|
|
268
|
+
context_id=context.context_id,
|
|
269
|
+
artifact=Artifact(
|
|
270
|
+
artifact_id=str(uuid.uuid4()),
|
|
271
|
+
parts=task_result_aggregator.task_status_message.parts,
|
|
272
|
+
),
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
# publish the final status update event
|
|
276
|
+
await event_queue.enqueue_event(
|
|
277
|
+
TaskStatusUpdateEvent(
|
|
278
|
+
task_id=context.task_id,
|
|
279
|
+
status=TaskStatus(
|
|
280
|
+
state=TaskState.completed,
|
|
281
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
282
|
+
),
|
|
283
|
+
context_id=context.context_id,
|
|
284
|
+
final=True,
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
else:
|
|
288
|
+
await event_queue.enqueue_event(
|
|
289
|
+
TaskStatusUpdateEvent(
|
|
290
|
+
task_id=context.task_id,
|
|
291
|
+
status=TaskStatus(
|
|
292
|
+
state=task_result_aggregator.task_state,
|
|
293
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
294
|
+
message=task_result_aggregator.task_status_message,
|
|
295
|
+
),
|
|
296
|
+
context_id=context.context_id,
|
|
297
|
+
final=True,
|
|
298
|
+
)
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
async def _prepare_session(self, context: RequestContext, run_args: dict[str, Any], runner: Runner):
|
|
302
|
+
session_id = run_args["session_id"]
|
|
303
|
+
# create a new session if not exists
|
|
304
|
+
user_id = run_args["user_id"]
|
|
305
|
+
session = await runner.session_service.get_session(
|
|
306
|
+
app_name=runner.app_name,
|
|
307
|
+
user_id=user_id,
|
|
308
|
+
session_id=session_id,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if session is None:
|
|
312
|
+
# Extract session name from the first TextPart (like the UI does)
|
|
313
|
+
session_name = None
|
|
314
|
+
if context.message and context.message.parts:
|
|
315
|
+
for part in context.message.parts:
|
|
316
|
+
# A2A parts have a .root property that contains the actual part (TextPart, FilePart, etc.)
|
|
317
|
+
if isinstance(part, Part):
|
|
318
|
+
root_part = part.root
|
|
319
|
+
if isinstance(root_part, TextPart) and root_part.text:
|
|
320
|
+
# Take first 20 chars + "..." if longer (matching UI behavior)
|
|
321
|
+
text = root_part.text.strip()
|
|
322
|
+
session_name = text[:20] + ("..." if len(text) > 20 else "")
|
|
323
|
+
break
|
|
324
|
+
|
|
325
|
+
session = await runner.session_service.create_session(
|
|
326
|
+
app_name=runner.app_name,
|
|
327
|
+
user_id=user_id,
|
|
328
|
+
state={"session_name": session_name},
|
|
329
|
+
session_id=session_id,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# Update run_args with the new session_id
|
|
333
|
+
run_args["session_id"] = session.id
|
|
334
|
+
|
|
335
|
+
return session
|
kagent/adk/_lifespan.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Lifespan manager for composing multiple FastAPI lifespans."""
|
|
2
|
+
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from typing import AsyncIterator, Callable, List
|
|
5
|
+
|
|
6
|
+
from fastapi import FastAPI
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LifespanManager:
|
|
10
|
+
"""
|
|
11
|
+
A simple lifespan manager that composes multiple FastAPI lifespans.
|
|
12
|
+
Inspired by https://github.com/uriyyo/fastapi-lifespan-manager
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self) -> None:
|
|
16
|
+
self._lifespans: List[Callable[[FastAPI], AsyncIterator[None]]] = []
|
|
17
|
+
|
|
18
|
+
def add(self, lifespan: Callable[[FastAPI], AsyncIterator[None]]) -> None:
|
|
19
|
+
"""Add a context manager to the manager."""
|
|
20
|
+
if lifespan is not None:
|
|
21
|
+
self._lifespans.append(lifespan)
|
|
22
|
+
|
|
23
|
+
@asynccontextmanager
|
|
24
|
+
async def __call__(self, app: FastAPI) -> AsyncIterator[None]:
|
|
25
|
+
"""Compose all lifespans into a single context manager."""
|
|
26
|
+
|
|
27
|
+
async def nested(index: int) -> AsyncIterator[None]:
|
|
28
|
+
if index >= len(self._lifespans):
|
|
29
|
+
yield
|
|
30
|
+
else:
|
|
31
|
+
async with self._lifespans[index](app):
|
|
32
|
+
async for _ in nested(index + 1):
|
|
33
|
+
yield
|
|
34
|
+
|
|
35
|
+
async for _ in nested(0):
|
|
36
|
+
yield
|