bisheng-langchain 0.3.0b0__py3-none-any.whl → 0.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,8 @@
1
1
  """Chain that runs an arbitrary python function."""
2
2
  import logging
3
- import os
4
3
  from typing import Callable, Dict, Optional
5
4
 
6
- import openai
5
+ import httpx
7
6
  from autogen import AssistantAgent
8
7
  from langchain.base_language import BaseLanguageModel
9
8
 
@@ -45,15 +44,6 @@ Reply "TERMINATE" in the end when everything is done.
45
44
  ):
46
45
  is_termination_msg = (is_termination_msg if is_termination_msg is not None else
47
46
  (lambda x: x.get('content') == 'TERMINATE'))
48
- if openai_proxy:
49
- openai.proxy = {'https': openai_proxy, 'http': openai_proxy}
50
- else:
51
- openai.proxy = None
52
- if openai_api_base:
53
- openai.api_base = openai_api_base
54
- else:
55
- openai.api_base = os.environ.get('OPENAI_API_BASE', 'https://api.openai.com/v1')
56
-
57
47
  config_list = [
58
48
  {
59
49
  'model': model_name,
@@ -63,17 +53,24 @@ Reply "TERMINATE" in the end when everything is done.
63
53
  'api_version': api_version,
64
54
  },
65
55
  ]
66
- llm_config = {
67
- 'seed': 42, # change the seed for different trials
68
- 'temperature': temperature,
69
- 'config_list': config_list,
70
- 'request_timeout': 120,
71
- }
56
+ if openai_proxy:
57
+ config_list[0]['http_client'] = httpx.Client(proxies=openai_proxy)
58
+ config_list[0]['http_async_client'] = httpx.AsyncClient(proxies=openai_proxy)
59
+
60
+ if llm:
61
+ llm_config = llm
62
+
63
+ else:
64
+ llm_config = {
65
+ 'seed': 42, # change the seed for different trials
66
+ 'temperature': temperature,
67
+ 'config_list': config_list,
68
+ 'request_timeout': 120,
69
+ }
72
70
 
73
71
  super().__init__(
74
72
  name,
75
73
  llm_config=llm_config,
76
- llm=llm,
77
74
  system_message=system_message,
78
75
  is_termination_msg=is_termination_msg,
79
76
  max_consecutive_auto_reply=None,
@@ -19,7 +19,6 @@ class AutoGenCustomRole(ConversableAgent):
19
19
  human_input_mode='NEVER',
20
20
  code_execution_config=False,
21
21
  llm_config=False,
22
- llm=None,
23
22
  **kwargs)
24
23
  self.func = func
25
24
  self.coroutine = coroutine
@@ -1,9 +1,8 @@
1
1
  """Chain that runs an arbitrary python function."""
2
2
  import logging
3
- import os
4
3
  from typing import List, Optional
5
4
 
6
- import openai
5
+ import httpx
7
6
  from autogen import Agent, GroupChat, GroupChatManager
8
7
  from langchain.base_language import BaseLanguageModel
9
8
 
@@ -20,6 +19,7 @@ class AutoGenGroupChatManager(GroupChatManager):
20
19
  self,
21
20
  agents: List[Agent],
22
21
  max_round: int = 50,
22
+ llm: Optional[BaseLanguageModel] = None,
23
23
  model_name: Optional[str] = 'gpt-4-0613',
24
24
  openai_api_key: Optional[str] = '',
25
25
  openai_api_base: Optional[str] = '',
@@ -28,7 +28,6 @@ class AutoGenGroupChatManager(GroupChatManager):
28
28
  api_type: Optional[str] = None, # when llm_flag=True, need to set
29
29
  api_version: Optional[str] = None, # when llm_flag=True, need to set
30
30
  name: Optional[str] = 'chat_manager',
31
- llm: Optional[BaseLanguageModel] = None,
32
31
  system_message: Optional[str] = 'Group chat manager.',
33
32
  **kwargs,
34
33
  ):
@@ -37,15 +36,6 @@ class AutoGenGroupChatManager(GroupChatManager):
37
36
 
38
37
  groupchat = GroupChat(agents=agents, messages=[], max_round=max_round)
39
38
 
40
- if openai_proxy:
41
- openai.proxy = {'https': openai_proxy, 'http': openai_proxy}
42
- else:
43
- openai.proxy = None
44
- if openai_api_base:
45
- openai.api_base = openai_api_base
46
- else:
47
- openai.api_base = os.environ.get('OPENAI_API_BASE', 'https://api.openai.com/v1')
48
-
49
39
  config_list = [
50
40
  {
51
41
  'model': model_name,
@@ -55,17 +45,23 @@ class AutoGenGroupChatManager(GroupChatManager):
55
45
  'api_version': api_version,
56
46
  },
57
47
  ]
58
- llm_config = {
59
- 'seed': 42, # change the seed for different trials
60
- 'temperature': temperature,
61
- 'config_list': config_list,
62
- 'request_timeout': 120,
63
- }
48
+ if openai_proxy:
49
+ config_list[0]['http_client'] = httpx.Client(proxies=openai_proxy)
50
+ config_list[0]['http_async_client'] = httpx.AsyncClient(proxies=openai_proxy)
51
+
52
+ if llm:
53
+ llm_config = llm
54
+ else:
55
+ llm_config = {
56
+ 'seed': 42, # change the seed for different trials
57
+ 'temperature': temperature,
58
+ 'config_list': config_list,
59
+ 'request_timeout': 120,
60
+ }
64
61
 
65
62
  super().__init__(
66
63
  groupchat=groupchat,
67
64
  llm_config=llm_config,
68
- llm=llm,
69
65
  name=name,
70
66
  system_message=system_message,
71
67
  )
@@ -1,9 +1,8 @@
1
1
  """Chain that runs an arbitrary python function."""
2
2
  import logging
3
- import os
4
3
  from typing import Callable, Dict, Optional
5
4
 
6
- import openai
5
+ import httpx
7
6
  from autogen import UserProxyAgent
8
7
  from langchain.base_language import BaseLanguageModel
9
8
 
@@ -47,14 +46,6 @@ class AutoGenUserProxyAgent(UserProxyAgent):
47
46
  code_execution_config = False
48
47
 
49
48
  if llm_flag:
50
- if openai_proxy:
51
- openai.proxy = {'https': openai_proxy, 'http': openai_proxy}
52
- else:
53
- openai.proxy = None
54
- if openai_api_base:
55
- openai.api_base = openai_api_base
56
- else:
57
- openai.api_base = os.environ.get('OPENAI_API_BASE', 'https://api.openai.com/v1')
58
49
  config_list = [
59
50
  {
60
51
  'model': model_name,
@@ -64,12 +55,19 @@ class AutoGenUserProxyAgent(UserProxyAgent):
64
55
  'api_version': api_version,
65
56
  },
66
57
  ]
67
- llm_config = {
68
- 'seed': 42, # change the seed for different trials
69
- 'temperature': temperature,
70
- 'config_list': config_list,
71
- 'request_timeout': 120,
72
- }
58
+ if openai_proxy:
59
+ config_list[0]['http_client'] = httpx.Client(proxies=openai_proxy)
60
+ config_list[0]['http_async_client'] = httpx.AsyncClient(proxies=openai_proxy)
61
+
62
+ if llm:
63
+ llm_config = llm
64
+ else:
65
+ llm_config = {
66
+ 'seed': 42, # change the seed for different trials
67
+ 'temperature': temperature,
68
+ 'config_list': config_list,
69
+ 'request_timeout': 120,
70
+ }
73
71
  else:
74
72
  llm_config = False
75
73
 
@@ -80,7 +78,6 @@ class AutoGenUserProxyAgent(UserProxyAgent):
80
78
  function_map=function_map,
81
79
  code_execution_config=code_execution_config,
82
80
  llm_config=llm_config,
83
- llm=llm,
84
81
  system_message=system_message)
85
82
 
86
83
 
@@ -109,7 +106,6 @@ class AutoGenUser(UserProxyAgent):
109
106
  human_input_mode=human_input_mode,
110
107
  code_execution_config=code_execution_config,
111
108
  llm_config=llm_config,
112
- llm=None,
113
109
  system_message=system_message)
114
110
 
115
111
 
@@ -140,5 +136,4 @@ class AutoGenCoder(UserProxyAgent):
140
136
  function_map=function_map,
141
137
  code_execution_config=code_execution_config,
142
138
  llm_config=llm_config,
143
- llm=None,
144
139
  system_message=system_message)
@@ -163,7 +163,7 @@ class BaseHostChatLLM(BaseChatModel):
163
163
  values[
164
164
  'host_base_url'] = f"{values['host_base_url']}/{values['model_name']}/infer"
165
165
  except Exception:
166
- raise Exception(f'Update Decoupled status faild for model {model}')
166
+ raise Exception(f'Update Decoupled status failed for model {model}')
167
167
 
168
168
  try:
169
169
  if values['headers']:
@@ -3,7 +3,11 @@ import os
3
3
  import re
4
4
 
5
5
  import httpx
6
- from bisheng_langchain.gpts.prompts import ASSISTANT_PROMPT_OPT, BREIF_DES_PROMPT, OPENDIALOG_PROMPT
6
+ from bisheng_langchain.gpts.prompts import (
7
+ ASSISTANT_PROMPT_OPT,
8
+ BREIF_DES_PROMPT,
9
+ OPENDIALOG_PROMPT,
10
+ )
7
11
  from langchain_core.language_models.base import LanguageModelLike
8
12
  from langchain_openai.chat_models import ChatOpenAI
9
13
  from loguru import logger
@@ -48,16 +52,13 @@ def optimize_assistant_prompt(
48
52
  Returns:
49
53
  assistant_prompt(str):
50
54
  """
51
- chain = ({
52
- 'assistant_name': lambda x: x['assistant_name'],
53
- 'assistant_description': lambda x: x['assistant_description'],
54
- }
55
- | ASSISTANT_PROMPT_OPT
56
- | llm)
57
- chain_output = chain.invoke({
58
- 'assistant_name': assistant_name,
59
- 'assistant_description': assistant_description,
60
- })
55
+ chain = ASSISTANT_PROMPT_OPT | llm
56
+ chain_output = chain.invoke(
57
+ {
58
+ 'assistant_name': assistant_name,
59
+ 'assistant_description': assistant_description,
60
+ }
61
+ )
61
62
  response = chain_output.content
62
63
  assistant_prompt = parse_markdown(response)
63
64
  return assistant_prompt
@@ -67,17 +68,15 @@ def generate_opening_dialog(
67
68
  llm: LanguageModelLike,
68
69
  description: str,
69
70
  ) -> str:
70
- chain = ({
71
- 'description': lambda x: x['description'],
72
- }
73
- | OPENDIALOG_PROMPT
74
- | llm)
71
+ chain = OPENDIALOG_PROMPT | llm
75
72
  time = 0
76
73
  while time <= 3:
77
74
  try:
78
- chain_output = chain.invoke({
79
- 'description': description,
80
- })
75
+ chain_output = chain.invoke(
76
+ {
77
+ 'description': description,
78
+ }
79
+ )
81
80
  output = parse_json(chain_output.content)
82
81
  output = json.loads(output)
83
82
  opening_lines = output[0]['开场白']
@@ -101,20 +100,22 @@ def generate_breif_description(
101
100
  llm: LanguageModelLike,
102
101
  description: str,
103
102
  ) -> str:
104
- chain = ({
105
- 'description': lambda x: x['description'],
106
- }
107
- | BREIF_DES_PROMPT
108
- | llm)
109
- chain_output = chain.invoke({
110
- 'description': description,
111
- })
103
+ chain = BREIF_DES_PROMPT | llm
104
+ chain_output = chain.invoke(
105
+ {
106
+ 'description': description,
107
+ }
108
+ )
112
109
  breif_description = chain_output.content
113
110
  breif_description = breif_description.strip()
114
111
  return breif_description
115
112
 
116
113
 
117
114
  if __name__ == '__main__':
115
+ from dotenv import load_dotenv
116
+
117
+ load_dotenv('/app/.env', override=True)
118
+
118
119
  httpx_client = httpx.Client(proxies=os.getenv('OPENAI_PROXY'))
119
120
  llm = ChatOpenAI(model='gpt-4-0125-preview', temperature=0.01, http_client=httpx_client)
120
121
  # llm = ChatQWen(model="qwen1.5-72b-chat", temperature=0.01, api_key=os.getenv('QWEN_API_KEY'))
@@ -1,6 +1,9 @@
1
1
  from bisheng_langchain.gpts.prompts.select_tools_prompt import HUMAN_MSG, SYS_MSG
2
- from langchain.prompts import (ChatPromptTemplate, HumanMessagePromptTemplate,
3
- SystemMessagePromptTemplate)
2
+ from langchain.prompts import (
3
+ ChatPromptTemplate,
4
+ HumanMessagePromptTemplate,
5
+ SystemMessagePromptTemplate,
6
+ )
4
7
  from langchain_core.language_models.base import LanguageModelLike
5
8
  from pydantic import BaseModel
6
9
 
@@ -31,19 +34,15 @@ class ToolSelector:
31
34
  HumanMessagePromptTemplate.from_template(self.human_message),
32
35
  ]
33
36
 
34
- chain = ({
35
- 'tool_pool': lambda x: x['tool_pool'],
36
- 'task_name': lambda x: x['task_name'],
37
- 'task_description': lambda x: x['task_description'],
38
- }
39
- | ChatPromptTemplate.from_messages(messages)
40
- | self.llm)
41
-
42
- chain_output = chain.invoke({
43
- 'tool_pool': tool_pool,
44
- 'task_name': task_name,
45
- 'task_description': task_description,
46
- })
37
+ chain = ChatPromptTemplate.from_messages(messages) | self.llm
38
+
39
+ chain_output = chain.invoke(
40
+ {
41
+ 'tool_pool': tool_pool,
42
+ 'task_name': task_name,
43
+ 'task_description': task_description,
44
+ }
45
+ )
47
46
 
48
47
  try:
49
48
  all_tool_name = set([tool.tool_name for tool in self.tools])
@@ -1,7 +1,11 @@
1
+ import json
2
+ import os
1
3
  import warnings
2
4
  from typing import Any, Callable, Dict, List, Optional, Tuple
3
5
 
4
6
  import httpx
7
+ import pandas as pd
8
+ import pymysql
5
9
  from bisheng_langchain.gpts.tools.api_tools import ALL_API_TOOLS
6
10
  from bisheng_langchain.gpts.tools.bing_search.tool import BingSearchRun
7
11
  from bisheng_langchain.gpts.tools.calculator.tool import calculator
@@ -13,6 +17,7 @@ from bisheng_langchain.gpts.tools.dalle_image_generator.tool import (
13
17
  DallEImageGenerator,
14
18
  )
15
19
  from bisheng_langchain.gpts.tools.get_current_time.tool import get_current_time
20
+ from dotenv import load_dotenv
16
21
  from langchain_community.tools.arxiv.tool import ArxivQueryRun
17
22
  from langchain_community.tools.bearly.tool import BearlyInterpreterTool
18
23
  from langchain_community.utilities.arxiv import ArxivAPIWrapper
@@ -54,12 +59,14 @@ def _get_bing_search(**kwargs: Any) -> BaseTool:
54
59
 
55
60
  def _get_dalle_image_generator(**kwargs: Any) -> Tool:
56
61
  openai_api_key = kwargs.get('openai_api_key')
62
+ openai_api_base = kwargs.get('openai_api_base')
57
63
  http_async_client = httpx.AsyncClient(proxies=kwargs.get('openai_proxy'))
58
64
  httpc_client = httpx.Client(proxies=kwargs.get('openai_proxy'))
59
65
  return DallEImageGenerator(
60
66
  api_wrapper=DallEAPIWrapper(
61
67
  model='dall-e-3',
62
68
  api_key=openai_api_key,
69
+ base_url=openai_api_base,
63
70
  http_client=httpc_client,
64
71
  http_async_client=http_async_client,
65
72
  )
@@ -78,7 +85,7 @@ def _get_native_code_interpreter(**kwargs: Any) -> Tool:
78
85
  _EXTRA_PARAM_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[Optional[str]], List[Optional[str]]]] = { # type: ignore
79
86
  'dalle_image_generator': (_get_dalle_image_generator, ['openai_api_key', 'openai_proxy'], []),
80
87
  'bing_search': (_get_bing_search, ['bing_subscription_key', 'bing_search_url'], []),
81
- 'code_interpreter': (_get_native_code_interpreter, ["minio"], ['files']),
88
+ 'bisheng_code_interpreter': (_get_native_code_interpreter, ["minio"], ['files']),
82
89
  }
83
90
 
84
91
  _API_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {**ALL_API_TOOLS} # type: ignore
@@ -159,3 +166,46 @@ def load_tools(
159
166
  def get_all_tool_names() -> List[str]:
160
167
  """Get a list of all possible tool names."""
161
168
  return list(_ALL_TOOLS.keys())
169
+
170
+
171
+ def get_tool_table():
172
+
173
+ load_dotenv('.sql_env', override=True)
174
+ db = pymysql.connect(
175
+ host=os.getenv('MYSQL_HOST'),
176
+ user=os.getenv('MYSQL_USER'),
177
+ password=os.getenv('MYSQL_PASSWORD'),
178
+ database=os.getenv('MYSQL_DATABASE'),
179
+ port=int(os.getenv('MYSQL_PORT')),
180
+ )
181
+ cursor = db.cursor()
182
+ cursor.execute("SELECT name, t.desc, tool_key, extra FROM t_gpts_tools as t;")
183
+ results = cursor.fetchall()
184
+ db.close()
185
+
186
+ df = pd.DataFrame(
187
+ columns=[
188
+ '前端工具名',
189
+ '前端工具描述',
190
+ 'tool_key',
191
+ 'tool参数配置',
192
+ 'function_name',
193
+ 'function_description',
194
+ 'function_args',
195
+ ]
196
+ )
197
+ for i, result in enumerate(results):
198
+ name, desc, tool_key, extra = result
199
+ if not extra:
200
+ extra = '{}'
201
+ tool_func = load_tools({tool_key: json.loads(extra)})[0]
202
+
203
+ df.loc[i, '前端工具名'] = name
204
+ df.loc[i, '前端工具描述'] = desc
205
+ df.loc[i, 'tool_key'] = tool_key
206
+ df.loc[i, 'tool参数配置'] = extra
207
+ df.loc[i, 'function_name'] = tool_func.name
208
+ df.loc[i, 'function_description'] = tool_func.description
209
+ df.loc[i, 'function_args'] = f"{tool_func.args_schema.schema()['properties']}"
210
+
211
+ return df
@@ -6,7 +6,7 @@ from langchain_core.prompts.chat import (
6
6
  )
7
7
 
8
8
  system_template = """
9
- 你是一个生成开场白和预置问题的助手。接下来,你会收到一段关于任务助手的描述,你需要带入描述中的角色,以描述中的角色身份生成一段开场白,同时你还需要以描述中的角色身份生成几个预置问题。输出格式如下:
9
+ 你是一个生成开场白和预置问题的助手。接下来,你会收到一段关于任务助手的描述,你需要带入描述中的角色,以描述中的角色身份生成一段开场白,同时你还需要站在用户的角度生成几个用户可能的提问。输出格式如下:
10
10
  [
11
11
  {{
12
12
  "开场白": "开场白内容",
@@ -40,7 +40,7 @@ _MACRO_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {
40
40
 
41
41
  _tmp_flow = ['knowledge_retrieve']
42
42
  _TMP_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {
43
- f'flow_{name}': (FlowTools.get_api_tool, ['collection_id'])
43
+ f'flow_{name}': (FlowTools.get_api_tool, ['collection_id', 'description'])
44
44
  for name in _tmp_flow
45
45
  }
46
46
  ALL_API_TOOLS = {}
@@ -49,10 +49,12 @@ class APIToolBase(BaseModel):
49
49
  request_timeout=timeout)
50
50
  return values
51
51
 
52
- def run(self, query: str) -> str:
52
+ def run(self, query: str, **kwargs) -> str:
53
53
  """Run query through api and parse result."""
54
54
  if query:
55
55
  self.params[self.input_key] = query
56
+ if kwargs:
57
+ self.params.update(kwargs)
56
58
  if self.params:
57
59
  param = '&'.join([f'{k}={v}' for k, v in self.params.items()])
58
60
  url = self.url + '?' + param if '?' not in self.url else self.url + '&' + param
@@ -62,12 +64,14 @@ class APIToolBase(BaseModel):
62
64
  resp = self.client.get(url)
63
65
  if resp.status_code != 200:
64
66
  logger.info('api_call_fail res={}', resp.text)
65
- return resp.text
67
+ return resp.text[:10000]
66
68
 
67
- async def arun(self, query: str) -> str:
69
+ async def arun(self, query: str, **kwargs) -> str:
68
70
  """Run query through api and parse result."""
69
71
  if query:
70
72
  self.params[self.input_key] = query
73
+ if kwargs:
74
+ self.params.update(kwargs)
71
75
  if self.params:
72
76
  param = '&'.join([f'{k}={v}' for k, v in self.params.items()])
73
77
  url = self.url + '?' + param if '?' not in self.url else self.url + '&' + param
@@ -75,8 +79,8 @@ class APIToolBase(BaseModel):
75
79
  url = self.url
76
80
  logger.info('api_call url={}', url)
77
81
  resp = await self.async_client.aget(url)
78
- logger.info(resp)
79
- return resp
82
+ logger.info(resp[:10000])
83
+ return resp[:10000]
80
84
 
81
85
  @classmethod
82
86
  def get_api_tool(cls, name, **kwargs: Any) -> BaseTool:
@@ -1,8 +1,9 @@
1
1
  from loguru import logger
2
2
  from pydantic import BaseModel, Field
3
-
3
+ from typing import Any
4
4
  from .base import APIToolBase
5
-
5
+ from .base import MultArgsSchemaTool
6
+ from langchain_core.tools import BaseTool
6
7
 
7
8
  class FlowTools(APIToolBase):
8
9
 
@@ -34,10 +35,8 @@ class FlowTools(APIToolBase):
34
35
  return resp
35
36
 
36
37
  @classmethod
37
- def knowledge_retrieve(cls, collection_id: int = None) -> str:
38
- """
39
- 知识库检索工具,从内部知识库进行检索总结
40
- """
38
+ def knowledge_retrieve(cls, collection_id: int = None) -> str:
39
+
41
40
  flow_id = 'c7985115-a9d2-446a-9c55-40b5728ffb52'
42
41
  url = 'http://192.168.106.120:3002/api/v1/process/{}'.format(flow_id)
43
42
  input_key = 'inputs'
@@ -55,5 +54,18 @@ class FlowTools(APIToolBase):
55
54
  class InputArgs(BaseModel):
56
55
  """args_schema"""
57
56
  query: str = Field(description='questions to ask')
58
-
57
+
59
58
  return cls(url=url, params=params, input_key=input_key, args_schema=InputArgs)
59
+
60
+ @classmethod
61
+ def get_api_tool(cls, name, **kwargs: Any) -> BaseTool:
62
+ attr_name = name.split('_', 1)[-1]
63
+ class_method = getattr(cls, attr_name)
64
+ function_description = kwargs.get('description','')
65
+ kwargs.pop('description')
66
+ return MultArgsSchemaTool(name=name + '_' +str(kwargs.get('collection_id')),
67
+ description=function_description,
68
+ func=class_method(**kwargs).run,
69
+ coroutine=class_method(**kwargs).arun,
70
+ args_schema=class_method(**kwargs).args_schema)
71
+
@@ -81,7 +81,7 @@ class MacroData(BaseModel):
81
81
  JS_CHINA_GDP_YEARLY_URL = 'https://cdn.jin10.com/dc/reports/dc_chinese_gdp_yoy_all.js?v={}&_={}'
82
82
  t = time.time()
83
83
  r = requests.get(JS_CHINA_GDP_YEARLY_URL.format(str(int(round(t * 1000))), str(int(round(t * 1000)) + 90)))
84
- json_data = json.loads(r.text[r.text.find('{') : r.text.rfind('}') + 1])
84
+ json_data = json.loads(r.text[r.text.find('{'): r.text.rfind('}') + 1])
85
85
  date_list = [item['date'] for item in json_data['list']]
86
86
  value_list = [item['datas']['中国GDP年率报告'] for item in json_data['list']]
87
87
  value_df = pd.DataFrame(value_list)
@@ -249,6 +249,60 @@ class MacroData(BaseModel):
249
249
  temp_df = temp_df[(temp_df['月份'] >= start) & (temp_df['月份'] <= end)]
250
250
  return temp_df.to_markdown()
251
251
 
252
+ @classmethod
253
+ def china_pmi(cls, start_date: str = '', end_date: str = '') -> str:
254
+ """中国 PMI (采购经理人指数)月度统计数据。
255
+ 返回数据包括:月份制造业 PMI,制造业 PMI 同比增长,非制造业 PMI,非制造业 PMI 同比增长。
256
+ """
257
+ url = "https://datacenter-web.eastmoney.com/api/data/v1/get"
258
+ params = {
259
+ "columns": "REPORT_DATE,TIME,MAKE_INDEX,MAKE_SAME,NMAKE_INDEX,NMAKE_SAME",
260
+ "pageNumber": "1",
261
+ "pageSize": "2000",
262
+ "sortColumns": "REPORT_DATE",
263
+ "sortTypes": "-1",
264
+ "source": "WEB",
265
+ "client": "WEB",
266
+ "reportName": "RPT_ECONOMY_PMI",
267
+ "p": "1",
268
+ "pageNo": "1",
269
+ "pageNum": "1",
270
+ "_": "1669047266881",
271
+ }
272
+ r = requests.get(url, params=params)
273
+ data_json = r.json()
274
+ temp_df = pd.DataFrame(data_json["result"]["data"])
275
+ temp_df.columns = [
276
+ "-",
277
+ "月份",
278
+ "制造业-指数",
279
+ "制造业-同比增长",
280
+ "非制造业-指数",
281
+ "非制造业-同比增长",
282
+ ]
283
+ temp_df = temp_df[
284
+ [
285
+ "月份",
286
+ "制造业-指数",
287
+ "制造业-同比增长",
288
+ "非制造业-指数",
289
+ "非制造业-同比增长",
290
+ ]
291
+ ]
292
+ temp_df["制造业-指数"] = pd.to_numeric(temp_df["制造业-指数"], errors="coerce")
293
+ temp_df["制造业-同比增长"] = pd.to_numeric(
294
+ temp_df["制造业-同比增长"], errors="coerce"
295
+ )
296
+ temp_df["非制造业-指数"] = pd.to_numeric(temp_df["非制造业-指数"], errors="coerce")
297
+ temp_df["非制造业-同比增长"] = pd.to_numeric(
298
+ temp_df["非制造业-同比增长"], errors="coerce"
299
+ )
300
+ if start_date and end_date:
301
+ start = start_date.split('-')[0] + '年' + start_date.split('-')[1] + '月份'
302
+ end = end_date.split('-')[0] + '年' + end_date.split('-')[1] + '月份'
303
+ temp_df = temp_df[(temp_df['月份'] >= start) & (temp_df['月份'] <= end)]
304
+ return temp_df.to_markdown()
305
+
252
306
  @classmethod
253
307
  def china_money_supply(cls, start_date: str = '', end_date: str = '') -> pd.DataFrame:
254
308
  """中国货币供应量(M2,M1,M0)月度统计数据。\
@@ -376,6 +430,121 @@ M0数量(单位:亿元),M0 同比(单位:%),M0 环比(单位
376
430
 
377
431
  return temp_df.to_markdown()
378
432
 
433
+ @classmethod
434
+ def bond_zh_us_rate(cls, start_date: str = "", end_date: str = "") -> str:
435
+ """
436
+ 本接口返回指定时间段[start_date,end_date]内交易日的中美两国的 2 年、5 年、10 年、30 年、10 年-2 年收益率数据。
437
+ start_date表示起始日期,end_date表示结束日期,日期格式例如 2024-04-07
438
+ """
439
+ url = "https://datacenter.eastmoney.com/api/data/get"
440
+ params = {
441
+ "type": "RPTA_WEB_TREASURYYIELD",
442
+ "sty": "ALL",
443
+ "st": "SOLAR_DATE",
444
+ "sr": "-1",
445
+ "token": "894050c76af8597a853f5b408b759f5d",
446
+ "p": "1",
447
+ "ps": "500",
448
+ "pageNo": "1",
449
+ "pageNum": "1",
450
+ "_": "1615791534490",
451
+ }
452
+ r = requests.get(url, params=params)
453
+ data_json = r.json()
454
+ total_page = data_json["result"]["pages"]
455
+ big_df = pd.DataFrame()
456
+ for page in range(1, total_page + 1):
457
+ params = {
458
+ "type": "RPTA_WEB_TREASURYYIELD",
459
+ "sty": "ALL",
460
+ "st": "SOLAR_DATE",
461
+ "sr": "-1",
462
+ "token": "894050c76af8597a853f5b408b759f5d",
463
+ "p": page,
464
+ "ps": "500",
465
+ "pageNo": page,
466
+ "pageNum": page,
467
+ "_": "1615791534490",
468
+ }
469
+ r = requests.get(url, params=params)
470
+ data_json = r.json()
471
+ # 时间过滤
472
+ if start_date and end_date:
473
+ temp_data = []
474
+ for item in data_json["result"]["data"]:
475
+ if start_date <= item["SOLAR_DATE"].split(" ")[0] <= end_date:
476
+ temp_data.append(item)
477
+ elif start_date > item["SOLAR_DATE"].split(" ")[0]:
478
+ break
479
+ else:
480
+ continue
481
+ else:
482
+ temp_data = data_json["result"]["data"]
483
+ temp_df = pd.DataFrame(temp_data)
484
+ for col in temp_df.columns:
485
+ if temp_df[col].isnull().all(): # 检查列是否包含 None 或 NaN
486
+ temp_df[col] = pd.to_numeric(temp_df[col], errors='coerce')
487
+ if big_df.empty:
488
+ big_df = temp_df
489
+ else:
490
+ big_df = pd.concat(objs=[big_df, temp_df], ignore_index=True)
491
+
492
+ big_df.rename(
493
+ columns={
494
+ "SOLAR_DATE": "日期",
495
+ "EMM00166462": "中国国债收益率5年",
496
+ "EMM00166466": "中国国债收益率10年",
497
+ "EMM00166469": "中国国债收益率30年",
498
+ "EMM00588704": "中国国债收益率2年",
499
+ "EMM01276014": "中国国债收益率10年-2年",
500
+ "EMG00001306": "美国国债收益率2年",
501
+ "EMG00001308": "美国国债收益率5年",
502
+ "EMG00001310": "美国国债收益率10年",
503
+ "EMG00001312": "美国国债收益率30年",
504
+ "EMG01339436": "美国国债收益率10年-2年",
505
+ "EMM00000024": "中国GDP年增率",
506
+ "EMG00159635": "美国GDP年增率",
507
+ },
508
+ inplace=True,
509
+ )
510
+ big_df = big_df[
511
+ [
512
+ "日期",
513
+ "中国国债收益率2年",
514
+ "中国国债收益率5年",
515
+ "中国国债收益率10年",
516
+ "中国国债收益率30年",
517
+ "中国国债收益率10年-2年",
518
+ "中国GDP年增率",
519
+ "美国国债收益率2年",
520
+ "美国国债收益率5年",
521
+ "美国国债收益率10年",
522
+ "美国国债收益率30年",
523
+ "美国国债收益率10年-2年",
524
+ "美国GDP年增率",
525
+ ]
526
+ ]
527
+ big_df = big_df.drop(["中国GDP年增率", "美国GDP年增率"], axis=1)
528
+ big_df["日期"] = pd.to_datetime(big_df["日期"], errors="coerce")
529
+ big_df["中国国债收益率2年"] = pd.to_numeric(big_df["中国国债收益率2年"], errors="coerce")
530
+ big_df["中国国债收益率5年"] = pd.to_numeric(big_df["中国国债收益率5年"], errors="coerce")
531
+ big_df["中国国债收益率10年"] = pd.to_numeric(big_df["中国国债收益率10年"], errors="coerce")
532
+ big_df["中国国债收益率30年"] = pd.to_numeric(big_df["中国国债收益率30年"], errors="coerce")
533
+ big_df["中国国债收益率10年-2年"] = pd.to_numeric(big_df["中国国债收益率10年-2年"], errors="coerce")
534
+ # big_df["中国GDP年增率"] = pd.to_numeric(big_df["中国GDP年增率"], errors="coerce")
535
+ big_df["美国国债收益率2年"] = pd.to_numeric(big_df["美国国债收益率2年"], errors="coerce")
536
+ big_df["美国国债收益率5年"] = pd.to_numeric(big_df["美国国债收益率5年"], errors="coerce")
537
+ big_df["美国国债收益率10年"] = pd.to_numeric(big_df["美国国债收益率10年"], errors="coerce")
538
+ big_df["美国国债收益率30年"] = pd.to_numeric(big_df["美国国债收益率30年"], errors="coerce")
539
+ big_df["美国国债收益率10年-2年"] = pd.to_numeric(big_df["美国国债收益率10年-2年"], errors="coerce")
540
+ # big_df["美国GDP年增率"] = pd.to_numeric(big_df["美国GDP年增率"], errors="coerce")
541
+ big_df.sort_values("日期", inplace=True)
542
+ big_df.set_index(["日期"], inplace=True)
543
+ big_df = big_df[pd.to_datetime(start_date):]
544
+ big_df.reset_index(inplace=True)
545
+ big_df["日期"] = pd.to_datetime(big_df["日期"]).dt.date
546
+ return big_df.to_markdown()
547
+
379
548
  @classmethod
380
549
  def get_api_tool(cls, name: str, **kwargs: Any) -> BaseTool:
381
550
  attr_name = name.split('_', 1)[-1]
@@ -385,13 +554,15 @@ M0数量(单位:亿元),M0 同比(单位:%),M0 环比(单位
385
554
 
386
555
 
387
556
  if __name__ == '__main__':
388
- start_date = '2023-01-01'
389
- end_date = '2023-05-01'
557
+ tmp_start_date = '2024-01-01'
558
+ tmp_end_date = '2024-01-03'
390
559
  # start_date = ''
391
560
  # end_date = ''
392
561
  # print(MacroData.china_ppi(start_date=start_date, end_date=end_date))
393
562
  # print(MacroData.china_shrzgm(start_date=start_date, end_date=end_date))
394
563
  # print(MacroData.china_consumer_goods_retail(start_date=start_date, end_date=end_date))
395
564
  # print(MacroData.china_cpi(start_date=start_date, end_date=end_date))
565
+ # print(MacroData.china_pmi(start_date=start_date, end_date=end_date))
396
566
  # print(MacroData.china_money_supply(start_date=start_date, end_date=end_date))
397
- print(MacroData.china_gdp_yearly(start_date=start_date, end_date=end_date))
567
+ # print(MacroData.china_gdp_yearly(start_date=start_date, end_date=end_date))
568
+ print(MacroData.bond_zh_us_rate(start_date=tmp_start_date, end_date=tmp_end_date))
@@ -8,6 +8,7 @@ from datetime import datetime
8
8
  from typing import List, Type
9
9
 
10
10
  from langchain_core.pydantic_v1 import BaseModel, Field
11
+ from loguru import logger
11
12
 
12
13
  from .base import APIToolBase
13
14
 
@@ -139,9 +140,11 @@ class StockInfo(APIToolBase):
139
140
  ts = int(datetime.timestamp(date_obj) * 1000)
140
141
  stock = f'{stock_number}_240_{ts}'
141
142
  count = datetime.today() - date_obj
142
- self.url = self.url.format(stockName=stock_number, stock=stock, count=count.days)
143
-
144
- k_data = super().run('')
143
+ url = self.url.format(stockName=stock_number, stock=stock, count=count.days)
144
+ resp = self.client.get(url)
145
+ if resp.status_code != 200:
146
+ logger.info('api_call_fail res={}', resp.text)
147
+ k_data = resp.text
145
148
  data_array = json.loads(kLinePattern.search(k_data).group(1))
146
149
  for item in data_array:
147
150
  if item.get('day') == date:
@@ -151,7 +154,7 @@ class StockInfo(APIToolBase):
151
154
  resp = super().run(query=stock_number)
152
155
  stock = self.devideStock(resp)[0]
153
156
  if isinstance(stock, Stock):
154
- return json.dumps(stock.__dict__)
157
+ return json.dumps(stock.__dict__, ensure_ascii=False)
155
158
  else:
156
159
  return stock
157
160
 
@@ -168,9 +171,8 @@ class StockInfo(APIToolBase):
168
171
  ts = int(datetime.timestamp(date_obj) * 1000)
169
172
  stock = f'{stock_number}_240_{ts}'
170
173
  count = datetime.today() - date_obj
171
- self.url = self.url.format(stockName=stock_number, stock=stock, count=count.days)
172
- k_data = await super().arun('')
173
-
174
+ url = self.url.format(stockName=stock_number, stock=stock, count=count.days)
175
+ k_data = await self.async_client.aget(url)
174
176
  data_array = json.loads(kLinePattern.search(k_data).group(1))
175
177
  for item in data_array:
176
178
  if item.get('day') == date:
@@ -181,7 +183,7 @@ class StockInfo(APIToolBase):
181
183
  resp = await super().arun(query=stock_number)
182
184
  stock = self.devideStock(resp)[0]
183
185
  if isinstance(stock, Stock):
184
- return json.dumps(stock.__dict__)
186
+ return json.dumps(stock.__dict__, ensure_ascii=False)
185
187
  else:
186
188
  return stock
187
189
 
@@ -114,29 +114,7 @@ class CompanyInfo(APIToolBase):
114
114
 
115
115
  @classmethod
116
116
  def all_companys_by_company(cls, api_key: str, pageSize: int = 20, pageNum: int = 1):
117
- """可以通过公司名称获取企业人员的所有相关公司,包括其担任法人、股东、董监高的公司信息"""
118
- url = 'http://open.api.tianyancha.com/services/v4/open/allCompanys'
119
- input_key = 'name'
120
- params = {}
121
- params['pageSize'] = pageSize
122
- params['pageNum'] = pageNum
123
-
124
- class InputArgs(BaseModel):
125
- """args_schema"""
126
- query: str = Field(description='company name to query')
127
-
128
- return cls(url=url,
129
- api_key=api_key,
130
- params=params,
131
- input_key=input_key,
132
- args_schema=InputArgs)
133
-
134
- @classmethod
135
- def all_companys_by_humanname(cls,
136
- api_key: str,
137
- pageSize: int = 20,
138
- pageNum: int = 1) -> CompanyInfo:
139
- """可以通过人名获取企业人员的所有相关公司,包括其担任法人、股东、董监高的公司信息"""
117
+ """可以通过公司名称和人名获取企业人员的所有相关公司,包括其担任法人、股东、董监高的公司信息"""
140
118
  url = 'http://open.api.tianyancha.com/services/v4/open/allCompanys'
141
119
  input_key = 'humanName'
142
120
  params = {}
@@ -145,7 +123,8 @@ class CompanyInfo(APIToolBase):
145
123
 
146
124
  class InputArgs(BaseModel):
147
125
  """args_schema"""
148
- query: str = Field(description='human name to query')
126
+ query: str = Field(description='human who you want to search')
127
+ name: str = Field(description='company name which human worked')
149
128
 
150
129
  return cls(url=url,
151
130
  api_key=api_key,
@@ -1,6 +1,8 @@
1
+ import glob
1
2
  import itertools
2
3
  import os
3
4
  import pathlib
5
+ import re
4
6
  import subprocess
5
7
  import sys
6
8
  import tempfile
@@ -11,24 +13,18 @@ from pathlib import Path
11
13
  from typing import Dict, List, Optional, Tuple, Type
12
14
  from uuid import uuid4
13
15
 
14
- from autogen.code_utils import extract_code, infer_lang
16
+ import matplotlib
15
17
  from langchain_community.tools import Tool
16
18
  from langchain_core.pydantic_v1 import BaseModel, Field
17
19
  from loguru import logger
18
20
 
19
- try:
20
- from termcolor import colored
21
- except ImportError:
22
-
23
- def colored(x, *args, **kwargs):
24
- return x
25
-
26
-
21
+ CODE_BLOCK_PATTERN = r"```(\w*)\n(.*?)\n```"
27
22
  DEFAULT_TIMEOUT = 600
28
23
  WIN32 = sys.platform == 'win32'
29
24
  PATH_SEPARATOR = WIN32 and '\\' or '/'
30
25
  WORKING_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'extensions')
31
26
  TIMEOUT_MSG = 'Timeout'
27
+ UNKNOWN = "unknown"
32
28
 
33
29
 
34
30
  def _cmd(lang):
@@ -41,6 +37,61 @@ def _cmd(lang):
41
37
  raise NotImplementedError(f'{lang} not recognized in code execution')
42
38
 
43
39
 
40
+ def infer_lang(code):
41
+ """infer the language for the code.
42
+ TODO: make it robust.
43
+ """
44
+ if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "):
45
+ return "sh"
46
+
47
+ # check if code is a valid python code
48
+ try:
49
+ compile(code, "test", "exec")
50
+ return "python"
51
+ except SyntaxError:
52
+ # not a valid python code
53
+ return UNKNOWN
54
+
55
+
56
+ def extract_code(
57
+ text: str, pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False
58
+ ) -> List[Tuple[str, str]]:
59
+ """Extract code from a text.
60
+
61
+ Args:
62
+ text (str): The text to extract code from.
63
+ pattern (str, optional): The regular expression pattern for finding the
64
+ code block. Defaults to CODE_BLOCK_PATTERN.
65
+ detect_single_line_code (bool, optional): Enable the new feature for
66
+ extracting single line code. Defaults to False.
67
+
68
+ Returns:
69
+ list: A list of tuples, each containing the language and the code.
70
+ If there is no code block in the input text, the language would be "unknown".
71
+ If there is code block but the language is not specified, the language would be "".
72
+ """
73
+ if not detect_single_line_code:
74
+ match = re.findall(pattern, text, flags=re.DOTALL)
75
+ return match if match else [(UNKNOWN, text)]
76
+
77
+ # Extract both multi-line and single-line code block, separated by the | operator
78
+ # `{3}(\w+)?\s*([\s\S]*?)`{3}: Matches multi-line code blocks.
79
+ # The (\w+)? matches the language, where the ? indicates it is optional.
80
+ # `([^`]+)`: Matches inline code.
81
+ code_pattern = re.compile(r"`{3}(\w+)?\s*([\s\S]*?)`{3}|`([^`]+)`")
82
+ code_blocks = code_pattern.findall(text)
83
+
84
+ # Extract the individual code blocks and languages from the matched groups
85
+ extracted = []
86
+ for lang, group1, group2 in code_blocks:
87
+ if group1:
88
+ extracted.append((lang.strip(), group1.strip()))
89
+ elif group2:
90
+ extracted.append(("", group2.strip()))
91
+
92
+ return extracted
93
+
94
+
44
95
  def execute_code(
45
96
  code: Optional[str] = None,
46
97
  timeout: Optional[int] = None,
@@ -121,16 +172,66 @@ def head_file(path: str, n: int) -> List[str]:
121
172
  return []
122
173
 
123
174
 
124
- def upload_minio(param: dict, bucket: str, object_name: str, file_path, content_type='application/text'):
175
+ def upload_minio(
176
+ param: dict,
177
+ bucket: str,
178
+ object_name: str,
179
+ file_path,
180
+ content_type='application/text',
181
+ ):
125
182
  # 初始化minio
126
183
  import minio
127
184
 
128
- minio_client = minio.Minio(**param)
129
- logger.debug('upload_file obj={} bucket={} file_paht={}', object_name, bucket, file_path)
185
+ minio_client = minio.Minio(
186
+ endpoint=param.get('MINIO_ENDPOINT'),
187
+ access_key=param.get('MINIO_ACCESS_KEY'),
188
+ secret_key=param.get('MINIO_SECRET_KEY'),
189
+ secure=param.get('SCHEMA'),
190
+ cert_check=param.get('CERT_CHECK'),
191
+ )
192
+ minio_share = minio.Minio(
193
+ endpoint=param.get('MINIO_SHAREPOIN'),
194
+ access_key=param.get('MINIO_ACCESS_KEY'),
195
+ secret_key=param.get('MINIO_SECRET_KEY'),
196
+ secure=param.get('SCHEMA'),
197
+ cert_check=param.get('CERT_CHECK'),
198
+ )
199
+ logger.debug(
200
+ 'upload_file obj={} bucket={} file_paht={}',
201
+ object_name,
202
+ bucket,
203
+ file_path,
204
+ )
130
205
  minio_client.fput_object(
131
- bucket_name=bucket, object_name=object_name, file_path=file_path, content_type=content_type
206
+ bucket_name=bucket,
207
+ object_name=object_name,
208
+ file_path=file_path,
209
+ content_type=content_type,
210
+ )
211
+ return minio_share.presigned_get_object(
212
+ bucket_name=bucket,
213
+ object_name=object_name,
214
+ expires=timedelta(days=7),
132
215
  )
133
- return minio_client.presigned_get_object(bucket_name=bucket, object_name=object_name, expires=timedelta(days=7))
216
+
217
+
218
+ def insert_set_font_code(code: str) -> str:
219
+ """判断python代码中是否导入了matplotlib库,如果有则插入设置字体的代码"""
220
+
221
+ split_code = code.split('\n')
222
+ cache_file = matplotlib.get_cachedir()
223
+ font_cache = glob.glob(f'{cache_file}/fontlist*')
224
+
225
+ for cache in font_cache:
226
+ os.remove(cache)
227
+
228
+ # todo: 如果生成的代码中已经有了设置字体的代码,可能会导致该段代码失效
229
+ if 'matplotlib' in code:
230
+ pattern = re.compile(r'(import matplotlib|from matplotlib)')
231
+ index = max(i for i, line in enumerate(split_code) if pattern.search(line))
232
+ split_code.insert(index + 1, 'import matplotlib\nmatplotlib.rc("font", family="WenQuanYi Zen Hei")')
233
+
234
+ return '\n'.join(split_code)
134
235
 
135
236
 
136
237
  class CodeInterpreterToolArguments(BaseModel):
@@ -169,7 +270,7 @@ class FileInfo(BaseModel):
169
270
  class CodeInterpreterTool:
170
271
  """Tool for evaluating python code in native environment."""
171
272
 
172
- name = 'code_interpreter'
273
+ name = 'bisheng_code_interpreter'
173
274
  args_schema: Type[BaseModel] = CodeInterpreterToolArguments
174
275
 
175
276
  def __init__(
@@ -204,6 +305,7 @@ class CodeInterpreterTool:
204
305
  for i, code_block in enumerate(code_blocks):
205
306
  lang, code = code_block
206
307
  lang = infer_lang(code)
308
+ code = insert_set_font_code(code)
207
309
  temp_dir = tempfile.TemporaryDirectory()
208
310
  exitcode, logs, _ = execute_code(
209
311
  code,
@@ -215,7 +317,7 @@ class CodeInterpreterTool:
215
317
  return {'exitcode': exitcode, 'log': logs_all}
216
318
 
217
319
  # 获取文件
218
- temp_output_dir = Path(temp_dir.name) / 'output'
320
+ temp_output_dir = Path(temp_dir.name)
219
321
  for root, dirs, files in os.walk(temp_output_dir):
220
322
  for name in files:
221
323
  file_name = os.path.join(root, name)
@@ -236,26 +338,3 @@ class CodeInterpreterTool:
236
338
  description=self.description,
237
339
  args_schema=self.args_schema,
238
340
  )
239
-
240
-
241
- if __name__ == '__main__':
242
- code_string = """print('hha')"""
243
- code_blocks = extract_code(code_string)
244
- logger.info(code_blocks)
245
- logs_all = ''
246
- for i, code_block in enumerate(code_blocks):
247
- lang, code = code_block
248
- lang = infer_lang(code)
249
- print(
250
- colored(
251
- f'\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...',
252
- 'red',
253
- ),
254
- flush=True,
255
- )
256
- exitcode, logs, image = execute_code(code, lang=lang)
257
- logs_all += '\n' + logs
258
- if exitcode != 0:
259
- logger.error(f'{exitcode}, {logs_all}')
260
-
261
- logger.info(logs_all)
@@ -132,7 +132,7 @@ class ElasticKeywordsSearch(VectorStore, ABC):
132
132
  self.client.indices.delete(index=index_name)
133
133
  except elasticsearch.exceptions.NotFoundError:
134
134
  pass
135
-
135
+
136
136
  def add_texts(
137
137
  self,
138
138
  texts: Iterable[str],
@@ -195,6 +195,9 @@ class ElasticKeywordsSearch(VectorStore, ABC):
195
195
  query_strategy: str = 'match_phrase',
196
196
  must_or_should: str = 'should',
197
197
  **kwargs: Any) -> List[Document]:
198
+ if k == 0:
199
+ # pm need to control
200
+ return []
198
201
  docs_and_scores = self.similarity_search_with_score(query,
199
202
  k=k,
200
203
  query_strategy=query_strategy,
@@ -218,6 +221,9 @@ class ElasticKeywordsSearch(VectorStore, ABC):
218
221
  query_strategy: str = 'match_phrase',
219
222
  must_or_should: str = 'should',
220
223
  **kwargs: Any) -> List[Tuple[Document, float]]:
224
+ if k == 0:
225
+ # pm need to control
226
+ return []
221
227
  assert must_or_should in ['must', 'should'], 'only support must and should.'
222
228
  # llm or jiaba extract keywords
223
229
  if self.llm_chain:
@@ -288,10 +294,17 @@ class ElasticKeywordsSearch(VectorStore, ABC):
288
294
  index_name = index_name or uuid.uuid4().hex
289
295
  if llm:
290
296
  llm_chain = LLMChain(llm=llm, prompt=prompt)
291
- vectorsearch = cls(elasticsearch_url, index_name, llm_chain=llm_chain, drop_old=drop_old, **kwargs)
297
+ vectorsearch = cls(elasticsearch_url,
298
+ index_name,
299
+ llm_chain=llm_chain,
300
+ drop_old=drop_old,
301
+ **kwargs)
292
302
  else:
293
303
  vectorsearch = cls(elasticsearch_url, index_name, drop_old=drop_old, **kwargs)
294
- vectorsearch.add_texts(texts, metadatas=metadatas, ids=ids, refresh_indices=refresh_indices)
304
+ vectorsearch.add_texts(texts,
305
+ metadatas=metadatas,
306
+ ids=ids,
307
+ refresh_indices=refresh_indices)
295
308
 
296
309
  return vectorsearch
297
310
 
@@ -552,6 +552,9 @@ class Milvus(MilvusLangchain):
552
552
  Returns:
553
553
  List[Document]: Document results for search.
554
554
  """
555
+ if k == 0:
556
+ # pm need to control
557
+ return []
555
558
  if self.col is None:
556
559
  logger.debug('No existing collection to search.')
557
560
  return []
@@ -587,6 +590,9 @@ class Milvus(MilvusLangchain):
587
590
  Returns:
588
591
  List[Document]: Document results for search.
589
592
  """
593
+ if k == 0:
594
+ # pm need to control
595
+ return []
590
596
  if self.col is None:
591
597
  logger.debug('No existing collection to search.')
592
598
  return []
@@ -626,6 +632,9 @@ class Milvus(MilvusLangchain):
626
632
  Returns:
627
633
  List[float], List[Tuple[Document, any, any]]:
628
634
  """
635
+ if k == 0:
636
+ # pm need to control
637
+ return []
629
638
  if self.col is None:
630
639
  logger.debug('No existing collection to search.')
631
640
  return []
@@ -669,6 +678,9 @@ class Milvus(MilvusLangchain):
669
678
  Returns:
670
679
  List[Tuple[Document, float]]: Result doc and score.
671
680
  """
681
+ if k == 0:
682
+ # pm need to control
683
+ return []
672
684
  if self.col is None:
673
685
  logger.debug('No existing collection to search.')
674
686
  return []
@@ -741,6 +753,9 @@ class Milvus(MilvusLangchain):
741
753
  Returns:
742
754
  List[Document]: Document results for search.
743
755
  """
756
+ if k == 0:
757
+ # pm need to control
758
+ return []
744
759
  if self.col is None:
745
760
  logger.debug('No existing collection to search.')
746
761
  return []
@@ -790,6 +805,9 @@ class Milvus(MilvusLangchain):
790
805
  Returns:
791
806
  List[Document]: Document results for search.
792
807
  """
808
+ if k == 0:
809
+ # pm need to control
810
+ return []
793
811
  if self.col is None:
794
812
  logger.debug('No existing collection to search.')
795
813
  return []
@@ -908,7 +926,7 @@ class Milvus(MilvusLangchain):
908
926
 
909
927
  def _select_relevance_score_fn(self) -> Callable[[float], float]:
910
928
  return self._relevance_score_fn
911
-
929
+
912
930
  def query(self, expr: str, timeout: Optional[int] = None, **kwargs: Any) -> List[Document]:
913
931
  output_fields = self.fields[:]
914
932
  output_fields.remove(self._vector_field)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bisheng-langchain
3
- Version: 0.3.0b0
3
+ Version: 0.3.0rc1
4
4
  Summary: bisheng langchain modules
5
5
  Home-page: https://github.com/dataelement/bisheng
6
6
  Author: DataElem
@@ -8,10 +8,10 @@ bisheng_langchain/agents/chatglm_functions_agent/prompt.py,sha256=OiBTRUOhvhSyO2
8
8
  bisheng_langchain/agents/llm_functions_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  bisheng_langchain/agents/llm_functions_agent/base.py,sha256=DbykNAk3vU2sfTPTSM2KotHygXgzAJSUmo4tA0h9ezc,12296
10
10
  bisheng_langchain/autogen_role/__init__.py,sha256=MnTGbAOK770JM9l95Qcxu93s2gNAmhlil7K9HdFG81o,430
11
- bisheng_langchain/autogen_role/assistant.py,sha256=VGCoxJaRxRG6ZIJa2TsxcLZbMbF4KC8PRB76DOuznNU,4736
12
- bisheng_langchain/autogen_role/custom.py,sha256=8xxtAzNF_N1fysyChynVD19t659Qvtcyj_LNiOrE7ew,2499
13
- bisheng_langchain/autogen_role/groupchat_manager.py,sha256=O9XIove5yzyF_g3K5DnF-Fasdx0sUrRWMogYgEDYJAI,2314
14
- bisheng_langchain/autogen_role/user.py,sha256=lISbJN5yFsUXHnDCUwr5t6R8O8K3dOMspH4l4_kITnE,5885
11
+ bisheng_langchain/autogen_role/assistant.py,sha256=rqUaD6fbW6d1jtzfrUQv5pJMKJgVGLagllz8LvzPCxY,4657
12
+ bisheng_langchain/autogen_role/custom.py,sha256=vAyEGxnmV9anyLL12v4ZB_A2VOPwdl-kjGP037I8jPw,2464
13
+ bisheng_langchain/autogen_role/groupchat_manager.py,sha256=AybsH3duoAFpo3bojOYVeSOE4iYkkbgmYIga6m2Jj_Y,2234
14
+ bisheng_langchain/autogen_role/user.py,sha256=fbaORhC7oQjxGhc2RYIWpELdIogFBsgqgQUhZsK6Osk,5715
15
15
  bisheng_langchain/chains/__init__.py,sha256=oxN2tUMt_kNxKd_FzCQ7x8xIwojtdCNNKo-DI7q0unM,759
16
16
  bisheng_langchain/chains/loader_output.py,sha256=02ZercAFaudStTZ4t7mcVkGRj5pD78HZ6NO8HbmbDH8,1903
17
17
  bisheng_langchain/chains/transform.py,sha256=G2fMqoMB62e03ES--aoVjEo06FzYWb87jCt3EOsiwwg,2805
@@ -28,7 +28,7 @@ bisheng_langchain/chains/router/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm
28
28
  bisheng_langchain/chains/router/multi_rule.py,sha256=BiFryj3-7rOxfttD-MyOkKWLCSGB9LVYd2rjOsIfQC8,375
29
29
  bisheng_langchain/chains/router/rule_router.py,sha256=R2YRUnwn7s_7DbsSn27uPn4cIV0D-5iXEORXir0tNGM,1835
30
30
  bisheng_langchain/chat_models/__init__.py,sha256=4-HTLE_SXO4hmNJu6yQxiQKBt2IFca_ezllVBLmvbEE,635
31
- bisheng_langchain/chat_models/host_llm.py,sha256=sBu_Vg-r7z6IJUV8Etwll4JTG5OvET-IXH7PZw8Ijrc,23152
31
+ bisheng_langchain/chat_models/host_llm.py,sha256=35_jTdUm85mk-t2MARZYGC8dIPVtf5XXlGfFE6hQ1Gc,23153
32
32
  bisheng_langchain/chat_models/minimax.py,sha256=JLs_f6vWD9beZYUtjD4FG28G8tZHrGUAWOwdLIuJomw,13901
33
33
  bisheng_langchain/chat_models/proxy_llm.py,sha256=wzVBZik9WC3-f7kyQ1eu3Ooibqpcocln08knf5lV1Nw,17082
34
34
  bisheng_langchain/chat_models/qwen.py,sha256=W73KxDRQBUZEzttEM4K7ZzPqbN-82O6YQmpX-HB_wZU,19971
@@ -66,9 +66,9 @@ bisheng_langchain/embeddings/interface/types.py,sha256=VdurbtsnjCPdlOjPFcK2Mg6r9
66
66
  bisheng_langchain/embeddings/interface/wenxin.py,sha256=5d9gI4enmfkD80s0FHKiDt33O0mwM8Xc5WTubnMUy8c,3104
67
67
  bisheng_langchain/gpts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
68
68
  bisheng_langchain/gpts/assistant.py,sha256=KCYPU1Bs4GtWcLk9Ya2NuQyXE0Twn7-92eSBTIzpq7I,5083
69
- bisheng_langchain/gpts/auto_optimization.py,sha256=Vf3zzYEpVf916dYt4RV9E1uw4vTXjE7ZXogUIdxjHYU,3786
70
- bisheng_langchain/gpts/auto_tool_selected.py,sha256=25lFLadqQ36t63EKMEF3zJOG_jkoRB9IfP5eRkY1JZo,1777
71
- bisheng_langchain/gpts/load_tools.py,sha256=C7tlRLy4wAArr9qtkRl9dW6QXdspLLbcv0UvulW9A9U,6345
69
+ bisheng_langchain/gpts/auto_optimization.py,sha256=WNsC19rgvuDYQlSIaYThq5RqCbuobDbzCwAJW4Ksw0c,3626
70
+ bisheng_langchain/gpts/auto_tool_selected.py,sha256=21WETf9o0YS-QEBwv3mmZRObKWszefQkXEqAA6KzoaM,1582
71
+ bisheng_langchain/gpts/load_tools.py,sha256=JZpwTH5cvaLdab8-TbTxBGHug-llnCQR0wB4VsduSrs,7871
72
72
  bisheng_langchain/gpts/message_types.py,sha256=7EJOx62j9E1U67jxWgxE_I7a8IjAvvKANknXkD2gFm0,213
73
73
  bisheng_langchain/gpts/utils.py,sha256=t3YDxaJ0OYd6EKsek7PJFRYnsezwzEFK5oVU-PRbu5g,6671
74
74
  bisheng_langchain/gpts/agent_types/__init__.py,sha256=bg0zlTYGfNXoSBqcICHlzNpVQbejMYeyji_dzvP5qQ0,261
@@ -77,21 +77,21 @@ bisheng_langchain/gpts/prompts/__init__.py,sha256=IfuoxVpsSLKJtDx0aJbRgnSZYZr_kD
77
77
  bisheng_langchain/gpts/prompts/assistant_prompt_opt.py,sha256=TZsRK4XPMrUhGg0PoMyiE3wE-aG34UmlVflkCl_c0QI,4151
78
78
  bisheng_langchain/gpts/prompts/base_prompt.py,sha256=v2eO0c6RF8e6MtGdleHs5B4YTkikg6IZUuBvL2zvyOI,55
79
79
  bisheng_langchain/gpts/prompts/breif_description_prompt.py,sha256=w4A5et0jB-GkxEMQBp4i6GKX3RkVeu7NzWEjOZZAicM,5336
80
- bisheng_langchain/gpts/prompts/opening_dialog_prompt.py,sha256=U6SDslWuXAB1ZamLZVujpEjAY8L244IZfD2qFVRTzPM,5962
80
+ bisheng_langchain/gpts/prompts/opening_dialog_prompt.py,sha256=VVF0JLHtetupVB0kabiFHWDHlQaa4nFLcbYXgIBA3nw,5965
81
81
  bisheng_langchain/gpts/prompts/select_tools_prompt.py,sha256=AyvVnrLEsQy7RHuGTPkcrMUxgA98Q0TzF-xweoc7GyY,1400
82
82
  bisheng_langchain/gpts/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
83
- bisheng_langchain/gpts/tools/api_tools/__init__.py,sha256=mrmTV5bT5R1mEx9hbMAWKzNAAC4EL6biNn53dx5lYsc,1593
84
- bisheng_langchain/gpts/tools/api_tools/base.py,sha256=TF5MW0e62YvcfABp_-U32ESMKvN9CXPFKqiCeaZ3xFk,3458
85
- bisheng_langchain/gpts/tools/api_tools/flow.py,sha256=u1_ASWlCcZarKR-293kACB_qQ1RzJuzPC3YZSl2JR-E,1814
86
- bisheng_langchain/gpts/tools/api_tools/macro_data.py,sha256=rlFNhjJ3HEHfWeW9Wqb27eeF1Q1Qmd2SA8VfgUK4ACs,19270
87
- bisheng_langchain/gpts/tools/api_tools/sina.py,sha256=A8YDLko3lptBxkGN2_e42GWfcrEgwJAZyp2wouC1Qvg,9340
88
- bisheng_langchain/gpts/tools/api_tools/tianyancha.py,sha256=sQbjPt8K0dLupFprWwc_z938DBB8nB7ydyIV5frWSJ0,7461
83
+ bisheng_langchain/gpts/tools/api_tools/__init__.py,sha256=CkEjgIFM4GIv86V1B7SsFLaB6M86c54QuO8wIRizUZ8,1608
84
+ bisheng_langchain/gpts/tools/api_tools/base.py,sha256=fWQSDIOVb4JZrtJ9ML9q2ycsAa-_61gXTD0MT19J1LM,3618
85
+ bisheng_langchain/gpts/tools/api_tools/flow.py,sha256=rHCRpaafriQomMaOqSeKjPXwVUO_nAsFDNRIjOofbuI,2486
86
+ bisheng_langchain/gpts/tools/api_tools/macro_data.py,sha256=FyG-qtl2ECS1CDKt6olN0eDTDM91d-UvDkMDBiVLgYQ,27429
87
+ bisheng_langchain/gpts/tools/api_tools/sina.py,sha256=GGA4ZYvNEpqBZ_l8MUYqgkI8xZe9XcGa9-KlHZVqr6I,9542
88
+ bisheng_langchain/gpts/tools/api_tools/tianyancha.py,sha256=abDAz-yAH1-2rKiSmZ6TgnrNUnpgAZpDY8oDiWfWapc,6684
89
89
  bisheng_langchain/gpts/tools/bing_search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
90
90
  bisheng_langchain/gpts/tools/bing_search/tool.py,sha256=v_VlqcMplITA5go5qWA4qZ5p43E1-1s0bzmyY7H0hqY,1710
91
91
  bisheng_langchain/gpts/tools/calculator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
92
92
  bisheng_langchain/gpts/tools/calculator/tool.py,sha256=iwGPE7jvxZg_jUL2Aq9HHwnRJrF9-ongwrsBX6uk1U0,705
93
93
  bisheng_langchain/gpts/tools/code_interpreter/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
94
- bisheng_langchain/gpts/tools/code_interpreter/tool.py,sha256=PGipxd-qtW31GonRGfGow7nylI-osSnmBsvEJDlMUCE,8717
94
+ bisheng_langchain/gpts/tools/code_interpreter/tool.py,sha256=1VLkgngRR0k8YjA4eYkfPd1E7fD29tMKpqtCtn7WwYE,11443
95
95
  bisheng_langchain/gpts/tools/dalle_image_generator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
96
96
  bisheng_langchain/gpts/tools/dalle_image_generator/tool.py,sha256=mhxdNNhBESjbOy30Rnp6hQhnrV4evQpv-B1fFXcU-68,7528
97
97
  bisheng_langchain/gpts/tools/get_current_time/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -105,10 +105,10 @@ bisheng_langchain/retrievers/mix_es_vector.py,sha256=dSrrsuMPSgGiu181EOzACyIKiDX
105
105
  bisheng_langchain/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
106
106
  bisheng_langchain/utils/requests.py,sha256=vWGKyNTxApVeaVdKxqACfIT1Q8wMy-jC3kUv2Ce9Mzc,8688
107
107
  bisheng_langchain/vectorstores/__init__.py,sha256=zCZgDe7LyQ0iDkfcm5UJ5NxwKQSRHnqrsjx700Fy11M,213
108
- bisheng_langchain/vectorstores/elastic_keywords_search.py,sha256=ACUzgeTwzVOVrm0EqBXF_VhzwrWZJbKYQgqNSW5VhbQ,12929
109
- bisheng_langchain/vectorstores/milvus.py,sha256=hk1XqmWoz04lltubzRcZHEcXXFMkxMeK84hH0GZoo1c,35857
108
+ bisheng_langchain/vectorstores/elastic_keywords_search.py,sha256=JV_GM40cYx0PtPPvH2JYxtsMV0psSW2CDKagpR4M_0o,13286
109
+ bisheng_langchain/vectorstores/milvus.py,sha256=lrnezKnYXhyH5M1g3a-Mcwpj9mwzAj44TKmzyUXlQYY,36297
110
110
  bisheng_langchain/vectorstores/retriever.py,sha256=hj4nAAl352EV_ANnU2OHJn7omCH3nBK82ydo14KqMH4,4353
111
- bisheng_langchain-0.3.0b0.dist-info/METADATA,sha256=ib2MCOn7ntlsILSxun7xakoaW3K53UmhWt750yiSZGg,2413
112
- bisheng_langchain-0.3.0b0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
113
- bisheng_langchain-0.3.0b0.dist-info/top_level.txt,sha256=Z6pPNyCo4ihyr9iqGQbH8sJiC4dAUwA_mAyGRQB5_Fs,18
114
- bisheng_langchain-0.3.0b0.dist-info/RECORD,,
111
+ bisheng_langchain-0.3.0rc1.dist-info/METADATA,sha256=vHWUJcrt2hO4QpW5o0Al8bn23d6c3zpm9yf_2NWGEmE,2414
112
+ bisheng_langchain-0.3.0rc1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
113
+ bisheng_langchain-0.3.0rc1.dist-info/top_level.txt,sha256=Z6pPNyCo4ihyr9iqGQbH8sJiC4dAUwA_mAyGRQB5_Fs,18
114
+ bisheng_langchain-0.3.0rc1.dist-info/RECORD,,