smarta2a 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smarta2a/__init__.py +4 -4
- smarta2a/agent/a2a_agent.py +38 -0
- smarta2a/agent/a2a_mcp_server.py +37 -0
- smarta2a/archive/mcp_client.py +86 -0
- smarta2a/client/__init__.py +0 -0
- smarta2a/client/a2a_client.py +267 -0
- smarta2a/client/smart_mcp_client.py +60 -0
- smarta2a/client/tools_manager.py +58 -0
- smarta2a/history_update_strategies/__init__.py +8 -0
- smarta2a/history_update_strategies/append_strategy.py +10 -0
- smarta2a/history_update_strategies/history_update_strategy.py +15 -0
- smarta2a/model_providers/__init__.py +5 -0
- smarta2a/model_providers/base_llm_provider.py +15 -0
- smarta2a/model_providers/openai_provider.py +281 -0
- smarta2a/server/__init__.py +3 -0
- smarta2a/server/handler_registry.py +23 -0
- smarta2a/{server.py → server/server.py} +224 -254
- smarta2a/server/state_manager.py +34 -0
- smarta2a/server/subscription_service.py +109 -0
- smarta2a/server/task_service.py +155 -0
- smarta2a/state_stores/__init__.py +8 -0
- smarta2a/state_stores/base_state_store.py +20 -0
- smarta2a/state_stores/inmemory_state_store.py +21 -0
- smarta2a/utils/__init__.py +32 -0
- smarta2a/utils/prompt_helpers.py +38 -0
- smarta2a/utils/task_builder.py +153 -0
- smarta2a/utils/task_request_builder.py +114 -0
- smarta2a/{types.py → utils/types.py} +62 -2
- {smarta2a-0.2.1.dist-info → smarta2a-0.2.3.dist-info}/METADATA +13 -7
- smarta2a-0.2.3.dist-info/RECORD +32 -0
- smarta2a-0.2.1.dist-info/RECORD +0 -7
- {smarta2a-0.2.1.dist-info → smarta2a-0.2.3.dist-info}/WHEEL +0 -0
- {smarta2a-0.2.1.dist-info → smarta2a-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,34 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Optional, Dict, Any
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
# Local imports
|
6
|
+
from smarta2a.state_stores.base_state_store import BaseStateStore
|
7
|
+
from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
|
8
|
+
from smarta2a.utils.types import Message, StateData
|
9
|
+
|
10
|
+
class StateManager:
|
11
|
+
def __init__(self, store: Optional[BaseStateStore], history_strategy: HistoryUpdateStrategy):
|
12
|
+
self.store = store
|
13
|
+
self.strategy = history_strategy
|
14
|
+
|
15
|
+
def init_or_get(self, session_id: Optional[str], message: Message, metadata: Dict[str, Any]) -> StateData:
|
16
|
+
sid = session_id or str(uuid4())
|
17
|
+
if not self.store:
|
18
|
+
return StateData(sessionId=sid, history=[message], metadata=metadata or {})
|
19
|
+
existing = self.store.get_state(sid) or StateData(sessionId=sid, history=[], metadata={})
|
20
|
+
existing.history.append(message)
|
21
|
+
existing.metadata = {**(existing.metadata or {}), **(metadata or {})}
|
22
|
+
self.store.update_state(sid, existing)
|
23
|
+
return existing
|
24
|
+
|
25
|
+
def update(self, state: StateData):
|
26
|
+
if self.store:
|
27
|
+
self.store.update_state(state.sessionId, state)
|
28
|
+
|
29
|
+
def get_store(self) -> Optional[BaseStateStore]:
|
30
|
+
return self.store
|
31
|
+
|
32
|
+
def get_strategy(self) -> HistoryUpdateStrategy:
|
33
|
+
return self.strategy
|
34
|
+
|
@@ -0,0 +1,109 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Optional, List, Dict, Any, AsyncGenerator, Union
|
3
|
+
from datetime import datetime
|
4
|
+
from collections import defaultdict
|
5
|
+
from uuid import uuid4
|
6
|
+
from fastapi.responses import StreamingResponse
|
7
|
+
from sse_starlette.sse import EventSourceResponse
|
8
|
+
|
9
|
+
# Local imports
|
10
|
+
from smarta2a.server.handler_registry import HandlerRegistry
|
11
|
+
from smarta2a.server.state_manager import StateManager
|
12
|
+
from smarta2a.utils.types import (
|
13
|
+
Message, StateData, SendTaskStreamingRequest, SendTaskStreamingResponse,
|
14
|
+
TaskSendParams, A2AStatus, A2AStreamResponse, TaskStatusUpdateEvent,
|
15
|
+
TaskStatus, TaskState, TaskArtifactUpdateEvent, Artifact, TextPart,
|
16
|
+
FilePart, DataPart, FileContent, MethodNotFoundError, TaskNotFoundError,
|
17
|
+
InternalError
|
18
|
+
)
|
19
|
+
|
20
|
+
class SubscriptionService:
|
21
|
+
def __init__(self, registry: HandlerRegistry, state_mgr: StateManager):
|
22
|
+
self.registry = registry
|
23
|
+
self.state_mgr = state_mgr
|
24
|
+
|
25
|
+
async def subscribe(self, request: SendTaskStreamingRequest, state: Optional[StateData]) -> StreamingResponse:
|
26
|
+
handler = self.registry.get_subscription("tasks/sendSubscribe")
|
27
|
+
if not handler:
|
28
|
+
err = SendTaskStreamingResponse(jsonrpc="2.0", id=request.id, error=MethodNotFoundError()).model_dump_json()
|
29
|
+
return EventSourceResponse(err)
|
30
|
+
|
31
|
+
session_id = state.sessionId if state else request.params.sessionId or str(uuid4())
|
32
|
+
history = state.history.copy() if state else [request.params.message]
|
33
|
+
metadata = state.metadata.copy() if state else (request.params.metadata or {})
|
34
|
+
|
35
|
+
async def event_stream():
|
36
|
+
try:
|
37
|
+
events = handler(request, state) if state else handler(request)
|
38
|
+
async for ev in self._normalize(request.params, events, history.copy(), metadata.copy(), session_id):
|
39
|
+
yield f"data: {ev}\n\n"
|
40
|
+
except Exception as e:
|
41
|
+
err = TaskNotFoundError() if 'not found' in str(e).lower() else InternalError(data=str(e))
|
42
|
+
msg = SendTaskStreamingResponse(jsonrpc="2.0", id=request.id, error=err).model_dump_json()
|
43
|
+
yield f"data: {msg}\n\n"
|
44
|
+
|
45
|
+
return StreamingResponse(event_stream(), media_type="text/event-stream; charset=utf-8")
|
46
|
+
|
47
|
+
async def _normalize(
|
48
|
+
self,
|
49
|
+
params: TaskSendParams,
|
50
|
+
events: AsyncGenerator,
|
51
|
+
history: List[Message],
|
52
|
+
metadata: Dict[str, Any],
|
53
|
+
session_id: str
|
54
|
+
) -> AsyncGenerator[str, None]:
|
55
|
+
artifact_state = defaultdict(lambda: {"index": 0, "last_chunk": False})
|
56
|
+
async for item in events:
|
57
|
+
if isinstance(item, SendTaskStreamingResponse):
|
58
|
+
yield item.model_dump_json()
|
59
|
+
continue
|
60
|
+
|
61
|
+
if isinstance(item, A2AStatus):
|
62
|
+
te = TaskStatusUpdateEvent(
|
63
|
+
id=params.id,
|
64
|
+
status=TaskStatus(state=TaskState(item.status), timestamp=datetime.now()),
|
65
|
+
final=item.final or (item.status.lower() == TaskState.COMPLETED),
|
66
|
+
metadata=item.metadata
|
67
|
+
)
|
68
|
+
yield SendTaskStreamingResponse(jsonrpc="2.0", id=params.id, result=te).model_dump_json()
|
69
|
+
continue
|
70
|
+
|
71
|
+
content_item = item
|
72
|
+
if not isinstance(item, A2AStreamResponse):
|
73
|
+
content_item = A2AStreamResponse(content=item)
|
74
|
+
|
75
|
+
parts: List[Union[TextPart, FilePart, DataPart]] = []
|
76
|
+
cont = content_item.content
|
77
|
+
if isinstance(cont, str): parts.append(TextPart(text=cont))
|
78
|
+
elif isinstance(cont, bytes): parts.append(FilePart(file=FileContent(bytes=cont)))
|
79
|
+
elif isinstance(cont, (TextPart, FilePart, DataPart)): parts.append(cont)
|
80
|
+
elif isinstance(cont, Artifact): parts.extend(cont.parts)
|
81
|
+
elif isinstance(cont, list):
|
82
|
+
for elem in cont:
|
83
|
+
if isinstance(elem, str): parts.append(TextPart(text=elem))
|
84
|
+
elif isinstance(elem, (TextPart, FilePart, DataPart)): parts.append(elem)
|
85
|
+
elif isinstance(elem, Artifact): parts.extend(elem.parts)
|
86
|
+
|
87
|
+
idx = content_item.index
|
88
|
+
state = artifact_state[idx]
|
89
|
+
evt = TaskArtifactUpdateEvent(
|
90
|
+
id=params.id,
|
91
|
+
artifact=Artifact(
|
92
|
+
parts=parts,
|
93
|
+
index=idx,
|
94
|
+
append=content_item.append or (state["index"] == idx),
|
95
|
+
lastChunk=content_item.final or state["last_chunk"],
|
96
|
+
metadata=content_item.metadata
|
97
|
+
)
|
98
|
+
)
|
99
|
+
if content_item.final:
|
100
|
+
state["last_chunk"] = True
|
101
|
+
state["index"] += 1
|
102
|
+
|
103
|
+
agent_msg = Message(role="agent", parts=evt.artifact.parts, metadata=evt.artifact.metadata)
|
104
|
+
new_hist = self.state_mgr.strategy.update_history(history, [agent_msg])
|
105
|
+
metadata = {**metadata, **(evt.artifact.metadata or {})}
|
106
|
+
self.state_mgr.update(StateData(session_id, new_hist, metadata))
|
107
|
+
history = new_hist
|
108
|
+
|
109
|
+
yield SendTaskStreamingResponse(jsonrpc="2.0", id=params.id, result=evt).model_dump_json()
|
@@ -0,0 +1,155 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Optional, Union, Any
|
3
|
+
from uuid import uuid4
|
4
|
+
from fastapi import HTTPException
|
5
|
+
from pydantic import ValidationError
|
6
|
+
|
7
|
+
# Local imports
|
8
|
+
from smarta2a.server.handler_registry import HandlerRegistry
|
9
|
+
from smarta2a.server.state_manager import StateManager
|
10
|
+
from smarta2a.utils.task_builder import TaskBuilder
|
11
|
+
from smarta2a.utils.types import (
|
12
|
+
Message, StateData, SendTaskRequest, SendTaskResponse,
|
13
|
+
GetTaskRequest, GetTaskResponse, CancelTaskRequest, CancelTaskResponse,
|
14
|
+
SetTaskPushNotificationRequest, GetTaskPushNotificationRequest,
|
15
|
+
SetTaskPushNotificationResponse, GetTaskPushNotificationResponse,
|
16
|
+
TaskPushNotificationConfig, TaskState, A2AStatus,
|
17
|
+
JSONRPCError, MethodNotFoundError, InternalError, InvalidParamsError,
|
18
|
+
TaskNotCancelableError, UnsupportedOperationError
|
19
|
+
)
|
20
|
+
|
21
|
+
class TaskService:
|
22
|
+
def __init__(self, registry: HandlerRegistry, state_mgr: StateManager):
|
23
|
+
self.registry = registry
|
24
|
+
self.state_mgr = state_mgr
|
25
|
+
self.builder = TaskBuilder(default_status=TaskState.COMPLETED)
|
26
|
+
|
27
|
+
def send(self, request: SendTaskRequest, state: Optional[StateData]) -> SendTaskResponse:
|
28
|
+
handler = self.registry.get_handler("tasks/send")
|
29
|
+
if not handler:
|
30
|
+
return SendTaskResponse(id=request.id, error=MethodNotFoundError())
|
31
|
+
|
32
|
+
session_id = state.sessionId if state else request.params.sessionId or str(uuid4())
|
33
|
+
history = state.history.copy() if state else [request.params.message]
|
34
|
+
metadata = state.metadata.copy() if state else (request.params.metadata or {})
|
35
|
+
|
36
|
+
try:
|
37
|
+
raw = handler(request, state) if state else handler(request)
|
38
|
+
if isinstance(raw, SendTaskResponse):
|
39
|
+
return raw
|
40
|
+
|
41
|
+
task = self.builder.build(
|
42
|
+
content=raw,
|
43
|
+
task_id=request.params.id,
|
44
|
+
session_id=session_id,
|
45
|
+
metadata=metadata,
|
46
|
+
history=history
|
47
|
+
)
|
48
|
+
|
49
|
+
if task.artifacts:
|
50
|
+
parts = [p for a in task.artifacts for p in a.parts]
|
51
|
+
agent_msg = Message(role="agent", parts=parts, metadata=task.metadata)
|
52
|
+
new_hist = self.state_mgr.strategy.update_history(history, [agent_msg])
|
53
|
+
task.history = new_hist
|
54
|
+
self.state_mgr.update(StateData(sessionId=session_id, history=new_hist, metadata=metadata))
|
55
|
+
|
56
|
+
return SendTaskResponse(id=request.id, result=task)
|
57
|
+
except JSONRPCError as e:
|
58
|
+
return SendTaskResponse(id=request.id, error=e)
|
59
|
+
except Exception as e:
|
60
|
+
return SendTaskResponse(id=request.id, error=InternalError(data=str(e)))
|
61
|
+
|
62
|
+
def get(self, request: GetTaskRequest) -> GetTaskResponse:
|
63
|
+
handler = self.registry.get_handler("tasks/get")
|
64
|
+
if not handler:
|
65
|
+
return GetTaskResponse(id=request.id, error=MethodNotFoundError())
|
66
|
+
try:
|
67
|
+
raw = handler(request)
|
68
|
+
if isinstance(raw, GetTaskResponse):
|
69
|
+
return self._validate(raw, request)
|
70
|
+
|
71
|
+
task = self.builder.build(
|
72
|
+
content=raw,
|
73
|
+
task_id=request.params.id,
|
74
|
+
metadata=request.params.metadata or {}
|
75
|
+
)
|
76
|
+
return self._finalize(request, task)
|
77
|
+
except JSONRPCError as e:
|
78
|
+
return GetTaskResponse(id=request.id, error=e)
|
79
|
+
except Exception as e:
|
80
|
+
return GetTaskResponse(id=request.id, error=InternalError(data=str(e)))
|
81
|
+
|
82
|
+
def cancel(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
83
|
+
handler = self.registry.get_handler("tasks/cancel")
|
84
|
+
if not handler:
|
85
|
+
return CancelTaskResponse(id=request.id, error=MethodNotFoundError())
|
86
|
+
try:
|
87
|
+
raw = handler(request)
|
88
|
+
if isinstance(raw, CancelTaskResponse):
|
89
|
+
return self._validate(raw, request)
|
90
|
+
|
91
|
+
if isinstance(raw, A2AStatus):
|
92
|
+
task = self.builder.normalize_from_status(status=raw.status, task_id=request.params.id, metadata=raw.metadata or {})
|
93
|
+
else:
|
94
|
+
task = self.builder.build(content=raw, task_id=request.params.id, metadata=raw.metadata or {})
|
95
|
+
|
96
|
+
if task.id != request.params.id:
|
97
|
+
raise InvalidParamsError(data=f"Task ID mismatch: {task.id} vs {request.params.id}")
|
98
|
+
if task.status.state not in [TaskState.CANCELED, TaskState.COMPLETED]:
|
99
|
+
raise TaskNotCancelableError()
|
100
|
+
|
101
|
+
return CancelTaskResponse(id=request.id, result=task)
|
102
|
+
except JSONRPCError as e:
|
103
|
+
return CancelTaskResponse(id=request.id, error=e)
|
104
|
+
except (InvalidParamsError, TaskNotCancelableError) as e:
|
105
|
+
return CancelTaskResponse(id=request.id, error=e)
|
106
|
+
except HTTPException as e:
|
107
|
+
if e.status_code == 405:
|
108
|
+
return CancelTaskResponse(id=request.id, error=UnsupportedOperationError())
|
109
|
+
return CancelTaskResponse(id=request.id, error=InternalError(data=str(e)))
|
110
|
+
except Exception as e:
|
111
|
+
return CancelTaskResponse(id=request.id, error=InternalError(data=str(e)))
|
112
|
+
|
113
|
+
def set_notification(self, request: SetTaskPushNotificationRequest) -> SetTaskPushNotificationResponse:
|
114
|
+
handler = self.registry.get_handler("tasks/pushNotification/set")
|
115
|
+
if not handler:
|
116
|
+
return SetTaskPushNotificationResponse(id=request.id, error=MethodNotFoundError())
|
117
|
+
try:
|
118
|
+
raw = handler(request)
|
119
|
+
if raw is None:
|
120
|
+
return SetTaskPushNotificationResponse(id=request.id, result=request.params)
|
121
|
+
if isinstance(raw, SetTaskPushNotificationResponse):
|
122
|
+
return raw
|
123
|
+
except JSONRPCError as e:
|
124
|
+
return SetTaskPushNotificationResponse(id=request.id, error=e)
|
125
|
+
except Exception as e:
|
126
|
+
return SetTaskPushNotificationResponse(id=request.id, error=InternalError(data=str(e)))
|
127
|
+
|
128
|
+
def get_notification(self, request: GetTaskPushNotificationRequest) -> GetTaskPushNotificationResponse:
|
129
|
+
handler = self.registry.get_handler("tasks/pushNotification/get")
|
130
|
+
if not handler:
|
131
|
+
return GetTaskPushNotificationResponse(id=request.id, error=MethodNotFoundError())
|
132
|
+
try:
|
133
|
+
raw = handler(request)
|
134
|
+
if isinstance(raw, GetTaskPushNotificationResponse):
|
135
|
+
return raw
|
136
|
+
cfg = TaskPushNotificationConfig.model_validate(raw)
|
137
|
+
return GetTaskPushNotificationResponse(id=request.id, result=cfg)
|
138
|
+
except ValidationError as e:
|
139
|
+
return GetTaskPushNotificationResponse(id=request.id, error=InvalidParamsError(data=e.errors()))
|
140
|
+
except JSONRPCError as e:
|
141
|
+
return GetTaskPushNotificationResponse(id=request.id, error=e)
|
142
|
+
except Exception as e:
|
143
|
+
return GetTaskPushNotificationResponse(id=request.id, error=InternalError(data=str(e)))
|
144
|
+
|
145
|
+
def _validate(self, resp: Union[SendTaskResponse, GetTaskResponse, CancelTaskResponse], req) -> Any:
|
146
|
+
if resp.result and resp.result.id != req.params.id:
|
147
|
+
return type(resp)(id=req.id, error=InvalidParamsError(data=f"Task ID mismatch: {resp.result.id} vs {req.params.id}"))
|
148
|
+
return resp
|
149
|
+
|
150
|
+
def _finalize(self, request: GetTaskRequest, task) -> GetTaskResponse:
|
151
|
+
if task.id != request.params.id:
|
152
|
+
return GetTaskResponse(id=request.id, error=InvalidParamsError(data=f"Task ID mismatch: {task.id} vs {request.params.id}"))
|
153
|
+
if request.params.historyLength and task.history:
|
154
|
+
task.history = task.history[-request.params.historyLength:]
|
155
|
+
return GetTaskResponse(id=request.id, result=task)
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Library imports
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from typing import Optional, List, Dict, Any
|
4
|
+
|
5
|
+
# Local imports
|
6
|
+
from smarta2a.utils.types import StateData, Message
|
7
|
+
|
8
|
+
class BaseStateStore(ABC):
|
9
|
+
|
10
|
+
@abstractmethod
|
11
|
+
async def get_state(self, session_id: str) -> Optional[StateData]:
|
12
|
+
pass
|
13
|
+
|
14
|
+
@abstractmethod
|
15
|
+
async def update_state(self, session_id: str, state_data: StateData) -> None:
|
16
|
+
pass
|
17
|
+
|
18
|
+
@abstractmethod
|
19
|
+
async def delete_state(self, session_id: str) -> None:
|
20
|
+
pass
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Dict, Any, Optional, List
|
3
|
+
import uuid
|
4
|
+
|
5
|
+
# Local imports
|
6
|
+
from smarta2a.state_stores.base_state_store import BaseStateStore
|
7
|
+
from smarta2a.utils.types import StateData, Message
|
8
|
+
|
9
|
+
class InMemoryStateStore(BaseStateStore):
|
10
|
+
def __init__(self):
|
11
|
+
self.states: Dict[str, StateData] = {}
|
12
|
+
|
13
|
+
def get_state(self, session_id: str) -> Optional[StateData]:
|
14
|
+
return self.states.get(session_id)
|
15
|
+
|
16
|
+
def update_state(self, session_id: str, state_data: StateData):
|
17
|
+
self.states[session_id] = state_data
|
18
|
+
|
19
|
+
def delete_state(self, session_id: str):
|
20
|
+
if session_id in self.states:
|
21
|
+
del self.states[session_id]
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from .types import *
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"TaskSendParams",
|
5
|
+
"SendTaskRequest",
|
6
|
+
"GetTaskRequest",
|
7
|
+
"CancelTaskRequest",
|
8
|
+
"CancelTaskResponse",
|
9
|
+
"Task",
|
10
|
+
"TaskStatus",
|
11
|
+
"TaskState",
|
12
|
+
"Artifact",
|
13
|
+
"TextPart",
|
14
|
+
"FilePart",
|
15
|
+
"FileContent",
|
16
|
+
"A2AResponse",
|
17
|
+
"A2ARequest",
|
18
|
+
"TaskQueryParams",
|
19
|
+
"TaskStatusUpdateEvent",
|
20
|
+
"TaskArtifactUpdateEvent",
|
21
|
+
"A2AStatus",
|
22
|
+
"A2AStreamResponse",
|
23
|
+
"SendTaskResponse",
|
24
|
+
"Message",
|
25
|
+
"InternalError",
|
26
|
+
"TaskNotFoundError",
|
27
|
+
"SetTaskPushNotificationRequest",
|
28
|
+
"GetTaskPushNotificationRequest",
|
29
|
+
"SetTaskPushNotificationResponse",
|
30
|
+
"GetTaskPushNotificationResponse",
|
31
|
+
"TaskPushNotificationConfig"
|
32
|
+
]
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Optional, List
|
3
|
+
|
4
|
+
# Local imports
|
5
|
+
from smarta2a.client.tools_manager import ToolsManager
|
6
|
+
from smarta2a.utils.types import AgentCard
|
7
|
+
|
8
|
+
def build_system_prompt(
|
9
|
+
base_prompt: Optional[str],
|
10
|
+
tools_manager: ToolsManager,
|
11
|
+
mcp_server_urls_or_paths: Optional[List[str]] = None,
|
12
|
+
agent_cards: Optional[List[AgentCard]] = None
|
13
|
+
) -> str:
|
14
|
+
"""
|
15
|
+
Compose the final system prompt by combining the base prompt
|
16
|
+
with a clear listing of available tools.
|
17
|
+
"""
|
18
|
+
header = base_prompt or "You are a helpful assistant with access to the following tools:"
|
19
|
+
|
20
|
+
if mcp_server_urls_or_paths:
|
21
|
+
mcp_tools_desc = tools_manager.describe_tools("mcp")
|
22
|
+
header += f"\n\nAvailable tools:\n{mcp_tools_desc}"
|
23
|
+
|
24
|
+
if agent_cards:
|
25
|
+
a2a_tools_desc = tools_manager.describe_tools("a2a")
|
26
|
+
header += f"\n\nIf needed, you can delegate parts of your task to other agents. The Agents you can use are:\n{_print_agent_list(agent_cards)}\n\nUse the following tools to send tasks to an agent:\n{a2a_tools_desc}"
|
27
|
+
|
28
|
+
return header
|
29
|
+
|
30
|
+
|
31
|
+
def _print_agent_list(agents: List[AgentCard]) -> None:
|
32
|
+
"""Prints multiple agents with separators"""
|
33
|
+
separator = "---"
|
34
|
+
agent_strings = [agent.pretty_print(include_separators=False) for agent in agents]
|
35
|
+
full_output = [separator]
|
36
|
+
full_output.extend(agent_strings)
|
37
|
+
full_output.append(separator)
|
38
|
+
print("\n".join(full_output))
|
@@ -0,0 +1,153 @@
|
|
1
|
+
# Library imports
|
2
|
+
from uuid import uuid4
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import Any, List, Optional, Dict, Union
|
5
|
+
from pydantic import ValidationError
|
6
|
+
|
7
|
+
# Local imports
|
8
|
+
from smarta2a.utils.types import (
|
9
|
+
Task,
|
10
|
+
TaskStatus,
|
11
|
+
TaskState,
|
12
|
+
Artifact,
|
13
|
+
Part,
|
14
|
+
TextPart,
|
15
|
+
FilePart,
|
16
|
+
DataPart,
|
17
|
+
Message,
|
18
|
+
A2AResponse,
|
19
|
+
)
|
20
|
+
|
21
|
+
class TaskBuilder:
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
default_status: TaskState = TaskState.COMPLETED,
|
25
|
+
):
|
26
|
+
self.default_status = default_status
|
27
|
+
|
28
|
+
def build(
|
29
|
+
self,
|
30
|
+
content: Any,
|
31
|
+
task_id: str,
|
32
|
+
session_id: Optional[str] = None,
|
33
|
+
metadata: Optional[Dict[str,Any]] = None,
|
34
|
+
history: Optional[List[Message]] = None,
|
35
|
+
) -> Task:
|
36
|
+
"""Universal task construction from various return types."""
|
37
|
+
history = history or []
|
38
|
+
metadata = metadata or {}
|
39
|
+
|
40
|
+
# 1) If the handler already gave us a full Task, just fix IDs & history:
|
41
|
+
if isinstance(content, Task):
|
42
|
+
content.sessionId = content.sessionId or session_id
|
43
|
+
content.history = history + (content.history or [])
|
44
|
+
content.metadata = content.metadata or metadata
|
45
|
+
return content
|
46
|
+
|
47
|
+
# 2) If they returned an A2AResponse, extract status/content:
|
48
|
+
if isinstance(content, A2AResponse):
|
49
|
+
# prefer the sessionId inside the A2AResponse
|
50
|
+
sid = content.sessionId or session_id
|
51
|
+
# merge metadata from builder-call and from A2AResponse
|
52
|
+
md = {**(metadata or {}), **(content.metadata or {})}
|
53
|
+
status = (
|
54
|
+
content.status
|
55
|
+
if isinstance(content.status, TaskStatus)
|
56
|
+
else TaskStatus(state=content.status)
|
57
|
+
)
|
58
|
+
artifacts = self._normalize_content(content.content)
|
59
|
+
return Task(
|
60
|
+
id=task_id,
|
61
|
+
sessionId=sid,
|
62
|
+
status=status,
|
63
|
+
artifacts=artifacts,
|
64
|
+
metadata=md,
|
65
|
+
history=history,
|
66
|
+
)
|
67
|
+
|
68
|
+
# 3) If they returned a plain dict describing a Task:
|
69
|
+
if isinstance(content, dict):
|
70
|
+
try:
|
71
|
+
return Task(
|
72
|
+
**content,
|
73
|
+
sessionId=session_id or content.get("sessionId"),
|
74
|
+
metadata=metadata or content.get("metadata", {}),
|
75
|
+
history=history,
|
76
|
+
)
|
77
|
+
except ValidationError:
|
78
|
+
pass
|
79
|
+
|
80
|
+
# 4) Fallback: treat whatever they returned as “artifact content”:
|
81
|
+
artifacts = self._normalize_content(content)
|
82
|
+
return Task(
|
83
|
+
id=task_id,
|
84
|
+
sessionId=session_id,
|
85
|
+
status=TaskStatus(state=self.default_status),
|
86
|
+
artifacts=artifacts,
|
87
|
+
metadata=metadata,
|
88
|
+
history=history,
|
89
|
+
)
|
90
|
+
|
91
|
+
def normalize_from_status(
|
92
|
+
self, status: TaskState, task_id: str, metadata: Dict[str,Any]
|
93
|
+
) -> Task:
|
94
|
+
"""Build a Task when only a cancellation or status‐only event occurs."""
|
95
|
+
return Task(
|
96
|
+
id=task_id,
|
97
|
+
sessionId="",
|
98
|
+
status=TaskStatus(state=status, timestamp=datetime.now()),
|
99
|
+
artifacts=[],
|
100
|
+
metadata=metadata,
|
101
|
+
history=[],
|
102
|
+
)
|
103
|
+
|
104
|
+
def _normalize_content(self, content: Any) -> List[Artifact]:
|
105
|
+
"""Turn any handler return value into a list of Artifact."""
|
106
|
+
if isinstance(content, Artifact):
|
107
|
+
return [content]
|
108
|
+
|
109
|
+
if isinstance(content, list) and all(isinstance(a, Artifact) for a in content):
|
110
|
+
return content
|
111
|
+
|
112
|
+
if isinstance(content, list):
|
113
|
+
return [Artifact(parts=self._parts_from_mixed(content))]
|
114
|
+
|
115
|
+
if isinstance(content, str):
|
116
|
+
return [Artifact(parts=[TextPart(text=content)])]
|
117
|
+
|
118
|
+
if isinstance(content, dict):
|
119
|
+
# raw artifact dict
|
120
|
+
return [Artifact.model_validate(content)]
|
121
|
+
|
122
|
+
# explicit `Part` subclasses
|
123
|
+
if isinstance(content, (TextPart, FilePart, DataPart)):
|
124
|
+
return [Artifact(parts=[content])]
|
125
|
+
|
126
|
+
# “unknown” object: try Pydantic → dict → fallback to text
|
127
|
+
try:
|
128
|
+
return [Artifact.model_validate(content)]
|
129
|
+
except ValidationError:
|
130
|
+
return [Artifact(parts=[TextPart(text=str(content))])]
|
131
|
+
|
132
|
+
def _parts_from_mixed(self, items: List[Any]) -> List[Part]:
|
133
|
+
parts: List[Part] = []
|
134
|
+
for item in items:
|
135
|
+
if isinstance(item, Artifact):
|
136
|
+
parts.extend(item.parts)
|
137
|
+
else:
|
138
|
+
parts.append(self._create_part(item))
|
139
|
+
return parts
|
140
|
+
|
141
|
+
def _create_part(self, item: Any) -> Part:
|
142
|
+
from smarta2a.utils.types import Part as UnionPart
|
143
|
+
# guard against Union alias
|
144
|
+
if isinstance(item, (TextPart, FilePart, DataPart)):
|
145
|
+
return item
|
146
|
+
if isinstance(item, str):
|
147
|
+
return TextPart(text=item)
|
148
|
+
if isinstance(item, dict):
|
149
|
+
try:
|
150
|
+
return UnionPart.model_validate(item)
|
151
|
+
except ValidationError:
|
152
|
+
return TextPart(text=str(item))
|
153
|
+
return TextPart(text=str(item))
|
@@ -0,0 +1,114 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Any, Literal
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
# Local imports
|
6
|
+
from smarta2a.utils.types import (
|
7
|
+
TaskPushNotificationConfig,
|
8
|
+
PushNotificationConfig,
|
9
|
+
TaskSendParams,
|
10
|
+
TextPart,
|
11
|
+
DataPart,
|
12
|
+
FilePart,
|
13
|
+
FileContent,
|
14
|
+
Message,
|
15
|
+
Part,
|
16
|
+
TaskQueryParams,
|
17
|
+
TaskIdParams,
|
18
|
+
GetTaskRequest,
|
19
|
+
CancelTaskRequest,
|
20
|
+
SetTaskPushNotificationRequest,
|
21
|
+
GetTaskPushNotificationRequest,
|
22
|
+
AuthenticationInfo,
|
23
|
+
)
|
24
|
+
|
25
|
+
class TaskRequestBuilder:
|
26
|
+
@staticmethod
|
27
|
+
def build_send_task_request(
|
28
|
+
*,
|
29
|
+
id: str,
|
30
|
+
role: Literal["user", "agent"] = "user",
|
31
|
+
text: str | None = None,
|
32
|
+
data: dict[str, Any] | None = None,
|
33
|
+
file_uri: str | None = None,
|
34
|
+
session_id: str | None = None,
|
35
|
+
accepted_output_modes: list[str] | None = None,
|
36
|
+
push_notification: PushNotificationConfig | None = None,
|
37
|
+
history_length: int | None = None,
|
38
|
+
metadata: dict[str, Any] | None = None,
|
39
|
+
) -> TaskSendParams:
|
40
|
+
parts: list[Part] = []
|
41
|
+
|
42
|
+
if text is not None:
|
43
|
+
parts.append(TextPart(text=text))
|
44
|
+
|
45
|
+
if data is not None:
|
46
|
+
parts.append(DataPart(data=data))
|
47
|
+
|
48
|
+
if file_uri is not None:
|
49
|
+
file_content = FileContent(uri=file_uri)
|
50
|
+
parts.append(FilePart(file=file_content))
|
51
|
+
|
52
|
+
message = Message(role=role, parts=parts)
|
53
|
+
|
54
|
+
return TaskSendParams(
|
55
|
+
id=id,
|
56
|
+
sessionId=session_id or uuid4().hex,
|
57
|
+
message=message,
|
58
|
+
acceptedOutputModes=accepted_output_modes,
|
59
|
+
pushNotification=push_notification,
|
60
|
+
historyLength=history_length,
|
61
|
+
metadata=metadata,
|
62
|
+
)
|
63
|
+
|
64
|
+
@staticmethod
|
65
|
+
def get_task(
|
66
|
+
id: str,
|
67
|
+
history_length: int | None = None,
|
68
|
+
metadata: dict[str, Any] | None = None,
|
69
|
+
) -> GetTaskRequest:
|
70
|
+
params = TaskQueryParams(
|
71
|
+
id=id,
|
72
|
+
historyLength=history_length,
|
73
|
+
metadata=metadata,
|
74
|
+
)
|
75
|
+
return GetTaskRequest(params=params)
|
76
|
+
|
77
|
+
@staticmethod
|
78
|
+
def cancel_task(
|
79
|
+
id: str,
|
80
|
+
metadata: dict[str, Any] | None = None,
|
81
|
+
) -> CancelTaskRequest:
|
82
|
+
params = TaskIdParams(id=id, metadata=metadata)
|
83
|
+
return CancelTaskRequest(params=params)
|
84
|
+
|
85
|
+
@staticmethod
|
86
|
+
def set_push_notification(
|
87
|
+
id: str,
|
88
|
+
url: str,
|
89
|
+
token: str | None = None,
|
90
|
+
authentication: AuthenticationInfo | dict[str, Any] | None = None,
|
91
|
+
) -> SetTaskPushNotificationRequest:
|
92
|
+
# allow passing AuthenticationInfo _or_ raw dict
|
93
|
+
auth = (
|
94
|
+
authentication
|
95
|
+
if isinstance(authentication, AuthenticationInfo)
|
96
|
+
else (AuthenticationInfo(**authentication) if authentication else None)
|
97
|
+
)
|
98
|
+
push_cfg = TaskPushNotificationConfig(
|
99
|
+
id=id,
|
100
|
+
pushNotificationConfig=PushNotificationConfig(
|
101
|
+
url=url,
|
102
|
+
token=token,
|
103
|
+
authentication=auth,
|
104
|
+
)
|
105
|
+
)
|
106
|
+
return SetTaskPushNotificationRequest(params=push_cfg)
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
def get_push_notification(
|
110
|
+
id: str,
|
111
|
+
metadata: dict[str, Any] | None = None,
|
112
|
+
) -> GetTaskPushNotificationRequest:
|
113
|
+
params = TaskIdParams(id=id, metadata=metadata)
|
114
|
+
return GetTaskPushNotificationRequest(params=params)
|