gpt-batch 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpt_batch/batcher.py +4 -2
- {gpt_batch-0.1.8.dist-info → gpt_batch-0.1.9.dist-info}/METADATA +1 -1
- gpt_batch-0.1.9.dist-info/RECORD +8 -0
- tests/test_batcher.py +21 -0
- gpt_batch-0.1.8.dist-info/RECORD +0 -8
- {gpt_batch-0.1.8.dist-info → gpt_batch-0.1.9.dist-info}/WHEEL +0 -0
- {gpt_batch-0.1.8.dist-info → gpt_batch-0.1.9.dist-info}/top_level.txt +0 -0
gpt_batch/batcher.py
CHANGED
@@ -20,8 +20,7 @@ class GPTBatcher:
|
|
20
20
|
retry_attempts (int): Number of retries if a request fails. Default is 2.
|
21
21
|
miss_index (list): Tracks the indices of requests that failed to process correctly.
|
22
22
|
"""
|
23
|
-
|
24
|
-
def __init__(self, api_key, model_name="gpt-3.5-turbo-0125", system_prompt="",temperature=1,num_workers=64,timeout_duration=60,retry_attempts=2,api_base_url=None):
|
23
|
+
def __init__(self, api_key, model_name="gpt-3.5-turbo-0125", system_prompt="",temperature=1,num_workers=64,timeout_duration=60,retry_attempts=2,api_base_url=None,**kwargs):
|
25
24
|
|
26
25
|
self.is_claude = bool(re.search(r'claude', model_name, re.IGNORECASE))
|
27
26
|
|
@@ -41,6 +40,7 @@ class GPTBatcher:
|
|
41
40
|
self.timeout_duration = timeout_duration
|
42
41
|
self.retry_attempts = retry_attempts
|
43
42
|
self.miss_index = []
|
43
|
+
self.extra_params = kwargs
|
44
44
|
|
45
45
|
def get_attitude(self, ask_text):
|
46
46
|
index, ask_text = ask_text
|
@@ -55,6 +55,7 @@ class GPTBatcher:
|
|
55
55
|
],
|
56
56
|
system=self.system_prompt if self.system_prompt else None,
|
57
57
|
temperature=self.temperature,
|
58
|
+
**self.extra_params
|
58
59
|
)
|
59
60
|
return (index, message.content[0].text)
|
60
61
|
else:
|
@@ -66,6 +67,7 @@ class GPTBatcher:
|
|
66
67
|
{"role": "user", "content": ask_text}
|
67
68
|
],
|
68
69
|
temperature=self.temperature,
|
70
|
+
**self.extra_params
|
69
71
|
)
|
70
72
|
return (index, completion.choices[0].message.content)
|
71
73
|
except Exception as e:
|
@@ -0,0 +1,8 @@
|
|
1
|
+
gpt_batch/__init__.py,sha256=zGDItktTxKLSQr44GY78dl5LKsSJig0Q59dzusqhU0U,59
|
2
|
+
gpt_batch/batcher.py,sha256=y8B4hIeQJQ16G5PvlNgHE-CtVQzHPhpBssOAg7npQLA,9083
|
3
|
+
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
tests/test_batcher.py,sha256=yRwqe2_VTp4wXXeQRhyDPZ2NZ-H3SSCDAxlNNXh3Aro,5314
|
5
|
+
gpt_batch-0.1.9.dist-info/METADATA,sha256=30t3VH_tY1mNWnzBPuQWKmD1o3bcA9yh3htvAWBgyok,3401
|
6
|
+
gpt_batch-0.1.9.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
7
|
+
gpt_batch-0.1.9.dist-info/top_level.txt,sha256=FtvJB_L9W_S6jL4G8Em_YWphG1wdKAF20BHUrf4B0yM,16
|
8
|
+
gpt_batch-0.1.9.dist-info/RECORD,,
|
tests/test_batcher.py
CHANGED
@@ -18,6 +18,27 @@ def test_handle_message_list():
|
|
18
18
|
assert len(results) == 2, "There should be two results, one for each message"
|
19
19
|
assert all(len(result) >= 2 for result in results), "Each result should be at least two elements"
|
20
20
|
|
21
|
+
|
22
|
+
def test_json_format():
|
23
|
+
import json
|
24
|
+
# Initialize the GPTBatcher with hypothetical valid credentials
|
25
|
+
#api_key = #get from system environment
|
26
|
+
api_key = os.getenv('TEST_KEY')
|
27
|
+
if not api_key:
|
28
|
+
raise ValueError("API key must be set in the environment variables")
|
29
|
+
batcher = GPTBatcher(api_key=api_key, model_name='gpt-3.5-turbo-1106', system_prompt="Your task is to discuss privacy questions and provide persuasive answers with supporting reasons.",response_format={ "type": "json_object" })
|
30
|
+
message_list = ["return me a random json object", "return me a random json object"]
|
31
|
+
|
32
|
+
# Call the method under test
|
33
|
+
results = batcher.handle_message_list(message_list)
|
34
|
+
# Assertions to verify the length of the results and the structure of each item
|
35
|
+
assert len(results) == 2, "There should be two results, one for each message"
|
36
|
+
assert all(len(result) >= 2 for result in results), "Each result should be at least two elements"
|
37
|
+
#assert all(isinstance(result, dict) and 'json' in result for result in results), "Each result should be a JSON object with 'json' key"
|
38
|
+
assert all(isinstance(json.loads(result), dict) for result in results), "Each result should be a JSON object with 'json' key"
|
39
|
+
|
40
|
+
|
41
|
+
|
21
42
|
def test_handle_embedding_list():
|
22
43
|
# Initialize the GPTBatcher with hypothetical valid credentials
|
23
44
|
#api_key = #get from system environment
|
gpt_batch-0.1.8.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
gpt_batch/__init__.py,sha256=zGDItktTxKLSQr44GY78dl5LKsSJig0Q59dzusqhU0U,59
|
2
|
-
gpt_batch/batcher.py,sha256=sruKqK_tY5WNMXvejNz_OAw6AHtJIqcvDKbMbX1rX_M,8959
|
3
|
-
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
tests/test_batcher.py,sha256=FAlV_T5dT2OcM28yWPWWU9yIN0D7SBeYP1oTkvvzXKk,4077
|
5
|
-
gpt_batch-0.1.8.dist-info/METADATA,sha256=IIBetdVEVrUx30cxrzB4ZkAy0mGp3PVueJc9tm2WcZA,3401
|
6
|
-
gpt_batch-0.1.8.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
7
|
-
gpt_batch-0.1.8.dist-info/top_level.txt,sha256=FtvJB_L9W_S6jL4G8Em_YWphG1wdKAF20BHUrf4B0yM,16
|
8
|
-
gpt_batch-0.1.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|