evalscope 0.6.0__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalscope/backend/opencompass/tasks/eval_datasets.py +1 -0
- evalscope/backend/rag_eval/cmteb/tasks/Clustering.py +96 -96
- evalscope/backend/rag_eval/cmteb/tasks/Reranking.py +70 -71
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +120 -100
- evalscope/backend/rag_eval/utils/__init__.py +0 -0
- evalscope/backend/rag_eval/utils/clip.py +149 -0
- evalscope/backend/rag_eval/utils/embedding.py +183 -0
- evalscope/backend/rag_eval/utils/llm.py +72 -0
- evalscope/backend/rag_eval/utils/tools.py +63 -0
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +1 -1
- evalscope/version.py +2 -2
- {evalscope-0.6.0.dist-info → evalscope-0.6.1.dist-info}/METADATA +14 -13
- {evalscope-0.6.0.dist-info → evalscope-0.6.1.dist-info}/RECORD +16 -11
- {evalscope-0.6.0.dist-info → evalscope-0.6.1.dist-info}/WHEEL +1 -1
- {evalscope-0.6.0.dist-info → evalscope-0.6.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.6.0.dist-info → evalscope-0.6.1.dist-info}/top_level.txt +0 -0
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
|
|
3
4
|
import pandas as pd
|
|
4
|
-
from tqdm import tqdm
|
|
5
|
-
from ragas.llms import LangchainLLMWrapper
|
|
6
5
|
from ragas.embeddings import LangchainEmbeddingsWrapper
|
|
7
|
-
from .
|
|
8
|
-
from
|
|
9
|
-
from evalscope.backend.rag_eval.ragas.arguments import TestsetGenerationArguments
|
|
10
|
-
from evalscope.backend.rag_eval import EmbeddingModel, LLM, ChatOpenAI
|
|
6
|
+
from ragas.llms import LangchainLLMWrapper
|
|
7
|
+
from tqdm import tqdm
|
|
11
8
|
|
|
12
|
-
|
|
9
|
+
from evalscope.backend.rag_eval import LLM, ChatOpenAI, EmbeddingModel
|
|
10
|
+
from evalscope.backend.rag_eval.ragas.arguments import TestsetGenerationArguments
|
|
11
|
+
from evalscope.utils.logger import get_logger
|
|
12
|
+
from .translate_prompt import translate_prompts
|
|
13
13
|
|
|
14
14
|
logger = get_logger()
|
|
15
15
|
|
|
@@ -17,116 +17,110 @@ logger = get_logger()
|
|
|
17
17
|
def get_transform(llm, embedding, language):
|
|
18
18
|
"""
|
|
19
19
|
Creates and returns a default set of transforms for processing a knowledge graph.
|
|
20
|
-
|
|
21
|
-
This function defines a series of transformation steps to be applied to a
|
|
22
|
-
knowledge graph, including extracting summaries, keyphrases, titles,
|
|
23
|
-
headlines, and embeddings, as well as building similarity relationships
|
|
24
|
-
between nodes.
|
|
25
|
-
|
|
26
|
-
The transforms are applied in the following order:
|
|
27
|
-
1. Parallel extraction of summaries and headlines
|
|
28
|
-
2. Embedding of summaries for document nodes
|
|
29
|
-
3. Splitting of headlines
|
|
30
|
-
4. Parallel extraction of embeddings, keyphrases, and titles
|
|
31
|
-
5. Building cosine similarity relationships between nodes
|
|
32
|
-
6. Building cosine similarity relationships between summaries
|
|
33
|
-
|
|
34
|
-
Returns
|
|
35
|
-
-------
|
|
36
|
-
Transforms
|
|
37
|
-
A list of transformation steps to be applied to the knowledge graph.
|
|
38
|
-
|
|
39
20
|
"""
|
|
40
21
|
from ragas.testset.transforms.engine import Parallel
|
|
41
22
|
from ragas.testset.transforms.extractors import (
|
|
42
23
|
EmbeddingExtractor,
|
|
43
24
|
HeadlinesExtractor,
|
|
44
|
-
KeyphrasesExtractor,
|
|
45
25
|
SummaryExtractor,
|
|
46
|
-
TitleExtractor,
|
|
47
26
|
)
|
|
48
|
-
from ragas.testset.transforms.
|
|
27
|
+
from ragas.testset.transforms.extractors.llm_based import NERExtractor, ThemesExtractor
|
|
28
|
+
from ragas.testset.transforms.relationship_builders import (
|
|
49
29
|
CosineSimilarityBuilder,
|
|
50
|
-
|
|
30
|
+
OverlapScoreBuilder,
|
|
51
31
|
)
|
|
52
32
|
from ragas.testset.transforms.splitters import HeadlineSplitter
|
|
33
|
+
from ragas.testset.transforms.filters import CustomNodeFilter
|
|
53
34
|
from ragas.testset.graph import NodeType
|
|
35
|
+
from ragas.utils import num_tokens_from_string
|
|
36
|
+
|
|
37
|
+
def summary_filter(node):
|
|
38
|
+
return (node.type == NodeType.DOCUMENT and num_tokens_from_string(node.properties['page_content']) > 500)
|
|
54
39
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
title_extractor = TitleExtractor(llm=llm)
|
|
40
|
+
summary_extractor = SummaryExtractor(llm=llm, filter_nodes=lambda node: summary_filter(node))
|
|
41
|
+
ner_extractor = NERExtractor(llm=llm, filter_nodes=lambda node: node.type == NodeType.CHUNK)
|
|
42
|
+
theme_extractor = ThemesExtractor(llm=llm)
|
|
59
43
|
headline_extractor = HeadlinesExtractor(llm=llm)
|
|
60
44
|
|
|
61
45
|
asyncio.run(
|
|
62
46
|
translate_prompts(
|
|
63
47
|
prompts=[
|
|
64
48
|
summary_extractor,
|
|
65
|
-
|
|
66
|
-
|
|
49
|
+
theme_extractor,
|
|
50
|
+
ner_extractor,
|
|
67
51
|
headline_extractor,
|
|
68
52
|
],
|
|
69
53
|
target_lang=language,
|
|
70
54
|
llm=llm,
|
|
71
55
|
adapt_instruction=True,
|
|
72
|
-
)
|
|
73
|
-
|
|
56
|
+
))
|
|
57
|
+
|
|
58
|
+
splitter = HeadlineSplitter(min_tokens=500)
|
|
74
59
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
cosine_sim_builder = CosineSimilarityBuilder(threshold=0.8)
|
|
78
|
-
summary_embedder = EmbeddingExtractor(
|
|
79
|
-
name='summary_embedder',
|
|
80
|
-
filter_nodes=lambda node: True if node.type == NodeType.DOCUMENT else False,
|
|
60
|
+
summary_emb_extractor = EmbeddingExtractor(
|
|
61
|
+
embedding_model=embedding,
|
|
81
62
|
property_name='summary_embedding',
|
|
82
63
|
embed_property_name='summary',
|
|
83
|
-
|
|
64
|
+
filter_nodes=lambda node: summary_filter(node),
|
|
84
65
|
)
|
|
85
|
-
summary_cosine_sim_builder = SummaryCosineSimilarityBuilder(threshold=0.6)
|
|
86
66
|
|
|
87
|
-
|
|
67
|
+
cosine_sim_builder = CosineSimilarityBuilder(
|
|
68
|
+
property_name='summary_embedding',
|
|
69
|
+
new_property_name='summary_similarity',
|
|
70
|
+
threshold=0.7,
|
|
71
|
+
filter_nodes=lambda node: summary_filter(node),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
ner_overlap_sim = OverlapScoreBuilder(threshold=0.01, filter_nodes=lambda node: node.type == NodeType.CHUNK)
|
|
75
|
+
|
|
76
|
+
node_filter = CustomNodeFilter(llm=llm, filter_nodes=lambda node: node.type == NodeType.CHUNK)
|
|
77
|
+
|
|
88
78
|
transforms = [
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
79
|
+
headline_extractor,
|
|
80
|
+
splitter,
|
|
81
|
+
summary_extractor,
|
|
82
|
+
node_filter,
|
|
83
|
+
Parallel(summary_emb_extractor, theme_extractor, ner_extractor),
|
|
84
|
+
Parallel(cosine_sim_builder, ner_overlap_sim),
|
|
95
85
|
]
|
|
86
|
+
|
|
96
87
|
return transforms
|
|
97
88
|
|
|
98
89
|
|
|
99
90
|
def get_distribution(llm, distribution, language):
|
|
100
|
-
from ragas.testset.synthesizers.
|
|
101
|
-
|
|
102
|
-
|
|
91
|
+
from ragas.testset.synthesizers.multi_hop import (
|
|
92
|
+
MultiHopAbstractQuerySynthesizer,
|
|
93
|
+
MultiHopSpecificQuerySynthesizer,
|
|
103
94
|
)
|
|
104
|
-
from ragas.testset.synthesizers.
|
|
95
|
+
from ragas.testset.synthesizers.single_hop.specific import (
|
|
96
|
+
SingleHopSpecificQuerySynthesizer, )
|
|
105
97
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
98
|
+
single_hop = SingleHopSpecificQuerySynthesizer(llm=llm)
|
|
99
|
+
multi_hop_abs = MultiHopAbstractQuerySynthesizer(llm=llm)
|
|
100
|
+
multi_hop_spec = MultiHopSpecificQuerySynthesizer(llm=llm)
|
|
109
101
|
|
|
110
102
|
asyncio.run(
|
|
111
103
|
translate_prompts(
|
|
112
104
|
prompts=[
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
105
|
+
single_hop,
|
|
106
|
+
multi_hop_abs,
|
|
107
|
+
multi_hop_spec,
|
|
116
108
|
],
|
|
117
109
|
target_lang=language,
|
|
118
110
|
llm=llm,
|
|
119
111
|
adapt_instruction=True,
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
112
|
+
))
|
|
113
|
+
|
|
114
|
+
mapping = {
|
|
115
|
+
'simple': single_hop,
|
|
116
|
+
'multi_context': multi_hop_abs,
|
|
117
|
+
'reasoning': multi_hop_spec,
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
return [(mapping[key], distribution[key]) for key in mapping if key in distribution]
|
|
127
121
|
|
|
128
122
|
|
|
129
|
-
def get_knowledge_graph(documents, transforms, local_file):
|
|
123
|
+
def get_knowledge_graph(documents, transforms, local_file, run_config):
|
|
130
124
|
from ragas.testset.graph import KnowledgeGraph, Node, NodeType
|
|
131
125
|
from ragas.testset.transforms import apply_transforms
|
|
132
126
|
|
|
@@ -148,7 +142,7 @@ def get_knowledge_graph(documents, transforms, local_file):
|
|
|
148
142
|
kg = KnowledgeGraph(nodes=nodes)
|
|
149
143
|
|
|
150
144
|
# apply transforms and update the knowledge graph
|
|
151
|
-
apply_transforms(kg, transforms)
|
|
145
|
+
apply_transforms(kg, transforms, run_config=run_config)
|
|
152
146
|
|
|
153
147
|
# save the knowledge graph
|
|
154
148
|
output_path = os.path.dirname(local_file)
|
|
@@ -158,6 +152,39 @@ def get_knowledge_graph(documents, transforms, local_file):
|
|
|
158
152
|
return kg
|
|
159
153
|
|
|
160
154
|
|
|
155
|
+
def get_persona(llm, kg, language):
|
|
156
|
+
from evalscope.backend.rag_eval.ragas.prompts.persona_prompt import PersonaGenerationPromptZH
|
|
157
|
+
from ragas.testset.persona import generate_personas_from_kg, PersonaGenerationPrompt
|
|
158
|
+
from ragas.testset.graph import Node
|
|
159
|
+
|
|
160
|
+
def filter(node: Node) -> bool:
|
|
161
|
+
if (node.type.name == 'DOCUMENT' and node.properties.get('summary_embedding') is not None):
|
|
162
|
+
return True
|
|
163
|
+
else:
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
if language == 'chinese':
|
|
167
|
+
persona_prompt = PersonaGenerationPromptZH()
|
|
168
|
+
else:
|
|
169
|
+
persona_prompt = PersonaGenerationPrompt()
|
|
170
|
+
# NOTE: can't translate this yet
|
|
171
|
+
# asyncio.run(
|
|
172
|
+
# translate_prompts(
|
|
173
|
+
# prompts=[persona_prompt],
|
|
174
|
+
# target_lang=language,
|
|
175
|
+
# llm=llm,
|
|
176
|
+
# adapt_instruction=True,
|
|
177
|
+
# ))
|
|
178
|
+
|
|
179
|
+
return generate_personas_from_kg(
|
|
180
|
+
llm=llm,
|
|
181
|
+
kg=kg,
|
|
182
|
+
num_personas=3,
|
|
183
|
+
persona_generation_prompt=persona_prompt,
|
|
184
|
+
filter_fn=filter,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
161
188
|
def load_data(file_path):
|
|
162
189
|
from langchain_community.document_loaders import UnstructuredFileLoader
|
|
163
190
|
|
|
@@ -178,32 +205,31 @@ def generate_testset(args: TestsetGenerationArguments) -> None:
|
|
|
178
205
|
generator_llm = LLM.load(**args.generator_llm)
|
|
179
206
|
embeddings = EmbeddingModel.load(**args.embeddings)
|
|
180
207
|
|
|
208
|
+
wrapped_llm = LangchainLLMWrapper(generator_llm)
|
|
209
|
+
wrapped_embeddings = LangchainEmbeddingsWrapper(embeddings)
|
|
210
|
+
|
|
181
211
|
# Change resulting question type distribution
|
|
182
|
-
distributions = get_distribution(
|
|
183
|
-
LangchainLLMWrapper(generator_llm), args.distribution, args.language
|
|
184
|
-
)
|
|
212
|
+
distributions = get_distribution(wrapped_llm, args.distribution, args.language)
|
|
185
213
|
|
|
214
|
+
run_config = RunConfig(timeout=600, max_retries=3, max_wait=120, max_workers=1, log_tenacity=True)
|
|
186
215
|
# get transforms
|
|
187
216
|
transforms = get_transform(
|
|
188
|
-
|
|
189
|
-
|
|
217
|
+
wrapped_llm,
|
|
218
|
+
wrapped_embeddings,
|
|
190
219
|
args.language,
|
|
191
220
|
)
|
|
192
221
|
|
|
193
222
|
# get knowledge graph
|
|
194
|
-
knowledge_graph = get_knowledge_graph(documents, transforms, args.knowledge_graph)
|
|
223
|
+
knowledge_graph = get_knowledge_graph(documents, transforms, args.knowledge_graph, run_config)
|
|
195
224
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
)
|
|
225
|
+
persona_list = get_persona(llm=wrapped_llm, kg=knowledge_graph, language=args.language)
|
|
226
|
+
|
|
227
|
+
generator = TestsetGenerator(llm=wrapped_llm, knowledge_graph=knowledge_graph, persona_list=persona_list)
|
|
199
228
|
|
|
200
|
-
runconfig = RunConfig(
|
|
201
|
-
timeout=600, max_retries=3, max_wait=120, max_workers=1, log_tenacity=True
|
|
202
|
-
)
|
|
203
229
|
testset = generator.generate(
|
|
204
230
|
testset_size=args.test_size,
|
|
205
231
|
query_distribution=distributions,
|
|
206
|
-
run_config=
|
|
232
|
+
run_config=run_config,
|
|
207
233
|
with_debugging_logs=True,
|
|
208
234
|
raise_exceptions=True,
|
|
209
235
|
)
|
|
@@ -212,9 +238,7 @@ def generate_testset(args: TestsetGenerationArguments) -> None:
|
|
|
212
238
|
testset_df = testset.to_pandas()
|
|
213
239
|
output_path = os.path.dirname(args.output_file)
|
|
214
240
|
os.makedirs(output_path, exist_ok=True)
|
|
215
|
-
testset_df.to_json(
|
|
216
|
-
args.output_file, indent=4, index=False, orient='records', force_ascii=False
|
|
217
|
-
)
|
|
241
|
+
testset_df.to_json(args.output_file, indent=4, index=False, orient='records', force_ascii=False)
|
|
218
242
|
|
|
219
243
|
# get answer
|
|
220
244
|
testset_with_answer = get_answer(testset_df, generator_llm, args.language)
|
|
@@ -243,21 +267,17 @@ Answer:
|
|
|
243
267
|
contexts = '\n'.join(row['reference_contexts'])
|
|
244
268
|
|
|
245
269
|
# Combine question and contexts as input for the LLM
|
|
246
|
-
input_text = template.format(
|
|
247
|
-
language=language, question=question, contexts=contexts
|
|
248
|
-
)
|
|
270
|
+
input_text = template.format(language=language, question=question, contexts=contexts)
|
|
249
271
|
|
|
250
272
|
# Generate the answer using the generator LLM
|
|
251
273
|
answer = generator_llm.invoke(input_text)
|
|
252
274
|
if isinstance(generator_llm, ChatOpenAI):
|
|
253
275
|
answer = answer.content
|
|
254
|
-
items.append(
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
}
|
|
261
|
-
)
|
|
276
|
+
items.append({
|
|
277
|
+
'user_input': question,
|
|
278
|
+
'retrieved_contexts': row['reference_contexts'],
|
|
279
|
+
'response': answer,
|
|
280
|
+
'reference': row['reference'],
|
|
281
|
+
})
|
|
262
282
|
|
|
263
283
|
return pd.DataFrame.from_dict(items)
|
|
File without changes
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from typing import List
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from evalscope.backend.rag_eval.utils.tools import download_model, PIL_to_base64
|
|
7
|
+
from transformers import AutoModel, AutoProcessor
|
|
8
|
+
from langchain_core.embeddings import Embeddings
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class VisionModel:
|
|
12
|
+
@staticmethod
|
|
13
|
+
def load(**kw):
|
|
14
|
+
api_base = kw.get("api_base", None)
|
|
15
|
+
if api_base:
|
|
16
|
+
|
|
17
|
+
return VLMAPI(
|
|
18
|
+
model_name=kw.get("model_name", ""),
|
|
19
|
+
openai_api_base=api_base,
|
|
20
|
+
openai_api_key=kw.get("api_key", "EMPTY"),
|
|
21
|
+
prompt=kw.get("prompt", None),
|
|
22
|
+
)
|
|
23
|
+
else:
|
|
24
|
+
return CLIPModel(**kw)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class VLMAPI:
|
|
28
|
+
def __init__(self, model_name, openai_api_base, openai_api_key, prompt=None):
|
|
29
|
+
from langchain_openai import ChatOpenAI
|
|
30
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
31
|
+
|
|
32
|
+
self.model_name = model_name
|
|
33
|
+
self.model = ChatOpenAI(
|
|
34
|
+
model_name=model_name,
|
|
35
|
+
openai_api_base=openai_api_base,
|
|
36
|
+
openai_api_key=openai_api_key,
|
|
37
|
+
)
|
|
38
|
+
self.default_prompt = "Please describe this image in general. Directly provide the description, do not include prefix like 'This image depicts'"
|
|
39
|
+
self.prompt = ChatPromptTemplate.from_messages(
|
|
40
|
+
[
|
|
41
|
+
("system", prompt if prompt else self.default_prompt),
|
|
42
|
+
(
|
|
43
|
+
"user",
|
|
44
|
+
[
|
|
45
|
+
{
|
|
46
|
+
"type": "image_url",
|
|
47
|
+
"image_url": {"url": "data:image/jpeg;base64,{image_data}"},
|
|
48
|
+
}
|
|
49
|
+
],
|
|
50
|
+
),
|
|
51
|
+
]
|
|
52
|
+
)
|
|
53
|
+
self.chain = self.prompt | self.model
|
|
54
|
+
self.transform = PIL_to_base64
|
|
55
|
+
|
|
56
|
+
def encode_image(self, images):
|
|
57
|
+
captions = []
|
|
58
|
+
for image in images:
|
|
59
|
+
response = self.chain.invoke({"image_data": image})
|
|
60
|
+
captions.append(response.content)
|
|
61
|
+
return captions
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class CLIPModel(Embeddings):
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
model_name: str,
|
|
68
|
+
revision: str = "master",
|
|
69
|
+
hub="modelscope",
|
|
70
|
+
device="cpu",
|
|
71
|
+
):
|
|
72
|
+
self.device = device
|
|
73
|
+
self.model_name = model_name
|
|
74
|
+
self.revision = revision
|
|
75
|
+
|
|
76
|
+
# Download the model if it doesn't exist locally
|
|
77
|
+
if not os.path.exists(model_name) and hub == "modelscope":
|
|
78
|
+
model_name = download_model(self.model_name, self.revision)
|
|
79
|
+
|
|
80
|
+
# Load the model and processor
|
|
81
|
+
self.model = AutoModel.from_pretrained(model_name).to(self.device)
|
|
82
|
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
83
|
+
self.transform = self.processor.image_processor
|
|
84
|
+
self.tokenizer = self.processor.tokenizer
|
|
85
|
+
|
|
86
|
+
def encode_text(self, batch_texts: List[str] | List[List[str]]):
|
|
87
|
+
if isinstance(batch_texts[0], list):
|
|
88
|
+
batch_texts = [
|
|
89
|
+
text for _, texts in enumerate(batch_texts) for text in texts
|
|
90
|
+
]
|
|
91
|
+
# Ensure that the input texts are within the token limit
|
|
92
|
+
max_length = self.tokenizer.model_max_length
|
|
93
|
+
if not max_length or max_length > 0xFFFFFF:
|
|
94
|
+
max_length = 512
|
|
95
|
+
encoded_inputs = self.tokenizer(
|
|
96
|
+
text=batch_texts,
|
|
97
|
+
max_length=max_length,
|
|
98
|
+
padding=True,
|
|
99
|
+
truncation=True,
|
|
100
|
+
return_tensors="pt",
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
inputs = {k: v.to(self.device) for k, v in encoded_inputs.items()}
|
|
104
|
+
|
|
105
|
+
with torch.no_grad():
|
|
106
|
+
text_features = self.model.get_text_features(**inputs)
|
|
107
|
+
text_features = F.normalize(text_features, p=2, dim=-1)
|
|
108
|
+
return text_features
|
|
109
|
+
|
|
110
|
+
def encode_image(self, image):
|
|
111
|
+
batch_images = torch.stack([d["pixel_values"][0] for d in image])
|
|
112
|
+
batch_images = batch_images.to(self.device)
|
|
113
|
+
with torch.no_grad():
|
|
114
|
+
image_features = self.model.get_image_features(batch_images)
|
|
115
|
+
image_features = F.normalize(image_features, p=2, dim=-1)
|
|
116
|
+
return image_features
|
|
117
|
+
|
|
118
|
+
def embed_documents(self, texts):
|
|
119
|
+
text_features = self.encode_text(texts)
|
|
120
|
+
return text_features.cpu().numpy().tolist()
|
|
121
|
+
|
|
122
|
+
def embed_query(self, text):
|
|
123
|
+
text_features = self.encode_text([text])
|
|
124
|
+
return text_features.cpu().numpy().tolist()[0]
|
|
125
|
+
|
|
126
|
+
def embed_image(self, uris: List[str]):
|
|
127
|
+
# read image and transform
|
|
128
|
+
images = [Image.open(image_path) for image_path in uris]
|
|
129
|
+
transformed_images = [
|
|
130
|
+
self.transform(
|
|
131
|
+
image,
|
|
132
|
+
return_tensors="pt",
|
|
133
|
+
)
|
|
134
|
+
for image in images
|
|
135
|
+
]
|
|
136
|
+
image_features = self.encode_image(transformed_images)
|
|
137
|
+
return image_features.cpu().numpy().tolist()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
if __name__ == "__main__":
|
|
141
|
+
model = CLIPModel("AI-ModelScope/chinese-clip-vit-large-patch14-336px")
|
|
142
|
+
model.embed_image(
|
|
143
|
+
[
|
|
144
|
+
"custom_eval/multimodal/images/AMNH.jpg",
|
|
145
|
+
"custom_eval/multimodal/images/AMNH.jpg",
|
|
146
|
+
]
|
|
147
|
+
)
|
|
148
|
+
model.encode_text(["我喜欢吃饭" * 1000])
|
|
149
|
+
print("done")
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from typing import List, Optional, Union, Dict
|
|
4
|
+
from sentence_transformers import models
|
|
5
|
+
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
|
6
|
+
from sentence_transformers.cross_encoder import CrossEncoder
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from evalscope.backend.rag_eval.utils.tools import download_model
|
|
9
|
+
from evalscope.utils.logger import get_logger
|
|
10
|
+
from langchain_core.embeddings import Embeddings
|
|
11
|
+
|
|
12
|
+
logger = get_logger()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseModel(Embeddings):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model_name_or_path: str,
|
|
19
|
+
max_seq_length: int = 512,
|
|
20
|
+
prompt: str = '',
|
|
21
|
+
revision: Optional[str] = None,
|
|
22
|
+
**kwargs,
|
|
23
|
+
):
|
|
24
|
+
self.model_name_or_path = model_name_or_path
|
|
25
|
+
self.max_seq_length = max_seq_length
|
|
26
|
+
self.model_kwargs = kwargs.pop('model_kwargs', {})
|
|
27
|
+
self.model_kwargs['trust_remote_code'] = True
|
|
28
|
+
|
|
29
|
+
self.config_kwargs = kwargs.pop('config_kwargs', {})
|
|
30
|
+
self.config_kwargs['trust_remote_code'] = True
|
|
31
|
+
|
|
32
|
+
self.encode_kwargs = kwargs.pop('encode_kwargs', {})
|
|
33
|
+
self.encode_kwargs['convert_to_tensor'] = True
|
|
34
|
+
|
|
35
|
+
self.prompt = prompt
|
|
36
|
+
self.revision = revision
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def mteb_model_meta(self):
|
|
40
|
+
"""Model metadata for MTEB (Multilingual Task Embeddings Benchmark)"""
|
|
41
|
+
from mteb import ModelMeta
|
|
42
|
+
|
|
43
|
+
return ModelMeta(
|
|
44
|
+
name=os.path.basename(self.model_name_or_path),
|
|
45
|
+
revision=self.revision,
|
|
46
|
+
languages=None,
|
|
47
|
+
release_date=None,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
51
|
+
"""Embed search docs. Compact langchain.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
texts: List of text to embed.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
List of embeddings.
|
|
58
|
+
"""
|
|
59
|
+
return self.encode_corpus(texts).tolist()
|
|
60
|
+
|
|
61
|
+
def embed_query(self, text: str) -> List[float]:
|
|
62
|
+
"""Embed query text. Compact langchain.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
text: Text to embed.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Embedding.
|
|
69
|
+
"""
|
|
70
|
+
return self.encode_queries(text).tolist()
|
|
71
|
+
|
|
72
|
+
def encode(self, texts: Union[str, List[str]], **kwargs) -> List[List[float]]:
|
|
73
|
+
"""Embed text."""
|
|
74
|
+
raise NotImplementedError
|
|
75
|
+
|
|
76
|
+
def encode_queries(self, queries: List[str], **kwargs) -> list[torch.Tensor]:
|
|
77
|
+
"""Embed query text. Compact mteb."""
|
|
78
|
+
raise NotImplementedError
|
|
79
|
+
|
|
80
|
+
def encode_corpus(self, corpus: List[str] | List[Dict[str, str]], **kwargs) -> list[torch.Tensor]:
|
|
81
|
+
"""Embed search docs . Compact mteb."""
|
|
82
|
+
raise NotImplementedError
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class SentenceTransformerModel(BaseModel):
|
|
86
|
+
def __init__(
|
|
87
|
+
self, model_name_or_path: str, pooling_mode: Optional[str] = None, **kwargs
|
|
88
|
+
):
|
|
89
|
+
super().__init__(model_name_or_path, **kwargs)
|
|
90
|
+
|
|
91
|
+
if not pooling_mode:
|
|
92
|
+
self.model = SentenceTransformer(
|
|
93
|
+
self.model_name_or_path,
|
|
94
|
+
config_kwargs=self.config_kwargs,
|
|
95
|
+
model_kwargs=self.model_kwargs,
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
word_embedding_model = models.Transformer(
|
|
99
|
+
self.model_name_or_path,
|
|
100
|
+
config_args=self.config_kwargs,
|
|
101
|
+
model_args=self.model_kwargs,
|
|
102
|
+
)
|
|
103
|
+
pooling_model = models.Pooling(
|
|
104
|
+
word_embedding_model.get_word_embedding_dimension(),
|
|
105
|
+
pooling_mode=pooling_mode,
|
|
106
|
+
)
|
|
107
|
+
self.model = SentenceTransformer(
|
|
108
|
+
modules=[word_embedding_model, pooling_model],
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
self.model.max_seq_length = self.max_seq_length
|
|
112
|
+
|
|
113
|
+
def encode(self, texts: Union[str, List[str]], prompt=None, **kwargs) -> List[torch.Tensor]:
|
|
114
|
+
kwargs.pop('prompt_name', '') # remove prompt name, use prompt
|
|
115
|
+
self.encode_kwargs.update(kwargs)
|
|
116
|
+
|
|
117
|
+
embeddings = self.model.encode(texts, prompt=prompt, **self.encode_kwargs)
|
|
118
|
+
assert isinstance(embeddings, Tensor)
|
|
119
|
+
return embeddings.cpu().detach()
|
|
120
|
+
|
|
121
|
+
def encode_queries(self, queries, **kwargs):
|
|
122
|
+
return self.encode(queries, prompt=self.prompt)
|
|
123
|
+
|
|
124
|
+
def encode_corpus(self, corpus, **kwargs):
|
|
125
|
+
if isinstance(corpus[0], dict):
|
|
126
|
+
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
|
|
127
|
+
else:
|
|
128
|
+
input_texts = corpus
|
|
129
|
+
return self.encode(input_texts)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class CrossEncoderModel(BaseModel):
|
|
133
|
+
def __init__(self, model_name_or_path: str, **kwargs):
|
|
134
|
+
super().__init__(model_name_or_path, **kwargs)
|
|
135
|
+
self.model = CrossEncoder(
|
|
136
|
+
self.model_name_or_path,
|
|
137
|
+
trust_remote_code=True,
|
|
138
|
+
max_length=self.max_seq_length,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def predict(self, sentences: List[List[str]], **kwargs) -> List[List[float]]:
|
|
142
|
+
self.encode_kwargs.update(kwargs)
|
|
143
|
+
|
|
144
|
+
if len(sentences[0]) == 3: # Note: For mteb retrieval task
|
|
145
|
+
processed_sentences = []
|
|
146
|
+
for query, docs, instruction in sentences:
|
|
147
|
+
if isinstance(docs, dict):
|
|
148
|
+
docs = docs['text']
|
|
149
|
+
processed_sentences.append((self.prompt + query, docs))
|
|
150
|
+
sentences = processed_sentences
|
|
151
|
+
embeddings = self.model.predict(sentences, **self.encode_kwargs)
|
|
152
|
+
assert isinstance(embeddings, Tensor)
|
|
153
|
+
return embeddings
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class EmbeddingModel:
|
|
157
|
+
"""Custom embeddings"""
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def load(
|
|
161
|
+
model_name_or_path: str = '',
|
|
162
|
+
is_cross_encoder: bool = False,
|
|
163
|
+
hub: str = 'modelscope',
|
|
164
|
+
revision: Optional[str] = 'master',
|
|
165
|
+
**kwargs,
|
|
166
|
+
):
|
|
167
|
+
# If model path does not exist and hub is 'modelscope', download the model
|
|
168
|
+
if not os.path.exists(model_name_or_path) and hub == 'modelscope':
|
|
169
|
+
model_name_or_path = download_model(model_name_or_path, revision)
|
|
170
|
+
|
|
171
|
+
# Return different model instances based on whether it is a cross-encoder and pooling mode
|
|
172
|
+
if is_cross_encoder:
|
|
173
|
+
return CrossEncoderModel(
|
|
174
|
+
model_name_or_path,
|
|
175
|
+
revision=revision,
|
|
176
|
+
**kwargs,
|
|
177
|
+
)
|
|
178
|
+
else:
|
|
179
|
+
return SentenceTransformerModel(
|
|
180
|
+
model_name_or_path,
|
|
181
|
+
revision=revision,
|
|
182
|
+
**kwargs,
|
|
183
|
+
)
|