hjxdl 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.3.7'
16
- __version_tuple__ = version_tuple = (0, 3, 7)
15
+ __version__ = version = '0.3.9'
16
+ __version_tuple__ = version_tuple = (0, 3, 9)
hdl/utils/llm/chat.py CHANGED
@@ -5,6 +5,7 @@ from concurrent.futures import ProcessPoolExecutor
5
5
  import subprocess
6
6
  from typing import Generator
7
7
  import re
8
+ import yaml
8
9
 
9
10
  from openai import OpenAI
10
11
  from PIL import Image
@@ -12,8 +13,6 @@ from PIL import Image
12
13
  from ..desc.template import FN_TEMPLATE, COT_TEMPLATE, OD_TEMPLATE
13
14
  from ..desc.func_desc import TOOL_DESC
14
15
  from .vis import draw_and_plot_boxes_from_json, to_img, to_base64
15
- # import json
16
- # import traceback
17
16
 
18
17
 
19
18
  def parse_fn_markdown(markdown_text, params_key="params"):
@@ -121,16 +120,11 @@ def run_tool_with_kwargs(tool, func_kwargs):
121
120
  return tool(**func_kwargs)
122
121
 
123
122
 
124
- class OpenAI_M():
123
+ class OpenAI_M:
125
124
  def __init__(
126
125
  self,
127
- model_path: str = "default_model",
128
- device: str='gpu',
129
- generation_kwargs: dict = None,
130
- server_host: str = None,
131
- server_ip: str = None,
132
- server_port: int = 11434,
133
- api_key: str = "dummy_key",
126
+ client_conf: dict = None,
127
+ client_conf_dir: str = None,
134
128
  tools: list = None,
135
129
  tool_desc: dict = None,
136
130
  cot_desc: str = None,
@@ -148,40 +142,43 @@ class OpenAI_M():
148
142
  server_ip (str): IP address of the server. Defaults to "172.28.1.2".
149
143
  server_port (int): Port number of the server. Defaults to 8000.
150
144
  api_key (str): API key for authentication. Defaults to "dummy_key".
145
+ use_groq (bool): Flag to use Groq client. Defaults to False.
146
+ groq_api_key (str, optional): API key for Groq client.
151
147
  tools (list, optional): List of tools to be used.
152
148
  tool_desc (dict, optional): Additional tool descriptions.
153
149
  cot_desc (str, optional): Chain of Thought description.
154
150
  *args: Additional positional arguments.
155
151
  **kwargs: Additional keyword arguments.
156
152
  """
157
- # self.model_path = model_path
158
- self.server_ip = server_ip
159
- self.server_port = server_port
160
- self.server_host = server_host
161
- if self.server_ip:
162
- self.base_url = f"http://{self.server_ip}:{str(self.server_port)}/v1"
163
- elif self.server_host:
164
- self.base_url = self.server_host
165
- self.api_key = api_key
166
-
167
- self.client = OpenAI(
168
- base_url=self.base_url,
169
- api_key=self.api_key,
170
- *args,
171
- **kwargs
172
- )
153
+
154
+ if client_conf is None:
155
+ assert client_conf_dir is not None
156
+ self.client_conf_path = client_conf_dir
157
+ self.load_clients()
158
+ else:
159
+ self.client_conf = client_conf
160
+
161
+ # self.clients = {}
162
+ for client_id, conf in self.client_conf.items():
163
+ conf[client_id]["client"] = OpenAI(
164
+ base_url=conf["host"],
165
+ api_key=conf.get("api_key", "dummy_key"),
166
+ *args,
167
+ **kwargs
168
+ )
169
+
173
170
  self.tools: list = tools if tools else []
174
171
  self.tool_desc: dict = TOOL_DESC
175
172
  if tool_desc is not None:
176
173
  self.tool_desc = self.tool_desc | tool_desc
177
174
 
178
175
  self.tool_descs = [
179
- self.tool_desc[tool.__name__]['desc']
176
+ self.tool_desc[tool]['desc']
180
177
  for tool in self.tools
181
178
  ]
182
179
  self.tool_descs_verbose = [
183
- self.tool_desc[tool.__name__]['desc']
184
- + self.tool_desc[tool.__name__]['md']
180
+ self.tool_desc[tool]['desc']
181
+ + self.tool_desc[tool]['md']
185
182
  for tool in self.tools
186
183
  ]
187
184
 
@@ -191,6 +188,18 @@ class OpenAI_M():
191
188
  self.cot_desc = cot_desc if cot_desc else COT_TEMPLATE
192
189
  self.od_desc = od_desc if od_desc else OD_TEMPLATE
193
190
 
191
+ def load_clients(self):
192
+ with open(self.client_conf_path, 'r') as file:
193
+ data = yaml.safe_load(file)
194
+
195
+ # 更新 host 字段
196
+ for _, value in data.items():
197
+ host = value.get('host', '')
198
+ port = value.get('port', '')
199
+ if not host.startswith('http') and port: # 确保有 port 才处理
200
+ value['host'] = f"http://{host}:{port}/v1"
201
+ self.client_conf = data
202
+
194
203
  def cot(
195
204
  self,
196
205
  prompt,
@@ -198,23 +207,6 @@ class OpenAI_M():
198
207
  steps: list = None,
199
208
  **kwargs
200
209
  ):
201
- """
202
- Execute a Chain of Thought (COT) process to iteratively generate steps
203
- towards solving a given prompt, utilizing tools if necessary.
204
-
205
- Args:
206
- prompt (str): The initial question or problem to solve.
207
- max_step (int, optional): Maximum number of steps to attempt. Defaults to 30.
208
- steps (list, optional): List to accumulate steps taken. Defaults to None.
209
- **kwargs: Additional keyword arguments for tool invocation.
210
-
211
- Yields:
212
- tuple: A tuple containing the current step number, accumulated information,
213
- and the list of steps taken.
214
-
215
- Raises:
216
- Exception: If an error occurs during the parsing or tool invocation process.
217
- """
218
210
  # 初始化当前信息为空字符串,用于累积后续的思考步骤和用户问题
219
211
  current_info = ""
220
212
  # 初始化步数为0,用于控制最大思考次数
@@ -235,7 +227,8 @@ class OpenAI_M():
235
227
  resp = self.invoke(
236
228
  "现有的步骤得出来的信息:\n" + current_info + "\n用户问题:" + prompt,
237
229
  sys_info=COT_TEMPLATE + self.tool_info,
238
- assis_info = "好的,我将根据用户的问题和信息给出当前需要进行的操作或最终答案"
230
+ assis_info="好的,我将根据用户的问题和信息给出当前需要进行的操作或最终答案",
231
+ **kwargs
239
232
  )
240
233
 
241
234
  # print(f"第{n_steps}步思考结果:\n{resp}\n\n")
@@ -249,7 +242,7 @@ class OpenAI_M():
249
242
  # 如果思考步骤中标记为停止思考,则打印所有步骤并返回最终答案
250
243
 
251
244
  # 如果思考步骤中包含使用工具的指示,则构造工具提示并调用agent_response方法
252
- if 'tool' in step_json and step_json['tool']:
245
+ if 'tool' in step_json and step_json['tool'] in self.tools:
253
246
  tool_prompt = step_json["tool"] \
254
247
  + step_json.get("title", "") \
255
248
  + step_json.get("content", "") \
@@ -288,6 +281,7 @@ class OpenAI_M():
288
281
 
289
282
  def get_resp(
290
283
  self,
284
+ client_id,
291
285
  prompt: str,
292
286
  sys_info: str = None,
293
287
  assis_info: str = None,
@@ -364,7 +358,7 @@ class OpenAI_M():
364
358
  })
365
359
 
366
360
  # Call the model to generate a response
367
- response = self.client.chat.completions.create(
361
+ response = self.client_conf[client_id]["client"].chat.completions.create(
368
362
  messages=messages,
369
363
  stream=stream,
370
364
  model=model,
@@ -468,7 +462,8 @@ class OpenAI_M():
468
462
  )
469
463
 
470
464
  def get_decision(
471
- self, prompt: str,
465
+ self,
466
+ prompt: str,
472
467
  **kwargs: t.Any,
473
468
  ):
474
469
  """Get decision based on the given prompt.
@@ -514,13 +509,14 @@ class OpenAI_M():
514
509
  return ""
515
510
  else:
516
511
  try:
512
+ tool_final = ''
517
513
  for tool in self.tools:
518
- if tool.__name__ == func_name:
514
+ if tool == func_name:
519
515
  tool_final = tool
520
516
  func_kwargs = decision_dict.get("params")
521
- if tool_final.__name__ == "object_detect":
517
+ if tool_final == "object_detect":
522
518
  func_kwargs["llm"] = self
523
- return tool_final(**func_kwargs)
519
+ return getattr(self.tools, tool_final)(**func_kwargs)
524
520
  except Exception as e:
525
521
  print(e)
526
522
  return ""
@@ -552,14 +548,15 @@ class OpenAI_M():
552
548
  return ""
553
549
  else:
554
550
  try:
551
+ tool_final = ''
555
552
  for tool in self.tools:
556
- if tool.__name__ == func_name:
553
+ if tool == func_name:
557
554
  tool_final = tool
558
555
  func_kwargs = decision_dict.get("params")
559
556
 
560
557
  loop = asyncio.get_running_loop()
561
558
  with ProcessPoolExecutor() as pool:
562
- result = await loop.run_in_executor(pool, run_tool_with_kwargs, tool_final, func_kwargs)
559
+ result = await loop.run_in_executor(pool, run_tool_with_kwargs, getattr(tools, tool_final), func_kwargs)
563
560
  return result
564
561
  except Exception as e:
565
562
  print(e)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.3.7
3
+ Version: 0.3.9
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
2
- hdl/_version.py,sha256=vVN20516E2VTC9JNgtvqrQNlj5XptaB_a5z2XL8NFxg,411
2
+ hdl/_version.py,sha256=nV2HEiFwTdaOZoFEyVxxG_D8Oq_nlSmX2vHL4jK4h6w,411
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -128,7 +128,7 @@ hdl/utils/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
128
128
  hdl/utils/general/glob.py,sha256=Zuf7WHU0UdUPOs9UrhxmrCiMC8GrHxQU6n3mTThv6yc,1120
129
129
  hdl/utils/general/runners.py,sha256=x7QBolp3MrqNV6L4rB6Ueybr26bqkRFZTuXhY0SwyLk,3061
130
130
  hdl/utils/llm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
131
- hdl/utils/llm/chat.py,sha256=qU-heyMG3xG4s-XCUE5kIkKJXhzre_XDEEDkLC1dWVo,26016
131
+ hdl/utils/llm/chat.py,sha256=u59sIk6e-4-uM-XTR-EjtdKM-rtxlGzYdvNINH4fDAw,25784
132
132
  hdl/utils/llm/chatgr.py,sha256=5F5PJHe8vz3iCfi4TT54DCLRi1UeJshECdVtgvvvao0,3696
133
133
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
134
134
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
@@ -139,7 +139,7 @@ hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
139
139
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
140
140
  hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
141
141
  hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
142
- hjxdl-0.3.7.dist-info/METADATA,sha256=TwlIfhWMN9vxu5k4lmfolsTjNdfwnt4BqWiRU5YWHdk,861
143
- hjxdl-0.3.7.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
144
- hjxdl-0.3.7.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
145
- hjxdl-0.3.7.dist-info/RECORD,,
142
+ hjxdl-0.3.9.dist-info/METADATA,sha256=Jfl4N6_RI8chY8-LsWR28HByq4gFzuW00k0ANjYCl-0,861
143
+ hjxdl-0.3.9.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
144
+ hjxdl-0.3.9.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
145
+ hjxdl-0.3.9.dist-info/RECORD,,
File without changes