gemba 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gemba/gpt_api.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1
1
|
import os
|
2
2
|
import sys
|
3
3
|
import time
|
4
|
-
import ipdb
|
5
4
|
import logging
|
6
5
|
from termcolor import colored
|
7
|
-
from datetime import datetime
|
8
6
|
import openai
|
9
7
|
import tqdm
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
9
|
+
from collections import defaultdict
|
10
|
+
|
10
11
|
|
11
12
|
|
12
13
|
# class for calling OpenAI API and handling cache
|
@@ -58,6 +59,22 @@ class GptApi:
|
|
58
59
|
for full_answer in answers:
|
59
60
|
finish_reason = full_answer["finish_reason"]
|
60
61
|
full_answer = full_answer["answer"]
|
62
|
+
|
63
|
+
if finish_reason != "stop":
|
64
|
+
print(f"No valid answer, giving score 0")
|
65
|
+
errors = defaultdict(list)
|
66
|
+
errors["critical"].append("Judge errored, giving answer score 0.")
|
67
|
+
parsed_answers.append({
|
68
|
+
"temperature": temperature,
|
69
|
+
"answer_id": answer_id,
|
70
|
+
"answer": 0,
|
71
|
+
"errors": errors,
|
72
|
+
"prompt": prompt,
|
73
|
+
"finish_reason": finish_reason,
|
74
|
+
"model": model,
|
75
|
+
})
|
76
|
+
continue
|
77
|
+
|
61
78
|
answer_id += 1
|
62
79
|
answer = parse_response(full_answer)
|
63
80
|
if isinstance(answer, tuple):
|
@@ -68,33 +85,32 @@ class GptApi:
|
|
68
85
|
print(f"Answer (t={temperature}): " + colored(answer, "yellow") + " (" + colored(full_answer, "blue") + ")", file=sys.stderr)
|
69
86
|
if answer is None:
|
70
87
|
continue
|
71
|
-
parsed_answers.append(
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
}
|
81
|
-
)
|
88
|
+
parsed_answers.append({
|
89
|
+
"temperature": temperature,
|
90
|
+
"answer_id": answer_id,
|
91
|
+
"answer": answer,
|
92
|
+
"errors": errors,
|
93
|
+
"prompt": prompt,
|
94
|
+
"finish_reason": finish_reason,
|
95
|
+
"model": model,
|
96
|
+
})
|
82
97
|
|
83
98
|
# there was no valid answer, increase temperature and try again
|
84
99
|
if len(parsed_answers) == 0:
|
100
|
+
print(f"No valid answer, increasing temperature to {temperature + 1} and trying again")
|
85
101
|
return self.request(prompt, model, parse_response, temperature=temperature + 1, answer_id=answer_id, cache=cache)
|
86
102
|
|
87
103
|
return parsed_answers
|
88
104
|
|
89
105
|
def request_api(self, prompt, model, temperature=0, max_tokens=None):
|
90
106
|
if temperature > 10:
|
91
|
-
return []
|
107
|
+
return [{"answer": None, "finish_reason": "error"}]
|
92
108
|
|
93
109
|
# Add maximum token limit
|
94
110
|
MAX_TOKENS_LIMIT = 4000 # Adjust this based on your model's context window
|
95
111
|
if max_tokens and max_tokens > MAX_TOKENS_LIMIT:
|
96
112
|
print(f"Reached maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
|
97
|
-
return []
|
113
|
+
return [{"answer": None, "finish_reason": "length"}]
|
98
114
|
|
99
115
|
while True:
|
100
116
|
try:
|
@@ -104,10 +120,10 @@ class GptApi:
|
|
104
120
|
# response was filtered
|
105
121
|
if hasattr(e, 'code'):
|
106
122
|
if e.code == 'content_filter':
|
107
|
-
return []
|
123
|
+
return [{"answer": None, "finish_reason": "filter"}]
|
108
124
|
print(e.code, file=sys.stderr)
|
109
125
|
if hasattr(e, 'error') and e.error['code'] == 'invalid_model_output':
|
110
|
-
return []
|
126
|
+
return [{"answer": None, "finish_reason": "invalid"}]
|
111
127
|
|
112
128
|
# frequent error is reaching the API limit
|
113
129
|
print(colored("Error, retrying...", "red"), file=sys.stderr)
|
@@ -117,7 +133,7 @@ class GptApi:
|
|
117
133
|
answers = []
|
118
134
|
for choice in response.choices:
|
119
135
|
if choice.message.content is None:
|
120
|
-
return []
|
136
|
+
return [{"answer": None, "finish_reason": "invalid"}]
|
121
137
|
if hasattr(choice, "message"):
|
122
138
|
answer = choice.message.content.strip()
|
123
139
|
else:
|
@@ -127,13 +143,13 @@ class GptApi:
|
|
127
143
|
if choice.finish_reason != "stop":
|
128
144
|
if self.verbose:
|
129
145
|
print(colored(f"Increasing max tokens to fit answers.", "red") + colored(answer, "blue"), file=sys.stderr)
|
130
|
-
print(f"Finish reason: {choice.finish_reason}", file=sys.stderr)
|
131
146
|
if max_tokens is None:
|
132
147
|
max_tokens = 500 # Set initial max_tokens if None
|
133
|
-
new_max_tokens = max_tokens
|
148
|
+
new_max_tokens = max_tokens * 2
|
149
|
+
print(f"Finish reason: {choice.finish_reason}, increasing max tokens to {new_max_tokens}", file=sys.stderr)
|
134
150
|
if new_max_tokens > MAX_TOKENS_LIMIT:
|
135
151
|
print(f"Would exceed maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
|
136
|
-
return []
|
152
|
+
return [{"answer": None, "finish_reason": choice.finish_reason}]
|
137
153
|
return self.request_api(prompt, model, temperature=temperature, max_tokens=new_max_tokens)
|
138
154
|
|
139
155
|
answers.append({
|
@@ -177,8 +193,13 @@ class GptApi:
|
|
177
193
|
|
178
194
|
def bulk_request(self, df, model, parse_mqm_answer, cache, max_tokens=None):
|
179
195
|
answers = []
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
196
|
+
with ThreadPoolExecutor(100) as executor:
|
197
|
+
futures = [
|
198
|
+
executor.submit(self.request, row["prompt"], model, parse_mqm_answer, cache=cache, max_tokens=max_tokens)
|
199
|
+
for _, row in df.iterrows()
|
200
|
+
]
|
201
|
+
|
202
|
+
for future in tqdm.tqdm(futures, total=len(df), file=sys.stderr):
|
203
|
+
answers += future.result()
|
204
|
+
|
184
205
|
return answers
|
@@ -1,10 +1,11 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: gemba
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.3
|
4
4
|
Summary: GEMBA — GPT Estimation Metric Based Assessment
|
5
5
|
Project-URL: Homepage, https://github.com/joelniklaus/gemba
|
6
6
|
Author-email: Joel Niklaus <joel@niklaus.ai>
|
7
|
-
License: MIT
|
7
|
+
License-Expression: MIT
|
8
|
+
License-File: LICENSE.md
|
8
9
|
Classifier: License :: OSI Approved :: MIT License
|
9
10
|
Classifier: Operating System :: OS Independent
|
10
11
|
Classifier: Programming Language :: Python :: 3
|
@@ -2,13 +2,13 @@ gemba/__init__.py,sha256=0ZuEumkUMWPI5wQMY7OxLolELI9GYYlup-iJw8SwBgc,67
|
|
2
2
|
gemba/gemba_da.py,sha256=YCOKKP7kZBL9e1d44Zr7aTa23BqLFvh4KDOfbNSMgOU,2360
|
3
3
|
gemba/gemba_esa.py,sha256=nBCeFjrS24wXLOcAXHRSmZFYJSkUzRS4hfp2LEqYwp8,4461
|
4
4
|
gemba/gemba_mqm_utils.py,sha256=qiIdJv7IDx0eeqpsTCHMoUeo8EUOhG6k-YfrzkRfxyw,9612
|
5
|
-
gemba/gpt_api.py,sha256=
|
5
|
+
gemba/gpt_api.py,sha256=A1GYi0vxUGmephkadI-h6v6G52uDQ7yWOFvIxSRrN8o,8380
|
6
6
|
gemba/mtme_tools.py,sha256=xpLxCzfnLHFIxsq_LOi1Lpb-gkyFGYqFXiq9y6O315Q,4667
|
7
7
|
gemba/prompt.py,sha256=AuPBhO2OBL3EB5I37p-GX10sx29gRw35xFAnB3bqtII,7578
|
8
8
|
gemba/scores.py,sha256=FmmBJ-ds-abExphcVUw9qaPMnKttPWobuXNwZKLAtEs,4388
|
9
9
|
gemba/testset.py,sha256=tDvi6xQIBXrODg02WWINrYg9jNQqruCmhBrxe9AaK48,1926
|
10
10
|
gemba/utils.py,sha256=Re5uW5dcFj3ITWIGpxjXdAKNDKQ7i4H-Tr_s74SQgmk,4311
|
11
|
-
gemba-0.1.
|
12
|
-
gemba-0.1.
|
13
|
-
gemba-0.1.
|
14
|
-
gemba-0.1.
|
11
|
+
gemba-0.1.3.dist-info/METADATA,sha256=qyrjjVewIjFJeWKTAXw9rCMqLIj9OqCntenyw0F2oyw,3727
|
12
|
+
gemba-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
13
|
+
gemba-0.1.3.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
|
14
|
+
gemba-0.1.3.dist-info/RECORD,,
|
File without changes
|