jarvis-ai-assistant 0.3.23__py3-none-any.whl → 0.3.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +96 -13
  3. jarvis/jarvis_agent/agent_manager.py +0 -3
  4. jarvis/jarvis_agent/jarvis.py +19 -34
  5. jarvis/jarvis_agent/main.py +2 -8
  6. jarvis/jarvis_code_agent/code_agent.py +5 -11
  7. jarvis/jarvis_code_analysis/code_review.py +12 -40
  8. jarvis/jarvis_data/config_schema.json +11 -18
  9. jarvis/jarvis_git_utils/git_commiter.py +11 -25
  10. jarvis/jarvis_mcp/sse_mcp_client.py +4 -3
  11. jarvis/jarvis_mcp/streamable_mcp_client.py +9 -8
  12. jarvis/jarvis_memory_organizer/memory_organizer.py +46 -53
  13. jarvis/jarvis_methodology/main.py +4 -2
  14. jarvis/jarvis_platform/base.py +90 -21
  15. jarvis/jarvis_platform/kimi.py +16 -22
  16. jarvis/jarvis_platform/registry.py +7 -14
  17. jarvis/jarvis_platform/tongyi.py +21 -32
  18. jarvis/jarvis_platform/yuanbao.py +15 -17
  19. jarvis/jarvis_platform_manager/main.py +14 -51
  20. jarvis/jarvis_rag/cli.py +21 -13
  21. jarvis/jarvis_rag/embedding_manager.py +138 -6
  22. jarvis/jarvis_rag/llm_interface.py +0 -2
  23. jarvis/jarvis_rag/rag_pipeline.py +41 -17
  24. jarvis/jarvis_rag/reranker.py +24 -2
  25. jarvis/jarvis_rag/retriever.py +21 -23
  26. jarvis/jarvis_smart_shell/main.py +1 -10
  27. jarvis/jarvis_tools/cli/main.py +22 -15
  28. jarvis/jarvis_tools/edit_file.py +6 -6
  29. jarvis/jarvis_tools/execute_script.py +1 -2
  30. jarvis/jarvis_tools/file_analyzer.py +12 -6
  31. jarvis/jarvis_tools/registry.py +13 -10
  32. jarvis/jarvis_tools/sub_agent.py +5 -8
  33. jarvis/jarvis_tools/sub_code_agent.py +5 -5
  34. jarvis/jarvis_utils/config.py +24 -10
  35. jarvis/jarvis_utils/input.py +8 -5
  36. jarvis/jarvis_utils/methodology.py +11 -6
  37. jarvis/jarvis_utils/utils.py +29 -12
  38. {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/METADATA +10 -3
  39. {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/RECORD +43 -43
  40. {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/WHEEL +0 -0
  41. {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/entry_points.txt +0 -0
  42. {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/licenses/LICENSE +0 -0
  43. {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/top_level.txt +0 -0
@@ -134,10 +134,11 @@ class YuanbaoPlatform(BasePlatform):
134
134
 
135
135
  for file_path in file_list:
136
136
  file_name = os.path.basename(file_path)
137
- PrettyOutput.print(f"上传文件 {file_name}", OutputType.INFO)
137
+ log_lines: list[str] = []
138
+ log_lines.append(f"上传文件 {file_name}")
138
139
  try:
139
140
  # 1. Prepare the file information
140
- PrettyOutput.print(f"准备文件信息: {file_name}", OutputType.INFO)
141
+ log_lines.append(f"准备文件信息: {file_name}")
141
142
  file_size = os.path.getsize(file_path)
142
143
  file_extension = os.path.splitext(file_path)[1].lower().lstrip(".")
143
144
 
@@ -192,23 +193,23 @@ class YuanbaoPlatform(BasePlatform):
192
193
  file_type = "code"
193
194
 
194
195
  # 2. Generate upload information
195
- PrettyOutput.print(f"获取上传信息: {file_name}", OutputType.INFO)
196
+ log_lines.append(f"获取上传信息: {file_name}")
196
197
  upload_info = self._generate_upload_info(file_name)
197
198
  if not upload_info:
198
- PrettyOutput.print(
199
- f"无法获取文件 {file_name} 的上传信息", OutputType.ERROR
200
- )
199
+ log_lines.append(f"无法获取文件 {file_name} 的上传信息")
200
+ PrettyOutput.print("\n".join(log_lines), OutputType.ERROR)
201
201
  return False
202
202
 
203
203
  # 3. Upload the file to COS
204
- PrettyOutput.print(f"上传文件到云存储: {file_name}", OutputType.INFO)
204
+ log_lines.append(f"上传文件到云存储: {file_name}")
205
205
  upload_success = self._upload_file_to_cos(file_path, upload_info)
206
206
  if not upload_success:
207
- PrettyOutput.print(f"上传文件 {file_name} 失败", OutputType.ERROR)
207
+ log_lines.append(f"上传文件 {file_name} 失败")
208
+ PrettyOutput.print("\n".join(log_lines), OutputType.ERROR)
208
209
  return False
209
210
 
210
211
  # 4. Create file metadata for chat
211
- PrettyOutput.print(f"生成文件元数据: {file_name}", OutputType.INFO)
212
+ log_lines.append(f"生成文件元数据: {file_name}")
212
213
  file_metadata = {
213
214
  "type": file_type,
214
215
  "docType": file_extension if file_extension else file_type,
@@ -226,19 +227,16 @@ class YuanbaoPlatform(BasePlatform):
226
227
  file_metadata["width"] = img.width
227
228
  file_metadata["height"] = img.height
228
229
  except Exception as e:
229
- PrettyOutput.print(
230
- f"无法获取图片 {file_name} 的尺寸: {str(e)}",
231
- OutputType.WARNING,
232
- )
230
+ log_lines.append(f"无法获取图片 {file_name} 的尺寸: {str(e)}")
233
231
 
234
232
  uploaded_files.append(file_metadata)
235
- PrettyOutput.print(f"文件 {file_name} 上传成功", OutputType.SUCCESS)
233
+ log_lines.append(f"文件 {file_name} 上传成功")
234
+ PrettyOutput.print("\n".join(log_lines), OutputType.INFO)
236
235
  time.sleep(3) # 上传成功后等待3秒
237
236
 
238
237
  except Exception as e:
239
- PrettyOutput.print(
240
- f"上传文件 {file_path} 时出错: {str(e)}", OutputType.ERROR
241
- )
238
+ log_lines.append(f"上传文件 {file_path} 时出错: {str(e)}")
239
+ PrettyOutput.print("\n".join(log_lines), OutputType.ERROR)
242
240
  return False
243
241
 
244
242
  self.multimedia = uploaded_files
@@ -11,8 +11,6 @@ import typer
11
11
  from jarvis.jarvis_utils.config import (
12
12
  get_normal_platform_name,
13
13
  get_normal_model_name,
14
- get_thinking_platform_name,
15
- get_thinking_model_name,
16
14
  )
17
15
 
18
16
  from jarvis.jarvis_platform.registry import PlatformRegistry
@@ -66,7 +64,7 @@ def list_platforms(
66
64
 
67
65
 
68
66
  def chat_with_model(
69
- platform_name: str, model_name: str, system_prompt: str, llm_type: str = "normal"
67
+ platform_name: str, model_name: str, system_prompt: str
70
68
  ) -> None:
71
69
  """与指定平台和模型进行对话。
72
70
 
@@ -74,7 +72,7 @@ def chat_with_model(
74
72
  platform_name: 平台名称
75
73
  model_name: 模型名称
76
74
  system_prompt: 系统提示语
77
- llm_type: LLM类型,可选值:'normal'(普通)或 'thinking'(思考模式)
75
+
78
76
  """
79
77
  registry = PlatformRegistry.get_global_platform_registry()
80
78
  conversation_history: List[Dict[str, str]] = [] # 存储对话记录
@@ -360,32 +358,19 @@ def chat_command(
360
358
  None, "--platform", "-p", help="指定要使用的平台"
361
359
  ),
362
360
  model: Optional[str] = typer.Option(None, "--model", "-m", help="指定要使用的模型"),
363
- llm_type: str = typer.Option(
364
- "normal",
365
- "-t",
366
- "--llm-type",
367
- help="使用的LLM类型,可选值:'normal'(普通)或 'thinking'(思考模式)",
368
- ),
361
+
369
362
  llm_group: Optional[str] = typer.Option(
370
363
  None, "-g", "--llm-group", help="使用的模型组,覆盖配置文件中的设置"
371
364
  ),
372
365
  ) -> None:
373
366
  """与指定平台和模型聊天。"""
374
367
  # 如果未提供平台或模型参数,则从config获取默认值
375
- platform = platform or (
376
- get_thinking_platform_name(llm_group)
377
- if llm_type == "thinking"
378
- else get_normal_platform_name(llm_group)
379
- )
380
- model = model or (
381
- get_thinking_model_name(llm_group)
382
- if llm_type == "thinking"
383
- else get_normal_model_name(llm_group)
384
- )
368
+ platform = platform or get_normal_platform_name(llm_group)
369
+ model = model or get_normal_model_name(llm_group)
385
370
 
386
371
  if not validate_platform_model(platform, model):
387
372
  return
388
- chat_with_model(platform, model, "", llm_type)
373
+ chat_with_model(platform, model, "")
389
374
 
390
375
 
391
376
  @app.command("service")
@@ -444,12 +429,7 @@ def role_command(
444
429
  model: Optional[str] = typer.Option(
445
430
  None, "--model", "-m", help="指定要使用的模型,覆盖角色配置"
446
431
  ),
447
- llm_type: Optional[str] = typer.Option(
448
- None,
449
- "-t",
450
- "--llm-type",
451
- help="使用的LLM类型,可选值:'normal'(普通)或 'thinking'(思考模式),覆盖角色配置",
452
- ),
432
+
453
433
  llm_group: Optional[str] = typer.Option(
454
434
  None, "-g", "--llm-group", help="使用的模型组,覆盖配置文件中的设置"
455
435
  ),
@@ -483,54 +463,37 @@ def role_command(
483
463
  PrettyOutput.print("无效的选择", OutputType.ERROR)
484
464
  return
485
465
 
486
- # 获取llm_type,优先使用命令行参数,否则使用角色配置,默认为normal
487
- role_llm_type = llm_type or selected_role.get("llm_type", "normal")
466
+
488
467
 
489
468
  # 初始化平台和模型
490
469
  # 如果提供了platform或model参数,优先使用命令行参数
491
- # 否则,如果提供了llm_group,根据llm_type从配置中获取
470
+ # 否则,如果提供了 llm_group,则从配置中获取
492
471
  # 最后才使用角色配置中的platform和model
493
472
  if platform:
494
473
  platform_name = platform
495
474
  elif llm_group:
496
- platform_name = (
497
- get_thinking_platform_name(llm_group)
498
- if role_llm_type == "thinking"
499
- else get_normal_platform_name(llm_group)
500
- )
475
+ platform_name = get_normal_platform_name(llm_group)
501
476
  else:
502
477
  platform_name = selected_role.get("platform")
503
478
  if not platform_name:
504
479
  # 如果角色配置中没有platform,使用默认配置
505
- platform_name = (
506
- get_thinking_platform_name()
507
- if role_llm_type == "thinking"
508
- else get_normal_platform_name()
509
- )
480
+ platform_name = get_normal_platform_name()
510
481
 
511
482
  if model:
512
483
  model_name = model
513
484
  elif llm_group:
514
- model_name = (
515
- get_thinking_model_name(llm_group)
516
- if role_llm_type == "thinking"
517
- else get_normal_model_name(llm_group)
518
- )
485
+ model_name = get_normal_model_name(llm_group)
519
486
  else:
520
487
  model_name = selected_role.get("model")
521
488
  if not model_name:
522
489
  # 如果角色配置中没有model,使用默认配置
523
- model_name = (
524
- get_thinking_model_name()
525
- if role_llm_type == "thinking"
526
- else get_normal_model_name()
527
- )
490
+ model_name = get_normal_model_name()
528
491
 
529
492
  system_prompt = selected_role.get("system_prompt", "")
530
493
 
531
494
  # 开始对话
532
495
  PrettyOutput.print(f"已选择角色: {selected_role['name']}", OutputType.SUCCESS)
533
- chat_with_model(platform_name, model_name, system_prompt, role_llm_type)
496
+ chat_with_model(platform_name, model_name, system_prompt)
534
497
 
535
498
 
536
499
  def main() -> None:
jarvis/jarvis_rag/cli.py CHANGED
@@ -240,6 +240,7 @@ def add_documents(
240
240
 
241
241
  sorted_files = sorted(list(files_to_process))
242
242
  total_files = len(sorted_files)
243
+ loaded_msgs: List[str] = []
243
244
 
244
245
  for i, file_path in enumerate(sorted_files):
245
246
  try:
@@ -249,14 +250,15 @@ def add_documents(
249
250
  loader = TextLoader(str(file_path), encoding="utf-8")
250
251
 
251
252
  docs_batch.extend(loader.load())
252
- PrettyOutput.print(
253
- f"已加载: {file_path} (文件 {i + 1}/{total_files})", OutputType.INFO
254
- )
253
+ loaded_msgs.append(f"已加载: {file_path} (文件 {i + 1}/{total_files})")
255
254
  except Exception as e:
256
255
  PrettyOutput.print(f"加载失败 {file_path}: {e}", OutputType.WARNING)
257
256
 
258
257
  # 当批处理已满或是最后一个文件时处理批处理
259
258
  if docs_batch and (len(docs_batch) >= batch_size or (i + 1) == total_files):
259
+ if loaded_msgs:
260
+ PrettyOutput.print("\n".join(loaded_msgs), OutputType.INFO)
261
+ loaded_msgs = []
260
262
  PrettyOutput.print(
261
263
  f"正在处理批次,包含 {len(docs_batch)} 个文档...", OutputType.INFO
262
264
  )
@@ -267,6 +269,10 @@ def add_documents(
267
269
  )
268
270
  docs_batch = [] # 清空批处理
269
271
 
272
+ # 最后统一打印可能残留的“已加载”信息
273
+ if loaded_msgs:
274
+ PrettyOutput.print("\n".join(loaded_msgs), OutputType.INFO)
275
+ loaded_msgs = []
270
276
  if total_docs_added == 0:
271
277
  PrettyOutput.print("未能成功加载任何文档。", OutputType.ERROR)
272
278
  raise typer.Exit(code=1)
@@ -321,12 +327,11 @@ def list_documents(
321
327
  )
322
328
  return
323
329
 
324
- PrettyOutput.print(
325
- f"知识库 '{collection_name}' 中共有 {len(sources)} 个独立文档:",
326
- OutputType.INFO,
327
- )
330
+ # 避免在循环中逐条打印,先拼接后统一打印
331
+ lines = [f"知识库 '{collection_name}' 中共有 {len(sources)} 个独立文档:"]
328
332
  for i, source in enumerate(sorted(list(sources)), 1):
329
- PrettyOutput.print(f" {i}. {source}", OutputType.INFO)
333
+ lines.append(f" {i}. {source}")
334
+ PrettyOutput.print("\n".join(lines), OutputType.INFO)
330
335
 
331
336
  except Exception as e:
332
337
  PrettyOutput.print(f"发生错误: {e}", OutputType.ERROR)
@@ -352,6 +357,12 @@ def retrieve(
352
357
  None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
353
358
  ),
354
359
  n_results: int = typer.Option(5, "--top-n", help="要检索的文档数量。"),
360
+ rewrite: bool = typer.Option(
361
+ True,
362
+ "--rewrite/--no-rewrite",
363
+ help="是否对查询进行LLM重写以提升召回,默认开启。",
364
+ show_default=True,
365
+ ),
355
366
  ):
356
367
  """仅从RAG知识库检索文档并打印结果。"""
357
368
  try:
@@ -366,6 +377,7 @@ def retrieve(
366
377
  collection_name=collection_name,
367
378
  use_bm25=use_bm25,
368
379
  use_rerank=use_rerank,
380
+ use_query_rewrite=rewrite,
369
381
  )
370
382
 
371
383
  PrettyOutput.print(f"正在为问题检索文档: '{question}'", OutputType.INFO)
@@ -450,11 +462,7 @@ def query(
450
462
  PrettyOutput.print(f"正在查询: '{question}'", OutputType.INFO)
451
463
  answer = pipeline.query(question)
452
464
 
453
- PrettyOutput.print("答案:", OutputType.INFO)
454
- # 我们仍然可以使用 rich.markdown.Markdown,因为 PrettyOutput 底层使用了 rich
455
- from jarvis.jarvis_utils.globals import console
456
-
457
- console.print(Markdown(answer))
465
+ PrettyOutput.print(answer, OutputType.SUCCESS)
458
466
 
459
467
  except Exception as e:
460
468
  PrettyOutput.print(f"发生错误: {e}", OutputType.ERROR)
@@ -1,6 +1,9 @@
1
1
  import torch
2
+ import os
2
3
  from typing import List, cast
3
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
+ from huggingface_hub import snapshot_download
6
+
4
7
 
5
8
  from .cache import EmbeddingCache
6
9
  from jarvis.jarvis_utils.output import OutputType, PrettyOutput
@@ -38,12 +41,141 @@ class EmbeddingManager:
38
41
  encode_kwargs = {"normalize_embeddings": True}
39
42
 
40
43
  try:
41
- return HuggingFaceEmbeddings(
42
- model_name=self.model_name,
43
- model_kwargs=model_kwargs,
44
- encode_kwargs=encode_kwargs,
45
- show_progress=True,
46
- )
44
+ # First try to load model from local cache without any network access
45
+ try:
46
+ from sentence_transformers import SentenceTransformer
47
+ local_dir = None
48
+ # Prefer explicit local dir via env or direct path
49
+
50
+ if os.path.isdir(self.model_name):
51
+ return HuggingFaceEmbeddings(
52
+ model_name=self.model_name,
53
+ model_kwargs=model_kwargs,
54
+ encode_kwargs=encode_kwargs,
55
+ show_progress=False,
56
+ )
57
+
58
+ # Try common local cache directories for sentence-transformers and HF hub
59
+ try:
60
+ home = os.path.expanduser("~")
61
+ st_home = os.path.join(home, ".cache", "sentence_transformers")
62
+ torch_st_home = os.path.join(home, ".cache", "torch", "sentence_transformers")
63
+ # Build common name variants found in local caches
64
+ org, name = (
65
+ self.model_name.split("/", 1)
66
+ if "/" in self.model_name
67
+ else ("", self.model_name)
68
+ )
69
+ san1 = self.model_name.replace("/", "_")
70
+ san2 = self.model_name.replace("/", "__")
71
+ san3 = self.model_name.replace("/", "--")
72
+ # include plain 'name' for caches that drop org prefix
73
+ name_variants = list(dict.fromkeys([self.model_name, san1, san2, san3, name]))
74
+ candidates = []
75
+ for base in [st_home, torch_st_home]:
76
+ for nv in name_variants:
77
+ p = os.path.join(base, nv)
78
+ if os.path.isdir(p):
79
+ candidates.append(p)
80
+ # Fuzzy scan cache directory for entries that include variants
81
+ try:
82
+ for entry in os.listdir(base):
83
+ ep = os.path.join(base, entry)
84
+ if not os.path.isdir(ep):
85
+ continue
86
+ if (
87
+ (org and entry.startswith(f"{org}__") and name in entry)
88
+ or (san1 in entry)
89
+ or (name in entry)
90
+ ):
91
+ candidates.append(ep)
92
+ except Exception:
93
+ pass
94
+
95
+ # Hugging Face Hub cache snapshots
96
+ hf_cache = os.path.join(home, ".cache", "huggingface", "hub")
97
+ if "/" in self.model_name:
98
+ org, name = self.model_name.split("/", 1)
99
+ models_dir = os.path.join(hf_cache, f"models--{org}--{name}", "snapshots")
100
+ if os.path.isdir(models_dir):
101
+ try:
102
+ snaps = sorted(
103
+ [os.path.join(models_dir, d) for d in os.listdir(models_dir)],
104
+ key=lambda p: os.path.getmtime(p),
105
+ reverse=True,
106
+ )
107
+ except Exception:
108
+ snaps = [os.path.join(models_dir, d) for d in os.listdir(models_dir)]
109
+ for sp in snaps:
110
+ if os.path.isdir(sp):
111
+ candidates.append(sp)
112
+ break
113
+
114
+ for cand in candidates:
115
+ try:
116
+ return HuggingFaceEmbeddings(
117
+ model_name=cand,
118
+ model_kwargs=model_kwargs,
119
+ encode_kwargs=encode_kwargs,
120
+ show_progress=False,
121
+ )
122
+ except Exception:
123
+ continue
124
+ except Exception:
125
+ pass
126
+
127
+ try:
128
+ # Try resolve local cached directory; do not hit network
129
+ local_dir = snapshot_download(repo_id=self.model_name, local_files_only=True)
130
+ except Exception:
131
+ local_dir = None
132
+
133
+ if local_dir:
134
+ return HuggingFaceEmbeddings(
135
+ model_name=local_dir,
136
+ model_kwargs=model_kwargs,
137
+ encode_kwargs=encode_kwargs,
138
+ show_progress=False,
139
+ )
140
+
141
+
142
+
143
+ # Fall back to remote download if local cache not found and not offline
144
+ return HuggingFaceEmbeddings(
145
+ model_name=self.model_name,
146
+ model_kwargs=model_kwargs,
147
+ encode_kwargs=encode_kwargs,
148
+ show_progress=True,
149
+ )
150
+ except Exception as _e:
151
+ # 如果已检测到本地候选路径(直接目录 / 本地缓存快照),则视为本地加载失败,
152
+ # 为避免在用户期望“本地优先不联网”的情况下触发联网,直接抛错并给出修复建议。
153
+ had_local_candidate = False
154
+ try:
155
+ had_local_candidate = (
156
+ os.path.isdir(self.model_name)
157
+ # 如果上面 snapshot_download 命中了本地缓存,会将 local_dir 设为非 None
158
+ or (locals().get("local_dir") is not None)
159
+ )
160
+ except Exception:
161
+ pass
162
+
163
+ if had_local_candidate:
164
+ PrettyOutput.print(
165
+ "检测到本地模型路径但加载失败。为避免触发网络访问,已中止远程回退。\n"
166
+ "请确认本地目录包含完整的 Transformers/Tokenizer 文件(如 config.json、model.safetensors、tokenizer.json/merges.txt 等),\n"
167
+ "或在配置中将 embedding_model 设置为该本地目录,或将模型放置到默认的 Hugging Face 缓存目录(例如 ~/.cache/huggingface/hub)。",
168
+ OutputType.ERROR,
169
+ )
170
+ raise
171
+
172
+ # 未发现任何本地候选,则保持原有行为:回退至远程下载
173
+ return HuggingFaceEmbeddings(
174
+ model_name=self.model_name,
175
+ model_kwargs=model_kwargs,
176
+ encode_kwargs=encode_kwargs,
177
+ show_progress=True,
178
+ )
47
179
  except Exception as e:
48
180
  PrettyOutput.print(
49
181
  f"加载嵌入模型 '{self.model_name}' 时出错: {e}", OutputType.ERROR
@@ -47,13 +47,11 @@ class ToolAgent_LLM(LLMInterface):
47
47
  # 为代理提供一个通用的系统提示
48
48
  self.system_prompt = "You are a helpful assistant. Please answer the user's question based on the provided context. You can use tools to find more information if needed."
49
49
  self.summary_prompt = """
50
- <report>
51
50
  请为本次问答任务生成一个总结报告,包含以下内容:
52
51
 
53
52
  1. **原始问题**: 重述用户最开始提出的问题。
54
53
  2. **关键信息来源**: 总结你是基于哪些关键信息或文件得出的结论。
55
54
  3. **最终答案**: 给出最终的、精炼的回答。
56
- </report>
57
55
  """
58
56
 
59
57
  def generate(self, prompt: str, **kwargs) -> str:
@@ -34,6 +34,7 @@ class JarvisRAGPipeline:
34
34
  collection_name: str = "jarvis_rag_collection",
35
35
  use_bm25: bool = True,
36
36
  use_rerank: bool = True,
37
+ use_query_rewrite: bool = True,
37
38
  ):
38
39
  """
39
40
  初始化RAG管道。
@@ -69,6 +70,8 @@ class JarvisRAGPipeline:
69
70
  self.collection_name = collection_name
70
71
  self.use_bm25 = use_bm25
71
72
  self.use_rerank = use_rerank
73
+ # 查询重写开关(默认开启,可由CLI控制)
74
+ self.use_query_rewrite = use_query_rewrite
72
75
 
73
76
  # 延迟加载的组件
74
77
  self._embedding_manager: Optional[EmbeddingManager] = None
@@ -161,14 +164,15 @@ class JarvisRAGPipeline:
161
164
  if not changed and not deleted:
162
165
  return
163
166
  # 打印摘要
164
- PrettyOutput.print(
165
- f"检测到索引可能不一致:变更 {len(changed)} 个,删除 {len(deleted)} 个。",
166
- OutputType.WARNING,
167
- )
168
- for p in changed[:3] if changed else []:
169
- PrettyOutput.print(f" 变更: {p}", OutputType.WARNING)
170
- for p in deleted[:3] if deleted else []:
171
- PrettyOutput.print(f" 删除: {p}", OutputType.WARNING)
167
+ # 先拼接列表信息再统一打印,避免循环中逐条打印
168
+ lines = [
169
+ f"检测到索引可能不一致:变更 {len(changed)} 个,删除 {len(deleted)} 个。"
170
+ ]
171
+ if changed:
172
+ lines.extend([f" 变更: {p}" for p in changed[:3]])
173
+ if deleted:
174
+ lines.extend([f" 删除: {p}" for p in deleted[:3]])
175
+ PrettyOutput.print("\n".join(lines), OutputType.WARNING)
172
176
  # 询问用户
173
177
  if get_yes_no(
174
178
  "检测到索引变更,是否现在更新索引后再开始检索?", default=True
@@ -228,13 +232,23 @@ class JarvisRAGPipeline:
228
232
  """
229
233
  # 0. 检测索引变更并可选更新(在重写query之前)
230
234
  self._pre_search_update_index_if_needed()
231
- # 1. 将原始查询重写为多个查询
232
- rewritten_queries = self._get_query_rewriter().rewrite(query_text)
235
+ # 1. 将原始查询重写为多个查询(可配置)
236
+ if self.use_query_rewrite:
237
+ rewritten_queries = self._get_query_rewriter().rewrite(query_text)
238
+ else:
239
+ PrettyOutput.print(
240
+ "已关闭查询重写,将直接使用原始查询进行检索。",
241
+ OutputType.INFO,
242
+ )
243
+ rewritten_queries = [query_text]
233
244
 
234
245
  # 2. 为每个重写的查询检索初始候选文档
246
+ PrettyOutput.print(
247
+ "将为以下查询变体进行混合检索:\n" + "\n".join([f" - {q}" for q in rewritten_queries]),
248
+ OutputType.INFO,
249
+ )
235
250
  all_candidate_docs = []
236
251
  for q in rewritten_queries:
237
- PrettyOutput.print(f"正在为查询变体 '{q}' 进行混合检索...", OutputType.INFO)
238
252
  candidates = self._get_retriever().retrieve(
239
253
  q, n_results=n_results * 2, use_bm25=self.use_bm25
240
254
  )
@@ -273,9 +287,9 @@ class JarvisRAGPipeline:
273
287
  )
274
288
  )
275
289
  if sources:
276
- PrettyOutput.print("根据以下文档回答:", OutputType.INFO)
277
- for source in sources:
278
- PrettyOutput.print(f" - {source}", OutputType.INFO)
290
+ # 合并来源列表后一次性打印,避免多次加框
291
+ lines = ["根据以下文档回答:"] + [f" - {source}" for source in sources]
292
+ PrettyOutput.print("\n".join(lines), OutputType.INFO)
279
293
 
280
294
  # 4. 创建最终提示并生成答案
281
295
  # 我们使用原始的query_text作为给LLM的最终提示
@@ -299,13 +313,23 @@ class JarvisRAGPipeline:
299
313
  """
300
314
  # 0. 检测索引变更并可选更新(在重写query之前)
301
315
  self._pre_search_update_index_if_needed()
302
- # 1. 重写查询
303
- rewritten_queries = self._get_query_rewriter().rewrite(query_text)
316
+ # 1. 重写查询(可配置)
317
+ if self.use_query_rewrite:
318
+ rewritten_queries = self._get_query_rewriter().rewrite(query_text)
319
+ else:
320
+ PrettyOutput.print(
321
+ "已关闭查询重写,将直接使用原始查询进行检索。",
322
+ OutputType.INFO,
323
+ )
324
+ rewritten_queries = [query_text]
304
325
 
305
326
  # 2. 检索候选文档
327
+ PrettyOutput.print(
328
+ "将为以下查询变体进行混合检索:\n" + "\n".join([f" - {q}" for q in rewritten_queries]),
329
+ OutputType.INFO,
330
+ )
306
331
  all_candidate_docs = []
307
332
  for q in rewritten_queries:
308
- PrettyOutput.print(f"正在为查询变体 '{q}' 进行混合检索...", OutputType.INFO)
309
333
  candidates = self._get_retriever().retrieve(
310
334
  q, n_results=n_results * 2, use_bm25=self.use_bm25
311
335
  )
@@ -1,9 +1,11 @@
1
1
  from typing import List
2
+ import os
2
3
 
3
4
  from langchain.docstore.document import Document
4
5
  from sentence_transformers.cross_encoder import ( # type: ignore
5
6
  CrossEncoder,
6
7
  )
8
+ from huggingface_hub import snapshot_download
7
9
  from jarvis.jarvis_utils.output import OutputType, PrettyOutput
8
10
 
9
11
 
@@ -21,8 +23,28 @@ class Reranker:
21
23
  model_name (str): 要使用的Cross-Encoder模型的名称。
22
24
  """
23
25
  PrettyOutput.print(f"正在初始化重排模型: {model_name}...", OutputType.INFO)
24
- self.model = CrossEncoder(model_name)
25
- PrettyOutput.print("重排模型初始化成功。", OutputType.SUCCESS)
26
+ try:
27
+ local_dir = None
28
+
29
+ if os.path.isdir(model_name):
30
+ self.model = CrossEncoder(model_name)
31
+ PrettyOutput.print("重排模型初始化成功。", OutputType.SUCCESS)
32
+ return
33
+ try:
34
+ # Prefer local cache; avoid any network access
35
+ local_dir = snapshot_download(repo_id=model_name, local_files_only=True)
36
+ except Exception:
37
+ local_dir = None
38
+
39
+ if local_dir:
40
+ self.model = CrossEncoder(local_dir)
41
+ else:
42
+ self.model = CrossEncoder(model_name)
43
+
44
+ PrettyOutput.print("重排模型初始化成功。", OutputType.SUCCESS)
45
+ except Exception as e:
46
+ PrettyOutput.print(f"初始化重排模型失败: {e}", OutputType.ERROR)
47
+ raise
26
48
 
27
49
  def rerank(
28
50
  self, query: str, documents: List[Document], top_n: int = 5