mlchat 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ recursive-include mlchat *
mlchat-1.0.0/PKG-INFO ADDED
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlchat
3
+ Version: 1.0.0
4
+ Summary: opean+mcp调用工具
5
+ Home-page: https://www.python.org
6
+ Author: mengling
7
+ Author-email: 1321443305@qq.com
8
+ Requires-Python: >=3.8
9
+ Requires-Dist: loguru
10
+ Requires-Dist: openai
11
+ Requires-Dist: mcp>=1.25.0
12
+ Requires-Dist: pydantic
13
+ Requires-Dist: json-repair
14
+ Dynamic: author
15
+ Dynamic: author-email
16
+ Dynamic: home-page
17
+ Dynamic: requires-dist
18
+ Dynamic: requires-python
19
+ Dynamic: summary
@@ -0,0 +1,229 @@
1
+ import asyncio
2
+ import anyio
3
+ from pydantic import BaseModel
4
+ import json
5
+ import warnings
6
+ from mcp import ClientSession, StdioServerParameters
7
+ from openai import AsyncOpenAI, NOT_GIVEN
8
+ from loguru import logger as _log
9
+ from typing import AsyncIterator
10
+ from contextlib import AsyncExitStack
11
+ from mcp.client.stdio import stdio_client
12
+ from mcp.client.sse import sse_client
13
+ from mcp.client.streamable_http import streamable_http_client
14
+ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk,ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
15
+ from .extension import DialogueMessager, aget_trace, aget_title, ToolCeil, Messager
16
+
17
+
18
+ class _ToolCallStream(BaseModel):
19
+ id: str
20
+ name: str
21
+ arguments: str =''
22
+
23
+ def arg_add(self, data: ChoiceDeltaToolCall):
24
+ if not data: return
25
+ if not data.id:
26
+ self.arguments += data.function.arguments or ''
27
+ else:
28
+ self.id = data.id
29
+ if data.function.name: self.name = data.function.name
30
+ self.arguments = data.function.arguments or ''
31
+
32
+ def to_cdtc(self, index:int)->ChoiceDeltaToolCall:
33
+ return ChoiceDeltaToolCall(index=index, id=self.id,
34
+ function=ChoiceDeltaToolCallFunction(arguments=self.arguments or '{}', name=self.name),
35
+ type='function')
36
+
37
+
38
+ class MCPClient:
39
+ def __init__(self, base_url, model, api_key='EMPTY', ban_tools:list=None, logger=None):
40
+ self.exit_stack = AsyncExitStack()
41
+ self.model = model
42
+ self.aclient = AsyncOpenAI(api_key=api_key, base_url=base_url)
43
+ self.server_session:dict[str, ClientSession] = {} # 存储多个服务端会话
44
+ self.server_tool:dict[str, list[str]] = {}
45
+ self.available_tool:dict[str, ToolCeil] = {}
46
+ self.tool_session = {}
47
+ self.ban_tools=ban_tools
48
+ self.logger=logger or _log
49
+
50
+ @property
51
+ def tool_keys(self)->list[str]:
52
+ return list(self.available_tool.keys())
53
+
54
+ @property
55
+ def tool_values(self)->list[ToolCeil]:
56
+ return list(self.available_tool.values())
57
+
58
+ async def _connect_to_server(self, server_name, session:ClientSession):
59
+ await session.initialize()
60
+ self.server_session[server_name]=session
61
+ # 更新工具映射
62
+ response = await session.list_tools()
63
+ server_tools = []
64
+ for tool in response.tools:
65
+ if self.ban_tools and tool.name in self.ban_tools: continue
66
+ self.logger.debug({"name": tool.name,
67
+ "description": tool.description,
68
+ "parameters": tool.inputSchema})
69
+ # 构建统一的工具列表
70
+ key = tool.title or tool.name
71
+ self.available_tool[key] = ToolCeil(name=tool.name, description=tool.description, parameters=tool.inputSchema)
72
+ self.tool_session[tool.name] = session
73
+ server_tools.append(key)
74
+ self.server_tool[server_name] = server_tools
75
+ self.logger.info(f"已连接到MCP服务器 - {server_name}\n{server_tools}")
76
+
77
+ async def connect_to_stdio_server(self, server_name:str, command:str, *args: str, env:dict=None):
78
+ server_params = StdioServerParameters(command=command, args=args, env=env)
79
+ read_stream, write_stream = await self.exit_stack.enter_async_context(stdio_client(server_params))
80
+ session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
81
+ return await self._connect_to_server(server_name, session)
82
+
83
+ async def connect_to_http_server(self, server_name:str, url:str, **kwargs):
84
+ read_stream, write_stream,_ = await self.exit_stack.enter_async_context(streamable_http_client(url, **kwargs))
85
+ session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
86
+ return await self._connect_to_server(server_name, session)
87
+
88
+ async def connect_to_sse_server(self, server_name:str, url:str, **kwargs):
89
+ read_stream, write_stream = await self.exit_stack.enter_async_context(sse_client(url, **kwargs))
90
+ session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
91
+ return await self._connect_to_server(server_name, session)
92
+
93
+ async def connect_to_config(self, config_or_path:dict|str):
94
+ if isinstance(config_or_path, str):
95
+ with open(config_or_path, encoding="utf-8") as f:
96
+ config = json.load(f)
97
+ else:
98
+ config = config_or_path
99
+ for server_name, server_config in config['mcpServers'].items():
100
+ if server_config.get("command"):
101
+ await self.connect_to_stdio_server(server_name, server_config["command"], *server_config.get("args",[]),
102
+ env=server_config.get("env"))
103
+ elif server_config.get("url"):
104
+ if server_config.get("type", 'http').lower() == 'sse':
105
+ await self.connect_to_sse_server(server_name, server_config["url"])
106
+ else:
107
+ await self.connect_to_http_server(server_name, server_config["url"])
108
+ else:
109
+ warnings.warn(f"未指定command或url, 无法连接到 MCP 服务器 {server_name}")
110
+
111
+ def set_tool_description(self, title_or_name:str, description:str):
112
+ """重新设置工具描述"""
113
+ self.available_tool[title_or_name].description = description
114
+
115
+ async def _call_tool(self, tool_call):
116
+ tool_name = tool_call.function.name
117
+ tool_args = json.loads(tool_call.function.arguments)
118
+ # 根据工具名称找到对应的服务端
119
+ session:ClientSession = self.tool_session[tool_name]
120
+ self.logger.info(f"tool: {tool_name} args: {tool_args}")
121
+ try:
122
+ result = await session.call_tool(tool_name, tool_args)
123
+ except anyio.ClosedResourceError:
124
+ self.logger.info(f'{tool_name} 重新连接...')
125
+ await session.initialize()
126
+ result = await session.call_tool(tool_name, tool_args)
127
+ return tool_call.id, tool_name, tool_args, result.content
128
+
129
+ def get_tool_choice(self, *tools:str):
130
+ """部分模型并不支持单个或多个工具指定"""
131
+ if not tools:
132
+ return 'auto'
133
+ elif len(tools)==1:
134
+ return {
135
+ "type": "function",
136
+ "function": {"name": self.available_tool[tools[0]].name}
137
+ }
138
+ else:
139
+ return {
140
+ 'type': 'allowed_tools',
141
+ 'allowed_tools': {
142
+ 'mode':'required',
143
+ 'tools':[{ "type": "function", "function": { "name": self.available_tool[tool].name }} for tool in tools]
144
+ }
145
+ }
146
+
147
+ async def chat(self, messages:list[Messager|dict], max_tool_num=3, use_tool:str|list[str]=None, **kwargs)->AsyncIterator[Messager]:
148
+ """调用大模型处理用户查询,并根据返回的 tools 列表调用对应工具。
149
+ 支持多次工具调用,直到所有工具调用完成。
150
+ 流式输出
151
+ Args:
152
+ query (str): 查询
153
+ max_num (int, optional): 最大工具调用次数. Defaults to 3.
154
+ user_tool (str): 强制调用工具title或name, 没有则该参数无效
155
+ Yields:
156
+ str: 结果词语
157
+ """
158
+ the_messages = []
159
+ for message in messages:
160
+ mer = Messager(**message) if isinstance(message, dict) else message
161
+ the_messages.append(mer.model_dump(exclude_none=True))
162
+ # 循环处理工具调用
163
+ for i in range(max_tool_num+1):
164
+ tool_choice = 'auto'
165
+ message = Messager(role="assistant")
166
+ # 超出最大调用工具限制, 最后一次不再加载工具
167
+ if i<max_tool_num:
168
+ # 仅首次会调用指定工具
169
+ if i==0 and use_tool:
170
+ available_tools, uses=[], []
171
+ for ut in ([use_tool] if isinstance(use_tool, str) else use_tool):
172
+ tool = self.available_tool.get(ut)
173
+ if tool:
174
+ available_tools.append(tool)
175
+ uses.append(ut)
176
+ tool_choice = self.get_tool_choice(*uses)
177
+ else:
178
+ available_tools = self.tool_values
179
+ else:
180
+ available_tools = None
181
+ tcdt:dict[int, _ToolCallStream] = {}
182
+ chunk:ChatCompletionChunk
183
+ async for chunk in await self.aclient.chat.completions.create(
184
+ model=self.model,
185
+ messages=the_messages,
186
+ tools=available_tools and [tool.to_tool() for tool in available_tools],
187
+ tool_choice=tool_choice if available_tools else NOT_GIVEN,
188
+ **kwargs,
189
+ stream=True
190
+ ):
191
+ if chunk.choices:
192
+ message.chunk = chunk.choices[0].delta.content or ''
193
+ tool_calls = chunk.choices[0].delta.tool_calls
194
+ else:
195
+ message.chunk, tool_calls = '', None
196
+ if tool_calls:
197
+ for tool_call in tool_calls:
198
+ if tcdt.get(tool_call.index) is None:
199
+ tcdt[tool_call.index] = _ToolCallStream(id=tool_call.id,name=tool_call.function.name)
200
+ tcdt[tool_call.index].arg_add(tool_call)
201
+ message.content += message.chunk
202
+ yield message
203
+ message.tool_calls = [data.to_cdtc(index) for index,data in tcdt.items()] or None
204
+ yield message
205
+ the_messages.append(message.model_dump(exclude_none=True))
206
+ # 没有工具调用则结束
207
+ if not message.tool_calls: break
208
+ tool_choice = 'auto'
209
+ # 执行实际工具调用
210
+ for tool_call_id, tool_name, tool_args, rcs in await asyncio.gather(*[self._call_tool(tool_call) for tool_call in message.tool_calls]):
211
+ # 将工具调用的结果添加到 messages 中, 暂时只处理文本返回内容
212
+ for rc in rcs:
213
+ tmessage = Messager(role="tool", content=rc.text, name=tool_name, args=tool_args, tool_call_id=tool_call_id)
214
+ yield tmessage
215
+ the_messages.append(tmessage.model_dump(exclude_none=True))
216
+
217
+ async def close(self):
218
+ await self.exit_stack.aclose()
219
+ self.server_session.clear()
220
+
221
+ async def get_summary_title(self, historys:list[DialogueMessager])->str:
222
+ """根据最近的历史记录生成总结性的标题"""
223
+ if not historys: return ''
224
+ return await aget_title(self.aclient, self.model, historys)
225
+
226
+ async def get_traces(self, historys:list[DialogueMessager], trace_num:int=3)->list[str]:
227
+ """根据最近的历史记录生成追问"""
228
+ if not historys: return []
229
+ return await aget_trace(self.aclient, self.model, historys, trace_num=trace_num)
@@ -0,0 +1,122 @@
1
+ from functools import cached_property
2
+ from typing import Literal
3
+ from openai import AsyncOpenAI
4
+ from pydantic import BaseModel, field_validator
5
+ from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
6
+
7
+
8
+ class DialogueMessager(BaseModel):
9
+ role: Literal['assistant', 'user']
10
+ content: str = ''
11
+
12
+ class Messager(DialogueMessager):
13
+ role: Literal['developer', 'system', 'assistant', 'user', 'tool']
14
+ chunk: str|None = None
15
+ name: str|None = None
16
+ args: dict|list|None = None
17
+ tool_call_id: str|None = None
18
+ tool_calls: list[ChoiceDeltaToolCall]|None = None
19
+
20
+ @field_validator('content', mode='before')
21
+ def _strip(cls, value:str, data):
22
+ return value.strip()
23
+
24
+ @property
25
+ def is_tool_messager(self)->bool:
26
+ return self.role == 'tool'
27
+
28
+ @cached_property
29
+ def is_assistant(self):
30
+ return self.role == 'assistant'
31
+
32
+ @property
33
+ def is_dialogue(self):
34
+ return bool(self.content and self.role in ('assistant','user'))
35
+
36
+ @property
37
+ def debug_log(self)->str:
38
+ if self.role == 'assistant' and not self.content:
39
+ log = '; '.join(f'{tool.function.name} {tool.function.arguments}'for tool in self.tool_calls)
40
+ else:
41
+ log = self.content.split('\n')
42
+ if len(log)>1 or len(log[0])>100:
43
+ log = log[0][:100]+'...'
44
+ else:
45
+ log = log[0]
46
+ return f"{self.role}: {log}"
47
+
48
+ class ToolCeil(BaseModel):
49
+ name: str
50
+ description: str
51
+ parameters: dict
52
+
53
+ def to_tool(self)->dict:
54
+ return {
55
+ "type": "function",
56
+ "function": {
57
+ "name": self.name,
58
+ "description": self.description,
59
+ "parameters": self.parameters
60
+ }
61
+ }
62
+
63
+ def _get_chat_history_text(messagers:list[DialogueMessager])->str:
64
+ return '\n'.join(f'{messager.role}: {messager.content}' for messager in messagers if messager.content)
65
+
66
+ async def _get_result(aclient:AsyncOpenAI, model:str, prompt:str)->dict:
67
+ response = await aclient.chat.completions.create(model=model,
68
+ messages=[{'role': 'user', 'content': prompt}],
69
+ stream=False)
70
+ import json_repair
71
+ return json_repair.loads(response.choices[0].message.content)
72
+
73
+ _trace_tp = '''### Task:
74
+ Suggest {trace_num} relevant follow-up questions or prompts that the user might naturally ask next in this conversation as a **user**, based on the chat history, to help continue or deepen the discussion.
75
+ ### Guidelines:
76
+ - Write all follow-up questions from the user’s point of view, directed to the assistant.
77
+ - Make questions concise, clear, and directly related to the discussed topic(s).
78
+ - Only suggest follow-ups that make sense given the chat content and do not repeat what was already covered.
79
+ - If the conversation is very short or not specific, suggest more general (but relevant) follow-ups the user might ask.
80
+ - Use the conversation's primary language; default to English if multilingual.
81
+ - Response must be a JSON array of strings, no extra text or formatting.
82
+ ### Output:
83
+ JSON format: {{ "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] }}
84
+ ### Chat History:
85
+ <chat_history>
86
+ {chat_history}
87
+ </chat_history>'''
88
+
89
+ async def aget_trace(aclient:AsyncOpenAI, model:str, messagers:list[DialogueMessager], trace_num:int=3)->list[str]:
90
+ """获取根据历史对话生成追问"""
91
+ result = await _get_result(aclient, model, _trace_tp.format(trace_num=trace_num, chat_history=_get_chat_history_text(messagers)))
92
+ return result.get('follow_ups', [])
93
+
94
+ _title_tp ='''### Task:
95
+ Generate a concise, 3-5 word title with an emoji summarizing the chat history.
96
+ ### Guidelines:
97
+ - The title should clearly represent the main theme or subject of the conversation.
98
+ - Use emojis that enhance understanding of the topic, but avoid quotation marks or special formatting.
99
+ - Write the title in the chat's primary language; default to English if multilingual.
100
+ - Prioritize accuracy over excessive creativity; keep it clear and simple.
101
+ - Your entire response must consist solely of the JSON object, without any introductory or concluding text.
102
+ - The output must be a single, raw JSON object, without any markdown code fences or other encapsulating text.
103
+ - Ensure no conversational text, affirmations, or explanations precede or follow the raw JSON output, as this will cause direct parsing failure.
104
+ ### Output:
105
+ JSON format: {{ "title": "your concise title here" }}
106
+ ### Examples:
107
+ - { "title": "📉 Stock Market Trends" },
108
+ - { "title": "🍪 Perfect Chocolate Chip Recipe" },
109
+ - { "title": "Evolution of Music Streaming" },
110
+ - { "title": "Remote Work Productivity Tips" },
111
+ - { "title": "Artificial Intelligence in Healthcare" },
112
+ - { "title": "🎮 Video Game Development Insights" }
113
+ ### Chat History:
114
+ <chat_history>
115
+ {chat_history}
116
+ </chat_history>'''
117
+
118
+ async def aget_title(aclient:AsyncOpenAI, model:str, messagers:list[DialogueMessager])->str:
119
+ """获取根据历史对话生成总结性的标题"""
120
+ result = await _get_result(aclient, model, _title_tp.format(chat_history=_get_chat_history_text(messagers)))
121
+ return result.get('title', '')
122
+
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlchat
3
+ Version: 1.0.0
4
+ Summary: opean+mcp调用工具
5
+ Home-page: https://www.python.org
6
+ Author: mengling
7
+ Author-email: 1321443305@qq.com
8
+ Requires-Python: >=3.8
9
+ Requires-Dist: loguru
10
+ Requires-Dist: openai
11
+ Requires-Dist: mcp>=1.25.0
12
+ Requires-Dist: pydantic
13
+ Requires-Dist: json-repair
14
+ Dynamic: author
15
+ Dynamic: author-email
16
+ Dynamic: home-page
17
+ Dynamic: requires-dist
18
+ Dynamic: requires-python
19
+ Dynamic: summary
@@ -0,0 +1,9 @@
1
+ MANIFEST.in
2
+ setup.py
3
+ mlchat/__init__.py
4
+ mlchat/extension.py
5
+ mlchat.egg-info/PKG-INFO
6
+ mlchat.egg-info/SOURCES.txt
7
+ mlchat.egg-info/dependency_links.txt
8
+ mlchat.egg-info/requires.txt
9
+ mlchat.egg-info/top_level.txt
@@ -0,0 +1,5 @@
1
+ loguru
2
+ openai
3
+ mcp>=1.25.0
4
+ pydantic
5
+ json-repair
@@ -0,0 +1 @@
1
+ mlchat
mlchat-1.0.0/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
mlchat-1.0.0/setup.py ADDED
@@ -0,0 +1,24 @@
1
+ import setuptools
2
+ import json
3
+
4
+ # 参考:https://www.jb51.net/article/202841.htm
5
+ # 打包需将此文件和MANIFEST.in文件置于mengling_tool包同目录
6
+ # 包中必须有__init__.py文件存在才能在pip时正常导入
7
+ # pip install --upgrade setuptools wheel -i https://pypi.douban.com/simple
8
+ # python setup.py sdist bdist_wheel
9
+ # pip install twine
10
+ # twine upload dist/*
11
+ '''
12
+ python setup.py sdist bdist_wheel
13
+ twine upload -u user -p password dist/*
14
+ '''
15
+
16
+ name = 'mlchat'
17
+
18
+ with open('../config.json', encoding='utf-8') as file:
19
+ definfo, opmap = json.loads(file.read())
20
+ setuptools.setup(
21
+ name=name,
22
+ packages=setuptools.find_packages(),
23
+ **{**definfo, **opmap[name]}
24
+ )