evalscope 0.16.0__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (61) hide show
  1. evalscope/app/__init__.py +28 -0
  2. evalscope/{report → app}/app.py +20 -25
  3. evalscope/app/constants.py +21 -0
  4. evalscope/arguments.py +2 -1
  5. evalscope/backend/opencompass/backend_manager.py +2 -1
  6. evalscope/backend/rag_eval/cmteb/arguments.py +4 -1
  7. evalscope/backend/rag_eval/cmteb/task_template.py +19 -3
  8. evalscope/backend/rag_eval/cmteb/tasks/CustomTask.py +1 -1
  9. evalscope/backend/rag_eval/utils/embedding.py +75 -35
  10. evalscope/benchmarks/benchmark.py +1 -0
  11. evalscope/benchmarks/data_adapter.py +97 -16
  12. evalscope/benchmarks/docmath/__init__.py +0 -0
  13. evalscope/benchmarks/docmath/docmath_adapter.py +84 -0
  14. evalscope/benchmarks/docmath/utils.py +220 -0
  15. evalscope/benchmarks/frames/__init__.py +0 -0
  16. evalscope/benchmarks/frames/frames_adapter.py +90 -0
  17. evalscope/benchmarks/frames/utils.py +37 -0
  18. evalscope/benchmarks/needle_haystack/__init__.py +0 -0
  19. evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +341 -0
  20. evalscope/benchmarks/needle_haystack/utils.py +79 -0
  21. evalscope/benchmarks/tool_bench/tool_bench_adapter.py +4 -1
  22. evalscope/benchmarks/tool_bench/utils.py +5 -4
  23. evalscope/benchmarks/utils.py +25 -0
  24. evalscope/cli/start_app.py +2 -2
  25. evalscope/collections/__init__.py +35 -3
  26. evalscope/collections/evaluator.py +18 -6
  27. evalscope/config.py +8 -2
  28. evalscope/evaluator/evaluator.py +38 -27
  29. evalscope/metrics/__init__.py +3 -1
  30. evalscope/metrics/bundled_rouge_score/rouge_scorer.py +1 -1
  31. evalscope/metrics/llm_judge.py +12 -5
  32. evalscope/metrics/math_parser.py +1 -1
  33. evalscope/models/adapters/server_adapter.py +2 -6
  34. evalscope/perf/arguments.py +2 -2
  35. evalscope/perf/benchmark.py +0 -9
  36. evalscope/perf/main.py +7 -0
  37. evalscope/perf/plugin/datasets/custom.py +15 -0
  38. evalscope/perf/utils/benchmark_util.py +1 -1
  39. evalscope/perf/utils/local_server.py +1 -0
  40. evalscope/perf/utils/log_utils.py +12 -5
  41. evalscope/perf/utils/rich_display.py +1 -1
  42. evalscope/report/__init__.py +36 -4
  43. evalscope/report/combinator.py +8 -0
  44. evalscope/report/generator.py +33 -9
  45. evalscope/report/utils.py +60 -3
  46. evalscope/run.py +12 -0
  47. evalscope/utils/logger.py +1 -1
  48. evalscope/utils/utils.py +12 -0
  49. evalscope/version.py +2 -2
  50. {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/METADATA +13 -11
  51. {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/RECORD +61 -50
  52. tests/aigc/test_t2i.py +40 -3
  53. tests/cli/test_all.py +39 -35
  54. tests/cli/test_collection.py +7 -6
  55. tests/cli/test_run.py +21 -11
  56. tests/rag/test_mteb.py +5 -5
  57. /evalscope/{report/app_arguments.py → app/arguments.py} +0 -0
  58. {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/LICENSE +0 -0
  59. {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/WHEEL +0 -0
  60. {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/entry_points.txt +0 -0
  61. {evalscope-0.16.0.dist-info → evalscope-0.16.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,28 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import TYPE_CHECKING
3
+
4
+ from evalscope.utils.import_utils import _LazyModule
5
+
6
+ if TYPE_CHECKING:
7
+ from .app import create_app
8
+ from .arguments import add_argument
9
+
10
+ else:
11
+ _import_structure = {
12
+ 'app': [
13
+ 'create_app',
14
+ ],
15
+ 'arguments': [
16
+ 'add_argument',
17
+ ],
18
+ }
19
+
20
+ import sys
21
+
22
+ sys.modules[__name__] = _LazyModule(
23
+ __name__,
24
+ globals()['__file__'],
25
+ _import_structure,
26
+ module_spec=__spec__,
27
+ extra_objects={},
28
+ )
@@ -11,35 +11,15 @@ from dataclasses import dataclass
11
11
  from typing import Any, List, Union
12
12
 
13
13
  from evalscope.constants import DataCollection
14
- from evalscope.report import Report, ReportKey, add_argument, get_data_frame, get_report_list
14
+ from evalscope.report import Report, ReportKey, get_data_frame, get_report_list
15
15
  from evalscope.utils.io_utils import OutputsStructure, yaml_to_dict
16
16
  from evalscope.utils.logger import configure_logging, get_logger
17
17
  from evalscope.version import __version__
18
+ from .arguments import add_argument
19
+ from .constants import DATASET_TOKEN, LATEX_DELIMITERS, MODEL_TOKEN, PLOTLY_THEME, REPORT_TOKEN
18
20
 
19
21
  logger = get_logger()
20
22
 
21
- PLOTLY_THEME = 'plotly_dark'
22
- REPORT_TOKEN = '@@'
23
- MODEL_TOKEN = '::'
24
- DATASET_TOKEN = ', '
25
- LATEX_DELIMITERS = [{
26
- 'left': '$$',
27
- 'right': '$$',
28
- 'display': True
29
- }, {
30
- 'left': '$',
31
- 'right': '$',
32
- 'display': False
33
- }, {
34
- 'left': '\\(',
35
- 'right': '\\)',
36
- 'display': False
37
- }, {
38
- 'left': '\\[',
39
- 'right': '\\]',
40
- 'display': True
41
- }]
42
-
43
23
 
44
24
  def scan_for_report_folders(root_path):
45
25
  """Scan for folders containing reports subdirectories"""
@@ -185,6 +165,13 @@ def get_single_dataset_df(df: pd.DataFrame, dataset_name: str):
185
165
  return df, styler
186
166
 
187
167
 
168
+ def get_report_analysis(report_list: List[Report], dataset_name: str) -> str:
169
+ for report in report_list:
170
+ if report.dataset_name == dataset_name:
171
+ return report.analysis
172
+ return 'N/A'
173
+
174
+
188
175
  def plot_single_dataset_scores(df: pd.DataFrame):
189
176
  # TODO: add metric radio and relace category name
190
177
  plot = px.bar(
@@ -456,6 +443,10 @@ def create_single_model_tab(sidebar: SidebarComponents, lang: str):
456
443
  'zh': '数据集分数',
457
444
  'en': 'Dataset Scores'
458
445
  },
446
+ 'report_analysis': {
447
+ 'zh': '报告智能分析',
448
+ 'en': 'Report Intelligent Analysis'
449
+ },
459
450
  'dataset_scores_table': {
460
451
  'zh': '数据集分数表',
461
452
  'en': 'Dataset Scores Table'
@@ -511,6 +502,9 @@ def create_single_model_tab(sidebar: SidebarComponents, lang: str):
511
502
  with gr.Tab(locale_dict['dataset_details'][lang]):
512
503
  dataset_radio = gr.Radio(
513
504
  label=locale_dict['select_dataset'][lang], choices=[], show_label=True, interactive=True)
505
+ # show dataset details
506
+ with gr.Accordion(locale_dict['report_analysis'][lang], open=True):
507
+ report_analysis = gr.Markdown(value='N/A', show_copy_button=True)
514
508
  gr.Markdown(f'### {locale_dict["dataset_scores"][lang]}')
515
509
  dataset_plot = gr.Plot(value=None, scale=1, label=locale_dict['dataset_scores'][lang])
516
510
  gr.Markdown(f'### {locale_dict["dataset_scores_table"][lang]}')
@@ -586,15 +580,16 @@ def create_single_model_tab(sidebar: SidebarComponents, lang: str):
586
580
  @gr.on(
587
581
  triggers=[dataset_radio.change, report_list.change],
588
582
  inputs=[dataset_radio, report_list],
589
- outputs=[dataset_plot, dataset_table, subset_select, data_review_df])
583
+ outputs=[dataset_plot, dataset_table, subset_select, data_review_df, report_analysis])
590
584
  def update_single_report_dataset(dataset_name, report_list):
591
585
  logger.debug(f'Updating single report dataset: {dataset_name}')
592
586
  report_df = get_data_frame(report_list)
587
+ analysis = get_report_analysis(report_list, dataset_name)
593
588
  data_score_df, styler = get_single_dataset_df(report_df, dataset_name)
594
589
  data_score_plot = plot_single_dataset_scores(data_score_df)
595
590
  subsets = data_score_df[ReportKey.subset_name].unique().tolist()
596
591
  logger.debug(f'subsets: {subsets}')
597
- return data_score_plot, styler, gr.update(choices=subsets, value=None), None
592
+ return data_score_plot, styler, gr.update(choices=subsets, value=None), None, analysis
598
593
 
599
594
  @gr.on(
600
595
  triggers=[subset_select.change],
@@ -0,0 +1,21 @@
1
+ PLOTLY_THEME = 'plotly_dark'
2
+ REPORT_TOKEN = '@@'
3
+ MODEL_TOKEN = '::'
4
+ DATASET_TOKEN = ', '
5
+ LATEX_DELIMITERS = [{
6
+ 'left': '$$',
7
+ 'right': '$$',
8
+ 'display': True
9
+ }, {
10
+ 'left': '$',
11
+ 'right': '$',
12
+ 'display': False
13
+ }, {
14
+ 'left': '\\(',
15
+ 'right': '\\)',
16
+ 'display': False
17
+ }, {
18
+ 'left': '\\[',
19
+ 'right': '\\]',
20
+ 'display': True
21
+ }]
evalscope/arguments.py CHANGED
@@ -67,7 +67,7 @@ def add_argument(parser: argparse.ArgumentParser):
67
67
  parser.add_argument('--eval-config', type=str, required=False, help='The eval task config file path for evaluation backend.') # noqa: E501
68
68
  parser.add_argument('--stage', type=str, default='all', help='The stage of evaluation pipeline.',
69
69
  choices=[EvalStage.ALL, EvalStage.INFER, EvalStage.REVIEW])
70
- parser.add_argument('--limit', type=int, default=None, help='Max evaluation samples num for each subset.')
70
+ parser.add_argument('--limit', type=float, default=None, help='Max evaluation samples num for each subset.')
71
71
  parser.add_argument('--eval-batch-size', type=int, default=1, help='The batch size for evaluation.')
72
72
 
73
73
  # Cache and working directory arguments
@@ -89,6 +89,7 @@ def add_argument(parser: argparse.ArgumentParser):
89
89
  parser.add_argument('--judge-strategy', type=str, default=JudgeStrategy.AUTO, help='The judge strategy.')
90
90
  parser.add_argument('--judge-model-args', type=json.loads, default='{}', help='The judge model args, should be a json string.') # noqa: E501
91
91
  parser.add_argument('--judge-worker-num', type=int, default=1, help='The number of workers for the judge model.')
92
+ parser.add_argument('--analysis-report', action='store_true', default=False, help='Generate analysis report for the evaluation results using judge model.') # noqa: E501
92
93
  # yapf: enable
93
94
 
94
95
 
@@ -1,4 +1,5 @@
1
1
  # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
2
3
  import subprocess
3
4
  import tempfile
4
5
  from dataclasses import asdict
@@ -204,7 +205,7 @@ class OpenCompassBackendManager(BackendManager):
204
205
  model_d['meta_template'] = get_template(model_d['meta_template'])
205
206
 
206
207
  # set the 'abbr' as the 'path' if 'abbr' is not specified
207
- model_d['abbr'] = model_d['path']
208
+ model_d['abbr'] = os.path.basename(model_d['path'])
208
209
 
209
210
  model_config = ApiModelConfig(**model_d)
210
211
  models.append(asdict(model_config))
@@ -11,7 +11,9 @@ class ModelArguments:
11
11
  pooling_mode: Optional[str] = None
12
12
  max_seq_length: int = 512 # max sequence length
13
13
  # prompt for llm based model
14
- prompt: str = ''
14
+ prompt: Optional[str] = None
15
+ # prompts dictionary for different tasks, if prompt is not set
16
+ prompts: Optional[Dict[str, str]] = None
15
17
  # model kwargs
16
18
  model_kwargs: dict = field(default_factory=dict)
17
19
  # config kwargs
@@ -33,6 +35,7 @@ class ModelArguments:
33
35
  'pooling_mode': self.pooling_mode,
34
36
  'max_seq_length': self.max_seq_length,
35
37
  'prompt': self.prompt,
38
+ 'prompts': self.prompts,
36
39
  'model_kwargs': self.model_kwargs,
37
40
  'config_kwargs': self.config_kwargs,
38
41
  'encode_kwargs': self.encode_kwargs,
@@ -1,6 +1,6 @@
1
1
  import mteb
2
2
  import os
3
- from mteb.task_selection import results_to_dataframe
3
+ from tabulate import tabulate
4
4
 
5
5
  from evalscope.backend.rag_eval import EmbeddingModel, cmteb
6
6
  from evalscope.utils.logger import get_logger
@@ -12,14 +12,27 @@ def show_results(output_folder, model, results):
12
12
  model_name = model.mteb_model_meta.model_name_as_path()
13
13
  revision = model.mteb_model_meta.revision
14
14
 
15
- results_df = results_to_dataframe({model_name: {revision: results}})
15
+ data = []
16
+ for model_res in results:
17
+ main_res = model_res.only_main_score()
18
+ for split, score in main_res.scores.items():
19
+ for sub_score in score:
20
+ data.append({
21
+ 'Model': model_name.replace('eval__', ''),
22
+ 'Revision': revision,
23
+ 'Task Type': main_res.task_type,
24
+ 'Task': main_res.task_name,
25
+ 'Split': split,
26
+ 'Subset': sub_score['hf_subset'],
27
+ 'Main Score': sub_score['main_score'],
28
+ })
16
29
 
17
30
  save_path = os.path.join(
18
31
  output_folder,
19
32
  model_name,
20
33
  revision,
21
34
  )
22
- logger.info(f'Evaluation results:\n{results_df.to_markdown()}')
35
+ logger.info(f'Evaluation results:\n{tabulate(data, headers="keys", tablefmt="grid")}')
23
36
  logger.info(f'Evaluation results saved in {os.path.abspath(save_path)}')
24
37
 
25
38
 
@@ -34,6 +47,7 @@ def one_stage_eval(
34
47
  tasks = cmteb.TaskBase.get_tasks(task_names=eval_args['tasks'], dataset_path=custom_dataset_path)
35
48
  evaluation = mteb.MTEB(tasks=tasks)
36
49
 
50
+ eval_args['encode_kwargs'] = model_args.get('encode_kwargs', {})
37
51
  # run evaluation
38
52
  results = evaluation.run(model, **eval_args)
39
53
 
@@ -66,6 +80,7 @@ def two_stage_eval(
66
80
  overwrite_results=True,
67
81
  hub=eval_args['hub'],
68
82
  limits=eval_args['limits'],
83
+ encode_kwargs=model1_args.get('encode_kwargs', {}),
69
84
  )
70
85
  # stage 2: run cross encoder
71
86
  results = evaluation.run(
@@ -77,6 +92,7 @@ def two_stage_eval(
77
92
  overwrite_results=True,
78
93
  hub=eval_args['hub'],
79
94
  limits=eval_args['limits'],
95
+ encode_kwargs=model2_args.get('encode_kwargs', {}),
80
96
  )
81
97
 
82
98
  # save and log results
@@ -9,7 +9,6 @@ class CustomRetrieval(AbsTaskRetrieval):
9
9
  ignore_identical_ids: bool = True
10
10
 
11
11
  def __init__(self, dataset_path: Optional[str] = 'custom_eval/text/retrieval', **kwargs):
12
- super().__init__(**kwargs)
13
12
  self.metadata = TaskMetadata(
14
13
  name='CustomRetrieval',
15
14
  description='CustomRetrieval Task',
@@ -34,6 +33,7 @@ class CustomRetrieval(AbsTaskRetrieval):
34
33
  bibtex_citation='',
35
34
  descriptive_stats={},
36
35
  )
36
+ super().__init__(**kwargs)
37
37
 
38
38
  def load_data(self, **kwargs):
39
39
  if self.data_loaded:
@@ -2,6 +2,7 @@ import os
2
2
  import torch
3
3
  from langchain_core.embeddings import Embeddings
4
4
  from langchain_openai.embeddings import OpenAIEmbeddings
5
+ from mteb.encoder_interface import PromptType
5
6
  from sentence_transformers import models
6
7
  from sentence_transformers.cross_encoder import CrossEncoder
7
8
  from sentence_transformers.SentenceTransformer import SentenceTransformer
@@ -12,6 +13,7 @@ from typing import Dict, List, Optional, Union
12
13
  from evalscope.backend.rag_eval.utils.tools import download_model
13
14
  from evalscope.constants import HubType
14
15
  from evalscope.utils.logger import get_logger
16
+ from evalscope.utils.utils import get_supported_params
15
17
 
16
18
  logger = get_logger()
17
19
 
@@ -22,14 +24,14 @@ class BaseModel(Embeddings):
22
24
  self,
23
25
  model_name_or_path: str = '',
24
26
  max_seq_length: int = 512,
25
- prompt: str = '',
27
+ prompt: Optional[str] = None,
28
+ prompts: Optional[Dict[str, str]] = None,
26
29
  revision: Optional[str] = 'master',
27
30
  **kwargs,
28
31
  ):
29
32
  self.model_name_or_path = model_name_or_path
30
33
  self.max_seq_length = max_seq_length
31
34
  self.model_kwargs = kwargs.pop('model_kwargs', {})
32
- self.model_kwargs['trust_remote_code'] = True
33
35
 
34
36
  self.config_kwargs = kwargs.pop('config_kwargs', {})
35
37
  self.config_kwargs['trust_remote_code'] = True
@@ -38,7 +40,9 @@ class BaseModel(Embeddings):
38
40
  self.encode_kwargs['convert_to_tensor'] = True
39
41
 
40
42
  self.prompt = prompt
43
+ self.prompts = prompts if prompts else {}
41
44
  self.revision = revision
45
+ self.framework = ['PyTorch']
42
46
 
43
47
  @property
44
48
  def mteb_model_meta(self):
@@ -46,10 +50,22 @@ class BaseModel(Embeddings):
46
50
  from mteb import ModelMeta
47
51
 
48
52
  return ModelMeta(
49
- name=os.path.basename(self.model_name_or_path),
53
+ name='eval/' + os.path.basename(self.model_name_or_path), # Ensure the name contains a slash
50
54
  revision=self.revision,
51
55
  languages=None,
52
56
  release_date=None,
57
+ n_parameters=None,
58
+ memory_usage_mb=None,
59
+ max_tokens=None,
60
+ embed_dim=None,
61
+ license=None,
62
+ open_weights=None,
63
+ public_training_code=None,
64
+ public_training_data=None,
65
+ similarity_fn_name=None,
66
+ use_instructions=None,
67
+ training_datasets=None,
68
+ framework=self.framework,
53
69
  )
54
70
 
55
71
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -61,7 +77,7 @@ class BaseModel(Embeddings):
61
77
  Returns:
62
78
  List of embeddings.
63
79
  """
64
- return self.encode_corpus(texts).tolist()
80
+ return self.encode(texts).tolist()
65
81
 
66
82
  def embed_query(self, text: str) -> List[float]:
67
83
  """Embed query text. Compact langchain.
@@ -72,19 +88,17 @@ class BaseModel(Embeddings):
72
88
  Returns:
73
89
  Embedding.
74
90
  """
75
- return self.encode_queries(text).tolist()
91
+ return self.encode(text).tolist()
76
92
 
77
93
  def encode(self, texts: Union[str, List[str]], **kwargs) -> List[List[float]]:
78
94
  """Embed text."""
79
95
  raise NotImplementedError
80
96
 
81
- def encode_queries(self, queries: List[str], **kwargs) -> list[torch.Tensor]:
82
- """Embed query text. Compact mteb."""
83
- raise NotImplementedError
84
-
85
- def encode_corpus(self, corpus: Union[List[str], List[Dict[str, str]]], **kwargs) -> list[torch.Tensor]:
86
- """Embed search docs . Compact mteb."""
87
- raise NotImplementedError
97
+ def get_prompt(self, task_name: str) -> Optional[str]:
98
+ """Get prompt for the given task name."""
99
+ if self.prompt:
100
+ return self.prompt
101
+ return self.prompts.get(task_name, None)
88
102
 
89
103
 
90
104
  class SentenceTransformerModel(BaseModel):
@@ -92,6 +106,9 @@ class SentenceTransformerModel(BaseModel):
92
106
  def __init__(self, model_name_or_path: str, pooling_mode: Optional[str] = None, **kwargs):
93
107
  super().__init__(model_name_or_path, **kwargs)
94
108
 
109
+ self.framework = ['Sentence Transformers', 'PyTorch']
110
+
111
+ self.model_kwargs['trust_remote_code'] = True
95
112
  if not pooling_mode:
96
113
  self.model = SentenceTransformer(
97
114
  self.model_name_or_path,
@@ -112,36 +129,47 @@ class SentenceTransformerModel(BaseModel):
112
129
 
113
130
  self.model.max_seq_length = self.max_seq_length
114
131
 
115
- def encode(self, texts: Union[str, List[str]], prompt=None, **kwargs) -> List[torch.Tensor]:
116
- kwargs.pop('prompt_name', '') # remove prompt name, use prompt
132
+ self.supported_encode_params = get_supported_params(self.model.encode)
133
+
134
+ def encode(self, texts: Union[str, List[str]], **kwargs) -> List[torch.Tensor]:
135
+ # pop unused kwargs
136
+ extra_params = {}
137
+ for key in list(kwargs.keys()):
138
+ if key not in self.supported_encode_params:
139
+ extra_params[key] = kwargs.pop(key)
117
140
  self.encode_kwargs.update(kwargs)
118
141
 
142
+ # set prompt if provided
143
+ prompt = None
144
+ prompt_type = extra_params.pop('prompt_type', '')
145
+ task_name = extra_params.pop('task_name', '')
146
+ if prompt_type and prompt_type == PromptType.query:
147
+ prompt = self.get_prompt(task_name)
148
+
119
149
  embeddings = self.model.encode(texts, prompt=prompt, **self.encode_kwargs)
120
150
  assert isinstance(embeddings, Tensor)
121
151
  return embeddings.cpu().detach()
122
152
 
123
- def encode_queries(self, queries, **kwargs):
124
- return self.encode(queries, prompt=self.prompt)
125
-
126
- def encode_corpus(self, corpus, **kwargs):
127
- if isinstance(corpus[0], dict):
128
- input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
129
- else:
130
- input_texts = corpus
131
- return self.encode(input_texts)
132
-
133
153
 
134
154
  class CrossEncoderModel(BaseModel):
135
155
 
136
156
  def __init__(self, model_name_or_path: str, **kwargs):
137
157
  super().__init__(model_name_or_path, **kwargs)
158
+
159
+ self.framework = ['Sentence Transformers', 'PyTorch']
160
+
138
161
  self.model = CrossEncoder(
139
162
  self.model_name_or_path,
140
163
  trust_remote_code=True,
141
164
  max_length=self.max_seq_length,
165
+ automodel_args=self.model_kwargs,
142
166
  )
167
+ self.supported_encode_params = get_supported_params(self.model.predict)
143
168
 
144
169
  def predict(self, sentences: List[List[str]], **kwargs) -> Tensor:
170
+ for key in list(kwargs.keys()):
171
+ if key not in self.supported_encode_params:
172
+ kwargs.pop(key)
145
173
  self.encode_kwargs.update(kwargs)
146
174
 
147
175
  if len(sentences[0]) == 3: # Note: For mteb retrieval task
@@ -163,6 +191,7 @@ class APIEmbeddingModel(BaseModel):
163
191
  self.openai_api_base = kwargs.get('api_base')
164
192
  self.openai_api_key = kwargs.get('api_key')
165
193
  self.dimensions = kwargs.get('dimensions')
194
+ self.framework = ['API']
166
195
 
167
196
  self.model = OpenAIEmbeddings(
168
197
  model=self.model_name,
@@ -175,26 +204,37 @@ class APIEmbeddingModel(BaseModel):
175
204
 
176
205
  self.batch_size = self.encode_kwargs.get('batch_size', 10)
177
206
 
207
+ self.supported_encode_params = get_supported_params(self.model.embed_documents)
208
+
178
209
  def encode(self, texts: Union[str, List[str]], **kwargs) -> Tensor:
210
+ # pop unused kwargs
211
+ extra_params = {}
212
+ for key in list(kwargs.keys()):
213
+ if key not in self.supported_encode_params:
214
+ extra_params[key] = kwargs.pop(key)
215
+ self.encode_kwargs.update(kwargs)
216
+
217
+ # set prompt if provided
218
+ prompt = None
219
+ prompt_type = extra_params.pop('prompt_type', '')
220
+ task_name = extra_params.pop('task_name', '')
221
+ if prompt_type and prompt_type == PromptType.query:
222
+ prompt = self.get_prompt(task_name)
223
+
179
224
  if isinstance(texts, str):
180
225
  texts = [texts]
181
226
 
182
227
  embeddings: List[List[float]] = []
183
228
  for i in tqdm(range(0, len(texts), self.batch_size)):
184
- response = self.model.embed_documents(texts[i:i + self.batch_size], chunk_size=self.batch_size)
229
+ # set prompt if provided
230
+ if prompt is not None:
231
+ batch_texts = [prompt + text for text in texts[i:i + self.batch_size]]
232
+ else:
233
+ batch_texts = texts[i:i + self.batch_size]
234
+ response = self.model.embed_documents(batch_texts, chunk_size=self.batch_size)
185
235
  embeddings.extend(response)
186
236
  return torch.tensor(embeddings)
187
237
 
188
- def encode_queries(self, queries, **kwargs):
189
- return self.encode(queries, **kwargs)
190
-
191
- def encode_corpus(self, corpus, **kwargs):
192
- if isinstance(corpus[0], dict):
193
- input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
194
- else:
195
- input_texts = corpus
196
- return self.encode(input_texts, **kwargs)
197
-
198
238
 
199
239
  class EmbeddingModel:
200
240
  """Custom embeddings"""
@@ -28,6 +28,7 @@ class BenchmarkMeta:
28
28
  system_prompt: Optional[str] = None
29
29
  query_template: Optional[str] = None
30
30
  pretty_name: Optional[str] = None
31
+ description: Optional[str] = None
31
32
  filters: Optional[OrderedDict] = None
32
33
  extra_params: Optional[Dict] = field(default_factory=dict)
33
34