nbrag 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nbrag/__init__.py +7 -0
- nbrag/__main__.py +5 -0
- nbrag/chunker.py +311 -0
- nbrag/config.py +169 -0
- nbrag/core.py +1059 -0
- nbrag/loggers.py +4 -0
- nbrag/server.py +594 -0
- nbrag-0.2.0.dist-info/METADATA +324 -0
- nbrag-0.2.0.dist-info/RECORD +12 -0
- nbrag-0.2.0.dist-info/WHEEL +4 -0
- nbrag-0.2.0.dist-info/entry_points.txt +2 -0
- nbrag-0.2.0.dist-info/licenses/LICENSE +21 -0
nbrag/__init__.py
ADDED
nbrag/__main__.py
ADDED
nbrag/chunker.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RAG 分块增强模块 —— 文本切分 + 行号计算 + AST scope 解析 + 头部上下文注入。
|
|
3
|
+
|
|
4
|
+
从 core.py 拆分出来,专注于「把原始文本变成 enriched chunks」这一步。
|
|
5
|
+
core.py 专注于存储和检索。
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import ast
|
|
10
|
+
import bisect
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
DEFAULT_CHUNK_SIZE = 1500
|
|
17
|
+
DEFAULT_CHUNK_OVERLAP = 200
|
|
18
|
+
|
|
19
|
+
# ─── 文件类型映射 ─────────────────────────────────────────
|
|
20
|
+
|
|
21
|
+
TEXT_EXTENSIONS = {
|
|
22
|
+
".txt", ".md", ".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".go",
|
|
23
|
+
".rs", ".c", ".cpp", ".h", ".hpp", ".cs", ".rb", ".php", ".sh",
|
|
24
|
+
".json", ".yaml", ".yml", ".toml", ".ini", ".cfg", ".conf",
|
|
25
|
+
".csv", ".tsv", ".xml", ".html", ".css", ".sql", ".r", ".lua",
|
|
26
|
+
".swift", ".kt", ".scala", ".dart", ".vue", ".svelte",
|
|
27
|
+
".rst", ".tex", ".log", ".env", ".bat", ".ps1",
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
_EXT_TO_LANG = {
|
|
31
|
+
".py": Language.PYTHON, ".js": Language.JS, ".ts": Language.TS,
|
|
32
|
+
".jsx": Language.JS, ".tsx": Language.TS,
|
|
33
|
+
".java": Language.JAVA, ".go": Language.GO,
|
|
34
|
+
".rs": Language.RUST, ".rb": Language.RUBY,
|
|
35
|
+
".cpp": Language.CPP, ".c": Language.C, ".h": Language.C, ".hpp": Language.CPP,
|
|
36
|
+
".cs": Language.CSHARP, ".php": Language.PHP,
|
|
37
|
+
".swift": Language.SWIFT, ".kt": Language.KOTLIN, ".scala": Language.SCALA,
|
|
38
|
+
".lua": Language.LUA, ".sol": Language.SOL,
|
|
39
|
+
".md": Language.MARKDOWN, ".html": Language.HTML,
|
|
40
|
+
".tex": Language.LATEX, ".rst": Language.RST,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# ─── 文本切分 ─────────────────────────────────────────────
|
|
45
|
+
|
|
46
|
+
def chunk_text(text, chunk_size=DEFAULT_CHUNK_SIZE, overlap=DEFAULT_CHUNK_OVERLAP,
|
|
47
|
+
file_ext=""):
|
|
48
|
+
"""根据文件类型自动选择最优切分策略(不含头部注入)。
|
|
49
|
+
代码文件按 class/function 边界切分;Markdown 按标题切分;通用文本按段落切分。
|
|
50
|
+
"""
|
|
51
|
+
text = text.strip()
|
|
52
|
+
if not text:
|
|
53
|
+
return []
|
|
54
|
+
|
|
55
|
+
lang = _EXT_TO_LANG.get(file_ext.lower())
|
|
56
|
+
if lang:
|
|
57
|
+
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
58
|
+
language=lang, chunk_size=chunk_size, chunk_overlap=overlap,
|
|
59
|
+
)
|
|
60
|
+
else:
|
|
61
|
+
splitter = RecursiveCharacterTextSplitter(
|
|
62
|
+
chunk_size=chunk_size, chunk_overlap=overlap,
|
|
63
|
+
separators=["\n\n", "\n", "。", ".", ";", ";", ",", ",", " ", ""],
|
|
64
|
+
)
|
|
65
|
+
return splitter.split_text(text)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# ─── 行号计算 ─────────────────────────────────────────────
|
|
69
|
+
|
|
70
|
+
def _build_line_offsets(text):
|
|
71
|
+
"""构建行号偏移表: line_offsets[i] = 第 i+1 行的起始字符位置。"""
|
|
72
|
+
offsets = [0]
|
|
73
|
+
for line in text.split('\n'):
|
|
74
|
+
offsets.append(offsets[-1] + len(line) + 1)
|
|
75
|
+
return offsets
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def compute_line_ranges(full_text, chunks, overlap=DEFAULT_CHUNK_OVERLAP):
|
|
79
|
+
"""为每个 chunk 计算它在原文中的行号范围 (1-based)。
|
|
80
|
+
|
|
81
|
+
使用顺序搜索 + 重叠偏移,确保每个 chunk 匹配到正确位置。
|
|
82
|
+
Returns: [(start_line, end_line), ...]
|
|
83
|
+
"""
|
|
84
|
+
line_offsets = _build_line_offsets(full_text)
|
|
85
|
+
ranges = []
|
|
86
|
+
search_start = 0
|
|
87
|
+
|
|
88
|
+
for chunk in chunks:
|
|
89
|
+
stripped = chunk.strip()
|
|
90
|
+
if not stripped:
|
|
91
|
+
ranges.append((0, 0))
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
needle = stripped[:200]
|
|
95
|
+
pos = full_text.find(needle, search_start)
|
|
96
|
+
if pos == -1:
|
|
97
|
+
pos = full_text.find(needle)
|
|
98
|
+
|
|
99
|
+
if pos >= 0:
|
|
100
|
+
start_line = bisect.bisect_right(line_offsets, pos)
|
|
101
|
+
end_pos = pos + len(stripped)
|
|
102
|
+
end_line = bisect.bisect_right(line_offsets, end_pos)
|
|
103
|
+
ranges.append((start_line, end_line))
|
|
104
|
+
advance = max(1, len(stripped) - overlap)
|
|
105
|
+
search_start = pos + advance
|
|
106
|
+
else:
|
|
107
|
+
ranges.append((0, 0))
|
|
108
|
+
|
|
109
|
+
return ranges
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# ─── Python AST scope 解析 ────────────────────────────────
|
|
113
|
+
|
|
114
|
+
def _extract_signature(node):
|
|
115
|
+
"""从 AST 节点提取函数签名字符串。ClassDef 返回基类列表。"""
|
|
116
|
+
if isinstance(node, ast.ClassDef):
|
|
117
|
+
if node.bases:
|
|
118
|
+
bases = []
|
|
119
|
+
for b in node.bases:
|
|
120
|
+
if isinstance(b, ast.Name):
|
|
121
|
+
bases.append(b.id)
|
|
122
|
+
elif isinstance(b, ast.Attribute):
|
|
123
|
+
bases.append(b.attr)
|
|
124
|
+
else:
|
|
125
|
+
bases.append("?")
|
|
126
|
+
return f"class {node.name}({', '.join(bases)})"
|
|
127
|
+
return f"class {node.name}"
|
|
128
|
+
|
|
129
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
130
|
+
args = node.args
|
|
131
|
+
params = []
|
|
132
|
+
all_args = args.args + args.posonlyargs
|
|
133
|
+
defaults_offset = len(all_args) - len(args.defaults)
|
|
134
|
+
for i, arg in enumerate(all_args):
|
|
135
|
+
p = arg.arg
|
|
136
|
+
if arg.annotation and isinstance(arg.annotation, ast.Name):
|
|
137
|
+
p += f": {arg.annotation.id}"
|
|
138
|
+
if i >= defaults_offset:
|
|
139
|
+
p += "=..."
|
|
140
|
+
params.append(p)
|
|
141
|
+
if args.vararg:
|
|
142
|
+
params.append(f"*{args.vararg.arg}")
|
|
143
|
+
if args.kwonlyargs:
|
|
144
|
+
if not args.vararg:
|
|
145
|
+
params.append("*")
|
|
146
|
+
for kw in args.kwonlyargs:
|
|
147
|
+
p = kw.arg
|
|
148
|
+
if kw.annotation and isinstance(kw.annotation, ast.Name):
|
|
149
|
+
p += f": {kw.annotation.id}"
|
|
150
|
+
params.append(p)
|
|
151
|
+
if args.kwarg:
|
|
152
|
+
params.append(f"**{args.kwarg.arg}")
|
|
153
|
+
prefix = "async def" if isinstance(node, ast.AsyncFunctionDef) else "def"
|
|
154
|
+
sig = f"{prefix} {node.name}({', '.join(params)})"
|
|
155
|
+
if node.returns and isinstance(node.returns, ast.Name):
|
|
156
|
+
sig += f" -> {node.returns.id}"
|
|
157
|
+
return sig
|
|
158
|
+
|
|
159
|
+
return ""
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _build_python_scope_map(text):
|
|
163
|
+
"""解析 Python 代码,构建 scope 列表。
|
|
164
|
+
|
|
165
|
+
返回: [(start_line, end_line, scope_str, signature), ...]
|
|
166
|
+
scope_str 示例: "MyClass", "MyClass.my_method", "my_function"
|
|
167
|
+
signature 示例: "def my_method(self, x, y=...)"
|
|
168
|
+
"""
|
|
169
|
+
try:
|
|
170
|
+
with warnings.catch_warnings():
|
|
171
|
+
warnings.simplefilter("ignore", SyntaxWarning)
|
|
172
|
+
tree = ast.parse(text)
|
|
173
|
+
except SyntaxError:
|
|
174
|
+
return []
|
|
175
|
+
|
|
176
|
+
scopes = []
|
|
177
|
+
|
|
178
|
+
def _walk(node, parent_chain=""):
|
|
179
|
+
for child in ast.iter_child_nodes(node):
|
|
180
|
+
if isinstance(child, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
181
|
+
name = child.name
|
|
182
|
+
scope_str = f"{parent_chain}.{name}" if parent_chain else name
|
|
183
|
+
start = child.lineno
|
|
184
|
+
end = child.end_lineno if hasattr(child, 'end_lineno') and child.end_lineno else start
|
|
185
|
+
sig = _extract_signature(child)
|
|
186
|
+
scopes.append((start, end, scope_str, sig))
|
|
187
|
+
_walk(child, scope_str)
|
|
188
|
+
|
|
189
|
+
_walk(tree)
|
|
190
|
+
return scopes
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_python_scope_at_line(text, line_number):
|
|
194
|
+
"""给定行号,返回该行所属的完整 scope 链字符串。
|
|
195
|
+
|
|
196
|
+
返回示例: "MyClass", "MyClass.my_method", ""(模块级)
|
|
197
|
+
"""
|
|
198
|
+
scope_map = _build_python_scope_map(text)
|
|
199
|
+
scope, _sig, _parent = _find_scope_in_map(scope_map, line_number)
|
|
200
|
+
return scope
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# ─── 头部上下文注入 ───────────────────────────────────────
|
|
204
|
+
|
|
205
|
+
def _format_header(file_path, line_start, line_end, scope="", signature="",
|
|
206
|
+
parent_class_sig=""):
|
|
207
|
+
"""格式化 chunk 头部注入字符串。"""
|
|
208
|
+
parts = [f"[File: {file_path}]"]
|
|
209
|
+
if scope:
|
|
210
|
+
if "." in scope:
|
|
211
|
+
class_part, method_part = scope.rsplit(".", 1)
|
|
212
|
+
if parent_class_sig:
|
|
213
|
+
parts.append(f"[Class: {parent_class_sig}]")
|
|
214
|
+
else:
|
|
215
|
+
parts.append(f"[Class: {class_part}]")
|
|
216
|
+
parts.append(f"[Method: {method_part}]")
|
|
217
|
+
else:
|
|
218
|
+
parts.append(f"[Scope: {scope}]")
|
|
219
|
+
if signature:
|
|
220
|
+
parts.append(f"[Sig: {signature}]")
|
|
221
|
+
if line_start and line_end:
|
|
222
|
+
parts.append(f"[Lines: {line_start}-{line_end}]")
|
|
223
|
+
return "# " + " ".join(parts)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _find_scope_in_map(scope_map, line_number):
|
|
227
|
+
"""从预构建的 scope_map 中查找指定行号的最内层 scope。
|
|
228
|
+
返回 (scope_str, signature, parent_class_sig) 元组。
|
|
229
|
+
parent_class_sig: 如果当前 scope 是方法(如 Cls.method),返回父类的签名(含继承链)。"""
|
|
230
|
+
best_scope = ""
|
|
231
|
+
best_sig = ""
|
|
232
|
+
best_size = float('inf')
|
|
233
|
+
parent_class_sig = ""
|
|
234
|
+
for entry in scope_map:
|
|
235
|
+
s_start, s_end, s_str = entry[0], entry[1], entry[2]
|
|
236
|
+
sig = entry[3] if len(entry) > 3 else ""
|
|
237
|
+
if s_start <= line_number <= s_end:
|
|
238
|
+
size = s_end - s_start
|
|
239
|
+
if size < best_size:
|
|
240
|
+
best_size = size
|
|
241
|
+
best_scope = s_str
|
|
242
|
+
best_sig = sig
|
|
243
|
+
if "." in best_scope:
|
|
244
|
+
class_name = best_scope.rsplit(".", 1)[0]
|
|
245
|
+
for entry in scope_map:
|
|
246
|
+
if entry[2] == class_name:
|
|
247
|
+
parent_class_sig = entry[3] if len(entry) > 3 else ""
|
|
248
|
+
break
|
|
249
|
+
return best_scope, best_sig, parent_class_sig
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def enrich_chunks(chunks, full_text, file_path="", file_ext=""):
|
|
253
|
+
"""为 chunks 注入头部上下文信息,返回 (enriched_chunks, line_ranges, scopes)。
|
|
254
|
+
|
|
255
|
+
头部注入在 embedding 前完成,让向量包含上下文信息。
|
|
256
|
+
line_ranges 和 scopes 用于存入 metadata。
|
|
257
|
+
"""
|
|
258
|
+
if not chunks:
|
|
259
|
+
return [], [], []
|
|
260
|
+
|
|
261
|
+
line_ranges = compute_line_ranges(full_text, chunks)
|
|
262
|
+
|
|
263
|
+
scope_map = []
|
|
264
|
+
if file_ext.lower() == ".py":
|
|
265
|
+
scope_map = _build_python_scope_map(full_text)
|
|
266
|
+
|
|
267
|
+
enriched = []
|
|
268
|
+
scopes = []
|
|
269
|
+
for i, chunk in enumerate(chunks):
|
|
270
|
+
ls, le = line_ranges[i]
|
|
271
|
+
if scope_map and ls > 0:
|
|
272
|
+
scope, sig, parent_cls_sig = _find_scope_in_map(scope_map, ls)
|
|
273
|
+
else:
|
|
274
|
+
scope, sig, parent_cls_sig = "", "", ""
|
|
275
|
+
scopes.append(scope)
|
|
276
|
+
header = _format_header(file_path, ls, le, scope, sig, parent_cls_sig)
|
|
277
|
+
enriched.append(f"{header}\n{chunk}")
|
|
278
|
+
|
|
279
|
+
return enriched, line_ranges, scopes
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def collect_files(path, file_extensions=None):
|
|
283
|
+
"""收集路径下所有文本文件(支持单文件或递归目录)。
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
path: 文件或目录路径。
|
|
287
|
+
file_extensions: 可选,限定的后缀列表(如 [".py", ".md"])。
|
|
288
|
+
传入时只收集这些后缀;不传则使用 TEXT_EXTENSIONS 全集。
|
|
289
|
+
后缀不区分大小写,自动补 "." 前缀("py" → ".py")。
|
|
290
|
+
"""
|
|
291
|
+
if file_extensions is not None:
|
|
292
|
+
allowed = set()
|
|
293
|
+
for ext in file_extensions:
|
|
294
|
+
ext = ext.strip().lower()
|
|
295
|
+
if not ext.startswith("."):
|
|
296
|
+
ext = "." + ext
|
|
297
|
+
allowed.add(ext)
|
|
298
|
+
else:
|
|
299
|
+
allowed = TEXT_EXTENSIONS
|
|
300
|
+
|
|
301
|
+
if os.path.isfile(path):
|
|
302
|
+
ext = os.path.splitext(path)[1].lower()
|
|
303
|
+
if ext in allowed:
|
|
304
|
+
return [path]
|
|
305
|
+
return []
|
|
306
|
+
files = []
|
|
307
|
+
for root, _, fnames in os.walk(path):
|
|
308
|
+
for fn in sorted(fnames):
|
|
309
|
+
if os.path.splitext(fn)[1].lower() in allowed:
|
|
310
|
+
files.append(os.path.join(root, fn))
|
|
311
|
+
return files
|
nbrag/config.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""
|
|
2
|
+
配置加载模块 — CLI > 环境变量 > YAML 配置文件 > 默认值。
|
|
3
|
+
|
|
4
|
+
最小启动只需要一个环境变量:
|
|
5
|
+
export NBRAG_API_KEY=sk-xxx
|
|
6
|
+
uvx nbrag
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class EmbeddingConfig:
|
|
15
|
+
api_key: str = ""
|
|
16
|
+
base_url: str = "https://api.siliconflow.cn/v1"
|
|
17
|
+
model: str = "BAAI/bge-m3"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class RerankConfig:
|
|
22
|
+
model: str = "BAAI/bge-reranker-v2-m3"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class StorageConfig:
|
|
27
|
+
db_path: str = "./rag_db"
|
|
28
|
+
raw_files_path: str = "" # 默认 db_path/raw_files
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class ChunkingConfig:
|
|
33
|
+
chunk_size: int = 1500
|
|
34
|
+
chunk_overlap: int = 200
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class RagConfig:
|
|
39
|
+
embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
|
|
40
|
+
rerank: RerankConfig = field(default_factory=RerankConfig)
|
|
41
|
+
storage: StorageConfig = field(default_factory=StorageConfig)
|
|
42
|
+
chunking: ChunkingConfig = field(default_factory=ChunkingConfig)
|
|
43
|
+
|
|
44
|
+
def __post_init__(self):
|
|
45
|
+
if not self.storage.raw_files_path:
|
|
46
|
+
self.storage.raw_files_path = os.path.join(self.storage.db_path, "raw_files")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
_config: RagConfig = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _load_yaml(path):
|
|
53
|
+
"""加载 YAML 配置文件,返回 dict(文件不存在返回空 dict)。"""
|
|
54
|
+
if not path or not os.path.isfile(path):
|
|
55
|
+
return {}
|
|
56
|
+
try:
|
|
57
|
+
import yaml
|
|
58
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
59
|
+
data = yaml.safe_load(f)
|
|
60
|
+
return data if isinstance(data, dict) else {}
|
|
61
|
+
except Exception:
|
|
62
|
+
return {}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _find_config_file():
|
|
66
|
+
"""按优先级查找配置文件。"""
|
|
67
|
+
candidates = [
|
|
68
|
+
os.path.join(os.getcwd(), "nbrag_config.yaml"),
|
|
69
|
+
os.path.join(os.getcwd(), "nbrag_config.yml"),
|
|
70
|
+
os.path.expanduser("~/.config/nbrag/config.yaml"),
|
|
71
|
+
os.path.expanduser("~/.config/nbrag/config.yml"),
|
|
72
|
+
]
|
|
73
|
+
for c in candidates:
|
|
74
|
+
if os.path.isfile(c):
|
|
75
|
+
return c
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _resolve_env_ref(value):
|
|
80
|
+
"""解析 ${VAR_NAME} 环境变量引用。"""
|
|
81
|
+
if not isinstance(value, str):
|
|
82
|
+
return value
|
|
83
|
+
if value.startswith("${") and value.endswith("}"):
|
|
84
|
+
var_name = value[2:-1]
|
|
85
|
+
return os.environ.get(var_name, "")
|
|
86
|
+
return value
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def load_config(cli_args=None) -> RagConfig:
|
|
90
|
+
"""加载配置:CLI > 环境变量 > YAML > 默认值。"""
|
|
91
|
+
global _config
|
|
92
|
+
|
|
93
|
+
yaml_path = None
|
|
94
|
+
if cli_args and hasattr(cli_args, 'config') and cli_args.config:
|
|
95
|
+
yaml_path = cli_args.config
|
|
96
|
+
else:
|
|
97
|
+
yaml_path = os.environ.get("NBRAG_CONFIG", None) or _find_config_file()
|
|
98
|
+
|
|
99
|
+
yaml_data = _load_yaml(yaml_path)
|
|
100
|
+
|
|
101
|
+
embedding_data = yaml_data.get("embedding", {})
|
|
102
|
+
rerank_data = yaml_data.get("rerank", {})
|
|
103
|
+
storage_data = yaml_data.get("storage", {})
|
|
104
|
+
chunking_data = yaml_data.get("chunking", {})
|
|
105
|
+
|
|
106
|
+
api_key = (
|
|
107
|
+
(getattr(cli_args, 'api_key', None) if cli_args else None)
|
|
108
|
+
or os.environ.get("NBRAG_API_KEY", "")
|
|
109
|
+
or _resolve_env_ref(embedding_data.get("api_key", ""))
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
base_url = (
|
|
113
|
+
os.environ.get("NBRAG_BASE_URL", "")
|
|
114
|
+
or embedding_data.get("base_url", "")
|
|
115
|
+
or "https://api.siliconflow.cn/v1"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
embedding_model = (
|
|
119
|
+
os.environ.get("NBRAG_EMBEDDING_MODEL", "")
|
|
120
|
+
or embedding_data.get("model", "")
|
|
121
|
+
or "BAAI/bge-m3"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
rerank_model = (
|
|
125
|
+
os.environ.get("NBRAG_RERANK_MODEL", "")
|
|
126
|
+
or rerank_data.get("model", "")
|
|
127
|
+
or "BAAI/bge-reranker-v2-m3"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
db_path = (
|
|
131
|
+
(getattr(cli_args, 'db_path', None) if cli_args else None)
|
|
132
|
+
or os.environ.get("NBRAG_DB_PATH", "")
|
|
133
|
+
or storage_data.get("db_path", "")
|
|
134
|
+
or "./rag_db"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
raw_files_path = (
|
|
138
|
+
os.environ.get("NBRAG_RAW_FILES_PATH", "")
|
|
139
|
+
or storage_data.get("raw_files_path", "")
|
|
140
|
+
or ""
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
chunk_size = int(
|
|
144
|
+
os.environ.get("NBRAG_CHUNK_SIZE", "0")
|
|
145
|
+
or chunking_data.get("chunk_size", 0)
|
|
146
|
+
or 1500
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
chunk_overlap = int(
|
|
150
|
+
os.environ.get("NBRAG_CHUNK_OVERLAP", "0")
|
|
151
|
+
or chunking_data.get("chunk_overlap", 0)
|
|
152
|
+
or 200
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
_config = RagConfig(
|
|
156
|
+
embedding=EmbeddingConfig(api_key=api_key, base_url=base_url, model=embedding_model),
|
|
157
|
+
rerank=RerankConfig(model=rerank_model),
|
|
158
|
+
storage=StorageConfig(db_path=db_path, raw_files_path=raw_files_path),
|
|
159
|
+
chunking=ChunkingConfig(chunk_size=chunk_size, chunk_overlap=chunk_overlap),
|
|
160
|
+
)
|
|
161
|
+
return _config
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def get_config() -> RagConfig:
|
|
165
|
+
"""获取当前配置(未加载时自动从环境变量加载)。"""
|
|
166
|
+
global _config
|
|
167
|
+
if _config is None:
|
|
168
|
+
load_config()
|
|
169
|
+
return _config
|