a2a-adapter 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,756 @@
1
+ """
2
+ LangGraph adapter for A2A Protocol.
3
+
4
+ This adapter enables LangGraph compiled workflows to be exposed as A2A-compliant
5
+ agents with support for both streaming and non-streaming modes.
6
+
7
+ Supports two modes:
8
+ - Synchronous (default): Blocks until workflow completes, returns Message
9
+ - Async Task Mode: Returns Task immediately, processes in background, supports polling
10
+ """
11
+
12
+ import asyncio
13
+ import json
14
+ import logging
15
+ import uuid
16
+ from datetime import datetime, timezone
17
+ from typing import Any, AsyncIterator, Dict
18
+
19
+ from a2a.types import (
20
+ Message,
21
+ MessageSendParams,
22
+ Task,
23
+ TaskState,
24
+ TaskStatus,
25
+ TextPart,
26
+ Role,
27
+ Part,
28
+ )
29
+ from ..adapter import BaseAgentAdapter
30
+
31
+ # Lazy import for TaskStore to avoid hard dependency
32
+ try:
33
+ from a2a.server.tasks import TaskStore, InMemoryTaskStore
34
+ _HAS_TASK_STORE = True
35
+ except ImportError:
36
+ _HAS_TASK_STORE = False
37
+ TaskStore = None # type: ignore
38
+ InMemoryTaskStore = None # type: ignore
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class LangGraphAgentAdapter(BaseAgentAdapter):
44
+ """
45
+ Adapter for integrating LangGraph compiled workflows as A2A agents.
46
+
47
+ This adapter works with LangGraph's CompiledGraph objects (the result of
48
+ calling .compile() on a StateGraph) and supports both streaming and
49
+ non-streaming execution modes.
50
+
51
+ Supports three execution patterns:
52
+
53
+ 1. **Synchronous Mode** (default):
54
+ - Blocks until the workflow completes
55
+ - Returns a Message with the final result
56
+ - Best for quick workflows (< 30 seconds)
57
+
58
+ 2. **Streaming Mode**:
59
+ - Streams intermediate results as they're produced
60
+ - Uses LangGraph's astream() method
61
+ - Best for real-time feedback during execution
62
+
63
+ 3. **Async Task Mode** (async_mode=True):
64
+ - Returns a Task with state="working" immediately
65
+ - Processes the workflow in the background
66
+ - Clients can poll get_task() for status updates
67
+ - Best for long-running workflows
68
+
69
+ Example:
70
+ >>> from langgraph.graph import StateGraph
71
+ >>> from typing import TypedDict
72
+ >>>
73
+ >>> class State(TypedDict):
74
+ ... messages: list
75
+ ... output: str
76
+ >>>
77
+ >>> def process(state: State) -> State:
78
+ ... return {"output": f"Processed: {state['messages'][-1]}"}
79
+ >>>
80
+ >>> builder = StateGraph(State)
81
+ >>> builder.add_node("process", process)
82
+ >>> builder.set_entry_point("process")
83
+ >>> builder.set_finish_point("process")
84
+ >>> graph = builder.compile()
85
+ >>>
86
+ >>> adapter = LangGraphAgentAdapter(graph=graph)
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ graph: Any, # Type: CompiledGraph (avoiding hard dependency)
92
+ input_key: str = "messages",
93
+ output_key: str | None = None,
94
+ state_key: str | None = None,
95
+ async_mode: bool = False,
96
+ task_store: "TaskStore | None" = None,
97
+ async_timeout: int = 300,
98
+ ):
99
+ """
100
+ Initialize the LangGraph adapter.
101
+
102
+ Args:
103
+ graph: A LangGraph CompiledGraph instance (result of StateGraph.compile())
104
+ input_key: The key in the state dict for input messages (default: "messages").
105
+ Set to "input" for simple string input workflows.
106
+ output_key: Optional key to extract from final state. If None, the adapter
107
+ will try common keys like "output", "response", "messages".
108
+ state_key: Optional key to use when extracting state for streaming events.
109
+ If None, uses output_key or auto-detection.
110
+ async_mode: If True, return Task immediately and process in background.
111
+ If False (default), block until workflow completes.
112
+ task_store: Optional TaskStore for persisting task state. If not provided
113
+ and async_mode is True, uses InMemoryTaskStore.
114
+ async_timeout: Timeout for async task execution in seconds (default: 300).
115
+ """
116
+ self.graph = graph
117
+ self.input_key = input_key
118
+ self.output_key = output_key
119
+ self.state_key = state_key or output_key
120
+
121
+ # Async task mode configuration
122
+ self.async_mode = async_mode
123
+ self.async_timeout = async_timeout
124
+ self._background_tasks: Dict[str, "asyncio.Task[None]"] = {}
125
+ self._cancelled_tasks: set[str] = set()
126
+
127
+ # Initialize task store for async mode
128
+ if async_mode:
129
+ if not _HAS_TASK_STORE:
130
+ raise ImportError(
131
+ "Async task mode requires the A2A SDK with task support. "
132
+ "Install with: pip install a2a-sdk"
133
+ )
134
+ self.task_store: "TaskStore" = task_store or InMemoryTaskStore()
135
+ else:
136
+ self.task_store = task_store # type: ignore
137
+
138
+ async def handle(self, params: MessageSendParams) -> Message | Task:
139
+ """
140
+ Handle a non-streaming A2A message request.
141
+
142
+ In sync mode (default): Blocks until workflow completes, returns Message.
143
+ In async mode: Returns Task immediately, processes in background.
144
+ """
145
+ if self.async_mode:
146
+ return await self._handle_async(params)
147
+ else:
148
+ return await self._handle_sync(params)
149
+
150
+ async def _handle_sync(self, params: MessageSendParams) -> Message:
151
+ """Handle request synchronously - blocks until workflow completes."""
152
+ framework_input = await self.to_framework(params)
153
+ framework_output = await self.call_framework(framework_input, params)
154
+ result = await self.from_framework(framework_output, params)
155
+
156
+ # In sync mode, always return Message
157
+ if isinstance(result, Task):
158
+ if result.status and result.status.message:
159
+ return result.status.message
160
+ return Message(
161
+ role=Role.agent,
162
+ message_id=str(uuid.uuid4()),
163
+ context_id=result.context_id,
164
+ parts=[Part(root=TextPart(text="Workflow completed"))],
165
+ )
166
+ return result
167
+
168
+ async def _handle_async(self, params: MessageSendParams) -> Task:
169
+ """
170
+ Handle request asynchronously - returns Task immediately, processes in background.
171
+ """
172
+ # Generate IDs
173
+ task_id = str(uuid.uuid4())
174
+ context_id = self._extract_context_id(params) or str(uuid.uuid4())
175
+
176
+ # Extract the initial message for history
177
+ initial_message = None
178
+ if hasattr(params, "message") and params.message:
179
+ initial_message = params.message
180
+
181
+ # Create initial task with "working" state
182
+ now = datetime.now(timezone.utc).isoformat()
183
+ task = Task(
184
+ id=task_id,
185
+ context_id=context_id,
186
+ status=TaskStatus(
187
+ state=TaskState.working,
188
+ timestamp=now,
189
+ ),
190
+ history=[initial_message] if initial_message else None,
191
+ )
192
+
193
+ # Save initial task state
194
+ await self.task_store.save(task)
195
+ logger.debug("Created async task %s with state=working", task_id)
196
+
197
+ # Start background processing with timeout
198
+ bg_task = asyncio.create_task(
199
+ self._execute_workflow_with_timeout(task_id, context_id, params)
200
+ )
201
+ self._background_tasks[task_id] = bg_task
202
+
203
+ # Clean up background task reference when done
204
+ def _on_task_done(t: "asyncio.Task[None]") -> None:
205
+ self._background_tasks.pop(task_id, None)
206
+ self._cancelled_tasks.discard(task_id)
207
+ if not t.cancelled():
208
+ exc = t.exception()
209
+ if exc:
210
+ logger.error(
211
+ "Unhandled exception in background task %s: %s",
212
+ task_id,
213
+ exc,
214
+ )
215
+
216
+ bg_task.add_done_callback(_on_task_done)
217
+
218
+ return task
219
+
220
+ async def _execute_workflow_with_timeout(
221
+ self,
222
+ task_id: str,
223
+ context_id: str,
224
+ params: MessageSendParams,
225
+ ) -> None:
226
+ """Execute the workflow with a timeout wrapper."""
227
+ try:
228
+ await asyncio.wait_for(
229
+ self._execute_workflow_background(task_id, context_id, params),
230
+ timeout=self.async_timeout,
231
+ )
232
+ except asyncio.TimeoutError:
233
+ if task_id in self._cancelled_tasks:
234
+ logger.debug("Task %s was cancelled, not marking as failed", task_id)
235
+ return
236
+
237
+ logger.error("Task %s timed out after %s seconds", task_id, self.async_timeout)
238
+ now = datetime.now(timezone.utc).isoformat()
239
+ error_message = Message(
240
+ role=Role.agent,
241
+ message_id=str(uuid.uuid4()),
242
+ context_id=context_id,
243
+ parts=[Part(root=TextPart(text=f"Workflow timed out after {self.async_timeout} seconds"))],
244
+ )
245
+
246
+ timeout_task = Task(
247
+ id=task_id,
248
+ context_id=context_id,
249
+ status=TaskStatus(
250
+ state=TaskState.failed,
251
+ message=error_message,
252
+ timestamp=now,
253
+ ),
254
+ )
255
+ await self.task_store.save(timeout_task)
256
+
257
+ async def _execute_workflow_background(
258
+ self,
259
+ task_id: str,
260
+ context_id: str,
261
+ params: MessageSendParams,
262
+ ) -> None:
263
+ """Execute the LangGraph workflow in the background and update task state."""
264
+ try:
265
+ logger.debug("Starting background execution for task %s", task_id)
266
+
267
+ # Execute the workflow
268
+ framework_input = await self.to_framework(params)
269
+ framework_output = await self.call_framework(framework_input, params)
270
+
271
+ # Check if task was cancelled during execution
272
+ if task_id in self._cancelled_tasks:
273
+ logger.debug("Task %s was cancelled during execution", task_id)
274
+ return
275
+
276
+ # Convert to message
277
+ response_text = self._extract_output_text(framework_output)
278
+ response_message = Message(
279
+ role=Role.agent,
280
+ message_id=str(uuid.uuid4()),
281
+ context_id=context_id,
282
+ parts=[Part(root=TextPart(text=response_text))],
283
+ )
284
+
285
+ # Build history
286
+ history = []
287
+ if hasattr(params, "message") and params.message:
288
+ history.append(params.message)
289
+ history.append(response_message)
290
+
291
+ # Update task to completed state
292
+ now = datetime.now(timezone.utc).isoformat()
293
+ completed_task = Task(
294
+ id=task_id,
295
+ context_id=context_id,
296
+ status=TaskStatus(
297
+ state=TaskState.completed,
298
+ message=response_message,
299
+ timestamp=now,
300
+ ),
301
+ history=history,
302
+ )
303
+
304
+ await self.task_store.save(completed_task)
305
+ logger.debug("Task %s completed successfully", task_id)
306
+
307
+ except asyncio.CancelledError:
308
+ logger.debug("Task %s was cancelled", task_id)
309
+ raise
310
+
311
+ except Exception as e:
312
+ if task_id in self._cancelled_tasks:
313
+ logger.debug("Task %s was cancelled, not marking as failed", task_id)
314
+ return
315
+
316
+ logger.error("Task %s failed: %s", task_id, e)
317
+ now = datetime.now(timezone.utc).isoformat()
318
+ error_message = Message(
319
+ role=Role.agent,
320
+ message_id=str(uuid.uuid4()),
321
+ context_id=context_id,
322
+ parts=[Part(root=TextPart(text=f"Workflow failed: {str(e)}"))],
323
+ )
324
+
325
+ failed_task = Task(
326
+ id=task_id,
327
+ context_id=context_id,
328
+ status=TaskStatus(
329
+ state=TaskState.failed,
330
+ message=error_message,
331
+ timestamp=now,
332
+ ),
333
+ )
334
+
335
+ await self.task_store.save(failed_task)
336
+
337
+ # ---------- Input mapping ----------
338
+
339
+ async def to_framework(self, params: MessageSendParams) -> Dict[str, Any]:
340
+ """
341
+ Convert A2A message parameters to LangGraph state input.
342
+
343
+ Supports two common input patterns:
344
+ 1. Messages-based: {"messages": [{"role": "user", "content": "..."}]}
345
+ 2. Simple input: {"input": "user message text"}
346
+
347
+ Args:
348
+ params: A2A message parameters
349
+
350
+ Returns:
351
+ Dictionary with graph state input
352
+ """
353
+ user_message = ""
354
+
355
+ # Extract message from A2A params (new format with message.parts)
356
+ if hasattr(params, "message") and params.message:
357
+ msg = params.message
358
+ if hasattr(msg, "parts") and msg.parts:
359
+ text_parts = []
360
+ for part in msg.parts:
361
+ # Handle Part(root=TextPart(...)) structure
362
+ if hasattr(part, "root") and hasattr(part.root, "text"):
363
+ text_parts.append(part.root.text)
364
+ # Handle direct TextPart
365
+ elif hasattr(part, "text"):
366
+ text_parts.append(part.text)
367
+ user_message = self._join_text_parts(text_parts)
368
+
369
+ # Legacy support for messages array (deprecated)
370
+ elif getattr(params, "messages", None):
371
+ last = params.messages[-1]
372
+ content = getattr(last, "content", "")
373
+ if isinstance(content, str):
374
+ user_message = content.strip()
375
+ elif isinstance(content, list):
376
+ text_parts = []
377
+ for item in content:
378
+ txt = getattr(item, "text", None)
379
+ if txt and isinstance(txt, str) and txt.strip():
380
+ text_parts.append(txt.strip())
381
+ user_message = self._join_text_parts(text_parts)
382
+
383
+ # Build graph input based on input_key
384
+ if self.input_key == "messages":
385
+ # LangGraph message format (for chat-like workflows)
386
+ # Try to use LangChain message format if available
387
+ try:
388
+ from langchain_core.messages import HumanMessage
389
+ return {"messages": [HumanMessage(content=user_message)]}
390
+ except ImportError:
391
+ # Fallback to dict format
392
+ return {"messages": [{"role": "user", "content": user_message}]}
393
+ else:
394
+ # Simple input key (e.g., "input", "query", etc.)
395
+ return {self.input_key: user_message}
396
+
397
+ @staticmethod
398
+ def _join_text_parts(parts: list[str]) -> str:
399
+ """Join text parts into a single string."""
400
+ if not parts:
401
+ return ""
402
+ text = " ".join(p.strip() for p in parts if p)
403
+ return text.strip()
404
+
405
+ def _extract_context_id(self, params: MessageSendParams) -> str | None:
406
+ """Extract context_id from MessageSendParams."""
407
+ if hasattr(params, "message") and params.message:
408
+ return getattr(params.message, "context_id", None)
409
+ return None
410
+
411
+ # ---------- Framework call ----------
412
+
413
+ async def call_framework(
414
+ self, framework_input: Dict[str, Any], params: MessageSendParams
415
+ ) -> Dict[str, Any]:
416
+ """
417
+ Execute the LangGraph workflow with the provided input.
418
+
419
+ Args:
420
+ framework_input: Input state dictionary for the graph
421
+ params: Original A2A parameters (for context)
422
+
423
+ Returns:
424
+ Final state from the graph execution
425
+
426
+ Raises:
427
+ Exception: If graph execution fails
428
+ """
429
+ logger.debug("Invoking LangGraph with input: %s", framework_input)
430
+ result = await self.graph.ainvoke(framework_input)
431
+ logger.debug("LangGraph returned state with keys: %s", list(result.keys()) if isinstance(result, dict) else type(result).__name__)
432
+ return result
433
+
434
+ # ---------- Output mapping ----------
435
+
436
+ async def from_framework(
437
+ self, framework_output: Dict[str, Any], params: MessageSendParams
438
+ ) -> Message | Task:
439
+ """
440
+ Convert LangGraph final state to A2A Message.
441
+
442
+ Args:
443
+ framework_output: Final state from graph execution
444
+ params: Original A2A parameters
445
+
446
+ Returns:
447
+ A2A Message with the workflow's response
448
+ """
449
+ response_text = self._extract_output_text(framework_output)
450
+ context_id = self._extract_context_id(params)
451
+
452
+ return Message(
453
+ role=Role.agent,
454
+ message_id=str(uuid.uuid4()),
455
+ context_id=context_id,
456
+ parts=[Part(root=TextPart(text=response_text))],
457
+ )
458
+
459
+ def _extract_output_text(self, framework_output: Any) -> str:
460
+ """
461
+ Extract text content from LangGraph state.
462
+
463
+ Handles various output patterns:
464
+ - output_key specified: extract that key
465
+ - "messages" key: extract last message content
466
+ - Common keys: "output", "response", "result", "answer"
467
+
468
+ Args:
469
+ framework_output: Final state from the graph
470
+
471
+ Returns:
472
+ Extracted text string
473
+ """
474
+ if not isinstance(framework_output, dict):
475
+ return str(framework_output)
476
+
477
+ # Use output_key if specified
478
+ if self.output_key and self.output_key in framework_output:
479
+ return self._extract_value_text(framework_output[self.output_key])
480
+
481
+ # Try "messages" key (common in chat workflows)
482
+ if "messages" in framework_output:
483
+ messages = framework_output["messages"]
484
+ if messages and len(messages) > 0:
485
+ last_message = messages[-1]
486
+ return self._extract_message_content(last_message)
487
+
488
+ # Try common output keys
489
+ for key in ["output", "response", "result", "answer", "text", "content"]:
490
+ if key in framework_output:
491
+ return self._extract_value_text(framework_output[key])
492
+
493
+ # Fallback: serialize entire state (excluding internal keys)
494
+ clean_state = {k: v for k, v in framework_output.items() if not k.startswith("_")}
495
+ return json.dumps(clean_state, indent=2, default=str)
496
+
497
+ def _extract_value_text(self, value: Any) -> str:
498
+ """Extract text from a value (handles strings, dicts, lists)."""
499
+ if isinstance(value, str):
500
+ return value
501
+ if isinstance(value, dict):
502
+ # Try common text keys in dict
503
+ for key in ["text", "content", "output"]:
504
+ if key in value:
505
+ return str(value[key])
506
+ return json.dumps(value, indent=2, default=str)
507
+ if isinstance(value, list):
508
+ # Join list items
509
+ return "\n".join(self._extract_value_text(item) for item in value)
510
+ return str(value)
511
+
512
+ def _extract_message_content(self, message: Any) -> str:
513
+ """Extract content from a message object (LangChain or dict)."""
514
+ # LangChain message with content attribute
515
+ if hasattr(message, "content"):
516
+ content = message.content
517
+ if isinstance(content, str):
518
+ return content
519
+ elif isinstance(content, list):
520
+ text_parts = []
521
+ for item in content:
522
+ if isinstance(item, str):
523
+ text_parts.append(item)
524
+ elif hasattr(item, "text"):
525
+ text_parts.append(item.text)
526
+ return " ".join(text_parts)
527
+ return str(content)
528
+
529
+ # Dict message
530
+ if isinstance(message, dict):
531
+ if "content" in message:
532
+ return str(message["content"])
533
+ if "text" in message:
534
+ return str(message["text"])
535
+
536
+ return str(message)
537
+
538
+ # ---------- Streaming support ----------
539
+
540
+ async def handle_stream(
541
+ self, params: MessageSendParams
542
+ ) -> AsyncIterator[Dict[str, Any]]:
543
+ """
544
+ Handle a streaming A2A message request.
545
+
546
+ Uses LangGraph's astream() or astream_events() to yield intermediate
547
+ results as the workflow executes.
548
+
549
+ Args:
550
+ params: A2A message parameters
551
+
552
+ Yields:
553
+ Server-Sent Events compatible dictionaries with streaming chunks
554
+ """
555
+ framework_input = await self.to_framework(params)
556
+ context_id = self._extract_context_id(params)
557
+ message_id = str(uuid.uuid4())
558
+
559
+ logger.debug("Starting LangGraph stream with input: %s", framework_input)
560
+
561
+ accumulated_text = ""
562
+ last_state = None
563
+
564
+ # Stream from LangGraph
565
+ async for state in self.graph.astream(framework_input):
566
+ last_state = state
567
+
568
+ # Extract text from current state
569
+ text = self._extract_streaming_text(state)
570
+
571
+ if text and text != accumulated_text:
572
+ # Calculate the new chunk (delta)
573
+ new_content = text[len(accumulated_text):] if text.startswith(accumulated_text) else text
574
+ accumulated_text = text
575
+
576
+ if new_content:
577
+ yield {
578
+ "event": "message",
579
+ "data": json.dumps({
580
+ "type": "content",
581
+ "content": new_content,
582
+ }),
583
+ }
584
+
585
+ # Use final state if we have it, otherwise use accumulated
586
+ final_text = self._extract_output_text(last_state) if last_state else accumulated_text
587
+
588
+ # Send final message with complete response
589
+ final_message = Message(
590
+ role=Role.agent,
591
+ message_id=message_id,
592
+ context_id=context_id,
593
+ parts=[Part(root=TextPart(text=final_text))],
594
+ )
595
+
596
+ # Send completion event
597
+ yield {
598
+ "event": "done",
599
+ "data": json.dumps({
600
+ "status": "completed",
601
+ "message": final_message.model_dump() if hasattr(final_message, "model_dump") else str(final_message),
602
+ }),
603
+ }
604
+
605
+ logger.debug("LangGraph stream completed")
606
+
607
+ def _extract_streaming_text(self, state: Any) -> str:
608
+ """
609
+ Extract text from intermediate streaming state.
610
+
611
+ Args:
612
+ state: Intermediate state from astream()
613
+
614
+ Returns:
615
+ Current text content
616
+ """
617
+ if not isinstance(state, dict):
618
+ return str(state)
619
+
620
+ # Use state_key if specified
621
+ if self.state_key and self.state_key in state:
622
+ return self._extract_value_text(state[self.state_key])
623
+
624
+ # Try messages (for chat workflows)
625
+ if "messages" in state:
626
+ messages = state["messages"]
627
+ if messages:
628
+ # Get content from last message
629
+ last = messages[-1]
630
+ return self._extract_message_content(last)
631
+
632
+ # Try common keys
633
+ for key in ["output", "response", "text", "content"]:
634
+ if key in state:
635
+ return self._extract_value_text(state[key])
636
+
637
+ return ""
638
+
639
+ def supports_streaming(self) -> bool:
640
+ """
641
+ Check if the graph supports streaming.
642
+
643
+ Returns:
644
+ True if the graph has an astream method
645
+ """
646
+ return hasattr(self.graph, "astream")
647
+
648
+ # ---------- Async Task Support ----------
649
+
650
+ def supports_async_tasks(self) -> bool:
651
+ """Check if this adapter supports async task execution."""
652
+ return self.async_mode
653
+
654
+ async def get_task(self, task_id: str) -> Task | None:
655
+ """
656
+ Get the current status of a task by ID.
657
+
658
+ Args:
659
+ task_id: The ID of the task to retrieve
660
+
661
+ Returns:
662
+ The Task object with current status, or None if not found
663
+
664
+ Raises:
665
+ RuntimeError: If async mode is not enabled
666
+ """
667
+ if not self.async_mode:
668
+ raise RuntimeError(
669
+ "get_task() is only available in async mode. "
670
+ "Initialize adapter with async_mode=True"
671
+ )
672
+
673
+ task = await self.task_store.get(task_id)
674
+ if task:
675
+ logger.debug("Retrieved task %s with state=%s", task_id, task.status.state)
676
+ else:
677
+ logger.debug("Task %s not found", task_id)
678
+ return task
679
+
680
+ async def cancel_task(self, task_id: str) -> Task | None:
681
+ """
682
+ Attempt to cancel a running task.
683
+
684
+ Args:
685
+ task_id: The ID of the task to cancel
686
+
687
+ Returns:
688
+ The updated Task object with state="canceled", or None if not found
689
+ """
690
+ if not self.async_mode:
691
+ raise RuntimeError(
692
+ "cancel_task() is only available in async mode. "
693
+ "Initialize adapter with async_mode=True"
694
+ )
695
+
696
+ # Mark task as cancelled to prevent race conditions
697
+ self._cancelled_tasks.add(task_id)
698
+
699
+ # Cancel the background task if still running
700
+ bg_task = self._background_tasks.get(task_id)
701
+ if bg_task and not bg_task.done():
702
+ bg_task.cancel()
703
+ logger.debug("Cancelling background task for %s", task_id)
704
+ try:
705
+ await bg_task
706
+ except asyncio.CancelledError:
707
+ pass
708
+ except Exception:
709
+ pass
710
+
711
+ # Update task state to canceled
712
+ task = await self.task_store.get(task_id)
713
+ if task:
714
+ now = datetime.now(timezone.utc).isoformat()
715
+ canceled_task = Task(
716
+ id=task_id,
717
+ context_id=task.context_id,
718
+ status=TaskStatus(
719
+ state=TaskState.canceled,
720
+ timestamp=now,
721
+ ),
722
+ history=task.history,
723
+ )
724
+ await self.task_store.save(canceled_task)
725
+ logger.debug("Task %s marked as canceled", task_id)
726
+ return canceled_task
727
+
728
+ return None
729
+
730
+ # ---------- Lifecycle ----------
731
+
732
+ async def close(self) -> None:
733
+ """Cancel pending background tasks."""
734
+ for task_id in self._background_tasks:
735
+ self._cancelled_tasks.add(task_id)
736
+
737
+ tasks_to_cancel = []
738
+ for task_id, bg_task in list(self._background_tasks.items()):
739
+ if not bg_task.done():
740
+ bg_task.cancel()
741
+ tasks_to_cancel.append(bg_task)
742
+ logger.debug("Cancelling background task %s during close", task_id)
743
+
744
+ if tasks_to_cancel:
745
+ await asyncio.gather(*tasks_to_cancel, return_exceptions=True)
746
+
747
+ self._background_tasks.clear()
748
+ self._cancelled_tasks.clear()
749
+
750
+ async def __aenter__(self):
751
+ """Async context manager entry."""
752
+ return self
753
+
754
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
755
+ """Async context manager exit."""
756
+ await self.close()