isage-rag-benchmark 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_rag_benchmark-0.1.0.1.dist-info/METADATA +63 -0
- isage_rag_benchmark-0.1.0.1.dist-info/RECORD +48 -0
- isage_rag_benchmark-0.1.0.1.dist-info/WHEEL +5 -0
- isage_rag_benchmark-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_rag_benchmark-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark_rag/__init__.py +16 -0
- sage/benchmark_rag/_version.py +4 -0
- sage/benchmark_rag/config/config_bm25s.yaml +51 -0
- sage/benchmark_rag/config/config_dense_milvus.yaml +61 -0
- sage/benchmark_rag/config/config_hf.yaml +43 -0
- sage/benchmark_rag/config/config_mixed.yaml +53 -0
- sage/benchmark_rag/config/config_monitoring_demo.yaml +59 -0
- sage/benchmark_rag/config/config_multiplex.yaml +79 -0
- sage/benchmark_rag/config/config_qa_chroma.yaml +51 -0
- sage/benchmark_rag/config/config_ray.yaml +57 -0
- sage/benchmark_rag/config/config_refiner.yaml +75 -0
- sage/benchmark_rag/config/config_rerank.yaml +56 -0
- sage/benchmark_rag/config/config_selfrag.yaml +24 -0
- sage/benchmark_rag/config/config_source.yaml +30 -0
- sage/benchmark_rag/config/config_source_local.yaml +21 -0
- sage/benchmark_rag/config/config_sparse_milvus.yaml +49 -0
- sage/benchmark_rag/evaluation/__init__.py +10 -0
- sage/benchmark_rag/evaluation/benchmark_runner.py +337 -0
- sage/benchmark_rag/evaluation/config/benchmark_config.yaml +35 -0
- sage/benchmark_rag/evaluation/evaluate_results.py +389 -0
- sage/benchmark_rag/implementations/__init__.py +31 -0
- sage/benchmark_rag/implementations/pipelines/__init__.py +24 -0
- sage/benchmark_rag/implementations/pipelines/qa_bm25_retrieval.py +55 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval.py +56 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_chroma.py +71 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_milvus.py +78 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_mixed.py +58 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_ray.py +174 -0
- sage/benchmark_rag/implementations/pipelines/qa_hf_model.py +57 -0
- sage/benchmark_rag/implementations/pipelines/qa_monitoring_demo.py +139 -0
- sage/benchmark_rag/implementations/pipelines/qa_multimodal_fusion.py +318 -0
- sage/benchmark_rag/implementations/pipelines/qa_multiplex.py +91 -0
- sage/benchmark_rag/implementations/pipelines/qa_refiner.py +91 -0
- sage/benchmark_rag/implementations/pipelines/qa_rerank.py +76 -0
- sage/benchmark_rag/implementations/pipelines/qa_sparse_retrieval_milvus.py +76 -0
- sage/benchmark_rag/implementations/pipelines/selfrag.py +226 -0
- sage/benchmark_rag/implementations/tools/__init__.py +17 -0
- sage/benchmark_rag/implementations/tools/build_chroma_index.py +261 -0
- sage/benchmark_rag/implementations/tools/build_milvus_dense_index.py +86 -0
- sage/benchmark_rag/implementations/tools/build_milvus_index.py +59 -0
- sage/benchmark_rag/implementations/tools/build_milvus_sparse_index.py +85 -0
- sage/benchmark_rag/implementations/tools/loaders/document_loaders.py +42 -0
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Self-RAG Pipeline Implementation
|
|
3
|
+
|
|
4
|
+
Based on the Self-RAG paper (https://arxiv.org/abs/2310.11511)
|
|
5
|
+
Uses pre-retrieved documents from the dataset to generate answers.
|
|
6
|
+
|
|
7
|
+
This implementation uses the Self-RAG dataset format where each item contains:
|
|
8
|
+
- question: The question to answer
|
|
9
|
+
- answers: Ground truth answers
|
|
10
|
+
- ctxs: Pre-retrieved documents with title and text
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from sage.common.core import MapFunction
|
|
16
|
+
from sage.kernel.api.local_environment import LocalEnvironment
|
|
17
|
+
from sage.libs.foundation.io.sink import FileSink
|
|
18
|
+
from sage.libs.foundation.io.source import FileSource
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SelfRAGRetriever(MapFunction):
|
|
22
|
+
"""
|
|
23
|
+
Self-RAG Retriever - extracts pre-retrieved documents from dataset.
|
|
24
|
+
|
|
25
|
+
Unlike traditional retrievers that perform retrieval, this simply
|
|
26
|
+
extracts the pre-computed retrieved documents from the Self-RAG dataset.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, config: dict):
|
|
30
|
+
self.top_k = config.get("top_k", 5)
|
|
31
|
+
|
|
32
|
+
def execute(self, item: dict[str, Any]) -> dict[str, Any]:
|
|
33
|
+
"""Extract pre-retrieved documents from the data item."""
|
|
34
|
+
question = item["question"]
|
|
35
|
+
ctxs = item.get("ctxs", [])
|
|
36
|
+
|
|
37
|
+
# Extract top-k documents
|
|
38
|
+
retrieved_docs = []
|
|
39
|
+
for i, ctx in enumerate(ctxs[: self.top_k]):
|
|
40
|
+
if "text" in ctx and ctx["text"].strip():
|
|
41
|
+
doc = {
|
|
42
|
+
"rank": i + 1,
|
|
43
|
+
"title": ctx.get("title", ""),
|
|
44
|
+
"text": ctx["text"],
|
|
45
|
+
"score": ctx.get("score", 1.0),
|
|
46
|
+
}
|
|
47
|
+
retrieved_docs.append(doc)
|
|
48
|
+
|
|
49
|
+
return {
|
|
50
|
+
"question": question,
|
|
51
|
+
"retrieved_docs": retrieved_docs,
|
|
52
|
+
"ground_truth": item.get("answers", []),
|
|
53
|
+
"id": item.get("id", ""),
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class SelfRAGPromptor(MapFunction):
|
|
58
|
+
"""
|
|
59
|
+
Self-RAG Promptor - builds prompts with retrieved evidence.
|
|
60
|
+
|
|
61
|
+
Constructs prompts in the Self-RAG format with numbered evidence paragraphs.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, config: dict):
|
|
65
|
+
self.model_name = config.get("model_name", "mistral")
|
|
66
|
+
self.use_context = config.get("use_context", True)
|
|
67
|
+
|
|
68
|
+
def execute(self, item: dict[str, Any]) -> dict[str, Any]:
|
|
69
|
+
"""Build prompt with evidence paragraphs."""
|
|
70
|
+
question = item["question"]
|
|
71
|
+
retrieved_docs = item.get("retrieved_docs", [])
|
|
72
|
+
|
|
73
|
+
# Build evidence context
|
|
74
|
+
context = None
|
|
75
|
+
if self.use_context and retrieved_docs:
|
|
76
|
+
evidences = []
|
|
77
|
+
for doc in retrieved_docs:
|
|
78
|
+
rank = doc["rank"]
|
|
79
|
+
title = doc["title"]
|
|
80
|
+
text = doc["text"]
|
|
81
|
+
evidence = f"[{rank}] {title}\n{text}"
|
|
82
|
+
evidences.append(evidence)
|
|
83
|
+
context = "\n".join(evidences)
|
|
84
|
+
|
|
85
|
+
# Build prompt based on model type
|
|
86
|
+
if self.use_context and context:
|
|
87
|
+
if "llama" in self.model_name.lower():
|
|
88
|
+
prompt = f"[INST]{context}\n{question}[/INST]"
|
|
89
|
+
else:
|
|
90
|
+
prompt = f"<s>[INST]{context}\n{question}[/INST]"
|
|
91
|
+
else:
|
|
92
|
+
if "llama" in self.model_name.lower():
|
|
93
|
+
prompt = f"[INST]{question}[/INST]"
|
|
94
|
+
else:
|
|
95
|
+
prompt = f"### Instruction:\n{question}\n\n### Response:\n"
|
|
96
|
+
|
|
97
|
+
item["prompt"] = prompt
|
|
98
|
+
item["context"] = context
|
|
99
|
+
return item
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class SelfRAGGenerator(MapFunction):
|
|
103
|
+
"""
|
|
104
|
+
Self-RAG Generator - generates answers using VLLM.
|
|
105
|
+
|
|
106
|
+
Uses vLLM for efficient batch inference.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(self, config: dict):
|
|
110
|
+
self.model_name = config.get("model_name", "mistralai/Mistral-7B-Instruct-v0.1")
|
|
111
|
+
|
|
112
|
+
from vllm import LLM, SamplingParams
|
|
113
|
+
|
|
114
|
+
self.llm = LLM(
|
|
115
|
+
model=self.model_name,
|
|
116
|
+
gpu_memory_utilization=config.get("gpu_memory_utilization", 0.8),
|
|
117
|
+
)
|
|
118
|
+
self.sampling_params = SamplingParams(
|
|
119
|
+
temperature=config.get("temperature", 0),
|
|
120
|
+
max_tokens=config.get("max_tokens", 100),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def execute(self, item: dict[str, Any]) -> dict[str, Any]:
|
|
124
|
+
"""Generate answer for the question."""
|
|
125
|
+
prompt = item["prompt"]
|
|
126
|
+
|
|
127
|
+
# Generate using vLLM
|
|
128
|
+
outputs = self.llm.generate([prompt], self.sampling_params)
|
|
129
|
+
response = outputs[0].outputs[0].text
|
|
130
|
+
|
|
131
|
+
# Post-process output
|
|
132
|
+
response = self._postprocess(response)
|
|
133
|
+
|
|
134
|
+
item["model_output"] = response
|
|
135
|
+
return item
|
|
136
|
+
|
|
137
|
+
def _postprocess(self, text: str) -> str:
|
|
138
|
+
"""Clean up model output."""
|
|
139
|
+
# Take first paragraph
|
|
140
|
+
text = text.split("\n\n")[0]
|
|
141
|
+
# Remove end tokens
|
|
142
|
+
text = text.replace("</s>", "")
|
|
143
|
+
# Remove leading space
|
|
144
|
+
if text and text[0] == " ":
|
|
145
|
+
text = text[1:]
|
|
146
|
+
return text
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def process_item(item: dict[str, Any], config: dict) -> dict[str, Any]:
|
|
150
|
+
"""
|
|
151
|
+
Process a single item through the Self-RAG pipeline.
|
|
152
|
+
|
|
153
|
+
This is a simplified interface for benchmark runner integration.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
item: Data item with keys: question, answers, ctxs
|
|
157
|
+
config: Pipeline configuration
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
Result dictionary with: id, question, prediction, ground_truth
|
|
161
|
+
"""
|
|
162
|
+
# Initialize components
|
|
163
|
+
retriever = SelfRAGRetriever(config)
|
|
164
|
+
promptor = SelfRAGPromptor(config)
|
|
165
|
+
generator = SelfRAGGenerator(config)
|
|
166
|
+
|
|
167
|
+
# Process through pipeline
|
|
168
|
+
retrieved = retriever.execute(item)
|
|
169
|
+
prompted = promptor.execute(retrieved)
|
|
170
|
+
result = generator.execute(prompted)
|
|
171
|
+
|
|
172
|
+
# Format output
|
|
173
|
+
return {
|
|
174
|
+
"id": item.get("id", "unknown"),
|
|
175
|
+
"question": item["question"],
|
|
176
|
+
"prediction": result["prediction"],
|
|
177
|
+
"ground_truth": item.get("answers", []),
|
|
178
|
+
"retrieved_docs": result.get("retrieved_docs", []),
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def run_selfrag_pipeline(config_path: str):
|
|
183
|
+
"""
|
|
184
|
+
Run Self-RAG pipeline.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
config_path: Path to configuration YAML file
|
|
188
|
+
"""
|
|
189
|
+
import yaml
|
|
190
|
+
|
|
191
|
+
with open(config_path) as f:
|
|
192
|
+
config = yaml.safe_load(f)
|
|
193
|
+
|
|
194
|
+
env = LocalEnvironment("selfrag_pipeline")
|
|
195
|
+
|
|
196
|
+
(
|
|
197
|
+
env.from_source(FileSource, {"file_path": config["data_path"]})
|
|
198
|
+
.map(SelfRAGRetriever, config["retriever"])
|
|
199
|
+
.map(SelfRAGPromptor, config["promptor"])
|
|
200
|
+
.map(SelfRAGGenerator, config["generator"])
|
|
201
|
+
.sink(FileSink, {"output_path": config["output_path"]})
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
env.submit()
|
|
205
|
+
env.close()
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
if __name__ == "__main__":
|
|
209
|
+
import sys
|
|
210
|
+
from pathlib import Path
|
|
211
|
+
|
|
212
|
+
# Default config path
|
|
213
|
+
config_dir = Path(__file__).parent.parent.parent / "config"
|
|
214
|
+
config_path_default = config_dir / "config_selfrag.yaml"
|
|
215
|
+
|
|
216
|
+
if len(sys.argv) > 1:
|
|
217
|
+
config_path_arg = sys.argv[1]
|
|
218
|
+
else:
|
|
219
|
+
config_path_arg = str(config_path_default)
|
|
220
|
+
|
|
221
|
+
print("🚀 Running Self-RAG Pipeline")
|
|
222
|
+
print(f"📝 Config: {config_path_arg}")
|
|
223
|
+
|
|
224
|
+
run_selfrag_pipeline(config_path_arg)
|
|
225
|
+
|
|
226
|
+
print("✅ Self-RAG pipeline completed!")
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Tools and utilities for RAG implementations.
|
|
2
|
+
|
|
3
|
+
This module provides supporting tools for RAG pipelines:
|
|
4
|
+
|
|
5
|
+
Index Building:
|
|
6
|
+
- build_chroma_index.py: Build ChromaDB vector index
|
|
7
|
+
- build_milvus_index.py: Build Milvus vector index
|
|
8
|
+
- build_milvus_dense_index.py: Build Milvus dense vector index
|
|
9
|
+
- build_milvus_sparse_index.py: Build Milvus sparse vector index
|
|
10
|
+
|
|
11
|
+
Document Loaders:
|
|
12
|
+
- loaders/: Various document format loaders
|
|
13
|
+
|
|
14
|
+
These tools are used to prepare data and indices before running RAG benchmarks.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
__all__: list[str] = []
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
知识库预加载脚本(SAGE多格式+工厂版)
|
|
4
|
+
支持 txt / pdf / md / docx 文件,
|
|
5
|
+
通过 LoaderFactory 动态选择 Loader,
|
|
6
|
+
使用 CharacterSplitter 分块,写入 ChromaDB。
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
|
|
12
|
+
|
|
13
|
+
import chromadb
|
|
14
|
+
import numpy as np
|
|
15
|
+
from numpy.typing import NDArray
|
|
16
|
+
from sage.libs.rag import CharacterSplitter
|
|
17
|
+
from sage.libs.rag.document_loaders import LoaderFactory
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from chromadb.api.types import Embeddings, Metadatas
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Document(TypedDict):
|
|
24
|
+
"""Document structure from LoaderFactory."""
|
|
25
|
+
|
|
26
|
+
content: str
|
|
27
|
+
metadata: dict[str, Any]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ChunkDocument(TypedDict):
|
|
31
|
+
"""Structure for chunked document."""
|
|
32
|
+
|
|
33
|
+
content: str
|
|
34
|
+
metadata: dict[str, Any]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Embedder(Protocol):
|
|
38
|
+
"""Protocol for embedder objects."""
|
|
39
|
+
|
|
40
|
+
def encode(
|
|
41
|
+
self, texts: list[str], convert_to_numpy: bool = True, **kwargs: Any
|
|
42
|
+
) -> NDArray[np.float32] | list[list[float]]:
|
|
43
|
+
"""Encode texts to embeddings."""
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# 在测试模式下避免下载大型模型,提供轻量级嵌入器
|
|
48
|
+
def _get_embedder() -> Embedder:
|
|
49
|
+
"""Return an object with encode(texts)->embeddings.
|
|
50
|
+
|
|
51
|
+
优先使用环境变量控制的测试模式,避免在CI/本地测试中下载大型模型。
|
|
52
|
+
- 当 SAGE_EXAMPLES_MODE=test 时,返回一个简单的内置嵌入器(固定维度、小开销)。
|
|
53
|
+
- 否则,使用 SentenceTransformer 加载真实模型。
|
|
54
|
+
"""
|
|
55
|
+
import os
|
|
56
|
+
|
|
57
|
+
if os.environ.get("SAGE_EXAMPLES_MODE") == "test":
|
|
58
|
+
|
|
59
|
+
class _MiniEmbedder:
|
|
60
|
+
"""Mini embedder for testing."""
|
|
61
|
+
|
|
62
|
+
def __init__(self, dim: int = 8) -> None:
|
|
63
|
+
self.dim = dim
|
|
64
|
+
|
|
65
|
+
def encode(
|
|
66
|
+
self, texts: list[str], convert_to_numpy: bool = True, **kwargs: Any
|
|
67
|
+
) -> list[list[float]]:
|
|
68
|
+
"""生成确定性、低维的伪嵌入以便测试通过."""
|
|
69
|
+
vecs: list[list[float]] = []
|
|
70
|
+
for i, _text in enumerate(texts):
|
|
71
|
+
base = float((i % 5) + 1)
|
|
72
|
+
vecs.append([base / (j + 1) for j in range(self.dim)])
|
|
73
|
+
return vecs
|
|
74
|
+
|
|
75
|
+
embedder: Embedder = _MiniEmbedder(dim=8)
|
|
76
|
+
return embedder
|
|
77
|
+
|
|
78
|
+
# 正常模式:使用真实模型(如可选通过环境变量覆盖模型名)
|
|
79
|
+
model_name = os.environ.get("SAGE_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
|
80
|
+
from sentence_transformers import SentenceTransformer
|
|
81
|
+
|
|
82
|
+
model: Embedder = SentenceTransformer(model_name) # type: ignore[assignment]
|
|
83
|
+
return model
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _to_2dlist(
|
|
87
|
+
arr: NDArray[np.float32] | list[list[float]] | list[float],
|
|
88
|
+
) -> list[list[float]]:
|
|
89
|
+
"""Normalize embeddings to a 2D Python list.
|
|
90
|
+
|
|
91
|
+
Accepts list, numpy array, torch tensor, etc., and returns List[List[float]].
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
arr: Input array (can be numpy array, list of floats, or list of lists)
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
2D list of floats suitable for ChromaDB
|
|
98
|
+
"""
|
|
99
|
+
# Handle numpy arrays first
|
|
100
|
+
if isinstance(arr, np.ndarray):
|
|
101
|
+
# Convert to Python list
|
|
102
|
+
if arr.ndim == 1:
|
|
103
|
+
# 1D array -> wrap to 2D
|
|
104
|
+
result_1d: list[float] = list(arr.astype(float))
|
|
105
|
+
return [result_1d]
|
|
106
|
+
elif arr.ndim == 2:
|
|
107
|
+
# 2D array -> convert directly
|
|
108
|
+
result_2d: list[list[float]] = [[float(x) for x in row] for row in arr]
|
|
109
|
+
return result_2d
|
|
110
|
+
else:
|
|
111
|
+
raise ValueError(f"Expected 1D or 2D array, got {arr.ndim}D")
|
|
112
|
+
|
|
113
|
+
# At this point arr must be a list (either list[float] or list[list[float]])
|
|
114
|
+
# Check if it's empty
|
|
115
|
+
if not arr:
|
|
116
|
+
return []
|
|
117
|
+
|
|
118
|
+
# Check the type of the first element to determine structure
|
|
119
|
+
first_elem = arr[0]
|
|
120
|
+
|
|
121
|
+
# If first element is a number, this is a 1D list[float] -> wrap to 2D
|
|
122
|
+
if isinstance(first_elem, (int, float, np.floating)):
|
|
123
|
+
# We know arr is list[float] because first element is a number
|
|
124
|
+
float_list: list[float] = arr # type: ignore[assignment]
|
|
125
|
+
return [float_list]
|
|
126
|
+
|
|
127
|
+
# Otherwise, first element is a list/sequence
|
|
128
|
+
# We know arr is list[list[float]]
|
|
129
|
+
nested_list: list[list[float]] = arr # type: ignore[assignment]
|
|
130
|
+
# Ensure all inner elements are properly converted to float
|
|
131
|
+
return [[float(x) for x in row] for row in nested_list]
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def load_knowledge_to_chromadb() -> bool:
|
|
135
|
+
"""Load knowledge base to ChromaDB from multiple file formats.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
True if successful, False otherwise
|
|
139
|
+
"""
|
|
140
|
+
# 配置参数
|
|
141
|
+
data_dir = "./data/qa" # 数据在共享的 data/qa 目录下
|
|
142
|
+
persistence_path = "./chroma_multi_store"
|
|
143
|
+
|
|
144
|
+
# 文件与集合对应关系
|
|
145
|
+
files_and_collections: list[tuple[str, str]] = [
|
|
146
|
+
(os.path.join(data_dir, "qa_knowledge_base.txt"), "txt_collection"),
|
|
147
|
+
(os.path.join(data_dir, "qa_knowledge_base.pdf"), "pdf_collection"),
|
|
148
|
+
(os.path.join(data_dir, "qa_knowledge_base.md"), "md_collection"),
|
|
149
|
+
(os.path.join(data_dir, "qa_knowledge_base.docx"), "docx_collection"),
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
print("=== 预加载多格式知识库到 ChromaDB ===")
|
|
153
|
+
print(f"存储路径: {persistence_path}")
|
|
154
|
+
|
|
155
|
+
# 初始化嵌入模型(在测试模式下不下载大模型)
|
|
156
|
+
print("\n加载嵌入模型...")
|
|
157
|
+
model = _get_embedder()
|
|
158
|
+
|
|
159
|
+
# 初始化 ChromaDB
|
|
160
|
+
print("初始化ChromaDB...")
|
|
161
|
+
client = chromadb.PersistentClient(path=persistence_path)
|
|
162
|
+
|
|
163
|
+
for file_path, collection_name in files_and_collections:
|
|
164
|
+
if not os.path.exists(file_path):
|
|
165
|
+
print(f"⚠ 文件不存在,跳过: {file_path}")
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
print(f"\n=== 处理文件: {file_path} | 集合: {collection_name} ===")
|
|
169
|
+
|
|
170
|
+
# 使用工厂类获取 loader
|
|
171
|
+
# LoaderFactory.load returns dict with 'content' and 'metadata' keys
|
|
172
|
+
raw_document = LoaderFactory.load(file_path)
|
|
173
|
+
document: Document = {
|
|
174
|
+
"content": str(raw_document.get("content", "")),
|
|
175
|
+
"metadata": dict(raw_document.get("metadata", {})),
|
|
176
|
+
}
|
|
177
|
+
print(f"已加载文档,长度: {len(document['content'])}")
|
|
178
|
+
|
|
179
|
+
# 分块
|
|
180
|
+
splitter = CharacterSplitter(separator="\n\n")
|
|
181
|
+
raw_chunks = splitter.split(document["content"])
|
|
182
|
+
# Convert chunks to list of strings
|
|
183
|
+
chunks: list[str] = [str(chunk) for chunk in raw_chunks]
|
|
184
|
+
print(f"分块数: {len(chunks)}")
|
|
185
|
+
|
|
186
|
+
chunk_docs: list[ChunkDocument] = [
|
|
187
|
+
{
|
|
188
|
+
"content": chunk,
|
|
189
|
+
"metadata": {"chunk": idx + 1, "source": file_path},
|
|
190
|
+
}
|
|
191
|
+
for idx, chunk in enumerate(chunks)
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
# 删除旧集合并创建新集合
|
|
195
|
+
try:
|
|
196
|
+
client.delete_collection(name=collection_name)
|
|
197
|
+
except Exception:
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
index_type = "flat" # 可选: "flat", "hnsw"
|
|
201
|
+
collection = client.create_collection(
|
|
202
|
+
name=collection_name, metadata={"index_type": index_type}
|
|
203
|
+
)
|
|
204
|
+
print(f"集合已创建,索引类型: {index_type}")
|
|
205
|
+
|
|
206
|
+
# 嵌入与写入
|
|
207
|
+
texts: list[str] = [doc["content"] for doc in chunk_docs]
|
|
208
|
+
raw_embeddings = model.encode(texts, convert_to_numpy=True)
|
|
209
|
+
embeddings = _to_2dlist(raw_embeddings)
|
|
210
|
+
|
|
211
|
+
ids: list[str] = [f"{collection_name}_chunk_{i}" for i in range(len(chunk_docs))]
|
|
212
|
+
metadatas: list[dict[str, str | int | float | bool]] = [
|
|
213
|
+
{
|
|
214
|
+
"chunk": doc["metadata"]["chunk"],
|
|
215
|
+
"source": doc["metadata"]["source"],
|
|
216
|
+
}
|
|
217
|
+
for doc in chunk_docs
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
# ChromaDB accepts List[List[float]] for embeddings
|
|
221
|
+
# Cast to satisfy type checker - the types are compatible at runtime
|
|
222
|
+
collection.add(
|
|
223
|
+
embeddings=cast("Embeddings", embeddings),
|
|
224
|
+
documents=texts,
|
|
225
|
+
metadatas=cast("Metadatas", metadatas),
|
|
226
|
+
ids=ids,
|
|
227
|
+
)
|
|
228
|
+
print(f"✓ 已添加 {len(chunk_docs)} 个文本块")
|
|
229
|
+
print(f"✓ 数据库文档数: {collection.count()}")
|
|
230
|
+
|
|
231
|
+
# 测试检索
|
|
232
|
+
test_query = "什么是ChromaDB"
|
|
233
|
+
raw_query_embedding = model.encode([test_query], convert_to_numpy=True)
|
|
234
|
+
query_embedding = _to_2dlist(raw_query_embedding)
|
|
235
|
+
|
|
236
|
+
results = collection.query(
|
|
237
|
+
query_embeddings=cast("Embeddings", query_embedding), n_results=3
|
|
238
|
+
)
|
|
239
|
+
print(f"检索: {test_query}")
|
|
240
|
+
|
|
241
|
+
if results and "documents" in results and results["documents"]:
|
|
242
|
+
docs = results["documents"]
|
|
243
|
+
# docs is List[List[Document]] where Document is str
|
|
244
|
+
if len(docs) > 0:
|
|
245
|
+
first_results = docs[0]
|
|
246
|
+
for i, doc in enumerate(first_results):
|
|
247
|
+
print(f" {i + 1}. {doc[:100]}...")
|
|
248
|
+
|
|
249
|
+
print("=== 完成 ===")
|
|
250
|
+
|
|
251
|
+
print("\n=== 所有文件已处理完成 ===")
|
|
252
|
+
return True
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
if __name__ == "__main__":
|
|
256
|
+
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
|
257
|
+
if load_knowledge_to_chromadb():
|
|
258
|
+
print("知识库已成功加载,可运行检索/问答脚本")
|
|
259
|
+
else:
|
|
260
|
+
print("知识库加载失败")
|
|
261
|
+
sys.exit(1)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
from sage.common.utils.config.loader import load_config
|
|
5
|
+
from sage.libs.rag import CharacterSplitter
|
|
6
|
+
from sage.libs.rag.document_loaders import LoaderFactory
|
|
7
|
+
from sage.middleware.operators.rag import MilvusDenseRetriever
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_knowledge_to_milvus(config):
|
|
11
|
+
"""
|
|
12
|
+
加载多格式知识库到 Milvus(单集合版本,不保留来源信息)
|
|
13
|
+
"""
|
|
14
|
+
knowledge_files = config.get("preload_knowledge_file")
|
|
15
|
+
if not isinstance(knowledge_files, list):
|
|
16
|
+
knowledge_files = [knowledge_files]
|
|
17
|
+
|
|
18
|
+
persistence_path = config.get("milvus_dense").get("persistence_path")
|
|
19
|
+
collection_name = "qa_dense_collection" # 单集合
|
|
20
|
+
|
|
21
|
+
print("=== 预加载多格式知识库到 Milvus ===")
|
|
22
|
+
print(f"DB: {persistence_path}")
|
|
23
|
+
print(f"统一集合: {collection_name}")
|
|
24
|
+
|
|
25
|
+
print("初始化Milvus...")
|
|
26
|
+
milvus_backend = MilvusDenseRetriever(config, collection_name=collection_name)
|
|
27
|
+
|
|
28
|
+
all_chunks = []
|
|
29
|
+
|
|
30
|
+
for file_path in knowledge_files:
|
|
31
|
+
if not os.path.exists(file_path):
|
|
32
|
+
print(f"⚠ 文件不存在,跳过: {file_path}")
|
|
33
|
+
continue
|
|
34
|
+
|
|
35
|
+
print(f"\n=== 处理文件: {file_path} ===")
|
|
36
|
+
|
|
37
|
+
document = LoaderFactory.load(file_path)
|
|
38
|
+
print(f"已加载文档,长度: {len(document['content'])}")
|
|
39
|
+
|
|
40
|
+
splitter = CharacterSplitter({"separator": "\n\n"})
|
|
41
|
+
chunks = splitter.execute(document)
|
|
42
|
+
print(f"分块数: {len(chunks)}")
|
|
43
|
+
|
|
44
|
+
all_chunks.extend(chunks)
|
|
45
|
+
print(f"✓ 已准备 {len(chunks)} 个文本块")
|
|
46
|
+
|
|
47
|
+
if all_chunks:
|
|
48
|
+
milvus_backend.add_documents(all_chunks)
|
|
49
|
+
print(f"\n✓ 已写入 {len(all_chunks)} 个文本块到集合 {collection_name}")
|
|
50
|
+
print(f"✓ 数据库信息: {milvus_backend.get_collection_info()}")
|
|
51
|
+
|
|
52
|
+
# 测试检索
|
|
53
|
+
text_query = "什么是ChromaDB?"
|
|
54
|
+
results = milvus_backend.execute(text_query)
|
|
55
|
+
print(f"检索结果: {results}")
|
|
56
|
+
|
|
57
|
+
# 测试检索
|
|
58
|
+
text_query = "RAG 系统的主要优势是什么?"
|
|
59
|
+
results = milvus_backend.execute(text_query)
|
|
60
|
+
print(f"检索结果: {results}")
|
|
61
|
+
else:
|
|
62
|
+
print("⚠ 没有有效的知识文件,未写入任何数据")
|
|
63
|
+
|
|
64
|
+
print("=== 完成 ===")
|
|
65
|
+
return True
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
if __name__ == "__main__":
|
|
69
|
+
if os.getenv("SAGE_EXAMPLES_MODE") == "test" or os.getenv("SAGE_TEST_MODE") == "true":
|
|
70
|
+
print("🧪 Test mode detected - build_milvus_dense_index example")
|
|
71
|
+
print("✅ Test passed: Example structure validated")
|
|
72
|
+
sys.exit(0)
|
|
73
|
+
|
|
74
|
+
config_path = "./examples/config/config_dense_milvus.yaml"
|
|
75
|
+
if not os.path.exists(config_path):
|
|
76
|
+
print(f"配置文件不存在: {config_path}")
|
|
77
|
+
print("Please create the configuration file first.")
|
|
78
|
+
sys.exit(1)
|
|
79
|
+
|
|
80
|
+
config = load_config(config_path)
|
|
81
|
+
result = load_knowledge_to_milvus(config["retriever"])
|
|
82
|
+
if result:
|
|
83
|
+
print("知识库已成功加载,可运行检索/问答脚本")
|
|
84
|
+
else:
|
|
85
|
+
print("知识库加载失败")
|
|
86
|
+
sys.exit(1)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
from sage.common.utils.config.loader import load_config
|
|
5
|
+
from sage.libs.rag import CharacterSplitter
|
|
6
|
+
from sage.libs.rag.document_loaders import TextLoader
|
|
7
|
+
from sage.middleware.operators.rag import MilvusDenseRetriever
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_knowledge_to_milvus(config):
|
|
11
|
+
"""
|
|
12
|
+
加载知识库到 Milvus
|
|
13
|
+
"""
|
|
14
|
+
knowledge_file = config.get("preload_knowledge_file")
|
|
15
|
+
persistence_path = config.get("milvus_dense").get("persistence_path")
|
|
16
|
+
collection_name = config.get("milvus_dense").get("collection_name")
|
|
17
|
+
|
|
18
|
+
print("=== 预加载知识库到 ChromaDB ===")
|
|
19
|
+
print(f"文件: {knowledge_file} | DB: {persistence_path} | 集合: {collection_name}")
|
|
20
|
+
|
|
21
|
+
loader = TextLoader(knowledge_file)
|
|
22
|
+
document = loader.load()
|
|
23
|
+
print(f"已加载文本,长度: {len(document['content'])}")
|
|
24
|
+
|
|
25
|
+
splitter = CharacterSplitter({"separator": "\n\n"})
|
|
26
|
+
chunks = splitter.execute(document)
|
|
27
|
+
print(f"分块数: {len(chunks)}")
|
|
28
|
+
|
|
29
|
+
print("初始化Milvus...")
|
|
30
|
+
milvus_backend = MilvusDenseRetriever(config)
|
|
31
|
+
milvus_backend.add_documents(chunks)
|
|
32
|
+
print(f"✓ 已添加 {len(chunks)} 个文本块")
|
|
33
|
+
print(f"✓ 数据库信息: {milvus_backend.get_collection_info()}")
|
|
34
|
+
text_query = "什么是ChromaDB?"
|
|
35
|
+
results = milvus_backend.execute(text_query)
|
|
36
|
+
print(f"检索结果: {results}")
|
|
37
|
+
return True
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
if __name__ == "__main__":
|
|
41
|
+
# 检查是否在测试模式下运行
|
|
42
|
+
if os.getenv("SAGE_EXAMPLES_MODE") == "test" or os.getenv("SAGE_TEST_MODE") == "true":
|
|
43
|
+
print("🧪 Test mode detected - build_milvus_index example")
|
|
44
|
+
print("✅ Test passed: Example structure validated")
|
|
45
|
+
sys.exit(0)
|
|
46
|
+
|
|
47
|
+
config_path = "./examples/config/config_dense_milvus.yaml"
|
|
48
|
+
if not os.path.exists(config_path):
|
|
49
|
+
print(f"配置文件不存在: {config_path}")
|
|
50
|
+
print("Please create the configuration file first.")
|
|
51
|
+
sys.exit(1)
|
|
52
|
+
|
|
53
|
+
config = load_config(config_path)
|
|
54
|
+
result = load_knowledge_to_milvus(config["retriever"])
|
|
55
|
+
if result:
|
|
56
|
+
print("知识库已成功加载,可运行检索/问答脚本")
|
|
57
|
+
else:
|
|
58
|
+
print("知识库加载失败")
|
|
59
|
+
sys.exit(1)
|