gpt-batch 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpt_batch/batcher.py CHANGED
@@ -20,8 +20,7 @@ class GPTBatcher:
20
20
  retry_attempts (int): Number of retries if a request fails. Default is 2.
21
21
  miss_index (list): Tracks the indices of requests that failed to process correctly.
22
22
  """
23
-
24
- def __init__(self, api_key, model_name="gpt-3.5-turbo-0125", system_prompt="",temperature=1,num_workers=64,timeout_duration=60,retry_attempts=2,api_base_url=None):
23
+ def __init__(self, api_key, model_name="gpt-3.5-turbo-0125", system_prompt="",temperature=1,num_workers=64,timeout_duration=60,retry_attempts=2,api_base_url=None,**kwargs):
25
24
 
26
25
  self.is_claude = bool(re.search(r'claude', model_name, re.IGNORECASE))
27
26
 
@@ -41,6 +40,7 @@ class GPTBatcher:
41
40
  self.timeout_duration = timeout_duration
42
41
  self.retry_attempts = retry_attempts
43
42
  self.miss_index = []
43
+ self.extra_params = kwargs
44
44
 
45
45
  def get_attitude(self, ask_text):
46
46
  index, ask_text = ask_text
@@ -55,6 +55,7 @@ class GPTBatcher:
55
55
  ],
56
56
  system=self.system_prompt if self.system_prompt else None,
57
57
  temperature=self.temperature,
58
+ **self.extra_params
58
59
  )
59
60
  return (index, message.content[0].text)
60
61
  else:
@@ -66,6 +67,7 @@ class GPTBatcher:
66
67
  {"role": "user", "content": ask_text}
67
68
  ],
68
69
  temperature=self.temperature,
70
+ **self.extra_params
69
71
  )
70
72
  return (index, completion.choices[0].message.content)
71
73
  except Exception as e:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gpt-batch
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: A package for batch processing with OpenAI API.
5
5
  Home-page: https://github.com/fengsxy/gpt_batch
6
6
  Author: Ted Yu
@@ -0,0 +1,8 @@
1
+ gpt_batch/__init__.py,sha256=zGDItktTxKLSQr44GY78dl5LKsSJig0Q59dzusqhU0U,59
2
+ gpt_batch/batcher.py,sha256=y8B4hIeQJQ16G5PvlNgHE-CtVQzHPhpBssOAg7npQLA,9083
3
+ tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ tests/test_batcher.py,sha256=yRwqe2_VTp4wXXeQRhyDPZ2NZ-H3SSCDAxlNNXh3Aro,5314
5
+ gpt_batch-0.1.9.dist-info/METADATA,sha256=30t3VH_tY1mNWnzBPuQWKmD1o3bcA9yh3htvAWBgyok,3401
6
+ gpt_batch-0.1.9.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
7
+ gpt_batch-0.1.9.dist-info/top_level.txt,sha256=FtvJB_L9W_S6jL4G8Em_YWphG1wdKAF20BHUrf4B0yM,16
8
+ gpt_batch-0.1.9.dist-info/RECORD,,
tests/test_batcher.py CHANGED
@@ -18,6 +18,27 @@ def test_handle_message_list():
18
18
  assert len(results) == 2, "There should be two results, one for each message"
19
19
  assert all(len(result) >= 2 for result in results), "Each result should be at least two elements"
20
20
 
21
+
22
+ def test_json_format():
23
+ import json
24
+ # Initialize the GPTBatcher with hypothetical valid credentials
25
+ #api_key = #get from system environment
26
+ api_key = os.getenv('TEST_KEY')
27
+ if not api_key:
28
+ raise ValueError("API key must be set in the environment variables")
29
+ batcher = GPTBatcher(api_key=api_key, model_name='gpt-3.5-turbo-1106', system_prompt="Your task is to discuss privacy questions and provide persuasive answers with supporting reasons.",response_format={ "type": "json_object" })
30
+ message_list = ["return me a random json object", "return me a random json object"]
31
+
32
+ # Call the method under test
33
+ results = batcher.handle_message_list(message_list)
34
+ # Assertions to verify the length of the results and the structure of each item
35
+ assert len(results) == 2, "There should be two results, one for each message"
36
+ assert all(len(result) >= 2 for result in results), "Each result should be at least two elements"
37
+ #assert all(isinstance(result, dict) and 'json' in result for result in results), "Each result should be a JSON object with 'json' key"
38
+ assert all(isinstance(json.loads(result), dict) for result in results), "Each result should be a JSON object with 'json' key"
39
+
40
+
41
+
21
42
  def test_handle_embedding_list():
22
43
  # Initialize the GPTBatcher with hypothetical valid credentials
23
44
  #api_key = #get from system environment
@@ -1,8 +0,0 @@
1
- gpt_batch/__init__.py,sha256=zGDItktTxKLSQr44GY78dl5LKsSJig0Q59dzusqhU0U,59
2
- gpt_batch/batcher.py,sha256=sruKqK_tY5WNMXvejNz_OAw6AHtJIqcvDKbMbX1rX_M,8959
3
- tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- tests/test_batcher.py,sha256=FAlV_T5dT2OcM28yWPWWU9yIN0D7SBeYP1oTkvvzXKk,4077
5
- gpt_batch-0.1.8.dist-info/METADATA,sha256=IIBetdVEVrUx30cxrzB4ZkAy0mGp3PVueJc9tm2WcZA,3401
6
- gpt_batch-0.1.8.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
7
- gpt_batch-0.1.8.dist-info/top_level.txt,sha256=FtvJB_L9W_S6jL4G8Em_YWphG1wdKAF20BHUrf4B0yM,16
8
- gpt_batch-0.1.8.dist-info/RECORD,,