jarvis-ai-assistant 0.7.16__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +567 -222
- jarvis/jarvis_agent/agent_manager.py +19 -12
- jarvis/jarvis_agent/builtin_input_handler.py +79 -11
- jarvis/jarvis_agent/config_editor.py +7 -2
- jarvis/jarvis_agent/event_bus.py +24 -13
- jarvis/jarvis_agent/events.py +19 -1
- jarvis/jarvis_agent/file_context_handler.py +67 -64
- jarvis/jarvis_agent/file_methodology_manager.py +38 -24
- jarvis/jarvis_agent/jarvis.py +186 -114
- jarvis/jarvis_agent/language_extractors/__init__.py +8 -1
- jarvis/jarvis_agent/language_extractors/c_extractor.py +7 -4
- jarvis/jarvis_agent/language_extractors/cpp_extractor.py +9 -4
- jarvis/jarvis_agent/language_extractors/go_extractor.py +7 -4
- jarvis/jarvis_agent/language_extractors/java_extractor.py +27 -20
- jarvis/jarvis_agent/language_extractors/javascript_extractor.py +22 -17
- jarvis/jarvis_agent/language_extractors/python_extractor.py +7 -4
- jarvis/jarvis_agent/language_extractors/rust_extractor.py +7 -4
- jarvis/jarvis_agent/language_extractors/typescript_extractor.py +22 -17
- jarvis/jarvis_agent/language_support_info.py +250 -219
- jarvis/jarvis_agent/main.py +19 -23
- jarvis/jarvis_agent/memory_manager.py +9 -6
- jarvis/jarvis_agent/methodology_share_manager.py +21 -15
- jarvis/jarvis_agent/output_handler.py +4 -2
- jarvis/jarvis_agent/prompt_builder.py +7 -6
- jarvis/jarvis_agent/prompt_manager.py +113 -8
- jarvis/jarvis_agent/prompts.py +317 -85
- jarvis/jarvis_agent/protocols.py +5 -2
- jarvis/jarvis_agent/run_loop.py +192 -32
- jarvis/jarvis_agent/session_manager.py +7 -3
- jarvis/jarvis_agent/share_manager.py +23 -13
- jarvis/jarvis_agent/shell_input_handler.py +12 -8
- jarvis/jarvis_agent/stdio_redirect.py +25 -26
- jarvis/jarvis_agent/task_analyzer.py +29 -23
- jarvis/jarvis_agent/task_list.py +869 -0
- jarvis/jarvis_agent/task_manager.py +26 -23
- jarvis/jarvis_agent/tool_executor.py +6 -5
- jarvis/jarvis_agent/tool_share_manager.py +24 -14
- jarvis/jarvis_agent/user_interaction.py +3 -3
- jarvis/jarvis_agent/utils.py +9 -1
- jarvis/jarvis_agent/web_bridge.py +37 -17
- jarvis/jarvis_agent/web_output_sink.py +5 -2
- jarvis/jarvis_agent/web_server.py +165 -36
- jarvis/jarvis_c2rust/__init__.py +1 -1
- jarvis/jarvis_c2rust/cli.py +260 -141
- jarvis/jarvis_c2rust/collector.py +37 -18
- jarvis/jarvis_c2rust/constants.py +60 -0
- jarvis/jarvis_c2rust/library_replacer.py +242 -1010
- jarvis/jarvis_c2rust/library_replacer_checkpoint.py +133 -0
- jarvis/jarvis_c2rust/library_replacer_llm.py +287 -0
- jarvis/jarvis_c2rust/library_replacer_loader.py +191 -0
- jarvis/jarvis_c2rust/library_replacer_output.py +134 -0
- jarvis/jarvis_c2rust/library_replacer_prompts.py +124 -0
- jarvis/jarvis_c2rust/library_replacer_utils.py +188 -0
- jarvis/jarvis_c2rust/llm_module_agent.py +98 -1044
- jarvis/jarvis_c2rust/llm_module_agent_apply.py +170 -0
- jarvis/jarvis_c2rust/llm_module_agent_executor.py +288 -0
- jarvis/jarvis_c2rust/llm_module_agent_loader.py +170 -0
- jarvis/jarvis_c2rust/llm_module_agent_prompts.py +268 -0
- jarvis/jarvis_c2rust/llm_module_agent_types.py +57 -0
- jarvis/jarvis_c2rust/llm_module_agent_utils.py +150 -0
- jarvis/jarvis_c2rust/llm_module_agent_validator.py +119 -0
- jarvis/jarvis_c2rust/loaders.py +28 -10
- jarvis/jarvis_c2rust/models.py +5 -2
- jarvis/jarvis_c2rust/optimizer.py +192 -1974
- jarvis/jarvis_c2rust/optimizer_build_fix.py +286 -0
- jarvis/jarvis_c2rust/optimizer_clippy.py +766 -0
- jarvis/jarvis_c2rust/optimizer_config.py +49 -0
- jarvis/jarvis_c2rust/optimizer_docs.py +183 -0
- jarvis/jarvis_c2rust/optimizer_options.py +48 -0
- jarvis/jarvis_c2rust/optimizer_progress.py +469 -0
- jarvis/jarvis_c2rust/optimizer_report.py +52 -0
- jarvis/jarvis_c2rust/optimizer_unsafe.py +309 -0
- jarvis/jarvis_c2rust/optimizer_utils.py +469 -0
- jarvis/jarvis_c2rust/optimizer_visibility.py +185 -0
- jarvis/jarvis_c2rust/scanner.py +229 -166
- jarvis/jarvis_c2rust/transpiler.py +531 -2732
- jarvis/jarvis_c2rust/transpiler_agents.py +503 -0
- jarvis/jarvis_c2rust/transpiler_build.py +1294 -0
- jarvis/jarvis_c2rust/transpiler_codegen.py +204 -0
- jarvis/jarvis_c2rust/transpiler_compile.py +146 -0
- jarvis/jarvis_c2rust/transpiler_config.py +178 -0
- jarvis/jarvis_c2rust/transpiler_context.py +122 -0
- jarvis/jarvis_c2rust/transpiler_executor.py +516 -0
- jarvis/jarvis_c2rust/transpiler_generation.py +278 -0
- jarvis/jarvis_c2rust/transpiler_git.py +163 -0
- jarvis/jarvis_c2rust/transpiler_mod_utils.py +225 -0
- jarvis/jarvis_c2rust/transpiler_modules.py +336 -0
- jarvis/jarvis_c2rust/transpiler_planning.py +394 -0
- jarvis/jarvis_c2rust/transpiler_review.py +1196 -0
- jarvis/jarvis_c2rust/transpiler_symbols.py +176 -0
- jarvis/jarvis_c2rust/utils.py +269 -79
- jarvis/jarvis_code_agent/after_change.py +233 -0
- jarvis/jarvis_code_agent/build_validation_config.py +37 -30
- jarvis/jarvis_code_agent/builtin_rules.py +68 -0
- jarvis/jarvis_code_agent/code_agent.py +976 -1517
- jarvis/jarvis_code_agent/code_agent_build.py +227 -0
- jarvis/jarvis_code_agent/code_agent_diff.py +246 -0
- jarvis/jarvis_code_agent/code_agent_git.py +525 -0
- jarvis/jarvis_code_agent/code_agent_impact.py +177 -0
- jarvis/jarvis_code_agent/code_agent_lint.py +283 -0
- jarvis/jarvis_code_agent/code_agent_llm.py +159 -0
- jarvis/jarvis_code_agent/code_agent_postprocess.py +105 -0
- jarvis/jarvis_code_agent/code_agent_prompts.py +46 -0
- jarvis/jarvis_code_agent/code_agent_rules.py +305 -0
- jarvis/jarvis_code_agent/code_analyzer/__init__.py +52 -48
- jarvis/jarvis_code_agent/code_analyzer/base_language.py +12 -10
- jarvis/jarvis_code_agent/code_analyzer/build_validator/__init__.py +12 -11
- jarvis/jarvis_code_agent/code_analyzer/build_validator/base.py +16 -12
- jarvis/jarvis_code_agent/code_analyzer/build_validator/cmake.py +26 -17
- jarvis/jarvis_code_agent/code_analyzer/build_validator/detector.py +558 -104
- jarvis/jarvis_code_agent/code_analyzer/build_validator/fallback.py +27 -16
- jarvis/jarvis_code_agent/code_analyzer/build_validator/go.py +22 -18
- jarvis/jarvis_code_agent/code_analyzer/build_validator/java_gradle.py +21 -16
- jarvis/jarvis_code_agent/code_analyzer/build_validator/java_maven.py +20 -16
- jarvis/jarvis_code_agent/code_analyzer/build_validator/makefile.py +27 -16
- jarvis/jarvis_code_agent/code_analyzer/build_validator/nodejs.py +47 -23
- jarvis/jarvis_code_agent/code_analyzer/build_validator/python.py +71 -37
- jarvis/jarvis_code_agent/code_analyzer/build_validator/rust.py +162 -35
- jarvis/jarvis_code_agent/code_analyzer/build_validator/validator.py +111 -57
- jarvis/jarvis_code_agent/code_analyzer/build_validator.py +18 -12
- jarvis/jarvis_code_agent/code_analyzer/context_manager.py +185 -183
- jarvis/jarvis_code_agent/code_analyzer/context_recommender.py +2 -1
- jarvis/jarvis_code_agent/code_analyzer/dependency_analyzer.py +24 -15
- jarvis/jarvis_code_agent/code_analyzer/file_ignore.py +227 -141
- jarvis/jarvis_code_agent/code_analyzer/impact_analyzer.py +321 -247
- jarvis/jarvis_code_agent/code_analyzer/language_registry.py +37 -29
- jarvis/jarvis_code_agent/code_analyzer/language_support.py +21 -13
- jarvis/jarvis_code_agent/code_analyzer/languages/__init__.py +15 -9
- jarvis/jarvis_code_agent/code_analyzer/languages/c_cpp_language.py +75 -45
- jarvis/jarvis_code_agent/code_analyzer/languages/go_language.py +87 -52
- jarvis/jarvis_code_agent/code_analyzer/languages/java_language.py +84 -51
- jarvis/jarvis_code_agent/code_analyzer/languages/javascript_language.py +94 -64
- jarvis/jarvis_code_agent/code_analyzer/languages/python_language.py +109 -71
- jarvis/jarvis_code_agent/code_analyzer/languages/rust_language.py +97 -63
- jarvis/jarvis_code_agent/code_analyzer/languages/typescript_language.py +103 -69
- jarvis/jarvis_code_agent/code_analyzer/llm_context_recommender.py +271 -268
- jarvis/jarvis_code_agent/code_analyzer/symbol_extractor.py +76 -64
- jarvis/jarvis_code_agent/code_analyzer/tree_sitter_extractor.py +92 -19
- jarvis/jarvis_code_agent/diff_visualizer.py +998 -0
- jarvis/jarvis_code_agent/lint.py +223 -524
- jarvis/jarvis_code_agent/rule_share_manager.py +158 -0
- jarvis/jarvis_code_agent/rules/clean_code.md +144 -0
- jarvis/jarvis_code_agent/rules/code_review.md +115 -0
- jarvis/jarvis_code_agent/rules/documentation.md +165 -0
- jarvis/jarvis_code_agent/rules/generate_rules.md +52 -0
- jarvis/jarvis_code_agent/rules/performance.md +158 -0
- jarvis/jarvis_code_agent/rules/refactoring.md +139 -0
- jarvis/jarvis_code_agent/rules/security.md +160 -0
- jarvis/jarvis_code_agent/rules/tdd.md +78 -0
- jarvis/jarvis_code_agent/test_rules/cpp_test.md +118 -0
- jarvis/jarvis_code_agent/test_rules/go_test.md +98 -0
- jarvis/jarvis_code_agent/test_rules/java_test.md +99 -0
- jarvis/jarvis_code_agent/test_rules/javascript_test.md +113 -0
- jarvis/jarvis_code_agent/test_rules/php_test.md +117 -0
- jarvis/jarvis_code_agent/test_rules/python_test.md +91 -0
- jarvis/jarvis_code_agent/test_rules/ruby_test.md +102 -0
- jarvis/jarvis_code_agent/test_rules/rust_test.md +86 -0
- jarvis/jarvis_code_agent/utils.py +36 -26
- jarvis/jarvis_code_analysis/checklists/loader.py +21 -21
- jarvis/jarvis_code_analysis/code_review.py +64 -33
- jarvis/jarvis_data/config_schema.json +285 -192
- jarvis/jarvis_git_squash/main.py +8 -6
- jarvis/jarvis_git_utils/git_commiter.py +53 -76
- jarvis/jarvis_mcp/__init__.py +5 -2
- jarvis/jarvis_mcp/sse_mcp_client.py +40 -30
- jarvis/jarvis_mcp/stdio_mcp_client.py +27 -19
- jarvis/jarvis_mcp/streamable_mcp_client.py +35 -26
- jarvis/jarvis_memory_organizer/memory_organizer.py +78 -55
- jarvis/jarvis_methodology/main.py +48 -39
- jarvis/jarvis_multi_agent/__init__.py +56 -23
- jarvis/jarvis_multi_agent/main.py +15 -18
- jarvis/jarvis_platform/base.py +179 -111
- jarvis/jarvis_platform/human.py +27 -16
- jarvis/jarvis_platform/kimi.py +52 -45
- jarvis/jarvis_platform/openai.py +101 -40
- jarvis/jarvis_platform/registry.py +51 -33
- jarvis/jarvis_platform/tongyi.py +68 -38
- jarvis/jarvis_platform/yuanbao.py +59 -43
- jarvis/jarvis_platform_manager/main.py +68 -76
- jarvis/jarvis_platform_manager/service.py +24 -14
- jarvis/jarvis_rag/README_CONFIG.md +314 -0
- jarvis/jarvis_rag/README_DYNAMIC_LOADING.md +311 -0
- jarvis/jarvis_rag/README_ONLINE_MODELS.md +230 -0
- jarvis/jarvis_rag/__init__.py +57 -4
- jarvis/jarvis_rag/cache.py +3 -1
- jarvis/jarvis_rag/cli.py +48 -68
- jarvis/jarvis_rag/embedding_interface.py +39 -0
- jarvis/jarvis_rag/embedding_manager.py +7 -230
- jarvis/jarvis_rag/embeddings/__init__.py +41 -0
- jarvis/jarvis_rag/embeddings/base.py +114 -0
- jarvis/jarvis_rag/embeddings/cohere.py +66 -0
- jarvis/jarvis_rag/embeddings/edgefn.py +117 -0
- jarvis/jarvis_rag/embeddings/local.py +260 -0
- jarvis/jarvis_rag/embeddings/openai.py +62 -0
- jarvis/jarvis_rag/embeddings/registry.py +293 -0
- jarvis/jarvis_rag/llm_interface.py +8 -6
- jarvis/jarvis_rag/query_rewriter.py +8 -9
- jarvis/jarvis_rag/rag_pipeline.py +61 -52
- jarvis/jarvis_rag/reranker.py +7 -75
- jarvis/jarvis_rag/reranker_interface.py +32 -0
- jarvis/jarvis_rag/rerankers/__init__.py +41 -0
- jarvis/jarvis_rag/rerankers/base.py +109 -0
- jarvis/jarvis_rag/rerankers/cohere.py +67 -0
- jarvis/jarvis_rag/rerankers/edgefn.py +140 -0
- jarvis/jarvis_rag/rerankers/jina.py +79 -0
- jarvis/jarvis_rag/rerankers/local.py +89 -0
- jarvis/jarvis_rag/rerankers/registry.py +293 -0
- jarvis/jarvis_rag/retriever.py +58 -43
- jarvis/jarvis_sec/__init__.py +66 -141
- jarvis/jarvis_sec/agents.py +21 -17
- jarvis/jarvis_sec/analysis.py +80 -33
- jarvis/jarvis_sec/checkers/__init__.py +7 -13
- jarvis/jarvis_sec/checkers/c_checker.py +356 -164
- jarvis/jarvis_sec/checkers/rust_checker.py +47 -29
- jarvis/jarvis_sec/cli.py +43 -21
- jarvis/jarvis_sec/clustering.py +430 -272
- jarvis/jarvis_sec/file_manager.py +99 -55
- jarvis/jarvis_sec/parsers.py +9 -6
- jarvis/jarvis_sec/prompts.py +4 -3
- jarvis/jarvis_sec/report.py +44 -22
- jarvis/jarvis_sec/review.py +180 -107
- jarvis/jarvis_sec/status.py +50 -41
- jarvis/jarvis_sec/types.py +3 -0
- jarvis/jarvis_sec/utils.py +160 -83
- jarvis/jarvis_sec/verification.py +411 -181
- jarvis/jarvis_sec/workflow.py +132 -21
- jarvis/jarvis_smart_shell/main.py +28 -41
- jarvis/jarvis_stats/cli.py +14 -12
- jarvis/jarvis_stats/stats.py +28 -19
- jarvis/jarvis_stats/storage.py +14 -8
- jarvis/jarvis_stats/visualizer.py +12 -7
- jarvis/jarvis_tools/base.py +5 -2
- jarvis/jarvis_tools/clear_memory.py +13 -9
- jarvis/jarvis_tools/cli/main.py +23 -18
- jarvis/jarvis_tools/edit_file.py +572 -873
- jarvis/jarvis_tools/execute_script.py +10 -7
- jarvis/jarvis_tools/file_analyzer.py +7 -8
- jarvis/jarvis_tools/meta_agent.py +287 -0
- jarvis/jarvis_tools/methodology.py +5 -3
- jarvis/jarvis_tools/read_code.py +305 -1438
- jarvis/jarvis_tools/read_symbols.py +50 -17
- jarvis/jarvis_tools/read_webpage.py +19 -18
- jarvis/jarvis_tools/registry.py +435 -156
- jarvis/jarvis_tools/retrieve_memory.py +16 -11
- jarvis/jarvis_tools/save_memory.py +8 -6
- jarvis/jarvis_tools/search_web.py +31 -31
- jarvis/jarvis_tools/sub_agent.py +32 -28
- jarvis/jarvis_tools/sub_code_agent.py +44 -60
- jarvis/jarvis_tools/task_list_manager.py +1811 -0
- jarvis/jarvis_tools/virtual_tty.py +29 -19
- jarvis/jarvis_utils/__init__.py +4 -0
- jarvis/jarvis_utils/builtin_replace_map.py +2 -1
- jarvis/jarvis_utils/clipboard.py +9 -8
- jarvis/jarvis_utils/collections.py +331 -0
- jarvis/jarvis_utils/config.py +699 -194
- jarvis/jarvis_utils/dialogue_recorder.py +294 -0
- jarvis/jarvis_utils/embedding.py +6 -3
- jarvis/jarvis_utils/file_processors.py +7 -1
- jarvis/jarvis_utils/fzf.py +9 -3
- jarvis/jarvis_utils/git_utils.py +71 -42
- jarvis/jarvis_utils/globals.py +116 -32
- jarvis/jarvis_utils/http.py +6 -2
- jarvis/jarvis_utils/input.py +318 -83
- jarvis/jarvis_utils/jsonnet_compat.py +119 -104
- jarvis/jarvis_utils/methodology.py +37 -28
- jarvis/jarvis_utils/output.py +201 -44
- jarvis/jarvis_utils/utils.py +986 -628
- {jarvis_ai_assistant-0.7.16.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/METADATA +49 -33
- jarvis_ai_assistant-1.0.2.dist-info/RECORD +304 -0
- jarvis/jarvis_code_agent/code_analyzer/structured_code.py +0 -556
- jarvis/jarvis_tools/generate_new_tool.py +0 -205
- jarvis/jarvis_tools/lsp_client.py +0 -1552
- jarvis/jarvis_tools/rewrite_file.py +0 -105
- jarvis_ai_assistant-0.7.16.dist-info/RECORD +0 -218
- {jarvis_ai_assistant-0.7.16.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.7.16.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.7.16.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.7.16.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
from typing import List
|
|
2
|
+
|
|
3
|
+
from jarvis.jarvis_utils.output import PrettyOutput
|
|
4
|
+
|
|
2
5
|
from .llm_interface import LLMInterface
|
|
3
6
|
|
|
4
7
|
|
|
@@ -56,9 +59,7 @@ English version of the query
|
|
|
56
59
|
一个经过重写、搜索优化的查询列表。
|
|
57
60
|
"""
|
|
58
61
|
prompt = self.rewrite_prompt_template.format(query=query)
|
|
59
|
-
|
|
60
|
-
"ℹ️ 正在将原始查询重写为多个搜索查询..."
|
|
61
|
-
)
|
|
62
|
+
PrettyOutput.auto_print("ℹ️ 正在将原始查询重写为多个搜索查询...")
|
|
62
63
|
|
|
63
64
|
import re
|
|
64
65
|
|
|
@@ -77,18 +78,18 @@ English version of the query
|
|
|
77
78
|
rewritten_queries = [
|
|
78
79
|
line.strip() for line in content.split("\n") if line.strip()
|
|
79
80
|
]
|
|
80
|
-
|
|
81
|
+
PrettyOutput.auto_print(
|
|
81
82
|
f"✅ 成功从LLM响应中提取到内容 (尝试 {attempts}/{max_retries})。"
|
|
82
83
|
)
|
|
83
84
|
break # 提取成功,退出循环
|
|
84
85
|
else:
|
|
85
|
-
|
|
86
|
+
PrettyOutput.auto_print(
|
|
86
87
|
f"⚠️ 未能从LLM响应中提取内容。正在重试... ({attempts}/{max_retries})"
|
|
87
88
|
)
|
|
88
89
|
|
|
89
90
|
# 如果所有重试都失败,则跳过重写步骤
|
|
90
91
|
if not rewritten_queries:
|
|
91
|
-
|
|
92
|
+
PrettyOutput.auto_print(
|
|
92
93
|
"❌ 所有重试均失败。跳过查询重写,将仅使用原始查询。"
|
|
93
94
|
)
|
|
94
95
|
|
|
@@ -96,7 +97,5 @@ English version of the query
|
|
|
96
97
|
if query not in rewritten_queries:
|
|
97
98
|
rewritten_queries.insert(0, query)
|
|
98
99
|
|
|
99
|
-
|
|
100
|
-
f"✅ 生成了 {len(rewritten_queries)} 个查询变体。"
|
|
101
|
-
)
|
|
100
|
+
PrettyOutput.auto_print(f"✅ 生成了 {len(rewritten_queries)} 个查询变体。")
|
|
102
101
|
return rewritten_queries
|
|
@@ -1,20 +1,27 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import List
|
|
2
|
+
from typing import List
|
|
3
|
+
from typing import Optional
|
|
3
4
|
|
|
4
5
|
from langchain.docstore.document import Document
|
|
5
6
|
|
|
6
|
-
from .
|
|
7
|
-
from .
|
|
7
|
+
from jarvis.jarvis_utils.config import get_rag_embedding_cache_path
|
|
8
|
+
from jarvis.jarvis_utils.config import get_rag_embedding_model
|
|
9
|
+
from jarvis.jarvis_utils.config import get_rag_rerank_model
|
|
10
|
+
from jarvis.jarvis_utils.config import get_rag_vector_db_path
|
|
11
|
+
from jarvis.jarvis_utils.output import PrettyOutput
|
|
12
|
+
from jarvis.jarvis_utils.utils import get_yes_no
|
|
13
|
+
|
|
14
|
+
from .embedding_interface import EmbeddingInterface
|
|
15
|
+
from .embeddings import EmbeddingManager
|
|
16
|
+
from .embeddings import EmbeddingRegistry
|
|
17
|
+
from .llm_interface import JarvisPlatform_LLM
|
|
18
|
+
from .llm_interface import LLMInterface
|
|
19
|
+
from .llm_interface import ToolAgent_LLM
|
|
8
20
|
from .query_rewriter import QueryRewriter
|
|
9
|
-
from .
|
|
21
|
+
from .reranker_interface import RerankerInterface
|
|
22
|
+
from .rerankers import Reranker
|
|
23
|
+
from .rerankers import RerankerRegistry
|
|
10
24
|
from .retriever import ChromaRetriever
|
|
11
|
-
from jarvis.jarvis_utils.config import (
|
|
12
|
-
get_rag_embedding_model,
|
|
13
|
-
get_rag_rerank_model,
|
|
14
|
-
get_rag_vector_db_path,
|
|
15
|
-
get_rag_embedding_cache_path,
|
|
16
|
-
)
|
|
17
|
-
from jarvis.jarvis_utils.utils import get_yes_no
|
|
18
25
|
|
|
19
26
|
|
|
20
27
|
class JarvisRAGPipeline:
|
|
@@ -73,27 +80,31 @@ class JarvisRAGPipeline:
|
|
|
73
80
|
self.use_query_rewrite = use_query_rewrite
|
|
74
81
|
|
|
75
82
|
# 延迟加载的组件
|
|
76
|
-
self._embedding_manager: Optional[
|
|
83
|
+
self._embedding_manager: Optional[EmbeddingInterface] = None
|
|
77
84
|
self._retriever: Optional[ChromaRetriever] = None
|
|
78
|
-
self._reranker: Optional[
|
|
85
|
+
self._reranker: Optional[RerankerInterface] = None
|
|
79
86
|
self._query_rewriter: Optional[QueryRewriter] = None
|
|
80
87
|
|
|
81
|
-
|
|
82
|
-
"✅ JarvisRAGPipeline 初始化成功 (模型按需加载)."
|
|
83
|
-
)
|
|
88
|
+
PrettyOutput.auto_print("✅ JarvisRAGPipeline 初始化成功 (模型按需加载).")
|
|
84
89
|
|
|
85
|
-
def _get_embedding_manager(self) ->
|
|
90
|
+
def _get_embedding_manager(self) -> EmbeddingInterface:
|
|
86
91
|
if self._embedding_manager is None:
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
92
|
+
# 尝试从配置创建模型
|
|
93
|
+
embedding_from_config = EmbeddingRegistry.create_from_config()
|
|
94
|
+
if embedding_from_config:
|
|
95
|
+
self._embedding_manager = embedding_from_config
|
|
96
|
+
else:
|
|
97
|
+
# 回退到传统方式(向后兼容)
|
|
98
|
+
sanitized_model_name = self.embedding_model_name.replace(
|
|
99
|
+
"/", "_"
|
|
100
|
+
).replace("\\", "_")
|
|
101
|
+
_final_cache_path = os.path.join(
|
|
102
|
+
get_rag_embedding_cache_path(), sanitized_model_name
|
|
103
|
+
)
|
|
104
|
+
self._embedding_manager = EmbeddingManager(
|
|
105
|
+
model_name=self.embedding_model_name,
|
|
106
|
+
cache_dir=_final_cache_path,
|
|
107
|
+
)
|
|
97
108
|
return self._embedding_manager
|
|
98
109
|
|
|
99
110
|
def _get_retriever(self) -> ChromaRetriever:
|
|
@@ -137,9 +148,15 @@ class JarvisRAGPipeline:
|
|
|
137
148
|
chroma_client = chromadb.PersistentClient(path=_final_db_path)
|
|
138
149
|
return chroma_client.get_collection(name=self.collection_name)
|
|
139
150
|
|
|
140
|
-
def _get_reranker(self) ->
|
|
151
|
+
def _get_reranker(self) -> RerankerInterface:
|
|
141
152
|
if self._reranker is None:
|
|
142
|
-
|
|
153
|
+
# 尝试从配置创建模型
|
|
154
|
+
reranker_from_config = RerankerRegistry.create_from_config()
|
|
155
|
+
if reranker_from_config:
|
|
156
|
+
self._reranker = reranker_from_config
|
|
157
|
+
else:
|
|
158
|
+
# 回退到传统方式(向后兼容)
|
|
159
|
+
self._reranker = Reranker(model_name=get_rag_rerank_model())
|
|
143
160
|
return self._reranker
|
|
144
161
|
|
|
145
162
|
def _get_query_rewriter(self) -> QueryRewriter:
|
|
@@ -171,19 +188,19 @@ class JarvisRAGPipeline:
|
|
|
171
188
|
lines.extend([f" 变更: {p}" for p in changed[:3]])
|
|
172
189
|
if deleted:
|
|
173
190
|
lines.extend([f" 删除: {p}" for p in deleted[:3]])
|
|
174
|
-
joined_lines =
|
|
175
|
-
|
|
191
|
+
joined_lines = "\n".join(lines)
|
|
192
|
+
PrettyOutput.auto_print(f"⚠️ {joined_lines}")
|
|
176
193
|
# 询问用户
|
|
177
194
|
if get_yes_no(
|
|
178
195
|
"检测到索引变更,是否现在更新索引后再开始检索?", default=True
|
|
179
196
|
):
|
|
180
197
|
retriever.update_index_for_changes(changed, deleted)
|
|
181
198
|
else:
|
|
182
|
-
|
|
199
|
+
PrettyOutput.auto_print(
|
|
183
200
|
"ℹ️ 已跳过索引更新,将直接使用当前索引进行检索。"
|
|
184
201
|
)
|
|
185
202
|
except Exception as e:
|
|
186
|
-
|
|
203
|
+
PrettyOutput.auto_print(f"⚠️ 检索前索引检查失败:{e}")
|
|
187
204
|
|
|
188
205
|
def add_documents(self, documents: List[Document]):
|
|
189
206
|
"""
|
|
@@ -236,16 +253,12 @@ class JarvisRAGPipeline:
|
|
|
236
253
|
if self.use_query_rewrite:
|
|
237
254
|
rewritten_queries = self._get_query_rewriter().rewrite(query_text)
|
|
238
255
|
else:
|
|
239
|
-
|
|
240
|
-
"ℹ️ 已关闭查询重写,将直接使用原始查询进行检索。"
|
|
241
|
-
)
|
|
256
|
+
PrettyOutput.auto_print("ℹ️ 已关闭查询重写,将直接使用原始查询进行检索。")
|
|
242
257
|
rewritten_queries = [query_text]
|
|
243
258
|
|
|
244
259
|
# 2. 为每个重写的查询检索初始候选文档
|
|
245
|
-
query_lines =
|
|
246
|
-
|
|
247
|
-
f"ℹ️ 将为以下查询变体进行混合检索:\n{query_lines}"
|
|
248
|
-
)
|
|
260
|
+
query_lines = "\n".join([f" - {q}" for q in rewritten_queries])
|
|
261
|
+
PrettyOutput.auto_print(f"ℹ️ 将为以下查询变体进行混合检索:\n{query_lines}")
|
|
249
262
|
all_candidate_docs = []
|
|
250
263
|
for q in rewritten_queries:
|
|
251
264
|
candidates = self._get_retriever().retrieve(
|
|
@@ -262,7 +275,7 @@ class JarvisRAGPipeline:
|
|
|
262
275
|
|
|
263
276
|
# 3. 根据*原始*查询对统一的候选池进行重排
|
|
264
277
|
if self.use_rerank:
|
|
265
|
-
|
|
278
|
+
PrettyOutput.auto_print(
|
|
266
279
|
f"ℹ️ 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)..."
|
|
267
280
|
)
|
|
268
281
|
retrieved_docs = self._get_reranker().rerank(
|
|
@@ -287,14 +300,14 @@ class JarvisRAGPipeline:
|
|
|
287
300
|
if sources:
|
|
288
301
|
# 合并来源列表后一次性打印,避免多次加框
|
|
289
302
|
lines = ["根据以下文档回答:"] + [f" - {source}" for source in sources]
|
|
290
|
-
joined_lines =
|
|
291
|
-
|
|
303
|
+
joined_lines = "\n".join(lines)
|
|
304
|
+
PrettyOutput.auto_print(f"ℹ️ {joined_lines}")
|
|
292
305
|
|
|
293
306
|
# 4. 创建最终提示并生成答案
|
|
294
307
|
# 我们使用原始的query_text作为给LLM的最终提示
|
|
295
308
|
prompt = self._create_prompt(query_text, retrieved_docs)
|
|
296
309
|
|
|
297
|
-
|
|
310
|
+
PrettyOutput.auto_print("ℹ️ 正在从LLM生成答案...")
|
|
298
311
|
answer = self.llm.generate(prompt)
|
|
299
312
|
|
|
300
313
|
return answer
|
|
@@ -316,16 +329,12 @@ class JarvisRAGPipeline:
|
|
|
316
329
|
if self.use_query_rewrite:
|
|
317
330
|
rewritten_queries = self._get_query_rewriter().rewrite(query_text)
|
|
318
331
|
else:
|
|
319
|
-
|
|
320
|
-
"ℹ️ 已关闭查询重写,将直接使用原始查询进行检索。"
|
|
321
|
-
)
|
|
332
|
+
PrettyOutput.auto_print("ℹ️ 已关闭查询重写,将直接使用原始查询进行检索。")
|
|
322
333
|
rewritten_queries = [query_text]
|
|
323
334
|
|
|
324
335
|
# 2. 检索候选文档
|
|
325
|
-
query_lines =
|
|
326
|
-
|
|
327
|
-
f"ℹ️ 将为以下查询变体进行混合检索:\n{query_lines}"
|
|
328
|
-
)
|
|
336
|
+
query_lines = "\n".join([f" - {q}" for q in rewritten_queries])
|
|
337
|
+
PrettyOutput.auto_print(f"ℹ️ 将为以下查询变体进行混合检索:\n{query_lines}")
|
|
329
338
|
all_candidate_docs = []
|
|
330
339
|
for q in rewritten_queries:
|
|
331
340
|
candidates = self._get_retriever().retrieve(
|
|
@@ -341,7 +350,7 @@ class JarvisRAGPipeline:
|
|
|
341
350
|
|
|
342
351
|
# 3. 重排
|
|
343
352
|
if self.use_rerank:
|
|
344
|
-
|
|
353
|
+
PrettyOutput.auto_print(
|
|
345
354
|
f"ℹ️ 正在对 {len(unique_candidate_docs)} 个候选文档进行重排..."
|
|
346
355
|
)
|
|
347
356
|
retrieved_docs = self._get_reranker().rerank(
|
jarvis/jarvis_rag/reranker.py
CHANGED
|
@@ -1,78 +1,10 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""
|
|
2
|
+
向后兼容模块:保持 Reranker 的导入路径。
|
|
3
3
|
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
CrossEncoder,
|
|
7
|
-
)
|
|
8
|
-
from huggingface_hub import snapshot_download
|
|
4
|
+
新的实现已移动到 rerankers/ 目录。
|
|
5
|
+
"""
|
|
9
6
|
|
|
7
|
+
from .rerankers.local import LocalReranker
|
|
10
8
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
一个重排器类,使用Cross-Encoder模型根据文档与给定查询的相关性
|
|
14
|
-
对文档进行重新评分和排序。
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
def __init__(self, model_name: str):
|
|
18
|
-
"""
|
|
19
|
-
初始化重排器。
|
|
20
|
-
|
|
21
|
-
参数:
|
|
22
|
-
model_name (str): 要使用的Cross-Encoder模型的名称。
|
|
23
|
-
"""
|
|
24
|
-
print(f"ℹ️ 正在初始化重排模型: {model_name}...")
|
|
25
|
-
try:
|
|
26
|
-
local_dir = None
|
|
27
|
-
|
|
28
|
-
if os.path.isdir(model_name):
|
|
29
|
-
self.model = CrossEncoder(model_name)
|
|
30
|
-
print("✅ 重排模型初始化成功。")
|
|
31
|
-
return
|
|
32
|
-
try:
|
|
33
|
-
# Prefer local cache; avoid any network access
|
|
34
|
-
local_dir = snapshot_download(repo_id=model_name, local_files_only=True)
|
|
35
|
-
except Exception:
|
|
36
|
-
local_dir = None
|
|
37
|
-
|
|
38
|
-
if local_dir:
|
|
39
|
-
self.model = CrossEncoder(local_dir)
|
|
40
|
-
else:
|
|
41
|
-
self.model = CrossEncoder(model_name)
|
|
42
|
-
|
|
43
|
-
print("✅ 重排模型初始化成功。")
|
|
44
|
-
except Exception as e:
|
|
45
|
-
print(f"❌ 初始化重排模型失败: {e}")
|
|
46
|
-
raise
|
|
47
|
-
|
|
48
|
-
def rerank(
|
|
49
|
-
self, query: str, documents: List[Document], top_n: int = 5
|
|
50
|
-
) -> List[Document]:
|
|
51
|
-
"""
|
|
52
|
-
根据文档与查询的相关性对文档列表进行重排。
|
|
53
|
-
|
|
54
|
-
参数:
|
|
55
|
-
query (str): 用户的查询。
|
|
56
|
-
documents (List[Document]): 从初始搜索中检索到的文档列表。
|
|
57
|
-
top_n (int): 重排后要返回的顶部文档数。
|
|
58
|
-
|
|
59
|
-
返回:
|
|
60
|
-
List[Document]: 一个已排序的最相关文档列表。
|
|
61
|
-
"""
|
|
62
|
-
if not documents:
|
|
63
|
-
return []
|
|
64
|
-
|
|
65
|
-
# 创建 [查询, 文档内容] 对用于评分
|
|
66
|
-
pairs = [[query, doc.page_content] for doc in documents]
|
|
67
|
-
|
|
68
|
-
# 从Cross-Encoder模型获取分数
|
|
69
|
-
scores = self.model.predict(pairs)
|
|
70
|
-
|
|
71
|
-
# 将文档与它们的分数结合并排序
|
|
72
|
-
doc_with_scores = list(zip(documents, scores))
|
|
73
|
-
doc_with_scores.sort(key=lambda x: x[1], reverse=True) # type: ignore
|
|
74
|
-
|
|
75
|
-
# 返回前N个文档
|
|
76
|
-
reranked_docs = [doc for doc, score in doc_with_scores[:top_n]]
|
|
77
|
-
|
|
78
|
-
return reranked_docs
|
|
9
|
+
# 向后兼容:保持 Reranker 作为 LocalReranker 的别名
|
|
10
|
+
Reranker = LocalReranker
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
from langchain.docstore.document import Document
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RerankerInterface(ABC):
|
|
9
|
+
"""
|
|
10
|
+
重排模型接口的抽象基类。
|
|
11
|
+
|
|
12
|
+
该类定义了重排模型的标准接口,支持本地模型和在线模型(API)的实现。
|
|
13
|
+
任何重排模型提供商(如sentence-transformers CrossEncoder、Cohere Rerank API等)
|
|
14
|
+
都应作为该接口的子类来实现。
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def rerank(
|
|
19
|
+
self, query: str, documents: List[Document], top_n: int = 5
|
|
20
|
+
) -> List[Document]:
|
|
21
|
+
"""
|
|
22
|
+
根据文档与查询的相关性对文档列表进行重排。
|
|
23
|
+
|
|
24
|
+
参数:
|
|
25
|
+
query: 用户的查询。
|
|
26
|
+
documents: 从初始搜索中检索到的文档列表。
|
|
27
|
+
top_n: 重排后要返回的顶部文档数。
|
|
28
|
+
|
|
29
|
+
返回:
|
|
30
|
+
一个已排序的最相关文档列表。
|
|
31
|
+
"""
|
|
32
|
+
pass
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
重排模型实现模块。
|
|
3
|
+
|
|
4
|
+
包含本地和在线重排模型的实现,支持动态加载自定义模型。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .base import OnlineReranker # noqa: F401
|
|
8
|
+
from .local import LocalReranker
|
|
9
|
+
from .registry import RerankerRegistry
|
|
10
|
+
|
|
11
|
+
# 向后兼容别名
|
|
12
|
+
Reranker = LocalReranker
|
|
13
|
+
|
|
14
|
+
# 在线模型实现(可选导入)
|
|
15
|
+
try:
|
|
16
|
+
from .cohere import CohereReranker # noqa: F401
|
|
17
|
+
from .edgefn import EdgeFnReranker # noqa: F401
|
|
18
|
+
from .jina import JinaReranker # noqa: F401
|
|
19
|
+
|
|
20
|
+
_base_exports = [
|
|
21
|
+
"OnlineReranker",
|
|
22
|
+
"LocalReranker",
|
|
23
|
+
"Reranker", # 向后兼容
|
|
24
|
+
"RerankerRegistry",
|
|
25
|
+
"CohereReranker",
|
|
26
|
+
"JinaReranker",
|
|
27
|
+
"EdgeFnReranker",
|
|
28
|
+
]
|
|
29
|
+
except ImportError:
|
|
30
|
+
_base_exports = [
|
|
31
|
+
"OnlineReranker",
|
|
32
|
+
"LocalReranker",
|
|
33
|
+
"Reranker", # 向后兼容
|
|
34
|
+
"RerankerRegistry",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
# 动态加载的模型(通过 registry)
|
|
38
|
+
_registry = RerankerRegistry.get_global_registry()
|
|
39
|
+
_dynamic_exports = _registry.get_available_rerankers()
|
|
40
|
+
|
|
41
|
+
__all__ = _base_exports + _dynamic_exports
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""
|
|
2
|
+
在线重排模型的基类实现。
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import List
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
from langchain.docstore.document import Document
|
|
10
|
+
|
|
11
|
+
from ..reranker_interface import RerankerInterface
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OnlineReranker(RerankerInterface):
|
|
15
|
+
"""
|
|
16
|
+
在线重排模型的基类实现。
|
|
17
|
+
|
|
18
|
+
这是一个抽象基类,定义了在线重排模型的基本结构。
|
|
19
|
+
子类需要实现具体的API调用逻辑。
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
api_key: Optional[str] = None,
|
|
25
|
+
api_key_env: Optional[str] = None,
|
|
26
|
+
base_url: Optional[str] = None,
|
|
27
|
+
model_name: Optional[str] = None,
|
|
28
|
+
top_n: int = 5,
|
|
29
|
+
max_length: Optional[int] = None,
|
|
30
|
+
**kwargs,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
初始化在线重排模型。
|
|
34
|
+
|
|
35
|
+
参数:
|
|
36
|
+
api_key: API密钥。如果为None,将从环境变量中读取。
|
|
37
|
+
api_key_env: 用于读取API密钥的环境变量名。
|
|
38
|
+
base_url: API的基础URL。
|
|
39
|
+
model_name: 要使用的模型名称。
|
|
40
|
+
top_n: 默认返回的顶部文档数。
|
|
41
|
+
max_length: 模型的最大输入长度(token数),用于文档处理。
|
|
42
|
+
**kwargs: 其他配置参数(可能包含从 reranker_config 传入的配置)。
|
|
43
|
+
"""
|
|
44
|
+
# 优先从 kwargs 中读取 api_key(可能来自 reranker_config)
|
|
45
|
+
# 如果没有,则使用传入的 api_key 参数
|
|
46
|
+
# 最后才从环境变量读取(向后兼容)
|
|
47
|
+
self.api_key = (
|
|
48
|
+
kwargs.get("api_key")
|
|
49
|
+
or api_key
|
|
50
|
+
or (os.getenv(api_key_env) if api_key_env else None)
|
|
51
|
+
)
|
|
52
|
+
if not self.api_key:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"API密钥未提供。请通过api_key参数、reranker_config配置或环境变量{api_key_env}提供。"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# 如果 base_url 在 kwargs 中,优先使用
|
|
58
|
+
if "base_url" in kwargs and kwargs["base_url"]:
|
|
59
|
+
self.base_url = kwargs["base_url"]
|
|
60
|
+
else:
|
|
61
|
+
self.base_url = base_url
|
|
62
|
+
|
|
63
|
+
self.model_name = model_name
|
|
64
|
+
self.default_top_n = top_n
|
|
65
|
+
self.max_length = max_length
|
|
66
|
+
|
|
67
|
+
def _call_api(self, query: str, documents: List[str]) -> List[tuple[int, float]]:
|
|
68
|
+
"""
|
|
69
|
+
调用在线API获取重排分数。
|
|
70
|
+
|
|
71
|
+
参数:
|
|
72
|
+
query: 查询文本。
|
|
73
|
+
documents: 文档文本列表。
|
|
74
|
+
|
|
75
|
+
返回:
|
|
76
|
+
包含 (索引, 分数) 元组的列表,按分数降序排序。
|
|
77
|
+
|
|
78
|
+
注意:
|
|
79
|
+
子类必须实现此方法。
|
|
80
|
+
"""
|
|
81
|
+
raise NotImplementedError("子类必须实现 _call_api 方法")
|
|
82
|
+
|
|
83
|
+
def rerank(
|
|
84
|
+
self, query: str, documents: List[Document], top_n: int = 5
|
|
85
|
+
) -> List[Document]:
|
|
86
|
+
"""
|
|
87
|
+
根据文档与查询的相关性对文档列表进行重排。
|
|
88
|
+
|
|
89
|
+
参数:
|
|
90
|
+
query: 用户的查询。
|
|
91
|
+
documents: 从初始搜索中检索到的文档列表。
|
|
92
|
+
top_n: 重排后要返回的顶部文档数。
|
|
93
|
+
|
|
94
|
+
返回:
|
|
95
|
+
一个已排序的最相关文档列表。
|
|
96
|
+
"""
|
|
97
|
+
if not documents:
|
|
98
|
+
return []
|
|
99
|
+
|
|
100
|
+
# 提取文档内容
|
|
101
|
+
doc_texts = [doc.page_content for doc in documents]
|
|
102
|
+
|
|
103
|
+
# 调用API获取重排分数
|
|
104
|
+
scored_indices = self._call_api(query, doc_texts)
|
|
105
|
+
|
|
106
|
+
# 根据分数排序并返回前top_n个文档
|
|
107
|
+
reranked_docs = [documents[idx] for idx, _ in scored_indices[:top_n]]
|
|
108
|
+
|
|
109
|
+
return reranked_docs
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cohere 重排模型实现。
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from .base import OnlineReranker
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CohereReranker(OnlineReranker):
|
|
12
|
+
"""
|
|
13
|
+
Cohere 重排模型的实现。
|
|
14
|
+
|
|
15
|
+
使用 Cohere 的 rerank API。
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
model_name: str = "rerank-english-v3.0",
|
|
22
|
+
**kwargs,
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
初始化 Cohere 重排模型。
|
|
26
|
+
|
|
27
|
+
参数:
|
|
28
|
+
api_key: Cohere API密钥。如果为None,将从COHERE_API_KEY环境变量读取。
|
|
29
|
+
model_name: 要使用的模型名称(如 'rerank-english-v3.0')。
|
|
30
|
+
**kwargs: 传递给父类的其他参数。
|
|
31
|
+
"""
|
|
32
|
+
super().__init__(
|
|
33
|
+
api_key=api_key,
|
|
34
|
+
api_key_env="COHERE_API_KEY",
|
|
35
|
+
model_name=model_name,
|
|
36
|
+
**kwargs,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def _call_api(self, query: str, documents: List[str]) -> List[tuple[int, float]]:
|
|
40
|
+
"""
|
|
41
|
+
调用 Cohere API 获取重排分数。
|
|
42
|
+
"""
|
|
43
|
+
try:
|
|
44
|
+
import cohere
|
|
45
|
+
|
|
46
|
+
client = cohere.Client(api_key=self.api_key)
|
|
47
|
+
|
|
48
|
+
response = client.rerank(
|
|
49
|
+
model=self.model_name,
|
|
50
|
+
query=query,
|
|
51
|
+
documents=documents,
|
|
52
|
+
top_n=len(documents), # 获取所有文档的分数
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# 返回 (索引, 分数) 元组列表,按分数降序排序
|
|
56
|
+
results = [
|
|
57
|
+
(result.index, result.relevance_score) for result in response.results
|
|
58
|
+
]
|
|
59
|
+
results.sort(key=lambda x: x[1], reverse=True)
|
|
60
|
+
|
|
61
|
+
return results
|
|
62
|
+
except ImportError:
|
|
63
|
+
raise ImportError(
|
|
64
|
+
"使用 CohereReranker 需要安装 cohere 包: pip install cohere"
|
|
65
|
+
)
|
|
66
|
+
except Exception as e:
|
|
67
|
+
raise RuntimeError(f"调用 Cohere API 时出错: {e}")
|