isage-middleware 0.2.4.3__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
- isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
- isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
- isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
- sage/middleware/__init__.py +59 -0
- sage/middleware/_version.py +6 -0
- sage/middleware/components/__init__.py +30 -0
- sage/middleware/components/extensions_compat.py +141 -0
- sage/middleware/components/sage_db/__init__.py +116 -0
- sage/middleware/components/sage_db/backend.py +136 -0
- sage/middleware/components/sage_db/service.py +15 -0
- sage/middleware/components/sage_flow/__init__.py +76 -0
- sage/middleware/components/sage_flow/python/__init__.py +14 -0
- sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
- sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
- sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
- sage/middleware/components/sage_flow/service.py +14 -0
- sage/middleware/components/sage_mem/__init__.py +83 -0
- sage/middleware/components/sage_sias/__init__.py +59 -0
- sage/middleware/components/sage_sias/continual_learner.py +184 -0
- sage/middleware/components/sage_sias/coreset_selector.py +302 -0
- sage/middleware/components/sage_sias/types.py +94 -0
- sage/middleware/components/sage_tsdb/__init__.py +81 -0
- sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
- sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
- sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
- sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
- sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
- sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
- sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
- sage/middleware/components/sage_tsdb/service.py +17 -0
- sage/middleware/components/vector_stores/__init__.py +25 -0
- sage/middleware/components/vector_stores/chroma.py +483 -0
- sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
- sage/middleware/components/vector_stores/milvus.py +677 -0
- sage/middleware/operators/__init__.py +56 -0
- sage/middleware/operators/agent/__init__.py +24 -0
- sage/middleware/operators/agent/planning/__init__.py +5 -0
- sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
- sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
- sage/middleware/operators/agent/planning/router.py +107 -0
- sage/middleware/operators/agent/runtime.py +296 -0
- sage/middleware/operators/agentic/__init__.py +41 -0
- sage/middleware/operators/agentic/config.py +254 -0
- sage/middleware/operators/agentic/planning_operator.py +125 -0
- sage/middleware/operators/agentic/refined_searcher.py +132 -0
- sage/middleware/operators/agentic/runtime.py +241 -0
- sage/middleware/operators/agentic/timing_operator.py +125 -0
- sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
- sage/middleware/operators/context/__init__.py +17 -0
- sage/middleware/operators/context/critic_evaluation.py +16 -0
- sage/middleware/operators/context/model_context.py +565 -0
- sage/middleware/operators/context/quality_label.py +12 -0
- sage/middleware/operators/context/search_query_results.py +61 -0
- sage/middleware/operators/context/search_result.py +42 -0
- sage/middleware/operators/context/search_session.py +79 -0
- sage/middleware/operators/filters/__init__.py +26 -0
- sage/middleware/operators/filters/context_sink.py +387 -0
- sage/middleware/operators/filters/context_source.py +376 -0
- sage/middleware/operators/filters/evaluate_filter.py +83 -0
- sage/middleware/operators/filters/tool_filter.py +74 -0
- sage/middleware/operators/llm/__init__.py +18 -0
- sage/middleware/operators/llm/sagellm_generator.py +432 -0
- sage/middleware/operators/rag/__init__.py +147 -0
- sage/middleware/operators/rag/arxiv.py +331 -0
- sage/middleware/operators/rag/chunk.py +13 -0
- sage/middleware/operators/rag/document_loaders.py +23 -0
- sage/middleware/operators/rag/evaluate.py +658 -0
- sage/middleware/operators/rag/generator.py +340 -0
- sage/middleware/operators/rag/index_builder/__init__.py +48 -0
- sage/middleware/operators/rag/index_builder/builder.py +363 -0
- sage/middleware/operators/rag/index_builder/manifest.py +101 -0
- sage/middleware/operators/rag/index_builder/storage.py +131 -0
- sage/middleware/operators/rag/pipeline.py +46 -0
- sage/middleware/operators/rag/profiler.py +59 -0
- sage/middleware/operators/rag/promptor.py +400 -0
- sage/middleware/operators/rag/refiner.py +231 -0
- sage/middleware/operators/rag/reranker.py +364 -0
- sage/middleware/operators/rag/retriever.py +1308 -0
- sage/middleware/operators/rag/searcher.py +37 -0
- sage/middleware/operators/rag/types.py +28 -0
- sage/middleware/operators/rag/writer.py +80 -0
- sage/middleware/operators/tools/__init__.py +71 -0
- sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
- sage/middleware/operators/tools/arxiv_searcher.py +102 -0
- sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
- sage/middleware/operators/tools/image_captioner.py +104 -0
- sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
- sage/middleware/operators/tools/searcher_tool.py +514 -0
- sage/middleware/operators/tools/text_detector.py +185 -0
- sage/middleware/operators/tools/url_text_extractor.py +104 -0
- sage/middleware/py.typed +2 -0
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from transformers import (
|
|
3
|
+
AutoModelForCausalLM,
|
|
4
|
+
AutoModelForSequenceClassification,
|
|
5
|
+
AutoTokenizer,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from sage.common.core.functions import MapFunction as MapOperator
|
|
9
|
+
from sage.libs.rag.types import (
|
|
10
|
+
RAGInput,
|
|
11
|
+
RAGResponse,
|
|
12
|
+
create_rag_response,
|
|
13
|
+
extract_query,
|
|
14
|
+
extract_results,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BGEReranker(MapOperator):
|
|
19
|
+
"""
|
|
20
|
+
A reranker that uses the BAAI/bge-reranker-v2-m3 model to reorder a list of retrieved documents.
|
|
21
|
+
The model assigns relevance scores to the documents and ranks them accordingly.
|
|
22
|
+
|
|
23
|
+
Input: A tuple of (query, List[retrieved_documents])
|
|
24
|
+
Output: A tuple of (query, List[reranked_documents_with_scores])
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
logger: Logger for logging error and information messages.
|
|
28
|
+
config: Configuration dictionary containing reranker settings (model name, top_k, etc.).
|
|
29
|
+
device: Device ('cuda' or 'cpu') where the model will be loaded.
|
|
30
|
+
tokenizer: Tokenizer used to preprocess input queries and documents.
|
|
31
|
+
model: The pre-trained reranking model.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config, **kwargs):
|
|
35
|
+
super().__init__(**kwargs)
|
|
36
|
+
"""
|
|
37
|
+
Initializes the BGEReranker with configuration settings and loads the model.
|
|
38
|
+
|
|
39
|
+
:param config: Dictionary containing configuration options, including model name and device settings.
|
|
40
|
+
"""
|
|
41
|
+
self.config = config
|
|
42
|
+
self.device = (
|
|
43
|
+
"cuda" if torch.cuda.is_available() else "cpu"
|
|
44
|
+
) # Set device to GPU if available, otherwise CPU
|
|
45
|
+
|
|
46
|
+
# Load tokenizer and model using the provided model name
|
|
47
|
+
self.tokenizer, self.model = self._load_model(self.config["model_name"])
|
|
48
|
+
self.model = self.model.to(self.device)
|
|
49
|
+
self.model.eval() # Set the model to evaluation mode
|
|
50
|
+
|
|
51
|
+
def _load_model(self, model_name: str):
|
|
52
|
+
"""
|
|
53
|
+
Loads the tokenizer and model for the reranker.
|
|
54
|
+
|
|
55
|
+
:param model_name: Name of the pre-trained model to load.
|
|
56
|
+
:return: Tuple containing the tokenizer and the model.
|
|
57
|
+
"""
|
|
58
|
+
try:
|
|
59
|
+
self.logger.info(f"Loading reranker: {model_name}")
|
|
60
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name) # Load the tokenizer
|
|
61
|
+
model = AutoModelForSequenceClassification.from_pretrained(model_name) # Load the model
|
|
62
|
+
return tokenizer, model
|
|
63
|
+
except Exception as e:
|
|
64
|
+
self.logger.error(f"Failed to load model {model_name}: {str(e)}")
|
|
65
|
+
raise RuntimeError(f"Model loading failed: {str(e)}")
|
|
66
|
+
|
|
67
|
+
def execute(self, data: RAGInput) -> RAGResponse:
|
|
68
|
+
"""
|
|
69
|
+
Executes the reranking process:
|
|
70
|
+
1. Unpacks the input data (query and list of documents).
|
|
71
|
+
2. Generates query-document pairs.
|
|
72
|
+
3. Calculates relevance scores using the model.
|
|
73
|
+
4. Sorts documents based on their relevance scores.
|
|
74
|
+
|
|
75
|
+
:param data: RAGInput - standardized input format
|
|
76
|
+
:return: RAGResponse containing {"query": str, "results": List[str]}
|
|
77
|
+
"""
|
|
78
|
+
try:
|
|
79
|
+
# 使用标准化函数提取数据
|
|
80
|
+
query = extract_query(data)
|
|
81
|
+
doc_set = extract_results(data)
|
|
82
|
+
|
|
83
|
+
if not query:
|
|
84
|
+
self.logger.error("Missing 'query' field in input")
|
|
85
|
+
return create_rag_response("", [])
|
|
86
|
+
return {"query": "", "results": []}
|
|
87
|
+
|
|
88
|
+
top_k = self.config.get("topk") or self.config.get(
|
|
89
|
+
"top_k", 3
|
|
90
|
+
) # Get the top-k parameter for reranking
|
|
91
|
+
|
|
92
|
+
# Handle empty document set case
|
|
93
|
+
if not doc_set:
|
|
94
|
+
print("BGEReranker received empty document set, returning empty results")
|
|
95
|
+
# 统一返回 dict 格式
|
|
96
|
+
return create_rag_response(query, [])
|
|
97
|
+
|
|
98
|
+
# Generate query-document pairs for scoring
|
|
99
|
+
pairs = [(query, doc) for doc in doc_set]
|
|
100
|
+
|
|
101
|
+
# Tokenize the pairs and move inputs to the appropriate device
|
|
102
|
+
raw_inputs = self.tokenizer(
|
|
103
|
+
pairs,
|
|
104
|
+
padding=True,
|
|
105
|
+
truncation=True,
|
|
106
|
+
max_length=512,
|
|
107
|
+
return_tensors="pt",
|
|
108
|
+
)
|
|
109
|
+
inputs = {
|
|
110
|
+
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
|
111
|
+
for k, v in raw_inputs.items()
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
# Perform inference and calculate scores
|
|
115
|
+
scores = self.model(**inputs).logits.view(-1).float()
|
|
116
|
+
|
|
117
|
+
# Create a list of scored documents
|
|
118
|
+
scored_docs = [
|
|
119
|
+
{"text": doc, "relevance_score": score}
|
|
120
|
+
for doc, score in zip(doc_set, scores, strict=False)
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
# Sort the documents by relevance score in descending order
|
|
124
|
+
reranked_docs = sorted(scored_docs, key=lambda x: x["relevance_score"], reverse=True)[
|
|
125
|
+
:top_k
|
|
126
|
+
]
|
|
127
|
+
reranked_docs_list = [doc["text"] for doc in reranked_docs]
|
|
128
|
+
self.logger.info(
|
|
129
|
+
f"\033[32m[ {self.__class__.__name__}]: Rerank Results: {reranked_docs_list}\033[0m "
|
|
130
|
+
)
|
|
131
|
+
self.logger.debug(
|
|
132
|
+
f"Top score: {reranked_docs[0]['relevance_score'] if reranked_docs else 'N/A'}"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
print(f"Rerank Results: {reranked_docs_list}")
|
|
136
|
+
|
|
137
|
+
except Exception as e:
|
|
138
|
+
raise RuntimeError(f"BGEReranker error: {str(e)}")
|
|
139
|
+
|
|
140
|
+
# 统一返回标准格式
|
|
141
|
+
return create_rag_response(query, reranked_docs_list)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class LLMbased_Reranker(MapOperator):
|
|
145
|
+
"""
|
|
146
|
+
A reranker that uses the BAAI/bge-reranker-v2-gemma model to determine if a retrieved document contains an answer to a given query.
|
|
147
|
+
It scores the documents with 'Yes' or 'No' predictions based on whether the document answers the query.
|
|
148
|
+
|
|
149
|
+
Input: A tuple of (query, List[retrieved_documents])
|
|
150
|
+
Output: A tuple of (query, List[reranked_documents_with_scores])
|
|
151
|
+
|
|
152
|
+
Attributes:
|
|
153
|
+
logger: Logger for logging error and information messages.
|
|
154
|
+
config: Configuration dictionary containing reranker settings (model name, top_k, etc.).
|
|
155
|
+
device: Device ('cuda' or 'cpu') where the model will be loaded.
|
|
156
|
+
tokenizer: Tokenizer used to preprocess input queries and documents.
|
|
157
|
+
model: The pre-trained reranking model.
|
|
158
|
+
yes_loc: Token ID representing 'Yes' (used for scoring).
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
def __init__(self, config, model_name: str = "BAAI/bge-reranker-v2-gemma"):
|
|
162
|
+
"""
|
|
163
|
+
Initializes the LLMbased_Reranker with configuration settings and loads the model.
|
|
164
|
+
|
|
165
|
+
:param config: Dictionary containing configuration options, including model name and device settings.
|
|
166
|
+
:param model_name: Name of the pre-trained model to load (default is "BAAI/bge-reranker-v2-gemma").
|
|
167
|
+
"""
|
|
168
|
+
super().__init__()
|
|
169
|
+
self.config = config
|
|
170
|
+
self.device = (
|
|
171
|
+
"cuda" if torch.cuda.is_available() else "cpu"
|
|
172
|
+
) # Set device to GPU if available, otherwise CPU
|
|
173
|
+
|
|
174
|
+
# Load tokenizer and model using the provided model name
|
|
175
|
+
self.tokenizer, self.model = self._load_model(model_name)
|
|
176
|
+
self.model = self.model.to(self.device) # type: ignore[arg-type]
|
|
177
|
+
|
|
178
|
+
# Get the token ID for the 'Yes' token (used for classification)
|
|
179
|
+
self.yes_loc = self.tokenizer("Yes", add_special_tokens=False)["input_ids"][0]
|
|
180
|
+
|
|
181
|
+
def _load_model(self, model_name: str):
|
|
182
|
+
"""
|
|
183
|
+
Loads the tokenizer and model for the reranker.
|
|
184
|
+
|
|
185
|
+
:param model_name: Name of the pre-trained model to load.
|
|
186
|
+
:return: Tuple containing the tokenizer and the model.
|
|
187
|
+
"""
|
|
188
|
+
try:
|
|
189
|
+
self.logger.info(f"Loading reranker: {model_name}")
|
|
190
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name) # Load the tokenizer
|
|
191
|
+
model = AutoModelForCausalLM.from_pretrained(model_name) # Load the model
|
|
192
|
+
return tokenizer, model
|
|
193
|
+
except Exception as e:
|
|
194
|
+
self.logger.error(f"Failed to load model {model_name}: {str(e)}")
|
|
195
|
+
raise RuntimeError(f"Model loading failed: {str(e)}")
|
|
196
|
+
|
|
197
|
+
def get_inputs(self, pairs, tokenizer, prompt=None, max_length=1024):
|
|
198
|
+
"""
|
|
199
|
+
Prepares the input for the model, including the prompt and the query-document pairs.
|
|
200
|
+
|
|
201
|
+
:param pairs: List of query-document pairs.
|
|
202
|
+
:param tokenizer: The tokenizer used to process the input data.
|
|
203
|
+
:param prompt: Optional prompt to guide the model (defaults to a generic query-passage prompt).
|
|
204
|
+
:param max_length: Maximum length of the tokenized input sequences.
|
|
205
|
+
:return: A tensor of tokenized inputs, ready for model inference.
|
|
206
|
+
"""
|
|
207
|
+
if prompt is None:
|
|
208
|
+
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
|
|
209
|
+
|
|
210
|
+
sep = "\n"
|
|
211
|
+
prompt_inputs = tokenizer(prompt, return_tensors=None, add_special_tokens=False)[
|
|
212
|
+
"input_ids"
|
|
213
|
+
]
|
|
214
|
+
sep_inputs = tokenizer(sep, return_tensors=None, add_special_tokens=False)["input_ids"]
|
|
215
|
+
|
|
216
|
+
inputs = []
|
|
217
|
+
for query, passage in pairs:
|
|
218
|
+
query_inputs = tokenizer(
|
|
219
|
+
f"A: {query}",
|
|
220
|
+
return_tensors=None,
|
|
221
|
+
add_special_tokens=False,
|
|
222
|
+
max_length=max_length * 3 // 4,
|
|
223
|
+
truncation=True,
|
|
224
|
+
)
|
|
225
|
+
passage_inputs = tokenizer(
|
|
226
|
+
f"B: {passage}",
|
|
227
|
+
return_tensors=None,
|
|
228
|
+
add_special_tokens=False,
|
|
229
|
+
max_length=max_length,
|
|
230
|
+
truncation=True,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
item = tokenizer.prepare_for_model(
|
|
234
|
+
[tokenizer.bos_token_id] + query_inputs["input_ids"],
|
|
235
|
+
sep_inputs + passage_inputs["input_ids"],
|
|
236
|
+
truncation="only_second",
|
|
237
|
+
max_length=max_length,
|
|
238
|
+
padding=False,
|
|
239
|
+
return_attention_mask=False,
|
|
240
|
+
return_token_type_ids=False,
|
|
241
|
+
add_special_tokens=False,
|
|
242
|
+
)
|
|
243
|
+
item["input_ids"] = item["input_ids"] + sep_inputs + prompt_inputs
|
|
244
|
+
item["attention_mask"] = [1] * len(item["input_ids"])
|
|
245
|
+
inputs.append(item)
|
|
246
|
+
|
|
247
|
+
return tokenizer.pad(
|
|
248
|
+
inputs,
|
|
249
|
+
padding=True,
|
|
250
|
+
max_length=max_length + len(sep_inputs) + len(prompt_inputs),
|
|
251
|
+
pad_to_multiple_of=8,
|
|
252
|
+
return_tensors="pt",
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# @torch.inference_mode()
|
|
256
|
+
def execute(self, data: RAGInput) -> RAGResponse:
|
|
257
|
+
"""
|
|
258
|
+
Executes the reranking process:
|
|
259
|
+
1. Unpacks the input data (query and list of documents).
|
|
260
|
+
2. Generates query-document pairs for classification.
|
|
261
|
+
3. Calculates relevance scores based on 'Yes'/'No' predictions.
|
|
262
|
+
4. Sorts documents based on their relevance scores.
|
|
263
|
+
|
|
264
|
+
:param data: RAGInput - standardized input format
|
|
265
|
+
:return: RAGResponse containing {"query": str, "results": List[str]}
|
|
266
|
+
"""
|
|
267
|
+
try:
|
|
268
|
+
# 使用标准化函数提取数据
|
|
269
|
+
query = extract_query(data)
|
|
270
|
+
doc_set = extract_results(data)
|
|
271
|
+
|
|
272
|
+
if not query:
|
|
273
|
+
self.logger.error("Missing 'query' field in input")
|
|
274
|
+
return create_rag_response("", [])
|
|
275
|
+
|
|
276
|
+
doc_set = [doc_set] # Wrap doc_set in a list for processing
|
|
277
|
+
top_k = self.config["topk"] # Get the top-k parameter for reranking
|
|
278
|
+
emit_docs = [] # Initialize the list to store reranked documents
|
|
279
|
+
|
|
280
|
+
for retrieved_docs in doc_set:
|
|
281
|
+
# Generate query-document pairs for classification
|
|
282
|
+
pairs = [[query, doc] for doc in retrieved_docs]
|
|
283
|
+
|
|
284
|
+
# Tokenize the pairs and move inputs to the appropriate device
|
|
285
|
+
with torch.no_grad():
|
|
286
|
+
raw_inputs = self.get_inputs(pairs, self.tokenizer)
|
|
287
|
+
inputs = {k: v.to(self.device) for k, v in raw_inputs.items()}
|
|
288
|
+
|
|
289
|
+
scores = (
|
|
290
|
+
self.model(**inputs, return_dict=True)
|
|
291
|
+
.logits[:, -1, self.yes_loc]
|
|
292
|
+
.view(-1)
|
|
293
|
+
.float()
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Create a list of scored documents
|
|
297
|
+
scored_docs = [
|
|
298
|
+
{"text": doc, "relevance_score": score}
|
|
299
|
+
for doc, score in zip(retrieved_docs, scores, strict=False)
|
|
300
|
+
]
|
|
301
|
+
|
|
302
|
+
# Sort the documents by relevance score in descending order
|
|
303
|
+
reranked_docs = sorted(
|
|
304
|
+
scored_docs, key=lambda x: x["relevance_score"], reverse=True
|
|
305
|
+
)[:top_k]
|
|
306
|
+
reranked_docs_list = [doc["text"] for doc in reranked_docs]
|
|
307
|
+
emit_docs.append(reranked_docs_list)
|
|
308
|
+
self.logger.info(
|
|
309
|
+
f"\033[32m[ {self.__class__.__name__}]: Rerank Results: {reranked_docs_list}\033[0m "
|
|
310
|
+
)
|
|
311
|
+
self.logger.debug(
|
|
312
|
+
f"Top score: {reranked_docs[0]['relevance_score'] if reranked_docs else 'N/A'}"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
except Exception as e:
|
|
316
|
+
self.logger.error(f"{str(e)} when RerankerFuncton")
|
|
317
|
+
raise RuntimeError(f"Reranker error: {str(e)}")
|
|
318
|
+
|
|
319
|
+
emit_docs = emit_docs[0] # Only return the first set of reranked documents
|
|
320
|
+
|
|
321
|
+
# 统一返回标准格式
|
|
322
|
+
return create_rag_response(query, emit_docs)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
# if __name__ == '__main__':
|
|
326
|
+
|
|
327
|
+
# # 设置配置
|
|
328
|
+
# config1 = {
|
|
329
|
+
# "reranker": {
|
|
330
|
+
# "model_name":"BAAI/bge-reranker-v2-m3",
|
|
331
|
+
# "top_k": 3
|
|
332
|
+
# }
|
|
333
|
+
# }
|
|
334
|
+
|
|
335
|
+
# config2 = {
|
|
336
|
+
# "reranker": {
|
|
337
|
+
# "model_name":"BAAI/bge-reranker-v2-gemma",
|
|
338
|
+
# "top_k": 3
|
|
339
|
+
# }
|
|
340
|
+
# }
|
|
341
|
+
|
|
342
|
+
# # 创建实例
|
|
343
|
+
# # reranker = BGEReranker(config)
|
|
344
|
+
# reranker = LLMbased_Reranker(config2)
|
|
345
|
+
# # 测试数据
|
|
346
|
+
# query = "What is the capital of France?"
|
|
347
|
+
# docs = [
|
|
348
|
+
# "Paris is the capital of France.",
|
|
349
|
+
# "Berlin is a city in Germany.",
|
|
350
|
+
# "The Eiffel Tower is located in Paris.",
|
|
351
|
+
# "France is a country in Western Europe.",
|
|
352
|
+
# "Madrid is the capital of Spain."
|
|
353
|
+
# ]
|
|
354
|
+
|
|
355
|
+
# # 执行重排
|
|
356
|
+
# input_data = (query, docs)
|
|
357
|
+
# output = reranker.execute(input_data)
|
|
358
|
+
|
|
359
|
+
# # 输出结果
|
|
360
|
+
# result_query, result_docs = output
|
|
361
|
+
# print("Query:", result_query)
|
|
362
|
+
# print("Top-k Re-ranked Documents:")
|
|
363
|
+
# for i, doc in enumerate(result_docs, 1):
|
|
364
|
+
# print(f"{i}. {doc}")
|