gpt-batch 0.1.5__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpt_batch/batcher.py CHANGED
@@ -1,63 +1,83 @@
1
1
  from openai import OpenAI
2
+ import anthropic
2
3
  from concurrent.futures import ThreadPoolExecutor, wait
3
4
  from functools import partial
4
5
  from tqdm import tqdm
6
+ import re
5
7
 
6
8
  class GPTBatcher:
7
9
  """
8
- A class to handle batching and sending requests to the OpenAI GPT model efficiently.
10
+ A class to handle batching and sending requests to the OpenAI GPT model and Anthropic Claude models efficiently.
9
11
 
10
12
  Attributes:
11
- client (OpenAI): The client instance to communicate with the OpenAI API using the provided API key.
12
- model_name (str): The name of the GPT model to be used. Default is 'gpt-3.5-turbo-0125'.
13
+ client: The client instance to communicate with the API (OpenAI or Anthropic).
14
+ is_claude (bool): Flag to indicate if using a Claude model.
15
+ model_name (str): The name of the model to be used. Default is 'gpt-3.5-turbo-0125'.
13
16
  system_prompt (str): Initial prompt or context to be used with the model. Default is an empty string.
14
17
  temperature (float): Controls the randomness of the model's responses. Higher values lead to more diverse outputs. Default is 1.
15
18
  num_workers (int): Number of worker threads used for handling concurrent requests. Default is 64.
16
19
  timeout_duration (int): Maximum time (in seconds) to wait for a response from the API before timing out. Default is 60 seconds.
17
20
  retry_attempts (int): Number of retries if a request fails. Default is 2.
18
21
  miss_index (list): Tracks the indices of requests that failed to process correctly.
19
-
20
- Parameters:
21
- api_key (str): API key for authenticating requests to the OpenAI API.
22
- model_name (str, optional): Specifies the GPT model version. Default is 'gpt-3.5-turbo-0125'.
23
- system_prompt (str, optional): Initial text or question to seed the model with. Default is empty.
24
- temperature (float, optional): Sets the creativity of the responses. Default is 1.
25
- num_workers (int, optional): Number of parallel workers for request handling. Default is 64.
26
- timeout_duration (int, optional): Timeout for API responses in seconds. Default is 60.
27
- retry_attempts (int, optional): How many times to retry a failed request. Default is 2.
28
22
  """
29
23
 
30
24
  def __init__(self, api_key, model_name="gpt-3.5-turbo-0125", system_prompt="",temperature=1,num_workers=64,timeout_duration=60,retry_attempts=2,api_base_url=None):
31
25
 
32
- self.client = OpenAI(api_key=api_key)
26
+ self.is_claude = bool(re.search(r'claude', model_name, re.IGNORECASE))
27
+
28
+ if self.is_claude:
29
+ self.client = anthropic.Anthropic(api_key=api_key)
30
+ # Anthropic doesn't support custom base URL the same way
31
+ # If needed, this could be implemented differently
32
+ else:
33
+ self.client = OpenAI(api_key=api_key)
34
+ if api_base_url:
35
+ self.client.base_url = api_base_url
36
+
33
37
  self.model_name = model_name
34
38
  self.system_prompt = system_prompt
35
39
  self.temperature = temperature
36
40
  self.num_workers = num_workers
37
41
  self.timeout_duration = timeout_duration
38
42
  self.retry_attempts = retry_attempts
39
- self.miss_index =[]
40
- if api_base_url:
41
- self.client.base_url = api_base_url
43
+ self.miss_index = []
42
44
 
43
45
  def get_attitude(self, ask_text):
44
46
  index, ask_text = ask_text
45
-
46
- completion = self.client.chat.completions.create(
47
- model=self.model_name,
48
- messages=[
49
- {"role": "system", "content": self.system_prompt},
50
- {"role": "user", "content": ask_text}
51
- ],
52
- temperature=self.temperature,
53
- )
54
- return (index, completion.choices[0].message.content)
47
+ try:
48
+ if self.is_claude:
49
+ # Use the Anthropic Claude API
50
+ message = self.client.messages.create(
51
+ model=self.model_name,
52
+ max_tokens=1024, # You can make this configurable if needed
53
+ messages=[
54
+ {"role": "user", "content": ask_text}
55
+ ],
56
+ system=self.system_prompt if self.system_prompt else None,
57
+ temperature=self.temperature,
58
+ )
59
+ return (index, message.content[0].text)
60
+ else:
61
+ # Use the OpenAI API as before
62
+ completion = self.client.chat.completions.create(
63
+ model=self.model_name,
64
+ messages=[
65
+ {"role": "system", "content": self.system_prompt},
66
+ {"role": "user", "content": ask_text}
67
+ ],
68
+ temperature=self.temperature,
69
+ )
70
+ return (index, completion.choices[0].message.content)
71
+ except Exception as e:
72
+ print(f"Error occurred: {e}")
73
+ self.miss_index.append(index)
74
+ return (index, None)
55
75
 
56
76
  def process_attitude(self, message_list):
57
77
  new_list = []
58
78
  num_workers = self.num_workers
59
79
  timeout_duration = self.timeout_duration
60
- retry_attempts = 2
80
+ retry_attempts = self.retry_attempts
61
81
 
62
82
  executor = ThreadPoolExecutor(max_workers=num_workers)
63
83
  message_chunks = list(self.chunk_list(message_list, num_workers))
@@ -71,14 +91,14 @@ class GPTBatcher:
71
91
  new_list.extend(future.result() for future in done if future.done())
72
92
  if len(not_done) == 0:
73
93
  break
74
- future_to_message = {executor.submit(self.get_attitude, future_to_message[future]): future for future in not_done}
94
+ future_to_message = {executor.submit(self.get_attitude, future_to_message[future]): future_to_message[future] for future in not_done}
75
95
  except Exception as e:
76
96
  print(f"Error occurred: {e}")
77
97
  finally:
78
98
  executor.shutdown(wait=False)
79
99
  return new_list
80
100
 
81
- def complete_attitude_list(self,attitude_list, max_length):
101
+ def complete_attitude_list(self, attitude_list, max_length):
82
102
  completed_list = []
83
103
  current_index = 0
84
104
  for item in attitude_list:
@@ -102,7 +122,7 @@ class GPTBatcher:
102
122
  for i in range(0, len(lst), n):
103
123
  yield lst[i:i + n]
104
124
 
105
- def handle_message_list(self,message_list):
125
+ def handle_message_list(self, message_list):
106
126
  indexed_list = [(index, data) for index, data in enumerate(message_list)]
107
127
  max_length = len(indexed_list)
108
128
  attitude_list = self.process_attitude(indexed_list)
@@ -111,32 +131,50 @@ class GPTBatcher:
111
131
  attitude_list = [x[1] for x in attitude_list]
112
132
  return attitude_list
113
133
 
114
- def process_embedding(self,message_list):
115
- new_list = []
116
- executor = ThreadPoolExecutor(max_workers=self.num_workers)
117
- # Split message_list into chunks
118
- message_chunks = list(self.chunk_list(message_list, self.num_workers))
119
- fixed_get_embedding = partial(self.get_embedding)
120
- for chunk in tqdm(message_chunks, desc="Processing messages"):
121
- future_to_message = {executor.submit(fixed_get_embedding, message): message for message in chunk}
122
- for i in range(self.retry_attempts):
123
- done, not_done = wait(future_to_message.keys(), timeout=self.timeout_duration)
124
- for future in not_done:
125
- future.cancel()
126
- new_list.extend(future.result() for future in done if future.done())
127
- if len(not_done) == 0:
128
- break
129
- future_to_message = {executor.submit(fixed_get_embedding, future_to_message[future]): future_to_message[future] for future in not_done}
130
- executor.shutdown(wait=False)
131
- return new_list
132
- def get_embedding(self,text):
133
- index,text = text
134
- response = self.client.embeddings.create(
135
- input=text,
136
- model=self.model_name)
137
- return (index,response.data[0].embedding)
134
+ def get_embedding(self, text):
135
+ index, text = text
136
+ try:
137
+ if self.is_claude:
138
+ # Use Anthropic's embedding API if available
139
+ # Note: As of March 2025, make sure to check Anthropic's latest API
140
+ # for embeddings, as the format might have changed
141
+ response = self.client.embeddings.create(
142
+ model=self.model_name,
143
+ input=text
144
+ )
145
+ return (index, response.embedding)
146
+ else:
147
+ # Use OpenAI's embedding API
148
+ response = self.client.embeddings.create(
149
+ input=text,
150
+ model=self.model_name
151
+ )
152
+ return (index, response.data[0].embedding)
153
+ except Exception as e:
154
+ print(f"Error getting embedding: {e}")
155
+ self.miss_index.append(index)
156
+ return (index, None)
157
+
158
+ def process_embedding(self, message_list):
159
+ new_list = []
160
+ executor = ThreadPoolExecutor(max_workers=self.num_workers)
161
+ # Split message_list into chunks
162
+ message_chunks = list(self.chunk_list(message_list, self.num_workers))
163
+ fixed_get_embedding = partial(self.get_embedding)
164
+ for chunk in tqdm(message_chunks, desc="Processing messages"):
165
+ future_to_message = {executor.submit(fixed_get_embedding, message): message for message in chunk}
166
+ for i in range(self.retry_attempts):
167
+ done, not_done = wait(future_to_message.keys(), timeout=self.timeout_duration)
168
+ for future in not_done:
169
+ future.cancel()
170
+ new_list.extend(future.result() for future in done if future.done())
171
+ if len(not_done) == 0:
172
+ break
173
+ future_to_message = {executor.submit(fixed_get_embedding, future_to_message[future]): future_to_message[future] for future in not_done}
174
+ executor.shutdown(wait=False)
175
+ return new_list
138
176
 
139
- def handle_embedding_list(self,message_list):
177
+ def handle_embedding_list(self, message_list):
140
178
  indexed_list = [(index, data) for index, data in enumerate(message_list)]
141
179
  max_length = len(indexed_list)
142
180
  attitude_list = self.process_embedding(indexed_list)
@@ -148,5 +186,16 @@ class GPTBatcher:
148
186
  def get_miss_index(self):
149
187
  return self.miss_index
150
188
 
151
- # Add other necessary methods similar to the above, refactored to fit within this class structure.
152
-
189
+ # Example usage:
190
+ if __name__ == "__main__":
191
+ # For OpenAI
192
+ openai_batcher = GPTBatcher(
193
+ api_key="your_openai_api_key",
194
+ model_name="gpt-4-turbo"
195
+ )
196
+
197
+ # For Claude
198
+ claude_batcher = GPTBatcher(
199
+ api_key="your_anthropic_api_key",
200
+ model_name="claude-3-7-sonnet-20250219"
201
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gpt-batch
3
- Version: 0.1.5
3
+ Version: 0.1.8
4
4
  Summary: A package for batch processing with OpenAI API.
5
5
  Home-page: https://github.com/fengsxy/gpt_batch
6
6
  Author: Ted Yu
@@ -10,8 +10,8 @@ Platform: UNKNOWN
10
10
  Description-Content-Type: text/markdown
11
11
  Requires-Dist: openai
12
12
  Requires-Dist: tqdm
13
+ Requires-Dist: anthropic
13
14
 
14
- Certainly! Here's a clean and comprehensive README for your `GPTBatcher` tool, formatted in Markdown:
15
15
 
16
16
  ```markdown
17
17
  # GPT Batcher
@@ -62,6 +62,22 @@ print(result)
62
62
  # Expected output: ["embedding_1", "embedding_2", "embedding_3", "embedding_4"]
63
63
  ```
64
64
 
65
+ ### Handling Message Lists with different API
66
+
67
+ This example demonstrates how to send a list of questions and receive answers with different api:
68
+
69
+ ```python
70
+ from gpt_batch.batcher import GPTBatcher
71
+
72
+ # Initialize the batcher
73
+ batcher = GPTBatcher(api_key='sk-', model_name='deepseek-chat',api_base_url="https://api.deepseek.com/v1")
74
+
75
+
76
+ # Send a list of messages and receive answers
77
+ result = batcher.handle_message_list(['question_1', 'question_2', 'question_3', 'question_4'])
78
+
79
+ # Expected output: ["answer_1", "answer_2", "answer_3", "answer_4"]
80
+ ```
65
81
  ## Configuration
66
82
 
67
83
  The `GPTBatcher` class can be customized with several parameters to adjust its performance and behavior:
@@ -0,0 +1,8 @@
1
+ gpt_batch/__init__.py,sha256=zGDItktTxKLSQr44GY78dl5LKsSJig0Q59dzusqhU0U,59
2
+ gpt_batch/batcher.py,sha256=sruKqK_tY5WNMXvejNz_OAw6AHtJIqcvDKbMbX1rX_M,8959
3
+ tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ tests/test_batcher.py,sha256=FAlV_T5dT2OcM28yWPWWU9yIN0D7SBeYP1oTkvvzXKk,4077
5
+ gpt_batch-0.1.8.dist-info/METADATA,sha256=IIBetdVEVrUx30cxrzB4ZkAy0mGp3PVueJc9tm2WcZA,3401
6
+ gpt_batch-0.1.8.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
7
+ gpt_batch-0.1.8.dist-info/top_level.txt,sha256=FtvJB_L9W_S6jL4G8Em_YWphG1wdKAF20BHUrf4B0yM,16
8
+ gpt_batch-0.1.8.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: bdist_wheel (0.45.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
tests/test_batcher.py CHANGED
@@ -51,5 +51,27 @@ def test_get_miss_index():
51
51
  miss_index = batcher.get_miss_index()
52
52
  assert miss_index == [], "The miss index should be empty"
53
53
  # Optionally, you can add a test configuration if you have specific needs
54
+
55
+
56
+ def test_claude_handle_message_list():
57
+ # Initialize the GPTBatcher with Claude model
58
+ api_key = os.getenv('ANTHROPIC_API_KEY')
59
+ if not api_key:
60
+ raise ValueError("Anthropic API key must be set in the environment variables as ANTHROPIC_API_KEY")
61
+
62
+ batcher = GPTBatcher(
63
+ api_key=api_key,
64
+ model_name='claude-3-7-sonnet-20250219',
65
+ system_prompt="Your task is to discuss privacy questions and provide persuasive answers with supporting reasons."
66
+ )
67
+ message_list = ["I think privacy is important", "I don't think privacy is important"]
68
+
69
+ # Call the method under test
70
+ results = batcher.handle_message_list(message_list)
71
+
72
+ # Assertions to verify the length of the results and the structure of each item
73
+ assert len(results) == 2, "There should be two results, one for each message"
74
+ assert all(isinstance(result, str) and len(result) > 0 for result in results if result is not None), "Each result should be a non-empty string if not None"
75
+ assert batcher.is_claude, "Should recognize model as Claude"
54
76
  if __name__ == "__main__":
55
77
  pytest.main()
@@ -1,8 +0,0 @@
1
- gpt_batch/__init__.py,sha256=zGDItktTxKLSQr44GY78dl5LKsSJig0Q59dzusqhU0U,59
2
- gpt_batch/batcher.py,sha256=YvOX1V_9iX5jTX7xZOhjOWeT0IUlv9c-_UkLW2s1wFo,7395
3
- tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- tests/test_batcher.py,sha256=N88RZrSuBaDti6Lry7xipyGXHn3jKg85O12mjcHHZA0,3006
5
- gpt_batch-0.1.5.dist-info/METADATA,sha256=kb524fTeHxmmYZM6MnhG_G_gqwLRy4ptdVmjIcaIdac,2932
6
- gpt_batch-0.1.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- gpt_batch-0.1.5.dist-info/top_level.txt,sha256=FtvJB_L9W_S6jL4G8Em_YWphG1wdKAF20BHUrf4B0yM,16
8
- gpt_batch-0.1.5.dist-info/RECORD,,