evalscope 0.6.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,15 @@
1
- import os
2
1
  import asyncio
2
+ import os
3
+
3
4
  import pandas as pd
4
- from tqdm import tqdm
5
- from ragas.llms import LangchainLLMWrapper
6
5
  from ragas.embeddings import LangchainEmbeddingsWrapper
7
- from .translate_prompt import translate_prompts
8
- from evalscope.utils.logger import get_logger
9
- from evalscope.backend.rag_eval.ragas.arguments import TestsetGenerationArguments
10
- from evalscope.backend.rag_eval import EmbeddingModel, LLM, ChatOpenAI
6
+ from ragas.llms import LangchainLLMWrapper
7
+ from tqdm import tqdm
11
8
 
12
- os.environ['DO_NOT_TRACK'] = 'true'
9
+ from evalscope.backend.rag_eval import LLM, ChatOpenAI, EmbeddingModel
10
+ from evalscope.backend.rag_eval.ragas.arguments import TestsetGenerationArguments
11
+ from evalscope.utils.logger import get_logger
12
+ from .translate_prompt import translate_prompts
13
13
 
14
14
  logger = get_logger()
15
15
 
@@ -17,116 +17,110 @@ logger = get_logger()
17
17
  def get_transform(llm, embedding, language):
18
18
  """
19
19
  Creates and returns a default set of transforms for processing a knowledge graph.
20
-
21
- This function defines a series of transformation steps to be applied to a
22
- knowledge graph, including extracting summaries, keyphrases, titles,
23
- headlines, and embeddings, as well as building similarity relationships
24
- between nodes.
25
-
26
- The transforms are applied in the following order:
27
- 1. Parallel extraction of summaries and headlines
28
- 2. Embedding of summaries for document nodes
29
- 3. Splitting of headlines
30
- 4. Parallel extraction of embeddings, keyphrases, and titles
31
- 5. Building cosine similarity relationships between nodes
32
- 6. Building cosine similarity relationships between summaries
33
-
34
- Returns
35
- -------
36
- Transforms
37
- A list of transformation steps to be applied to the knowledge graph.
38
-
39
20
  """
40
21
  from ragas.testset.transforms.engine import Parallel
41
22
  from ragas.testset.transforms.extractors import (
42
23
  EmbeddingExtractor,
43
24
  HeadlinesExtractor,
44
- KeyphrasesExtractor,
45
25
  SummaryExtractor,
46
- TitleExtractor,
47
26
  )
48
- from ragas.testset.transforms.relationship_builders.cosine import (
27
+ from ragas.testset.transforms.extractors.llm_based import NERExtractor, ThemesExtractor
28
+ from ragas.testset.transforms.relationship_builders import (
49
29
  CosineSimilarityBuilder,
50
- SummaryCosineSimilarityBuilder,
30
+ OverlapScoreBuilder,
51
31
  )
52
32
  from ragas.testset.transforms.splitters import HeadlineSplitter
33
+ from ragas.testset.transforms.filters import CustomNodeFilter
53
34
  from ragas.testset.graph import NodeType
35
+ from ragas.utils import num_tokens_from_string
36
+
37
+ def summary_filter(node):
38
+ return (node.type == NodeType.DOCUMENT and num_tokens_from_string(node.properties['page_content']) > 500)
54
39
 
55
- # define the transforms
56
- summary_extractor = SummaryExtractor(llm=llm)
57
- keyphrase_extractor = KeyphrasesExtractor(llm=llm)
58
- title_extractor = TitleExtractor(llm=llm)
40
+ summary_extractor = SummaryExtractor(llm=llm, filter_nodes=lambda node: summary_filter(node))
41
+ ner_extractor = NERExtractor(llm=llm, filter_nodes=lambda node: node.type == NodeType.CHUNK)
42
+ theme_extractor = ThemesExtractor(llm=llm)
59
43
  headline_extractor = HeadlinesExtractor(llm=llm)
60
44
 
61
45
  asyncio.run(
62
46
  translate_prompts(
63
47
  prompts=[
64
48
  summary_extractor,
65
- keyphrase_extractor,
66
- title_extractor,
49
+ theme_extractor,
50
+ ner_extractor,
67
51
  headline_extractor,
68
52
  ],
69
53
  target_lang=language,
70
54
  llm=llm,
71
55
  adapt_instruction=True,
72
- )
73
- )
56
+ ))
57
+
58
+ splitter = HeadlineSplitter(min_tokens=500)
74
59
 
75
- embedding_extractor = EmbeddingExtractor(embedding_model=embedding)
76
- headline_splitter = HeadlineSplitter()
77
- cosine_sim_builder = CosineSimilarityBuilder(threshold=0.8)
78
- summary_embedder = EmbeddingExtractor(
79
- name='summary_embedder',
80
- filter_nodes=lambda node: True if node.type == NodeType.DOCUMENT else False,
60
+ summary_emb_extractor = EmbeddingExtractor(
61
+ embedding_model=embedding,
81
62
  property_name='summary_embedding',
82
63
  embed_property_name='summary',
83
- embedding_model=embedding,
64
+ filter_nodes=lambda node: summary_filter(node),
84
65
  )
85
- summary_cosine_sim_builder = SummaryCosineSimilarityBuilder(threshold=0.6)
86
66
 
87
- # specify the transforms and their order to be applied
67
+ cosine_sim_builder = CosineSimilarityBuilder(
68
+ property_name='summary_embedding',
69
+ new_property_name='summary_similarity',
70
+ threshold=0.7,
71
+ filter_nodes=lambda node: summary_filter(node),
72
+ )
73
+
74
+ ner_overlap_sim = OverlapScoreBuilder(threshold=0.01, filter_nodes=lambda node: node.type == NodeType.CHUNK)
75
+
76
+ node_filter = CustomNodeFilter(llm=llm, filter_nodes=lambda node: node.type == NodeType.CHUNK)
77
+
88
78
  transforms = [
89
- Parallel(summary_extractor, headline_extractor),
90
- summary_embedder,
91
- headline_splitter,
92
- Parallel(embedding_extractor, keyphrase_extractor, title_extractor),
93
- cosine_sim_builder,
94
- summary_cosine_sim_builder,
79
+ headline_extractor,
80
+ splitter,
81
+ summary_extractor,
82
+ node_filter,
83
+ Parallel(summary_emb_extractor, theme_extractor, ner_extractor),
84
+ Parallel(cosine_sim_builder, ner_overlap_sim),
95
85
  ]
86
+
96
87
  return transforms
97
88
 
98
89
 
99
90
  def get_distribution(llm, distribution, language):
100
- from ragas.testset.synthesizers.abstract_query import (
101
- AbstractQuerySynthesizer,
102
- ComparativeAbstractQuerySynthesizer,
91
+ from ragas.testset.synthesizers.multi_hop import (
92
+ MultiHopAbstractQuerySynthesizer,
93
+ MultiHopSpecificQuerySynthesizer,
103
94
  )
104
- from ragas.testset.synthesizers.specific_query import SpecificQuerySynthesizer
95
+ from ragas.testset.synthesizers.single_hop.specific import (
96
+ SingleHopSpecificQuerySynthesizer, )
105
97
 
106
- abstract = AbstractQuerySynthesizer(llm=llm)
107
- comparative = ComparativeAbstractQuerySynthesizer(llm=llm)
108
- specific = SpecificQuerySynthesizer(llm=llm)
98
+ single_hop = SingleHopSpecificQuerySynthesizer(llm=llm)
99
+ multi_hop_abs = MultiHopAbstractQuerySynthesizer(llm=llm)
100
+ multi_hop_spec = MultiHopSpecificQuerySynthesizer(llm=llm)
109
101
 
110
102
  asyncio.run(
111
103
  translate_prompts(
112
104
  prompts=[
113
- abstract,
114
- comparative,
115
- specific,
105
+ single_hop,
106
+ multi_hop_abs,
107
+ multi_hop_spec,
116
108
  ],
117
109
  target_lang=language,
118
110
  llm=llm,
119
111
  adapt_instruction=True,
120
- )
121
- )
122
- return [
123
- (abstract, distribution['simple']),
124
- (comparative, distribution['multi_context']),
125
- (specific, distribution['reasoning']),
126
- ]
112
+ ))
113
+
114
+ mapping = {
115
+ 'simple': single_hop,
116
+ 'multi_context': multi_hop_abs,
117
+ 'reasoning': multi_hop_spec,
118
+ }
119
+
120
+ return [(mapping[key], distribution[key]) for key in mapping if key in distribution]
127
121
 
128
122
 
129
- def get_knowledge_graph(documents, transforms, local_file):
123
+ def get_knowledge_graph(documents, transforms, local_file, run_config):
130
124
  from ragas.testset.graph import KnowledgeGraph, Node, NodeType
131
125
  from ragas.testset.transforms import apply_transforms
132
126
 
@@ -148,7 +142,7 @@ def get_knowledge_graph(documents, transforms, local_file):
148
142
  kg = KnowledgeGraph(nodes=nodes)
149
143
 
150
144
  # apply transforms and update the knowledge graph
151
- apply_transforms(kg, transforms)
145
+ apply_transforms(kg, transforms, run_config=run_config)
152
146
 
153
147
  # save the knowledge graph
154
148
  output_path = os.path.dirname(local_file)
@@ -158,6 +152,39 @@ def get_knowledge_graph(documents, transforms, local_file):
158
152
  return kg
159
153
 
160
154
 
155
+ def get_persona(llm, kg, language):
156
+ from evalscope.backend.rag_eval.ragas.prompts.persona_prompt import PersonaGenerationPromptZH
157
+ from ragas.testset.persona import generate_personas_from_kg, PersonaGenerationPrompt
158
+ from ragas.testset.graph import Node
159
+
160
+ def filter(node: Node) -> bool:
161
+ if (node.type.name == 'DOCUMENT' and node.properties.get('summary_embedding') is not None):
162
+ return True
163
+ else:
164
+ return False
165
+
166
+ if language == 'chinese':
167
+ persona_prompt = PersonaGenerationPromptZH()
168
+ else:
169
+ persona_prompt = PersonaGenerationPrompt()
170
+ # NOTE: can't translate this yet
171
+ # asyncio.run(
172
+ # translate_prompts(
173
+ # prompts=[persona_prompt],
174
+ # target_lang=language,
175
+ # llm=llm,
176
+ # adapt_instruction=True,
177
+ # ))
178
+
179
+ return generate_personas_from_kg(
180
+ llm=llm,
181
+ kg=kg,
182
+ num_personas=3,
183
+ persona_generation_prompt=persona_prompt,
184
+ filter_fn=filter,
185
+ )
186
+
187
+
161
188
  def load_data(file_path):
162
189
  from langchain_community.document_loaders import UnstructuredFileLoader
163
190
 
@@ -178,32 +205,31 @@ def generate_testset(args: TestsetGenerationArguments) -> None:
178
205
  generator_llm = LLM.load(**args.generator_llm)
179
206
  embeddings = EmbeddingModel.load(**args.embeddings)
180
207
 
208
+ wrapped_llm = LangchainLLMWrapper(generator_llm)
209
+ wrapped_embeddings = LangchainEmbeddingsWrapper(embeddings)
210
+
181
211
  # Change resulting question type distribution
182
- distributions = get_distribution(
183
- LangchainLLMWrapper(generator_llm), args.distribution, args.language
184
- )
212
+ distributions = get_distribution(wrapped_llm, args.distribution, args.language)
185
213
 
214
+ run_config = RunConfig(timeout=600, max_retries=3, max_wait=120, max_workers=1, log_tenacity=True)
186
215
  # get transforms
187
216
  transforms = get_transform(
188
- LangchainLLMWrapper(generator_llm),
189
- LangchainEmbeddingsWrapper(embeddings),
217
+ wrapped_llm,
218
+ wrapped_embeddings,
190
219
  args.language,
191
220
  )
192
221
 
193
222
  # get knowledge graph
194
- knowledge_graph = get_knowledge_graph(documents, transforms, args.knowledge_graph)
223
+ knowledge_graph = get_knowledge_graph(documents, transforms, args.knowledge_graph, run_config)
195
224
 
196
- generator = TestsetGenerator.from_langchain(
197
- generator_llm, embeddings, knowledge_graph
198
- )
225
+ persona_list = get_persona(llm=wrapped_llm, kg=knowledge_graph, language=args.language)
226
+
227
+ generator = TestsetGenerator(llm=wrapped_llm, knowledge_graph=knowledge_graph, persona_list=persona_list)
199
228
 
200
- runconfig = RunConfig(
201
- timeout=600, max_retries=3, max_wait=120, max_workers=1, log_tenacity=True
202
- )
203
229
  testset = generator.generate(
204
230
  testset_size=args.test_size,
205
231
  query_distribution=distributions,
206
- run_config=runconfig,
232
+ run_config=run_config,
207
233
  with_debugging_logs=True,
208
234
  raise_exceptions=True,
209
235
  )
@@ -212,9 +238,7 @@ def generate_testset(args: TestsetGenerationArguments) -> None:
212
238
  testset_df = testset.to_pandas()
213
239
  output_path = os.path.dirname(args.output_file)
214
240
  os.makedirs(output_path, exist_ok=True)
215
- testset_df.to_json(
216
- args.output_file, indent=4, index=False, orient='records', force_ascii=False
217
- )
241
+ testset_df.to_json(args.output_file, indent=4, index=False, orient='records', force_ascii=False)
218
242
 
219
243
  # get answer
220
244
  testset_with_answer = get_answer(testset_df, generator_llm, args.language)
@@ -243,21 +267,17 @@ Answer:
243
267
  contexts = '\n'.join(row['reference_contexts'])
244
268
 
245
269
  # Combine question and contexts as input for the LLM
246
- input_text = template.format(
247
- language=language, question=question, contexts=contexts
248
- )
270
+ input_text = template.format(language=language, question=question, contexts=contexts)
249
271
 
250
272
  # Generate the answer using the generator LLM
251
273
  answer = generator_llm.invoke(input_text)
252
274
  if isinstance(generator_llm, ChatOpenAI):
253
275
  answer = answer.content
254
- items.append(
255
- {
256
- 'user_input': question,
257
- 'retrieved_contexts': row['reference_contexts'],
258
- 'response': answer,
259
- 'reference': row['reference'],
260
- }
261
- )
276
+ items.append({
277
+ 'user_input': question,
278
+ 'retrieved_contexts': row['reference_contexts'],
279
+ 'response': answer,
280
+ 'reference': row['reference'],
281
+ })
262
282
 
263
283
  return pd.DataFrame.from_dict(items)
File without changes
@@ -0,0 +1,149 @@
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from typing import List
5
+ from PIL import Image
6
+ from evalscope.backend.rag_eval.utils.tools import download_model, PIL_to_base64
7
+ from transformers import AutoModel, AutoProcessor
8
+ from langchain_core.embeddings import Embeddings
9
+
10
+
11
+ class VisionModel:
12
+ @staticmethod
13
+ def load(**kw):
14
+ api_base = kw.get("api_base", None)
15
+ if api_base:
16
+
17
+ return VLMAPI(
18
+ model_name=kw.get("model_name", ""),
19
+ openai_api_base=api_base,
20
+ openai_api_key=kw.get("api_key", "EMPTY"),
21
+ prompt=kw.get("prompt", None),
22
+ )
23
+ else:
24
+ return CLIPModel(**kw)
25
+
26
+
27
+ class VLMAPI:
28
+ def __init__(self, model_name, openai_api_base, openai_api_key, prompt=None):
29
+ from langchain_openai import ChatOpenAI
30
+ from langchain_core.prompts import ChatPromptTemplate
31
+
32
+ self.model_name = model_name
33
+ self.model = ChatOpenAI(
34
+ model_name=model_name,
35
+ openai_api_base=openai_api_base,
36
+ openai_api_key=openai_api_key,
37
+ )
38
+ self.default_prompt = "Please describe this image in general. Directly provide the description, do not include prefix like 'This image depicts'"
39
+ self.prompt = ChatPromptTemplate.from_messages(
40
+ [
41
+ ("system", prompt if prompt else self.default_prompt),
42
+ (
43
+ "user",
44
+ [
45
+ {
46
+ "type": "image_url",
47
+ "image_url": {"url": "data:image/jpeg;base64,{image_data}"},
48
+ }
49
+ ],
50
+ ),
51
+ ]
52
+ )
53
+ self.chain = self.prompt | self.model
54
+ self.transform = PIL_to_base64
55
+
56
+ def encode_image(self, images):
57
+ captions = []
58
+ for image in images:
59
+ response = self.chain.invoke({"image_data": image})
60
+ captions.append(response.content)
61
+ return captions
62
+
63
+
64
+ class CLIPModel(Embeddings):
65
+ def __init__(
66
+ self,
67
+ model_name: str,
68
+ revision: str = "master",
69
+ hub="modelscope",
70
+ device="cpu",
71
+ ):
72
+ self.device = device
73
+ self.model_name = model_name
74
+ self.revision = revision
75
+
76
+ # Download the model if it doesn't exist locally
77
+ if not os.path.exists(model_name) and hub == "modelscope":
78
+ model_name = download_model(self.model_name, self.revision)
79
+
80
+ # Load the model and processor
81
+ self.model = AutoModel.from_pretrained(model_name).to(self.device)
82
+ self.processor = AutoProcessor.from_pretrained(model_name)
83
+ self.transform = self.processor.image_processor
84
+ self.tokenizer = self.processor.tokenizer
85
+
86
+ def encode_text(self, batch_texts: List[str] | List[List[str]]):
87
+ if isinstance(batch_texts[0], list):
88
+ batch_texts = [
89
+ text for _, texts in enumerate(batch_texts) for text in texts
90
+ ]
91
+ # Ensure that the input texts are within the token limit
92
+ max_length = self.tokenizer.model_max_length
93
+ if not max_length or max_length > 0xFFFFFF:
94
+ max_length = 512
95
+ encoded_inputs = self.tokenizer(
96
+ text=batch_texts,
97
+ max_length=max_length,
98
+ padding=True,
99
+ truncation=True,
100
+ return_tensors="pt",
101
+ )
102
+
103
+ inputs = {k: v.to(self.device) for k, v in encoded_inputs.items()}
104
+
105
+ with torch.no_grad():
106
+ text_features = self.model.get_text_features(**inputs)
107
+ text_features = F.normalize(text_features, p=2, dim=-1)
108
+ return text_features
109
+
110
+ def encode_image(self, image):
111
+ batch_images = torch.stack([d["pixel_values"][0] for d in image])
112
+ batch_images = batch_images.to(self.device)
113
+ with torch.no_grad():
114
+ image_features = self.model.get_image_features(batch_images)
115
+ image_features = F.normalize(image_features, p=2, dim=-1)
116
+ return image_features
117
+
118
+ def embed_documents(self, texts):
119
+ text_features = self.encode_text(texts)
120
+ return text_features.cpu().numpy().tolist()
121
+
122
+ def embed_query(self, text):
123
+ text_features = self.encode_text([text])
124
+ return text_features.cpu().numpy().tolist()[0]
125
+
126
+ def embed_image(self, uris: List[str]):
127
+ # read image and transform
128
+ images = [Image.open(image_path) for image_path in uris]
129
+ transformed_images = [
130
+ self.transform(
131
+ image,
132
+ return_tensors="pt",
133
+ )
134
+ for image in images
135
+ ]
136
+ image_features = self.encode_image(transformed_images)
137
+ return image_features.cpu().numpy().tolist()
138
+
139
+
140
+ if __name__ == "__main__":
141
+ model = CLIPModel("AI-ModelScope/chinese-clip-vit-large-patch14-336px")
142
+ model.embed_image(
143
+ [
144
+ "custom_eval/multimodal/images/AMNH.jpg",
145
+ "custom_eval/multimodal/images/AMNH.jpg",
146
+ ]
147
+ )
148
+ model.encode_text(["我喜欢吃饭" * 1000])
149
+ print("done")
@@ -0,0 +1,183 @@
1
+ import os
2
+ import torch
3
+ from typing import List, Optional, Union, Dict
4
+ from sentence_transformers import models
5
+ from sentence_transformers.SentenceTransformer import SentenceTransformer
6
+ from sentence_transformers.cross_encoder import CrossEncoder
7
+ from torch import Tensor
8
+ from evalscope.backend.rag_eval.utils.tools import download_model
9
+ from evalscope.utils.logger import get_logger
10
+ from langchain_core.embeddings import Embeddings
11
+
12
+ logger = get_logger()
13
+
14
+
15
+ class BaseModel(Embeddings):
16
+ def __init__(
17
+ self,
18
+ model_name_or_path: str,
19
+ max_seq_length: int = 512,
20
+ prompt: str = '',
21
+ revision: Optional[str] = None,
22
+ **kwargs,
23
+ ):
24
+ self.model_name_or_path = model_name_or_path
25
+ self.max_seq_length = max_seq_length
26
+ self.model_kwargs = kwargs.pop('model_kwargs', {})
27
+ self.model_kwargs['trust_remote_code'] = True
28
+
29
+ self.config_kwargs = kwargs.pop('config_kwargs', {})
30
+ self.config_kwargs['trust_remote_code'] = True
31
+
32
+ self.encode_kwargs = kwargs.pop('encode_kwargs', {})
33
+ self.encode_kwargs['convert_to_tensor'] = True
34
+
35
+ self.prompt = prompt
36
+ self.revision = revision
37
+
38
+ @property
39
+ def mteb_model_meta(self):
40
+ """Model metadata for MTEB (Multilingual Task Embeddings Benchmark)"""
41
+ from mteb import ModelMeta
42
+
43
+ return ModelMeta(
44
+ name=os.path.basename(self.model_name_or_path),
45
+ revision=self.revision,
46
+ languages=None,
47
+ release_date=None,
48
+ )
49
+
50
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
51
+ """Embed search docs. Compact langchain.
52
+
53
+ Args:
54
+ texts: List of text to embed.
55
+
56
+ Returns:
57
+ List of embeddings.
58
+ """
59
+ return self.encode_corpus(texts).tolist()
60
+
61
+ def embed_query(self, text: str) -> List[float]:
62
+ """Embed query text. Compact langchain.
63
+
64
+ Args:
65
+ text: Text to embed.
66
+
67
+ Returns:
68
+ Embedding.
69
+ """
70
+ return self.encode_queries(text).tolist()
71
+
72
+ def encode(self, texts: Union[str, List[str]], **kwargs) -> List[List[float]]:
73
+ """Embed text."""
74
+ raise NotImplementedError
75
+
76
+ def encode_queries(self, queries: List[str], **kwargs) -> list[torch.Tensor]:
77
+ """Embed query text. Compact mteb."""
78
+ raise NotImplementedError
79
+
80
+ def encode_corpus(self, corpus: List[str] | List[Dict[str, str]], **kwargs) -> list[torch.Tensor]:
81
+ """Embed search docs . Compact mteb."""
82
+ raise NotImplementedError
83
+
84
+
85
+ class SentenceTransformerModel(BaseModel):
86
+ def __init__(
87
+ self, model_name_or_path: str, pooling_mode: Optional[str] = None, **kwargs
88
+ ):
89
+ super().__init__(model_name_or_path, **kwargs)
90
+
91
+ if not pooling_mode:
92
+ self.model = SentenceTransformer(
93
+ self.model_name_or_path,
94
+ config_kwargs=self.config_kwargs,
95
+ model_kwargs=self.model_kwargs,
96
+ )
97
+ else:
98
+ word_embedding_model = models.Transformer(
99
+ self.model_name_or_path,
100
+ config_args=self.config_kwargs,
101
+ model_args=self.model_kwargs,
102
+ )
103
+ pooling_model = models.Pooling(
104
+ word_embedding_model.get_word_embedding_dimension(),
105
+ pooling_mode=pooling_mode,
106
+ )
107
+ self.model = SentenceTransformer(
108
+ modules=[word_embedding_model, pooling_model],
109
+ )
110
+
111
+ self.model.max_seq_length = self.max_seq_length
112
+
113
+ def encode(self, texts: Union[str, List[str]], prompt=None, **kwargs) -> List[torch.Tensor]:
114
+ kwargs.pop('prompt_name', '') # remove prompt name, use prompt
115
+ self.encode_kwargs.update(kwargs)
116
+
117
+ embeddings = self.model.encode(texts, prompt=prompt, **self.encode_kwargs)
118
+ assert isinstance(embeddings, Tensor)
119
+ return embeddings.cpu().detach()
120
+
121
+ def encode_queries(self, queries, **kwargs):
122
+ return self.encode(queries, prompt=self.prompt)
123
+
124
+ def encode_corpus(self, corpus, **kwargs):
125
+ if isinstance(corpus[0], dict):
126
+ input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
127
+ else:
128
+ input_texts = corpus
129
+ return self.encode(input_texts)
130
+
131
+
132
+ class CrossEncoderModel(BaseModel):
133
+ def __init__(self, model_name_or_path: str, **kwargs):
134
+ super().__init__(model_name_or_path, **kwargs)
135
+ self.model = CrossEncoder(
136
+ self.model_name_or_path,
137
+ trust_remote_code=True,
138
+ max_length=self.max_seq_length,
139
+ )
140
+
141
+ def predict(self, sentences: List[List[str]], **kwargs) -> List[List[float]]:
142
+ self.encode_kwargs.update(kwargs)
143
+
144
+ if len(sentences[0]) == 3: # Note: For mteb retrieval task
145
+ processed_sentences = []
146
+ for query, docs, instruction in sentences:
147
+ if isinstance(docs, dict):
148
+ docs = docs['text']
149
+ processed_sentences.append((self.prompt + query, docs))
150
+ sentences = processed_sentences
151
+ embeddings = self.model.predict(sentences, **self.encode_kwargs)
152
+ assert isinstance(embeddings, Tensor)
153
+ return embeddings
154
+
155
+
156
+ class EmbeddingModel:
157
+ """Custom embeddings"""
158
+
159
+ @staticmethod
160
+ def load(
161
+ model_name_or_path: str = '',
162
+ is_cross_encoder: bool = False,
163
+ hub: str = 'modelscope',
164
+ revision: Optional[str] = 'master',
165
+ **kwargs,
166
+ ):
167
+ # If model path does not exist and hub is 'modelscope', download the model
168
+ if not os.path.exists(model_name_or_path) and hub == 'modelscope':
169
+ model_name_or_path = download_model(model_name_or_path, revision)
170
+
171
+ # Return different model instances based on whether it is a cross-encoder and pooling mode
172
+ if is_cross_encoder:
173
+ return CrossEncoderModel(
174
+ model_name_or_path,
175
+ revision=revision,
176
+ **kwargs,
177
+ )
178
+ else:
179
+ return SentenceTransformerModel(
180
+ model_name_or_path,
181
+ revision=revision,
182
+ **kwargs,
183
+ )