kagent-adk 0.7.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kagent/adk/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ import importlib.metadata
2
+
3
+ from ._a2a import KAgentApp
4
+ from .types import AgentConfig
5
+
6
+ __version__ = importlib.metadata.version("kagent_adk")
7
+
8
+ __all__ = ["KAgentApp", "AgentConfig"]
kagent/adk/_a2a.py ADDED
@@ -0,0 +1,178 @@
1
+ #! /usr/bin/env python3
2
+ import faulthandler
3
+ import logging
4
+ import os
5
+ from typing import Any, Callable, List, Optional
6
+
7
+ import httpx
8
+ from a2a.server.apps import A2AFastAPIApplication
9
+ from a2a.server.request_handlers import DefaultRequestHandler
10
+ from a2a.server.tasks import InMemoryTaskStore
11
+ from a2a.types import AgentCard
12
+ from agentsts.adk import ADKSTSIntegration, ADKTokenPropagationPlugin
13
+ from fastapi import FastAPI, Request
14
+ from fastapi.responses import PlainTextResponse
15
+ from google.adk.agents import BaseAgent
16
+ from google.adk.apps import App
17
+ from google.adk.artifacts import InMemoryArtifactService
18
+ from google.adk.plugins import BasePlugin
19
+ from google.adk.runners import Runner
20
+ from google.adk.sessions import InMemorySessionService
21
+ from google.genai import types
22
+
23
+ from kagent.core.a2a import (
24
+ KAgentRequestContextBuilder,
25
+ KAgentTaskStore,
26
+ get_a2a_max_content_length,
27
+ )
28
+
29
+ from ._agent_executor import A2aAgentExecutor, A2aAgentExecutorConfig
30
+ from ._lifespan import LifespanManager
31
+ from ._session_service import KAgentSessionService
32
+ from ._token import KAgentTokenService
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ def health_check(request: Request) -> PlainTextResponse:
38
+ return PlainTextResponse("OK")
39
+
40
+
41
+ def thread_dump(request: Request) -> PlainTextResponse:
42
+ import io
43
+
44
+ buf = io.StringIO()
45
+ faulthandler.dump_traceback(file=buf)
46
+ buf.seek(0)
47
+ return PlainTextResponse(buf.read())
48
+
49
+
50
+ kagent_url_override = os.getenv("KAGENT_URL")
51
+
52
+
53
+ class KAgentApp:
54
+ def __init__(
55
+ self,
56
+ root_agent_factory: Callable[[], BaseAgent],
57
+ agent_card: AgentCard,
58
+ kagent_url: str,
59
+ app_name: str,
60
+ lifespan: Optional[Callable[[Any], Any]] = None,
61
+ plugins: List[BasePlugin] = None,
62
+ stream: bool = False,
63
+ ):
64
+ """Initialize the KAgent application.
65
+
66
+ Args:
67
+ root_agent_factory: Root agent factory function that returns a new agent instance
68
+ agent_card: Agent card configuration for A2A protocol
69
+ kagent_url: URL of the KAgent backend server
70
+ app_name: Application name for identification
71
+ lifespan: Optional lifespan function
72
+ plugins: Optional list of plugins
73
+ stream: Whether to stream the response
74
+ """
75
+ self.root_agent_factory = root_agent_factory
76
+ self.kagent_url = kagent_url
77
+ self.app_name = app_name
78
+ self.agent_card = agent_card
79
+ self._lifespan = lifespan
80
+ self.plugins = plugins if plugins is not None else []
81
+ self.stream = stream
82
+
83
+ def build(self, local=False) -> FastAPI:
84
+ session_service = InMemorySessionService()
85
+ token_service = None
86
+ if not local:
87
+ token_service = KAgentTokenService(self.app_name)
88
+ http_client = httpx.AsyncClient(
89
+ # TODO: add user and agent headers
90
+ base_url=kagent_url_override or self.kagent_url,
91
+ event_hooks=token_service.event_hooks(),
92
+ )
93
+ session_service = KAgentSessionService(http_client)
94
+
95
+ def create_runner() -> Runner:
96
+ root_agent = self.root_agent_factory()
97
+ adk_app = App(name=self.app_name, root_agent=root_agent, plugins=self.plugins)
98
+
99
+ return Runner(
100
+ app=adk_app,
101
+ session_service=session_service,
102
+ artifact_service=InMemoryArtifactService(),
103
+ )
104
+
105
+ agent_executor = A2aAgentExecutor(
106
+ runner=create_runner,
107
+ config=A2aAgentExecutorConfig(stream=self.stream),
108
+ )
109
+
110
+ task_store = InMemoryTaskStore()
111
+ if not local:
112
+ task_store = KAgentTaskStore(http_client)
113
+
114
+ request_context_builder = KAgentRequestContextBuilder(task_store=task_store)
115
+ request_handler = DefaultRequestHandler(
116
+ agent_executor=agent_executor,
117
+ task_store=task_store,
118
+ request_context_builder=request_context_builder,
119
+ )
120
+
121
+ max_content_length = get_a2a_max_content_length()
122
+ a2a_app = A2AFastAPIApplication(
123
+ agent_card=self.agent_card,
124
+ http_handler=request_handler,
125
+ max_content_length=max_content_length,
126
+ )
127
+
128
+ faulthandler.enable()
129
+
130
+ lifespan_manager = LifespanManager()
131
+ lifespan_manager.add(self._lifespan)
132
+ if not local:
133
+ lifespan_manager.add(token_service.lifespan())
134
+
135
+ app = FastAPI(lifespan=lifespan_manager)
136
+
137
+ # Health check/readiness probe
138
+ app.add_route("/health", methods=["GET"], route=health_check)
139
+ app.add_route("/thread_dump", methods=["GET"], route=thread_dump)
140
+ a2a_app.add_routes_to_app(app)
141
+
142
+ return app
143
+
144
+ async def test(self, task: str):
145
+ session_service = InMemorySessionService()
146
+ SESSION_ID = "12345"
147
+ USER_ID = "admin"
148
+ await session_service.create_session(
149
+ app_name=self.app_name,
150
+ session_id=SESSION_ID,
151
+ user_id=USER_ID,
152
+ )
153
+
154
+ root_agent = self.root_agent_factory()
155
+ runner = Runner(
156
+ agent=root_agent,
157
+ app_name=self.app_name,
158
+ session_service=session_service,
159
+ artifact_service=InMemoryArtifactService(),
160
+ )
161
+
162
+ logger.info(f"\n>>> User Query: {task}")
163
+
164
+ # Prepare the user's message in ADK format
165
+ content = types.Content(role="user", parts=[types.Part(text=task)])
166
+ # Key Concept: run_async executes the agent logic and yields Events.
167
+ # We iterate through events to find the final answer.
168
+ async for event in runner.run_async(
169
+ user_id=USER_ID,
170
+ session_id=SESSION_ID,
171
+ new_message=content,
172
+ ):
173
+ # You can uncomment the line below to see *all* events during execution
174
+ # print(f" [Event] Author: {event.author}, Type: {type(event).__name__}, Final: {event.is_final_response()}, Content: {event.content}")
175
+
176
+ # Key Concept: is_final_response() marks the concluding message for the turn.
177
+ jsn = event.model_dump_json()
178
+ logger.info(f" [Event] {jsn}")
@@ -0,0 +1,335 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import logging
5
+ import uuid
6
+ from datetime import datetime, timezone
7
+ from typing import Any, Awaitable, Callable, Optional
8
+
9
+ from a2a.server.agent_execution import AgentExecutor
10
+ from a2a.server.agent_execution.context import RequestContext
11
+ from a2a.server.events.event_queue import EventQueue
12
+ from a2a.types import (
13
+ Artifact,
14
+ Message,
15
+ Part,
16
+ Role,
17
+ TaskArtifactUpdateEvent,
18
+ TaskState,
19
+ TaskStatus,
20
+ TaskStatusUpdateEvent,
21
+ TextPart,
22
+ )
23
+ from google.adk.events import Event, EventActions
24
+ from google.adk.runners import Runner
25
+ from google.adk.utils.context_utils import Aclosing
26
+ from opentelemetry import trace
27
+ from pydantic import BaseModel
28
+ from typing_extensions import override
29
+
30
+ from kagent.core.a2a import TaskResultAggregator, get_kagent_metadata_key
31
+ from kagent.core.tracing._span_processor import (
32
+ clear_kagent_span_attributes,
33
+ set_kagent_span_attributes,
34
+ )
35
+
36
+ from .converters.event_converter import convert_event_to_a2a_events
37
+ from .converters.request_converter import convert_a2a_request_to_adk_run_args
38
+
39
+ logger = logging.getLogger("kagent_adk." + __name__)
40
+
41
+
42
+ class A2aAgentExecutorConfig(BaseModel):
43
+ """Configuration for the A2aAgentExecutor."""
44
+
45
+ stream: bool = False
46
+
47
+
48
+ # This class is a copy of the A2aAgentExecutor class in the ADK sdk,
49
+ # with the following changes:
50
+ # - The runner is ALWAYS a callable that returns a Runner instance
51
+ # - The runner is cleaned up at the end of the execution
52
+ class A2aAgentExecutor(AgentExecutor):
53
+ """An AgentExecutor that runs an ADK Agent against an A2A request and
54
+ publishes updates to an event queue.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ *,
60
+ runner: Callable[..., Runner | Awaitable[Runner]],
61
+ config: Optional[A2aAgentExecutorConfig] = None,
62
+ ):
63
+ super().__init__()
64
+ self._runner = runner
65
+ self._config = config
66
+
67
+ async def _resolve_runner(self) -> Runner:
68
+ """Resolve the runner, handling cases where it's a callable that returns a Runner."""
69
+ if callable(self._runner):
70
+ # Call the function to get the runner
71
+ result = self._runner()
72
+
73
+ # Handle async callables
74
+ if inspect.iscoroutine(result):
75
+ resolved_runner = await result
76
+ else:
77
+ resolved_runner = result
78
+
79
+ # Ensure we got a Runner instance
80
+ if not isinstance(resolved_runner, Runner):
81
+ raise TypeError(f"Callable must return a Runner instance, got {type(resolved_runner)}")
82
+
83
+ return resolved_runner
84
+
85
+ raise TypeError(
86
+ f"Runner must be a Runner instance or a callable that returns a Runner, got {type(self._runner)}"
87
+ )
88
+
89
+ @override
90
+ async def cancel(self, context: RequestContext, event_queue: EventQueue):
91
+ """Cancel the execution."""
92
+ # TODO: Implement proper cancellation logic if needed
93
+ raise NotImplementedError("Cancellation is not supported")
94
+
95
+ @override
96
+ async def execute(
97
+ self,
98
+ context: RequestContext,
99
+ event_queue: EventQueue,
100
+ ):
101
+ """Executes an A2A request and publishes updates to the event queue
102
+ specified. It runs as following:
103
+ * Takes the input from the A2A request
104
+ * Convert the input to ADK input content, and runs the ADK agent
105
+ * Collects output events of the underlying ADK Agent
106
+ * Converts the ADK output events into A2A task updates
107
+ * Publishes the updates back to A2A server via event queue
108
+ """
109
+ if not context.message:
110
+ raise ValueError("A2A request must have a message")
111
+
112
+ # Convert the a2a request to ADK run args
113
+ stream = self._config.stream if self._config is not None else False
114
+ run_args = convert_a2a_request_to_adk_run_args(context, stream=stream)
115
+
116
+ # Prepare span attributes.
117
+ span_attributes = {}
118
+ if run_args.get("user_id"):
119
+ span_attributes["kagent.user_id"] = run_args["user_id"]
120
+ if context.task_id:
121
+ span_attributes["gen_ai.task.id"] = context.task_id
122
+ if run_args.get("session_id"):
123
+ span_attributes["gen_ai.conversation.id"] = run_args["session_id"]
124
+
125
+ # Set kagent span attributes for all spans in context.
126
+ context_token = set_kagent_span_attributes(span_attributes)
127
+ try:
128
+ # for new task, create a task submitted event
129
+ if not context.current_task:
130
+ await event_queue.enqueue_event(
131
+ TaskStatusUpdateEvent(
132
+ task_id=context.task_id,
133
+ status=TaskStatus(
134
+ state=TaskState.submitted,
135
+ message=context.message,
136
+ timestamp=datetime.now(timezone.utc).isoformat(),
137
+ ),
138
+ context_id=context.context_id,
139
+ final=False,
140
+ )
141
+ )
142
+
143
+ # Handle the request and publish updates to the event queue
144
+ runner = await self._resolve_runner()
145
+ try:
146
+ await self._handle_request(context, event_queue, runner, run_args)
147
+ except Exception as e:
148
+ logger.error("Error handling A2A request: %s", e, exc_info=True)
149
+
150
+ # Check if this is a LiteLLM JSON parsing error (common with Ollama models that don't support function calling)
151
+ error_message = str(e)
152
+ if (
153
+ "JSONDecodeError" in error_message
154
+ or "Unterminated string" in error_message
155
+ or "APIConnectionError" in error_message
156
+ ):
157
+ # Check if it's related to function calling
158
+ if "function_call" in error_message.lower() or "json.loads" in error_message:
159
+ error_message = (
160
+ "The model does not support function calling properly. "
161
+ "This error typically occurs when using Ollama models with tools. "
162
+ "Please either:\n"
163
+ "1. Remove tools from the agent configuration, or\n"
164
+ "2. Use a model that supports function calling (e.g., OpenAI, Anthropic, or Gemini models)."
165
+ )
166
+ # Publish failure event
167
+ try:
168
+ await event_queue.enqueue_event(
169
+ TaskStatusUpdateEvent(
170
+ task_id=context.task_id,
171
+ status=TaskStatus(
172
+ state=TaskState.failed,
173
+ timestamp=datetime.now(timezone.utc).isoformat(),
174
+ message=Message(
175
+ message_id=str(uuid.uuid4()),
176
+ role=Role.agent,
177
+ parts=[Part(TextPart(text=error_message))],
178
+ ),
179
+ ),
180
+ context_id=context.context_id,
181
+ final=True,
182
+ )
183
+ )
184
+ except Exception as enqueue_error:
185
+ logger.error("Failed to publish failure event: %s", enqueue_error, exc_info=True)
186
+ finally:
187
+ clear_kagent_span_attributes(context_token)
188
+ # close the runner which cleans up the mcptoolsets
189
+ # since the runner is created for each a2a request
190
+ # and the mcptoolsets are not shared between requests
191
+ # this is necessary to gracefully handle mcp toolset connections
192
+ await runner.close()
193
+
194
+ async def _handle_request(
195
+ self,
196
+ context: RequestContext,
197
+ event_queue: EventQueue,
198
+ runner: Runner,
199
+ run_args: dict[str, Any],
200
+ ):
201
+ # ensure the session exists
202
+ session = await self._prepare_session(context, run_args, runner)
203
+
204
+ # set request headers to session state
205
+ headers = context.call_context.state.get("headers", {})
206
+ state_changes = {
207
+ "headers": headers,
208
+ }
209
+
210
+ actions_with_update = EventActions(state_delta=state_changes)
211
+ system_event = Event(
212
+ invocation_id="header_update",
213
+ author="system",
214
+ actions=actions_with_update,
215
+ )
216
+
217
+ await runner.session_service.append_event(session, system_event)
218
+
219
+ # create invocation context
220
+ invocation_context = runner._new_invocation_context(
221
+ session=session,
222
+ new_message=run_args["new_message"],
223
+ run_config=run_args["run_config"],
224
+ )
225
+
226
+ # publish the task working event
227
+ await event_queue.enqueue_event(
228
+ TaskStatusUpdateEvent(
229
+ task_id=context.task_id,
230
+ status=TaskStatus(
231
+ state=TaskState.working,
232
+ timestamp=datetime.now(timezone.utc).isoformat(),
233
+ ),
234
+ context_id=context.context_id,
235
+ final=False,
236
+ metadata={
237
+ get_kagent_metadata_key("app_name"): runner.app_name,
238
+ get_kagent_metadata_key("user_id"): run_args["user_id"],
239
+ get_kagent_metadata_key("session_id"): run_args["session_id"],
240
+ },
241
+ )
242
+ )
243
+
244
+ task_result_aggregator = TaskResultAggregator()
245
+ async with Aclosing(runner.run_async(**run_args)) as agen:
246
+ async for adk_event in agen:
247
+ for a2a_event in convert_event_to_a2a_events(
248
+ adk_event, invocation_context, context.task_id, context.context_id
249
+ ):
250
+ # Only aggregate non-partial events to avoid duplicates from streaming chunks
251
+ # Partial events are sent to frontend for display but not accumulated
252
+ if not adk_event.partial:
253
+ task_result_aggregator.process_event(a2a_event)
254
+ await event_queue.enqueue_event(a2a_event)
255
+
256
+ # publish the task result event - this is final
257
+ if (
258
+ task_result_aggregator.task_state == TaskState.working
259
+ and task_result_aggregator.task_status_message is not None
260
+ and task_result_aggregator.task_status_message.parts
261
+ ):
262
+ # if task is still working properly, publish the artifact update event as
263
+ # the final result according to a2a protocol.
264
+ await event_queue.enqueue_event(
265
+ TaskArtifactUpdateEvent(
266
+ task_id=context.task_id,
267
+ last_chunk=True,
268
+ context_id=context.context_id,
269
+ artifact=Artifact(
270
+ artifact_id=str(uuid.uuid4()),
271
+ parts=task_result_aggregator.task_status_message.parts,
272
+ ),
273
+ )
274
+ )
275
+ # publish the final status update event
276
+ await event_queue.enqueue_event(
277
+ TaskStatusUpdateEvent(
278
+ task_id=context.task_id,
279
+ status=TaskStatus(
280
+ state=TaskState.completed,
281
+ timestamp=datetime.now(timezone.utc).isoformat(),
282
+ ),
283
+ context_id=context.context_id,
284
+ final=True,
285
+ )
286
+ )
287
+ else:
288
+ await event_queue.enqueue_event(
289
+ TaskStatusUpdateEvent(
290
+ task_id=context.task_id,
291
+ status=TaskStatus(
292
+ state=task_result_aggregator.task_state,
293
+ timestamp=datetime.now(timezone.utc).isoformat(),
294
+ message=task_result_aggregator.task_status_message,
295
+ ),
296
+ context_id=context.context_id,
297
+ final=True,
298
+ )
299
+ )
300
+
301
+ async def _prepare_session(self, context: RequestContext, run_args: dict[str, Any], runner: Runner):
302
+ session_id = run_args["session_id"]
303
+ # create a new session if not exists
304
+ user_id = run_args["user_id"]
305
+ session = await runner.session_service.get_session(
306
+ app_name=runner.app_name,
307
+ user_id=user_id,
308
+ session_id=session_id,
309
+ )
310
+
311
+ if session is None:
312
+ # Extract session name from the first TextPart (like the UI does)
313
+ session_name = None
314
+ if context.message and context.message.parts:
315
+ for part in context.message.parts:
316
+ # A2A parts have a .root property that contains the actual part (TextPart, FilePart, etc.)
317
+ if isinstance(part, Part):
318
+ root_part = part.root
319
+ if isinstance(root_part, TextPart) and root_part.text:
320
+ # Take first 20 chars + "..." if longer (matching UI behavior)
321
+ text = root_part.text.strip()
322
+ session_name = text[:20] + ("..." if len(text) > 20 else "")
323
+ break
324
+
325
+ session = await runner.session_service.create_session(
326
+ app_name=runner.app_name,
327
+ user_id=user_id,
328
+ state={"session_name": session_name},
329
+ session_id=session_id,
330
+ )
331
+
332
+ # Update run_args with the new session_id
333
+ run_args["session_id"] = session.id
334
+
335
+ return session
@@ -0,0 +1,36 @@
1
+ """Lifespan manager for composing multiple FastAPI lifespans."""
2
+
3
+ from contextlib import asynccontextmanager
4
+ from typing import AsyncIterator, Callable, List
5
+
6
+ from fastapi import FastAPI
7
+
8
+
9
+ class LifespanManager:
10
+ """
11
+ A simple lifespan manager that composes multiple FastAPI lifespans.
12
+ Inspired by https://github.com/uriyyo/fastapi-lifespan-manager
13
+ """
14
+
15
+ def __init__(self) -> None:
16
+ self._lifespans: List[Callable[[FastAPI], AsyncIterator[None]]] = []
17
+
18
+ def add(self, lifespan: Callable[[FastAPI], AsyncIterator[None]]) -> None:
19
+ """Add a context manager to the manager."""
20
+ if lifespan is not None:
21
+ self._lifespans.append(lifespan)
22
+
23
+ @asynccontextmanager
24
+ async def __call__(self, app: FastAPI) -> AsyncIterator[None]:
25
+ """Compose all lifespans into a single context manager."""
26
+
27
+ async def nested(index: int) -> AsyncIterator[None]:
28
+ if index >= len(self._lifespans):
29
+ yield
30
+ else:
31
+ async with self._lifespans[index](app):
32
+ async for _ in nested(index + 1):
33
+ yield
34
+
35
+ async for _ in nested(0):
36
+ yield