openbb-pydantic-ai 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openbb_pydantic_ai/__init__.py +52 -0
- openbb_pydantic_ai/_adapter.py +307 -0
- openbb_pydantic_ai/_config.py +58 -0
- openbb_pydantic_ai/_dependencies.py +73 -0
- openbb_pydantic_ai/_event_builder.py +183 -0
- openbb_pydantic_ai/_event_stream.py +363 -0
- openbb_pydantic_ai/_event_stream_components.py +155 -0
- openbb_pydantic_ai/_event_stream_helpers.py +422 -0
- openbb_pydantic_ai/_exceptions.py +61 -0
- openbb_pydantic_ai/_message_transformer.py +127 -0
- openbb_pydantic_ai/_serializers.py +110 -0
- openbb_pydantic_ai/_toolsets.py +264 -0
- openbb_pydantic_ai/_types.py +39 -0
- openbb_pydantic_ai/_utils.py +132 -0
- openbb_pydantic_ai/_widget_registry.py +145 -0
- openbb_pydantic_ai-0.1.1.dist-info/METADATA +139 -0
- openbb_pydantic_ai-0.1.1.dist-info/RECORD +18 -0
- openbb_pydantic_ai-0.1.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Content serialization utilities for OpenBB Pydantic AI adapter."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from openbb_ai.models import LlmClientFunctionCallResultMessage
|
|
9
|
+
|
|
10
|
+
from ._types import SerializedContent
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ContentSerializer:
|
|
14
|
+
"""Handles serialization and parsing of content across the adapter."""
|
|
15
|
+
|
|
16
|
+
@staticmethod
|
|
17
|
+
def serialize_result(
|
|
18
|
+
message: LlmClientFunctionCallResultMessage,
|
|
19
|
+
) -> SerializedContent:
|
|
20
|
+
"""Serialize a function call result message into a content dictionary.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
message : LlmClientFunctionCallResultMessage
|
|
25
|
+
The function call result message to serialize
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
SerializedContent
|
|
30
|
+
A typed dictionary containing input_arguments, data, and
|
|
31
|
+
optionally extra_state
|
|
32
|
+
"""
|
|
33
|
+
data: list[Any] = []
|
|
34
|
+
for item in message.data:
|
|
35
|
+
if hasattr(item, "model_dump"):
|
|
36
|
+
data.append(item.model_dump(mode="json", exclude_none=True))
|
|
37
|
+
else:
|
|
38
|
+
data.append(item)
|
|
39
|
+
|
|
40
|
+
content: SerializedContent = cast(
|
|
41
|
+
SerializedContent,
|
|
42
|
+
{
|
|
43
|
+
"input_arguments": message.input_arguments,
|
|
44
|
+
"data": data,
|
|
45
|
+
},
|
|
46
|
+
)
|
|
47
|
+
if message.extra_state:
|
|
48
|
+
content["extra_state"] = message.extra_state
|
|
49
|
+
return content
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def parse_json(raw_content: str) -> Any:
|
|
53
|
+
"""Parse JSON content, returning the original string if parsing fails.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
raw_content : str
|
|
58
|
+
The raw JSON string to parse
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
Any
|
|
63
|
+
Parsed JSON object or original string if parsing fails
|
|
64
|
+
"""
|
|
65
|
+
try:
|
|
66
|
+
return json.loads(raw_content)
|
|
67
|
+
except (json.JSONDecodeError, ValueError):
|
|
68
|
+
return raw_content
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def to_string(content: Any) -> str | None:
|
|
72
|
+
"""Convert content to string with JSON fallback.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
content : Any
|
|
77
|
+
Content to stringify
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
str | None
|
|
82
|
+
String representation or None if content is None
|
|
83
|
+
"""
|
|
84
|
+
if content is None:
|
|
85
|
+
return None
|
|
86
|
+
if isinstance(content, str):
|
|
87
|
+
return content
|
|
88
|
+
try:
|
|
89
|
+
return json.dumps(content, default=str)
|
|
90
|
+
except (TypeError, ValueError):
|
|
91
|
+
return str(content)
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def to_json(value: Any) -> str:
|
|
95
|
+
"""Convert value to JSON string with fallback to str().
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
value : Any
|
|
100
|
+
Value to convert to JSON
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
str
|
|
105
|
+
JSON string representation
|
|
106
|
+
"""
|
|
107
|
+
try:
|
|
108
|
+
return json.dumps(value, default=str)
|
|
109
|
+
except (TypeError, ValueError):
|
|
110
|
+
return str(value)
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""Toolset implementations for OpenBB widgets and visualization."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from collections.abc import Mapping, Sequence
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
from openbb_ai.helpers import chart, table
|
|
10
|
+
from openbb_ai.models import Undefined, Widget, WidgetCollection, WidgetParam
|
|
11
|
+
from pydantic_ai import CallDeferred, Tool
|
|
12
|
+
from pydantic_ai.tools import RunContext
|
|
13
|
+
from pydantic_ai.toolsets import FunctionToolset
|
|
14
|
+
|
|
15
|
+
from ._dependencies import OpenBBDeps
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _base_param_schema(param: WidgetParam) -> dict[str, Any]:
|
|
19
|
+
"""Build the base JSON schema for a widget parameter."""
|
|
20
|
+
type_mapping: dict[str, dict[str, Any]] = {
|
|
21
|
+
"string": {"type": "string"},
|
|
22
|
+
"text": {"type": "string"},
|
|
23
|
+
"number": {"type": "number"},
|
|
24
|
+
"integer": {"type": "integer"},
|
|
25
|
+
"boolean": {"type": "boolean"},
|
|
26
|
+
"date": {"type": "string", "format": "date"},
|
|
27
|
+
"ticker": {"type": "string"},
|
|
28
|
+
"endpoint": {"type": "string"},
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
schema = type_mapping.get(param.type, {"type": "string"})
|
|
32
|
+
schema = dict(schema) # copy
|
|
33
|
+
schema["description"] = param.description
|
|
34
|
+
|
|
35
|
+
if param.options:
|
|
36
|
+
schema["enum"] = list(param.options)
|
|
37
|
+
|
|
38
|
+
if param.get_options:
|
|
39
|
+
schema.setdefault(
|
|
40
|
+
"description",
|
|
41
|
+
param.description + " (options retrieved dynamically)",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if param.default_value is not Undefined.UNDEFINED:
|
|
45
|
+
schema["default"] = param.default_value
|
|
46
|
+
|
|
47
|
+
if param.current_value is not None and param.multi_select is False:
|
|
48
|
+
schema.setdefault("examples", []).append(param.current_value)
|
|
49
|
+
|
|
50
|
+
return schema
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _param_schema(param: WidgetParam) -> tuple[dict[str, Any], bool]:
|
|
54
|
+
"""Return the schema for a parameter and whether it's required."""
|
|
55
|
+
schema = _base_param_schema(param)
|
|
56
|
+
|
|
57
|
+
if param.multi_select:
|
|
58
|
+
schema = {
|
|
59
|
+
"type": "array",
|
|
60
|
+
"items": schema,
|
|
61
|
+
"description": schema.get("description"),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
is_required = param.default_value is Undefined.UNDEFINED
|
|
65
|
+
return schema, is_required
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _widget_schema(widget: Widget) -> dict[str, Any]:
|
|
69
|
+
properties: dict[str, Any] = {}
|
|
70
|
+
required: list[str] = []
|
|
71
|
+
|
|
72
|
+
for param in widget.params:
|
|
73
|
+
schema, is_required = _param_schema(param)
|
|
74
|
+
properties[param.name] = schema
|
|
75
|
+
if is_required:
|
|
76
|
+
required.append(param.name)
|
|
77
|
+
|
|
78
|
+
widget_schema: dict[str, Any] = {
|
|
79
|
+
"type": "object",
|
|
80
|
+
"title": widget.name,
|
|
81
|
+
"properties": properties,
|
|
82
|
+
"additionalProperties": False,
|
|
83
|
+
}
|
|
84
|
+
if required:
|
|
85
|
+
widget_schema["required"] = required
|
|
86
|
+
|
|
87
|
+
return widget_schema
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _slugify(value: str) -> str:
|
|
91
|
+
slug = re.sub(r"[^0-9A-Za-z]+", "_", value).strip("_")
|
|
92
|
+
return slug.lower() or "value"
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def build_widget_tool_name(widget: Widget) -> str:
|
|
96
|
+
"""Generate a deterministic tool name for a widget.
|
|
97
|
+
|
|
98
|
+
The tool name is constructed as: openbb_widget_{origin}_{widget_id}
|
|
99
|
+
where both origin and widget_id are slugified.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
widget : Widget
|
|
104
|
+
The widget to generate a tool name for
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
str
|
|
109
|
+
A unique, deterministic tool name string
|
|
110
|
+
"""
|
|
111
|
+
origin_slug = _slugify(widget.origin)
|
|
112
|
+
widget_slug = _slugify(widget.widget_id)
|
|
113
|
+
return f"openbb_widget_{origin_slug}_{widget_slug}"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def build_widget_tool(widget: Widget) -> Tool:
|
|
117
|
+
"""Create a deferred tool for a widget.
|
|
118
|
+
|
|
119
|
+
This creates a Pydantic AI tool that will be called by the LLM but
|
|
120
|
+
executed by the OpenBB Workspace frontend (deferred execution).
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
widget : Widget
|
|
125
|
+
The widget to create a tool for
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
Tool
|
|
130
|
+
A Tool configured for deferred execution
|
|
131
|
+
"""
|
|
132
|
+
tool_name = build_widget_tool_name(widget)
|
|
133
|
+
schema = _widget_schema(widget)
|
|
134
|
+
description = widget.description or widget.name
|
|
135
|
+
|
|
136
|
+
async def _call_widget(ctx: RunContext[OpenBBDeps], **input_arguments: Any) -> None:
|
|
137
|
+
# Ensure we have a tool call id for deferred execution
|
|
138
|
+
if ctx.tool_call_id is None:
|
|
139
|
+
raise RuntimeError("Deferred widget tools require a tool call id.")
|
|
140
|
+
raise CallDeferred
|
|
141
|
+
|
|
142
|
+
_call_widget.__name__ = f"call_widget_{widget.uuid}"
|
|
143
|
+
|
|
144
|
+
return Tool.from_schema(
|
|
145
|
+
function=_call_widget,
|
|
146
|
+
name=tool_name,
|
|
147
|
+
description=description,
|
|
148
|
+
json_schema=schema,
|
|
149
|
+
takes_ctx=True,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class WidgetToolset(FunctionToolset[OpenBBDeps]):
|
|
154
|
+
"""Toolset that exposes widgets as deferred tools."""
|
|
155
|
+
|
|
156
|
+
def __init__(self, widgets: Sequence[Widget]):
|
|
157
|
+
super().__init__()
|
|
158
|
+
self._widgets_by_tool: dict[str, Widget] = {}
|
|
159
|
+
|
|
160
|
+
for widget in widgets:
|
|
161
|
+
tool = build_widget_tool(widget)
|
|
162
|
+
self.add_tool(tool)
|
|
163
|
+
self._widgets_by_tool[tool.name] = widget
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def widgets_by_tool(self) -> Mapping[str, Widget]:
|
|
167
|
+
return self._widgets_by_tool
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class VisualizationToolset(FunctionToolset[OpenBBDeps]):
|
|
171
|
+
"""Toolset exposing helper utilities for charts and tables."""
|
|
172
|
+
|
|
173
|
+
def __init__(self) -> None:
|
|
174
|
+
super().__init__()
|
|
175
|
+
|
|
176
|
+
def _create_table(
|
|
177
|
+
data: list[dict[str, Any]],
|
|
178
|
+
name: str | None = None,
|
|
179
|
+
description: str | None = None,
|
|
180
|
+
):
|
|
181
|
+
"""Create a table artifact to display in OpenBB Workspace."""
|
|
182
|
+
|
|
183
|
+
return table(data=data, name=name, description=description)
|
|
184
|
+
|
|
185
|
+
def _create_chart(
|
|
186
|
+
type: Literal["line", "bar", "scatter", "pie", "donut"],
|
|
187
|
+
data: list[dict[str, Any]],
|
|
188
|
+
x_key: str | None = None,
|
|
189
|
+
y_keys: list[str] | None = None,
|
|
190
|
+
angle_key: str | None = None,
|
|
191
|
+
callout_label_key: str | None = None,
|
|
192
|
+
name: str | None = None,
|
|
193
|
+
description: str | None = None,
|
|
194
|
+
):
|
|
195
|
+
"""Create a chart artifact (line, bar, scatter, pie, donut).
|
|
196
|
+
|
|
197
|
+
Raises
|
|
198
|
+
------
|
|
199
|
+
ValueError
|
|
200
|
+
If required parameters for the given chart ``type`` are missing.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
if type in {"line", "bar", "scatter"}:
|
|
204
|
+
if not x_key:
|
|
205
|
+
raise ValueError(
|
|
206
|
+
"x_key is required for line, bar, and scatter charts"
|
|
207
|
+
)
|
|
208
|
+
if not y_keys:
|
|
209
|
+
raise ValueError(
|
|
210
|
+
"y_keys is required for line, bar, and scatter charts"
|
|
211
|
+
)
|
|
212
|
+
elif type in {"pie", "donut"}:
|
|
213
|
+
if not angle_key:
|
|
214
|
+
raise ValueError("angle_key is required for pie and donut charts")
|
|
215
|
+
if not callout_label_key:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
"callout_label_key is required for pie and donut charts"
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return chart(
|
|
221
|
+
type=type,
|
|
222
|
+
data=data,
|
|
223
|
+
x_key=x_key,
|
|
224
|
+
y_keys=y_keys,
|
|
225
|
+
angle_key=angle_key,
|
|
226
|
+
callout_label_key=callout_label_key,
|
|
227
|
+
name=name,
|
|
228
|
+
description=description,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
self.add_function(_create_table, name="openbb_create_table")
|
|
232
|
+
self.add_function(_create_chart, name="openbb_create_chart")
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def build_widget_toolsets(
|
|
236
|
+
collection: WidgetCollection | None,
|
|
237
|
+
) -> tuple[FunctionToolset[OpenBBDeps], ...]:
|
|
238
|
+
"""Create toolsets for each widget priority group plus visualization tools.
|
|
239
|
+
|
|
240
|
+
Widgets are organized into separate toolsets by priority (primary, secondary, extra)
|
|
241
|
+
to allow control over tool selection. The visualization toolset is always
|
|
242
|
+
included for creating charts and tables.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
collection : WidgetCollection | None
|
|
247
|
+
Widget collection with priority groups, or None
|
|
248
|
+
|
|
249
|
+
Returns
|
|
250
|
+
-------
|
|
251
|
+
tuple[FunctionToolset[OpenBBDeps], ...]
|
|
252
|
+
Toolsets including widget toolsets and visualization toolset
|
|
253
|
+
"""
|
|
254
|
+
if collection is None:
|
|
255
|
+
return (VisualizationToolset(),)
|
|
256
|
+
|
|
257
|
+
toolsets: list[FunctionToolset[OpenBBDeps]] = []
|
|
258
|
+
for widgets in (collection.primary, collection.secondary, collection.extra):
|
|
259
|
+
if widgets:
|
|
260
|
+
toolsets.append(WidgetToolset(widgets))
|
|
261
|
+
|
|
262
|
+
toolsets.append(VisualizationToolset())
|
|
263
|
+
|
|
264
|
+
return tuple(toolsets)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Type definitions and protocols for OpenBB Pydantic AI adapter."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Any, Protocol, TypedDict
|
|
7
|
+
|
|
8
|
+
if sys.version_info >= (3, 11):
|
|
9
|
+
from typing import NotRequired
|
|
10
|
+
else:
|
|
11
|
+
from typing_extensions import NotRequired # type: ignore[assignment]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SerializedContent(TypedDict):
|
|
15
|
+
"""Structure for serialized tool result content."""
|
|
16
|
+
|
|
17
|
+
input_arguments: dict[str, Any]
|
|
18
|
+
data: list[Any]
|
|
19
|
+
extra_state: NotRequired[dict[str, Any]]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ToolCallMetadata(TypedDict):
|
|
23
|
+
"""Metadata for tracking tool calls in flight."""
|
|
24
|
+
|
|
25
|
+
tool_call_id: str
|
|
26
|
+
widget_uuid: str
|
|
27
|
+
widget_id: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TextStreamCallback(Protocol):
|
|
31
|
+
"""Protocol for callbacks that mark text as having been streamed."""
|
|
32
|
+
|
|
33
|
+
def __call__(self) -> None:
|
|
34
|
+
"""Mark that text has been streamed."""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Type alias for structured detail entries
|
|
39
|
+
DetailEntry = dict[str, Any] | str
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Utility functions for OpenBB Pydantic AI UI adapter."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
from collections.abc import Mapping, Sequence
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from ._config import (
|
|
11
|
+
MAX_ARG_DISPLAY_CHARS,
|
|
12
|
+
MAX_ARG_PREVIEW_ITEMS,
|
|
13
|
+
)
|
|
14
|
+
from ._serializers import ContentSerializer
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def hash_tool_call(function: str, input_arguments: dict[str, Any]) -> str:
|
|
18
|
+
"""Generate a deterministic hash-based ID for a tool call.
|
|
19
|
+
|
|
20
|
+
This creates a unique identifier by hashing the function name and arguments,
|
|
21
|
+
ensuring consistent tool call IDs across message history and deferred results.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
function : str
|
|
26
|
+
The name of the function/tool being called
|
|
27
|
+
input_arguments : dict[str, Any]
|
|
28
|
+
The arguments passed to the tool
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
str
|
|
33
|
+
A string combining the function name with a 16-character hash digest
|
|
34
|
+
"""
|
|
35
|
+
payload = json.dumps(
|
|
36
|
+
{"function": function, "input_arguments": input_arguments},
|
|
37
|
+
sort_keys=True,
|
|
38
|
+
default=str,
|
|
39
|
+
)
|
|
40
|
+
digest = hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
|
41
|
+
return f"{function}_{digest[:16]}"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def normalize_args(args: Any) -> dict[str, Any]:
|
|
45
|
+
"""Normalize tool call arguments to a dictionary."""
|
|
46
|
+
if isinstance(args, dict):
|
|
47
|
+
return args
|
|
48
|
+
if isinstance(args, str):
|
|
49
|
+
try:
|
|
50
|
+
parsed = json.loads(args)
|
|
51
|
+
if isinstance(parsed, dict):
|
|
52
|
+
return parsed
|
|
53
|
+
except ValueError:
|
|
54
|
+
pass
|
|
55
|
+
return {}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_str(mapping: Mapping[str, Any], *keys: str) -> str | None:
|
|
59
|
+
"""Return the first string value found for the given keys."""
|
|
60
|
+
for key in keys:
|
|
61
|
+
value = mapping.get(key)
|
|
62
|
+
if isinstance(value, str):
|
|
63
|
+
return value
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_str_list(mapping: Mapping[str, Any], *keys: str) -> list[str] | None:
|
|
68
|
+
"""Return the first list of strings (or single string) found for the keys."""
|
|
69
|
+
for key in keys:
|
|
70
|
+
value = mapping.get(key)
|
|
71
|
+
if isinstance(value, str):
|
|
72
|
+
return [value]
|
|
73
|
+
if isinstance(value, list):
|
|
74
|
+
items = [item for item in value if isinstance(item, str)]
|
|
75
|
+
if items:
|
|
76
|
+
return items
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _truncate(value: str, max_chars: int = 160) -> str:
|
|
81
|
+
if len(value) <= max_chars:
|
|
82
|
+
return value
|
|
83
|
+
return value[: max_chars - 3] + "..."
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _json_dump(value: Any) -> str:
|
|
87
|
+
return ContentSerializer.to_json(value)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def format_arg_value(
|
|
91
|
+
value: Any,
|
|
92
|
+
*,
|
|
93
|
+
max_chars: int = MAX_ARG_DISPLAY_CHARS,
|
|
94
|
+
max_items: int = MAX_ARG_PREVIEW_ITEMS,
|
|
95
|
+
) -> str:
|
|
96
|
+
"""Summarize nested structures so reasoning details stay readable."""
|
|
97
|
+
|
|
98
|
+
if isinstance(value, str):
|
|
99
|
+
return _truncate(value, max_chars)
|
|
100
|
+
|
|
101
|
+
if isinstance(value, (int, float, bool)) or value is None:
|
|
102
|
+
return _json_dump(value)
|
|
103
|
+
|
|
104
|
+
if isinstance(value, Mapping):
|
|
105
|
+
keys = list(value.keys())
|
|
106
|
+
preview_keys = keys[:max_items]
|
|
107
|
+
preview = {k: value[k] for k in preview_keys}
|
|
108
|
+
suffix = "..." if len(keys) > max_items else ""
|
|
109
|
+
return _truncate(
|
|
110
|
+
f"dict(keys={preview_keys}{suffix}, sample={_json_dump(preview)})",
|
|
111
|
+
max_chars,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray)):
|
|
115
|
+
seq = list(value)
|
|
116
|
+
preview = seq[:max_items]
|
|
117
|
+
suffix = "..." if len(seq) > max_items else ""
|
|
118
|
+
return _truncate(
|
|
119
|
+
f"list(len={len(seq)}{suffix}, sample={_json_dump(preview)})",
|
|
120
|
+
max_chars,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return _truncate(_json_dump(value), max_chars)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def format_args(args: Mapping[str, Any]) -> dict[str, str]:
|
|
127
|
+
"""Format a mapping of arguments into readable key/value strings."""
|
|
128
|
+
|
|
129
|
+
formatted: dict[str, str] = {}
|
|
130
|
+
for key, value in args.items():
|
|
131
|
+
formatted[key] = format_arg_value(value)
|
|
132
|
+
return formatted
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Widget registry for centralized widget discovery and lookup."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterator, Mapping, Sequence
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from openbb_ai.models import (
|
|
9
|
+
LlmClientFunctionCallResultMessage,
|
|
10
|
+
Widget,
|
|
11
|
+
WidgetCollection,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from ._config import GET_WIDGET_DATA_TOOL_NAME
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from pydantic_ai.toolsets import FunctionToolset
|
|
18
|
+
|
|
19
|
+
from ._dependencies import OpenBBDeps
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class WidgetRegistry:
|
|
23
|
+
"""Centralized registry for widget discovery and lookup."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
collection: WidgetCollection | None = None,
|
|
28
|
+
toolsets: Sequence[FunctionToolset[OpenBBDeps]] | None = None,
|
|
29
|
+
):
|
|
30
|
+
"""Initialize widget registry from collection and toolsets.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
collection : WidgetCollection | None
|
|
35
|
+
Widget collection with priority groups
|
|
36
|
+
toolsets : Sequence[FunctionToolset[OpenBBDeps]] | None
|
|
37
|
+
Widget toolsets
|
|
38
|
+
"""
|
|
39
|
+
self._by_tool_name: dict[str, Widget] = {}
|
|
40
|
+
self._by_uuid: dict[str, Widget] = {}
|
|
41
|
+
|
|
42
|
+
# Build lookup from toolsets
|
|
43
|
+
if toolsets:
|
|
44
|
+
for toolset in toolsets:
|
|
45
|
+
widgets = getattr(toolset, "widgets_by_tool", None)
|
|
46
|
+
if widgets:
|
|
47
|
+
for tool_name, widget in widgets.items():
|
|
48
|
+
self._by_tool_name[tool_name] = widget
|
|
49
|
+
self._by_uuid[str(widget.uuid)] = widget
|
|
50
|
+
|
|
51
|
+
# Also index from collection if provided
|
|
52
|
+
if collection:
|
|
53
|
+
for widget in self._iter_collection(collection):
|
|
54
|
+
self._by_uuid[str(widget.uuid)] = widget
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def _iter_collection(collection: WidgetCollection) -> Iterator[Widget]:
|
|
58
|
+
"""Iterate all widgets in a collection."""
|
|
59
|
+
for group in (collection.primary, collection.secondary, collection.extra):
|
|
60
|
+
yield from group
|
|
61
|
+
|
|
62
|
+
def find_by_tool_name(self, name: str) -> Widget | None:
|
|
63
|
+
"""Find a widget by its tool name.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
name : str
|
|
68
|
+
The tool name to search for
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
Widget | None
|
|
73
|
+
The widget if found, None otherwise
|
|
74
|
+
"""
|
|
75
|
+
return self._by_tool_name.get(name)
|
|
76
|
+
|
|
77
|
+
def find_by_uuid(self, uuid: str) -> Widget | None:
|
|
78
|
+
"""Find a widget by its UUID string.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
uuid : str
|
|
83
|
+
The UUID to search for
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
Widget | None
|
|
88
|
+
The widget if found, None otherwise
|
|
89
|
+
"""
|
|
90
|
+
return self._by_uuid.get(uuid)
|
|
91
|
+
|
|
92
|
+
def find_for_result(
|
|
93
|
+
self, result: LlmClientFunctionCallResultMessage
|
|
94
|
+
) -> Widget | None:
|
|
95
|
+
"""Find the widget that produced a result message.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
result : LlmClientFunctionCallResultMessage
|
|
100
|
+
The result message to find a widget for
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
Widget | None
|
|
105
|
+
The widget if found, None otherwise
|
|
106
|
+
"""
|
|
107
|
+
# Check direct tool name match
|
|
108
|
+
widget = self.find_by_tool_name(result.function)
|
|
109
|
+
if widget is not None:
|
|
110
|
+
return widget
|
|
111
|
+
|
|
112
|
+
# Check if it's a get_widget_data call
|
|
113
|
+
if result.function == GET_WIDGET_DATA_TOOL_NAME:
|
|
114
|
+
data_sources = result.input_arguments.get("data_sources", [])
|
|
115
|
+
if data_sources:
|
|
116
|
+
widget_uuid = data_sources[0].get("widget_uuid")
|
|
117
|
+
if widget_uuid:
|
|
118
|
+
return self.find_by_uuid(widget_uuid)
|
|
119
|
+
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
def iter_all(self) -> Iterator[Widget]:
|
|
123
|
+
"""Iterate all registered widgets.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
Iterator[Widget]
|
|
128
|
+
Iterator over all widgets
|
|
129
|
+
"""
|
|
130
|
+
# Use dict to deduplicate by UUID
|
|
131
|
+
seen = set()
|
|
132
|
+
for widget in self._by_uuid.values():
|
|
133
|
+
if str(widget.uuid) not in seen:
|
|
134
|
+
seen.add(str(widget.uuid))
|
|
135
|
+
yield widget
|
|
136
|
+
|
|
137
|
+
def as_mapping(self) -> Mapping[str, Widget]:
|
|
138
|
+
"""Get widget lookup as a read-only mapping by tool name.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
Mapping[str, Widget]
|
|
143
|
+
Read-only mapping from tool names to widgets
|
|
144
|
+
"""
|
|
145
|
+
return self._by_tool_name
|