smarta2a 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smarta2a/__init__.py +1 -1
- smarta2a/agent/a2a_agent.py +38 -0
- smarta2a/agent/a2a_mcp_server.py +37 -0
- smarta2a/archive/mcp_client.py +86 -0
- smarta2a/client/a2a_client.py +97 -3
- smarta2a/client/smart_mcp_client.py +60 -0
- smarta2a/client/tools_manager.py +58 -0
- smarta2a/examples/__init__.py +0 -0
- smarta2a/examples/echo_server/__init__.py +0 -0
- smarta2a/examples/echo_server/curl.txt +1 -0
- smarta2a/examples/echo_server/main.py +37 -0
- smarta2a/history_update_strategies/__init__.py +8 -0
- smarta2a/history_update_strategies/append_strategy.py +10 -0
- smarta2a/history_update_strategies/history_update_strategy.py +15 -0
- smarta2a/model_providers/__init__.py +5 -0
- smarta2a/model_providers/base_llm_provider.py +15 -0
- smarta2a/model_providers/openai_provider.py +281 -0
- smarta2a/server/handler_registry.py +23 -0
- smarta2a/server/server.py +233 -255
- smarta2a/server/state_manager.py +34 -0
- smarta2a/server/subscription_service.py +109 -0
- smarta2a/server/task_service.py +155 -0
- smarta2a/state_stores/__init__.py +8 -0
- smarta2a/state_stores/base_state_store.py +20 -0
- smarta2a/state_stores/inmemory_state_store.py +21 -0
- smarta2a/utils/prompt_helpers.py +38 -0
- smarta2a/utils/task_builder.py +153 -0
- smarta2a/{common → utils}/task_request_builder.py +1 -1
- smarta2a/{common → utils}/types.py +62 -2
- {smarta2a-0.2.2.dist-info → smarta2a-0.2.4.dist-info}/METADATA +12 -6
- smarta2a-0.2.4.dist-info/RECORD +36 -0
- smarta2a-0.2.2.dist-info/RECORD +0 -12
- /smarta2a/{common → utils}/__init__.py +0 -0
- {smarta2a-0.2.2.dist-info → smarta2a-0.2.4.dist-info}/WHEEL +0 -0
- {smarta2a-0.2.2.dist-info → smarta2a-0.2.4.dist-info}/licenses/LICENSE +0 -0
smarta2a/server/server.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
# Library imports
|
1
2
|
from typing import Callable, Any, Optional, Dict, Union, List, AsyncGenerator
|
2
3
|
import json
|
3
4
|
from datetime import datetime
|
@@ -10,7 +11,15 @@ import uvicorn
|
|
10
11
|
from fastapi.responses import StreamingResponse
|
11
12
|
from uuid import uuid4
|
12
13
|
|
13
|
-
|
14
|
+
# Local imports
|
15
|
+
from smarta2a.server.handler_registry import HandlerRegistry
|
16
|
+
from smarta2a.server.state_manager import StateManager
|
17
|
+
from smarta2a.state_stores.base_state_store import BaseStateStore
|
18
|
+
from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
|
19
|
+
from smarta2a.history_update_strategies.append_strategy import AppendStrategy
|
20
|
+
from smarta2a.utils.task_builder import TaskBuilder
|
21
|
+
|
22
|
+
from smarta2a.utils.types import (
|
14
23
|
JSONRPCResponse,
|
15
24
|
Task,
|
16
25
|
Artifact,
|
@@ -19,6 +28,7 @@ from smarta2a.common.types import (
|
|
19
28
|
FileContent,
|
20
29
|
DataPart,
|
21
30
|
Part,
|
31
|
+
Message,
|
22
32
|
TaskStatus,
|
23
33
|
TaskState,
|
24
34
|
JSONRPCError,
|
@@ -50,41 +60,78 @@ from smarta2a.common.types import (
|
|
50
60
|
SetTaskPushNotificationResponse,
|
51
61
|
GetTaskPushNotificationResponse,
|
52
62
|
TaskPushNotificationConfig,
|
63
|
+
StateData
|
53
64
|
)
|
54
65
|
|
55
66
|
class SmartA2A:
|
56
|
-
def __init__(self, name: str, **fastapi_kwargs):
|
67
|
+
def __init__(self, name: str, state_store: Optional[BaseStateStore] = None, history_strategy: HistoryUpdateStrategy = AppendStrategy(), **fastapi_kwargs):
|
57
68
|
self.name = name
|
58
|
-
self.
|
59
|
-
self.
|
69
|
+
self.registry = HandlerRegistry()
|
70
|
+
self.state_mgr = StateManager(state_store, history_strategy)
|
60
71
|
self.app = FastAPI(title=name, **fastapi_kwargs)
|
61
72
|
self.router = APIRouter()
|
62
|
-
self.
|
73
|
+
self.state_store = state_store
|
74
|
+
self.history_strategy = history_strategy
|
63
75
|
self._setup_routes()
|
64
76
|
self.server_config = {
|
65
77
|
"host": "0.0.0.0",
|
66
78
|
"port": 8000,
|
67
79
|
"reload": False
|
68
80
|
}
|
81
|
+
self.task_builder = TaskBuilder(default_status=TaskState.COMPLETED)
|
82
|
+
|
83
|
+
# Add this method to delegate ASGI calls
|
84
|
+
async def __call__(self, scope, receive, send):
|
85
|
+
return await self.app(scope, receive, send)
|
69
86
|
|
87
|
+
def on_send_task(self):
|
88
|
+
def decorator(func: Callable[[SendTaskRequest, Optional[StateData]], Any]) -> Callable:
|
89
|
+
self.registry.register("tasks/send", func)
|
90
|
+
return func
|
91
|
+
return decorator
|
92
|
+
|
93
|
+
def on_send_subscribe_task(self):
|
94
|
+
def decorator(fn: Callable[[SendTaskStreamingRequest, Optional[StateData]], Any]):
|
95
|
+
self.registry.register("tasks/sendSubscribe", fn, subscription=True)
|
96
|
+
return fn
|
97
|
+
return decorator
|
98
|
+
|
99
|
+
def task_get(self):
|
100
|
+
def decorator(fn: Callable[[GetTaskRequest], Any]):
|
101
|
+
self.registry.register("tasks/get", fn)
|
102
|
+
return fn
|
103
|
+
return decorator
|
104
|
+
|
105
|
+
def task_cancel(self):
|
106
|
+
def decorator(fn: Callable[[CancelTaskRequest], Any]):
|
107
|
+
self.registry.register("tasks/cancel", fn)
|
108
|
+
return fn
|
109
|
+
return decorator
|
110
|
+
|
111
|
+
def set_notification(self):
|
112
|
+
def decorator(fn: Callable[[SetTaskPushNotificationRequest], Any]):
|
113
|
+
self.registry.register("tasks/pushNotification/set", fn)
|
114
|
+
return fn
|
115
|
+
return decorator
|
116
|
+
|
117
|
+
def get_notification(self):
|
118
|
+
def decorator(fn: Callable[[GetTaskPushNotificationRequest], Any]):
|
119
|
+
self.registry.register("tasks/pushNotification/get", fn)
|
120
|
+
return fn
|
121
|
+
return decorator
|
122
|
+
|
70
123
|
|
71
124
|
def _setup_routes(self):
|
72
125
|
@self.app.post("/")
|
73
126
|
async def handle_request(request: Request):
|
74
127
|
try:
|
75
128
|
data = await request.json()
|
76
|
-
|
129
|
+
req = JSONRPCRequest.model_validate(data)
|
130
|
+
#request_obj = JSONRPCRequest(**data)
|
77
131
|
except Exception as e:
|
78
|
-
return JSONRPCResponse(
|
79
|
-
|
80
|
-
|
81
|
-
code=-32700,
|
82
|
-
message="Parse error",
|
83
|
-
data=str(e)
|
84
|
-
)
|
85
|
-
).model_dump()
|
86
|
-
|
87
|
-
response = await self.process_request(request_obj.model_dump())
|
132
|
+
return JSONRPCResponse(id=None, error=JSONRPCError(code=-32700, message="Parse error", data=str(e))).model_dump()
|
133
|
+
|
134
|
+
response = await self.process_request(req)
|
88
135
|
|
89
136
|
# <-- Accept both SSE‐style responses:
|
90
137
|
if isinstance(response, (EventSourceResponse, StreamingResponse)):
|
@@ -92,114 +139,117 @@ class SmartA2A:
|
|
92
139
|
|
93
140
|
# <-- Everything else is a normal pydantic JSONRPCResponse
|
94
141
|
return response.model_dump()
|
95
|
-
|
96
|
-
def _register_handler(self, method: str, func: Callable, handler_name: str, handler_type: str = "handler"):
|
97
|
-
"""Shared registration logic with duplicate checking"""
|
98
|
-
if method in self._registered_decorators:
|
99
|
-
raise RuntimeError(
|
100
|
-
f"@{handler_name} decorator for method '{method}' "
|
101
|
-
f"can only be used once per SmartA2A instance"
|
102
|
-
)
|
103
|
-
|
104
|
-
if handler_type == "handler":
|
105
|
-
self.handlers[method] = func
|
106
|
-
else:
|
107
|
-
self.subscriptions[method] = func
|
108
|
-
|
109
|
-
self._registered_decorators.add(method)
|
110
|
-
|
111
|
-
def on_send_task(self) -> Callable:
|
112
|
-
def decorator(func: Callable[[SendTaskRequest], Any]) -> Callable:
|
113
|
-
self._register_handler("tasks/send", func, "on_send_task", "handler")
|
114
|
-
return func
|
115
|
-
return decorator
|
116
|
-
|
117
|
-
def on_send_subscribe_task(self) -> Callable:
|
118
|
-
def decorator(func: Callable) -> Callable:
|
119
|
-
self._register_handler("tasks/sendSubscribe", func, "on_send_subscribe_task", "subscription")
|
120
|
-
return func
|
121
|
-
return decorator
|
122
|
-
|
123
|
-
def task_get(self):
|
124
|
-
def decorator(func: Callable[[GetTaskRequest], Task]):
|
125
|
-
self._register_handler("tasks/get", func, "task_get", "handler")
|
126
|
-
return func
|
127
|
-
return decorator
|
128
142
|
|
129
|
-
def task_cancel(self):
|
130
|
-
def decorator(func: Callable[[CancelTaskRequest], Task]):
|
131
|
-
self._register_handler("tasks/cancel", func, "task_cancel", "handler")
|
132
|
-
return func
|
133
|
-
return decorator
|
134
|
-
|
135
|
-
def set_notification(self):
|
136
|
-
def decorator(func: Callable[[SetTaskPushNotificationRequest], None]) -> Callable:
|
137
|
-
self._register_handler("tasks/pushNotification/set", func, "set_notification", "handler")
|
138
|
-
return func
|
139
|
-
return decorator
|
140
|
-
|
141
|
-
def get_notification(self):
|
142
|
-
def decorator(func: Callable[[GetTaskPushNotificationRequest], Union[TaskPushNotificationConfig, GetTaskPushNotificationResponse]]):
|
143
|
-
self._register_handler("tasks/pushNotification/get", func, "get_notification", "handler")
|
144
|
-
return func
|
145
|
-
return decorator
|
146
143
|
|
147
|
-
async def process_request(self,
|
144
|
+
async def process_request(self, request: JSONRPCRequest) -> JSONRPCResponse:
|
145
|
+
|
148
146
|
try:
|
149
|
-
method =
|
147
|
+
method = request.method
|
148
|
+
params = request.params
|
149
|
+
state_store = self.state_mgr.get_store()
|
150
150
|
if method == "tasks/send":
|
151
|
-
|
151
|
+
state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
|
152
|
+
if state_store:
|
153
|
+
return self._handle_send_task(request, state_data)
|
154
|
+
else:
|
155
|
+
return self._handle_send_task(request)
|
152
156
|
elif method == "tasks/sendSubscribe":
|
153
|
-
|
157
|
+
state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
|
158
|
+
if state_store:
|
159
|
+
return await self._handle_subscribe_task(request, state_data)
|
160
|
+
else:
|
161
|
+
return await self._handle_subscribe_task(request)
|
154
162
|
elif method == "tasks/get":
|
155
|
-
return self._handle_get_task(
|
163
|
+
return self._handle_get_task(request)
|
156
164
|
elif method == "tasks/cancel":
|
157
|
-
return self._handle_cancel_task(
|
165
|
+
return self._handle_cancel_task(request)
|
158
166
|
elif method == "tasks/pushNotification/set":
|
159
|
-
return self._handle_set_notification(
|
167
|
+
return self._handle_set_notification(request)
|
160
168
|
elif method == "tasks/pushNotification/get":
|
161
|
-
return self._handle_get_notification(
|
169
|
+
return self._handle_get_notification(request)
|
162
170
|
else:
|
163
|
-
return
|
164
|
-
request_data.get("id"),
|
165
|
-
-32601,
|
166
|
-
"Method not found"
|
167
|
-
)
|
171
|
+
return JSONRPCResponse(id=request.id, error=MethodNotFoundError()).model_dump()
|
168
172
|
except ValidationError as e:
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
)
|
173
|
+
return JSONRPCResponse(id=request.id, error=InvalidParamsError(data=e.errors())).model_dump()
|
174
|
+
except HTTPException as e:
|
175
|
+
err = UnsupportedOperationError() if e.status_code == 405 else InternalError(data=str(e))
|
176
|
+
return JSONRPCResponse(id=request.id, error=err).model_dump()
|
177
|
+
|
175
178
|
|
176
|
-
def _handle_send_task(self, request_data:
|
179
|
+
def _handle_send_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> SendTaskResponse:
|
177
180
|
try:
|
178
181
|
# Validate request format
|
179
|
-
request = SendTaskRequest.model_validate(request_data)
|
180
|
-
handler = self.
|
182
|
+
request = SendTaskRequest.model_validate(request_data.model_dump())
|
183
|
+
handler = self.registry.get_handler("tasks/send")
|
181
184
|
|
182
185
|
if not handler:
|
183
186
|
return SendTaskResponse(
|
184
187
|
id=request.id,
|
185
188
|
error=MethodNotFoundError()
|
186
189
|
)
|
190
|
+
|
191
|
+
user_message = request.params.message
|
192
|
+
request_metadata = request.params.metadata or {}
|
193
|
+
if state_data:
|
194
|
+
session_id = state_data.sessionId
|
195
|
+
existing_history = state_data.history.copy() or []
|
196
|
+
metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
|
197
|
+
else:
|
198
|
+
session_id = request.params.sessionId or str(uuid4())
|
199
|
+
existing_history = [user_message]
|
200
|
+
metadata = request_metadata
|
201
|
+
|
187
202
|
|
188
203
|
try:
|
189
|
-
|
190
|
-
|
204
|
+
|
205
|
+
if state_data:
|
206
|
+
raw_result = handler(request, state_data)
|
207
|
+
else:
|
208
|
+
raw_result = handler(request)
|
209
|
+
|
210
|
+
# Handle direct SendTaskResponse returns
|
191
211
|
if isinstance(raw_result, SendTaskResponse):
|
192
212
|
return raw_result
|
193
213
|
|
194
|
-
#
|
195
|
-
task = self.
|
214
|
+
# Build task with updated history (before agent response)
|
215
|
+
task = self.task_builder.build(
|
196
216
|
content=raw_result,
|
197
217
|
task_id=request.params.id,
|
198
|
-
session_id=
|
199
|
-
|
200
|
-
|
218
|
+
session_id=session_id, # Always use generated session ID
|
219
|
+
metadata=metadata, # Use merged metadata
|
220
|
+
history=existing_history # History
|
201
221
|
)
|
202
222
|
|
223
|
+
# Process messages through strategy
|
224
|
+
messages = []
|
225
|
+
if task.artifacts:
|
226
|
+
agent_parts = [p for a in task.artifacts for p in a.parts]
|
227
|
+
agent_message = Message(
|
228
|
+
role="agent",
|
229
|
+
parts=agent_parts,
|
230
|
+
metadata=task.metadata
|
231
|
+
)
|
232
|
+
messages.append(agent_message)
|
233
|
+
|
234
|
+
final_history = self.history_strategy.update_history(
|
235
|
+
existing_history=existing_history,
|
236
|
+
new_messages=messages
|
237
|
+
)
|
238
|
+
|
239
|
+
# Update task with final state
|
240
|
+
task.history = final_history
|
241
|
+
|
242
|
+
# State store update (if enabled)
|
243
|
+
if self.state_store:
|
244
|
+
self.state_store.update_state(
|
245
|
+
session_id=session_id,
|
246
|
+
state_data=StateData(
|
247
|
+
sessionId=session_id,
|
248
|
+
history=final_history,
|
249
|
+
metadata=metadata # Use merged metadata
|
250
|
+
)
|
251
|
+
)
|
252
|
+
|
203
253
|
return SendTaskResponse(
|
204
254
|
id=request.id,
|
205
255
|
result=task
|
@@ -229,10 +279,11 @@ class SmartA2A:
|
|
229
279
|
)
|
230
280
|
|
231
281
|
|
232
|
-
async def _handle_subscribe_task(self, request_data:
|
282
|
+
async def _handle_subscribe_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> Union[EventSourceResponse, SendTaskStreamingResponse]:
|
233
283
|
try:
|
234
|
-
request = SendTaskStreamingRequest.model_validate(request_data)
|
235
|
-
handler = self.subscriptions.get("tasks/sendSubscribe")
|
284
|
+
request = SendTaskStreamingRequest.model_validate(request_data.model_dump())
|
285
|
+
#handler = self.subscriptions.get("tasks/sendSubscribe")
|
286
|
+
handler = self.registry.get_subscription("tasks/sendSubscribe")
|
236
287
|
|
237
288
|
if not handler:
|
238
289
|
return SendTaskStreamingResponse(
|
@@ -240,15 +291,74 @@ class SmartA2A:
|
|
240
291
|
id=request.id,
|
241
292
|
error=MethodNotFoundError()
|
242
293
|
)
|
294
|
+
|
295
|
+
user_message = request.params.message
|
296
|
+
request_metadata = request.params.metadata or {}
|
297
|
+
if state_data:
|
298
|
+
session_id = state_data.sessionId
|
299
|
+
existing_history = state_data.history.copy() or []
|
300
|
+
metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
|
301
|
+
else:
|
302
|
+
session_id = request.params.sessionId or str(uuid4())
|
303
|
+
existing_history = [user_message]
|
304
|
+
metadata = request_metadata
|
305
|
+
|
243
306
|
|
244
307
|
async def event_generator():
|
245
308
|
|
246
309
|
try:
|
247
|
-
|
310
|
+
|
311
|
+
if state_data:
|
312
|
+
raw_events = handler(request, state_data)
|
313
|
+
else:
|
314
|
+
raw_events = handler(request)
|
315
|
+
|
248
316
|
normalized_events = self._normalize_subscription_events(request.params, raw_events)
|
249
317
|
|
318
|
+
# Initialize streaming state
|
319
|
+
stream_history = existing_history.copy()
|
320
|
+
stream_metadata = metadata.copy()
|
321
|
+
|
250
322
|
async for item in normalized_events:
|
251
323
|
try:
|
324
|
+
|
325
|
+
# Process artifact updates
|
326
|
+
if isinstance(item, TaskArtifactUpdateEvent):
|
327
|
+
# Create agent message from artifact parts
|
328
|
+
agent_message = Message(
|
329
|
+
role="agent",
|
330
|
+
parts=[p for p in item.artifact.parts],
|
331
|
+
metadata=item.artifact.metadata
|
332
|
+
)
|
333
|
+
|
334
|
+
# Update history using strategy
|
335
|
+
new_history = self.history_strategy.update_history(
|
336
|
+
existing_history=stream_history,
|
337
|
+
new_messages=[agent_message]
|
338
|
+
)
|
339
|
+
|
340
|
+
# Merge metadata
|
341
|
+
new_metadata = {
|
342
|
+
**stream_metadata,
|
343
|
+
**(item.artifact.metadata or {})
|
344
|
+
}
|
345
|
+
|
346
|
+
# Update state store if configured
|
347
|
+
if self.state_store:
|
348
|
+
self.state_store.update_state(
|
349
|
+
session_id=session_id,
|
350
|
+
state_data=StateData(
|
351
|
+
sessionId=session_id,
|
352
|
+
history=new_history,
|
353
|
+
metadata=new_metadata
|
354
|
+
)
|
355
|
+
)
|
356
|
+
|
357
|
+
# Update streaming state
|
358
|
+
stream_history = new_history
|
359
|
+
stream_metadata = new_metadata
|
360
|
+
|
361
|
+
|
252
362
|
if isinstance(item, SendTaskStreamingResponse):
|
253
363
|
yield item.model_dump_json()
|
254
364
|
continue
|
@@ -318,11 +428,11 @@ class SmartA2A:
|
|
318
428
|
)
|
319
429
|
|
320
430
|
|
321
|
-
def _handle_get_task(self, request_data:
|
431
|
+
def _handle_get_task(self, request_data: JSONRPCRequest) -> GetTaskResponse:
|
322
432
|
try:
|
323
433
|
# Validate request structure
|
324
|
-
request = GetTaskRequest.model_validate(request_data)
|
325
|
-
handler = self.
|
434
|
+
request = GetTaskRequest.model_validate(request_data.model_dump())
|
435
|
+
handler = self.registry.get_handler("tasks/get")
|
326
436
|
|
327
437
|
if not handler:
|
328
438
|
return GetTaskResponse(
|
@@ -337,11 +447,10 @@ class SmartA2A:
|
|
337
447
|
return self._validate_response_id(raw_result, request)
|
338
448
|
|
339
449
|
# Use unified task builder with different defaults
|
340
|
-
task = self.
|
450
|
+
task = self.task_builder.build(
|
341
451
|
content=raw_result,
|
342
452
|
task_id=request.params.id,
|
343
|
-
|
344
|
-
metadata=request.params.metadata or {}
|
453
|
+
metadata=getattr(raw_result, "metadata", {}) or {}
|
345
454
|
)
|
346
455
|
|
347
456
|
return self._finalize_task_response(request, task)
|
@@ -370,11 +479,11 @@ class SmartA2A:
|
|
370
479
|
)
|
371
480
|
|
372
481
|
|
373
|
-
def _handle_cancel_task(self, request_data:
|
482
|
+
def _handle_cancel_task(self, request_data: JSONRPCRequest) -> CancelTaskResponse:
|
374
483
|
try:
|
375
484
|
# Validate request structure
|
376
|
-
request = CancelTaskRequest.model_validate(request_data)
|
377
|
-
handler = self.
|
485
|
+
request = CancelTaskRequest.model_validate(request_data.model_dump())
|
486
|
+
handler = self.registry.get_handler("tasks/cancel")
|
378
487
|
|
379
488
|
if not handler:
|
380
489
|
return CancelTaskResponse(
|
@@ -385,24 +494,26 @@ class SmartA2A:
|
|
385
494
|
try:
|
386
495
|
raw_result = handler(request)
|
387
496
|
|
497
|
+
cancel_task_builder = TaskBuilder(default_status=TaskState.CANCELED)
|
388
498
|
# Handle direct CancelTaskResponse returns
|
389
499
|
if isinstance(raw_result, CancelTaskResponse):
|
390
500
|
return self._validate_response_id(raw_result, request)
|
391
501
|
|
392
502
|
# Handle A2AStatus returns
|
393
503
|
if isinstance(raw_result, A2AStatus):
|
394
|
-
task =
|
395
|
-
status=raw_result,
|
504
|
+
task = cancel_task_builder.normalize_from_status(
|
505
|
+
status=raw_result.status,
|
396
506
|
task_id=request.params.id,
|
397
|
-
metadata=raw_result
|
507
|
+
metadata=getattr(raw_result, "metadata", {}) or {}
|
398
508
|
)
|
399
509
|
else:
|
400
510
|
# Existing processing for other return types
|
401
|
-
task =
|
511
|
+
task = cancel_task_builder.build(
|
402
512
|
content=raw_result,
|
403
513
|
task_id=request.params.id,
|
404
|
-
metadata=raw_result
|
514
|
+
metadata=getattr(raw_result, "metadata", {}) or {}
|
405
515
|
)
|
516
|
+
print(task)
|
406
517
|
|
407
518
|
# Final validation and packaging
|
408
519
|
return self._finalize_cancel_response(request, task)
|
@@ -440,10 +551,10 @@ class SmartA2A:
|
|
440
551
|
error=InternalError(data=str(e))
|
441
552
|
)
|
442
553
|
|
443
|
-
def _handle_set_notification(self, request_data:
|
554
|
+
def _handle_set_notification(self, request_data: JSONRPCRequest) -> SetTaskPushNotificationResponse:
|
444
555
|
try:
|
445
|
-
request = SetTaskPushNotificationRequest.model_validate(request_data)
|
446
|
-
handler = self.
|
556
|
+
request = SetTaskPushNotificationRequest.model_validate(request_data.model_dump())
|
557
|
+
handler = self.registry.get_handler("tasks/pushNotification/set")
|
447
558
|
|
448
559
|
if not handler:
|
449
560
|
return SetTaskPushNotificationResponse(
|
@@ -485,10 +596,10 @@ class SmartA2A:
|
|
485
596
|
)
|
486
597
|
|
487
598
|
|
488
|
-
def _handle_get_notification(self, request_data:
|
599
|
+
def _handle_get_notification(self, request_data: JSONRPCRequest) -> GetTaskPushNotificationResponse:
|
489
600
|
try:
|
490
|
-
request = GetTaskPushNotificationRequest.model_validate(request_data)
|
491
|
-
handler = self.
|
601
|
+
request = GetTaskPushNotificationRequest.model_validate(request_data.model_dump())
|
602
|
+
handler = self.registry.get_handler("tasks/pushNotification/get")
|
492
603
|
|
493
604
|
if not handler:
|
494
605
|
return GetTaskPushNotificationResponse(
|
@@ -535,139 +646,6 @@ class SmartA2A:
|
|
535
646
|
error=JSONParseError(data=str(e))
|
536
647
|
)
|
537
648
|
|
538
|
-
|
539
|
-
def _normalize_artifacts(self, content: Any) -> List[Artifact]:
|
540
|
-
"""Handle both A2AResponse content and regular returns"""
|
541
|
-
if isinstance(content, Artifact):
|
542
|
-
return [content]
|
543
|
-
|
544
|
-
if isinstance(content, list):
|
545
|
-
# Handle list of artifacts
|
546
|
-
if all(isinstance(item, Artifact) for item in content):
|
547
|
-
return content
|
548
|
-
|
549
|
-
# Handle mixed parts in list
|
550
|
-
parts = []
|
551
|
-
for item in content:
|
552
|
-
if isinstance(item, Artifact):
|
553
|
-
parts.extend(item.parts)
|
554
|
-
else:
|
555
|
-
parts.append(self._create_part(item))
|
556
|
-
return [Artifact(parts=parts)]
|
557
|
-
|
558
|
-
# Handle single part returns
|
559
|
-
if isinstance(content, (str, Part, dict)):
|
560
|
-
return [Artifact(parts=[self._create_part(content)])]
|
561
|
-
|
562
|
-
# Handle raw artifact dicts
|
563
|
-
try:
|
564
|
-
return [Artifact.model_validate(content)]
|
565
|
-
except ValidationError:
|
566
|
-
return [Artifact(parts=[TextPart(text=str(content))])]
|
567
|
-
|
568
|
-
|
569
|
-
def _build_task(
|
570
|
-
self,
|
571
|
-
content: Any,
|
572
|
-
task_id: str,
|
573
|
-
session_id: Optional[str] = None,
|
574
|
-
default_status: TaskState = TaskState.COMPLETED,
|
575
|
-
metadata: Optional[dict] = None
|
576
|
-
) -> Task:
|
577
|
-
"""Universal task construction from various return types."""
|
578
|
-
if isinstance(content, Task):
|
579
|
-
return content
|
580
|
-
|
581
|
-
# Handle A2AResponse for sendTask case
|
582
|
-
if isinstance(content, A2AResponse):
|
583
|
-
status = content.status if isinstance(content.status, TaskStatus) \
|
584
|
-
else TaskStatus(state=content.status)
|
585
|
-
artifacts = self._normalize_content(content.content)
|
586
|
-
return Task(
|
587
|
-
id=task_id,
|
588
|
-
sessionId=session_id or str(uuid4()), # Generate if missing
|
589
|
-
status=status,
|
590
|
-
artifacts=artifacts,
|
591
|
-
metadata=metadata or {}
|
592
|
-
)
|
593
|
-
|
594
|
-
try: # Attempt direct validation for dicts
|
595
|
-
return Task.model_validate(content)
|
596
|
-
except ValidationError:
|
597
|
-
pass
|
598
|
-
|
599
|
-
# Fallback to content normalization
|
600
|
-
artifacts = self._normalize_content(content)
|
601
|
-
return Task(
|
602
|
-
id=task_id,
|
603
|
-
sessionId=session_id,
|
604
|
-
status=TaskStatus(state=default_status),
|
605
|
-
artifacts=artifacts,
|
606
|
-
metadata=metadata or {}
|
607
|
-
)
|
608
|
-
|
609
|
-
def _build_task_from_status(self, status: A2AStatus, task_id: str, metadata: dict) -> Task:
|
610
|
-
"""Convert A2AStatus to a Task with proper cancellation state."""
|
611
|
-
return Task(
|
612
|
-
id=task_id,
|
613
|
-
status=TaskStatus(
|
614
|
-
state=TaskState(status.status),
|
615
|
-
timestamp=datetime.now()
|
616
|
-
),
|
617
|
-
metadata=metadata,
|
618
|
-
# Include empty/default values for required fields
|
619
|
-
sessionId="",
|
620
|
-
artifacts=[],
|
621
|
-
history=[]
|
622
|
-
)
|
623
|
-
|
624
|
-
|
625
|
-
def _normalize_content(self, content: Any) -> List[Artifact]:
|
626
|
-
"""Handle all content types for both sendTask and getTask cases."""
|
627
|
-
if isinstance(content, Artifact):
|
628
|
-
return [content]
|
629
|
-
|
630
|
-
if isinstance(content, list):
|
631
|
-
if all(isinstance(item, Artifact) for item in content):
|
632
|
-
return content
|
633
|
-
return [Artifact(parts=self._parts_from_mixed(content))]
|
634
|
-
|
635
|
-
if isinstance(content, (str, Part, dict)):
|
636
|
-
return [Artifact(parts=[self._create_part(content)])]
|
637
|
-
|
638
|
-
try: # Handle raw artifact dicts
|
639
|
-
return [Artifact.model_validate(content)]
|
640
|
-
except ValidationError:
|
641
|
-
return [Artifact(parts=[TextPart(text=str(content))])]
|
642
|
-
|
643
|
-
def _parts_from_mixed(self, items: List[Any]) -> List[Part]:
|
644
|
-
"""Extract parts from mixed content lists."""
|
645
|
-
parts = []
|
646
|
-
for item in items:
|
647
|
-
if isinstance(item, Artifact):
|
648
|
-
parts.extend(item.parts)
|
649
|
-
else:
|
650
|
-
parts.append(self._create_part(item))
|
651
|
-
return parts
|
652
|
-
|
653
|
-
|
654
|
-
def _create_part(self, item: Any) -> Part:
|
655
|
-
"""Convert primitive types to proper Part models"""
|
656
|
-
if isinstance(item, (TextPart, FilePart, DataPart)):
|
657
|
-
return item
|
658
|
-
|
659
|
-
if isinstance(item, str):
|
660
|
-
return TextPart(text=item)
|
661
|
-
|
662
|
-
if isinstance(item, dict):
|
663
|
-
try:
|
664
|
-
return Part.model_validate(item)
|
665
|
-
except ValidationError:
|
666
|
-
return TextPart(text=str(item))
|
667
|
-
|
668
|
-
return TextPart(text=str(item))
|
669
|
-
|
670
|
-
|
671
649
|
# Response validation helper
|
672
650
|
def _validate_response_id(self, response: Union[SendTaskResponse, GetTaskResponse], request) -> Union[SendTaskResponse, GetTaskResponse]:
|
673
651
|
if response.result and response.result.id != request.params.id:
|
@@ -678,7 +656,7 @@ class SmartA2A:
|
|
678
656
|
)
|
679
657
|
)
|
680
658
|
return response
|
681
|
-
|
659
|
+
|
682
660
|
# Might refactor this later
|
683
661
|
def _finalize_task_response(self, request: GetTaskRequest, task: Task) -> GetTaskResponse:
|
684
662
|
"""Final validation and processing for getTask responses."""
|
@@ -721,8 +699,8 @@ class SmartA2A:
|
|
721
699
|
id=request.id,
|
722
700
|
result=task
|
723
701
|
)
|
724
|
-
|
725
|
-
|
702
|
+
|
703
|
+
|
726
704
|
async def _normalize_subscription_events(self, params: TaskSendParams, events: AsyncGenerator) -> AsyncGenerator[Union[SendTaskStreamingResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent], None]:
|
727
705
|
artifact_state = defaultdict(lambda: {"index": 0, "last_chunk": False})
|
728
706
|
|