smarta2a 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. smarta2a/agent/a2a_agent.py +25 -15
  2. smarta2a/agent/a2a_human.py +56 -0
  3. smarta2a/archive/smart_mcp_client.py +47 -0
  4. smarta2a/archive/subscription_service.py +85 -0
  5. smarta2a/{server → archive}/task_service.py +17 -8
  6. smarta2a/client/a2a_client.py +33 -6
  7. smarta2a/history_update_strategies/rolling_window_strategy.py +16 -0
  8. smarta2a/model_providers/__init__.py +1 -1
  9. smarta2a/model_providers/base_llm_provider.py +3 -3
  10. smarta2a/model_providers/openai_provider.py +126 -89
  11. smarta2a/server/json_rpc_request_processor.py +130 -0
  12. smarta2a/server/nats_client.py +49 -0
  13. smarta2a/server/request_handler.py +667 -0
  14. smarta2a/server/send_task_handler.py +174 -0
  15. smarta2a/server/server.py +124 -726
  16. smarta2a/server/state_manager.py +171 -20
  17. smarta2a/server/webhook_request_processor.py +112 -0
  18. smarta2a/state_stores/base_state_store.py +3 -3
  19. smarta2a/state_stores/inmemory_state_store.py +21 -7
  20. smarta2a/utils/agent_discovery_manager.py +121 -0
  21. smarta2a/utils/prompt_helpers.py +1 -1
  22. smarta2a/{client → utils}/tools_manager.py +39 -8
  23. smarta2a/utils/types.py +17 -3
  24. {smarta2a-0.3.1.dist-info → smarta2a-0.4.1.dist-info}/METADATA +7 -4
  25. smarta2a-0.4.1.dist-info/RECORD +40 -0
  26. smarta2a-0.4.1.dist-info/licenses/LICENSE +35 -0
  27. smarta2a/examples/__init__.py +0 -0
  28. smarta2a/examples/echo_server/__init__.py +0 -0
  29. smarta2a/examples/echo_server/curl.txt +0 -1
  30. smarta2a/examples/echo_server/main.py +0 -39
  31. smarta2a/examples/openai_airbnb_agent/__init__.py +0 -0
  32. smarta2a/examples/openai_airbnb_agent/main.py +0 -33
  33. smarta2a/examples/openai_delegator_agent/__init__.py +0 -0
  34. smarta2a/examples/openai_delegator_agent/main.py +0 -51
  35. smarta2a/examples/openai_weather_agent/__init__.py +0 -0
  36. smarta2a/examples/openai_weather_agent/main.py +0 -32
  37. smarta2a/server/subscription_service.py +0 -109
  38. smarta2a-0.3.1.dist-info/RECORD +0 -42
  39. smarta2a-0.3.1.dist-info/licenses/LICENSE +0 -21
  40. {smarta2a-0.3.1.dist-info → smarta2a-0.4.1.dist-info}/WHEEL +0 -0
smarta2a/server/server.py CHANGED
@@ -1,77 +1,53 @@
1
1
  # Library imports
2
- from typing import Callable, Any, Optional, Dict, Union, List, AsyncGenerator
3
- import json
4
- from datetime import datetime
5
- from collections import defaultdict
6
- from fastapi import FastAPI, Request, HTTPException, APIRouter
2
+ from typing import Callable, Any, Optional
3
+ from fastapi import FastAPI, Request, APIRouter
7
4
  from fastapi.middleware.cors import CORSMiddleware
8
5
  from sse_starlette.sse import EventSourceResponse
9
- from pydantic import ValidationError
10
6
  import uvicorn
11
7
  from fastapi.responses import StreamingResponse
12
- from uuid import uuid4
13
8
 
14
9
  # Local imports
15
10
  from smarta2a.server.handler_registry import HandlerRegistry
16
11
  from smarta2a.server.state_manager import StateManager
17
- from smarta2a.state_stores.base_state_store import BaseStateStore
18
- from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
19
- from smarta2a.history_update_strategies.append_strategy import AppendStrategy
20
12
  from smarta2a.utils.task_builder import TaskBuilder
13
+ from smarta2a.server.json_rpc_request_processor import JSONRPCRequestProcessor
14
+ from smarta2a.server.webhook_request_processor import WebhookRequestProcessor
21
15
 
22
16
  from smarta2a.utils.types import (
23
17
  JSONRPCResponse,
24
- Task,
25
- Artifact,
26
- TextPart,
27
- FilePart,
28
- FileContent,
29
- DataPart,
30
- Part,
31
- Message,
32
- TaskStatus,
33
18
  TaskState,
34
19
  JSONRPCError,
35
- SendTaskResponse,
36
20
  JSONRPCRequest,
37
- A2AResponse,
38
21
  SendTaskRequest,
39
22
  SendTaskStreamingRequest,
40
- SendTaskStreamingResponse,
41
23
  GetTaskRequest,
42
- GetTaskResponse,
43
24
  CancelTaskRequest,
44
- CancelTaskResponse,
45
- TaskStatusUpdateEvent,
46
- TaskArtifactUpdateEvent,
47
- JSONParseError,
48
- InvalidRequestError,
49
- MethodNotFoundError,
50
- InternalError,
51
- UnsupportedOperationError,
52
- TaskNotFoundError,
53
- InvalidParamsError,
54
- TaskNotCancelableError,
55
- A2AStatus,
56
- A2AStreamResponse,
57
- TaskSendParams,
58
25
  SetTaskPushNotificationRequest,
59
26
  GetTaskPushNotificationRequest,
60
- SetTaskPushNotificationResponse,
61
- GetTaskPushNotificationResponse,
62
- TaskPushNotificationConfig,
63
- StateData
27
+ SetTaskPushNotificationRequest,
28
+ GetTaskPushNotificationRequest,
29
+ StateData,
30
+ AgentCard,
31
+ WebhookRequest,
32
+ WebhookResponse
64
33
  )
65
34
 
66
35
  class SmartA2A:
67
- def __init__(self, name: str, state_store: Optional[BaseStateStore] = None, history_strategy: HistoryUpdateStrategy = AppendStrategy(), **fastapi_kwargs):
36
+ def __init__(self,
37
+ name: str,
38
+ agent_card: Optional[AgentCard] = None,
39
+ state_manager: Optional[StateManager] = None,
40
+ has_frontend: bool = False,
41
+ **fastapi_kwargs
42
+ ):
68
43
  self.name = name
69
44
  self.registry = HandlerRegistry()
70
- self.state_mgr = StateManager(state_store, history_strategy)
45
+ self.agent_card = agent_card
46
+ self.state_mgr = state_manager
71
47
  self.app = FastAPI(title=name, **fastapi_kwargs)
72
48
  self.router = APIRouter()
73
- self.state_store = state_store
74
- self.history_strategy = history_strategy
49
+ self.has_frontend = has_frontend
50
+ self._setup_cors()
75
51
  self._setup_routes()
76
52
  self.server_config = {
77
53
  "host": "0.0.0.0",
@@ -79,6 +55,7 @@ class SmartA2A:
79
55
  "reload": False
80
56
  }
81
57
  self.task_builder = TaskBuilder(default_status=TaskState.COMPLETED)
58
+ self.webhook_fn = None
82
59
 
83
60
  # Add this method to delegate ASGI calls
84
61
  async def __call__(self, scope, receive, send):
@@ -86,719 +63,140 @@ class SmartA2A:
86
63
 
87
64
  def on_event(self, event_name: str):
88
65
  return self.app.on_event(event_name)
89
-
90
- def on_send_task(self):
91
- def decorator(func: Callable[[SendTaskRequest, Optional[StateData]], Any]) -> Callable:
92
- self.registry.register("tasks/send", func)
93
- return func
94
- return decorator
95
-
96
- def on_send_subscribe_task(self):
97
- def decorator(fn: Callable[[SendTaskStreamingRequest, Optional[StateData]], Any]):
98
- self.registry.register("tasks/sendSubscribe", fn, subscription=True)
99
- return fn
100
- return decorator
101
66
 
102
- def task_get(self):
103
- def decorator(fn: Callable[[GetTaskRequest], Any]):
104
- self.registry.register("tasks/get", fn)
105
- return fn
106
- return decorator
67
+ def configure(self, **kwargs):
68
+ self.server_config.update(kwargs)
107
69
 
108
- def task_cancel(self):
109
- def decorator(fn: Callable[[CancelTaskRequest], Any]):
110
- self.registry.register("tasks/cancel", fn)
111
- return fn
112
- return decorator
70
+ def add_cors_middleware(self, **kwargs):
71
+ self.app.add_middleware(
72
+ CORSMiddleware,
73
+ **{k: v for k, v in kwargs.items() if v is not None}
74
+ )
113
75
 
114
- def set_notification(self):
115
- def decorator(fn: Callable[[SetTaskPushNotificationRequest], Any]):
116
- self.registry.register("tasks/pushNotification/set", fn)
117
- return fn
118
- return decorator
76
+ def run(self):
77
+ uvicorn.run(
78
+ self.app,
79
+ host=self.server_config["host"],
80
+ port=self.server_config["port"],
81
+ reload=self.server_config["reload"]
82
+ )
119
83
 
120
- def get_notification(self):
121
- def decorator(fn: Callable[[GetTaskPushNotificationRequest], Any]):
122
- self.registry.register("tasks/pushNotification/get", fn)
123
- return fn
124
- return decorator
125
-
84
+ def _setup_cors(self):
85
+ self.app.add_middleware(
86
+ CORSMiddleware,
87
+ allow_origins=["http://localhost:3000"],
88
+ allow_credentials=True,
89
+ allow_methods=["*"],
90
+ allow_headers=["*"],
91
+ )
126
92
 
127
93
  def _setup_routes(self):
128
- @self.app.post("/")
94
+ @self.app.on_event("startup")
95
+ async def on_startup():
96
+ if self.state_mgr:
97
+ await self.state_mgr.load()
98
+
99
+ @self.app.post("/rpc")
129
100
  async def handle_request(request: Request):
130
101
  try:
131
102
  data = await request.json()
132
103
  req = JSONRPCRequest.model_validate(data)
133
- #request_obj = JSONRPCRequest(**data)
134
104
  except Exception as e:
135
105
  return JSONRPCResponse(id=None, error=JSONRPCError(code=-32700, message="Parse error", data=str(e))).model_dump()
136
-
137
- response = await self.process_request(req)
106
+
107
+ #response = await self.process_request(req)
108
+ response = await JSONRPCRequestProcessor(self.registry, self.state_mgr).process_request(req)
138
109
 
139
110
  # <-- Accept both SSE‐style responses:
140
111
  if isinstance(response, (EventSourceResponse, StreamingResponse)):
141
112
  return response
142
-
113
+ print(response)
143
114
  # <-- Everything else is a normal pydantic JSONRPCResponse
144
115
  return response.model_dump()
145
-
146
-
147
- async def process_request(self, request: JSONRPCRequest) -> JSONRPCResponse:
148
116
 
149
- try:
150
- method = request.method
151
- params = request.params
152
- state_store = self.state_mgr.get_store()
153
- if method == "tasks/send":
154
- state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
155
- if state_store:
156
- return await self._handle_send_task(request, state_data)
157
- else:
158
- return await self._handle_send_task(request)
159
- elif method == "tasks/sendSubscribe":
160
- state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
161
- if state_store:
162
- return await self._handle_subscribe_task(request, state_data)
163
- else:
164
- return await self._handle_subscribe_task(request)
165
- elif method == "tasks/get":
166
- return self._handle_get_task(request)
167
- elif method == "tasks/cancel":
168
- return self._handle_cancel_task(request)
169
- elif method == "tasks/pushNotification/set":
170
- return self._handle_set_notification(request)
171
- elif method == "tasks/pushNotification/get":
172
- return self._handle_get_notification(request)
173
- else:
174
- return JSONRPCResponse(id=request.id, error=MethodNotFoundError()).model_dump()
175
- except ValidationError as e:
176
- return JSONRPCResponse(id=request.id, error=InvalidParamsError(data=e.errors())).model_dump()
177
- except HTTPException as e:
178
- err = UnsupportedOperationError() if e.status_code == 405 else InternalError(data=str(e))
179
- return JSONRPCResponse(id=request.id, error=err).model_dump()
180
-
181
-
182
- async def _handle_send_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> SendTaskResponse:
183
- try:
184
- # Validate request format
185
- request = SendTaskRequest.model_validate(request_data.model_dump())
186
- handler = self.registry.get_handler("tasks/send")
187
-
188
- if not handler:
189
- return SendTaskResponse(
190
- id=request.id,
191
- error=MethodNotFoundError()
192
- )
193
-
194
- user_message = request.params.message
195
- request_metadata = request.params.metadata or {}
196
- if state_data:
197
- session_id = state_data.sessionId
198
- existing_history = state_data.history.copy() or []
199
- metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
200
- else:
201
- session_id = request.params.sessionId or str(uuid4())
202
- existing_history = [user_message]
203
- metadata = request_metadata
204
-
205
-
206
- try:
207
-
208
- if state_data:
209
- raw_result = await handler(request, state_data)
210
- else:
211
- raw_result = await handler(request)
212
-
213
- # Handle direct SendTaskResponse returns
214
- if isinstance(raw_result, SendTaskResponse):
215
- return raw_result
216
-
217
- # Build task with updated history (before agent response)
218
- task = self.task_builder.build(
219
- content=raw_result,
220
- task_id=request.params.id,
221
- session_id=session_id, # Always use generated session ID
222
- metadata=metadata, # Use merged metadata
223
- history=existing_history # History
224
- )
225
-
226
- # Process messages through strategy
227
- messages = []
228
- if task.artifacts:
229
- agent_parts = [p for a in task.artifacts for p in a.parts]
230
- agent_message = Message(
231
- role="agent",
232
- parts=agent_parts,
233
- metadata=task.metadata
234
- )
235
- messages.append(agent_message)
236
-
237
- final_history = self.history_strategy.update_history(
238
- existing_history=existing_history,
239
- new_messages=messages
240
- )
241
-
242
- # Update task with final state
243
- task.history = final_history
244
-
245
- # State store update (if enabled)
246
- if self.state_store:
247
- self.state_store.update_state(
248
- session_id=session_id,
249
- state_data=StateData(
250
- sessionId=session_id,
251
- history=final_history,
252
- metadata=metadata # Use merged metadata
253
- )
254
- )
255
-
256
- return SendTaskResponse(
257
- id=request.id,
258
- result=task
259
- )
260
-
261
- except Exception as e:
262
- # Handle case where handler returns SendTaskResponse with error
263
- if isinstance(e, JSONRPCError):
264
- return SendTaskResponse(
265
- id=request.id,
266
- error=e
267
- )
268
- return SendTaskResponse(
269
- id=request.id,
270
- error=InternalError(data=str(e))
271
- )
272
-
273
- except ValidationError as e:
274
- return SendTaskResponse(
275
- id=request_data.get("id"),
276
- error=InvalidRequestError(data=e.errors())
277
- )
278
- except json.JSONDecodeError as e:
279
- return SendTaskResponse(
280
- id=request_data.get("id"),
281
- error=JSONParseError(data=str(e))
282
- )
283
-
117
+ # Add agent.json endpoint if card exists
118
+ if self.agent_card is not None:
119
+ @self.app.get("/.well-known/agent.json", response_model=AgentCard)
120
+ async def get_agent_card():
121
+ """Return the agent's service description"""
122
+ return self.agent_card
284
123
 
285
- async def _handle_subscribe_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> Union[EventSourceResponse, SendTaskStreamingResponse]:
286
- try:
287
- request = SendTaskStreamingRequest.model_validate(request_data.model_dump())
288
- #handler = self.subscriptions.get("tasks/sendSubscribe")
289
- handler = self.registry.get_subscription("tasks/sendSubscribe")
290
-
291
- if not handler:
292
- return SendTaskStreamingResponse(
293
- jsonrpc="2.0",
294
- id=request.id,
295
- error=MethodNotFoundError()
296
- )
297
-
298
- user_message = request.params.message
299
- request_metadata = request.params.metadata or {}
300
- if state_data:
301
- session_id = state_data.sessionId
302
- existing_history = state_data.history.copy() or []
303
- metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
304
- else:
305
- session_id = request.params.sessionId or str(uuid4())
306
- existing_history = [user_message]
307
- metadata = request_metadata
308
-
309
-
310
- async def event_generator():
311
-
312
- try:
313
-
314
- if state_data:
315
- raw_events = handler(request, state_data)
316
- else:
317
- raw_events = handler(request)
318
-
319
- normalized_events = self._normalize_subscription_events(request.params, raw_events)
320
-
321
- # Initialize streaming state
322
- stream_history = existing_history.copy()
323
- stream_metadata = metadata.copy()
324
-
325
- async for item in normalized_events:
326
- try:
327
-
328
- # Process artifact updates
329
- if isinstance(item, TaskArtifactUpdateEvent):
330
- # Create agent message from artifact parts
331
- agent_message = Message(
332
- role="agent",
333
- parts=[p for p in item.artifact.parts],
334
- metadata=item.artifact.metadata
335
- )
336
-
337
- # Update history using strategy
338
- new_history = self.history_strategy.update_history(
339
- existing_history=stream_history,
340
- new_messages=[agent_message]
341
- )
342
-
343
- # Merge metadata
344
- new_metadata = {
345
- **stream_metadata,
346
- **(item.artifact.metadata or {})
347
- }
348
-
349
- # Update state store if configured
350
- if self.state_store:
351
- self.state_store.update_state(
352
- session_id=session_id,
353
- state_data=StateData(
354
- sessionId=session_id,
355
- history=new_history,
356
- metadata=new_metadata
357
- )
358
- )
359
-
360
- # Update streaming state
361
- stream_history = new_history
362
- stream_metadata = new_metadata
363
-
364
-
365
- if isinstance(item, SendTaskStreamingResponse):
366
- yield item.model_dump_json()
367
- continue
368
-
369
- # Add validation for proper event types
370
- if not isinstance(item, (TaskStatusUpdateEvent, TaskArtifactUpdateEvent)):
371
- raise ValueError(f"Invalid event type: {type(item).__name__}")
372
-
373
- yield SendTaskStreamingResponse(
374
- jsonrpc="2.0",
375
- id=request.id,
376
- result=item
377
- ).model_dump_json()
378
-
379
- except Exception as e:
380
- yield SendTaskStreamingResponse(
381
- jsonrpc="2.0",
382
- id=request.id,
383
- error=InternalError(data=str(e))
384
- ).model_dump_json()
385
-
386
-
387
- except Exception as e:
388
- error = InternalError(data=str(e))
389
- if "not found" in str(e).lower():
390
- error = TaskNotFoundError()
391
- yield SendTaskStreamingResponse(
392
- jsonrpc="2.0",
393
- id=request.id,
394
- error=error
395
- ).model_dump_json()
396
-
397
- async def sse_stream():
398
- async for chunk in event_generator():
399
- # each chunk is already JSON; SSE wants "data: <payload>\n\n"
400
- yield (f"data: {chunk}\n\n").encode("utf-8")
401
-
402
- return StreamingResponse(
403
- sse_stream(),
404
- media_type="text/event-stream; charset=utf-8"
405
- )
406
-
407
-
408
- except ValidationError as e:
409
- return SendTaskStreamingResponse(
410
- jsonrpc="2.0",
411
- id=request_data.get("id"),
412
- error=InvalidRequestError(data=e.errors())
413
- )
414
- except json.JSONDecodeError as e:
415
- return SendTaskStreamingResponse(
416
- jsonrpc="2.0",
417
- id=request_data.get("id"),
418
- error=JSONParseError(data=str(e))
419
- )
420
- except HTTPException as e:
421
- if e.status_code == 405:
422
- return SendTaskStreamingResponse(
423
- jsonrpc="2.0",
424
- id=request_data.get("id"),
425
- error=UnsupportedOperationError()
426
- )
427
- return SendTaskStreamingResponse(
428
- jsonrpc="2.0",
429
- id=request_data.get("id"),
430
- error=InternalError(data=str(e))
431
- )
432
-
433
-
434
- def _handle_get_task(self, request_data: JSONRPCRequest) -> GetTaskResponse:
435
- try:
436
- # Validate request structure
437
- request = GetTaskRequest.model_validate(request_data.model_dump())
438
- handler = self.registry.get_handler("tasks/get")
439
-
440
- if not handler:
441
- return GetTaskResponse(
442
- id=request.id,
443
- error=MethodNotFoundError()
444
- )
445
-
446
- try:
447
- raw_result = handler(request)
448
-
449
- if isinstance(raw_result, GetTaskResponse):
450
- return self._validate_response_id(raw_result, request)
451
-
452
- # Use unified task builder with different defaults
453
- task = self.task_builder.build(
454
- content=raw_result,
455
- task_id=request.params.id,
456
- metadata=getattr(raw_result, "metadata", {}) or {}
457
- )
458
-
459
- return self._finalize_task_response(request, task)
460
-
461
- except Exception as e:
462
- # Handle case where handler returns SendTaskResponse with error
463
- if isinstance(e, JSONRPCError):
464
- return GetTaskResponse(
465
- id=request.id,
466
- error=e
467
- )
468
- return GetTaskResponse(
469
- id=request.id,
470
- error=InternalError(data=str(e))
471
- )
472
-
473
- except ValidationError as e:
474
- return GetTaskResponse(
475
- id=request_data.get("id"),
476
- error=InvalidRequestError(data=e.errors())
477
- )
478
- except json.JSONDecodeError as e:
479
- return GetTaskResponse(
480
- id=request_data.get("id"),
481
- error=JSONParseError(data=str(e))
482
- )
483
124
 
484
-
485
- def _handle_cancel_task(self, request_data: JSONRPCRequest) -> CancelTaskResponse:
486
- try:
487
- # Validate request structure
488
- request = CancelTaskRequest.model_validate(request_data.model_dump())
489
- handler = self.registry.get_handler("tasks/cancel")
490
-
491
- if not handler:
492
- return CancelTaskResponse(
493
- id=request.id,
494
- error=MethodNotFoundError()
495
- )
496
-
125
+ @self.app.post("/webhook")
126
+ async def handle_webhook(request: Request):
497
127
  try:
498
- raw_result = handler(request)
499
-
500
- cancel_task_builder = TaskBuilder(default_status=TaskState.CANCELED)
501
- # Handle direct CancelTaskResponse returns
502
- if isinstance(raw_result, CancelTaskResponse):
503
- return self._validate_response_id(raw_result, request)
504
-
505
- # Handle A2AStatus returns
506
- if isinstance(raw_result, A2AStatus):
507
- task = cancel_task_builder.normalize_from_status(
508
- status=raw_result.status,
509
- task_id=request.params.id,
510
- metadata=getattr(raw_result, "metadata", {}) or {}
511
- )
512
- else:
513
- # Existing processing for other return types
514
- task = cancel_task_builder.build(
515
- content=raw_result,
516
- task_id=request.params.id,
517
- metadata=getattr(raw_result, "metadata", {}) or {}
518
- )
519
-
520
- # Final validation and packaging
521
- return self._finalize_cancel_response(request, task)
522
-
128
+ data = await request.json()
129
+ req = WebhookRequest.model_validate(data)
523
130
  except Exception as e:
524
- # Handle case where handler returns SendTaskResponse with error
525
- if isinstance(e, JSONRPCError):
526
- return CancelTaskResponse(
527
- id=request.id,
528
- error=e
529
- )
530
- return CancelTaskResponse(
531
- id=request.id,
532
- error=InternalError(data=str(e))
533
- )
534
-
535
- except ValidationError as e:
536
- return CancelTaskResponse(
537
- id=request_data.get("id"),
538
- error=InvalidRequestError(data=e.errors())
539
- )
540
- except json.JSONDecodeError as e:
541
- return CancelTaskResponse(
542
- id=request_data.get("id"),
543
- error=JSONParseError(data=str(e))
544
- )
545
- except HTTPException as e:
546
- if e.status_code == 405:
547
- return CancelTaskResponse(
548
- id=request_data.get("id"),
549
- error=UnsupportedOperationError()
550
- )
551
- return CancelTaskResponse(
552
- id=request_data.get("id"),
553
- error=InternalError(data=str(e))
554
- )
555
-
556
- def _handle_set_notification(self, request_data: JSONRPCRequest) -> SetTaskPushNotificationResponse:
557
- try:
558
- request = SetTaskPushNotificationRequest.model_validate(request_data.model_dump())
559
- handler = self.registry.get_handler("tasks/pushNotification/set")
131
+ return WebhookResponse(accepted=False, error=str(e)).model_dump()
560
132
 
561
- if not handler:
562
- return SetTaskPushNotificationResponse(
563
- id=request.id,
564
- error=MethodNotFoundError()
565
- )
566
-
567
- try:
568
- # Execute handler (may or may not return something)
569
- raw_result = handler(request)
570
-
571
- # If handler returns nothing - build success response from request params
572
- if raw_result is None:
573
- return SetTaskPushNotificationResponse(
574
- id=request.id,
575
- result=request.params
576
- )
577
-
578
- # If handler returns a full response object
579
- if isinstance(raw_result, SetTaskPushNotificationResponse):
580
- return raw_result
581
-
133
+ response = await WebhookRequestProcessor(self.webhook_fn, self.state_mgr).process_request(req)
582
134
 
583
- except Exception as e:
584
- if isinstance(e, JSONRPCError):
585
- return SetTaskPushNotificationResponse(
586
- id=request.id,
587
- error=e
588
- )
589
- return SetTaskPushNotificationResponse(
590
- id=request.id,
591
- error=InternalError(data=str(e))
592
- )
593
-
594
- except ValidationError as e:
595
- return SetTaskPushNotificationResponse(
596
- id=request_data.get("id"),
597
- error=InvalidRequestError(data=e.errors())
598
- )
599
-
600
-
601
- def _handle_get_notification(self, request_data: JSONRPCRequest) -> GetTaskPushNotificationResponse:
602
- try:
603
- request = GetTaskPushNotificationRequest.model_validate(request_data.model_dump())
604
- handler = self.registry.get_handler("tasks/pushNotification/get")
135
+ return response.model_dump()
605
136
 
606
- if not handler:
607
- return GetTaskPushNotificationResponse(
608
- id=request.id,
609
- error=MethodNotFoundError()
610
- )
611
137
 
612
- try:
613
- raw_result = handler(request)
614
-
615
- if isinstance(raw_result, GetTaskPushNotificationResponse):
616
- return raw_result
617
- else:
618
- # Validate raw_result as TaskPushNotificationConfig
619
- config = TaskPushNotificationConfig.model_validate(raw_result)
620
- return GetTaskPushNotificationResponse(
621
- id=request.id,
622
- result=config
623
- )
624
- except ValidationError as e:
625
- return GetTaskPushNotificationResponse(
626
- id=request.id,
627
- error=InvalidParamsError(data=e.errors())
628
- )
629
- except Exception as e:
630
- if isinstance(e, JSONRPCError):
631
- return GetTaskPushNotificationResponse(
632
- id=request.id,
633
- error=e
634
- )
635
- return GetTaskPushNotificationResponse(
636
- id=request.id,
637
- error=InternalError(data=str(e))
638
- )
639
-
640
- except ValidationError as e:
641
- return GetTaskPushNotificationResponse(
642
- id=request_data.get("id"),
643
- error=InvalidRequestError(data=e.errors())
644
- )
645
- except json.JSONDecodeError as e:
646
- return GetTaskPushNotificationResponse(
647
- id=request_data.get("id"),
648
- error=JSONParseError(data=str(e))
649
- )
650
-
651
- # Response validation helper
652
- def _validate_response_id(self, response: Union[SendTaskResponse, GetTaskResponse], request) -> Union[SendTaskResponse, GetTaskResponse]:
653
- if response.result and response.result.id != request.params.id:
654
- return type(response)(
655
- id=request.id,
656
- error=InvalidParamsError(
657
- data=f"Task ID mismatch: {response.result.id} vs {request.params.id}"
658
- )
659
- )
660
- return response
661
-
662
- # Might refactor this later
663
- def _finalize_task_response(self, request: GetTaskRequest, task: Task) -> GetTaskResponse:
664
- """Final validation and processing for getTask responses."""
665
- # Validate task ID matches request
666
- if task.id != request.params.id:
667
- return GetTaskResponse(
668
- id=request.id,
669
- error=InvalidParamsError(
670
- data=f"Task ID mismatch: {task.id} vs {request.params.id}"
671
- )
672
- )
673
-
674
- # Apply history length filtering
675
- if request.params.historyLength and task.history:
676
- task.history = task.history[-request.params.historyLength:]
677
-
678
- return GetTaskResponse(
679
- id=request.id,
680
- result=task
681
- )
682
-
683
- def _finalize_cancel_response(self, request: CancelTaskRequest, task: Task) -> CancelTaskResponse:
684
- """Final validation and processing for cancel responses."""
685
- if task.id != request.params.id:
686
- return CancelTaskResponse(
687
- id=request.id,
688
- error=InvalidParamsError(
689
- data=f"Task ID mismatch: {task.id} vs {request.params.id}"
690
- )
691
- )
138
+ '''
692
139
 
693
- # Ensure cancellation-specific requirements are met
694
- if task.status.state not in [TaskState.CANCELED, TaskState.COMPLETED]:
695
- return CancelTaskResponse(
696
- id=request.id,
697
- error=TaskNotCancelableError()
698
- )
140
+ if self.has_frontend:
141
+ if not os.path.exists("frontend/index.html"):
142
+ raise FileNotFoundError("frontend/index.html does not exist")
143
+
144
+ @self.app.get("/")
145
+ async def get_frontend():
146
+ return FileResponse("frontend/index.html")
147
+ '''
699
148
 
700
- return CancelTaskResponse(
701
- id=request.id,
702
- result=task
703
- )
149
+ '''
150
+ Setup the decorators for the various A2A methods.
151
+ '''
152
+ def on_send_task(self,forward_to_webhook: bool = False):
153
+ def decorator(fn: Callable[[SendTaskRequest, Optional[StateData]], Any]) -> Callable:
154
+ fn.forward_to_webhook = forward_to_webhook
155
+ self.registry.register("tasks/send", fn)
156
+ return fn
157
+ return decorator
704
158
 
159
+ def on_send_subscribe_task(self,forward_to_webhook: bool = False):
160
+ def decorator(fn: Callable[[SendTaskStreamingRequest, Optional[StateData]], Any]):
161
+ fn.forward_to_webhook = forward_to_webhook
162
+ self.registry.register("tasks/sendSubscribe", fn, subscription=True)
163
+ return fn
164
+ return decorator
705
165
 
706
- async def _normalize_subscription_events(self, params: TaskSendParams, events: AsyncGenerator) -> AsyncGenerator[Union[SendTaskStreamingResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent], None]:
707
- artifact_state = defaultdict(lambda: {"index": 0, "last_chunk": False})
708
-
709
- async for item in events:
710
- # Pass through fully formed responses immediately
711
- if isinstance(item, SendTaskStreamingResponse):
712
- yield item
713
- continue
166
+ def task_get(self):
167
+ def decorator(fn: Callable[[GetTaskRequest], Any]):
168
+ self.registry.register("tasks/get", fn)
169
+ return fn
170
+ return decorator
714
171
 
715
- # Handle protocol status updates
716
- if isinstance(item, A2AStatus):
717
- yield TaskStatusUpdateEvent(
718
- id=params.id,
719
- status=TaskStatus(
720
- state=TaskState(item.status),
721
- timestamp=datetime.now()
722
- ),
723
- final=item.final or (item.status.lower() == TaskState.COMPLETED),
724
- metadata=item.metadata
725
- )
726
-
727
- # Handle stream content
728
- elif isinstance(item, (A2AStreamResponse, str, bytes, TextPart, FilePart, DataPart, Artifact, list)):
729
- # Convert to A2AStreamResponse if needed
730
- if not isinstance(item, A2AStreamResponse):
731
- item = A2AStreamResponse(content=item)
172
+ def task_cancel(self):
173
+ def decorator(fn: Callable[[CancelTaskRequest], Any]):
174
+ self.registry.register("tasks/cancel", fn)
175
+ return fn
176
+ return decorator
732
177
 
733
- # Process content into parts
734
- parts = []
735
- content = item.content
736
-
737
- if isinstance(content, str):
738
- parts.append(TextPart(text=content))
739
- elif isinstance(content, bytes):
740
- parts.append(FilePart(file=FileContent(bytes=content)))
741
- elif isinstance(content, (TextPart, FilePart, DataPart)):
742
- parts.append(content)
743
- elif isinstance(content, Artifact):
744
- parts = content.parts
745
- elif isinstance(content, list):
746
- for elem in content:
747
- if isinstance(elem, str):
748
- parts.append(TextPart(text=elem))
749
- elif isinstance(elem, (TextPart, FilePart, DataPart)):
750
- parts.append(elem)
751
- elif isinstance(elem, Artifact):
752
- parts.extend(elem.parts)
178
+ def set_notification(self):
179
+ def decorator(fn: Callable[[SetTaskPushNotificationRequest], Any]):
180
+ self.registry.register("tasks/pushNotification/set", fn)
181
+ return fn
182
+ return decorator
753
183
 
754
- # Track artifact state
755
- artifact_idx = item.index
756
- state = artifact_state[artifact_idx]
757
-
758
- yield TaskArtifactUpdateEvent(
759
- id=params.id,
760
- artifact=Artifact(
761
- parts=parts,
762
- index=artifact_idx,
763
- append=item.append or state["index"] == artifact_idx,
764
- lastChunk=item.final or state["last_chunk"],
765
- metadata=item.metadata
766
- )
767
- )
768
-
769
- # Update artifact state tracking
770
- if item.final:
771
- state["last_chunk"] = True
772
- state["index"] += 1
773
-
774
- # Pass through protocol events directly
775
- elif isinstance(item, (TaskStatusUpdateEvent, TaskArtifactUpdateEvent)):
776
- yield item
777
-
778
- # Handle invalid types
779
- else:
780
- yield SendTaskStreamingResponse(
781
- jsonrpc="2.0",
782
- id=params.id, # Typically comes from request, but using params.id as fallback
783
- error=InvalidParamsError(
784
- data=f"Unsupported event type: {type(item).__name__}"
785
- )
786
- )
184
+ def get_notification(self):
185
+ def decorator(fn: Callable[[GetTaskPushNotificationRequest], Any]):
186
+ self.registry.register("tasks/pushNotification/get", fn)
187
+ return fn
188
+ return decorator
189
+
190
+ '''
191
+ This is outside of the A2A protocol spec. A callback allows a re-triggering of an existing task by a external service.
192
+ If a state store is provided, the callback will use the push notification config to call another callback.
193
+ This effectively allows backward communication.
194
+ '''
195
+
196
+ def webhook(self):
197
+ def decorator(fn: Callable[[WebhookRequest], Any]):
198
+ self.webhook_fn = fn
199
+ return fn
200
+ return decorator
201
+
787
202
 
788
-
789
- def configure(self, **kwargs):
790
- self.server_config.update(kwargs)
791
-
792
- def add_cors_middleware(self, **kwargs):
793
- self.app.add_middleware(
794
- CORSMiddleware,
795
- **{k: v for k, v in kwargs.items() if v is not None}
796
- )
797
-
798
- def run(self):
799
- uvicorn.run(
800
- self.app,
801
- host=self.server_config["host"],
802
- port=self.server_config["port"],
803
- reload=self.server_config["reload"]
804
- )