gemba 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
gemba/gpt_api.py CHANGED
@@ -6,6 +6,8 @@ from termcolor import colored
6
6
  import openai
7
7
  import tqdm
8
8
  from concurrent.futures import ThreadPoolExecutor
9
+ from collections import defaultdict
10
+
9
11
 
10
12
 
11
13
  # class for calling OpenAI API and handling cache
@@ -57,6 +59,22 @@ class GptApi:
57
59
  for full_answer in answers:
58
60
  finish_reason = full_answer["finish_reason"]
59
61
  full_answer = full_answer["answer"]
62
+
63
+ if finish_reason != "stop":
64
+ print(f"No valid answer, giving score 0")
65
+ errors = defaultdict(list)
66
+ errors["critical"].append("Judge errored, giving answer score 0.")
67
+ parsed_answers.append({
68
+ "temperature": temperature,
69
+ "answer_id": answer_id,
70
+ "answer": 0,
71
+ "errors": errors,
72
+ "prompt": prompt,
73
+ "finish_reason": finish_reason,
74
+ "model": model,
75
+ })
76
+ continue
77
+
60
78
  answer_id += 1
61
79
  answer = parse_response(full_answer)
62
80
  if isinstance(answer, tuple):
@@ -67,33 +85,32 @@ class GptApi:
67
85
  print(f"Answer (t={temperature}): " + colored(answer, "yellow") + " (" + colored(full_answer, "blue") + ")", file=sys.stderr)
68
86
  if answer is None:
69
87
  continue
70
- parsed_answers.append(
71
- {
72
- "temperature": temperature,
73
- "answer_id": answer_id,
74
- "answer": answer,
75
- "errors": errors,
76
- "prompt": prompt,
77
- "finish_reason": finish_reason,
78
- "model": model,
79
- }
80
- )
88
+ parsed_answers.append({
89
+ "temperature": temperature,
90
+ "answer_id": answer_id,
91
+ "answer": answer,
92
+ "errors": errors,
93
+ "prompt": prompt,
94
+ "finish_reason": finish_reason,
95
+ "model": model,
96
+ })
81
97
 
82
98
  # there was no valid answer, increase temperature and try again
83
99
  if len(parsed_answers) == 0:
100
+ print(f"No valid answer, increasing temperature to {temperature + 1} and trying again")
84
101
  return self.request(prompt, model, parse_response, temperature=temperature + 1, answer_id=answer_id, cache=cache)
85
102
 
86
103
  return parsed_answers
87
104
 
88
105
  def request_api(self, prompt, model, temperature=0, max_tokens=None):
89
106
  if temperature > 10:
90
- return []
107
+ return [{"answer": None, "finish_reason": "error"}]
91
108
 
92
109
  # Add maximum token limit
93
110
  MAX_TOKENS_LIMIT = 4000 # Adjust this based on your model's context window
94
111
  if max_tokens and max_tokens > MAX_TOKENS_LIMIT:
95
112
  print(f"Reached maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
96
- return []
113
+ return [{"answer": None, "finish_reason": "length"}]
97
114
 
98
115
  while True:
99
116
  try:
@@ -103,10 +120,10 @@ class GptApi:
103
120
  # response was filtered
104
121
  if hasattr(e, 'code'):
105
122
  if e.code == 'content_filter':
106
- return []
123
+ return [{"answer": None, "finish_reason": "filter"}]
107
124
  print(e.code, file=sys.stderr)
108
125
  if hasattr(e, 'error') and e.error['code'] == 'invalid_model_output':
109
- return []
126
+ return [{"answer": None, "finish_reason": "invalid"}]
110
127
 
111
128
  # frequent error is reaching the API limit
112
129
  print(colored("Error, retrying...", "red"), file=sys.stderr)
@@ -116,7 +133,7 @@ class GptApi:
116
133
  answers = []
117
134
  for choice in response.choices:
118
135
  if choice.message.content is None:
119
- return []
136
+ return [{"answer": None, "finish_reason": "invalid"}]
120
137
  if hasattr(choice, "message"):
121
138
  answer = choice.message.content.strip()
122
139
  else:
@@ -126,13 +143,13 @@ class GptApi:
126
143
  if choice.finish_reason != "stop":
127
144
  if self.verbose:
128
145
  print(colored(f"Increasing max tokens to fit answers.", "red") + colored(answer, "blue"), file=sys.stderr)
129
- print(f"Finish reason: {choice.finish_reason}", file=sys.stderr)
130
146
  if max_tokens is None:
131
147
  max_tokens = 500 # Set initial max_tokens if None
132
- new_max_tokens = max_tokens + 200
148
+ new_max_tokens = max_tokens * 2
149
+ print(f"Finish reason: {choice.finish_reason}, increasing max tokens to {new_max_tokens}", file=sys.stderr)
133
150
  if new_max_tokens > MAX_TOKENS_LIMIT:
134
151
  print(f"Would exceed maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
135
- return []
152
+ return [{"answer": None, "finish_reason": choice.finish_reason}]
136
153
  return self.request_api(prompt, model, temperature=temperature, max_tokens=new_max_tokens)
137
154
 
138
155
  answers.append({
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gemba
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: GEMBA — GPT Estimation Metric Based Assessment
5
5
  Project-URL: Homepage, https://github.com/joelniklaus/gemba
6
6
  Author-email: Joel Niklaus <joel@niklaus.ai>
@@ -2,13 +2,13 @@ gemba/__init__.py,sha256=0ZuEumkUMWPI5wQMY7OxLolELI9GYYlup-iJw8SwBgc,67
2
2
  gemba/gemba_da.py,sha256=YCOKKP7kZBL9e1d44Zr7aTa23BqLFvh4KDOfbNSMgOU,2360
3
3
  gemba/gemba_esa.py,sha256=nBCeFjrS24wXLOcAXHRSmZFYJSkUzRS4hfp2LEqYwp8,4461
4
4
  gemba/gemba_mqm_utils.py,sha256=qiIdJv7IDx0eeqpsTCHMoUeo8EUOhG6k-YfrzkRfxyw,9612
5
- gemba/gpt_api.py,sha256=UJGXQBnRLBujLGdQhr6HUvbvWYQIxqmQqa_JG8iS0Uc,7394
5
+ gemba/gpt_api.py,sha256=A1GYi0vxUGmephkadI-h6v6G52uDQ7yWOFvIxSRrN8o,8380
6
6
  gemba/mtme_tools.py,sha256=xpLxCzfnLHFIxsq_LOi1Lpb-gkyFGYqFXiq9y6O315Q,4667
7
7
  gemba/prompt.py,sha256=AuPBhO2OBL3EB5I37p-GX10sx29gRw35xFAnB3bqtII,7578
8
8
  gemba/scores.py,sha256=FmmBJ-ds-abExphcVUw9qaPMnKttPWobuXNwZKLAtEs,4388
9
9
  gemba/testset.py,sha256=tDvi6xQIBXrODg02WWINrYg9jNQqruCmhBrxe9AaK48,1926
10
10
  gemba/utils.py,sha256=Re5uW5dcFj3ITWIGpxjXdAKNDKQ7i4H-Tr_s74SQgmk,4311
11
- gemba-0.1.2.dist-info/METADATA,sha256=98Ge9LVScGEzoTyv6gQICfY4KA8V0Gq3927gcEPE5xI,3727
12
- gemba-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- gemba-0.1.2.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
14
- gemba-0.1.2.dist-info/RECORD,,
11
+ gemba-0.1.3.dist-info/METADATA,sha256=qyrjjVewIjFJeWKTAXw9rCMqLIj9OqCntenyw0F2oyw,3727
12
+ gemba-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ gemba-0.1.3.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
14
+ gemba-0.1.3.dist-info/RECORD,,
File without changes