evalscope 0.17.1__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/__init__.py +4 -1
- evalscope/api/benchmark/__init__.py +3 -0
- evalscope/api/benchmark/adapters/__init__.py +5 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +684 -0
- evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
- evalscope/api/benchmark/adapters/multi_choice_adapter.py +83 -0
- evalscope/api/benchmark/adapters/text2image_adapter.py +156 -0
- evalscope/api/benchmark/adapters/vision_language_adapter.py +6 -0
- evalscope/api/benchmark/benchmark.py +356 -0
- evalscope/api/benchmark/meta.py +121 -0
- evalscope/api/dataset/__init__.py +2 -0
- evalscope/api/dataset/dataset.py +349 -0
- evalscope/api/dataset/loader.py +262 -0
- evalscope/api/dataset/utils.py +143 -0
- evalscope/api/evaluator/__init__.py +3 -0
- evalscope/api/evaluator/cache.py +378 -0
- evalscope/api/evaluator/evaluator.py +56 -0
- evalscope/api/evaluator/state.py +275 -0
- evalscope/api/filter/__init__.py +1 -0
- evalscope/api/filter/filter.py +72 -0
- evalscope/api/messages/__init__.py +12 -0
- evalscope/api/messages/chat_message.py +243 -0
- evalscope/api/messages/content.py +102 -0
- evalscope/api/messages/utils.py +35 -0
- evalscope/api/metric/__init__.py +2 -0
- evalscope/api/metric/metric.py +55 -0
- evalscope/api/metric/scorer.py +113 -0
- evalscope/api/mixin/__init__.py +1 -0
- evalscope/api/mixin/llm_judge_mixin.py +168 -0
- evalscope/api/model/__init__.py +12 -0
- evalscope/api/model/generate_config.py +155 -0
- evalscope/api/model/model.py +386 -0
- evalscope/api/model/model_output.py +285 -0
- evalscope/api/registry.py +182 -0
- evalscope/api/tool/__init__.py +3 -0
- evalscope/api/tool/tool_call.py +101 -0
- evalscope/api/tool/tool_info.py +173 -0
- evalscope/api/tool/utils.py +64 -0
- evalscope/app/app.py +3 -0
- evalscope/app/ui/app_ui.py +2 -1
- evalscope/app/ui/multi_model.py +50 -25
- evalscope/app/ui/single_model.py +26 -14
- evalscope/app/utils/data_utils.py +43 -27
- evalscope/app/utils/env_utils.py +12 -0
- evalscope/app/utils/text_utils.py +14 -14
- evalscope/app/utils/visualization.py +9 -4
- evalscope/arguments.py +7 -10
- evalscope/backend/opencompass/api_meta_template.py +2 -1
- evalscope/backend/opencompass/backend_manager.py +6 -5
- evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +10 -10
- evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
- evalscope/backend/rag_eval/ragas/task_template.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -1
- evalscope/backend/rag_eval/utils/embedding.py +10 -1
- evalscope/backend/rag_eval/utils/llm.py +13 -12
- evalscope/benchmarks/__init__.py +0 -2
- evalscope/benchmarks/aime/aime24_adapter.py +38 -40
- evalscope/benchmarks/aime/aime25_adapter.py +34 -40
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +86 -60
- evalscope/benchmarks/arc/arc_adapter.py +34 -147
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +96 -70
- evalscope/benchmarks/arena_hard/utils.py +37 -1
- evalscope/benchmarks/bbh/bbh_adapter.py +72 -144
- evalscope/benchmarks/bfcl/bfcl_adapter.py +188 -171
- evalscope/benchmarks/bfcl/generation.py +222 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +93 -162
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +85 -82
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -125
- evalscope/benchmarks/competition_math/competition_math_adapter.py +56 -108
- evalscope/benchmarks/data_collection/data_collection_adapter.py +187 -45
- evalscope/benchmarks/docmath/docmath_adapter.py +109 -51
- evalscope/benchmarks/docmath/utils.py +4 -5
- evalscope/benchmarks/drop/drop_adapter.py +88 -40
- evalscope/benchmarks/frames/frames_adapter.py +136 -52
- evalscope/benchmarks/general_arena/general_arena_adapter.py +140 -98
- evalscope/benchmarks/general_arena/utils.py +23 -27
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +40 -101
- evalscope/benchmarks/general_qa/general_qa_adapter.py +73 -134
- evalscope/benchmarks/gpqa/gpqa_adapter.py +61 -100
- evalscope/benchmarks/gpqa/{chain_of_thought.txt → prompt.py} +12 -5
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +62 -142
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +35 -124
- evalscope/benchmarks/hle/hle_adapter.py +127 -93
- evalscope/benchmarks/humaneval/humaneval_adapter.py +86 -55
- evalscope/benchmarks/ifeval/ifeval_adapter.py +69 -40
- evalscope/benchmarks/ifeval/instructions.py +109 -64
- evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
- evalscope/benchmarks/ifeval/instructions_util.py +2 -3
- evalscope/benchmarks/ifeval/utils.py +6 -7
- evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
- evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
- evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
- evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
- evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -65
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +2 -2
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +121 -71
- evalscope/benchmarks/live_code_bench/load_utils.py +13 -21
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -2
- evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +49 -75
- evalscope/benchmarks/math_500/math_500_adapter.py +41 -48
- evalscope/benchmarks/math_vista/__init__.py +0 -0
- evalscope/benchmarks/math_vista/math_vista_adapter.py +129 -0
- evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -205
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +80 -99
- evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +64 -110
- evalscope/benchmarks/mmmu/__init__.py +0 -0
- evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
- evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
- evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +129 -0
- evalscope/benchmarks/musr/musr_adapter.py +33 -64
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +196 -152
- evalscope/benchmarks/process_bench/process_bench_adapter.py +144 -76
- evalscope/benchmarks/race/race_adapter.py +33 -119
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +72 -70
- evalscope/benchmarks/super_gpqa/{five_shot_prompt.txt → prompt.py} +14 -16
- evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +73 -117
- evalscope/benchmarks/super_gpqa/utils.py +2 -1
- evalscope/benchmarks/tau_bench/generation.py +147 -0
- evalscope/benchmarks/tau_bench/tau_bench_adapter.py +114 -60
- evalscope/benchmarks/text2image/__init__.py +0 -0
- evalscope/benchmarks/text2image/evalmuse_adapter.py +78 -0
- evalscope/benchmarks/text2image/genai_bench_adapter.py +53 -0
- evalscope/benchmarks/text2image/general_t2i_adapter.py +42 -0
- evalscope/benchmarks/text2image/hpdv2_adapter.py +52 -0
- evalscope/benchmarks/text2image/tifa_adapter.py +27 -0
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +91 -70
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -124
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -266
- evalscope/benchmarks/winogrande/winogrande_adapter.py +28 -54
- evalscope/cli/cli.py +2 -0
- evalscope/cli/start_app.py +7 -1
- evalscope/cli/start_perf.py +7 -1
- evalscope/cli/start_server.py +6 -3
- evalscope/collections/__init__.py +2 -10
- evalscope/collections/sampler.py +10 -10
- evalscope/collections/schema.py +13 -11
- evalscope/config.py +157 -57
- evalscope/constants.py +37 -61
- evalscope/evaluator/__init__.py +1 -1
- evalscope/evaluator/evaluator.py +275 -419
- evalscope/filters/__init__.py +2 -0
- evalscope/filters/extraction.py +126 -0
- evalscope/filters/selection.py +57 -0
- evalscope/metrics/__init__.py +13 -13
- evalscope/metrics/llm_judge.py +47 -33
- evalscope/metrics/math_parser.py +27 -22
- evalscope/metrics/metric.py +307 -0
- evalscope/metrics/metrics.py +22 -18
- evalscope/metrics/t2v_metrics/__init__.py +0 -52
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +9 -13
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +3 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +2 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +10 -5
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +15 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +15 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +9 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +2 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +3 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +16 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +8 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +47 -25
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +12 -7
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +23 -17
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +33 -23
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +46 -30
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +69 -37
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +6 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +5 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +17 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +35 -19
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +14 -12
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +63 -52
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +63 -38
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +6 -3
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +6 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +15 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +3 -2
- evalscope/models/__init__.py +6 -29
- evalscope/models/image_edit_model.py +125 -0
- evalscope/models/mockllm.py +65 -0
- evalscope/models/model_apis.py +67 -0
- evalscope/models/modelscope.py +455 -0
- evalscope/models/openai_compatible.py +126 -0
- evalscope/models/text2image_model.py +124 -0
- evalscope/models/utils/openai.py +701 -0
- evalscope/perf/benchmark.py +4 -1
- evalscope/perf/http_client.py +4 -2
- evalscope/perf/plugin/api/custom_api.py +5 -4
- evalscope/perf/plugin/api/openai_api.py +11 -9
- evalscope/perf/plugin/datasets/custom.py +2 -1
- evalscope/perf/plugin/datasets/flickr8k.py +1 -1
- evalscope/perf/plugin/datasets/kontext_bench.py +1 -1
- evalscope/perf/plugin/datasets/line_by_line.py +2 -1
- evalscope/perf/plugin/datasets/longalpaca.py +2 -1
- evalscope/perf/plugin/datasets/openqa.py +4 -2
- evalscope/perf/utils/benchmark_util.py +15 -10
- evalscope/perf/utils/db_util.py +9 -6
- evalscope/perf/utils/local_server.py +11 -3
- evalscope/perf/utils/rich_display.py +16 -10
- evalscope/report/__init__.py +2 -3
- evalscope/report/combinator.py +18 -12
- evalscope/report/generator.py +51 -35
- evalscope/report/{utils.py → report.py} +8 -6
- evalscope/run.py +33 -47
- evalscope/summarizer.py +1 -1
- evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
- evalscope/utils/__init__.py +21 -2
- evalscope/utils/chat_service.py +3 -2
- evalscope/utils/deprecation_utils.py +12 -1
- evalscope/utils/function_utils.py +29 -0
- evalscope/utils/import_utils.py +23 -1
- evalscope/utils/io_utils.py +142 -6
- evalscope/utils/json_schema.py +208 -0
- evalscope/utils/logger.py +51 -12
- evalscope/utils/model_utils.py +11 -7
- evalscope/utils/multi_choices.py +288 -0
- evalscope/utils/url_utils.py +65 -0
- evalscope/version.py +2 -2
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/METADATA +108 -62
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/RECORD +258 -226
- tests/benchmark/test_eval.py +385 -0
- tests/benchmark/test_image_edit.py +65 -0
- tests/{aigc → benchmark}/test_t2i.py +22 -4
- tests/benchmark/test_vlm.py +80 -0
- tests/cli/test_all.py +85 -47
- tests/cli/test_collection.py +20 -8
- tests/cli/test_custom.py +22 -15
- tests/cli/test_reasoning.py +81 -0
- tests/common.py +73 -0
- tests/perf/test_perf.py +4 -2
- tests/rag/test_clip_benchmark.py +0 -2
- evalscope/benchmarks/aigc/t2i/base.py +0 -56
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +0 -78
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +0 -58
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +0 -58
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +0 -57
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +0 -37
- evalscope/benchmarks/arc/ai2_arc.py +0 -151
- evalscope/benchmarks/benchmark.py +0 -81
- evalscope/benchmarks/ceval/ceval_exam.py +0 -146
- evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
- evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
- evalscope/benchmarks/competition_math/competition_math.py +0 -79
- evalscope/benchmarks/data_adapter.py +0 -528
- evalscope/benchmarks/filters.py +0 -59
- evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
- evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
- evalscope/benchmarks/humaneval/humaneval.py +0 -79
- evalscope/benchmarks/mmlu/mmlu.py +0 -160
- evalscope/benchmarks/mmlu/samples.jsonl +0 -5
- evalscope/benchmarks/process_bench/critique_template.txt +0 -13
- evalscope/benchmarks/race/race.py +0 -104
- evalscope/benchmarks/race/samples.jsonl +0 -5
- evalscope/benchmarks/super_gpqa/zero_shot_prompt.txt +0 -4
- evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
- evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
- evalscope/benchmarks/utils.py +0 -60
- evalscope/collections/evaluator.py +0 -375
- evalscope/metrics/completion_parsers.py +0 -227
- evalscope/metrics/named_metrics.py +0 -55
- evalscope/models/adapters/__init__.py +0 -14
- evalscope/models/adapters/base_adapter.py +0 -84
- evalscope/models/adapters/bfcl_adapter.py +0 -246
- evalscope/models/adapters/chat_adapter.py +0 -207
- evalscope/models/adapters/choice_adapter.py +0 -222
- evalscope/models/adapters/custom_adapter.py +0 -71
- evalscope/models/adapters/server_adapter.py +0 -236
- evalscope/models/adapters/t2i_adapter.py +0 -79
- evalscope/models/adapters/tau_bench_adapter.py +0 -189
- evalscope/models/custom/__init__.py +0 -4
- evalscope/models/custom/custom_model.py +0 -50
- evalscope/models/custom/dummy_model.py +0 -99
- evalscope/models/local_model.py +0 -128
- evalscope/models/register.py +0 -41
- tests/cli/test_run.py +0 -489
- /evalscope/{benchmarks/aigc → api}/__init__.py +0 -0
- /evalscope/benchmarks/{aigc/t2i → image_edit}/__init__.py +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/LICENSE +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/WHEEL +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/top_level.txt +0 -0
- /tests/{aigc → benchmark}/__init__.py +0 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
import types
|
|
2
|
+
import typing
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from dataclasses import is_dataclass
|
|
5
|
+
from datetime import date, datetime, time
|
|
6
|
+
from enum import EnumMeta
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
List,
|
|
12
|
+
Literal,
|
|
13
|
+
Optional,
|
|
14
|
+
Set,
|
|
15
|
+
Tuple,
|
|
16
|
+
Type,
|
|
17
|
+
Union,
|
|
18
|
+
cast,
|
|
19
|
+
get_args,
|
|
20
|
+
get_origin,
|
|
21
|
+
get_type_hints,
|
|
22
|
+
is_typeddict,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
JSONType = Literal['string', 'integer', 'number', 'boolean', 'array', 'object', 'null']
|
|
26
|
+
"""Valid types within JSON schema."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class JSONSchema(BaseModel):
|
|
30
|
+
"""JSON Schema for type."""
|
|
31
|
+
|
|
32
|
+
type: Optional[JSONType] = Field(default=None)
|
|
33
|
+
"""JSON type of tool parameter."""
|
|
34
|
+
|
|
35
|
+
format: Optional[str] = Field(default=None)
|
|
36
|
+
"""Format of the parameter (e.g. date-time)."""
|
|
37
|
+
|
|
38
|
+
description: Optional[str] = Field(default=None)
|
|
39
|
+
"""Parameter description."""
|
|
40
|
+
|
|
41
|
+
default: Any = Field(default=None)
|
|
42
|
+
"""Default value for parameter."""
|
|
43
|
+
|
|
44
|
+
enum: Optional[List[Any]] = Field(default=None)
|
|
45
|
+
"""Valid values for enum parameters."""
|
|
46
|
+
|
|
47
|
+
items: Optional['JSONSchema'] = Field(default=None)
|
|
48
|
+
"""Valid type for array parameters."""
|
|
49
|
+
|
|
50
|
+
properties: Optional[Dict[str, 'JSONSchema']] = Field(default=None)
|
|
51
|
+
"""Valid fields for object parametrs."""
|
|
52
|
+
|
|
53
|
+
additionalProperties: Optional[Union['JSONSchema', bool]] = Field(default=None)
|
|
54
|
+
"""Are additional properties allowed?"""
|
|
55
|
+
|
|
56
|
+
anyOf: Optional[List['JSONSchema']] = Field(default=None)
|
|
57
|
+
"""Valid types for union parameters."""
|
|
58
|
+
|
|
59
|
+
required: Optional[List[str]] = Field(default=None)
|
|
60
|
+
"""Required fields for object parameters."""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def json_schema(t: Type[Any]) -> JSONSchema:
|
|
64
|
+
"""Provide a JSON Schema for the specified type.
|
|
65
|
+
|
|
66
|
+
Schemas can be automatically inferred for a wide variety of
|
|
67
|
+
Python class types including Pydantic BaseModel, dataclasses,
|
|
68
|
+
and typed dicts.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
t: Python type
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
JSON Schema for type.
|
|
75
|
+
"""
|
|
76
|
+
origin = get_origin(t)
|
|
77
|
+
args = get_args(t)
|
|
78
|
+
|
|
79
|
+
if origin is None:
|
|
80
|
+
if t is int:
|
|
81
|
+
return JSONSchema(type='integer')
|
|
82
|
+
elif t is float:
|
|
83
|
+
return JSONSchema(type='number')
|
|
84
|
+
elif t is str:
|
|
85
|
+
return JSONSchema(type='string')
|
|
86
|
+
elif t is bool:
|
|
87
|
+
return JSONSchema(type='boolean')
|
|
88
|
+
elif t is datetime:
|
|
89
|
+
return JSONSchema(type='string', format='date-time')
|
|
90
|
+
elif t is date:
|
|
91
|
+
return JSONSchema(type='string', format='date')
|
|
92
|
+
elif t is time:
|
|
93
|
+
return JSONSchema(type='string', format='time')
|
|
94
|
+
elif t is list or t is set:
|
|
95
|
+
return JSONSchema(type='array', items=JSONSchema())
|
|
96
|
+
elif t is dict:
|
|
97
|
+
return JSONSchema(type='object', additionalProperties=JSONSchema())
|
|
98
|
+
elif (is_dataclass(t) or is_typeddict(t) or (isinstance(t, type) and issubclass(t, BaseModel))):
|
|
99
|
+
return cls_json_schema(t)
|
|
100
|
+
elif isinstance(t, EnumMeta):
|
|
101
|
+
return JSONSchema(enum=[item.value for item in t])
|
|
102
|
+
elif t is type(None):
|
|
103
|
+
return JSONSchema(type='null')
|
|
104
|
+
else:
|
|
105
|
+
return JSONSchema()
|
|
106
|
+
elif (origin is list or origin is List or origin is tuple or origin is Tuple or origin is set or origin is Set):
|
|
107
|
+
return JSONSchema(type='array', items=json_schema(args[0]) if args else JSONSchema())
|
|
108
|
+
elif origin is dict or origin is Dict:
|
|
109
|
+
return JSONSchema(
|
|
110
|
+
type='object',
|
|
111
|
+
additionalProperties=json_schema(args[1]) if len(args) > 1 else JSONSchema(),
|
|
112
|
+
)
|
|
113
|
+
elif origin is Union or origin is types.UnionType:
|
|
114
|
+
return JSONSchema(anyOf=[json_schema(arg) for arg in args])
|
|
115
|
+
elif origin is Optional:
|
|
116
|
+
return JSONSchema(anyOf=[json_schema(arg) for arg in args] + [JSONSchema(type='null')])
|
|
117
|
+
elif origin is typing.Literal:
|
|
118
|
+
return JSONSchema(enum=list(args))
|
|
119
|
+
|
|
120
|
+
return JSONSchema() # Default case if we can't determine the type
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def cls_json_schema(cls: Type[Any]) -> JSONSchema:
|
|
124
|
+
properties: Dict[str, JSONSchema] = {}
|
|
125
|
+
required: List[str] = []
|
|
126
|
+
|
|
127
|
+
if is_dataclass(cls):
|
|
128
|
+
fields = cls.__dataclass_fields__ # type: ignore
|
|
129
|
+
for name, field in fields.items():
|
|
130
|
+
properties[name] = json_schema(field.type) # type: ignore
|
|
131
|
+
if field.default == field.default_factory:
|
|
132
|
+
required.append(name)
|
|
133
|
+
elif isinstance(cls, type) and issubclass(cls, BaseModel):
|
|
134
|
+
schema = cls.model_json_schema()
|
|
135
|
+
schema = resolve_schema_references(schema)
|
|
136
|
+
for name, prop in schema.get('properties', {}).items():
|
|
137
|
+
properties[name] = JSONSchema(**prop)
|
|
138
|
+
required = schema.get('required', [])
|
|
139
|
+
elif is_typeddict(cls):
|
|
140
|
+
annotations = get_type_hints(cls)
|
|
141
|
+
for name, type_hint in annotations.items():
|
|
142
|
+
properties[name] = json_schema(type_hint)
|
|
143
|
+
if name in cls.__required_keys__:
|
|
144
|
+
required.append(name)
|
|
145
|
+
|
|
146
|
+
return JSONSchema(
|
|
147
|
+
type='object',
|
|
148
|
+
properties=properties,
|
|
149
|
+
required=required if required else None,
|
|
150
|
+
additionalProperties=False,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def python_type_to_json_type(python_type: Optional[str]) -> JSONType:
|
|
155
|
+
if python_type == 'str':
|
|
156
|
+
return 'string'
|
|
157
|
+
elif python_type == 'int':
|
|
158
|
+
return 'integer'
|
|
159
|
+
elif python_type == 'float':
|
|
160
|
+
return 'number'
|
|
161
|
+
elif python_type == 'bool':
|
|
162
|
+
return 'boolean'
|
|
163
|
+
elif python_type == 'list':
|
|
164
|
+
return 'array'
|
|
165
|
+
elif python_type == 'dict':
|
|
166
|
+
return 'object'
|
|
167
|
+
elif python_type == 'None':
|
|
168
|
+
return 'null'
|
|
169
|
+
elif python_type is None:
|
|
170
|
+
# treat 'unknown' as string as anything can be converted to string
|
|
171
|
+
return 'string'
|
|
172
|
+
else:
|
|
173
|
+
raise ValueError(f'Unsupported type: {python_type} for Python to JSON conversion.')
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def resolve_schema_references(schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
177
|
+
"""Resolves all $ref references in a JSON schema by inlining the definitions."""
|
|
178
|
+
schema = deepcopy(schema)
|
|
179
|
+
definitions = schema.pop('$defs', {})
|
|
180
|
+
|
|
181
|
+
def _resolve_refs(obj: Any) -> Any:
|
|
182
|
+
if isinstance(obj, dict):
|
|
183
|
+
if '$ref' in obj and obj['$ref'].startswith('#/$defs/'):
|
|
184
|
+
ref_key = obj['$ref'].split('/')[-1]
|
|
185
|
+
if ref_key in definitions:
|
|
186
|
+
# Replace with a deep copy of the definition
|
|
187
|
+
resolved = deepcopy(definitions[ref_key])
|
|
188
|
+
# Process any nested references in the definition
|
|
189
|
+
resolved = _resolve_refs(resolved)
|
|
190
|
+
|
|
191
|
+
# Merge in the current object fields, which should take priority
|
|
192
|
+
# This means that if you have e.g.
|
|
193
|
+
# {"$ref": "#/$defs/SubType", "description": "subtype of type SubType"},
|
|
194
|
+
# and SubType resolves to
|
|
195
|
+
# {"description": "The SubType Class", "parameters": {"param1": {"type": "string"}}},
|
|
196
|
+
# the final result will be:
|
|
197
|
+
# {"description": "subtype of type SubType", "parameters": {"param1": {"type": "string"}}}
|
|
198
|
+
return resolved | {k: o for k, o in obj.items() if k != '$ref'}
|
|
199
|
+
|
|
200
|
+
# Process all entries in the dictionary
|
|
201
|
+
return {k: _resolve_refs(v) for k, v in obj.items()}
|
|
202
|
+
elif isinstance(obj, list):
|
|
203
|
+
return [_resolve_refs(item) for item in obj]
|
|
204
|
+
else:
|
|
205
|
+
return obj
|
|
206
|
+
|
|
207
|
+
return cast(Dict[str, Any], _resolve_refs(schema))
|
|
208
|
+
return cast(Dict[str, Any], _resolve_refs(schema))
|
evalscope/utils/logger.py
CHANGED
|
@@ -1,18 +1,27 @@
|
|
|
1
|
+
import colorlog
|
|
1
2
|
import importlib.util as iutil
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
4
|
-
from
|
|
5
|
+
from logging import Logger
|
|
6
|
+
from typing import List, Optional
|
|
5
7
|
|
|
6
8
|
init_loggers = {}
|
|
9
|
+
# Define log formats
|
|
10
|
+
data_format = '%Y-%m-%d %H:%M:%S'
|
|
11
|
+
# For console output
|
|
12
|
+
color_detailed_format = '%(asctime)s - %(name)s - %(filename)s - %(funcName)s - %(lineno)d - %(log_color)s%(levelname)s%(reset)s: %(message)s' # noqa:E501
|
|
13
|
+
color_simple_format = '%(asctime)s - %(name)s - %(log_color)s%(levelname)s%(reset)s: %(message)s'
|
|
14
|
+
color_detailed_formatter = colorlog.ColoredFormatter(color_detailed_format, datefmt=data_format)
|
|
15
|
+
color_simple_formatter = colorlog.ColoredFormatter(color_simple_format, datefmt=data_format)
|
|
16
|
+
# For file output
|
|
17
|
+
detailed_format = '%(asctime)s - %(name)s - %(filename)s - %(funcName)s - %(lineno)d - %(levelname)s: %(message)s' # noqa:E501
|
|
18
|
+
simple_format = '%(asctime)s - %(name)s - %(levelname)s: %(message)s'
|
|
19
|
+
plain_detailed_formatter = logging.Formatter(detailed_format, datefmt=data_format)
|
|
20
|
+
plain_simple_formatter = logging.Formatter(simple_format, datefmt=data_format)
|
|
7
21
|
|
|
8
|
-
detailed_format = '%(asctime)s - %(name)s - %(filename)s - %(funcName)s - %(lineno)d - %(levelname)s - %(message)s'
|
|
9
|
-
simple_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
10
|
-
|
|
11
|
-
detailed_formatter = logging.Formatter(detailed_format)
|
|
12
|
-
simple_formatter = logging.Formatter(simple_format)
|
|
13
22
|
DEFAULT_LEVEL = logging.DEBUG if os.getenv('EVALSCOPE_LOG_LEVEL', 'INFO') == 'DEBUG' else logging.INFO
|
|
14
23
|
|
|
15
|
-
logging.basicConfig(format=simple_format, level=
|
|
24
|
+
logging.basicConfig(format=simple_format, level=logging.INFO, force=True)
|
|
16
25
|
|
|
17
26
|
# set logging level
|
|
18
27
|
logging.getLogger('datasets').setLevel(logging.WARNING)
|
|
@@ -20,7 +29,13 @@ logging.getLogger('httpx').setLevel(logging.WARNING)
|
|
|
20
29
|
logging.getLogger('modelscope').setLevel(logging.ERROR)
|
|
21
30
|
|
|
22
31
|
|
|
23
|
-
def get_logger(
|
|
32
|
+
def get_logger(
|
|
33
|
+
log_file: Optional[str] = None,
|
|
34
|
+
name: Optional[str] = None,
|
|
35
|
+
log_level: int = DEFAULT_LEVEL,
|
|
36
|
+
file_mode: str = 'w',
|
|
37
|
+
force=False
|
|
38
|
+
):
|
|
24
39
|
"""Get logging logger
|
|
25
40
|
|
|
26
41
|
Args:
|
|
@@ -31,7 +46,10 @@ def get_logger(log_file: Optional[str] = None, log_level: int = DEFAULT_LEVEL, f
|
|
|
31
46
|
specified (if filemode is unspecified, it defaults to 'w').
|
|
32
47
|
"""
|
|
33
48
|
|
|
34
|
-
|
|
49
|
+
if name:
|
|
50
|
+
logger_name = f"evalscope.{name.split('.')[-1]}"
|
|
51
|
+
else:
|
|
52
|
+
logger_name = 'evalscope'
|
|
35
53
|
logger = logging.getLogger(logger_name)
|
|
36
54
|
logger.propagate = False
|
|
37
55
|
|
|
@@ -40,7 +58,15 @@ def get_logger(log_file: Optional[str] = None, log_level: int = DEFAULT_LEVEL, f
|
|
|
40
58
|
logger.setLevel(log_level)
|
|
41
59
|
for handler in logger.handlers:
|
|
42
60
|
handler.setLevel(log_level)
|
|
43
|
-
|
|
61
|
+
# 区分不同类型的 handler,使用相应的格式化器
|
|
62
|
+
if isinstance(handler, logging.FileHandler):
|
|
63
|
+
handler.setFormatter(
|
|
64
|
+
plain_detailed_formatter if log_level == logging.DEBUG else plain_simple_formatter
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
handler.setFormatter(
|
|
68
|
+
color_detailed_formatter if log_level == logging.DEBUG else color_simple_formatter
|
|
69
|
+
)
|
|
44
70
|
add_file_handler_if_needed(logger, log_file, file_mode, log_level)
|
|
45
71
|
return logger
|
|
46
72
|
|
|
@@ -66,7 +92,11 @@ def get_logger(log_file: Optional[str] = None, log_level: int = DEFAULT_LEVEL, f
|
|
|
66
92
|
handlers.append(file_handler)
|
|
67
93
|
|
|
68
94
|
for handler in handlers:
|
|
69
|
-
|
|
95
|
+
# 区分不同类型的 handler,使用相应的格式化器
|
|
96
|
+
if isinstance(handler, logging.FileHandler):
|
|
97
|
+
handler.setFormatter(plain_detailed_formatter if log_level == logging.DEBUG else plain_simple_formatter)
|
|
98
|
+
else:
|
|
99
|
+
handler.setFormatter(color_detailed_formatter if log_level == logging.DEBUG else color_simple_formatter)
|
|
70
100
|
handler.setLevel(log_level)
|
|
71
101
|
logger.addHandler(handler)
|
|
72
102
|
|
|
@@ -102,6 +132,15 @@ def add_file_handler_if_needed(logger, log_file, file_mode, log_level):
|
|
|
102
132
|
|
|
103
133
|
if is_worker0 and log_file is not None:
|
|
104
134
|
file_handler = logging.FileHandler(log_file, file_mode)
|
|
105
|
-
file_handler.setFormatter(
|
|
135
|
+
file_handler.setFormatter(plain_detailed_formatter if log_level == logging.DEBUG else plain_simple_formatter)
|
|
106
136
|
file_handler.setLevel(log_level)
|
|
107
137
|
logger.addHandler(file_handler)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def warn_once(logger: Logger, message: str) -> None:
|
|
141
|
+
if message not in _warned:
|
|
142
|
+
logger.warning(message)
|
|
143
|
+
_warned.append(message)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
_warned: List[str] = []
|
evalscope/utils/model_utils.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
import os
|
|
3
2
|
import random
|
|
4
|
-
import torch
|
|
5
3
|
from enum import Enum
|
|
6
4
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
7
5
|
|
|
6
|
+
from evalscope.utils.import_utils import check_import
|
|
7
|
+
|
|
8
8
|
if TYPE_CHECKING:
|
|
9
9
|
from transformers import GenerationConfig
|
|
10
10
|
|
|
@@ -69,8 +69,12 @@ def seed_everything(seed: int):
|
|
|
69
69
|
"""
|
|
70
70
|
random.seed(seed)
|
|
71
71
|
np.random.seed(seed)
|
|
72
|
-
|
|
73
|
-
if torch
|
|
74
|
-
torch
|
|
75
|
-
|
|
76
|
-
torch.
|
|
72
|
+
|
|
73
|
+
if check_import('torch'):
|
|
74
|
+
import torch
|
|
75
|
+
|
|
76
|
+
torch.manual_seed(seed)
|
|
77
|
+
if torch.cuda.is_available():
|
|
78
|
+
torch.cuda.manual_seed_all(seed)
|
|
79
|
+
torch.backends.cudnn.deterministic = True
|
|
80
|
+
torch.backends.cudnn.benchmark = False
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# flake8: noqa: E501
|
|
2
|
+
import re
|
|
3
|
+
from typing import List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from evalscope.api.evaluator import Choices, Target, TaskState
|
|
6
|
+
|
|
7
|
+
FEW_SHOT_TEMPLATE = r"""Here are some examples of how to answer similar questions:
|
|
8
|
+
|
|
9
|
+
{fewshot}
|
|
10
|
+
|
|
11
|
+
""".lstrip()
|
|
12
|
+
|
|
13
|
+
CHINESE_FEW_SHOT_TEMPLATE = r"""以下是一些示例问题:
|
|
14
|
+
|
|
15
|
+
{fewshot}
|
|
16
|
+
|
|
17
|
+
""".lstrip()
|
|
18
|
+
|
|
19
|
+
CHINESE_SINGLE_ANSWER_TEMPLATE = r"""回答下面的单项选择题,请选出其中的正确答案。你的回答的最后一行应该是这样的格式:"答案:LETTER"(不带引号),其中 LETTER 是 {letters} 中的一个。
|
|
20
|
+
|
|
21
|
+
问题:{question}
|
|
22
|
+
选项:
|
|
23
|
+
{choices}
|
|
24
|
+
""".lstrip()
|
|
25
|
+
|
|
26
|
+
CHINESE_SINGLE_ANSWER_TEMPLATE_COT = r"""回答下面的单项选择题,请选出其中的正确答案。你的回答的最后一行应该是这样的格式:"答案:LETTER"(不带引号),其中 LETTER 是 {letters} 中的一个。请在回答前进行一步步思考。
|
|
27
|
+
|
|
28
|
+
问题:{question}
|
|
29
|
+
选项:
|
|
30
|
+
{choices}
|
|
31
|
+
""".lstrip()
|
|
32
|
+
|
|
33
|
+
SINGLE_ANSWER_TEMPLATE = r"""
|
|
34
|
+
Answer the following multiple choice question. The entire content of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of {letters}.
|
|
35
|
+
|
|
36
|
+
{question}
|
|
37
|
+
|
|
38
|
+
{choices}
|
|
39
|
+
""".strip()
|
|
40
|
+
|
|
41
|
+
SINGLE_ANSWER_TEMPLATE_COT = r"""
|
|
42
|
+
Answer the following multiple choice question. The last line of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of {letters}. Think step by step before answering.
|
|
43
|
+
|
|
44
|
+
{question}
|
|
45
|
+
|
|
46
|
+
{choices}
|
|
47
|
+
""".strip()
|
|
48
|
+
|
|
49
|
+
MULTIPLE_ANSWER_TEMPLATE = r"""
|
|
50
|
+
Answer the following multiple choice question where multiple answers may be correct. The entire content of your response should be of the following format: 'ANSWER: $LETTERS' (without quotes) where LETTERS is one or more of {letters}.
|
|
51
|
+
|
|
52
|
+
{question}
|
|
53
|
+
|
|
54
|
+
{choices}
|
|
55
|
+
""".strip()
|
|
56
|
+
|
|
57
|
+
MULTIPLE_ANSWER_TEMPLATE_COT = r"""
|
|
58
|
+
Answer the following multiple choice question where multiple answers may be correct. The last line of your response should be of the following format: 'ANSWER: $LETTERS' (without quotes) where LETTERS is one or more of {letters}. Think step by step before answering.
|
|
59
|
+
|
|
60
|
+
{question}
|
|
61
|
+
|
|
62
|
+
{choices}
|
|
63
|
+
""".strip()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def unshuffle_choices(choices: Choices) -> Choices:
|
|
67
|
+
# `sorted` returns `list[Choice]`, but for consistency we wrap this back
|
|
68
|
+
# into a `Choices` object
|
|
69
|
+
return Choices(sorted(choices, key=lambda choice: choice.original_position))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def answer_options(choices: Choices) -> str:
|
|
73
|
+
r"""
|
|
74
|
+
Returns the `choices` formatted as a multiple choice question, e.g.:
|
|
75
|
+
|
|
76
|
+
["choice 1", "choice 2", "choice 3"] ->
|
|
77
|
+
"A) choice 1\nB) choice 2\nC) choice 3"
|
|
78
|
+
"""
|
|
79
|
+
indexes = list(range(len(choices)))
|
|
80
|
+
|
|
81
|
+
return '\n'.join([f'{answer_character(i)}) {choices[j].value}' for i, j in enumerate(indexes)])
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def prompt(question: str, choices: Union[Choices, List[str]], template: str, fewshot: Optional[str] = None) -> str:
|
|
85
|
+
if isinstance(choices, list):
|
|
86
|
+
choices = Choices(choices)
|
|
87
|
+
|
|
88
|
+
choices_text = answer_options(choices)
|
|
89
|
+
letters = ','.join(answer_character(i) for i in range(len(choices)))
|
|
90
|
+
if not fewshot:
|
|
91
|
+
return template.format(
|
|
92
|
+
choices=choices_text,
|
|
93
|
+
letters=letters,
|
|
94
|
+
question=question,
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
return template.format(
|
|
98
|
+
choices=choices_text,
|
|
99
|
+
letters=letters,
|
|
100
|
+
question=question,
|
|
101
|
+
fewshot=fewshot,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def format_example(
|
|
106
|
+
question: str,
|
|
107
|
+
choices: Choices,
|
|
108
|
+
answer: Target,
|
|
109
|
+
) -> str:
|
|
110
|
+
"""Format a single example for few-shot learning.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
question (str): The question text.
|
|
114
|
+
choices (list[str]): The list of choices.
|
|
115
|
+
answer (list[str]): The correct answers.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
str: Formatted example string.
|
|
119
|
+
"""
|
|
120
|
+
choices_text = answer_options(choices)
|
|
121
|
+
return f'{question}\n{choices_text}\nANSWER: {answer.text}'
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _fallback_parse_answer(completion: str) -> Optional[set[str]]:
|
|
125
|
+
# Fallback to find the last upper case letter
|
|
126
|
+
for letter in reversed(completion):
|
|
127
|
+
if letter.isupper():
|
|
128
|
+
return {letter}
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def parse_answers(state: TaskState, multiple_correct: bool = False) -> set[str]:
|
|
133
|
+
"""
|
|
134
|
+
Convenience function for extracting answers from the state output.
|
|
135
|
+
|
|
136
|
+
The generated response must be in the format 'ANSWER: <answers>',
|
|
137
|
+
otherwise we can't extract what the model thinks is "true". We can be a
|
|
138
|
+
bit flexible whether these are "AB" vs "A,B" vs "A B".
|
|
139
|
+
|
|
140
|
+
However, if the answer isn't in the expected format the model has
|
|
141
|
+
failed in the task so we'll ultimately just mark it as incorrect
|
|
142
|
+
"""
|
|
143
|
+
# First check whether the string strictly ends with the expected answer
|
|
144
|
+
# In this case, we're looking for a single line which contains the expected
|
|
145
|
+
# ANSWER: <answer> string with only whitespace or a period/full stop at the end.
|
|
146
|
+
match = re.search(
|
|
147
|
+
r'(?i)^ANSWER\s*:\s*([A-Za-z\d ,]+)\s*(?:$|\n|\.)',
|
|
148
|
+
state.output.completion,
|
|
149
|
+
flags=re.MULTILINE,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# If we couldn't match the strict version, we can try the less strict
|
|
153
|
+
# version for backward compatibility
|
|
154
|
+
if match is None:
|
|
155
|
+
match = re.search(
|
|
156
|
+
r'(?i)ANSWER\s*:\s*([A-Za-z\d ,]+)(?:[^\w]|\n|$|\.)',
|
|
157
|
+
state.output.completion,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if match is None:
|
|
161
|
+
fallback_answer = _fallback_parse_answer(state.output.completion)
|
|
162
|
+
if fallback_answer:
|
|
163
|
+
return fallback_answer
|
|
164
|
+
|
|
165
|
+
if match is None:
|
|
166
|
+
return set()
|
|
167
|
+
|
|
168
|
+
matched = match.group(1)
|
|
169
|
+
|
|
170
|
+
# Strip trailing period / full stop
|
|
171
|
+
matched = matched.strip()
|
|
172
|
+
matched = matched.rstrip('.')
|
|
173
|
+
|
|
174
|
+
allowed_options = set(answer_character(i) for i in range(len(state.choices)))
|
|
175
|
+
|
|
176
|
+
if multiple_correct:
|
|
177
|
+
# Match must contain only the allowed choices
|
|
178
|
+
# (may be separated by commas, spaces, the word 'and', or nothing at all)
|
|
179
|
+
|
|
180
|
+
matched = matched.replace(' and ', '')
|
|
181
|
+
|
|
182
|
+
matched = matched.replace(' ', '')
|
|
183
|
+
|
|
184
|
+
split_comma = set(matched.split(','))
|
|
185
|
+
if split_comma.issubset(allowed_options):
|
|
186
|
+
answers = split_comma
|
|
187
|
+
return answers
|
|
188
|
+
|
|
189
|
+
split_nothing = set(matched)
|
|
190
|
+
if split_nothing.issubset(allowed_options):
|
|
191
|
+
answers = split_nothing
|
|
192
|
+
return answers
|
|
193
|
+
|
|
194
|
+
else:
|
|
195
|
+
# Match must contain a single letter in the allowed choices
|
|
196
|
+
if matched in allowed_options:
|
|
197
|
+
answers = {matched}
|
|
198
|
+
return answers
|
|
199
|
+
|
|
200
|
+
return set()
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def parse_answers_zh(state: TaskState, multiple_correct: bool = False) -> set[str]:
|
|
204
|
+
"""
|
|
205
|
+
Convenience function for extracting answers from the state output in Chinese format.
|
|
206
|
+
|
|
207
|
+
The generated response must be in the format '答案:选项',
|
|
208
|
+
otherwise we can't extract what the model thinks is "true". We can be a
|
|
209
|
+
bit flexible whether these are "AB" vs "A,B" vs "A B".
|
|
210
|
+
"""
|
|
211
|
+
# Simple pattern to capture answers with optional bold markdown
|
|
212
|
+
pattern = r'答案\s*[::]\s*([A-Za-z0-9,,]+)'
|
|
213
|
+
match = re.search(pattern, state.output.completion, flags=re.MULTILINE)
|
|
214
|
+
|
|
215
|
+
if match is None:
|
|
216
|
+
fallback_answer = _fallback_parse_answer(state.output.completion)
|
|
217
|
+
if fallback_answer:
|
|
218
|
+
return fallback_answer
|
|
219
|
+
|
|
220
|
+
if match is None:
|
|
221
|
+
return set()
|
|
222
|
+
|
|
223
|
+
matched = match.group(1).strip().rstrip('。.')
|
|
224
|
+
allowed_options = set(answer_character(i) for i in range(len(state.choices)))
|
|
225
|
+
|
|
226
|
+
if multiple_correct:
|
|
227
|
+
# Handle comma-separated or continuous letters
|
|
228
|
+
matched = matched.replace(' 和 ', '').replace(' ', '').replace(',', ',')
|
|
229
|
+
answers = set(matched.split(',')) if ',' in matched else set(matched)
|
|
230
|
+
return answers if answers.issubset(allowed_options) else set()
|
|
231
|
+
else:
|
|
232
|
+
# Single answer
|
|
233
|
+
return {matched} if matched in allowed_options else set()
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def set_choices_based_on_generated_response(state: TaskState, answers: set[str]) -> None:
|
|
237
|
+
true_answers = [answer_index(letter) for letter in answers]
|
|
238
|
+
|
|
239
|
+
for i in range(len(state.choices)):
|
|
240
|
+
if i in true_answers:
|
|
241
|
+
state.choices.mark_choice(i, True)
|
|
242
|
+
else:
|
|
243
|
+
state.choices.mark_choice(i, False)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def valid_template(template: str) -> bool:
|
|
247
|
+
"""Check if a template has the required capture groups for a multiple choice question"""
|
|
248
|
+
return bool(re.search(r'\{question\}', template) and re.search(r'\{choices\}', template))
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class MultipleChoiceTemplate:
|
|
252
|
+
"""
|
|
253
|
+
Templates for multiple choice questions.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
SINGLE_ANSWER = SINGLE_ANSWER_TEMPLATE
|
|
257
|
+
SINGLE_ANSWER_COT = SINGLE_ANSWER_TEMPLATE_COT
|
|
258
|
+
MULTIPLE_ANSWER = MULTIPLE_ANSWER_TEMPLATE
|
|
259
|
+
MULTIPLE_ANSWER_COT = MULTIPLE_ANSWER_TEMPLATE_COT
|
|
260
|
+
CHINESE_FEW_SHOT_TEMPLATE = CHINESE_FEW_SHOT_TEMPLATE
|
|
261
|
+
CHINESE_SINGLE_ANSWER_TEMPLATE = CHINESE_SINGLE_ANSWER_TEMPLATE
|
|
262
|
+
CHINESE_SINGLE_ANSWER_TEMPLATE_COT = CHINESE_SINGLE_ANSWER_TEMPLATE_COT
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def answer_character(index: int) -> str:
|
|
266
|
+
r"""
|
|
267
|
+
Helper to go from array index to char, for example:
|
|
268
|
+
|
|
269
|
+
0 -> 'A', 1 -> 'B', etc
|
|
270
|
+
"""
|
|
271
|
+
if index < 26:
|
|
272
|
+
return chr(ord('A') + index)
|
|
273
|
+
else:
|
|
274
|
+
return str(index - 25)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def answer_index(char: str) -> int:
|
|
278
|
+
r"""
|
|
279
|
+
Helper to go from char to array index, for example:
|
|
280
|
+
|
|
281
|
+
'A' -> 0, 'B' -> 1, etc
|
|
282
|
+
"""
|
|
283
|
+
if char.isalpha() or char == ',' or char == ' ':
|
|
284
|
+
return ord(char.upper()) - ord('A')
|
|
285
|
+
elif char.isnumeric():
|
|
286
|
+
return 25 + int(char)
|
|
287
|
+
else:
|
|
288
|
+
raise ValueError(f'Unepxected multiple choice answer: {char} (must be a letter or number)')
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import httpx
|
|
3
|
+
import mimetypes
|
|
4
|
+
import re
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def is_http_url(url: str) -> bool:
|
|
8
|
+
return url.startswith('http://') or url.startswith('https://')
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def is_data_uri(url: str) -> bool:
|
|
12
|
+
pattern = r'^data:([^;]+);base64,.*'
|
|
13
|
+
return re.match(pattern, url) is not None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def data_uri_mime_type(data_url: str) -> str | None:
|
|
17
|
+
pattern = r'^data:([^;]+);.*'
|
|
18
|
+
match = re.match(pattern, data_url)
|
|
19
|
+
if match:
|
|
20
|
+
mime_type = match.group(1)
|
|
21
|
+
return mime_type
|
|
22
|
+
else:
|
|
23
|
+
return None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def data_uri_to_base64(data_uri: str) -> str:
|
|
27
|
+
pattern = r'^data:[^,]+,'
|
|
28
|
+
stripped_uri = re.sub(pattern, '', data_uri)
|
|
29
|
+
return stripped_uri
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def file_as_data(file: str) -> tuple[bytes, str]:
|
|
33
|
+
if is_data_uri(file):
|
|
34
|
+
# resolve mime type and base64 content
|
|
35
|
+
mime_type = data_uri_mime_type(file) or 'image/png'
|
|
36
|
+
file_base64 = data_uri_to_base64(file)
|
|
37
|
+
file_bytes = base64.b64decode(file_base64)
|
|
38
|
+
else:
|
|
39
|
+
# guess mime type; need strict=False for webp images
|
|
40
|
+
type, _ = mimetypes.guess_type(file, strict=False)
|
|
41
|
+
if type:
|
|
42
|
+
mime_type = type
|
|
43
|
+
else:
|
|
44
|
+
mime_type = 'image/png'
|
|
45
|
+
|
|
46
|
+
# handle url or file
|
|
47
|
+
if is_http_url(file):
|
|
48
|
+
client = httpx.Client()
|
|
49
|
+
file_bytes = client.get(file).content
|
|
50
|
+
else:
|
|
51
|
+
with open(file, 'rb') as f:
|
|
52
|
+
file_bytes = f.read()
|
|
53
|
+
|
|
54
|
+
# return bytes and type
|
|
55
|
+
return file_bytes, mime_type
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def file_as_data_uri(file: str) -> str:
|
|
59
|
+
if is_data_uri(file):
|
|
60
|
+
return file
|
|
61
|
+
else:
|
|
62
|
+
bytes, mime_type = file_as_data(file)
|
|
63
|
+
base64_file = base64.b64encode(bytes).decode('utf-8')
|
|
64
|
+
file = f'data:{mime_type};base64,{base64_file}'
|
|
65
|
+
return file
|