smarta2a 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {smarta2a-0.2.0 → smarta2a-0.2.2}/PKG-INFO +2 -2
- {smarta2a-0.2.0 → smarta2a-0.2.2}/README.md +1 -1
- {smarta2a-0.2.0 → smarta2a-0.2.2}/pyproject.toml +1 -1
- smarta2a-0.2.2/smarta2a/__init__.py +10 -0
- smarta2a-0.2.2/smarta2a/client/__init__.py +0 -0
- smarta2a-0.2.2/smarta2a/client/a2a_client.py +173 -0
- smarta2a-0.2.2/smarta2a/common/__init__.py +32 -0
- smarta2a-0.2.2/smarta2a/common/task_request_builder.py +114 -0
- smarta2a-0.2.2/smarta2a/server/__init__.py +3 -0
- {smarta2a-0.2.0/smarta2a → smarta2a-0.2.2/smarta2a/server}/server.py +118 -4
- {smarta2a-0.2.0 → smarta2a-0.2.2}/tests/test_server.py +150 -4
- smarta2a-0.2.2/tests/test_task_request_builder.py +130 -0
- smarta2a-0.2.0/smarta2a/__init__.py +0 -10
- {smarta2a-0.2.0 → smarta2a-0.2.2}/.gitignore +0 -0
- {smarta2a-0.2.0 → smarta2a-0.2.2}/LICENSE +0 -0
- {smarta2a-0.2.0 → smarta2a-0.2.2}/requirements.txt +0 -0
- {smarta2a-0.2.0/smarta2a → smarta2a-0.2.2/smarta2a/common}/types.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: smarta2a
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.2
|
4
4
|
Summary: A Python package for creating servers and clients following Google's Agent2Agent protocol
|
5
5
|
Project-URL: Homepage, https://github.com/siddharthsma/smarta2a
|
6
6
|
Project-URL: Bug Tracker, https://github.com/siddharthsma/smarta2a/issues
|
@@ -45,7 +45,7 @@ pip install smarta2a
|
|
45
45
|
## Simple Echo Server Implementation
|
46
46
|
|
47
47
|
```python
|
48
|
-
from smarta2a import SmartA2A
|
48
|
+
from smarta2a.server import SmartA2A
|
49
49
|
|
50
50
|
app = SmartA2A("EchoServer")
|
51
51
|
|
File without changes
|
@@ -0,0 +1,173 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Any, Literal, AsyncIterable
|
3
|
+
import httpx
|
4
|
+
import json
|
5
|
+
from httpx_sse import connect_sse
|
6
|
+
|
7
|
+
# Local imports
|
8
|
+
from smarta2a.common.types import (
|
9
|
+
PushNotificationConfig,
|
10
|
+
SendTaskStreamingResponse,
|
11
|
+
SendTaskResponse,
|
12
|
+
SendTaskStreamingRequest,
|
13
|
+
SendTaskRequest,
|
14
|
+
JSONRPCRequest,
|
15
|
+
A2AClientJSONError,
|
16
|
+
A2AClientHTTPError,
|
17
|
+
AgentCard,
|
18
|
+
AuthenticationInfo,
|
19
|
+
GetTaskResponse,
|
20
|
+
CancelTaskResponse,
|
21
|
+
SetTaskPushNotificationResponse,
|
22
|
+
GetTaskPushNotificationResponse,
|
23
|
+
)
|
24
|
+
from smarta2a.common.task_request_builder import TaskRequestBuilder
|
25
|
+
|
26
|
+
|
27
|
+
class A2AClient:
|
28
|
+
def __init__(self, agent_card: AgentCard = None, url: str = None):
|
29
|
+
if agent_card:
|
30
|
+
self.url = agent_card.url
|
31
|
+
elif url:
|
32
|
+
self.url = url
|
33
|
+
else:
|
34
|
+
raise ValueError("Must provide either agent_card or url")
|
35
|
+
|
36
|
+
async def send(
|
37
|
+
self,
|
38
|
+
*,
|
39
|
+
id: str,
|
40
|
+
role: Literal["user", "agent"] = "user",
|
41
|
+
text: str | None = None,
|
42
|
+
data: dict[str, Any] | None = None,
|
43
|
+
file_uri: str | None = None,
|
44
|
+
session_id: str | None = None,
|
45
|
+
accepted_output_modes: list[str] | None = None,
|
46
|
+
push_notification: PushNotificationConfig | None = None,
|
47
|
+
history_length: int | None = None,
|
48
|
+
metadata: dict[str, Any] | None = None,
|
49
|
+
):
|
50
|
+
params = TaskRequestBuilder.build_send_task_request(
|
51
|
+
id=id,
|
52
|
+
role=role,
|
53
|
+
text=text,
|
54
|
+
data=data,
|
55
|
+
file_uri=file_uri,
|
56
|
+
session_id=session_id,
|
57
|
+
accepted_output_modes=accepted_output_modes,
|
58
|
+
push_notification=push_notification,
|
59
|
+
history_length=history_length,
|
60
|
+
metadata=metadata,
|
61
|
+
)
|
62
|
+
request = SendTaskRequest(params=params)
|
63
|
+
return SendTaskResponse(**await self._send_request(request))
|
64
|
+
|
65
|
+
def subscribe(
|
66
|
+
self,
|
67
|
+
*,
|
68
|
+
id: str,
|
69
|
+
role: Literal["user", "agent"] = "user",
|
70
|
+
text: str | None = None,
|
71
|
+
data: dict[str, Any] | None = None,
|
72
|
+
file_uri: str | None = None,
|
73
|
+
session_id: str | None = None,
|
74
|
+
accepted_output_modes: list[str] | None = None,
|
75
|
+
push_notification: PushNotificationConfig | None = None,
|
76
|
+
history_length: int | None = None,
|
77
|
+
metadata: dict[str, Any] | None = None,
|
78
|
+
):
|
79
|
+
params = TaskRequestBuilder.build_send_task_request(
|
80
|
+
id=id,
|
81
|
+
role=role,
|
82
|
+
text=text,
|
83
|
+
data=data,
|
84
|
+
file_uri=file_uri,
|
85
|
+
session_id=session_id,
|
86
|
+
accepted_output_modes=accepted_output_modes,
|
87
|
+
push_notification=push_notification,
|
88
|
+
history_length=history_length,
|
89
|
+
metadata=metadata,
|
90
|
+
)
|
91
|
+
request = SendTaskStreamingRequest(params=params)
|
92
|
+
with httpx.Client(timeout=None) as client:
|
93
|
+
with connect_sse(
|
94
|
+
client, "POST", self.url, json=request.model_dump()
|
95
|
+
) as event_source:
|
96
|
+
try:
|
97
|
+
for sse in event_source.iter_sse():
|
98
|
+
yield SendTaskStreamingResponse(**json.loads(sse.data))
|
99
|
+
except json.JSONDecodeError as e:
|
100
|
+
raise A2AClientJSONError(str(e)) from e
|
101
|
+
except httpx.RequestError as e:
|
102
|
+
raise A2AClientHTTPError(400, str(e)) from e
|
103
|
+
|
104
|
+
async def get_task(
|
105
|
+
self,
|
106
|
+
*,
|
107
|
+
id: str,
|
108
|
+
history_length: int | None = None,
|
109
|
+
metadata: dict[str, Any] | None = None,
|
110
|
+
) -> GetTaskResponse:
|
111
|
+
req = TaskRequestBuilder.get_task(id, history_length, metadata)
|
112
|
+
raw = await self._send_request(req)
|
113
|
+
return GetTaskResponse(**raw)
|
114
|
+
|
115
|
+
async def cancel_task(
|
116
|
+
self,
|
117
|
+
*,
|
118
|
+
id: str,
|
119
|
+
metadata: dict[str, Any] | None = None,
|
120
|
+
) -> CancelTaskResponse:
|
121
|
+
req = TaskRequestBuilder.cancel_task(id, metadata)
|
122
|
+
raw = await self._send_request(req)
|
123
|
+
return CancelTaskResponse(**raw)
|
124
|
+
|
125
|
+
async def set_push_notification(
|
126
|
+
self,
|
127
|
+
*,
|
128
|
+
id: str,
|
129
|
+
url: str,
|
130
|
+
token: str | None = None,
|
131
|
+
authentication: AuthenticationInfo | dict[str, Any] | None = None,
|
132
|
+
) -> SetTaskPushNotificationResponse:
|
133
|
+
req = TaskRequestBuilder.set_push_notification(id, url, token, authentication)
|
134
|
+
raw = await self._send_request(req)
|
135
|
+
return SetTaskPushNotificationResponse(**raw)
|
136
|
+
|
137
|
+
async def get_push_notification(
|
138
|
+
self,
|
139
|
+
*,
|
140
|
+
id: str,
|
141
|
+
metadata: dict[str, Any] | None = None,
|
142
|
+
) -> GetTaskPushNotificationResponse:
|
143
|
+
req = TaskRequestBuilder.get_push_notification(id, metadata)
|
144
|
+
raw = await self._send_request(req)
|
145
|
+
return GetTaskPushNotificationResponse(**raw)
|
146
|
+
|
147
|
+
|
148
|
+
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
|
149
|
+
async with httpx.AsyncClient() as client:
|
150
|
+
try:
|
151
|
+
# Image generation could take time, adding timeout
|
152
|
+
response = await client.post(
|
153
|
+
self.url, json=request.model_dump(), timeout=30
|
154
|
+
)
|
155
|
+
response.raise_for_status()
|
156
|
+
return response.json()
|
157
|
+
except httpx.HTTPStatusError as e:
|
158
|
+
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
|
159
|
+
except json.JSONDecodeError as e:
|
160
|
+
raise A2AClientJSONError(str(e)) from e
|
161
|
+
|
162
|
+
async def _send_streaming_request(self, request: JSONRPCRequest) -> AsyncIterable[SendTaskStreamingResponse]:
|
163
|
+
with httpx.Client(timeout=None) as client:
|
164
|
+
with connect_sse(
|
165
|
+
client, "POST", self.url, json=request.model_dump()
|
166
|
+
) as event_source:
|
167
|
+
try:
|
168
|
+
for sse in event_source.iter_sse():
|
169
|
+
yield SendTaskStreamingResponse(**json.loads(sse.data))
|
170
|
+
except json.JSONDecodeError as e:
|
171
|
+
raise A2AClientJSONError(str(e)) from e
|
172
|
+
except httpx.RequestError as e:
|
173
|
+
raise A2AClientHTTPError(400, str(e)) from e
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from .types import *
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"TaskSendParams",
|
5
|
+
"SendTaskRequest",
|
6
|
+
"GetTaskRequest",
|
7
|
+
"CancelTaskRequest",
|
8
|
+
"CancelTaskResponse",
|
9
|
+
"Task",
|
10
|
+
"TaskStatus",
|
11
|
+
"TaskState",
|
12
|
+
"Artifact",
|
13
|
+
"TextPart",
|
14
|
+
"FilePart",
|
15
|
+
"FileContent",
|
16
|
+
"A2AResponse",
|
17
|
+
"A2ARequest",
|
18
|
+
"TaskQueryParams",
|
19
|
+
"TaskStatusUpdateEvent",
|
20
|
+
"TaskArtifactUpdateEvent",
|
21
|
+
"A2AStatus",
|
22
|
+
"A2AStreamResponse",
|
23
|
+
"SendTaskResponse",
|
24
|
+
"Message",
|
25
|
+
"InternalError",
|
26
|
+
"TaskNotFoundError",
|
27
|
+
"SetTaskPushNotificationRequest",
|
28
|
+
"GetTaskPushNotificationRequest",
|
29
|
+
"SetTaskPushNotificationResponse",
|
30
|
+
"GetTaskPushNotificationResponse",
|
31
|
+
"TaskPushNotificationConfig"
|
32
|
+
]
|
@@ -0,0 +1,114 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Any, Literal
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
# Local imports
|
6
|
+
from smarta2a.common.types import (
|
7
|
+
TaskPushNotificationConfig,
|
8
|
+
PushNotificationConfig,
|
9
|
+
TaskSendParams,
|
10
|
+
TextPart,
|
11
|
+
DataPart,
|
12
|
+
FilePart,
|
13
|
+
FileContent,
|
14
|
+
Message,
|
15
|
+
Part,
|
16
|
+
TaskQueryParams,
|
17
|
+
TaskIdParams,
|
18
|
+
GetTaskRequest,
|
19
|
+
CancelTaskRequest,
|
20
|
+
SetTaskPushNotificationRequest,
|
21
|
+
GetTaskPushNotificationRequest,
|
22
|
+
AuthenticationInfo,
|
23
|
+
)
|
24
|
+
|
25
|
+
class TaskRequestBuilder:
|
26
|
+
@staticmethod
|
27
|
+
def build_send_task_request(
|
28
|
+
*,
|
29
|
+
id: str,
|
30
|
+
role: Literal["user", "agent"] = "user",
|
31
|
+
text: str | None = None,
|
32
|
+
data: dict[str, Any] | None = None,
|
33
|
+
file_uri: str | None = None,
|
34
|
+
session_id: str | None = None,
|
35
|
+
accepted_output_modes: list[str] | None = None,
|
36
|
+
push_notification: PushNotificationConfig | None = None,
|
37
|
+
history_length: int | None = None,
|
38
|
+
metadata: dict[str, Any] | None = None,
|
39
|
+
) -> TaskSendParams:
|
40
|
+
parts: list[Part] = []
|
41
|
+
|
42
|
+
if text is not None:
|
43
|
+
parts.append(TextPart(text=text))
|
44
|
+
|
45
|
+
if data is not None:
|
46
|
+
parts.append(DataPart(data=data))
|
47
|
+
|
48
|
+
if file_uri is not None:
|
49
|
+
file_content = FileContent(uri=file_uri)
|
50
|
+
parts.append(FilePart(file=file_content))
|
51
|
+
|
52
|
+
message = Message(role=role, parts=parts)
|
53
|
+
|
54
|
+
return TaskSendParams(
|
55
|
+
id=id,
|
56
|
+
sessionId=session_id or uuid4().hex,
|
57
|
+
message=message,
|
58
|
+
acceptedOutputModes=accepted_output_modes,
|
59
|
+
pushNotification=push_notification,
|
60
|
+
historyLength=history_length,
|
61
|
+
metadata=metadata,
|
62
|
+
)
|
63
|
+
|
64
|
+
@staticmethod
|
65
|
+
def get_task(
|
66
|
+
id: str,
|
67
|
+
history_length: int | None = None,
|
68
|
+
metadata: dict[str, Any] | None = None,
|
69
|
+
) -> GetTaskRequest:
|
70
|
+
params = TaskQueryParams(
|
71
|
+
id=id,
|
72
|
+
historyLength=history_length,
|
73
|
+
metadata=metadata,
|
74
|
+
)
|
75
|
+
return GetTaskRequest(params=params)
|
76
|
+
|
77
|
+
@staticmethod
|
78
|
+
def cancel_task(
|
79
|
+
id: str,
|
80
|
+
metadata: dict[str, Any] | None = None,
|
81
|
+
) -> CancelTaskRequest:
|
82
|
+
params = TaskIdParams(id=id, metadata=metadata)
|
83
|
+
return CancelTaskRequest(params=params)
|
84
|
+
|
85
|
+
@staticmethod
|
86
|
+
def set_push_notification(
|
87
|
+
id: str,
|
88
|
+
url: str,
|
89
|
+
token: str | None = None,
|
90
|
+
authentication: AuthenticationInfo | dict[str, Any] | None = None,
|
91
|
+
) -> SetTaskPushNotificationRequest:
|
92
|
+
# allow passing AuthenticationInfo _or_ raw dict
|
93
|
+
auth = (
|
94
|
+
authentication
|
95
|
+
if isinstance(authentication, AuthenticationInfo)
|
96
|
+
else (AuthenticationInfo(**authentication) if authentication else None)
|
97
|
+
)
|
98
|
+
push_cfg = TaskPushNotificationConfig(
|
99
|
+
id=id,
|
100
|
+
pushNotificationConfig=PushNotificationConfig(
|
101
|
+
url=url,
|
102
|
+
token=token,
|
103
|
+
authentication=auth,
|
104
|
+
)
|
105
|
+
)
|
106
|
+
return SetTaskPushNotificationRequest(params=push_cfg)
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
def get_push_notification(
|
110
|
+
id: str,
|
111
|
+
metadata: dict[str, Any] | None = None,
|
112
|
+
) -> GetTaskPushNotificationRequest:
|
113
|
+
params = TaskIdParams(id=id, metadata=metadata)
|
114
|
+
return GetTaskPushNotificationRequest(params=params)
|
@@ -1,6 +1,5 @@
|
|
1
1
|
from typing import Callable, Any, Optional, Dict, Union, List, AsyncGenerator
|
2
2
|
import json
|
3
|
-
import inspect
|
4
3
|
from datetime import datetime
|
5
4
|
from collections import defaultdict
|
6
5
|
from fastapi import FastAPI, Request, HTTPException, APIRouter
|
@@ -11,8 +10,7 @@ import uvicorn
|
|
11
10
|
from fastapi.responses import StreamingResponse
|
12
11
|
from uuid import uuid4
|
13
12
|
|
14
|
-
|
15
|
-
from .types import (
|
13
|
+
from smarta2a.common.types import (
|
16
14
|
JSONRPCResponse,
|
17
15
|
Task,
|
18
16
|
Artifact,
|
@@ -39,7 +37,6 @@ from .types import (
|
|
39
37
|
JSONParseError,
|
40
38
|
InvalidRequestError,
|
41
39
|
MethodNotFoundError,
|
42
|
-
ContentTypeNotSupportedError,
|
43
40
|
InternalError,
|
44
41
|
UnsupportedOperationError,
|
45
42
|
TaskNotFoundError,
|
@@ -48,6 +45,11 @@ from .types import (
|
|
48
45
|
A2AStatus,
|
49
46
|
A2AStreamResponse,
|
50
47
|
TaskSendParams,
|
48
|
+
SetTaskPushNotificationRequest,
|
49
|
+
GetTaskPushNotificationRequest,
|
50
|
+
SetTaskPushNotificationResponse,
|
51
|
+
GetTaskPushNotificationResponse,
|
52
|
+
TaskPushNotificationConfig,
|
51
53
|
)
|
52
54
|
|
53
55
|
class SmartA2A:
|
@@ -129,6 +131,18 @@ class SmartA2A:
|
|
129
131
|
self._register_handler("tasks/cancel", func, "task_cancel", "handler")
|
130
132
|
return func
|
131
133
|
return decorator
|
134
|
+
|
135
|
+
def set_notification(self):
|
136
|
+
def decorator(func: Callable[[SetTaskPushNotificationRequest], None]) -> Callable:
|
137
|
+
self._register_handler("tasks/pushNotification/set", func, "set_notification", "handler")
|
138
|
+
return func
|
139
|
+
return decorator
|
140
|
+
|
141
|
+
def get_notification(self):
|
142
|
+
def decorator(func: Callable[[GetTaskPushNotificationRequest], Union[TaskPushNotificationConfig, GetTaskPushNotificationResponse]]):
|
143
|
+
self._register_handler("tasks/pushNotification/get", func, "get_notification", "handler")
|
144
|
+
return func
|
145
|
+
return decorator
|
132
146
|
|
133
147
|
async def process_request(self, request_data: dict) -> JSONRPCResponse:
|
134
148
|
try:
|
@@ -141,6 +155,10 @@ class SmartA2A:
|
|
141
155
|
return self._handle_get_task(request_data)
|
142
156
|
elif method == "tasks/cancel":
|
143
157
|
return self._handle_cancel_task(request_data)
|
158
|
+
elif method == "tasks/pushNotification/set":
|
159
|
+
return self._handle_set_notification(request_data)
|
160
|
+
elif method == "tasks/pushNotification/get":
|
161
|
+
return self._handle_get_notification(request_data)
|
144
162
|
else:
|
145
163
|
return self._error_response(
|
146
164
|
request_data.get("id"),
|
@@ -422,6 +440,102 @@ class SmartA2A:
|
|
422
440
|
error=InternalError(data=str(e))
|
423
441
|
)
|
424
442
|
|
443
|
+
def _handle_set_notification(self, request_data: dict) -> SetTaskPushNotificationResponse:
|
444
|
+
try:
|
445
|
+
request = SetTaskPushNotificationRequest.model_validate(request_data)
|
446
|
+
handler = self.handlers.get("tasks/pushNotification/set")
|
447
|
+
|
448
|
+
if not handler:
|
449
|
+
return SetTaskPushNotificationResponse(
|
450
|
+
id=request.id,
|
451
|
+
error=MethodNotFoundError()
|
452
|
+
)
|
453
|
+
|
454
|
+
try:
|
455
|
+
# Execute handler (may or may not return something)
|
456
|
+
raw_result = handler(request)
|
457
|
+
|
458
|
+
# If handler returns nothing - build success response from request params
|
459
|
+
if raw_result is None:
|
460
|
+
return SetTaskPushNotificationResponse(
|
461
|
+
id=request.id,
|
462
|
+
result=request.params
|
463
|
+
)
|
464
|
+
|
465
|
+
# If handler returns a full response object
|
466
|
+
if isinstance(raw_result, SetTaskPushNotificationResponse):
|
467
|
+
return raw_result
|
468
|
+
|
469
|
+
|
470
|
+
except Exception as e:
|
471
|
+
if isinstance(e, JSONRPCError):
|
472
|
+
return SetTaskPushNotificationResponse(
|
473
|
+
id=request.id,
|
474
|
+
error=e
|
475
|
+
)
|
476
|
+
return SetTaskPushNotificationResponse(
|
477
|
+
id=request.id,
|
478
|
+
error=InternalError(data=str(e))
|
479
|
+
)
|
480
|
+
|
481
|
+
except ValidationError as e:
|
482
|
+
return SetTaskPushNotificationResponse(
|
483
|
+
id=request_data.get("id"),
|
484
|
+
error=InvalidRequestError(data=e.errors())
|
485
|
+
)
|
486
|
+
|
487
|
+
|
488
|
+
def _handle_get_notification(self, request_data: dict) -> GetTaskPushNotificationResponse:
|
489
|
+
try:
|
490
|
+
request = GetTaskPushNotificationRequest.model_validate(request_data)
|
491
|
+
handler = self.handlers.get("tasks/pushNotification/get")
|
492
|
+
|
493
|
+
if not handler:
|
494
|
+
return GetTaskPushNotificationResponse(
|
495
|
+
id=request.id,
|
496
|
+
error=MethodNotFoundError()
|
497
|
+
)
|
498
|
+
|
499
|
+
try:
|
500
|
+
raw_result = handler(request)
|
501
|
+
|
502
|
+
if isinstance(raw_result, GetTaskPushNotificationResponse):
|
503
|
+
return raw_result
|
504
|
+
else:
|
505
|
+
# Validate raw_result as TaskPushNotificationConfig
|
506
|
+
config = TaskPushNotificationConfig.model_validate(raw_result)
|
507
|
+
return GetTaskPushNotificationResponse(
|
508
|
+
id=request.id,
|
509
|
+
result=config
|
510
|
+
)
|
511
|
+
except ValidationError as e:
|
512
|
+
return GetTaskPushNotificationResponse(
|
513
|
+
id=request.id,
|
514
|
+
error=InvalidParamsError(data=e.errors())
|
515
|
+
)
|
516
|
+
except Exception as e:
|
517
|
+
if isinstance(e, JSONRPCError):
|
518
|
+
return GetTaskPushNotificationResponse(
|
519
|
+
id=request.id,
|
520
|
+
error=e
|
521
|
+
)
|
522
|
+
return GetTaskPushNotificationResponse(
|
523
|
+
id=request.id,
|
524
|
+
error=InternalError(data=str(e))
|
525
|
+
)
|
526
|
+
|
527
|
+
except ValidationError as e:
|
528
|
+
return GetTaskPushNotificationResponse(
|
529
|
+
id=request_data.get("id"),
|
530
|
+
error=InvalidRequestError(data=e.errors())
|
531
|
+
)
|
532
|
+
except json.JSONDecodeError as e:
|
533
|
+
return GetTaskPushNotificationResponse(
|
534
|
+
id=request_data.get("id"),
|
535
|
+
error=JSONParseError(data=str(e))
|
536
|
+
)
|
537
|
+
|
538
|
+
|
425
539
|
def _normalize_artifacts(self, content: Any) -> List[Artifact]:
|
426
540
|
"""Handle both A2AResponse content and regular returns"""
|
427
541
|
if isinstance(content, Artifact):
|
@@ -2,8 +2,8 @@ import pytest
|
|
2
2
|
import json
|
3
3
|
import requests
|
4
4
|
from fastapi.testclient import TestClient
|
5
|
-
from smarta2a import SmartA2A
|
6
|
-
from smarta2a.types import (
|
5
|
+
from smarta2a.server import SmartA2A
|
6
|
+
from smarta2a.common.types import (
|
7
7
|
TaskSendParams,
|
8
8
|
SendTaskRequest,
|
9
9
|
GetTaskRequest,
|
@@ -18,13 +18,18 @@ from smarta2a.types import (
|
|
18
18
|
FileContent,
|
19
19
|
A2AResponse,
|
20
20
|
A2ARequest,
|
21
|
-
TaskQueryParams,
|
22
21
|
TaskStatusUpdateEvent,
|
23
22
|
TaskArtifactUpdateEvent,
|
24
23
|
A2AStatus,
|
25
24
|
A2AStreamResponse,
|
26
25
|
SendTaskResponse,
|
27
|
-
Message
|
26
|
+
Message,
|
27
|
+
InternalError,
|
28
|
+
SetTaskPushNotificationRequest,
|
29
|
+
GetTaskPushNotificationRequest,
|
30
|
+
SetTaskPushNotificationResponse,
|
31
|
+
GetTaskPushNotificationResponse,
|
32
|
+
TaskPushNotificationConfig
|
28
33
|
)
|
29
34
|
|
30
35
|
@pytest.fixture
|
@@ -500,4 +505,145 @@ def test_send_task_content_access():
|
|
500
505
|
assert request.content == request.params.message.parts
|
501
506
|
|
502
507
|
|
508
|
+
def test_set_notification_success(a2a_server, client):
|
509
|
+
# Test basic success case with no return value
|
510
|
+
@a2a_server.set_notification()
|
511
|
+
def handle_set(req: SetTaskPushNotificationRequest):
|
512
|
+
# No return needed - just validate request
|
513
|
+
assert req.params.id == "test123"
|
514
|
+
|
515
|
+
request_data = {
|
516
|
+
"jsonrpc": "2.0",
|
517
|
+
"id": 1,
|
518
|
+
"method": "tasks/pushNotification/set",
|
519
|
+
"params": {
|
520
|
+
"id": "test123",
|
521
|
+
"pushNotificationConfig": {
|
522
|
+
"url": "https://example.com/callback",
|
523
|
+
"authentication": {
|
524
|
+
"schemes": ["jwt"]
|
525
|
+
}
|
526
|
+
}
|
527
|
+
}
|
528
|
+
}
|
529
|
+
|
530
|
+
response = client.post("/", json=request_data).json()
|
531
|
+
|
532
|
+
assert response["result"]["id"] == "test123"
|
533
|
+
assert response["result"]["pushNotificationConfig"]["url"] == request_data["params"]["pushNotificationConfig"]["url"]
|
534
|
+
assert response["result"]["pushNotificationConfig"]["authentication"]["schemes"] == ["jwt"]
|
535
|
+
|
536
|
+
def test_set_notification_custom_response(a2a_server, client):
|
537
|
+
# Test handler returning custom response
|
538
|
+
@a2a_server.set_notification()
|
539
|
+
def handle_set(req):
|
540
|
+
return SetTaskPushNotificationResponse(
|
541
|
+
id=req.id,
|
542
|
+
result=TaskPushNotificationConfig(
|
543
|
+
id="test123",
|
544
|
+
pushNotificationConfig={
|
545
|
+
"url": "custom-url",
|
546
|
+
"token": "secret"
|
547
|
+
}
|
548
|
+
)
|
549
|
+
)
|
550
|
+
|
551
|
+
response = client.post("/", json={
|
552
|
+
"jsonrpc": "2.0",
|
553
|
+
"id": 2,
|
554
|
+
"method": "tasks/pushNotification/set",
|
555
|
+
"params": {
|
556
|
+
"id": "test123",
|
557
|
+
"pushNotificationConfig": {"url": "https://example.com"}
|
558
|
+
}
|
559
|
+
}).json()
|
560
|
+
|
561
|
+
assert response["result"]["pushNotificationConfig"]["url"] == "custom-url"
|
562
|
+
assert "secret" in response["result"]["pushNotificationConfig"]["token"]
|
563
|
+
|
564
|
+
|
565
|
+
# --- Get Notification Tests ---
|
566
|
+
|
567
|
+
def test_get_notification_success(a2a_server, client):
|
568
|
+
# Test successful config retrieval
|
569
|
+
@a2a_server.get_notification()
|
570
|
+
def handle_get(req: GetTaskPushNotificationRequest):
|
571
|
+
return TaskPushNotificationConfig(
|
572
|
+
id=req.params.id,
|
573
|
+
pushNotificationConfig={
|
574
|
+
"url": "https://test.com",
|
575
|
+
"token": "abc123"
|
576
|
+
}
|
577
|
+
)
|
578
|
+
|
579
|
+
request_data = {
|
580
|
+
"jsonrpc": "2.0",
|
581
|
+
"id": 4,
|
582
|
+
"method": "tasks/pushNotification/get",
|
583
|
+
"params": {"id": "test456"}
|
584
|
+
}
|
585
|
+
|
586
|
+
response = client.post("/", json=request_data).json()
|
587
|
+
|
588
|
+
assert response["result"]["id"] == "test456"
|
589
|
+
assert response["result"]["pushNotificationConfig"]["url"] == "https://test.com"
|
590
|
+
|
591
|
+
def test_get_notification_direct_response(a2a_server, client):
|
592
|
+
# Test handler returning full response object
|
593
|
+
@a2a_server.get_notification()
|
594
|
+
def handle_get(req):
|
595
|
+
return GetTaskPushNotificationResponse(
|
596
|
+
id=req.id,
|
597
|
+
result=TaskPushNotificationConfig(
|
598
|
+
id=req.params.id,
|
599
|
+
pushNotificationConfig={
|
600
|
+
"url": "direct-response.example",
|
601
|
+
"authentication": {"schemes": ["basic"]}
|
602
|
+
}
|
603
|
+
)
|
604
|
+
)
|
605
|
+
|
606
|
+
response = client.post("/", json={
|
607
|
+
"jsonrpc": "2.0",
|
608
|
+
"id": 5,
|
609
|
+
"method": "tasks/pushNotification/get",
|
610
|
+
"params": {"id": "test789"}
|
611
|
+
}).json()
|
612
|
+
|
613
|
+
assert "direct-response" in response["result"]["pushNotificationConfig"]["url"]
|
614
|
+
assert "basic" in response["result"]["pushNotificationConfig"]["authentication"]["schemes"]
|
615
|
+
|
616
|
+
def test_get_notification_validation_error(a2a_server, client):
|
617
|
+
# Test invalid response from handler
|
618
|
+
@a2a_server.get_notification()
|
619
|
+
def handle_get(req):
|
620
|
+
return {"invalid": "config"}
|
621
|
+
|
622
|
+
response = client.post("/", json={
|
623
|
+
"jsonrpc": "2.0",
|
624
|
+
"id": 6,
|
625
|
+
"method": "tasks/pushNotification/get",
|
626
|
+
"params": {"id": "test999"}
|
627
|
+
}).json()
|
628
|
+
|
629
|
+
assert response["error"]["code"] == -32602 # Invalid params
|
630
|
+
|
631
|
+
|
632
|
+
def test_get_notification_error_propagation(a2a_server, client):
|
633
|
+
# Test exception handling
|
634
|
+
@a2a_server.get_notification()
|
635
|
+
def handle_get(req):
|
636
|
+
raise InternalError(message="Storage failure")
|
637
|
+
|
638
|
+
response = client.post("/", json={
|
639
|
+
"jsonrpc": "2.0",
|
640
|
+
"id": 7,
|
641
|
+
"method": "tasks/pushNotification/get",
|
642
|
+
"params": {"id": "test-error"}
|
643
|
+
}).json()
|
644
|
+
|
645
|
+
assert response["error"]["code"] == -32603 # Internal error code
|
646
|
+
|
647
|
+
|
648
|
+
|
503
649
|
|
@@ -0,0 +1,130 @@
|
|
1
|
+
import pytest
|
2
|
+
from uuid import UUID
|
3
|
+
from smarta2a.common.types import (
|
4
|
+
TaskSendParams,
|
5
|
+
Message,
|
6
|
+
TextPart,
|
7
|
+
DataPart,
|
8
|
+
FilePart,
|
9
|
+
GetTaskRequest,
|
10
|
+
TaskQueryParams,
|
11
|
+
CancelTaskRequest,
|
12
|
+
TaskIdParams,
|
13
|
+
SetTaskPushNotificationRequest,
|
14
|
+
PushNotificationConfig,
|
15
|
+
AuthenticationInfo,
|
16
|
+
GetTaskPushNotificationRequest,
|
17
|
+
)
|
18
|
+
from smarta2a.common.task_request_builder import TaskRequestBuilder
|
19
|
+
|
20
|
+
class TestTaskRequestBuilder:
|
21
|
+
def test_build_send_task_request_with_text(self):
|
22
|
+
# Test with text part
|
23
|
+
request = TaskRequestBuilder.build_send_task_request(
|
24
|
+
id="task123",
|
25
|
+
text="Hello world",
|
26
|
+
role="agent",
|
27
|
+
session_id="session_456",
|
28
|
+
metadata={"key": "value"},
|
29
|
+
)
|
30
|
+
|
31
|
+
assert isinstance(request, TaskSendParams)
|
32
|
+
assert request.id == "task123"
|
33
|
+
assert request.sessionId == "session_456"
|
34
|
+
assert request.metadata == {"key": "value"}
|
35
|
+
assert isinstance(request.message, Message)
|
36
|
+
assert request.message.role == "agent"
|
37
|
+
assert len(request.message.parts) == 1
|
38
|
+
assert isinstance(request.message.parts[0], TextPart)
|
39
|
+
assert request.message.parts[0].text == "Hello world"
|
40
|
+
|
41
|
+
def test_build_send_task_request_with_all_parts(self):
|
42
|
+
# Test with text, data, and file parts
|
43
|
+
request = TaskRequestBuilder.build_send_task_request(
|
44
|
+
id="task123",
|
45
|
+
text="Hello",
|
46
|
+
data={"key": "value"},
|
47
|
+
file_uri="file:///data.csv",
|
48
|
+
)
|
49
|
+
|
50
|
+
parts = request.message.parts
|
51
|
+
assert len(parts) == 3
|
52
|
+
assert any(isinstance(p, TextPart) for p in parts)
|
53
|
+
assert any(isinstance(p, DataPart) and p.data == {"key": "value"} for p in parts)
|
54
|
+
assert any(isinstance(p, FilePart) and p.file.uri == "file:///data.csv" for p in parts)
|
55
|
+
|
56
|
+
def test_build_send_task_request_default_session_id(self):
|
57
|
+
# Ensure sessionId is a UUID hex string when not provided
|
58
|
+
request = TaskRequestBuilder.build_send_task_request(id="task123", text="Hi")
|
59
|
+
assert len(request.sessionId) == 32
|
60
|
+
try:
|
61
|
+
UUID(request.sessionId, version=4)
|
62
|
+
except ValueError:
|
63
|
+
pytest.fail("sessionId is not a valid UUID4 hex string")
|
64
|
+
|
65
|
+
def test_get_task(self):
|
66
|
+
request = TaskRequestBuilder.get_task(
|
67
|
+
id="task123",
|
68
|
+
history_length=5,
|
69
|
+
metadata={"key": "value"},
|
70
|
+
)
|
71
|
+
|
72
|
+
assert isinstance(request, GetTaskRequest)
|
73
|
+
assert isinstance(request.params, TaskQueryParams)
|
74
|
+
assert request.params.id == "task123"
|
75
|
+
assert request.params.historyLength == 5
|
76
|
+
assert request.params.metadata == {"key": "value"}
|
77
|
+
|
78
|
+
def test_cancel_task(self):
|
79
|
+
request = TaskRequestBuilder.cancel_task(
|
80
|
+
id="task123",
|
81
|
+
metadata={"key": "value"},
|
82
|
+
)
|
83
|
+
|
84
|
+
assert isinstance(request, CancelTaskRequest)
|
85
|
+
assert isinstance(request.params, TaskIdParams)
|
86
|
+
assert request.params.id == "task123"
|
87
|
+
assert request.params.metadata == {"key": "value"}
|
88
|
+
|
89
|
+
def test_set_push_notification_with_authentication_info(self):
|
90
|
+
# Test with AuthenticationInfo instance (include REQUIRED 'schemes' field)
|
91
|
+
auth_info = AuthenticationInfo(
|
92
|
+
schemes=["https", "bearer"], # Required field
|
93
|
+
credentials="token123" # Optional field
|
94
|
+
)
|
95
|
+
request = TaskRequestBuilder.set_push_notification(
|
96
|
+
id="task123",
|
97
|
+
url="https://example.com",
|
98
|
+
token="auth_token",
|
99
|
+
authentication=auth_info,
|
100
|
+
)
|
101
|
+
|
102
|
+
assert isinstance(request.params.pushNotificationConfig.authentication, AuthenticationInfo)
|
103
|
+
assert request.params.pushNotificationConfig.authentication.schemes == ["https", "bearer"]
|
104
|
+
assert request.params.pushNotificationConfig.authentication.credentials == "token123"
|
105
|
+
|
106
|
+
def test_set_push_notification_with_dict(self):
|
107
|
+
# Test with authentication dict (MUST include 'schemes')
|
108
|
+
auth_dict = {
|
109
|
+
"schemes": ["basic"], # Required field
|
110
|
+
"credentials": "user:pass" # Optional field
|
111
|
+
}
|
112
|
+
request = TaskRequestBuilder.set_push_notification(
|
113
|
+
id="task123",
|
114
|
+
url="https://example.com",
|
115
|
+
authentication=auth_dict,
|
116
|
+
)
|
117
|
+
|
118
|
+
assert request.params.pushNotificationConfig.authentication.schemes == ["basic"]
|
119
|
+
assert request.params.pushNotificationConfig.authentication.credentials == "user:pass"
|
120
|
+
|
121
|
+
def test_get_push_notification(self):
|
122
|
+
request = TaskRequestBuilder.get_push_notification(
|
123
|
+
id="task123",
|
124
|
+
metadata={"key": "value"},
|
125
|
+
)
|
126
|
+
|
127
|
+
assert isinstance(request, GetTaskPushNotificationRequest)
|
128
|
+
assert isinstance(request.params, TaskIdParams)
|
129
|
+
assert request.params.id == "task123"
|
130
|
+
assert request.params.metadata == {"key": "value"}
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|