eval-framework 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +177 -0
- eval_framework/context/eval.py +121 -0
- eval_framework/context/local.py +78 -0
- eval_framework/evaluation_generator.py +234 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +432 -0
- eval_framework/llm/base.py +180 -0
- eval_framework/llm/huggingface.py +418 -0
- eval_framework/llm/mistral.py +88 -0
- eval_framework/llm/models.py +28 -0
- eval_framework/llm/openai.py +400 -0
- eval_framework/llm/vllm.py +554 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +166 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/aidanbench.py +28 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +179 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +34 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/llm/utils.py +20 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/base.py +50 -0
- eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
- eval_framework/metrics/loglikelihood/dcs.py +43 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
- eval_framework/metrics/loglikelihood/ternary.py +42 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +351 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +88 -0
- eval_framework/result_processors/hf_uploader.py +75 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/result_processors/wandb_uploader.py +137 -0
- eval_framework/run.py +369 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +392 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/aidanbench.py +211 -0
- eval_framework/tasks/benchmarks/arc.py +70 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +64 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +133 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +201 -0
- eval_framework/tasks/benchmarks/gsm8k.py +150 -0
- eval_framework/tasks/benchmarks/hellaswag.py +69 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +215 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +85 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +64 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +110 -0
- eval_framework/tasks/benchmarks/sphyr.py +79 -0
- eval_framework/tasks/benchmarks/squad.py +211 -0
- eval_framework/tasks/benchmarks/struct_eval.py +116 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
- eval_framework/tasks/benchmarks/winogender.py +64 -0
- eval_framework/tasks/benchmarks/winogrande.py +69 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +136 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +81 -0
- eval_framework/tasks/task_names.py +324 -0
- eval_framework/tasks/utils.py +584 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/file_ops.py +245 -0
- eval_framework/utils/generate_task_docs.py +244 -0
- eval_framework/utils/helpers.py +32 -0
- eval_framework/utils/logging.py +62 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework/utils/tqdm_handler.py +14 -0
- eval_framework-0.2.7.dist-info/METADATA +548 -0
- eval_framework-0.2.7.dist-info/RECORD +170 -0
- eval_framework-0.2.7.dist-info/WHEEL +4 -0
- eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +537 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# mypy: ignore-errors
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
|
|
5
|
+
from eval_framework.external.ifeval_impl import instructions_registry
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclasses.dataclass
|
|
9
|
+
class InputExample:
|
|
10
|
+
key: int
|
|
11
|
+
instruction_id_list: list[str]
|
|
12
|
+
prompt: str
|
|
13
|
+
kwargs: list[dict[str, str | int | None]]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclasses.dataclass
|
|
17
|
+
class OutputExample:
|
|
18
|
+
instruction_id_list: list[str]
|
|
19
|
+
prompt: str
|
|
20
|
+
response: str
|
|
21
|
+
follow_all_instructions: bool
|
|
22
|
+
follow_instruction_list: list[bool]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def test_instruction_following_strict(
|
|
26
|
+
inp,
|
|
27
|
+
response,
|
|
28
|
+
):
|
|
29
|
+
"""Tests response to see if instructions are followed."""
|
|
30
|
+
instruction_list = inp.instruction_id_list
|
|
31
|
+
is_following_list = []
|
|
32
|
+
|
|
33
|
+
for index, instruction_id in enumerate(instruction_list):
|
|
34
|
+
instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
|
|
35
|
+
instruction = instruction_cls(instruction_id)
|
|
36
|
+
|
|
37
|
+
# Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
|
|
38
|
+
kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
|
|
39
|
+
instruction.build_description(**kwargs)
|
|
40
|
+
args = instruction.get_instruction_args()
|
|
41
|
+
if args and "prompt" in args:
|
|
42
|
+
instruction.build_description(prompt=inp.prompt)
|
|
43
|
+
|
|
44
|
+
if response.strip() and instruction.check_following(response):
|
|
45
|
+
is_following_list.append(True)
|
|
46
|
+
else:
|
|
47
|
+
is_following_list.append(False)
|
|
48
|
+
|
|
49
|
+
return OutputExample(
|
|
50
|
+
instruction_id_list=inp.instruction_id_list,
|
|
51
|
+
prompt=inp.prompt,
|
|
52
|
+
response=response,
|
|
53
|
+
follow_all_instructions=all(is_following_list),
|
|
54
|
+
follow_instruction_list=is_following_list,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_instruction_following_loose(
|
|
59
|
+
inp,
|
|
60
|
+
response,
|
|
61
|
+
):
|
|
62
|
+
"""Tests response for an upper bound for following instructions."""
|
|
63
|
+
r = response.split("\n")
|
|
64
|
+
response_remove_first = "\n".join(r[1:]).strip()
|
|
65
|
+
response_remove_last = "\n".join(r[:-1]).strip()
|
|
66
|
+
response_remove_both = "\n".join(r[1:-1]).strip()
|
|
67
|
+
revised_response = response.replace("*", "")
|
|
68
|
+
revised_response_remove_first = response_remove_first.replace("*", "")
|
|
69
|
+
revised_response_remove_last = response_remove_last.replace("*", "")
|
|
70
|
+
revised_response_remove_both = response_remove_both.replace("*", "")
|
|
71
|
+
all_responses = [
|
|
72
|
+
response,
|
|
73
|
+
revised_response,
|
|
74
|
+
response_remove_first,
|
|
75
|
+
response_remove_last,
|
|
76
|
+
response_remove_both,
|
|
77
|
+
revised_response_remove_first,
|
|
78
|
+
revised_response_remove_last,
|
|
79
|
+
revised_response_remove_both,
|
|
80
|
+
]
|
|
81
|
+
instruction_list = inp.instruction_id_list
|
|
82
|
+
is_following_list = []
|
|
83
|
+
|
|
84
|
+
for index, instruction_id in enumerate(instruction_list):
|
|
85
|
+
instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
|
|
86
|
+
instruction = instruction_cls(instruction_id)
|
|
87
|
+
|
|
88
|
+
# Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
|
|
89
|
+
kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
|
|
90
|
+
instruction.build_description(**kwargs)
|
|
91
|
+
args = instruction.get_instruction_args()
|
|
92
|
+
if args and "prompt" in args:
|
|
93
|
+
instruction.build_description(prompt=inp.prompt)
|
|
94
|
+
|
|
95
|
+
is_following = False
|
|
96
|
+
for r in all_responses:
|
|
97
|
+
if r.strip() and instruction.check_following(r):
|
|
98
|
+
is_following = True
|
|
99
|
+
break
|
|
100
|
+
|
|
101
|
+
is_following_list.append(is_following)
|
|
102
|
+
|
|
103
|
+
return OutputExample(
|
|
104
|
+
instruction_id_list=inp.instruction_id_list,
|
|
105
|
+
prompt=inp.prompt,
|
|
106
|
+
response=response,
|
|
107
|
+
follow_all_instructions=all(is_following_list),
|
|
108
|
+
follow_instruction_list=is_following_list,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def process_results(input, results):
|
|
113
|
+
response = results[0]
|
|
114
|
+
input_example = InputExample(
|
|
115
|
+
key=input.key,
|
|
116
|
+
instruction_id_list=input.instruction_id_list,
|
|
117
|
+
prompt=input.prompt,
|
|
118
|
+
kwargs=input.additional_kwargs,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
out_strict = test_instruction_following_strict(input_example, response)
|
|
122
|
+
out_loose = test_instruction_following_loose(input_example, response)
|
|
123
|
+
|
|
124
|
+
return {
|
|
125
|
+
"prompt_level_strict_acc": out_strict.follow_all_instructions,
|
|
126
|
+
"inst_level_strict_acc": out_strict.follow_instruction_list,
|
|
127
|
+
"prompt_level_loose_acc": out_loose.follow_all_instructions,
|
|
128
|
+
"inst_level_loose_acc": out_loose.follow_instruction_list,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def agg_inst_level_acc(items):
|
|
133
|
+
flat_items = [item for sublist in items for item in sublist]
|
|
134
|
+
inst_level_acc = sum(flat_items) / len(flat_items)
|
|
135
|
+
return inst_level_acc
|
|
File without changes
|
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import math
|
|
5
|
+
import os
|
|
6
|
+
import random
|
|
7
|
+
import re
|
|
8
|
+
import time
|
|
9
|
+
import traceback
|
|
10
|
+
from collections.abc import Callable, Sequence
|
|
11
|
+
|
|
12
|
+
import aiohttp
|
|
13
|
+
from aleph_alpha_client import (
|
|
14
|
+
AsyncClient,
|
|
15
|
+
BusyError,
|
|
16
|
+
Client,
|
|
17
|
+
CompletionRequest,
|
|
18
|
+
CompletionResponse,
|
|
19
|
+
Prompt,
|
|
20
|
+
)
|
|
21
|
+
from aleph_alpha_client.prompt import Text
|
|
22
|
+
from dotenv import load_dotenv
|
|
23
|
+
|
|
24
|
+
from eval_framework.llm.base import BaseLLM
|
|
25
|
+
from eval_framework.shared.types import Error, PromptTooLongException, RawCompletion, RawLoglikelihood
|
|
26
|
+
from eval_framework.tasks.base import Sample
|
|
27
|
+
from eval_framework.tasks.utils import raise_errors
|
|
28
|
+
from template_formatting.formatter import BaseFormatter, Llama3Formatter, Message
|
|
29
|
+
|
|
30
|
+
load_dotenv()
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def safe_json_loads(s: str) -> dict[str, str]:
|
|
36
|
+
try:
|
|
37
|
+
return json.loads(s)
|
|
38
|
+
except (json.JSONDecodeError, TypeError):
|
|
39
|
+
return {}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AlephAlphaAPIModel(BaseLLM):
|
|
43
|
+
LLM_NAME: str
|
|
44
|
+
DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
|
|
45
|
+
BYTES_PER_TOKEN: float = 4.0 # rule of thumb according to https://platform.openai.com/tokenizer
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
formatter: BaseFormatter | None = None,
|
|
50
|
+
checkpoint_name: str | None = None,
|
|
51
|
+
temperature: float | None = None,
|
|
52
|
+
# Please see README.md for tips if adapting the following parameters.
|
|
53
|
+
max_retries: int = 100,
|
|
54
|
+
max_async_concurrent_requests: int = 32,
|
|
55
|
+
request_timeout_seconds: int = 30 * 60 + 5,
|
|
56
|
+
queue_full_timeout_seconds: int = 30 * 60 + 5,
|
|
57
|
+
bytes_per_token: float | None = None,
|
|
58
|
+
token: str = os.getenv("AA_TOKEN", "dummy"),
|
|
59
|
+
base_url: str = os.getenv("AA_INFERENCE_ENDPOINT", "dummy_endpoint"),
|
|
60
|
+
) -> None:
|
|
61
|
+
self._formatter: BaseFormatter
|
|
62
|
+
if formatter is None:
|
|
63
|
+
if self.DEFAULT_FORMATTER is None:
|
|
64
|
+
raise ValueError("Either formatter or default formatter must be specified")
|
|
65
|
+
self._formatter = self.DEFAULT_FORMATTER()
|
|
66
|
+
else:
|
|
67
|
+
self._formatter = formatter
|
|
68
|
+
self._llm_name = checkpoint_name or self.LLM_NAME
|
|
69
|
+
self._temperature = temperature if temperature is not None else 0.0
|
|
70
|
+
self.max_async_concurrent_requests = max_async_concurrent_requests
|
|
71
|
+
self.max_retries = max_retries
|
|
72
|
+
self.request_timeout_seconds = request_timeout_seconds
|
|
73
|
+
self.queue_full_timeout_seconds = queue_full_timeout_seconds
|
|
74
|
+
self.token = token
|
|
75
|
+
self.base_url = base_url
|
|
76
|
+
self._validate_model_availability(base_url, token)
|
|
77
|
+
# set bytes_per_token_scalar for non-standard models
|
|
78
|
+
if bytes_per_token is not None and bytes_per_token <= 0:
|
|
79
|
+
raise ValueError("bytes_per_token must be positive")
|
|
80
|
+
self.bytes_per_token_scalar = (
|
|
81
|
+
4.0 / bytes_per_token if bytes_per_token is not None else 4.0 / self.BYTES_PER_TOKEN
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def _validate_model_availability(self, base_url: str, token: str) -> None:
|
|
85
|
+
"""
|
|
86
|
+
Validate that the model name is available by making a test request.
|
|
87
|
+
"""
|
|
88
|
+
try:
|
|
89
|
+
# 'Client' object does not support the context manager protocol
|
|
90
|
+
client = Client(
|
|
91
|
+
host=base_url,
|
|
92
|
+
token=token,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
request = CompletionRequest(
|
|
96
|
+
prompt=Prompt.from_text(""),
|
|
97
|
+
maximum_tokens=1,
|
|
98
|
+
)
|
|
99
|
+
client.complete(request, model=self._llm_name)
|
|
100
|
+
logger.info(f"Model '{self._llm_name}' available and loaded.")
|
|
101
|
+
except Exception as e:
|
|
102
|
+
raise RuntimeError(f"Model '{self._llm_name}' is not available: {e}")
|
|
103
|
+
|
|
104
|
+
async def _request_with_backoff(
|
|
105
|
+
self, client: AsyncClient, request: CompletionRequest, id: int
|
|
106
|
+
) -> CompletionResponse:
|
|
107
|
+
"""
|
|
108
|
+
Query Aleph-Alpha API with complete. Retry with back-off until it responds.
|
|
109
|
+
"""
|
|
110
|
+
num_attempts = 0
|
|
111
|
+
start_time: float | None = None
|
|
112
|
+
|
|
113
|
+
while True:
|
|
114
|
+
try:
|
|
115
|
+
return await client.complete(request, model=self._llm_name)
|
|
116
|
+
|
|
117
|
+
except (TimeoutError, BusyError, RuntimeError, aiohttp.ClientError) as e:
|
|
118
|
+
status_code: str = safe_json_loads(e.args[1]).get("code", "") if len(e.args) >= 2 else ""
|
|
119
|
+
str_e = str(e)
|
|
120
|
+
if status_code == "QUEUE_FULL":
|
|
121
|
+
# Worker not available or missed a heartbeat (inference longer than scheduler's
|
|
122
|
+
# API_MODEL_AVAILABLE_TIMEOUT_DURATION_MILLIS) or the scheduler is overloaded.
|
|
123
|
+
if start_time is None:
|
|
124
|
+
start_time = time.time()
|
|
125
|
+
elapsed = time.time() - start_time
|
|
126
|
+
if elapsed <= self.queue_full_timeout_seconds:
|
|
127
|
+
logger.info(
|
|
128
|
+
f"Request {id}: {status_code or str_e[:256]} - retrying: attempt"
|
|
129
|
+
f" {num_attempts}/{self.max_retries}, elapsed {elapsed:.1f} sec"
|
|
130
|
+
)
|
|
131
|
+
# don't count as retry (request returns immediately, so just wait a bit not to DoS the server)
|
|
132
|
+
await asyncio.sleep(random.randint(5, 30))
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
elif (
|
|
136
|
+
status_code == "TIMEOUT_TASK"
|
|
137
|
+
or isinstance(e, TimeoutError)
|
|
138
|
+
or "502 Bad Gateway" in str_e
|
|
139
|
+
or "504 Gateway Time-out" in str_e
|
|
140
|
+
or isinstance(e, aiohttp.ClientError)
|
|
141
|
+
):
|
|
142
|
+
# client timeout, either because task too long in a queue or inference too long
|
|
143
|
+
# (scheduler's API_CLIENT_TIMEOUT_DURATION_MILLIS). Retrying for the "inference too long"
|
|
144
|
+
# case makes no sense but we unfortunately don't know which case has happened.
|
|
145
|
+
num_attempts += 1
|
|
146
|
+
start_time = None
|
|
147
|
+
if num_attempts < self.max_retries:
|
|
148
|
+
logger.info(f"Request {id}: TIMEOUT_TASK - retrying: attempt {num_attempts}/{self.max_retries}")
|
|
149
|
+
await asyncio.sleep(random.randint(5, 30))
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
raise e
|
|
153
|
+
|
|
154
|
+
def _error_from_exception(self, e: Exception) -> Error:
|
|
155
|
+
"""Convert an exception to an Error object."""
|
|
156
|
+
if len(e.args) >= 2:
|
|
157
|
+
status_code: str = safe_json_loads(e.args[1]).get("code", "")
|
|
158
|
+
if status_code == "PROMPT_TOO_LONG":
|
|
159
|
+
return Error(
|
|
160
|
+
error_class=PromptTooLongException.__name__,
|
|
161
|
+
message="Prompt exceeded context size.",
|
|
162
|
+
traceback=traceback.format_exc(),
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
return Error(
|
|
166
|
+
error_class=status_code or e.__class__.__name__, message=str(e), traceback=traceback.format_exc()
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
return Error(error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc())
|
|
170
|
+
|
|
171
|
+
async def _process_request_with_client(
|
|
172
|
+
self,
|
|
173
|
+
client: AsyncClient,
|
|
174
|
+
semaphore: asyncio.Semaphore,
|
|
175
|
+
request: CompletionRequest,
|
|
176
|
+
id: int,
|
|
177
|
+
) -> tuple[CompletionRequest, CompletionResponse | Error]:
|
|
178
|
+
"""Process a single request, returning the request and either a response or error."""
|
|
179
|
+
async with semaphore:
|
|
180
|
+
try:
|
|
181
|
+
response = await self._request_with_backoff(client=client, request=request, id=id)
|
|
182
|
+
logger.info(f"Request {id}: Success")
|
|
183
|
+
return (request, response)
|
|
184
|
+
except Exception as e:
|
|
185
|
+
if raise_errors():
|
|
186
|
+
raise e
|
|
187
|
+
logger.info(f"Request {id}: Failure: {str(e)[:256]}")
|
|
188
|
+
return (request, self._error_from_exception(e))
|
|
189
|
+
|
|
190
|
+
async def _process_requests(
|
|
191
|
+
self,
|
|
192
|
+
requests: list[CompletionRequest],
|
|
193
|
+
) -> list[tuple[CompletionRequest, CompletionResponse | Error]]:
|
|
194
|
+
"""Process multiple requests concurrently, returning request/response pairs."""
|
|
195
|
+
semaphore = asyncio.Semaphore(self.max_async_concurrent_requests)
|
|
196
|
+
async with AsyncClient(
|
|
197
|
+
host=self.base_url,
|
|
198
|
+
nice=True,
|
|
199
|
+
request_timeout_seconds=self.request_timeout_seconds,
|
|
200
|
+
token=self.token,
|
|
201
|
+
total_retries=0, # we have a custom retry policy in _request_with_backoff()
|
|
202
|
+
) as client:
|
|
203
|
+
tasks = (
|
|
204
|
+
self._process_request_with_client(
|
|
205
|
+
client,
|
|
206
|
+
semaphore,
|
|
207
|
+
request,
|
|
208
|
+
i,
|
|
209
|
+
)
|
|
210
|
+
for i, request in enumerate(requests)
|
|
211
|
+
)
|
|
212
|
+
responses = await asyncio.gather(*tasks) # guarantees order of responses
|
|
213
|
+
return list(responses)
|
|
214
|
+
|
|
215
|
+
def _response_to_raw_completion(
|
|
216
|
+
self, request: CompletionRequest, response: CompletionResponse | Error
|
|
217
|
+
) -> RawCompletion:
|
|
218
|
+
"""Convert a request/response pair to a RawCompletion."""
|
|
219
|
+
assert isinstance(request.prompt.items[0], Text)
|
|
220
|
+
prompt = request.prompt.items[0].text
|
|
221
|
+
|
|
222
|
+
if isinstance(response, Error):
|
|
223
|
+
return RawCompletion(
|
|
224
|
+
prompt=prompt,
|
|
225
|
+
prompt_sequence_positions=None,
|
|
226
|
+
completion="",
|
|
227
|
+
completion_sequence_positions=0,
|
|
228
|
+
raw_completion_error=response,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
assert len(response.completions) == 1
|
|
232
|
+
completion = response.completions[0].completion or ""
|
|
233
|
+
prompt_sequence_positions: int | None = None
|
|
234
|
+
completion_sequence_positions: int | None = None
|
|
235
|
+
|
|
236
|
+
# Support workaround in api-worker-transformer's scaling generator to return the correct number of tokens.
|
|
237
|
+
# These are part of the completion string; those in CompletionResponse are invalid in this case.
|
|
238
|
+
m = re.match(r"\uf8c9(\d+),(\d+)\uf8c9(.*)", completion, re.DOTALL)
|
|
239
|
+
if m is not None:
|
|
240
|
+
num_input_tokens, num_completion_tokens, completion = m.groups()
|
|
241
|
+
prompt_sequence_positions = int(num_input_tokens)
|
|
242
|
+
completion_sequence_positions = int(num_completion_tokens)
|
|
243
|
+
else:
|
|
244
|
+
prompt_sequence_positions = response.num_tokens_prompt_total if response else None
|
|
245
|
+
completion_sequence_positions = response.num_tokens_generated if response else None
|
|
246
|
+
|
|
247
|
+
return RawCompletion(
|
|
248
|
+
prompt=prompt,
|
|
249
|
+
prompt_sequence_positions=prompt_sequence_positions,
|
|
250
|
+
completion=completion,
|
|
251
|
+
completion_sequence_positions=completion_sequence_positions,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def generate_from_messages(
|
|
255
|
+
self,
|
|
256
|
+
messages: list[Sequence[Message]],
|
|
257
|
+
stop_sequences: list[str] | None = None,
|
|
258
|
+
max_tokens: int | None = None,
|
|
259
|
+
temperature: float | None = None,
|
|
260
|
+
) -> list[RawCompletion]:
|
|
261
|
+
effective_temperature = temperature if temperature is not None else self._temperature
|
|
262
|
+
|
|
263
|
+
requests: list[CompletionRequest] = []
|
|
264
|
+
|
|
265
|
+
# Adjust max tokens based on bytes_per_token_scalar so that non-standard models generate full responses
|
|
266
|
+
scaled_max_tokens = math.ceil(max_tokens * self.bytes_per_token_scalar) if max_tokens is not None else None
|
|
267
|
+
|
|
268
|
+
for single_messages in messages:
|
|
269
|
+
requests.append(
|
|
270
|
+
CompletionRequest(
|
|
271
|
+
prompt=Prompt.from_text(self._formatter.format(single_messages, output_mode="string")),
|
|
272
|
+
maximum_tokens=scaled_max_tokens,
|
|
273
|
+
stop_sequences=stop_sequences,
|
|
274
|
+
temperature=effective_temperature,
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
responses = asyncio.run(self._process_requests(requests))
|
|
279
|
+
return [self._response_to_raw_completion(req, resp) for req, resp in responses]
|
|
280
|
+
|
|
281
|
+
def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
|
|
282
|
+
prompts: list[str] = []
|
|
283
|
+
completion_requests: list[CompletionRequest] = []
|
|
284
|
+
|
|
285
|
+
for sample in samples:
|
|
286
|
+
prompt: str = self._formatter.format(sample.messages, output_mode="string") if sample.messages else ""
|
|
287
|
+
prompts.append(prompt)
|
|
288
|
+
for choice in sample.possible_completions or []:
|
|
289
|
+
completion_requests.append(
|
|
290
|
+
CompletionRequest(
|
|
291
|
+
prompt=Prompt.from_text(prompt + choice),
|
|
292
|
+
maximum_tokens=0,
|
|
293
|
+
temperature=0.0,
|
|
294
|
+
log_probs=0,
|
|
295
|
+
echo=True,
|
|
296
|
+
tokens=True,
|
|
297
|
+
)
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
completion_responses: list[tuple[CompletionRequest, CompletionResponse | Error]] = []
|
|
301
|
+
if completion_requests:
|
|
302
|
+
completion_responses = asyncio.run(self._process_requests(completion_requests))
|
|
303
|
+
completion_iter = iter(completion_responses)
|
|
304
|
+
|
|
305
|
+
results: list[RawLoglikelihood] = []
|
|
306
|
+
for sample_idx, (sample, prompt) in enumerate(zip(samples, prompts, strict=True)):
|
|
307
|
+
choices_log_probs: dict[str, float] = {}
|
|
308
|
+
choices_sequence_positions: dict[str, int] = {}
|
|
309
|
+
prompt_sequence_positions: int | None = 0
|
|
310
|
+
number_of_initial_choices_tokens: int | None = None
|
|
311
|
+
error: Error | None = None
|
|
312
|
+
|
|
313
|
+
for choice in sample.possible_completions or []:
|
|
314
|
+
request, response = next(completion_iter)
|
|
315
|
+
assert isinstance(request, CompletionRequest)
|
|
316
|
+
if error is not None:
|
|
317
|
+
continue
|
|
318
|
+
|
|
319
|
+
if isinstance(response, Error):
|
|
320
|
+
error = response
|
|
321
|
+
prompt_sequence_positions = None
|
|
322
|
+
choices_log_probs = {}
|
|
323
|
+
choices_sequence_positions = {}
|
|
324
|
+
else:
|
|
325
|
+
try:
|
|
326
|
+
logprob, choice_token_count = self._extract_choice_logprob_from_completion(
|
|
327
|
+
prompt=prompt,
|
|
328
|
+
choice=choice,
|
|
329
|
+
response=response,
|
|
330
|
+
)
|
|
331
|
+
choices_log_probs[choice] = logprob
|
|
332
|
+
choices_sequence_positions[choice] = choice_token_count
|
|
333
|
+
if number_of_initial_choices_tokens is None:
|
|
334
|
+
number_of_initial_choices_tokens = choice_token_count
|
|
335
|
+
|
|
336
|
+
self._check_choices_token_count(
|
|
337
|
+
sample_idx, choice_token_count, number_of_initial_choices_tokens
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
except Exception as exc:
|
|
341
|
+
if raise_errors():
|
|
342
|
+
raise
|
|
343
|
+
error = Error(
|
|
344
|
+
error_class=exc.__class__.__name__,
|
|
345
|
+
message=str(exc),
|
|
346
|
+
traceback=traceback.format_exc(),
|
|
347
|
+
)
|
|
348
|
+
prompt_sequence_positions = None
|
|
349
|
+
choices_log_probs = {}
|
|
350
|
+
choices_sequence_positions = {}
|
|
351
|
+
|
|
352
|
+
results.append(
|
|
353
|
+
RawLoglikelihood(
|
|
354
|
+
prompt=prompt,
|
|
355
|
+
prompt_sequence_positions=prompt_sequence_positions,
|
|
356
|
+
loglikelihoods=choices_log_probs,
|
|
357
|
+
loglikelihoods_sequence_positions=choices_sequence_positions,
|
|
358
|
+
raw_loglikelihood_error=error,
|
|
359
|
+
)
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
return results
|
|
363
|
+
|
|
364
|
+
@staticmethod
|
|
365
|
+
def _check_choices_token_count(
|
|
366
|
+
sample_idx: int, choice_token_count: int, number_of_initial_choices_tokens: int | None
|
|
367
|
+
) -> None:
|
|
368
|
+
if number_of_initial_choices_tokens is not None:
|
|
369
|
+
if choice_token_count != number_of_initial_choices_tokens:
|
|
370
|
+
logger.warning(
|
|
371
|
+
"Choice token count differed between choices for sample %s (%s vs %s). Using latest value.",
|
|
372
|
+
sample_idx,
|
|
373
|
+
choice_token_count,
|
|
374
|
+
number_of_initial_choices_tokens,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
@staticmethod
|
|
378
|
+
def _extract_choice_logprob_from_completion(
|
|
379
|
+
prompt: str, choice: str, response: CompletionResponse
|
|
380
|
+
) -> tuple[float, int]:
|
|
381
|
+
if not response.completions:
|
|
382
|
+
raise ValueError("Completion response did not contain any choices.")
|
|
383
|
+
completion_result = response.completions[0]
|
|
384
|
+
if completion_result.log_probs is None:
|
|
385
|
+
raise ValueError("Completion result did not include log_probs.")
|
|
386
|
+
if completion_result.completion_tokens is None:
|
|
387
|
+
raise ValueError("Completion result did not include completion_tokens.")
|
|
388
|
+
|
|
389
|
+
tokens = list(completion_result.completion_tokens)
|
|
390
|
+
log_prob_entries = list(completion_result.log_probs)
|
|
391
|
+
|
|
392
|
+
if len(tokens) != len(log_prob_entries):
|
|
393
|
+
raise ValueError("Mismatch between completion tokens and log_prob entries.")
|
|
394
|
+
|
|
395
|
+
combined_text = "".join(tokens)
|
|
396
|
+
expected_text = prompt + choice
|
|
397
|
+
if combined_text != expected_text:
|
|
398
|
+
raise ValueError("Completion tokens differed from prompt + choice text.")
|
|
399
|
+
|
|
400
|
+
prompt_token_count = AlephAlphaAPIModel._count_prompt_tokens_from_sequence(tokens, prompt)
|
|
401
|
+
choice_token_count = len(tokens) - prompt_token_count
|
|
402
|
+
if choice_token_count < 0:
|
|
403
|
+
raise ValueError("Choice token count computed as negative.")
|
|
404
|
+
|
|
405
|
+
total_logprob = 0.0
|
|
406
|
+
for entry in log_prob_entries[prompt_token_count:]:
|
|
407
|
+
assert isinstance(entry, dict)
|
|
408
|
+
if len(entry) != 1:
|
|
409
|
+
raise ValueError("Log_probs entry did not contain exactly one key-value pair.")
|
|
410
|
+
_, value = entry.popitem()
|
|
411
|
+
assert isinstance(value, float)
|
|
412
|
+
total_logprob += value
|
|
413
|
+
|
|
414
|
+
return total_logprob, choice_token_count
|
|
415
|
+
|
|
416
|
+
@staticmethod
|
|
417
|
+
def _count_prompt_tokens_from_sequence(tokens: Sequence[str], prompt: str) -> int:
|
|
418
|
+
if not prompt:
|
|
419
|
+
return 0
|
|
420
|
+
current_text = ""
|
|
421
|
+
for idx, token in enumerate(tokens):
|
|
422
|
+
current_text += token
|
|
423
|
+
if current_text == prompt:
|
|
424
|
+
return idx + 1
|
|
425
|
+
if len(current_text) > len(prompt):
|
|
426
|
+
break
|
|
427
|
+
raise ValueError("Unable to align completion tokens with prompt text.")
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
class Llama31_8B_Instruct_API(AlephAlphaAPIModel):
|
|
431
|
+
LLM_NAME = "llama-3.1-8b-instruct"
|
|
432
|
+
DEFAULT_FORMATTER = Llama3Formatter
|