dm-aioaiagent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Optional, Literal
|
|
4
|
+
from typing_extensions import TypedDict
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
from threading import Thread
|
|
7
|
+
from langchain_openai import ChatOpenAI
|
|
8
|
+
from langchain_core.tools import BaseTool
|
|
9
|
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
10
|
+
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, ToolMessage
|
|
11
|
+
from langgraph.graph import StateGraph
|
|
12
|
+
from dm_logger import DMLogger
|
|
13
|
+
|
|
14
|
+
__all__ = ["DMAIAgent"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Message(TypedDict):
|
|
18
|
+
role: Literal["user", "ai"]
|
|
19
|
+
content: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class InnerState(BaseModel):
|
|
23
|
+
messages: list[BaseMessage] = Field(default=[])
|
|
24
|
+
context: list[Message] = Field(default=[])
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class InputState(BaseModel):
|
|
28
|
+
messages: list[Message]
|
|
29
|
+
inner_state: Optional[InnerState] = Field(default=InnerState())
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OutputState(TypedDict):
|
|
33
|
+
answer: str
|
|
34
|
+
context: list[Message]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class DMAIAgent:
|
|
38
|
+
agent_name = "AIAgent"
|
|
39
|
+
_allowed_roles = ("user", "ai")
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
system_message: str = "You are a helpful assistant.",
|
|
44
|
+
tools: list[BaseTool] = None,
|
|
45
|
+
*,
|
|
46
|
+
model: str = "gpt-4o-mini",
|
|
47
|
+
temperature: int = 1,
|
|
48
|
+
agent_name: str = None,
|
|
49
|
+
input_output_logging: bool = True
|
|
50
|
+
):
|
|
51
|
+
if not os.getenv("OPENAI_API_KEY"):
|
|
52
|
+
raise EnvironmentError("OPENAI_API_KEY environment variable is not set!")
|
|
53
|
+
|
|
54
|
+
self._logger = DMLogger(agent_name or self.agent_name)
|
|
55
|
+
self._input_output_logging = input_output_logging
|
|
56
|
+
self._is_tools_exists = bool(tools)
|
|
57
|
+
|
|
58
|
+
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=system_message),
|
|
59
|
+
MessagesPlaceholder(variable_name="messages")])
|
|
60
|
+
llm = ChatOpenAI(model=str(model), temperature=int(temperature))
|
|
61
|
+
if self._is_tools_exists:
|
|
62
|
+
self._tool_map = {t.name: t for t in tools}
|
|
63
|
+
llm = llm.bind_tools(tools)
|
|
64
|
+
self._agent = prompt | llm
|
|
65
|
+
|
|
66
|
+
workflow = StateGraph(input=InputState, output=OutputState)
|
|
67
|
+
workflow.add_node("Prepare messages", self._prepare_messages_node)
|
|
68
|
+
workflow.add_node("Invoke LLM", self._invoke_llm_node)
|
|
69
|
+
workflow.add_node("Execute tool", self._execute_tool_node)
|
|
70
|
+
workflow.add_node("Exit", self._exit_node)
|
|
71
|
+
|
|
72
|
+
workflow.add_edge("Prepare messages", "Invoke LLM")
|
|
73
|
+
workflow.add_conditional_edges(source="Invoke LLM",
|
|
74
|
+
path=self._messages_router,
|
|
75
|
+
path_map={"execute_tool": "Execute tool", "exit": "Exit"})
|
|
76
|
+
workflow.add_edge("Execute tool", "Invoke LLM")
|
|
77
|
+
workflow.set_entry_point("Prepare messages")
|
|
78
|
+
workflow.set_finish_point("Exit")
|
|
79
|
+
self._graph = workflow.compile()
|
|
80
|
+
|
|
81
|
+
def run(self, messages: list[Message]) -> OutputState:
|
|
82
|
+
return self._graph.invoke({"messages": messages})
|
|
83
|
+
|
|
84
|
+
def _prepare_messages_node(self, state: InputState) -> InputState:
|
|
85
|
+
state.messages = state.messages or [{"role": "user", "content": "Привіт"}]
|
|
86
|
+
state.inner_state = InnerState()
|
|
87
|
+
if self._input_output_logging:
|
|
88
|
+
self._logger.debug(input_messages=state.messages)
|
|
89
|
+
|
|
90
|
+
for item in state.messages:
|
|
91
|
+
role = item.get("role")
|
|
92
|
+
content = item.get("content")
|
|
93
|
+
if not role or role not in self._allowed_roles or not content:
|
|
94
|
+
continue
|
|
95
|
+
if role == "ai":
|
|
96
|
+
MessageClass = AIMessage
|
|
97
|
+
else:
|
|
98
|
+
MessageClass = HumanMessage
|
|
99
|
+
state.inner_state.messages.append(MessageClass(content))
|
|
100
|
+
return state
|
|
101
|
+
|
|
102
|
+
def _invoke_llm_node(self, state: InputState) -> InputState:
|
|
103
|
+
self._logger.debug("Run node: Invoke LLM")
|
|
104
|
+
ai_response = self._agent.invoke({"messages": state.inner_state.messages})
|
|
105
|
+
state.inner_state.messages.append(ai_response)
|
|
106
|
+
return state
|
|
107
|
+
|
|
108
|
+
def _execute_tool_node(self, state: InputState) -> InputState:
|
|
109
|
+
self._logger.debug("Run node: Execute tool")
|
|
110
|
+
threads = []
|
|
111
|
+
for tool_call in state.inner_state.messages[-1].tool_calls:
|
|
112
|
+
tool_id = tool_call["id"]
|
|
113
|
+
tool_name = tool_call["name"]
|
|
114
|
+
tool_args = tool_call["args"]
|
|
115
|
+
|
|
116
|
+
def tool_callback(tool_id=tool_id, tool_name=tool_name, tool_args=tool_args) -> None:
|
|
117
|
+
self._logger.debug("Invoke tool", tool_id=tool_id, tool_name=tool_name, tool_args=tool_args)
|
|
118
|
+
if tool_name in self._tool_map:
|
|
119
|
+
try:
|
|
120
|
+
tool_response = self._tool_map[tool_name].run(tool_args)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
self._logger.error(e, tool_id=tool_id)
|
|
123
|
+
tool_response = "Tool executed with an error!"
|
|
124
|
+
else:
|
|
125
|
+
tool_response = f"Tool not found!"
|
|
126
|
+
self._logger.debug(f"Tool response:\n{tool_response}", tool_id=tool_id)
|
|
127
|
+
|
|
128
|
+
state.inner_state.context.append({"tool_name": tool_name,
|
|
129
|
+
"tool_args": json.dumps(tool_args, ensure_ascii=False),
|
|
130
|
+
"tool_response": tool_response})
|
|
131
|
+
tool_message = ToolMessage(content=str(tool_response), name=tool_name, tool_call_id=tool_id)
|
|
132
|
+
state.inner_state.messages.append(tool_message)
|
|
133
|
+
|
|
134
|
+
threads.append(Thread(target=tool_callback, daemon=True))
|
|
135
|
+
|
|
136
|
+
for t in threads:
|
|
137
|
+
t.start()
|
|
138
|
+
for t in threads:
|
|
139
|
+
t.join()
|
|
140
|
+
|
|
141
|
+
return state
|
|
142
|
+
|
|
143
|
+
def _exit_node(self, state: InputState) -> OutputState:
|
|
144
|
+
answer = state.inner_state.messages[-1].content if state.inner_state.messages else ""
|
|
145
|
+
if self._input_output_logging:
|
|
146
|
+
self._logger.debug(f"Answer:\n{answer}")
|
|
147
|
+
return OutputState(answer=answer, context=state.inner_state.context)
|
|
148
|
+
|
|
149
|
+
def _messages_router(self, state: InputState) -> str:
|
|
150
|
+
if self._is_tools_exists and state.inner_state.messages[-1].tool_calls:
|
|
151
|
+
route = "execute_tool"
|
|
152
|
+
else:
|
|
153
|
+
route = "exit"
|
|
154
|
+
return route
|
|
155
|
+
|
|
156
|
+
def print_graph(self) -> None:
|
|
157
|
+
self._graph.get_graph().print_ascii()
|
|
158
|
+
|
|
159
|
+
def save_graph_image(self, path: str) -> None:
|
|
160
|
+
try:
|
|
161
|
+
image = self._graph.get_graph().draw_mermaid_png()
|
|
162
|
+
with open(str(path), "wb") as f:
|
|
163
|
+
f.write(image)
|
|
164
|
+
except Exception as e:
|
|
165
|
+
self._logger.error(e)
|
|
166
|
+
|
|
167
|
+
def set_logger(self, logger) -> None:
|
|
168
|
+
if (
|
|
169
|
+
hasattr(logger, "debug") and callable(logger.debug) and
|
|
170
|
+
hasattr(logger, "info") and callable(logger.info) and
|
|
171
|
+
hasattr(logger, "warning") and callable(logger.warning) and
|
|
172
|
+
hasattr(logger, "error") and callable(logger.error)
|
|
173
|
+
):
|
|
174
|
+
self._logger = logger
|
|
175
|
+
else:
|
|
176
|
+
print("Invalid logger")
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import sys
|
|
3
|
+
import asyncio
|
|
4
|
+
from langchain_core.messages import ToolMessage
|
|
5
|
+
|
|
6
|
+
from .ai_agent import DMAIAgent, InputState, OutputState
|
|
7
|
+
|
|
8
|
+
__all__ = ["DMAioAIAgent"]
|
|
9
|
+
|
|
10
|
+
if sys.platform == "win32":
|
|
11
|
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DMAioAIAgent(DMAIAgent):
|
|
15
|
+
agent_name = "AsyncAIAgent"
|
|
16
|
+
|
|
17
|
+
async def run(self, messages: list[dict[str, str]]) -> OutputState:
|
|
18
|
+
return await self._graph.ainvoke({"messages": messages})
|
|
19
|
+
|
|
20
|
+
async def _invoke_llm_node(self, state: InputState) -> InputState:
|
|
21
|
+
self._logger.debug("Run node: Invoke LLM")
|
|
22
|
+
ai_response = await self._agent.ainvoke({"messages": state.inner_state.messages})
|
|
23
|
+
state.inner_state.messages.append(ai_response)
|
|
24
|
+
return state
|
|
25
|
+
|
|
26
|
+
async def _execute_tool_node(self, state: InputState) -> InputState:
|
|
27
|
+
self._logger.debug("Run node: Execute tool")
|
|
28
|
+
tasks = []
|
|
29
|
+
for tool_call in state.inner_state.messages[-1].tool_calls:
|
|
30
|
+
tool_id = tool_call["id"]
|
|
31
|
+
tool_name = tool_call["name"]
|
|
32
|
+
tool_args = tool_call["args"]
|
|
33
|
+
|
|
34
|
+
async def tool_callback(tool_id=tool_id, tool_name=tool_name, tool_args=tool_args) -> None:
|
|
35
|
+
self._logger.debug("Invoke tool", tool_id=tool_id, tool_name=tool_name, tool_args=tool_args)
|
|
36
|
+
if tool_name in self._tool_map:
|
|
37
|
+
try:
|
|
38
|
+
tool_response = await self._tool_map[tool_name].arun(tool_args)
|
|
39
|
+
except Exception as e:
|
|
40
|
+
self._logger.error(e, tool_id=tool_id)
|
|
41
|
+
tool_response = "Tool executed with an error!"
|
|
42
|
+
else:
|
|
43
|
+
tool_response = f"Tool '{tool_name}' not found!"
|
|
44
|
+
self._logger.debug(f"Tool response:\n{tool_response}", tool_id=tool_id)
|
|
45
|
+
|
|
46
|
+
state.inner_state.context.append({"tool_name": tool_name,
|
|
47
|
+
"tool_args": json.dumps(tool_args, ensure_ascii=False),
|
|
48
|
+
"tool_response": tool_response})
|
|
49
|
+
tool_message = ToolMessage(content=str(tool_response), name=tool_name, tool_call_id=tool_id)
|
|
50
|
+
state.inner_state.messages.append(tool_message)
|
|
51
|
+
|
|
52
|
+
tasks.append(asyncio.create_task(tool_callback()))
|
|
53
|
+
|
|
54
|
+
await asyncio.gather(*tasks)
|
|
55
|
+
return state
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: dm-aioaiagent
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: This is my custom aioaiagent client
|
|
5
|
+
Home-page: https://pypi.org/project/dm-aioaiagent
|
|
6
|
+
Author: dimka4621
|
|
7
|
+
Author-email: mismartconfig@gmail.com
|
|
8
|
+
Project-URL: GitHub, https://github.com/MykhLibs/dm-aioaiagent
|
|
9
|
+
Keywords: dm aioaiagent
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Requires-Python: >=3.9
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
Requires-Dist: dm-logger==0.5.3
|
|
16
|
+
Requires-Dist: python-dotenv==1.0.1
|
|
17
|
+
Requires-Dist: pydantic==2.9.2
|
|
18
|
+
Requires-Dist: langchain==0.3.0
|
|
19
|
+
Requires-Dist: langchain-core==0.3.5
|
|
20
|
+
Requires-Dist: langgraph==0.2.23
|
|
21
|
+
Requires-Dist: langchain-community==0.3.0
|
|
22
|
+
Requires-Dist: langchain-openai==0.2.0
|
|
23
|
+
|
|
24
|
+
# DM-aioaiagent
|
|
25
|
+
|
|
26
|
+
## Urls
|
|
27
|
+
|
|
28
|
+
* [PyPI](https://pypi.org/project/dm-aioaiagent)
|
|
29
|
+
* [GitHub](https://github.com/MykhLibs/dm-aioaiagent)
|
|
30
|
+
|
|
31
|
+
### * Package contains both `asynchronous` and `synchronous` clients
|
|
32
|
+
|
|
33
|
+
## Usage
|
|
34
|
+
|
|
35
|
+
### Example of using DMAioAIAgent without tools
|
|
36
|
+
|
|
37
|
+
Analogue to `DMAioAIAgent` is the synchronous client `DMAIAgent`.
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
import asyncio
|
|
41
|
+
from dm_aioaiagent import DMAioAIAgent
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
async def main():
|
|
45
|
+
# define a system message
|
|
46
|
+
system_message = "Your custom system message with role, backstory and goal"
|
|
47
|
+
|
|
48
|
+
# define a list of tools, if you want to use them
|
|
49
|
+
tools = [...]
|
|
50
|
+
|
|
51
|
+
# define a openai model, default is "gpt-4o-mini"
|
|
52
|
+
model_name = "gpt-4o"
|
|
53
|
+
|
|
54
|
+
# create an agent
|
|
55
|
+
ai_agent = DMAioAIAgent(system_message, tools, model=model_name)
|
|
56
|
+
|
|
57
|
+
# define the conversation messages
|
|
58
|
+
messages = [
|
|
59
|
+
{"role": "user", "content": "Hello!"},
|
|
60
|
+
{"role": "ai", "content": "How can I help you?"},
|
|
61
|
+
{"role": "user", "content": "I want to know the weather in Kyiv"},
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
# start the agent
|
|
65
|
+
state = await ai_agent.run(messages)
|
|
66
|
+
|
|
67
|
+
# print answer
|
|
68
|
+
print(state["answer"])
|
|
69
|
+
|
|
70
|
+
# if you define tools, you can see the context of the tools
|
|
71
|
+
print(state["answer"])
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
if __name__ == "__main__":
|
|
75
|
+
asyncio.run(main())
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
### Set custom logger
|
|
79
|
+
|
|
80
|
+
_If you want set up custom logger_
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
from dm_aioaiagent import DMAioAIAgent
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# create custom logger
|
|
87
|
+
class MyLogger:
|
|
88
|
+
def debug(self, message):
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
def info(self, message):
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
def warning(self, message):
|
|
95
|
+
print(message)
|
|
96
|
+
|
|
97
|
+
def error(self, message):
|
|
98
|
+
print(message)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# create agent
|
|
102
|
+
ai_agent = DMAioAIAgent()
|
|
103
|
+
|
|
104
|
+
# set up custom logger for this agent
|
|
105
|
+
ai_agent.set_logger(MyLogger())
|
|
106
|
+
```
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
dm_aioaiagent/__init__.py,sha256=Hhgz18uMAnR0ezLOukoWCe1sFS4Q4LC26xbRKdu9FFw,118
|
|
2
|
+
dm_aioaiagent/ai_agent.py,sha256=R08RvKc80yyxIgX7VRioqdwhQoLAdeTFtr2AJ5mPsvc,6956
|
|
3
|
+
dm_aioaiagent/async_ai_agent.py,sha256=eGRg5IWR-lz8py_n6z4yB2cdKHwZNbcngIUxuOGN-hc,2396
|
|
4
|
+
dm_aioaiagent-0.1.0.dist-info/METADATA,sha256=2oH3lfL4sigpzhY8EgK3Tzoe1pXOoGNj0lcW1L542Mc,2519
|
|
5
|
+
dm_aioaiagent-0.1.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
6
|
+
dm_aioaiagent-0.1.0.dist-info/top_level.txt,sha256=CbasLH0KI7zA77XwT6JDCnmRascxKNGvUVV9MgYjHAU,14
|
|
7
|
+
dm_aioaiagent-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
dm_aioaiagent
|