gemba 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gemba/gpt_api.py CHANGED
@@ -6,6 +6,8 @@ from termcolor import colored
6
6
  import openai
7
7
  import tqdm
8
8
  from concurrent.futures import ThreadPoolExecutor
9
+ from collections import defaultdict
10
+
9
11
 
10
12
 
11
13
  # class for calling OpenAI API and handling cache
@@ -57,6 +59,22 @@ class GptApi:
57
59
  for full_answer in answers:
58
60
  finish_reason = full_answer["finish_reason"]
59
61
  full_answer = full_answer["answer"]
62
+
63
+ if finish_reason != "stop":
64
+ print(f"No valid answer, giving score 0")
65
+ errors = defaultdict(list)
66
+ errors["critical"].append("Judge errored, giving answer score 0.")
67
+ parsed_answers.append({
68
+ "temperature": temperature,
69
+ "answer_id": answer_id,
70
+ "answer": 0,
71
+ "errors": errors,
72
+ "prompt": prompt,
73
+ "finish_reason": finish_reason,
74
+ "model": model,
75
+ })
76
+ continue
77
+
60
78
  answer_id += 1
61
79
  answer = parse_response(full_answer)
62
80
  if isinstance(answer, tuple):
@@ -67,33 +85,32 @@ class GptApi:
67
85
  print(f"Answer (t={temperature}): " + colored(answer, "yellow") + " (" + colored(full_answer, "blue") + ")", file=sys.stderr)
68
86
  if answer is None:
69
87
  continue
70
- parsed_answers.append(
71
- {
72
- "temperature": temperature,
73
- "answer_id": answer_id,
74
- "answer": answer,
75
- "errors": errors,
76
- "prompt": prompt,
77
- "finish_reason": finish_reason,
78
- "model": model,
79
- }
80
- )
88
+ parsed_answers.append({
89
+ "temperature": temperature,
90
+ "answer_id": answer_id,
91
+ "answer": answer,
92
+ "errors": errors,
93
+ "prompt": prompt,
94
+ "finish_reason": finish_reason,
95
+ "model": model,
96
+ })
81
97
 
82
98
  # there was no valid answer, increase temperature and try again
83
99
  if len(parsed_answers) == 0:
100
+ print(f"No valid answer, increasing temperature to {temperature + 1} and trying again")
84
101
  return self.request(prompt, model, parse_response, temperature=temperature + 1, answer_id=answer_id, cache=cache)
85
102
 
86
103
  return parsed_answers
87
104
 
88
105
  def request_api(self, prompt, model, temperature=0, max_tokens=None):
89
106
  if temperature > 10:
90
- return []
107
+ return [{"answer": None, "finish_reason": "error"}]
91
108
 
92
109
  # Add maximum token limit
93
110
  MAX_TOKENS_LIMIT = 4000 # Adjust this based on your model's context window
94
111
  if max_tokens and max_tokens > MAX_TOKENS_LIMIT:
95
112
  print(f"Reached maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
96
- return []
113
+ return [{"answer": None, "finish_reason": "length"}]
97
114
 
98
115
  while True:
99
116
  try:
@@ -103,10 +120,10 @@ class GptApi:
103
120
  # response was filtered
104
121
  if hasattr(e, 'code'):
105
122
  if e.code == 'content_filter':
106
- return []
123
+ return [{"answer": None, "finish_reason": "filter"}]
107
124
  print(e.code, file=sys.stderr)
108
125
  if hasattr(e, 'error') and e.error['code'] == 'invalid_model_output':
109
- return []
126
+ return [{"answer": None, "finish_reason": "invalid"}]
110
127
 
111
128
  # frequent error is reaching the API limit
112
129
  print(colored("Error, retrying...", "red"), file=sys.stderr)
@@ -116,7 +133,7 @@ class GptApi:
116
133
  answers = []
117
134
  for choice in response.choices:
118
135
  if choice.message.content is None:
119
- return []
136
+ return [{"answer": None, "finish_reason": "invalid"}]
120
137
  if hasattr(choice, "message"):
121
138
  answer = choice.message.content.strip()
122
139
  else:
@@ -126,13 +143,13 @@ class GptApi:
126
143
  if choice.finish_reason != "stop":
127
144
  if self.verbose:
128
145
  print(colored(f"Increasing max tokens to fit answers.", "red") + colored(answer, "blue"), file=sys.stderr)
129
- print(f"Finish reason: {choice.finish_reason}", file=sys.stderr)
130
146
  if max_tokens is None:
131
147
  max_tokens = 500 # Set initial max_tokens if None
132
- new_max_tokens = max_tokens + 200
148
+ new_max_tokens = max_tokens * 2
149
+ print(f"Finish reason: {choice.finish_reason}, increasing max tokens to {new_max_tokens}", file=sys.stderr)
133
150
  if new_max_tokens > MAX_TOKENS_LIMIT:
134
151
  print(f"Would exceed maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
135
- return []
152
+ return [{"answer": None, "finish_reason": choice.finish_reason}]
136
153
  return self.request_api(prompt, model, temperature=temperature, max_tokens=new_max_tokens)
137
154
 
138
155
  answers.append({
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gemba
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: GEMBA — GPT Estimation Metric Based Assessment
5
5
  Project-URL: Homepage, https://github.com/joelniklaus/gemba
6
6
  Author-email: Joel Niklaus <joel@niklaus.ai>
@@ -2,13 +2,13 @@ gemba/__init__.py,sha256=0ZuEumkUMWPI5wQMY7OxLolELI9GYYlup-iJw8SwBgc,67
2
2
  gemba/gemba_da.py,sha256=YCOKKP7kZBL9e1d44Zr7aTa23BqLFvh4KDOfbNSMgOU,2360
3
3
  gemba/gemba_esa.py,sha256=nBCeFjrS24wXLOcAXHRSmZFYJSkUzRS4hfp2LEqYwp8,4461
4
4
  gemba/gemba_mqm_utils.py,sha256=qiIdJv7IDx0eeqpsTCHMoUeo8EUOhG6k-YfrzkRfxyw,9612
5
- gemba/gpt_api.py,sha256=UJGXQBnRLBujLGdQhr6HUvbvWYQIxqmQqa_JG8iS0Uc,7394
5
+ gemba/gpt_api.py,sha256=A1GYi0vxUGmephkadI-h6v6G52uDQ7yWOFvIxSRrN8o,8380
6
6
  gemba/mtme_tools.py,sha256=xpLxCzfnLHFIxsq_LOi1Lpb-gkyFGYqFXiq9y6O315Q,4667
7
7
  gemba/prompt.py,sha256=AuPBhO2OBL3EB5I37p-GX10sx29gRw35xFAnB3bqtII,7578
8
8
  gemba/scores.py,sha256=FmmBJ-ds-abExphcVUw9qaPMnKttPWobuXNwZKLAtEs,4388
9
9
  gemba/testset.py,sha256=tDvi6xQIBXrODg02WWINrYg9jNQqruCmhBrxe9AaK48,1926
10
10
  gemba/utils.py,sha256=Re5uW5dcFj3ITWIGpxjXdAKNDKQ7i4H-Tr_s74SQgmk,4311
11
- gemba-0.1.2.dist-info/METADATA,sha256=98Ge9LVScGEzoTyv6gQICfY4KA8V0Gq3927gcEPE5xI,3727
12
- gemba-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- gemba-0.1.2.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
14
- gemba-0.1.2.dist-info/RECORD,,
11
+ gemba-0.1.3.dist-info/METADATA,sha256=qyrjjVewIjFJeWKTAXw9rCMqLIj9OqCntenyw0F2oyw,3727
12
+ gemba-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ gemba-0.1.3.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
14
+ gemba-0.1.3.dist-info/RECORD,,
File without changes