hjxdl 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.8'
16
- __version_tuple__ = version_tuple = (0, 1, 8)
15
+ __version__ = version = '0.1.10'
16
+ __version_tuple__ = version_tuple = (0, 1, 10)
hdl/utils/llm/chat.py CHANGED
@@ -8,7 +8,7 @@ from openai import OpenAI
8
8
  def chat_oai_stream(
9
9
  base_url="http://127.0.0.1:8000/v1",
10
10
  api_key="dummy_key",
11
- model="/data/models/Qwen-7B-Chat-Int4",
11
+ model="default_model",
12
12
  prompt="Who are you?",
13
13
  *args,
14
14
  **kwargs
@@ -48,7 +48,7 @@ def chat_oai_stream(
48
48
  def chat_oai_invoke(
49
49
  base_url="http://127.0.0.1:8000/v1",
50
50
  api_key="dummy_key",
51
- model="/data/models/Qwen-7B-Chat-Int4",
51
+ model="default_model",
52
52
  prompt="Who are you?",
53
53
  *args,
54
54
  **kwargs
@@ -318,26 +318,47 @@ class GGUF_M(Llama):
318
318
  class OpenAI_M():
319
319
  def __init__(
320
320
  self,
321
- model_path: str = None,
321
+ model_path: str = "default_model",
322
322
  device: str='gpu',
323
323
  generation_kwargs: dict = {},
324
324
  server_ip: str = "172.28.1.2",
325
325
  server_port: int = 8000,
326
+ api_key: str = "dummy_key",
327
+ *args,
328
+ **kwargs
326
329
  ):
327
- """Initialize the class with the specified parameters.
330
+ """Initialize the OpenAI client.
328
331
 
329
332
  Args:
330
- model_path (str, optional): Path to the model file. Defaults to None.
331
- device (str, optional): Device to run the model on. Defaults to 'gpu'.
332
- generation_kwargs (dict, optional): Additional keyword arguments for model generation. Defaults to {}.
333
- server_ip (str, optional): IP address of the server. Defaults to "172.28.1.2".
334
- server_port (int, optional): Port number of the server. Defaults to 8000.
333
+ model_path (str): Path to the model file. Defaults to "default_model".
334
+ device (str): Device to use for model inference. Defaults to 'gpu'.
335
+ generation_kwargs (dict): Additional keyword arguments for model generation. Defaults to {}.
336
+ server_ip (str): IP address of the server. Defaults to "172.28.1.2".
337
+ server_port (int): Port number of the server. Defaults to 8000.
338
+ api_key (str): API key for authentication. Defaults to "dummy_key".
339
+ *args: Variable length argument list.
340
+ **kwargs: Arbitrary keyword arguments.
341
+
342
+ Attributes:
343
+ model_path (str): Path to the model file.
344
+ server_ip (str): IP address of the server.
345
+ server_port (int): Port number of the server.
346
+ base_url (str): Base URL for API requests.
347
+ api_key (str): API key for authentication.
348
+ client (OpenAI): OpenAI client for making API requests.
335
349
  """
336
350
  self.model_path = model_path
337
351
  self.server_ip = server_ip
338
352
  self.server_port = server_port
339
353
  self.base_url = "http://{self.server_ip}:{str(self.server_port)}/v1"
340
-
354
+ self.api_key = api_key
355
+ self.client = OpenAI(
356
+ base_url=self.base_url,
357
+ api_key=self.api_key,
358
+ *args,
359
+ **kwargs
360
+ )
361
+
341
362
  def invoke(
342
363
  self,
343
364
  prompt : str,
@@ -355,33 +376,45 @@ class OpenAI_M():
355
376
  Returns:
356
377
  str: The response generated by the chatbot.
357
378
  """
358
- resp = chat_oai_invoke(
359
- base_url=self.base_url,
379
+ response = self.client.chat.completions.create(
360
380
  model=self.model_path,
361
- prompt=prompt
381
+ messages=[{
382
+ "role": "user",
383
+ "content": prompt
384
+ }],
385
+ stream=False,
386
+ **kwargs
362
387
  )
363
- return resp
364
-
388
+ return response.choices[0].message.content
389
+
365
390
  def stream(
366
391
  self,
367
392
  prompt : str,
368
393
  stop: list[str] | None = ["USER:", "ASSISTANT:"],
369
394
  # history: list = [],
370
395
  **kwargs: t.Any,
371
- ) -> str:
372
- """Generate a response by streaming conversation with the OpenAI chat model.
396
+ ):
397
+ """Generate text completion in a streaming fashion.
373
398
 
374
399
  Args:
375
- prompt (str): The prompt to start the conversation.
376
- stop (list[str], optional): List of strings that indicate when the conversation should stop. Defaults to ["USER:", "ASSISTANT:"].
377
- **kwargs: Additional keyword arguments to pass to the chat model.
400
+ prompt (str): The text prompt to generate completion for.
401
+ stop (list[str], optional): List of strings to stop streaming at. Defaults to ["USER:", "ASSISTANT:"].
402
+ **kwargs: Additional keyword arguments to pass to the completion API.
378
403
 
379
- Returns:
380
- str: The response generated by the chat model.
404
+ Yields:
405
+ str: The generated text completion in a streaming fashion.
381
406
  """
382
- resp = chat_oai_stream(
383
- base_url=self.base_url,
407
+ response = self.client.chat.completions.create(
384
408
  model=self.model_path,
385
- prompt=prompt
409
+ messages=[{
410
+ "role": "user",
411
+ "content": prompt
412
+ }],
413
+ stream=True,
414
+ **kwargs
386
415
  )
387
- return resp
416
+
417
+ for chunk in response:
418
+ content = chunk.choices[0].delta.content
419
+ if content:
420
+ yield content
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=5sZZNySv08wwfzJcSDssGTqUn9wlmDsR6R4XB8J8mFM,70
2
- hdl/_version.py,sha256=PdJ7dZoz_SyEgX0MdrMfQYBFlGcwpemv6ibF8NKALBY,411
2
+ hdl/_version.py,sha256=0iLmzkTe5cfY4SBtaPpUzHn9tXwbwplszcfp5pHW6nU,413
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -84,12 +84,12 @@ hdl/utils/database_tools/connect.py,sha256=KUnVG-8raifEJ_N0b3c8LkTTIfn9NIyw8LX6q
84
84
  hdl/utils/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
85
85
  hdl/utils/general/glob.py,sha256=8-RCnt6L297wMIfn34ZAMCsGCZUjHG3MGglGZI1cX0g,491
86
86
  hdl/utils/llm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
87
- hdl/utils/llm/chat.py,sha256=gsbqWh8fTcJUENU6ZuMClZAuSOLFnD5VP8kXOxGh3Zw,13776
87
+ hdl/utils/llm/chat.py,sha256=2r3lIAf14l3obfK2ee2z6gbdp1LU-qJiV9o7rlGG0fg,14853
88
88
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
89
89
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
90
90
  hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
91
91
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
92
- hjxdl-0.1.8.dist-info/METADATA,sha256=a9BaE0EGy5G9EM3Tbsi4LMmIrCMFJUuDjFnmmu_nBW4,542
93
- hjxdl-0.1.8.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
94
- hjxdl-0.1.8.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
95
- hjxdl-0.1.8.dist-info/RECORD,,
92
+ hjxdl-0.1.10.dist-info/METADATA,sha256=wyNIVGQpVGcszuRGBHAUeM8AuGUmelzADJ2yphmFUas,543
93
+ hjxdl-0.1.10.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
94
+ hjxdl-0.1.10.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
95
+ hjxdl-0.1.10.dist-info/RECORD,,
File without changes