evalscope 0.16.0__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/app/__init__.py +28 -0
- evalscope/{report → app}/app.py +20 -25
- evalscope/app/constants.py +21 -0
- evalscope/arguments.py +2 -1
- evalscope/backend/opencompass/backend_manager.py +2 -1
- evalscope/backend/rag_eval/cmteb/arguments.py +4 -1
- evalscope/backend/rag_eval/cmteb/task_template.py +19 -3
- evalscope/backend/rag_eval/cmteb/tasks/CustomTask.py +1 -1
- evalscope/backend/rag_eval/utils/embedding.py +75 -35
- evalscope/benchmarks/benchmark.py +1 -0
- evalscope/benchmarks/data_adapter.py +97 -16
- evalscope/benchmarks/docmath/__init__.py +0 -0
- evalscope/benchmarks/docmath/docmath_adapter.py +84 -0
- evalscope/benchmarks/docmath/utils.py +220 -0
- evalscope/benchmarks/frames/__init__.py +0 -0
- evalscope/benchmarks/frames/frames_adapter.py +90 -0
- evalscope/benchmarks/frames/utils.py +37 -0
- evalscope/benchmarks/needle_haystack/__init__.py +0 -0
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +341 -0
- evalscope/benchmarks/needle_haystack/utils.py +79 -0
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +4 -1
- evalscope/benchmarks/tool_bench/utils.py +5 -4
- evalscope/benchmarks/utils.py +25 -0
- evalscope/cli/start_app.py +2 -2
- evalscope/collections/__init__.py +35 -3
- evalscope/collections/evaluator.py +18 -6
- evalscope/config.py +8 -2
- evalscope/evaluator/evaluator.py +38 -27
- evalscope/metrics/__init__.py +3 -1
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +1 -1
- evalscope/metrics/llm_judge.py +12 -5
- evalscope/metrics/math_parser.py +1 -1
- evalscope/models/adapters/server_adapter.py +2 -6
- evalscope/perf/arguments.py +2 -2
- evalscope/perf/benchmark.py +0 -9
- evalscope/perf/main.py +7 -0
- evalscope/perf/plugin/datasets/custom.py +15 -0
- evalscope/perf/utils/benchmark_util.py +1 -1
- evalscope/perf/utils/local_server.py +1 -0
- evalscope/perf/utils/log_utils.py +12 -5
- evalscope/perf/utils/rich_display.py +1 -1
- evalscope/report/__init__.py +36 -4
- evalscope/report/combinator.py +8 -0
- evalscope/report/generator.py +33 -9
- evalscope/report/utils.py +60 -3
- evalscope/run.py +12 -0
- evalscope/utils/logger.py +1 -1
- evalscope/utils/utils.py +12 -0
- evalscope/version.py +2 -2
- {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/METADATA +13 -11
- {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/RECORD +61 -50
- tests/aigc/test_t2i.py +40 -3
- tests/cli/test_all.py +39 -35
- tests/cli/test_collection.py +7 -6
- tests/cli/test_run.py +21 -11
- tests/rag/test_mteb.py +5 -5
- /evalscope/{report/app_arguments.py → app/arguments.py} +0 -0
- {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/LICENSE +0 -0
- {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/WHEEL +0 -0
- {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from evalscope.utils.import_utils import _LazyModule
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .app import create_app
|
|
8
|
+
from .arguments import add_argument
|
|
9
|
+
|
|
10
|
+
else:
|
|
11
|
+
_import_structure = {
|
|
12
|
+
'app': [
|
|
13
|
+
'create_app',
|
|
14
|
+
],
|
|
15
|
+
'arguments': [
|
|
16
|
+
'add_argument',
|
|
17
|
+
],
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
import sys
|
|
21
|
+
|
|
22
|
+
sys.modules[__name__] = _LazyModule(
|
|
23
|
+
__name__,
|
|
24
|
+
globals()['__file__'],
|
|
25
|
+
_import_structure,
|
|
26
|
+
module_spec=__spec__,
|
|
27
|
+
extra_objects={},
|
|
28
|
+
)
|
evalscope/{report → app}/app.py
RENAMED
|
@@ -11,35 +11,15 @@ from dataclasses import dataclass
|
|
|
11
11
|
from typing import Any, List, Union
|
|
12
12
|
|
|
13
13
|
from evalscope.constants import DataCollection
|
|
14
|
-
from evalscope.report import Report, ReportKey,
|
|
14
|
+
from evalscope.report import Report, ReportKey, get_data_frame, get_report_list
|
|
15
15
|
from evalscope.utils.io_utils import OutputsStructure, yaml_to_dict
|
|
16
16
|
from evalscope.utils.logger import configure_logging, get_logger
|
|
17
17
|
from evalscope.version import __version__
|
|
18
|
+
from .arguments import add_argument
|
|
19
|
+
from .constants import DATASET_TOKEN, LATEX_DELIMITERS, MODEL_TOKEN, PLOTLY_THEME, REPORT_TOKEN
|
|
18
20
|
|
|
19
21
|
logger = get_logger()
|
|
20
22
|
|
|
21
|
-
PLOTLY_THEME = 'plotly_dark'
|
|
22
|
-
REPORT_TOKEN = '@@'
|
|
23
|
-
MODEL_TOKEN = '::'
|
|
24
|
-
DATASET_TOKEN = ', '
|
|
25
|
-
LATEX_DELIMITERS = [{
|
|
26
|
-
'left': '$$',
|
|
27
|
-
'right': '$$',
|
|
28
|
-
'display': True
|
|
29
|
-
}, {
|
|
30
|
-
'left': '$',
|
|
31
|
-
'right': '$',
|
|
32
|
-
'display': False
|
|
33
|
-
}, {
|
|
34
|
-
'left': '\\(',
|
|
35
|
-
'right': '\\)',
|
|
36
|
-
'display': False
|
|
37
|
-
}, {
|
|
38
|
-
'left': '\\[',
|
|
39
|
-
'right': '\\]',
|
|
40
|
-
'display': True
|
|
41
|
-
}]
|
|
42
|
-
|
|
43
23
|
|
|
44
24
|
def scan_for_report_folders(root_path):
|
|
45
25
|
"""Scan for folders containing reports subdirectories"""
|
|
@@ -185,6 +165,13 @@ def get_single_dataset_df(df: pd.DataFrame, dataset_name: str):
|
|
|
185
165
|
return df, styler
|
|
186
166
|
|
|
187
167
|
|
|
168
|
+
def get_report_analysis(report_list: List[Report], dataset_name: str) -> str:
|
|
169
|
+
for report in report_list:
|
|
170
|
+
if report.dataset_name == dataset_name:
|
|
171
|
+
return report.analysis
|
|
172
|
+
return 'N/A'
|
|
173
|
+
|
|
174
|
+
|
|
188
175
|
def plot_single_dataset_scores(df: pd.DataFrame):
|
|
189
176
|
# TODO: add metric radio and relace category name
|
|
190
177
|
plot = px.bar(
|
|
@@ -456,6 +443,10 @@ def create_single_model_tab(sidebar: SidebarComponents, lang: str):
|
|
|
456
443
|
'zh': '数据集分数',
|
|
457
444
|
'en': 'Dataset Scores'
|
|
458
445
|
},
|
|
446
|
+
'report_analysis': {
|
|
447
|
+
'zh': '报告智能分析',
|
|
448
|
+
'en': 'Report Intelligent Analysis'
|
|
449
|
+
},
|
|
459
450
|
'dataset_scores_table': {
|
|
460
451
|
'zh': '数据集分数表',
|
|
461
452
|
'en': 'Dataset Scores Table'
|
|
@@ -511,6 +502,9 @@ def create_single_model_tab(sidebar: SidebarComponents, lang: str):
|
|
|
511
502
|
with gr.Tab(locale_dict['dataset_details'][lang]):
|
|
512
503
|
dataset_radio = gr.Radio(
|
|
513
504
|
label=locale_dict['select_dataset'][lang], choices=[], show_label=True, interactive=True)
|
|
505
|
+
# show dataset details
|
|
506
|
+
with gr.Accordion(locale_dict['report_analysis'][lang], open=True):
|
|
507
|
+
report_analysis = gr.Markdown(value='N/A', show_copy_button=True)
|
|
514
508
|
gr.Markdown(f'### {locale_dict["dataset_scores"][lang]}')
|
|
515
509
|
dataset_plot = gr.Plot(value=None, scale=1, label=locale_dict['dataset_scores'][lang])
|
|
516
510
|
gr.Markdown(f'### {locale_dict["dataset_scores_table"][lang]}')
|
|
@@ -586,15 +580,16 @@ def create_single_model_tab(sidebar: SidebarComponents, lang: str):
|
|
|
586
580
|
@gr.on(
|
|
587
581
|
triggers=[dataset_radio.change, report_list.change],
|
|
588
582
|
inputs=[dataset_radio, report_list],
|
|
589
|
-
outputs=[dataset_plot, dataset_table, subset_select, data_review_df])
|
|
583
|
+
outputs=[dataset_plot, dataset_table, subset_select, data_review_df, report_analysis])
|
|
590
584
|
def update_single_report_dataset(dataset_name, report_list):
|
|
591
585
|
logger.debug(f'Updating single report dataset: {dataset_name}')
|
|
592
586
|
report_df = get_data_frame(report_list)
|
|
587
|
+
analysis = get_report_analysis(report_list, dataset_name)
|
|
593
588
|
data_score_df, styler = get_single_dataset_df(report_df, dataset_name)
|
|
594
589
|
data_score_plot = plot_single_dataset_scores(data_score_df)
|
|
595
590
|
subsets = data_score_df[ReportKey.subset_name].unique().tolist()
|
|
596
591
|
logger.debug(f'subsets: {subsets}')
|
|
597
|
-
return data_score_plot, styler, gr.update(choices=subsets, value=None), None
|
|
592
|
+
return data_score_plot, styler, gr.update(choices=subsets, value=None), None, analysis
|
|
598
593
|
|
|
599
594
|
@gr.on(
|
|
600
595
|
triggers=[subset_select.change],
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
PLOTLY_THEME = 'plotly_dark'
|
|
2
|
+
REPORT_TOKEN = '@@'
|
|
3
|
+
MODEL_TOKEN = '::'
|
|
4
|
+
DATASET_TOKEN = ', '
|
|
5
|
+
LATEX_DELIMITERS = [{
|
|
6
|
+
'left': '$$',
|
|
7
|
+
'right': '$$',
|
|
8
|
+
'display': True
|
|
9
|
+
}, {
|
|
10
|
+
'left': '$',
|
|
11
|
+
'right': '$',
|
|
12
|
+
'display': False
|
|
13
|
+
}, {
|
|
14
|
+
'left': '\\(',
|
|
15
|
+
'right': '\\)',
|
|
16
|
+
'display': False
|
|
17
|
+
}, {
|
|
18
|
+
'left': '\\[',
|
|
19
|
+
'right': '\\]',
|
|
20
|
+
'display': True
|
|
21
|
+
}]
|
evalscope/arguments.py
CHANGED
|
@@ -67,7 +67,7 @@ def add_argument(parser: argparse.ArgumentParser):
|
|
|
67
67
|
parser.add_argument('--eval-config', type=str, required=False, help='The eval task config file path for evaluation backend.') # noqa: E501
|
|
68
68
|
parser.add_argument('--stage', type=str, default='all', help='The stage of evaluation pipeline.',
|
|
69
69
|
choices=[EvalStage.ALL, EvalStage.INFER, EvalStage.REVIEW])
|
|
70
|
-
parser.add_argument('--limit', type=
|
|
70
|
+
parser.add_argument('--limit', type=float, default=None, help='Max evaluation samples num for each subset.')
|
|
71
71
|
parser.add_argument('--eval-batch-size', type=int, default=1, help='The batch size for evaluation.')
|
|
72
72
|
|
|
73
73
|
# Cache and working directory arguments
|
|
@@ -89,6 +89,7 @@ def add_argument(parser: argparse.ArgumentParser):
|
|
|
89
89
|
parser.add_argument('--judge-strategy', type=str, default=JudgeStrategy.AUTO, help='The judge strategy.')
|
|
90
90
|
parser.add_argument('--judge-model-args', type=json.loads, default='{}', help='The judge model args, should be a json string.') # noqa: E501
|
|
91
91
|
parser.add_argument('--judge-worker-num', type=int, default=1, help='The number of workers for the judge model.')
|
|
92
|
+
parser.add_argument('--analysis-report', action='store_true', default=False, help='Generate analysis report for the evaluation results using judge model.') # noqa: E501
|
|
92
93
|
# yapf: enable
|
|
93
94
|
|
|
94
95
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
import os
|
|
2
3
|
import subprocess
|
|
3
4
|
import tempfile
|
|
4
5
|
from dataclasses import asdict
|
|
@@ -204,7 +205,7 @@ class OpenCompassBackendManager(BackendManager):
|
|
|
204
205
|
model_d['meta_template'] = get_template(model_d['meta_template'])
|
|
205
206
|
|
|
206
207
|
# set the 'abbr' as the 'path' if 'abbr' is not specified
|
|
207
|
-
model_d['abbr'] = model_d['path']
|
|
208
|
+
model_d['abbr'] = os.path.basename(model_d['path'])
|
|
208
209
|
|
|
209
210
|
model_config = ApiModelConfig(**model_d)
|
|
210
211
|
models.append(asdict(model_config))
|
|
@@ -11,7 +11,9 @@ class ModelArguments:
|
|
|
11
11
|
pooling_mode: Optional[str] = None
|
|
12
12
|
max_seq_length: int = 512 # max sequence length
|
|
13
13
|
# prompt for llm based model
|
|
14
|
-
prompt: str =
|
|
14
|
+
prompt: Optional[str] = None
|
|
15
|
+
# prompts dictionary for different tasks, if prompt is not set
|
|
16
|
+
prompts: Optional[Dict[str, str]] = None
|
|
15
17
|
# model kwargs
|
|
16
18
|
model_kwargs: dict = field(default_factory=dict)
|
|
17
19
|
# config kwargs
|
|
@@ -33,6 +35,7 @@ class ModelArguments:
|
|
|
33
35
|
'pooling_mode': self.pooling_mode,
|
|
34
36
|
'max_seq_length': self.max_seq_length,
|
|
35
37
|
'prompt': self.prompt,
|
|
38
|
+
'prompts': self.prompts,
|
|
36
39
|
'model_kwargs': self.model_kwargs,
|
|
37
40
|
'config_kwargs': self.config_kwargs,
|
|
38
41
|
'encode_kwargs': self.encode_kwargs,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import mteb
|
|
2
2
|
import os
|
|
3
|
-
from
|
|
3
|
+
from tabulate import tabulate
|
|
4
4
|
|
|
5
5
|
from evalscope.backend.rag_eval import EmbeddingModel, cmteb
|
|
6
6
|
from evalscope.utils.logger import get_logger
|
|
@@ -12,14 +12,27 @@ def show_results(output_folder, model, results):
|
|
|
12
12
|
model_name = model.mteb_model_meta.model_name_as_path()
|
|
13
13
|
revision = model.mteb_model_meta.revision
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
data = []
|
|
16
|
+
for model_res in results:
|
|
17
|
+
main_res = model_res.only_main_score()
|
|
18
|
+
for split, score in main_res.scores.items():
|
|
19
|
+
for sub_score in score:
|
|
20
|
+
data.append({
|
|
21
|
+
'Model': model_name.replace('eval__', ''),
|
|
22
|
+
'Revision': revision,
|
|
23
|
+
'Task Type': main_res.task_type,
|
|
24
|
+
'Task': main_res.task_name,
|
|
25
|
+
'Split': split,
|
|
26
|
+
'Subset': sub_score['hf_subset'],
|
|
27
|
+
'Main Score': sub_score['main_score'],
|
|
28
|
+
})
|
|
16
29
|
|
|
17
30
|
save_path = os.path.join(
|
|
18
31
|
output_folder,
|
|
19
32
|
model_name,
|
|
20
33
|
revision,
|
|
21
34
|
)
|
|
22
|
-
logger.info(f'Evaluation results:\n{
|
|
35
|
+
logger.info(f'Evaluation results:\n{tabulate(data, headers="keys", tablefmt="grid")}')
|
|
23
36
|
logger.info(f'Evaluation results saved in {os.path.abspath(save_path)}')
|
|
24
37
|
|
|
25
38
|
|
|
@@ -34,6 +47,7 @@ def one_stage_eval(
|
|
|
34
47
|
tasks = cmteb.TaskBase.get_tasks(task_names=eval_args['tasks'], dataset_path=custom_dataset_path)
|
|
35
48
|
evaluation = mteb.MTEB(tasks=tasks)
|
|
36
49
|
|
|
50
|
+
eval_args['encode_kwargs'] = model_args.get('encode_kwargs', {})
|
|
37
51
|
# run evaluation
|
|
38
52
|
results = evaluation.run(model, **eval_args)
|
|
39
53
|
|
|
@@ -66,6 +80,7 @@ def two_stage_eval(
|
|
|
66
80
|
overwrite_results=True,
|
|
67
81
|
hub=eval_args['hub'],
|
|
68
82
|
limits=eval_args['limits'],
|
|
83
|
+
encode_kwargs=model1_args.get('encode_kwargs', {}),
|
|
69
84
|
)
|
|
70
85
|
# stage 2: run cross encoder
|
|
71
86
|
results = evaluation.run(
|
|
@@ -77,6 +92,7 @@ def two_stage_eval(
|
|
|
77
92
|
overwrite_results=True,
|
|
78
93
|
hub=eval_args['hub'],
|
|
79
94
|
limits=eval_args['limits'],
|
|
95
|
+
encode_kwargs=model2_args.get('encode_kwargs', {}),
|
|
80
96
|
)
|
|
81
97
|
|
|
82
98
|
# save and log results
|
|
@@ -9,7 +9,6 @@ class CustomRetrieval(AbsTaskRetrieval):
|
|
|
9
9
|
ignore_identical_ids: bool = True
|
|
10
10
|
|
|
11
11
|
def __init__(self, dataset_path: Optional[str] = 'custom_eval/text/retrieval', **kwargs):
|
|
12
|
-
super().__init__(**kwargs)
|
|
13
12
|
self.metadata = TaskMetadata(
|
|
14
13
|
name='CustomRetrieval',
|
|
15
14
|
description='CustomRetrieval Task',
|
|
@@ -34,6 +33,7 @@ class CustomRetrieval(AbsTaskRetrieval):
|
|
|
34
33
|
bibtex_citation='',
|
|
35
34
|
descriptive_stats={},
|
|
36
35
|
)
|
|
36
|
+
super().__init__(**kwargs)
|
|
37
37
|
|
|
38
38
|
def load_data(self, **kwargs):
|
|
39
39
|
if self.data_loaded:
|
|
@@ -2,6 +2,7 @@ import os
|
|
|
2
2
|
import torch
|
|
3
3
|
from langchain_core.embeddings import Embeddings
|
|
4
4
|
from langchain_openai.embeddings import OpenAIEmbeddings
|
|
5
|
+
from mteb.encoder_interface import PromptType
|
|
5
6
|
from sentence_transformers import models
|
|
6
7
|
from sentence_transformers.cross_encoder import CrossEncoder
|
|
7
8
|
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
|
@@ -12,6 +13,7 @@ from typing import Dict, List, Optional, Union
|
|
|
12
13
|
from evalscope.backend.rag_eval.utils.tools import download_model
|
|
13
14
|
from evalscope.constants import HubType
|
|
14
15
|
from evalscope.utils.logger import get_logger
|
|
16
|
+
from evalscope.utils.utils import get_supported_params
|
|
15
17
|
|
|
16
18
|
logger = get_logger()
|
|
17
19
|
|
|
@@ -22,14 +24,14 @@ class BaseModel(Embeddings):
|
|
|
22
24
|
self,
|
|
23
25
|
model_name_or_path: str = '',
|
|
24
26
|
max_seq_length: int = 512,
|
|
25
|
-
prompt: str =
|
|
27
|
+
prompt: Optional[str] = None,
|
|
28
|
+
prompts: Optional[Dict[str, str]] = None,
|
|
26
29
|
revision: Optional[str] = 'master',
|
|
27
30
|
**kwargs,
|
|
28
31
|
):
|
|
29
32
|
self.model_name_or_path = model_name_or_path
|
|
30
33
|
self.max_seq_length = max_seq_length
|
|
31
34
|
self.model_kwargs = kwargs.pop('model_kwargs', {})
|
|
32
|
-
self.model_kwargs['trust_remote_code'] = True
|
|
33
35
|
|
|
34
36
|
self.config_kwargs = kwargs.pop('config_kwargs', {})
|
|
35
37
|
self.config_kwargs['trust_remote_code'] = True
|
|
@@ -38,7 +40,9 @@ class BaseModel(Embeddings):
|
|
|
38
40
|
self.encode_kwargs['convert_to_tensor'] = True
|
|
39
41
|
|
|
40
42
|
self.prompt = prompt
|
|
43
|
+
self.prompts = prompts if prompts else {}
|
|
41
44
|
self.revision = revision
|
|
45
|
+
self.framework = ['PyTorch']
|
|
42
46
|
|
|
43
47
|
@property
|
|
44
48
|
def mteb_model_meta(self):
|
|
@@ -46,10 +50,22 @@ class BaseModel(Embeddings):
|
|
|
46
50
|
from mteb import ModelMeta
|
|
47
51
|
|
|
48
52
|
return ModelMeta(
|
|
49
|
-
name=os.path.basename(self.model_name_or_path),
|
|
53
|
+
name='eval/' + os.path.basename(self.model_name_or_path), # Ensure the name contains a slash
|
|
50
54
|
revision=self.revision,
|
|
51
55
|
languages=None,
|
|
52
56
|
release_date=None,
|
|
57
|
+
n_parameters=None,
|
|
58
|
+
memory_usage_mb=None,
|
|
59
|
+
max_tokens=None,
|
|
60
|
+
embed_dim=None,
|
|
61
|
+
license=None,
|
|
62
|
+
open_weights=None,
|
|
63
|
+
public_training_code=None,
|
|
64
|
+
public_training_data=None,
|
|
65
|
+
similarity_fn_name=None,
|
|
66
|
+
use_instructions=None,
|
|
67
|
+
training_datasets=None,
|
|
68
|
+
framework=self.framework,
|
|
53
69
|
)
|
|
54
70
|
|
|
55
71
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
@@ -61,7 +77,7 @@ class BaseModel(Embeddings):
|
|
|
61
77
|
Returns:
|
|
62
78
|
List of embeddings.
|
|
63
79
|
"""
|
|
64
|
-
return self.
|
|
80
|
+
return self.encode(texts).tolist()
|
|
65
81
|
|
|
66
82
|
def embed_query(self, text: str) -> List[float]:
|
|
67
83
|
"""Embed query text. Compact langchain.
|
|
@@ -72,19 +88,17 @@ class BaseModel(Embeddings):
|
|
|
72
88
|
Returns:
|
|
73
89
|
Embedding.
|
|
74
90
|
"""
|
|
75
|
-
return self.
|
|
91
|
+
return self.encode(text).tolist()
|
|
76
92
|
|
|
77
93
|
def encode(self, texts: Union[str, List[str]], **kwargs) -> List[List[float]]:
|
|
78
94
|
"""Embed text."""
|
|
79
95
|
raise NotImplementedError
|
|
80
96
|
|
|
81
|
-
def
|
|
82
|
-
"""
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
"""Embed search docs . Compact mteb."""
|
|
87
|
-
raise NotImplementedError
|
|
97
|
+
def get_prompt(self, task_name: str) -> Optional[str]:
|
|
98
|
+
"""Get prompt for the given task name."""
|
|
99
|
+
if self.prompt:
|
|
100
|
+
return self.prompt
|
|
101
|
+
return self.prompts.get(task_name, None)
|
|
88
102
|
|
|
89
103
|
|
|
90
104
|
class SentenceTransformerModel(BaseModel):
|
|
@@ -92,6 +106,9 @@ class SentenceTransformerModel(BaseModel):
|
|
|
92
106
|
def __init__(self, model_name_or_path: str, pooling_mode: Optional[str] = None, **kwargs):
|
|
93
107
|
super().__init__(model_name_or_path, **kwargs)
|
|
94
108
|
|
|
109
|
+
self.framework = ['Sentence Transformers', 'PyTorch']
|
|
110
|
+
|
|
111
|
+
self.model_kwargs['trust_remote_code'] = True
|
|
95
112
|
if not pooling_mode:
|
|
96
113
|
self.model = SentenceTransformer(
|
|
97
114
|
self.model_name_or_path,
|
|
@@ -112,36 +129,47 @@ class SentenceTransformerModel(BaseModel):
|
|
|
112
129
|
|
|
113
130
|
self.model.max_seq_length = self.max_seq_length
|
|
114
131
|
|
|
115
|
-
|
|
116
|
-
|
|
132
|
+
self.supported_encode_params = get_supported_params(self.model.encode)
|
|
133
|
+
|
|
134
|
+
def encode(self, texts: Union[str, List[str]], **kwargs) -> List[torch.Tensor]:
|
|
135
|
+
# pop unused kwargs
|
|
136
|
+
extra_params = {}
|
|
137
|
+
for key in list(kwargs.keys()):
|
|
138
|
+
if key not in self.supported_encode_params:
|
|
139
|
+
extra_params[key] = kwargs.pop(key)
|
|
117
140
|
self.encode_kwargs.update(kwargs)
|
|
118
141
|
|
|
142
|
+
# set prompt if provided
|
|
143
|
+
prompt = None
|
|
144
|
+
prompt_type = extra_params.pop('prompt_type', '')
|
|
145
|
+
task_name = extra_params.pop('task_name', '')
|
|
146
|
+
if prompt_type and prompt_type == PromptType.query:
|
|
147
|
+
prompt = self.get_prompt(task_name)
|
|
148
|
+
|
|
119
149
|
embeddings = self.model.encode(texts, prompt=prompt, **self.encode_kwargs)
|
|
120
150
|
assert isinstance(embeddings, Tensor)
|
|
121
151
|
return embeddings.cpu().detach()
|
|
122
152
|
|
|
123
|
-
def encode_queries(self, queries, **kwargs):
|
|
124
|
-
return self.encode(queries, prompt=self.prompt)
|
|
125
|
-
|
|
126
|
-
def encode_corpus(self, corpus, **kwargs):
|
|
127
|
-
if isinstance(corpus[0], dict):
|
|
128
|
-
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
|
|
129
|
-
else:
|
|
130
|
-
input_texts = corpus
|
|
131
|
-
return self.encode(input_texts)
|
|
132
|
-
|
|
133
153
|
|
|
134
154
|
class CrossEncoderModel(BaseModel):
|
|
135
155
|
|
|
136
156
|
def __init__(self, model_name_or_path: str, **kwargs):
|
|
137
157
|
super().__init__(model_name_or_path, **kwargs)
|
|
158
|
+
|
|
159
|
+
self.framework = ['Sentence Transformers', 'PyTorch']
|
|
160
|
+
|
|
138
161
|
self.model = CrossEncoder(
|
|
139
162
|
self.model_name_or_path,
|
|
140
163
|
trust_remote_code=True,
|
|
141
164
|
max_length=self.max_seq_length,
|
|
165
|
+
automodel_args=self.model_kwargs,
|
|
142
166
|
)
|
|
167
|
+
self.supported_encode_params = get_supported_params(self.model.predict)
|
|
143
168
|
|
|
144
169
|
def predict(self, sentences: List[List[str]], **kwargs) -> Tensor:
|
|
170
|
+
for key in list(kwargs.keys()):
|
|
171
|
+
if key not in self.supported_encode_params:
|
|
172
|
+
kwargs.pop(key)
|
|
145
173
|
self.encode_kwargs.update(kwargs)
|
|
146
174
|
|
|
147
175
|
if len(sentences[0]) == 3: # Note: For mteb retrieval task
|
|
@@ -163,6 +191,7 @@ class APIEmbeddingModel(BaseModel):
|
|
|
163
191
|
self.openai_api_base = kwargs.get('api_base')
|
|
164
192
|
self.openai_api_key = kwargs.get('api_key')
|
|
165
193
|
self.dimensions = kwargs.get('dimensions')
|
|
194
|
+
self.framework = ['API']
|
|
166
195
|
|
|
167
196
|
self.model = OpenAIEmbeddings(
|
|
168
197
|
model=self.model_name,
|
|
@@ -175,26 +204,37 @@ class APIEmbeddingModel(BaseModel):
|
|
|
175
204
|
|
|
176
205
|
self.batch_size = self.encode_kwargs.get('batch_size', 10)
|
|
177
206
|
|
|
207
|
+
self.supported_encode_params = get_supported_params(self.model.embed_documents)
|
|
208
|
+
|
|
178
209
|
def encode(self, texts: Union[str, List[str]], **kwargs) -> Tensor:
|
|
210
|
+
# pop unused kwargs
|
|
211
|
+
extra_params = {}
|
|
212
|
+
for key in list(kwargs.keys()):
|
|
213
|
+
if key not in self.supported_encode_params:
|
|
214
|
+
extra_params[key] = kwargs.pop(key)
|
|
215
|
+
self.encode_kwargs.update(kwargs)
|
|
216
|
+
|
|
217
|
+
# set prompt if provided
|
|
218
|
+
prompt = None
|
|
219
|
+
prompt_type = extra_params.pop('prompt_type', '')
|
|
220
|
+
task_name = extra_params.pop('task_name', '')
|
|
221
|
+
if prompt_type and prompt_type == PromptType.query:
|
|
222
|
+
prompt = self.get_prompt(task_name)
|
|
223
|
+
|
|
179
224
|
if isinstance(texts, str):
|
|
180
225
|
texts = [texts]
|
|
181
226
|
|
|
182
227
|
embeddings: List[List[float]] = []
|
|
183
228
|
for i in tqdm(range(0, len(texts), self.batch_size)):
|
|
184
|
-
|
|
229
|
+
# set prompt if provided
|
|
230
|
+
if prompt is not None:
|
|
231
|
+
batch_texts = [prompt + text for text in texts[i:i + self.batch_size]]
|
|
232
|
+
else:
|
|
233
|
+
batch_texts = texts[i:i + self.batch_size]
|
|
234
|
+
response = self.model.embed_documents(batch_texts, chunk_size=self.batch_size)
|
|
185
235
|
embeddings.extend(response)
|
|
186
236
|
return torch.tensor(embeddings)
|
|
187
237
|
|
|
188
|
-
def encode_queries(self, queries, **kwargs):
|
|
189
|
-
return self.encode(queries, **kwargs)
|
|
190
|
-
|
|
191
|
-
def encode_corpus(self, corpus, **kwargs):
|
|
192
|
-
if isinstance(corpus[0], dict):
|
|
193
|
-
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
|
|
194
|
-
else:
|
|
195
|
-
input_texts = corpus
|
|
196
|
-
return self.encode(input_texts, **kwargs)
|
|
197
|
-
|
|
198
238
|
|
|
199
239
|
class EmbeddingModel:
|
|
200
240
|
"""Custom embeddings"""
|
|
@@ -28,6 +28,7 @@ class BenchmarkMeta:
|
|
|
28
28
|
system_prompt: Optional[str] = None
|
|
29
29
|
query_template: Optional[str] = None
|
|
30
30
|
pretty_name: Optional[str] = None
|
|
31
|
+
description: Optional[str] = None
|
|
31
32
|
filters: Optional[OrderedDict] = None
|
|
32
33
|
extra_params: Optional[Dict] = field(default_factory=dict)
|
|
33
34
|
|