hjxdl 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.8'
16
- __version_tuple__ = version_tuple = (0, 1, 8)
15
+ __version__ = version = '0.1.9'
16
+ __version_tuple__ = version_tuple = (0, 1, 9)
hdl/utils/llm/chat.py CHANGED
@@ -8,7 +8,7 @@ from openai import OpenAI
8
8
  def chat_oai_stream(
9
9
  base_url="http://127.0.0.1:8000/v1",
10
10
  api_key="dummy_key",
11
- model="/data/models/Qwen-7B-Chat-Int4",
11
+ model="default_model",
12
12
  prompt="Who are you?",
13
13
  *args,
14
14
  **kwargs
@@ -48,7 +48,7 @@ def chat_oai_stream(
48
48
  def chat_oai_invoke(
49
49
  base_url="http://127.0.0.1:8000/v1",
50
50
  api_key="dummy_key",
51
- model="/data/models/Qwen-7B-Chat-Int4",
51
+ model="default_model",
52
52
  prompt="Who are you?",
53
53
  *args,
54
54
  **kwargs
@@ -318,11 +318,14 @@ class GGUF_M(Llama):
318
318
  class OpenAI_M():
319
319
  def __init__(
320
320
  self,
321
- model_path: str = None,
321
+ model_path: str = "default_model",
322
322
  device: str='gpu',
323
323
  generation_kwargs: dict = {},
324
324
  server_ip: str = "172.28.1.2",
325
325
  server_port: int = 8000,
326
+ api_key: str = "dummy_key",
327
+ *args,
328
+ **kwargs
326
329
  ):
327
330
  """Initialize the class with the specified parameters.
328
331
 
@@ -337,7 +340,14 @@ class OpenAI_M():
337
340
  self.server_ip = server_ip
338
341
  self.server_port = server_port
339
342
  self.base_url = "http://{self.server_ip}:{str(self.server_port)}/v1"
340
-
343
+ self.api_key = api_key
344
+ self.client = OpenAI(
345
+ base_url=self.base_url,
346
+ api_key=self.api_key,
347
+ *args,
348
+ **kwargs
349
+ )
350
+
341
351
  def invoke(
342
352
  self,
343
353
  prompt : str,
@@ -355,33 +365,45 @@ class OpenAI_M():
355
365
  Returns:
356
366
  str: The response generated by the chatbot.
357
367
  """
358
- resp = chat_oai_invoke(
359
- base_url=self.base_url,
368
+ response = self.client.chat.completions.create(
360
369
  model=self.model_path,
361
- prompt=prompt
370
+ messages=[{
371
+ "role": "user",
372
+ "content": prompt
373
+ }],
374
+ stream=False,
375
+ **kwargs
362
376
  )
363
- return resp
364
-
377
+ return response.choices[0].message.content
378
+
365
379
  def stream(
366
380
  self,
367
381
  prompt : str,
368
382
  stop: list[str] | None = ["USER:", "ASSISTANT:"],
369
383
  # history: list = [],
370
384
  **kwargs: t.Any,
371
- ) -> str:
372
- """Generate a response by streaming conversation with the OpenAI chat model.
385
+ ):
386
+ """Generate text completion in a streaming fashion.
373
387
 
374
388
  Args:
375
- prompt (str): The prompt to start the conversation.
376
- stop (list[str], optional): List of strings that indicate when the conversation should stop. Defaults to ["USER:", "ASSISTANT:"].
377
- **kwargs: Additional keyword arguments to pass to the chat model.
389
+ prompt (str): The text prompt to generate completion for.
390
+ stop (list[str], optional): List of strings to stop streaming at. Defaults to ["USER:", "ASSISTANT:"].
391
+ **kwargs: Additional keyword arguments to pass to the completion API.
378
392
 
379
- Returns:
380
- str: The response generated by the chat model.
393
+ Yields:
394
+ str: The generated text completion in a streaming fashion.
381
395
  """
382
- resp = chat_oai_stream(
383
- base_url=self.base_url,
396
+ response = self.client.chat.completions.create(
384
397
  model=self.model_path,
385
- prompt=prompt
398
+ messages=[{
399
+ "role": "user",
400
+ "content": prompt
401
+ }],
402
+ stream=True,
403
+ **kwargs
386
404
  )
387
- return resp
405
+
406
+ for chunk in response:
407
+ content = chunk.choices[0].delta.content
408
+ if content:
409
+ yield content
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=5sZZNySv08wwfzJcSDssGTqUn9wlmDsR6R4XB8J8mFM,70
2
- hdl/_version.py,sha256=PdJ7dZoz_SyEgX0MdrMfQYBFlGcwpemv6ibF8NKALBY,411
2
+ hdl/_version.py,sha256=NWmu2cvzOcqY9v-ee-qFLmtXRczssdN-cFGZ9qMNSmY,411
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -84,12 +84,12 @@ hdl/utils/database_tools/connect.py,sha256=KUnVG-8raifEJ_N0b3c8LkTTIfn9NIyw8LX6q
84
84
  hdl/utils/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
85
85
  hdl/utils/general/glob.py,sha256=8-RCnt6L297wMIfn34ZAMCsGCZUjHG3MGglGZI1cX0g,491
86
86
  hdl/utils/llm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
87
- hdl/utils/llm/chat.py,sha256=gsbqWh8fTcJUENU6ZuMClZAuSOLFnD5VP8kXOxGh3Zw,13776
87
+ hdl/utils/llm/chat.py,sha256=67Mx1Q077BOscQbHiDxJpjBQalJKO8SQMLcRch5_Xj8,14352
88
88
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
89
89
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
90
90
  hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
91
91
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
92
- hjxdl-0.1.8.dist-info/METADATA,sha256=a9BaE0EGy5G9EM3Tbsi4LMmIrCMFJUuDjFnmmu_nBW4,542
93
- hjxdl-0.1.8.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
94
- hjxdl-0.1.8.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
95
- hjxdl-0.1.8.dist-info/RECORD,,
92
+ hjxdl-0.1.9.dist-info/METADATA,sha256=O6Utim9uYtAH6dTsbrPQrgkdzCOHjwsYaGUHn6_noIM,542
93
+ hjxdl-0.1.9.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
94
+ hjxdl-0.1.9.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
95
+ hjxdl-0.1.9.dist-info/RECORD,,
File without changes