smarta2a 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smarta2a/__init__.py +4 -4
- smarta2a/agent/a2a_agent.py +38 -0
- smarta2a/agent/a2a_mcp_server.py +37 -0
- smarta2a/archive/mcp_client.py +86 -0
- smarta2a/client/__init__.py +0 -0
- smarta2a/client/a2a_client.py +267 -0
- smarta2a/client/smart_mcp_client.py +60 -0
- smarta2a/client/tools_manager.py +58 -0
- smarta2a/history_update_strategies/__init__.py +8 -0
- smarta2a/history_update_strategies/append_strategy.py +10 -0
- smarta2a/history_update_strategies/history_update_strategy.py +15 -0
- smarta2a/model_providers/__init__.py +5 -0
- smarta2a/model_providers/base_llm_provider.py +15 -0
- smarta2a/model_providers/openai_provider.py +281 -0
- smarta2a/server/__init__.py +3 -0
- smarta2a/server/handler_registry.py +23 -0
- smarta2a/{server.py → server/server.py} +224 -254
- smarta2a/server/state_manager.py +34 -0
- smarta2a/server/subscription_service.py +109 -0
- smarta2a/server/task_service.py +155 -0
- smarta2a/state_stores/__init__.py +8 -0
- smarta2a/state_stores/base_state_store.py +20 -0
- smarta2a/state_stores/inmemory_state_store.py +21 -0
- smarta2a/utils/__init__.py +32 -0
- smarta2a/utils/prompt_helpers.py +38 -0
- smarta2a/utils/task_builder.py +153 -0
- smarta2a/utils/task_request_builder.py +114 -0
- smarta2a/{types.py → utils/types.py} +62 -2
- {smarta2a-0.2.1.dist-info → smarta2a-0.2.3.dist-info}/METADATA +13 -7
- smarta2a-0.2.3.dist-info/RECORD +32 -0
- smarta2a-0.2.1.dist-info/RECORD +0 -7
- {smarta2a-0.2.1.dist-info → smarta2a-0.2.3.dist-info}/WHEEL +0 -0
- {smarta2a-0.2.1.dist-info → smarta2a-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
+
# Library imports
|
1
2
|
from typing import Callable, Any, Optional, Dict, Union, List, AsyncGenerator
|
2
3
|
import json
|
3
|
-
import inspect
|
4
4
|
from datetime import datetime
|
5
5
|
from collections import defaultdict
|
6
6
|
from fastapi import FastAPI, Request, HTTPException, APIRouter
|
@@ -11,8 +11,15 @@ import uvicorn
|
|
11
11
|
from fastapi.responses import StreamingResponse
|
12
12
|
from uuid import uuid4
|
13
13
|
|
14
|
+
# Local imports
|
15
|
+
from smarta2a.server.handler_registry import HandlerRegistry
|
16
|
+
from smarta2a.server.state_manager import StateManager
|
17
|
+
from smarta2a.state_stores.base_state_store import BaseStateStore
|
18
|
+
from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
|
19
|
+
from smarta2a.history_update_strategies.append_strategy import AppendStrategy
|
20
|
+
from smarta2a.utils.task_builder import TaskBuilder
|
14
21
|
|
15
|
-
from .types import (
|
22
|
+
from smarta2a.utils.types import (
|
16
23
|
JSONRPCResponse,
|
17
24
|
Task,
|
18
25
|
Artifact,
|
@@ -21,6 +28,7 @@ from .types import (
|
|
21
28
|
FileContent,
|
22
29
|
DataPart,
|
23
30
|
Part,
|
31
|
+
Message,
|
24
32
|
TaskStatus,
|
25
33
|
TaskState,
|
26
34
|
JSONRPCError,
|
@@ -39,7 +47,6 @@ from .types import (
|
|
39
47
|
JSONParseError,
|
40
48
|
InvalidRequestError,
|
41
49
|
MethodNotFoundError,
|
42
|
-
ContentTypeNotSupportedError,
|
43
50
|
InternalError,
|
44
51
|
UnsupportedOperationError,
|
45
52
|
TaskNotFoundError,
|
@@ -53,41 +60,75 @@ from .types import (
|
|
53
60
|
SetTaskPushNotificationResponse,
|
54
61
|
GetTaskPushNotificationResponse,
|
55
62
|
TaskPushNotificationConfig,
|
63
|
+
StateData
|
56
64
|
)
|
57
65
|
|
58
66
|
class SmartA2A:
|
59
|
-
def __init__(self, name: str, **fastapi_kwargs):
|
67
|
+
def __init__(self, name: str, state_store: Optional[BaseStateStore] = None, history_strategy: HistoryUpdateStrategy = AppendStrategy(), **fastapi_kwargs):
|
60
68
|
self.name = name
|
61
|
-
self.
|
62
|
-
self.
|
69
|
+
self.registry = HandlerRegistry()
|
70
|
+
self.state_mgr = StateManager(state_store, history_strategy)
|
63
71
|
self.app = FastAPI(title=name, **fastapi_kwargs)
|
64
72
|
self.router = APIRouter()
|
65
|
-
self.
|
73
|
+
self.state_store = state_store
|
74
|
+
self.history_strategy = history_strategy
|
66
75
|
self._setup_routes()
|
67
76
|
self.server_config = {
|
68
77
|
"host": "0.0.0.0",
|
69
78
|
"port": 8000,
|
70
79
|
"reload": False
|
71
80
|
}
|
81
|
+
self.task_builder = TaskBuilder(default_status=TaskState.COMPLETED)
|
72
82
|
|
83
|
+
|
84
|
+
def on_send_task(self):
|
85
|
+
def decorator(func: Callable[[SendTaskRequest, Optional[StateData]], Any]) -> Callable:
|
86
|
+
self.registry.register("tasks/send", func)
|
87
|
+
return func
|
88
|
+
return decorator
|
89
|
+
|
90
|
+
def on_send_subscribe_task(self):
|
91
|
+
def decorator(fn: Callable[[SendTaskStreamingRequest, Optional[StateData]], Any]):
|
92
|
+
self.registry.register("tasks/sendSubscribe", fn, subscription=True)
|
93
|
+
return fn
|
94
|
+
return decorator
|
95
|
+
|
96
|
+
def task_get(self):
|
97
|
+
def decorator(fn: Callable[[GetTaskRequest], Any]):
|
98
|
+
self.registry.register("tasks/get", fn)
|
99
|
+
return fn
|
100
|
+
return decorator
|
101
|
+
|
102
|
+
def task_cancel(self):
|
103
|
+
def decorator(fn: Callable[[CancelTaskRequest], Any]):
|
104
|
+
self.registry.register("tasks/cancel", fn)
|
105
|
+
return fn
|
106
|
+
return decorator
|
107
|
+
|
108
|
+
def set_notification(self):
|
109
|
+
def decorator(fn: Callable[[SetTaskPushNotificationRequest], Any]):
|
110
|
+
self.registry.register("tasks/pushNotification/set", fn)
|
111
|
+
return fn
|
112
|
+
return decorator
|
113
|
+
|
114
|
+
def get_notification(self):
|
115
|
+
def decorator(fn: Callable[[GetTaskPushNotificationRequest], Any]):
|
116
|
+
self.registry.register("tasks/pushNotification/get", fn)
|
117
|
+
return fn
|
118
|
+
return decorator
|
119
|
+
|
73
120
|
|
74
121
|
def _setup_routes(self):
|
75
122
|
@self.app.post("/")
|
76
123
|
async def handle_request(request: Request):
|
77
124
|
try:
|
78
125
|
data = await request.json()
|
79
|
-
|
126
|
+
req = JSONRPCRequest.model_validate(data)
|
127
|
+
#request_obj = JSONRPCRequest(**data)
|
80
128
|
except Exception as e:
|
81
|
-
return JSONRPCResponse(
|
82
|
-
|
83
|
-
|
84
|
-
code=-32700,
|
85
|
-
message="Parse error",
|
86
|
-
data=str(e)
|
87
|
-
)
|
88
|
-
).model_dump()
|
89
|
-
|
90
|
-
response = await self.process_request(request_obj.model_dump())
|
129
|
+
return JSONRPCResponse(id=None, error=JSONRPCError(code=-32700, message="Parse error", data=str(e))).model_dump()
|
130
|
+
|
131
|
+
response = await self.process_request(req)
|
91
132
|
|
92
133
|
# <-- Accept both SSE‐style responses:
|
93
134
|
if isinstance(response, (EventSourceResponse, StreamingResponse)):
|
@@ -95,114 +136,117 @@ class SmartA2A:
|
|
95
136
|
|
96
137
|
# <-- Everything else is a normal pydantic JSONRPCResponse
|
97
138
|
return response.model_dump()
|
98
|
-
|
99
|
-
def _register_handler(self, method: str, func: Callable, handler_name: str, handler_type: str = "handler"):
|
100
|
-
"""Shared registration logic with duplicate checking"""
|
101
|
-
if method in self._registered_decorators:
|
102
|
-
raise RuntimeError(
|
103
|
-
f"@{handler_name} decorator for method '{method}' "
|
104
|
-
f"can only be used once per SmartA2A instance"
|
105
|
-
)
|
106
|
-
|
107
|
-
if handler_type == "handler":
|
108
|
-
self.handlers[method] = func
|
109
|
-
else:
|
110
|
-
self.subscriptions[method] = func
|
111
|
-
|
112
|
-
self._registered_decorators.add(method)
|
113
|
-
|
114
|
-
def on_send_task(self) -> Callable:
|
115
|
-
def decorator(func: Callable[[SendTaskRequest], Any]) -> Callable:
|
116
|
-
self._register_handler("tasks/send", func, "on_send_task", "handler")
|
117
|
-
return func
|
118
|
-
return decorator
|
119
|
-
|
120
|
-
def on_send_subscribe_task(self) -> Callable:
|
121
|
-
def decorator(func: Callable) -> Callable:
|
122
|
-
self._register_handler("tasks/sendSubscribe", func, "on_send_subscribe_task", "subscription")
|
123
|
-
return func
|
124
|
-
return decorator
|
125
|
-
|
126
|
-
def task_get(self):
|
127
|
-
def decorator(func: Callable[[GetTaskRequest], Task]):
|
128
|
-
self._register_handler("tasks/get", func, "task_get", "handler")
|
129
|
-
return func
|
130
|
-
return decorator
|
131
139
|
|
132
|
-
def task_cancel(self):
|
133
|
-
def decorator(func: Callable[[CancelTaskRequest], Task]):
|
134
|
-
self._register_handler("tasks/cancel", func, "task_cancel", "handler")
|
135
|
-
return func
|
136
|
-
return decorator
|
137
|
-
|
138
|
-
def set_notification(self):
|
139
|
-
def decorator(func: Callable[[SetTaskPushNotificationRequest], None]) -> Callable:
|
140
|
-
self._register_handler("tasks/pushNotification/set", func, "set_notification", "handler")
|
141
|
-
return func
|
142
|
-
return decorator
|
143
|
-
|
144
|
-
def get_notification(self):
|
145
|
-
def decorator(func: Callable[[GetTaskPushNotificationRequest], Union[TaskPushNotificationConfig, GetTaskPushNotificationResponse]]):
|
146
|
-
self._register_handler("tasks/pushNotification/get", func, "get_notification", "handler")
|
147
|
-
return func
|
148
|
-
return decorator
|
149
140
|
|
150
|
-
async def process_request(self,
|
141
|
+
async def process_request(self, request: JSONRPCRequest) -> JSONRPCResponse:
|
142
|
+
|
151
143
|
try:
|
152
|
-
method =
|
144
|
+
method = request.method
|
145
|
+
params = request.params
|
146
|
+
state_store = self.state_mgr.get_store()
|
153
147
|
if method == "tasks/send":
|
154
|
-
|
148
|
+
state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
|
149
|
+
if state_store:
|
150
|
+
return self._handle_send_task(request, state_data)
|
151
|
+
else:
|
152
|
+
return self._handle_send_task(request)
|
155
153
|
elif method == "tasks/sendSubscribe":
|
156
|
-
|
154
|
+
state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
|
155
|
+
if state_store:
|
156
|
+
return await self._handle_subscribe_task(request, state_data)
|
157
|
+
else:
|
158
|
+
return await self._handle_subscribe_task(request)
|
157
159
|
elif method == "tasks/get":
|
158
|
-
return self._handle_get_task(
|
160
|
+
return self._handle_get_task(request)
|
159
161
|
elif method == "tasks/cancel":
|
160
|
-
return self._handle_cancel_task(
|
162
|
+
return self._handle_cancel_task(request)
|
161
163
|
elif method == "tasks/pushNotification/set":
|
162
|
-
return self._handle_set_notification(
|
164
|
+
return self._handle_set_notification(request)
|
163
165
|
elif method == "tasks/pushNotification/get":
|
164
|
-
return self._handle_get_notification(
|
166
|
+
return self._handle_get_notification(request)
|
165
167
|
else:
|
166
|
-
return
|
167
|
-
request_data.get("id"),
|
168
|
-
-32601,
|
169
|
-
"Method not found"
|
170
|
-
)
|
168
|
+
return JSONRPCResponse(id=request.id, error=MethodNotFoundError()).model_dump()
|
171
169
|
except ValidationError as e:
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
e.errors()
|
177
|
-
)
|
170
|
+
return JSONRPCResponse(id=request.id, error=InvalidParamsError(data=e.errors())).model_dump()
|
171
|
+
except HTTPException as e:
|
172
|
+
err = UnsupportedOperationError() if e.status_code == 405 else InternalError(data=str(e))
|
173
|
+
return JSONRPCResponse(id=request.id, error=err).model_dump()
|
178
174
|
|
179
|
-
|
175
|
+
|
176
|
+
def _handle_send_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> SendTaskResponse:
|
180
177
|
try:
|
181
178
|
# Validate request format
|
182
|
-
request = SendTaskRequest.model_validate(request_data)
|
183
|
-
handler = self.
|
179
|
+
request = SendTaskRequest.model_validate(request_data.model_dump())
|
180
|
+
handler = self.registry.get_handler("tasks/send")
|
184
181
|
|
185
182
|
if not handler:
|
186
183
|
return SendTaskResponse(
|
187
184
|
id=request.id,
|
188
185
|
error=MethodNotFoundError()
|
189
186
|
)
|
187
|
+
|
188
|
+
user_message = request.params.message
|
189
|
+
request_metadata = request.params.metadata or {}
|
190
|
+
if state_data:
|
191
|
+
session_id = state_data.sessionId
|
192
|
+
existing_history = state_data.history.copy() or []
|
193
|
+
metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
|
194
|
+
else:
|
195
|
+
session_id = request.params.sessionId or str(uuid4())
|
196
|
+
existing_history = [user_message]
|
197
|
+
metadata = request_metadata
|
198
|
+
|
190
199
|
|
191
200
|
try:
|
192
|
-
|
193
|
-
|
201
|
+
|
202
|
+
if state_data:
|
203
|
+
raw_result = handler(request, state_data)
|
204
|
+
else:
|
205
|
+
raw_result = handler(request)
|
206
|
+
|
207
|
+
# Handle direct SendTaskResponse returns
|
194
208
|
if isinstance(raw_result, SendTaskResponse):
|
195
209
|
return raw_result
|
196
210
|
|
197
|
-
#
|
198
|
-
task = self.
|
211
|
+
# Build task with updated history (before agent response)
|
212
|
+
task = self.task_builder.build(
|
199
213
|
content=raw_result,
|
200
214
|
task_id=request.params.id,
|
201
|
-
session_id=
|
202
|
-
|
203
|
-
|
215
|
+
session_id=session_id, # Always use generated session ID
|
216
|
+
metadata=metadata, # Use merged metadata
|
217
|
+
history=existing_history # History
|
204
218
|
)
|
205
219
|
|
220
|
+
# Process messages through strategy
|
221
|
+
messages = []
|
222
|
+
if task.artifacts:
|
223
|
+
agent_parts = [p for a in task.artifacts for p in a.parts]
|
224
|
+
agent_message = Message(
|
225
|
+
role="agent",
|
226
|
+
parts=agent_parts,
|
227
|
+
metadata=task.metadata
|
228
|
+
)
|
229
|
+
messages.append(agent_message)
|
230
|
+
|
231
|
+
final_history = self.history_strategy.update_history(
|
232
|
+
existing_history=existing_history,
|
233
|
+
new_messages=messages
|
234
|
+
)
|
235
|
+
|
236
|
+
# Update task with final state
|
237
|
+
task.history = final_history
|
238
|
+
|
239
|
+
# State store update (if enabled)
|
240
|
+
if self.state_store:
|
241
|
+
self.state_store.update_state(
|
242
|
+
session_id=session_id,
|
243
|
+
state_data=StateData(
|
244
|
+
sessionId=session_id,
|
245
|
+
history=final_history,
|
246
|
+
metadata=metadata # Use merged metadata
|
247
|
+
)
|
248
|
+
)
|
249
|
+
|
206
250
|
return SendTaskResponse(
|
207
251
|
id=request.id,
|
208
252
|
result=task
|
@@ -232,10 +276,11 @@ class SmartA2A:
|
|
232
276
|
)
|
233
277
|
|
234
278
|
|
235
|
-
async def _handle_subscribe_task(self, request_data:
|
279
|
+
async def _handle_subscribe_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> Union[EventSourceResponse, SendTaskStreamingResponse]:
|
236
280
|
try:
|
237
|
-
request = SendTaskStreamingRequest.model_validate(request_data)
|
238
|
-
handler = self.subscriptions.get("tasks/sendSubscribe")
|
281
|
+
request = SendTaskStreamingRequest.model_validate(request_data.model_dump())
|
282
|
+
#handler = self.subscriptions.get("tasks/sendSubscribe")
|
283
|
+
handler = self.registry.get_subscription("tasks/sendSubscribe")
|
239
284
|
|
240
285
|
if not handler:
|
241
286
|
return SendTaskStreamingResponse(
|
@@ -243,15 +288,74 @@ class SmartA2A:
|
|
243
288
|
id=request.id,
|
244
289
|
error=MethodNotFoundError()
|
245
290
|
)
|
291
|
+
|
292
|
+
user_message = request.params.message
|
293
|
+
request_metadata = request.params.metadata or {}
|
294
|
+
if state_data:
|
295
|
+
session_id = state_data.sessionId
|
296
|
+
existing_history = state_data.history.copy() or []
|
297
|
+
metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
|
298
|
+
else:
|
299
|
+
session_id = request.params.sessionId or str(uuid4())
|
300
|
+
existing_history = [user_message]
|
301
|
+
metadata = request_metadata
|
302
|
+
|
246
303
|
|
247
304
|
async def event_generator():
|
248
305
|
|
249
306
|
try:
|
250
|
-
|
307
|
+
|
308
|
+
if state_data:
|
309
|
+
raw_events = handler(request, state_data)
|
310
|
+
else:
|
311
|
+
raw_events = handler(request)
|
312
|
+
|
251
313
|
normalized_events = self._normalize_subscription_events(request.params, raw_events)
|
252
314
|
|
315
|
+
# Initialize streaming state
|
316
|
+
stream_history = existing_history.copy()
|
317
|
+
stream_metadata = metadata.copy()
|
318
|
+
|
253
319
|
async for item in normalized_events:
|
254
320
|
try:
|
321
|
+
|
322
|
+
# Process artifact updates
|
323
|
+
if isinstance(item, TaskArtifactUpdateEvent):
|
324
|
+
# Create agent message from artifact parts
|
325
|
+
agent_message = Message(
|
326
|
+
role="agent",
|
327
|
+
parts=[p for p in item.artifact.parts],
|
328
|
+
metadata=item.artifact.metadata
|
329
|
+
)
|
330
|
+
|
331
|
+
# Update history using strategy
|
332
|
+
new_history = self.history_strategy.update_history(
|
333
|
+
existing_history=stream_history,
|
334
|
+
new_messages=[agent_message]
|
335
|
+
)
|
336
|
+
|
337
|
+
# Merge metadata
|
338
|
+
new_metadata = {
|
339
|
+
**stream_metadata,
|
340
|
+
**(item.artifact.metadata or {})
|
341
|
+
}
|
342
|
+
|
343
|
+
# Update state store if configured
|
344
|
+
if self.state_store:
|
345
|
+
self.state_store.update_state(
|
346
|
+
session_id=session_id,
|
347
|
+
state_data=StateData(
|
348
|
+
sessionId=session_id,
|
349
|
+
history=new_history,
|
350
|
+
metadata=new_metadata
|
351
|
+
)
|
352
|
+
)
|
353
|
+
|
354
|
+
# Update streaming state
|
355
|
+
stream_history = new_history
|
356
|
+
stream_metadata = new_metadata
|
357
|
+
|
358
|
+
|
255
359
|
if isinstance(item, SendTaskStreamingResponse):
|
256
360
|
yield item.model_dump_json()
|
257
361
|
continue
|
@@ -321,11 +425,11 @@ class SmartA2A:
|
|
321
425
|
)
|
322
426
|
|
323
427
|
|
324
|
-
def _handle_get_task(self, request_data:
|
428
|
+
def _handle_get_task(self, request_data: JSONRPCRequest) -> GetTaskResponse:
|
325
429
|
try:
|
326
430
|
# Validate request structure
|
327
|
-
request = GetTaskRequest.model_validate(request_data)
|
328
|
-
handler = self.
|
431
|
+
request = GetTaskRequest.model_validate(request_data.model_dump())
|
432
|
+
handler = self.registry.get_handler("tasks/get")
|
329
433
|
|
330
434
|
if not handler:
|
331
435
|
return GetTaskResponse(
|
@@ -340,10 +444,9 @@ class SmartA2A:
|
|
340
444
|
return self._validate_response_id(raw_result, request)
|
341
445
|
|
342
446
|
# Use unified task builder with different defaults
|
343
|
-
task = self.
|
447
|
+
task = self.task_builder.build(
|
344
448
|
content=raw_result,
|
345
449
|
task_id=request.params.id,
|
346
|
-
default_status=TaskState.COMPLETED,
|
347
450
|
metadata=request.params.metadata or {}
|
348
451
|
)
|
349
452
|
|
@@ -373,11 +476,11 @@ class SmartA2A:
|
|
373
476
|
)
|
374
477
|
|
375
478
|
|
376
|
-
def _handle_cancel_task(self, request_data:
|
479
|
+
def _handle_cancel_task(self, request_data: JSONRPCRequest) -> CancelTaskResponse:
|
377
480
|
try:
|
378
481
|
# Validate request structure
|
379
|
-
request = CancelTaskRequest.model_validate(request_data)
|
380
|
-
handler = self.
|
482
|
+
request = CancelTaskRequest.model_validate(request_data.model_dump())
|
483
|
+
handler = self.registry.get_handler("tasks/cancel")
|
381
484
|
|
382
485
|
if not handler:
|
383
486
|
return CancelTaskResponse(
|
@@ -394,14 +497,14 @@ class SmartA2A:
|
|
394
497
|
|
395
498
|
# Handle A2AStatus returns
|
396
499
|
if isinstance(raw_result, A2AStatus):
|
397
|
-
task = self.
|
398
|
-
status=raw_result,
|
500
|
+
task = self.task_builder.normalize_from_status(
|
501
|
+
status=raw_result.status,
|
399
502
|
task_id=request.params.id,
|
400
503
|
metadata=raw_result.metadata or {}
|
401
504
|
)
|
402
505
|
else:
|
403
506
|
# Existing processing for other return types
|
404
|
-
task = self.
|
507
|
+
task = self.task_builder.build(
|
405
508
|
content=raw_result,
|
406
509
|
task_id=request.params.id,
|
407
510
|
metadata=raw_result.metadata or {}
|
@@ -443,10 +546,10 @@ class SmartA2A:
|
|
443
546
|
error=InternalError(data=str(e))
|
444
547
|
)
|
445
548
|
|
446
|
-
def _handle_set_notification(self, request_data:
|
549
|
+
def _handle_set_notification(self, request_data: JSONRPCRequest) -> SetTaskPushNotificationResponse:
|
447
550
|
try:
|
448
|
-
request = SetTaskPushNotificationRequest.model_validate(request_data)
|
449
|
-
handler = self.
|
551
|
+
request = SetTaskPushNotificationRequest.model_validate(request_data.model_dump())
|
552
|
+
handler = self.registry.get_handler("tasks/pushNotification/set")
|
450
553
|
|
451
554
|
if not handler:
|
452
555
|
return SetTaskPushNotificationResponse(
|
@@ -488,10 +591,10 @@ class SmartA2A:
|
|
488
591
|
)
|
489
592
|
|
490
593
|
|
491
|
-
def _handle_get_notification(self, request_data:
|
594
|
+
def _handle_get_notification(self, request_data: JSONRPCRequest) -> GetTaskPushNotificationResponse:
|
492
595
|
try:
|
493
|
-
request = GetTaskPushNotificationRequest.model_validate(request_data)
|
494
|
-
handler = self.
|
596
|
+
request = GetTaskPushNotificationRequest.model_validate(request_data.model_dump())
|
597
|
+
handler = self.registry.get_handler("tasks/pushNotification/get")
|
495
598
|
|
496
599
|
if not handler:
|
497
600
|
return GetTaskPushNotificationResponse(
|
@@ -538,139 +641,6 @@ class SmartA2A:
|
|
538
641
|
error=JSONParseError(data=str(e))
|
539
642
|
)
|
540
643
|
|
541
|
-
|
542
|
-
def _normalize_artifacts(self, content: Any) -> List[Artifact]:
|
543
|
-
"""Handle both A2AResponse content and regular returns"""
|
544
|
-
if isinstance(content, Artifact):
|
545
|
-
return [content]
|
546
|
-
|
547
|
-
if isinstance(content, list):
|
548
|
-
# Handle list of artifacts
|
549
|
-
if all(isinstance(item, Artifact) for item in content):
|
550
|
-
return content
|
551
|
-
|
552
|
-
# Handle mixed parts in list
|
553
|
-
parts = []
|
554
|
-
for item in content:
|
555
|
-
if isinstance(item, Artifact):
|
556
|
-
parts.extend(item.parts)
|
557
|
-
else:
|
558
|
-
parts.append(self._create_part(item))
|
559
|
-
return [Artifact(parts=parts)]
|
560
|
-
|
561
|
-
# Handle single part returns
|
562
|
-
if isinstance(content, (str, Part, dict)):
|
563
|
-
return [Artifact(parts=[self._create_part(content)])]
|
564
|
-
|
565
|
-
# Handle raw artifact dicts
|
566
|
-
try:
|
567
|
-
return [Artifact.model_validate(content)]
|
568
|
-
except ValidationError:
|
569
|
-
return [Artifact(parts=[TextPart(text=str(content))])]
|
570
|
-
|
571
|
-
|
572
|
-
def _build_task(
|
573
|
-
self,
|
574
|
-
content: Any,
|
575
|
-
task_id: str,
|
576
|
-
session_id: Optional[str] = None,
|
577
|
-
default_status: TaskState = TaskState.COMPLETED,
|
578
|
-
metadata: Optional[dict] = None
|
579
|
-
) -> Task:
|
580
|
-
"""Universal task construction from various return types."""
|
581
|
-
if isinstance(content, Task):
|
582
|
-
return content
|
583
|
-
|
584
|
-
# Handle A2AResponse for sendTask case
|
585
|
-
if isinstance(content, A2AResponse):
|
586
|
-
status = content.status if isinstance(content.status, TaskStatus) \
|
587
|
-
else TaskStatus(state=content.status)
|
588
|
-
artifacts = self._normalize_content(content.content)
|
589
|
-
return Task(
|
590
|
-
id=task_id,
|
591
|
-
sessionId=session_id or str(uuid4()), # Generate if missing
|
592
|
-
status=status,
|
593
|
-
artifacts=artifacts,
|
594
|
-
metadata=metadata or {}
|
595
|
-
)
|
596
|
-
|
597
|
-
try: # Attempt direct validation for dicts
|
598
|
-
return Task.model_validate(content)
|
599
|
-
except ValidationError:
|
600
|
-
pass
|
601
|
-
|
602
|
-
# Fallback to content normalization
|
603
|
-
artifacts = self._normalize_content(content)
|
604
|
-
return Task(
|
605
|
-
id=task_id,
|
606
|
-
sessionId=session_id,
|
607
|
-
status=TaskStatus(state=default_status),
|
608
|
-
artifacts=artifacts,
|
609
|
-
metadata=metadata or {}
|
610
|
-
)
|
611
|
-
|
612
|
-
def _build_task_from_status(self, status: A2AStatus, task_id: str, metadata: dict) -> Task:
|
613
|
-
"""Convert A2AStatus to a Task with proper cancellation state."""
|
614
|
-
return Task(
|
615
|
-
id=task_id,
|
616
|
-
status=TaskStatus(
|
617
|
-
state=TaskState(status.status),
|
618
|
-
timestamp=datetime.now()
|
619
|
-
),
|
620
|
-
metadata=metadata,
|
621
|
-
# Include empty/default values for required fields
|
622
|
-
sessionId="",
|
623
|
-
artifacts=[],
|
624
|
-
history=[]
|
625
|
-
)
|
626
|
-
|
627
|
-
|
628
|
-
def _normalize_content(self, content: Any) -> List[Artifact]:
|
629
|
-
"""Handle all content types for both sendTask and getTask cases."""
|
630
|
-
if isinstance(content, Artifact):
|
631
|
-
return [content]
|
632
|
-
|
633
|
-
if isinstance(content, list):
|
634
|
-
if all(isinstance(item, Artifact) for item in content):
|
635
|
-
return content
|
636
|
-
return [Artifact(parts=self._parts_from_mixed(content))]
|
637
|
-
|
638
|
-
if isinstance(content, (str, Part, dict)):
|
639
|
-
return [Artifact(parts=[self._create_part(content)])]
|
640
|
-
|
641
|
-
try: # Handle raw artifact dicts
|
642
|
-
return [Artifact.model_validate(content)]
|
643
|
-
except ValidationError:
|
644
|
-
return [Artifact(parts=[TextPart(text=str(content))])]
|
645
|
-
|
646
|
-
def _parts_from_mixed(self, items: List[Any]) -> List[Part]:
|
647
|
-
"""Extract parts from mixed content lists."""
|
648
|
-
parts = []
|
649
|
-
for item in items:
|
650
|
-
if isinstance(item, Artifact):
|
651
|
-
parts.extend(item.parts)
|
652
|
-
else:
|
653
|
-
parts.append(self._create_part(item))
|
654
|
-
return parts
|
655
|
-
|
656
|
-
|
657
|
-
def _create_part(self, item: Any) -> Part:
|
658
|
-
"""Convert primitive types to proper Part models"""
|
659
|
-
if isinstance(item, (TextPart, FilePart, DataPart)):
|
660
|
-
return item
|
661
|
-
|
662
|
-
if isinstance(item, str):
|
663
|
-
return TextPart(text=item)
|
664
|
-
|
665
|
-
if isinstance(item, dict):
|
666
|
-
try:
|
667
|
-
return Part.model_validate(item)
|
668
|
-
except ValidationError:
|
669
|
-
return TextPart(text=str(item))
|
670
|
-
|
671
|
-
return TextPart(text=str(item))
|
672
|
-
|
673
|
-
|
674
644
|
# Response validation helper
|
675
645
|
def _validate_response_id(self, response: Union[SendTaskResponse, GetTaskResponse], request) -> Union[SendTaskResponse, GetTaskResponse]:
|
676
646
|
if response.result and response.result.id != request.params.id:
|
@@ -681,7 +651,7 @@ class SmartA2A:
|
|
681
651
|
)
|
682
652
|
)
|
683
653
|
return response
|
684
|
-
|
654
|
+
|
685
655
|
# Might refactor this later
|
686
656
|
def _finalize_task_response(self, request: GetTaskRequest, task: Task) -> GetTaskResponse:
|
687
657
|
"""Final validation and processing for getTask responses."""
|
@@ -724,8 +694,8 @@ class SmartA2A:
|
|
724
694
|
id=request.id,
|
725
695
|
result=task
|
726
696
|
)
|
727
|
-
|
728
|
-
|
697
|
+
|
698
|
+
|
729
699
|
async def _normalize_subscription_events(self, params: TaskSendParams, events: AsyncGenerator) -> AsyncGenerator[Union[SendTaskStreamingResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent], None]:
|
730
700
|
artifact_state = defaultdict(lambda: {"index": 0, "last_chunk": False})
|
731
701
|
|