smarta2a 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
smarta2a/server/server.py CHANGED
@@ -1,3 +1,4 @@
1
+ # Library imports
1
2
  from typing import Callable, Any, Optional, Dict, Union, List, AsyncGenerator
2
3
  import json
3
4
  from datetime import datetime
@@ -10,7 +11,15 @@ import uvicorn
10
11
  from fastapi.responses import StreamingResponse
11
12
  from uuid import uuid4
12
13
 
13
- from smarta2a.common.types import (
14
+ # Local imports
15
+ from smarta2a.server.handler_registry import HandlerRegistry
16
+ from smarta2a.server.state_manager import StateManager
17
+ from smarta2a.state_stores.base_state_store import BaseStateStore
18
+ from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
19
+ from smarta2a.history_update_strategies.append_strategy import AppendStrategy
20
+ from smarta2a.utils.task_builder import TaskBuilder
21
+
22
+ from smarta2a.utils.types import (
14
23
  JSONRPCResponse,
15
24
  Task,
16
25
  Artifact,
@@ -19,6 +28,7 @@ from smarta2a.common.types import (
19
28
  FileContent,
20
29
  DataPart,
21
30
  Part,
31
+ Message,
22
32
  TaskStatus,
23
33
  TaskState,
24
34
  JSONRPCError,
@@ -50,41 +60,75 @@ from smarta2a.common.types import (
50
60
  SetTaskPushNotificationResponse,
51
61
  GetTaskPushNotificationResponse,
52
62
  TaskPushNotificationConfig,
63
+ StateData
53
64
  )
54
65
 
55
66
  class SmartA2A:
56
- def __init__(self, name: str, **fastapi_kwargs):
67
+ def __init__(self, name: str, state_store: Optional[BaseStateStore] = None, history_strategy: HistoryUpdateStrategy = AppendStrategy(), **fastapi_kwargs):
57
68
  self.name = name
58
- self.handlers: Dict[str, Callable] = {}
59
- self.subscriptions: Dict[str, Callable] = {}
69
+ self.registry = HandlerRegistry()
70
+ self.state_mgr = StateManager(state_store, history_strategy)
60
71
  self.app = FastAPI(title=name, **fastapi_kwargs)
61
72
  self.router = APIRouter()
62
- self._registered_decorators = set()
73
+ self.state_store = state_store
74
+ self.history_strategy = history_strategy
63
75
  self._setup_routes()
64
76
  self.server_config = {
65
77
  "host": "0.0.0.0",
66
78
  "port": 8000,
67
79
  "reload": False
68
80
  }
81
+ self.task_builder = TaskBuilder(default_status=TaskState.COMPLETED)
69
82
 
83
+
84
+ def on_send_task(self):
85
+ def decorator(func: Callable[[SendTaskRequest, Optional[StateData]], Any]) -> Callable:
86
+ self.registry.register("tasks/send", func)
87
+ return func
88
+ return decorator
89
+
90
+ def on_send_subscribe_task(self):
91
+ def decorator(fn: Callable[[SendTaskStreamingRequest, Optional[StateData]], Any]):
92
+ self.registry.register("tasks/sendSubscribe", fn, subscription=True)
93
+ return fn
94
+ return decorator
95
+
96
+ def task_get(self):
97
+ def decorator(fn: Callable[[GetTaskRequest], Any]):
98
+ self.registry.register("tasks/get", fn)
99
+ return fn
100
+ return decorator
101
+
102
+ def task_cancel(self):
103
+ def decorator(fn: Callable[[CancelTaskRequest], Any]):
104
+ self.registry.register("tasks/cancel", fn)
105
+ return fn
106
+ return decorator
107
+
108
+ def set_notification(self):
109
+ def decorator(fn: Callable[[SetTaskPushNotificationRequest], Any]):
110
+ self.registry.register("tasks/pushNotification/set", fn)
111
+ return fn
112
+ return decorator
113
+
114
+ def get_notification(self):
115
+ def decorator(fn: Callable[[GetTaskPushNotificationRequest], Any]):
116
+ self.registry.register("tasks/pushNotification/get", fn)
117
+ return fn
118
+ return decorator
119
+
70
120
 
71
121
  def _setup_routes(self):
72
122
  @self.app.post("/")
73
123
  async def handle_request(request: Request):
74
124
  try:
75
125
  data = await request.json()
76
- request_obj = JSONRPCRequest(**data)
126
+ req = JSONRPCRequest.model_validate(data)
127
+ #request_obj = JSONRPCRequest(**data)
77
128
  except Exception as e:
78
- return JSONRPCResponse(
79
- id=None,
80
- error=JSONRPCError(
81
- code=-32700,
82
- message="Parse error",
83
- data=str(e)
84
- )
85
- ).model_dump()
86
-
87
- response = await self.process_request(request_obj.model_dump())
129
+ return JSONRPCResponse(id=None, error=JSONRPCError(code=-32700, message="Parse error", data=str(e))).model_dump()
130
+
131
+ response = await self.process_request(req)
88
132
 
89
133
  # <-- Accept both SSE‐style responses:
90
134
  if isinstance(response, (EventSourceResponse, StreamingResponse)):
@@ -92,114 +136,117 @@ class SmartA2A:
92
136
 
93
137
  # <-- Everything else is a normal pydantic JSONRPCResponse
94
138
  return response.model_dump()
95
-
96
- def _register_handler(self, method: str, func: Callable, handler_name: str, handler_type: str = "handler"):
97
- """Shared registration logic with duplicate checking"""
98
- if method in self._registered_decorators:
99
- raise RuntimeError(
100
- f"@{handler_name} decorator for method '{method}' "
101
- f"can only be used once per SmartA2A instance"
102
- )
103
-
104
- if handler_type == "handler":
105
- self.handlers[method] = func
106
- else:
107
- self.subscriptions[method] = func
108
-
109
- self._registered_decorators.add(method)
110
-
111
- def on_send_task(self) -> Callable:
112
- def decorator(func: Callable[[SendTaskRequest], Any]) -> Callable:
113
- self._register_handler("tasks/send", func, "on_send_task", "handler")
114
- return func
115
- return decorator
116
-
117
- def on_send_subscribe_task(self) -> Callable:
118
- def decorator(func: Callable) -> Callable:
119
- self._register_handler("tasks/sendSubscribe", func, "on_send_subscribe_task", "subscription")
120
- return func
121
- return decorator
122
-
123
- def task_get(self):
124
- def decorator(func: Callable[[GetTaskRequest], Task]):
125
- self._register_handler("tasks/get", func, "task_get", "handler")
126
- return func
127
- return decorator
128
139
 
129
- def task_cancel(self):
130
- def decorator(func: Callable[[CancelTaskRequest], Task]):
131
- self._register_handler("tasks/cancel", func, "task_cancel", "handler")
132
- return func
133
- return decorator
134
-
135
- def set_notification(self):
136
- def decorator(func: Callable[[SetTaskPushNotificationRequest], None]) -> Callable:
137
- self._register_handler("tasks/pushNotification/set", func, "set_notification", "handler")
138
- return func
139
- return decorator
140
-
141
- def get_notification(self):
142
- def decorator(func: Callable[[GetTaskPushNotificationRequest], Union[TaskPushNotificationConfig, GetTaskPushNotificationResponse]]):
143
- self._register_handler("tasks/pushNotification/get", func, "get_notification", "handler")
144
- return func
145
- return decorator
146
140
 
147
- async def process_request(self, request_data: dict) -> JSONRPCResponse:
141
+ async def process_request(self, request: JSONRPCRequest) -> JSONRPCResponse:
142
+
148
143
  try:
149
- method = request_data.get("method")
144
+ method = request.method
145
+ params = request.params
146
+ state_store = self.state_mgr.get_store()
150
147
  if method == "tasks/send":
151
- return self._handle_send_task(request_data)
148
+ state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
149
+ if state_store:
150
+ return self._handle_send_task(request, state_data)
151
+ else:
152
+ return self._handle_send_task(request)
152
153
  elif method == "tasks/sendSubscribe":
153
- return await self._handle_subscribe_task(request_data)
154
+ state_data = self.state_mgr.init_or_get(params.get("sessionId"), params.get("message"), params.get("metadata") or {})
155
+ if state_store:
156
+ return await self._handle_subscribe_task(request, state_data)
157
+ else:
158
+ return await self._handle_subscribe_task(request)
154
159
  elif method == "tasks/get":
155
- return self._handle_get_task(request_data)
160
+ return self._handle_get_task(request)
156
161
  elif method == "tasks/cancel":
157
- return self._handle_cancel_task(request_data)
162
+ return self._handle_cancel_task(request)
158
163
  elif method == "tasks/pushNotification/set":
159
- return self._handle_set_notification(request_data)
164
+ return self._handle_set_notification(request)
160
165
  elif method == "tasks/pushNotification/get":
161
- return self._handle_get_notification(request_data)
166
+ return self._handle_get_notification(request)
162
167
  else:
163
- return self._error_response(
164
- request_data.get("id"),
165
- -32601,
166
- "Method not found"
167
- )
168
+ return JSONRPCResponse(id=request.id, error=MethodNotFoundError()).model_dump()
168
169
  except ValidationError as e:
169
- return self._error_response(
170
- request_data.get("id"),
171
- -32600,
172
- "Invalid params",
173
- e.errors()
174
- )
170
+ return JSONRPCResponse(id=request.id, error=InvalidParamsError(data=e.errors())).model_dump()
171
+ except HTTPException as e:
172
+ err = UnsupportedOperationError() if e.status_code == 405 else InternalError(data=str(e))
173
+ return JSONRPCResponse(id=request.id, error=err).model_dump()
174
+
175
175
 
176
- def _handle_send_task(self, request_data: dict) -> SendTaskResponse:
176
+ def _handle_send_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> SendTaskResponse:
177
177
  try:
178
178
  # Validate request format
179
- request = SendTaskRequest.model_validate(request_data)
180
- handler = self.handlers.get("tasks/send")
179
+ request = SendTaskRequest.model_validate(request_data.model_dump())
180
+ handler = self.registry.get_handler("tasks/send")
181
181
 
182
182
  if not handler:
183
183
  return SendTaskResponse(
184
184
  id=request.id,
185
185
  error=MethodNotFoundError()
186
186
  )
187
+
188
+ user_message = request.params.message
189
+ request_metadata = request.params.metadata or {}
190
+ if state_data:
191
+ session_id = state_data.sessionId
192
+ existing_history = state_data.history.copy() or []
193
+ metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
194
+ else:
195
+ session_id = request.params.sessionId or str(uuid4())
196
+ existing_history = [user_message]
197
+ metadata = request_metadata
198
+
187
199
 
188
200
  try:
189
- raw_result = handler(request)
190
-
201
+
202
+ if state_data:
203
+ raw_result = handler(request, state_data)
204
+ else:
205
+ raw_result = handler(request)
206
+
207
+ # Handle direct SendTaskResponse returns
191
208
  if isinstance(raw_result, SendTaskResponse):
192
209
  return raw_result
193
210
 
194
- # Use unified task builder
195
- task = self._build_task(
211
+ # Build task with updated history (before agent response)
212
+ task = self.task_builder.build(
196
213
  content=raw_result,
197
214
  task_id=request.params.id,
198
- session_id=request.params.sessionId,
199
- default_status=TaskState.COMPLETED,
200
- metadata=request.params.metadata or {}
215
+ session_id=session_id, # Always use generated session ID
216
+ metadata=metadata, # Use merged metadata
217
+ history=existing_history # History
201
218
  )
202
219
 
220
+ # Process messages through strategy
221
+ messages = []
222
+ if task.artifacts:
223
+ agent_parts = [p for a in task.artifacts for p in a.parts]
224
+ agent_message = Message(
225
+ role="agent",
226
+ parts=agent_parts,
227
+ metadata=task.metadata
228
+ )
229
+ messages.append(agent_message)
230
+
231
+ final_history = self.history_strategy.update_history(
232
+ existing_history=existing_history,
233
+ new_messages=messages
234
+ )
235
+
236
+ # Update task with final state
237
+ task.history = final_history
238
+
239
+ # State store update (if enabled)
240
+ if self.state_store:
241
+ self.state_store.update_state(
242
+ session_id=session_id,
243
+ state_data=StateData(
244
+ sessionId=session_id,
245
+ history=final_history,
246
+ metadata=metadata # Use merged metadata
247
+ )
248
+ )
249
+
203
250
  return SendTaskResponse(
204
251
  id=request.id,
205
252
  result=task
@@ -229,10 +276,11 @@ class SmartA2A:
229
276
  )
230
277
 
231
278
 
232
- async def _handle_subscribe_task(self, request_data: dict) -> Union[EventSourceResponse, SendTaskStreamingResponse]:
279
+ async def _handle_subscribe_task(self, request_data: JSONRPCRequest, state_data: Optional[StateData] = None) -> Union[EventSourceResponse, SendTaskStreamingResponse]:
233
280
  try:
234
- request = SendTaskStreamingRequest.model_validate(request_data)
235
- handler = self.subscriptions.get("tasks/sendSubscribe")
281
+ request = SendTaskStreamingRequest.model_validate(request_data.model_dump())
282
+ #handler = self.subscriptions.get("tasks/sendSubscribe")
283
+ handler = self.registry.get_subscription("tasks/sendSubscribe")
236
284
 
237
285
  if not handler:
238
286
  return SendTaskStreamingResponse(
@@ -240,15 +288,74 @@ class SmartA2A:
240
288
  id=request.id,
241
289
  error=MethodNotFoundError()
242
290
  )
291
+
292
+ user_message = request.params.message
293
+ request_metadata = request.params.metadata or {}
294
+ if state_data:
295
+ session_id = state_data.sessionId
296
+ existing_history = state_data.history.copy() or []
297
+ metadata = state_data.metadata or {} # Request metadata has already been merged so need to do it here
298
+ else:
299
+ session_id = request.params.sessionId or str(uuid4())
300
+ existing_history = [user_message]
301
+ metadata = request_metadata
302
+
243
303
 
244
304
  async def event_generator():
245
305
 
246
306
  try:
247
- raw_events = handler(request)
307
+
308
+ if state_data:
309
+ raw_events = handler(request, state_data)
310
+ else:
311
+ raw_events = handler(request)
312
+
248
313
  normalized_events = self._normalize_subscription_events(request.params, raw_events)
249
314
 
315
+ # Initialize streaming state
316
+ stream_history = existing_history.copy()
317
+ stream_metadata = metadata.copy()
318
+
250
319
  async for item in normalized_events:
251
320
  try:
321
+
322
+ # Process artifact updates
323
+ if isinstance(item, TaskArtifactUpdateEvent):
324
+ # Create agent message from artifact parts
325
+ agent_message = Message(
326
+ role="agent",
327
+ parts=[p for p in item.artifact.parts],
328
+ metadata=item.artifact.metadata
329
+ )
330
+
331
+ # Update history using strategy
332
+ new_history = self.history_strategy.update_history(
333
+ existing_history=stream_history,
334
+ new_messages=[agent_message]
335
+ )
336
+
337
+ # Merge metadata
338
+ new_metadata = {
339
+ **stream_metadata,
340
+ **(item.artifact.metadata or {})
341
+ }
342
+
343
+ # Update state store if configured
344
+ if self.state_store:
345
+ self.state_store.update_state(
346
+ session_id=session_id,
347
+ state_data=StateData(
348
+ sessionId=session_id,
349
+ history=new_history,
350
+ metadata=new_metadata
351
+ )
352
+ )
353
+
354
+ # Update streaming state
355
+ stream_history = new_history
356
+ stream_metadata = new_metadata
357
+
358
+
252
359
  if isinstance(item, SendTaskStreamingResponse):
253
360
  yield item.model_dump_json()
254
361
  continue
@@ -318,11 +425,11 @@ class SmartA2A:
318
425
  )
319
426
 
320
427
 
321
- def _handle_get_task(self, request_data: dict) -> GetTaskResponse:
428
+ def _handle_get_task(self, request_data: JSONRPCRequest) -> GetTaskResponse:
322
429
  try:
323
430
  # Validate request structure
324
- request = GetTaskRequest.model_validate(request_data)
325
- handler = self.handlers.get("tasks/get")
431
+ request = GetTaskRequest.model_validate(request_data.model_dump())
432
+ handler = self.registry.get_handler("tasks/get")
326
433
 
327
434
  if not handler:
328
435
  return GetTaskResponse(
@@ -337,10 +444,9 @@ class SmartA2A:
337
444
  return self._validate_response_id(raw_result, request)
338
445
 
339
446
  # Use unified task builder with different defaults
340
- task = self._build_task(
447
+ task = self.task_builder.build(
341
448
  content=raw_result,
342
449
  task_id=request.params.id,
343
- default_status=TaskState.COMPLETED,
344
450
  metadata=request.params.metadata or {}
345
451
  )
346
452
 
@@ -370,11 +476,11 @@ class SmartA2A:
370
476
  )
371
477
 
372
478
 
373
- def _handle_cancel_task(self, request_data: dict) -> CancelTaskResponse:
479
+ def _handle_cancel_task(self, request_data: JSONRPCRequest) -> CancelTaskResponse:
374
480
  try:
375
481
  # Validate request structure
376
- request = CancelTaskRequest.model_validate(request_data)
377
- handler = self.handlers.get("tasks/cancel")
482
+ request = CancelTaskRequest.model_validate(request_data.model_dump())
483
+ handler = self.registry.get_handler("tasks/cancel")
378
484
 
379
485
  if not handler:
380
486
  return CancelTaskResponse(
@@ -391,14 +497,14 @@ class SmartA2A:
391
497
 
392
498
  # Handle A2AStatus returns
393
499
  if isinstance(raw_result, A2AStatus):
394
- task = self._build_task_from_status(
395
- status=raw_result,
500
+ task = self.task_builder.normalize_from_status(
501
+ status=raw_result.status,
396
502
  task_id=request.params.id,
397
503
  metadata=raw_result.metadata or {}
398
504
  )
399
505
  else:
400
506
  # Existing processing for other return types
401
- task = self._build_task(
507
+ task = self.task_builder.build(
402
508
  content=raw_result,
403
509
  task_id=request.params.id,
404
510
  metadata=raw_result.metadata or {}
@@ -440,10 +546,10 @@ class SmartA2A:
440
546
  error=InternalError(data=str(e))
441
547
  )
442
548
 
443
- def _handle_set_notification(self, request_data: dict) -> SetTaskPushNotificationResponse:
549
+ def _handle_set_notification(self, request_data: JSONRPCRequest) -> SetTaskPushNotificationResponse:
444
550
  try:
445
- request = SetTaskPushNotificationRequest.model_validate(request_data)
446
- handler = self.handlers.get("tasks/pushNotification/set")
551
+ request = SetTaskPushNotificationRequest.model_validate(request_data.model_dump())
552
+ handler = self.registry.get_handler("tasks/pushNotification/set")
447
553
 
448
554
  if not handler:
449
555
  return SetTaskPushNotificationResponse(
@@ -485,10 +591,10 @@ class SmartA2A:
485
591
  )
486
592
 
487
593
 
488
- def _handle_get_notification(self, request_data: dict) -> GetTaskPushNotificationResponse:
594
+ def _handle_get_notification(self, request_data: JSONRPCRequest) -> GetTaskPushNotificationResponse:
489
595
  try:
490
- request = GetTaskPushNotificationRequest.model_validate(request_data)
491
- handler = self.handlers.get("tasks/pushNotification/get")
596
+ request = GetTaskPushNotificationRequest.model_validate(request_data.model_dump())
597
+ handler = self.registry.get_handler("tasks/pushNotification/get")
492
598
 
493
599
  if not handler:
494
600
  return GetTaskPushNotificationResponse(
@@ -535,139 +641,6 @@ class SmartA2A:
535
641
  error=JSONParseError(data=str(e))
536
642
  )
537
643
 
538
-
539
- def _normalize_artifacts(self, content: Any) -> List[Artifact]:
540
- """Handle both A2AResponse content and regular returns"""
541
- if isinstance(content, Artifact):
542
- return [content]
543
-
544
- if isinstance(content, list):
545
- # Handle list of artifacts
546
- if all(isinstance(item, Artifact) for item in content):
547
- return content
548
-
549
- # Handle mixed parts in list
550
- parts = []
551
- for item in content:
552
- if isinstance(item, Artifact):
553
- parts.extend(item.parts)
554
- else:
555
- parts.append(self._create_part(item))
556
- return [Artifact(parts=parts)]
557
-
558
- # Handle single part returns
559
- if isinstance(content, (str, Part, dict)):
560
- return [Artifact(parts=[self._create_part(content)])]
561
-
562
- # Handle raw artifact dicts
563
- try:
564
- return [Artifact.model_validate(content)]
565
- except ValidationError:
566
- return [Artifact(parts=[TextPart(text=str(content))])]
567
-
568
-
569
- def _build_task(
570
- self,
571
- content: Any,
572
- task_id: str,
573
- session_id: Optional[str] = None,
574
- default_status: TaskState = TaskState.COMPLETED,
575
- metadata: Optional[dict] = None
576
- ) -> Task:
577
- """Universal task construction from various return types."""
578
- if isinstance(content, Task):
579
- return content
580
-
581
- # Handle A2AResponse for sendTask case
582
- if isinstance(content, A2AResponse):
583
- status = content.status if isinstance(content.status, TaskStatus) \
584
- else TaskStatus(state=content.status)
585
- artifacts = self._normalize_content(content.content)
586
- return Task(
587
- id=task_id,
588
- sessionId=session_id or str(uuid4()), # Generate if missing
589
- status=status,
590
- artifacts=artifacts,
591
- metadata=metadata or {}
592
- )
593
-
594
- try: # Attempt direct validation for dicts
595
- return Task.model_validate(content)
596
- except ValidationError:
597
- pass
598
-
599
- # Fallback to content normalization
600
- artifacts = self._normalize_content(content)
601
- return Task(
602
- id=task_id,
603
- sessionId=session_id,
604
- status=TaskStatus(state=default_status),
605
- artifacts=artifacts,
606
- metadata=metadata or {}
607
- )
608
-
609
- def _build_task_from_status(self, status: A2AStatus, task_id: str, metadata: dict) -> Task:
610
- """Convert A2AStatus to a Task with proper cancellation state."""
611
- return Task(
612
- id=task_id,
613
- status=TaskStatus(
614
- state=TaskState(status.status),
615
- timestamp=datetime.now()
616
- ),
617
- metadata=metadata,
618
- # Include empty/default values for required fields
619
- sessionId="",
620
- artifacts=[],
621
- history=[]
622
- )
623
-
624
-
625
- def _normalize_content(self, content: Any) -> List[Artifact]:
626
- """Handle all content types for both sendTask and getTask cases."""
627
- if isinstance(content, Artifact):
628
- return [content]
629
-
630
- if isinstance(content, list):
631
- if all(isinstance(item, Artifact) for item in content):
632
- return content
633
- return [Artifact(parts=self._parts_from_mixed(content))]
634
-
635
- if isinstance(content, (str, Part, dict)):
636
- return [Artifact(parts=[self._create_part(content)])]
637
-
638
- try: # Handle raw artifact dicts
639
- return [Artifact.model_validate(content)]
640
- except ValidationError:
641
- return [Artifact(parts=[TextPart(text=str(content))])]
642
-
643
- def _parts_from_mixed(self, items: List[Any]) -> List[Part]:
644
- """Extract parts from mixed content lists."""
645
- parts = []
646
- for item in items:
647
- if isinstance(item, Artifact):
648
- parts.extend(item.parts)
649
- else:
650
- parts.append(self._create_part(item))
651
- return parts
652
-
653
-
654
- def _create_part(self, item: Any) -> Part:
655
- """Convert primitive types to proper Part models"""
656
- if isinstance(item, (TextPart, FilePart, DataPart)):
657
- return item
658
-
659
- if isinstance(item, str):
660
- return TextPart(text=item)
661
-
662
- if isinstance(item, dict):
663
- try:
664
- return Part.model_validate(item)
665
- except ValidationError:
666
- return TextPart(text=str(item))
667
-
668
- return TextPart(text=str(item))
669
-
670
-
671
644
  # Response validation helper
672
645
  def _validate_response_id(self, response: Union[SendTaskResponse, GetTaskResponse], request) -> Union[SendTaskResponse, GetTaskResponse]:
673
646
  if response.result and response.result.id != request.params.id:
@@ -678,7 +651,7 @@ class SmartA2A:
678
651
  )
679
652
  )
680
653
  return response
681
-
654
+
682
655
  # Might refactor this later
683
656
  def _finalize_task_response(self, request: GetTaskRequest, task: Task) -> GetTaskResponse:
684
657
  """Final validation and processing for getTask responses."""
@@ -721,8 +694,8 @@ class SmartA2A:
721
694
  id=request.id,
722
695
  result=task
723
696
  )
724
-
725
-
697
+
698
+
726
699
  async def _normalize_subscription_events(self, params: TaskSendParams, events: AsyncGenerator) -> AsyncGenerator[Union[SendTaskStreamingResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent], None]:
727
700
  artifact_state = defaultdict(lambda: {"index": 0, "last_chunk": False})
728
701
 
@@ -0,0 +1,34 @@
1
+ # Library imports
2
+ from typing import Optional, Dict, Any
3
+ from uuid import uuid4
4
+
5
+ # Local imports
6
+ from smarta2a.state_stores.base_state_store import BaseStateStore
7
+ from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
8
+ from smarta2a.utils.types import Message, StateData
9
+
10
+ class StateManager:
11
+ def __init__(self, store: Optional[BaseStateStore], history_strategy: HistoryUpdateStrategy):
12
+ self.store = store
13
+ self.strategy = history_strategy
14
+
15
+ def init_or_get(self, session_id: Optional[str], message: Message, metadata: Dict[str, Any]) -> StateData:
16
+ sid = session_id or str(uuid4())
17
+ if not self.store:
18
+ return StateData(sessionId=sid, history=[message], metadata=metadata or {})
19
+ existing = self.store.get_state(sid) or StateData(sessionId=sid, history=[], metadata={})
20
+ existing.history.append(message)
21
+ existing.metadata = {**(existing.metadata or {}), **(metadata or {})}
22
+ self.store.update_state(sid, existing)
23
+ return existing
24
+
25
+ def update(self, state: StateData):
26
+ if self.store:
27
+ self.store.update_state(state.sessionId, state)
28
+
29
+ def get_store(self) -> Optional[BaseStateStore]:
30
+ return self.store
31
+
32
+ def get_strategy(self) -> HistoryUpdateStrategy:
33
+ return self.strategy
34
+