gemba 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
gemba/gpt_api.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1
1
|
import os
|
2
2
|
import sys
|
3
3
|
import time
|
4
|
-
import ipdb
|
5
4
|
import logging
|
6
5
|
from termcolor import colored
|
7
|
-
from datetime import datetime
|
8
6
|
import openai
|
9
7
|
import tqdm
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
9
|
+
from collections import defaultdict
|
10
|
+
|
10
11
|
|
11
12
|
|
12
13
|
# class for calling OpenAI API and handling cache
|
@@ -58,6 +59,22 @@ class GptApi:
|
|
58
59
|
for full_answer in answers:
|
59
60
|
finish_reason = full_answer["finish_reason"]
|
60
61
|
full_answer = full_answer["answer"]
|
62
|
+
|
63
|
+
if finish_reason != "stop":
|
64
|
+
print(f"No valid answer, giving score 0")
|
65
|
+
errors = defaultdict(list)
|
66
|
+
errors["critical"].append("Judge errored, giving answer score 0.")
|
67
|
+
parsed_answers.append({
|
68
|
+
"temperature": temperature,
|
69
|
+
"answer_id": answer_id,
|
70
|
+
"answer": 0,
|
71
|
+
"errors": errors,
|
72
|
+
"prompt": prompt,
|
73
|
+
"finish_reason": finish_reason,
|
74
|
+
"model": model,
|
75
|
+
})
|
76
|
+
continue
|
77
|
+
|
61
78
|
answer_id += 1
|
62
79
|
answer = parse_response(full_answer)
|
63
80
|
if isinstance(answer, tuple):
|
@@ -68,33 +85,32 @@ class GptApi:
|
|
68
85
|
print(f"Answer (t={temperature}): " + colored(answer, "yellow") + " (" + colored(full_answer, "blue") + ")", file=sys.stderr)
|
69
86
|
if answer is None:
|
70
87
|
continue
|
71
|
-
parsed_answers.append(
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
}
|
81
|
-
)
|
88
|
+
parsed_answers.append({
|
89
|
+
"temperature": temperature,
|
90
|
+
"answer_id": answer_id,
|
91
|
+
"answer": answer,
|
92
|
+
"errors": errors,
|
93
|
+
"prompt": prompt,
|
94
|
+
"finish_reason": finish_reason,
|
95
|
+
"model": model,
|
96
|
+
})
|
82
97
|
|
83
98
|
# there was no valid answer, increase temperature and try again
|
84
99
|
if len(parsed_answers) == 0:
|
100
|
+
print(f"No valid answer, increasing temperature to {temperature + 1} and trying again")
|
85
101
|
return self.request(prompt, model, parse_response, temperature=temperature + 1, answer_id=answer_id, cache=cache)
|
86
102
|
|
87
103
|
return parsed_answers
|
88
104
|
|
89
105
|
def request_api(self, prompt, model, temperature=0, max_tokens=None):
|
90
106
|
if temperature > 10:
|
91
|
-
return []
|
107
|
+
return [{"answer": None, "finish_reason": "error"}]
|
92
108
|
|
93
109
|
# Add maximum token limit
|
94
110
|
MAX_TOKENS_LIMIT = 4000 # Adjust this based on your model's context window
|
95
111
|
if max_tokens and max_tokens > MAX_TOKENS_LIMIT:
|
96
112
|
print(f"Reached maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
|
97
|
-
return []
|
113
|
+
return [{"answer": None, "finish_reason": "length"}]
|
98
114
|
|
99
115
|
while True:
|
100
116
|
try:
|
@@ -104,10 +120,10 @@ class GptApi:
|
|
104
120
|
# response was filtered
|
105
121
|
if hasattr(e, 'code'):
|
106
122
|
if e.code == 'content_filter':
|
107
|
-
return []
|
123
|
+
return [{"answer": None, "finish_reason": "filter"}]
|
108
124
|
print(e.code, file=sys.stderr)
|
109
125
|
if hasattr(e, 'error') and e.error['code'] == 'invalid_model_output':
|
110
|
-
return []
|
126
|
+
return [{"answer": None, "finish_reason": "invalid"}]
|
111
127
|
|
112
128
|
# frequent error is reaching the API limit
|
113
129
|
print(colored("Error, retrying...", "red"), file=sys.stderr)
|
@@ -117,7 +133,7 @@ class GptApi:
|
|
117
133
|
answers = []
|
118
134
|
for choice in response.choices:
|
119
135
|
if choice.message.content is None:
|
120
|
-
return []
|
136
|
+
return [{"answer": None, "finish_reason": "invalid"}]
|
121
137
|
if hasattr(choice, "message"):
|
122
138
|
answer = choice.message.content.strip()
|
123
139
|
else:
|
@@ -127,13 +143,13 @@ class GptApi:
|
|
127
143
|
if choice.finish_reason != "stop":
|
128
144
|
if self.verbose:
|
129
145
|
print(colored(f"Increasing max tokens to fit answers.", "red") + colored(answer, "blue"), file=sys.stderr)
|
130
|
-
print(f"Finish reason: {choice.finish_reason}", file=sys.stderr)
|
131
146
|
if max_tokens is None:
|
132
147
|
max_tokens = 500 # Set initial max_tokens if None
|
133
|
-
new_max_tokens = max_tokens
|
148
|
+
new_max_tokens = max_tokens * 2
|
149
|
+
print(f"Finish reason: {choice.finish_reason}, increasing max tokens to {new_max_tokens}", file=sys.stderr)
|
134
150
|
if new_max_tokens > MAX_TOKENS_LIMIT:
|
135
151
|
print(f"Would exceed maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
|
136
|
-
return []
|
152
|
+
return [{"answer": None, "finish_reason": choice.finish_reason}]
|
137
153
|
return self.request_api(prompt, model, temperature=temperature, max_tokens=new_max_tokens)
|
138
154
|
|
139
155
|
answers.append({
|
@@ -177,8 +193,13 @@ class GptApi:
|
|
177
193
|
|
178
194
|
def bulk_request(self, df, model, parse_mqm_answer, cache, max_tokens=None):
|
179
195
|
answers = []
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
196
|
+
with ThreadPoolExecutor(100) as executor:
|
197
|
+
futures = [
|
198
|
+
executor.submit(self.request, row["prompt"], model, parse_mqm_answer, cache=cache, max_tokens=max_tokens)
|
199
|
+
for _, row in df.iterrows()
|
200
|
+
]
|
201
|
+
|
202
|
+
for future in tqdm.tqdm(futures, total=len(df), file=sys.stderr):
|
203
|
+
answers += future.result()
|
204
|
+
|
184
205
|
return answers
|
@@ -1,10 +1,11 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: gemba
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.3
|
4
4
|
Summary: GEMBA — GPT Estimation Metric Based Assessment
|
5
5
|
Project-URL: Homepage, https://github.com/joelniklaus/gemba
|
6
6
|
Author-email: Joel Niklaus <joel@niklaus.ai>
|
7
|
-
License: MIT
|
7
|
+
License-Expression: MIT
|
8
|
+
License-File: LICENSE.md
|
8
9
|
Classifier: License :: OSI Approved :: MIT License
|
9
10
|
Classifier: Operating System :: OS Independent
|
10
11
|
Classifier: Programming Language :: Python :: 3
|
@@ -2,13 +2,13 @@ gemba/__init__.py,sha256=0ZuEumkUMWPI5wQMY7OxLolELI9GYYlup-iJw8SwBgc,67
|
|
2
2
|
gemba/gemba_da.py,sha256=YCOKKP7kZBL9e1d44Zr7aTa23BqLFvh4KDOfbNSMgOU,2360
|
3
3
|
gemba/gemba_esa.py,sha256=nBCeFjrS24wXLOcAXHRSmZFYJSkUzRS4hfp2LEqYwp8,4461
|
4
4
|
gemba/gemba_mqm_utils.py,sha256=qiIdJv7IDx0eeqpsTCHMoUeo8EUOhG6k-YfrzkRfxyw,9612
|
5
|
-
gemba/gpt_api.py,sha256=
|
5
|
+
gemba/gpt_api.py,sha256=A1GYi0vxUGmephkadI-h6v6G52uDQ7yWOFvIxSRrN8o,8380
|
6
6
|
gemba/mtme_tools.py,sha256=xpLxCzfnLHFIxsq_LOi1Lpb-gkyFGYqFXiq9y6O315Q,4667
|
7
7
|
gemba/prompt.py,sha256=AuPBhO2OBL3EB5I37p-GX10sx29gRw35xFAnB3bqtII,7578
|
8
8
|
gemba/scores.py,sha256=FmmBJ-ds-abExphcVUw9qaPMnKttPWobuXNwZKLAtEs,4388
|
9
9
|
gemba/testset.py,sha256=tDvi6xQIBXrODg02WWINrYg9jNQqruCmhBrxe9AaK48,1926
|
10
10
|
gemba/utils.py,sha256=Re5uW5dcFj3ITWIGpxjXdAKNDKQ7i4H-Tr_s74SQgmk,4311
|
11
|
-
gemba-0.1.
|
12
|
-
gemba-0.1.
|
13
|
-
gemba-0.1.
|
14
|
-
gemba-0.1.
|
11
|
+
gemba-0.1.3.dist-info/METADATA,sha256=qyrjjVewIjFJeWKTAXw9rCMqLIj9OqCntenyw0F2oyw,3727
|
12
|
+
gemba-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
13
|
+
gemba-0.1.3.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
|
14
|
+
gemba-0.1.3.dist-info/RECORD,,
|
File without changes
|