openbb-pydantic-ai 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,110 @@
1
+ """Content serialization utilities for OpenBB Pydantic AI adapter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from typing import Any, cast
7
+
8
+ from openbb_ai.models import LlmClientFunctionCallResultMessage
9
+
10
+ from ._types import SerializedContent
11
+
12
+
13
+ class ContentSerializer:
14
+ """Handles serialization and parsing of content across the adapter."""
15
+
16
+ @staticmethod
17
+ def serialize_result(
18
+ message: LlmClientFunctionCallResultMessage,
19
+ ) -> SerializedContent:
20
+ """Serialize a function call result message into a content dictionary.
21
+
22
+ Parameters
23
+ ----------
24
+ message : LlmClientFunctionCallResultMessage
25
+ The function call result message to serialize
26
+
27
+ Returns
28
+ -------
29
+ SerializedContent
30
+ A typed dictionary containing input_arguments, data, and
31
+ optionally extra_state
32
+ """
33
+ data: list[Any] = []
34
+ for item in message.data:
35
+ if hasattr(item, "model_dump"):
36
+ data.append(item.model_dump(mode="json", exclude_none=True))
37
+ else:
38
+ data.append(item)
39
+
40
+ content: SerializedContent = cast(
41
+ SerializedContent,
42
+ {
43
+ "input_arguments": message.input_arguments,
44
+ "data": data,
45
+ },
46
+ )
47
+ if message.extra_state:
48
+ content["extra_state"] = message.extra_state
49
+ return content
50
+
51
+ @staticmethod
52
+ def parse_json(raw_content: str) -> Any:
53
+ """Parse JSON content, returning the original string if parsing fails.
54
+
55
+ Parameters
56
+ ----------
57
+ raw_content : str
58
+ The raw JSON string to parse
59
+
60
+ Returns
61
+ -------
62
+ Any
63
+ Parsed JSON object or original string if parsing fails
64
+ """
65
+ try:
66
+ return json.loads(raw_content)
67
+ except (json.JSONDecodeError, ValueError):
68
+ return raw_content
69
+
70
+ @staticmethod
71
+ def to_string(content: Any) -> str | None:
72
+ """Convert content to string with JSON fallback.
73
+
74
+ Parameters
75
+ ----------
76
+ content : Any
77
+ Content to stringify
78
+
79
+ Returns
80
+ -------
81
+ str | None
82
+ String representation or None if content is None
83
+ """
84
+ if content is None:
85
+ return None
86
+ if isinstance(content, str):
87
+ return content
88
+ try:
89
+ return json.dumps(content, default=str)
90
+ except (TypeError, ValueError):
91
+ return str(content)
92
+
93
+ @staticmethod
94
+ def to_json(value: Any) -> str:
95
+ """Convert value to JSON string with fallback to str().
96
+
97
+ Parameters
98
+ ----------
99
+ value : Any
100
+ Value to convert to JSON
101
+
102
+ Returns
103
+ -------
104
+ str
105
+ JSON string representation
106
+ """
107
+ try:
108
+ return json.dumps(value, default=str)
109
+ except (TypeError, ValueError):
110
+ return str(value)
@@ -0,0 +1,264 @@
1
+ """Toolset implementations for OpenBB widgets and visualization."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from collections.abc import Mapping, Sequence
7
+ from typing import Any, Literal
8
+
9
+ from openbb_ai.helpers import chart, table
10
+ from openbb_ai.models import Undefined, Widget, WidgetCollection, WidgetParam
11
+ from pydantic_ai import CallDeferred, Tool
12
+ from pydantic_ai.tools import RunContext
13
+ from pydantic_ai.toolsets import FunctionToolset
14
+
15
+ from ._dependencies import OpenBBDeps
16
+
17
+
18
+ def _base_param_schema(param: WidgetParam) -> dict[str, Any]:
19
+ """Build the base JSON schema for a widget parameter."""
20
+ type_mapping: dict[str, dict[str, Any]] = {
21
+ "string": {"type": "string"},
22
+ "text": {"type": "string"},
23
+ "number": {"type": "number"},
24
+ "integer": {"type": "integer"},
25
+ "boolean": {"type": "boolean"},
26
+ "date": {"type": "string", "format": "date"},
27
+ "ticker": {"type": "string"},
28
+ "endpoint": {"type": "string"},
29
+ }
30
+
31
+ schema = type_mapping.get(param.type, {"type": "string"})
32
+ schema = dict(schema) # copy
33
+ schema["description"] = param.description
34
+
35
+ if param.options:
36
+ schema["enum"] = list(param.options)
37
+
38
+ if param.get_options:
39
+ schema.setdefault(
40
+ "description",
41
+ param.description + " (options retrieved dynamically)",
42
+ )
43
+
44
+ if param.default_value is not Undefined.UNDEFINED:
45
+ schema["default"] = param.default_value
46
+
47
+ if param.current_value is not None and param.multi_select is False:
48
+ schema.setdefault("examples", []).append(param.current_value)
49
+
50
+ return schema
51
+
52
+
53
+ def _param_schema(param: WidgetParam) -> tuple[dict[str, Any], bool]:
54
+ """Return the schema for a parameter and whether it's required."""
55
+ schema = _base_param_schema(param)
56
+
57
+ if param.multi_select:
58
+ schema = {
59
+ "type": "array",
60
+ "items": schema,
61
+ "description": schema.get("description"),
62
+ }
63
+
64
+ is_required = param.default_value is Undefined.UNDEFINED
65
+ return schema, is_required
66
+
67
+
68
+ def _widget_schema(widget: Widget) -> dict[str, Any]:
69
+ properties: dict[str, Any] = {}
70
+ required: list[str] = []
71
+
72
+ for param in widget.params:
73
+ schema, is_required = _param_schema(param)
74
+ properties[param.name] = schema
75
+ if is_required:
76
+ required.append(param.name)
77
+
78
+ widget_schema: dict[str, Any] = {
79
+ "type": "object",
80
+ "title": widget.name,
81
+ "properties": properties,
82
+ "additionalProperties": False,
83
+ }
84
+ if required:
85
+ widget_schema["required"] = required
86
+
87
+ return widget_schema
88
+
89
+
90
+ def _slugify(value: str) -> str:
91
+ slug = re.sub(r"[^0-9A-Za-z]+", "_", value).strip("_")
92
+ return slug.lower() or "value"
93
+
94
+
95
+ def build_widget_tool_name(widget: Widget) -> str:
96
+ """Generate a deterministic tool name for a widget.
97
+
98
+ The tool name is constructed as: openbb_widget_{origin}_{widget_id}
99
+ where both origin and widget_id are slugified.
100
+
101
+ Parameters
102
+ ----------
103
+ widget : Widget
104
+ The widget to generate a tool name for
105
+
106
+ Returns
107
+ -------
108
+ str
109
+ A unique, deterministic tool name string
110
+ """
111
+ origin_slug = _slugify(widget.origin)
112
+ widget_slug = _slugify(widget.widget_id)
113
+ return f"openbb_widget_{origin_slug}_{widget_slug}"
114
+
115
+
116
+ def build_widget_tool(widget: Widget) -> Tool:
117
+ """Create a deferred tool for a widget.
118
+
119
+ This creates a Pydantic AI tool that will be called by the LLM but
120
+ executed by the OpenBB Workspace frontend (deferred execution).
121
+
122
+ Parameters
123
+ ----------
124
+ widget : Widget
125
+ The widget to create a tool for
126
+
127
+ Returns
128
+ -------
129
+ Tool
130
+ A Tool configured for deferred execution
131
+ """
132
+ tool_name = build_widget_tool_name(widget)
133
+ schema = _widget_schema(widget)
134
+ description = widget.description or widget.name
135
+
136
+ async def _call_widget(ctx: RunContext[OpenBBDeps], **input_arguments: Any) -> None:
137
+ # Ensure we have a tool call id for deferred execution
138
+ if ctx.tool_call_id is None:
139
+ raise RuntimeError("Deferred widget tools require a tool call id.")
140
+ raise CallDeferred
141
+
142
+ _call_widget.__name__ = f"call_widget_{widget.uuid}"
143
+
144
+ return Tool.from_schema(
145
+ function=_call_widget,
146
+ name=tool_name,
147
+ description=description,
148
+ json_schema=schema,
149
+ takes_ctx=True,
150
+ )
151
+
152
+
153
+ class WidgetToolset(FunctionToolset[OpenBBDeps]):
154
+ """Toolset that exposes widgets as deferred tools."""
155
+
156
+ def __init__(self, widgets: Sequence[Widget]):
157
+ super().__init__()
158
+ self._widgets_by_tool: dict[str, Widget] = {}
159
+
160
+ for widget in widgets:
161
+ tool = build_widget_tool(widget)
162
+ self.add_tool(tool)
163
+ self._widgets_by_tool[tool.name] = widget
164
+
165
+ @property
166
+ def widgets_by_tool(self) -> Mapping[str, Widget]:
167
+ return self._widgets_by_tool
168
+
169
+
170
+ class VisualizationToolset(FunctionToolset[OpenBBDeps]):
171
+ """Toolset exposing helper utilities for charts and tables."""
172
+
173
+ def __init__(self) -> None:
174
+ super().__init__()
175
+
176
+ def _create_table(
177
+ data: list[dict[str, Any]],
178
+ name: str | None = None,
179
+ description: str | None = None,
180
+ ):
181
+ """Create a table artifact to display in OpenBB Workspace."""
182
+
183
+ return table(data=data, name=name, description=description)
184
+
185
+ def _create_chart(
186
+ type: Literal["line", "bar", "scatter", "pie", "donut"],
187
+ data: list[dict[str, Any]],
188
+ x_key: str | None = None,
189
+ y_keys: list[str] | None = None,
190
+ angle_key: str | None = None,
191
+ callout_label_key: str | None = None,
192
+ name: str | None = None,
193
+ description: str | None = None,
194
+ ):
195
+ """Create a chart artifact (line, bar, scatter, pie, donut).
196
+
197
+ Raises
198
+ ------
199
+ ValueError
200
+ If required parameters for the given chart ``type`` are missing.
201
+ """
202
+
203
+ if type in {"line", "bar", "scatter"}:
204
+ if not x_key:
205
+ raise ValueError(
206
+ "x_key is required for line, bar, and scatter charts"
207
+ )
208
+ if not y_keys:
209
+ raise ValueError(
210
+ "y_keys is required for line, bar, and scatter charts"
211
+ )
212
+ elif type in {"pie", "donut"}:
213
+ if not angle_key:
214
+ raise ValueError("angle_key is required for pie and donut charts")
215
+ if not callout_label_key:
216
+ raise ValueError(
217
+ "callout_label_key is required for pie and donut charts"
218
+ )
219
+
220
+ return chart(
221
+ type=type,
222
+ data=data,
223
+ x_key=x_key,
224
+ y_keys=y_keys,
225
+ angle_key=angle_key,
226
+ callout_label_key=callout_label_key,
227
+ name=name,
228
+ description=description,
229
+ )
230
+
231
+ self.add_function(_create_table, name="openbb_create_table")
232
+ self.add_function(_create_chart, name="openbb_create_chart")
233
+
234
+
235
+ def build_widget_toolsets(
236
+ collection: WidgetCollection | None,
237
+ ) -> tuple[FunctionToolset[OpenBBDeps], ...]:
238
+ """Create toolsets for each widget priority group plus visualization tools.
239
+
240
+ Widgets are organized into separate toolsets by priority (primary, secondary, extra)
241
+ to allow control over tool selection. The visualization toolset is always
242
+ included for creating charts and tables.
243
+
244
+ Parameters
245
+ ----------
246
+ collection : WidgetCollection | None
247
+ Widget collection with priority groups, or None
248
+
249
+ Returns
250
+ -------
251
+ tuple[FunctionToolset[OpenBBDeps], ...]
252
+ Toolsets including widget toolsets and visualization toolset
253
+ """
254
+ if collection is None:
255
+ return (VisualizationToolset(),)
256
+
257
+ toolsets: list[FunctionToolset[OpenBBDeps]] = []
258
+ for widgets in (collection.primary, collection.secondary, collection.extra):
259
+ if widgets:
260
+ toolsets.append(WidgetToolset(widgets))
261
+
262
+ toolsets.append(VisualizationToolset())
263
+
264
+ return tuple(toolsets)
@@ -0,0 +1,39 @@
1
+ """Type definitions and protocols for OpenBB Pydantic AI adapter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ from typing import Any, Protocol, TypedDict
7
+
8
+ if sys.version_info >= (3, 11):
9
+ from typing import NotRequired
10
+ else:
11
+ from typing_extensions import NotRequired # type: ignore[assignment]
12
+
13
+
14
+ class SerializedContent(TypedDict):
15
+ """Structure for serialized tool result content."""
16
+
17
+ input_arguments: dict[str, Any]
18
+ data: list[Any]
19
+ extra_state: NotRequired[dict[str, Any]]
20
+
21
+
22
+ class ToolCallMetadata(TypedDict):
23
+ """Metadata for tracking tool calls in flight."""
24
+
25
+ tool_call_id: str
26
+ widget_uuid: str
27
+ widget_id: str
28
+
29
+
30
+ class TextStreamCallback(Protocol):
31
+ """Protocol for callbacks that mark text as having been streamed."""
32
+
33
+ def __call__(self) -> None:
34
+ """Mark that text has been streamed."""
35
+ ...
36
+
37
+
38
+ # Type alias for structured detail entries
39
+ DetailEntry = dict[str, Any] | str
@@ -0,0 +1,132 @@
1
+ """Utility functions for OpenBB Pydantic AI UI adapter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ from collections.abc import Mapping, Sequence
8
+ from typing import Any
9
+
10
+ from ._config import (
11
+ MAX_ARG_DISPLAY_CHARS,
12
+ MAX_ARG_PREVIEW_ITEMS,
13
+ )
14
+ from ._serializers import ContentSerializer
15
+
16
+
17
+ def hash_tool_call(function: str, input_arguments: dict[str, Any]) -> str:
18
+ """Generate a deterministic hash-based ID for a tool call.
19
+
20
+ This creates a unique identifier by hashing the function name and arguments,
21
+ ensuring consistent tool call IDs across message history and deferred results.
22
+
23
+ Parameters
24
+ ----------
25
+ function : str
26
+ The name of the function/tool being called
27
+ input_arguments : dict[str, Any]
28
+ The arguments passed to the tool
29
+
30
+ Returns
31
+ -------
32
+ str
33
+ A string combining the function name with a 16-character hash digest
34
+ """
35
+ payload = json.dumps(
36
+ {"function": function, "input_arguments": input_arguments},
37
+ sort_keys=True,
38
+ default=str,
39
+ )
40
+ digest = hashlib.sha256(payload.encode("utf-8")).hexdigest()
41
+ return f"{function}_{digest[:16]}"
42
+
43
+
44
+ def normalize_args(args: Any) -> dict[str, Any]:
45
+ """Normalize tool call arguments to a dictionary."""
46
+ if isinstance(args, dict):
47
+ return args
48
+ if isinstance(args, str):
49
+ try:
50
+ parsed = json.loads(args)
51
+ if isinstance(parsed, dict):
52
+ return parsed
53
+ except ValueError:
54
+ pass
55
+ return {}
56
+
57
+
58
+ def get_str(mapping: Mapping[str, Any], *keys: str) -> str | None:
59
+ """Return the first string value found for the given keys."""
60
+ for key in keys:
61
+ value = mapping.get(key)
62
+ if isinstance(value, str):
63
+ return value
64
+ return None
65
+
66
+
67
+ def get_str_list(mapping: Mapping[str, Any], *keys: str) -> list[str] | None:
68
+ """Return the first list of strings (or single string) found for the keys."""
69
+ for key in keys:
70
+ value = mapping.get(key)
71
+ if isinstance(value, str):
72
+ return [value]
73
+ if isinstance(value, list):
74
+ items = [item for item in value if isinstance(item, str)]
75
+ if items:
76
+ return items
77
+ return None
78
+
79
+
80
+ def _truncate(value: str, max_chars: int = 160) -> str:
81
+ if len(value) <= max_chars:
82
+ return value
83
+ return value[: max_chars - 3] + "..."
84
+
85
+
86
+ def _json_dump(value: Any) -> str:
87
+ return ContentSerializer.to_json(value)
88
+
89
+
90
+ def format_arg_value(
91
+ value: Any,
92
+ *,
93
+ max_chars: int = MAX_ARG_DISPLAY_CHARS,
94
+ max_items: int = MAX_ARG_PREVIEW_ITEMS,
95
+ ) -> str:
96
+ """Summarize nested structures so reasoning details stay readable."""
97
+
98
+ if isinstance(value, str):
99
+ return _truncate(value, max_chars)
100
+
101
+ if isinstance(value, (int, float, bool)) or value is None:
102
+ return _json_dump(value)
103
+
104
+ if isinstance(value, Mapping):
105
+ keys = list(value.keys())
106
+ preview_keys = keys[:max_items]
107
+ preview = {k: value[k] for k in preview_keys}
108
+ suffix = "..." if len(keys) > max_items else ""
109
+ return _truncate(
110
+ f"dict(keys={preview_keys}{suffix}, sample={_json_dump(preview)})",
111
+ max_chars,
112
+ )
113
+
114
+ if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray)):
115
+ seq = list(value)
116
+ preview = seq[:max_items]
117
+ suffix = "..." if len(seq) > max_items else ""
118
+ return _truncate(
119
+ f"list(len={len(seq)}{suffix}, sample={_json_dump(preview)})",
120
+ max_chars,
121
+ )
122
+
123
+ return _truncate(_json_dump(value), max_chars)
124
+
125
+
126
+ def format_args(args: Mapping[str, Any]) -> dict[str, str]:
127
+ """Format a mapping of arguments into readable key/value strings."""
128
+
129
+ formatted: dict[str, str] = {}
130
+ for key, value in args.items():
131
+ formatted[key] = format_arg_value(value)
132
+ return formatted
@@ -0,0 +1,145 @@
1
+ """Widget registry for centralized widget discovery and lookup."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterator, Mapping, Sequence
6
+ from typing import TYPE_CHECKING
7
+
8
+ from openbb_ai.models import (
9
+ LlmClientFunctionCallResultMessage,
10
+ Widget,
11
+ WidgetCollection,
12
+ )
13
+
14
+ from ._config import GET_WIDGET_DATA_TOOL_NAME
15
+
16
+ if TYPE_CHECKING:
17
+ from pydantic_ai.toolsets import FunctionToolset
18
+
19
+ from ._dependencies import OpenBBDeps
20
+
21
+
22
+ class WidgetRegistry:
23
+ """Centralized registry for widget discovery and lookup."""
24
+
25
+ def __init__(
26
+ self,
27
+ collection: WidgetCollection | None = None,
28
+ toolsets: Sequence[FunctionToolset[OpenBBDeps]] | None = None,
29
+ ):
30
+ """Initialize widget registry from collection and toolsets.
31
+
32
+ Parameters
33
+ ----------
34
+ collection : WidgetCollection | None
35
+ Widget collection with priority groups
36
+ toolsets : Sequence[FunctionToolset[OpenBBDeps]] | None
37
+ Widget toolsets
38
+ """
39
+ self._by_tool_name: dict[str, Widget] = {}
40
+ self._by_uuid: dict[str, Widget] = {}
41
+
42
+ # Build lookup from toolsets
43
+ if toolsets:
44
+ for toolset in toolsets:
45
+ widgets = getattr(toolset, "widgets_by_tool", None)
46
+ if widgets:
47
+ for tool_name, widget in widgets.items():
48
+ self._by_tool_name[tool_name] = widget
49
+ self._by_uuid[str(widget.uuid)] = widget
50
+
51
+ # Also index from collection if provided
52
+ if collection:
53
+ for widget in self._iter_collection(collection):
54
+ self._by_uuid[str(widget.uuid)] = widget
55
+
56
+ @staticmethod
57
+ def _iter_collection(collection: WidgetCollection) -> Iterator[Widget]:
58
+ """Iterate all widgets in a collection."""
59
+ for group in (collection.primary, collection.secondary, collection.extra):
60
+ yield from group
61
+
62
+ def find_by_tool_name(self, name: str) -> Widget | None:
63
+ """Find a widget by its tool name.
64
+
65
+ Parameters
66
+ ----------
67
+ name : str
68
+ The tool name to search for
69
+
70
+ Returns
71
+ -------
72
+ Widget | None
73
+ The widget if found, None otherwise
74
+ """
75
+ return self._by_tool_name.get(name)
76
+
77
+ def find_by_uuid(self, uuid: str) -> Widget | None:
78
+ """Find a widget by its UUID string.
79
+
80
+ Parameters
81
+ ----------
82
+ uuid : str
83
+ The UUID to search for
84
+
85
+ Returns
86
+ -------
87
+ Widget | None
88
+ The widget if found, None otherwise
89
+ """
90
+ return self._by_uuid.get(uuid)
91
+
92
+ def find_for_result(
93
+ self, result: LlmClientFunctionCallResultMessage
94
+ ) -> Widget | None:
95
+ """Find the widget that produced a result message.
96
+
97
+ Parameters
98
+ ----------
99
+ result : LlmClientFunctionCallResultMessage
100
+ The result message to find a widget for
101
+
102
+ Returns
103
+ -------
104
+ Widget | None
105
+ The widget if found, None otherwise
106
+ """
107
+ # Check direct tool name match
108
+ widget = self.find_by_tool_name(result.function)
109
+ if widget is not None:
110
+ return widget
111
+
112
+ # Check if it's a get_widget_data call
113
+ if result.function == GET_WIDGET_DATA_TOOL_NAME:
114
+ data_sources = result.input_arguments.get("data_sources", [])
115
+ if data_sources:
116
+ widget_uuid = data_sources[0].get("widget_uuid")
117
+ if widget_uuid:
118
+ return self.find_by_uuid(widget_uuid)
119
+
120
+ return None
121
+
122
+ def iter_all(self) -> Iterator[Widget]:
123
+ """Iterate all registered widgets.
124
+
125
+ Returns
126
+ -------
127
+ Iterator[Widget]
128
+ Iterator over all widgets
129
+ """
130
+ # Use dict to deduplicate by UUID
131
+ seen = set()
132
+ for widget in self._by_uuid.values():
133
+ if str(widget.uuid) not in seen:
134
+ seen.add(str(widget.uuid))
135
+ yield widget
136
+
137
+ def as_mapping(self) -> Mapping[str, Widget]:
138
+ """Get widget lookup as a read-only mapping by tool name.
139
+
140
+ Returns
141
+ -------
142
+ Mapping[str, Widget]
143
+ Read-only mapping from tool names to widgets
144
+ """
145
+ return self._by_tool_name