gemba 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
gemba/gpt_api.py CHANGED
@@ -1,12 +1,13 @@
1
1
  import os
2
2
  import sys
3
3
  import time
4
- import ipdb
5
4
  import logging
6
5
  from termcolor import colored
7
- from datetime import datetime
8
6
  import openai
9
7
  import tqdm
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from collections import defaultdict
10
+
10
11
 
11
12
 
12
13
  # class for calling OpenAI API and handling cache
@@ -58,6 +59,22 @@ class GptApi:
58
59
  for full_answer in answers:
59
60
  finish_reason = full_answer["finish_reason"]
60
61
  full_answer = full_answer["answer"]
62
+
63
+ if finish_reason != "stop":
64
+ print(f"No valid answer, giving score 0")
65
+ errors = defaultdict(list)
66
+ errors["critical"].append("Judge errored, giving answer score 0.")
67
+ parsed_answers.append({
68
+ "temperature": temperature,
69
+ "answer_id": answer_id,
70
+ "answer": 0,
71
+ "errors": errors,
72
+ "prompt": prompt,
73
+ "finish_reason": finish_reason,
74
+ "model": model,
75
+ })
76
+ continue
77
+
61
78
  answer_id += 1
62
79
  answer = parse_response(full_answer)
63
80
  if isinstance(answer, tuple):
@@ -68,33 +85,32 @@ class GptApi:
68
85
  print(f"Answer (t={temperature}): " + colored(answer, "yellow") + " (" + colored(full_answer, "blue") + ")", file=sys.stderr)
69
86
  if answer is None:
70
87
  continue
71
- parsed_answers.append(
72
- {
73
- "temperature": temperature,
74
- "answer_id": answer_id,
75
- "answer": answer,
76
- "errors": errors,
77
- "prompt": prompt,
78
- "finish_reason": finish_reason,
79
- "model": model,
80
- }
81
- )
88
+ parsed_answers.append({
89
+ "temperature": temperature,
90
+ "answer_id": answer_id,
91
+ "answer": answer,
92
+ "errors": errors,
93
+ "prompt": prompt,
94
+ "finish_reason": finish_reason,
95
+ "model": model,
96
+ })
82
97
 
83
98
  # there was no valid answer, increase temperature and try again
84
99
  if len(parsed_answers) == 0:
100
+ print(f"No valid answer, increasing temperature to {temperature + 1} and trying again")
85
101
  return self.request(prompt, model, parse_response, temperature=temperature + 1, answer_id=answer_id, cache=cache)
86
102
 
87
103
  return parsed_answers
88
104
 
89
105
  def request_api(self, prompt, model, temperature=0, max_tokens=None):
90
106
  if temperature > 10:
91
- return []
107
+ return [{"answer": None, "finish_reason": "error"}]
92
108
 
93
109
  # Add maximum token limit
94
110
  MAX_TOKENS_LIMIT = 4000 # Adjust this based on your model's context window
95
111
  if max_tokens and max_tokens > MAX_TOKENS_LIMIT:
96
112
  print(f"Reached maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
97
- return []
113
+ return [{"answer": None, "finish_reason": "length"}]
98
114
 
99
115
  while True:
100
116
  try:
@@ -104,10 +120,10 @@ class GptApi:
104
120
  # response was filtered
105
121
  if hasattr(e, 'code'):
106
122
  if e.code == 'content_filter':
107
- return []
123
+ return [{"answer": None, "finish_reason": "filter"}]
108
124
  print(e.code, file=sys.stderr)
109
125
  if hasattr(e, 'error') and e.error['code'] == 'invalid_model_output':
110
- return []
126
+ return [{"answer": None, "finish_reason": "invalid"}]
111
127
 
112
128
  # frequent error is reaching the API limit
113
129
  print(colored("Error, retrying...", "red"), file=sys.stderr)
@@ -117,7 +133,7 @@ class GptApi:
117
133
  answers = []
118
134
  for choice in response.choices:
119
135
  if choice.message.content is None:
120
- return []
136
+ return [{"answer": None, "finish_reason": "invalid"}]
121
137
  if hasattr(choice, "message"):
122
138
  answer = choice.message.content.strip()
123
139
  else:
@@ -127,13 +143,13 @@ class GptApi:
127
143
  if choice.finish_reason != "stop":
128
144
  if self.verbose:
129
145
  print(colored(f"Increasing max tokens to fit answers.", "red") + colored(answer, "blue"), file=sys.stderr)
130
- print(f"Finish reason: {choice.finish_reason}", file=sys.stderr)
131
146
  if max_tokens is None:
132
147
  max_tokens = 500 # Set initial max_tokens if None
133
- new_max_tokens = max_tokens + 200
148
+ new_max_tokens = max_tokens * 2
149
+ print(f"Finish reason: {choice.finish_reason}, increasing max tokens to {new_max_tokens}", file=sys.stderr)
134
150
  if new_max_tokens > MAX_TOKENS_LIMIT:
135
151
  print(f"Would exceed maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
136
- return []
152
+ return [{"answer": None, "finish_reason": choice.finish_reason}]
137
153
  return self.request_api(prompt, model, temperature=temperature, max_tokens=new_max_tokens)
138
154
 
139
155
  answers.append({
@@ -177,8 +193,13 @@ class GptApi:
177
193
 
178
194
  def bulk_request(self, df, model, parse_mqm_answer, cache, max_tokens=None):
179
195
  answers = []
180
- for i, row in tqdm.tqdm(df.iterrows(), total=len(df), file=sys.stderr):
181
- prompt = row["prompt"]
182
- parsed_answers = self.request(prompt, model, parse_mqm_answer, cache=cache, max_tokens=max_tokens)
183
- answers += parsed_answers
196
+ with ThreadPoolExecutor(100) as executor:
197
+ futures = [
198
+ executor.submit(self.request, row["prompt"], model, parse_mqm_answer, cache=cache, max_tokens=max_tokens)
199
+ for _, row in df.iterrows()
200
+ ]
201
+
202
+ for future in tqdm.tqdm(futures, total=len(df), file=sys.stderr):
203
+ answers += future.result()
204
+
184
205
  return answers
@@ -1,10 +1,11 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: gemba
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: GEMBA — GPT Estimation Metric Based Assessment
5
5
  Project-URL: Homepage, https://github.com/joelniklaus/gemba
6
6
  Author-email: Joel Niklaus <joel@niklaus.ai>
7
- License: MIT
7
+ License-Expression: MIT
8
+ License-File: LICENSE.md
8
9
  Classifier: License :: OSI Approved :: MIT License
9
10
  Classifier: Operating System :: OS Independent
10
11
  Classifier: Programming Language :: Python :: 3
@@ -2,13 +2,13 @@ gemba/__init__.py,sha256=0ZuEumkUMWPI5wQMY7OxLolELI9GYYlup-iJw8SwBgc,67
2
2
  gemba/gemba_da.py,sha256=YCOKKP7kZBL9e1d44Zr7aTa23BqLFvh4KDOfbNSMgOU,2360
3
3
  gemba/gemba_esa.py,sha256=nBCeFjrS24wXLOcAXHRSmZFYJSkUzRS4hfp2LEqYwp8,4461
4
4
  gemba/gemba_mqm_utils.py,sha256=qiIdJv7IDx0eeqpsTCHMoUeo8EUOhG6k-YfrzkRfxyw,9612
5
- gemba/gpt_api.py,sha256=mucb4iGRk8ZHMHx9uSQpaFodPgMP_r5sjkvQwot8j3M,7245
5
+ gemba/gpt_api.py,sha256=A1GYi0vxUGmephkadI-h6v6G52uDQ7yWOFvIxSRrN8o,8380
6
6
  gemba/mtme_tools.py,sha256=xpLxCzfnLHFIxsq_LOi1Lpb-gkyFGYqFXiq9y6O315Q,4667
7
7
  gemba/prompt.py,sha256=AuPBhO2OBL3EB5I37p-GX10sx29gRw35xFAnB3bqtII,7578
8
8
  gemba/scores.py,sha256=FmmBJ-ds-abExphcVUw9qaPMnKttPWobuXNwZKLAtEs,4388
9
9
  gemba/testset.py,sha256=tDvi6xQIBXrODg02WWINrYg9jNQqruCmhBrxe9AaK48,1926
10
10
  gemba/utils.py,sha256=Re5uW5dcFj3ITWIGpxjXdAKNDKQ7i4H-Tr_s74SQgmk,4311
11
- gemba-0.1.1.dist-info/METADATA,sha256=crV8OU2-e1rixB9gfFGjvidoqQLlX5kRZ23sVKqLhqc,3691
12
- gemba-0.1.1.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
13
- gemba-0.1.1.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
14
- gemba-0.1.1.dist-info/RECORD,,
11
+ gemba-0.1.3.dist-info/METADATA,sha256=qyrjjVewIjFJeWKTAXw9rCMqLIj9OqCntenyw0F2oyw,3727
12
+ gemba-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ gemba-0.1.3.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
14
+ gemba-0.1.3.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.26.3
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any