gemba 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gemba/gpt_api.py
CHANGED
@@ -6,6 +6,8 @@ from termcolor import colored
|
|
6
6
|
import openai
|
7
7
|
import tqdm
|
8
8
|
from concurrent.futures import ThreadPoolExecutor
|
9
|
+
from collections import defaultdict
|
10
|
+
|
9
11
|
|
10
12
|
|
11
13
|
# class for calling OpenAI API and handling cache
|
@@ -57,6 +59,22 @@ class GptApi:
|
|
57
59
|
for full_answer in answers:
|
58
60
|
finish_reason = full_answer["finish_reason"]
|
59
61
|
full_answer = full_answer["answer"]
|
62
|
+
|
63
|
+
if finish_reason != "stop":
|
64
|
+
print(f"No valid answer, giving score 0")
|
65
|
+
errors = defaultdict(list)
|
66
|
+
errors["critical"].append("Judge errored, giving answer score 0.")
|
67
|
+
parsed_answers.append({
|
68
|
+
"temperature": temperature,
|
69
|
+
"answer_id": answer_id,
|
70
|
+
"answer": 0,
|
71
|
+
"errors": errors,
|
72
|
+
"prompt": prompt,
|
73
|
+
"finish_reason": finish_reason,
|
74
|
+
"model": model,
|
75
|
+
})
|
76
|
+
continue
|
77
|
+
|
60
78
|
answer_id += 1
|
61
79
|
answer = parse_response(full_answer)
|
62
80
|
if isinstance(answer, tuple):
|
@@ -67,33 +85,32 @@ class GptApi:
|
|
67
85
|
print(f"Answer (t={temperature}): " + colored(answer, "yellow") + " (" + colored(full_answer, "blue") + ")", file=sys.stderr)
|
68
86
|
if answer is None:
|
69
87
|
continue
|
70
|
-
parsed_answers.append(
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
}
|
80
|
-
)
|
88
|
+
parsed_answers.append({
|
89
|
+
"temperature": temperature,
|
90
|
+
"answer_id": answer_id,
|
91
|
+
"answer": answer,
|
92
|
+
"errors": errors,
|
93
|
+
"prompt": prompt,
|
94
|
+
"finish_reason": finish_reason,
|
95
|
+
"model": model,
|
96
|
+
})
|
81
97
|
|
82
98
|
# there was no valid answer, increase temperature and try again
|
83
99
|
if len(parsed_answers) == 0:
|
100
|
+
print(f"No valid answer, increasing temperature to {temperature + 1} and trying again")
|
84
101
|
return self.request(prompt, model, parse_response, temperature=temperature + 1, answer_id=answer_id, cache=cache)
|
85
102
|
|
86
103
|
return parsed_answers
|
87
104
|
|
88
105
|
def request_api(self, prompt, model, temperature=0, max_tokens=None):
|
89
106
|
if temperature > 10:
|
90
|
-
return []
|
107
|
+
return [{"answer": None, "finish_reason": "error"}]
|
91
108
|
|
92
109
|
# Add maximum token limit
|
93
110
|
MAX_TOKENS_LIMIT = 4000 # Adjust this based on your model's context window
|
94
111
|
if max_tokens and max_tokens > MAX_TOKENS_LIMIT:
|
95
112
|
print(f"Reached maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
|
96
|
-
return []
|
113
|
+
return [{"answer": None, "finish_reason": "length"}]
|
97
114
|
|
98
115
|
while True:
|
99
116
|
try:
|
@@ -103,10 +120,10 @@ class GptApi:
|
|
103
120
|
# response was filtered
|
104
121
|
if hasattr(e, 'code'):
|
105
122
|
if e.code == 'content_filter':
|
106
|
-
return []
|
123
|
+
return [{"answer": None, "finish_reason": "filter"}]
|
107
124
|
print(e.code, file=sys.stderr)
|
108
125
|
if hasattr(e, 'error') and e.error['code'] == 'invalid_model_output':
|
109
|
-
return []
|
126
|
+
return [{"answer": None, "finish_reason": "invalid"}]
|
110
127
|
|
111
128
|
# frequent error is reaching the API limit
|
112
129
|
print(colored("Error, retrying...", "red"), file=sys.stderr)
|
@@ -116,7 +133,7 @@ class GptApi:
|
|
116
133
|
answers = []
|
117
134
|
for choice in response.choices:
|
118
135
|
if choice.message.content is None:
|
119
|
-
return []
|
136
|
+
return [{"answer": None, "finish_reason": "invalid"}]
|
120
137
|
if hasattr(choice, "message"):
|
121
138
|
answer = choice.message.content.strip()
|
122
139
|
else:
|
@@ -126,13 +143,13 @@ class GptApi:
|
|
126
143
|
if choice.finish_reason != "stop":
|
127
144
|
if self.verbose:
|
128
145
|
print(colored(f"Increasing max tokens to fit answers.", "red") + colored(answer, "blue"), file=sys.stderr)
|
129
|
-
print(f"Finish reason: {choice.finish_reason}", file=sys.stderr)
|
130
146
|
if max_tokens is None:
|
131
147
|
max_tokens = 500 # Set initial max_tokens if None
|
132
|
-
new_max_tokens = max_tokens
|
148
|
+
new_max_tokens = max_tokens * 2
|
149
|
+
print(f"Finish reason: {choice.finish_reason}, increasing max tokens to {new_max_tokens}", file=sys.stderr)
|
133
150
|
if new_max_tokens > MAX_TOKENS_LIMIT:
|
134
151
|
print(f"Would exceed maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
|
135
|
-
return []
|
152
|
+
return [{"answer": None, "finish_reason": choice.finish_reason}]
|
136
153
|
return self.request_api(prompt, model, temperature=temperature, max_tokens=new_max_tokens)
|
137
154
|
|
138
155
|
answers.append({
|
@@ -2,13 +2,13 @@ gemba/__init__.py,sha256=0ZuEumkUMWPI5wQMY7OxLolELI9GYYlup-iJw8SwBgc,67
|
|
2
2
|
gemba/gemba_da.py,sha256=YCOKKP7kZBL9e1d44Zr7aTa23BqLFvh4KDOfbNSMgOU,2360
|
3
3
|
gemba/gemba_esa.py,sha256=nBCeFjrS24wXLOcAXHRSmZFYJSkUzRS4hfp2LEqYwp8,4461
|
4
4
|
gemba/gemba_mqm_utils.py,sha256=qiIdJv7IDx0eeqpsTCHMoUeo8EUOhG6k-YfrzkRfxyw,9612
|
5
|
-
gemba/gpt_api.py,sha256=
|
5
|
+
gemba/gpt_api.py,sha256=A1GYi0vxUGmephkadI-h6v6G52uDQ7yWOFvIxSRrN8o,8380
|
6
6
|
gemba/mtme_tools.py,sha256=xpLxCzfnLHFIxsq_LOi1Lpb-gkyFGYqFXiq9y6O315Q,4667
|
7
7
|
gemba/prompt.py,sha256=AuPBhO2OBL3EB5I37p-GX10sx29gRw35xFAnB3bqtII,7578
|
8
8
|
gemba/scores.py,sha256=FmmBJ-ds-abExphcVUw9qaPMnKttPWobuXNwZKLAtEs,4388
|
9
9
|
gemba/testset.py,sha256=tDvi6xQIBXrODg02WWINrYg9jNQqruCmhBrxe9AaK48,1926
|
10
10
|
gemba/utils.py,sha256=Re5uW5dcFj3ITWIGpxjXdAKNDKQ7i4H-Tr_s74SQgmk,4311
|
11
|
-
gemba-0.1.
|
12
|
-
gemba-0.1.
|
13
|
-
gemba-0.1.
|
14
|
-
gemba-0.1.
|
11
|
+
gemba-0.1.3.dist-info/METADATA,sha256=qyrjjVewIjFJeWKTAXw9rCMqLIj9OqCntenyw0F2oyw,3727
|
12
|
+
gemba-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
13
|
+
gemba-0.1.3.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
|
14
|
+
gemba-0.1.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|