openbb-pydantic-ai 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openbb_pydantic_ai/__init__.py +52 -0
- openbb_pydantic_ai/_adapter.py +307 -0
- openbb_pydantic_ai/_config.py +58 -0
- openbb_pydantic_ai/_dependencies.py +73 -0
- openbb_pydantic_ai/_event_builder.py +183 -0
- openbb_pydantic_ai/_event_stream.py +363 -0
- openbb_pydantic_ai/_event_stream_components.py +155 -0
- openbb_pydantic_ai/_event_stream_helpers.py +422 -0
- openbb_pydantic_ai/_exceptions.py +61 -0
- openbb_pydantic_ai/_message_transformer.py +127 -0
- openbb_pydantic_ai/_serializers.py +110 -0
- openbb_pydantic_ai/_toolsets.py +264 -0
- openbb_pydantic_ai/_types.py +39 -0
- openbb_pydantic_ai/_utils.py +132 -0
- openbb_pydantic_ai/_widget_registry.py +145 -0
- openbb_pydantic_ai-0.1.1.dist-info/METADATA +139 -0
- openbb_pydantic_ai-0.1.1.dist-info/RECORD +18 -0
- openbb_pydantic_ai-0.1.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
"""Event stream transformer for OpenBB Workspace SSE protocol."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from openbb_ai.helpers import (
|
|
11
|
+
citations,
|
|
12
|
+
cite,
|
|
13
|
+
get_widget_data,
|
|
14
|
+
message_chunk,
|
|
15
|
+
reasoning_step,
|
|
16
|
+
)
|
|
17
|
+
from openbb_ai.models import (
|
|
18
|
+
SSE,
|
|
19
|
+
LlmClientFunctionCallResultMessage,
|
|
20
|
+
MessageArtifactSSE,
|
|
21
|
+
MessageChunkSSE,
|
|
22
|
+
QueryRequest,
|
|
23
|
+
StatusUpdateSSE,
|
|
24
|
+
WidgetRequest,
|
|
25
|
+
)
|
|
26
|
+
from pydantic_ai import DeferredToolRequests
|
|
27
|
+
from pydantic_ai.messages import (
|
|
28
|
+
FunctionToolCallEvent,
|
|
29
|
+
FunctionToolResultEvent,
|
|
30
|
+
RetryPromptPart,
|
|
31
|
+
TextPart,
|
|
32
|
+
TextPartDelta,
|
|
33
|
+
ThinkingPart,
|
|
34
|
+
ThinkingPartDelta,
|
|
35
|
+
ToolReturnPart,
|
|
36
|
+
)
|
|
37
|
+
from pydantic_ai.run import AgentRunResultEvent
|
|
38
|
+
from pydantic_ai.ui import UIEventStream
|
|
39
|
+
|
|
40
|
+
from ._config import GET_WIDGET_DATA_TOOL_NAME
|
|
41
|
+
from ._dependencies import OpenBBDeps
|
|
42
|
+
from ._event_stream_components import (
|
|
43
|
+
CitationCollector,
|
|
44
|
+
ThinkingBuffer,
|
|
45
|
+
ToolCallTracker,
|
|
46
|
+
)
|
|
47
|
+
from ._event_stream_helpers import (
|
|
48
|
+
ToolCallInfo,
|
|
49
|
+
artifact_from_output,
|
|
50
|
+
extract_widget_args,
|
|
51
|
+
handle_generic_tool_result,
|
|
52
|
+
serialized_content_from_result,
|
|
53
|
+
tool_result_events_from_content,
|
|
54
|
+
)
|
|
55
|
+
from ._utils import format_args, normalize_args
|
|
56
|
+
from ._widget_registry import WidgetRegistry
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _encode_sse(event: SSE) -> str:
|
|
60
|
+
payload = event.model_dump()
|
|
61
|
+
return f"event: {payload['event']}\ndata: {payload['data']}\n\n"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class OpenBBAIEventStream(UIEventStream[QueryRequest, SSE, OpenBBDeps, Any]):
|
|
66
|
+
"""Transform native Pydantic AI events into OpenBB SSE events."""
|
|
67
|
+
|
|
68
|
+
widget_registry: WidgetRegistry = field(default_factory=WidgetRegistry)
|
|
69
|
+
"""Registry for widget lookup and discovery."""
|
|
70
|
+
pending_results: list[LlmClientFunctionCallResultMessage] = field(
|
|
71
|
+
default_factory=list
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# State management components
|
|
75
|
+
_tool_calls: ToolCallTracker = field(init=False, default_factory=ToolCallTracker)
|
|
76
|
+
_citations: CitationCollector = field(init=False, default_factory=CitationCollector)
|
|
77
|
+
_thinking: ThinkingBuffer = field(init=False, default_factory=ThinkingBuffer)
|
|
78
|
+
|
|
79
|
+
# Simple state flags
|
|
80
|
+
_has_streamed_text: bool = field(init=False, default=False)
|
|
81
|
+
_final_output_pending: str | None = field(init=False, default=None)
|
|
82
|
+
_deferred_results_emitted: bool = field(init=False, default=False)
|
|
83
|
+
|
|
84
|
+
def encode_event(self, event: SSE) -> str:
|
|
85
|
+
return _encode_sse(event)
|
|
86
|
+
|
|
87
|
+
def _record_text_streamed(self) -> None:
|
|
88
|
+
"""Record that text content has been streamed to the client."""
|
|
89
|
+
self._has_streamed_text = True
|
|
90
|
+
|
|
91
|
+
async def before_stream(self) -> AsyncIterator[SSE]:
|
|
92
|
+
"""Emit tool results for any deferred results provided upfront."""
|
|
93
|
+
if self._deferred_results_emitted:
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
self._deferred_results_emitted = True
|
|
97
|
+
|
|
98
|
+
# Process any pending deferred tool results from previous requests
|
|
99
|
+
for result_message in self.pending_results:
|
|
100
|
+
async for event in self._process_deferred_result(result_message):
|
|
101
|
+
yield event
|
|
102
|
+
|
|
103
|
+
async def _process_deferred_result(
|
|
104
|
+
self, result_message: LlmClientFunctionCallResultMessage
|
|
105
|
+
) -> AsyncIterator[SSE]:
|
|
106
|
+
"""Process a single deferred result message and yield SSE events."""
|
|
107
|
+
widget = self.widget_registry.find_for_result(result_message)
|
|
108
|
+
|
|
109
|
+
widget_args = extract_widget_args(result_message)
|
|
110
|
+
content = serialized_content_from_result(result_message)
|
|
111
|
+
call_info = ToolCallInfo(
|
|
112
|
+
tool_name=result_message.function,
|
|
113
|
+
args=widget_args,
|
|
114
|
+
widget=widget,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if widget is not None:
|
|
118
|
+
citation = cite(widget, widget_args)
|
|
119
|
+
self._citations.add(citation)
|
|
120
|
+
else:
|
|
121
|
+
details = format_args(widget_args)
|
|
122
|
+
yield reasoning_step(
|
|
123
|
+
f"Received result for '{result_message.function}' "
|
|
124
|
+
"without widget metadata",
|
|
125
|
+
details=details if details else None,
|
|
126
|
+
event_type="WARNING",
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
for event in self._widget_result_events(call_info, content):
|
|
130
|
+
yield event
|
|
131
|
+
|
|
132
|
+
async def on_error(self, error: Exception) -> AsyncIterator[SSE]:
|
|
133
|
+
yield reasoning_step(str(error), event_type="ERROR")
|
|
134
|
+
|
|
135
|
+
async def handle_text_start(
|
|
136
|
+
self, part: TextPart, follows_text: bool = False
|
|
137
|
+
) -> AsyncIterator[SSE]:
|
|
138
|
+
if part.content:
|
|
139
|
+
self._record_text_streamed()
|
|
140
|
+
yield message_chunk(part.content)
|
|
141
|
+
|
|
142
|
+
async def handle_text_delta(self, delta: TextPartDelta) -> AsyncIterator[SSE]:
|
|
143
|
+
if delta.content_delta:
|
|
144
|
+
self._record_text_streamed()
|
|
145
|
+
yield message_chunk(delta.content_delta)
|
|
146
|
+
|
|
147
|
+
async def handle_thinking_start(
|
|
148
|
+
self,
|
|
149
|
+
part: ThinkingPart,
|
|
150
|
+
follows_thinking: bool = False,
|
|
151
|
+
) -> AsyncIterator[SSE]:
|
|
152
|
+
self._thinking.clear()
|
|
153
|
+
if part.content:
|
|
154
|
+
self._thinking.append(part.content)
|
|
155
|
+
return
|
|
156
|
+
yield # pragma: no cover
|
|
157
|
+
|
|
158
|
+
async def handle_thinking_delta(
|
|
159
|
+
self,
|
|
160
|
+
delta: ThinkingPartDelta,
|
|
161
|
+
) -> AsyncIterator[SSE]:
|
|
162
|
+
if delta.content_delta:
|
|
163
|
+
self._thinking.append(delta.content_delta)
|
|
164
|
+
return
|
|
165
|
+
yield # pragma: no cover
|
|
166
|
+
|
|
167
|
+
async def handle_thinking_end(
|
|
168
|
+
self,
|
|
169
|
+
part: ThinkingPart,
|
|
170
|
+
followed_by_thinking: bool = False,
|
|
171
|
+
) -> AsyncIterator[SSE]:
|
|
172
|
+
content = part.content or self._thinking.get_content()
|
|
173
|
+
if not content and not self._thinking.is_empty():
|
|
174
|
+
content = self._thinking.get_content()
|
|
175
|
+
|
|
176
|
+
if content:
|
|
177
|
+
details = {"Thinking": content}
|
|
178
|
+
yield reasoning_step("Thinking", details=details)
|
|
179
|
+
|
|
180
|
+
self._thinking.clear()
|
|
181
|
+
|
|
182
|
+
async def handle_run_result(
|
|
183
|
+
self, event: AgentRunResultEvent[Any]
|
|
184
|
+
) -> AsyncIterator[SSE]:
|
|
185
|
+
"""Handle agent run result events, including deferred tool requests."""
|
|
186
|
+
result = event.result
|
|
187
|
+
output = getattr(result, "output", None)
|
|
188
|
+
|
|
189
|
+
if isinstance(output, DeferredToolRequests):
|
|
190
|
+
async for sse_event in self._handle_deferred_tool_requests(output):
|
|
191
|
+
yield sse_event
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
artifact = self._artifact_from_output(output)
|
|
195
|
+
if artifact is not None:
|
|
196
|
+
yield artifact
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
if isinstance(output, str) and output and not self._has_streamed_text:
|
|
200
|
+
self._final_output_pending = output
|
|
201
|
+
|
|
202
|
+
async def _handle_deferred_tool_requests(
|
|
203
|
+
self, output: DeferredToolRequests
|
|
204
|
+
) -> AsyncIterator[SSE]:
|
|
205
|
+
"""Process deferred tool requests and yield widget request events."""
|
|
206
|
+
widget_requests: list[WidgetRequest] = []
|
|
207
|
+
tool_call_ids: list[dict[str, Any]] = []
|
|
208
|
+
|
|
209
|
+
for call in output.calls:
|
|
210
|
+
widget = self.widget_registry.find_by_tool_name(call.tool_name)
|
|
211
|
+
if widget is None:
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
args = normalize_args(call.args)
|
|
215
|
+
widget_requests.append(WidgetRequest(widget=widget, input_arguments=args))
|
|
216
|
+
self._tool_calls.register_call(
|
|
217
|
+
tool_call_id=call.tool_call_id,
|
|
218
|
+
tool_name=call.tool_name,
|
|
219
|
+
args=args,
|
|
220
|
+
widget=widget,
|
|
221
|
+
)
|
|
222
|
+
tool_call_ids.append(
|
|
223
|
+
{
|
|
224
|
+
"tool_call_id": call.tool_call_id,
|
|
225
|
+
"widget_uuid": str(widget.uuid),
|
|
226
|
+
"widget_id": widget.widget_id,
|
|
227
|
+
}
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Create details dict with widget info and arguments for display
|
|
231
|
+
details = {
|
|
232
|
+
"Origin": widget.origin,
|
|
233
|
+
"Widget Id": widget.widget_id,
|
|
234
|
+
**format_args(args),
|
|
235
|
+
}
|
|
236
|
+
yield reasoning_step(
|
|
237
|
+
f"Requesting widget '{widget.name}'",
|
|
238
|
+
details=details,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
if widget_requests:
|
|
242
|
+
sse = get_widget_data(widget_requests)
|
|
243
|
+
sse.data.extra_state = {"tool_calls": tool_call_ids}
|
|
244
|
+
yield sse
|
|
245
|
+
|
|
246
|
+
async def handle_function_tool_call(
|
|
247
|
+
self, event: FunctionToolCallEvent
|
|
248
|
+
) -> AsyncIterator[SSE]:
|
|
249
|
+
"""Surface non-widget tool calls as reasoning steps."""
|
|
250
|
+
|
|
251
|
+
part = event.part
|
|
252
|
+
tool_name = part.tool_name
|
|
253
|
+
|
|
254
|
+
is_widget_call = self.widget_registry.find_by_tool_name(tool_name)
|
|
255
|
+
if is_widget_call or tool_name == GET_WIDGET_DATA_TOOL_NAME:
|
|
256
|
+
return
|
|
257
|
+
|
|
258
|
+
tool_call_id = part.tool_call_id
|
|
259
|
+
if not tool_call_id or self._tool_calls.has_pending(tool_call_id):
|
|
260
|
+
return
|
|
261
|
+
|
|
262
|
+
args = normalize_args(part.args)
|
|
263
|
+
self._tool_calls.register_call(
|
|
264
|
+
tool_call_id=tool_call_id,
|
|
265
|
+
tool_name=tool_name,
|
|
266
|
+
args=args,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
formatted_args = format_args(args)
|
|
270
|
+
details = formatted_args if formatted_args else None
|
|
271
|
+
yield reasoning_step(f"Calling tool '{tool_name}'", details=details)
|
|
272
|
+
|
|
273
|
+
async def handle_function_tool_result(
|
|
274
|
+
self, event: FunctionToolResultEvent
|
|
275
|
+
) -> AsyncIterator[SSE]:
|
|
276
|
+
result_part = event.result
|
|
277
|
+
|
|
278
|
+
if isinstance(result_part, RetryPromptPart):
|
|
279
|
+
if result_part.content:
|
|
280
|
+
content = result_part.content
|
|
281
|
+
message = (
|
|
282
|
+
content
|
|
283
|
+
if isinstance(content, str)
|
|
284
|
+
else json.dumps(content, default=str)
|
|
285
|
+
)
|
|
286
|
+
yield reasoning_step(message, event_type="ERROR")
|
|
287
|
+
return
|
|
288
|
+
|
|
289
|
+
if not isinstance(result_part, ToolReturnPart):
|
|
290
|
+
return
|
|
291
|
+
|
|
292
|
+
tool_call_id = result_part.tool_call_id
|
|
293
|
+
if not tool_call_id:
|
|
294
|
+
return
|
|
295
|
+
|
|
296
|
+
if isinstance(
|
|
297
|
+
result_part.content, (MessageArtifactSSE, MessageChunkSSE, StatusUpdateSSE)
|
|
298
|
+
):
|
|
299
|
+
yield result_part.content
|
|
300
|
+
return
|
|
301
|
+
|
|
302
|
+
call_info = self._tool_calls.get_call_info(tool_call_id)
|
|
303
|
+
if call_info is None:
|
|
304
|
+
return
|
|
305
|
+
|
|
306
|
+
if call_info.widget is not None:
|
|
307
|
+
# Collect citation for later emission (at the end)
|
|
308
|
+
citation = cite(call_info.widget, call_info.args)
|
|
309
|
+
self._citations.add(citation)
|
|
310
|
+
|
|
311
|
+
for sse in self._widget_result_events(call_info, result_part.content):
|
|
312
|
+
yield sse
|
|
313
|
+
return
|
|
314
|
+
|
|
315
|
+
for sse in handle_generic_tool_result(
|
|
316
|
+
call_info,
|
|
317
|
+
result_part.content,
|
|
318
|
+
mark_streamed_text=self._record_text_streamed,
|
|
319
|
+
):
|
|
320
|
+
yield sse
|
|
321
|
+
|
|
322
|
+
async def after_stream(self) -> AsyncIterator[SSE]:
|
|
323
|
+
if not self._thinking.is_empty():
|
|
324
|
+
content = self._thinking.get_content()
|
|
325
|
+
if content:
|
|
326
|
+
yield reasoning_step(content)
|
|
327
|
+
self._thinking.clear()
|
|
328
|
+
|
|
329
|
+
if self._final_output_pending and not self._has_streamed_text:
|
|
330
|
+
yield message_chunk(self._final_output_pending)
|
|
331
|
+
|
|
332
|
+
self._final_output_pending = None
|
|
333
|
+
|
|
334
|
+
# Emit all citations at the end
|
|
335
|
+
if self._citations.has_citations():
|
|
336
|
+
yield citations(self._citations.get_all())
|
|
337
|
+
self._citations.clear()
|
|
338
|
+
|
|
339
|
+
return
|
|
340
|
+
yield # pragma: no cover
|
|
341
|
+
|
|
342
|
+
def _artifact_from_output(self, output: Any) -> SSE | None:
|
|
343
|
+
"""Create an artifact (chart or table) from agent output if possible."""
|
|
344
|
+
return artifact_from_output(output)
|
|
345
|
+
|
|
346
|
+
def _widget_result_events(
|
|
347
|
+
self,
|
|
348
|
+
call_info: ToolCallInfo,
|
|
349
|
+
content: Any,
|
|
350
|
+
) -> list[SSE]:
|
|
351
|
+
"""Emit SSE events for widget results with graceful fallbacks."""
|
|
352
|
+
|
|
353
|
+
events = tool_result_events_from_content(
|
|
354
|
+
content, mark_streamed_text=self._record_text_streamed
|
|
355
|
+
)
|
|
356
|
+
if events:
|
|
357
|
+
return events
|
|
358
|
+
|
|
359
|
+
return handle_generic_tool_result(
|
|
360
|
+
call_info,
|
|
361
|
+
content,
|
|
362
|
+
mark_streamed_text=self._record_text_streamed,
|
|
363
|
+
)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""State management components for OpenBB event stream."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from openbb_ai.models import Citation, Widget
|
|
9
|
+
|
|
10
|
+
from ._event_stream_helpers import ToolCallInfo
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class ThinkingBuffer:
|
|
15
|
+
"""Manages thinking content accumulation during streaming."""
|
|
16
|
+
|
|
17
|
+
_buffer: list[str] = field(default_factory=list, init=False)
|
|
18
|
+
|
|
19
|
+
def append(self, content: str) -> None:
|
|
20
|
+
"""Add content to the thinking buffer.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
content : str
|
|
25
|
+
Content to append
|
|
26
|
+
"""
|
|
27
|
+
self._buffer.append(content)
|
|
28
|
+
|
|
29
|
+
def get_content(self) -> str:
|
|
30
|
+
"""Get accumulated thinking content.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
str
|
|
35
|
+
Concatenated thinking content
|
|
36
|
+
"""
|
|
37
|
+
return "".join(self._buffer)
|
|
38
|
+
|
|
39
|
+
def clear(self) -> None:
|
|
40
|
+
"""Clear the thinking buffer."""
|
|
41
|
+
self._buffer.clear()
|
|
42
|
+
|
|
43
|
+
def is_empty(self) -> bool:
|
|
44
|
+
"""Check if buffer is empty.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
bool
|
|
49
|
+
True if buffer has no content
|
|
50
|
+
"""
|
|
51
|
+
return len(self._buffer) == 0
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class CitationCollector:
|
|
56
|
+
"""Tracks and manages citations during streaming."""
|
|
57
|
+
|
|
58
|
+
_citations: list[Citation] = field(default_factory=list, init=False)
|
|
59
|
+
|
|
60
|
+
def add(self, citation: Citation) -> None:
|
|
61
|
+
"""Add a citation to the collection.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
citation : Citation
|
|
66
|
+
Citation to add
|
|
67
|
+
"""
|
|
68
|
+
self._citations.append(citation)
|
|
69
|
+
|
|
70
|
+
def get_all(self) -> list[Citation]:
|
|
71
|
+
"""Get all collected citations.
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
list[Citation]
|
|
76
|
+
List of all citations
|
|
77
|
+
"""
|
|
78
|
+
return self._citations.copy()
|
|
79
|
+
|
|
80
|
+
def clear(self) -> None:
|
|
81
|
+
"""Clear all citations."""
|
|
82
|
+
self._citations.clear()
|
|
83
|
+
|
|
84
|
+
def has_citations(self) -> bool:
|
|
85
|
+
"""Check if any citations have been collected.
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
bool
|
|
90
|
+
True if there are citations
|
|
91
|
+
"""
|
|
92
|
+
return len(self._citations) > 0
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class ToolCallTracker:
|
|
97
|
+
"""Maps tool call IDs to their metadata and results."""
|
|
98
|
+
|
|
99
|
+
_pending: dict[str, ToolCallInfo] = field(default_factory=dict, init=False)
|
|
100
|
+
|
|
101
|
+
def register_call(
|
|
102
|
+
self,
|
|
103
|
+
tool_call_id: str,
|
|
104
|
+
tool_name: str,
|
|
105
|
+
args: dict[str, Any],
|
|
106
|
+
widget: Widget | None = None,
|
|
107
|
+
) -> None:
|
|
108
|
+
"""Register a pending tool call.
|
|
109
|
+
|
|
110
|
+
Parameters
|
|
111
|
+
----------
|
|
112
|
+
tool_call_id : str
|
|
113
|
+
Unique identifier for the tool call
|
|
114
|
+
tool_name : str
|
|
115
|
+
Name of the tool being called
|
|
116
|
+
args : dict[str, Any]
|
|
117
|
+
Arguments passed to the tool
|
|
118
|
+
widget : Widget | None
|
|
119
|
+
Associated widget if this is a widget tool call
|
|
120
|
+
"""
|
|
121
|
+
self._pending[tool_call_id] = ToolCallInfo(
|
|
122
|
+
tool_name=tool_name,
|
|
123
|
+
args=args,
|
|
124
|
+
widget=widget,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def get_call_info(self, tool_call_id: str) -> ToolCallInfo | None:
|
|
128
|
+
"""Retrieve and remove call info for a tool call ID.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
tool_call_id : str
|
|
133
|
+
Tool call ID to look up
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
ToolCallInfo | None
|
|
138
|
+
Tool call metadata if found, None otherwise
|
|
139
|
+
"""
|
|
140
|
+
return self._pending.pop(tool_call_id, None)
|
|
141
|
+
|
|
142
|
+
def has_pending(self, tool_call_id: str) -> bool:
|
|
143
|
+
"""Check if a tool call ID is registered.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
tool_call_id : str
|
|
148
|
+
Tool call ID to check
|
|
149
|
+
|
|
150
|
+
Returns
|
|
151
|
+
-------
|
|
152
|
+
bool
|
|
153
|
+
True if the ID is registered
|
|
154
|
+
"""
|
|
155
|
+
return tool_call_id in self._pending
|