openbb-pydantic-ai 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,363 @@
1
+ """Event stream transformer for OpenBB Workspace SSE protocol."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from collections.abc import AsyncIterator
7
+ from dataclasses import dataclass, field
8
+ from typing import Any
9
+
10
+ from openbb_ai.helpers import (
11
+ citations,
12
+ cite,
13
+ get_widget_data,
14
+ message_chunk,
15
+ reasoning_step,
16
+ )
17
+ from openbb_ai.models import (
18
+ SSE,
19
+ LlmClientFunctionCallResultMessage,
20
+ MessageArtifactSSE,
21
+ MessageChunkSSE,
22
+ QueryRequest,
23
+ StatusUpdateSSE,
24
+ WidgetRequest,
25
+ )
26
+ from pydantic_ai import DeferredToolRequests
27
+ from pydantic_ai.messages import (
28
+ FunctionToolCallEvent,
29
+ FunctionToolResultEvent,
30
+ RetryPromptPart,
31
+ TextPart,
32
+ TextPartDelta,
33
+ ThinkingPart,
34
+ ThinkingPartDelta,
35
+ ToolReturnPart,
36
+ )
37
+ from pydantic_ai.run import AgentRunResultEvent
38
+ from pydantic_ai.ui import UIEventStream
39
+
40
+ from ._config import GET_WIDGET_DATA_TOOL_NAME
41
+ from ._dependencies import OpenBBDeps
42
+ from ._event_stream_components import (
43
+ CitationCollector,
44
+ ThinkingBuffer,
45
+ ToolCallTracker,
46
+ )
47
+ from ._event_stream_helpers import (
48
+ ToolCallInfo,
49
+ artifact_from_output,
50
+ extract_widget_args,
51
+ handle_generic_tool_result,
52
+ serialized_content_from_result,
53
+ tool_result_events_from_content,
54
+ )
55
+ from ._utils import format_args, normalize_args
56
+ from ._widget_registry import WidgetRegistry
57
+
58
+
59
+ def _encode_sse(event: SSE) -> str:
60
+ payload = event.model_dump()
61
+ return f"event: {payload['event']}\ndata: {payload['data']}\n\n"
62
+
63
+
64
+ @dataclass
65
+ class OpenBBAIEventStream(UIEventStream[QueryRequest, SSE, OpenBBDeps, Any]):
66
+ """Transform native Pydantic AI events into OpenBB SSE events."""
67
+
68
+ widget_registry: WidgetRegistry = field(default_factory=WidgetRegistry)
69
+ """Registry for widget lookup and discovery."""
70
+ pending_results: list[LlmClientFunctionCallResultMessage] = field(
71
+ default_factory=list
72
+ )
73
+
74
+ # State management components
75
+ _tool_calls: ToolCallTracker = field(init=False, default_factory=ToolCallTracker)
76
+ _citations: CitationCollector = field(init=False, default_factory=CitationCollector)
77
+ _thinking: ThinkingBuffer = field(init=False, default_factory=ThinkingBuffer)
78
+
79
+ # Simple state flags
80
+ _has_streamed_text: bool = field(init=False, default=False)
81
+ _final_output_pending: str | None = field(init=False, default=None)
82
+ _deferred_results_emitted: bool = field(init=False, default=False)
83
+
84
+ def encode_event(self, event: SSE) -> str:
85
+ return _encode_sse(event)
86
+
87
+ def _record_text_streamed(self) -> None:
88
+ """Record that text content has been streamed to the client."""
89
+ self._has_streamed_text = True
90
+
91
+ async def before_stream(self) -> AsyncIterator[SSE]:
92
+ """Emit tool results for any deferred results provided upfront."""
93
+ if self._deferred_results_emitted:
94
+ return
95
+
96
+ self._deferred_results_emitted = True
97
+
98
+ # Process any pending deferred tool results from previous requests
99
+ for result_message in self.pending_results:
100
+ async for event in self._process_deferred_result(result_message):
101
+ yield event
102
+
103
+ async def _process_deferred_result(
104
+ self, result_message: LlmClientFunctionCallResultMessage
105
+ ) -> AsyncIterator[SSE]:
106
+ """Process a single deferred result message and yield SSE events."""
107
+ widget = self.widget_registry.find_for_result(result_message)
108
+
109
+ widget_args = extract_widget_args(result_message)
110
+ content = serialized_content_from_result(result_message)
111
+ call_info = ToolCallInfo(
112
+ tool_name=result_message.function,
113
+ args=widget_args,
114
+ widget=widget,
115
+ )
116
+
117
+ if widget is not None:
118
+ citation = cite(widget, widget_args)
119
+ self._citations.add(citation)
120
+ else:
121
+ details = format_args(widget_args)
122
+ yield reasoning_step(
123
+ f"Received result for '{result_message.function}' "
124
+ "without widget metadata",
125
+ details=details if details else None,
126
+ event_type="WARNING",
127
+ )
128
+
129
+ for event in self._widget_result_events(call_info, content):
130
+ yield event
131
+
132
+ async def on_error(self, error: Exception) -> AsyncIterator[SSE]:
133
+ yield reasoning_step(str(error), event_type="ERROR")
134
+
135
+ async def handle_text_start(
136
+ self, part: TextPart, follows_text: bool = False
137
+ ) -> AsyncIterator[SSE]:
138
+ if part.content:
139
+ self._record_text_streamed()
140
+ yield message_chunk(part.content)
141
+
142
+ async def handle_text_delta(self, delta: TextPartDelta) -> AsyncIterator[SSE]:
143
+ if delta.content_delta:
144
+ self._record_text_streamed()
145
+ yield message_chunk(delta.content_delta)
146
+
147
+ async def handle_thinking_start(
148
+ self,
149
+ part: ThinkingPart,
150
+ follows_thinking: bool = False,
151
+ ) -> AsyncIterator[SSE]:
152
+ self._thinking.clear()
153
+ if part.content:
154
+ self._thinking.append(part.content)
155
+ return
156
+ yield # pragma: no cover
157
+
158
+ async def handle_thinking_delta(
159
+ self,
160
+ delta: ThinkingPartDelta,
161
+ ) -> AsyncIterator[SSE]:
162
+ if delta.content_delta:
163
+ self._thinking.append(delta.content_delta)
164
+ return
165
+ yield # pragma: no cover
166
+
167
+ async def handle_thinking_end(
168
+ self,
169
+ part: ThinkingPart,
170
+ followed_by_thinking: bool = False,
171
+ ) -> AsyncIterator[SSE]:
172
+ content = part.content or self._thinking.get_content()
173
+ if not content and not self._thinking.is_empty():
174
+ content = self._thinking.get_content()
175
+
176
+ if content:
177
+ details = {"Thinking": content}
178
+ yield reasoning_step("Thinking", details=details)
179
+
180
+ self._thinking.clear()
181
+
182
+ async def handle_run_result(
183
+ self, event: AgentRunResultEvent[Any]
184
+ ) -> AsyncIterator[SSE]:
185
+ """Handle agent run result events, including deferred tool requests."""
186
+ result = event.result
187
+ output = getattr(result, "output", None)
188
+
189
+ if isinstance(output, DeferredToolRequests):
190
+ async for sse_event in self._handle_deferred_tool_requests(output):
191
+ yield sse_event
192
+ return
193
+
194
+ artifact = self._artifact_from_output(output)
195
+ if artifact is not None:
196
+ yield artifact
197
+ return
198
+
199
+ if isinstance(output, str) and output and not self._has_streamed_text:
200
+ self._final_output_pending = output
201
+
202
+ async def _handle_deferred_tool_requests(
203
+ self, output: DeferredToolRequests
204
+ ) -> AsyncIterator[SSE]:
205
+ """Process deferred tool requests and yield widget request events."""
206
+ widget_requests: list[WidgetRequest] = []
207
+ tool_call_ids: list[dict[str, Any]] = []
208
+
209
+ for call in output.calls:
210
+ widget = self.widget_registry.find_by_tool_name(call.tool_name)
211
+ if widget is None:
212
+ continue
213
+
214
+ args = normalize_args(call.args)
215
+ widget_requests.append(WidgetRequest(widget=widget, input_arguments=args))
216
+ self._tool_calls.register_call(
217
+ tool_call_id=call.tool_call_id,
218
+ tool_name=call.tool_name,
219
+ args=args,
220
+ widget=widget,
221
+ )
222
+ tool_call_ids.append(
223
+ {
224
+ "tool_call_id": call.tool_call_id,
225
+ "widget_uuid": str(widget.uuid),
226
+ "widget_id": widget.widget_id,
227
+ }
228
+ )
229
+
230
+ # Create details dict with widget info and arguments for display
231
+ details = {
232
+ "Origin": widget.origin,
233
+ "Widget Id": widget.widget_id,
234
+ **format_args(args),
235
+ }
236
+ yield reasoning_step(
237
+ f"Requesting widget '{widget.name}'",
238
+ details=details,
239
+ )
240
+
241
+ if widget_requests:
242
+ sse = get_widget_data(widget_requests)
243
+ sse.data.extra_state = {"tool_calls": tool_call_ids}
244
+ yield sse
245
+
246
+ async def handle_function_tool_call(
247
+ self, event: FunctionToolCallEvent
248
+ ) -> AsyncIterator[SSE]:
249
+ """Surface non-widget tool calls as reasoning steps."""
250
+
251
+ part = event.part
252
+ tool_name = part.tool_name
253
+
254
+ is_widget_call = self.widget_registry.find_by_tool_name(tool_name)
255
+ if is_widget_call or tool_name == GET_WIDGET_DATA_TOOL_NAME:
256
+ return
257
+
258
+ tool_call_id = part.tool_call_id
259
+ if not tool_call_id or self._tool_calls.has_pending(tool_call_id):
260
+ return
261
+
262
+ args = normalize_args(part.args)
263
+ self._tool_calls.register_call(
264
+ tool_call_id=tool_call_id,
265
+ tool_name=tool_name,
266
+ args=args,
267
+ )
268
+
269
+ formatted_args = format_args(args)
270
+ details = formatted_args if formatted_args else None
271
+ yield reasoning_step(f"Calling tool '{tool_name}'", details=details)
272
+
273
+ async def handle_function_tool_result(
274
+ self, event: FunctionToolResultEvent
275
+ ) -> AsyncIterator[SSE]:
276
+ result_part = event.result
277
+
278
+ if isinstance(result_part, RetryPromptPart):
279
+ if result_part.content:
280
+ content = result_part.content
281
+ message = (
282
+ content
283
+ if isinstance(content, str)
284
+ else json.dumps(content, default=str)
285
+ )
286
+ yield reasoning_step(message, event_type="ERROR")
287
+ return
288
+
289
+ if not isinstance(result_part, ToolReturnPart):
290
+ return
291
+
292
+ tool_call_id = result_part.tool_call_id
293
+ if not tool_call_id:
294
+ return
295
+
296
+ if isinstance(
297
+ result_part.content, (MessageArtifactSSE, MessageChunkSSE, StatusUpdateSSE)
298
+ ):
299
+ yield result_part.content
300
+ return
301
+
302
+ call_info = self._tool_calls.get_call_info(tool_call_id)
303
+ if call_info is None:
304
+ return
305
+
306
+ if call_info.widget is not None:
307
+ # Collect citation for later emission (at the end)
308
+ citation = cite(call_info.widget, call_info.args)
309
+ self._citations.add(citation)
310
+
311
+ for sse in self._widget_result_events(call_info, result_part.content):
312
+ yield sse
313
+ return
314
+
315
+ for sse in handle_generic_tool_result(
316
+ call_info,
317
+ result_part.content,
318
+ mark_streamed_text=self._record_text_streamed,
319
+ ):
320
+ yield sse
321
+
322
+ async def after_stream(self) -> AsyncIterator[SSE]:
323
+ if not self._thinking.is_empty():
324
+ content = self._thinking.get_content()
325
+ if content:
326
+ yield reasoning_step(content)
327
+ self._thinking.clear()
328
+
329
+ if self._final_output_pending and not self._has_streamed_text:
330
+ yield message_chunk(self._final_output_pending)
331
+
332
+ self._final_output_pending = None
333
+
334
+ # Emit all citations at the end
335
+ if self._citations.has_citations():
336
+ yield citations(self._citations.get_all())
337
+ self._citations.clear()
338
+
339
+ return
340
+ yield # pragma: no cover
341
+
342
+ def _artifact_from_output(self, output: Any) -> SSE | None:
343
+ """Create an artifact (chart or table) from agent output if possible."""
344
+ return artifact_from_output(output)
345
+
346
+ def _widget_result_events(
347
+ self,
348
+ call_info: ToolCallInfo,
349
+ content: Any,
350
+ ) -> list[SSE]:
351
+ """Emit SSE events for widget results with graceful fallbacks."""
352
+
353
+ events = tool_result_events_from_content(
354
+ content, mark_streamed_text=self._record_text_streamed
355
+ )
356
+ if events:
357
+ return events
358
+
359
+ return handle_generic_tool_result(
360
+ call_info,
361
+ content,
362
+ mark_streamed_text=self._record_text_streamed,
363
+ )
@@ -0,0 +1,155 @@
1
+ """State management components for OpenBB event stream."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any
7
+
8
+ from openbb_ai.models import Citation, Widget
9
+
10
+ from ._event_stream_helpers import ToolCallInfo
11
+
12
+
13
+ @dataclass
14
+ class ThinkingBuffer:
15
+ """Manages thinking content accumulation during streaming."""
16
+
17
+ _buffer: list[str] = field(default_factory=list, init=False)
18
+
19
+ def append(self, content: str) -> None:
20
+ """Add content to the thinking buffer.
21
+
22
+ Parameters
23
+ ----------
24
+ content : str
25
+ Content to append
26
+ """
27
+ self._buffer.append(content)
28
+
29
+ def get_content(self) -> str:
30
+ """Get accumulated thinking content.
31
+
32
+ Returns
33
+ -------
34
+ str
35
+ Concatenated thinking content
36
+ """
37
+ return "".join(self._buffer)
38
+
39
+ def clear(self) -> None:
40
+ """Clear the thinking buffer."""
41
+ self._buffer.clear()
42
+
43
+ def is_empty(self) -> bool:
44
+ """Check if buffer is empty.
45
+
46
+ Returns
47
+ -------
48
+ bool
49
+ True if buffer has no content
50
+ """
51
+ return len(self._buffer) == 0
52
+
53
+
54
+ @dataclass
55
+ class CitationCollector:
56
+ """Tracks and manages citations during streaming."""
57
+
58
+ _citations: list[Citation] = field(default_factory=list, init=False)
59
+
60
+ def add(self, citation: Citation) -> None:
61
+ """Add a citation to the collection.
62
+
63
+ Parameters
64
+ ----------
65
+ citation : Citation
66
+ Citation to add
67
+ """
68
+ self._citations.append(citation)
69
+
70
+ def get_all(self) -> list[Citation]:
71
+ """Get all collected citations.
72
+
73
+ Returns
74
+ -------
75
+ list[Citation]
76
+ List of all citations
77
+ """
78
+ return self._citations.copy()
79
+
80
+ def clear(self) -> None:
81
+ """Clear all citations."""
82
+ self._citations.clear()
83
+
84
+ def has_citations(self) -> bool:
85
+ """Check if any citations have been collected.
86
+
87
+ Returns
88
+ -------
89
+ bool
90
+ True if there are citations
91
+ """
92
+ return len(self._citations) > 0
93
+
94
+
95
+ @dataclass
96
+ class ToolCallTracker:
97
+ """Maps tool call IDs to their metadata and results."""
98
+
99
+ _pending: dict[str, ToolCallInfo] = field(default_factory=dict, init=False)
100
+
101
+ def register_call(
102
+ self,
103
+ tool_call_id: str,
104
+ tool_name: str,
105
+ args: dict[str, Any],
106
+ widget: Widget | None = None,
107
+ ) -> None:
108
+ """Register a pending tool call.
109
+
110
+ Parameters
111
+ ----------
112
+ tool_call_id : str
113
+ Unique identifier for the tool call
114
+ tool_name : str
115
+ Name of the tool being called
116
+ args : dict[str, Any]
117
+ Arguments passed to the tool
118
+ widget : Widget | None
119
+ Associated widget if this is a widget tool call
120
+ """
121
+ self._pending[tool_call_id] = ToolCallInfo(
122
+ tool_name=tool_name,
123
+ args=args,
124
+ widget=widget,
125
+ )
126
+
127
+ def get_call_info(self, tool_call_id: str) -> ToolCallInfo | None:
128
+ """Retrieve and remove call info for a tool call ID.
129
+
130
+ Parameters
131
+ ----------
132
+ tool_call_id : str
133
+ Tool call ID to look up
134
+
135
+ Returns
136
+ -------
137
+ ToolCallInfo | None
138
+ Tool call metadata if found, None otherwise
139
+ """
140
+ return self._pending.pop(tool_call_id, None)
141
+
142
+ def has_pending(self, tool_call_id: str) -> bool:
143
+ """Check if a tool call ID is registered.
144
+
145
+ Parameters
146
+ ----------
147
+ tool_call_id : str
148
+ Tool call ID to check
149
+
150
+ Returns
151
+ -------
152
+ bool
153
+ True if the ID is registered
154
+ """
155
+ return tool_call_id in self._pending