eval-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +170 -0
- eval_framework/context/eval.py +114 -0
- eval_framework/context/local.py +52 -0
- eval_framework/evaluation_generator.py +231 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +323 -0
- eval_framework/llm/base.py +58 -0
- eval_framework/llm/huggingface.py +332 -0
- eval_framework/llm/mistral.py +73 -0
- eval_framework/llm/models.py +16 -0
- eval_framework/llm/openai.py +205 -0
- eval_framework/llm/vllm.py +438 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +187 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +171 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +8 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +416 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +74 -0
- eval_framework/result_processors/hf_processor.py +87 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/run.py +314 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +314 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/arc.py +46 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +39 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +62 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +177 -0
- eval_framework/tasks/benchmarks/gsm8k.py +148 -0
- eval_framework/tasks/benchmarks/hellaswag.py +44 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +190 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +37 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +39 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +44 -0
- eval_framework/tasks/benchmarks/sphyr.py +75 -0
- eval_framework/tasks/benchmarks/squad.py +89 -0
- eval_framework/tasks/benchmarks/struct_eval.py +110 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
- eval_framework/tasks/benchmarks/winogender.py +39 -0
- eval_framework/tasks/benchmarks/winogrande.py +44 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +112 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +80 -0
- eval_framework/tasks/task_names.py +138 -0
- eval_framework/tasks/utils.py +578 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/generate_task_docs.py +229 -0
- eval_framework/utils/helpers.py +3 -0
- eval_framework/utils/logging.py +50 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework-0.2.0.dist-info/METADATA +514 -0
- eval_framework-0.2.0.dist-info/RECORD +161 -0
- eval_framework-0.2.0.dist-info/WHEEL +4 -0
- eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +536 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
- template_formatting/tests/test_formatter_eval.py +408 -0
- template_formatting/tests/test_formatter_scaling.py +253 -0
- template_formatting/tests/test_mistral_formatter.py +136 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# mypy: ignore-errors
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
|
|
5
|
+
from eval_framework.external.ifeval_impl import instructions_registry
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclasses.dataclass
|
|
9
|
+
class InputExample:
|
|
10
|
+
key: int
|
|
11
|
+
instruction_id_list: list[str]
|
|
12
|
+
prompt: str
|
|
13
|
+
kwargs: list[dict[str, str | int | None]]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclasses.dataclass
|
|
17
|
+
class OutputExample:
|
|
18
|
+
instruction_id_list: list[str]
|
|
19
|
+
prompt: str
|
|
20
|
+
response: str
|
|
21
|
+
follow_all_instructions: bool
|
|
22
|
+
follow_instruction_list: list[bool]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def test_instruction_following_strict(
|
|
26
|
+
inp,
|
|
27
|
+
response,
|
|
28
|
+
):
|
|
29
|
+
"""Tests response to see if instructions are followed."""
|
|
30
|
+
instruction_list = inp.instruction_id_list
|
|
31
|
+
is_following_list = []
|
|
32
|
+
|
|
33
|
+
for index, instruction_id in enumerate(instruction_list):
|
|
34
|
+
instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
|
|
35
|
+
instruction = instruction_cls(instruction_id)
|
|
36
|
+
|
|
37
|
+
# Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
|
|
38
|
+
kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
|
|
39
|
+
instruction.build_description(**kwargs)
|
|
40
|
+
args = instruction.get_instruction_args()
|
|
41
|
+
if args and "prompt" in args:
|
|
42
|
+
instruction.build_description(prompt=inp.prompt)
|
|
43
|
+
|
|
44
|
+
if response.strip() and instruction.check_following(response):
|
|
45
|
+
is_following_list.append(True)
|
|
46
|
+
else:
|
|
47
|
+
is_following_list.append(False)
|
|
48
|
+
|
|
49
|
+
return OutputExample(
|
|
50
|
+
instruction_id_list=inp.instruction_id_list,
|
|
51
|
+
prompt=inp.prompt,
|
|
52
|
+
response=response,
|
|
53
|
+
follow_all_instructions=all(is_following_list),
|
|
54
|
+
follow_instruction_list=is_following_list,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_instruction_following_loose(
|
|
59
|
+
inp,
|
|
60
|
+
response,
|
|
61
|
+
):
|
|
62
|
+
"""Tests response for an upper bound for following instructions."""
|
|
63
|
+
r = response.split("\n")
|
|
64
|
+
response_remove_first = "\n".join(r[1:]).strip()
|
|
65
|
+
response_remove_last = "\n".join(r[:-1]).strip()
|
|
66
|
+
response_remove_both = "\n".join(r[1:-1]).strip()
|
|
67
|
+
revised_response = response.replace("*", "")
|
|
68
|
+
revised_response_remove_first = response_remove_first.replace("*", "")
|
|
69
|
+
revised_response_remove_last = response_remove_last.replace("*", "")
|
|
70
|
+
revised_response_remove_both = response_remove_both.replace("*", "")
|
|
71
|
+
all_responses = [
|
|
72
|
+
response,
|
|
73
|
+
revised_response,
|
|
74
|
+
response_remove_first,
|
|
75
|
+
response_remove_last,
|
|
76
|
+
response_remove_both,
|
|
77
|
+
revised_response_remove_first,
|
|
78
|
+
revised_response_remove_last,
|
|
79
|
+
revised_response_remove_both,
|
|
80
|
+
]
|
|
81
|
+
instruction_list = inp.instruction_id_list
|
|
82
|
+
is_following_list = []
|
|
83
|
+
|
|
84
|
+
for index, instruction_id in enumerate(instruction_list):
|
|
85
|
+
instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
|
|
86
|
+
instruction = instruction_cls(instruction_id)
|
|
87
|
+
|
|
88
|
+
# Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
|
|
89
|
+
kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
|
|
90
|
+
instruction.build_description(**kwargs)
|
|
91
|
+
args = instruction.get_instruction_args()
|
|
92
|
+
if args and "prompt" in args:
|
|
93
|
+
instruction.build_description(prompt=inp.prompt)
|
|
94
|
+
|
|
95
|
+
is_following = False
|
|
96
|
+
for r in all_responses:
|
|
97
|
+
if r.strip() and instruction.check_following(r):
|
|
98
|
+
is_following = True
|
|
99
|
+
break
|
|
100
|
+
|
|
101
|
+
is_following_list.append(is_following)
|
|
102
|
+
|
|
103
|
+
return OutputExample(
|
|
104
|
+
instruction_id_list=inp.instruction_id_list,
|
|
105
|
+
prompt=inp.prompt,
|
|
106
|
+
response=response,
|
|
107
|
+
follow_all_instructions=all(is_following_list),
|
|
108
|
+
follow_instruction_list=is_following_list,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def process_results(input, results):
|
|
113
|
+
response = results[0]
|
|
114
|
+
input_example = InputExample(
|
|
115
|
+
key=input.key,
|
|
116
|
+
instruction_id_list=input.instruction_id_list,
|
|
117
|
+
prompt=input.prompt,
|
|
118
|
+
kwargs=input.additional_kwargs,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
out_strict = test_instruction_following_strict(input_example, response)
|
|
122
|
+
out_loose = test_instruction_following_loose(input_example, response)
|
|
123
|
+
|
|
124
|
+
return {
|
|
125
|
+
"prompt_level_strict_acc": out_strict.follow_all_instructions,
|
|
126
|
+
"inst_level_strict_acc": out_strict.follow_instruction_list,
|
|
127
|
+
"prompt_level_loose_acc": out_loose.follow_all_instructions,
|
|
128
|
+
"inst_level_loose_acc": out_loose.follow_instruction_list,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def agg_inst_level_acc(items):
|
|
133
|
+
flat_items = [item for sublist in items for item in sublist]
|
|
134
|
+
inst_level_acc = sum(flat_items) / len(flat_items)
|
|
135
|
+
return inst_level_acc
|
|
File without changes
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
import re
|
|
7
|
+
import time
|
|
8
|
+
import traceback
|
|
9
|
+
from collections.abc import Callable, Sequence
|
|
10
|
+
|
|
11
|
+
import aiohttp
|
|
12
|
+
from aleph_alpha_client import (
|
|
13
|
+
AsyncClient,
|
|
14
|
+
BusyError,
|
|
15
|
+
Client,
|
|
16
|
+
CompletionRequest,
|
|
17
|
+
CompletionResponse,
|
|
18
|
+
EvaluationRequest,
|
|
19
|
+
EvaluationResponse,
|
|
20
|
+
Prompt,
|
|
21
|
+
)
|
|
22
|
+
from aleph_alpha_client.prompt import Text
|
|
23
|
+
from dotenv import load_dotenv
|
|
24
|
+
|
|
25
|
+
from eval_framework.llm.base import BaseLLM
|
|
26
|
+
from eval_framework.shared.types import Error, PromptTooLongException, RawCompletion, RawLoglikelihood
|
|
27
|
+
from eval_framework.tasks.base import Sample
|
|
28
|
+
from eval_framework.tasks.utils import raise_errors
|
|
29
|
+
from template_formatting.formatter import BaseFormatter, Llama3Formatter, Message
|
|
30
|
+
|
|
31
|
+
load_dotenv()
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def safe_json_loads(s: str) -> dict:
|
|
37
|
+
try:
|
|
38
|
+
return json.loads(s)
|
|
39
|
+
except (json.JSONDecodeError, TypeError):
|
|
40
|
+
return {}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AlephAlphaAPIModel(BaseLLM):
|
|
44
|
+
LLM_NAME: str
|
|
45
|
+
DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
formatter: BaseFormatter | None = None,
|
|
50
|
+
checkpoint_name: str | None = None,
|
|
51
|
+
# Please see README.md for tips if adapting the following parameters.
|
|
52
|
+
max_retries: int = 100,
|
|
53
|
+
max_async_concurrent_requests: int = 32,
|
|
54
|
+
request_timeout_seconds: int = 30 * 60 + 5,
|
|
55
|
+
queue_full_timeout_seconds: int = 30 * 60 + 5,
|
|
56
|
+
) -> None:
|
|
57
|
+
self._formatter: BaseFormatter
|
|
58
|
+
if formatter is None:
|
|
59
|
+
if self.DEFAULT_FORMATTER is None:
|
|
60
|
+
raise ValueError("Either formatter or default formatter must be specified")
|
|
61
|
+
self._formatter = self.DEFAULT_FORMATTER()
|
|
62
|
+
else:
|
|
63
|
+
self._formatter = formatter
|
|
64
|
+
self._llm_name = checkpoint_name or self.LLM_NAME
|
|
65
|
+
self.max_async_concurrent_requests = max_async_concurrent_requests
|
|
66
|
+
self.max_retries = max_retries
|
|
67
|
+
self.request_timeout_seconds = request_timeout_seconds
|
|
68
|
+
self.queue_full_timeout_seconds = queue_full_timeout_seconds
|
|
69
|
+
self._validate_model_availability()
|
|
70
|
+
|
|
71
|
+
def _validate_model_availability(self) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Validate that the model name is available by making a test request.
|
|
74
|
+
"""
|
|
75
|
+
try:
|
|
76
|
+
# 'Client' object does not support the context manager protocol
|
|
77
|
+
client = Client(
|
|
78
|
+
host=os.getenv("AA_INFERENCE_ENDPOINT", "dummy_endpoint"),
|
|
79
|
+
token=os.getenv("AA_TOKEN", "dummy"),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
request = CompletionRequest(
|
|
83
|
+
prompt=Prompt.from_text(""),
|
|
84
|
+
maximum_tokens=1,
|
|
85
|
+
)
|
|
86
|
+
client.complete(request, model=self._llm_name)
|
|
87
|
+
logger.info(f"Model '{self._llm_name}' available and loaded.")
|
|
88
|
+
except Exception as e:
|
|
89
|
+
raise RuntimeError(f"Model '{self._llm_name}' is not available: {e}")
|
|
90
|
+
|
|
91
|
+
async def _request_with_backoff(
|
|
92
|
+
self, client: AsyncClient, request: CompletionRequest | EvaluationRequest, id: int
|
|
93
|
+
) -> CompletionResponse | EvaluationResponse:
|
|
94
|
+
"""
|
|
95
|
+
Query Aleph-Alpha API with complete. Retry with back-off until it responds.
|
|
96
|
+
"""
|
|
97
|
+
num_attempts = 0
|
|
98
|
+
start_time: float | None = None
|
|
99
|
+
|
|
100
|
+
while True:
|
|
101
|
+
try:
|
|
102
|
+
if isinstance(request, CompletionRequest):
|
|
103
|
+
return await client.complete(request, model=self._llm_name)
|
|
104
|
+
elif isinstance(request, EvaluationRequest):
|
|
105
|
+
return await client.evaluate(request, model=self._llm_name)
|
|
106
|
+
else:
|
|
107
|
+
raise ValueError(f"Unsupported request type: {type(request)}")
|
|
108
|
+
|
|
109
|
+
except (TimeoutError, BusyError, RuntimeError, aiohttp.ClientError) as e:
|
|
110
|
+
status_code: str = safe_json_loads(e.args[1]).get("code", "") if len(e.args) >= 2 else ""
|
|
111
|
+
str_e = str(e)
|
|
112
|
+
if status_code == "QUEUE_FULL":
|
|
113
|
+
# Worker not available or missed a heartbeat (inference longer than scheduler's
|
|
114
|
+
# API_MODEL_AVAILABLE_TIMEOUT_DURATION_MILLIS) or the scheduler is overloaded.
|
|
115
|
+
if start_time is None:
|
|
116
|
+
start_time = time.time()
|
|
117
|
+
elapsed = time.time() - start_time
|
|
118
|
+
if elapsed <= self.queue_full_timeout_seconds:
|
|
119
|
+
logger.info(
|
|
120
|
+
f"Request {id}: {status_code or str_e[:256]} - retrying: attempt"
|
|
121
|
+
f" {num_attempts}/{self.max_retries}, elapsed {elapsed:.1f} sec"
|
|
122
|
+
)
|
|
123
|
+
# don't count as retry (request returns immediately, so just wait a bit not to DoS the server)
|
|
124
|
+
await asyncio.sleep(random.randint(5, 30))
|
|
125
|
+
continue
|
|
126
|
+
|
|
127
|
+
elif (
|
|
128
|
+
status_code == "TIMEOUT_TASK"
|
|
129
|
+
or isinstance(e, TimeoutError)
|
|
130
|
+
or "502 Bad Gateway" in str_e
|
|
131
|
+
or "504 Gateway Time-out" in str_e
|
|
132
|
+
or isinstance(e, aiohttp.ClientError)
|
|
133
|
+
):
|
|
134
|
+
# client timeout, either because task too long in a queue or inference too long
|
|
135
|
+
# (scheduler's API_CLIENT_TIMEOUT_DURATION_MILLIS). Retrying for the "inference too long"
|
|
136
|
+
# case makes no sense but we unfortunately don't know which case has happened.
|
|
137
|
+
num_attempts += 1
|
|
138
|
+
start_time = None
|
|
139
|
+
if num_attempts < self.max_retries:
|
|
140
|
+
logger.info(f"Request {id}: TIMEOUT_TASK - retrying: attempt {num_attempts}/{self.max_retries}")
|
|
141
|
+
await asyncio.sleep(random.randint(5, 30))
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
raise e
|
|
145
|
+
|
|
146
|
+
async def _process_request_with_client(
|
|
147
|
+
self,
|
|
148
|
+
client: AsyncClient,
|
|
149
|
+
semaphore: asyncio.Semaphore,
|
|
150
|
+
request: CompletionRequest | EvaluationRequest,
|
|
151
|
+
id: int,
|
|
152
|
+
) -> RawCompletion | tuple[EvaluationRequest, EvaluationResponse | Error]:
|
|
153
|
+
async with semaphore:
|
|
154
|
+
try:
|
|
155
|
+
response = await self._request_with_backoff(client=client, request=request, id=id)
|
|
156
|
+
logger.info(f"Request {id}: Success")
|
|
157
|
+
except Exception as e:
|
|
158
|
+
if raise_errors():
|
|
159
|
+
raise e
|
|
160
|
+
logger.info(f"Request {id}: Failure: {str(e)[:256]}")
|
|
161
|
+
if len(e.args) >= 2:
|
|
162
|
+
status_code: str = safe_json_loads(e.args[1]).get("code", "")
|
|
163
|
+
if status_code == "PROMPT_TOO_LONG":
|
|
164
|
+
error = Error(
|
|
165
|
+
error_class=PromptTooLongException.__name__,
|
|
166
|
+
message="Prompt exceeded context size.",
|
|
167
|
+
traceback=traceback.format_exc(),
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
170
|
+
error = Error(error_class=status_code, message=str(e), traceback=traceback.format_exc())
|
|
171
|
+
else:
|
|
172
|
+
error = Error(error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc())
|
|
173
|
+
|
|
174
|
+
if isinstance(request, CompletionRequest):
|
|
175
|
+
assert isinstance(request.prompt.items[0], Text)
|
|
176
|
+
return RawCompletion(
|
|
177
|
+
prompt=request.prompt.items[0].text,
|
|
178
|
+
prompt_sequence_positions=None,
|
|
179
|
+
completion="",
|
|
180
|
+
completion_sequence_positions=0,
|
|
181
|
+
raw_completion_error=error,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
return (request, error)
|
|
185
|
+
|
|
186
|
+
# Completion responses can directly be converted to RawCompletion
|
|
187
|
+
if isinstance(request, CompletionRequest):
|
|
188
|
+
assert isinstance(request.prompt.items[0], Text) and isinstance(response, CompletionResponse)
|
|
189
|
+
assert len(response.completions) == 1
|
|
190
|
+
prompt = request.prompt.items[0].text
|
|
191
|
+
completion = response.completions[0].completion or ""
|
|
192
|
+
prompt_sequence_positions: int | None = None
|
|
193
|
+
completion_sequence_positions: int | None = None
|
|
194
|
+
|
|
195
|
+
# Support workaround in api-worker-transformer's scaling generator to return the correct number of tokens.
|
|
196
|
+
# These are part of the completion string; those in CompletionResponse are invalid in this case.
|
|
197
|
+
m = re.match(r"\uf8c9(\d+),(\d+)\uf8c9(.*)", completion, re.DOTALL)
|
|
198
|
+
if m is not None:
|
|
199
|
+
num_input_tokens, num_completion_tokens, completion = m.groups()
|
|
200
|
+
prompt_sequence_positions = int(num_input_tokens)
|
|
201
|
+
completion_sequence_positions = int(num_completion_tokens)
|
|
202
|
+
else:
|
|
203
|
+
prompt_sequence_positions = response.num_tokens_prompt_total if response else None
|
|
204
|
+
completion_sequence_positions = response.num_tokens_generated if response else None
|
|
205
|
+
|
|
206
|
+
return RawCompletion(
|
|
207
|
+
prompt=prompt,
|
|
208
|
+
prompt_sequence_positions=prompt_sequence_positions,
|
|
209
|
+
completion=completion,
|
|
210
|
+
completion_sequence_positions=completion_sequence_positions,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Evaluation responses must be assembled from individual choice requests later
|
|
214
|
+
else:
|
|
215
|
+
assert isinstance(response, EvaluationResponse)
|
|
216
|
+
return (request, response)
|
|
217
|
+
|
|
218
|
+
async def _process_requests(
|
|
219
|
+
self, requests: list[CompletionRequest] | list[EvaluationRequest]
|
|
220
|
+
) -> list[RawCompletion | tuple[EvaluationRequest, EvaluationResponse | Error]]:
|
|
221
|
+
semaphore = asyncio.Semaphore(self.max_async_concurrent_requests)
|
|
222
|
+
async with AsyncClient(
|
|
223
|
+
host=os.getenv("AA_INFERENCE_ENDPOINT", "dummy_endpoint"),
|
|
224
|
+
nice=True,
|
|
225
|
+
request_timeout_seconds=self.request_timeout_seconds,
|
|
226
|
+
token=os.getenv("AA_TOKEN", "dummy"),
|
|
227
|
+
total_retries=0, # we have a custom retry policy in _request_with_backoff()
|
|
228
|
+
) as client:
|
|
229
|
+
tasks = (
|
|
230
|
+
self._process_request_with_client(client, semaphore, request, i)
|
|
231
|
+
for i, request in enumerate[CompletionRequest | EvaluationRequest](requests)
|
|
232
|
+
)
|
|
233
|
+
responses = await asyncio.gather(*tasks) # guarantees order of responses
|
|
234
|
+
return responses
|
|
235
|
+
|
|
236
|
+
def generate_from_messages(
|
|
237
|
+
self,
|
|
238
|
+
messages: list[Sequence[Message]],
|
|
239
|
+
stop_sequences: list[str] | None = None,
|
|
240
|
+
max_tokens: int | None = None,
|
|
241
|
+
temperature: float | None = None,
|
|
242
|
+
) -> list[RawCompletion]:
|
|
243
|
+
if temperature is None:
|
|
244
|
+
effective_temperature = 0.0 # Current default, TODO: refactor to use model's default
|
|
245
|
+
logger.info(
|
|
246
|
+
f"Using default temperature value: {effective_temperature} as no custom temperature value was provided"
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
effective_temperature = temperature
|
|
250
|
+
|
|
251
|
+
requests = []
|
|
252
|
+
|
|
253
|
+
for single_messages in messages:
|
|
254
|
+
requests.append(
|
|
255
|
+
CompletionRequest(
|
|
256
|
+
prompt=Prompt.from_text(self._formatter.format(single_messages, output_mode="string")),
|
|
257
|
+
maximum_tokens=max_tokens,
|
|
258
|
+
stop_sequences=stop_sequences,
|
|
259
|
+
temperature=effective_temperature,
|
|
260
|
+
)
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
responses = asyncio.run(self._process_requests(requests))
|
|
264
|
+
return responses # type: ignore
|
|
265
|
+
|
|
266
|
+
def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
|
|
267
|
+
samples_prompt: list[str] = []
|
|
268
|
+
evaluation_requests: list[EvaluationRequest] = []
|
|
269
|
+
results: list[RawLoglikelihood] = []
|
|
270
|
+
|
|
271
|
+
# evaluate all choices independently in flattened list
|
|
272
|
+
for sample in samples:
|
|
273
|
+
prompt: str = self._formatter.format(sample.messages, output_mode="string") if sample.messages else ""
|
|
274
|
+
samples_prompt.append(prompt)
|
|
275
|
+
for choice in sample.possible_completions or []:
|
|
276
|
+
evaluation_requests.append(
|
|
277
|
+
EvaluationRequest(prompt=Prompt.from_text(prompt), completion_expected=choice)
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
evaluation_responses = asyncio.run(self._process_requests(evaluation_requests))
|
|
281
|
+
evaluation_responses_iter = iter(evaluation_responses)
|
|
282
|
+
|
|
283
|
+
# assemble choices to RawLoglikelihood from a flattened list for all possible choice replies
|
|
284
|
+
for sample, prompt in zip(samples, samples_prompt, strict=True):
|
|
285
|
+
choices_log_probs: dict[str, float] = {}
|
|
286
|
+
choices_sequence_positions: dict[str, int] = {}
|
|
287
|
+
prompt_sequence_positions: int | None = 0
|
|
288
|
+
error = None
|
|
289
|
+
|
|
290
|
+
for choice in sample.possible_completions or []:
|
|
291
|
+
request, response = next(evaluation_responses_iter)
|
|
292
|
+
if error is not None:
|
|
293
|
+
continue
|
|
294
|
+
if isinstance(response, Error): # failure for one choice leads to failure of the whole sample
|
|
295
|
+
error = response
|
|
296
|
+
prompt_sequence_positions = None
|
|
297
|
+
choices_log_probs = {}
|
|
298
|
+
choices_sequence_positions = {}
|
|
299
|
+
else:
|
|
300
|
+
assert isinstance(request, EvaluationRequest) and isinstance(response, EvaluationResponse)
|
|
301
|
+
assert isinstance(request.prompt.items[0], Text)
|
|
302
|
+
assert prompt == request.prompt.items[0].text, f"{prompt} != {request.prompt.items[0].text}"
|
|
303
|
+
assert choice == request.completion_expected, f"{choice} != {request.completion_expected}"
|
|
304
|
+
prompt_sequence_positions = response.num_tokens_prompt_total - response.result["token_count"]
|
|
305
|
+
choices_log_probs[choice] = response.result["log_probability"]
|
|
306
|
+
choices_sequence_positions[choice] = response.result["token_count"]
|
|
307
|
+
|
|
308
|
+
results.append(
|
|
309
|
+
RawLoglikelihood(
|
|
310
|
+
prompt=prompt,
|
|
311
|
+
prompt_sequence_positions=prompt_sequence_positions,
|
|
312
|
+
loglikelihoods=choices_log_probs,
|
|
313
|
+
loglikelihoods_sequence_positions=choices_sequence_positions,
|
|
314
|
+
raw_loglikelihood_error=error,
|
|
315
|
+
)
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return results
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class Llama31_8B_Instruct_API(AlephAlphaAPIModel):
|
|
322
|
+
LLM_NAME = "llama-3.1-8b-instruct"
|
|
323
|
+
DEFAULT_FORMATTER = Llama3Formatter
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
|
|
4
|
+
from eval_framework.shared.types import RawCompletion, RawLoglikelihood
|
|
5
|
+
from eval_framework.tasks.base import Sample
|
|
6
|
+
from template_formatting.formatter import Message
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseLLM(ABC):
|
|
10
|
+
@property
|
|
11
|
+
def name(self) -> str:
|
|
12
|
+
"""
|
|
13
|
+
This property is used to name the results folder and identify the eval results.
|
|
14
|
+
Overwrite this property in the subclass with e.g. the checkpoint name/huggingface model name."""
|
|
15
|
+
return self.__class__.__name__
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def generate_from_messages(
|
|
19
|
+
self,
|
|
20
|
+
messages: list[Sequence[Message]],
|
|
21
|
+
stop_sequences: list[str] | None = None,
|
|
22
|
+
max_tokens: int | None = None,
|
|
23
|
+
temperature: float | None = None,
|
|
24
|
+
) -> list[RawCompletion]:
|
|
25
|
+
"""
|
|
26
|
+
stop_sequences and max_tokens are injected by the task if exist. They should be overwritten or
|
|
27
|
+
extended with the properties of the model. This includes but is not limited to the stop tokens
|
|
28
|
+
by the evaluated checkpoint (e.g. <|eot_id|> for an instruction finetuned Llama3.1, <|endoftext|>
|
|
29
|
+
for a pretrained Llama3.1).
|
|
30
|
+
|
|
31
|
+
This function is expected to raise errors which are caught and reported when running the eval.
|
|
32
|
+
Please also make sure to raise an error in case of sequence length issues. We expect to always
|
|
33
|
+
raise an error if something impedes the expected completion of a task.
|
|
34
|
+
|
|
35
|
+
Important! The completion is expected to be detokenized and to NOT contain special tokens.
|
|
36
|
+
|
|
37
|
+
Returns: List[RawCompletion]
|
|
38
|
+
"""
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
|
|
43
|
+
"""
|
|
44
|
+
This function is expected to raise errors which are caught and reported when running the eval.
|
|
45
|
+
Please also make sure to raise an error in case of sequence length issues. We expect to always
|
|
46
|
+
raise an error if something prevents the expected completion of a task.
|
|
47
|
+
"""
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
def generate(
|
|
51
|
+
self,
|
|
52
|
+
samples: list[Sample],
|
|
53
|
+
stop_sequences: list[str] | None = None,
|
|
54
|
+
max_tokens: int | None = None,
|
|
55
|
+
temperature: float | None = None,
|
|
56
|
+
) -> list[RawCompletion]:
|
|
57
|
+
messages: list[Sequence[Message]] = [sample.messages for sample in samples]
|
|
58
|
+
return self.generate_from_messages(messages, stop_sequences, max_tokens, temperature)
|