openbb-agent 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openbb_agent-0.1.0/PKG-INFO +29 -0
- openbb_agent-0.1.0/openbb_agent/__init__.py +0 -0
- openbb_agent-0.1.0/openbb_agent/agent.py +440 -0
- openbb_agent-0.1.0/openbb_agent/cache.py +74 -0
- openbb_agent-0.1.0/openbb_agent/config.py +49 -0
- openbb_agent-0.1.0/openbb_agent/crews/__init__.py +0 -0
- openbb_agent-0.1.0/openbb_agent/data/disabled_bash_commands.txt +8 -0
- openbb_agent-0.1.0/openbb_agent/flows/__init__.py +0 -0
- openbb_agent-0.1.0/openbb_agent/hooks.py +47 -0
- openbb_agent-0.1.0/openbb_agent/main.py +165 -0
- openbb_agent-0.1.0/openbb_agent/memory/__init__.py +4 -0
- openbb_agent-0.1.0/openbb_agent/memory/config/__init__.py +17 -0
- openbb_agent-0.1.0/openbb_agent/memory/config/embedder_config.py +26 -0
- openbb_agent-0.1.0/openbb_agent/memory/config/enums.py +22 -0
- openbb_agent-0.1.0/openbb_agent/memory/config/llm_config.py +31 -0
- openbb_agent-0.1.0/openbb_agent/memory/config/mem0_config_builder.py +118 -0
- openbb_agent-0.1.0/openbb_agent/memory/config/memory_config.py +12 -0
- openbb_agent-0.1.0/openbb_agent/memory/config/vector_store_config.py +33 -0
- openbb_agent-0.1.0/openbb_agent/memory/ids.py +15 -0
- openbb_agent-0.1.0/openbb_agent/memory/service.py +121 -0
- openbb_agent-0.1.0/openbb_agent/memory/summarizer.py +39 -0
- openbb_agent-0.1.0/openbb_agent/models.py +334 -0
- openbb_agent-0.1.0/openbb_agent/obb_model/Nasdaq_models.py +128 -0
- openbb_agent-0.1.0/openbb_agent/obb_model/__init__.py +0 -0
- openbb_agent-0.1.0/openbb_agent/obb_model/commodity_models.py +16 -0
- openbb_agent-0.1.0/openbb_agent/obb_model/currency_models.py +59 -0
- openbb_agent-0.1.0/openbb_agent/obb_model/derivatives_models.py +45 -0
- openbb_agent-0.1.0/openbb_agent/obb_model/economy_models.py +74 -0
- openbb_agent-0.1.0/openbb_agent/obb_model/equity_models.py +28 -0
- openbb_agent-0.1.0/openbb_agent/obb_model/etf_models.py +50 -0
- openbb_agent-0.1.0/openbb_agent/obb_model/index_models.py +133 -0
- openbb_agent-0.1.0/openbb_agent/obb_model/news_models.py +42 -0
- openbb_agent-0.1.0/openbb_agent/obb_model/rate_model.py +92 -0
- openbb_agent-0.1.0/openbb_agent/prompts/__init__.py +4 -0
- openbb_agent-0.1.0/openbb_agent/prompts/prompt_manager.py +61 -0
- openbb_agent-0.1.0/openbb_agent/prompts/render.py +110 -0
- openbb_agent-0.1.0/openbb_agent/prompts/templates/base.json +1 -0
- openbb_agent-0.1.0/openbb_agent/prompts/templates/mem0_fact_extraction.json +3 -0
- openbb_agent-0.1.0/openbb_agent/tools/__init__.py +31 -0
- openbb_agent-0.1.0/openbb_agent/tools/bash.py +56 -0
- openbb_agent-0.1.0/openbb_agent/tools/chart_table.py +217 -0
- openbb_agent-0.1.0/openbb_agent/tools/crew_flow.py +86 -0
- openbb_agent-0.1.0/openbb_agent/tools/delegate.py +67 -0
- openbb_agent-0.1.0/openbb_agent/tools/file_io.py +83 -0
- openbb_agent-0.1.0/openbb_agent/tools/mcp.py +65 -0
- openbb_agent-0.1.0/openbb_agent/tools/widget.py +94 -0
- openbb_agent-0.1.0/openbb_agent.egg-info/PKG-INFO +29 -0
- openbb_agent-0.1.0/openbb_agent.egg-info/SOURCES.txt +52 -0
- openbb_agent-0.1.0/openbb_agent.egg-info/dependency_links.txt +1 -0
- openbb_agent-0.1.0/openbb_agent.egg-info/entry_points.txt +2 -0
- openbb_agent-0.1.0/openbb_agent.egg-info/requires.txt +15 -0
- openbb_agent-0.1.0/openbb_agent.egg-info/top_level.txt +1 -0
- openbb_agent-0.1.0/pyproject.toml +49 -0
- openbb_agent-0.1.0/setup.cfg +4 -0
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: openbb-agent
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: OpenBB Terminal Pro copilot agent
|
|
5
|
+
Author-email: hjm_g <ghj791990@gmail.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Repository, https://github.com/Findworth/openbb_agent
|
|
8
|
+
Keywords: openbb,finance,agent,copilot
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Requires-Python: >=3.10
|
|
16
|
+
Requires-Dist: openbb-ai
|
|
17
|
+
Requires-Dist: fastapi
|
|
18
|
+
Requires-Dist: uvicorn[standard]
|
|
19
|
+
Requires-Dist: sse-starlette
|
|
20
|
+
Requires-Dist: python-dotenv
|
|
21
|
+
Requires-Dist: pydantic-settings
|
|
22
|
+
Requires-Dist: magentic
|
|
23
|
+
Requires-Dist: crewai
|
|
24
|
+
Requires-Dist: redis
|
|
25
|
+
Requires-Dist: mem0ai
|
|
26
|
+
Requires-Dist: chromadb
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: ruff; extra == "dev"
|
|
29
|
+
Requires-Dist: pytest; extra == "dev"
|
|
File without changes
|
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import inspect
|
|
3
|
+
import json
|
|
4
|
+
from typing import AsyncGenerator
|
|
5
|
+
|
|
6
|
+
from magentic import (
|
|
7
|
+
AsyncStreamedResponse,
|
|
8
|
+
Chat,
|
|
9
|
+
AsyncStreamedStr,
|
|
10
|
+
FunctionCall,
|
|
11
|
+
OpenaiChatModel,
|
|
12
|
+
)
|
|
13
|
+
from .prompts import PromptManager, render_system_prompt
|
|
14
|
+
from .memory.summarizer import summarize_and_store
|
|
15
|
+
from .models import (
|
|
16
|
+
ClientFunctionCallError,
|
|
17
|
+
LlmClientMessage,
|
|
18
|
+
QueryRequest,
|
|
19
|
+
Citation,
|
|
20
|
+
CitationCollection,
|
|
21
|
+
CitationCollectionSSE,
|
|
22
|
+
DataContent,
|
|
23
|
+
DataFileReferences,
|
|
24
|
+
DataSourceRequest,
|
|
25
|
+
FunctionCallSSE,
|
|
26
|
+
FunctionCallSSEData,
|
|
27
|
+
LlmClientFunctionCallResultMessage,
|
|
28
|
+
LlmFunctionCall,
|
|
29
|
+
MessageChunkSSE,
|
|
30
|
+
MessageChunkSSEData,
|
|
31
|
+
StatusUpdateSSE,
|
|
32
|
+
StatusUpdateSSEData,
|
|
33
|
+
Widget,
|
|
34
|
+
)
|
|
35
|
+
from magentic import (
|
|
36
|
+
AssistantMessage,
|
|
37
|
+
FunctionResultMessage,
|
|
38
|
+
AnyMessage,
|
|
39
|
+
SystemMessage,
|
|
40
|
+
UserMessage,
|
|
41
|
+
)
|
|
42
|
+
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, cast
|
|
43
|
+
from openai import AsyncOpenAI
|
|
44
|
+
from openai.types.chat import (
|
|
45
|
+
ChatCompletionMessageParam,
|
|
46
|
+
ChatCompletionSystemMessageParam,
|
|
47
|
+
ChatCompletionUserMessageParam,
|
|
48
|
+
ChatCompletionAssistantMessageParam,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def reasoning_step(
|
|
53
|
+
event_type: Literal["INFO", "WARNING", "ERROR"],
|
|
54
|
+
message: str,
|
|
55
|
+
details: dict[str, Any] | None = None,
|
|
56
|
+
) -> StatusUpdateSSE:
|
|
57
|
+
return StatusUpdateSSE(
|
|
58
|
+
data=StatusUpdateSSEData(
|
|
59
|
+
eventType=event_type,
|
|
60
|
+
message=message,
|
|
61
|
+
details=[details] if details else [],
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class WrappedFunctionProtocol(Protocol):
|
|
67
|
+
async def execute_post_processing(
|
|
68
|
+
self, data: list[DataContent | DataFileReferences | ClientFunctionCallError]
|
|
69
|
+
) -> str: ...
|
|
70
|
+
|
|
71
|
+
def execute_callbacks(
|
|
72
|
+
self,
|
|
73
|
+
function_call_result: LlmClientFunctionCallResultMessage,
|
|
74
|
+
request: QueryRequest,
|
|
75
|
+
) -> AsyncGenerator[Any, None]: ...
|
|
76
|
+
|
|
77
|
+
def __call__(
|
|
78
|
+
self, *args: Any, **kwargs: Any
|
|
79
|
+
) -> AsyncGenerator[FunctionCallSSE | StatusUpdateSSE, None]: ...
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def remote_function_call(
|
|
83
|
+
function: Literal["get_widget_data"],
|
|
84
|
+
output_formatter: Callable[..., Awaitable[str]] | None = None,
|
|
85
|
+
callbacks: list[Callable[..., Awaitable[Any]]] | None = None,
|
|
86
|
+
) -> Callable:
|
|
87
|
+
if function not in ["get_widget_data"]:
|
|
88
|
+
raise ValueError(f"Unsupported function: {function}. Must be 'get_widget_data'.")
|
|
89
|
+
|
|
90
|
+
def outer_wrapper(func: Callable) -> WrappedFunctionProtocol:
|
|
91
|
+
class InnerWrapper(WrappedFunctionProtocol):
|
|
92
|
+
def __init__(self):
|
|
93
|
+
self.__name__ = func.__name__
|
|
94
|
+
self.__signature__ = self._mask_signature(func)
|
|
95
|
+
self.__doc__ = func.__doc__
|
|
96
|
+
self.local_function = func
|
|
97
|
+
self.function = function
|
|
98
|
+
self.post_process_function = output_formatter
|
|
99
|
+
self.callbacks = callbacks
|
|
100
|
+
self._request = None
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def request(self) -> QueryRequest:
|
|
104
|
+
return self._request
|
|
105
|
+
|
|
106
|
+
@request.setter
|
|
107
|
+
def request(self, request: QueryRequest):
|
|
108
|
+
self._request = request
|
|
109
|
+
|
|
110
|
+
def _mask_signature(self, func: Callable):
|
|
111
|
+
signature = inspect.signature(func)
|
|
112
|
+
masked_params = [p for p in signature.parameters.values() if p.name != "request"]
|
|
113
|
+
return signature.replace(parameters=masked_params)
|
|
114
|
+
|
|
115
|
+
async def execute_callbacks(
|
|
116
|
+
self,
|
|
117
|
+
function_call_result: LlmClientFunctionCallResultMessage,
|
|
118
|
+
request: QueryRequest,
|
|
119
|
+
) -> AsyncGenerator[Any, None]:
|
|
120
|
+
if self.callbacks:
|
|
121
|
+
for callback in self.callbacks:
|
|
122
|
+
if inspect.isasyncgenfunction(callback):
|
|
123
|
+
async for event in callback(function_call_result, request):
|
|
124
|
+
yield event
|
|
125
|
+
else:
|
|
126
|
+
await callback(function_call_result, self.request)
|
|
127
|
+
|
|
128
|
+
async def execute_post_processing(
|
|
129
|
+
self,
|
|
130
|
+
data: list[DataContent | DataFileReferences | ClientFunctionCallError],
|
|
131
|
+
) -> str:
|
|
132
|
+
if self.post_process_function:
|
|
133
|
+
return await self.post_process_function(data)
|
|
134
|
+
return str(data)
|
|
135
|
+
|
|
136
|
+
async def __call__(
|
|
137
|
+
self, *args, **kwargs
|
|
138
|
+
) -> AsyncGenerator[FunctionCallSSE | StatusUpdateSSE, None]:
|
|
139
|
+
bound_args = self.__signature__.bind(*args, **kwargs).arguments
|
|
140
|
+
async for event in func(*args, request=self._request, **kwargs):
|
|
141
|
+
if isinstance(event, StatusUpdateSSE):
|
|
142
|
+
yield event
|
|
143
|
+
elif isinstance(event, DataSourceRequest):
|
|
144
|
+
yield FunctionCallSSE(
|
|
145
|
+
data=FunctionCallSSEData(
|
|
146
|
+
function=self.function,
|
|
147
|
+
input_arguments={
|
|
148
|
+
"data_sources": [
|
|
149
|
+
DataSourceRequest(
|
|
150
|
+
widget_uuid=event.widget_uuid,
|
|
151
|
+
origin=event.origin,
|
|
152
|
+
id=event.id,
|
|
153
|
+
input_args=event.input_args,
|
|
154
|
+
)
|
|
155
|
+
]
|
|
156
|
+
},
|
|
157
|
+
extra_state={
|
|
158
|
+
"copilot_function_call_arguments": {**bound_args},
|
|
159
|
+
"_locally_bound_function": func.__name__,
|
|
160
|
+
},
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
return
|
|
164
|
+
else:
|
|
165
|
+
yield event
|
|
166
|
+
|
|
167
|
+
return InnerWrapper()
|
|
168
|
+
|
|
169
|
+
return outer_wrapper
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def get_remote_data(
|
|
173
|
+
widget: Widget,
|
|
174
|
+
input_arguments: dict[str, Any],
|
|
175
|
+
) -> DataSourceRequest:
|
|
176
|
+
return DataSourceRequest(
|
|
177
|
+
widget_uuid=str(widget.uuid),
|
|
178
|
+
origin=widget.origin,
|
|
179
|
+
id=widget.widget_id,
|
|
180
|
+
input_args=input_arguments,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _last_user_message(request: QueryRequest) -> str:
|
|
185
|
+
for m in reversed[LlmClientFunctionCallResultMessage | LlmClientMessage](request.messages):
|
|
186
|
+
if isinstance(m, LlmClientMessage) and m.role == "human":
|
|
187
|
+
return m.content if isinstance(m.content, str) else ""
|
|
188
|
+
return ""
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
async def _get_memory_context(
|
|
192
|
+
memory_service,
|
|
193
|
+
request: QueryRequest,
|
|
194
|
+
memory_ids: dict[str, str | None],
|
|
195
|
+
) -> str:
|
|
196
|
+
if not memory_service or not memory_service.is_available():
|
|
197
|
+
return ""
|
|
198
|
+
query = _last_user_message(request)
|
|
199
|
+
if not query:
|
|
200
|
+
return ""
|
|
201
|
+
kwargs: dict = {
|
|
202
|
+
"query": query,
|
|
203
|
+
"user_id": memory_ids.get("user_id") or "default",
|
|
204
|
+
"limit": 5,
|
|
205
|
+
}
|
|
206
|
+
if memory_ids.get("agent_id"):
|
|
207
|
+
kwargs["agent_id"] = memory_ids["agent_id"]
|
|
208
|
+
if memory_ids.get("app_id"):
|
|
209
|
+
kwargs["app_id"] = memory_ids["app_id"]
|
|
210
|
+
results = await asyncio.to_thread(memory_service.search, **kwargs)
|
|
211
|
+
return memory_service.format_context(results)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
async def run_copilot(
|
|
215
|
+
request: QueryRequest,
|
|
216
|
+
memory_ids: dict[str, str | None],
|
|
217
|
+
functions: list,
|
|
218
|
+
model,
|
|
219
|
+
memory_service=None,
|
|
220
|
+
) -> AsyncGenerator[dict, None]:
|
|
221
|
+
memory_context = await _get_memory_context(memory_service, request, memory_ids)
|
|
222
|
+
workspace_options = getattr(request, "workspace_options", []) or []
|
|
223
|
+
run_crew_enabled = "run-crew" in workspace_options
|
|
224
|
+
run_flow_enabled = "run-flow" in workspace_options
|
|
225
|
+
prompt_manager = PromptManager(
|
|
226
|
+
include_widgets=True,
|
|
227
|
+
include_flows_crews=run_crew_enabled or run_flow_enabled,
|
|
228
|
+
)
|
|
229
|
+
system_prompt = render_system_prompt(
|
|
230
|
+
request=request,
|
|
231
|
+
widget_collection=request.widgets,
|
|
232
|
+
memory_context=memory_context,
|
|
233
|
+
prompt_manager=prompt_manager,
|
|
234
|
+
)
|
|
235
|
+
agent = OpenBBAgent(
|
|
236
|
+
query_request=request,
|
|
237
|
+
system_prompt=system_prompt,
|
|
238
|
+
functions=functions,
|
|
239
|
+
model=model,
|
|
240
|
+
prompt_manager=prompt_manager,
|
|
241
|
+
)
|
|
242
|
+
async for event in agent.run():
|
|
243
|
+
yield event
|
|
244
|
+
|
|
245
|
+
if memory_service:
|
|
246
|
+
await summarize_and_store(
|
|
247
|
+
memory_service,
|
|
248
|
+
request,
|
|
249
|
+
memory_ids=memory_ids,
|
|
250
|
+
n_rounds=3,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def get_wrapped_function(
|
|
255
|
+
function_name: str, functions: list[Any]
|
|
256
|
+
) -> WrappedFunctionProtocol:
|
|
257
|
+
matching_local_functions = list(filter(lambda x: x.__name__ == function_name, functions))
|
|
258
|
+
wrapped_function = matching_local_functions[0] if matching_local_functions else None
|
|
259
|
+
if not wrapped_function:
|
|
260
|
+
raise ValueError(f"Local function not found: {function_name}")
|
|
261
|
+
return wrapped_function
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class OpenBBAgent:
|
|
265
|
+
def __init__(
|
|
266
|
+
self,
|
|
267
|
+
query_request: QueryRequest,
|
|
268
|
+
system_prompt: str,
|
|
269
|
+
functions: list[Callable] | None = None,
|
|
270
|
+
chat_class: type[Chat] | None = None,
|
|
271
|
+
model: str | OpenaiChatModel | None = None,
|
|
272
|
+
hooks: Any = None,
|
|
273
|
+
prompt_manager: PromptManager | None = None,
|
|
274
|
+
**kwargs: Any,
|
|
275
|
+
):
|
|
276
|
+
self.request = query_request
|
|
277
|
+
self.widgets = query_request.widgets
|
|
278
|
+
self.system_prompt = system_prompt
|
|
279
|
+
self.functions = functions
|
|
280
|
+
self.chat_class = chat_class or Chat
|
|
281
|
+
self._model = model if model is not None else OpenaiChatModel()
|
|
282
|
+
self._chat: Chat | None = None
|
|
283
|
+
self._citations: CitationCollection | None = None
|
|
284
|
+
self._messages: list[AnyMessage] = []
|
|
285
|
+
self._hooks = hooks
|
|
286
|
+
self._kwargs = kwargs
|
|
287
|
+
self.prompt_manager = prompt_manager or PromptManager()
|
|
288
|
+
|
|
289
|
+
async def run(self, max_completions: int = 10) -> AsyncGenerator[dict, None]:
|
|
290
|
+
self._messages = await self._handle_request()
|
|
291
|
+
self._citations = await self._handle_callbacks()
|
|
292
|
+
self._chat = self.chat_class(
|
|
293
|
+
messages=self._messages,
|
|
294
|
+
output_types=[AsyncStreamedResponse],
|
|
295
|
+
functions=self.functions if self.functions else None,
|
|
296
|
+
model=self._model, # type: ignore[arg-type]
|
|
297
|
+
**self._kwargs,
|
|
298
|
+
)
|
|
299
|
+
async for event in self._execute(max_completions=max_completions):
|
|
300
|
+
yield event.model_dump()
|
|
301
|
+
if self._citations.citations:
|
|
302
|
+
yield CitationCollectionSSE(data=self._citations).model_dump()
|
|
303
|
+
|
|
304
|
+
async def _handle_callbacks(self) -> CitationCollection:
|
|
305
|
+
if not self.functions:
|
|
306
|
+
return CitationCollection(citations=[])
|
|
307
|
+
citations: list[Citation] = []
|
|
308
|
+
if isinstance(self.request.messages[-1], LlmClientFunctionCallResultMessage):
|
|
309
|
+
wrapped_function = get_wrapped_function(
|
|
310
|
+
function_name=self.request.messages[-1].function,
|
|
311
|
+
functions=self.functions,
|
|
312
|
+
)
|
|
313
|
+
if hasattr(wrapped_function, "execute_callbacks"):
|
|
314
|
+
async for event in wrapped_function.execute_callbacks( # type: ignore
|
|
315
|
+
request=self.request, function_call_result=self.request.messages[-1]
|
|
316
|
+
):
|
|
317
|
+
if isinstance(event, Citation):
|
|
318
|
+
citations.append(event)
|
|
319
|
+
return CitationCollection(citations=citations)
|
|
320
|
+
|
|
321
|
+
async def _handle_request(self) -> list[AnyMessage]:
|
|
322
|
+
chat_messages: list[AnyMessage] = [SystemMessage(self.system_prompt)]
|
|
323
|
+
for message in self.request.messages:
|
|
324
|
+
match message:
|
|
325
|
+
case LlmClientMessage(role="human"):
|
|
326
|
+
chat_messages.append(UserMessage(content=message.content))
|
|
327
|
+
case LlmClientMessage(role="ai") if isinstance(message.content, str):
|
|
328
|
+
chat_messages.append(AssistantMessage(content=message.content))
|
|
329
|
+
case LlmClientMessage(role="ai") if isinstance(message.content, LlmFunctionCall):
|
|
330
|
+
pass
|
|
331
|
+
case LlmClientFunctionCallResultMessage(role="tool"):
|
|
332
|
+
if not self.functions:
|
|
333
|
+
continue
|
|
334
|
+
fn_name = message.extra_state.get("_locally_bound_function") or message.function
|
|
335
|
+
wrapped_function = get_wrapped_function(
|
|
336
|
+
function_name=fn_name,
|
|
337
|
+
functions=self.functions,
|
|
338
|
+
)
|
|
339
|
+
function_call = FunctionCall(
|
|
340
|
+
function=wrapped_function,
|
|
341
|
+
**message.extra_state.get("copilot_function_call_arguments", {}),
|
|
342
|
+
)
|
|
343
|
+
chat_messages.append(AssistantMessage(function_call))
|
|
344
|
+
if hasattr(wrapped_function, "execute_post_processing"):
|
|
345
|
+
content = await wrapped_function.execute_post_processing(message.data)
|
|
346
|
+
else:
|
|
347
|
+
content = str(message.data)
|
|
348
|
+
chat_messages.append(
|
|
349
|
+
FunctionResultMessage(
|
|
350
|
+
content=content,
|
|
351
|
+
function_call=function_call,
|
|
352
|
+
)
|
|
353
|
+
)
|
|
354
|
+
case _:
|
|
355
|
+
raise ValueError(f"Unsupported message type: {message}")
|
|
356
|
+
return chat_messages
|
|
357
|
+
|
|
358
|
+
async def _handle_text_stream(
|
|
359
|
+
self, stream: AsyncStreamedStr
|
|
360
|
+
) -> AsyncGenerator[MessageChunkSSE, None]:
|
|
361
|
+
self._chat = cast(Chat, self._chat)
|
|
362
|
+
async for chunk in stream:
|
|
363
|
+
yield MessageChunkSSE(data=MessageChunkSSEData(delta=chunk))
|
|
364
|
+
|
|
365
|
+
async def _handle_function_call(
|
|
366
|
+
self, function_call: FunctionCall
|
|
367
|
+
) -> AsyncGenerator[FunctionCallSSE | StatusUpdateSSE | Any, None]:
|
|
368
|
+
self._chat = cast(Chat, self._chat)
|
|
369
|
+
function_call_result: str = ""
|
|
370
|
+
if not isinstance(self._chat.last_message, AssistantMessage):
|
|
371
|
+
raise ValueError("Last message is not an assistant message")
|
|
372
|
+
if hasattr(function_call.function, "request"):
|
|
373
|
+
function_call.function.request = self.request
|
|
374
|
+
if self._hooks is not None and hasattr(self._hooks, "run"):
|
|
375
|
+
await self._hooks.run("pre_tool_call", function_call=function_call)
|
|
376
|
+
yield reasoning_step(
|
|
377
|
+
"INFO",
|
|
378
|
+
"Tool execution started",
|
|
379
|
+
{"function": getattr(function_call.function, "__name__", "unknown")},
|
|
380
|
+
)
|
|
381
|
+
async for event in function_call():
|
|
382
|
+
if isinstance(event, StatusUpdateSSE):
|
|
383
|
+
yield event
|
|
384
|
+
elif isinstance(event, FunctionCallSSE):
|
|
385
|
+
yield event
|
|
386
|
+
return
|
|
387
|
+
elif hasattr(event, "event") and getattr(event, "event", None) == "copilotMessageArtifact":
|
|
388
|
+
yield event
|
|
389
|
+
else:
|
|
390
|
+
function_call_result += str(event)
|
|
391
|
+
yield reasoning_step(
|
|
392
|
+
"INFO",
|
|
393
|
+
"Tool execution completed",
|
|
394
|
+
{"function": getattr(function_call.function, "__name__", "unknown")},
|
|
395
|
+
)
|
|
396
|
+
if self._hooks is not None and hasattr(self._hooks, "run"):
|
|
397
|
+
await self._hooks.run(
|
|
398
|
+
"post_tool_call",
|
|
399
|
+
function_call=function_call,
|
|
400
|
+
result=function_call_result,
|
|
401
|
+
)
|
|
402
|
+
self._chat = self._chat.add_message(
|
|
403
|
+
FunctionResultMessage(
|
|
404
|
+
content=function_call_result,
|
|
405
|
+
function_call=function_call,
|
|
406
|
+
)
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
async def _execute(
|
|
410
|
+
self, max_completions: int
|
|
411
|
+
) -> AsyncGenerator[MessageChunkSSE | FunctionCallSSE | StatusUpdateSSE, None]:
|
|
412
|
+
completion_count = 0
|
|
413
|
+
while completion_count < max_completions:
|
|
414
|
+
completion_count += 1
|
|
415
|
+
yield reasoning_step(
|
|
416
|
+
"INFO",
|
|
417
|
+
f"LLM reasoning (round {completion_count})...",
|
|
418
|
+
{"completion_count": completion_count},
|
|
419
|
+
)
|
|
420
|
+
if self._hooks is not None and hasattr(self._hooks, "run"):
|
|
421
|
+
await self._hooks.run("pre_llm", request=self.request)
|
|
422
|
+
self._chat = await cast(Chat, self._chat).asubmit()
|
|
423
|
+
if self._hooks is not None and hasattr(self._hooks, "run"):
|
|
424
|
+
await self._hooks.run("post_llm", chat=self._chat, request=self.request)
|
|
425
|
+
if isinstance(self._chat.last_message.content, AsyncStreamedResponse):
|
|
426
|
+
async for item in self._chat.last_message.content:
|
|
427
|
+
if isinstance(item, AsyncStreamedStr):
|
|
428
|
+
async for event in self._handle_text_stream(item):
|
|
429
|
+
yield event
|
|
430
|
+
elif isinstance(item, FunctionCall):
|
|
431
|
+
yield reasoning_step(
|
|
432
|
+
"INFO",
|
|
433
|
+
"Executing tool/function call...",
|
|
434
|
+
{"function": getattr(item.function, "__name__", str(item.function))},
|
|
435
|
+
)
|
|
436
|
+
async for event in self._handle_function_call(item):
|
|
437
|
+
yield event
|
|
438
|
+
if isinstance(event, FunctionCallSSE):
|
|
439
|
+
return
|
|
440
|
+
return
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
_env_cache_path = Path(__file__).resolve().parent.parent / ".env.cache"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FunctionCacheConfig(BaseSettings):
|
|
13
|
+
model_config = SettingsConfigDict(
|
|
14
|
+
env_file=_env_cache_path,
|
|
15
|
+
env_file_encoding="utf-8",
|
|
16
|
+
extra="ignore",
|
|
17
|
+
)
|
|
18
|
+
redis_url: str = ""
|
|
19
|
+
ttl_seconds: int = 3600
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FunctionCache:
|
|
23
|
+
def __init__(self, config: FunctionCacheConfig | None = None) -> None:
|
|
24
|
+
cfg = config or FunctionCacheConfig()
|
|
25
|
+
self._redis_url = cfg.redis_url
|
|
26
|
+
self._ttl = cfg.ttl_seconds
|
|
27
|
+
self._client: Any = None
|
|
28
|
+
|
|
29
|
+
def _get_client(self) -> Any:
|
|
30
|
+
if self._client is not None:
|
|
31
|
+
return self._client
|
|
32
|
+
if not self._redis_url:
|
|
33
|
+
return None
|
|
34
|
+
try:
|
|
35
|
+
import redis.asyncio as redis
|
|
36
|
+
|
|
37
|
+
self._client = redis.from_url(self._redis_url, decode_responses=True)
|
|
38
|
+
return self._client
|
|
39
|
+
except ImportError:
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
def _cache_key(self, func_name: str, args: dict[str, Any]) -> str:
|
|
43
|
+
payload = json.dumps(sorted(args.items()), sort_keys=True, default=str)
|
|
44
|
+
h = hashlib.sha256(f"{func_name}:{payload}".encode()).hexdigest()
|
|
45
|
+
return f"fc:{func_name}:{h}"
|
|
46
|
+
|
|
47
|
+
async def get(self, func_name: str, args: dict[str, Any]) -> str | None:
|
|
48
|
+
client = self._get_client()
|
|
49
|
+
if client is None:
|
|
50
|
+
return None
|
|
51
|
+
try:
|
|
52
|
+
key = self._cache_key(func_name, args)
|
|
53
|
+
return await client.get(key)
|
|
54
|
+
except Exception:
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
async def set(
|
|
58
|
+
self,
|
|
59
|
+
func_name: str,
|
|
60
|
+
args: dict[str, Any],
|
|
61
|
+
value: str,
|
|
62
|
+
ttl: int | None = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
client = self._get_client()
|
|
65
|
+
if client is None:
|
|
66
|
+
return
|
|
67
|
+
try:
|
|
68
|
+
key = self._cache_key(func_name, args)
|
|
69
|
+
await client.set(key, value, ex=ttl or self._ttl)
|
|
70
|
+
except Exception:
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
def is_enabled(self) -> bool:
|
|
74
|
+
return bool(self._redis_url)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from magentic import OpenaiChatModel
|
|
5
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ModelConfig(BaseSettings):
|
|
9
|
+
model_config = SettingsConfigDict(
|
|
10
|
+
env_file=".env",
|
|
11
|
+
env_file_encoding="utf-8",
|
|
12
|
+
extra="ignore",
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
openai_api_key: str = ""
|
|
16
|
+
openai_base_url: str | None = None
|
|
17
|
+
deepseek_api_key: str = ""
|
|
18
|
+
deepseek_base_url: str = "https://api.deepseek.com"
|
|
19
|
+
model_name: str = "deepseek-chat"
|
|
20
|
+
delegate_model_name: str | None = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_chat_model(
|
|
24
|
+
api_key: str | None = None,
|
|
25
|
+
base_url: str | None = None,
|
|
26
|
+
model: str | None = None,
|
|
27
|
+
) -> OpenaiChatModel:
|
|
28
|
+
cfg = get_config()
|
|
29
|
+
model_name = model or cfg.model_name
|
|
30
|
+
if "deepseek" in model_name.lower():
|
|
31
|
+
key = api_key or cfg.deepseek_api_key or None
|
|
32
|
+
url = base_url or cfg.deepseek_base_url or None
|
|
33
|
+
else:
|
|
34
|
+
key = api_key or cfg.openai_api_key or None
|
|
35
|
+
url = base_url or cfg.openai_base_url or None
|
|
36
|
+
return OpenaiChatModel(
|
|
37
|
+
model=model_name,
|
|
38
|
+
api_key=key,
|
|
39
|
+
base_url=url,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@lru_cache
|
|
44
|
+
def _get_config() -> ModelConfig:
|
|
45
|
+
return ModelConfig()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_config() -> ModelConfig:
|
|
49
|
+
return _get_config()
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any, Awaitable, Callable
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class HookType(str, Enum):
|
|
6
|
+
PRE_LLM = "pre_llm"
|
|
7
|
+
POST_LLM = "post_llm"
|
|
8
|
+
PRE_TOOL_CALL = "pre_tool_call"
|
|
9
|
+
POST_TOOL_CALL = "post_tool_call"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
HookFunc = Callable[..., Awaitable[None]]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class HooksRegistry:
|
|
16
|
+
def __init__(self) -> None:
|
|
17
|
+
self._hooks: dict[HookType, list[HookFunc]] = {
|
|
18
|
+
t: [] for t in HookType
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
def register(self, hook_type: HookType) -> Callable[[HookFunc], HookFunc]:
|
|
22
|
+
def decorator(fn: HookFunc) -> HookFunc:
|
|
23
|
+
self._hooks[hook_type].append(fn)
|
|
24
|
+
return fn
|
|
25
|
+
return decorator
|
|
26
|
+
|
|
27
|
+
async def run(
|
|
28
|
+
self,
|
|
29
|
+
hook_type: HookType | str,
|
|
30
|
+
*args: Any,
|
|
31
|
+
**kwargs: Any,
|
|
32
|
+
) -> None:
|
|
33
|
+
ht = HookType(hook_type) if isinstance(hook_type, str) else hook_type
|
|
34
|
+
for fn in self._hooks.get(ht, []):
|
|
35
|
+
await fn(*args, **kwargs)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def register_hook(hook_type: HookType, registry: HooksRegistry | None = None) -> Callable[[HookFunc], HookFunc]:
|
|
39
|
+
_registry = registry or _default_registry
|
|
40
|
+
|
|
41
|
+
def decorator(fn: HookFunc) -> HookFunc:
|
|
42
|
+
_registry._hooks[hook_type].append(fn)
|
|
43
|
+
return fn
|
|
44
|
+
return decorator
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
_default_registry = HooksRegistry()
|