openbb-pydantic-ai 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openbb_pydantic_ai/__init__.py +52 -0
- openbb_pydantic_ai/_adapter.py +307 -0
- openbb_pydantic_ai/_config.py +58 -0
- openbb_pydantic_ai/_dependencies.py +73 -0
- openbb_pydantic_ai/_event_builder.py +183 -0
- openbb_pydantic_ai/_event_stream.py +363 -0
- openbb_pydantic_ai/_event_stream_components.py +155 -0
- openbb_pydantic_ai/_event_stream_helpers.py +422 -0
- openbb_pydantic_ai/_exceptions.py +61 -0
- openbb_pydantic_ai/_message_transformer.py +127 -0
- openbb_pydantic_ai/_serializers.py +110 -0
- openbb_pydantic_ai/_toolsets.py +264 -0
- openbb_pydantic_ai/_types.py +39 -0
- openbb_pydantic_ai/_utils.py +132 -0
- openbb_pydantic_ai/_widget_registry.py +145 -0
- openbb_pydantic_ai-0.1.1.dist-info/METADATA +139 -0
- openbb_pydantic_ai-0.1.1.dist-info/RECORD +18 -0
- openbb_pydantic_ai-0.1.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,422 @@
|
|
|
1
|
+
"""Helper utilities for OpenBB event stream transformations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any, Literal, Mapping, cast
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
|
|
10
|
+
from openbb_ai.helpers import chart, message_chunk, reasoning_step, table
|
|
11
|
+
from openbb_ai.models import (
|
|
12
|
+
SSE,
|
|
13
|
+
ClientArtifact,
|
|
14
|
+
LlmClientFunctionCallResultMessage,
|
|
15
|
+
MessageArtifactSSE,
|
|
16
|
+
Widget,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from ._config import GET_WIDGET_DATA_TOOL_NAME
|
|
20
|
+
from ._event_builder import EventBuilder
|
|
21
|
+
from ._serializers import ContentSerializer
|
|
22
|
+
from ._types import SerializedContent, TextStreamCallback
|
|
23
|
+
from ._utils import (
|
|
24
|
+
format_arg_value,
|
|
25
|
+
format_args,
|
|
26
|
+
get_str,
|
|
27
|
+
get_str_list,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(slots=True)
|
|
32
|
+
class ToolCallInfo:
|
|
33
|
+
"""Metadata captured when a tool call event is received.
|
|
34
|
+
|
|
35
|
+
Attributes
|
|
36
|
+
----------
|
|
37
|
+
tool_name : str
|
|
38
|
+
Name of the tool being called
|
|
39
|
+
args : dict[str, Any]
|
|
40
|
+
Arguments passed to the tool
|
|
41
|
+
widget : Widget | None
|
|
42
|
+
Associated widget if this is a widget tool call, None otherwise
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
tool_name: str
|
|
46
|
+
args: dict[str, Any]
|
|
47
|
+
widget: Widget | None = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def find_widget_for_result(
|
|
51
|
+
result_message: LlmClientFunctionCallResultMessage,
|
|
52
|
+
widget_lookup: Mapping[str, Widget],
|
|
53
|
+
) -> Widget | None:
|
|
54
|
+
"""Locate the widget that produced a deferred result message.
|
|
55
|
+
|
|
56
|
+
Attempts to find the widget first by direct tool name match, then by
|
|
57
|
+
checking if the result is from a get_widget_data call with a widget_uuid.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
result_message : LlmClientFunctionCallResultMessage
|
|
62
|
+
The result message to find a widget for
|
|
63
|
+
widget_lookup : Mapping[str, Widget]
|
|
64
|
+
Mapping from tool names to widgets
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
Widget | None
|
|
69
|
+
The widget that produced the result, or None if not found
|
|
70
|
+
"""
|
|
71
|
+
widget = widget_lookup.get(result_message.function)
|
|
72
|
+
if widget is not None:
|
|
73
|
+
return widget
|
|
74
|
+
|
|
75
|
+
if result_message.function == GET_WIDGET_DATA_TOOL_NAME:
|
|
76
|
+
data_sources = result_message.input_arguments.get("data_sources", [])
|
|
77
|
+
if data_sources:
|
|
78
|
+
data_source = data_sources[0]
|
|
79
|
+
widget_uuid = data_source.get("widget_uuid")
|
|
80
|
+
for candidate in widget_lookup.values():
|
|
81
|
+
if str(candidate.uuid) == widget_uuid:
|
|
82
|
+
return candidate
|
|
83
|
+
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def extract_widget_args(
|
|
88
|
+
result_message: LlmClientFunctionCallResultMessage,
|
|
89
|
+
) -> dict[str, Any]:
|
|
90
|
+
"""Extract the arguments originally supplied to a widget invocation.
|
|
91
|
+
|
|
92
|
+
For get_widget_data calls, extracts the input_args from the first data source.
|
|
93
|
+
For direct widget calls, returns the input_arguments directly.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
result_message : LlmClientFunctionCallResultMessage
|
|
98
|
+
The result message to extract arguments from
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
dict[str, Any]
|
|
103
|
+
The widget invocation arguments
|
|
104
|
+
"""
|
|
105
|
+
if result_message.function == GET_WIDGET_DATA_TOOL_NAME:
|
|
106
|
+
data_sources = result_message.input_arguments.get("data_sources", [])
|
|
107
|
+
if data_sources:
|
|
108
|
+
return data_sources[0].get("input_args", {})
|
|
109
|
+
return result_message.input_arguments
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def serialized_content_from_result(
|
|
113
|
+
result_message: LlmClientFunctionCallResultMessage,
|
|
114
|
+
) -> SerializedContent:
|
|
115
|
+
"""Serialize a result message into structured content.
|
|
116
|
+
|
|
117
|
+
This is a thin wrapper around ContentSerializer.serialize_result() with
|
|
118
|
+
clearer intent for event stream processing.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
result_message : LlmClientFunctionCallResultMessage
|
|
123
|
+
The result message to serialize
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
SerializedContent
|
|
128
|
+
Typed dictionary with input_arguments, data, and optional extra_state
|
|
129
|
+
"""
|
|
130
|
+
return ContentSerializer.serialize_result(result_message)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def handle_generic_tool_result(
|
|
134
|
+
info: ToolCallInfo,
|
|
135
|
+
content: Any,
|
|
136
|
+
*,
|
|
137
|
+
mark_streamed_text: TextStreamCallback,
|
|
138
|
+
) -> list[SSE]:
|
|
139
|
+
"""Emit SSE events for a non-widget tool result.
|
|
140
|
+
|
|
141
|
+
Attempts to parse the content and create appropriate SSE events. Falls back
|
|
142
|
+
to reasoning steps with formatted details if content cannot be structured.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
info : ToolCallInfo
|
|
147
|
+
Metadata about the tool call
|
|
148
|
+
content : Any
|
|
149
|
+
The tool result content to process
|
|
150
|
+
mark_streamed_text : TextStreamCallback
|
|
151
|
+
Callback to mark that text has been streamed
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
list[SSE]
|
|
156
|
+
List of SSE events representing the tool result
|
|
157
|
+
"""
|
|
158
|
+
events = tool_result_events_from_content(
|
|
159
|
+
content, mark_streamed_text=mark_streamed_text
|
|
160
|
+
)
|
|
161
|
+
if events:
|
|
162
|
+
events.insert(0, reasoning_step(f"Tool '{info.tool_name}' returned"))
|
|
163
|
+
return events
|
|
164
|
+
|
|
165
|
+
artifact = artifact_from_output(content)
|
|
166
|
+
if artifact is not None:
|
|
167
|
+
if isinstance(artifact, MessageArtifactSSE):
|
|
168
|
+
return [
|
|
169
|
+
EventBuilder.reasoning_with_artifacts(
|
|
170
|
+
f"Tool '{info.tool_name}' returned",
|
|
171
|
+
[artifact.data],
|
|
172
|
+
)
|
|
173
|
+
]
|
|
174
|
+
return [
|
|
175
|
+
reasoning_step(f"Tool '{info.tool_name}' returned"),
|
|
176
|
+
artifact,
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
details: dict[str, Any] | None = None
|
|
180
|
+
if info.args:
|
|
181
|
+
formatted = format_args(info.args)
|
|
182
|
+
if formatted:
|
|
183
|
+
details = formatted.copy()
|
|
184
|
+
|
|
185
|
+
result_text = ContentSerializer.to_string(content)
|
|
186
|
+
if result_text:
|
|
187
|
+
details = details or {}
|
|
188
|
+
details["Result"] = format_arg_value(content)
|
|
189
|
+
|
|
190
|
+
return [
|
|
191
|
+
reasoning_step(
|
|
192
|
+
f"Tool '{info.tool_name}' returned",
|
|
193
|
+
details=details,
|
|
194
|
+
)
|
|
195
|
+
]
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def tool_result_events_from_content(
|
|
199
|
+
content: Any,
|
|
200
|
+
*,
|
|
201
|
+
mark_streamed_text: TextStreamCallback,
|
|
202
|
+
) -> list[SSE]:
|
|
203
|
+
"""Transform tool result payloads into SSE events.
|
|
204
|
+
|
|
205
|
+
Processes structured content with a 'data' field containing items and
|
|
206
|
+
converts them into appropriate SSE events (artifacts, message chunks, etc.).
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
content : Any
|
|
211
|
+
The tool result content to transform
|
|
212
|
+
mark_streamed_text : TextStreamCallback
|
|
213
|
+
Callback to mark that text has been streamed
|
|
214
|
+
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
list[SSE]
|
|
218
|
+
List of SSE events, may be empty if content is not structured
|
|
219
|
+
"""
|
|
220
|
+
if not isinstance(content, dict):
|
|
221
|
+
return []
|
|
222
|
+
|
|
223
|
+
data_entries = content.get("data") or []
|
|
224
|
+
if not isinstance(data_entries, list):
|
|
225
|
+
return []
|
|
226
|
+
|
|
227
|
+
events: list[SSE] = []
|
|
228
|
+
artifacts: list[ClientArtifact] = []
|
|
229
|
+
|
|
230
|
+
for entry in data_entries:
|
|
231
|
+
if not isinstance(entry, dict):
|
|
232
|
+
continue
|
|
233
|
+
|
|
234
|
+
command_event = _process_command_result(entry)
|
|
235
|
+
if command_event:
|
|
236
|
+
events.append(command_event)
|
|
237
|
+
|
|
238
|
+
entry_artifacts, entry_events = _process_data_items(entry, mark_streamed_text)
|
|
239
|
+
artifacts.extend(entry_artifacts)
|
|
240
|
+
events.extend(entry_events)
|
|
241
|
+
|
|
242
|
+
if artifacts:
|
|
243
|
+
events.append(
|
|
244
|
+
EventBuilder.reasoning_with_artifacts("Data retrieved", artifacts)
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
return events
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def artifact_from_output(output: Any) -> SSE | None:
|
|
251
|
+
"""Create an artifact SSE from generic tool output payloads.
|
|
252
|
+
|
|
253
|
+
Detects and creates appropriate artifacts (charts or tables) from structured
|
|
254
|
+
output. Supports various chart types (line, bar, scatter, pie, donut) and
|
|
255
|
+
table formats.
|
|
256
|
+
|
|
257
|
+
Parameters
|
|
258
|
+
----------
|
|
259
|
+
output : Any
|
|
260
|
+
The tool output to convert to an artifact. Can be:
|
|
261
|
+
- dict with 'type' and 'data' for charts
|
|
262
|
+
- dict with 'table' key for tables
|
|
263
|
+
- list of dicts for automatic table creation
|
|
264
|
+
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
SSE | None
|
|
268
|
+
A chart or table artifact event, or None if output format is not recognized
|
|
269
|
+
|
|
270
|
+
Notes
|
|
271
|
+
-----
|
|
272
|
+
Chart types require specific keys:
|
|
273
|
+
- line/bar/scatter: x_key and y_keys required
|
|
274
|
+
- pie/donut: angle_key and callout_label_key required
|
|
275
|
+
"""
|
|
276
|
+
if isinstance(output, dict):
|
|
277
|
+
chart_type = output.get("type")
|
|
278
|
+
data = output.get("data")
|
|
279
|
+
|
|
280
|
+
if isinstance(chart_type, str) and chart_type in {
|
|
281
|
+
"line",
|
|
282
|
+
"bar",
|
|
283
|
+
"scatter",
|
|
284
|
+
"pie",
|
|
285
|
+
"donut",
|
|
286
|
+
}:
|
|
287
|
+
rows = (
|
|
288
|
+
[row for row in data or [] if isinstance(row, dict)]
|
|
289
|
+
if isinstance(data, list)
|
|
290
|
+
else []
|
|
291
|
+
)
|
|
292
|
+
if not rows:
|
|
293
|
+
return None
|
|
294
|
+
|
|
295
|
+
chart_type_literal = cast(
|
|
296
|
+
Literal["line", "bar", "scatter", "pie", "donut"], chart_type
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
x_key = get_str(output, "x_key", "xKey")
|
|
300
|
+
y_keys = get_str_list(output, "y_keys", "yKeys", "y_key", "yKey")
|
|
301
|
+
angle_key = get_str(output, "angle_key", "angleKey")
|
|
302
|
+
callout_label_key = get_str(output, "callout_label_key", "calloutLabelKey")
|
|
303
|
+
|
|
304
|
+
if chart_type_literal in {"line", "bar", "scatter"}:
|
|
305
|
+
if not x_key or not y_keys:
|
|
306
|
+
return None
|
|
307
|
+
elif chart_type_literal in {"pie", "donut"}:
|
|
308
|
+
if not angle_key or not callout_label_key:
|
|
309
|
+
return None
|
|
310
|
+
|
|
311
|
+
return chart(
|
|
312
|
+
type=chart_type_literal,
|
|
313
|
+
data=rows,
|
|
314
|
+
x_key=x_key,
|
|
315
|
+
y_keys=y_keys,
|
|
316
|
+
angle_key=angle_key,
|
|
317
|
+
callout_label_key=callout_label_key,
|
|
318
|
+
name=output.get("name"),
|
|
319
|
+
description=output.get("description"),
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
table_data = None
|
|
323
|
+
if isinstance(output.get("table"), list):
|
|
324
|
+
table_data = output["table"]
|
|
325
|
+
elif isinstance(data, list) and all(isinstance(item, dict) for item in data):
|
|
326
|
+
table_data = data
|
|
327
|
+
|
|
328
|
+
if table_data:
|
|
329
|
+
return table(
|
|
330
|
+
data=table_data,
|
|
331
|
+
name=output.get("name"),
|
|
332
|
+
description=output.get("description"),
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
if (
|
|
336
|
+
isinstance(output, list)
|
|
337
|
+
and output
|
|
338
|
+
and all(isinstance(item, dict) for item in output)
|
|
339
|
+
):
|
|
340
|
+
return table(data=output, name=None, description=None)
|
|
341
|
+
|
|
342
|
+
return None
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def _process_command_result(entry: dict[str, Any]) -> SSE | None:
|
|
346
|
+
"""Process command result status messages.
|
|
347
|
+
|
|
348
|
+
Parameters
|
|
349
|
+
----------
|
|
350
|
+
entry : dict[str, Any]
|
|
351
|
+
Data entry potentially containing status and message fields
|
|
352
|
+
|
|
353
|
+
Returns
|
|
354
|
+
-------
|
|
355
|
+
SSE | None
|
|
356
|
+
Reasoning step event if status/message found, None otherwise
|
|
357
|
+
"""
|
|
358
|
+
status = entry.get("status")
|
|
359
|
+
message = entry.get("message")
|
|
360
|
+
if status and message:
|
|
361
|
+
return reasoning_step(f"[{status}] {message}")
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _process_data_items(
|
|
366
|
+
entry: dict[str, Any], mark_streamed_text: TextStreamCallback
|
|
367
|
+
) -> tuple[list[ClientArtifact], list[SSE]]:
|
|
368
|
+
"""Process data items from a data entry into artifacts or SSE events.
|
|
369
|
+
|
|
370
|
+
Parses JSON content from items and converts them into table artifacts
|
|
371
|
+
for list-of-dicts data, or message chunks for other content types.
|
|
372
|
+
|
|
373
|
+
Parameters
|
|
374
|
+
----------
|
|
375
|
+
entry : dict[str, Any]
|
|
376
|
+
Data entry containing an 'items' field with content
|
|
377
|
+
mark_streamed_text : TextStreamCallback
|
|
378
|
+
Callback to mark that text has been streamed
|
|
379
|
+
|
|
380
|
+
Returns
|
|
381
|
+
-------
|
|
382
|
+
tuple[list[ClientArtifact], list[SSE]]
|
|
383
|
+
Tuple of (artifacts created, events emitted)
|
|
384
|
+
"""
|
|
385
|
+
items = entry.get("items")
|
|
386
|
+
if not isinstance(items, list):
|
|
387
|
+
return [], []
|
|
388
|
+
|
|
389
|
+
artifacts: list[ClientArtifact] = []
|
|
390
|
+
events: list[SSE] = []
|
|
391
|
+
|
|
392
|
+
for item in items:
|
|
393
|
+
if not isinstance(item, dict):
|
|
394
|
+
continue
|
|
395
|
+
|
|
396
|
+
raw_content = item.get("content")
|
|
397
|
+
if not isinstance(raw_content, str):
|
|
398
|
+
continue
|
|
399
|
+
|
|
400
|
+
parsed = ContentSerializer.parse_json(raw_content)
|
|
401
|
+
|
|
402
|
+
if (
|
|
403
|
+
isinstance(parsed, list)
|
|
404
|
+
and parsed
|
|
405
|
+
and all(isinstance(row, dict) for row in parsed)
|
|
406
|
+
):
|
|
407
|
+
artifacts.append(
|
|
408
|
+
ClientArtifact(
|
|
409
|
+
type="table",
|
|
410
|
+
name=item.get("name") or f"Table_{uuid4().hex[:4]}",
|
|
411
|
+
description=item.get("description") or "Widget data",
|
|
412
|
+
content=parsed,
|
|
413
|
+
)
|
|
414
|
+
)
|
|
415
|
+
elif isinstance(parsed, dict):
|
|
416
|
+
mark_streamed_text()
|
|
417
|
+
events.append(message_chunk(json.dumps(parsed)))
|
|
418
|
+
else:
|
|
419
|
+
mark_streamed_text()
|
|
420
|
+
events.append(message_chunk(raw_content))
|
|
421
|
+
|
|
422
|
+
return artifacts, events
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Custom exceptions for OpenBB Pydantic AI adapter."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OpenBBPydanticAIError(Exception):
|
|
7
|
+
"""Base exception for OpenBB Pydantic AI adapter errors."""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class WidgetNotFoundError(OpenBBPydanticAIError):
|
|
11
|
+
"""Raised when a widget cannot be found by tool name or UUID."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, identifier: str, lookup_type: str = "tool_name"):
|
|
14
|
+
"""Initialize with widget identifier and lookup type.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
identifier : str
|
|
19
|
+
The widget identifier that was not found
|
|
20
|
+
lookup_type : str
|
|
21
|
+
The type of lookup performed (tool_name, uuid, etc.)
|
|
22
|
+
"""
|
|
23
|
+
super().__init__(f"Widget not found by {lookup_type}: {identifier}")
|
|
24
|
+
self.identifier = identifier
|
|
25
|
+
self.lookup_type = lookup_type
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class InvalidToolCallError(OpenBBPydanticAIError):
|
|
29
|
+
"""Raised when a tool call is malformed or invalid."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, tool_name: str, reason: str):
|
|
32
|
+
"""Initialize with tool name and reason.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
tool_name : str
|
|
37
|
+
The name of the tool that had an invalid call
|
|
38
|
+
reason : str
|
|
39
|
+
The reason why the call is invalid
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(f"Invalid tool call for '{tool_name}': {reason}")
|
|
42
|
+
self.tool_name = tool_name
|
|
43
|
+
self.reason = reason
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SerializationError(OpenBBPydanticAIError):
|
|
47
|
+
"""Raised when content serialization or deserialization fails."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, content_type: str, reason: str):
|
|
50
|
+
"""Initialize with content type and reason.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
content_type : str
|
|
55
|
+
The type of content being serialized
|
|
56
|
+
reason : str
|
|
57
|
+
The reason for the failure
|
|
58
|
+
"""
|
|
59
|
+
super().__init__(f"Serialization failed for {content_type}: {reason}")
|
|
60
|
+
self.content_type = content_type
|
|
61
|
+
self.reason = reason
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Message transformation utilities for OpenBB Pydantic AI adapter."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
|
|
7
|
+
from openbb_ai.models import (
|
|
8
|
+
LlmClientFunctionCall,
|
|
9
|
+
LlmClientFunctionCallResultMessage,
|
|
10
|
+
LlmClientMessage,
|
|
11
|
+
LlmMessage,
|
|
12
|
+
RoleEnum,
|
|
13
|
+
)
|
|
14
|
+
from pydantic_ai.messages import (
|
|
15
|
+
ModelMessage,
|
|
16
|
+
TextPart,
|
|
17
|
+
ToolCallPart,
|
|
18
|
+
ToolReturnPart,
|
|
19
|
+
UserPromptPart,
|
|
20
|
+
)
|
|
21
|
+
from pydantic_ai.ui import MessagesBuilder
|
|
22
|
+
|
|
23
|
+
from ._serializers import ContentSerializer
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MessageTransformer:
|
|
27
|
+
"""Transforms OpenBB messages to Pydantic AI messages.
|
|
28
|
+
|
|
29
|
+
Manages tool call ID consistency across message history.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, tool_call_id_overrides: dict[str, str] | None = None):
|
|
33
|
+
"""Initialize transformer with optional tool call ID overrides.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
tool_call_id_overrides : dict[str, str] | None
|
|
38
|
+
Mapping from hash-based IDs to actual tool call IDs for consistency
|
|
39
|
+
"""
|
|
40
|
+
self._overrides = tool_call_id_overrides or {}
|
|
41
|
+
|
|
42
|
+
def transform_batch(self, messages: Sequence[LlmMessage]) -> list[ModelMessage]:
|
|
43
|
+
"""Transform a batch of OpenBB messages to Pydantic AI messages.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
messages : Sequence[LlmMessage]
|
|
48
|
+
List of OpenBB messages to transform
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
list[ModelMessage]
|
|
53
|
+
List of Pydantic AI messages
|
|
54
|
+
"""
|
|
55
|
+
builder = MessagesBuilder()
|
|
56
|
+
for message in messages:
|
|
57
|
+
if isinstance(message, LlmClientMessage):
|
|
58
|
+
self._add_client_message(builder, message)
|
|
59
|
+
elif isinstance(message, LlmClientFunctionCallResultMessage):
|
|
60
|
+
self._add_result_message(builder, message)
|
|
61
|
+
return builder.messages
|
|
62
|
+
|
|
63
|
+
def _add_client_message(
|
|
64
|
+
self, builder: MessagesBuilder, message: LlmClientMessage
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Add a client message to the builder.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
builder : MessagesBuilder
|
|
71
|
+
The message builder to add to
|
|
72
|
+
message : LlmClientMessage
|
|
73
|
+
The client message to add
|
|
74
|
+
"""
|
|
75
|
+
content = message.content
|
|
76
|
+
|
|
77
|
+
if isinstance(content, LlmClientFunctionCall):
|
|
78
|
+
# Use override if available, otherwise use base ID
|
|
79
|
+
from ._utils import hash_tool_call
|
|
80
|
+
|
|
81
|
+
base_id = hash_tool_call(content.function, content.input_arguments)
|
|
82
|
+
tool_call_id = self._overrides.get(base_id, base_id)
|
|
83
|
+
|
|
84
|
+
builder.add(
|
|
85
|
+
ToolCallPart(
|
|
86
|
+
tool_name=content.function,
|
|
87
|
+
tool_call_id=tool_call_id,
|
|
88
|
+
args=content.input_arguments,
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
if isinstance(content, str):
|
|
94
|
+
if message.role == RoleEnum.human:
|
|
95
|
+
builder.add(UserPromptPart(content=content))
|
|
96
|
+
elif message.role == RoleEnum.ai:
|
|
97
|
+
builder.add(TextPart(content=content))
|
|
98
|
+
else:
|
|
99
|
+
builder.add(TextPart(content=content))
|
|
100
|
+
|
|
101
|
+
def _add_result_message(
|
|
102
|
+
self,
|
|
103
|
+
builder: MessagesBuilder,
|
|
104
|
+
message: LlmClientFunctionCallResultMessage,
|
|
105
|
+
) -> None:
|
|
106
|
+
"""Add a function call result message to the builder.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
builder : MessagesBuilder
|
|
111
|
+
The message builder to add to
|
|
112
|
+
message : LlmClientFunctionCallResultMessage
|
|
113
|
+
The result message to add
|
|
114
|
+
"""
|
|
115
|
+
from ._utils import hash_tool_call
|
|
116
|
+
|
|
117
|
+
# Generate base ID and use override if available
|
|
118
|
+
base_id = hash_tool_call(message.function, message.input_arguments)
|
|
119
|
+
tool_call_id = self._overrides.get(base_id, base_id)
|
|
120
|
+
|
|
121
|
+
builder.add(
|
|
122
|
+
ToolReturnPart(
|
|
123
|
+
tool_name=message.function,
|
|
124
|
+
tool_call_id=tool_call_id,
|
|
125
|
+
content=ContentSerializer.serialize_result(message),
|
|
126
|
+
)
|
|
127
|
+
)
|