gemba 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
    
        gemba/gpt_api.py
    CHANGED
    
    | @@ -1,12 +1,13 @@ | |
| 1 1 | 
             
            import os
         | 
| 2 2 | 
             
            import sys
         | 
| 3 3 | 
             
            import time
         | 
| 4 | 
            -
            import ipdb
         | 
| 5 4 | 
             
            import logging
         | 
| 6 5 | 
             
            from termcolor import colored
         | 
| 7 | 
            -
            from datetime import datetime
         | 
| 8 6 | 
             
            import openai
         | 
| 9 7 | 
             
            import tqdm
         | 
| 8 | 
            +
            from concurrent.futures import ThreadPoolExecutor
         | 
| 9 | 
            +
            from collections import defaultdict
         | 
| 10 | 
            +
             | 
| 10 11 |  | 
| 11 12 |  | 
| 12 13 | 
             
            # class for calling OpenAI API and handling cache
         | 
| @@ -58,6 +59,22 @@ class GptApi: | |
| 58 59 | 
             
                    for full_answer in answers:
         | 
| 59 60 | 
             
                        finish_reason = full_answer["finish_reason"]
         | 
| 60 61 | 
             
                        full_answer = full_answer["answer"]
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                        if finish_reason != "stop":
         | 
| 64 | 
            +
                            print(f"No valid answer, giving score 0")
         | 
| 65 | 
            +
                            errors = defaultdict(list)
         | 
| 66 | 
            +
                            errors["critical"].append("Judge errored, giving answer score 0.")
         | 
| 67 | 
            +
                            parsed_answers.append({
         | 
| 68 | 
            +
                                "temperature": temperature,
         | 
| 69 | 
            +
                                "answer_id": answer_id,
         | 
| 70 | 
            +
                                "answer": 0,
         | 
| 71 | 
            +
                                "errors": errors,
         | 
| 72 | 
            +
                                "prompt": prompt,
         | 
| 73 | 
            +
                                "finish_reason": finish_reason,
         | 
| 74 | 
            +
                                "model": model,
         | 
| 75 | 
            +
                            })
         | 
| 76 | 
            +
                            continue
         | 
| 77 | 
            +
             | 
| 61 78 | 
             
                        answer_id += 1
         | 
| 62 79 | 
             
                        answer = parse_response(full_answer)
         | 
| 63 80 | 
             
                        if isinstance(answer, tuple):
         | 
| @@ -68,33 +85,32 @@ class GptApi: | |
| 68 85 | 
             
                            print(f"Answer (t={temperature}): " + colored(answer, "yellow") + " (" + colored(full_answer, "blue") + ")", file=sys.stderr)
         | 
| 69 86 | 
             
                        if answer is None:
         | 
| 70 87 | 
             
                            continue
         | 
| 71 | 
            -
                        parsed_answers.append(
         | 
| 72 | 
            -
                             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
                            }
         | 
| 81 | 
            -
                        )
         | 
| 88 | 
            +
                        parsed_answers.append({
         | 
| 89 | 
            +
                            "temperature": temperature,
         | 
| 90 | 
            +
                            "answer_id": answer_id,
         | 
| 91 | 
            +
                            "answer": answer,
         | 
| 92 | 
            +
                            "errors": errors,
         | 
| 93 | 
            +
                            "prompt": prompt,
         | 
| 94 | 
            +
                            "finish_reason": finish_reason,
         | 
| 95 | 
            +
                            "model": model,
         | 
| 96 | 
            +
                        })
         | 
| 82 97 |  | 
| 83 98 | 
             
                    # there was no valid answer, increase temperature and try again
         | 
| 84 99 | 
             
                    if len(parsed_answers) == 0:
         | 
| 100 | 
            +
                        print(f"No valid answer, increasing temperature to {temperature + 1} and trying again")
         | 
| 85 101 | 
             
                        return self.request(prompt, model, parse_response, temperature=temperature + 1, answer_id=answer_id, cache=cache)
         | 
| 86 102 |  | 
| 87 103 | 
             
                    return parsed_answers
         | 
| 88 104 |  | 
| 89 105 | 
             
                def request_api(self, prompt, model, temperature=0, max_tokens=None):
         | 
| 90 106 | 
             
                    if temperature > 10:
         | 
| 91 | 
            -
                        return []
         | 
| 107 | 
            +
                        return [{"answer": None, "finish_reason": "error"}]
         | 
| 92 108 |  | 
| 93 109 | 
             
                    # Add maximum token limit
         | 
| 94 110 | 
             
                    MAX_TOKENS_LIMIT = 4000  # Adjust this based on your model's context window
         | 
| 95 111 | 
             
                    if max_tokens and max_tokens > MAX_TOKENS_LIMIT:
         | 
| 96 112 | 
             
                        print(f"Reached maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
         | 
| 97 | 
            -
                        return []
         | 
| 113 | 
            +
                        return [{"answer": None, "finish_reason": "length"}]
         | 
| 98 114 |  | 
| 99 115 | 
             
                    while True:
         | 
| 100 116 | 
             
                        try:
         | 
| @@ -104,10 +120,10 @@ class GptApi: | |
| 104 120 | 
             
                            # response was filtered
         | 
| 105 121 | 
             
                            if hasattr(e, 'code'):
         | 
| 106 122 | 
             
                                if e.code == 'content_filter':
         | 
| 107 | 
            -
                                    return []
         | 
| 123 | 
            +
                                    return [{"answer": None, "finish_reason": "filter"}]
         | 
| 108 124 | 
             
                                print(e.code, file=sys.stderr)
         | 
| 109 125 | 
             
                            if hasattr(e, 'error') and e.error['code'] == 'invalid_model_output':
         | 
| 110 | 
            -
                                return []
         | 
| 126 | 
            +
                                return [{"answer": None, "finish_reason": "invalid"}]
         | 
| 111 127 |  | 
| 112 128 | 
             
                            # frequent error is reaching the API limit
         | 
| 113 129 | 
             
                            print(colored("Error, retrying...", "red"), file=sys.stderr)
         | 
| @@ -117,7 +133,7 @@ class GptApi: | |
| 117 133 | 
             
                    answers = []
         | 
| 118 134 | 
             
                    for choice in response.choices:
         | 
| 119 135 | 
             
                        if choice.message.content is None:
         | 
| 120 | 
            -
                            return []
         | 
| 136 | 
            +
                            return [{"answer": None, "finish_reason": "invalid"}]
         | 
| 121 137 | 
             
                        if hasattr(choice, "message"):
         | 
| 122 138 | 
             
                            answer = choice.message.content.strip()
         | 
| 123 139 | 
             
                        else:
         | 
| @@ -127,13 +143,13 @@ class GptApi: | |
| 127 143 | 
             
                        if choice.finish_reason != "stop":
         | 
| 128 144 | 
             
                            if self.verbose:
         | 
| 129 145 | 
             
                                print(colored(f"Increasing max tokens to fit answers.", "red") + colored(answer, "blue"), file=sys.stderr)
         | 
| 130 | 
            -
                            print(f"Finish reason: {choice.finish_reason}", file=sys.stderr)
         | 
| 131 146 | 
             
                            if max_tokens is None:
         | 
| 132 147 | 
             
                                max_tokens = 500  # Set initial max_tokens if None
         | 
| 133 | 
            -
                            new_max_tokens = max_tokens  | 
| 148 | 
            +
                            new_max_tokens = max_tokens * 2
         | 
| 149 | 
            +
                            print(f"Finish reason: {choice.finish_reason}, increasing max tokens to {new_max_tokens}", file=sys.stderr)
         | 
| 134 150 | 
             
                            if new_max_tokens > MAX_TOKENS_LIMIT:
         | 
| 135 151 | 
             
                                print(f"Would exceed maximum token limit of {MAX_TOKENS_LIMIT}", file=sys.stderr)
         | 
| 136 | 
            -
                                return []
         | 
| 152 | 
            +
                                return [{"answer": None, "finish_reason": choice.finish_reason}]
         | 
| 137 153 | 
             
                            return self.request_api(prompt, model, temperature=temperature, max_tokens=new_max_tokens)
         | 
| 138 154 |  | 
| 139 155 | 
             
                        answers.append({
         | 
| @@ -177,8 +193,13 @@ class GptApi: | |
| 177 193 |  | 
| 178 194 | 
             
                def bulk_request(self, df, model, parse_mqm_answer, cache, max_tokens=None):
         | 
| 179 195 | 
             
                    answers = []
         | 
| 180 | 
            -
                     | 
| 181 | 
            -
                         | 
| 182 | 
            -
             | 
| 183 | 
            -
             | 
| 196 | 
            +
                    with ThreadPoolExecutor(100) as executor:
         | 
| 197 | 
            +
                        futures = [
         | 
| 198 | 
            +
                            executor.submit(self.request, row["prompt"], model, parse_mqm_answer, cache=cache, max_tokens=max_tokens)
         | 
| 199 | 
            +
                            for _, row in df.iterrows()
         | 
| 200 | 
            +
                        ]
         | 
| 201 | 
            +
                        
         | 
| 202 | 
            +
                        for future in tqdm.tqdm(futures, total=len(df), file=sys.stderr):
         | 
| 203 | 
            +
                            answers += future.result()
         | 
| 204 | 
            +
                            
         | 
| 184 205 | 
             
                    return answers
         | 
| @@ -1,10 +1,11 @@ | |
| 1 | 
            -
            Metadata-Version: 2. | 
| 1 | 
            +
            Metadata-Version: 2.4
         | 
| 2 2 | 
             
            Name: gemba
         | 
| 3 | 
            -
            Version: 0.1. | 
| 3 | 
            +
            Version: 0.1.3
         | 
| 4 4 | 
             
            Summary: GEMBA — GPT Estimation Metric Based Assessment
         | 
| 5 5 | 
             
            Project-URL: Homepage, https://github.com/joelniklaus/gemba
         | 
| 6 6 | 
             
            Author-email: Joel Niklaus <joel@niklaus.ai>
         | 
| 7 | 
            -
            License: MIT
         | 
| 7 | 
            +
            License-Expression: MIT
         | 
| 8 | 
            +
            License-File: LICENSE.md
         | 
| 8 9 | 
             
            Classifier: License :: OSI Approved :: MIT License
         | 
| 9 10 | 
             
            Classifier: Operating System :: OS Independent
         | 
| 10 11 | 
             
            Classifier: Programming Language :: Python :: 3
         | 
| @@ -2,13 +2,13 @@ gemba/__init__.py,sha256=0ZuEumkUMWPI5wQMY7OxLolELI9GYYlup-iJw8SwBgc,67 | |
| 2 2 | 
             
            gemba/gemba_da.py,sha256=YCOKKP7kZBL9e1d44Zr7aTa23BqLFvh4KDOfbNSMgOU,2360
         | 
| 3 3 | 
             
            gemba/gemba_esa.py,sha256=nBCeFjrS24wXLOcAXHRSmZFYJSkUzRS4hfp2LEqYwp8,4461
         | 
| 4 4 | 
             
            gemba/gemba_mqm_utils.py,sha256=qiIdJv7IDx0eeqpsTCHMoUeo8EUOhG6k-YfrzkRfxyw,9612
         | 
| 5 | 
            -
            gemba/gpt_api.py,sha256= | 
| 5 | 
            +
            gemba/gpt_api.py,sha256=A1GYi0vxUGmephkadI-h6v6G52uDQ7yWOFvIxSRrN8o,8380
         | 
| 6 6 | 
             
            gemba/mtme_tools.py,sha256=xpLxCzfnLHFIxsq_LOi1Lpb-gkyFGYqFXiq9y6O315Q,4667
         | 
| 7 7 | 
             
            gemba/prompt.py,sha256=AuPBhO2OBL3EB5I37p-GX10sx29gRw35xFAnB3bqtII,7578
         | 
| 8 8 | 
             
            gemba/scores.py,sha256=FmmBJ-ds-abExphcVUw9qaPMnKttPWobuXNwZKLAtEs,4388
         | 
| 9 9 | 
             
            gemba/testset.py,sha256=tDvi6xQIBXrODg02WWINrYg9jNQqruCmhBrxe9AaK48,1926
         | 
| 10 10 | 
             
            gemba/utils.py,sha256=Re5uW5dcFj3ITWIGpxjXdAKNDKQ7i4H-Tr_s74SQgmk,4311
         | 
| 11 | 
            -
            gemba-0.1. | 
| 12 | 
            -
            gemba-0.1. | 
| 13 | 
            -
            gemba-0.1. | 
| 14 | 
            -
            gemba-0.1. | 
| 11 | 
            +
            gemba-0.1.3.dist-info/METADATA,sha256=qyrjjVewIjFJeWKTAXw9rCMqLIj9OqCntenyw0F2oyw,3727
         | 
| 12 | 
            +
            gemba-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
         | 
| 13 | 
            +
            gemba-0.1.3.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
         | 
| 14 | 
            +
            gemba-0.1.3.dist-info/RECORD,,
         | 
| 
            File without changes
         |