isage-rag-benchmark 0.1.0.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. isage_rag_benchmark-0.1.0.1.dist-info/METADATA +63 -0
  2. isage_rag_benchmark-0.1.0.1.dist-info/RECORD +48 -0
  3. isage_rag_benchmark-0.1.0.1.dist-info/WHEEL +5 -0
  4. isage_rag_benchmark-0.1.0.1.dist-info/licenses/LICENSE +21 -0
  5. isage_rag_benchmark-0.1.0.1.dist-info/top_level.txt +1 -0
  6. sage/__init__.py +0 -0
  7. sage/benchmark_rag/__init__.py +16 -0
  8. sage/benchmark_rag/_version.py +4 -0
  9. sage/benchmark_rag/config/config_bm25s.yaml +51 -0
  10. sage/benchmark_rag/config/config_dense_milvus.yaml +61 -0
  11. sage/benchmark_rag/config/config_hf.yaml +43 -0
  12. sage/benchmark_rag/config/config_mixed.yaml +53 -0
  13. sage/benchmark_rag/config/config_monitoring_demo.yaml +59 -0
  14. sage/benchmark_rag/config/config_multiplex.yaml +79 -0
  15. sage/benchmark_rag/config/config_qa_chroma.yaml +51 -0
  16. sage/benchmark_rag/config/config_ray.yaml +57 -0
  17. sage/benchmark_rag/config/config_refiner.yaml +75 -0
  18. sage/benchmark_rag/config/config_rerank.yaml +56 -0
  19. sage/benchmark_rag/config/config_selfrag.yaml +24 -0
  20. sage/benchmark_rag/config/config_source.yaml +30 -0
  21. sage/benchmark_rag/config/config_source_local.yaml +21 -0
  22. sage/benchmark_rag/config/config_sparse_milvus.yaml +49 -0
  23. sage/benchmark_rag/evaluation/__init__.py +10 -0
  24. sage/benchmark_rag/evaluation/benchmark_runner.py +337 -0
  25. sage/benchmark_rag/evaluation/config/benchmark_config.yaml +35 -0
  26. sage/benchmark_rag/evaluation/evaluate_results.py +389 -0
  27. sage/benchmark_rag/implementations/__init__.py +31 -0
  28. sage/benchmark_rag/implementations/pipelines/__init__.py +24 -0
  29. sage/benchmark_rag/implementations/pipelines/qa_bm25_retrieval.py +55 -0
  30. sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval.py +56 -0
  31. sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_chroma.py +71 -0
  32. sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_milvus.py +78 -0
  33. sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_mixed.py +58 -0
  34. sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_ray.py +174 -0
  35. sage/benchmark_rag/implementations/pipelines/qa_hf_model.py +57 -0
  36. sage/benchmark_rag/implementations/pipelines/qa_monitoring_demo.py +139 -0
  37. sage/benchmark_rag/implementations/pipelines/qa_multimodal_fusion.py +318 -0
  38. sage/benchmark_rag/implementations/pipelines/qa_multiplex.py +91 -0
  39. sage/benchmark_rag/implementations/pipelines/qa_refiner.py +91 -0
  40. sage/benchmark_rag/implementations/pipelines/qa_rerank.py +76 -0
  41. sage/benchmark_rag/implementations/pipelines/qa_sparse_retrieval_milvus.py +76 -0
  42. sage/benchmark_rag/implementations/pipelines/selfrag.py +226 -0
  43. sage/benchmark_rag/implementations/tools/__init__.py +17 -0
  44. sage/benchmark_rag/implementations/tools/build_chroma_index.py +261 -0
  45. sage/benchmark_rag/implementations/tools/build_milvus_dense_index.py +86 -0
  46. sage/benchmark_rag/implementations/tools/build_milvus_index.py +59 -0
  47. sage/benchmark_rag/implementations/tools/build_milvus_sparse_index.py +85 -0
  48. sage/benchmark_rag/implementations/tools/loaders/document_loaders.py +42 -0
@@ -0,0 +1,226 @@
1
+ """
2
+ Self-RAG Pipeline Implementation
3
+
4
+ Based on the Self-RAG paper (https://arxiv.org/abs/2310.11511)
5
+ Uses pre-retrieved documents from the dataset to generate answers.
6
+
7
+ This implementation uses the Self-RAG dataset format where each item contains:
8
+ - question: The question to answer
9
+ - answers: Ground truth answers
10
+ - ctxs: Pre-retrieved documents with title and text
11
+ """
12
+
13
+ from typing import Any
14
+
15
+ from sage.common.core import MapFunction
16
+ from sage.kernel.api.local_environment import LocalEnvironment
17
+ from sage.libs.foundation.io.sink import FileSink
18
+ from sage.libs.foundation.io.source import FileSource
19
+
20
+
21
+ class SelfRAGRetriever(MapFunction):
22
+ """
23
+ Self-RAG Retriever - extracts pre-retrieved documents from dataset.
24
+
25
+ Unlike traditional retrievers that perform retrieval, this simply
26
+ extracts the pre-computed retrieved documents from the Self-RAG dataset.
27
+ """
28
+
29
+ def __init__(self, config: dict):
30
+ self.top_k = config.get("top_k", 5)
31
+
32
+ def execute(self, item: dict[str, Any]) -> dict[str, Any]:
33
+ """Extract pre-retrieved documents from the data item."""
34
+ question = item["question"]
35
+ ctxs = item.get("ctxs", [])
36
+
37
+ # Extract top-k documents
38
+ retrieved_docs = []
39
+ for i, ctx in enumerate(ctxs[: self.top_k]):
40
+ if "text" in ctx and ctx["text"].strip():
41
+ doc = {
42
+ "rank": i + 1,
43
+ "title": ctx.get("title", ""),
44
+ "text": ctx["text"],
45
+ "score": ctx.get("score", 1.0),
46
+ }
47
+ retrieved_docs.append(doc)
48
+
49
+ return {
50
+ "question": question,
51
+ "retrieved_docs": retrieved_docs,
52
+ "ground_truth": item.get("answers", []),
53
+ "id": item.get("id", ""),
54
+ }
55
+
56
+
57
+ class SelfRAGPromptor(MapFunction):
58
+ """
59
+ Self-RAG Promptor - builds prompts with retrieved evidence.
60
+
61
+ Constructs prompts in the Self-RAG format with numbered evidence paragraphs.
62
+ """
63
+
64
+ def __init__(self, config: dict):
65
+ self.model_name = config.get("model_name", "mistral")
66
+ self.use_context = config.get("use_context", True)
67
+
68
+ def execute(self, item: dict[str, Any]) -> dict[str, Any]:
69
+ """Build prompt with evidence paragraphs."""
70
+ question = item["question"]
71
+ retrieved_docs = item.get("retrieved_docs", [])
72
+
73
+ # Build evidence context
74
+ context = None
75
+ if self.use_context and retrieved_docs:
76
+ evidences = []
77
+ for doc in retrieved_docs:
78
+ rank = doc["rank"]
79
+ title = doc["title"]
80
+ text = doc["text"]
81
+ evidence = f"[{rank}] {title}\n{text}"
82
+ evidences.append(evidence)
83
+ context = "\n".join(evidences)
84
+
85
+ # Build prompt based on model type
86
+ if self.use_context and context:
87
+ if "llama" in self.model_name.lower():
88
+ prompt = f"[INST]{context}\n{question}[/INST]"
89
+ else:
90
+ prompt = f"<s>[INST]{context}\n{question}[/INST]"
91
+ else:
92
+ if "llama" in self.model_name.lower():
93
+ prompt = f"[INST]{question}[/INST]"
94
+ else:
95
+ prompt = f"### Instruction:\n{question}\n\n### Response:\n"
96
+
97
+ item["prompt"] = prompt
98
+ item["context"] = context
99
+ return item
100
+
101
+
102
+ class SelfRAGGenerator(MapFunction):
103
+ """
104
+ Self-RAG Generator - generates answers using VLLM.
105
+
106
+ Uses vLLM for efficient batch inference.
107
+ """
108
+
109
+ def __init__(self, config: dict):
110
+ self.model_name = config.get("model_name", "mistralai/Mistral-7B-Instruct-v0.1")
111
+
112
+ from vllm import LLM, SamplingParams
113
+
114
+ self.llm = LLM(
115
+ model=self.model_name,
116
+ gpu_memory_utilization=config.get("gpu_memory_utilization", 0.8),
117
+ )
118
+ self.sampling_params = SamplingParams(
119
+ temperature=config.get("temperature", 0),
120
+ max_tokens=config.get("max_tokens", 100),
121
+ )
122
+
123
+ def execute(self, item: dict[str, Any]) -> dict[str, Any]:
124
+ """Generate answer for the question."""
125
+ prompt = item["prompt"]
126
+
127
+ # Generate using vLLM
128
+ outputs = self.llm.generate([prompt], self.sampling_params)
129
+ response = outputs[0].outputs[0].text
130
+
131
+ # Post-process output
132
+ response = self._postprocess(response)
133
+
134
+ item["model_output"] = response
135
+ return item
136
+
137
+ def _postprocess(self, text: str) -> str:
138
+ """Clean up model output."""
139
+ # Take first paragraph
140
+ text = text.split("\n\n")[0]
141
+ # Remove end tokens
142
+ text = text.replace("</s>", "")
143
+ # Remove leading space
144
+ if text and text[0] == " ":
145
+ text = text[1:]
146
+ return text
147
+
148
+
149
+ def process_item(item: dict[str, Any], config: dict) -> dict[str, Any]:
150
+ """
151
+ Process a single item through the Self-RAG pipeline.
152
+
153
+ This is a simplified interface for benchmark runner integration.
154
+
155
+ Args:
156
+ item: Data item with keys: question, answers, ctxs
157
+ config: Pipeline configuration
158
+
159
+ Returns:
160
+ Result dictionary with: id, question, prediction, ground_truth
161
+ """
162
+ # Initialize components
163
+ retriever = SelfRAGRetriever(config)
164
+ promptor = SelfRAGPromptor(config)
165
+ generator = SelfRAGGenerator(config)
166
+
167
+ # Process through pipeline
168
+ retrieved = retriever.execute(item)
169
+ prompted = promptor.execute(retrieved)
170
+ result = generator.execute(prompted)
171
+
172
+ # Format output
173
+ return {
174
+ "id": item.get("id", "unknown"),
175
+ "question": item["question"],
176
+ "prediction": result["prediction"],
177
+ "ground_truth": item.get("answers", []),
178
+ "retrieved_docs": result.get("retrieved_docs", []),
179
+ }
180
+
181
+
182
+ def run_selfrag_pipeline(config_path: str):
183
+ """
184
+ Run Self-RAG pipeline.
185
+
186
+ Args:
187
+ config_path: Path to configuration YAML file
188
+ """
189
+ import yaml
190
+
191
+ with open(config_path) as f:
192
+ config = yaml.safe_load(f)
193
+
194
+ env = LocalEnvironment("selfrag_pipeline")
195
+
196
+ (
197
+ env.from_source(FileSource, {"file_path": config["data_path"]})
198
+ .map(SelfRAGRetriever, config["retriever"])
199
+ .map(SelfRAGPromptor, config["promptor"])
200
+ .map(SelfRAGGenerator, config["generator"])
201
+ .sink(FileSink, {"output_path": config["output_path"]})
202
+ )
203
+
204
+ env.submit()
205
+ env.close()
206
+
207
+
208
+ if __name__ == "__main__":
209
+ import sys
210
+ from pathlib import Path
211
+
212
+ # Default config path
213
+ config_dir = Path(__file__).parent.parent.parent / "config"
214
+ config_path_default = config_dir / "config_selfrag.yaml"
215
+
216
+ if len(sys.argv) > 1:
217
+ config_path_arg = sys.argv[1]
218
+ else:
219
+ config_path_arg = str(config_path_default)
220
+
221
+ print("🚀 Running Self-RAG Pipeline")
222
+ print(f"📝 Config: {config_path_arg}")
223
+
224
+ run_selfrag_pipeline(config_path_arg)
225
+
226
+ print("✅ Self-RAG pipeline completed!")
@@ -0,0 +1,17 @@
1
+ """Tools and utilities for RAG implementations.
2
+
3
+ This module provides supporting tools for RAG pipelines:
4
+
5
+ Index Building:
6
+ - build_chroma_index.py: Build ChromaDB vector index
7
+ - build_milvus_index.py: Build Milvus vector index
8
+ - build_milvus_dense_index.py: Build Milvus dense vector index
9
+ - build_milvus_sparse_index.py: Build Milvus sparse vector index
10
+
11
+ Document Loaders:
12
+ - loaders/: Various document format loaders
13
+
14
+ These tools are used to prepare data and indices before running RAG benchmarks.
15
+ """
16
+
17
+ __all__: list[str] = []
@@ -0,0 +1,261 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ 知识库预加载脚本(SAGE多格式+工厂版)
4
+ 支持 txt / pdf / md / docx 文件,
5
+ 通过 LoaderFactory 动态选择 Loader,
6
+ 使用 CharacterSplitter 分块,写入 ChromaDB。
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
12
+
13
+ import chromadb
14
+ import numpy as np
15
+ from numpy.typing import NDArray
16
+ from sage.libs.rag import CharacterSplitter
17
+ from sage.libs.rag.document_loaders import LoaderFactory
18
+
19
+ if TYPE_CHECKING:
20
+ from chromadb.api.types import Embeddings, Metadatas
21
+
22
+
23
+ class Document(TypedDict):
24
+ """Document structure from LoaderFactory."""
25
+
26
+ content: str
27
+ metadata: dict[str, Any]
28
+
29
+
30
+ class ChunkDocument(TypedDict):
31
+ """Structure for chunked document."""
32
+
33
+ content: str
34
+ metadata: dict[str, Any]
35
+
36
+
37
+ class Embedder(Protocol):
38
+ """Protocol for embedder objects."""
39
+
40
+ def encode(
41
+ self, texts: list[str], convert_to_numpy: bool = True, **kwargs: Any
42
+ ) -> NDArray[np.float32] | list[list[float]]:
43
+ """Encode texts to embeddings."""
44
+ ...
45
+
46
+
47
+ # 在测试模式下避免下载大型模型,提供轻量级嵌入器
48
+ def _get_embedder() -> Embedder:
49
+ """Return an object with encode(texts)->embeddings.
50
+
51
+ 优先使用环境变量控制的测试模式,避免在CI/本地测试中下载大型模型。
52
+ - 当 SAGE_EXAMPLES_MODE=test 时,返回一个简单的内置嵌入器(固定维度、小开销)。
53
+ - 否则,使用 SentenceTransformer 加载真实模型。
54
+ """
55
+ import os
56
+
57
+ if os.environ.get("SAGE_EXAMPLES_MODE") == "test":
58
+
59
+ class _MiniEmbedder:
60
+ """Mini embedder for testing."""
61
+
62
+ def __init__(self, dim: int = 8) -> None:
63
+ self.dim = dim
64
+
65
+ def encode(
66
+ self, texts: list[str], convert_to_numpy: bool = True, **kwargs: Any
67
+ ) -> list[list[float]]:
68
+ """生成确定性、低维的伪嵌入以便测试通过."""
69
+ vecs: list[list[float]] = []
70
+ for i, _text in enumerate(texts):
71
+ base = float((i % 5) + 1)
72
+ vecs.append([base / (j + 1) for j in range(self.dim)])
73
+ return vecs
74
+
75
+ embedder: Embedder = _MiniEmbedder(dim=8)
76
+ return embedder
77
+
78
+ # 正常模式:使用真实模型(如可选通过环境变量覆盖模型名)
79
+ model_name = os.environ.get("SAGE_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
80
+ from sentence_transformers import SentenceTransformer
81
+
82
+ model: Embedder = SentenceTransformer(model_name) # type: ignore[assignment]
83
+ return model
84
+
85
+
86
+ def _to_2dlist(
87
+ arr: NDArray[np.float32] | list[list[float]] | list[float],
88
+ ) -> list[list[float]]:
89
+ """Normalize embeddings to a 2D Python list.
90
+
91
+ Accepts list, numpy array, torch tensor, etc., and returns List[List[float]].
92
+
93
+ Args:
94
+ arr: Input array (can be numpy array, list of floats, or list of lists)
95
+
96
+ Returns:
97
+ 2D list of floats suitable for ChromaDB
98
+ """
99
+ # Handle numpy arrays first
100
+ if isinstance(arr, np.ndarray):
101
+ # Convert to Python list
102
+ if arr.ndim == 1:
103
+ # 1D array -> wrap to 2D
104
+ result_1d: list[float] = list(arr.astype(float))
105
+ return [result_1d]
106
+ elif arr.ndim == 2:
107
+ # 2D array -> convert directly
108
+ result_2d: list[list[float]] = [[float(x) for x in row] for row in arr]
109
+ return result_2d
110
+ else:
111
+ raise ValueError(f"Expected 1D or 2D array, got {arr.ndim}D")
112
+
113
+ # At this point arr must be a list (either list[float] or list[list[float]])
114
+ # Check if it's empty
115
+ if not arr:
116
+ return []
117
+
118
+ # Check the type of the first element to determine structure
119
+ first_elem = arr[0]
120
+
121
+ # If first element is a number, this is a 1D list[float] -> wrap to 2D
122
+ if isinstance(first_elem, (int, float, np.floating)):
123
+ # We know arr is list[float] because first element is a number
124
+ float_list: list[float] = arr # type: ignore[assignment]
125
+ return [float_list]
126
+
127
+ # Otherwise, first element is a list/sequence
128
+ # We know arr is list[list[float]]
129
+ nested_list: list[list[float]] = arr # type: ignore[assignment]
130
+ # Ensure all inner elements are properly converted to float
131
+ return [[float(x) for x in row] for row in nested_list]
132
+
133
+
134
+ def load_knowledge_to_chromadb() -> bool:
135
+ """Load knowledge base to ChromaDB from multiple file formats.
136
+
137
+ Returns:
138
+ True if successful, False otherwise
139
+ """
140
+ # 配置参数
141
+ data_dir = "./data/qa" # 数据在共享的 data/qa 目录下
142
+ persistence_path = "./chroma_multi_store"
143
+
144
+ # 文件与集合对应关系
145
+ files_and_collections: list[tuple[str, str]] = [
146
+ (os.path.join(data_dir, "qa_knowledge_base.txt"), "txt_collection"),
147
+ (os.path.join(data_dir, "qa_knowledge_base.pdf"), "pdf_collection"),
148
+ (os.path.join(data_dir, "qa_knowledge_base.md"), "md_collection"),
149
+ (os.path.join(data_dir, "qa_knowledge_base.docx"), "docx_collection"),
150
+ ]
151
+
152
+ print("=== 预加载多格式知识库到 ChromaDB ===")
153
+ print(f"存储路径: {persistence_path}")
154
+
155
+ # 初始化嵌入模型(在测试模式下不下载大模型)
156
+ print("\n加载嵌入模型...")
157
+ model = _get_embedder()
158
+
159
+ # 初始化 ChromaDB
160
+ print("初始化ChromaDB...")
161
+ client = chromadb.PersistentClient(path=persistence_path)
162
+
163
+ for file_path, collection_name in files_and_collections:
164
+ if not os.path.exists(file_path):
165
+ print(f"⚠ 文件不存在,跳过: {file_path}")
166
+ continue
167
+
168
+ print(f"\n=== 处理文件: {file_path} | 集合: {collection_name} ===")
169
+
170
+ # 使用工厂类获取 loader
171
+ # LoaderFactory.load returns dict with 'content' and 'metadata' keys
172
+ raw_document = LoaderFactory.load(file_path)
173
+ document: Document = {
174
+ "content": str(raw_document.get("content", "")),
175
+ "metadata": dict(raw_document.get("metadata", {})),
176
+ }
177
+ print(f"已加载文档,长度: {len(document['content'])}")
178
+
179
+ # 分块
180
+ splitter = CharacterSplitter(separator="\n\n")
181
+ raw_chunks = splitter.split(document["content"])
182
+ # Convert chunks to list of strings
183
+ chunks: list[str] = [str(chunk) for chunk in raw_chunks]
184
+ print(f"分块数: {len(chunks)}")
185
+
186
+ chunk_docs: list[ChunkDocument] = [
187
+ {
188
+ "content": chunk,
189
+ "metadata": {"chunk": idx + 1, "source": file_path},
190
+ }
191
+ for idx, chunk in enumerate(chunks)
192
+ ]
193
+
194
+ # 删除旧集合并创建新集合
195
+ try:
196
+ client.delete_collection(name=collection_name)
197
+ except Exception:
198
+ pass
199
+
200
+ index_type = "flat" # 可选: "flat", "hnsw"
201
+ collection = client.create_collection(
202
+ name=collection_name, metadata={"index_type": index_type}
203
+ )
204
+ print(f"集合已创建,索引类型: {index_type}")
205
+
206
+ # 嵌入与写入
207
+ texts: list[str] = [doc["content"] for doc in chunk_docs]
208
+ raw_embeddings = model.encode(texts, convert_to_numpy=True)
209
+ embeddings = _to_2dlist(raw_embeddings)
210
+
211
+ ids: list[str] = [f"{collection_name}_chunk_{i}" for i in range(len(chunk_docs))]
212
+ metadatas: list[dict[str, str | int | float | bool]] = [
213
+ {
214
+ "chunk": doc["metadata"]["chunk"],
215
+ "source": doc["metadata"]["source"],
216
+ }
217
+ for doc in chunk_docs
218
+ ]
219
+
220
+ # ChromaDB accepts List[List[float]] for embeddings
221
+ # Cast to satisfy type checker - the types are compatible at runtime
222
+ collection.add(
223
+ embeddings=cast("Embeddings", embeddings),
224
+ documents=texts,
225
+ metadatas=cast("Metadatas", metadatas),
226
+ ids=ids,
227
+ )
228
+ print(f"✓ 已添加 {len(chunk_docs)} 个文本块")
229
+ print(f"✓ 数据库文档数: {collection.count()}")
230
+
231
+ # 测试检索
232
+ test_query = "什么是ChromaDB"
233
+ raw_query_embedding = model.encode([test_query], convert_to_numpy=True)
234
+ query_embedding = _to_2dlist(raw_query_embedding)
235
+
236
+ results = collection.query(
237
+ query_embeddings=cast("Embeddings", query_embedding), n_results=3
238
+ )
239
+ print(f"检索: {test_query}")
240
+
241
+ if results and "documents" in results and results["documents"]:
242
+ docs = results["documents"]
243
+ # docs is List[List[Document]] where Document is str
244
+ if len(docs) > 0:
245
+ first_results = docs[0]
246
+ for i, doc in enumerate(first_results):
247
+ print(f" {i + 1}. {doc[:100]}...")
248
+
249
+ print("=== 完成 ===")
250
+
251
+ print("\n=== 所有文件已处理完成 ===")
252
+ return True
253
+
254
+
255
+ if __name__ == "__main__":
256
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
257
+ if load_knowledge_to_chromadb():
258
+ print("知识库已成功加载,可运行检索/问答脚本")
259
+ else:
260
+ print("知识库加载失败")
261
+ sys.exit(1)
@@ -0,0 +1,86 @@
1
+ import os
2
+ import sys
3
+
4
+ from sage.common.utils.config.loader import load_config
5
+ from sage.libs.rag import CharacterSplitter
6
+ from sage.libs.rag.document_loaders import LoaderFactory
7
+ from sage.middleware.operators.rag import MilvusDenseRetriever
8
+
9
+
10
+ def load_knowledge_to_milvus(config):
11
+ """
12
+ 加载多格式知识库到 Milvus(单集合版本,不保留来源信息)
13
+ """
14
+ knowledge_files = config.get("preload_knowledge_file")
15
+ if not isinstance(knowledge_files, list):
16
+ knowledge_files = [knowledge_files]
17
+
18
+ persistence_path = config.get("milvus_dense").get("persistence_path")
19
+ collection_name = "qa_dense_collection" # 单集合
20
+
21
+ print("=== 预加载多格式知识库到 Milvus ===")
22
+ print(f"DB: {persistence_path}")
23
+ print(f"统一集合: {collection_name}")
24
+
25
+ print("初始化Milvus...")
26
+ milvus_backend = MilvusDenseRetriever(config, collection_name=collection_name)
27
+
28
+ all_chunks = []
29
+
30
+ for file_path in knowledge_files:
31
+ if not os.path.exists(file_path):
32
+ print(f"⚠ 文件不存在,跳过: {file_path}")
33
+ continue
34
+
35
+ print(f"\n=== 处理文件: {file_path} ===")
36
+
37
+ document = LoaderFactory.load(file_path)
38
+ print(f"已加载文档,长度: {len(document['content'])}")
39
+
40
+ splitter = CharacterSplitter({"separator": "\n\n"})
41
+ chunks = splitter.execute(document)
42
+ print(f"分块数: {len(chunks)}")
43
+
44
+ all_chunks.extend(chunks)
45
+ print(f"✓ 已准备 {len(chunks)} 个文本块")
46
+
47
+ if all_chunks:
48
+ milvus_backend.add_documents(all_chunks)
49
+ print(f"\n✓ 已写入 {len(all_chunks)} 个文本块到集合 {collection_name}")
50
+ print(f"✓ 数据库信息: {milvus_backend.get_collection_info()}")
51
+
52
+ # 测试检索
53
+ text_query = "什么是ChromaDB?"
54
+ results = milvus_backend.execute(text_query)
55
+ print(f"检索结果: {results}")
56
+
57
+ # 测试检索
58
+ text_query = "RAG 系统的主要优势是什么?"
59
+ results = milvus_backend.execute(text_query)
60
+ print(f"检索结果: {results}")
61
+ else:
62
+ print("⚠ 没有有效的知识文件,未写入任何数据")
63
+
64
+ print("=== 完成 ===")
65
+ return True
66
+
67
+
68
+ if __name__ == "__main__":
69
+ if os.getenv("SAGE_EXAMPLES_MODE") == "test" or os.getenv("SAGE_TEST_MODE") == "true":
70
+ print("🧪 Test mode detected - build_milvus_dense_index example")
71
+ print("✅ Test passed: Example structure validated")
72
+ sys.exit(0)
73
+
74
+ config_path = "./examples/config/config_dense_milvus.yaml"
75
+ if not os.path.exists(config_path):
76
+ print(f"配置文件不存在: {config_path}")
77
+ print("Please create the configuration file first.")
78
+ sys.exit(1)
79
+
80
+ config = load_config(config_path)
81
+ result = load_knowledge_to_milvus(config["retriever"])
82
+ if result:
83
+ print("知识库已成功加载,可运行检索/问答脚本")
84
+ else:
85
+ print("知识库加载失败")
86
+ sys.exit(1)
@@ -0,0 +1,59 @@
1
+ import os
2
+ import sys
3
+
4
+ from sage.common.utils.config.loader import load_config
5
+ from sage.libs.rag import CharacterSplitter
6
+ from sage.libs.rag.document_loaders import TextLoader
7
+ from sage.middleware.operators.rag import MilvusDenseRetriever
8
+
9
+
10
+ def load_knowledge_to_milvus(config):
11
+ """
12
+ 加载知识库到 Milvus
13
+ """
14
+ knowledge_file = config.get("preload_knowledge_file")
15
+ persistence_path = config.get("milvus_dense").get("persistence_path")
16
+ collection_name = config.get("milvus_dense").get("collection_name")
17
+
18
+ print("=== 预加载知识库到 ChromaDB ===")
19
+ print(f"文件: {knowledge_file} | DB: {persistence_path} | 集合: {collection_name}")
20
+
21
+ loader = TextLoader(knowledge_file)
22
+ document = loader.load()
23
+ print(f"已加载文本,长度: {len(document['content'])}")
24
+
25
+ splitter = CharacterSplitter({"separator": "\n\n"})
26
+ chunks = splitter.execute(document)
27
+ print(f"分块数: {len(chunks)}")
28
+
29
+ print("初始化Milvus...")
30
+ milvus_backend = MilvusDenseRetriever(config)
31
+ milvus_backend.add_documents(chunks)
32
+ print(f"✓ 已添加 {len(chunks)} 个文本块")
33
+ print(f"✓ 数据库信息: {milvus_backend.get_collection_info()}")
34
+ text_query = "什么是ChromaDB?"
35
+ results = milvus_backend.execute(text_query)
36
+ print(f"检索结果: {results}")
37
+ return True
38
+
39
+
40
+ if __name__ == "__main__":
41
+ # 检查是否在测试模式下运行
42
+ if os.getenv("SAGE_EXAMPLES_MODE") == "test" or os.getenv("SAGE_TEST_MODE") == "true":
43
+ print("🧪 Test mode detected - build_milvus_index example")
44
+ print("✅ Test passed: Example structure validated")
45
+ sys.exit(0)
46
+
47
+ config_path = "./examples/config/config_dense_milvus.yaml"
48
+ if not os.path.exists(config_path):
49
+ print(f"配置文件不存在: {config_path}")
50
+ print("Please create the configuration file first.")
51
+ sys.exit(1)
52
+
53
+ config = load_config(config_path)
54
+ result = load_knowledge_to_milvus(config["retriever"])
55
+ if result:
56
+ print("知识库已成功加载,可运行检索/问答脚本")
57
+ else:
58
+ print("知识库加载失败")
59
+ sys.exit(1)