smarta2a 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. smarta2a/__init__.py +1 -1
  2. smarta2a/agent/a2a_agent.py +38 -0
  3. smarta2a/agent/a2a_mcp_server.py +37 -0
  4. smarta2a/archive/mcp_client.py +86 -0
  5. smarta2a/client/a2a_client.py +97 -3
  6. smarta2a/client/smart_mcp_client.py +60 -0
  7. smarta2a/client/tools_manager.py +58 -0
  8. smarta2a/examples/__init__.py +0 -0
  9. smarta2a/examples/echo_server/__init__.py +0 -0
  10. smarta2a/examples/echo_server/curl.txt +1 -0
  11. smarta2a/examples/echo_server/main.py +37 -0
  12. smarta2a/history_update_strategies/__init__.py +8 -0
  13. smarta2a/history_update_strategies/append_strategy.py +10 -0
  14. smarta2a/history_update_strategies/history_update_strategy.py +15 -0
  15. smarta2a/model_providers/__init__.py +5 -0
  16. smarta2a/model_providers/base_llm_provider.py +15 -0
  17. smarta2a/model_providers/openai_provider.py +281 -0
  18. smarta2a/server/handler_registry.py +23 -0
  19. smarta2a/server/server.py +233 -255
  20. smarta2a/server/state_manager.py +34 -0
  21. smarta2a/server/subscription_service.py +109 -0
  22. smarta2a/server/task_service.py +155 -0
  23. smarta2a/state_stores/__init__.py +8 -0
  24. smarta2a/state_stores/base_state_store.py +20 -0
  25. smarta2a/state_stores/inmemory_state_store.py +21 -0
  26. smarta2a/utils/prompt_helpers.py +38 -0
  27. smarta2a/utils/task_builder.py +153 -0
  28. smarta2a/{common → utils}/task_request_builder.py +1 -1
  29. smarta2a/{common → utils}/types.py +62 -2
  30. {smarta2a-0.2.2.dist-info → smarta2a-0.2.4.dist-info}/METADATA +12 -6
  31. smarta2a-0.2.4.dist-info/RECORD +36 -0
  32. smarta2a-0.2.2.dist-info/RECORD +0 -12
  33. /smarta2a/{common → utils}/__init__.py +0 -0
  34. {smarta2a-0.2.2.dist-info → smarta2a-0.2.4.dist-info}/WHEEL +0 -0
  35. {smarta2a-0.2.2.dist-info → smarta2a-0.2.4.dist-info}/licenses/LICENSE +0 -0
smarta2a/server/server.py CHANGED
@@ -1,3 +1,4 @@
1
+ # Library imports
1
2
  from typing import Callable, Any, Optional, Dict, Union, List, AsyncGenerator
2
3
  import json
3
4
  from datetime import datetime
@@ -10,7 +11,15 @@ import uvicorn
10
11
  from fastapi.responses import StreamingResponse
11
12
  from uuid import uuid4
12
13
 
13
- from smarta2a.common.types import (
14
+ # Local imports
15
+ from smarta2a.server.handler_registry import HandlerRegistry
16
+ from smarta2a.server.state_manager import StateManager
17
+ from smarta2a.state_stores.base_state_store import BaseStateStore
18
+ from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
19
+ from smarta2a.history_update_strategies.append_strategy import AppendStrategy
20
+ from smarta2a.utils.task_builder import TaskBuilder
21
+
22
+ from smarta2a.utils.types import (
14
23
  JSONRPCResponse,
15
24
  Task,
16
25
  Artifact,
@@ -19,6 +28,7 @@ from smarta2a.common.types import (
19
28
  FileContent,
20
29
  DataPart,
21
30
  Part,
31
+ Message,
22
32
  TaskStatus,
23
33
  TaskState,
24
34
  JSONRPCError,
@@ -50,41 +60,78 @@ from smarta2a.common.types import (
50
60
  SetTaskPushNotificationResponse,
51
61
  GetTaskPushNotificationResponse,
52
62
  TaskPushNotificationConfig,
63
+ StateData
53
64
  )
54
65
 
55
66
  class SmartA2A:
56
- def __init__(self, name: str, **fastapi_kwargs):
67
+ def __init__(self, name: str, state_store: Optional[BaseStateStore] = None, history_strategy: HistoryUpdateStrategy = AppendStrategy(), **fastapi_kwargs):
57
68
  self.name = name
58
- self.handlers: Dict[str, Callable] = {}
59
- self.subscriptions: Dict[str, Callable] = {}
69
+ self.registry = HandlerRegistry()
70
+ self.state_mgr = StateManager(state_store, history_strategy)
60
71
  self.app = FastAPI(title=name, **fastapi_kwargs)
61
72
  self.router = APIRouter()
62
- self._registered_decorators = set()
73
+ self.state_store = state_store
74
+ self.history_strategy = history_strategy
63
75
  self._setup_routes()
64
76
  self.server_config = {
65
77
  "host": "0.0.0.0",
66
78
  "port": 8000,
67
79
  "reload": False
68
80
  }
81
+ self.task_builder = TaskBuilder(default_status=TaskState.COMPLETED)
82
+
83
+ # Add this method to delegate ASGI calls
84
+ async def __call__(self, scope, receive, send):
85
+ return await self.app(scope, receive, send)
69
86
 
87
+ def on_send_task(self):
88
+ def decorator(func: Callable[[SendTaskRequest, Optional[StateData]], Any]) -> Callable:
89
+ self.registry.register("tasks/send", func)
90
+ return func
91
+ return decorator
92
+
93
+ def on_send_subscribe_task(self):
94
+ def decorator(fn: Callable[[SendTaskStreamingRequest, Optional[StateData]], Any]):
95
+ self.registry.register("tasks/sendSubscribe", fn, subscription=True)
96
+ return fn
97
+ return decorator
98
+
99
+ def task_get(self):
100
+ def decorator(fn: Callable[[GetTaskRequest], Any]):
101
+ self.registry.register("tasks/get", fn)
102
+ return fn
103
+ return decorator
104
+
105
+ def task_cancel(self):
106
+ def decorator(fn: Callable[[CancelTaskRequest], Any]):
107
+ self.registry.register("tasks/cancel", fn)
108
+ return fn
109
+ return decorator
110
+
111
+ def set_notification(self):
112
+ def decorator(fn: Callable[[SetTaskPushNotificationRequest], Any]):
113
+ self.registry.register("tasks/pushNotification/set", fn)
114
+ return fn
115
+ return decorator
116
+
117
+ def get_notification(self):
118
+ def decorator(fn: Callable[[GetTaskPushNotificationRequest], Any]):
119
+ self.registry.register("tasks/pushNotification/get", fn)
120
+ return fn
121
+ return decorator
122
+
70
123
 
71
124
  def _setup_routes(self):
72
125
  @self.app.post("/")
73
126
  async def handle_request(request: Request):
74
127
  try:
75
128
  data = await request.json()
76
- request_obj = JSONRPCRequest(**data)
129
+ req = JSONRPCRequest.model_validate(data)
130
+ #request_obj = JSONRPCRequest(**data)
77
131
  except Exception as e:
78
- return JSONRPCResponse(
79
- id=None,
80
- error=JSONRPCError(
81
- code=-32700,
82
- message="Parse error",
83
- data=str(e)
84
- )
85
- ).model_dump()
86
-
87
- response = await self.process_request(request_obj.model_dump())
132
+ return JSONRPCResponse(id=None, error=JSONRPCError(code=-32700, message="Parse error", data=str(e))).model_dump()
133
+
134
+ response = await self.process_request(req)
88
135
 
89
136
  # <-- Accept both SSE‐style responses:
90
137
  if isinstance(response, (EventSourceResponse, StreamingResponse)):
@@ -92,114 +139,117 @@ class SmartA2A:
92
139
 
93
140
  # <-- Everything else is a normal pydantic JSONRPCResponse
94
141
  return response.model_dump()
95
-
96
- def _register_handler(self, method: str, func: Callable, handler_name: str, handler_type: str = "handler"):
97
- """Shared registration logic with duplicate checking"""
98
- if method in self._registered_decorators:
99
- raise RuntimeError(
100
- f"@{handler_name} decorator for method '{method}' "
101
- f"can only be used once per SmartA2A instance"
102
- )
103
-
104
- if handler_type == "handler":
105
- self.handlers[method] = func
106
- else:
107
- self.subscriptions[method] = func
108
-
109
- self._registered_decorators.add(method)
110
-
111
- def on_send_task(self) -> Callable:
112
- def decorator(func: Callable[[SendTaskRequest], Any]) -> Callable:
113
- self._register_handler("tasks/send", func, "on_send_task", "handler")
114
- return func
115
- return decorator
116
-
117
- def on_send_subscribe_task(self) -> Callable:
118
- def decorator(func: Callable) -> Callable:
119
- self._register_handler("tasks/sendSubscribe", func, "on_send_subscribe_task", "subscription")
120
- return func
121
- return decorator
122
-
123
- def task_get(self):
124
- def decorator(func: Callable[[GetTaskRequest], Task]):
125
- self._register_handler("tasks/get", func, "task_get", "handler")
126
- return func
127
- return decorator
128
142
 
129
- def task_cancel(self):
130
- def decorator(func: Callable[[CancelTaskRequest], Task]):
131
- self._register_handler("tasks/cancel", func, "task_cancel", "handler")
132
- return func
133
- return decorator
134
-
135
- def set_notification(self):
136
- def decorator(func: Callable[[SetTaskPushNotificationRequest], None]) -> Callable:
137
- self._register_handler("tasks/pushNotification/set", func, "set_notification", "handler")
138
- return func
139
- return decorator
140
-
141
- def get_notification(self):
142
- def decorator(func: Callable[[GetTaskPushNotificationRequest], Union[TaskPushNotificationConfig, GetTaskPushNotificationResponse]]):
143
- self._register_handler("tasks/pushNotification/get", func, "get_notification", "handler")
144
- return func
145
- return decorator
146
143
 
147
- async def process_request(self, request_data: dict) -> JSONRPCResponse:
144
+ async def process_request(self, request: JSONRPCRequest) -> JSONRPCResponse:
145
+
148
146
  try:
149
- method = request_data.get("method")
147
+ method = request.method
148
+ params = request.params
149
+ state_store = self.state_mgr.get_store()
150
150
  if method == "tasks/send":
151
- return self._handle_send_task(request_data)
151
+ state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
152
+ if state_store:
153
+ return self._handle_send_task(request, state_data)
154
+ else:
155
+ return self._handle_send_task(request)
152
156
  elif method == "tasks/sendSubscribe":
153
- return await self._handle_subscribe_task(request_data)
157
+ state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
158
+ if state_store:
159
+ return await self._handle_subscribe_task(request, state_data)
160
+ else:
161
+ return await self._handle_subscribe_task(request)
154
162
  elif method == "tasks/get":
155
- return self._handle_get_task(request_data)
163
+ return self._handle_get_task(request)
156
164
  elif method == "tasks/cancel":
157
- return self._handle_cancel_task(request_data)
165
+ return self._handle_cancel_task(request)
158
166
  elif method == "tasks/pushNotification/set":
159
- return self._handle_set_notification(request_data)
167
+ return self._handle_set_notification(request)
160
168
  elif method == "tasks/pushNotification/get":
161
- return self._handle_get_notification(request_data)
169
+ return self._handle_get_notification(request)
162
170
  else:
163
- return self._error_response(
164
- request_data.get("id"),
165
- -32601,
166
- "Method not found"
167
- )
171
+ return JSONRPCResponse(id=request.id, error=MethodNotFoundError()).model_dump()
168
172
  except ValidationError as e:
169
- return self._error_response(
170
- request_data.get("id"),
171
- -32600,
172
- "Invalid params",
173
- e.errors()
174
- )
173
+ return JSONRPCResponse(id=request.id, error=InvalidParamsError(data=e.errors())).model_dump()
174
+ except HTTPException as e:
175
+ err = UnsupportedOperationError() if e.status_code == 405 else InternalError(data=str(e))
176
+ return JSONRPCResponse(id=request.id, error=err).model_dump()
177
+
175
178
 
176
- def _handle_send_task(self, request_data: dict) -> SendTaskResponse:
179
+ def _handle_send_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> SendTaskResponse:
177
180
  try:
178
181
  # Validate request format
179
- request = SendTaskRequest.model_validate(request_data)
180
- handler = self.handlers.get("tasks/send")
182
+ request = SendTaskRequest.model_validate(request_data.model_dump())
183
+ handler = self.registry.get_handler("tasks/send")
181
184
 
182
185
  if not handler:
183
186
  return SendTaskResponse(
184
187
  id=request.id,
185
188
  error=MethodNotFoundError()
186
189
  )
190
+
191
+ user_message = request.params.message
192
+ request_metadata = request.params.metadata or {}
193
+ if state_data:
194
+ session_id = state_data.sessionId
195
+ existing_history = state_data.history.copy() or []
196
+ metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
197
+ else:
198
+ session_id = request.params.sessionId or str(uuid4())
199
+ existing_history = [user_message]
200
+ metadata = request_metadata
201
+
187
202
 
188
203
  try:
189
- raw_result = handler(request)
190
-
204
+
205
+ if state_data:
206
+ raw_result = handler(request, state_data)
207
+ else:
208
+ raw_result = handler(request)
209
+
210
+ # Handle direct SendTaskResponse returns
191
211
  if isinstance(raw_result, SendTaskResponse):
192
212
  return raw_result
193
213
 
194
- # Use unified task builder
195
- task = self._build_task(
214
+ # Build task with updated history (before agent response)
215
+ task = self.task_builder.build(
196
216
  content=raw_result,
197
217
  task_id=request.params.id,
198
- session_id=request.params.sessionId,
199
- default_status=TaskState.COMPLETED,
200
- metadata=request.params.metadata or {}
218
+ session_id=session_id, # Always use generated session ID
219
+ metadata=metadata, # Use merged metadata
220
+ history=existing_history # History
201
221
  )
202
222
 
223
+ # Process messages through strategy
224
+ messages = []
225
+ if task.artifacts:
226
+ agent_parts = [p for a in task.artifacts for p in a.parts]
227
+ agent_message = Message(
228
+ role="agent",
229
+ parts=agent_parts,
230
+ metadata=task.metadata
231
+ )
232
+ messages.append(agent_message)
233
+
234
+ final_history = self.history_strategy.update_history(
235
+ existing_history=existing_history,
236
+ new_messages=messages
237
+ )
238
+
239
+ # Update task with final state
240
+ task.history = final_history
241
+
242
+ # State store update (if enabled)
243
+ if self.state_store:
244
+ self.state_store.update_state(
245
+ session_id=session_id,
246
+ state_data=StateData(
247
+ sessionId=session_id,
248
+ history=final_history,
249
+ metadata=metadata # Use merged metadata
250
+ )
251
+ )
252
+
203
253
  return SendTaskResponse(
204
254
  id=request.id,
205
255
  result=task
@@ -229,10 +279,11 @@ class SmartA2A:
229
279
  )
230
280
 
231
281
 
232
- async def _handle_subscribe_task(self, request_data: dict) -> Union[EventSourceResponse, SendTaskStreamingResponse]:
282
+ async def _handle_subscribe_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> Union[EventSourceResponse, SendTaskStreamingResponse]:
233
283
  try:
234
- request = SendTaskStreamingRequest.model_validate(request_data)
235
- handler = self.subscriptions.get("tasks/sendSubscribe")
284
+ request = SendTaskStreamingRequest.model_validate(request_data.model_dump())
285
+ #handler = self.subscriptions.get("tasks/sendSubscribe")
286
+ handler = self.registry.get_subscription("tasks/sendSubscribe")
236
287
 
237
288
  if not handler:
238
289
  return SendTaskStreamingResponse(
@@ -240,15 +291,74 @@ class SmartA2A:
240
291
  id=request.id,
241
292
  error=MethodNotFoundError()
242
293
  )
294
+
295
+ user_message = request.params.message
296
+ request_metadata = request.params.metadata or {}
297
+ if state_data:
298
+ session_id = state_data.sessionId
299
+ existing_history = state_data.history.copy() or []
300
+ metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
301
+ else:
302
+ session_id = request.params.sessionId or str(uuid4())
303
+ existing_history = [user_message]
304
+ metadata = request_metadata
305
+
243
306
 
244
307
  async def event_generator():
245
308
 
246
309
  try:
247
- raw_events = handler(request)
310
+
311
+ if state_data:
312
+ raw_events = handler(request, state_data)
313
+ else:
314
+ raw_events = handler(request)
315
+
248
316
  normalized_events = self._normalize_subscription_events(request.params, raw_events)
249
317
 
318
+ # Initialize streaming state
319
+ stream_history = existing_history.copy()
320
+ stream_metadata = metadata.copy()
321
+
250
322
  async for item in normalized_events:
251
323
  try:
324
+
325
+ # Process artifact updates
326
+ if isinstance(item, TaskArtifactUpdateEvent):
327
+ # Create agent message from artifact parts
328
+ agent_message = Message(
329
+ role="agent",
330
+ parts=[p for p in item.artifact.parts],
331
+ metadata=item.artifact.metadata
332
+ )
333
+
334
+ # Update history using strategy
335
+ new_history = self.history_strategy.update_history(
336
+ existing_history=stream_history,
337
+ new_messages=[agent_message]
338
+ )
339
+
340
+ # Merge metadata
341
+ new_metadata = {
342
+ **stream_metadata,
343
+ **(item.artifact.metadata or {})
344
+ }
345
+
346
+ # Update state store if configured
347
+ if self.state_store:
348
+ self.state_store.update_state(
349
+ session_id=session_id,
350
+ state_data=StateData(
351
+ sessionId=session_id,
352
+ history=new_history,
353
+ metadata=new_metadata
354
+ )
355
+ )
356
+
357
+ # Update streaming state
358
+ stream_history = new_history
359
+ stream_metadata = new_metadata
360
+
361
+
252
362
  if isinstance(item, SendTaskStreamingResponse):
253
363
  yield item.model_dump_json()
254
364
  continue
@@ -318,11 +428,11 @@ class SmartA2A:
318
428
  )
319
429
 
320
430
 
321
- def _handle_get_task(self, request_data: dict) -> GetTaskResponse:
431
+ def _handle_get_task(self, request_data: JSONRPCRequest) -> GetTaskResponse:
322
432
  try:
323
433
  # Validate request structure
324
- request = GetTaskRequest.model_validate(request_data)
325
- handler = self.handlers.get("tasks/get")
434
+ request = GetTaskRequest.model_validate(request_data.model_dump())
435
+ handler = self.registry.get_handler("tasks/get")
326
436
 
327
437
  if not handler:
328
438
  return GetTaskResponse(
@@ -337,11 +447,10 @@ class SmartA2A:
337
447
  return self._validate_response_id(raw_result, request)
338
448
 
339
449
  # Use unified task builder with different defaults
340
- task = self._build_task(
450
+ task = self.task_builder.build(
341
451
  content=raw_result,
342
452
  task_id=request.params.id,
343
- default_status=TaskState.COMPLETED,
344
- metadata=request.params.metadata or {}
453
+ metadata=getattr(raw_result, "metadata", {}) or {}
345
454
  )
346
455
 
347
456
  return self._finalize_task_response(request, task)
@@ -370,11 +479,11 @@ class SmartA2A:
370
479
  )
371
480
 
372
481
 
373
- def _handle_cancel_task(self, request_data: dict) -> CancelTaskResponse:
482
+ def _handle_cancel_task(self, request_data: JSONRPCRequest) -> CancelTaskResponse:
374
483
  try:
375
484
  # Validate request structure
376
- request = CancelTaskRequest.model_validate(request_data)
377
- handler = self.handlers.get("tasks/cancel")
485
+ request = CancelTaskRequest.model_validate(request_data.model_dump())
486
+ handler = self.registry.get_handler("tasks/cancel")
378
487
 
379
488
  if not handler:
380
489
  return CancelTaskResponse(
@@ -385,24 +494,26 @@ class SmartA2A:
385
494
  try:
386
495
  raw_result = handler(request)
387
496
 
497
+ cancel_task_builder = TaskBuilder(default_status=TaskState.CANCELED)
388
498
  # Handle direct CancelTaskResponse returns
389
499
  if isinstance(raw_result, CancelTaskResponse):
390
500
  return self._validate_response_id(raw_result, request)
391
501
 
392
502
  # Handle A2AStatus returns
393
503
  if isinstance(raw_result, A2AStatus):
394
- task = self._build_task_from_status(
395
- status=raw_result,
504
+ task = cancel_task_builder.normalize_from_status(
505
+ status=raw_result.status,
396
506
  task_id=request.params.id,
397
- metadata=raw_result.metadata or {}
507
+ metadata=getattr(raw_result, "metadata", {}) or {}
398
508
  )
399
509
  else:
400
510
  # Existing processing for other return types
401
- task = self._build_task(
511
+ task = cancel_task_builder.build(
402
512
  content=raw_result,
403
513
  task_id=request.params.id,
404
- metadata=raw_result.metadata or {}
514
+ metadata=getattr(raw_result, "metadata", {}) or {}
405
515
  )
516
+ print(task)
406
517
 
407
518
  # Final validation and packaging
408
519
  return self._finalize_cancel_response(request, task)
@@ -440,10 +551,10 @@ class SmartA2A:
440
551
  error=InternalError(data=str(e))
441
552
  )
442
553
 
443
- def _handle_set_notification(self, request_data: dict) -> SetTaskPushNotificationResponse:
554
+ def _handle_set_notification(self, request_data: JSONRPCRequest) -> SetTaskPushNotificationResponse:
444
555
  try:
445
- request = SetTaskPushNotificationRequest.model_validate(request_data)
446
- handler = self.handlers.get("tasks/pushNotification/set")
556
+ request = SetTaskPushNotificationRequest.model_validate(request_data.model_dump())
557
+ handler = self.registry.get_handler("tasks/pushNotification/set")
447
558
 
448
559
  if not handler:
449
560
  return SetTaskPushNotificationResponse(
@@ -485,10 +596,10 @@ class SmartA2A:
485
596
  )
486
597
 
487
598
 
488
- def _handle_get_notification(self, request_data: dict) -> GetTaskPushNotificationResponse:
599
+ def _handle_get_notification(self, request_data: JSONRPCRequest) -> GetTaskPushNotificationResponse:
489
600
  try:
490
- request = GetTaskPushNotificationRequest.model_validate(request_data)
491
- handler = self.handlers.get("tasks/pushNotification/get")
601
+ request = GetTaskPushNotificationRequest.model_validate(request_data.model_dump())
602
+ handler = self.registry.get_handler("tasks/pushNotification/get")
492
603
 
493
604
  if not handler:
494
605
  return GetTaskPushNotificationResponse(
@@ -535,139 +646,6 @@ class SmartA2A:
535
646
  error=JSONParseError(data=str(e))
536
647
  )
537
648
 
538
-
539
- def _normalize_artifacts(self, content: Any) -> List[Artifact]:
540
- """Handle both A2AResponse content and regular returns"""
541
- if isinstance(content, Artifact):
542
- return [content]
543
-
544
- if isinstance(content, list):
545
- # Handle list of artifacts
546
- if all(isinstance(item, Artifact) for item in content):
547
- return content
548
-
549
- # Handle mixed parts in list
550
- parts = []
551
- for item in content:
552
- if isinstance(item, Artifact):
553
- parts.extend(item.parts)
554
- else:
555
- parts.append(self._create_part(item))
556
- return [Artifact(parts=parts)]
557
-
558
- # Handle single part returns
559
- if isinstance(content, (str, Part, dict)):
560
- return [Artifact(parts=[self._create_part(content)])]
561
-
562
- # Handle raw artifact dicts
563
- try:
564
- return [Artifact.model_validate(content)]
565
- except ValidationError:
566
- return [Artifact(parts=[TextPart(text=str(content))])]
567
-
568
-
569
- def _build_task(
570
- self,
571
- content: Any,
572
- task_id: str,
573
- session_id: Optional[str] = None,
574
- default_status: TaskState = TaskState.COMPLETED,
575
- metadata: Optional[dict] = None
576
- ) -> Task:
577
- """Universal task construction from various return types."""
578
- if isinstance(content, Task):
579
- return content
580
-
581
- # Handle A2AResponse for sendTask case
582
- if isinstance(content, A2AResponse):
583
- status = content.status if isinstance(content.status, TaskStatus) \
584
- else TaskStatus(state=content.status)
585
- artifacts = self._normalize_content(content.content)
586
- return Task(
587
- id=task_id,
588
- sessionId=session_id or str(uuid4()), # Generate if missing
589
- status=status,
590
- artifacts=artifacts,
591
- metadata=metadata or {}
592
- )
593
-
594
- try: # Attempt direct validation for dicts
595
- return Task.model_validate(content)
596
- except ValidationError:
597
- pass
598
-
599
- # Fallback to content normalization
600
- artifacts = self._normalize_content(content)
601
- return Task(
602
- id=task_id,
603
- sessionId=session_id,
604
- status=TaskStatus(state=default_status),
605
- artifacts=artifacts,
606
- metadata=metadata or {}
607
- )
608
-
609
- def _build_task_from_status(self, status: A2AStatus, task_id: str, metadata: dict) -> Task:
610
- """Convert A2AStatus to a Task with proper cancellation state."""
611
- return Task(
612
- id=task_id,
613
- status=TaskStatus(
614
- state=TaskState(status.status),
615
- timestamp=datetime.now()
616
- ),
617
- metadata=metadata,
618
- # Include empty/default values for required fields
619
- sessionId="",
620
- artifacts=[],
621
- history=[]
622
- )
623
-
624
-
625
- def _normalize_content(self, content: Any) -> List[Artifact]:
626
- """Handle all content types for both sendTask and getTask cases."""
627
- if isinstance(content, Artifact):
628
- return [content]
629
-
630
- if isinstance(content, list):
631
- if all(isinstance(item, Artifact) for item in content):
632
- return content
633
- return [Artifact(parts=self._parts_from_mixed(content))]
634
-
635
- if isinstance(content, (str, Part, dict)):
636
- return [Artifact(parts=[self._create_part(content)])]
637
-
638
- try: # Handle raw artifact dicts
639
- return [Artifact.model_validate(content)]
640
- except ValidationError:
641
- return [Artifact(parts=[TextPart(text=str(content))])]
642
-
643
- def _parts_from_mixed(self, items: List[Any]) -> List[Part]:
644
- """Extract parts from mixed content lists."""
645
- parts = []
646
- for item in items:
647
- if isinstance(item, Artifact):
648
- parts.extend(item.parts)
649
- else:
650
- parts.append(self._create_part(item))
651
- return parts
652
-
653
-
654
- def _create_part(self, item: Any) -> Part:
655
- """Convert primitive types to proper Part models"""
656
- if isinstance(item, (TextPart, FilePart, DataPart)):
657
- return item
658
-
659
- if isinstance(item, str):
660
- return TextPart(text=item)
661
-
662
- if isinstance(item, dict):
663
- try:
664
- return Part.model_validate(item)
665
- except ValidationError:
666
- return TextPart(text=str(item))
667
-
668
- return TextPart(text=str(item))
669
-
670
-
671
649
  # Response validation helper
672
650
  def _validate_response_id(self, response: Union[SendTaskResponse, GetTaskResponse], request) -> Union[SendTaskResponse, GetTaskResponse]:
673
651
  if response.result and response.result.id != request.params.id:
@@ -678,7 +656,7 @@ class SmartA2A:
678
656
  )
679
657
  )
680
658
  return response
681
-
659
+
682
660
  # Might refactor this later
683
661
  def _finalize_task_response(self, request: GetTaskRequest, task: Task) -> GetTaskResponse:
684
662
  """Final validation and processing for getTask responses."""
@@ -721,8 +699,8 @@ class SmartA2A:
721
699
  id=request.id,
722
700
  result=task
723
701
  )
724
-
725
-
702
+
703
+
726
704
  async def _normalize_subscription_events(self, params: TaskSendParams, events: AsyncGenerator) -> AsyncGenerator[Union[SendTaskStreamingResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent], None]:
727
705
  artifact_state = defaultdict(lambda: {"index": 0, "last_chunk": False})
728
706