hjxdl 0.3.8__py3-none-any.whl → 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/_version.py +2 -2
- hdl/utils/llm/chat.py +52 -55
- {hjxdl-0.3.8.dist-info → hjxdl-0.3.9.dist-info}/METADATA +1 -1
- {hjxdl-0.3.8.dist-info → hjxdl-0.3.9.dist-info}/RECORD +6 -6
- {hjxdl-0.3.8.dist-info → hjxdl-0.3.9.dist-info}/WHEEL +0 -0
- {hjxdl-0.3.8.dist-info → hjxdl-0.3.9.dist-info}/top_level.txt +0 -0
hdl/_version.py
CHANGED
hdl/utils/llm/chat.py
CHANGED
@@ -5,6 +5,7 @@ from concurrent.futures import ProcessPoolExecutor
|
|
5
5
|
import subprocess
|
6
6
|
from typing import Generator
|
7
7
|
import re
|
8
|
+
import yaml
|
8
9
|
|
9
10
|
from openai import OpenAI
|
10
11
|
from PIL import Image
|
@@ -12,8 +13,6 @@ from PIL import Image
|
|
12
13
|
from ..desc.template import FN_TEMPLATE, COT_TEMPLATE, OD_TEMPLATE
|
13
14
|
from ..desc.func_desc import TOOL_DESC
|
14
15
|
from .vis import draw_and_plot_boxes_from_json, to_img, to_base64
|
15
|
-
# import json
|
16
|
-
# import traceback
|
17
16
|
|
18
17
|
|
19
18
|
def parse_fn_markdown(markdown_text, params_key="params"):
|
@@ -121,16 +120,11 @@ def run_tool_with_kwargs(tool, func_kwargs):
|
|
121
120
|
return tool(**func_kwargs)
|
122
121
|
|
123
122
|
|
124
|
-
class OpenAI_M
|
123
|
+
class OpenAI_M:
|
125
124
|
def __init__(
|
126
125
|
self,
|
127
|
-
|
128
|
-
|
129
|
-
generation_kwargs: dict = None,
|
130
|
-
server_host: str = None,
|
131
|
-
server_ip: str = None,
|
132
|
-
server_port: int = 11434,
|
133
|
-
api_key: str = "dummy_key",
|
126
|
+
client_conf: dict = None,
|
127
|
+
client_conf_dir: str = None,
|
134
128
|
tools: list = None,
|
135
129
|
tool_desc: dict = None,
|
136
130
|
cot_desc: str = None,
|
@@ -148,40 +142,43 @@ class OpenAI_M():
|
|
148
142
|
server_ip (str): IP address of the server. Defaults to "172.28.1.2".
|
149
143
|
server_port (int): Port number of the server. Defaults to 8000.
|
150
144
|
api_key (str): API key for authentication. Defaults to "dummy_key".
|
145
|
+
use_groq (bool): Flag to use Groq client. Defaults to False.
|
146
|
+
groq_api_key (str, optional): API key for Groq client.
|
151
147
|
tools (list, optional): List of tools to be used.
|
152
148
|
tool_desc (dict, optional): Additional tool descriptions.
|
153
149
|
cot_desc (str, optional): Chain of Thought description.
|
154
150
|
*args: Additional positional arguments.
|
155
151
|
**kwargs: Additional keyword arguments.
|
156
152
|
"""
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
self.
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
153
|
+
|
154
|
+
if client_conf is None:
|
155
|
+
assert client_conf_dir is not None
|
156
|
+
self.client_conf_path = client_conf_dir
|
157
|
+
self.load_clients()
|
158
|
+
else:
|
159
|
+
self.client_conf = client_conf
|
160
|
+
|
161
|
+
# self.clients = {}
|
162
|
+
for client_id, conf in self.client_conf.items():
|
163
|
+
conf[client_id]["client"] = OpenAI(
|
164
|
+
base_url=conf["host"],
|
165
|
+
api_key=conf.get("api_key", "dummy_key"),
|
166
|
+
*args,
|
167
|
+
**kwargs
|
168
|
+
)
|
169
|
+
|
173
170
|
self.tools: list = tools if tools else []
|
174
171
|
self.tool_desc: dict = TOOL_DESC
|
175
172
|
if tool_desc is not None:
|
176
173
|
self.tool_desc = self.tool_desc | tool_desc
|
177
174
|
|
178
175
|
self.tool_descs = [
|
179
|
-
self.tool_desc[tool
|
176
|
+
self.tool_desc[tool]['desc']
|
180
177
|
for tool in self.tools
|
181
178
|
]
|
182
179
|
self.tool_descs_verbose = [
|
183
|
-
self.tool_desc[tool
|
184
|
-
+ self.tool_desc[tool
|
180
|
+
self.tool_desc[tool]['desc']
|
181
|
+
+ self.tool_desc[tool]['md']
|
185
182
|
for tool in self.tools
|
186
183
|
]
|
187
184
|
|
@@ -191,6 +188,18 @@ class OpenAI_M():
|
|
191
188
|
self.cot_desc = cot_desc if cot_desc else COT_TEMPLATE
|
192
189
|
self.od_desc = od_desc if od_desc else OD_TEMPLATE
|
193
190
|
|
191
|
+
def load_clients(self):
|
192
|
+
with open(self.client_conf_path, 'r') as file:
|
193
|
+
data = yaml.safe_load(file)
|
194
|
+
|
195
|
+
# 更新 host 字段
|
196
|
+
for _, value in data.items():
|
197
|
+
host = value.get('host', '')
|
198
|
+
port = value.get('port', '')
|
199
|
+
if not host.startswith('http') and port: # 确保有 port 才处理
|
200
|
+
value['host'] = f"http://{host}:{port}/v1"
|
201
|
+
self.client_conf = data
|
202
|
+
|
194
203
|
def cot(
|
195
204
|
self,
|
196
205
|
prompt,
|
@@ -198,23 +207,6 @@ class OpenAI_M():
|
|
198
207
|
steps: list = None,
|
199
208
|
**kwargs
|
200
209
|
):
|
201
|
-
"""
|
202
|
-
Execute a Chain of Thought (COT) process to iteratively generate steps
|
203
|
-
towards solving a given prompt, utilizing tools if necessary.
|
204
|
-
|
205
|
-
Args:
|
206
|
-
prompt (str): The initial question or problem to solve.
|
207
|
-
max_step (int, optional): Maximum number of steps to attempt. Defaults to 30.
|
208
|
-
steps (list, optional): List to accumulate steps taken. Defaults to None.
|
209
|
-
**kwargs: Additional keyword arguments for tool invocation.
|
210
|
-
|
211
|
-
Yields:
|
212
|
-
tuple: A tuple containing the current step number, accumulated information,
|
213
|
-
and the list of steps taken.
|
214
|
-
|
215
|
-
Raises:
|
216
|
-
Exception: If an error occurs during the parsing or tool invocation process.
|
217
|
-
"""
|
218
210
|
# 初始化当前信息为空字符串,用于累积后续的思考步骤和用户问题
|
219
211
|
current_info = ""
|
220
212
|
# 初始化步数为0,用于控制最大思考次数
|
@@ -235,7 +227,8 @@ class OpenAI_M():
|
|
235
227
|
resp = self.invoke(
|
236
228
|
"现有的步骤得出来的信息:\n" + current_info + "\n用户问题:" + prompt,
|
237
229
|
sys_info=COT_TEMPLATE + self.tool_info,
|
238
|
-
assis_info
|
230
|
+
assis_info="好的,我将根据用户的问题和信息给出当前需要进行的操作或最终答案",
|
231
|
+
**kwargs
|
239
232
|
)
|
240
233
|
|
241
234
|
# print(f"第{n_steps}步思考结果:\n{resp}\n\n")
|
@@ -249,7 +242,7 @@ class OpenAI_M():
|
|
249
242
|
# 如果思考步骤中标记为停止思考,则打印所有步骤并返回最终答案
|
250
243
|
|
251
244
|
# 如果思考步骤中包含使用工具的指示,则构造工具提示并调用agent_response方法
|
252
|
-
if 'tool' in step_json and step_json['tool']:
|
245
|
+
if 'tool' in step_json and step_json['tool'] in self.tools:
|
253
246
|
tool_prompt = step_json["tool"] \
|
254
247
|
+ step_json.get("title", "") \
|
255
248
|
+ step_json.get("content", "") \
|
@@ -288,6 +281,7 @@ class OpenAI_M():
|
|
288
281
|
|
289
282
|
def get_resp(
|
290
283
|
self,
|
284
|
+
client_id,
|
291
285
|
prompt: str,
|
292
286
|
sys_info: str = None,
|
293
287
|
assis_info: str = None,
|
@@ -364,7 +358,7 @@ class OpenAI_M():
|
|
364
358
|
})
|
365
359
|
|
366
360
|
# Call the model to generate a response
|
367
|
-
response = self.client.chat.completions.create(
|
361
|
+
response = self.client_conf[client_id]["client"].chat.completions.create(
|
368
362
|
messages=messages,
|
369
363
|
stream=stream,
|
370
364
|
model=model,
|
@@ -468,7 +462,8 @@ class OpenAI_M():
|
|
468
462
|
)
|
469
463
|
|
470
464
|
def get_decision(
|
471
|
-
self,
|
465
|
+
self,
|
466
|
+
prompt: str,
|
472
467
|
**kwargs: t.Any,
|
473
468
|
):
|
474
469
|
"""Get decision based on the given prompt.
|
@@ -514,13 +509,14 @@ class OpenAI_M():
|
|
514
509
|
return ""
|
515
510
|
else:
|
516
511
|
try:
|
512
|
+
tool_final = ''
|
517
513
|
for tool in self.tools:
|
518
|
-
if tool
|
514
|
+
if tool == func_name:
|
519
515
|
tool_final = tool
|
520
516
|
func_kwargs = decision_dict.get("params")
|
521
|
-
if tool_final
|
517
|
+
if tool_final == "object_detect":
|
522
518
|
func_kwargs["llm"] = self
|
523
|
-
return tool_final(**func_kwargs)
|
519
|
+
return getattr(self.tools, tool_final)(**func_kwargs)
|
524
520
|
except Exception as e:
|
525
521
|
print(e)
|
526
522
|
return ""
|
@@ -552,14 +548,15 @@ class OpenAI_M():
|
|
552
548
|
return ""
|
553
549
|
else:
|
554
550
|
try:
|
551
|
+
tool_final = ''
|
555
552
|
for tool in self.tools:
|
556
|
-
if tool
|
553
|
+
if tool == func_name:
|
557
554
|
tool_final = tool
|
558
555
|
func_kwargs = decision_dict.get("params")
|
559
556
|
|
560
557
|
loop = asyncio.get_running_loop()
|
561
558
|
with ProcessPoolExecutor() as pool:
|
562
|
-
result = await loop.run_in_executor(pool, run_tool_with_kwargs, tool_final, func_kwargs)
|
559
|
+
result = await loop.run_in_executor(pool, run_tool_with_kwargs, getattr(tools, tool_final), func_kwargs)
|
563
560
|
return result
|
564
561
|
except Exception as e:
|
565
562
|
print(e)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
|
2
|
-
hdl/_version.py,sha256=
|
2
|
+
hdl/_version.py,sha256=nV2HEiFwTdaOZoFEyVxxG_D8Oq_nlSmX2vHL4jK4h6w,411
|
3
3
|
hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
|
5
5
|
hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -128,7 +128,7 @@ hdl/utils/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
|
|
128
128
|
hdl/utils/general/glob.py,sha256=Zuf7WHU0UdUPOs9UrhxmrCiMC8GrHxQU6n3mTThv6yc,1120
|
129
129
|
hdl/utils/general/runners.py,sha256=x7QBolp3MrqNV6L4rB6Ueybr26bqkRFZTuXhY0SwyLk,3061
|
130
130
|
hdl/utils/llm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
131
|
-
hdl/utils/llm/chat.py,sha256=
|
131
|
+
hdl/utils/llm/chat.py,sha256=u59sIk6e-4-uM-XTR-EjtdKM-rtxlGzYdvNINH4fDAw,25784
|
132
132
|
hdl/utils/llm/chatgr.py,sha256=5F5PJHe8vz3iCfi4TT54DCLRi1UeJshECdVtgvvvao0,3696
|
133
133
|
hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
|
134
134
|
hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
|
@@ -139,7 +139,7 @@ hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
|
|
139
139
|
hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
|
140
140
|
hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
141
141
|
hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
|
142
|
-
hjxdl-0.3.
|
143
|
-
hjxdl-0.3.
|
144
|
-
hjxdl-0.3.
|
145
|
-
hjxdl-0.3.
|
142
|
+
hjxdl-0.3.9.dist-info/METADATA,sha256=Jfl4N6_RI8chY8-LsWR28HByq4gFzuW00k0ANjYCl-0,861
|
143
|
+
hjxdl-0.3.9.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
|
144
|
+
hjxdl-0.3.9.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
|
145
|
+
hjxdl-0.3.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|