smarta2a 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fasta2a/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """
2
+ py_a2a - A Python package for implementing an A2A server
3
+ """
4
+
5
+ __version__ = "0.1.0"
6
+
7
+ from .server import FastA2A
8
+
9
+ __all__ = ["FastA2A", "models"]
fasta2a/server.py ADDED
@@ -0,0 +1,710 @@
1
+ from typing import Callable, Any, Optional, Dict, Union, List, AsyncGenerator
2
+ import json
3
+ import inspect
4
+ from datetime import datetime
5
+ from collections import defaultdict
6
+ from fastapi import FastAPI, Request, HTTPException, APIRouter
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from sse_starlette.sse import EventSourceResponse
9
+ from pydantic import ValidationError
10
+ import uvicorn
11
+ from fastapi.responses import StreamingResponse
12
+ from uuid import uuid4
13
+
14
+
15
+ from .types import (
16
+ JSONRPCResponse,
17
+ Task,
18
+ Artifact,
19
+ TextPart,
20
+ FilePart,
21
+ FileContent,
22
+ DataPart,
23
+ Part,
24
+ TaskStatus,
25
+ TaskState,
26
+ JSONRPCError,
27
+ SendTaskResponse,
28
+ JSONRPCRequest,
29
+ A2AResponse,
30
+ SendTaskRequest,
31
+ SendTaskStreamingRequest,
32
+ SendTaskStreamingResponse,
33
+ GetTaskRequest,
34
+ GetTaskResponse,
35
+ CancelTaskRequest,
36
+ CancelTaskResponse,
37
+ TaskStatusUpdateEvent,
38
+ TaskArtifactUpdateEvent,
39
+ JSONParseError,
40
+ InvalidRequestError,
41
+ MethodNotFoundError,
42
+ ContentTypeNotSupportedError,
43
+ InternalError,
44
+ UnsupportedOperationError,
45
+ TaskNotFoundError,
46
+ InvalidParamsError,
47
+ TaskNotCancelableError,
48
+ A2AStatus,
49
+ A2AStreamResponse,
50
+ TaskSendParams,
51
+ )
52
+
53
+ class FastA2A:
54
+ def __init__(self, name: str, **fastapi_kwargs):
55
+ self.name = name
56
+ self.handlers: Dict[str, Callable] = {}
57
+ self.subscriptions: Dict[str, Callable] = {}
58
+ self.app = FastAPI(title=name, **fastapi_kwargs)
59
+ self.router = APIRouter()
60
+ self._registered_decorators = set()
61
+ self._setup_routes()
62
+ self.server_config = {
63
+ "host": "0.0.0.0",
64
+ "port": 8000,
65
+ "reload": False
66
+ }
67
+
68
+
69
+ def _setup_routes(self):
70
+ @self.app.post("/")
71
+ async def handle_request(request: Request):
72
+ try:
73
+ data = await request.json()
74
+ request_obj = JSONRPCRequest(**data)
75
+ except Exception as e:
76
+ return JSONRPCResponse(
77
+ id=None,
78
+ error=JSONRPCError(
79
+ code=-32700,
80
+ message="Parse error",
81
+ data=str(e)
82
+ )
83
+ ).model_dump()
84
+
85
+ response = await self.process_request(request_obj.model_dump())
86
+
87
+ # <-- Accept both SSE‐style responses:
88
+ if isinstance(response, (EventSourceResponse, StreamingResponse)):
89
+ return response
90
+
91
+ # <-- Everything else is a normal pydantic JSONRPCResponse
92
+ return response.model_dump()
93
+
94
+ def _register_handler(self, method: str, func: Callable, handler_name: str, handler_type: str = "handler"):
95
+ """Shared registration logic with duplicate checking"""
96
+ if method in self._registered_decorators:
97
+ raise RuntimeError(
98
+ f"@{handler_name} decorator for method '{method}' "
99
+ f"can only be used once per FastA2A instance"
100
+ )
101
+
102
+ if handler_type == "handler":
103
+ self.handlers[method] = func
104
+ else:
105
+ self.subscriptions[method] = func
106
+
107
+ self._registered_decorators.add(method)
108
+
109
+ def on_send_task(self) -> Callable:
110
+ def decorator(func: Callable[[SendTaskRequest], Any]) -> Callable:
111
+ self._register_handler("tasks/send", func, "on_send_task", "handler")
112
+ return func
113
+ return decorator
114
+
115
+ def on_send_subscribe_task(self) -> Callable:
116
+ def decorator(func: Callable) -> Callable:
117
+ self._register_handler("tasks/sendSubscribe", func, "on_send_subscribe_task", "subscription")
118
+ return func
119
+ return decorator
120
+
121
+ def task_get(self):
122
+ def decorator(func: Callable[[GetTaskRequest], Task]):
123
+ self._register_handler("tasks/get", func, "task_get", "handler")
124
+ return func
125
+ return decorator
126
+
127
+ def task_cancel(self):
128
+ def decorator(func: Callable[[CancelTaskRequest], Task]):
129
+ self._register_handler("tasks/cancel", func, "task_cancel", "handler")
130
+ return func
131
+ return decorator
132
+
133
+ async def process_request(self, request_data: dict) -> JSONRPCResponse:
134
+ try:
135
+ method = request_data.get("method")
136
+ if method == "tasks/send":
137
+ return self._handle_send_task(request_data)
138
+ elif method == "tasks/sendSubscribe":
139
+ return await self._handle_subscribe_task(request_data)
140
+ elif method == "tasks/get":
141
+ return self._handle_get_task(request_data)
142
+ elif method == "tasks/cancel":
143
+ return self._handle_cancel_task(request_data)
144
+ else:
145
+ return self._error_response(
146
+ request_data.get("id"),
147
+ -32601,
148
+ "Method not found"
149
+ )
150
+ except ValidationError as e:
151
+ return self._error_response(
152
+ request_data.get("id"),
153
+ -32600,
154
+ "Invalid params",
155
+ e.errors()
156
+ )
157
+
158
+ def _handle_send_task(self, request_data: dict) -> SendTaskResponse:
159
+ try:
160
+ # Validate request format
161
+ request = SendTaskRequest.model_validate(request_data)
162
+ handler = self.handlers.get("tasks/send")
163
+
164
+ if not handler:
165
+ return SendTaskResponse(
166
+ id=request.id,
167
+ error=MethodNotFoundError()
168
+ )
169
+
170
+ try:
171
+ raw_result = handler(request)
172
+
173
+ if isinstance(raw_result, SendTaskResponse):
174
+ return raw_result
175
+
176
+ # Use unified task builder
177
+ task = self._build_task(
178
+ content=raw_result,
179
+ task_id=request.params.id,
180
+ session_id=request.params.sessionId,
181
+ default_status=TaskState.COMPLETED,
182
+ metadata=request.params.metadata or {}
183
+ )
184
+
185
+ return SendTaskResponse(
186
+ id=request.id,
187
+ result=task
188
+ )
189
+
190
+ except Exception as e:
191
+ # Handle case where handler returns SendTaskResponse with error
192
+ if isinstance(e, JSONRPCError):
193
+ return SendTaskResponse(
194
+ id=request.id,
195
+ error=e
196
+ )
197
+ return SendTaskResponse(
198
+ id=request.id,
199
+ error=InternalError(data=str(e))
200
+ )
201
+
202
+ except ValidationError as e:
203
+ return SendTaskResponse(
204
+ id=request_data.get("id"),
205
+ error=InvalidRequestError(data=e.errors())
206
+ )
207
+ except json.JSONDecodeError as e:
208
+ return SendTaskResponse(
209
+ id=request_data.get("id"),
210
+ error=JSONParseError(data=str(e))
211
+ )
212
+
213
+
214
+ async def _handle_subscribe_task(self, request_data: dict) -> Union[EventSourceResponse, SendTaskStreamingResponse]:
215
+ try:
216
+ request = SendTaskStreamingRequest.model_validate(request_data)
217
+ handler = self.subscriptions.get("tasks/sendSubscribe")
218
+
219
+ if not handler:
220
+ return SendTaskStreamingResponse(
221
+ jsonrpc="2.0",
222
+ id=request.id,
223
+ error=MethodNotFoundError()
224
+ )
225
+
226
+ async def event_generator():
227
+
228
+ try:
229
+ raw_events = handler(request)
230
+ normalized_events = self._normalize_subscription_events(request.params, raw_events)
231
+
232
+ async for item in normalized_events:
233
+ try:
234
+ if isinstance(item, SendTaskStreamingResponse):
235
+ yield item.model_dump_json()
236
+ continue
237
+
238
+ # Add validation for proper event types
239
+ if not isinstance(item, (TaskStatusUpdateEvent, TaskArtifactUpdateEvent)):
240
+ raise ValueError(f"Invalid event type: {type(item).__name__}")
241
+
242
+ yield SendTaskStreamingResponse(
243
+ jsonrpc="2.0",
244
+ id=request.id,
245
+ result=item
246
+ ).model_dump_json()
247
+
248
+ except Exception as e:
249
+ yield SendTaskStreamingResponse(
250
+ jsonrpc="2.0",
251
+ id=request.id,
252
+ error=InternalError(data=str(e))
253
+ ).model_dump_json()
254
+
255
+
256
+ except Exception as e:
257
+ error = InternalError(data=str(e))
258
+ if "not found" in str(e).lower():
259
+ error = TaskNotFoundError()
260
+ yield SendTaskStreamingResponse(
261
+ jsonrpc="2.0",
262
+ id=request.id,
263
+ error=error
264
+ ).model_dump_json()
265
+
266
+ async def sse_stream():
267
+ async for chunk in event_generator():
268
+ # each chunk is already JSON; SSE wants "data: <payload>\n\n"
269
+ yield (f"data: {chunk}\n\n").encode("utf-8")
270
+
271
+ return StreamingResponse(
272
+ sse_stream(),
273
+ media_type="text/event-stream; charset=utf-8"
274
+ )
275
+
276
+
277
+ except ValidationError as e:
278
+ return SendTaskStreamingResponse(
279
+ jsonrpc="2.0",
280
+ id=request_data.get("id"),
281
+ error=InvalidRequestError(data=e.errors())
282
+ )
283
+ except json.JSONDecodeError as e:
284
+ return SendTaskStreamingResponse(
285
+ jsonrpc="2.0",
286
+ id=request_data.get("id"),
287
+ error=JSONParseError(data=str(e))
288
+ )
289
+ except HTTPException as e:
290
+ if e.status_code == 405:
291
+ return SendTaskStreamingResponse(
292
+ jsonrpc="2.0",
293
+ id=request_data.get("id"),
294
+ error=UnsupportedOperationError()
295
+ )
296
+ return SendTaskStreamingResponse(
297
+ jsonrpc="2.0",
298
+ id=request_data.get("id"),
299
+ error=InternalError(data=str(e))
300
+ )
301
+
302
+
303
+ def _handle_get_task(self, request_data: dict) -> GetTaskResponse:
304
+ try:
305
+ # Validate request structure
306
+ request = GetTaskRequest.model_validate(request_data)
307
+ handler = self.handlers.get("tasks/get")
308
+
309
+ if not handler:
310
+ return GetTaskResponse(
311
+ id=request.id,
312
+ error=MethodNotFoundError()
313
+ )
314
+
315
+ try:
316
+ raw_result = handler(request)
317
+
318
+ if isinstance(raw_result, GetTaskResponse):
319
+ return self._validate_response_id(raw_result, request)
320
+
321
+ # Use unified task builder with different defaults
322
+ task = self._build_task(
323
+ content=raw_result,
324
+ task_id=request.params.id,
325
+ default_status=TaskState.COMPLETED,
326
+ metadata=request.params.metadata or {}
327
+ )
328
+
329
+ return self._finalize_task_response(request, task)
330
+
331
+ except Exception as e:
332
+ # Handle case where handler returns SendTaskResponse with error
333
+ if isinstance(e, JSONRPCError):
334
+ return GetTaskResponse(
335
+ id=request.id,
336
+ error=e
337
+ )
338
+ return GetTaskResponse(
339
+ id=request.id,
340
+ error=InternalError(data=str(e))
341
+ )
342
+
343
+ except ValidationError as e:
344
+ return GetTaskResponse(
345
+ id=request_data.get("id"),
346
+ error=InvalidRequestError(data=e.errors())
347
+ )
348
+ except json.JSONDecodeError as e:
349
+ return GetTaskResponse(
350
+ id=request_data.get("id"),
351
+ error=JSONParseError(data=str(e))
352
+ )
353
+
354
+
355
+ def _handle_cancel_task(self, request_data: dict) -> CancelTaskResponse:
356
+ try:
357
+ # Validate request structure
358
+ request = CancelTaskRequest.model_validate(request_data)
359
+ handler = self.handlers.get("tasks/cancel")
360
+
361
+ if not handler:
362
+ return CancelTaskResponse(
363
+ id=request.id,
364
+ error=MethodNotFoundError()
365
+ )
366
+
367
+ try:
368
+ raw_result = handler(request)
369
+
370
+ # Handle direct CancelTaskResponse returns
371
+ if isinstance(raw_result, CancelTaskResponse):
372
+ return self._validate_response_id(raw_result, request)
373
+
374
+ # Handle A2AStatus returns
375
+ if isinstance(raw_result, A2AStatus):
376
+ task = self._build_task_from_status(
377
+ status=raw_result,
378
+ task_id=request.params.id,
379
+ metadata=raw_result.metadata or {}
380
+ )
381
+ else:
382
+ # Existing processing for other return types
383
+ task = self._build_task(
384
+ content=raw_result,
385
+ task_id=request.params.id,
386
+ metadata=raw_result.metadata or {}
387
+ )
388
+
389
+ # Final validation and packaging
390
+ return self._finalize_cancel_response(request, task)
391
+
392
+ except Exception as e:
393
+ # Handle case where handler returns SendTaskResponse with error
394
+ if isinstance(e, JSONRPCError):
395
+ return CancelTaskResponse(
396
+ id=request.id,
397
+ error=e
398
+ )
399
+ return CancelTaskResponse(
400
+ id=request.id,
401
+ error=InternalError(data=str(e))
402
+ )
403
+
404
+ except ValidationError as e:
405
+ return CancelTaskResponse(
406
+ id=request_data.get("id"),
407
+ error=InvalidRequestError(data=e.errors())
408
+ )
409
+ except json.JSONDecodeError as e:
410
+ return CancelTaskResponse(
411
+ id=request_data.get("id"),
412
+ error=JSONParseError(data=str(e))
413
+ )
414
+ except HTTPException as e:
415
+ if e.status_code == 405:
416
+ return CancelTaskResponse(
417
+ id=request_data.get("id"),
418
+ error=UnsupportedOperationError()
419
+ )
420
+ return CancelTaskResponse(
421
+ id=request_data.get("id"),
422
+ error=InternalError(data=str(e))
423
+ )
424
+
425
+ def _normalize_artifacts(self, content: Any) -> List[Artifact]:
426
+ """Handle both A2AResponse content and regular returns"""
427
+ if isinstance(content, Artifact):
428
+ return [content]
429
+
430
+ if isinstance(content, list):
431
+ # Handle list of artifacts
432
+ if all(isinstance(item, Artifact) for item in content):
433
+ return content
434
+
435
+ # Handle mixed parts in list
436
+ parts = []
437
+ for item in content:
438
+ if isinstance(item, Artifact):
439
+ parts.extend(item.parts)
440
+ else:
441
+ parts.append(self._create_part(item))
442
+ return [Artifact(parts=parts)]
443
+
444
+ # Handle single part returns
445
+ if isinstance(content, (str, Part, dict)):
446
+ return [Artifact(parts=[self._create_part(content)])]
447
+
448
+ # Handle raw artifact dicts
449
+ try:
450
+ return [Artifact.model_validate(content)]
451
+ except ValidationError:
452
+ return [Artifact(parts=[TextPart(text=str(content))])]
453
+
454
+
455
+ def _build_task(
456
+ self,
457
+ content: Any,
458
+ task_id: str,
459
+ session_id: Optional[str] = None,
460
+ default_status: TaskState = TaskState.COMPLETED,
461
+ metadata: Optional[dict] = None
462
+ ) -> Task:
463
+ """Universal task construction from various return types."""
464
+ if isinstance(content, Task):
465
+ return content
466
+
467
+ # Handle A2AResponse for sendTask case
468
+ if isinstance(content, A2AResponse):
469
+ status = content.status if isinstance(content.status, TaskStatus) \
470
+ else TaskStatus(state=content.status)
471
+ artifacts = self._normalize_content(content.content)
472
+ return Task(
473
+ id=task_id,
474
+ sessionId=session_id or str(uuid4()), # Generate if missing
475
+ status=status,
476
+ artifacts=artifacts,
477
+ metadata=metadata or {}
478
+ )
479
+
480
+ try: # Attempt direct validation for dicts
481
+ return Task.model_validate(content)
482
+ except ValidationError:
483
+ pass
484
+
485
+ # Fallback to content normalization
486
+ artifacts = self._normalize_content(content)
487
+ return Task(
488
+ id=task_id,
489
+ sessionId=session_id,
490
+ status=TaskStatus(state=default_status),
491
+ artifacts=artifacts,
492
+ metadata=metadata or {}
493
+ )
494
+
495
+ def _build_task_from_status(self, status: A2AStatus, task_id: str, metadata: dict) -> Task:
496
+ """Convert A2AStatus to a Task with proper cancellation state."""
497
+ return Task(
498
+ id=task_id,
499
+ status=TaskStatus(
500
+ state=TaskState(status.status),
501
+ timestamp=datetime.now()
502
+ ),
503
+ metadata=metadata,
504
+ # Include empty/default values for required fields
505
+ sessionId="",
506
+ artifacts=[],
507
+ history=[]
508
+ )
509
+
510
+
511
+ def _normalize_content(self, content: Any) -> List[Artifact]:
512
+ """Handle all content types for both sendTask and getTask cases."""
513
+ if isinstance(content, Artifact):
514
+ return [content]
515
+
516
+ if isinstance(content, list):
517
+ if all(isinstance(item, Artifact) for item in content):
518
+ return content
519
+ return [Artifact(parts=self._parts_from_mixed(content))]
520
+
521
+ if isinstance(content, (str, Part, dict)):
522
+ return [Artifact(parts=[self._create_part(content)])]
523
+
524
+ try: # Handle raw artifact dicts
525
+ return [Artifact.model_validate(content)]
526
+ except ValidationError:
527
+ return [Artifact(parts=[TextPart(text=str(content))])]
528
+
529
+ def _parts_from_mixed(self, items: List[Any]) -> List[Part]:
530
+ """Extract parts from mixed content lists."""
531
+ parts = []
532
+ for item in items:
533
+ if isinstance(item, Artifact):
534
+ parts.extend(item.parts)
535
+ else:
536
+ parts.append(self._create_part(item))
537
+ return parts
538
+
539
+
540
+ def _create_part(self, item: Any) -> Part:
541
+ """Convert primitive types to proper Part models"""
542
+ if isinstance(item, (TextPart, FilePart, DataPart)):
543
+ return item
544
+
545
+ if isinstance(item, str):
546
+ return TextPart(text=item)
547
+
548
+ if isinstance(item, dict):
549
+ try:
550
+ return Part.model_validate(item)
551
+ except ValidationError:
552
+ return TextPart(text=str(item))
553
+
554
+ return TextPart(text=str(item))
555
+
556
+
557
+ # Response validation helper
558
+ def _validate_response_id(self, response: Union[SendTaskResponse, GetTaskResponse], request) -> Union[SendTaskResponse, GetTaskResponse]:
559
+ if response.result and response.result.id != request.params.id:
560
+ return type(response)(
561
+ id=request.id,
562
+ error=InvalidParamsError(
563
+ data=f"Task ID mismatch: {response.result.id} vs {request.params.id}"
564
+ )
565
+ )
566
+ return response
567
+
568
+ # Might refactor this later
569
+ def _finalize_task_response(self, request: GetTaskRequest, task: Task) -> GetTaskResponse:
570
+ """Final validation and processing for getTask responses."""
571
+ # Validate task ID matches request
572
+ if task.id != request.params.id:
573
+ return GetTaskResponse(
574
+ id=request.id,
575
+ error=InvalidParamsError(
576
+ data=f"Task ID mismatch: {task.id} vs {request.params.id}"
577
+ )
578
+ )
579
+
580
+ # Apply history length filtering
581
+ if request.params.historyLength and task.history:
582
+ task.history = task.history[-request.params.historyLength:]
583
+
584
+ return GetTaskResponse(
585
+ id=request.id,
586
+ result=task
587
+ )
588
+
589
+ def _finalize_cancel_response(self, request: CancelTaskRequest, task: Task) -> CancelTaskResponse:
590
+ """Final validation and processing for cancel responses."""
591
+ if task.id != request.params.id:
592
+ return CancelTaskResponse(
593
+ id=request.id,
594
+ error=InvalidParamsError(
595
+ data=f"Task ID mismatch: {task.id} vs {request.params.id}"
596
+ )
597
+ )
598
+
599
+ # Ensure cancellation-specific requirements are met
600
+ if task.status.state not in [TaskState.CANCELED, TaskState.COMPLETED]:
601
+ return CancelTaskResponse(
602
+ id=request.id,
603
+ error=TaskNotCancelableError()
604
+ )
605
+
606
+ return CancelTaskResponse(
607
+ id=request.id,
608
+ result=task
609
+ )
610
+
611
+
612
+ async def _normalize_subscription_events(self, params: TaskSendParams, events: AsyncGenerator) -> AsyncGenerator[Union[SendTaskStreamingResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent], None]:
613
+ artifact_state = defaultdict(lambda: {"index": 0, "last_chunk": False})
614
+
615
+ async for item in events:
616
+ # Pass through fully formed responses immediately
617
+ if isinstance(item, SendTaskStreamingResponse):
618
+ yield item
619
+ continue
620
+
621
+ # Handle protocol status updates
622
+ if isinstance(item, A2AStatus):
623
+ yield TaskStatusUpdateEvent(
624
+ id=params.id,
625
+ status=TaskStatus(
626
+ state=TaskState(item.status),
627
+ timestamp=datetime.now()
628
+ ),
629
+ final=item.final or (item.status.lower() == TaskState.COMPLETED),
630
+ metadata=item.metadata
631
+ )
632
+
633
+ # Handle stream content
634
+ elif isinstance(item, (A2AStreamResponse, str, bytes, TextPart, FilePart, DataPart, Artifact, list)):
635
+ # Convert to A2AStreamResponse if needed
636
+ if not isinstance(item, A2AStreamResponse):
637
+ item = A2AStreamResponse(content=item)
638
+
639
+ # Process content into parts
640
+ parts = []
641
+ content = item.content
642
+
643
+ if isinstance(content, str):
644
+ parts.append(TextPart(text=content))
645
+ elif isinstance(content, bytes):
646
+ parts.append(FilePart(file=FileContent(bytes=content)))
647
+ elif isinstance(content, (TextPart, FilePart, DataPart)):
648
+ parts.append(content)
649
+ elif isinstance(content, Artifact):
650
+ parts = content.parts
651
+ elif isinstance(content, list):
652
+ for elem in content:
653
+ if isinstance(elem, str):
654
+ parts.append(TextPart(text=elem))
655
+ elif isinstance(elem, (TextPart, FilePart, DataPart)):
656
+ parts.append(elem)
657
+ elif isinstance(elem, Artifact):
658
+ parts.extend(elem.parts)
659
+
660
+ # Track artifact state
661
+ artifact_idx = item.index
662
+ state = artifact_state[artifact_idx]
663
+
664
+ yield TaskArtifactUpdateEvent(
665
+ id=params.id,
666
+ artifact=Artifact(
667
+ parts=parts,
668
+ index=artifact_idx,
669
+ append=item.append or state["index"] == artifact_idx,
670
+ lastChunk=item.final or state["last_chunk"],
671
+ metadata=item.metadata
672
+ )
673
+ )
674
+
675
+ # Update artifact state tracking
676
+ if item.final:
677
+ state["last_chunk"] = True
678
+ state["index"] += 1
679
+
680
+ # Pass through protocol events directly
681
+ elif isinstance(item, (TaskStatusUpdateEvent, TaskArtifactUpdateEvent)):
682
+ yield item
683
+
684
+ # Handle invalid types
685
+ else:
686
+ yield SendTaskStreamingResponse(
687
+ jsonrpc="2.0",
688
+ id=params.id, # Typically comes from request, but using params.id as fallback
689
+ error=InvalidParamsError(
690
+ data=f"Unsupported event type: {type(item).__name__}"
691
+ )
692
+ )
693
+
694
+
695
+ def configure(self, **kwargs):
696
+ self.server_config.update(kwargs)
697
+
698
+ def add_cors_middleware(self, **kwargs):
699
+ self.app.add_middleware(
700
+ CORSMiddleware,
701
+ **{k: v for k, v in kwargs.items() if v is not None}
702
+ )
703
+
704
+ def run(self):
705
+ uvicorn.run(
706
+ self.app,
707
+ host=self.server_config["host"],
708
+ port=self.server_config["port"],
709
+ reload=self.server_config["reload"]
710
+ )
fasta2a/types.py ADDED
@@ -0,0 +1,424 @@
1
+ from typing import Union, Any
2
+ from pydantic import BaseModel, Field, TypeAdapter
3
+ from typing import Literal, List, Annotated, Optional
4
+ from datetime import datetime
5
+ from pydantic import model_validator, ConfigDict, field_serializer, field_validator
6
+ from uuid import uuid4
7
+ from enum import Enum
8
+ from typing_extensions import Self
9
+
10
+
11
+ class TaskState(str, Enum):
12
+ SUBMITTED = "submitted"
13
+ WORKING = "working"
14
+ INPUT_REQUIRED = "input-required"
15
+ COMPLETED = "completed"
16
+ CANCELED = "canceled"
17
+ FAILED = "failed"
18
+ UNKNOWN = "unknown"
19
+
20
+
21
+ class TextPart(BaseModel):
22
+ type: Literal["text"] = "text"
23
+ text: str
24
+ metadata: dict[str, Any] | None = None
25
+
26
+
27
+ class FileContent(BaseModel):
28
+ name: str | None = None
29
+ mimeType: str | None = None
30
+ bytes: str | None = None
31
+ uri: str | None = None
32
+
33
+ @model_validator(mode="after")
34
+ def check_content(self) -> Self:
35
+ if not (self.bytes or self.uri):
36
+ raise ValueError("Either 'bytes' or 'uri' must be present in the file data")
37
+ if self.bytes and self.uri:
38
+ raise ValueError(
39
+ "Only one of 'bytes' or 'uri' can be present in the file data"
40
+ )
41
+ return self
42
+
43
+
44
+ class FilePart(BaseModel):
45
+ type: Literal["file"] = "file"
46
+ file: FileContent
47
+ metadata: dict[str, Any] | None = None
48
+
49
+
50
+ class DataPart(BaseModel):
51
+ type: Literal["data"] = "data"
52
+ data: dict[str, Any]
53
+ metadata: dict[str, Any] | None = None
54
+
55
+
56
+ Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")]
57
+
58
+
59
+ class Message(BaseModel):
60
+ role: Literal["user", "agent"]
61
+ parts: List[Part]
62
+ metadata: dict[str, Any] | None = None
63
+
64
+
65
+ class TaskStatus(BaseModel):
66
+ state: TaskState
67
+ message: Message | None = None
68
+ timestamp: datetime = Field(default_factory=datetime.now)
69
+
70
+ @field_serializer("timestamp")
71
+ def serialize_dt(self, dt: datetime, _info):
72
+ return dt.isoformat()
73
+
74
+
75
+ class Artifact(BaseModel):
76
+ name: str | None = None
77
+ description: str | None = None
78
+ parts: List[Part]
79
+ metadata: dict[str, Any] | None = None
80
+ index: int = 0
81
+ append: bool | None = None
82
+ lastChunk: bool | None = None
83
+
84
+
85
+ class Task(BaseModel):
86
+ id: str
87
+ sessionId: str | None = None
88
+ status: TaskStatus
89
+ artifacts: List[Artifact] | None = None
90
+ history: List[Message] | None = None
91
+ metadata: dict[str, Any] | None = None
92
+
93
+
94
+ class TaskStatusUpdateEvent(BaseModel):
95
+ id: str
96
+ status: TaskStatus
97
+ final: bool = False
98
+ metadata: dict[str, Any] | None = None
99
+
100
+
101
+ class TaskArtifactUpdateEvent(BaseModel):
102
+ id: str
103
+ artifact: Artifact
104
+ metadata: dict[str, Any] | None = None
105
+
106
+
107
+ class AuthenticationInfo(BaseModel):
108
+ model_config = ConfigDict(extra="allow")
109
+
110
+ schemes: List[str]
111
+ credentials: str | None = None
112
+
113
+
114
+ class PushNotificationConfig(BaseModel):
115
+ url: str
116
+ token: str | None = None
117
+ authentication: AuthenticationInfo | None = None
118
+
119
+
120
+ class TaskIdParams(BaseModel):
121
+ id: str
122
+ metadata: dict[str, Any] | None = None
123
+
124
+
125
+ class TaskQueryParams(TaskIdParams):
126
+ historyLength: int | None = None
127
+
128
+
129
+ class TaskSendParams(BaseModel):
130
+ id: str
131
+ sessionId: str = Field(default_factory=lambda: uuid4().hex)
132
+ message: Message
133
+ acceptedOutputModes: Optional[List[str]] = None
134
+ pushNotification: PushNotificationConfig | None = None
135
+ historyLength: int | None = None
136
+ metadata: dict[str, Any] | None = None
137
+
138
+
139
+ class TaskPushNotificationConfig(BaseModel):
140
+ id: str
141
+ pushNotificationConfig: PushNotificationConfig
142
+
143
+
144
+ ## Custom Mixins
145
+
146
+ class ContentMixin(BaseModel):
147
+ @property
148
+ def content(self) -> Optional[List[Part]]:
149
+ """Direct access to message parts when available"""
150
+ try:
151
+ # Handle different request types that contain messages
152
+ if hasattr(self.params, 'message'):
153
+ return self.params.message.parts
154
+ if hasattr(self, 'message'):
155
+ return self.message.parts
156
+ except AttributeError:
157
+ pass
158
+ return None
159
+
160
+ ## RPC Messages
161
+
162
+
163
+ class JSONRPCMessage(BaseModel):
164
+ jsonrpc: Literal["2.0"] = "2.0"
165
+ id: int | str | None = Field(default_factory=lambda: uuid4().hex)
166
+
167
+
168
+ class JSONRPCRequest(JSONRPCMessage):
169
+ method: str
170
+ params: dict[str, Any] | None = None
171
+
172
+
173
+ class JSONRPCError(BaseModel):
174
+ code: int
175
+ message: str
176
+ data: Any | None = None
177
+
178
+
179
+ class JSONRPCResponse(JSONRPCMessage):
180
+ result: Any | None = None
181
+ error: JSONRPCError | None = None
182
+
183
+
184
+ class SendTaskRequest(JSONRPCRequest, ContentMixin):
185
+ method: Literal["tasks/send"] = "tasks/send"
186
+ params: TaskSendParams
187
+
188
+
189
+ class SendTaskResponse(JSONRPCResponse):
190
+ result: Task | None = None
191
+
192
+
193
+ class SendTaskStreamingRequest(JSONRPCRequest, ContentMixin):
194
+ method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe"
195
+ params: TaskSendParams
196
+
197
+
198
+ class SendTaskStreamingResponse(JSONRPCResponse):
199
+ result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None
200
+
201
+
202
+ class GetTaskRequest(JSONRPCRequest, ContentMixin):
203
+ method: Literal["tasks/get"] = "tasks/get"
204
+ params: TaskQueryParams
205
+
206
+
207
+ class GetTaskResponse(JSONRPCResponse):
208
+ result: Task | None = None
209
+
210
+
211
+ class CancelTaskRequest(JSONRPCRequest, ContentMixin):
212
+ method: Literal["tasks/cancel",] = "tasks/cancel"
213
+ params: TaskIdParams
214
+
215
+
216
+ class CancelTaskResponse(JSONRPCResponse):
217
+ result: Task | None = None
218
+
219
+
220
+ class SetTaskPushNotificationRequest(JSONRPCRequest, ContentMixin):
221
+ method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set"
222
+ params: TaskPushNotificationConfig
223
+
224
+
225
+ class SetTaskPushNotificationResponse(JSONRPCResponse):
226
+ result: TaskPushNotificationConfig | None = None
227
+
228
+
229
+ class GetTaskPushNotificationRequest(JSONRPCRequest, ContentMixin):
230
+ method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get"
231
+ params: TaskIdParams
232
+
233
+
234
+ class GetTaskPushNotificationResponse(JSONRPCResponse):
235
+ result: TaskPushNotificationConfig | None = None
236
+
237
+
238
+ class TaskResubscriptionRequest(JSONRPCRequest, ContentMixin):
239
+ method: Literal["tasks/resubscribe",] = "tasks/resubscribe"
240
+ params: TaskIdParams
241
+
242
+
243
+ A2ARequest = TypeAdapter(
244
+ Annotated[
245
+ Union[
246
+ SendTaskRequest,
247
+ GetTaskRequest,
248
+ CancelTaskRequest,
249
+ SetTaskPushNotificationRequest,
250
+ GetTaskPushNotificationRequest,
251
+ TaskResubscriptionRequest,
252
+ SendTaskStreamingRequest,
253
+ ],
254
+ Field(discriminator="method"),
255
+ ]
256
+ )
257
+
258
+ ## Error types
259
+
260
+
261
+ class JSONParseError(JSONRPCError):
262
+ code: int = -32700
263
+ message: str = "Invalid JSON payload"
264
+ data: Any | None = None
265
+
266
+
267
+ class InvalidRequestError(JSONRPCError):
268
+ code: int = -32600
269
+ message: str = "Request payload validation error"
270
+ data: Any | None = None
271
+
272
+
273
+ class MethodNotFoundError(JSONRPCError):
274
+ code: int = -32601
275
+ message: str = "Method not found"
276
+ data: None = None
277
+
278
+
279
+ class InvalidParamsError(JSONRPCError):
280
+ code: int = -32602
281
+ message: str = "Invalid parameters"
282
+ data: Any | None = None
283
+
284
+
285
+ class InternalError(JSONRPCError):
286
+ code: int = -32603
287
+ message: str = "Internal error"
288
+ data: Any | None = None
289
+
290
+
291
+ class TaskNotFoundError(JSONRPCError):
292
+ code: int = -32001
293
+ message: str = "Task not found"
294
+ data: None = None
295
+
296
+
297
+ class TaskNotCancelableError(JSONRPCError):
298
+ code: int = -32002
299
+ message: str = "Task cannot be canceled"
300
+ data: None = None
301
+
302
+
303
+ class PushNotificationNotSupportedError(JSONRPCError):
304
+ code: int = -32003
305
+ message: str = "Push Notification is not supported"
306
+ data: None = None
307
+
308
+
309
+ class UnsupportedOperationError(JSONRPCError):
310
+ code: int = -32004
311
+ message: str = "This operation is not supported"
312
+ data: None = None
313
+
314
+
315
+ class ContentTypeNotSupportedError(JSONRPCError):
316
+ code: int = -32005
317
+ message: str = "Incompatible content types"
318
+ data: None = None
319
+
320
+
321
+ class AgentProvider(BaseModel):
322
+ organization: str
323
+ url: str | None = None
324
+
325
+
326
+ class AgentCapabilities(BaseModel):
327
+ streaming: bool = False
328
+ pushNotifications: bool = False
329
+ stateTransitionHistory: bool = False
330
+
331
+
332
+ class AgentAuthentication(BaseModel):
333
+ schemes: List[str]
334
+ credentials: str | None = None
335
+
336
+
337
+ class AgentSkill(BaseModel):
338
+ id: str
339
+ name: str
340
+ description: str | None = None
341
+ tags: List[str] | None = None
342
+ examples: List[str] | None = None
343
+ inputModes: List[str] | None = None
344
+ outputModes: List[str] | None = None
345
+
346
+
347
+ class AgentCard(BaseModel):
348
+ name: str
349
+ description: str | None = None
350
+ url: str
351
+ provider: AgentProvider | None = None
352
+ version: str
353
+ documentationUrl: str | None = None
354
+ capabilities: AgentCapabilities
355
+ authentication: AgentAuthentication | None = None
356
+ defaultInputModes: List[str] = ["text"]
357
+ defaultOutputModes: List[str] = ["text"]
358
+ skills: List[AgentSkill]
359
+
360
+
361
+ class A2AClientError(Exception):
362
+ pass
363
+
364
+
365
+ class A2AClientHTTPError(A2AClientError):
366
+ def __init__(self, status_code: int, message: str):
367
+ self.status_code = status_code
368
+ self.message = message
369
+ super().__init__(f"HTTP Error {status_code}: {message}")
370
+
371
+
372
+ class A2AClientJSONError(A2AClientError):
373
+ def __init__(self, message: str):
374
+ self.message = message
375
+ super().__init__(f"JSON Error: {message}")
376
+
377
+
378
+ class MissingAPIKeyError(Exception):
379
+ """Exception for missing API key."""
380
+
381
+ pass
382
+
383
+ """
384
+ Beyond this point, the types are not part of the A2A protocol.
385
+ They are used to help with the implementation of the server.
386
+ """
387
+
388
+ class A2AResponse(BaseModel):
389
+ status: Union[TaskStatus, str]
390
+ content: Union[str, List[Any], Part, Artifact, List[Part], List[Artifact]]
391
+
392
+ @model_validator(mode="after")
393
+ def validate_state(self) -> 'A2AResponse':
394
+ if isinstance(self.status, str):
395
+ try:
396
+ self.status = TaskStatus(state=self.status.lower())
397
+ except ValueError:
398
+ raise ValueError(f"Invalid state: {self.status}")
399
+ return self
400
+
401
+ class A2AStatus(BaseModel):
402
+ status: str
403
+ metadata: dict[str, Any] | None = None
404
+ final: bool = False
405
+
406
+ @field_validator('status')
407
+ def validate_status(cls, v):
408
+ valid_states = {e.value for e in TaskState}
409
+ if v.lower() not in valid_states:
410
+ raise ValueError(f"Invalid status: {v}. Valid states: {valid_states}")
411
+ return v.lower()
412
+
413
+ @field_validator('final', mode='after')
414
+ def set_final_for_completed(cls, v, values):
415
+ if values.data.get('status') == TaskState.COMPLETED:
416
+ return True
417
+ return v
418
+
419
+ class A2AStreamResponse(BaseModel):
420
+ content: Union[str, Part, List[Union[str, Part]], Artifact]
421
+ index: int = 0
422
+ append: bool = False
423
+ final: bool = False
424
+ metadata: dict[str, Any] | None = None
@@ -0,0 +1,97 @@
1
+ Metadata-Version: 2.4
2
+ Name: smarta2a
3
+ Version: 0.1.0
4
+ Summary: A Python package for creating servers and clients following Google's Agent2Agent protocol
5
+ Project-URL: Homepage, https://github.com/siddharthsma/fasta2a
6
+ Project-URL: Bug Tracker, https://github.com/siddharthsma/fasta2a/issues
7
+ Author-email: Siddharth Ambegaonkar <siddharthsma@gmail.com>
8
+ License-File: LICENSE
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Programming Language :: Python :: 3
12
+ Requires-Python: >=3.8
13
+ Requires-Dist: fastapi
14
+ Requires-Dist: pydantic
15
+ Requires-Dist: sse-starlette
16
+ Requires-Dist: uvicorn
17
+ Description-Content-Type: text/markdown
18
+
19
+ # FastA2A
20
+
21
+ A Python package for creating a server following Google's Agent2Agent protocol
22
+
23
+ ## Features
24
+
25
+ ✅ **Full A2A Protocol Compliance** - Implements all required endpoints and response formats
26
+
27
+ ⚡ **Decorator-Driven Development** - Rapid endpoint configuration with type safety
28
+
29
+ 🧩 **Automatic Protocol Conversion** - Simple returns become valid A2A responses
30
+
31
+ 🔀 **Flexible Response Handling** - Support for Tasks, Artifacts, Streaming, and raw protocol types if needed!
32
+
33
+ 🛡️ **Built-in Validation** - Automatic Pydantic validation of A2A schemas
34
+
35
+ ⚡ **Single File Setup** - Get compliant in <10 lines of code
36
+
37
+ 🌍 **Production Ready** - CORS, async support, and error handling included
38
+
39
+ ## Installation
40
+
41
+ ```bash
42
+ pip install fasta2a
43
+ ```
44
+
45
+ ## Simple Echo Server Implementation
46
+
47
+ ```python
48
+ from fasta2a import FastA2A
49
+
50
+ app = FastA2A("EchoServer")
51
+
52
+ @app.on_send_task()
53
+ def handle_task(request):
54
+ """Echo the input text back as a completed task"""
55
+ input_text = request.content[0].text
56
+ return f"Echo: {input_text}"
57
+
58
+ if __name__ == "__main__":
59
+ app.run()
60
+ ```
61
+
62
+ Automatically contructs the response:
63
+
64
+ ```json
65
+ {
66
+ "jsonrpc": "2.0",
67
+ "id": "test",
68
+ "result": {
69
+ "id": "echo-task",
70
+ "status": {"state": "completed"},
71
+ "artifacts": [{
72
+ "parts": [{"type": "text", "text": "Echo: Hello!"}]
73
+ }]
74
+ }
75
+ }
76
+ ```
77
+
78
+ ## Development
79
+
80
+ To set up the development environment:
81
+
82
+ ```bash
83
+ # Clone the repository
84
+ git clone https://github.com/siddharthsma/fasta2a.git
85
+ cd fasta2a
86
+
87
+ # Create and activate virtual environment
88
+ python -m venv venv
89
+ source venv/bin/activate # On Windows: venv\Scripts\activate
90
+
91
+ # Install development dependencies
92
+ pip install -e ".[dev]"
93
+ ```
94
+
95
+ ## License
96
+
97
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
@@ -0,0 +1,7 @@
1
+ fasta2a/__init__.py,sha256=lW8fJ0XHZJVZC4Oy18UxJjxtxuSco878tV6wAKiCzw0,150
2
+ fasta2a/server.py,sha256=Ge0eh6Go8P9LUmQwmFysTx2YSvRoWk3UIkVh7o-mAHU,26115
3
+ fasta2a/types.py,sha256=_UuFtOsnHIIqfQ2m_FiIBBp141iYmhpPGgxE0jmHSHg,10807
4
+ smarta2a-0.1.0.dist-info/METADATA,sha256=a8EA5KXN0YIwnkpoLSOJMh2gi8Vpn0XLes6lWvlYa0Y,2469
5
+ smarta2a-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ smarta2a-0.1.0.dist-info/licenses/LICENSE,sha256=ECMEVHuFkvpEmH-_A9HSxs_UnnsUqpCkiAYNHPCf2z0,1078
7
+ smarta2a-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Siddharth Ambegaonkar
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.