openbb-pydantic-ai 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,422 @@
1
+ """Helper utilities for OpenBB event stream transformations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import dataclass
7
+ from typing import Any, Literal, Mapping, cast
8
+ from uuid import uuid4
9
+
10
+ from openbb_ai.helpers import chart, message_chunk, reasoning_step, table
11
+ from openbb_ai.models import (
12
+ SSE,
13
+ ClientArtifact,
14
+ LlmClientFunctionCallResultMessage,
15
+ MessageArtifactSSE,
16
+ Widget,
17
+ )
18
+
19
+ from ._config import GET_WIDGET_DATA_TOOL_NAME
20
+ from ._event_builder import EventBuilder
21
+ from ._serializers import ContentSerializer
22
+ from ._types import SerializedContent, TextStreamCallback
23
+ from ._utils import (
24
+ format_arg_value,
25
+ format_args,
26
+ get_str,
27
+ get_str_list,
28
+ )
29
+
30
+
31
+ @dataclass(slots=True)
32
+ class ToolCallInfo:
33
+ """Metadata captured when a tool call event is received.
34
+
35
+ Attributes
36
+ ----------
37
+ tool_name : str
38
+ Name of the tool being called
39
+ args : dict[str, Any]
40
+ Arguments passed to the tool
41
+ widget : Widget | None
42
+ Associated widget if this is a widget tool call, None otherwise
43
+ """
44
+
45
+ tool_name: str
46
+ args: dict[str, Any]
47
+ widget: Widget | None = None
48
+
49
+
50
+ def find_widget_for_result(
51
+ result_message: LlmClientFunctionCallResultMessage,
52
+ widget_lookup: Mapping[str, Widget],
53
+ ) -> Widget | None:
54
+ """Locate the widget that produced a deferred result message.
55
+
56
+ Attempts to find the widget first by direct tool name match, then by
57
+ checking if the result is from a get_widget_data call with a widget_uuid.
58
+
59
+ Parameters
60
+ ----------
61
+ result_message : LlmClientFunctionCallResultMessage
62
+ The result message to find a widget for
63
+ widget_lookup : Mapping[str, Widget]
64
+ Mapping from tool names to widgets
65
+
66
+ Returns
67
+ -------
68
+ Widget | None
69
+ The widget that produced the result, or None if not found
70
+ """
71
+ widget = widget_lookup.get(result_message.function)
72
+ if widget is not None:
73
+ return widget
74
+
75
+ if result_message.function == GET_WIDGET_DATA_TOOL_NAME:
76
+ data_sources = result_message.input_arguments.get("data_sources", [])
77
+ if data_sources:
78
+ data_source = data_sources[0]
79
+ widget_uuid = data_source.get("widget_uuid")
80
+ for candidate in widget_lookup.values():
81
+ if str(candidate.uuid) == widget_uuid:
82
+ return candidate
83
+
84
+ return None
85
+
86
+
87
+ def extract_widget_args(
88
+ result_message: LlmClientFunctionCallResultMessage,
89
+ ) -> dict[str, Any]:
90
+ """Extract the arguments originally supplied to a widget invocation.
91
+
92
+ For get_widget_data calls, extracts the input_args from the first data source.
93
+ For direct widget calls, returns the input_arguments directly.
94
+
95
+ Parameters
96
+ ----------
97
+ result_message : LlmClientFunctionCallResultMessage
98
+ The result message to extract arguments from
99
+
100
+ Returns
101
+ -------
102
+ dict[str, Any]
103
+ The widget invocation arguments
104
+ """
105
+ if result_message.function == GET_WIDGET_DATA_TOOL_NAME:
106
+ data_sources = result_message.input_arguments.get("data_sources", [])
107
+ if data_sources:
108
+ return data_sources[0].get("input_args", {})
109
+ return result_message.input_arguments
110
+
111
+
112
+ def serialized_content_from_result(
113
+ result_message: LlmClientFunctionCallResultMessage,
114
+ ) -> SerializedContent:
115
+ """Serialize a result message into structured content.
116
+
117
+ This is a thin wrapper around ContentSerializer.serialize_result() with
118
+ clearer intent for event stream processing.
119
+
120
+ Parameters
121
+ ----------
122
+ result_message : LlmClientFunctionCallResultMessage
123
+ The result message to serialize
124
+
125
+ Returns
126
+ -------
127
+ SerializedContent
128
+ Typed dictionary with input_arguments, data, and optional extra_state
129
+ """
130
+ return ContentSerializer.serialize_result(result_message)
131
+
132
+
133
+ def handle_generic_tool_result(
134
+ info: ToolCallInfo,
135
+ content: Any,
136
+ *,
137
+ mark_streamed_text: TextStreamCallback,
138
+ ) -> list[SSE]:
139
+ """Emit SSE events for a non-widget tool result.
140
+
141
+ Attempts to parse the content and create appropriate SSE events. Falls back
142
+ to reasoning steps with formatted details if content cannot be structured.
143
+
144
+ Parameters
145
+ ----------
146
+ info : ToolCallInfo
147
+ Metadata about the tool call
148
+ content : Any
149
+ The tool result content to process
150
+ mark_streamed_text : TextStreamCallback
151
+ Callback to mark that text has been streamed
152
+
153
+ Returns
154
+ -------
155
+ list[SSE]
156
+ List of SSE events representing the tool result
157
+ """
158
+ events = tool_result_events_from_content(
159
+ content, mark_streamed_text=mark_streamed_text
160
+ )
161
+ if events:
162
+ events.insert(0, reasoning_step(f"Tool '{info.tool_name}' returned"))
163
+ return events
164
+
165
+ artifact = artifact_from_output(content)
166
+ if artifact is not None:
167
+ if isinstance(artifact, MessageArtifactSSE):
168
+ return [
169
+ EventBuilder.reasoning_with_artifacts(
170
+ f"Tool '{info.tool_name}' returned",
171
+ [artifact.data],
172
+ )
173
+ ]
174
+ return [
175
+ reasoning_step(f"Tool '{info.tool_name}' returned"),
176
+ artifact,
177
+ ]
178
+
179
+ details: dict[str, Any] | None = None
180
+ if info.args:
181
+ formatted = format_args(info.args)
182
+ if formatted:
183
+ details = formatted.copy()
184
+
185
+ result_text = ContentSerializer.to_string(content)
186
+ if result_text:
187
+ details = details or {}
188
+ details["Result"] = format_arg_value(content)
189
+
190
+ return [
191
+ reasoning_step(
192
+ f"Tool '{info.tool_name}' returned",
193
+ details=details,
194
+ )
195
+ ]
196
+
197
+
198
+ def tool_result_events_from_content(
199
+ content: Any,
200
+ *,
201
+ mark_streamed_text: TextStreamCallback,
202
+ ) -> list[SSE]:
203
+ """Transform tool result payloads into SSE events.
204
+
205
+ Processes structured content with a 'data' field containing items and
206
+ converts them into appropriate SSE events (artifacts, message chunks, etc.).
207
+
208
+ Parameters
209
+ ----------
210
+ content : Any
211
+ The tool result content to transform
212
+ mark_streamed_text : TextStreamCallback
213
+ Callback to mark that text has been streamed
214
+
215
+ Returns
216
+ -------
217
+ list[SSE]
218
+ List of SSE events, may be empty if content is not structured
219
+ """
220
+ if not isinstance(content, dict):
221
+ return []
222
+
223
+ data_entries = content.get("data") or []
224
+ if not isinstance(data_entries, list):
225
+ return []
226
+
227
+ events: list[SSE] = []
228
+ artifacts: list[ClientArtifact] = []
229
+
230
+ for entry in data_entries:
231
+ if not isinstance(entry, dict):
232
+ continue
233
+
234
+ command_event = _process_command_result(entry)
235
+ if command_event:
236
+ events.append(command_event)
237
+
238
+ entry_artifacts, entry_events = _process_data_items(entry, mark_streamed_text)
239
+ artifacts.extend(entry_artifacts)
240
+ events.extend(entry_events)
241
+
242
+ if artifacts:
243
+ events.append(
244
+ EventBuilder.reasoning_with_artifacts("Data retrieved", artifacts)
245
+ )
246
+
247
+ return events
248
+
249
+
250
+ def artifact_from_output(output: Any) -> SSE | None:
251
+ """Create an artifact SSE from generic tool output payloads.
252
+
253
+ Detects and creates appropriate artifacts (charts or tables) from structured
254
+ output. Supports various chart types (line, bar, scatter, pie, donut) and
255
+ table formats.
256
+
257
+ Parameters
258
+ ----------
259
+ output : Any
260
+ The tool output to convert to an artifact. Can be:
261
+ - dict with 'type' and 'data' for charts
262
+ - dict with 'table' key for tables
263
+ - list of dicts for automatic table creation
264
+
265
+ Returns
266
+ -------
267
+ SSE | None
268
+ A chart or table artifact event, or None if output format is not recognized
269
+
270
+ Notes
271
+ -----
272
+ Chart types require specific keys:
273
+ - line/bar/scatter: x_key and y_keys required
274
+ - pie/donut: angle_key and callout_label_key required
275
+ """
276
+ if isinstance(output, dict):
277
+ chart_type = output.get("type")
278
+ data = output.get("data")
279
+
280
+ if isinstance(chart_type, str) and chart_type in {
281
+ "line",
282
+ "bar",
283
+ "scatter",
284
+ "pie",
285
+ "donut",
286
+ }:
287
+ rows = (
288
+ [row for row in data or [] if isinstance(row, dict)]
289
+ if isinstance(data, list)
290
+ else []
291
+ )
292
+ if not rows:
293
+ return None
294
+
295
+ chart_type_literal = cast(
296
+ Literal["line", "bar", "scatter", "pie", "donut"], chart_type
297
+ )
298
+
299
+ x_key = get_str(output, "x_key", "xKey")
300
+ y_keys = get_str_list(output, "y_keys", "yKeys", "y_key", "yKey")
301
+ angle_key = get_str(output, "angle_key", "angleKey")
302
+ callout_label_key = get_str(output, "callout_label_key", "calloutLabelKey")
303
+
304
+ if chart_type_literal in {"line", "bar", "scatter"}:
305
+ if not x_key or not y_keys:
306
+ return None
307
+ elif chart_type_literal in {"pie", "donut"}:
308
+ if not angle_key or not callout_label_key:
309
+ return None
310
+
311
+ return chart(
312
+ type=chart_type_literal,
313
+ data=rows,
314
+ x_key=x_key,
315
+ y_keys=y_keys,
316
+ angle_key=angle_key,
317
+ callout_label_key=callout_label_key,
318
+ name=output.get("name"),
319
+ description=output.get("description"),
320
+ )
321
+
322
+ table_data = None
323
+ if isinstance(output.get("table"), list):
324
+ table_data = output["table"]
325
+ elif isinstance(data, list) and all(isinstance(item, dict) for item in data):
326
+ table_data = data
327
+
328
+ if table_data:
329
+ return table(
330
+ data=table_data,
331
+ name=output.get("name"),
332
+ description=output.get("description"),
333
+ )
334
+
335
+ if (
336
+ isinstance(output, list)
337
+ and output
338
+ and all(isinstance(item, dict) for item in output)
339
+ ):
340
+ return table(data=output, name=None, description=None)
341
+
342
+ return None
343
+
344
+
345
+ def _process_command_result(entry: dict[str, Any]) -> SSE | None:
346
+ """Process command result status messages.
347
+
348
+ Parameters
349
+ ----------
350
+ entry : dict[str, Any]
351
+ Data entry potentially containing status and message fields
352
+
353
+ Returns
354
+ -------
355
+ SSE | None
356
+ Reasoning step event if status/message found, None otherwise
357
+ """
358
+ status = entry.get("status")
359
+ message = entry.get("message")
360
+ if status and message:
361
+ return reasoning_step(f"[{status}] {message}")
362
+ return None
363
+
364
+
365
+ def _process_data_items(
366
+ entry: dict[str, Any], mark_streamed_text: TextStreamCallback
367
+ ) -> tuple[list[ClientArtifact], list[SSE]]:
368
+ """Process data items from a data entry into artifacts or SSE events.
369
+
370
+ Parses JSON content from items and converts them into table artifacts
371
+ for list-of-dicts data, or message chunks for other content types.
372
+
373
+ Parameters
374
+ ----------
375
+ entry : dict[str, Any]
376
+ Data entry containing an 'items' field with content
377
+ mark_streamed_text : TextStreamCallback
378
+ Callback to mark that text has been streamed
379
+
380
+ Returns
381
+ -------
382
+ tuple[list[ClientArtifact], list[SSE]]
383
+ Tuple of (artifacts created, events emitted)
384
+ """
385
+ items = entry.get("items")
386
+ if not isinstance(items, list):
387
+ return [], []
388
+
389
+ artifacts: list[ClientArtifact] = []
390
+ events: list[SSE] = []
391
+
392
+ for item in items:
393
+ if not isinstance(item, dict):
394
+ continue
395
+
396
+ raw_content = item.get("content")
397
+ if not isinstance(raw_content, str):
398
+ continue
399
+
400
+ parsed = ContentSerializer.parse_json(raw_content)
401
+
402
+ if (
403
+ isinstance(parsed, list)
404
+ and parsed
405
+ and all(isinstance(row, dict) for row in parsed)
406
+ ):
407
+ artifacts.append(
408
+ ClientArtifact(
409
+ type="table",
410
+ name=item.get("name") or f"Table_{uuid4().hex[:4]}",
411
+ description=item.get("description") or "Widget data",
412
+ content=parsed,
413
+ )
414
+ )
415
+ elif isinstance(parsed, dict):
416
+ mark_streamed_text()
417
+ events.append(message_chunk(json.dumps(parsed)))
418
+ else:
419
+ mark_streamed_text()
420
+ events.append(message_chunk(raw_content))
421
+
422
+ return artifacts, events
@@ -0,0 +1,61 @@
1
+ """Custom exceptions for OpenBB Pydantic AI adapter."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ class OpenBBPydanticAIError(Exception):
7
+ """Base exception for OpenBB Pydantic AI adapter errors."""
8
+
9
+
10
+ class WidgetNotFoundError(OpenBBPydanticAIError):
11
+ """Raised when a widget cannot be found by tool name or UUID."""
12
+
13
+ def __init__(self, identifier: str, lookup_type: str = "tool_name"):
14
+ """Initialize with widget identifier and lookup type.
15
+
16
+ Parameters
17
+ ----------
18
+ identifier : str
19
+ The widget identifier that was not found
20
+ lookup_type : str
21
+ The type of lookup performed (tool_name, uuid, etc.)
22
+ """
23
+ super().__init__(f"Widget not found by {lookup_type}: {identifier}")
24
+ self.identifier = identifier
25
+ self.lookup_type = lookup_type
26
+
27
+
28
+ class InvalidToolCallError(OpenBBPydanticAIError):
29
+ """Raised when a tool call is malformed or invalid."""
30
+
31
+ def __init__(self, tool_name: str, reason: str):
32
+ """Initialize with tool name and reason.
33
+
34
+ Parameters
35
+ ----------
36
+ tool_name : str
37
+ The name of the tool that had an invalid call
38
+ reason : str
39
+ The reason why the call is invalid
40
+ """
41
+ super().__init__(f"Invalid tool call for '{tool_name}': {reason}")
42
+ self.tool_name = tool_name
43
+ self.reason = reason
44
+
45
+
46
+ class SerializationError(OpenBBPydanticAIError):
47
+ """Raised when content serialization or deserialization fails."""
48
+
49
+ def __init__(self, content_type: str, reason: str):
50
+ """Initialize with content type and reason.
51
+
52
+ Parameters
53
+ ----------
54
+ content_type : str
55
+ The type of content being serialized
56
+ reason : str
57
+ The reason for the failure
58
+ """
59
+ super().__init__(f"Serialization failed for {content_type}: {reason}")
60
+ self.content_type = content_type
61
+ self.reason = reason
@@ -0,0 +1,127 @@
1
+ """Message transformation utilities for OpenBB Pydantic AI adapter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Sequence
6
+
7
+ from openbb_ai.models import (
8
+ LlmClientFunctionCall,
9
+ LlmClientFunctionCallResultMessage,
10
+ LlmClientMessage,
11
+ LlmMessage,
12
+ RoleEnum,
13
+ )
14
+ from pydantic_ai.messages import (
15
+ ModelMessage,
16
+ TextPart,
17
+ ToolCallPart,
18
+ ToolReturnPart,
19
+ UserPromptPart,
20
+ )
21
+ from pydantic_ai.ui import MessagesBuilder
22
+
23
+ from ._serializers import ContentSerializer
24
+
25
+
26
+ class MessageTransformer:
27
+ """Transforms OpenBB messages to Pydantic AI messages.
28
+
29
+ Manages tool call ID consistency across message history.
30
+ """
31
+
32
+ def __init__(self, tool_call_id_overrides: dict[str, str] | None = None):
33
+ """Initialize transformer with optional tool call ID overrides.
34
+
35
+ Parameters
36
+ ----------
37
+ tool_call_id_overrides : dict[str, str] | None
38
+ Mapping from hash-based IDs to actual tool call IDs for consistency
39
+ """
40
+ self._overrides = tool_call_id_overrides or {}
41
+
42
+ def transform_batch(self, messages: Sequence[LlmMessage]) -> list[ModelMessage]:
43
+ """Transform a batch of OpenBB messages to Pydantic AI messages.
44
+
45
+ Parameters
46
+ ----------
47
+ messages : Sequence[LlmMessage]
48
+ List of OpenBB messages to transform
49
+
50
+ Returns
51
+ -------
52
+ list[ModelMessage]
53
+ List of Pydantic AI messages
54
+ """
55
+ builder = MessagesBuilder()
56
+ for message in messages:
57
+ if isinstance(message, LlmClientMessage):
58
+ self._add_client_message(builder, message)
59
+ elif isinstance(message, LlmClientFunctionCallResultMessage):
60
+ self._add_result_message(builder, message)
61
+ return builder.messages
62
+
63
+ def _add_client_message(
64
+ self, builder: MessagesBuilder, message: LlmClientMessage
65
+ ) -> None:
66
+ """Add a client message to the builder.
67
+
68
+ Parameters
69
+ ----------
70
+ builder : MessagesBuilder
71
+ The message builder to add to
72
+ message : LlmClientMessage
73
+ The client message to add
74
+ """
75
+ content = message.content
76
+
77
+ if isinstance(content, LlmClientFunctionCall):
78
+ # Use override if available, otherwise use base ID
79
+ from ._utils import hash_tool_call
80
+
81
+ base_id = hash_tool_call(content.function, content.input_arguments)
82
+ tool_call_id = self._overrides.get(base_id, base_id)
83
+
84
+ builder.add(
85
+ ToolCallPart(
86
+ tool_name=content.function,
87
+ tool_call_id=tool_call_id,
88
+ args=content.input_arguments,
89
+ )
90
+ )
91
+ return
92
+
93
+ if isinstance(content, str):
94
+ if message.role == RoleEnum.human:
95
+ builder.add(UserPromptPart(content=content))
96
+ elif message.role == RoleEnum.ai:
97
+ builder.add(TextPart(content=content))
98
+ else:
99
+ builder.add(TextPart(content=content))
100
+
101
+ def _add_result_message(
102
+ self,
103
+ builder: MessagesBuilder,
104
+ message: LlmClientFunctionCallResultMessage,
105
+ ) -> None:
106
+ """Add a function call result message to the builder.
107
+
108
+ Parameters
109
+ ----------
110
+ builder : MessagesBuilder
111
+ The message builder to add to
112
+ message : LlmClientFunctionCallResultMessage
113
+ The result message to add
114
+ """
115
+ from ._utils import hash_tool_call
116
+
117
+ # Generate base ID and use override if available
118
+ base_id = hash_tool_call(message.function, message.input_arguments)
119
+ tool_call_id = self._overrides.get(base_id, base_id)
120
+
121
+ builder.add(
122
+ ToolReturnPart(
123
+ tool_name=message.function,
124
+ tool_call_id=tool_call_id,
125
+ content=ContentSerializer.serialize_result(message),
126
+ )
127
+ )