gpt-batch 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpt_batch/batcher.py +101 -56
- {gpt_batch-0.1.6.dist-info → gpt_batch-0.1.8.dist-info}/METADATA +18 -2
- gpt_batch-0.1.8.dist-info/RECORD +8 -0
- {gpt_batch-0.1.6.dist-info → gpt_batch-0.1.8.dist-info}/WHEEL +1 -1
- tests/test_batcher.py +22 -0
- gpt_batch-0.1.6.dist-info/RECORD +0 -8
- {gpt_batch-0.1.6.dist-info → gpt_batch-0.1.8.dist-info}/top_level.txt +0 -0
gpt_batch/batcher.py
CHANGED
@@ -1,57 +1,73 @@
|
|
1
1
|
from openai import OpenAI
|
2
|
+
import anthropic
|
2
3
|
from concurrent.futures import ThreadPoolExecutor, wait
|
3
4
|
from functools import partial
|
4
5
|
from tqdm import tqdm
|
6
|
+
import re
|
5
7
|
|
6
8
|
class GPTBatcher:
|
7
9
|
"""
|
8
|
-
A class to handle batching and sending requests to the OpenAI GPT model efficiently.
|
10
|
+
A class to handle batching and sending requests to the OpenAI GPT model and Anthropic Claude models efficiently.
|
9
11
|
|
10
12
|
Attributes:
|
11
|
-
client
|
12
|
-
|
13
|
+
client: The client instance to communicate with the API (OpenAI or Anthropic).
|
14
|
+
is_claude (bool): Flag to indicate if using a Claude model.
|
15
|
+
model_name (str): The name of the model to be used. Default is 'gpt-3.5-turbo-0125'.
|
13
16
|
system_prompt (str): Initial prompt or context to be used with the model. Default is an empty string.
|
14
17
|
temperature (float): Controls the randomness of the model's responses. Higher values lead to more diverse outputs. Default is 1.
|
15
18
|
num_workers (int): Number of worker threads used for handling concurrent requests. Default is 64.
|
16
19
|
timeout_duration (int): Maximum time (in seconds) to wait for a response from the API before timing out. Default is 60 seconds.
|
17
20
|
retry_attempts (int): Number of retries if a request fails. Default is 2.
|
18
21
|
miss_index (list): Tracks the indices of requests that failed to process correctly.
|
19
|
-
|
20
|
-
Parameters:
|
21
|
-
api_key (str): API key for authenticating requests to the OpenAI API.
|
22
|
-
model_name (str, optional): Specifies the GPT model version. Default is 'gpt-3.5-turbo-0125'.
|
23
|
-
system_prompt (str, optional): Initial text or question to seed the model with. Default is empty.
|
24
|
-
temperature (float, optional): Sets the creativity of the responses. Default is 1.
|
25
|
-
num_workers (int, optional): Number of parallel workers for request handling. Default is 64.
|
26
|
-
timeout_duration (int, optional): Timeout for API responses in seconds. Default is 60.
|
27
|
-
retry_attempts (int, optional): How many times to retry a failed request. Default is 2.
|
28
22
|
"""
|
29
23
|
|
30
24
|
def __init__(self, api_key, model_name="gpt-3.5-turbo-0125", system_prompt="",temperature=1,num_workers=64,timeout_duration=60,retry_attempts=2,api_base_url=None):
|
31
25
|
|
32
|
-
self.
|
26
|
+
self.is_claude = bool(re.search(r'claude', model_name, re.IGNORECASE))
|
27
|
+
|
28
|
+
if self.is_claude:
|
29
|
+
self.client = anthropic.Anthropic(api_key=api_key)
|
30
|
+
# Anthropic doesn't support custom base URL the same way
|
31
|
+
# If needed, this could be implemented differently
|
32
|
+
else:
|
33
|
+
self.client = OpenAI(api_key=api_key)
|
34
|
+
if api_base_url:
|
35
|
+
self.client.base_url = api_base_url
|
36
|
+
|
33
37
|
self.model_name = model_name
|
34
38
|
self.system_prompt = system_prompt
|
35
39
|
self.temperature = temperature
|
36
40
|
self.num_workers = num_workers
|
37
41
|
self.timeout_duration = timeout_duration
|
38
42
|
self.retry_attempts = retry_attempts
|
39
|
-
self.miss_index =[]
|
40
|
-
if api_base_url:
|
41
|
-
self.client.base_url = api_base_url
|
43
|
+
self.miss_index = []
|
42
44
|
|
43
45
|
def get_attitude(self, ask_text):
|
44
46
|
index, ask_text = ask_text
|
45
47
|
try:
|
46
|
-
|
47
|
-
|
48
|
-
messages
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
48
|
+
if self.is_claude:
|
49
|
+
# Use the Anthropic Claude API
|
50
|
+
message = self.client.messages.create(
|
51
|
+
model=self.model_name,
|
52
|
+
max_tokens=1024, # You can make this configurable if needed
|
53
|
+
messages=[
|
54
|
+
{"role": "user", "content": ask_text}
|
55
|
+
],
|
56
|
+
system=self.system_prompt if self.system_prompt else None,
|
57
|
+
temperature=self.temperature,
|
58
|
+
)
|
59
|
+
return (index, message.content[0].text)
|
60
|
+
else:
|
61
|
+
# Use the OpenAI API as before
|
62
|
+
completion = self.client.chat.completions.create(
|
63
|
+
model=self.model_name,
|
64
|
+
messages=[
|
65
|
+
{"role": "system", "content": self.system_prompt},
|
66
|
+
{"role": "user", "content": ask_text}
|
67
|
+
],
|
68
|
+
temperature=self.temperature,
|
69
|
+
)
|
70
|
+
return (index, completion.choices[0].message.content)
|
55
71
|
except Exception as e:
|
56
72
|
print(f"Error occurred: {e}")
|
57
73
|
self.miss_index.append(index)
|
@@ -61,7 +77,7 @@ class GPTBatcher:
|
|
61
77
|
new_list = []
|
62
78
|
num_workers = self.num_workers
|
63
79
|
timeout_duration = self.timeout_duration
|
64
|
-
retry_attempts =
|
80
|
+
retry_attempts = self.retry_attempts
|
65
81
|
|
66
82
|
executor = ThreadPoolExecutor(max_workers=num_workers)
|
67
83
|
message_chunks = list(self.chunk_list(message_list, num_workers))
|
@@ -75,14 +91,14 @@ class GPTBatcher:
|
|
75
91
|
new_list.extend(future.result() for future in done if future.done())
|
76
92
|
if len(not_done) == 0:
|
77
93
|
break
|
78
|
-
future_to_message = {executor.submit(self.get_attitude, future_to_message[future]): future for future in not_done}
|
94
|
+
future_to_message = {executor.submit(self.get_attitude, future_to_message[future]): future_to_message[future] for future in not_done}
|
79
95
|
except Exception as e:
|
80
96
|
print(f"Error occurred: {e}")
|
81
97
|
finally:
|
82
98
|
executor.shutdown(wait=False)
|
83
99
|
return new_list
|
84
100
|
|
85
|
-
def complete_attitude_list(self,attitude_list, max_length):
|
101
|
+
def complete_attitude_list(self, attitude_list, max_length):
|
86
102
|
completed_list = []
|
87
103
|
current_index = 0
|
88
104
|
for item in attitude_list:
|
@@ -106,7 +122,7 @@ class GPTBatcher:
|
|
106
122
|
for i in range(0, len(lst), n):
|
107
123
|
yield lst[i:i + n]
|
108
124
|
|
109
|
-
def handle_message_list(self,message_list):
|
125
|
+
def handle_message_list(self, message_list):
|
110
126
|
indexed_list = [(index, data) for index, data in enumerate(message_list)]
|
111
127
|
max_length = len(indexed_list)
|
112
128
|
attitude_list = self.process_attitude(indexed_list)
|
@@ -115,32 +131,50 @@ class GPTBatcher:
|
|
115
131
|
attitude_list = [x[1] for x in attitude_list]
|
116
132
|
return attitude_list
|
117
133
|
|
118
|
-
def
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
134
|
+
def get_embedding(self, text):
|
135
|
+
index, text = text
|
136
|
+
try:
|
137
|
+
if self.is_claude:
|
138
|
+
# Use Anthropic's embedding API if available
|
139
|
+
# Note: As of March 2025, make sure to check Anthropic's latest API
|
140
|
+
# for embeddings, as the format might have changed
|
141
|
+
response = self.client.embeddings.create(
|
142
|
+
model=self.model_name,
|
143
|
+
input=text
|
144
|
+
)
|
145
|
+
return (index, response.embedding)
|
146
|
+
else:
|
147
|
+
# Use OpenAI's embedding API
|
148
|
+
response = self.client.embeddings.create(
|
149
|
+
input=text,
|
150
|
+
model=self.model_name
|
151
|
+
)
|
152
|
+
return (index, response.data[0].embedding)
|
153
|
+
except Exception as e:
|
154
|
+
print(f"Error getting embedding: {e}")
|
155
|
+
self.miss_index.append(index)
|
156
|
+
return (index, None)
|
157
|
+
|
158
|
+
def process_embedding(self, message_list):
|
159
|
+
new_list = []
|
160
|
+
executor = ThreadPoolExecutor(max_workers=self.num_workers)
|
161
|
+
# Split message_list into chunks
|
162
|
+
message_chunks = list(self.chunk_list(message_list, self.num_workers))
|
163
|
+
fixed_get_embedding = partial(self.get_embedding)
|
164
|
+
for chunk in tqdm(message_chunks, desc="Processing messages"):
|
165
|
+
future_to_message = {executor.submit(fixed_get_embedding, message): message for message in chunk}
|
166
|
+
for i in range(self.retry_attempts):
|
167
|
+
done, not_done = wait(future_to_message.keys(), timeout=self.timeout_duration)
|
168
|
+
for future in not_done:
|
169
|
+
future.cancel()
|
170
|
+
new_list.extend(future.result() for future in done if future.done())
|
171
|
+
if len(not_done) == 0:
|
172
|
+
break
|
173
|
+
future_to_message = {executor.submit(fixed_get_embedding, future_to_message[future]): future_to_message[future] for future in not_done}
|
174
|
+
executor.shutdown(wait=False)
|
175
|
+
return new_list
|
142
176
|
|
143
|
-
def handle_embedding_list(self,message_list):
|
177
|
+
def handle_embedding_list(self, message_list):
|
144
178
|
indexed_list = [(index, data) for index, data in enumerate(message_list)]
|
145
179
|
max_length = len(indexed_list)
|
146
180
|
attitude_list = self.process_embedding(indexed_list)
|
@@ -152,5 +186,16 @@ class GPTBatcher:
|
|
152
186
|
def get_miss_index(self):
|
153
187
|
return self.miss_index
|
154
188
|
|
155
|
-
|
156
|
-
|
189
|
+
# Example usage:
|
190
|
+
if __name__ == "__main__":
|
191
|
+
# For OpenAI
|
192
|
+
openai_batcher = GPTBatcher(
|
193
|
+
api_key="your_openai_api_key",
|
194
|
+
model_name="gpt-4-turbo"
|
195
|
+
)
|
196
|
+
|
197
|
+
# For Claude
|
198
|
+
claude_batcher = GPTBatcher(
|
199
|
+
api_key="your_anthropic_api_key",
|
200
|
+
model_name="claude-3-7-sonnet-20250219"
|
201
|
+
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: gpt-batch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.8
|
4
4
|
Summary: A package for batch processing with OpenAI API.
|
5
5
|
Home-page: https://github.com/fengsxy/gpt_batch
|
6
6
|
Author: Ted Yu
|
@@ -10,8 +10,8 @@ Platform: UNKNOWN
|
|
10
10
|
Description-Content-Type: text/markdown
|
11
11
|
Requires-Dist: openai
|
12
12
|
Requires-Dist: tqdm
|
13
|
+
Requires-Dist: anthropic
|
13
14
|
|
14
|
-
Certainly! Here's a clean and comprehensive README for your `GPTBatcher` tool, formatted in Markdown:
|
15
15
|
|
16
16
|
```markdown
|
17
17
|
# GPT Batcher
|
@@ -62,6 +62,22 @@ print(result)
|
|
62
62
|
# Expected output: ["embedding_1", "embedding_2", "embedding_3", "embedding_4"]
|
63
63
|
```
|
64
64
|
|
65
|
+
### Handling Message Lists with different API
|
66
|
+
|
67
|
+
This example demonstrates how to send a list of questions and receive answers with different api:
|
68
|
+
|
69
|
+
```python
|
70
|
+
from gpt_batch.batcher import GPTBatcher
|
71
|
+
|
72
|
+
# Initialize the batcher
|
73
|
+
batcher = GPTBatcher(api_key='sk-', model_name='deepseek-chat',api_base_url="https://api.deepseek.com/v1")
|
74
|
+
|
75
|
+
|
76
|
+
# Send a list of messages and receive answers
|
77
|
+
result = batcher.handle_message_list(['question_1', 'question_2', 'question_3', 'question_4'])
|
78
|
+
|
79
|
+
# Expected output: ["answer_1", "answer_2", "answer_3", "answer_4"]
|
80
|
+
```
|
65
81
|
## Configuration
|
66
82
|
|
67
83
|
The `GPTBatcher` class can be customized with several parameters to adjust its performance and behavior:
|
@@ -0,0 +1,8 @@
|
|
1
|
+
gpt_batch/__init__.py,sha256=zGDItktTxKLSQr44GY78dl5LKsSJig0Q59dzusqhU0U,59
|
2
|
+
gpt_batch/batcher.py,sha256=sruKqK_tY5WNMXvejNz_OAw6AHtJIqcvDKbMbX1rX_M,8959
|
3
|
+
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
tests/test_batcher.py,sha256=FAlV_T5dT2OcM28yWPWWU9yIN0D7SBeYP1oTkvvzXKk,4077
|
5
|
+
gpt_batch-0.1.8.dist-info/METADATA,sha256=IIBetdVEVrUx30cxrzB4ZkAy0mGp3PVueJc9tm2WcZA,3401
|
6
|
+
gpt_batch-0.1.8.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
7
|
+
gpt_batch-0.1.8.dist-info/top_level.txt,sha256=FtvJB_L9W_S6jL4G8Em_YWphG1wdKAF20BHUrf4B0yM,16
|
8
|
+
gpt_batch-0.1.8.dist-info/RECORD,,
|
tests/test_batcher.py
CHANGED
@@ -51,5 +51,27 @@ def test_get_miss_index():
|
|
51
51
|
miss_index = batcher.get_miss_index()
|
52
52
|
assert miss_index == [], "The miss index should be empty"
|
53
53
|
# Optionally, you can add a test configuration if you have specific needs
|
54
|
+
|
55
|
+
|
56
|
+
def test_claude_handle_message_list():
|
57
|
+
# Initialize the GPTBatcher with Claude model
|
58
|
+
api_key = os.getenv('ANTHROPIC_API_KEY')
|
59
|
+
if not api_key:
|
60
|
+
raise ValueError("Anthropic API key must be set in the environment variables as ANTHROPIC_API_KEY")
|
61
|
+
|
62
|
+
batcher = GPTBatcher(
|
63
|
+
api_key=api_key,
|
64
|
+
model_name='claude-3-7-sonnet-20250219',
|
65
|
+
system_prompt="Your task is to discuss privacy questions and provide persuasive answers with supporting reasons."
|
66
|
+
)
|
67
|
+
message_list = ["I think privacy is important", "I don't think privacy is important"]
|
68
|
+
|
69
|
+
# Call the method under test
|
70
|
+
results = batcher.handle_message_list(message_list)
|
71
|
+
|
72
|
+
# Assertions to verify the length of the results and the structure of each item
|
73
|
+
assert len(results) == 2, "There should be two results, one for each message"
|
74
|
+
assert all(isinstance(result, str) and len(result) > 0 for result in results if result is not None), "Each result should be a non-empty string if not None"
|
75
|
+
assert batcher.is_claude, "Should recognize model as Claude"
|
54
76
|
if __name__ == "__main__":
|
55
77
|
pytest.main()
|
gpt_batch-0.1.6.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
gpt_batch/__init__.py,sha256=zGDItktTxKLSQr44GY78dl5LKsSJig0Q59dzusqhU0U,59
|
2
|
-
gpt_batch/batcher.py,sha256=jKLK-iuByg3Mc2ZungT5aZYzO60c5yO-YXCOf_70O6w,7591
|
3
|
-
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
tests/test_batcher.py,sha256=N88RZrSuBaDti6Lry7xipyGXHn3jKg85O12mjcHHZA0,3006
|
5
|
-
gpt_batch-0.1.6.dist-info/METADATA,sha256=Q0EhkVe8YbKac3JjhASu3_wY3y9hV_YJqqwVEzlf9wc,2932
|
6
|
-
gpt_batch-0.1.6.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
gpt_batch-0.1.6.dist-info/top_level.txt,sha256=FtvJB_L9W_S6jL4G8Em_YWphG1wdKAF20BHUrf4B0yM,16
|
8
|
-
gpt_batch-0.1.6.dist-info/RECORD,,
|
File without changes
|