isage-middleware 0.2.4.3__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
  2. isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
  3. isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
  4. isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
  5. sage/middleware/__init__.py +59 -0
  6. sage/middleware/_version.py +6 -0
  7. sage/middleware/components/__init__.py +30 -0
  8. sage/middleware/components/extensions_compat.py +141 -0
  9. sage/middleware/components/sage_db/__init__.py +116 -0
  10. sage/middleware/components/sage_db/backend.py +136 -0
  11. sage/middleware/components/sage_db/service.py +15 -0
  12. sage/middleware/components/sage_flow/__init__.py +76 -0
  13. sage/middleware/components/sage_flow/python/__init__.py +14 -0
  14. sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
  15. sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
  16. sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
  17. sage/middleware/components/sage_flow/service.py +14 -0
  18. sage/middleware/components/sage_mem/__init__.py +83 -0
  19. sage/middleware/components/sage_sias/__init__.py +59 -0
  20. sage/middleware/components/sage_sias/continual_learner.py +184 -0
  21. sage/middleware/components/sage_sias/coreset_selector.py +302 -0
  22. sage/middleware/components/sage_sias/types.py +94 -0
  23. sage/middleware/components/sage_tsdb/__init__.py +81 -0
  24. sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
  25. sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
  26. sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
  27. sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
  28. sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
  29. sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
  30. sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
  31. sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
  32. sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
  33. sage/middleware/components/sage_tsdb/service.py +17 -0
  34. sage/middleware/components/vector_stores/__init__.py +25 -0
  35. sage/middleware/components/vector_stores/chroma.py +483 -0
  36. sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
  37. sage/middleware/components/vector_stores/milvus.py +677 -0
  38. sage/middleware/operators/__init__.py +56 -0
  39. sage/middleware/operators/agent/__init__.py +24 -0
  40. sage/middleware/operators/agent/planning/__init__.py +5 -0
  41. sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
  42. sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
  43. sage/middleware/operators/agent/planning/router.py +107 -0
  44. sage/middleware/operators/agent/runtime.py +296 -0
  45. sage/middleware/operators/agentic/__init__.py +41 -0
  46. sage/middleware/operators/agentic/config.py +254 -0
  47. sage/middleware/operators/agentic/planning_operator.py +125 -0
  48. sage/middleware/operators/agentic/refined_searcher.py +132 -0
  49. sage/middleware/operators/agentic/runtime.py +241 -0
  50. sage/middleware/operators/agentic/timing_operator.py +125 -0
  51. sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
  52. sage/middleware/operators/context/__init__.py +17 -0
  53. sage/middleware/operators/context/critic_evaluation.py +16 -0
  54. sage/middleware/operators/context/model_context.py +565 -0
  55. sage/middleware/operators/context/quality_label.py +12 -0
  56. sage/middleware/operators/context/search_query_results.py +61 -0
  57. sage/middleware/operators/context/search_result.py +42 -0
  58. sage/middleware/operators/context/search_session.py +79 -0
  59. sage/middleware/operators/filters/__init__.py +26 -0
  60. sage/middleware/operators/filters/context_sink.py +387 -0
  61. sage/middleware/operators/filters/context_source.py +376 -0
  62. sage/middleware/operators/filters/evaluate_filter.py +83 -0
  63. sage/middleware/operators/filters/tool_filter.py +74 -0
  64. sage/middleware/operators/llm/__init__.py +18 -0
  65. sage/middleware/operators/llm/sagellm_generator.py +432 -0
  66. sage/middleware/operators/rag/__init__.py +147 -0
  67. sage/middleware/operators/rag/arxiv.py +331 -0
  68. sage/middleware/operators/rag/chunk.py +13 -0
  69. sage/middleware/operators/rag/document_loaders.py +23 -0
  70. sage/middleware/operators/rag/evaluate.py +658 -0
  71. sage/middleware/operators/rag/generator.py +340 -0
  72. sage/middleware/operators/rag/index_builder/__init__.py +48 -0
  73. sage/middleware/operators/rag/index_builder/builder.py +363 -0
  74. sage/middleware/operators/rag/index_builder/manifest.py +101 -0
  75. sage/middleware/operators/rag/index_builder/storage.py +131 -0
  76. sage/middleware/operators/rag/pipeline.py +46 -0
  77. sage/middleware/operators/rag/profiler.py +59 -0
  78. sage/middleware/operators/rag/promptor.py +400 -0
  79. sage/middleware/operators/rag/refiner.py +231 -0
  80. sage/middleware/operators/rag/reranker.py +364 -0
  81. sage/middleware/operators/rag/retriever.py +1308 -0
  82. sage/middleware/operators/rag/searcher.py +37 -0
  83. sage/middleware/operators/rag/types.py +28 -0
  84. sage/middleware/operators/rag/writer.py +80 -0
  85. sage/middleware/operators/tools/__init__.py +71 -0
  86. sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
  87. sage/middleware/operators/tools/arxiv_searcher.py +102 -0
  88. sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
  89. sage/middleware/operators/tools/image_captioner.py +104 -0
  90. sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
  91. sage/middleware/operators/tools/searcher_tool.py +514 -0
  92. sage/middleware/operators/tools/text_detector.py +185 -0
  93. sage/middleware/operators/tools/url_text_extractor.py +104 -0
  94. sage/middleware/py.typed +2 -0
@@ -0,0 +1,364 @@
1
+ import torch
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoModelForSequenceClassification,
5
+ AutoTokenizer,
6
+ )
7
+
8
+ from sage.common.core.functions import MapFunction as MapOperator
9
+ from sage.libs.rag.types import (
10
+ RAGInput,
11
+ RAGResponse,
12
+ create_rag_response,
13
+ extract_query,
14
+ extract_results,
15
+ )
16
+
17
+
18
+ class BGEReranker(MapOperator):
19
+ """
20
+ A reranker that uses the BAAI/bge-reranker-v2-m3 model to reorder a list of retrieved documents.
21
+ The model assigns relevance scores to the documents and ranks them accordingly.
22
+
23
+ Input: A tuple of (query, List[retrieved_documents])
24
+ Output: A tuple of (query, List[reranked_documents_with_scores])
25
+
26
+ Attributes:
27
+ logger: Logger for logging error and information messages.
28
+ config: Configuration dictionary containing reranker settings (model name, top_k, etc.).
29
+ device: Device ('cuda' or 'cpu') where the model will be loaded.
30
+ tokenizer: Tokenizer used to preprocess input queries and documents.
31
+ model: The pre-trained reranking model.
32
+ """
33
+
34
+ def __init__(self, config, **kwargs):
35
+ super().__init__(**kwargs)
36
+ """
37
+ Initializes the BGEReranker with configuration settings and loads the model.
38
+
39
+ :param config: Dictionary containing configuration options, including model name and device settings.
40
+ """
41
+ self.config = config
42
+ self.device = (
43
+ "cuda" if torch.cuda.is_available() else "cpu"
44
+ ) # Set device to GPU if available, otherwise CPU
45
+
46
+ # Load tokenizer and model using the provided model name
47
+ self.tokenizer, self.model = self._load_model(self.config["model_name"])
48
+ self.model = self.model.to(self.device)
49
+ self.model.eval() # Set the model to evaluation mode
50
+
51
+ def _load_model(self, model_name: str):
52
+ """
53
+ Loads the tokenizer and model for the reranker.
54
+
55
+ :param model_name: Name of the pre-trained model to load.
56
+ :return: Tuple containing the tokenizer and the model.
57
+ """
58
+ try:
59
+ self.logger.info(f"Loading reranker: {model_name}")
60
+ tokenizer = AutoTokenizer.from_pretrained(model_name) # Load the tokenizer
61
+ model = AutoModelForSequenceClassification.from_pretrained(model_name) # Load the model
62
+ return tokenizer, model
63
+ except Exception as e:
64
+ self.logger.error(f"Failed to load model {model_name}: {str(e)}")
65
+ raise RuntimeError(f"Model loading failed: {str(e)}")
66
+
67
+ def execute(self, data: RAGInput) -> RAGResponse:
68
+ """
69
+ Executes the reranking process:
70
+ 1. Unpacks the input data (query and list of documents).
71
+ 2. Generates query-document pairs.
72
+ 3. Calculates relevance scores using the model.
73
+ 4. Sorts documents based on their relevance scores.
74
+
75
+ :param data: RAGInput - standardized input format
76
+ :return: RAGResponse containing {"query": str, "results": List[str]}
77
+ """
78
+ try:
79
+ # 使用标准化函数提取数据
80
+ query = extract_query(data)
81
+ doc_set = extract_results(data)
82
+
83
+ if not query:
84
+ self.logger.error("Missing 'query' field in input")
85
+ return create_rag_response("", [])
86
+ return {"query": "", "results": []}
87
+
88
+ top_k = self.config.get("topk") or self.config.get(
89
+ "top_k", 3
90
+ ) # Get the top-k parameter for reranking
91
+
92
+ # Handle empty document set case
93
+ if not doc_set:
94
+ print("BGEReranker received empty document set, returning empty results")
95
+ # 统一返回 dict 格式
96
+ return create_rag_response(query, [])
97
+
98
+ # Generate query-document pairs for scoring
99
+ pairs = [(query, doc) for doc in doc_set]
100
+
101
+ # Tokenize the pairs and move inputs to the appropriate device
102
+ raw_inputs = self.tokenizer(
103
+ pairs,
104
+ padding=True,
105
+ truncation=True,
106
+ max_length=512,
107
+ return_tensors="pt",
108
+ )
109
+ inputs = {
110
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
111
+ for k, v in raw_inputs.items()
112
+ }
113
+
114
+ # Perform inference and calculate scores
115
+ scores = self.model(**inputs).logits.view(-1).float()
116
+
117
+ # Create a list of scored documents
118
+ scored_docs = [
119
+ {"text": doc, "relevance_score": score}
120
+ for doc, score in zip(doc_set, scores, strict=False)
121
+ ]
122
+
123
+ # Sort the documents by relevance score in descending order
124
+ reranked_docs = sorted(scored_docs, key=lambda x: x["relevance_score"], reverse=True)[
125
+ :top_k
126
+ ]
127
+ reranked_docs_list = [doc["text"] for doc in reranked_docs]
128
+ self.logger.info(
129
+ f"\033[32m[ {self.__class__.__name__}]: Rerank Results: {reranked_docs_list}\033[0m "
130
+ )
131
+ self.logger.debug(
132
+ f"Top score: {reranked_docs[0]['relevance_score'] if reranked_docs else 'N/A'}"
133
+ )
134
+
135
+ print(f"Rerank Results: {reranked_docs_list}")
136
+
137
+ except Exception as e:
138
+ raise RuntimeError(f"BGEReranker error: {str(e)}")
139
+
140
+ # 统一返回标准格式
141
+ return create_rag_response(query, reranked_docs_list)
142
+
143
+
144
+ class LLMbased_Reranker(MapOperator):
145
+ """
146
+ A reranker that uses the BAAI/bge-reranker-v2-gemma model to determine if a retrieved document contains an answer to a given query.
147
+ It scores the documents with 'Yes' or 'No' predictions based on whether the document answers the query.
148
+
149
+ Input: A tuple of (query, List[retrieved_documents])
150
+ Output: A tuple of (query, List[reranked_documents_with_scores])
151
+
152
+ Attributes:
153
+ logger: Logger for logging error and information messages.
154
+ config: Configuration dictionary containing reranker settings (model name, top_k, etc.).
155
+ device: Device ('cuda' or 'cpu') where the model will be loaded.
156
+ tokenizer: Tokenizer used to preprocess input queries and documents.
157
+ model: The pre-trained reranking model.
158
+ yes_loc: Token ID representing 'Yes' (used for scoring).
159
+ """
160
+
161
+ def __init__(self, config, model_name: str = "BAAI/bge-reranker-v2-gemma"):
162
+ """
163
+ Initializes the LLMbased_Reranker with configuration settings and loads the model.
164
+
165
+ :param config: Dictionary containing configuration options, including model name and device settings.
166
+ :param model_name: Name of the pre-trained model to load (default is "BAAI/bge-reranker-v2-gemma").
167
+ """
168
+ super().__init__()
169
+ self.config = config
170
+ self.device = (
171
+ "cuda" if torch.cuda.is_available() else "cpu"
172
+ ) # Set device to GPU if available, otherwise CPU
173
+
174
+ # Load tokenizer and model using the provided model name
175
+ self.tokenizer, self.model = self._load_model(model_name)
176
+ self.model = self.model.to(self.device) # type: ignore[arg-type]
177
+
178
+ # Get the token ID for the 'Yes' token (used for classification)
179
+ self.yes_loc = self.tokenizer("Yes", add_special_tokens=False)["input_ids"][0]
180
+
181
+ def _load_model(self, model_name: str):
182
+ """
183
+ Loads the tokenizer and model for the reranker.
184
+
185
+ :param model_name: Name of the pre-trained model to load.
186
+ :return: Tuple containing the tokenizer and the model.
187
+ """
188
+ try:
189
+ self.logger.info(f"Loading reranker: {model_name}")
190
+ tokenizer = AutoTokenizer.from_pretrained(model_name) # Load the tokenizer
191
+ model = AutoModelForCausalLM.from_pretrained(model_name) # Load the model
192
+ return tokenizer, model
193
+ except Exception as e:
194
+ self.logger.error(f"Failed to load model {model_name}: {str(e)}")
195
+ raise RuntimeError(f"Model loading failed: {str(e)}")
196
+
197
+ def get_inputs(self, pairs, tokenizer, prompt=None, max_length=1024):
198
+ """
199
+ Prepares the input for the model, including the prompt and the query-document pairs.
200
+
201
+ :param pairs: List of query-document pairs.
202
+ :param tokenizer: The tokenizer used to process the input data.
203
+ :param prompt: Optional prompt to guide the model (defaults to a generic query-passage prompt).
204
+ :param max_length: Maximum length of the tokenized input sequences.
205
+ :return: A tensor of tokenized inputs, ready for model inference.
206
+ """
207
+ if prompt is None:
208
+ prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
209
+
210
+ sep = "\n"
211
+ prompt_inputs = tokenizer(prompt, return_tensors=None, add_special_tokens=False)[
212
+ "input_ids"
213
+ ]
214
+ sep_inputs = tokenizer(sep, return_tensors=None, add_special_tokens=False)["input_ids"]
215
+
216
+ inputs = []
217
+ for query, passage in pairs:
218
+ query_inputs = tokenizer(
219
+ f"A: {query}",
220
+ return_tensors=None,
221
+ add_special_tokens=False,
222
+ max_length=max_length * 3 // 4,
223
+ truncation=True,
224
+ )
225
+ passage_inputs = tokenizer(
226
+ f"B: {passage}",
227
+ return_tensors=None,
228
+ add_special_tokens=False,
229
+ max_length=max_length,
230
+ truncation=True,
231
+ )
232
+
233
+ item = tokenizer.prepare_for_model(
234
+ [tokenizer.bos_token_id] + query_inputs["input_ids"],
235
+ sep_inputs + passage_inputs["input_ids"],
236
+ truncation="only_second",
237
+ max_length=max_length,
238
+ padding=False,
239
+ return_attention_mask=False,
240
+ return_token_type_ids=False,
241
+ add_special_tokens=False,
242
+ )
243
+ item["input_ids"] = item["input_ids"] + sep_inputs + prompt_inputs
244
+ item["attention_mask"] = [1] * len(item["input_ids"])
245
+ inputs.append(item)
246
+
247
+ return tokenizer.pad(
248
+ inputs,
249
+ padding=True,
250
+ max_length=max_length + len(sep_inputs) + len(prompt_inputs),
251
+ pad_to_multiple_of=8,
252
+ return_tensors="pt",
253
+ )
254
+
255
+ # @torch.inference_mode()
256
+ def execute(self, data: RAGInput) -> RAGResponse:
257
+ """
258
+ Executes the reranking process:
259
+ 1. Unpacks the input data (query and list of documents).
260
+ 2. Generates query-document pairs for classification.
261
+ 3. Calculates relevance scores based on 'Yes'/'No' predictions.
262
+ 4. Sorts documents based on their relevance scores.
263
+
264
+ :param data: RAGInput - standardized input format
265
+ :return: RAGResponse containing {"query": str, "results": List[str]}
266
+ """
267
+ try:
268
+ # 使用标准化函数提取数据
269
+ query = extract_query(data)
270
+ doc_set = extract_results(data)
271
+
272
+ if not query:
273
+ self.logger.error("Missing 'query' field in input")
274
+ return create_rag_response("", [])
275
+
276
+ doc_set = [doc_set] # Wrap doc_set in a list for processing
277
+ top_k = self.config["topk"] # Get the top-k parameter for reranking
278
+ emit_docs = [] # Initialize the list to store reranked documents
279
+
280
+ for retrieved_docs in doc_set:
281
+ # Generate query-document pairs for classification
282
+ pairs = [[query, doc] for doc in retrieved_docs]
283
+
284
+ # Tokenize the pairs and move inputs to the appropriate device
285
+ with torch.no_grad():
286
+ raw_inputs = self.get_inputs(pairs, self.tokenizer)
287
+ inputs = {k: v.to(self.device) for k, v in raw_inputs.items()}
288
+
289
+ scores = (
290
+ self.model(**inputs, return_dict=True)
291
+ .logits[:, -1, self.yes_loc]
292
+ .view(-1)
293
+ .float()
294
+ )
295
+
296
+ # Create a list of scored documents
297
+ scored_docs = [
298
+ {"text": doc, "relevance_score": score}
299
+ for doc, score in zip(retrieved_docs, scores, strict=False)
300
+ ]
301
+
302
+ # Sort the documents by relevance score in descending order
303
+ reranked_docs = sorted(
304
+ scored_docs, key=lambda x: x["relevance_score"], reverse=True
305
+ )[:top_k]
306
+ reranked_docs_list = [doc["text"] for doc in reranked_docs]
307
+ emit_docs.append(reranked_docs_list)
308
+ self.logger.info(
309
+ f"\033[32m[ {self.__class__.__name__}]: Rerank Results: {reranked_docs_list}\033[0m "
310
+ )
311
+ self.logger.debug(
312
+ f"Top score: {reranked_docs[0]['relevance_score'] if reranked_docs else 'N/A'}"
313
+ )
314
+
315
+ except Exception as e:
316
+ self.logger.error(f"{str(e)} when RerankerFuncton")
317
+ raise RuntimeError(f"Reranker error: {str(e)}")
318
+
319
+ emit_docs = emit_docs[0] # Only return the first set of reranked documents
320
+
321
+ # 统一返回标准格式
322
+ return create_rag_response(query, emit_docs)
323
+
324
+
325
+ # if __name__ == '__main__':
326
+
327
+ # # 设置配置
328
+ # config1 = {
329
+ # "reranker": {
330
+ # "model_name":"BAAI/bge-reranker-v2-m3",
331
+ # "top_k": 3
332
+ # }
333
+ # }
334
+
335
+ # config2 = {
336
+ # "reranker": {
337
+ # "model_name":"BAAI/bge-reranker-v2-gemma",
338
+ # "top_k": 3
339
+ # }
340
+ # }
341
+
342
+ # # 创建实例
343
+ # # reranker = BGEReranker(config)
344
+ # reranker = LLMbased_Reranker(config2)
345
+ # # 测试数据
346
+ # query = "What is the capital of France?"
347
+ # docs = [
348
+ # "Paris is the capital of France.",
349
+ # "Berlin is a city in Germany.",
350
+ # "The Eiffel Tower is located in Paris.",
351
+ # "France is a country in Western Europe.",
352
+ # "Madrid is the capital of Spain."
353
+ # ]
354
+
355
+ # # 执行重排
356
+ # input_data = (query, docs)
357
+ # output = reranker.execute(input_data)
358
+
359
+ # # 输出结果
360
+ # result_query, result_docs = output
361
+ # print("Query:", result_query)
362
+ # print("Top-k Re-ranked Documents:")
363
+ # for i, doc in enumerate(result_docs, 1):
364
+ # print(f"{i}. {doc}")