gemba 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gemba/gpt_api.py CHANGED
@@ -1,12 +1,13 @@
1
1
  import os
2
2
  import sys
3
3
  import time
4
- import ipdb
5
4
  import logging
6
5
  from termcolor import colored
7
- from datetime import datetime
8
6
  import openai
9
7
  import tqdm
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from collections import defaultdict
10
+
10
11
 
11
12
 
12
13
  # class for calling OpenAI API and handling cache
@@ -58,6 +59,22 @@ class GptApi:
58
59
  for full_answer in answers:
59
60
  finish_reason = full_answer["finish_reason"]
60
61
  full_answer = full_answer["answer"]
62
+
63
+ if finish_reason != "stop":
64
+ print(f"No valid answer, giving score 0")
65
+ errors = defaultdict(list)
66
+ errors["critical"].append("Judge errored, giving answer score 0.")
67
+ parsed_answers.append({
68
+ "temperature": temperature,
69
+ "answer_id": answer_id,
70
+ "answer": 0,
71
+ "errors": errors,
72
+ "prompt": prompt,
73
+ "finish_reason": finish_reason,
74
+ "model": model,
75
+ })
76
+ continue
77
+
61
78
  answer_id += 1
62
79
  answer = parse_response(full_answer)
63
80
  if isinstance(answer, tuple):
@@ -68,33 +85,32 @@ class GptApi:
68
85
  print(f"Answer (t={temperature}): " + colored(answer, "yellow") + " (" + colored(full_answer, "blue") + ")", file=sys.stderr)
69
86
  if answer is None:
70
87
  continue
71
- parsed_answers.append(
72
- {
73
- "temperature": temperature,
74
- "answer_id": answer_id,
75
- "answer": answer,
76
- "errors": errors,
77
- "prompt": prompt,
78
- "finish_reason": finish_reason,
79
- "model": model,
80
- }
81
- )
88
+ parsed_answers.append({
89
+ "temperature": temperature,
90
+ "answer_id": answer_id,
91
+ "answer": answer,
92
+ "errors": errors,
93
+ "prompt": prompt,
94
+ "finish_reason": finish_reason,
95
+ "model": model,
96
+ })
82
97
 
83
98
  # there was no valid answer, increase temperature and try again
84
99
  if len(parsed_answers) == 0:
100
+ print(f"No valid answer, increasing temperature to {temperature + 1} and trying again")
85
101
  return self.request(prompt, model, parse_response, temperature=temperature + 1, answer_id=answer_id, cache=cache)
86
102
 
87
103
  return parsed_answers
88
104
 
89
105
  def request_api(self, prompt, model, temperature=0, max_tokens=None):
90
106
  if temperature > 10:
91
- return []
107
+ return [{"answer": None, "finish_reason": "error"}]
92
108
 
93
109
  # Add maximum token limit
94
110
  MAX_TOKENS_LIMIT = 4000 # Adjust this based on your model's context window
95
111
  if max_tokens and max_tokens > MAX_TOKENS_LIMIT:
96
112
  print(f"Reached maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
97
- return []
113
+ return [{"answer": None, "finish_reason": "length"}]
98
114
 
99
115
  while True:
100
116
  try:
@@ -104,10 +120,10 @@ class GptApi:
104
120
  # response was filtered
105
121
  if hasattr(e, 'code'):
106
122
  if e.code == 'content_filter':
107
- return []
123
+ return [{"answer": None, "finish_reason": "filter"}]
108
124
  print(e.code, file=sys.stderr)
109
125
  if hasattr(e, 'error') and e.error['code'] == 'invalid_model_output':
110
- return []
126
+ return [{"answer": None, "finish_reason": "invalid"}]
111
127
 
112
128
  # frequent error is reaching the API limit
113
129
  print(colored("Error, retrying...", "red"), file=sys.stderr)
@@ -117,7 +133,7 @@ class GptApi:
117
133
  answers = []
118
134
  for choice in response.choices:
119
135
  if choice.message.content is None:
120
- return []
136
+ return [{"answer": None, "finish_reason": "invalid"}]
121
137
  if hasattr(choice, "message"):
122
138
  answer = choice.message.content.strip()
123
139
  else:
@@ -127,13 +143,13 @@ class GptApi:
127
143
  if choice.finish_reason != "stop":
128
144
  if self.verbose:
129
145
  print(colored(f"Increasing max tokens to fit answers.", "red") + colored(answer, "blue"), file=sys.stderr)
130
- print(f"Finish reason: {choice.finish_reason}", file=sys.stderr)
131
146
  if max_tokens is None:
132
147
  max_tokens = 500 # Set initial max_tokens if None
133
- new_max_tokens = max_tokens + 200
148
+ new_max_tokens = max_tokens * 2
149
+ print(f"Finish reason: {choice.finish_reason}, increasing max tokens to {new_max_tokens}", file=sys.stderr)
134
150
  if new_max_tokens > MAX_TOKENS_LIMIT:
135
151
  print(f"Would exceed maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
136
- return []
152
+ return [{"answer": None, "finish_reason": choice.finish_reason}]
137
153
  return self.request_api(prompt, model, temperature=temperature, max_tokens=new_max_tokens)
138
154
 
139
155
  answers.append({
@@ -177,8 +193,13 @@ class GptApi:
177
193
 
178
194
  def bulk_request(self, df, model, parse_mqm_answer, cache, max_tokens=None):
179
195
  answers = []
180
- for i, row in tqdm.tqdm(df.iterrows(), total=len(df), file=sys.stderr):
181
- prompt = row["prompt"]
182
- parsed_answers = self.request(prompt, model, parse_mqm_answer, cache=cache, max_tokens=max_tokens)
183
- answers += parsed_answers
196
+ with ThreadPoolExecutor(100) as executor:
197
+ futures = [
198
+ executor.submit(self.request, row["prompt"], model, parse_mqm_answer, cache=cache, max_tokens=max_tokens)
199
+ for _, row in df.iterrows()
200
+ ]
201
+
202
+ for future in tqdm.tqdm(futures, total=len(df), file=sys.stderr):
203
+ answers += future.result()
204
+
184
205
  return answers
@@ -1,10 +1,11 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: gemba
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: GEMBA — GPT Estimation Metric Based Assessment
5
5
  Project-URL: Homepage, https://github.com/joelniklaus/gemba
6
6
  Author-email: Joel Niklaus <joel@niklaus.ai>
7
- License: MIT
7
+ License-Expression: MIT
8
+ License-File: LICENSE.md
8
9
  Classifier: License :: OSI Approved :: MIT License
9
10
  Classifier: Operating System :: OS Independent
10
11
  Classifier: Programming Language :: Python :: 3
@@ -2,13 +2,13 @@ gemba/__init__.py,sha256=0ZuEumkUMWPI5wQMY7OxLolELI9GYYlup-iJw8SwBgc,67
2
2
  gemba/gemba_da.py,sha256=YCOKKP7kZBL9e1d44Zr7aTa23BqLFvh4KDOfbNSMgOU,2360
3
3
  gemba/gemba_esa.py,sha256=nBCeFjrS24wXLOcAXHRSmZFYJSkUzRS4hfp2LEqYwp8,4461
4
4
  gemba/gemba_mqm_utils.py,sha256=qiIdJv7IDx0eeqpsTCHMoUeo8EUOhG6k-YfrzkRfxyw,9612
5
- gemba/gpt_api.py,sha256=mucb4iGRk8ZHMHx9uSQpaFodPgMP_r5sjkvQwot8j3M,7245
5
+ gemba/gpt_api.py,sha256=A1GYi0vxUGmephkadI-h6v6G52uDQ7yWOFvIxSRrN8o,8380
6
6
  gemba/mtme_tools.py,sha256=xpLxCzfnLHFIxsq_LOi1Lpb-gkyFGYqFXiq9y6O315Q,4667
7
7
  gemba/prompt.py,sha256=AuPBhO2OBL3EB5I37p-GX10sx29gRw35xFAnB3bqtII,7578
8
8
  gemba/scores.py,sha256=FmmBJ-ds-abExphcVUw9qaPMnKttPWobuXNwZKLAtEs,4388
9
9
  gemba/testset.py,sha256=tDvi6xQIBXrODg02WWINrYg9jNQqruCmhBrxe9AaK48,1926
10
10
  gemba/utils.py,sha256=Re5uW5dcFj3ITWIGpxjXdAKNDKQ7i4H-Tr_s74SQgmk,4311
11
- gemba-0.1.1.dist-info/METADATA,sha256=crV8OU2-e1rixB9gfFGjvidoqQLlX5kRZ23sVKqLhqc,3691
12
- gemba-0.1.1.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
13
- gemba-0.1.1.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
14
- gemba-0.1.1.dist-info/RECORD,,
11
+ gemba-0.1.3.dist-info/METADATA,sha256=qyrjjVewIjFJeWKTAXw9rCMqLIj9OqCntenyw0F2oyw,3727
12
+ gemba-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ gemba-0.1.3.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
14
+ gemba-0.1.3.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.26.3
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any