gpt-batch 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpt_batch/__init__.py +3 -0
- gpt_batch/batcher.py +85 -0
- gpt_batch-0.1.0.dist-info/METADATA +30 -0
- gpt_batch-0.1.0.dist-info/RECORD +8 -0
- gpt_batch-0.1.0.dist-info/WHEEL +5 -0
- gpt_batch-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +0 -0
- tests/test_batcher.py +23 -0
gpt_batch/__init__.py
ADDED
gpt_batch/batcher.py
ADDED
@@ -0,0 +1,85 @@
|
|
1
|
+
from openai import OpenAI
|
2
|
+
from concurrent.futures import ThreadPoolExecutor, wait
|
3
|
+
from functools import partial
|
4
|
+
from tqdm import tqdm
|
5
|
+
|
6
|
+
class GPTBatcher:
|
7
|
+
def __init__(self, api_key, model_name="gpt-3.5-turbo-0125", system_prompt="",temperature=1,num_workers=64,timeout_duration=60,retry_attempts=2):
|
8
|
+
self.client = OpenAI(api_key=api_key)
|
9
|
+
self.model_name = model_name
|
10
|
+
self.system_prompt = system_prompt
|
11
|
+
self.temperature = temperature
|
12
|
+
self.num_workers = num_workers
|
13
|
+
self.timeout_duration = timeout_duration
|
14
|
+
self.retry_attempts = retry_attempts
|
15
|
+
self.miss_index =[]
|
16
|
+
|
17
|
+
def get_attitude(self, ask_text):
|
18
|
+
index, ask_text = ask_text
|
19
|
+
|
20
|
+
completion = self.client.chat.completions.create(
|
21
|
+
model=self.model_name,
|
22
|
+
messages=[
|
23
|
+
{"role": "system", "content": self.system_prompt},
|
24
|
+
{"role": "user", "content": ask_text}
|
25
|
+
],
|
26
|
+
temperature=self.temperature,
|
27
|
+
)
|
28
|
+
return (index, completion.choices[0].message.content)
|
29
|
+
|
30
|
+
def process_attitude(self, message_list):
|
31
|
+
new_list = []
|
32
|
+
num_workers = self.num_workers
|
33
|
+
timeout_duration = self.timeout_duration
|
34
|
+
retry_attempts=2
|
35
|
+
|
36
|
+
executor = ThreadPoolExecutor(max_workers=num_workers)
|
37
|
+
message_chunks = list(self.chunk_list(message_list, num_workers))
|
38
|
+
for chunk in tqdm(message_chunks, desc="Processing messages"):
|
39
|
+
future_to_message = {executor.submit(self.get_attitude, message): message for message in chunk}
|
40
|
+
for _ in range(retry_attempts):
|
41
|
+
done, not_done = wait(future_to_message.keys(), timeout=timeout_duration)
|
42
|
+
for future in not_done:
|
43
|
+
future.cancel()
|
44
|
+
new_list.extend(future.result() for future in done if future.done())
|
45
|
+
if len(not_done) == 0:
|
46
|
+
break
|
47
|
+
future_to_message = {executor.submit(self.get_attitude, (future_to_message[future], msg), temperature): future_to_message[future] for future, msg in not_done}
|
48
|
+
executor.shutdown(wait=False)
|
49
|
+
return new_list
|
50
|
+
|
51
|
+
def complete_attitude_list(self,attitude_list, max_length):
|
52
|
+
completed_list = []
|
53
|
+
current_index = 0
|
54
|
+
for item in attitude_list:
|
55
|
+
index, value = item
|
56
|
+
# Fill in missing indices
|
57
|
+
while current_index < index:
|
58
|
+
completed_list.append((current_index, None))
|
59
|
+
current_index += 1
|
60
|
+
# Add the current element from the list
|
61
|
+
completed_list.append(item)
|
62
|
+
current_index = index + 1
|
63
|
+
while current_index < max_length:
|
64
|
+
print("Filling in missing index", current_index)
|
65
|
+
self.miss_index.append(current_index)
|
66
|
+
completed_list.append((current_index, None))
|
67
|
+
current_index += 1
|
68
|
+
return completed_list
|
69
|
+
|
70
|
+
def chunk_list(self, lst, n):
|
71
|
+
"""Yield successive n-sized chunks from lst."""
|
72
|
+
for i in range(0, len(lst), n):
|
73
|
+
yield lst[i:i + n]
|
74
|
+
|
75
|
+
def handle_message_list(self,message_list):
|
76
|
+
indexed_list = [(index, data) for index, data in enumerate(message_list)]
|
77
|
+
max_length = len(indexed_list)
|
78
|
+
attitude_list = self.process_attitude(indexed_list)
|
79
|
+
attitude_list.sort(key=lambda x: x[0])
|
80
|
+
attitude_list = self.complete_attitude_list(attitude_list, max_length)
|
81
|
+
attitude_list = [x[1] for x in attitude_list]
|
82
|
+
return attitude_list
|
83
|
+
|
84
|
+
# Add other necessary methods similar to the above, refactored to fit within this class structure.
|
85
|
+
|
@@ -0,0 +1,30 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: gpt-batch
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: A package for batch processing with OpenAI API.
|
5
|
+
Home-page: https://github.com/fengsxy/gpt_batch
|
6
|
+
Author: Ted Yu
|
7
|
+
Author-email: liddlerain@gmail.com
|
8
|
+
Description-Content-Type: text/markdown
|
9
|
+
Requires-Dist: openai
|
10
|
+
Requires-Dist: tqdm
|
11
|
+
|
12
|
+
# GPT Batcher
|
13
|
+
|
14
|
+
A simple tool to batch process messages using OpenAI's GPT models.
|
15
|
+
|
16
|
+
## Installation
|
17
|
+
|
18
|
+
Clone this repository and run:
|
19
|
+
|
20
|
+
## Usage
|
21
|
+
|
22
|
+
Here's how to use the `GPTBatcher`:
|
23
|
+
|
24
|
+
```python
|
25
|
+
from gpt_batch.batcher import GPTBatcher
|
26
|
+
|
27
|
+
batcher = GPTBatcher(key='your_key_here', model_name='gpt-3.5-turbo-1106')
|
28
|
+
result = batcher.handle_list(['your', 'list', 'of', 'messages'])
|
29
|
+
print(result)
|
30
|
+
|
@@ -0,0 +1,8 @@
|
|
1
|
+
gpt_batch/__init__.py,sha256=pPeIKAbsqVso-jQOnhXWABziotoABCeDR0QrwyzPGv0,61
|
2
|
+
gpt_batch/batcher.py,sha256=kzL0OT8dpWUEOiKMs8BptqVGBrlLVcv4UWrWfEEihjg,3720
|
3
|
+
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
tests/test_batcher.py,sha256=uq4GUAtlpofAFRMAUJOCSp5V-qBXpbVR26e_9k5h7-I,1117
|
5
|
+
gpt_batch-0.1.0.dist-info/METADATA,sha256=Ns-Dz7Z4cIZFfmp9DpAFpF6zLeaAFa3DWtsNWPv0Peg,707
|
6
|
+
gpt_batch-0.1.0.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
7
|
+
gpt_batch-0.1.0.dist-info/top_level.txt,sha256=FtvJB_L9W_S6jL4G8Em_YWphG1wdKAF20BHUrf4B0yM,16
|
8
|
+
gpt_batch-0.1.0.dist-info/RECORD,,
|
tests/__init__.py
ADDED
File without changes
|
tests/test_batcher.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
import pytest
|
2
|
+
from gpt_batch import GPTBatcher
|
3
|
+
import os
|
4
|
+
|
5
|
+
def test_handle_message_list():
|
6
|
+
# Initialize the GPTBatcher with hypothetical valid credentials
|
7
|
+
#api_key = #get from system environment
|
8
|
+
api_key = os.getenv('TEST_KEY')
|
9
|
+
if not api_key:
|
10
|
+
raise ValueError("API key must be set in the environment variables")
|
11
|
+
batcher = GPTBatcher(api_key=api_key, model_name='gpt-3.5-turbo-1106', system_prompt="Your task is to discuss privacy questions and provide persuasive answers with supporting reasons.")
|
12
|
+
message_list = ["I think privacy is important", "I don't think privacy is important"]
|
13
|
+
|
14
|
+
# Call the method under test
|
15
|
+
results = batcher.handle_message_list(message_list)
|
16
|
+
|
17
|
+
# Assertions to verify the length of the results and the structure of each item
|
18
|
+
assert len(results) == 2, "There should be two results, one for each message"
|
19
|
+
assert all(len(result) >= 2 for result in results), "Each result should be at least two elements"
|
20
|
+
|
21
|
+
# Optionally, you can add a test configuration if you have specific needs
|
22
|
+
if __name__ == "__main__":
|
23
|
+
pytest.main()
|