bisheng-langchain 0.2.3.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/agents/llm_functions_agent/base.py +1 -1
- bisheng_langchain/chains/__init__.py +2 -1
- bisheng_langchain/chains/transform.py +85 -0
- bisheng_langchain/chat_models/host_llm.py +19 -5
- bisheng_langchain/chat_models/qwen.py +29 -8
- bisheng_langchain/document_loaders/custom_kv.py +1 -1
- bisheng_langchain/embeddings/host_embedding.py +9 -11
- bisheng_langchain/gpts/__init__.py +0 -0
- bisheng_langchain/gpts/agent_types/__init__.py +10 -0
- bisheng_langchain/gpts/agent_types/llm_functions_agent.py +220 -0
- bisheng_langchain/gpts/assistant.py +137 -0
- bisheng_langchain/gpts/auto_optimization.py +130 -0
- bisheng_langchain/gpts/auto_tool_selected.py +54 -0
- bisheng_langchain/gpts/load_tools.py +161 -0
- bisheng_langchain/gpts/message_types.py +11 -0
- bisheng_langchain/gpts/prompts/__init__.py +15 -0
- bisheng_langchain/gpts/prompts/assistant_prompt_opt.py +95 -0
- bisheng_langchain/gpts/prompts/base_prompt.py +1 -0
- bisheng_langchain/gpts/prompts/breif_description_prompt.py +104 -0
- bisheng_langchain/gpts/prompts/opening_dialog_prompt.py +118 -0
- bisheng_langchain/gpts/prompts/select_tools_prompt.py +29 -0
- bisheng_langchain/gpts/tools/__init__.py +0 -0
- bisheng_langchain/gpts/tools/api_tools/__init__.py +50 -0
- bisheng_langchain/gpts/tools/api_tools/base.py +90 -0
- bisheng_langchain/gpts/tools/api_tools/flow.py +59 -0
- bisheng_langchain/gpts/tools/api_tools/macro_data.py +397 -0
- bisheng_langchain/gpts/tools/api_tools/sina.py +221 -0
- bisheng_langchain/gpts/tools/api_tools/tianyancha.py +160 -0
- bisheng_langchain/gpts/tools/bing_search/__init__.py +0 -0
- bisheng_langchain/gpts/tools/bing_search/tool.py +55 -0
- bisheng_langchain/gpts/tools/calculator/__init__.py +0 -0
- bisheng_langchain/gpts/tools/calculator/tool.py +25 -0
- bisheng_langchain/gpts/tools/code_interpreter/__init__.py +0 -0
- bisheng_langchain/gpts/tools/code_interpreter/tool.py +261 -0
- bisheng_langchain/gpts/tools/dalle_image_generator/__init__.py +0 -0
- bisheng_langchain/gpts/tools/dalle_image_generator/tool.py +181 -0
- bisheng_langchain/gpts/tools/get_current_time/__init__.py +0 -0
- bisheng_langchain/gpts/tools/get_current_time/tool.py +23 -0
- bisheng_langchain/gpts/utils.py +197 -0
- bisheng_langchain/utils/requests.py +5 -1
- bisheng_langchain/vectorstores/milvus.py +1 -1
- {bisheng_langchain-0.2.3.2.dist-info → bisheng_langchain-0.3.0.dist-info}/METADATA +5 -2
- {bisheng_langchain-0.2.3.2.dist-info → bisheng_langchain-0.3.0.dist-info}/RECORD +45 -12
- {bisheng_langchain-0.2.3.2.dist-info → bisheng_langchain-0.3.0.dist-info}/WHEEL +0 -0
- {bisheng_langchain-0.2.3.2.dist-info → bisheng_langchain-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,160 @@
|
|
1
|
+
"""tianyancha api"""
|
2
|
+
from __future__ import annotations
|
3
|
+
|
4
|
+
from typing import Any, Dict, Type
|
5
|
+
|
6
|
+
from bisheng_langchain.utils.requests import Requests, RequestsWrapper
|
7
|
+
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
8
|
+
|
9
|
+
from .base import APIToolBase
|
10
|
+
|
11
|
+
|
12
|
+
class InputArgs(BaseModel):
|
13
|
+
"""args_schema"""
|
14
|
+
query: str = Field(description='搜索关键字(公司名称、公司id、注册号或社会统一信用代码)')
|
15
|
+
|
16
|
+
|
17
|
+
class CompanyInfo(APIToolBase):
|
18
|
+
"""Manage tianyancha company client."""
|
19
|
+
api_key: str = None
|
20
|
+
args_schema: Type[BaseModel] = InputArgs
|
21
|
+
|
22
|
+
@root_validator(pre=True)
|
23
|
+
def build_header(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
24
|
+
"""Build headers that were passed in."""
|
25
|
+
if not values.get('api_key'):
|
26
|
+
raise ValueError('Parameters api_key should be specified give.')
|
27
|
+
|
28
|
+
headers = values.get('headers', {})
|
29
|
+
headers.update({'Authorization': values['api_key']})
|
30
|
+
values['headers'] = headers
|
31
|
+
return values
|
32
|
+
|
33
|
+
@root_validator()
|
34
|
+
def validate_environment(cls, values: Dict) -> Dict:
|
35
|
+
"""Validate that api key and python package exists in environment."""
|
36
|
+
timeout = values.get('request_timeout', 30)
|
37
|
+
if not values.get('client'):
|
38
|
+
values['client'] = Requests(headers=values['headers'], request_timeout=timeout)
|
39
|
+
if not values.get('async_client'):
|
40
|
+
values['async_client'] = RequestsWrapper(headers=values['headers'],
|
41
|
+
request_timeout=timeout)
|
42
|
+
return values
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def search_company(cls, api_key: str, pageNum: int = 1, pageSize: int = 20) -> CompanyInfo:
|
46
|
+
"""可以通过关键词获取企业列表,企业列表包括公司名称或ID、类型、成立日期、经营状态、统一社会信用代码等字段的详细信息"""
|
47
|
+
url = 'http://open.api.tianyancha.com/services/open/search/2.0'
|
48
|
+
input_key = 'word'
|
49
|
+
params = {}
|
50
|
+
params['pageSize'] = pageSize
|
51
|
+
params['pageNum'] = pageNum
|
52
|
+
|
53
|
+
return cls(url=url, api_key=api_key, input_key=input_key, params=params)
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def get_company_baseinfo(cls, api_key: str) -> CompanyInfo:
|
57
|
+
"""可以通过公司名称或ID获取企业基本信息,企业基本信息包括公司名称或ID、类型、成立日期、经营状态、注册资本、法人、工商注册号、统一社会信用代码、组织机构代码、纳税人识别号等字段信息"""
|
58
|
+
url = 'http://open.api.tianyancha.com/services/open/ic/baseinfo/normal'
|
59
|
+
input_key = 'keyword'
|
60
|
+
params = {}
|
61
|
+
|
62
|
+
return cls(url=url, api_key=api_key, input_key=input_key, params=params)
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def ip_rights(cls, api_key: str) -> CompanyInfo:
|
66
|
+
"""可以通过公司名称或ID获取包含商标、专利、作品著作权、软件著作权、网站备案等维度的相关信息"""
|
67
|
+
url = 'http://open.api.tianyancha.com/services/open/cb/ipr/3.0'
|
68
|
+
input_key = 'keyword'
|
69
|
+
|
70
|
+
return cls(url=url, api_key=api_key, input_key=input_key, params={})
|
71
|
+
|
72
|
+
@classmethod
|
73
|
+
def judicial_risk(cls, api_key: str) -> CompanyInfo:
|
74
|
+
"""可以通过公司名称或ID获取包含法律诉讼、法院公告、开庭公告、失信人、被执行人、立案信息、送达公告等维度的相关信息"""
|
75
|
+
url = 'http://open.api.tianyancha.com/services/open/cb/judicial/2.0'
|
76
|
+
return cls(url=url, api_key=api_key)
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def ic_info(cls, api_key: str) -> CompanyInfo:
|
80
|
+
"""可以通过公司名称或ID获取包含企业基本信息、主要人员、股东信息、对外投资、分支机构等维度的相关信息"""
|
81
|
+
url = 'http://open.api.tianyancha.com/services/open/cb/ic/2.0'
|
82
|
+
|
83
|
+
return cls(url=url, api_key=api_key)
|
84
|
+
|
85
|
+
@classmethod
|
86
|
+
def law_suit_case(cls, api_key: str, pageSize: int = 20, pageNum: int = 1) -> CompanyInfo:
|
87
|
+
"""可以通过公司名称或ID获取企业法律诉讼信息,法律诉讼包括案件名称、案由、案件身份、案号等字段的详细信息"""
|
88
|
+
url = 'http://open.api.tianyancha.com/services/open/jr/lawSuit/3.0'
|
89
|
+
params = {}
|
90
|
+
params['pageSize'] = pageSize
|
91
|
+
params['pageNum'] = pageNum
|
92
|
+
return cls(url=url, api_key=api_key, params=params)
|
93
|
+
|
94
|
+
@classmethod
|
95
|
+
def company_change_info(cls,
|
96
|
+
api_key: str,
|
97
|
+
pageSize: int = 20,
|
98
|
+
pageNum: int = 1) -> CompanyInfo:
|
99
|
+
"""可以通过公司名称或ID获取企业变更记录,变更记录包括工商变更事项、变更前后信息等字段的详细信息"""
|
100
|
+
url = 'http://open.api.tianyancha.com/services/open/ic/changeinfo/2.0'
|
101
|
+
params = {}
|
102
|
+
params['pageSize'] = pageSize
|
103
|
+
params['pageNum'] = pageNum
|
104
|
+
return cls(url=url, api_key=api_key, params=params)
|
105
|
+
|
106
|
+
@classmethod
|
107
|
+
def company_holders(cls, api_key: str, pageSize: int = 20, pageNum: int = 1) -> CompanyInfo:
|
108
|
+
"""可以通过公司名称或ID获取企业股东信息,股东信息包括股东名、出资比例、出资金额、股东总数等字段的详细信息"""
|
109
|
+
url = 'http://open.api.tianyancha.com/services/open/ic/holder/2.0'
|
110
|
+
params = {}
|
111
|
+
params['pageSize'] = pageSize
|
112
|
+
params['pageNum'] = pageNum
|
113
|
+
return cls(url=url, api_key=api_key, params=params)
|
114
|
+
|
115
|
+
@classmethod
|
116
|
+
def all_companys_by_company(cls, api_key: str, pageSize: int = 20, pageNum: int = 1):
|
117
|
+
"""可以通过公司名称获取企业人员的所有相关公司,包括其担任法人、股东、董监高的公司信息"""
|
118
|
+
url = 'http://open.api.tianyancha.com/services/v4/open/allCompanys'
|
119
|
+
input_key = 'name'
|
120
|
+
params = {}
|
121
|
+
params['pageSize'] = pageSize
|
122
|
+
params['pageNum'] = pageNum
|
123
|
+
|
124
|
+
class InputArgs(BaseModel):
|
125
|
+
"""args_schema"""
|
126
|
+
query: str = Field(description='company name to query')
|
127
|
+
|
128
|
+
return cls(url=url,
|
129
|
+
api_key=api_key,
|
130
|
+
params=params,
|
131
|
+
input_key=input_key,
|
132
|
+
args_schema=InputArgs)
|
133
|
+
|
134
|
+
@classmethod
|
135
|
+
def all_companys_by_humanname(cls,
|
136
|
+
api_key: str,
|
137
|
+
pageSize: int = 20,
|
138
|
+
pageNum: int = 1) -> CompanyInfo:
|
139
|
+
"""可以通过人名获取企业人员的所有相关公司,包括其担任法人、股东、董监高的公司信息"""
|
140
|
+
url = 'http://open.api.tianyancha.com/services/v4/open/allCompanys'
|
141
|
+
input_key = 'humanName'
|
142
|
+
params = {}
|
143
|
+
params['pageSize'] = pageSize
|
144
|
+
params['pageNum'] = pageNum
|
145
|
+
|
146
|
+
class InputArgs(BaseModel):
|
147
|
+
"""args_schema"""
|
148
|
+
query: str = Field(description='human name to query')
|
149
|
+
|
150
|
+
return cls(url=url,
|
151
|
+
api_key=api_key,
|
152
|
+
params=params,
|
153
|
+
input_key=input_key,
|
154
|
+
args_schema=InputArgs)
|
155
|
+
|
156
|
+
@classmethod
|
157
|
+
def riskinfo(cls, api_key: str) -> CompanyInfo:
|
158
|
+
"""可以通过关键字(公司名称、公司id、注册号或社会统一信用代码)获取企业相关天眼风险列表,包括企业自身/周边/预警风险信息。"""
|
159
|
+
url = 'http://open.api.tianyancha.com/services/open/risk/riskInfo/2.0'
|
160
|
+
return cls(url=url, api_key=api_key)
|
File without changes
|
@@ -0,0 +1,55 @@
|
|
1
|
+
"""Tool for the Bing search API."""
|
2
|
+
|
3
|
+
from typing import Optional, Type
|
4
|
+
|
5
|
+
from langchain.pydantic_v1 import BaseModel, Field
|
6
|
+
from langchain_community.utilities.bing_search import BingSearchAPIWrapper
|
7
|
+
from langchain_core.callbacks import CallbackManagerForToolRun
|
8
|
+
from langchain_core.tools import BaseTool
|
9
|
+
|
10
|
+
|
11
|
+
class BingSearchInput(BaseModel):
|
12
|
+
query: str = Field(description="query to look up in Bing search")
|
13
|
+
|
14
|
+
|
15
|
+
class BingSearchRun(BaseTool):
|
16
|
+
"""Tool that queries the Bing search API."""
|
17
|
+
|
18
|
+
name: str = "bing_search"
|
19
|
+
description: str = (
|
20
|
+
"A wrapper around Bing Search. "
|
21
|
+
"Useful for when you need to answer questions about current events. "
|
22
|
+
"Input should be a search query."
|
23
|
+
)
|
24
|
+
args_schema: Type[BaseModel] = BingSearchInput
|
25
|
+
api_wrapper: BingSearchAPIWrapper
|
26
|
+
|
27
|
+
def _run(
|
28
|
+
self,
|
29
|
+
query: str,
|
30
|
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
31
|
+
) -> str:
|
32
|
+
"""Use the tool."""
|
33
|
+
return self.api_wrapper.run(query)
|
34
|
+
|
35
|
+
|
36
|
+
class BingSearchResults(BaseTool):
|
37
|
+
"""Tool that queries the Bing Search API and gets back json."""
|
38
|
+
|
39
|
+
name: str = "bing_search_results_json"
|
40
|
+
description: str = (
|
41
|
+
"A wrapper around Bing Search. "
|
42
|
+
"Useful for when you need to answer questions about current events. "
|
43
|
+
"Input should be a search query. Output is a JSON array of the query results"
|
44
|
+
)
|
45
|
+
num_results: int = 4
|
46
|
+
args_schema = BingSearchInput
|
47
|
+
api_wrapper: BingSearchAPIWrapper
|
48
|
+
|
49
|
+
def _run(
|
50
|
+
self,
|
51
|
+
query: str,
|
52
|
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
53
|
+
) -> str:
|
54
|
+
"""Use the tool."""
|
55
|
+
return str(self.api_wrapper.results(query, self.num_results))
|
File without changes
|
@@ -0,0 +1,25 @@
|
|
1
|
+
import math
|
2
|
+
from math import *
|
3
|
+
|
4
|
+
import sympy
|
5
|
+
from langchain.pydantic_v1 import BaseModel, Field
|
6
|
+
from langchain.tools import tool
|
7
|
+
from sympy import *
|
8
|
+
|
9
|
+
|
10
|
+
class CalculatorInput(BaseModel):
|
11
|
+
expression: str = Field(
|
12
|
+
description="The input to this tool should be a mathematical expression using only Python's built-in mathematical operators.",
|
13
|
+
example='200*7',
|
14
|
+
)
|
15
|
+
|
16
|
+
|
17
|
+
@tool("calculator", args_schema=CalculatorInput)
|
18
|
+
def calculator(expression):
|
19
|
+
"""Useful to perform any mathematical calculations,
|
20
|
+
like sum, minus, multiplication, division, etc
|
21
|
+
"""
|
22
|
+
try:
|
23
|
+
return eval(expression)
|
24
|
+
except SyntaxError:
|
25
|
+
return "Error: Invalid syntax in mathematical expression"
|
File without changes
|
@@ -0,0 +1,261 @@
|
|
1
|
+
import itertools
|
2
|
+
import os
|
3
|
+
import pathlib
|
4
|
+
import subprocess
|
5
|
+
import sys
|
6
|
+
import tempfile
|
7
|
+
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
8
|
+
from datetime import timedelta
|
9
|
+
from hashlib import md5
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Dict, List, Optional, Tuple, Type
|
12
|
+
from uuid import uuid4
|
13
|
+
|
14
|
+
from autogen.code_utils import extract_code, infer_lang
|
15
|
+
from langchain_community.tools import Tool
|
16
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
17
|
+
from loguru import logger
|
18
|
+
|
19
|
+
try:
|
20
|
+
from termcolor import colored
|
21
|
+
except ImportError:
|
22
|
+
|
23
|
+
def colored(x, *args, **kwargs):
|
24
|
+
return x
|
25
|
+
|
26
|
+
|
27
|
+
DEFAULT_TIMEOUT = 600
|
28
|
+
WIN32 = sys.platform == 'win32'
|
29
|
+
PATH_SEPARATOR = WIN32 and '\\' or '/'
|
30
|
+
WORKING_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'extensions')
|
31
|
+
TIMEOUT_MSG = 'Timeout'
|
32
|
+
|
33
|
+
|
34
|
+
def _cmd(lang):
|
35
|
+
if lang.startswith('python') or lang in ['bash', 'sh', 'powershell']:
|
36
|
+
return lang
|
37
|
+
if lang in ['shell']:
|
38
|
+
return 'sh'
|
39
|
+
if lang in ['ps1']:
|
40
|
+
return 'powershell'
|
41
|
+
raise NotImplementedError(f'{lang} not recognized in code execution')
|
42
|
+
|
43
|
+
|
44
|
+
def execute_code(
|
45
|
+
code: Optional[str] = None,
|
46
|
+
timeout: Optional[int] = None,
|
47
|
+
filename: Optional[str] = None,
|
48
|
+
work_dir: Optional[str] = None,
|
49
|
+
lang: Optional[str] = 'python',
|
50
|
+
) -> Tuple[int, str, str]:
|
51
|
+
if all((code is None, filename is None)):
|
52
|
+
error_msg = f'Either {code=} or {filename=} must be provided.'
|
53
|
+
logger.error(error_msg)
|
54
|
+
raise AssertionError(error_msg)
|
55
|
+
|
56
|
+
timeout = timeout or DEFAULT_TIMEOUT
|
57
|
+
original_filename = filename
|
58
|
+
|
59
|
+
if filename is None:
|
60
|
+
code_hash = md5(code.encode()).hexdigest()
|
61
|
+
# create a file with a automatically generated name
|
62
|
+
filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
|
63
|
+
if work_dir is None:
|
64
|
+
work_dir = WORKING_DIR
|
65
|
+
filepath = os.path.join(work_dir, filename)
|
66
|
+
file_dir = os.path.dirname(filepath)
|
67
|
+
os.makedirs(file_dir, exist_ok=True)
|
68
|
+
(Path(file_dir) / 'output').mkdir(exist_ok=True, parents=True)
|
69
|
+
if code is not None:
|
70
|
+
with open(filepath, 'w', encoding='utf-8') as fout:
|
71
|
+
fout.write(code)
|
72
|
+
|
73
|
+
cmd = [
|
74
|
+
sys.executable if lang.startswith('python') else _cmd(lang),
|
75
|
+
f'.\\{filename}' if WIN32 else filename,
|
76
|
+
]
|
77
|
+
if WIN32:
|
78
|
+
logger.warning('SIGALRM is not supported on Windows. No timeout will be enforced.')
|
79
|
+
result = subprocess.run(
|
80
|
+
cmd,
|
81
|
+
cwd=work_dir,
|
82
|
+
capture_output=True,
|
83
|
+
text=True,
|
84
|
+
)
|
85
|
+
else:
|
86
|
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
87
|
+
future = executor.submit(
|
88
|
+
subprocess.run,
|
89
|
+
cmd,
|
90
|
+
cwd=work_dir,
|
91
|
+
capture_output=True,
|
92
|
+
text=True,
|
93
|
+
)
|
94
|
+
try:
|
95
|
+
result = future.result(timeout=timeout)
|
96
|
+
except TimeoutError:
|
97
|
+
if original_filename is None:
|
98
|
+
os.remove(filepath)
|
99
|
+
return 1, TIMEOUT_MSG, None
|
100
|
+
if original_filename is None:
|
101
|
+
os.remove(filepath)
|
102
|
+
if result.returncode:
|
103
|
+
logs = result.stderr
|
104
|
+
if original_filename is None:
|
105
|
+
abs_path = str(pathlib.Path(filepath).absolute())
|
106
|
+
logs = logs.replace(str(abs_path), '').replace(filename, '')
|
107
|
+
else:
|
108
|
+
abs_path = str(pathlib.Path(work_dir).absolute()) + PATH_SEPARATOR
|
109
|
+
logs = logs.replace(str(abs_path), '')
|
110
|
+
else:
|
111
|
+
logs = result.stdout
|
112
|
+
return result.returncode, logs, None
|
113
|
+
|
114
|
+
|
115
|
+
def head_file(path: str, n: int) -> List[str]:
|
116
|
+
"""Get the first n lines of a file."""
|
117
|
+
try:
|
118
|
+
with open(path, 'r') as f:
|
119
|
+
return [str(line) for line in itertools.islice(f, n)]
|
120
|
+
except Exception:
|
121
|
+
return []
|
122
|
+
|
123
|
+
|
124
|
+
def upload_minio(param: dict, bucket: str, object_name: str, file_path, content_type='application/text'):
|
125
|
+
# 初始化minio
|
126
|
+
import minio
|
127
|
+
|
128
|
+
minio_client = minio.Minio(**param)
|
129
|
+
logger.debug('upload_file obj={} bucket={} file_paht={}', object_name, bucket, file_path)
|
130
|
+
minio_client.fput_object(
|
131
|
+
bucket_name=bucket, object_name=object_name, file_path=file_path, content_type=content_type
|
132
|
+
)
|
133
|
+
return minio_client.presigned_get_object(bucket_name=bucket, object_name=object_name, expires=timedelta(days=7))
|
134
|
+
|
135
|
+
|
136
|
+
class CodeInterpreterToolArguments(BaseModel):
|
137
|
+
"""Arguments for the BearlyInterpreterTool."""
|
138
|
+
|
139
|
+
python_code: str = Field(
|
140
|
+
...,
|
141
|
+
example="print('Hello World')",
|
142
|
+
description=(
|
143
|
+
'The pure python script to be evaluated. '
|
144
|
+
'The contents will be in main.py. '
|
145
|
+
'It should not be in markdown format.'
|
146
|
+
),
|
147
|
+
)
|
148
|
+
|
149
|
+
|
150
|
+
base_description = """Evaluates python code in native environment. \
|
151
|
+
You must send the whole script every time and print your outputs. \
|
152
|
+
Script should be pure python code that can be evaluated. \
|
153
|
+
It should be in python format NOT markdown. \
|
154
|
+
The code should NOT be wrapped in backticks. \
|
155
|
+
If you have any files outputted write them to "output/" relative to the execution \
|
156
|
+
path. Output can only be read from the directory, stdout, and stdin. \
|
157
|
+
Do not use things like plot.show() as it will \
|
158
|
+
not work instead write them out `output/`\
|
159
|
+
print() any output and results so you can capture the output.""" # noqa: T201
|
160
|
+
|
161
|
+
|
162
|
+
class FileInfo(BaseModel):
|
163
|
+
"""Information about a file to be uploaded."""
|
164
|
+
|
165
|
+
source_path: str
|
166
|
+
description: str
|
167
|
+
|
168
|
+
|
169
|
+
class CodeInterpreterTool:
|
170
|
+
"""Tool for evaluating python code in native environment."""
|
171
|
+
|
172
|
+
name = 'code_interpreter'
|
173
|
+
args_schema: Type[BaseModel] = CodeInterpreterToolArguments
|
174
|
+
|
175
|
+
def __init__(
|
176
|
+
self,
|
177
|
+
minio: Dict[str, any],
|
178
|
+
files: Dict[str, FileInfo] = None,
|
179
|
+
) -> None:
|
180
|
+
self.minio = minio
|
181
|
+
self.files = files if files else {}
|
182
|
+
|
183
|
+
@property
|
184
|
+
def file_description(self) -> str:
|
185
|
+
if not len(self.files) or not isinstance(self.files, dict):
|
186
|
+
return ''
|
187
|
+
lines = ['The following files available in the evaluation environment:']
|
188
|
+
for source_path, file_info in self.files.items():
|
189
|
+
peek_content = head_file(file_info.source_path, 4)
|
190
|
+
lines.append(
|
191
|
+
f'- path: `{file_info.source_path}` \n first four lines: {peek_content}'
|
192
|
+
f' \n description: `{file_info.description}`'
|
193
|
+
)
|
194
|
+
return '\n'.join(lines)
|
195
|
+
|
196
|
+
@property
|
197
|
+
def description(self) -> str:
|
198
|
+
return (base_description + '\n\n' + self.file_description).strip()
|
199
|
+
|
200
|
+
def _run(self, code_string: str) -> dict:
|
201
|
+
code_blocks = extract_code(code_string)
|
202
|
+
logs_all = ''
|
203
|
+
file_list = []
|
204
|
+
for i, code_block in enumerate(code_blocks):
|
205
|
+
lang, code = code_block
|
206
|
+
lang = infer_lang(code)
|
207
|
+
temp_dir = tempfile.TemporaryDirectory()
|
208
|
+
exitcode, logs, _ = execute_code(
|
209
|
+
code,
|
210
|
+
work_dir=temp_dir.name,
|
211
|
+
lang=lang,
|
212
|
+
)
|
213
|
+
logs_all += '\n' + logs
|
214
|
+
if exitcode != 0:
|
215
|
+
return {'exitcode': exitcode, 'log': logs_all}
|
216
|
+
|
217
|
+
# 获取文件
|
218
|
+
temp_output_dir = Path(temp_dir.name) / 'output'
|
219
|
+
for root, dirs, files in os.walk(temp_output_dir):
|
220
|
+
for name in files:
|
221
|
+
file_name = os.path.join(root, name)
|
222
|
+
if self.minio:
|
223
|
+
file_type = file_name.rsplit('.', 1)[-1]
|
224
|
+
object_name = uuid4().hex
|
225
|
+
file_list.append(upload_minio(self.minio, 'bisheng', f'{object_name}.{file_type}', file_name))
|
226
|
+
else:
|
227
|
+
file_list.append(file_name)
|
228
|
+
temp_dir.cleanup()
|
229
|
+
|
230
|
+
return {'exitcode': 0, 'log': logs_all, 'file_list': file_list}
|
231
|
+
|
232
|
+
def as_tool(self) -> Tool:
|
233
|
+
return Tool.from_function(
|
234
|
+
func=self._run,
|
235
|
+
name=self.name,
|
236
|
+
description=self.description,
|
237
|
+
args_schema=self.args_schema,
|
238
|
+
)
|
239
|
+
|
240
|
+
|
241
|
+
if __name__ == '__main__':
|
242
|
+
code_string = """print('hha')"""
|
243
|
+
code_blocks = extract_code(code_string)
|
244
|
+
logger.info(code_blocks)
|
245
|
+
logs_all = ''
|
246
|
+
for i, code_block in enumerate(code_blocks):
|
247
|
+
lang, code = code_block
|
248
|
+
lang = infer_lang(code)
|
249
|
+
print(
|
250
|
+
colored(
|
251
|
+
f'\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...',
|
252
|
+
'red',
|
253
|
+
),
|
254
|
+
flush=True,
|
255
|
+
)
|
256
|
+
exitcode, logs, image = execute_code(code, lang=lang)
|
257
|
+
logs_all += '\n' + logs
|
258
|
+
if exitcode != 0:
|
259
|
+
logger.error(f'{exitcode}, {logs_all}')
|
260
|
+
|
261
|
+
logger.info(logs_all)
|
File without changes
|