jarvis-ai-assistant 0.7.16__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +567 -222
- jarvis/jarvis_agent/agent_manager.py +19 -12
- jarvis/jarvis_agent/builtin_input_handler.py +79 -11
- jarvis/jarvis_agent/config_editor.py +7 -2
- jarvis/jarvis_agent/event_bus.py +24 -13
- jarvis/jarvis_agent/events.py +19 -1
- jarvis/jarvis_agent/file_context_handler.py +67 -64
- jarvis/jarvis_agent/file_methodology_manager.py +38 -24
- jarvis/jarvis_agent/jarvis.py +186 -114
- jarvis/jarvis_agent/language_extractors/__init__.py +8 -1
- jarvis/jarvis_agent/language_extractors/c_extractor.py +7 -4
- jarvis/jarvis_agent/language_extractors/cpp_extractor.py +9 -4
- jarvis/jarvis_agent/language_extractors/go_extractor.py +7 -4
- jarvis/jarvis_agent/language_extractors/java_extractor.py +27 -20
- jarvis/jarvis_agent/language_extractors/javascript_extractor.py +22 -17
- jarvis/jarvis_agent/language_extractors/python_extractor.py +7 -4
- jarvis/jarvis_agent/language_extractors/rust_extractor.py +7 -4
- jarvis/jarvis_agent/language_extractors/typescript_extractor.py +22 -17
- jarvis/jarvis_agent/language_support_info.py +250 -219
- jarvis/jarvis_agent/main.py +19 -23
- jarvis/jarvis_agent/memory_manager.py +9 -6
- jarvis/jarvis_agent/methodology_share_manager.py +21 -15
- jarvis/jarvis_agent/output_handler.py +4 -2
- jarvis/jarvis_agent/prompt_builder.py +7 -6
- jarvis/jarvis_agent/prompt_manager.py +113 -8
- jarvis/jarvis_agent/prompts.py +317 -85
- jarvis/jarvis_agent/protocols.py +5 -2
- jarvis/jarvis_agent/run_loop.py +192 -32
- jarvis/jarvis_agent/session_manager.py +7 -3
- jarvis/jarvis_agent/share_manager.py +23 -13
- jarvis/jarvis_agent/shell_input_handler.py +12 -8
- jarvis/jarvis_agent/stdio_redirect.py +25 -26
- jarvis/jarvis_agent/task_analyzer.py +29 -23
- jarvis/jarvis_agent/task_list.py +869 -0
- jarvis/jarvis_agent/task_manager.py +26 -23
- jarvis/jarvis_agent/tool_executor.py +6 -5
- jarvis/jarvis_agent/tool_share_manager.py +24 -14
- jarvis/jarvis_agent/user_interaction.py +3 -3
- jarvis/jarvis_agent/utils.py +9 -1
- jarvis/jarvis_agent/web_bridge.py +37 -17
- jarvis/jarvis_agent/web_output_sink.py +5 -2
- jarvis/jarvis_agent/web_server.py +165 -36
- jarvis/jarvis_c2rust/__init__.py +1 -1
- jarvis/jarvis_c2rust/cli.py +260 -141
- jarvis/jarvis_c2rust/collector.py +37 -18
- jarvis/jarvis_c2rust/constants.py +60 -0
- jarvis/jarvis_c2rust/library_replacer.py +242 -1010
- jarvis/jarvis_c2rust/library_replacer_checkpoint.py +133 -0
- jarvis/jarvis_c2rust/library_replacer_llm.py +287 -0
- jarvis/jarvis_c2rust/library_replacer_loader.py +191 -0
- jarvis/jarvis_c2rust/library_replacer_output.py +134 -0
- jarvis/jarvis_c2rust/library_replacer_prompts.py +124 -0
- jarvis/jarvis_c2rust/library_replacer_utils.py +188 -0
- jarvis/jarvis_c2rust/llm_module_agent.py +98 -1044
- jarvis/jarvis_c2rust/llm_module_agent_apply.py +170 -0
- jarvis/jarvis_c2rust/llm_module_agent_executor.py +288 -0
- jarvis/jarvis_c2rust/llm_module_agent_loader.py +170 -0
- jarvis/jarvis_c2rust/llm_module_agent_prompts.py +268 -0
- jarvis/jarvis_c2rust/llm_module_agent_types.py +57 -0
- jarvis/jarvis_c2rust/llm_module_agent_utils.py +150 -0
- jarvis/jarvis_c2rust/llm_module_agent_validator.py +119 -0
- jarvis/jarvis_c2rust/loaders.py +28 -10
- jarvis/jarvis_c2rust/models.py +5 -2
- jarvis/jarvis_c2rust/optimizer.py +192 -1974
- jarvis/jarvis_c2rust/optimizer_build_fix.py +286 -0
- jarvis/jarvis_c2rust/optimizer_clippy.py +766 -0
- jarvis/jarvis_c2rust/optimizer_config.py +49 -0
- jarvis/jarvis_c2rust/optimizer_docs.py +183 -0
- jarvis/jarvis_c2rust/optimizer_options.py +48 -0
- jarvis/jarvis_c2rust/optimizer_progress.py +469 -0
- jarvis/jarvis_c2rust/optimizer_report.py +52 -0
- jarvis/jarvis_c2rust/optimizer_unsafe.py +309 -0
- jarvis/jarvis_c2rust/optimizer_utils.py +469 -0
- jarvis/jarvis_c2rust/optimizer_visibility.py +185 -0
- jarvis/jarvis_c2rust/scanner.py +229 -166
- jarvis/jarvis_c2rust/transpiler.py +531 -2732
- jarvis/jarvis_c2rust/transpiler_agents.py +503 -0
- jarvis/jarvis_c2rust/transpiler_build.py +1294 -0
- jarvis/jarvis_c2rust/transpiler_codegen.py +204 -0
- jarvis/jarvis_c2rust/transpiler_compile.py +146 -0
- jarvis/jarvis_c2rust/transpiler_config.py +178 -0
- jarvis/jarvis_c2rust/transpiler_context.py +122 -0
- jarvis/jarvis_c2rust/transpiler_executor.py +516 -0
- jarvis/jarvis_c2rust/transpiler_generation.py +278 -0
- jarvis/jarvis_c2rust/transpiler_git.py +163 -0
- jarvis/jarvis_c2rust/transpiler_mod_utils.py +225 -0
- jarvis/jarvis_c2rust/transpiler_modules.py +336 -0
- jarvis/jarvis_c2rust/transpiler_planning.py +394 -0
- jarvis/jarvis_c2rust/transpiler_review.py +1196 -0
- jarvis/jarvis_c2rust/transpiler_symbols.py +176 -0
- jarvis/jarvis_c2rust/utils.py +269 -79
- jarvis/jarvis_code_agent/after_change.py +233 -0
- jarvis/jarvis_code_agent/build_validation_config.py +37 -30
- jarvis/jarvis_code_agent/builtin_rules.py +68 -0
- jarvis/jarvis_code_agent/code_agent.py +976 -1517
- jarvis/jarvis_code_agent/code_agent_build.py +227 -0
- jarvis/jarvis_code_agent/code_agent_diff.py +246 -0
- jarvis/jarvis_code_agent/code_agent_git.py +525 -0
- jarvis/jarvis_code_agent/code_agent_impact.py +177 -0
- jarvis/jarvis_code_agent/code_agent_lint.py +283 -0
- jarvis/jarvis_code_agent/code_agent_llm.py +159 -0
- jarvis/jarvis_code_agent/code_agent_postprocess.py +105 -0
- jarvis/jarvis_code_agent/code_agent_prompts.py +46 -0
- jarvis/jarvis_code_agent/code_agent_rules.py +305 -0
- jarvis/jarvis_code_agent/code_analyzer/__init__.py +52 -48
- jarvis/jarvis_code_agent/code_analyzer/base_language.py +12 -10
- jarvis/jarvis_code_agent/code_analyzer/build_validator/__init__.py +12 -11
- jarvis/jarvis_code_agent/code_analyzer/build_validator/base.py +16 -12
- jarvis/jarvis_code_agent/code_analyzer/build_validator/cmake.py +26 -17
- jarvis/jarvis_code_agent/code_analyzer/build_validator/detector.py +558 -104
- jarvis/jarvis_code_agent/code_analyzer/build_validator/fallback.py +27 -16
- jarvis/jarvis_code_agent/code_analyzer/build_validator/go.py +22 -18
- jarvis/jarvis_code_agent/code_analyzer/build_validator/java_gradle.py +21 -16
- jarvis/jarvis_code_agent/code_analyzer/build_validator/java_maven.py +20 -16
- jarvis/jarvis_code_agent/code_analyzer/build_validator/makefile.py +27 -16
- jarvis/jarvis_code_agent/code_analyzer/build_validator/nodejs.py +47 -23
- jarvis/jarvis_code_agent/code_analyzer/build_validator/python.py +71 -37
- jarvis/jarvis_code_agent/code_analyzer/build_validator/rust.py +162 -35
- jarvis/jarvis_code_agent/code_analyzer/build_validator/validator.py +111 -57
- jarvis/jarvis_code_agent/code_analyzer/build_validator.py +18 -12
- jarvis/jarvis_code_agent/code_analyzer/context_manager.py +185 -183
- jarvis/jarvis_code_agent/code_analyzer/context_recommender.py +2 -1
- jarvis/jarvis_code_agent/code_analyzer/dependency_analyzer.py +24 -15
- jarvis/jarvis_code_agent/code_analyzer/file_ignore.py +227 -141
- jarvis/jarvis_code_agent/code_analyzer/impact_analyzer.py +321 -247
- jarvis/jarvis_code_agent/code_analyzer/language_registry.py +37 -29
- jarvis/jarvis_code_agent/code_analyzer/language_support.py +21 -13
- jarvis/jarvis_code_agent/code_analyzer/languages/__init__.py +15 -9
- jarvis/jarvis_code_agent/code_analyzer/languages/c_cpp_language.py +75 -45
- jarvis/jarvis_code_agent/code_analyzer/languages/go_language.py +87 -52
- jarvis/jarvis_code_agent/code_analyzer/languages/java_language.py +84 -51
- jarvis/jarvis_code_agent/code_analyzer/languages/javascript_language.py +94 -64
- jarvis/jarvis_code_agent/code_analyzer/languages/python_language.py +109 -71
- jarvis/jarvis_code_agent/code_analyzer/languages/rust_language.py +97 -63
- jarvis/jarvis_code_agent/code_analyzer/languages/typescript_language.py +103 -69
- jarvis/jarvis_code_agent/code_analyzer/llm_context_recommender.py +271 -268
- jarvis/jarvis_code_agent/code_analyzer/symbol_extractor.py +76 -64
- jarvis/jarvis_code_agent/code_analyzer/tree_sitter_extractor.py +92 -19
- jarvis/jarvis_code_agent/diff_visualizer.py +998 -0
- jarvis/jarvis_code_agent/lint.py +223 -524
- jarvis/jarvis_code_agent/rule_share_manager.py +158 -0
- jarvis/jarvis_code_agent/rules/clean_code.md +144 -0
- jarvis/jarvis_code_agent/rules/code_review.md +115 -0
- jarvis/jarvis_code_agent/rules/documentation.md +165 -0
- jarvis/jarvis_code_agent/rules/generate_rules.md +52 -0
- jarvis/jarvis_code_agent/rules/performance.md +158 -0
- jarvis/jarvis_code_agent/rules/refactoring.md +139 -0
- jarvis/jarvis_code_agent/rules/security.md +160 -0
- jarvis/jarvis_code_agent/rules/tdd.md +78 -0
- jarvis/jarvis_code_agent/test_rules/cpp_test.md +118 -0
- jarvis/jarvis_code_agent/test_rules/go_test.md +98 -0
- jarvis/jarvis_code_agent/test_rules/java_test.md +99 -0
- jarvis/jarvis_code_agent/test_rules/javascript_test.md +113 -0
- jarvis/jarvis_code_agent/test_rules/php_test.md +117 -0
- jarvis/jarvis_code_agent/test_rules/python_test.md +91 -0
- jarvis/jarvis_code_agent/test_rules/ruby_test.md +102 -0
- jarvis/jarvis_code_agent/test_rules/rust_test.md +86 -0
- jarvis/jarvis_code_agent/utils.py +36 -26
- jarvis/jarvis_code_analysis/checklists/loader.py +21 -21
- jarvis/jarvis_code_analysis/code_review.py +64 -33
- jarvis/jarvis_data/config_schema.json +285 -192
- jarvis/jarvis_git_squash/main.py +8 -6
- jarvis/jarvis_git_utils/git_commiter.py +53 -76
- jarvis/jarvis_mcp/__init__.py +5 -2
- jarvis/jarvis_mcp/sse_mcp_client.py +40 -30
- jarvis/jarvis_mcp/stdio_mcp_client.py +27 -19
- jarvis/jarvis_mcp/streamable_mcp_client.py +35 -26
- jarvis/jarvis_memory_organizer/memory_organizer.py +78 -55
- jarvis/jarvis_methodology/main.py +48 -39
- jarvis/jarvis_multi_agent/__init__.py +56 -23
- jarvis/jarvis_multi_agent/main.py +15 -18
- jarvis/jarvis_platform/base.py +179 -111
- jarvis/jarvis_platform/human.py +27 -16
- jarvis/jarvis_platform/kimi.py +52 -45
- jarvis/jarvis_platform/openai.py +101 -40
- jarvis/jarvis_platform/registry.py +51 -33
- jarvis/jarvis_platform/tongyi.py +68 -38
- jarvis/jarvis_platform/yuanbao.py +59 -43
- jarvis/jarvis_platform_manager/main.py +68 -76
- jarvis/jarvis_platform_manager/service.py +24 -14
- jarvis/jarvis_rag/README_CONFIG.md +314 -0
- jarvis/jarvis_rag/README_DYNAMIC_LOADING.md +311 -0
- jarvis/jarvis_rag/README_ONLINE_MODELS.md +230 -0
- jarvis/jarvis_rag/__init__.py +57 -4
- jarvis/jarvis_rag/cache.py +3 -1
- jarvis/jarvis_rag/cli.py +48 -68
- jarvis/jarvis_rag/embedding_interface.py +39 -0
- jarvis/jarvis_rag/embedding_manager.py +7 -230
- jarvis/jarvis_rag/embeddings/__init__.py +41 -0
- jarvis/jarvis_rag/embeddings/base.py +114 -0
- jarvis/jarvis_rag/embeddings/cohere.py +66 -0
- jarvis/jarvis_rag/embeddings/edgefn.py +117 -0
- jarvis/jarvis_rag/embeddings/local.py +260 -0
- jarvis/jarvis_rag/embeddings/openai.py +62 -0
- jarvis/jarvis_rag/embeddings/registry.py +293 -0
- jarvis/jarvis_rag/llm_interface.py +8 -6
- jarvis/jarvis_rag/query_rewriter.py +8 -9
- jarvis/jarvis_rag/rag_pipeline.py +61 -52
- jarvis/jarvis_rag/reranker.py +7 -75
- jarvis/jarvis_rag/reranker_interface.py +32 -0
- jarvis/jarvis_rag/rerankers/__init__.py +41 -0
- jarvis/jarvis_rag/rerankers/base.py +109 -0
- jarvis/jarvis_rag/rerankers/cohere.py +67 -0
- jarvis/jarvis_rag/rerankers/edgefn.py +140 -0
- jarvis/jarvis_rag/rerankers/jina.py +79 -0
- jarvis/jarvis_rag/rerankers/local.py +89 -0
- jarvis/jarvis_rag/rerankers/registry.py +293 -0
- jarvis/jarvis_rag/retriever.py +58 -43
- jarvis/jarvis_sec/__init__.py +66 -141
- jarvis/jarvis_sec/agents.py +21 -17
- jarvis/jarvis_sec/analysis.py +80 -33
- jarvis/jarvis_sec/checkers/__init__.py +7 -13
- jarvis/jarvis_sec/checkers/c_checker.py +356 -164
- jarvis/jarvis_sec/checkers/rust_checker.py +47 -29
- jarvis/jarvis_sec/cli.py +43 -21
- jarvis/jarvis_sec/clustering.py +430 -272
- jarvis/jarvis_sec/file_manager.py +99 -55
- jarvis/jarvis_sec/parsers.py +9 -6
- jarvis/jarvis_sec/prompts.py +4 -3
- jarvis/jarvis_sec/report.py +44 -22
- jarvis/jarvis_sec/review.py +180 -107
- jarvis/jarvis_sec/status.py +50 -41
- jarvis/jarvis_sec/types.py +3 -0
- jarvis/jarvis_sec/utils.py +160 -83
- jarvis/jarvis_sec/verification.py +411 -181
- jarvis/jarvis_sec/workflow.py +132 -21
- jarvis/jarvis_smart_shell/main.py +28 -41
- jarvis/jarvis_stats/cli.py +14 -12
- jarvis/jarvis_stats/stats.py +28 -19
- jarvis/jarvis_stats/storage.py +14 -8
- jarvis/jarvis_stats/visualizer.py +12 -7
- jarvis/jarvis_tools/base.py +5 -2
- jarvis/jarvis_tools/clear_memory.py +13 -9
- jarvis/jarvis_tools/cli/main.py +23 -18
- jarvis/jarvis_tools/edit_file.py +572 -873
- jarvis/jarvis_tools/execute_script.py +10 -7
- jarvis/jarvis_tools/file_analyzer.py +7 -8
- jarvis/jarvis_tools/meta_agent.py +287 -0
- jarvis/jarvis_tools/methodology.py +5 -3
- jarvis/jarvis_tools/read_code.py +305 -1438
- jarvis/jarvis_tools/read_symbols.py +50 -17
- jarvis/jarvis_tools/read_webpage.py +19 -18
- jarvis/jarvis_tools/registry.py +435 -156
- jarvis/jarvis_tools/retrieve_memory.py +16 -11
- jarvis/jarvis_tools/save_memory.py +8 -6
- jarvis/jarvis_tools/search_web.py +31 -31
- jarvis/jarvis_tools/sub_agent.py +32 -28
- jarvis/jarvis_tools/sub_code_agent.py +44 -60
- jarvis/jarvis_tools/task_list_manager.py +1811 -0
- jarvis/jarvis_tools/virtual_tty.py +29 -19
- jarvis/jarvis_utils/__init__.py +4 -0
- jarvis/jarvis_utils/builtin_replace_map.py +2 -1
- jarvis/jarvis_utils/clipboard.py +9 -8
- jarvis/jarvis_utils/collections.py +331 -0
- jarvis/jarvis_utils/config.py +699 -194
- jarvis/jarvis_utils/dialogue_recorder.py +294 -0
- jarvis/jarvis_utils/embedding.py +6 -3
- jarvis/jarvis_utils/file_processors.py +7 -1
- jarvis/jarvis_utils/fzf.py +9 -3
- jarvis/jarvis_utils/git_utils.py +71 -42
- jarvis/jarvis_utils/globals.py +116 -32
- jarvis/jarvis_utils/http.py +6 -2
- jarvis/jarvis_utils/input.py +318 -83
- jarvis/jarvis_utils/jsonnet_compat.py +119 -104
- jarvis/jarvis_utils/methodology.py +37 -28
- jarvis/jarvis_utils/output.py +201 -44
- jarvis/jarvis_utils/utils.py +986 -628
- {jarvis_ai_assistant-0.7.16.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/METADATA +49 -33
- jarvis_ai_assistant-1.0.2.dist-info/RECORD +304 -0
- jarvis/jarvis_code_agent/code_analyzer/structured_code.py +0 -556
- jarvis/jarvis_tools/generate_new_tool.py +0 -205
- jarvis/jarvis_tools/lsp_client.py +0 -1552
- jarvis/jarvis_tools/rewrite_file.py +0 -105
- jarvis_ai_assistant-0.7.16.dist-info/RECORD +0 -218
- {jarvis_ai_assistant-0.7.16.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.7.16.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.7.16.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.7.16.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""
|
|
2
|
+
本地嵌入模型实现。
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from jarvis.jarvis_utils.output import PrettyOutput
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from typing import List
|
|
9
|
+
from typing import Optional
|
|
10
|
+
from typing import cast
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from huggingface_hub import snapshot_download
|
|
14
|
+
from langchain_huggingface import HuggingFaceEmbeddings
|
|
15
|
+
|
|
16
|
+
from ..cache import EmbeddingCache
|
|
17
|
+
from ..embedding_interface import EmbeddingInterface
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LocalEmbeddingModel(EmbeddingInterface):
|
|
21
|
+
"""
|
|
22
|
+
管理本地嵌入模型的加载和使用,并带有缓存功能。
|
|
23
|
+
|
|
24
|
+
该类负责从Hugging Face加载指定的模型,并使用基于磁盘的缓存
|
|
25
|
+
来避免为相同文本重新计算嵌入。
|
|
26
|
+
|
|
27
|
+
这是 EmbeddingInterface 的本地实现,使用 HuggingFace 模型。
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self, model_name: str, cache_dir: str, max_length: Optional[int] = None
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
初始化LocalEmbeddingModel。
|
|
35
|
+
|
|
36
|
+
参数:
|
|
37
|
+
model_name: 要加载的Hugging Face模型的名称。
|
|
38
|
+
cache_dir: 用于存储嵌入缓存的目录。
|
|
39
|
+
max_length: 模型的最大输入长度(token数),用于文档分割。
|
|
40
|
+
"""
|
|
41
|
+
self.model_name = model_name
|
|
42
|
+
self.max_length = max_length
|
|
43
|
+
|
|
44
|
+
PrettyOutput.auto_print(f"ℹ️ 初始化嵌入管理器, 模型: '{self.model_name}'...")
|
|
45
|
+
|
|
46
|
+
# 缓存的salt是模型名称,以防止冲突
|
|
47
|
+
self.cache = EmbeddingCache(cache_dir=cache_dir, salt=self.model_name)
|
|
48
|
+
self.model = self._load_model()
|
|
49
|
+
|
|
50
|
+
def _load_model(self) -> HuggingFaceEmbeddings:
|
|
51
|
+
"""根据配置加载Hugging Face嵌入模型。"""
|
|
52
|
+
model_kwargs = {"device": "cuda" if torch.cuda.is_available() else "cpu"}
|
|
53
|
+
encode_kwargs = {"normalize_embeddings": True}
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
# First try to load model from local cache without any network access
|
|
57
|
+
try:
|
|
58
|
+
local_dir = None
|
|
59
|
+
# Prefer explicit local dir via env or direct path
|
|
60
|
+
|
|
61
|
+
if os.path.isdir(self.model_name):
|
|
62
|
+
return HuggingFaceEmbeddings(
|
|
63
|
+
model_name=self.model_name,
|
|
64
|
+
model_kwargs=model_kwargs,
|
|
65
|
+
encode_kwargs=encode_kwargs,
|
|
66
|
+
show_progress=False,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Try common local cache directories for sentence-transformers and HF hub
|
|
70
|
+
try:
|
|
71
|
+
home = os.path.expanduser("~")
|
|
72
|
+
st_home = os.path.join(home, ".cache", "sentence_transformers")
|
|
73
|
+
torch_st_home = os.path.join(
|
|
74
|
+
home, ".cache", "torch", "sentence_transformers"
|
|
75
|
+
)
|
|
76
|
+
# Build common name variants found in local caches
|
|
77
|
+
org, name = (
|
|
78
|
+
self.model_name.split("/", 1)
|
|
79
|
+
if "/" in self.model_name
|
|
80
|
+
else ("", self.model_name)
|
|
81
|
+
)
|
|
82
|
+
san1 = self.model_name.replace("/", "_")
|
|
83
|
+
san2 = self.model_name.replace("/", "__")
|
|
84
|
+
san3 = self.model_name.replace("/", "--")
|
|
85
|
+
# include plain 'name' for caches that drop org prefix
|
|
86
|
+
name_variants = list(
|
|
87
|
+
dict.fromkeys([self.model_name, san1, san2, san3, name])
|
|
88
|
+
)
|
|
89
|
+
candidates = []
|
|
90
|
+
for base in [st_home, torch_st_home]:
|
|
91
|
+
for nv in name_variants:
|
|
92
|
+
p = os.path.join(base, nv)
|
|
93
|
+
if os.path.isdir(p):
|
|
94
|
+
candidates.append(p)
|
|
95
|
+
# Fuzzy scan cache directory for entries that include variants
|
|
96
|
+
try:
|
|
97
|
+
for entry in os.listdir(base):
|
|
98
|
+
ep = os.path.join(base, entry)
|
|
99
|
+
if not os.path.isdir(ep):
|
|
100
|
+
continue
|
|
101
|
+
if (
|
|
102
|
+
(
|
|
103
|
+
org
|
|
104
|
+
and entry.startswith(f"{org}__")
|
|
105
|
+
and name in entry
|
|
106
|
+
)
|
|
107
|
+
or (san1 in entry)
|
|
108
|
+
or (name in entry)
|
|
109
|
+
):
|
|
110
|
+
candidates.append(ep)
|
|
111
|
+
except Exception:
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
# Hugging Face Hub cache snapshots
|
|
115
|
+
hf_cache = os.path.join(home, ".cache", "huggingface", "hub")
|
|
116
|
+
if "/" in self.model_name:
|
|
117
|
+
org, name = self.model_name.split("/", 1)
|
|
118
|
+
models_dir = os.path.join(
|
|
119
|
+
hf_cache, f"models--{org}--{name}", "snapshots"
|
|
120
|
+
)
|
|
121
|
+
if os.path.isdir(models_dir):
|
|
122
|
+
try:
|
|
123
|
+
snaps = sorted(
|
|
124
|
+
[
|
|
125
|
+
os.path.join(models_dir, d)
|
|
126
|
+
for d in os.listdir(models_dir)
|
|
127
|
+
],
|
|
128
|
+
key=lambda p: os.path.getmtime(p),
|
|
129
|
+
reverse=True,
|
|
130
|
+
)
|
|
131
|
+
except Exception:
|
|
132
|
+
snaps = [
|
|
133
|
+
os.path.join(models_dir, d)
|
|
134
|
+
for d in os.listdir(models_dir)
|
|
135
|
+
]
|
|
136
|
+
for sp in snaps:
|
|
137
|
+
if os.path.isdir(sp):
|
|
138
|
+
candidates.append(sp)
|
|
139
|
+
break
|
|
140
|
+
|
|
141
|
+
for cand in candidates:
|
|
142
|
+
try:
|
|
143
|
+
return HuggingFaceEmbeddings(
|
|
144
|
+
model_name=cand,
|
|
145
|
+
model_kwargs=model_kwargs,
|
|
146
|
+
encode_kwargs=encode_kwargs,
|
|
147
|
+
show_progress=False,
|
|
148
|
+
)
|
|
149
|
+
except Exception:
|
|
150
|
+
continue
|
|
151
|
+
except Exception:
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
# Try resolve local cached directory; do not hit network
|
|
156
|
+
local_dir = snapshot_download(
|
|
157
|
+
repo_id=self.model_name, local_files_only=True
|
|
158
|
+
)
|
|
159
|
+
except Exception:
|
|
160
|
+
local_dir = None
|
|
161
|
+
|
|
162
|
+
if local_dir:
|
|
163
|
+
return HuggingFaceEmbeddings(
|
|
164
|
+
model_name=local_dir,
|
|
165
|
+
model_kwargs=model_kwargs,
|
|
166
|
+
encode_kwargs=encode_kwargs,
|
|
167
|
+
show_progress=False,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Fall back to remote download if local cache not found and not offline
|
|
171
|
+
return HuggingFaceEmbeddings(
|
|
172
|
+
model_name=self.model_name,
|
|
173
|
+
model_kwargs=model_kwargs,
|
|
174
|
+
encode_kwargs=encode_kwargs,
|
|
175
|
+
show_progress=True,
|
|
176
|
+
)
|
|
177
|
+
except Exception as _e:
|
|
178
|
+
# 如果已检测到本地候选路径(直接目录 / 本地缓存快照),则视为本地加载失败,
|
|
179
|
+
# 为避免在用户期望"本地优先不联网"的情况下触发联网,直接抛错并给出修复建议。
|
|
180
|
+
had_local_candidate = False
|
|
181
|
+
try:
|
|
182
|
+
had_local_candidate = (
|
|
183
|
+
os.path.isdir(self.model_name)
|
|
184
|
+
# 如果上面 snapshot_download 命中了本地缓存,会将 local_dir 设为非 None
|
|
185
|
+
or (locals().get("local_dir") is not None)
|
|
186
|
+
)
|
|
187
|
+
except Exception:
|
|
188
|
+
pass
|
|
189
|
+
|
|
190
|
+
if had_local_candidate:
|
|
191
|
+
PrettyOutput.auto_print(
|
|
192
|
+
"❌ 检测到本地模型路径但加载失败。为避免触发网络访问,已中止远程回退。\n"
|
|
193
|
+
"请确认本地目录包含完整的 Transformers/Tokenizer 文件(如 config.json、model.safetensors、tokenizer.json/merges.txt 等),\n"
|
|
194
|
+
"或在配置中将 embedding_model 设置为该本地目录,或将模型放置到默认的 Hugging Face 缓存目录(例如 ~/.cache/huggingface/hub)。"
|
|
195
|
+
)
|
|
196
|
+
raise
|
|
197
|
+
|
|
198
|
+
# 未发现任何本地候选,则保持原有行为:回退至远程下载
|
|
199
|
+
return HuggingFaceEmbeddings(
|
|
200
|
+
model_name=self.model_name,
|
|
201
|
+
model_kwargs=model_kwargs,
|
|
202
|
+
encode_kwargs=encode_kwargs,
|
|
203
|
+
show_progress=True,
|
|
204
|
+
)
|
|
205
|
+
except Exception as e:
|
|
206
|
+
PrettyOutput.auto_print(f"❌ 加载嵌入模型 '{self.model_name}' 时出错: {e}")
|
|
207
|
+
PrettyOutput.auto_print(
|
|
208
|
+
"⚠️ 请确保您已安装 'sentence_transformers' 和 'torch'。"
|
|
209
|
+
)
|
|
210
|
+
raise
|
|
211
|
+
|
|
212
|
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
213
|
+
"""
|
|
214
|
+
使用缓存为文档列表计算嵌入。
|
|
215
|
+
|
|
216
|
+
参数:
|
|
217
|
+
texts: 要嵌入的文档(字符串)列表。
|
|
218
|
+
|
|
219
|
+
返回:
|
|
220
|
+
一个嵌入列表,每个文档对应一个嵌入。
|
|
221
|
+
"""
|
|
222
|
+
if not texts:
|
|
223
|
+
return []
|
|
224
|
+
|
|
225
|
+
# 检查缓存中是否已存在嵌入
|
|
226
|
+
cached_embeddings = self.cache.get_batch(texts)
|
|
227
|
+
|
|
228
|
+
texts_to_embed = []
|
|
229
|
+
indices_to_embed = []
|
|
230
|
+
for i, (text, cached) in enumerate(zip(texts, cached_embeddings)):
|
|
231
|
+
if cached is None:
|
|
232
|
+
texts_to_embed.append(text)
|
|
233
|
+
indices_to_embed.append(i)
|
|
234
|
+
|
|
235
|
+
# 为不在缓存中的文本计算嵌入
|
|
236
|
+
if texts_to_embed:
|
|
237
|
+
PrettyOutput.auto_print(
|
|
238
|
+
f"ℹ️ 缓存未命中。正在为 {len(texts_to_embed)}/{len(texts)} 个文档计算嵌入。"
|
|
239
|
+
)
|
|
240
|
+
new_embeddings = self.model.embed_documents(texts_to_embed)
|
|
241
|
+
|
|
242
|
+
# 将新的嵌入存储在缓存中
|
|
243
|
+
self.cache.set_batch(texts_to_embed, new_embeddings)
|
|
244
|
+
|
|
245
|
+
# 将新的嵌入放回结果列表中
|
|
246
|
+
for i, embedding in zip(indices_to_embed, new_embeddings):
|
|
247
|
+
cached_embeddings[i] = embedding
|
|
248
|
+
else:
|
|
249
|
+
PrettyOutput.auto_print(
|
|
250
|
+
f"✅ 缓存命中。所有 {len(texts)} 个文档的嵌入均从缓存中检索。"
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
return cast(List[List[float]], cached_embeddings)
|
|
254
|
+
|
|
255
|
+
def embed_query(self, text: str) -> List[float]:
|
|
256
|
+
"""
|
|
257
|
+
为单个查询计算嵌入。
|
|
258
|
+
查询通常不被缓存,但如果需要可以添加。
|
|
259
|
+
"""
|
|
260
|
+
return self.model.embed_query(text)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenAI 嵌入模型实现。
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from .base import OnlineEmbeddingModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenAIEmbeddingModel(OnlineEmbeddingModel):
|
|
12
|
+
"""
|
|
13
|
+
OpenAI 嵌入模型的实现。
|
|
14
|
+
|
|
15
|
+
使用 OpenAI 的 embeddings API。
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
model_name: str = "text-embedding-3-small",
|
|
22
|
+
base_url: Optional[str] = None,
|
|
23
|
+
**kwargs,
|
|
24
|
+
):
|
|
25
|
+
"""
|
|
26
|
+
初始化 OpenAI 嵌入模型。
|
|
27
|
+
|
|
28
|
+
参数:
|
|
29
|
+
api_key: OpenAI API密钥。如果为None,将从OPENAI_API_KEY环境变量读取。
|
|
30
|
+
model_name: 要使用的模型名称(如 'text-embedding-3-small')。
|
|
31
|
+
base_url: API的基础URL(用于自定义端点)。
|
|
32
|
+
**kwargs: 传递给父类的其他参数。
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(
|
|
35
|
+
api_key=api_key,
|
|
36
|
+
api_key_env="OPENAI_API_KEY",
|
|
37
|
+
base_url=base_url,
|
|
38
|
+
model_name=model_name,
|
|
39
|
+
**kwargs,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def _call_api(self, texts: List[str], is_query: bool = False) -> List[List[float]]:
|
|
43
|
+
"""
|
|
44
|
+
调用 OpenAI API 获取嵌入。
|
|
45
|
+
"""
|
|
46
|
+
try:
|
|
47
|
+
from openai import OpenAI
|
|
48
|
+
|
|
49
|
+
client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
50
|
+
|
|
51
|
+
response = client.embeddings.create(
|
|
52
|
+
model=self.model_name,
|
|
53
|
+
input=texts,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return [item.embedding for item in response.data]
|
|
57
|
+
except ImportError:
|
|
58
|
+
raise ImportError(
|
|
59
|
+
"使用 OpenAIEmbeddingModel 需要安装 openai 包: pip install openai"
|
|
60
|
+
)
|
|
61
|
+
except Exception as e:
|
|
62
|
+
raise RuntimeError(f"调用 OpenAI API 时出错: {e}")
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""
|
|
2
|
+
嵌入模型注册表,支持动态加载自定义嵌入模型实现。
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import importlib
|
|
6
|
+
import inspect
|
|
7
|
+
|
|
8
|
+
from jarvis.jarvis_utils.output import PrettyOutput
|
|
9
|
+
|
|
10
|
+
# -*- coding: utf-8 -*-
|
|
11
|
+
import os
|
|
12
|
+
import sys
|
|
13
|
+
from typing import Dict
|
|
14
|
+
from typing import List
|
|
15
|
+
from typing import Optional
|
|
16
|
+
from typing import Type
|
|
17
|
+
|
|
18
|
+
from jarvis.jarvis_utils.config import get_data_dir
|
|
19
|
+
|
|
20
|
+
from ..embedding_interface import EmbeddingInterface
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class EmbeddingRegistry:
|
|
24
|
+
"""嵌入模型注册表,支持动态加载自定义嵌入模型实现"""
|
|
25
|
+
|
|
26
|
+
global_registry: Optional["EmbeddingRegistry"] = None
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def get_embedding_dir() -> str:
|
|
30
|
+
"""获取用户自定义嵌入模型目录"""
|
|
31
|
+
embedding_dir = os.path.join(get_data_dir(), "embeddings")
|
|
32
|
+
if not os.path.exists(embedding_dir):
|
|
33
|
+
try:
|
|
34
|
+
os.makedirs(embedding_dir)
|
|
35
|
+
# 创建 __init__.py 使其成为 Python 包
|
|
36
|
+
with open(
|
|
37
|
+
os.path.join(embedding_dir, "__init__.py"), "w", errors="ignore"
|
|
38
|
+
):
|
|
39
|
+
pass
|
|
40
|
+
except Exception as e:
|
|
41
|
+
PrettyOutput.auto_print(f"❌ 创建嵌入模型目录失败: {str(e)}")
|
|
42
|
+
return ""
|
|
43
|
+
return embedding_dir
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def check_embedding_implementation(
|
|
47
|
+
embedding_class: Type[EmbeddingInterface],
|
|
48
|
+
) -> bool:
|
|
49
|
+
"""检查嵌入模型类是否实现了所有必需的方法
|
|
50
|
+
|
|
51
|
+
参数:
|
|
52
|
+
embedding_class: 要检查的嵌入模型类
|
|
53
|
+
|
|
54
|
+
返回:
|
|
55
|
+
bool: 是否实现了所有必需的方法
|
|
56
|
+
"""
|
|
57
|
+
required_methods = [
|
|
58
|
+
("embed_documents", ["texts"]),
|
|
59
|
+
("embed_query", ["text"]),
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
missing_methods = []
|
|
63
|
+
|
|
64
|
+
for method_name, params in required_methods:
|
|
65
|
+
if not hasattr(embedding_class, method_name):
|
|
66
|
+
missing_methods.append(method_name)
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
method = getattr(embedding_class, method_name)
|
|
70
|
+
if not callable(method):
|
|
71
|
+
missing_methods.append(method_name)
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
# 检查方法参数
|
|
75
|
+
sig = inspect.signature(method)
|
|
76
|
+
method_params = [p for p in sig.parameters if p != "self"]
|
|
77
|
+
if len(method_params) != len(params):
|
|
78
|
+
missing_methods.append(f"{method_name}(parameter mismatch)")
|
|
79
|
+
|
|
80
|
+
if missing_methods:
|
|
81
|
+
PrettyOutput.auto_print(
|
|
82
|
+
f"⚠️ 嵌入模型 {embedding_class.__name__} 缺少必要的方法: {', '.join(missing_methods)}"
|
|
83
|
+
)
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def load_embeddings_from_dir(
|
|
90
|
+
directory: str,
|
|
91
|
+
) -> Dict[str, Type[EmbeddingInterface]]:
|
|
92
|
+
"""从指定目录加载嵌入模型
|
|
93
|
+
|
|
94
|
+
参数:
|
|
95
|
+
directory: 嵌入模型目录路径
|
|
96
|
+
|
|
97
|
+
返回:
|
|
98
|
+
Dict[str, Type[EmbeddingInterface]]: 嵌入模型名称到类的映射
|
|
99
|
+
"""
|
|
100
|
+
embeddings: Dict[str, Type[EmbeddingInterface]] = {}
|
|
101
|
+
|
|
102
|
+
# 确保目录存在
|
|
103
|
+
if not os.path.exists(directory):
|
|
104
|
+
PrettyOutput.auto_print(f"⚠️ 嵌入模型目录不存在: {directory}")
|
|
105
|
+
return embeddings
|
|
106
|
+
|
|
107
|
+
# 获取目录的包名
|
|
108
|
+
package_name = None
|
|
109
|
+
if directory == os.path.dirname(__file__):
|
|
110
|
+
package_name = "jarvis.jarvis_rag.embeddings"
|
|
111
|
+
|
|
112
|
+
# 添加目录到Python路径
|
|
113
|
+
if directory not in sys.path:
|
|
114
|
+
sys.path.append(directory)
|
|
115
|
+
|
|
116
|
+
error_lines = []
|
|
117
|
+
# 遍历目录下的所有.py文件
|
|
118
|
+
for filename in os.listdir(directory):
|
|
119
|
+
if filename.endswith(".py") and not filename.startswith("__"):
|
|
120
|
+
module_name = filename[:-3] # 移除.py后缀
|
|
121
|
+
try:
|
|
122
|
+
# 导入模块
|
|
123
|
+
if package_name:
|
|
124
|
+
module = importlib.import_module(
|
|
125
|
+
f"{package_name}.{module_name}"
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
module = importlib.import_module(module_name)
|
|
129
|
+
|
|
130
|
+
# 遍历模块中的所有类
|
|
131
|
+
for _, obj in inspect.getmembers(module):
|
|
132
|
+
# 检查是否是EmbeddingInterface的子类,但不是EmbeddingInterface本身
|
|
133
|
+
if (
|
|
134
|
+
inspect.isclass(obj)
|
|
135
|
+
and issubclass(obj, EmbeddingInterface)
|
|
136
|
+
and obj != EmbeddingInterface
|
|
137
|
+
):
|
|
138
|
+
# 检查嵌入模型实现
|
|
139
|
+
if not EmbeddingRegistry.check_embedding_implementation(
|
|
140
|
+
obj
|
|
141
|
+
):
|
|
142
|
+
continue
|
|
143
|
+
try:
|
|
144
|
+
# 使用类名作为注册名(可以后续扩展为使用类方法获取名称)
|
|
145
|
+
embedding_name = obj.__name__
|
|
146
|
+
embeddings[embedding_name] = obj
|
|
147
|
+
except Exception as e:
|
|
148
|
+
error_lines.append(
|
|
149
|
+
f"注册嵌入模型失败 {obj.__name__}: {str(e)}"
|
|
150
|
+
)
|
|
151
|
+
except Exception as e:
|
|
152
|
+
error_lines.append(f"加载嵌入模型 {module_name} 失败: {str(e)}")
|
|
153
|
+
|
|
154
|
+
if error_lines:
|
|
155
|
+
joined_errors = "\n".join(error_lines)
|
|
156
|
+
PrettyOutput.auto_print(f"❌ {joined_errors}")
|
|
157
|
+
return embeddings
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def get_global_registry() -> "EmbeddingRegistry":
|
|
161
|
+
"""获取全局嵌入模型注册表"""
|
|
162
|
+
if EmbeddingRegistry.global_registry is None:
|
|
163
|
+
EmbeddingRegistry.global_registry = EmbeddingRegistry()
|
|
164
|
+
return EmbeddingRegistry.global_registry
|
|
165
|
+
|
|
166
|
+
def __init__(self) -> None:
|
|
167
|
+
"""初始化嵌入模型注册表"""
|
|
168
|
+
self.embeddings: Dict[str, Type[EmbeddingInterface]] = {}
|
|
169
|
+
|
|
170
|
+
# 从用户自定义目录加载额外嵌入模型
|
|
171
|
+
embedding_dir = EmbeddingRegistry.get_embedding_dir()
|
|
172
|
+
if embedding_dir and os.path.exists(embedding_dir):
|
|
173
|
+
for (
|
|
174
|
+
embedding_name,
|
|
175
|
+
embedding_class,
|
|
176
|
+
) in EmbeddingRegistry.load_embeddings_from_dir(embedding_dir).items():
|
|
177
|
+
self.register_embedding(embedding_name, embedding_class)
|
|
178
|
+
|
|
179
|
+
# 从内置目录加载嵌入模型
|
|
180
|
+
embedding_dir = os.path.dirname(__file__)
|
|
181
|
+
if embedding_dir and os.path.exists(embedding_dir):
|
|
182
|
+
for (
|
|
183
|
+
embedding_name,
|
|
184
|
+
embedding_class,
|
|
185
|
+
) in EmbeddingRegistry.load_embeddings_from_dir(embedding_dir).items():
|
|
186
|
+
self.register_embedding(embedding_name, embedding_class)
|
|
187
|
+
|
|
188
|
+
def register_embedding(
|
|
189
|
+
self, name: str, embedding_class: Type[EmbeddingInterface]
|
|
190
|
+
) -> None:
|
|
191
|
+
"""注册嵌入模型类
|
|
192
|
+
|
|
193
|
+
参数:
|
|
194
|
+
name: 嵌入模型名称
|
|
195
|
+
embedding_class: 嵌入模型类
|
|
196
|
+
"""
|
|
197
|
+
self.embeddings[name] = embedding_class
|
|
198
|
+
|
|
199
|
+
def create_embedding(
|
|
200
|
+
self, name: str, *args, **kwargs
|
|
201
|
+
) -> Optional[EmbeddingInterface]:
|
|
202
|
+
"""创建嵌入模型实例
|
|
203
|
+
|
|
204
|
+
参数:
|
|
205
|
+
name: 嵌入模型名称
|
|
206
|
+
*args: 传递给构造函数的参数
|
|
207
|
+
**kwargs: 传递给构造函数的关键字参数
|
|
208
|
+
|
|
209
|
+
返回:
|
|
210
|
+
EmbeddingInterface: 嵌入模型实例
|
|
211
|
+
"""
|
|
212
|
+
if name not in self.embeddings:
|
|
213
|
+
PrettyOutput.auto_print(f"⚠️ 未找到嵌入模型: {name}")
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
embedding = self.embeddings[name](*args, **kwargs)
|
|
218
|
+
return embedding
|
|
219
|
+
except Exception as e:
|
|
220
|
+
PrettyOutput.auto_print(f"❌ 创建嵌入模型失败: {str(e)}")
|
|
221
|
+
return None
|
|
222
|
+
|
|
223
|
+
def get_available_embeddings(self) -> List[str]:
|
|
224
|
+
"""获取可用的嵌入模型列表"""
|
|
225
|
+
return list(self.embeddings.keys())
|
|
226
|
+
|
|
227
|
+
@staticmethod
|
|
228
|
+
def create_from_config() -> Optional[EmbeddingInterface]:
|
|
229
|
+
"""从配置创建嵌入模型实例
|
|
230
|
+
|
|
231
|
+
从配置系统读取embedding_type、embedding_model和embedding_config,
|
|
232
|
+
然后创建相应的嵌入模型实例。
|
|
233
|
+
|
|
234
|
+
返回:
|
|
235
|
+
Optional[EmbeddingInterface]: 嵌入模型实例,如果创建失败则返回None
|
|
236
|
+
"""
|
|
237
|
+
from jarvis.jarvis_utils.config import get_rag_embedding_cache_path
|
|
238
|
+
from jarvis.jarvis_utils.config import get_rag_embedding_config
|
|
239
|
+
from jarvis.jarvis_utils.config import get_rag_embedding_model
|
|
240
|
+
from jarvis.jarvis_utils.config import get_rag_embedding_type
|
|
241
|
+
|
|
242
|
+
embedding_type = get_rag_embedding_type()
|
|
243
|
+
model_name = get_rag_embedding_model()
|
|
244
|
+
embedding_config = get_rag_embedding_config()
|
|
245
|
+
|
|
246
|
+
registry = EmbeddingRegistry.get_global_registry()
|
|
247
|
+
|
|
248
|
+
from jarvis.jarvis_utils.config import get_rag_embedding_max_length
|
|
249
|
+
|
|
250
|
+
# 构建创建参数
|
|
251
|
+
create_kwargs = {"model_name": model_name}
|
|
252
|
+
create_kwargs.update(embedding_config)
|
|
253
|
+
|
|
254
|
+
# 将配置中的键名映射到 api_key 参数
|
|
255
|
+
# 支持的键名:openai_api_key, cohere_api_key, edgefn_api_key, jina_api_key
|
|
256
|
+
api_key_mapping = {
|
|
257
|
+
"openai_api_key": "api_key",
|
|
258
|
+
"cohere_api_key": "api_key",
|
|
259
|
+
"edgefn_api_key": "api_key",
|
|
260
|
+
"jina_api_key": "api_key",
|
|
261
|
+
}
|
|
262
|
+
for config_key, param_key in api_key_mapping.items():
|
|
263
|
+
if config_key in create_kwargs:
|
|
264
|
+
# 如果还没有设置 api_key,则使用配置中的值
|
|
265
|
+
if param_key not in create_kwargs:
|
|
266
|
+
create_kwargs[param_key] = create_kwargs.pop(config_key)
|
|
267
|
+
else:
|
|
268
|
+
# 如果已经设置了 api_key,移除配置中的键
|
|
269
|
+
create_kwargs.pop(config_key)
|
|
270
|
+
|
|
271
|
+
# 同样处理 base_url
|
|
272
|
+
base_url_mapping = {
|
|
273
|
+
"openai_api_base": "base_url",
|
|
274
|
+
"cohere_api_base": "base_url",
|
|
275
|
+
"edgefn_api_base": "base_url",
|
|
276
|
+
"jina_api_base": "base_url",
|
|
277
|
+
}
|
|
278
|
+
for config_key, param_key in base_url_mapping.items():
|
|
279
|
+
if config_key in create_kwargs:
|
|
280
|
+
if param_key not in create_kwargs:
|
|
281
|
+
create_kwargs[param_key] = create_kwargs.pop(config_key)
|
|
282
|
+
else:
|
|
283
|
+
create_kwargs.pop(config_key)
|
|
284
|
+
|
|
285
|
+
# 添加max_length(如果配置中没有指定,使用配置系统的默认值)
|
|
286
|
+
if "max_length" not in create_kwargs:
|
|
287
|
+
create_kwargs["max_length"] = str(get_rag_embedding_max_length())
|
|
288
|
+
|
|
289
|
+
# 如果是LocalEmbeddingModel,需要添加cache_dir
|
|
290
|
+
if embedding_type == "LocalEmbeddingModel":
|
|
291
|
+
create_kwargs["cache_dir"] = get_rag_embedding_cache_path()
|
|
292
|
+
|
|
293
|
+
return registry.create_embedding(embedding_type, **create_kwargs)
|
|
@@ -1,8 +1,10 @@
|
|
|
1
|
-
from abc import ABC
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from abc import abstractmethod
|
|
2
3
|
|
|
3
4
|
from jarvis.jarvis_agent import Agent as JarvisAgent
|
|
4
5
|
from jarvis.jarvis_platform.base import BasePlatform
|
|
5
6
|
from jarvis.jarvis_platform.registry import PlatformRegistry
|
|
7
|
+
from jarvis.jarvis_utils.output import PrettyOutput
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class LLMInterface(ABC):
|
|
@@ -38,7 +40,7 @@ class ToolAgent_LLM(LLMInterface):
|
|
|
38
40
|
"""
|
|
39
41
|
初始化工具-代理 LLM 包装器。
|
|
40
42
|
"""
|
|
41
|
-
|
|
43
|
+
PrettyOutput.auto_print("ℹ️ 已初始化工具 Agent 作为最终应答者。")
|
|
42
44
|
self.allowed_tools = ["read_code", "execute_script"]
|
|
43
45
|
# 为代理提供一个通用的系统提示
|
|
44
46
|
self.system_prompt = "You are a helpful assistant. Please answer the user's question based on the provided context. You can use tools to find more information if needed."
|
|
@@ -78,7 +80,7 @@ class ToolAgent_LLM(LLMInterface):
|
|
|
78
80
|
return str(final_answer)
|
|
79
81
|
|
|
80
82
|
except Exception as e:
|
|
81
|
-
|
|
83
|
+
PrettyOutput.auto_print(f"❌ Agent 在执行过程中发生错误: {e}")
|
|
82
84
|
return "错误: Agent 未能成功生成回答。"
|
|
83
85
|
|
|
84
86
|
|
|
@@ -98,11 +100,11 @@ class JarvisPlatform_LLM(LLMInterface):
|
|
|
98
100
|
self.registry = PlatformRegistry.get_global_platform_registry()
|
|
99
101
|
self.platform: BasePlatform = self.registry.get_cheap_platform()
|
|
100
102
|
self.platform.set_suppress_output(False) # 确保模型没有控制台输出
|
|
101
|
-
|
|
103
|
+
PrettyOutput.auto_print(
|
|
102
104
|
f"ℹ️ 已初始化 Jarvis 平台 LLM(cheap),模型: {self.platform.name()}"
|
|
103
105
|
)
|
|
104
106
|
except Exception as e:
|
|
105
|
-
|
|
107
|
+
PrettyOutput.auto_print(f"❌ 初始化 Jarvis 平台 LLM 失败: {e}")
|
|
106
108
|
raise
|
|
107
109
|
|
|
108
110
|
def generate(self, prompt: str, **kwargs) -> str:
|
|
@@ -120,5 +122,5 @@ class JarvisPlatform_LLM(LLMInterface):
|
|
|
120
122
|
# 使用健壮的chat_until_success方法
|
|
121
123
|
return self.platform.chat_until_success(prompt)
|
|
122
124
|
except Exception as e:
|
|
123
|
-
|
|
125
|
+
PrettyOutput.auto_print(f"❌ 调用 Jarvis 平台模型时发生错误: {e}")
|
|
124
126
|
return "错误: 无法从本地LLM获取响应。"
|