smarta2a 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smarta2a/__init__.py +1 -1
- smarta2a/agent/a2a_agent.py +38 -0
- smarta2a/agent/a2a_mcp_server.py +37 -0
- smarta2a/archive/mcp_client.py +86 -0
- smarta2a/client/a2a_client.py +97 -3
- smarta2a/client/smart_mcp_client.py +60 -0
- smarta2a/client/tools_manager.py +58 -0
- smarta2a/history_update_strategies/__init__.py +8 -0
- smarta2a/history_update_strategies/append_strategy.py +10 -0
- smarta2a/history_update_strategies/history_update_strategy.py +15 -0
- smarta2a/model_providers/__init__.py +5 -0
- smarta2a/model_providers/base_llm_provider.py +15 -0
- smarta2a/model_providers/openai_provider.py +281 -0
- smarta2a/server/handler_registry.py +23 -0
- smarta2a/server/server.py +225 -252
- smarta2a/server/state_manager.py +34 -0
- smarta2a/server/subscription_service.py +109 -0
- smarta2a/server/task_service.py +155 -0
- smarta2a/state_stores/__init__.py +8 -0
- smarta2a/state_stores/base_state_store.py +20 -0
- smarta2a/state_stores/inmemory_state_store.py +21 -0
- smarta2a/utils/prompt_helpers.py +38 -0
- smarta2a/utils/task_builder.py +153 -0
- smarta2a/{common → utils}/task_request_builder.py +1 -1
- smarta2a/{common → utils}/types.py +62 -2
- {smarta2a-0.2.2.dist-info → smarta2a-0.2.3.dist-info}/METADATA +12 -6
- smarta2a-0.2.3.dist-info/RECORD +32 -0
- smarta2a-0.2.2.dist-info/RECORD +0 -12
- /smarta2a/{common → utils}/__init__.py +0 -0
- {smarta2a-0.2.2.dist-info → smarta2a-0.2.3.dist-info}/WHEEL +0 -0
- {smarta2a-0.2.2.dist-info → smarta2a-0.2.3.dist-info}/licenses/LICENSE +0 -0
smarta2a/server/server.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
# Library imports
|
1
2
|
from typing import Callable, Any, Optional, Dict, Union, List, AsyncGenerator
|
2
3
|
import json
|
3
4
|
from datetime import datetime
|
@@ -10,7 +11,15 @@ import uvicorn
|
|
10
11
|
from fastapi.responses import StreamingResponse
|
11
12
|
from uuid import uuid4
|
12
13
|
|
13
|
-
|
14
|
+
# Local imports
|
15
|
+
from smarta2a.server.handler_registry import HandlerRegistry
|
16
|
+
from smarta2a.server.state_manager import StateManager
|
17
|
+
from smarta2a.state_stores.base_state_store import BaseStateStore
|
18
|
+
from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
|
19
|
+
from smarta2a.history_update_strategies.append_strategy import AppendStrategy
|
20
|
+
from smarta2a.utils.task_builder import TaskBuilder
|
21
|
+
|
22
|
+
from smarta2a.utils.types import (
|
14
23
|
JSONRPCResponse,
|
15
24
|
Task,
|
16
25
|
Artifact,
|
@@ -19,6 +28,7 @@ from smarta2a.common.types import (
|
|
19
28
|
FileContent,
|
20
29
|
DataPart,
|
21
30
|
Part,
|
31
|
+
Message,
|
22
32
|
TaskStatus,
|
23
33
|
TaskState,
|
24
34
|
JSONRPCError,
|
@@ -50,41 +60,75 @@ from smarta2a.common.types import (
|
|
50
60
|
SetTaskPushNotificationResponse,
|
51
61
|
GetTaskPushNotificationResponse,
|
52
62
|
TaskPushNotificationConfig,
|
63
|
+
StateData
|
53
64
|
)
|
54
65
|
|
55
66
|
class SmartA2A:
|
56
|
-
def __init__(self, name: str, **fastapi_kwargs):
|
67
|
+
def __init__(self, name: str, state_store: Optional[BaseStateStore] = None, history_strategy: HistoryUpdateStrategy = AppendStrategy(), **fastapi_kwargs):
|
57
68
|
self.name = name
|
58
|
-
self.
|
59
|
-
self.
|
69
|
+
self.registry = HandlerRegistry()
|
70
|
+
self.state_mgr = StateManager(state_store, history_strategy)
|
60
71
|
self.app = FastAPI(title=name, **fastapi_kwargs)
|
61
72
|
self.router = APIRouter()
|
62
|
-
self.
|
73
|
+
self.state_store = state_store
|
74
|
+
self.history_strategy = history_strategy
|
63
75
|
self._setup_routes()
|
64
76
|
self.server_config = {
|
65
77
|
"host": "0.0.0.0",
|
66
78
|
"port": 8000,
|
67
79
|
"reload": False
|
68
80
|
}
|
81
|
+
self.task_builder = TaskBuilder(default_status=TaskState.COMPLETED)
|
69
82
|
|
83
|
+
|
84
|
+
def on_send_task(self):
|
85
|
+
def decorator(func: Callable[[SendTaskRequest, Optional[StateData]], Any]) -> Callable:
|
86
|
+
self.registry.register("tasks/send", func)
|
87
|
+
return func
|
88
|
+
return decorator
|
89
|
+
|
90
|
+
def on_send_subscribe_task(self):
|
91
|
+
def decorator(fn: Callable[[SendTaskStreamingRequest, Optional[StateData]], Any]):
|
92
|
+
self.registry.register("tasks/sendSubscribe", fn, subscription=True)
|
93
|
+
return fn
|
94
|
+
return decorator
|
95
|
+
|
96
|
+
def task_get(self):
|
97
|
+
def decorator(fn: Callable[[GetTaskRequest], Any]):
|
98
|
+
self.registry.register("tasks/get", fn)
|
99
|
+
return fn
|
100
|
+
return decorator
|
101
|
+
|
102
|
+
def task_cancel(self):
|
103
|
+
def decorator(fn: Callable[[CancelTaskRequest], Any]):
|
104
|
+
self.registry.register("tasks/cancel", fn)
|
105
|
+
return fn
|
106
|
+
return decorator
|
107
|
+
|
108
|
+
def set_notification(self):
|
109
|
+
def decorator(fn: Callable[[SetTaskPushNotificationRequest], Any]):
|
110
|
+
self.registry.register("tasks/pushNotification/set", fn)
|
111
|
+
return fn
|
112
|
+
return decorator
|
113
|
+
|
114
|
+
def get_notification(self):
|
115
|
+
def decorator(fn: Callable[[GetTaskPushNotificationRequest], Any]):
|
116
|
+
self.registry.register("tasks/pushNotification/get", fn)
|
117
|
+
return fn
|
118
|
+
return decorator
|
119
|
+
|
70
120
|
|
71
121
|
def _setup_routes(self):
|
72
122
|
@self.app.post("/")
|
73
123
|
async def handle_request(request: Request):
|
74
124
|
try:
|
75
125
|
data = await request.json()
|
76
|
-
|
126
|
+
req = JSONRPCRequest.model_validate(data)
|
127
|
+
#request_obj = JSONRPCRequest(**data)
|
77
128
|
except Exception as e:
|
78
|
-
return JSONRPCResponse(
|
79
|
-
|
80
|
-
|
81
|
-
code=-32700,
|
82
|
-
message="Parse error",
|
83
|
-
data=str(e)
|
84
|
-
)
|
85
|
-
).model_dump()
|
86
|
-
|
87
|
-
response = await self.process_request(request_obj.model_dump())
|
129
|
+
return JSONRPCResponse(id=None, error=JSONRPCError(code=-32700, message="Parse error", data=str(e))).model_dump()
|
130
|
+
|
131
|
+
response = await self.process_request(req)
|
88
132
|
|
89
133
|
# <-- Accept both SSE‐style responses:
|
90
134
|
if isinstance(response, (EventSourceResponse, StreamingResponse)):
|
@@ -92,114 +136,117 @@ class SmartA2A:
|
|
92
136
|
|
93
137
|
# <-- Everything else is a normal pydantic JSONRPCResponse
|
94
138
|
return response.model_dump()
|
95
|
-
|
96
|
-
def _register_handler(self, method: str, func: Callable, handler_name: str, handler_type: str = "handler"):
|
97
|
-
"""Shared registration logic with duplicate checking"""
|
98
|
-
if method in self._registered_decorators:
|
99
|
-
raise RuntimeError(
|
100
|
-
f"@{handler_name} decorator for method '{method}' "
|
101
|
-
f"can only be used once per SmartA2A instance"
|
102
|
-
)
|
103
|
-
|
104
|
-
if handler_type == "handler":
|
105
|
-
self.handlers[method] = func
|
106
|
-
else:
|
107
|
-
self.subscriptions[method] = func
|
108
|
-
|
109
|
-
self._registered_decorators.add(method)
|
110
|
-
|
111
|
-
def on_send_task(self) -> Callable:
|
112
|
-
def decorator(func: Callable[[SendTaskRequest], Any]) -> Callable:
|
113
|
-
self._register_handler("tasks/send", func, "on_send_task", "handler")
|
114
|
-
return func
|
115
|
-
return decorator
|
116
|
-
|
117
|
-
def on_send_subscribe_task(self) -> Callable:
|
118
|
-
def decorator(func: Callable) -> Callable:
|
119
|
-
self._register_handler("tasks/sendSubscribe", func, "on_send_subscribe_task", "subscription")
|
120
|
-
return func
|
121
|
-
return decorator
|
122
|
-
|
123
|
-
def task_get(self):
|
124
|
-
def decorator(func: Callable[[GetTaskRequest], Task]):
|
125
|
-
self._register_handler("tasks/get", func, "task_get", "handler")
|
126
|
-
return func
|
127
|
-
return decorator
|
128
139
|
|
129
|
-
def task_cancel(self):
|
130
|
-
def decorator(func: Callable[[CancelTaskRequest], Task]):
|
131
|
-
self._register_handler("tasks/cancel", func, "task_cancel", "handler")
|
132
|
-
return func
|
133
|
-
return decorator
|
134
|
-
|
135
|
-
def set_notification(self):
|
136
|
-
def decorator(func: Callable[[SetTaskPushNotificationRequest], None]) -> Callable:
|
137
|
-
self._register_handler("tasks/pushNotification/set", func, "set_notification", "handler")
|
138
|
-
return func
|
139
|
-
return decorator
|
140
|
-
|
141
|
-
def get_notification(self):
|
142
|
-
def decorator(func: Callable[[GetTaskPushNotificationRequest], Union[TaskPushNotificationConfig, GetTaskPushNotificationResponse]]):
|
143
|
-
self._register_handler("tasks/pushNotification/get", func, "get_notification", "handler")
|
144
|
-
return func
|
145
|
-
return decorator
|
146
140
|
|
147
|
-
async def process_request(self,
|
141
|
+
async def process_request(self, request: JSONRPCRequest) -> JSONRPCResponse:
|
142
|
+
|
148
143
|
try:
|
149
|
-
method =
|
144
|
+
method = request.method
|
145
|
+
params = request.params
|
146
|
+
state_store = self.state_mgr.get_store()
|
150
147
|
if method == "tasks/send":
|
151
|
-
|
148
|
+
state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
|
149
|
+
if state_store:
|
150
|
+
return self._handle_send_task(request, state_data)
|
151
|
+
else:
|
152
|
+
return self._handle_send_task(request)
|
152
153
|
elif method == "tasks/sendSubscribe":
|
153
|
-
|
154
|
+
state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
|
155
|
+
if state_store:
|
156
|
+
return await self._handle_subscribe_task(request, state_data)
|
157
|
+
else:
|
158
|
+
return await self._handle_subscribe_task(request)
|
154
159
|
elif method == "tasks/get":
|
155
|
-
return self._handle_get_task(
|
160
|
+
return self._handle_get_task(request)
|
156
161
|
elif method == "tasks/cancel":
|
157
|
-
return self._handle_cancel_task(
|
162
|
+
return self._handle_cancel_task(request)
|
158
163
|
elif method == "tasks/pushNotification/set":
|
159
|
-
return self._handle_set_notification(
|
164
|
+
return self._handle_set_notification(request)
|
160
165
|
elif method == "tasks/pushNotification/get":
|
161
|
-
return self._handle_get_notification(
|
166
|
+
return self._handle_get_notification(request)
|
162
167
|
else:
|
163
|
-
return
|
164
|
-
request_data.get("id"),
|
165
|
-
-32601,
|
166
|
-
"Method not found"
|
167
|
-
)
|
168
|
+
return JSONRPCResponse(id=request.id, error=MethodNotFoundError()).model_dump()
|
168
169
|
except ValidationError as e:
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
)
|
170
|
+
return JSONRPCResponse(id=request.id, error=InvalidParamsError(data=e.errors())).model_dump()
|
171
|
+
except HTTPException as e:
|
172
|
+
err = UnsupportedOperationError() if e.status_code == 405 else InternalError(data=str(e))
|
173
|
+
return JSONRPCResponse(id=request.id, error=err).model_dump()
|
174
|
+
|
175
175
|
|
176
|
-
def _handle_send_task(self, request_data:
|
176
|
+
def _handle_send_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> SendTaskResponse:
|
177
177
|
try:
|
178
178
|
# Validate request format
|
179
|
-
request = SendTaskRequest.model_validate(request_data)
|
180
|
-
handler = self.
|
179
|
+
request = SendTaskRequest.model_validate(request_data.model_dump())
|
180
|
+
handler = self.registry.get_handler("tasks/send")
|
181
181
|
|
182
182
|
if not handler:
|
183
183
|
return SendTaskResponse(
|
184
184
|
id=request.id,
|
185
185
|
error=MethodNotFoundError()
|
186
186
|
)
|
187
|
+
|
188
|
+
user_message = request.params.message
|
189
|
+
request_metadata = request.params.metadata or {}
|
190
|
+
if state_data:
|
191
|
+
session_id = state_data.sessionId
|
192
|
+
existing_history = state_data.history.copy() or []
|
193
|
+
metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
|
194
|
+
else:
|
195
|
+
session_id = request.params.sessionId or str(uuid4())
|
196
|
+
existing_history = [user_message]
|
197
|
+
metadata = request_metadata
|
198
|
+
|
187
199
|
|
188
200
|
try:
|
189
|
-
|
190
|
-
|
201
|
+
|
202
|
+
if state_data:
|
203
|
+
raw_result = handler(request, state_data)
|
204
|
+
else:
|
205
|
+
raw_result = handler(request)
|
206
|
+
|
207
|
+
# Handle direct SendTaskResponse returns
|
191
208
|
if isinstance(raw_result, SendTaskResponse):
|
192
209
|
return raw_result
|
193
210
|
|
194
|
-
#
|
195
|
-
task = self.
|
211
|
+
# Build task with updated history (before agent response)
|
212
|
+
task = self.task_builder.build(
|
196
213
|
content=raw_result,
|
197
214
|
task_id=request.params.id,
|
198
|
-
session_id=
|
199
|
-
|
200
|
-
|
215
|
+
session_id=session_id, # Always use generated session ID
|
216
|
+
metadata=metadata, # Use merged metadata
|
217
|
+
history=existing_history # History
|
201
218
|
)
|
202
219
|
|
220
|
+
# Process messages through strategy
|
221
|
+
messages = []
|
222
|
+
if task.artifacts:
|
223
|
+
agent_parts = [p for a in task.artifacts for p in a.parts]
|
224
|
+
agent_message = Message(
|
225
|
+
role="agent",
|
226
|
+
parts=agent_parts,
|
227
|
+
metadata=task.metadata
|
228
|
+
)
|
229
|
+
messages.append(agent_message)
|
230
|
+
|
231
|
+
final_history = self.history_strategy.update_history(
|
232
|
+
existing_history=existing_history,
|
233
|
+
new_messages=messages
|
234
|
+
)
|
235
|
+
|
236
|
+
# Update task with final state
|
237
|
+
task.history = final_history
|
238
|
+
|
239
|
+
# State store update (if enabled)
|
240
|
+
if self.state_store:
|
241
|
+
self.state_store.update_state(
|
242
|
+
session_id=session_id,
|
243
|
+
state_data=StateData(
|
244
|
+
sessionId=session_id,
|
245
|
+
history=final_history,
|
246
|
+
metadata=metadata # Use merged metadata
|
247
|
+
)
|
248
|
+
)
|
249
|
+
|
203
250
|
return SendTaskResponse(
|
204
251
|
id=request.id,
|
205
252
|
result=task
|
@@ -229,10 +276,11 @@ class SmartA2A:
|
|
229
276
|
)
|
230
277
|
|
231
278
|
|
232
|
-
async def _handle_subscribe_task(self, request_data:
|
279
|
+
async def _handle_subscribe_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> Union[EventSourceResponse, SendTaskStreamingResponse]:
|
233
280
|
try:
|
234
|
-
request = SendTaskStreamingRequest.model_validate(request_data)
|
235
|
-
handler = self.subscriptions.get("tasks/sendSubscribe")
|
281
|
+
request = SendTaskStreamingRequest.model_validate(request_data.model_dump())
|
282
|
+
#handler = self.subscriptions.get("tasks/sendSubscribe")
|
283
|
+
handler = self.registry.get_subscription("tasks/sendSubscribe")
|
236
284
|
|
237
285
|
if not handler:
|
238
286
|
return SendTaskStreamingResponse(
|
@@ -240,15 +288,74 @@ class SmartA2A:
|
|
240
288
|
id=request.id,
|
241
289
|
error=MethodNotFoundError()
|
242
290
|
)
|
291
|
+
|
292
|
+
user_message = request.params.message
|
293
|
+
request_metadata = request.params.metadata or {}
|
294
|
+
if state_data:
|
295
|
+
session_id = state_data.sessionId
|
296
|
+
existing_history = state_data.history.copy() or []
|
297
|
+
metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
|
298
|
+
else:
|
299
|
+
session_id = request.params.sessionId or str(uuid4())
|
300
|
+
existing_history = [user_message]
|
301
|
+
metadata = request_metadata
|
302
|
+
|
243
303
|
|
244
304
|
async def event_generator():
|
245
305
|
|
246
306
|
try:
|
247
|
-
|
307
|
+
|
308
|
+
if state_data:
|
309
|
+
raw_events = handler(request, state_data)
|
310
|
+
else:
|
311
|
+
raw_events = handler(request)
|
312
|
+
|
248
313
|
normalized_events = self._normalize_subscription_events(request.params, raw_events)
|
249
314
|
|
315
|
+
# Initialize streaming state
|
316
|
+
stream_history = existing_history.copy()
|
317
|
+
stream_metadata = metadata.copy()
|
318
|
+
|
250
319
|
async for item in normalized_events:
|
251
320
|
try:
|
321
|
+
|
322
|
+
# Process artifact updates
|
323
|
+
if isinstance(item, TaskArtifactUpdateEvent):
|
324
|
+
# Create agent message from artifact parts
|
325
|
+
agent_message = Message(
|
326
|
+
role="agent",
|
327
|
+
parts=[p for p in item.artifact.parts],
|
328
|
+
metadata=item.artifact.metadata
|
329
|
+
)
|
330
|
+
|
331
|
+
# Update history using strategy
|
332
|
+
new_history = self.history_strategy.update_history(
|
333
|
+
existing_history=stream_history,
|
334
|
+
new_messages=[agent_message]
|
335
|
+
)
|
336
|
+
|
337
|
+
# Merge metadata
|
338
|
+
new_metadata = {
|
339
|
+
**stream_metadata,
|
340
|
+
**(item.artifact.metadata or {})
|
341
|
+
}
|
342
|
+
|
343
|
+
# Update state store if configured
|
344
|
+
if self.state_store:
|
345
|
+
self.state_store.update_state(
|
346
|
+
session_id=session_id,
|
347
|
+
state_data=StateData(
|
348
|
+
sessionId=session_id,
|
349
|
+
history=new_history,
|
350
|
+
metadata=new_metadata
|
351
|
+
)
|
352
|
+
)
|
353
|
+
|
354
|
+
# Update streaming state
|
355
|
+
stream_history = new_history
|
356
|
+
stream_metadata = new_metadata
|
357
|
+
|
358
|
+
|
252
359
|
if isinstance(item, SendTaskStreamingResponse):
|
253
360
|
yield item.model_dump_json()
|
254
361
|
continue
|
@@ -318,11 +425,11 @@ class SmartA2A:
|
|
318
425
|
)
|
319
426
|
|
320
427
|
|
321
|
-
def _handle_get_task(self, request_data:
|
428
|
+
def _handle_get_task(self, request_data: JSONRPCRequest) -> GetTaskResponse:
|
322
429
|
try:
|
323
430
|
# Validate request structure
|
324
|
-
request = GetTaskRequest.model_validate(request_data)
|
325
|
-
handler = self.
|
431
|
+
request = GetTaskRequest.model_validate(request_data.model_dump())
|
432
|
+
handler = self.registry.get_handler("tasks/get")
|
326
433
|
|
327
434
|
if not handler:
|
328
435
|
return GetTaskResponse(
|
@@ -337,10 +444,9 @@ class SmartA2A:
|
|
337
444
|
return self._validate_response_id(raw_result, request)
|
338
445
|
|
339
446
|
# Use unified task builder with different defaults
|
340
|
-
task = self.
|
447
|
+
task = self.task_builder.build(
|
341
448
|
content=raw_result,
|
342
449
|
task_id=request.params.id,
|
343
|
-
default_status=TaskState.COMPLETED,
|
344
450
|
metadata=request.params.metadata or {}
|
345
451
|
)
|
346
452
|
|
@@ -370,11 +476,11 @@ class SmartA2A:
|
|
370
476
|
)
|
371
477
|
|
372
478
|
|
373
|
-
def _handle_cancel_task(self, request_data:
|
479
|
+
def _handle_cancel_task(self, request_data: JSONRPCRequest) -> CancelTaskResponse:
|
374
480
|
try:
|
375
481
|
# Validate request structure
|
376
|
-
request = CancelTaskRequest.model_validate(request_data)
|
377
|
-
handler = self.
|
482
|
+
request = CancelTaskRequest.model_validate(request_data.model_dump())
|
483
|
+
handler = self.registry.get_handler("tasks/cancel")
|
378
484
|
|
379
485
|
if not handler:
|
380
486
|
return CancelTaskResponse(
|
@@ -391,14 +497,14 @@ class SmartA2A:
|
|
391
497
|
|
392
498
|
# Handle A2AStatus returns
|
393
499
|
if isinstance(raw_result, A2AStatus):
|
394
|
-
task = self.
|
395
|
-
status=raw_result,
|
500
|
+
task = self.task_builder.normalize_from_status(
|
501
|
+
status=raw_result.status,
|
396
502
|
task_id=request.params.id,
|
397
503
|
metadata=raw_result.metadata or {}
|
398
504
|
)
|
399
505
|
else:
|
400
506
|
# Existing processing for other return types
|
401
|
-
task = self.
|
507
|
+
task = self.task_builder.build(
|
402
508
|
content=raw_result,
|
403
509
|
task_id=request.params.id,
|
404
510
|
metadata=raw_result.metadata or {}
|
@@ -440,10 +546,10 @@ class SmartA2A:
|
|
440
546
|
error=InternalError(data=str(e))
|
441
547
|
)
|
442
548
|
|
443
|
-
def _handle_set_notification(self, request_data:
|
549
|
+
def _handle_set_notification(self, request_data: JSONRPCRequest) -> SetTaskPushNotificationResponse:
|
444
550
|
try:
|
445
|
-
request = SetTaskPushNotificationRequest.model_validate(request_data)
|
446
|
-
handler = self.
|
551
|
+
request = SetTaskPushNotificationRequest.model_validate(request_data.model_dump())
|
552
|
+
handler = self.registry.get_handler("tasks/pushNotification/set")
|
447
553
|
|
448
554
|
if not handler:
|
449
555
|
return SetTaskPushNotificationResponse(
|
@@ -485,10 +591,10 @@ class SmartA2A:
|
|
485
591
|
)
|
486
592
|
|
487
593
|
|
488
|
-
def _handle_get_notification(self, request_data:
|
594
|
+
def _handle_get_notification(self, request_data: JSONRPCRequest) -> GetTaskPushNotificationResponse:
|
489
595
|
try:
|
490
|
-
request = GetTaskPushNotificationRequest.model_validate(request_data)
|
491
|
-
handler = self.
|
596
|
+
request = GetTaskPushNotificationRequest.model_validate(request_data.model_dump())
|
597
|
+
handler = self.registry.get_handler("tasks/pushNotification/get")
|
492
598
|
|
493
599
|
if not handler:
|
494
600
|
return GetTaskPushNotificationResponse(
|
@@ -535,139 +641,6 @@ class SmartA2A:
|
|
535
641
|
error=JSONParseError(data=str(e))
|
536
642
|
)
|
537
643
|
|
538
|
-
|
539
|
-
def _normalize_artifacts(self, content: Any) -> List[Artifact]:
|
540
|
-
"""Handle both A2AResponse content and regular returns"""
|
541
|
-
if isinstance(content, Artifact):
|
542
|
-
return [content]
|
543
|
-
|
544
|
-
if isinstance(content, list):
|
545
|
-
# Handle list of artifacts
|
546
|
-
if all(isinstance(item, Artifact) for item in content):
|
547
|
-
return content
|
548
|
-
|
549
|
-
# Handle mixed parts in list
|
550
|
-
parts = []
|
551
|
-
for item in content:
|
552
|
-
if isinstance(item, Artifact):
|
553
|
-
parts.extend(item.parts)
|
554
|
-
else:
|
555
|
-
parts.append(self._create_part(item))
|
556
|
-
return [Artifact(parts=parts)]
|
557
|
-
|
558
|
-
# Handle single part returns
|
559
|
-
if isinstance(content, (str, Part, dict)):
|
560
|
-
return [Artifact(parts=[self._create_part(content)])]
|
561
|
-
|
562
|
-
# Handle raw artifact dicts
|
563
|
-
try:
|
564
|
-
return [Artifact.model_validate(content)]
|
565
|
-
except ValidationError:
|
566
|
-
return [Artifact(parts=[TextPart(text=str(content))])]
|
567
|
-
|
568
|
-
|
569
|
-
def _build_task(
|
570
|
-
self,
|
571
|
-
content: Any,
|
572
|
-
task_id: str,
|
573
|
-
session_id: Optional[str] = None,
|
574
|
-
default_status: TaskState = TaskState.COMPLETED,
|
575
|
-
metadata: Optional[dict] = None
|
576
|
-
) -> Task:
|
577
|
-
"""Universal task construction from various return types."""
|
578
|
-
if isinstance(content, Task):
|
579
|
-
return content
|
580
|
-
|
581
|
-
# Handle A2AResponse for sendTask case
|
582
|
-
if isinstance(content, A2AResponse):
|
583
|
-
status = content.status if isinstance(content.status, TaskStatus) \
|
584
|
-
else TaskStatus(state=content.status)
|
585
|
-
artifacts = self._normalize_content(content.content)
|
586
|
-
return Task(
|
587
|
-
id=task_id,
|
588
|
-
sessionId=session_id or str(uuid4()), # Generate if missing
|
589
|
-
status=status,
|
590
|
-
artifacts=artifacts,
|
591
|
-
metadata=metadata or {}
|
592
|
-
)
|
593
|
-
|
594
|
-
try: # Attempt direct validation for dicts
|
595
|
-
return Task.model_validate(content)
|
596
|
-
except ValidationError:
|
597
|
-
pass
|
598
|
-
|
599
|
-
# Fallback to content normalization
|
600
|
-
artifacts = self._normalize_content(content)
|
601
|
-
return Task(
|
602
|
-
id=task_id,
|
603
|
-
sessionId=session_id,
|
604
|
-
status=TaskStatus(state=default_status),
|
605
|
-
artifacts=artifacts,
|
606
|
-
metadata=metadata or {}
|
607
|
-
)
|
608
|
-
|
609
|
-
def _build_task_from_status(self, status: A2AStatus, task_id: str, metadata: dict) -> Task:
|
610
|
-
"""Convert A2AStatus to a Task with proper cancellation state."""
|
611
|
-
return Task(
|
612
|
-
id=task_id,
|
613
|
-
status=TaskStatus(
|
614
|
-
state=TaskState(status.status),
|
615
|
-
timestamp=datetime.now()
|
616
|
-
),
|
617
|
-
metadata=metadata,
|
618
|
-
# Include empty/default values for required fields
|
619
|
-
sessionId="",
|
620
|
-
artifacts=[],
|
621
|
-
history=[]
|
622
|
-
)
|
623
|
-
|
624
|
-
|
625
|
-
def _normalize_content(self, content: Any) -> List[Artifact]:
|
626
|
-
"""Handle all content types for both sendTask and getTask cases."""
|
627
|
-
if isinstance(content, Artifact):
|
628
|
-
return [content]
|
629
|
-
|
630
|
-
if isinstance(content, list):
|
631
|
-
if all(isinstance(item, Artifact) for item in content):
|
632
|
-
return content
|
633
|
-
return [Artifact(parts=self._parts_from_mixed(content))]
|
634
|
-
|
635
|
-
if isinstance(content, (str, Part, dict)):
|
636
|
-
return [Artifact(parts=[self._create_part(content)])]
|
637
|
-
|
638
|
-
try: # Handle raw artifact dicts
|
639
|
-
return [Artifact.model_validate(content)]
|
640
|
-
except ValidationError:
|
641
|
-
return [Artifact(parts=[TextPart(text=str(content))])]
|
642
|
-
|
643
|
-
def _parts_from_mixed(self, items: List[Any]) -> List[Part]:
|
644
|
-
"""Extract parts from mixed content lists."""
|
645
|
-
parts = []
|
646
|
-
for item in items:
|
647
|
-
if isinstance(item, Artifact):
|
648
|
-
parts.extend(item.parts)
|
649
|
-
else:
|
650
|
-
parts.append(self._create_part(item))
|
651
|
-
return parts
|
652
|
-
|
653
|
-
|
654
|
-
def _create_part(self, item: Any) -> Part:
|
655
|
-
"""Convert primitive types to proper Part models"""
|
656
|
-
if isinstance(item, (TextPart, FilePart, DataPart)):
|
657
|
-
return item
|
658
|
-
|
659
|
-
if isinstance(item, str):
|
660
|
-
return TextPart(text=item)
|
661
|
-
|
662
|
-
if isinstance(item, dict):
|
663
|
-
try:
|
664
|
-
return Part.model_validate(item)
|
665
|
-
except ValidationError:
|
666
|
-
return TextPart(text=str(item))
|
667
|
-
|
668
|
-
return TextPart(text=str(item))
|
669
|
-
|
670
|
-
|
671
644
|
# Response validation helper
|
672
645
|
def _validate_response_id(self, response: Union[SendTaskResponse, GetTaskResponse], request) -> Union[SendTaskResponse, GetTaskResponse]:
|
673
646
|
if response.result and response.result.id != request.params.id:
|
@@ -678,7 +651,7 @@ class SmartA2A:
|
|
678
651
|
)
|
679
652
|
)
|
680
653
|
return response
|
681
|
-
|
654
|
+
|
682
655
|
# Might refactor this later
|
683
656
|
def _finalize_task_response(self, request: GetTaskRequest, task: Task) -> GetTaskResponse:
|
684
657
|
"""Final validation and processing for getTask responses."""
|
@@ -721,8 +694,8 @@ class SmartA2A:
|
|
721
694
|
id=request.id,
|
722
695
|
result=task
|
723
696
|
)
|
724
|
-
|
725
|
-
|
697
|
+
|
698
|
+
|
726
699
|
async def _normalize_subscription_events(self, params: TaskSendParams, events: AsyncGenerator) -> AsyncGenerator[Union[SendTaskStreamingResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent], None]:
|
727
700
|
artifact_state = defaultdict(lambda: {"index": 0, "last_chunk": False})
|
728
701
|
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Optional, Dict, Any
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
# Local imports
|
6
|
+
from smarta2a.state_stores.base_state_store import BaseStateStore
|
7
|
+
from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
|
8
|
+
from smarta2a.utils.types import Message, StateData
|
9
|
+
|
10
|
+
class StateManager:
|
11
|
+
def __init__(self, store: Optional[BaseStateStore], history_strategy: HistoryUpdateStrategy):
|
12
|
+
self.store = store
|
13
|
+
self.strategy = history_strategy
|
14
|
+
|
15
|
+
def init_or_get(self, session_id: Optional[str], message: Message, metadata: Dict[str, Any]) -> StateData:
|
16
|
+
sid = session_id or str(uuid4())
|
17
|
+
if not self.store:
|
18
|
+
return StateData(sessionId=sid, history=[message], metadata=metadata or {})
|
19
|
+
existing = self.store.get_state(sid) or StateData(sessionId=sid, history=[], metadata={})
|
20
|
+
existing.history.append(message)
|
21
|
+
existing.metadata = {**(existing.metadata or {}), **(metadata or {})}
|
22
|
+
self.store.update_state(sid, existing)
|
23
|
+
return existing
|
24
|
+
|
25
|
+
def update(self, state: StateData):
|
26
|
+
if self.store:
|
27
|
+
self.store.update_state(state.sessionId, state)
|
28
|
+
|
29
|
+
def get_store(self) -> Optional[BaseStateStore]:
|
30
|
+
return self.store
|
31
|
+
|
32
|
+
def get_strategy(self) -> HistoryUpdateStrategy:
|
33
|
+
return self.strategy
|
34
|
+
|