nbrag 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nbrag/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ """
2
+ nbrag — Agentic RAG MCP Server
3
+
4
+ AI-driven multi-round code retrieval with 12 complementary tools.
5
+ """
6
+
7
+ __version__ = "0.1.0"
nbrag/__main__.py ADDED
@@ -0,0 +1,5 @@
1
+ """python -m nbrag 入口。"""
2
+
3
+ from nbrag.server import main
4
+
5
+ main()
nbrag/chunker.py ADDED
@@ -0,0 +1,311 @@
1
+ """
2
+ RAG 分块增强模块 —— 文本切分 + 行号计算 + AST scope 解析 + 头部上下文注入。
3
+
4
+ 从 core.py 拆分出来,专注于「把原始文本变成 enriched chunks」这一步。
5
+ core.py 专注于存储和检索。
6
+ """
7
+
8
+ import os
9
+ import ast
10
+ import bisect
11
+ import warnings
12
+
13
+ from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
14
+
15
+
16
+ DEFAULT_CHUNK_SIZE = 1500
17
+ DEFAULT_CHUNK_OVERLAP = 200
18
+
19
+ # ─── 文件类型映射 ─────────────────────────────────────────
20
+
21
+ TEXT_EXTENSIONS = {
22
+ ".txt", ".md", ".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".go",
23
+ ".rs", ".c", ".cpp", ".h", ".hpp", ".cs", ".rb", ".php", ".sh",
24
+ ".json", ".yaml", ".yml", ".toml", ".ini", ".cfg", ".conf",
25
+ ".csv", ".tsv", ".xml", ".html", ".css", ".sql", ".r", ".lua",
26
+ ".swift", ".kt", ".scala", ".dart", ".vue", ".svelte",
27
+ ".rst", ".tex", ".log", ".env", ".bat", ".ps1",
28
+ }
29
+
30
+ _EXT_TO_LANG = {
31
+ ".py": Language.PYTHON, ".js": Language.JS, ".ts": Language.TS,
32
+ ".jsx": Language.JS, ".tsx": Language.TS,
33
+ ".java": Language.JAVA, ".go": Language.GO,
34
+ ".rs": Language.RUST, ".rb": Language.RUBY,
35
+ ".cpp": Language.CPP, ".c": Language.C, ".h": Language.C, ".hpp": Language.CPP,
36
+ ".cs": Language.CSHARP, ".php": Language.PHP,
37
+ ".swift": Language.SWIFT, ".kt": Language.KOTLIN, ".scala": Language.SCALA,
38
+ ".lua": Language.LUA, ".sol": Language.SOL,
39
+ ".md": Language.MARKDOWN, ".html": Language.HTML,
40
+ ".tex": Language.LATEX, ".rst": Language.RST,
41
+ }
42
+
43
+
44
+ # ─── 文本切分 ─────────────────────────────────────────────
45
+
46
+ def chunk_text(text, chunk_size=DEFAULT_CHUNK_SIZE, overlap=DEFAULT_CHUNK_OVERLAP,
47
+ file_ext=""):
48
+ """根据文件类型自动选择最优切分策略(不含头部注入)。
49
+ 代码文件按 class/function 边界切分;Markdown 按标题切分;通用文本按段落切分。
50
+ """
51
+ text = text.strip()
52
+ if not text:
53
+ return []
54
+
55
+ lang = _EXT_TO_LANG.get(file_ext.lower())
56
+ if lang:
57
+ splitter = RecursiveCharacterTextSplitter.from_language(
58
+ language=lang, chunk_size=chunk_size, chunk_overlap=overlap,
59
+ )
60
+ else:
61
+ splitter = RecursiveCharacterTextSplitter(
62
+ chunk_size=chunk_size, chunk_overlap=overlap,
63
+ separators=["\n\n", "\n", "。", ".", ";", ";", ",", ",", " ", ""],
64
+ )
65
+ return splitter.split_text(text)
66
+
67
+
68
+ # ─── 行号计算 ─────────────────────────────────────────────
69
+
70
+ def _build_line_offsets(text):
71
+ """构建行号偏移表: line_offsets[i] = 第 i+1 行的起始字符位置。"""
72
+ offsets = [0]
73
+ for line in text.split('\n'):
74
+ offsets.append(offsets[-1] + len(line) + 1)
75
+ return offsets
76
+
77
+
78
+ def compute_line_ranges(full_text, chunks, overlap=DEFAULT_CHUNK_OVERLAP):
79
+ """为每个 chunk 计算它在原文中的行号范围 (1-based)。
80
+
81
+ 使用顺序搜索 + 重叠偏移,确保每个 chunk 匹配到正确位置。
82
+ Returns: [(start_line, end_line), ...]
83
+ """
84
+ line_offsets = _build_line_offsets(full_text)
85
+ ranges = []
86
+ search_start = 0
87
+
88
+ for chunk in chunks:
89
+ stripped = chunk.strip()
90
+ if not stripped:
91
+ ranges.append((0, 0))
92
+ continue
93
+
94
+ needle = stripped[:200]
95
+ pos = full_text.find(needle, search_start)
96
+ if pos == -1:
97
+ pos = full_text.find(needle)
98
+
99
+ if pos >= 0:
100
+ start_line = bisect.bisect_right(line_offsets, pos)
101
+ end_pos = pos + len(stripped)
102
+ end_line = bisect.bisect_right(line_offsets, end_pos)
103
+ ranges.append((start_line, end_line))
104
+ advance = max(1, len(stripped) - overlap)
105
+ search_start = pos + advance
106
+ else:
107
+ ranges.append((0, 0))
108
+
109
+ return ranges
110
+
111
+
112
+ # ─── Python AST scope 解析 ────────────────────────────────
113
+
114
+ def _extract_signature(node):
115
+ """从 AST 节点提取函数签名字符串。ClassDef 返回基类列表。"""
116
+ if isinstance(node, ast.ClassDef):
117
+ if node.bases:
118
+ bases = []
119
+ for b in node.bases:
120
+ if isinstance(b, ast.Name):
121
+ bases.append(b.id)
122
+ elif isinstance(b, ast.Attribute):
123
+ bases.append(b.attr)
124
+ else:
125
+ bases.append("?")
126
+ return f"class {node.name}({', '.join(bases)})"
127
+ return f"class {node.name}"
128
+
129
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
130
+ args = node.args
131
+ params = []
132
+ all_args = args.args + args.posonlyargs
133
+ defaults_offset = len(all_args) - len(args.defaults)
134
+ for i, arg in enumerate(all_args):
135
+ p = arg.arg
136
+ if arg.annotation and isinstance(arg.annotation, ast.Name):
137
+ p += f": {arg.annotation.id}"
138
+ if i >= defaults_offset:
139
+ p += "=..."
140
+ params.append(p)
141
+ if args.vararg:
142
+ params.append(f"*{args.vararg.arg}")
143
+ if args.kwonlyargs:
144
+ if not args.vararg:
145
+ params.append("*")
146
+ for kw in args.kwonlyargs:
147
+ p = kw.arg
148
+ if kw.annotation and isinstance(kw.annotation, ast.Name):
149
+ p += f": {kw.annotation.id}"
150
+ params.append(p)
151
+ if args.kwarg:
152
+ params.append(f"**{args.kwarg.arg}")
153
+ prefix = "async def" if isinstance(node, ast.AsyncFunctionDef) else "def"
154
+ sig = f"{prefix} {node.name}({', '.join(params)})"
155
+ if node.returns and isinstance(node.returns, ast.Name):
156
+ sig += f" -> {node.returns.id}"
157
+ return sig
158
+
159
+ return ""
160
+
161
+
162
+ def _build_python_scope_map(text):
163
+ """解析 Python 代码,构建 scope 列表。
164
+
165
+ 返回: [(start_line, end_line, scope_str, signature), ...]
166
+ scope_str 示例: "MyClass", "MyClass.my_method", "my_function"
167
+ signature 示例: "def my_method(self, x, y=...)"
168
+ """
169
+ try:
170
+ with warnings.catch_warnings():
171
+ warnings.simplefilter("ignore", SyntaxWarning)
172
+ tree = ast.parse(text)
173
+ except SyntaxError:
174
+ return []
175
+
176
+ scopes = []
177
+
178
+ def _walk(node, parent_chain=""):
179
+ for child in ast.iter_child_nodes(node):
180
+ if isinstance(child, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
181
+ name = child.name
182
+ scope_str = f"{parent_chain}.{name}" if parent_chain else name
183
+ start = child.lineno
184
+ end = child.end_lineno if hasattr(child, 'end_lineno') and child.end_lineno else start
185
+ sig = _extract_signature(child)
186
+ scopes.append((start, end, scope_str, sig))
187
+ _walk(child, scope_str)
188
+
189
+ _walk(tree)
190
+ return scopes
191
+
192
+
193
+ def get_python_scope_at_line(text, line_number):
194
+ """给定行号,返回该行所属的完整 scope 链字符串。
195
+
196
+ 返回示例: "MyClass", "MyClass.my_method", ""(模块级)
197
+ """
198
+ scope_map = _build_python_scope_map(text)
199
+ scope, _sig, _parent = _find_scope_in_map(scope_map, line_number)
200
+ return scope
201
+
202
+
203
+ # ─── 头部上下文注入 ───────────────────────────────────────
204
+
205
+ def _format_header(file_path, line_start, line_end, scope="", signature="",
206
+ parent_class_sig=""):
207
+ """格式化 chunk 头部注入字符串。"""
208
+ parts = [f"[File: {file_path}]"]
209
+ if scope:
210
+ if "." in scope:
211
+ class_part, method_part = scope.rsplit(".", 1)
212
+ if parent_class_sig:
213
+ parts.append(f"[Class: {parent_class_sig}]")
214
+ else:
215
+ parts.append(f"[Class: {class_part}]")
216
+ parts.append(f"[Method: {method_part}]")
217
+ else:
218
+ parts.append(f"[Scope: {scope}]")
219
+ if signature:
220
+ parts.append(f"[Sig: {signature}]")
221
+ if line_start and line_end:
222
+ parts.append(f"[Lines: {line_start}-{line_end}]")
223
+ return "# " + " ".join(parts)
224
+
225
+
226
+ def _find_scope_in_map(scope_map, line_number):
227
+ """从预构建的 scope_map 中查找指定行号的最内层 scope。
228
+ 返回 (scope_str, signature, parent_class_sig) 元组。
229
+ parent_class_sig: 如果当前 scope 是方法(如 Cls.method),返回父类的签名(含继承链)。"""
230
+ best_scope = ""
231
+ best_sig = ""
232
+ best_size = float('inf')
233
+ parent_class_sig = ""
234
+ for entry in scope_map:
235
+ s_start, s_end, s_str = entry[0], entry[1], entry[2]
236
+ sig = entry[3] if len(entry) > 3 else ""
237
+ if s_start <= line_number <= s_end:
238
+ size = s_end - s_start
239
+ if size < best_size:
240
+ best_size = size
241
+ best_scope = s_str
242
+ best_sig = sig
243
+ if "." in best_scope:
244
+ class_name = best_scope.rsplit(".", 1)[0]
245
+ for entry in scope_map:
246
+ if entry[2] == class_name:
247
+ parent_class_sig = entry[3] if len(entry) > 3 else ""
248
+ break
249
+ return best_scope, best_sig, parent_class_sig
250
+
251
+
252
+ def enrich_chunks(chunks, full_text, file_path="", file_ext=""):
253
+ """为 chunks 注入头部上下文信息,返回 (enriched_chunks, line_ranges, scopes)。
254
+
255
+ 头部注入在 embedding 前完成,让向量包含上下文信息。
256
+ line_ranges 和 scopes 用于存入 metadata。
257
+ """
258
+ if not chunks:
259
+ return [], [], []
260
+
261
+ line_ranges = compute_line_ranges(full_text, chunks)
262
+
263
+ scope_map = []
264
+ if file_ext.lower() == ".py":
265
+ scope_map = _build_python_scope_map(full_text)
266
+
267
+ enriched = []
268
+ scopes = []
269
+ for i, chunk in enumerate(chunks):
270
+ ls, le = line_ranges[i]
271
+ if scope_map and ls > 0:
272
+ scope, sig, parent_cls_sig = _find_scope_in_map(scope_map, ls)
273
+ else:
274
+ scope, sig, parent_cls_sig = "", "", ""
275
+ scopes.append(scope)
276
+ header = _format_header(file_path, ls, le, scope, sig, parent_cls_sig)
277
+ enriched.append(f"{header}\n{chunk}")
278
+
279
+ return enriched, line_ranges, scopes
280
+
281
+
282
+ def collect_files(path, file_extensions=None):
283
+ """收集路径下所有文本文件(支持单文件或递归目录)。
284
+
285
+ Args:
286
+ path: 文件或目录路径。
287
+ file_extensions: 可选,限定的后缀列表(如 [".py", ".md"])。
288
+ 传入时只收集这些后缀;不传则使用 TEXT_EXTENSIONS 全集。
289
+ 后缀不区分大小写,自动补 "." 前缀("py" → ".py")。
290
+ """
291
+ if file_extensions is not None:
292
+ allowed = set()
293
+ for ext in file_extensions:
294
+ ext = ext.strip().lower()
295
+ if not ext.startswith("."):
296
+ ext = "." + ext
297
+ allowed.add(ext)
298
+ else:
299
+ allowed = TEXT_EXTENSIONS
300
+
301
+ if os.path.isfile(path):
302
+ ext = os.path.splitext(path)[1].lower()
303
+ if ext in allowed:
304
+ return [path]
305
+ return []
306
+ files = []
307
+ for root, _, fnames in os.walk(path):
308
+ for fn in sorted(fnames):
309
+ if os.path.splitext(fn)[1].lower() in allowed:
310
+ files.append(os.path.join(root, fn))
311
+ return files
nbrag/config.py ADDED
@@ -0,0 +1,169 @@
1
+ """
2
+ 配置加载模块 — CLI > 环境变量 > YAML 配置文件 > 默认值。
3
+
4
+ 最小启动只需要一个环境变量:
5
+ export NBRAG_API_KEY=sk-xxx
6
+ uvx nbrag
7
+ """
8
+
9
+ import os
10
+ from dataclasses import dataclass, field
11
+
12
+
13
+ @dataclass
14
+ class EmbeddingConfig:
15
+ api_key: str = ""
16
+ base_url: str = "https://api.siliconflow.cn/v1"
17
+ model: str = "BAAI/bge-m3"
18
+
19
+
20
+ @dataclass
21
+ class RerankConfig:
22
+ model: str = "BAAI/bge-reranker-v2-m3"
23
+
24
+
25
+ @dataclass
26
+ class StorageConfig:
27
+ db_path: str = "./rag_db"
28
+ raw_files_path: str = "" # 默认 db_path/raw_files
29
+
30
+
31
+ @dataclass
32
+ class ChunkingConfig:
33
+ chunk_size: int = 1500
34
+ chunk_overlap: int = 200
35
+
36
+
37
+ @dataclass
38
+ class RagConfig:
39
+ embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
40
+ rerank: RerankConfig = field(default_factory=RerankConfig)
41
+ storage: StorageConfig = field(default_factory=StorageConfig)
42
+ chunking: ChunkingConfig = field(default_factory=ChunkingConfig)
43
+
44
+ def __post_init__(self):
45
+ if not self.storage.raw_files_path:
46
+ self.storage.raw_files_path = os.path.join(self.storage.db_path, "raw_files")
47
+
48
+
49
+ _config: RagConfig = None
50
+
51
+
52
+ def _load_yaml(path):
53
+ """加载 YAML 配置文件,返回 dict(文件不存在返回空 dict)。"""
54
+ if not path or not os.path.isfile(path):
55
+ return {}
56
+ try:
57
+ import yaml
58
+ with open(path, "r", encoding="utf-8") as f:
59
+ data = yaml.safe_load(f)
60
+ return data if isinstance(data, dict) else {}
61
+ except Exception:
62
+ return {}
63
+
64
+
65
+ def _find_config_file():
66
+ """按优先级查找配置文件。"""
67
+ candidates = [
68
+ os.path.join(os.getcwd(), "nbrag_config.yaml"),
69
+ os.path.join(os.getcwd(), "nbrag_config.yml"),
70
+ os.path.expanduser("~/.config/nbrag/config.yaml"),
71
+ os.path.expanduser("~/.config/nbrag/config.yml"),
72
+ ]
73
+ for c in candidates:
74
+ if os.path.isfile(c):
75
+ return c
76
+ return None
77
+
78
+
79
+ def _resolve_env_ref(value):
80
+ """解析 ${VAR_NAME} 环境变量引用。"""
81
+ if not isinstance(value, str):
82
+ return value
83
+ if value.startswith("${") and value.endswith("}"):
84
+ var_name = value[2:-1]
85
+ return os.environ.get(var_name, "")
86
+ return value
87
+
88
+
89
+ def load_config(cli_args=None) -> RagConfig:
90
+ """加载配置:CLI > 环境变量 > YAML > 默认值。"""
91
+ global _config
92
+
93
+ yaml_path = None
94
+ if cli_args and hasattr(cli_args, 'config') and cli_args.config:
95
+ yaml_path = cli_args.config
96
+ else:
97
+ yaml_path = os.environ.get("NBRAG_CONFIG", None) or _find_config_file()
98
+
99
+ yaml_data = _load_yaml(yaml_path)
100
+
101
+ embedding_data = yaml_data.get("embedding", {})
102
+ rerank_data = yaml_data.get("rerank", {})
103
+ storage_data = yaml_data.get("storage", {})
104
+ chunking_data = yaml_data.get("chunking", {})
105
+
106
+ api_key = (
107
+ (getattr(cli_args, 'api_key', None) if cli_args else None)
108
+ or os.environ.get("NBRAG_API_KEY", "")
109
+ or _resolve_env_ref(embedding_data.get("api_key", ""))
110
+ )
111
+
112
+ base_url = (
113
+ os.environ.get("NBRAG_BASE_URL", "")
114
+ or embedding_data.get("base_url", "")
115
+ or "https://api.siliconflow.cn/v1"
116
+ )
117
+
118
+ embedding_model = (
119
+ os.environ.get("NBRAG_EMBEDDING_MODEL", "")
120
+ or embedding_data.get("model", "")
121
+ or "BAAI/bge-m3"
122
+ )
123
+
124
+ rerank_model = (
125
+ os.environ.get("NBRAG_RERANK_MODEL", "")
126
+ or rerank_data.get("model", "")
127
+ or "BAAI/bge-reranker-v2-m3"
128
+ )
129
+
130
+ db_path = (
131
+ (getattr(cli_args, 'db_path', None) if cli_args else None)
132
+ or os.environ.get("NBRAG_DB_PATH", "")
133
+ or storage_data.get("db_path", "")
134
+ or "./rag_db"
135
+ )
136
+
137
+ raw_files_path = (
138
+ os.environ.get("NBRAG_RAW_FILES_PATH", "")
139
+ or storage_data.get("raw_files_path", "")
140
+ or ""
141
+ )
142
+
143
+ chunk_size = int(
144
+ os.environ.get("NBRAG_CHUNK_SIZE", "0")
145
+ or chunking_data.get("chunk_size", 0)
146
+ or 1500
147
+ )
148
+
149
+ chunk_overlap = int(
150
+ os.environ.get("NBRAG_CHUNK_OVERLAP", "0")
151
+ or chunking_data.get("chunk_overlap", 0)
152
+ or 200
153
+ )
154
+
155
+ _config = RagConfig(
156
+ embedding=EmbeddingConfig(api_key=api_key, base_url=base_url, model=embedding_model),
157
+ rerank=RerankConfig(model=rerank_model),
158
+ storage=StorageConfig(db_path=db_path, raw_files_path=raw_files_path),
159
+ chunking=ChunkingConfig(chunk_size=chunk_size, chunk_overlap=chunk_overlap),
160
+ )
161
+ return _config
162
+
163
+
164
+ def get_config() -> RagConfig:
165
+ """获取当前配置(未加载时自动从环境变量加载)。"""
166
+ global _config
167
+ if _config is None:
168
+ load_config()
169
+ return _config