smarta2a 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. smarta2a/__init__.py +4 -4
  2. smarta2a/agent/a2a_agent.py +38 -0
  3. smarta2a/agent/a2a_mcp_server.py +37 -0
  4. smarta2a/archive/mcp_client.py +86 -0
  5. smarta2a/client/__init__.py +0 -0
  6. smarta2a/client/a2a_client.py +267 -0
  7. smarta2a/client/smart_mcp_client.py +60 -0
  8. smarta2a/client/tools_manager.py +58 -0
  9. smarta2a/history_update_strategies/__init__.py +8 -0
  10. smarta2a/history_update_strategies/append_strategy.py +10 -0
  11. smarta2a/history_update_strategies/history_update_strategy.py +15 -0
  12. smarta2a/model_providers/__init__.py +5 -0
  13. smarta2a/model_providers/base_llm_provider.py +15 -0
  14. smarta2a/model_providers/openai_provider.py +281 -0
  15. smarta2a/server/__init__.py +3 -0
  16. smarta2a/server/handler_registry.py +23 -0
  17. smarta2a/{server.py → server/server.py} +224 -254
  18. smarta2a/server/state_manager.py +34 -0
  19. smarta2a/server/subscription_service.py +109 -0
  20. smarta2a/server/task_service.py +155 -0
  21. smarta2a/state_stores/__init__.py +8 -0
  22. smarta2a/state_stores/base_state_store.py +20 -0
  23. smarta2a/state_stores/inmemory_state_store.py +21 -0
  24. smarta2a/utils/__init__.py +32 -0
  25. smarta2a/utils/prompt_helpers.py +38 -0
  26. smarta2a/utils/task_builder.py +153 -0
  27. smarta2a/utils/task_request_builder.py +114 -0
  28. smarta2a/{types.py → utils/types.py} +62 -2
  29. {smarta2a-0.2.1.dist-info → smarta2a-0.2.3.dist-info}/METADATA +13 -7
  30. smarta2a-0.2.3.dist-info/RECORD +32 -0
  31. smarta2a-0.2.1.dist-info/RECORD +0 -7
  32. {smarta2a-0.2.1.dist-info → smarta2a-0.2.3.dist-info}/WHEEL +0 -0
  33. {smarta2a-0.2.1.dist-info → smarta2a-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,34 @@
1
+ # Library imports
2
+ from typing import Optional, Dict, Any
3
+ from uuid import uuid4
4
+
5
+ # Local imports
6
+ from smarta2a.state_stores.base_state_store import BaseStateStore
7
+ from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
8
+ from smarta2a.utils.types import Message, StateData
9
+
10
+ class StateManager:
11
+ def __init__(self, store: Optional[BaseStateStore], history_strategy: HistoryUpdateStrategy):
12
+ self.store = store
13
+ self.strategy = history_strategy
14
+
15
+ def init_or_get(self, session_id: Optional[str], message: Message, metadata: Dict[str, Any]) -> StateData:
16
+ sid = session_id or str(uuid4())
17
+ if not self.store:
18
+ return StateData(sessionId=sid, history=[message], metadata=metadata or {})
19
+ existing = self.store.get_state(sid) or StateData(sessionId=sid, history=[], metadata={})
20
+ existing.history.append(message)
21
+ existing.metadata = {**(existing.metadata or {}), **(metadata or {})}
22
+ self.store.update_state(sid, existing)
23
+ return existing
24
+
25
+ def update(self, state: StateData):
26
+ if self.store:
27
+ self.store.update_state(state.sessionId, state)
28
+
29
+ def get_store(self) -> Optional[BaseStateStore]:
30
+ return self.store
31
+
32
+ def get_strategy(self) -> HistoryUpdateStrategy:
33
+ return self.strategy
34
+
@@ -0,0 +1,109 @@
1
+ # Library imports
2
+ from typing import Optional, List, Dict, Any, AsyncGenerator, Union
3
+ from datetime import datetime
4
+ from collections import defaultdict
5
+ from uuid import uuid4
6
+ from fastapi.responses import StreamingResponse
7
+ from sse_starlette.sse import EventSourceResponse
8
+
9
+ # Local imports
10
+ from smarta2a.server.handler_registry import HandlerRegistry
11
+ from smarta2a.server.state_manager import StateManager
12
+ from smarta2a.utils.types import (
13
+ Message, StateData, SendTaskStreamingRequest, SendTaskStreamingResponse,
14
+ TaskSendParams, A2AStatus, A2AStreamResponse, TaskStatusUpdateEvent,
15
+ TaskStatus, TaskState, TaskArtifactUpdateEvent, Artifact, TextPart,
16
+ FilePart, DataPart, FileContent, MethodNotFoundError, TaskNotFoundError,
17
+ InternalError
18
+ )
19
+
20
+ class SubscriptionService:
21
+ def __init__(self, registry: HandlerRegistry, state_mgr: StateManager):
22
+ self.registry = registry
23
+ self.state_mgr = state_mgr
24
+
25
+ async def subscribe(self, request: SendTaskStreamingRequest, state: Optional[StateData]) -> StreamingResponse:
26
+ handler = self.registry.get_subscription("tasks/sendSubscribe")
27
+ if not handler:
28
+ err = SendTaskStreamingResponse(jsonrpc="2.0", id=request.id, error=MethodNotFoundError()).model_dump_json()
29
+ return EventSourceResponse(err)
30
+
31
+ session_id = state.sessionId if state else request.params.sessionId or str(uuid4())
32
+ history = state.history.copy() if state else [request.params.message]
33
+ metadata = state.metadata.copy() if state else (request.params.metadata or {})
34
+
35
+ async def event_stream():
36
+ try:
37
+ events = handler(request, state) if state else handler(request)
38
+ async for ev in self._normalize(request.params, events, history.copy(), metadata.copy(), session_id):
39
+ yield f"data: {ev}\n\n"
40
+ except Exception as e:
41
+ err = TaskNotFoundError() if 'not found' in str(e).lower() else InternalError(data=str(e))
42
+ msg = SendTaskStreamingResponse(jsonrpc="2.0", id=request.id, error=err).model_dump_json()
43
+ yield f"data: {msg}\n\n"
44
+
45
+ return StreamingResponse(event_stream(), media_type="text/event-stream; charset=utf-8")
46
+
47
+ async def _normalize(
48
+ self,
49
+ params: TaskSendParams,
50
+ events: AsyncGenerator,
51
+ history: List[Message],
52
+ metadata: Dict[str, Any],
53
+ session_id: str
54
+ ) -> AsyncGenerator[str, None]:
55
+ artifact_state = defaultdict(lambda: {"index": 0, "last_chunk": False})
56
+ async for item in events:
57
+ if isinstance(item, SendTaskStreamingResponse):
58
+ yield item.model_dump_json()
59
+ continue
60
+
61
+ if isinstance(item, A2AStatus):
62
+ te = TaskStatusUpdateEvent(
63
+ id=params.id,
64
+ status=TaskStatus(state=TaskState(item.status), timestamp=datetime.now()),
65
+ final=item.final or (item.status.lower() == TaskState.COMPLETED),
66
+ metadata=item.metadata
67
+ )
68
+ yield SendTaskStreamingResponse(jsonrpc="2.0", id=params.id, result=te).model_dump_json()
69
+ continue
70
+
71
+ content_item = item
72
+ if not isinstance(item, A2AStreamResponse):
73
+ content_item = A2AStreamResponse(content=item)
74
+
75
+ parts: List[Union[TextPart, FilePart, DataPart]] = []
76
+ cont = content_item.content
77
+ if isinstance(cont, str): parts.append(TextPart(text=cont))
78
+ elif isinstance(cont, bytes): parts.append(FilePart(file=FileContent(bytes=cont)))
79
+ elif isinstance(cont, (TextPart, FilePart, DataPart)): parts.append(cont)
80
+ elif isinstance(cont, Artifact): parts.extend(cont.parts)
81
+ elif isinstance(cont, list):
82
+ for elem in cont:
83
+ if isinstance(elem, str): parts.append(TextPart(text=elem))
84
+ elif isinstance(elem, (TextPart, FilePart, DataPart)): parts.append(elem)
85
+ elif isinstance(elem, Artifact): parts.extend(elem.parts)
86
+
87
+ idx = content_item.index
88
+ state = artifact_state[idx]
89
+ evt = TaskArtifactUpdateEvent(
90
+ id=params.id,
91
+ artifact=Artifact(
92
+ parts=parts,
93
+ index=idx,
94
+ append=content_item.append or (state["index"] == idx),
95
+ lastChunk=content_item.final or state["last_chunk"],
96
+ metadata=content_item.metadata
97
+ )
98
+ )
99
+ if content_item.final:
100
+ state["last_chunk"] = True
101
+ state["index"] += 1
102
+
103
+ agent_msg = Message(role="agent", parts=evt.artifact.parts, metadata=evt.artifact.metadata)
104
+ new_hist = self.state_mgr.strategy.update_history(history, [agent_msg])
105
+ metadata = {**metadata, **(evt.artifact.metadata or {})}
106
+ self.state_mgr.update(StateData(session_id, new_hist, metadata))
107
+ history = new_hist
108
+
109
+ yield SendTaskStreamingResponse(jsonrpc="2.0", id=params.id, result=evt).model_dump_json()
@@ -0,0 +1,155 @@
1
+ # Library imports
2
+ from typing import Optional, Union, Any
3
+ from uuid import uuid4
4
+ from fastapi import HTTPException
5
+ from pydantic import ValidationError
6
+
7
+ # Local imports
8
+ from smarta2a.server.handler_registry import HandlerRegistry
9
+ from smarta2a.server.state_manager import StateManager
10
+ from smarta2a.utils.task_builder import TaskBuilder
11
+ from smarta2a.utils.types import (
12
+ Message, StateData, SendTaskRequest, SendTaskResponse,
13
+ GetTaskRequest, GetTaskResponse, CancelTaskRequest, CancelTaskResponse,
14
+ SetTaskPushNotificationRequest, GetTaskPushNotificationRequest,
15
+ SetTaskPushNotificationResponse, GetTaskPushNotificationResponse,
16
+ TaskPushNotificationConfig, TaskState, A2AStatus,
17
+ JSONRPCError, MethodNotFoundError, InternalError, InvalidParamsError,
18
+ TaskNotCancelableError, UnsupportedOperationError
19
+ )
20
+
21
+ class TaskService:
22
+ def __init__(self, registry: HandlerRegistry, state_mgr: StateManager):
23
+ self.registry = registry
24
+ self.state_mgr = state_mgr
25
+ self.builder = TaskBuilder(default_status=TaskState.COMPLETED)
26
+
27
+ def send(self, request: SendTaskRequest, state: Optional[StateData]) -> SendTaskResponse:
28
+ handler = self.registry.get_handler("tasks/send")
29
+ if not handler:
30
+ return SendTaskResponse(id=request.id, error=MethodNotFoundError())
31
+
32
+ session_id = state.sessionId if state else request.params.sessionId or str(uuid4())
33
+ history = state.history.copy() if state else [request.params.message]
34
+ metadata = state.metadata.copy() if state else (request.params.metadata or {})
35
+
36
+ try:
37
+ raw = handler(request, state) if state else handler(request)
38
+ if isinstance(raw, SendTaskResponse):
39
+ return raw
40
+
41
+ task = self.builder.build(
42
+ content=raw,
43
+ task_id=request.params.id,
44
+ session_id=session_id,
45
+ metadata=metadata,
46
+ history=history
47
+ )
48
+
49
+ if task.artifacts:
50
+ parts = [p for a in task.artifacts for p in a.parts]
51
+ agent_msg = Message(role="agent", parts=parts, metadata=task.metadata)
52
+ new_hist = self.state_mgr.strategy.update_history(history, [agent_msg])
53
+ task.history = new_hist
54
+ self.state_mgr.update(StateData(sessionId=session_id, history=new_hist, metadata=metadata))
55
+
56
+ return SendTaskResponse(id=request.id, result=task)
57
+ except JSONRPCError as e:
58
+ return SendTaskResponse(id=request.id, error=e)
59
+ except Exception as e:
60
+ return SendTaskResponse(id=request.id, error=InternalError(data=str(e)))
61
+
62
+ def get(self, request: GetTaskRequest) -> GetTaskResponse:
63
+ handler = self.registry.get_handler("tasks/get")
64
+ if not handler:
65
+ return GetTaskResponse(id=request.id, error=MethodNotFoundError())
66
+ try:
67
+ raw = handler(request)
68
+ if isinstance(raw, GetTaskResponse):
69
+ return self._validate(raw, request)
70
+
71
+ task = self.builder.build(
72
+ content=raw,
73
+ task_id=request.params.id,
74
+ metadata=request.params.metadata or {}
75
+ )
76
+ return self._finalize(request, task)
77
+ except JSONRPCError as e:
78
+ return GetTaskResponse(id=request.id, error=e)
79
+ except Exception as e:
80
+ return GetTaskResponse(id=request.id, error=InternalError(data=str(e)))
81
+
82
+ def cancel(self, request: CancelTaskRequest) -> CancelTaskResponse:
83
+ handler = self.registry.get_handler("tasks/cancel")
84
+ if not handler:
85
+ return CancelTaskResponse(id=request.id, error=MethodNotFoundError())
86
+ try:
87
+ raw = handler(request)
88
+ if isinstance(raw, CancelTaskResponse):
89
+ return self._validate(raw, request)
90
+
91
+ if isinstance(raw, A2AStatus):
92
+ task = self.builder.normalize_from_status(status=raw.status, task_id=request.params.id, metadata=raw.metadata or {})
93
+ else:
94
+ task = self.builder.build(content=raw, task_id=request.params.id, metadata=raw.metadata or {})
95
+
96
+ if task.id != request.params.id:
97
+ raise InvalidParamsError(data=f"Task ID mismatch: {task.id} vs {request.params.id}")
98
+ if task.status.state not in [TaskState.CANCELED, TaskState.COMPLETED]:
99
+ raise TaskNotCancelableError()
100
+
101
+ return CancelTaskResponse(id=request.id, result=task)
102
+ except JSONRPCError as e:
103
+ return CancelTaskResponse(id=request.id, error=e)
104
+ except (InvalidParamsError, TaskNotCancelableError) as e:
105
+ return CancelTaskResponse(id=request.id, error=e)
106
+ except HTTPException as e:
107
+ if e.status_code == 405:
108
+ return CancelTaskResponse(id=request.id, error=UnsupportedOperationError())
109
+ return CancelTaskResponse(id=request.id, error=InternalError(data=str(e)))
110
+ except Exception as e:
111
+ return CancelTaskResponse(id=request.id, error=InternalError(data=str(e)))
112
+
113
+ def set_notification(self, request: SetTaskPushNotificationRequest) -> SetTaskPushNotificationResponse:
114
+ handler = self.registry.get_handler("tasks/pushNotification/set")
115
+ if not handler:
116
+ return SetTaskPushNotificationResponse(id=request.id, error=MethodNotFoundError())
117
+ try:
118
+ raw = handler(request)
119
+ if raw is None:
120
+ return SetTaskPushNotificationResponse(id=request.id, result=request.params)
121
+ if isinstance(raw, SetTaskPushNotificationResponse):
122
+ return raw
123
+ except JSONRPCError as e:
124
+ return SetTaskPushNotificationResponse(id=request.id, error=e)
125
+ except Exception as e:
126
+ return SetTaskPushNotificationResponse(id=request.id, error=InternalError(data=str(e)))
127
+
128
+ def get_notification(self, request: GetTaskPushNotificationRequest) -> GetTaskPushNotificationResponse:
129
+ handler = self.registry.get_handler("tasks/pushNotification/get")
130
+ if not handler:
131
+ return GetTaskPushNotificationResponse(id=request.id, error=MethodNotFoundError())
132
+ try:
133
+ raw = handler(request)
134
+ if isinstance(raw, GetTaskPushNotificationResponse):
135
+ return raw
136
+ cfg = TaskPushNotificationConfig.model_validate(raw)
137
+ return GetTaskPushNotificationResponse(id=request.id, result=cfg)
138
+ except ValidationError as e:
139
+ return GetTaskPushNotificationResponse(id=request.id, error=InvalidParamsError(data=e.errors()))
140
+ except JSONRPCError as e:
141
+ return GetTaskPushNotificationResponse(id=request.id, error=e)
142
+ except Exception as e:
143
+ return GetTaskPushNotificationResponse(id=request.id, error=InternalError(data=str(e)))
144
+
145
+ def _validate(self, resp: Union[SendTaskResponse, GetTaskResponse, CancelTaskResponse], req) -> Any:
146
+ if resp.result and resp.result.id != req.params.id:
147
+ return type(resp)(id=req.id, error=InvalidParamsError(data=f"Task ID mismatch: {resp.result.id} vs {req.params.id}"))
148
+ return resp
149
+
150
+ def _finalize(self, request: GetTaskRequest, task) -> GetTaskResponse:
151
+ if task.id != request.params.id:
152
+ return GetTaskResponse(id=request.id, error=InvalidParamsError(data=f"Task ID mismatch: {task.id} vs {request.params.id}"))
153
+ if request.params.historyLength and task.history:
154
+ task.history = task.history[-request.params.historyLength:]
155
+ return GetTaskResponse(id=request.id, result=task)
@@ -0,0 +1,8 @@
1
+ """
2
+ State store implementations for managing conversation state.
3
+ """
4
+
5
+ from .base_state_store import BaseStateStore
6
+ from .inmemory_state_store import InMemoryStateStore
7
+
8
+ __all__ = ['BaseStateStore', 'InMemoryStateStore']
@@ -0,0 +1,20 @@
1
+ # Library imports
2
+ from abc import ABC, abstractmethod
3
+ from typing import Optional, List, Dict, Any
4
+
5
+ # Local imports
6
+ from smarta2a.utils.types import StateData, Message
7
+
8
+ class BaseStateStore(ABC):
9
+
10
+ @abstractmethod
11
+ async def get_state(self, session_id: str) -> Optional[StateData]:
12
+ pass
13
+
14
+ @abstractmethod
15
+ async def update_state(self, session_id: str, state_data: StateData) -> None:
16
+ pass
17
+
18
+ @abstractmethod
19
+ async def delete_state(self, session_id: str) -> None:
20
+ pass
@@ -0,0 +1,21 @@
1
+ # Library imports
2
+ from typing import Dict, Any, Optional, List
3
+ import uuid
4
+
5
+ # Local imports
6
+ from smarta2a.state_stores.base_state_store import BaseStateStore
7
+ from smarta2a.utils.types import StateData, Message
8
+
9
+ class InMemoryStateStore(BaseStateStore):
10
+ def __init__(self):
11
+ self.states: Dict[str, StateData] = {}
12
+
13
+ def get_state(self, session_id: str) -> Optional[StateData]:
14
+ return self.states.get(session_id)
15
+
16
+ def update_state(self, session_id: str, state_data: StateData):
17
+ self.states[session_id] = state_data
18
+
19
+ def delete_state(self, session_id: str):
20
+ if session_id in self.states:
21
+ del self.states[session_id]
@@ -0,0 +1,32 @@
1
+ from .types import *
2
+
3
+ __all__ = [
4
+ "TaskSendParams",
5
+ "SendTaskRequest",
6
+ "GetTaskRequest",
7
+ "CancelTaskRequest",
8
+ "CancelTaskResponse",
9
+ "Task",
10
+ "TaskStatus",
11
+ "TaskState",
12
+ "Artifact",
13
+ "TextPart",
14
+ "FilePart",
15
+ "FileContent",
16
+ "A2AResponse",
17
+ "A2ARequest",
18
+ "TaskQueryParams",
19
+ "TaskStatusUpdateEvent",
20
+ "TaskArtifactUpdateEvent",
21
+ "A2AStatus",
22
+ "A2AStreamResponse",
23
+ "SendTaskResponse",
24
+ "Message",
25
+ "InternalError",
26
+ "TaskNotFoundError",
27
+ "SetTaskPushNotificationRequest",
28
+ "GetTaskPushNotificationRequest",
29
+ "SetTaskPushNotificationResponse",
30
+ "GetTaskPushNotificationResponse",
31
+ "TaskPushNotificationConfig"
32
+ ]
@@ -0,0 +1,38 @@
1
+ # Library imports
2
+ from typing import Optional, List
3
+
4
+ # Local imports
5
+ from smarta2a.client.tools_manager import ToolsManager
6
+ from smarta2a.utils.types import AgentCard
7
+
8
+ def build_system_prompt(
9
+ base_prompt: Optional[str],
10
+ tools_manager: ToolsManager,
11
+ mcp_server_urls_or_paths: Optional[List[str]] = None,
12
+ agent_cards: Optional[List[AgentCard]] = None
13
+ ) -> str:
14
+ """
15
+ Compose the final system prompt by combining the base prompt
16
+ with a clear listing of available tools.
17
+ """
18
+ header = base_prompt or "You are a helpful assistant with access to the following tools:"
19
+
20
+ if mcp_server_urls_or_paths:
21
+ mcp_tools_desc = tools_manager.describe_tools("mcp")
22
+ header += f"\n\nAvailable tools:\n{mcp_tools_desc}"
23
+
24
+ if agent_cards:
25
+ a2a_tools_desc = tools_manager.describe_tools("a2a")
26
+ header += f"\n\nIf needed, you can delegate parts of your task to other agents. The Agents you can use are:\n{_print_agent_list(agent_cards)}\n\nUse the following tools to send tasks to an agent:\n{a2a_tools_desc}"
27
+
28
+ return header
29
+
30
+
31
+ def _print_agent_list(agents: List[AgentCard]) -> None:
32
+ """Prints multiple agents with separators"""
33
+ separator = "---"
34
+ agent_strings = [agent.pretty_print(include_separators=False) for agent in agents]
35
+ full_output = [separator]
36
+ full_output.extend(agent_strings)
37
+ full_output.append(separator)
38
+ print("\n".join(full_output))
@@ -0,0 +1,153 @@
1
+ # Library imports
2
+ from uuid import uuid4
3
+ from datetime import datetime
4
+ from typing import Any, List, Optional, Dict, Union
5
+ from pydantic import ValidationError
6
+
7
+ # Local imports
8
+ from smarta2a.utils.types import (
9
+ Task,
10
+ TaskStatus,
11
+ TaskState,
12
+ Artifact,
13
+ Part,
14
+ TextPart,
15
+ FilePart,
16
+ DataPart,
17
+ Message,
18
+ A2AResponse,
19
+ )
20
+
21
+ class TaskBuilder:
22
+ def __init__(
23
+ self,
24
+ default_status: TaskState = TaskState.COMPLETED,
25
+ ):
26
+ self.default_status = default_status
27
+
28
+ def build(
29
+ self,
30
+ content: Any,
31
+ task_id: str,
32
+ session_id: Optional[str] = None,
33
+ metadata: Optional[Dict[str,Any]] = None,
34
+ history: Optional[List[Message]] = None,
35
+ ) -> Task:
36
+ """Universal task construction from various return types."""
37
+ history = history or []
38
+ metadata = metadata or {}
39
+
40
+ # 1) If the handler already gave us a full Task, just fix IDs & history:
41
+ if isinstance(content, Task):
42
+ content.sessionId = content.sessionId or session_id
43
+ content.history = history + (content.history or [])
44
+ content.metadata = content.metadata or metadata
45
+ return content
46
+
47
+ # 2) If they returned an A2AResponse, extract status/content:
48
+ if isinstance(content, A2AResponse):
49
+ # prefer the sessionId inside the A2AResponse
50
+ sid = content.sessionId or session_id
51
+ # merge metadata from builder-call and from A2AResponse
52
+ md = {**(metadata or {}), **(content.metadata or {})}
53
+ status = (
54
+ content.status
55
+ if isinstance(content.status, TaskStatus)
56
+ else TaskStatus(state=content.status)
57
+ )
58
+ artifacts = self._normalize_content(content.content)
59
+ return Task(
60
+ id=task_id,
61
+ sessionId=sid,
62
+ status=status,
63
+ artifacts=artifacts,
64
+ metadata=md,
65
+ history=history,
66
+ )
67
+
68
+ # 3) If they returned a plain dict describing a Task:
69
+ if isinstance(content, dict):
70
+ try:
71
+ return Task(
72
+ **content,
73
+ sessionId=session_id or content.get("sessionId"),
74
+ metadata=metadata or content.get("metadata", {}),
75
+ history=history,
76
+ )
77
+ except ValidationError:
78
+ pass
79
+
80
+ # 4) Fallback: treat whatever they returned as “artifact content”:
81
+ artifacts = self._normalize_content(content)
82
+ return Task(
83
+ id=task_id,
84
+ sessionId=session_id,
85
+ status=TaskStatus(state=self.default_status),
86
+ artifacts=artifacts,
87
+ metadata=metadata,
88
+ history=history,
89
+ )
90
+
91
+ def normalize_from_status(
92
+ self, status: TaskState, task_id: str, metadata: Dict[str,Any]
93
+ ) -> Task:
94
+ """Build a Task when only a cancellation or status‐only event occurs."""
95
+ return Task(
96
+ id=task_id,
97
+ sessionId="",
98
+ status=TaskStatus(state=status, timestamp=datetime.now()),
99
+ artifacts=[],
100
+ metadata=metadata,
101
+ history=[],
102
+ )
103
+
104
+ def _normalize_content(self, content: Any) -> List[Artifact]:
105
+ """Turn any handler return value into a list of Artifact."""
106
+ if isinstance(content, Artifact):
107
+ return [content]
108
+
109
+ if isinstance(content, list) and all(isinstance(a, Artifact) for a in content):
110
+ return content
111
+
112
+ if isinstance(content, list):
113
+ return [Artifact(parts=self._parts_from_mixed(content))]
114
+
115
+ if isinstance(content, str):
116
+ return [Artifact(parts=[TextPart(text=content)])]
117
+
118
+ if isinstance(content, dict):
119
+ # raw artifact dict
120
+ return [Artifact.model_validate(content)]
121
+
122
+ # explicit `Part` subclasses
123
+ if isinstance(content, (TextPart, FilePart, DataPart)):
124
+ return [Artifact(parts=[content])]
125
+
126
+ # “unknown” object: try Pydantic → dict → fallback to text
127
+ try:
128
+ return [Artifact.model_validate(content)]
129
+ except ValidationError:
130
+ return [Artifact(parts=[TextPart(text=str(content))])]
131
+
132
+ def _parts_from_mixed(self, items: List[Any]) -> List[Part]:
133
+ parts: List[Part] = []
134
+ for item in items:
135
+ if isinstance(item, Artifact):
136
+ parts.extend(item.parts)
137
+ else:
138
+ parts.append(self._create_part(item))
139
+ return parts
140
+
141
+ def _create_part(self, item: Any) -> Part:
142
+ from smarta2a.utils.types import Part as UnionPart
143
+ # guard against Union alias
144
+ if isinstance(item, (TextPart, FilePart, DataPart)):
145
+ return item
146
+ if isinstance(item, str):
147
+ return TextPart(text=item)
148
+ if isinstance(item, dict):
149
+ try:
150
+ return UnionPart.model_validate(item)
151
+ except ValidationError:
152
+ return TextPart(text=str(item))
153
+ return TextPart(text=str(item))
@@ -0,0 +1,114 @@
1
+ # Library imports
2
+ from typing import Any, Literal
3
+ from uuid import uuid4
4
+
5
+ # Local imports
6
+ from smarta2a.utils.types import (
7
+ TaskPushNotificationConfig,
8
+ PushNotificationConfig,
9
+ TaskSendParams,
10
+ TextPart,
11
+ DataPart,
12
+ FilePart,
13
+ FileContent,
14
+ Message,
15
+ Part,
16
+ TaskQueryParams,
17
+ TaskIdParams,
18
+ GetTaskRequest,
19
+ CancelTaskRequest,
20
+ SetTaskPushNotificationRequest,
21
+ GetTaskPushNotificationRequest,
22
+ AuthenticationInfo,
23
+ )
24
+
25
+ class TaskRequestBuilder:
26
+ @staticmethod
27
+ def build_send_task_request(
28
+ *,
29
+ id: str,
30
+ role: Literal["user", "agent"] = "user",
31
+ text: str | None = None,
32
+ data: dict[str, Any] | None = None,
33
+ file_uri: str | None = None,
34
+ session_id: str | None = None,
35
+ accepted_output_modes: list[str] | None = None,
36
+ push_notification: PushNotificationConfig | None = None,
37
+ history_length: int | None = None,
38
+ metadata: dict[str, Any] | None = None,
39
+ ) -> TaskSendParams:
40
+ parts: list[Part] = []
41
+
42
+ if text is not None:
43
+ parts.append(TextPart(text=text))
44
+
45
+ if data is not None:
46
+ parts.append(DataPart(data=data))
47
+
48
+ if file_uri is not None:
49
+ file_content = FileContent(uri=file_uri)
50
+ parts.append(FilePart(file=file_content))
51
+
52
+ message = Message(role=role, parts=parts)
53
+
54
+ return TaskSendParams(
55
+ id=id,
56
+ sessionId=session_id or uuid4().hex,
57
+ message=message,
58
+ acceptedOutputModes=accepted_output_modes,
59
+ pushNotification=push_notification,
60
+ historyLength=history_length,
61
+ metadata=metadata,
62
+ )
63
+
64
+ @staticmethod
65
+ def get_task(
66
+ id: str,
67
+ history_length: int | None = None,
68
+ metadata: dict[str, Any] | None = None,
69
+ ) -> GetTaskRequest:
70
+ params = TaskQueryParams(
71
+ id=id,
72
+ historyLength=history_length,
73
+ metadata=metadata,
74
+ )
75
+ return GetTaskRequest(params=params)
76
+
77
+ @staticmethod
78
+ def cancel_task(
79
+ id: str,
80
+ metadata: dict[str, Any] | None = None,
81
+ ) -> CancelTaskRequest:
82
+ params = TaskIdParams(id=id, metadata=metadata)
83
+ return CancelTaskRequest(params=params)
84
+
85
+ @staticmethod
86
+ def set_push_notification(
87
+ id: str,
88
+ url: str,
89
+ token: str | None = None,
90
+ authentication: AuthenticationInfo | dict[str, Any] | None = None,
91
+ ) -> SetTaskPushNotificationRequest:
92
+ # allow passing AuthenticationInfo _or_ raw dict
93
+ auth = (
94
+ authentication
95
+ if isinstance(authentication, AuthenticationInfo)
96
+ else (AuthenticationInfo(**authentication) if authentication else None)
97
+ )
98
+ push_cfg = TaskPushNotificationConfig(
99
+ id=id,
100
+ pushNotificationConfig=PushNotificationConfig(
101
+ url=url,
102
+ token=token,
103
+ authentication=auth,
104
+ )
105
+ )
106
+ return SetTaskPushNotificationRequest(params=push_cfg)
107
+
108
+ @staticmethod
109
+ def get_push_notification(
110
+ id: str,
111
+ metadata: dict[str, Any] | None = None,
112
+ ) -> GetTaskPushNotificationRequest:
113
+ params = TaskIdParams(id=id, metadata=metadata)
114
+ return GetTaskPushNotificationRequest(params=params)