openbb-pydantic-ai 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,52 @@
1
+ """Pydantic AI UI adapter for OpenBB Workspace."""
2
+
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
+ from ._adapter import OpenBBAIAdapter
6
+ from ._config import GET_WIDGET_DATA_TOOL_NAME
7
+ from ._dependencies import OpenBBDeps, build_deps_from_request
8
+ from ._event_builder import EventBuilder
9
+ from ._event_stream import OpenBBAIEventStream
10
+ from ._exceptions import (
11
+ InvalidToolCallError,
12
+ OpenBBPydanticAIError,
13
+ SerializationError,
14
+ WidgetNotFoundError,
15
+ )
16
+ from ._message_transformer import MessageTransformer
17
+ from ._serializers import ContentSerializer
18
+ from ._toolsets import (
19
+ WidgetToolset,
20
+ build_widget_tool,
21
+ build_widget_tool_name,
22
+ build_widget_toolsets,
23
+ )
24
+ from ._types import SerializedContent, TextStreamCallback
25
+ from ._widget_registry import WidgetRegistry
26
+
27
+ try:
28
+ __version__ = version("openbb-pydantic-ai")
29
+ except PackageNotFoundError:
30
+ __version__ = "0.0.0"
31
+
32
+ __all__ = [
33
+ "OpenBBAIAdapter",
34
+ "OpenBBAIEventStream",
35
+ "OpenBBDeps",
36
+ "build_deps_from_request",
37
+ "WidgetToolset",
38
+ "build_widget_tool",
39
+ "build_widget_tool_name",
40
+ "build_widget_toolsets",
41
+ "GET_WIDGET_DATA_TOOL_NAME",
42
+ "EventBuilder",
43
+ "ContentSerializer",
44
+ "MessageTransformer",
45
+ "WidgetRegistry",
46
+ "OpenBBPydanticAIError",
47
+ "WidgetNotFoundError",
48
+ "InvalidToolCallError",
49
+ "SerializationError",
50
+ "SerializedContent",
51
+ "TextStreamCallback",
52
+ ]
@@ -0,0 +1,307 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+ from dataclasses import KW_ONLY, dataclass, field
5
+ from functools import cached_property
6
+ from typing import Any, cast
7
+
8
+ from openbb_ai.models import (
9
+ SSE,
10
+ LlmClientFunctionCallResultMessage,
11
+ LlmMessage,
12
+ QueryRequest,
13
+ )
14
+ from pydantic_ai import DeferredToolResults
15
+ from pydantic_ai.messages import (
16
+ ModelMessage,
17
+ SystemPromptPart,
18
+ )
19
+ from pydantic_ai.toolsets import AbstractToolset, CombinedToolset, FunctionToolset
20
+ from pydantic_ai.ui import UIAdapter
21
+
22
+ from ._dependencies import OpenBBDeps, build_deps_from_request
23
+ from ._event_stream import OpenBBAIEventStream
24
+ from ._message_transformer import MessageTransformer
25
+ from ._serializers import ContentSerializer
26
+ from ._toolsets import build_widget_toolsets
27
+ from ._utils import hash_tool_call
28
+ from ._widget_registry import WidgetRegistry
29
+
30
+
31
+ @dataclass(slots=True)
32
+ class OpenBBAIAdapter(UIAdapter[QueryRequest, LlmMessage, SSE, OpenBBDeps, Any]):
33
+ """UI adapter that bridges OpenBB Workspace requests with Pydantic AI."""
34
+
35
+ _: KW_ONLY
36
+ accept: str | None = None
37
+
38
+ # Initialized in __post_init__
39
+ _transformer: MessageTransformer = field(init=False)
40
+ _registry: WidgetRegistry = field(init=False)
41
+ _base_messages: list[LlmMessage] = field(init=False, default_factory=list)
42
+ _pending_results: list[LlmClientFunctionCallResultMessage] = field(
43
+ init=False, default_factory=list
44
+ )
45
+
46
+ def __post_init__(self) -> None:
47
+ base, pending = self._split_messages(self.run_input.messages)
48
+ self._base_messages = base
49
+ self._pending_results = pending
50
+
51
+ # Build tool call ID overrides for consistent IDs
52
+ tool_call_id_overrides: dict[str, str] = {}
53
+ for message in self._base_messages:
54
+ if isinstance(message, LlmClientFunctionCallResultMessage):
55
+ key = hash_tool_call(message.function, message.input_arguments)
56
+ tool_call_id = self._tool_call_id_from_result(message)
57
+ tool_call_id_overrides[key] = tool_call_id
58
+
59
+ for message in self._pending_results:
60
+ key = hash_tool_call(message.function, message.input_arguments)
61
+ tool_call_id_overrides.setdefault(
62
+ key,
63
+ self._tool_call_id_from_result(message),
64
+ )
65
+
66
+ # Initialize transformer and registry
67
+ self._transformer = MessageTransformer(tool_call_id_overrides)
68
+ self._registry = WidgetRegistry(
69
+ collection=self.run_input.widgets,
70
+ toolsets=self._widget_toolsets,
71
+ )
72
+
73
+ @classmethod
74
+ def build_run_input(cls, body: bytes) -> QueryRequest:
75
+ return QueryRequest.model_validate_json(body)
76
+
77
+ @classmethod
78
+ def load_messages(cls, messages: Sequence[LlmMessage]) -> list[ModelMessage]:
79
+ """Convert OpenBB messages to Pydantic AI messages.
80
+
81
+ Note: This creates a transformer without overrides for standalone use.
82
+ """
83
+ transformer = MessageTransformer()
84
+ return transformer.transform_batch(messages)
85
+
86
+ @staticmethod
87
+ def _split_messages(
88
+ messages: Sequence[LlmMessage],
89
+ ) -> tuple[list[LlmMessage], list[LlmClientFunctionCallResultMessage]]:
90
+ """Split messages into base history and pending deferred results.
91
+
92
+ Only results after the last AI message are considered pending. Results
93
+ followed by AI messages were already processed in previous turns.
94
+
95
+ Parameters
96
+ ----------
97
+ messages : Sequence[LlmMessage]
98
+ Full message sequence
99
+
100
+ Returns
101
+ -------
102
+ tuple[list[LlmMessage], list[LlmClientFunctionCallResultMessage]]
103
+ (base messages, pending results that need processing)
104
+ """
105
+ base = list(messages)
106
+ pending: list[LlmClientFunctionCallResultMessage] = []
107
+
108
+ # Treat only the trailing tool results (those after the final assistant
109
+ # message) as pending. Leave them in the base history so the next model
110
+ # call still sees the complete tool call/result exchange.
111
+ idx = len(base) - 1
112
+ while idx >= 0:
113
+ message = base[idx]
114
+ if not isinstance(message, LlmClientFunctionCallResultMessage):
115
+ break
116
+ pending.insert(0, cast(LlmClientFunctionCallResultMessage, message))
117
+ idx -= 1
118
+
119
+ return base, pending
120
+
121
+ def _tool_call_id_from_result(
122
+ self, message: LlmClientFunctionCallResultMessage
123
+ ) -> str:
124
+ """Extract or generate a tool call ID from a result message."""
125
+ extra_id = (
126
+ message.extra_state.get("tool_call_id") if message.extra_state else None
127
+ )
128
+ if isinstance(extra_id, str):
129
+ return extra_id
130
+ return hash_tool_call(message.function, message.input_arguments)
131
+
132
+ @cached_property
133
+ def deps(self) -> OpenBBDeps:
134
+ return build_deps_from_request(self.run_input)
135
+
136
+ @cached_property
137
+ def deferred_tool_results(self) -> DeferredToolResults | None:
138
+ """Build deferred tool results from pending result messages."""
139
+ if not self._pending_results:
140
+ return None
141
+
142
+ # When those trailing results already sit in the base history, skip
143
+ # emitting DeferredToolResults; resending them would show up as a
144
+ # conflicting duplicate tool response upstream.
145
+ if self._pending_results_are_in_history():
146
+ return None
147
+
148
+ results = DeferredToolResults()
149
+ for message in self._pending_results:
150
+ actual_id = self._tool_call_id_from_result(message)
151
+ serialized = ContentSerializer.serialize_result(message)
152
+ results.calls[actual_id] = serialized
153
+ return results
154
+
155
+ def _pending_results_are_in_history(self) -> bool:
156
+ if not self._pending_results:
157
+ return False
158
+ pending_len = len(self._pending_results)
159
+ if pending_len > len(self._base_messages):
160
+ return False
161
+ tail = self._base_messages[-pending_len:]
162
+ return all(
163
+ orig is pending
164
+ for orig, pending in zip(tail, self._pending_results, strict=True)
165
+ )
166
+
167
+ @cached_property
168
+ def _widget_toolsets(self) -> tuple[FunctionToolset[OpenBBDeps], ...]:
169
+ return build_widget_toolsets(self.run_input.widgets)
170
+
171
+ def build_event_stream(self) -> OpenBBAIEventStream:
172
+ return OpenBBAIEventStream(
173
+ run_input=self.run_input,
174
+ widget_registry=self._registry,
175
+ pending_results=self._pending_results,
176
+ )
177
+
178
+ @cached_property
179
+ def messages(self) -> list[ModelMessage]:
180
+ """Build message history with context prompts."""
181
+ from pydantic_ai.ui import MessagesBuilder
182
+
183
+ builder = MessagesBuilder()
184
+ self._add_context_prompts(builder)
185
+
186
+ # Use transformer to convert messages with ID overrides
187
+ transformed = self._transformer.transform_batch(self._base_messages)
188
+ for msg in transformed:
189
+ for part in msg.parts:
190
+ builder.add(part)
191
+
192
+ return builder.messages
193
+
194
+ def _add_context_prompts(self, builder) -> None:
195
+ """Add system prompts with workspace context, URLs, and dashboard info."""
196
+ lines: list[str] = []
197
+
198
+ if self.deps.context:
199
+ lines.append("Workspace context:")
200
+ for ctx in self.deps.context:
201
+ row_count = len(ctx.data.items) if ctx.data and ctx.data.items else 0
202
+ summary = f"- {ctx.name} ({row_count} rows): {ctx.description}"
203
+ lines.append(summary)
204
+
205
+ if self.deps.urls:
206
+ joined = ", ".join(self.deps.urls)
207
+ lines.append(f"Relevant URLs: {joined}")
208
+
209
+ workspace_state = self.deps.workspace_state
210
+ if workspace_state and workspace_state.current_dashboard_info:
211
+ dashboard = workspace_state.current_dashboard_info
212
+ lines.append(
213
+ f"Active dashboard: {dashboard.name} (tab {dashboard.current_tab_id})"
214
+ )
215
+
216
+ if lines:
217
+ builder.add(SystemPromptPart(content="\n".join(lines)))
218
+
219
+ @cached_property
220
+ def toolset(self) -> AbstractToolset[OpenBBDeps] | None:
221
+ """Build combined toolset from widget toolsets."""
222
+ if not self._widget_toolsets:
223
+ return None
224
+ if len(self._widget_toolsets) == 1:
225
+ return self._widget_toolsets[0]
226
+ combined = CombinedToolset(self._widget_toolsets)
227
+ return cast(AbstractToolset[OpenBBDeps], combined)
228
+
229
+ @cached_property
230
+ def state(self) -> dict[str, Any] | None:
231
+ """Extract workspace state as a dictionary."""
232
+ if self.run_input.workspace_state is None:
233
+ return None
234
+ return self.run_input.workspace_state.model_dump(exclude_none=True)
235
+
236
+ def run_stream_native(
237
+ self,
238
+ *,
239
+ output_type=None,
240
+ message_history=None,
241
+ deferred_tool_results=None,
242
+ model=None,
243
+ deps=None,
244
+ model_settings=None,
245
+ usage_limits=None,
246
+ usage=None,
247
+ infer_name=True,
248
+ toolsets=None,
249
+ builtin_tools=None,
250
+ ):
251
+ """
252
+ Run the agent with OpenBB-specific defaults for
253
+ deps, messages, and deferred results.
254
+ """
255
+ deps = deps or self.deps # type: ignore[assignment]
256
+ deferred_tool_results = deferred_tool_results or self.deferred_tool_results
257
+ message_history = message_history or self.messages
258
+
259
+ return super().run_stream_native(
260
+ output_type=output_type,
261
+ message_history=message_history,
262
+ deferred_tool_results=deferred_tool_results,
263
+ model=model,
264
+ deps=deps,
265
+ model_settings=model_settings,
266
+ usage_limits=usage_limits,
267
+ usage=usage,
268
+ infer_name=infer_name,
269
+ toolsets=toolsets,
270
+ builtin_tools=builtin_tools,
271
+ )
272
+
273
+ def run_stream(
274
+ self,
275
+ *,
276
+ output_type=None,
277
+ message_history=None,
278
+ deferred_tool_results=None,
279
+ model=None,
280
+ deps=None,
281
+ model_settings=None,
282
+ usage_limits=None,
283
+ usage=None,
284
+ infer_name=True,
285
+ toolsets=None,
286
+ builtin_tools=None,
287
+ on_complete=None,
288
+ ):
289
+ """Run the agent and stream protocol-specific events with OpenBB defaults."""
290
+ deps = deps or self.deps # type: ignore[assignment]
291
+ deferred_tool_results = deferred_tool_results or self.deferred_tool_results
292
+ message_history = message_history or self.messages
293
+
294
+ return super().run_stream(
295
+ output_type=output_type,
296
+ message_history=message_history,
297
+ deferred_tool_results=deferred_tool_results,
298
+ model=model,
299
+ deps=deps,
300
+ model_settings=model_settings,
301
+ usage_limits=usage_limits,
302
+ usage=usage,
303
+ infer_name=infer_name,
304
+ toolsets=toolsets,
305
+ builtin_tools=builtin_tools,
306
+ on_complete=on_complete,
307
+ )
@@ -0,0 +1,58 @@
1
+ """Centralized configuration for OpenBB Pydantic AI adapter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Mapping
6
+
7
+ # Tool name constants
8
+ GET_WIDGET_DATA_TOOL_NAME = "get_widget_data"
9
+
10
+ # Field exclusion lists for citation and status update details
11
+ CITATION_EXCLUDED_FIELDS = frozenset(
12
+ [
13
+ "lastupdated",
14
+ "source",
15
+ "id",
16
+ "uuid",
17
+ "storedfileuuid",
18
+ "datakey",
19
+ "originalfilename",
20
+ "extension",
21
+ "category",
22
+ "subcategory",
23
+ "transcript_url",
24
+ ]
25
+ )
26
+
27
+ STATUS_UPDATE_EXCLUDED_FIELDS = frozenset(
28
+ [
29
+ "lastupdated",
30
+ "source",
31
+ "id",
32
+ "uuid",
33
+ "storedfileuuid",
34
+ "url",
35
+ "datakey",
36
+ "originalfilename",
37
+ "extension",
38
+ "category",
39
+ "subcategory",
40
+ "transcript_url",
41
+ ]
42
+ )
43
+
44
+ # Widget parameter type to JSON schema mapping
45
+ PARAM_TYPE_SCHEMA_MAP: Mapping[str, dict[str, Any]] = {
46
+ "string": {"type": "string"},
47
+ "text": {"type": "string"},
48
+ "number": {"type": "number"},
49
+ "integer": {"type": "integer"},
50
+ "boolean": {"type": "boolean"},
51
+ "date": {"type": "string", "format": "date"},
52
+ "ticker": {"type": "string"},
53
+ "endpoint": {"type": "string"},
54
+ }
55
+
56
+ # Content formatting limits
57
+ MAX_ARG_DISPLAY_CHARS = 160
58
+ MAX_ARG_PREVIEW_ITEMS = 2
@@ -0,0 +1,73 @@
1
+ """Dependency injection container for OpenBB workspace context."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Iterable, Sequence
7
+
8
+ from openbb_ai.models import (
9
+ QueryRequest,
10
+ RawContext,
11
+ Widget,
12
+ WidgetCollection,
13
+ WorkspaceState,
14
+ )
15
+
16
+
17
+ @dataclass(slots=True)
18
+ class OpenBBDeps:
19
+ """Dependency container passed to Pydantic AI runs.
20
+
21
+ The dependency bundle exposes OpenBB Workspace specific context so that
22
+ system prompts, tools, and output validators can access widget metadata or
23
+ other request scoped information via ``RunContext[OpenBBDeps]``.
24
+
25
+ Attributes:
26
+ widgets: Collection of available widgets organized by priority
27
+ context: Workspace context data (datasets, documents, etc.)
28
+ urls: Relevant URLs for the current request
29
+ workspace_state: Current workspace state including dashboard info
30
+ timezone: User's timezone (defaults to UTC)
31
+ state: Serialized workspace state as dictionary
32
+ """
33
+
34
+ widgets: WidgetCollection | None = None
35
+ context: list[RawContext] | None = None
36
+ urls: list[str] | None = None
37
+ workspace_state: WorkspaceState | None = None
38
+ timezone: str = "UTC"
39
+ state: dict[str, Any] = field(default_factory=dict)
40
+
41
+ def iter_widgets(self) -> Iterable[Widget]:
42
+ """Yield all widgets across priority groups (primary, secondary, extra)."""
43
+ if not self.widgets:
44
+ return
45
+
46
+ for group in (self.widgets.primary, self.widgets.secondary, self.widgets.extra):
47
+ yield from group
48
+
49
+ def get_widget_by_uuid(self, widget_uuid: str) -> Widget | None:
50
+ """Find a widget by its UUID string."""
51
+ for widget in self.iter_widgets():
52
+ if str(widget.uuid) == widget_uuid:
53
+ return widget
54
+ return None
55
+
56
+
57
+ def build_deps_from_request(request: QueryRequest) -> OpenBBDeps:
58
+ """Create an OpenBBDeps instance from an incoming QueryRequest."""
59
+ context: Sequence[RawContext] | None = request.context
60
+ urls: Sequence[str] | None = request.urls
61
+
62
+ workspace_state = request.workspace_state
63
+
64
+ return OpenBBDeps(
65
+ widgets=request.widgets,
66
+ context=list(context) if context is not None else None,
67
+ urls=list(urls) if urls is not None else None,
68
+ workspace_state=workspace_state,
69
+ timezone=request.timezone,
70
+ state=workspace_state.model_dump(exclude_none=True)
71
+ if workspace_state is not None
72
+ else {},
73
+ )
@@ -0,0 +1,183 @@
1
+ """Event builder utilities for consistent SSE event creation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Literal
6
+
7
+ from openbb_ai.models import (
8
+ Citation,
9
+ CitationCollection,
10
+ CitationCollectionSSE,
11
+ ClientArtifact,
12
+ MessageArtifactSSE,
13
+ MessageChunkSSE,
14
+ MessageChunkSSEData,
15
+ StatusUpdateSSE,
16
+ StatusUpdateSSEData,
17
+ )
18
+
19
+
20
+ class EventBuilder:
21
+ """Provides consistent interface for creating SSE events."""
22
+
23
+ @staticmethod
24
+ def reasoning(
25
+ message: str,
26
+ *,
27
+ details: dict[str, Any] | list[dict[str, Any] | str] | None = None,
28
+ event_type: Literal["INFO", "WARNING", "ERROR"] = "INFO",
29
+ artifacts: list[ClientArtifact] | None = None,
30
+ hidden: bool = False,
31
+ ) -> StatusUpdateSSE:
32
+ """Create a reasoning/status update event.
33
+
34
+ Parameters
35
+ ----------
36
+ message : str
37
+ The reasoning message
38
+ details : dict | list | None
39
+ Optional details to include
40
+ event_type : Literal["INFO", "WARNING", "ERROR"]
41
+ The type of event
42
+ artifacts : list[ClientArtifact] | None
43
+ Optional artifacts to include
44
+ hidden : bool
45
+ Whether to hide this event
46
+
47
+ Returns
48
+ -------
49
+ StatusUpdateSSE
50
+ A status update SSE event
51
+ """
52
+ # Normalize dict details to list format expected by StatusUpdateSSEData
53
+ normalized_details: list[dict[str, Any] | str] | None = None
54
+ if details is not None:
55
+ if isinstance(details, dict):
56
+ normalized_details = [details]
57
+ else:
58
+ # details is already a list[dict[str, Any] | str]
59
+ normalized_details = details # type: ignore[assignment]
60
+
61
+ return StatusUpdateSSE(
62
+ data=StatusUpdateSSEData(
63
+ eventType=event_type,
64
+ message=message,
65
+ group="reasoning",
66
+ details=normalized_details,
67
+ artifacts=artifacts,
68
+ hidden=hidden,
69
+ )
70
+ )
71
+
72
+ @staticmethod
73
+ def reasoning_with_artifacts(
74
+ message: str,
75
+ artifacts: list[ClientArtifact],
76
+ ) -> StatusUpdateSSE:
77
+ """Create a reasoning event with inline artifacts.
78
+
79
+ Parameters
80
+ ----------
81
+ message : str
82
+ The reasoning message
83
+ artifacts : list[ClientArtifact]
84
+ Artifacts to include inline
85
+
86
+ Returns
87
+ -------
88
+ StatusUpdateSSE
89
+ A status update SSE event with artifacts
90
+ """
91
+ return EventBuilder.reasoning(
92
+ message,
93
+ event_type="INFO",
94
+ artifacts=artifacts,
95
+ )
96
+
97
+ @staticmethod
98
+ def message(content: str) -> MessageChunkSSE:
99
+ """Create a message chunk event.
100
+
101
+ Parameters
102
+ ----------
103
+ content : str
104
+ The message content
105
+
106
+ Returns
107
+ -------
108
+ MessageChunkSSE
109
+ A message chunk SSE event
110
+ """
111
+ return MessageChunkSSE(data=MessageChunkSSEData(delta=content))
112
+
113
+ @staticmethod
114
+ def artifact(artifact: ClientArtifact) -> MessageArtifactSSE:
115
+ """Create an artifact event.
116
+
117
+ Parameters
118
+ ----------
119
+ artifact : ClientArtifact
120
+ The artifact to send
121
+
122
+ Returns
123
+ -------
124
+ MessageArtifactSSE
125
+ An artifact SSE event
126
+ """
127
+ return MessageArtifactSSE(data=artifact)
128
+
129
+ @staticmethod
130
+ def citations(citation_list: list[Citation]) -> CitationCollectionSSE:
131
+ """Create a citation collection event.
132
+
133
+ Parameters
134
+ ----------
135
+ citation_list : list[Citation]
136
+ List of citations
137
+
138
+ Returns
139
+ -------
140
+ CitationCollectionSSE
141
+ A citation collection SSE event
142
+ """
143
+ return CitationCollectionSSE(data=CitationCollection(citations=citation_list))
144
+
145
+ @staticmethod
146
+ def error(
147
+ message: str, *, details: dict[str, Any] | None = None
148
+ ) -> StatusUpdateSSE:
149
+ """Create an error event.
150
+
151
+ Parameters
152
+ ----------
153
+ message : str
154
+ The error message
155
+ details : dict | None
156
+ Optional error details
157
+
158
+ Returns
159
+ -------
160
+ StatusUpdateSSE
161
+ An error status update SSE event
162
+ """
163
+ return EventBuilder.reasoning(message, event_type="ERROR", details=details)
164
+
165
+ @staticmethod
166
+ def warning(
167
+ message: str, *, details: dict[str, Any] | None = None
168
+ ) -> StatusUpdateSSE:
169
+ """Create a warning event.
170
+
171
+ Parameters
172
+ ----------
173
+ message : str
174
+ The warning message
175
+ details : dict | None
176
+ Optional warning details
177
+
178
+ Returns
179
+ -------
180
+ StatusUpdateSSE
181
+ A warning status update SSE event
182
+ """
183
+ return EventBuilder.reasoning(message, event_type="WARNING", details=details)