jarvis-ai-assistant 0.7.8__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +567 -222
  3. jarvis/jarvis_agent/agent_manager.py +19 -12
  4. jarvis/jarvis_agent/builtin_input_handler.py +79 -11
  5. jarvis/jarvis_agent/config_editor.py +7 -2
  6. jarvis/jarvis_agent/event_bus.py +24 -13
  7. jarvis/jarvis_agent/events.py +19 -1
  8. jarvis/jarvis_agent/file_context_handler.py +67 -64
  9. jarvis/jarvis_agent/file_methodology_manager.py +38 -24
  10. jarvis/jarvis_agent/jarvis.py +186 -114
  11. jarvis/jarvis_agent/language_extractors/__init__.py +8 -1
  12. jarvis/jarvis_agent/language_extractors/c_extractor.py +7 -4
  13. jarvis/jarvis_agent/language_extractors/cpp_extractor.py +9 -4
  14. jarvis/jarvis_agent/language_extractors/go_extractor.py +7 -4
  15. jarvis/jarvis_agent/language_extractors/java_extractor.py +27 -20
  16. jarvis/jarvis_agent/language_extractors/javascript_extractor.py +22 -17
  17. jarvis/jarvis_agent/language_extractors/python_extractor.py +7 -4
  18. jarvis/jarvis_agent/language_extractors/rust_extractor.py +7 -4
  19. jarvis/jarvis_agent/language_extractors/typescript_extractor.py +22 -17
  20. jarvis/jarvis_agent/language_support_info.py +250 -219
  21. jarvis/jarvis_agent/main.py +19 -23
  22. jarvis/jarvis_agent/memory_manager.py +9 -6
  23. jarvis/jarvis_agent/methodology_share_manager.py +21 -15
  24. jarvis/jarvis_agent/output_handler.py +4 -2
  25. jarvis/jarvis_agent/prompt_builder.py +7 -6
  26. jarvis/jarvis_agent/prompt_manager.py +113 -8
  27. jarvis/jarvis_agent/prompts.py +317 -85
  28. jarvis/jarvis_agent/protocols.py +5 -2
  29. jarvis/jarvis_agent/run_loop.py +192 -32
  30. jarvis/jarvis_agent/session_manager.py +7 -3
  31. jarvis/jarvis_agent/share_manager.py +23 -13
  32. jarvis/jarvis_agent/shell_input_handler.py +12 -8
  33. jarvis/jarvis_agent/stdio_redirect.py +25 -26
  34. jarvis/jarvis_agent/task_analyzer.py +29 -23
  35. jarvis/jarvis_agent/task_list.py +869 -0
  36. jarvis/jarvis_agent/task_manager.py +26 -23
  37. jarvis/jarvis_agent/tool_executor.py +6 -5
  38. jarvis/jarvis_agent/tool_share_manager.py +24 -14
  39. jarvis/jarvis_agent/user_interaction.py +3 -3
  40. jarvis/jarvis_agent/utils.py +9 -1
  41. jarvis/jarvis_agent/web_bridge.py +37 -17
  42. jarvis/jarvis_agent/web_output_sink.py +5 -2
  43. jarvis/jarvis_agent/web_server.py +165 -36
  44. jarvis/jarvis_c2rust/__init__.py +1 -1
  45. jarvis/jarvis_c2rust/cli.py +260 -141
  46. jarvis/jarvis_c2rust/collector.py +37 -18
  47. jarvis/jarvis_c2rust/constants.py +60 -0
  48. jarvis/jarvis_c2rust/library_replacer.py +242 -1010
  49. jarvis/jarvis_c2rust/library_replacer_checkpoint.py +133 -0
  50. jarvis/jarvis_c2rust/library_replacer_llm.py +287 -0
  51. jarvis/jarvis_c2rust/library_replacer_loader.py +191 -0
  52. jarvis/jarvis_c2rust/library_replacer_output.py +134 -0
  53. jarvis/jarvis_c2rust/library_replacer_prompts.py +124 -0
  54. jarvis/jarvis_c2rust/library_replacer_utils.py +188 -0
  55. jarvis/jarvis_c2rust/llm_module_agent.py +98 -1044
  56. jarvis/jarvis_c2rust/llm_module_agent_apply.py +170 -0
  57. jarvis/jarvis_c2rust/llm_module_agent_executor.py +288 -0
  58. jarvis/jarvis_c2rust/llm_module_agent_loader.py +170 -0
  59. jarvis/jarvis_c2rust/llm_module_agent_prompts.py +268 -0
  60. jarvis/jarvis_c2rust/llm_module_agent_types.py +57 -0
  61. jarvis/jarvis_c2rust/llm_module_agent_utils.py +150 -0
  62. jarvis/jarvis_c2rust/llm_module_agent_validator.py +119 -0
  63. jarvis/jarvis_c2rust/loaders.py +28 -10
  64. jarvis/jarvis_c2rust/models.py +5 -2
  65. jarvis/jarvis_c2rust/optimizer.py +192 -1974
  66. jarvis/jarvis_c2rust/optimizer_build_fix.py +286 -0
  67. jarvis/jarvis_c2rust/optimizer_clippy.py +766 -0
  68. jarvis/jarvis_c2rust/optimizer_config.py +49 -0
  69. jarvis/jarvis_c2rust/optimizer_docs.py +183 -0
  70. jarvis/jarvis_c2rust/optimizer_options.py +48 -0
  71. jarvis/jarvis_c2rust/optimizer_progress.py +469 -0
  72. jarvis/jarvis_c2rust/optimizer_report.py +52 -0
  73. jarvis/jarvis_c2rust/optimizer_unsafe.py +309 -0
  74. jarvis/jarvis_c2rust/optimizer_utils.py +469 -0
  75. jarvis/jarvis_c2rust/optimizer_visibility.py +185 -0
  76. jarvis/jarvis_c2rust/scanner.py +229 -166
  77. jarvis/jarvis_c2rust/transpiler.py +531 -2732
  78. jarvis/jarvis_c2rust/transpiler_agents.py +503 -0
  79. jarvis/jarvis_c2rust/transpiler_build.py +1294 -0
  80. jarvis/jarvis_c2rust/transpiler_codegen.py +204 -0
  81. jarvis/jarvis_c2rust/transpiler_compile.py +146 -0
  82. jarvis/jarvis_c2rust/transpiler_config.py +178 -0
  83. jarvis/jarvis_c2rust/transpiler_context.py +122 -0
  84. jarvis/jarvis_c2rust/transpiler_executor.py +516 -0
  85. jarvis/jarvis_c2rust/transpiler_generation.py +278 -0
  86. jarvis/jarvis_c2rust/transpiler_git.py +163 -0
  87. jarvis/jarvis_c2rust/transpiler_mod_utils.py +225 -0
  88. jarvis/jarvis_c2rust/transpiler_modules.py +336 -0
  89. jarvis/jarvis_c2rust/transpiler_planning.py +394 -0
  90. jarvis/jarvis_c2rust/transpiler_review.py +1196 -0
  91. jarvis/jarvis_c2rust/transpiler_symbols.py +176 -0
  92. jarvis/jarvis_c2rust/utils.py +269 -79
  93. jarvis/jarvis_code_agent/after_change.py +233 -0
  94. jarvis/jarvis_code_agent/build_validation_config.py +37 -30
  95. jarvis/jarvis_code_agent/builtin_rules.py +68 -0
  96. jarvis/jarvis_code_agent/code_agent.py +976 -1517
  97. jarvis/jarvis_code_agent/code_agent_build.py +227 -0
  98. jarvis/jarvis_code_agent/code_agent_diff.py +246 -0
  99. jarvis/jarvis_code_agent/code_agent_git.py +525 -0
  100. jarvis/jarvis_code_agent/code_agent_impact.py +177 -0
  101. jarvis/jarvis_code_agent/code_agent_lint.py +283 -0
  102. jarvis/jarvis_code_agent/code_agent_llm.py +159 -0
  103. jarvis/jarvis_code_agent/code_agent_postprocess.py +105 -0
  104. jarvis/jarvis_code_agent/code_agent_prompts.py +46 -0
  105. jarvis/jarvis_code_agent/code_agent_rules.py +305 -0
  106. jarvis/jarvis_code_agent/code_analyzer/__init__.py +52 -48
  107. jarvis/jarvis_code_agent/code_analyzer/base_language.py +12 -10
  108. jarvis/jarvis_code_agent/code_analyzer/build_validator/__init__.py +12 -11
  109. jarvis/jarvis_code_agent/code_analyzer/build_validator/base.py +16 -12
  110. jarvis/jarvis_code_agent/code_analyzer/build_validator/cmake.py +26 -17
  111. jarvis/jarvis_code_agent/code_analyzer/build_validator/detector.py +558 -104
  112. jarvis/jarvis_code_agent/code_analyzer/build_validator/fallback.py +27 -16
  113. jarvis/jarvis_code_agent/code_analyzer/build_validator/go.py +22 -18
  114. jarvis/jarvis_code_agent/code_analyzer/build_validator/java_gradle.py +21 -16
  115. jarvis/jarvis_code_agent/code_analyzer/build_validator/java_maven.py +20 -16
  116. jarvis/jarvis_code_agent/code_analyzer/build_validator/makefile.py +27 -16
  117. jarvis/jarvis_code_agent/code_analyzer/build_validator/nodejs.py +47 -23
  118. jarvis/jarvis_code_agent/code_analyzer/build_validator/python.py +71 -37
  119. jarvis/jarvis_code_agent/code_analyzer/build_validator/rust.py +162 -35
  120. jarvis/jarvis_code_agent/code_analyzer/build_validator/validator.py +111 -57
  121. jarvis/jarvis_code_agent/code_analyzer/build_validator.py +18 -12
  122. jarvis/jarvis_code_agent/code_analyzer/context_manager.py +185 -183
  123. jarvis/jarvis_code_agent/code_analyzer/context_recommender.py +2 -1
  124. jarvis/jarvis_code_agent/code_analyzer/dependency_analyzer.py +24 -15
  125. jarvis/jarvis_code_agent/code_analyzer/file_ignore.py +227 -141
  126. jarvis/jarvis_code_agent/code_analyzer/impact_analyzer.py +321 -247
  127. jarvis/jarvis_code_agent/code_analyzer/language_registry.py +37 -29
  128. jarvis/jarvis_code_agent/code_analyzer/language_support.py +21 -13
  129. jarvis/jarvis_code_agent/code_analyzer/languages/__init__.py +15 -9
  130. jarvis/jarvis_code_agent/code_analyzer/languages/c_cpp_language.py +75 -45
  131. jarvis/jarvis_code_agent/code_analyzer/languages/go_language.py +87 -52
  132. jarvis/jarvis_code_agent/code_analyzer/languages/java_language.py +84 -51
  133. jarvis/jarvis_code_agent/code_analyzer/languages/javascript_language.py +94 -64
  134. jarvis/jarvis_code_agent/code_analyzer/languages/python_language.py +109 -71
  135. jarvis/jarvis_code_agent/code_analyzer/languages/rust_language.py +97 -63
  136. jarvis/jarvis_code_agent/code_analyzer/languages/typescript_language.py +103 -69
  137. jarvis/jarvis_code_agent/code_analyzer/llm_context_recommender.py +271 -268
  138. jarvis/jarvis_code_agent/code_analyzer/symbol_extractor.py +76 -64
  139. jarvis/jarvis_code_agent/code_analyzer/tree_sitter_extractor.py +92 -19
  140. jarvis/jarvis_code_agent/diff_visualizer.py +998 -0
  141. jarvis/jarvis_code_agent/lint.py +223 -524
  142. jarvis/jarvis_code_agent/rule_share_manager.py +158 -0
  143. jarvis/jarvis_code_agent/rules/clean_code.md +144 -0
  144. jarvis/jarvis_code_agent/rules/code_review.md +115 -0
  145. jarvis/jarvis_code_agent/rules/documentation.md +165 -0
  146. jarvis/jarvis_code_agent/rules/generate_rules.md +52 -0
  147. jarvis/jarvis_code_agent/rules/performance.md +158 -0
  148. jarvis/jarvis_code_agent/rules/refactoring.md +139 -0
  149. jarvis/jarvis_code_agent/rules/security.md +160 -0
  150. jarvis/jarvis_code_agent/rules/tdd.md +78 -0
  151. jarvis/jarvis_code_agent/test_rules/cpp_test.md +118 -0
  152. jarvis/jarvis_code_agent/test_rules/go_test.md +98 -0
  153. jarvis/jarvis_code_agent/test_rules/java_test.md +99 -0
  154. jarvis/jarvis_code_agent/test_rules/javascript_test.md +113 -0
  155. jarvis/jarvis_code_agent/test_rules/php_test.md +117 -0
  156. jarvis/jarvis_code_agent/test_rules/python_test.md +91 -0
  157. jarvis/jarvis_code_agent/test_rules/ruby_test.md +102 -0
  158. jarvis/jarvis_code_agent/test_rules/rust_test.md +86 -0
  159. jarvis/jarvis_code_agent/utils.py +36 -26
  160. jarvis/jarvis_code_analysis/checklists/loader.py +21 -21
  161. jarvis/jarvis_code_analysis/code_review.py +64 -33
  162. jarvis/jarvis_data/config_schema.json +285 -192
  163. jarvis/jarvis_git_squash/main.py +8 -6
  164. jarvis/jarvis_git_utils/git_commiter.py +53 -76
  165. jarvis/jarvis_mcp/__init__.py +5 -2
  166. jarvis/jarvis_mcp/sse_mcp_client.py +40 -30
  167. jarvis/jarvis_mcp/stdio_mcp_client.py +27 -19
  168. jarvis/jarvis_mcp/streamable_mcp_client.py +35 -26
  169. jarvis/jarvis_memory_organizer/memory_organizer.py +78 -55
  170. jarvis/jarvis_methodology/main.py +48 -39
  171. jarvis/jarvis_multi_agent/__init__.py +56 -23
  172. jarvis/jarvis_multi_agent/main.py +15 -18
  173. jarvis/jarvis_platform/base.py +179 -111
  174. jarvis/jarvis_platform/human.py +27 -16
  175. jarvis/jarvis_platform/kimi.py +52 -45
  176. jarvis/jarvis_platform/openai.py +101 -40
  177. jarvis/jarvis_platform/registry.py +51 -33
  178. jarvis/jarvis_platform/tongyi.py +68 -38
  179. jarvis/jarvis_platform/yuanbao.py +59 -43
  180. jarvis/jarvis_platform_manager/main.py +68 -76
  181. jarvis/jarvis_platform_manager/service.py +24 -14
  182. jarvis/jarvis_rag/README_CONFIG.md +314 -0
  183. jarvis/jarvis_rag/README_DYNAMIC_LOADING.md +311 -0
  184. jarvis/jarvis_rag/README_ONLINE_MODELS.md +230 -0
  185. jarvis/jarvis_rag/__init__.py +57 -4
  186. jarvis/jarvis_rag/cache.py +3 -1
  187. jarvis/jarvis_rag/cli.py +48 -68
  188. jarvis/jarvis_rag/embedding_interface.py +39 -0
  189. jarvis/jarvis_rag/embedding_manager.py +7 -230
  190. jarvis/jarvis_rag/embeddings/__init__.py +41 -0
  191. jarvis/jarvis_rag/embeddings/base.py +114 -0
  192. jarvis/jarvis_rag/embeddings/cohere.py +66 -0
  193. jarvis/jarvis_rag/embeddings/edgefn.py +117 -0
  194. jarvis/jarvis_rag/embeddings/local.py +260 -0
  195. jarvis/jarvis_rag/embeddings/openai.py +62 -0
  196. jarvis/jarvis_rag/embeddings/registry.py +293 -0
  197. jarvis/jarvis_rag/llm_interface.py +8 -6
  198. jarvis/jarvis_rag/query_rewriter.py +8 -9
  199. jarvis/jarvis_rag/rag_pipeline.py +61 -52
  200. jarvis/jarvis_rag/reranker.py +7 -75
  201. jarvis/jarvis_rag/reranker_interface.py +32 -0
  202. jarvis/jarvis_rag/rerankers/__init__.py +41 -0
  203. jarvis/jarvis_rag/rerankers/base.py +109 -0
  204. jarvis/jarvis_rag/rerankers/cohere.py +67 -0
  205. jarvis/jarvis_rag/rerankers/edgefn.py +140 -0
  206. jarvis/jarvis_rag/rerankers/jina.py +79 -0
  207. jarvis/jarvis_rag/rerankers/local.py +89 -0
  208. jarvis/jarvis_rag/rerankers/registry.py +293 -0
  209. jarvis/jarvis_rag/retriever.py +58 -43
  210. jarvis/jarvis_sec/__init__.py +66 -141
  211. jarvis/jarvis_sec/agents.py +21 -17
  212. jarvis/jarvis_sec/analysis.py +80 -33
  213. jarvis/jarvis_sec/checkers/__init__.py +7 -13
  214. jarvis/jarvis_sec/checkers/c_checker.py +356 -164
  215. jarvis/jarvis_sec/checkers/rust_checker.py +47 -29
  216. jarvis/jarvis_sec/cli.py +43 -21
  217. jarvis/jarvis_sec/clustering.py +430 -272
  218. jarvis/jarvis_sec/file_manager.py +99 -55
  219. jarvis/jarvis_sec/parsers.py +9 -6
  220. jarvis/jarvis_sec/prompts.py +4 -3
  221. jarvis/jarvis_sec/report.py +44 -22
  222. jarvis/jarvis_sec/review.py +180 -107
  223. jarvis/jarvis_sec/status.py +50 -41
  224. jarvis/jarvis_sec/types.py +3 -0
  225. jarvis/jarvis_sec/utils.py +160 -83
  226. jarvis/jarvis_sec/verification.py +411 -181
  227. jarvis/jarvis_sec/workflow.py +132 -21
  228. jarvis/jarvis_smart_shell/main.py +28 -41
  229. jarvis/jarvis_stats/cli.py +14 -12
  230. jarvis/jarvis_stats/stats.py +28 -19
  231. jarvis/jarvis_stats/storage.py +14 -8
  232. jarvis/jarvis_stats/visualizer.py +12 -7
  233. jarvis/jarvis_tools/base.py +5 -2
  234. jarvis/jarvis_tools/clear_memory.py +13 -9
  235. jarvis/jarvis_tools/cli/main.py +23 -18
  236. jarvis/jarvis_tools/edit_file.py +572 -873
  237. jarvis/jarvis_tools/execute_script.py +10 -7
  238. jarvis/jarvis_tools/file_analyzer.py +7 -8
  239. jarvis/jarvis_tools/meta_agent.py +287 -0
  240. jarvis/jarvis_tools/methodology.py +5 -3
  241. jarvis/jarvis_tools/read_code.py +305 -1438
  242. jarvis/jarvis_tools/read_symbols.py +50 -17
  243. jarvis/jarvis_tools/read_webpage.py +19 -18
  244. jarvis/jarvis_tools/registry.py +435 -156
  245. jarvis/jarvis_tools/retrieve_memory.py +16 -11
  246. jarvis/jarvis_tools/save_memory.py +8 -6
  247. jarvis/jarvis_tools/search_web.py +31 -31
  248. jarvis/jarvis_tools/sub_agent.py +32 -28
  249. jarvis/jarvis_tools/sub_code_agent.py +44 -60
  250. jarvis/jarvis_tools/task_list_manager.py +1811 -0
  251. jarvis/jarvis_tools/virtual_tty.py +29 -19
  252. jarvis/jarvis_utils/__init__.py +4 -0
  253. jarvis/jarvis_utils/builtin_replace_map.py +2 -1
  254. jarvis/jarvis_utils/clipboard.py +9 -8
  255. jarvis/jarvis_utils/collections.py +331 -0
  256. jarvis/jarvis_utils/config.py +699 -194
  257. jarvis/jarvis_utils/dialogue_recorder.py +294 -0
  258. jarvis/jarvis_utils/embedding.py +6 -3
  259. jarvis/jarvis_utils/file_processors.py +7 -1
  260. jarvis/jarvis_utils/fzf.py +9 -3
  261. jarvis/jarvis_utils/git_utils.py +71 -42
  262. jarvis/jarvis_utils/globals.py +116 -32
  263. jarvis/jarvis_utils/http.py +6 -2
  264. jarvis/jarvis_utils/input.py +318 -83
  265. jarvis/jarvis_utils/jsonnet_compat.py +119 -104
  266. jarvis/jarvis_utils/methodology.py +37 -28
  267. jarvis/jarvis_utils/output.py +201 -44
  268. jarvis/jarvis_utils/utils.py +986 -628
  269. {jarvis_ai_assistant-0.7.8.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/METADATA +49 -33
  270. jarvis_ai_assistant-1.0.2.dist-info/RECORD +304 -0
  271. jarvis/jarvis_code_agent/code_analyzer/structured_code.py +0 -556
  272. jarvis/jarvis_tools/generate_new_tool.py +0 -205
  273. jarvis/jarvis_tools/lsp_client.py +0 -1552
  274. jarvis/jarvis_tools/rewrite_file.py +0 -105
  275. jarvis_ai_assistant-0.7.8.dist-info/RECORD +0 -218
  276. {jarvis_ai_assistant-0.7.8.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/WHEEL +0 -0
  277. {jarvis_ai_assistant-0.7.8.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/entry_points.txt +0 -0
  278. {jarvis_ai_assistant-0.7.8.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/licenses/LICENSE +0 -0
  279. {jarvis_ai_assistant-0.7.8.dist-info → jarvis_ai_assistant-1.0.2.dist-info}/top_level.txt +0 -0
@@ -1,233 +1,10 @@
1
- import torch
2
- import os
3
- from typing import List, cast
4
- from langchain_huggingface import HuggingFaceEmbeddings
5
- from huggingface_hub import snapshot_download
1
+ """
2
+ 向后兼容模块:保持 EmbeddingManager 的导入路径。
6
3
 
4
+ 新的实现已移动到 embeddings/ 目录。
5
+ """
7
6
 
8
- from .cache import EmbeddingCache
7
+ from .embeddings.local import LocalEmbeddingModel
9
8
 
10
-
11
- class EmbeddingManager:
12
- """
13
- 管理本地嵌入模型的加载和使用,并带有缓存功能。
14
-
15
- 该类负责从Hugging Face加载指定的模型,并使用基于磁盘的缓存
16
- 来避免为相同文本重新计算嵌入。
17
- """
18
-
19
- def __init__(self, model_name: str, cache_dir: str):
20
- """
21
- 初始化EmbeddingManager。
22
-
23
- 参数:
24
- model_name: 要加载的Hugging Face模型的名称。
25
- cache_dir: 用于存储嵌入缓存的目录。
26
- """
27
- self.model_name = model_name
28
-
29
- print(
30
- f"ℹ️ 初始化嵌入管理器, 模型: '{self.model_name}'..."
31
- )
32
-
33
- # 缓存的salt是模型名称,以防止冲突
34
- self.cache = EmbeddingCache(cache_dir=cache_dir, salt=self.model_name)
35
- self.model = self._load_model()
36
-
37
- def _load_model(self) -> HuggingFaceEmbeddings:
38
- """根据配置加载Hugging Face嵌入模型。"""
39
- model_kwargs = {"device": "cuda" if torch.cuda.is_available() else "cpu"}
40
- encode_kwargs = {"normalize_embeddings": True}
41
-
42
- try:
43
- # First try to load model from local cache without any network access
44
- try:
45
- local_dir = None
46
- # Prefer explicit local dir via env or direct path
47
-
48
- if os.path.isdir(self.model_name):
49
- return HuggingFaceEmbeddings(
50
- model_name=self.model_name,
51
- model_kwargs=model_kwargs,
52
- encode_kwargs=encode_kwargs,
53
- show_progress=False,
54
- )
55
-
56
- # Try common local cache directories for sentence-transformers and HF hub
57
- try:
58
- home = os.path.expanduser("~")
59
- st_home = os.path.join(home, ".cache", "sentence_transformers")
60
- torch_st_home = os.path.join(home, ".cache", "torch", "sentence_transformers")
61
- # Build common name variants found in local caches
62
- org, name = (
63
- self.model_name.split("/", 1)
64
- if "/" in self.model_name
65
- else ("", self.model_name)
66
- )
67
- san1 = self.model_name.replace("/", "_")
68
- san2 = self.model_name.replace("/", "__")
69
- san3 = self.model_name.replace("/", "--")
70
- # include plain 'name' for caches that drop org prefix
71
- name_variants = list(dict.fromkeys([self.model_name, san1, san2, san3, name]))
72
- candidates = []
73
- for base in [st_home, torch_st_home]:
74
- for nv in name_variants:
75
- p = os.path.join(base, nv)
76
- if os.path.isdir(p):
77
- candidates.append(p)
78
- # Fuzzy scan cache directory for entries that include variants
79
- try:
80
- for entry in os.listdir(base):
81
- ep = os.path.join(base, entry)
82
- if not os.path.isdir(ep):
83
- continue
84
- if (
85
- (org and entry.startswith(f"{org}__") and name in entry)
86
- or (san1 in entry)
87
- or (name in entry)
88
- ):
89
- candidates.append(ep)
90
- except Exception:
91
- pass
92
-
93
- # Hugging Face Hub cache snapshots
94
- hf_cache = os.path.join(home, ".cache", "huggingface", "hub")
95
- if "/" in self.model_name:
96
- org, name = self.model_name.split("/", 1)
97
- models_dir = os.path.join(hf_cache, f"models--{org}--{name}", "snapshots")
98
- if os.path.isdir(models_dir):
99
- try:
100
- snaps = sorted(
101
- [os.path.join(models_dir, d) for d in os.listdir(models_dir)],
102
- key=lambda p: os.path.getmtime(p),
103
- reverse=True,
104
- )
105
- except Exception:
106
- snaps = [os.path.join(models_dir, d) for d in os.listdir(models_dir)]
107
- for sp in snaps:
108
- if os.path.isdir(sp):
109
- candidates.append(sp)
110
- break
111
-
112
- for cand in candidates:
113
- try:
114
- return HuggingFaceEmbeddings(
115
- model_name=cand,
116
- model_kwargs=model_kwargs,
117
- encode_kwargs=encode_kwargs,
118
- show_progress=False,
119
- )
120
- except Exception:
121
- continue
122
- except Exception:
123
- pass
124
-
125
- try:
126
- # Try resolve local cached directory; do not hit network
127
- local_dir = snapshot_download(repo_id=self.model_name, local_files_only=True)
128
- except Exception:
129
- local_dir = None
130
-
131
- if local_dir:
132
- return HuggingFaceEmbeddings(
133
- model_name=local_dir,
134
- model_kwargs=model_kwargs,
135
- encode_kwargs=encode_kwargs,
136
- show_progress=False,
137
- )
138
-
139
-
140
-
141
- # Fall back to remote download if local cache not found and not offline
142
- return HuggingFaceEmbeddings(
143
- model_name=self.model_name,
144
- model_kwargs=model_kwargs,
145
- encode_kwargs=encode_kwargs,
146
- show_progress=True,
147
- )
148
- except Exception as _e:
149
- # 如果已检测到本地候选路径(直接目录 / 本地缓存快照),则视为本地加载失败,
150
- # 为避免在用户期望“本地优先不联网”的情况下触发联网,直接抛错并给出修复建议。
151
- had_local_candidate = False
152
- try:
153
- had_local_candidate = (
154
- os.path.isdir(self.model_name)
155
- # 如果上面 snapshot_download 命中了本地缓存,会将 local_dir 设为非 None
156
- or (locals().get("local_dir") is not None)
157
- )
158
- except Exception:
159
- pass
160
-
161
- if had_local_candidate:
162
- print(
163
- "❌ 检测到本地模型路径但加载失败。为避免触发网络访问,已中止远程回退。\n"
164
- "请确认本地目录包含完整的 Transformers/Tokenizer 文件(如 config.json、model.safetensors、tokenizer.json/merges.txt 等),\n"
165
- "或在配置中将 embedding_model 设置为该本地目录,或将模型放置到默认的 Hugging Face 缓存目录(例如 ~/.cache/huggingface/hub)。"
166
- )
167
- raise
168
-
169
- # 未发现任何本地候选,则保持原有行为:回退至远程下载
170
- return HuggingFaceEmbeddings(
171
- model_name=self.model_name,
172
- model_kwargs=model_kwargs,
173
- encode_kwargs=encode_kwargs,
174
- show_progress=True,
175
- )
176
- except Exception as e:
177
- print(
178
- f"❌ 加载嵌入模型 '{self.model_name}' 时出错: {e}"
179
- )
180
- print(
181
- "⚠️ 请确保您已安装 'sentence_transformers' 和 'torch'。"
182
- )
183
- raise
184
-
185
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
186
- """
187
- 使用缓存为文档列表计算嵌入。
188
-
189
- 参数:
190
- texts: 要嵌入的文档(字符串)列表。
191
-
192
- 返回:
193
- 一个嵌入列表,每个文档对应一个嵌入。
194
- """
195
- if not texts:
196
- return []
197
-
198
- # 检查缓存中是否已存在嵌入
199
- cached_embeddings = self.cache.get_batch(texts)
200
-
201
- texts_to_embed = []
202
- indices_to_embed = []
203
- for i, (text, cached) in enumerate(zip(texts, cached_embeddings)):
204
- if cached is None:
205
- texts_to_embed.append(text)
206
- indices_to_embed.append(i)
207
-
208
- # 为不在缓存中的文本计算嵌入
209
- if texts_to_embed:
210
- print(
211
- f"ℹ️ 缓存未命中。正在为 {len(texts_to_embed)}/{len(texts)} 个文档计算嵌入。"
212
- )
213
- new_embeddings = self.model.embed_documents(texts_to_embed)
214
-
215
- # 将新的嵌入存储在缓存中
216
- self.cache.set_batch(texts_to_embed, new_embeddings)
217
-
218
- # 将新的嵌入放回结果列表中
219
- for i, embedding in zip(indices_to_embed, new_embeddings):
220
- cached_embeddings[i] = embedding
221
- else:
222
- print(
223
- f"✅ 缓存命中。所有 {len(texts)} 个文档的嵌入均从缓存中检索。"
224
- )
225
-
226
- return cast(List[List[float]], cached_embeddings)
227
-
228
- def embed_query(self, text: str) -> List[float]:
229
- """
230
- 为单个查询计算嵌入。
231
- 查询通常不被缓存,但如果需要可以添加。
232
- """
233
- return self.model.embed_query(text)
9
+ # 向后兼容:保持 EmbeddingManager 作为 LocalEmbeddingModel 的别名
10
+ EmbeddingManager = LocalEmbeddingModel
@@ -0,0 +1,41 @@
1
+ """
2
+ 嵌入模型实现模块。
3
+
4
+ 包含本地和在线嵌入模型的实现,支持动态加载自定义模型。
5
+ """
6
+
7
+ from .base import OnlineEmbeddingModel # noqa: F401
8
+ from .local import LocalEmbeddingModel
9
+ from .registry import EmbeddingRegistry
10
+
11
+ # 向后兼容别名
12
+ EmbeddingManager = LocalEmbeddingModel
13
+
14
+ # 在线模型实现(可选导入)
15
+ try:
16
+ from .cohere import CohereEmbeddingModel # noqa: F401
17
+ from .edgefn import EdgeFnEmbeddingModel # noqa: F401
18
+ from .openai import OpenAIEmbeddingModel # noqa: F401
19
+
20
+ _base_exports = [
21
+ "OnlineEmbeddingModel",
22
+ "LocalEmbeddingModel",
23
+ "EmbeddingManager", # 向后兼容
24
+ "EmbeddingRegistry",
25
+ "OpenAIEmbeddingModel",
26
+ "CohereEmbeddingModel",
27
+ "EdgeFnEmbeddingModel",
28
+ ]
29
+ except ImportError:
30
+ _base_exports = [
31
+ "OnlineEmbeddingModel",
32
+ "LocalEmbeddingModel",
33
+ "EmbeddingManager", # 向后兼容
34
+ "EmbeddingRegistry",
35
+ ]
36
+
37
+ # 动态加载的模型(通过 registry)
38
+ _registry = EmbeddingRegistry.get_global_registry()
39
+ _dynamic_exports = _registry.get_available_embeddings()
40
+
41
+ __all__ = _base_exports + _dynamic_exports
@@ -0,0 +1,114 @@
1
+ """
2
+ 在线嵌入模型的基类实现。
3
+ """
4
+
5
+ import os
6
+ from typing import List
7
+ from typing import Optional
8
+
9
+ from ..embedding_interface import EmbeddingInterface
10
+
11
+
12
+ class OnlineEmbeddingModel(EmbeddingInterface):
13
+ """
14
+ 在线嵌入模型的基类实现。
15
+
16
+ 这是一个抽象基类,定义了在线嵌入模型的基本结构。
17
+ 子类需要实现具体的API调用逻辑。
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ api_key: Optional[str] = None,
23
+ api_key_env: Optional[str] = None,
24
+ base_url: Optional[str] = None,
25
+ model_name: str = "text-embedding-3-small",
26
+ batch_size: int = 100,
27
+ max_length: Optional[int] = None,
28
+ **kwargs,
29
+ ):
30
+ """
31
+ 初始化在线嵌入模型。
32
+
33
+ 参数:
34
+ api_key: API密钥。如果为None,将从环境变量中读取。
35
+ api_key_env: 用于读取API密钥的环境变量名。
36
+ base_url: API的基础URL。
37
+ model_name: 要使用的模型名称。
38
+ batch_size: 批量处理时的批次大小。
39
+ max_length: 模型的最大输入长度(token数),用于文档分割。
40
+ **kwargs: 其他配置参数(可能包含从 embedding_config 传入的配置)。
41
+ """
42
+ # 优先从 kwargs 中读取 api_key(可能来自 embedding_config)
43
+ # 如果没有,则使用传入的 api_key 参数
44
+ # 最后才从环境变量读取(向后兼容)
45
+ self.api_key = (
46
+ kwargs.get("api_key")
47
+ or api_key
48
+ or (os.getenv(api_key_env) if api_key_env else None)
49
+ )
50
+ if not self.api_key:
51
+ raise ValueError(
52
+ f"API密钥未提供。请通过api_key参数、embedding_config配置或环境变量{api_key_env}提供。"
53
+ )
54
+
55
+ # 如果 base_url 在 kwargs 中,优先使用
56
+ if "base_url" in kwargs and kwargs["base_url"]:
57
+ self.base_url = kwargs["base_url"]
58
+ else:
59
+ self.base_url = base_url
60
+
61
+ self.model_name = model_name
62
+ self.batch_size = batch_size
63
+ self.max_length = max_length
64
+
65
+ def _call_api(self, texts: List[str], is_query: bool = False) -> List[List[float]]:
66
+ """
67
+ 调用在线API获取嵌入。
68
+
69
+ 参数:
70
+ texts: 要嵌入的文本列表。
71
+ is_query: 是否为查询(某些API对查询和文档有不同的端点)。
72
+
73
+ 返回:
74
+ 嵌入向量列表。
75
+
76
+ 注意:
77
+ 子类必须实现此方法。
78
+ """
79
+ raise NotImplementedError("子类必须实现 _call_api 方法")
80
+
81
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
82
+ """
83
+ 为文档列表计算嵌入。
84
+
85
+ 参数:
86
+ texts: 要嵌入的文档(字符串)列表。
87
+
88
+ 返回:
89
+ 一个嵌入列表,每个文档对应一个嵌入向量。
90
+ """
91
+ if not texts:
92
+ return []
93
+
94
+ # 批量处理以优化API调用
95
+ all_embeddings = []
96
+ for i in range(0, len(texts), self.batch_size):
97
+ batch = texts[i : i + self.batch_size]
98
+ batch_embeddings = self._call_api(batch, is_query=False)
99
+ all_embeddings.extend(batch_embeddings)
100
+
101
+ return all_embeddings
102
+
103
+ def embed_query(self, text: str) -> List[float]:
104
+ """
105
+ 为单个查询计算嵌入。
106
+
107
+ 参数:
108
+ text: 要嵌入的查询文本。
109
+
110
+ 返回:
111
+ 查询的嵌入向量。
112
+ """
113
+ embeddings = self._call_api([text], is_query=True)
114
+ return embeddings[0] if embeddings else []
@@ -0,0 +1,66 @@
1
+ """
2
+ Cohere 嵌入模型实现。
3
+ """
4
+
5
+ from typing import List
6
+ from typing import Optional
7
+ from typing import cast
8
+
9
+ from .base import OnlineEmbeddingModel
10
+
11
+
12
+ class CohereEmbeddingModel(OnlineEmbeddingModel):
13
+ """
14
+ Cohere 嵌入模型的实现。
15
+
16
+ 使用 Cohere 的 embed API。
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ api_key: Optional[str] = None,
22
+ model_name: str = "embed-english-v3.0",
23
+ input_type: str = "search_document",
24
+ **kwargs,
25
+ ):
26
+ """
27
+ 初始化 Cohere 嵌入模型。
28
+
29
+ 参数:
30
+ api_key: Cohere API密钥。如果为None,将从COHERE_API_KEY环境变量读取。
31
+ model_name: 要使用的模型名称。
32
+ input_type: 输入类型('search_document' 或 'search_query')。
33
+ **kwargs: 传递给父类的其他参数。
34
+ """
35
+ super().__init__(
36
+ api_key=api_key,
37
+ api_key_env="COHERE_API_KEY",
38
+ model_name=model_name,
39
+ **kwargs,
40
+ )
41
+ self.input_type = input_type
42
+
43
+ def _call_api(self, texts: List[str], is_query: bool = False) -> List[List[float]]:
44
+ """
45
+ 调用 Cohere API 获取嵌入。
46
+ """
47
+ try:
48
+ import cohere
49
+
50
+ client = cohere.Client(api_key=self.api_key)
51
+
52
+ input_type = "search_query" if is_query else "search_document"
53
+
54
+ response = client.embed(
55
+ texts=texts,
56
+ model=self.model_name,
57
+ input_type=input_type,
58
+ )
59
+
60
+ return cast(List[List[float]], response.embeddings)
61
+ except ImportError:
62
+ raise ImportError(
63
+ "使用 CohereEmbeddingModel 需要安装 cohere 包: pip install cohere"
64
+ )
65
+ except Exception as e:
66
+ raise RuntimeError(f"调用 Cohere API 时出错: {e}")
@@ -0,0 +1,117 @@
1
+ """
2
+ EdgeFn 嵌入模型实现。
3
+ """
4
+
5
+ from typing import List
6
+ from typing import Optional
7
+
8
+ from .base import OnlineEmbeddingModel
9
+
10
+
11
+ class EdgeFnEmbeddingModel(OnlineEmbeddingModel):
12
+ """
13
+ EdgeFn 嵌入模型的实现。
14
+
15
+ 使用 EdgeFn 的 embeddings API。
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ api_key: Optional[str] = None,
21
+ model_name: str = "BAAI/bge-m3",
22
+ base_url: str = "https://api.edgefn.net/v1/embeddings",
23
+ **kwargs,
24
+ ):
25
+ """
26
+ 初始化 EdgeFn 嵌入模型。
27
+
28
+ 参数:
29
+ api_key: EdgeFn API密钥。如果为None,将从EDGEFN_API_KEY环境变量读取。
30
+ model_name: 要使用的模型名称(如 'BAAI/bge-m3')。
31
+ base_url: API的基础URL。
32
+ **kwargs: 传递给父类的其他参数。
33
+ """
34
+ super().__init__(
35
+ api_key=api_key,
36
+ api_key_env="EDGEFN_API_KEY",
37
+ base_url=base_url,
38
+ model_name=model_name,
39
+ **kwargs,
40
+ )
41
+
42
+ def _call_api(self, texts: List[str], is_query: bool = False) -> List[List[float]]:
43
+ """
44
+ 调用 EdgeFn API 获取嵌入。
45
+
46
+ 参数:
47
+ texts: 要嵌入的文本列表。
48
+ is_query: 是否为查询(EdgeFn API 不区分查询和文档)。
49
+
50
+ 返回:
51
+ 嵌入向量列表。
52
+ """
53
+ try:
54
+ import requests
55
+
56
+ headers = {
57
+ "Authorization": f"Bearer {self.api_key}",
58
+ "Content-Type": "application/json",
59
+ }
60
+
61
+ # EdgeFn API 支持批量处理
62
+ # 如果只有一个文本,可以发送字符串;多个文本发送列表
63
+ payload = {
64
+ "model": self.model_name,
65
+ "input": texts[0] if len(texts) == 1 else texts,
66
+ }
67
+
68
+ response = requests.post(
69
+ self.base_url, headers=headers, json=payload, timeout=60
70
+ )
71
+ response.raise_for_status()
72
+
73
+ data = response.json()
74
+
75
+ # EdgeFn API 返回格式通常是: {"data": [{"embedding": [...]}, ...]}
76
+ # 处理不同的返回格式
77
+ if "data" in data:
78
+ # 标准格式:data 是列表,每个元素包含 embedding
79
+ if isinstance(data["data"], list) and len(data["data"]) > 0:
80
+ if (
81
+ isinstance(data["data"][0], dict)
82
+ and "embedding" in data["data"][0]
83
+ ):
84
+ # 格式: {"data": [{"embedding": [...]}, ...]}
85
+ embeddings = [item["embedding"] for item in data["data"]]
86
+ else:
87
+ # 格式: {"data": [[...], [...]]} - 直接是嵌入向量列表
88
+ embeddings = data["data"]
89
+ return embeddings
90
+ else:
91
+ raise ValueError("EdgeFn API 返回的 data 格式不正确")
92
+ elif "embedding" in data:
93
+ # 单个嵌入的情况(非标准格式)
94
+ return [data["embedding"]]
95
+ else:
96
+ # 尝试直接返回(如果 API 直接返回嵌入列表)
97
+ if isinstance(data, list):
98
+ return data
99
+ raise ValueError(
100
+ f"EdgeFn API 返回了意外的格式。响应键: {list(data.keys()) if isinstance(data, dict) else '非字典类型'}"
101
+ )
102
+
103
+ except ImportError:
104
+ raise ImportError(
105
+ "使用 EdgeFnEmbeddingModel 需要安装 requests 包: pip install requests"
106
+ )
107
+ except requests.exceptions.RequestException as e:
108
+ error_msg = str(e)
109
+ if hasattr(e, "response") and e.response is not None:
110
+ try:
111
+ error_detail = e.response.json()
112
+ error_msg = f"{error_msg} - 详情: {error_detail}"
113
+ except Exception:
114
+ error_msg = f"{error_msg} - 响应状态码: {e.response.status_code}"
115
+ raise RuntimeError(f"调用 EdgeFn API 时出错: {error_msg}")
116
+ except Exception as e:
117
+ raise RuntimeError(f"处理 EdgeFn API 响应时出错: {e}")