smarta2a 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smarta2a/agent/a2a_agent.py +25 -15
- smarta2a/agent/a2a_human.py +56 -0
- smarta2a/archive/smart_mcp_client.py +47 -0
- smarta2a/archive/subscription_service.py +85 -0
- smarta2a/{server → archive}/task_service.py +17 -8
- smarta2a/client/a2a_client.py +35 -8
- smarta2a/client/mcp_client.py +3 -0
- smarta2a/history_update_strategies/rolling_window_strategy.py +16 -0
- smarta2a/model_providers/__init__.py +1 -1
- smarta2a/model_providers/base_llm_provider.py +3 -3
- smarta2a/model_providers/openai_provider.py +126 -89
- smarta2a/nats-server.conf +12 -0
- smarta2a/server/json_rpc_request_processor.py +130 -0
- smarta2a/server/nats_client.py +49 -0
- smarta2a/server/request_handler.py +667 -0
- smarta2a/server/send_task_handler.py +174 -0
- smarta2a/server/server.py +124 -726
- smarta2a/server/state_manager.py +173 -19
- smarta2a/server/webhook_request_processor.py +112 -0
- smarta2a/state_stores/base_state_store.py +3 -3
- smarta2a/state_stores/inmemory_state_store.py +21 -7
- smarta2a/utils/agent_discovery_manager.py +121 -0
- smarta2a/utils/prompt_helpers.py +1 -1
- smarta2a/utils/tools_manager.py +108 -0
- smarta2a/utils/types.py +18 -3
- smarta2a-0.4.0.dist-info/METADATA +402 -0
- smarta2a-0.4.0.dist-info/RECORD +41 -0
- smarta2a-0.4.0.dist-info/licenses/LICENSE +35 -0
- smarta2a/client/tools_manager.py +0 -62
- smarta2a/examples/__init__.py +0 -0
- smarta2a/examples/echo_server/__init__.py +0 -0
- smarta2a/examples/echo_server/curl.txt +0 -1
- smarta2a/examples/echo_server/main.py +0 -39
- smarta2a/examples/openai_delegator_agent/__init__.py +0 -0
- smarta2a/examples/openai_delegator_agent/main.py +0 -41
- smarta2a/examples/openai_weather_agent/__init__.py +0 -0
- smarta2a/examples/openai_weather_agent/main.py +0 -32
- smarta2a/server/subscription_service.py +0 -109
- smarta2a-0.3.0.dist-info/METADATA +0 -103
- smarta2a-0.3.0.dist-info/RECORD +0 -40
- smarta2a-0.3.0.dist-info/licenses/LICENSE +0 -21
- {smarta2a-0.3.0.dist-info → smarta2a-0.4.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,667 @@
|
|
1
|
+
# Library imports
|
2
|
+
import json
|
3
|
+
from typing import Optional, Union, AsyncGenerator
|
4
|
+
from uuid import uuid4
|
5
|
+
from datetime import datetime
|
6
|
+
from collections import defaultdict
|
7
|
+
from fastapi.responses import StreamingResponse
|
8
|
+
from sse_starlette.sse import EventSourceResponse
|
9
|
+
|
10
|
+
# Local imports
|
11
|
+
from smarta2a.utils.types import (
|
12
|
+
SendTaskRequest,
|
13
|
+
SendTaskStreamingRequest,
|
14
|
+
GetTaskRequest,
|
15
|
+
CancelTaskRequest,
|
16
|
+
SetTaskPushNotificationRequest,
|
17
|
+
GetTaskPushNotificationRequest,
|
18
|
+
SendTaskResponse,
|
19
|
+
SendTaskStreamingResponse,
|
20
|
+
GetTaskResponse,
|
21
|
+
CancelTaskResponse,
|
22
|
+
SetTaskPushNotificationResponse,
|
23
|
+
GetTaskPushNotificationResponse,
|
24
|
+
Task,
|
25
|
+
TaskStatus,
|
26
|
+
TaskState,
|
27
|
+
Message,
|
28
|
+
StateData,
|
29
|
+
TaskSendParams,
|
30
|
+
TaskPushNotificationConfig,
|
31
|
+
A2AStatus,
|
32
|
+
A2AStreamResponse,
|
33
|
+
TextPart,
|
34
|
+
FilePart,
|
35
|
+
DataPart,
|
36
|
+
Artifact,
|
37
|
+
FileContent,
|
38
|
+
TaskStatusUpdateEvent,
|
39
|
+
TaskArtifactUpdateEvent
|
40
|
+
)
|
41
|
+
from smarta2a.utils.types import (
|
42
|
+
TaskNotFoundError,
|
43
|
+
MethodNotFoundError,
|
44
|
+
InvalidParamsError,
|
45
|
+
InternalError,
|
46
|
+
JSONRPCError,
|
47
|
+
TaskNotCancelableError
|
48
|
+
)
|
49
|
+
from smarta2a.utils.task_builder import TaskBuilder
|
50
|
+
from smarta2a.server.handler_registry import HandlerRegistry
|
51
|
+
from smarta2a.server.state_manager import StateManager
|
52
|
+
from smarta2a.client.a2a_client import A2AClient
|
53
|
+
|
54
|
+
class RequestHandler:
|
55
|
+
def __init__(self, registry: HandlerRegistry, state_manager: Optional[StateManager] = None):
|
56
|
+
self.registry = registry
|
57
|
+
self.task_builder = TaskBuilder(default_status=TaskState.COMPLETED)
|
58
|
+
self.state_manager = state_manager
|
59
|
+
self.a2a_aclient = A2AClient()
|
60
|
+
|
61
|
+
async def handle_send_task(self, request: SendTaskRequest, state_data: Optional[StateData] = None) -> SendTaskResponse:
|
62
|
+
try:
|
63
|
+
handler = self.registry.get_handler("tasks/send")
|
64
|
+
|
65
|
+
if not handler:
|
66
|
+
return SendTaskResponse(
|
67
|
+
id=request.id,
|
68
|
+
error=MethodNotFoundError()
|
69
|
+
)
|
70
|
+
# Get the forward_to_webhook flag from the handler
|
71
|
+
forward_to_webhook = handler.forward_to_webhook
|
72
|
+
|
73
|
+
# Extract parameters from request
|
74
|
+
task_id = request.params.id
|
75
|
+
session_id = request.params.sessionId or str(uuid4())
|
76
|
+
raw = request.params.message
|
77
|
+
user_message = Message.model_validate(raw)
|
78
|
+
request_metadata = request.params.metadata or {}
|
79
|
+
push_notification_config = request.params.pushNotification
|
80
|
+
if state_data:
|
81
|
+
task_history = state_data.task.history.copy() or []
|
82
|
+
context_history = state_data.context_history.copy() or []
|
83
|
+
metadata = state_data.task.metadata or {}
|
84
|
+
push_notification_config = push_notification_config or state_data.push_notification_config
|
85
|
+
else:
|
86
|
+
# There is no state manager, so we need to build a task from scratch
|
87
|
+
task = Task(
|
88
|
+
id=task_id,
|
89
|
+
sessionId=session_id,
|
90
|
+
status=TaskStatus(state=TaskState.WORKING),
|
91
|
+
history=[user_message],
|
92
|
+
metadata=request_metadata
|
93
|
+
)
|
94
|
+
task_history = task.history.copy()
|
95
|
+
metadata = request_metadata.copy()
|
96
|
+
|
97
|
+
if state_data:
|
98
|
+
# Call handler with state data
|
99
|
+
raw_result = await handler(request, state_data)
|
100
|
+
else:
|
101
|
+
# Call handler without state data
|
102
|
+
raw_result = await handler(request)
|
103
|
+
|
104
|
+
# Handle direct SendTaskResponse returns
|
105
|
+
if isinstance(raw_result, SendTaskResponse):
|
106
|
+
return raw_result
|
107
|
+
|
108
|
+
# Build task with updated history (before agent response)
|
109
|
+
# SmartA2A overwrites the artifacts each time a new task is built.
|
110
|
+
# This is beause it assumes the last artifact is what matters.
|
111
|
+
# Also the the history (derived from the artifacts) contains all the messages anyway
|
112
|
+
task = self.task_builder.build(
|
113
|
+
content=raw_result,
|
114
|
+
task_id=task_id,
|
115
|
+
session_id=session_id,
|
116
|
+
metadata=metadata,
|
117
|
+
history=task_history
|
118
|
+
)
|
119
|
+
|
120
|
+
# Process messages through strategy
|
121
|
+
messages = []
|
122
|
+
if task.artifacts:
|
123
|
+
agent_parts = [p for a in task.artifacts for p in a.parts]
|
124
|
+
agent_message = Message(
|
125
|
+
role="agent",
|
126
|
+
parts=agent_parts,
|
127
|
+
metadata=task.metadata
|
128
|
+
)
|
129
|
+
messages.append(agent_message)
|
130
|
+
|
131
|
+
# Update Task history with a simple append
|
132
|
+
task_history.extend(messages)
|
133
|
+
|
134
|
+
if state_data:
|
135
|
+
# Update context history with a strategy - this is the history that will be passed to an LLM call
|
136
|
+
history_strategy = self.state_manager.get_history_strategy()
|
137
|
+
context_history = history_strategy.update_history(
|
138
|
+
existing_history=context_history,
|
139
|
+
new_messages=messages
|
140
|
+
)
|
141
|
+
|
142
|
+
# Update task with final state
|
143
|
+
task.history = task_history
|
144
|
+
|
145
|
+
|
146
|
+
# State store update (if enabled)
|
147
|
+
if state_data:
|
148
|
+
await self.state_manager.update_state(
|
149
|
+
state_data=StateData(
|
150
|
+
task_id=task_id,
|
151
|
+
task=task,
|
152
|
+
context_history=context_history,
|
153
|
+
push_notification_config=push_notification_config if push_notification_config else state_data.push_notification_config,
|
154
|
+
)
|
155
|
+
)
|
156
|
+
|
157
|
+
# If push_notification_config is set send the task to the push notification url
|
158
|
+
if push_notification_config and forward_to_webhook:
|
159
|
+
try:
|
160
|
+
self.a2a_aclient.send_to_webhook(webhook_url=push_notification_config.url,id=task_id,task=task.model_dump())
|
161
|
+
except Exception as e:
|
162
|
+
print(f"Error sending task to webhook: {e}")
|
163
|
+
|
164
|
+
|
165
|
+
# Send the task back to the client
|
166
|
+
return SendTaskResponse(
|
167
|
+
id=request.id,
|
168
|
+
result=task
|
169
|
+
)
|
170
|
+
except Exception as e:
|
171
|
+
# Handle case where handler returns SendTaskResponse with error
|
172
|
+
if isinstance(e, JSONRPCError):
|
173
|
+
return SendTaskResponse(
|
174
|
+
id=request.id,
|
175
|
+
error=e
|
176
|
+
)
|
177
|
+
return SendTaskResponse(
|
178
|
+
id=request.id,
|
179
|
+
error=InternalError(data=str(e))
|
180
|
+
)
|
181
|
+
|
182
|
+
|
183
|
+
|
184
|
+
async def handle_subscribe_task(self, request: SendTaskStreamingRequest, state_data: Optional[StateData] = None) -> Union[EventSourceResponse, SendTaskStreamingResponse]:
|
185
|
+
|
186
|
+
handler = self.registry.get_subscription("tasks/sendSubscribe")
|
187
|
+
|
188
|
+
if not handler:
|
189
|
+
return SendTaskStreamingResponse(
|
190
|
+
jsonrpc="2.0",
|
191
|
+
id=request.id,
|
192
|
+
error=MethodNotFoundError()
|
193
|
+
)
|
194
|
+
|
195
|
+
# Get the forward_to_webhook flag from the handler
|
196
|
+
forward_to_webhook = handler.forward_to_webhook
|
197
|
+
|
198
|
+
# Extract parameters from request
|
199
|
+
task_id = request.params.id
|
200
|
+
session_id = request.params.sessionId or str(uuid4())
|
201
|
+
raw = request.params.message
|
202
|
+
user_message = Message.model_validate(raw)
|
203
|
+
request_metadata = request.params.metadata or {}
|
204
|
+
push_notification_config = request.params.pushNotification
|
205
|
+
|
206
|
+
if state_data:
|
207
|
+
task = state_data.task
|
208
|
+
task_history = task.history.copy() or []
|
209
|
+
context_history = state_data.context_history.copy() or []
|
210
|
+
metadata = state_data.task.metadata or {} # Request metadata has already been merged so no need to do it here
|
211
|
+
push_notification_config = push_notification_config or state_data.push_notification_config
|
212
|
+
else:
|
213
|
+
task = Task(
|
214
|
+
id=task_id,
|
215
|
+
sessionId=session_id,
|
216
|
+
status=TaskStatus(state=TaskState.WORKING),
|
217
|
+
artifacts=[],
|
218
|
+
history=[user_message],
|
219
|
+
metadata=request_metadata
|
220
|
+
)
|
221
|
+
task_history = task.history.copy()
|
222
|
+
metadata = request_metadata
|
223
|
+
|
224
|
+
|
225
|
+
async def event_generator():
|
226
|
+
|
227
|
+
try:
|
228
|
+
|
229
|
+
if state_data:
|
230
|
+
raw_events = handler(request, state_data)
|
231
|
+
else:
|
232
|
+
raw_events = handler(request)
|
233
|
+
|
234
|
+
normalized_events = self._normalize_subscription_events(request.params, raw_events)
|
235
|
+
|
236
|
+
# Initialize streaming state
|
237
|
+
task_stream_history = task_history.copy()
|
238
|
+
stream_metadata = metadata.copy()
|
239
|
+
if state_data:
|
240
|
+
context_stream_history = context_history.copy()
|
241
|
+
|
242
|
+
# Get history strategy and state store from state manager
|
243
|
+
if state_data:
|
244
|
+
history_strategy = self.state_manager.get_history_strategy()
|
245
|
+
|
246
|
+
async for item in normalized_events:
|
247
|
+
try:
|
248
|
+
|
249
|
+
# Process artifact updates
|
250
|
+
if isinstance(item, TaskArtifactUpdateEvent):
|
251
|
+
# Create agent message from artifact parts
|
252
|
+
agent_message = Message(
|
253
|
+
role="agent",
|
254
|
+
parts=[p for p in item.artifact.parts],
|
255
|
+
metadata=item.artifact.metadata
|
256
|
+
)
|
257
|
+
|
258
|
+
# Update task history with a simple append
|
259
|
+
new_task_history = task_stream_history + [agent_message]
|
260
|
+
|
261
|
+
# Update contexthistory using strategy
|
262
|
+
if state_data:
|
263
|
+
new_context_history = history_strategy.update_history(
|
264
|
+
existing_history=context_stream_history,
|
265
|
+
new_messages=[agent_message]
|
266
|
+
)
|
267
|
+
|
268
|
+
# Merge metadata
|
269
|
+
new_metadata = {
|
270
|
+
**stream_metadata,
|
271
|
+
**(item.artifact.metadata or {})
|
272
|
+
}
|
273
|
+
|
274
|
+
# Update task with new artifact and metadata
|
275
|
+
task.artifacts.append(item.artifact)
|
276
|
+
task.metadata = new_metadata
|
277
|
+
task.history = new_task_history
|
278
|
+
|
279
|
+
# Update state store if configured
|
280
|
+
if state_data:
|
281
|
+
await self.state_manager.update_state(
|
282
|
+
state_data=StateData(
|
283
|
+
task_id=task_id,
|
284
|
+
task=task,
|
285
|
+
context_history=new_context_history,
|
286
|
+
)
|
287
|
+
)
|
288
|
+
|
289
|
+
# Update streaming state
|
290
|
+
task_stream_history = new_task_history
|
291
|
+
if state_data:
|
292
|
+
context_stream_history = new_context_history
|
293
|
+
stream_metadata = new_metadata
|
294
|
+
|
295
|
+
# TODO: If push_notification_config is set, send the task to the push notification url
|
296
|
+
# If push_notification_config is set send the task to the push notification url
|
297
|
+
if push_notification_config and forward_to_webhook:
|
298
|
+
try:
|
299
|
+
self.a2a_aclient.send_to_webhook(webhook_url=push_notification_config.url,id=task_id,task=task)
|
300
|
+
except Exception as e:
|
301
|
+
print(f"Error sending task to webhook: {e}")
|
302
|
+
|
303
|
+
|
304
|
+
elif isinstance(item, TaskStatusUpdateEvent):
|
305
|
+
task.status = item.status
|
306
|
+
|
307
|
+
# Merge metadata
|
308
|
+
new_metadata = {
|
309
|
+
**stream_metadata
|
310
|
+
}
|
311
|
+
|
312
|
+
# Update task with new status and metadata
|
313
|
+
task.status = item.status
|
314
|
+
task.metadata = new_metadata
|
315
|
+
|
316
|
+
# Update state store if configured
|
317
|
+
if state_data:
|
318
|
+
await self.state_manager.update_state(
|
319
|
+
state_data=StateData(
|
320
|
+
task_id=task_id,
|
321
|
+
task=task,
|
322
|
+
context_history=context_stream_history
|
323
|
+
)
|
324
|
+
)
|
325
|
+
|
326
|
+
# Add validation for proper event types
|
327
|
+
else:
|
328
|
+
raise ValueError(f"Invalid event type: {type(item).__name__}")
|
329
|
+
|
330
|
+
yield SendTaskStreamingResponse(
|
331
|
+
jsonrpc="2.0",
|
332
|
+
id=request.id,
|
333
|
+
result=item
|
334
|
+
).model_dump_json()
|
335
|
+
|
336
|
+
except Exception as e:
|
337
|
+
yield SendTaskStreamingResponse(
|
338
|
+
jsonrpc="2.0",
|
339
|
+
id=request.id,
|
340
|
+
error=InternalError(data=str(e))
|
341
|
+
).model_dump_json()
|
342
|
+
|
343
|
+
|
344
|
+
except Exception as e:
|
345
|
+
error = InternalError(data=str(e))
|
346
|
+
if "not found" in str(e).lower():
|
347
|
+
error = TaskNotFoundError()
|
348
|
+
yield SendTaskStreamingResponse(
|
349
|
+
jsonrpc="2.0",
|
350
|
+
id=request.id,
|
351
|
+
error=error
|
352
|
+
).model_dump_json()
|
353
|
+
|
354
|
+
async def sse_stream():
|
355
|
+
async for chunk in event_generator():
|
356
|
+
# each chunk is already JSON; SSE wants "data: <payload>\n\n"
|
357
|
+
yield (f"data: {chunk}\n\n").encode("utf-8")
|
358
|
+
|
359
|
+
return StreamingResponse(
|
360
|
+
sse_stream(),
|
361
|
+
media_type="text/event-stream; charset=utf-8"
|
362
|
+
)
|
363
|
+
|
364
|
+
|
365
|
+
def handle_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
366
|
+
|
367
|
+
# Validate request structure
|
368
|
+
handler = self.registry.get_handler("tasks/get")
|
369
|
+
|
370
|
+
if not handler:
|
371
|
+
return GetTaskResponse(
|
372
|
+
id=request.id,
|
373
|
+
error=MethodNotFoundError()
|
374
|
+
)
|
375
|
+
|
376
|
+
try:
|
377
|
+
|
378
|
+
raw_result = handler(request)
|
379
|
+
|
380
|
+
if isinstance(raw_result, GetTaskResponse):
|
381
|
+
return self._validate_response_id(raw_result, request)
|
382
|
+
|
383
|
+
# Use unified task builder with different defaults
|
384
|
+
task = self.task_builder.build(
|
385
|
+
content=raw_result,
|
386
|
+
task_id=request.params.id,
|
387
|
+
metadata=getattr(raw_result, "metadata", {}) or {}
|
388
|
+
)
|
389
|
+
|
390
|
+
return self._finalize_task_response(request, task)
|
391
|
+
|
392
|
+
except Exception as e:
|
393
|
+
# Handle case where handler returns SendTaskResponse with error
|
394
|
+
if isinstance(e, JSONRPCError):
|
395
|
+
return GetTaskResponse(
|
396
|
+
id=request.id,
|
397
|
+
error=e
|
398
|
+
)
|
399
|
+
return GetTaskResponse(
|
400
|
+
id=request.id,
|
401
|
+
error=InternalError(data=str(e))
|
402
|
+
)
|
403
|
+
|
404
|
+
|
405
|
+
|
406
|
+
def handle_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
407
|
+
|
408
|
+
# Validate request structure
|
409
|
+
handler = self.registry.get_handler("tasks/cancel")
|
410
|
+
|
411
|
+
if not handler:
|
412
|
+
return CancelTaskResponse(
|
413
|
+
id=request.id,
|
414
|
+
error=MethodNotFoundError()
|
415
|
+
)
|
416
|
+
|
417
|
+
try:
|
418
|
+
raw_result = handler(request)
|
419
|
+
|
420
|
+
cancel_task_builder = TaskBuilder(default_status=TaskState.CANCELED)
|
421
|
+
# Handle direct CancelTaskResponse returns
|
422
|
+
if isinstance(raw_result, CancelTaskResponse):
|
423
|
+
return self._validate_response_id(raw_result, request)
|
424
|
+
|
425
|
+
# Handle A2AStatus returns
|
426
|
+
if isinstance(raw_result, A2AStatus):
|
427
|
+
task = cancel_task_builder.normalize_from_status(
|
428
|
+
status=raw_result.status,
|
429
|
+
task_id=request.params.id,
|
430
|
+
metadata=getattr(raw_result, "metadata", {}) or {}
|
431
|
+
)
|
432
|
+
else:
|
433
|
+
# Existing processing for other return types
|
434
|
+
task = cancel_task_builder.build(
|
435
|
+
content=raw_result,
|
436
|
+
task_id=request.params.id,
|
437
|
+
metadata=getattr(raw_result, "metadata", {}) or {}
|
438
|
+
)
|
439
|
+
|
440
|
+
# Final validation and packaging
|
441
|
+
return self._finalize_cancel_response(request, task)
|
442
|
+
|
443
|
+
except Exception as e:
|
444
|
+
# Handle case where handler returns SendTaskResponse with error
|
445
|
+
if isinstance(e, JSONRPCError):
|
446
|
+
return CancelTaskResponse(
|
447
|
+
id=request.id,
|
448
|
+
error=e
|
449
|
+
)
|
450
|
+
return CancelTaskResponse(
|
451
|
+
id=request.id,
|
452
|
+
error=InternalError(data=str(e))
|
453
|
+
)
|
454
|
+
|
455
|
+
def handle_set_notification(self, request: SetTaskPushNotificationRequest) -> SetTaskPushNotificationResponse:
|
456
|
+
|
457
|
+
handler = self.registry.get_handler("tasks/pushNotification/set")
|
458
|
+
|
459
|
+
if not handler:
|
460
|
+
return SetTaskPushNotificationResponse(
|
461
|
+
id=request.id,
|
462
|
+
error=MethodNotFoundError()
|
463
|
+
)
|
464
|
+
|
465
|
+
try:
|
466
|
+
# Execute handler (may or may not return something)
|
467
|
+
raw_result = handler(request)
|
468
|
+
|
469
|
+
# If handler returns nothing - build success response from request params
|
470
|
+
if raw_result is None:
|
471
|
+
return SetTaskPushNotificationResponse(
|
472
|
+
id=request.id,
|
473
|
+
result=request.params
|
474
|
+
)
|
475
|
+
|
476
|
+
# If handler returns a full response object
|
477
|
+
if isinstance(raw_result, SetTaskPushNotificationResponse):
|
478
|
+
return raw_result
|
479
|
+
|
480
|
+
except Exception as e:
|
481
|
+
if isinstance(e, JSONRPCError):
|
482
|
+
return SetTaskPushNotificationResponse(
|
483
|
+
id=request.id,
|
484
|
+
error=e
|
485
|
+
)
|
486
|
+
return SetTaskPushNotificationResponse(
|
487
|
+
id=request.id,
|
488
|
+
error=InternalError(data=str(e))
|
489
|
+
)
|
490
|
+
|
491
|
+
|
492
|
+
|
493
|
+
def handle_get_notification(self, request: GetTaskPushNotificationRequest) -> GetTaskPushNotificationResponse:
|
494
|
+
|
495
|
+
handler = self.registry.get_handler("tasks/pushNotification/get")
|
496
|
+
|
497
|
+
if not handler:
|
498
|
+
return GetTaskPushNotificationResponse(
|
499
|
+
id=request.id,
|
500
|
+
error=MethodNotFoundError()
|
501
|
+
)
|
502
|
+
|
503
|
+
try:
|
504
|
+
raw_result = handler(request)
|
505
|
+
|
506
|
+
if isinstance(raw_result, GetTaskPushNotificationResponse):
|
507
|
+
return raw_result
|
508
|
+
else:
|
509
|
+
# Validate raw_result as TaskPushNotificationConfig
|
510
|
+
config = TaskPushNotificationConfig.model_validate(raw_result)
|
511
|
+
return GetTaskPushNotificationResponse(
|
512
|
+
id=request.id,
|
513
|
+
result=config
|
514
|
+
)
|
515
|
+
|
516
|
+
except Exception as e:
|
517
|
+
if isinstance(e, JSONRPCError):
|
518
|
+
return GetTaskPushNotificationResponse(
|
519
|
+
id=request.id,
|
520
|
+
error=e
|
521
|
+
)
|
522
|
+
return GetTaskPushNotificationResponse(
|
523
|
+
id=request.id,
|
524
|
+
error=InternalError(data=str(e))
|
525
|
+
)
|
526
|
+
|
527
|
+
|
528
|
+
'''
|
529
|
+
Private methods beyond this point
|
530
|
+
'''
|
531
|
+
|
532
|
+
# Response validation helper
|
533
|
+
def _validate_response_id(self, response: Union[SendTaskResponse, GetTaskResponse], request) -> Union[SendTaskResponse, GetTaskResponse]:
|
534
|
+
if response.result and response.result.id != request.params.id:
|
535
|
+
return type(response)(
|
536
|
+
id=request.id,
|
537
|
+
error=InvalidParamsError(
|
538
|
+
data=f"Task ID mismatch: {response.result.id} vs {request.params.id}"
|
539
|
+
)
|
540
|
+
)
|
541
|
+
return response
|
542
|
+
|
543
|
+
# Might refactor this later
|
544
|
+
def _finalize_task_response(self, request: GetTaskRequest, task: Task) -> GetTaskResponse:
|
545
|
+
"""Final validation and processing for getTask responses."""
|
546
|
+
# Validate task ID matches request
|
547
|
+
if task.id != request.params.id:
|
548
|
+
return GetTaskResponse(
|
549
|
+
id=request.id,
|
550
|
+
error=InvalidParamsError(
|
551
|
+
data=f"Task ID mismatch: {task.id} vs {request.params.id}"
|
552
|
+
)
|
553
|
+
)
|
554
|
+
|
555
|
+
# Apply history length filtering
|
556
|
+
if request.params.historyLength and task.history:
|
557
|
+
task.history = task.history[-request.params.historyLength:]
|
558
|
+
|
559
|
+
return GetTaskResponse(
|
560
|
+
id=request.id,
|
561
|
+
result=task
|
562
|
+
)
|
563
|
+
|
564
|
+
def _finalize_cancel_response(self, request: CancelTaskRequest, task: Task) -> CancelTaskResponse:
|
565
|
+
"""Final validation and processing for cancel responses."""
|
566
|
+
if task.id != request.params.id:
|
567
|
+
return CancelTaskResponse(
|
568
|
+
id=request.id,
|
569
|
+
error=InvalidParamsError(
|
570
|
+
data=f"Task ID mismatch: {task.id} vs {request.params.id}"
|
571
|
+
)
|
572
|
+
)
|
573
|
+
|
574
|
+
# Ensure cancellation-specific requirements are met
|
575
|
+
if task.status.state not in [TaskState.CANCELED, TaskState.COMPLETED]:
|
576
|
+
return CancelTaskResponse(
|
577
|
+
id=request.id,
|
578
|
+
error=TaskNotCancelableError()
|
579
|
+
)
|
580
|
+
|
581
|
+
return CancelTaskResponse(
|
582
|
+
id=request.id,
|
583
|
+
result=task
|
584
|
+
)
|
585
|
+
|
586
|
+
|
587
|
+
async def _normalize_subscription_events(self, params: TaskSendParams, events: AsyncGenerator) -> AsyncGenerator[Union[SendTaskStreamingResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent], None]:
|
588
|
+
artifact_state = defaultdict(lambda: {"index": 0, "last_chunk": False})
|
589
|
+
|
590
|
+
async for item in events:
|
591
|
+
# Pass through fully formed responses immediately
|
592
|
+
if isinstance(item, SendTaskStreamingResponse):
|
593
|
+
yield item
|
594
|
+
continue
|
595
|
+
|
596
|
+
# Handle protocol status updates
|
597
|
+
if isinstance(item, A2AStatus):
|
598
|
+
yield TaskStatusUpdateEvent(
|
599
|
+
id=params.id,
|
600
|
+
status=TaskStatus(
|
601
|
+
state=TaskState(item.status),
|
602
|
+
timestamp=datetime.now()
|
603
|
+
),
|
604
|
+
final=item.final or (item.status.lower() == TaskState.COMPLETED),
|
605
|
+
metadata=item.metadata
|
606
|
+
)
|
607
|
+
|
608
|
+
# Handle stream content
|
609
|
+
elif isinstance(item, (A2AStreamResponse, str, bytes, TextPart, FilePart, DataPart, Artifact, list)):
|
610
|
+
# Convert to A2AStreamResponse if needed
|
611
|
+
if not isinstance(item, A2AStreamResponse):
|
612
|
+
item = A2AStreamResponse(content=item)
|
613
|
+
|
614
|
+
# Process content into parts
|
615
|
+
parts = []
|
616
|
+
content = item.content
|
617
|
+
|
618
|
+
if isinstance(content, str):
|
619
|
+
parts.append(TextPart(text=content))
|
620
|
+
elif isinstance(content, bytes):
|
621
|
+
parts.append(FilePart(file=FileContent(bytes=content)))
|
622
|
+
elif isinstance(content, (TextPart, FilePart, DataPart)):
|
623
|
+
parts.append(content)
|
624
|
+
elif isinstance(content, Artifact):
|
625
|
+
parts = content.parts
|
626
|
+
elif isinstance(content, list):
|
627
|
+
for elem in content:
|
628
|
+
if isinstance(elem, str):
|
629
|
+
parts.append(TextPart(text=elem))
|
630
|
+
elif isinstance(elem, (TextPart, FilePart, DataPart)):
|
631
|
+
parts.append(elem)
|
632
|
+
elif isinstance(elem, Artifact):
|
633
|
+
parts.extend(elem.parts)
|
634
|
+
|
635
|
+
# Track artifact state
|
636
|
+
artifact_idx = item.index
|
637
|
+
state = artifact_state[artifact_idx]
|
638
|
+
|
639
|
+
yield TaskArtifactUpdateEvent(
|
640
|
+
id=params.id,
|
641
|
+
artifact=Artifact(
|
642
|
+
parts=parts,
|
643
|
+
index=artifact_idx,
|
644
|
+
append=item.append or state["index"] == artifact_idx,
|
645
|
+
lastChunk=item.final or state["last_chunk"],
|
646
|
+
metadata=item.metadata
|
647
|
+
)
|
648
|
+
)
|
649
|
+
|
650
|
+
# Update artifact state tracking
|
651
|
+
if item.final:
|
652
|
+
state["last_chunk"] = True
|
653
|
+
state["index"] += 1
|
654
|
+
|
655
|
+
# Pass through protocol events directly
|
656
|
+
elif isinstance(item, (TaskStatusUpdateEvent, TaskArtifactUpdateEvent)):
|
657
|
+
yield item
|
658
|
+
|
659
|
+
# Handle invalid types
|
660
|
+
else:
|
661
|
+
yield SendTaskStreamingResponse(
|
662
|
+
jsonrpc="2.0",
|
663
|
+
id=params.id, # Typically comes from request, but using params.id as fallback
|
664
|
+
error=InvalidParamsError(
|
665
|
+
data=f"Unsupported event type: {type(item).__name__}"
|
666
|
+
)
|
667
|
+
)
|