jarvis-ai-assistant 0.1.124__py3-none-any.whl → 0.1.125__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +18 -20
- jarvis/jarvis_code_agent/code_agent.py +195 -45
- jarvis/jarvis_code_agent/file_select.py +6 -19
- jarvis/jarvis_code_agent/patch.py +189 -310
- jarvis/jarvis_codebase/main.py +6 -2
- jarvis/jarvis_dev/main.py +6 -4
- jarvis/jarvis_git_squash/__init__.py +0 -0
- jarvis/jarvis_git_squash/main.py +81 -0
- jarvis/jarvis_lsp/cpp.py +1 -1
- jarvis/jarvis_lsp/go.py +1 -1
- jarvis/jarvis_lsp/registry.py +2 -2
- jarvis/jarvis_lsp/rust.py +1 -1
- jarvis/jarvis_multi_agent/__init__.py +1 -1
- jarvis/jarvis_platform/ai8.py +2 -1
- jarvis/jarvis_platform/base.py +19 -24
- jarvis/jarvis_platform/kimi.py +2 -3
- jarvis/jarvis_platform/ollama.py +3 -1
- jarvis/jarvis_platform/openai.py +1 -1
- jarvis/jarvis_platform/oyi.py +2 -1
- jarvis/jarvis_platform/registry.py +2 -1
- jarvis/jarvis_platform_manager/main.py +4 -6
- jarvis/jarvis_platform_manager/openai_test.py +0 -1
- jarvis/jarvis_rag/main.py +5 -2
- jarvis/jarvis_smart_shell/main.py +9 -4
- jarvis/jarvis_tools/ask_codebase.py +12 -7
- jarvis/jarvis_tools/ask_user.py +3 -2
- jarvis/jarvis_tools/base.py +21 -7
- jarvis/jarvis_tools/chdir.py +0 -1
- jarvis/jarvis_tools/code_review.py +13 -14
- jarvis/jarvis_tools/create_code_agent.py +2 -2
- jarvis/jarvis_tools/create_sub_agent.py +2 -2
- jarvis/jarvis_tools/execute_shell.py +3 -1
- jarvis/jarvis_tools/execute_shell_script.py +4 -4
- jarvis/jarvis_tools/file_operation.py +3 -2
- jarvis/jarvis_tools/git_commiter.py +5 -2
- jarvis/jarvis_tools/lsp_find_definition.py +1 -1
- jarvis/jarvis_tools/lsp_find_references.py +1 -1
- jarvis/jarvis_tools/lsp_get_diagnostics.py +19 -11
- jarvis/jarvis_tools/lsp_get_document_symbols.py +1 -1
- jarvis/jarvis_tools/lsp_prepare_rename.py +1 -1
- jarvis/jarvis_tools/lsp_validate_edit.py +1 -1
- jarvis/jarvis_tools/methodology.py +4 -1
- jarvis/jarvis_tools/rag.py +22 -15
- jarvis/jarvis_tools/read_code.py +4 -3
- jarvis/jarvis_tools/read_webpage.py +2 -1
- jarvis/jarvis_tools/registry.py +4 -1
- jarvis/jarvis_tools/{search.py → search_web.py} +5 -2
- jarvis/jarvis_tools/select_code_files.py +1 -1
- jarvis/jarvis_utils/__init__.py +19 -982
- jarvis/jarvis_utils/config.py +138 -0
- jarvis/jarvis_utils/embedding.py +201 -0
- jarvis/jarvis_utils/git_utils.py +120 -0
- jarvis/jarvis_utils/globals.py +82 -0
- jarvis/jarvis_utils/input.py +161 -0
- jarvis/jarvis_utils/methodology.py +128 -0
- jarvis/jarvis_utils/output.py +235 -0
- jarvis/jarvis_utils/utils.py +150 -0
- jarvis_ai_assistant-0.1.125.dist-info/METADATA +291 -0
- jarvis_ai_assistant-0.1.125.dist-info/RECORD +75 -0
- {jarvis_ai_assistant-0.1.124.dist-info → jarvis_ai_assistant-0.1.125.dist-info}/entry_points.txt +1 -0
- jarvis_ai_assistant-0.1.124.dist-info/METADATA +0 -460
- jarvis_ai_assistant-0.1.124.dist-info/RECORD +0 -65
- {jarvis_ai_assistant-0.1.124.dist-info → jarvis_ai_assistant-0.1.125.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.124.dist-info → jarvis_ai_assistant-0.1.125.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.124.dist-info → jarvis_ai_assistant-0.1.125.dist-info}/top_level.txt +0 -0
jarvis/jarvis_utils/__init__.py
CHANGED
|
@@ -1,987 +1,24 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
"""
|
|
2
|
+
Jarvis Utils Module
|
|
3
|
+
This module provides utility functions and classes used throughout the Jarvis system.
|
|
4
|
+
It includes various helper functions, configuration management, and common operations.
|
|
5
|
+
The module is organized into several submodules:
|
|
6
|
+
- config: Configuration management
|
|
7
|
+
- embedding: Text embedding utilities
|
|
8
|
+
- git_utils: Git repository operations
|
|
9
|
+
- input: User input handling
|
|
10
|
+
- methodology: Methodology management
|
|
11
|
+
- output: Output formatting
|
|
12
|
+
- utils: General utilities
|
|
13
|
+
"""
|
|
4
14
|
import os
|
|
5
|
-
from enum import Enum
|
|
6
|
-
from datetime import datetime
|
|
7
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
8
15
|
import colorama
|
|
9
|
-
from colorama import Fore, Style as ColoramaStyle
|
|
10
|
-
import numpy as np
|
|
11
|
-
from prompt_toolkit import PromptSession
|
|
12
|
-
from prompt_toolkit.styles import Style as PromptStyle
|
|
13
|
-
from prompt_toolkit.formatted_text import FormattedText
|
|
14
|
-
from sentence_transformers import SentenceTransformer
|
|
15
|
-
from tqdm import tqdm
|
|
16
|
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
17
|
-
import torch
|
|
18
|
-
import yaml
|
|
19
|
-
import faiss
|
|
20
|
-
from pygments.lexers import guess_lexer
|
|
21
|
-
from pygments.util import ClassNotFound
|
|
22
|
-
import psutil
|
|
23
|
-
from rich.console import Console
|
|
24
|
-
from rich.theme import Theme
|
|
25
|
-
from rich.panel import Panel
|
|
26
|
-
from rich.box import HEAVY
|
|
27
|
-
from rich.text import Text
|
|
28
16
|
from rich.traceback import install as install_rich_traceback
|
|
29
|
-
from
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
from prompt_toolkit.completion import Completer, Completion, PathCompleter
|
|
33
|
-
from prompt_toolkit.document import Document
|
|
34
|
-
from fuzzywuzzy import process
|
|
35
|
-
from prompt_toolkit.key_binding import KeyBindings
|
|
36
|
-
|
|
37
|
-
# 初始化colorama
|
|
17
|
+
# Re-export from new modules
|
|
18
|
+
# These imports are required for project functionality and may be used dynamically
|
|
19
|
+
# Initialize colorama for cross-platform colored text
|
|
38
20
|
colorama.init()
|
|
39
|
-
|
|
21
|
+
# Disable tokenizers parallelism to avoid issues with multiprocessing
|
|
40
22
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
current_agent_name = ""
|
|
44
|
-
|
|
45
|
-
# Install rich traceback handler
|
|
46
|
-
install_rich_traceback()
|
|
47
|
-
|
|
48
|
-
# Create console with custom theme
|
|
49
|
-
custom_theme = Theme({
|
|
50
|
-
"INFO": "yellow",
|
|
51
|
-
"WARNING": "yellow",
|
|
52
|
-
"ERROR": "red",
|
|
53
|
-
"SUCCESS": "green",
|
|
54
|
-
"SYSTEM": "cyan",
|
|
55
|
-
"CODE": "green",
|
|
56
|
-
"RESULT": "blue",
|
|
57
|
-
"PLANNING": "magenta",
|
|
58
|
-
"PROGRESS": "white",
|
|
59
|
-
"DEBUG": "blue",
|
|
60
|
-
"USER": "green",
|
|
61
|
-
"TOOL": "yellow",
|
|
62
|
-
})
|
|
63
|
-
|
|
64
|
-
console = Console(theme=custom_theme)
|
|
65
|
-
|
|
66
|
-
def make_agent_name(agent_name: str):
|
|
67
|
-
if agent_name in global_agents:
|
|
68
|
-
i = 1
|
|
69
|
-
while f"{agent_name}_{i}" in global_agents:
|
|
70
|
-
i += 1
|
|
71
|
-
return f"{agent_name}_{i}"
|
|
72
|
-
else:
|
|
73
|
-
return agent_name
|
|
74
|
-
|
|
75
|
-
def set_agent(agent_name: str, agent: Any):
|
|
76
|
-
global_agents.add(agent_name)
|
|
77
|
-
global current_agent_name
|
|
78
|
-
current_agent_name = agent_name
|
|
79
|
-
|
|
80
|
-
def get_agent_list():
|
|
81
|
-
return "[" + str(len(global_agents)) + "]" + current_agent_name if global_agents else ""
|
|
82
|
-
|
|
83
|
-
def delete_agent(agent_name: str):
|
|
84
|
-
if agent_name in global_agents:
|
|
85
|
-
global_agents.remove(agent_name)
|
|
86
|
-
global current_agent_name
|
|
87
|
-
current_agent_name = ""
|
|
88
|
-
|
|
89
|
-
class OutputType(Enum):
|
|
90
|
-
SYSTEM = "SYSTEM" # AI assistant message
|
|
91
|
-
CODE = "CODE" # Code related
|
|
92
|
-
RESULT = "RESULT" # Tool execution result
|
|
93
|
-
ERROR = "ERROR" # Error information
|
|
94
|
-
INFO = "INFO" # System prompt
|
|
95
|
-
PLANNING = "PLANNING" # Task planning
|
|
96
|
-
PROGRESS = "PROGRESS" # Execution progress
|
|
97
|
-
SUCCESS = "SUCCESS" # Success information
|
|
98
|
-
WARNING = "WARNING" # Warning information
|
|
99
|
-
DEBUG = "DEBUG" # Debug information
|
|
100
|
-
USER = "USER" # User input
|
|
101
|
-
TOOL = "TOOL" # Tool call
|
|
102
|
-
|
|
103
|
-
class PrettyOutput:
|
|
104
|
-
"""Pretty output using rich"""
|
|
105
|
-
|
|
106
|
-
# Icons for different output types
|
|
107
|
-
_ICONS = {
|
|
108
|
-
OutputType.SYSTEM: "🤖", # Robot - AI assistant
|
|
109
|
-
OutputType.CODE: "📝", # Notebook - Code
|
|
110
|
-
OutputType.RESULT: "✨", # Flash - Result
|
|
111
|
-
OutputType.ERROR: "❌", # Error - Error
|
|
112
|
-
OutputType.INFO: "ℹ️", # Info - Prompt
|
|
113
|
-
OutputType.PLANNING: "📋", # Clipboard - Planning
|
|
114
|
-
OutputType.PROGRESS: "⏳", # Hourglass - Progress
|
|
115
|
-
OutputType.SUCCESS: "✅", # Checkmark - Success
|
|
116
|
-
OutputType.WARNING: "⚠️", # Warning - Warning
|
|
117
|
-
OutputType.DEBUG: "🔍", # Magnifying glass - Debug
|
|
118
|
-
OutputType.USER: "👤", # User - User
|
|
119
|
-
OutputType.TOOL: "🔧", # Wrench - Tool
|
|
120
|
-
}
|
|
121
|
-
|
|
122
|
-
# Common language mapping dictionary
|
|
123
|
-
_lang_map = {
|
|
124
|
-
'Python': 'python',
|
|
125
|
-
'JavaScript': 'javascript',
|
|
126
|
-
'TypeScript': 'typescript',
|
|
127
|
-
'Java': 'java',
|
|
128
|
-
'C++': 'cpp',
|
|
129
|
-
'C#': 'csharp',
|
|
130
|
-
'Ruby': 'ruby',
|
|
131
|
-
'PHP': 'php',
|
|
132
|
-
'Go': 'go',
|
|
133
|
-
'Rust': 'rust',
|
|
134
|
-
'Bash': 'bash',
|
|
135
|
-
'HTML': 'html',
|
|
136
|
-
'CSS': 'css',
|
|
137
|
-
'SQL': 'sql',
|
|
138
|
-
'R': 'r',
|
|
139
|
-
'Kotlin': 'kotlin',
|
|
140
|
-
'Swift': 'swift',
|
|
141
|
-
'Scala': 'scala',
|
|
142
|
-
'Perl': 'perl',
|
|
143
|
-
'Lua': 'lua',
|
|
144
|
-
'YAML': 'yaml',
|
|
145
|
-
'JSON': 'json',
|
|
146
|
-
'XML': 'xml',
|
|
147
|
-
'Markdown': 'markdown',
|
|
148
|
-
'Text': 'text',
|
|
149
|
-
'Shell': 'bash',
|
|
150
|
-
'Dockerfile': 'dockerfile',
|
|
151
|
-
'Makefile': 'makefile',
|
|
152
|
-
'INI': 'ini',
|
|
153
|
-
'TOML': 'toml',
|
|
154
|
-
}
|
|
155
|
-
|
|
156
|
-
@staticmethod
|
|
157
|
-
def _detect_language(text: str, default_lang: str = 'markdown') -> str:
|
|
158
|
-
"""Helper method to detect language and map it to syntax highlighting name"""
|
|
159
|
-
try:
|
|
160
|
-
lexer = guess_lexer(text)
|
|
161
|
-
detected_lang = lexer.name
|
|
162
|
-
return PrettyOutput._lang_map.get(detected_lang, default_lang)
|
|
163
|
-
except ClassNotFound:
|
|
164
|
-
return default_lang
|
|
165
|
-
except Exception:
|
|
166
|
-
return default_lang
|
|
167
|
-
|
|
168
|
-
@staticmethod
|
|
169
|
-
def _format(output_type: OutputType, timestamp: bool = True) -> Text:
|
|
170
|
-
"""Format output text using rich Text"""
|
|
171
|
-
# Create rich Text object
|
|
172
|
-
formatted = Text()
|
|
173
|
-
|
|
174
|
-
# Add timestamp and agent info
|
|
175
|
-
if timestamp:
|
|
176
|
-
formatted.append(f"[{datetime.now().strftime('%H:%M:%S')}][{output_type.value}]", style=output_type.value)
|
|
177
|
-
agent_info = get_agent_list()
|
|
178
|
-
if agent_info: # Only add brackets if there's agent info
|
|
179
|
-
formatted.append(f"[{agent_info}]", style="blue")
|
|
180
|
-
# Add icon
|
|
181
|
-
icon = PrettyOutput._ICONS.get(output_type, "")
|
|
182
|
-
formatted.append(f" {icon} ", style=output_type.value)
|
|
183
|
-
|
|
184
|
-
return formatted
|
|
185
|
-
|
|
186
|
-
@staticmethod
|
|
187
|
-
def print(text: str, output_type: OutputType, timestamp: bool = True, lang: Optional[str] = None, traceback: bool = False):
|
|
188
|
-
"""Print formatted output using rich console with styling
|
|
189
|
-
|
|
190
|
-
Args:
|
|
191
|
-
text: The text content to print
|
|
192
|
-
output_type: The type of output (affects styling)
|
|
193
|
-
timestamp: Whether to show timestamp
|
|
194
|
-
lang: Language for syntax highlighting
|
|
195
|
-
traceback: Whether to show traceback for errors
|
|
196
|
-
"""
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
# Define styles for different output types
|
|
200
|
-
# Define styles for different output types
|
|
201
|
-
styles = {
|
|
202
|
-
OutputType.SYSTEM: RichStyle(
|
|
203
|
-
color="bright_cyan",
|
|
204
|
-
bold=True,
|
|
205
|
-
),
|
|
206
|
-
OutputType.CODE: RichStyle(
|
|
207
|
-
color="green",
|
|
208
|
-
bgcolor="#1a1a1a",
|
|
209
|
-
frame=True
|
|
210
|
-
),
|
|
211
|
-
OutputType.RESULT: RichStyle(
|
|
212
|
-
color="bright_blue",
|
|
213
|
-
bold=True,
|
|
214
|
-
bgcolor="navy_blue"
|
|
215
|
-
),
|
|
216
|
-
OutputType.ERROR: RichStyle(
|
|
217
|
-
color="red",
|
|
218
|
-
bold=True,
|
|
219
|
-
blink=True,
|
|
220
|
-
bgcolor="dark_red",
|
|
221
|
-
),
|
|
222
|
-
OutputType.INFO: RichStyle(
|
|
223
|
-
color="gold1",
|
|
224
|
-
dim=True,
|
|
225
|
-
bgcolor="grey11",
|
|
226
|
-
),
|
|
227
|
-
OutputType.PLANNING: RichStyle(
|
|
228
|
-
color="purple",
|
|
229
|
-
bold=True,
|
|
230
|
-
),
|
|
231
|
-
OutputType.PROGRESS: RichStyle(
|
|
232
|
-
color="white",
|
|
233
|
-
encircle=True,
|
|
234
|
-
),
|
|
235
|
-
OutputType.SUCCESS: RichStyle(
|
|
236
|
-
color="bright_green",
|
|
237
|
-
bold=True,
|
|
238
|
-
strike=False,
|
|
239
|
-
meta={"icon": "✓"},
|
|
240
|
-
),
|
|
241
|
-
OutputType.WARNING: RichStyle(
|
|
242
|
-
color="yellow",
|
|
243
|
-
bold=True,
|
|
244
|
-
blink2=True,
|
|
245
|
-
bgcolor="dark_orange",
|
|
246
|
-
),
|
|
247
|
-
OutputType.DEBUG: RichStyle(
|
|
248
|
-
color="grey58",
|
|
249
|
-
dim=True,
|
|
250
|
-
conceal=True
|
|
251
|
-
),
|
|
252
|
-
OutputType.USER: RichStyle(
|
|
253
|
-
color="spring_green2",
|
|
254
|
-
reverse=True,
|
|
255
|
-
frame=True,
|
|
256
|
-
),
|
|
257
|
-
OutputType.TOOL: RichStyle(
|
|
258
|
-
color="dark_sea_green4",
|
|
259
|
-
bgcolor="grey19",
|
|
260
|
-
)
|
|
261
|
-
}
|
|
262
|
-
|
|
263
|
-
# Get formatted header
|
|
264
|
-
lang = lang if lang is not None else PrettyOutput._detect_language(text, default_lang='markdown')
|
|
265
|
-
header = PrettyOutput._format(output_type, timestamp)
|
|
266
|
-
|
|
267
|
-
# Create syntax highlighted content
|
|
268
|
-
content = Syntax(
|
|
269
|
-
text,
|
|
270
|
-
lang,
|
|
271
|
-
theme="monokai",
|
|
272
|
-
word_wrap=True,
|
|
273
|
-
)
|
|
274
|
-
|
|
275
|
-
# Create panel with styling
|
|
276
|
-
panel = Panel(
|
|
277
|
-
content,
|
|
278
|
-
style=styles[output_type],
|
|
279
|
-
border_style=styles[output_type],
|
|
280
|
-
title=header,
|
|
281
|
-
title_align="left",
|
|
282
|
-
padding=(0, 0),
|
|
283
|
-
highlight=True,
|
|
284
|
-
box=HEAVY,
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
# Print panel
|
|
288
|
-
console.print(panel)
|
|
289
|
-
|
|
290
|
-
# Print stack trace for errors if requested
|
|
291
|
-
if traceback or output_type == OutputType.ERROR:
|
|
292
|
-
console.print_exception()
|
|
293
|
-
|
|
294
|
-
@staticmethod
|
|
295
|
-
def section(title: str, output_type: OutputType = OutputType.INFO):
|
|
296
|
-
"""Print section title in a panel"""
|
|
297
|
-
panel = Panel(
|
|
298
|
-
Text(title, style=output_type.value, justify="center"),
|
|
299
|
-
border_style=output_type.value
|
|
300
|
-
)
|
|
301
|
-
console.print()
|
|
302
|
-
console.print(panel)
|
|
303
|
-
console.print()
|
|
304
|
-
|
|
305
|
-
@staticmethod
|
|
306
|
-
def print_stream(text: str):
|
|
307
|
-
"""Print stream output without line break"""
|
|
308
|
-
# 使用进度类型样式
|
|
309
|
-
style = PrettyOutput._get_style(OutputType.SYSTEM)
|
|
310
|
-
console.print(text, style=style, end="")
|
|
311
|
-
|
|
312
|
-
@staticmethod
|
|
313
|
-
def print_stream_end():
|
|
314
|
-
"""End stream output with line break"""
|
|
315
|
-
# 结束符样式
|
|
316
|
-
end_style = PrettyOutput._get_style(OutputType.SUCCESS)
|
|
317
|
-
console.print("\n", style=end_style)
|
|
318
|
-
console.file.flush()
|
|
319
|
-
|
|
320
|
-
@staticmethod
|
|
321
|
-
def _get_style(output_type: OutputType) -> RichStyle:
|
|
322
|
-
"""Get pre-defined RichStyle for output type"""
|
|
323
|
-
return console.get_style(output_type.value)
|
|
324
|
-
|
|
325
|
-
def get_single_line_input(tip: str) -> str:
|
|
326
|
-
"""Get single line input, support direction key, history function, etc."""
|
|
327
|
-
session = PromptSession(history=None)
|
|
328
|
-
style = PromptStyle.from_dict({
|
|
329
|
-
'prompt': 'ansicyan',
|
|
330
|
-
})
|
|
331
|
-
return session.prompt(f"{tip}", style=style)
|
|
332
|
-
|
|
333
|
-
class FileCompleter(Completer):
|
|
334
|
-
"""Custom completer for file paths with fuzzy matching."""
|
|
335
|
-
def __init__(self):
|
|
336
|
-
self.path_completer = PathCompleter()
|
|
337
|
-
self.max_suggestions = 10
|
|
338
|
-
self.min_score = 10
|
|
339
|
-
|
|
340
|
-
def get_completions(self, document: Document, complete_event):
|
|
341
|
-
text = document.text_before_cursor
|
|
342
|
-
cursor_pos = document.cursor_position
|
|
343
|
-
|
|
344
|
-
# Find all @ positions in text
|
|
345
|
-
at_positions = [i for i, char in enumerate(text) if char == '@']
|
|
346
|
-
|
|
347
|
-
if not at_positions:
|
|
348
|
-
return
|
|
349
|
-
|
|
350
|
-
# Get the last @ position
|
|
351
|
-
current_at_pos = at_positions[-1]
|
|
352
|
-
|
|
353
|
-
# If cursor is not after the last @, don't complete
|
|
354
|
-
if cursor_pos <= current_at_pos:
|
|
355
|
-
return
|
|
356
|
-
|
|
357
|
-
# Check if there's a space after @
|
|
358
|
-
text_after_at = text[current_at_pos + 1:cursor_pos]
|
|
359
|
-
if ' ' in text_after_at:
|
|
360
|
-
return
|
|
361
|
-
|
|
362
|
-
# Get the text after the current @
|
|
363
|
-
file_path = text_after_at.strip()
|
|
364
|
-
|
|
365
|
-
# 计算需要删除的字符数(包括@符号)
|
|
366
|
-
replace_length = len(text_after_at) + 1 # +1 包含@符号
|
|
367
|
-
|
|
368
|
-
# Get all possible files using git ls-files only
|
|
369
|
-
all_files = []
|
|
370
|
-
try:
|
|
371
|
-
# Use git ls-files to get tracked files
|
|
372
|
-
import subprocess
|
|
373
|
-
result = subprocess.run(['git', 'ls-files'],
|
|
374
|
-
stdout=subprocess.PIPE,
|
|
375
|
-
stderr=subprocess.PIPE,
|
|
376
|
-
text=True)
|
|
377
|
-
if result.returncode == 0:
|
|
378
|
-
all_files = [line.strip() for line in result.stdout.splitlines() if line.strip()]
|
|
379
|
-
except Exception:
|
|
380
|
-
# If git command fails, just use an empty list
|
|
381
|
-
pass
|
|
382
|
-
|
|
383
|
-
# If no input after @, show all files
|
|
384
|
-
# Otherwise use fuzzy matching
|
|
385
|
-
if not file_path:
|
|
386
|
-
scored_files = [(path, 100) for path in all_files[:self.max_suggestions]]
|
|
387
|
-
else:
|
|
388
|
-
scored_files_data = process.extract(file_path, all_files, limit=self.max_suggestions)
|
|
389
|
-
scored_files = [
|
|
390
|
-
(m[0], m[1])
|
|
391
|
-
for m in scored_files_data
|
|
392
|
-
]
|
|
393
|
-
# Sort by score and take top results
|
|
394
|
-
scored_files.sort(key=lambda x: x[1], reverse=True)
|
|
395
|
-
scored_files = scored_files[:self.max_suggestions]
|
|
396
|
-
|
|
397
|
-
# Return completions for files
|
|
398
|
-
for path, score in scored_files:
|
|
399
|
-
if not file_path or score > self.min_score:
|
|
400
|
-
display_text = path # 显示时不带反引号
|
|
401
|
-
if file_path and score < 100:
|
|
402
|
-
display_text = f"{path} ({score}%)"
|
|
403
|
-
completion = Completion(
|
|
404
|
-
text=f"'{path}'", # 添加单引号包裹路径
|
|
405
|
-
start_position=-replace_length,
|
|
406
|
-
display=display_text,
|
|
407
|
-
display_meta="File"
|
|
408
|
-
)
|
|
409
|
-
yield completion
|
|
410
|
-
|
|
411
|
-
def get_multiline_input(tip: str) -> str:
|
|
412
|
-
"""Get multi-line input with enhanced completion confirmation"""
|
|
413
|
-
# 单行输入说明
|
|
414
|
-
PrettyOutput.section("用户输入 - 使用 @ 触发文件补全,Tab 选择补全项,Ctrl+J 提交,按 Ctrl+C 取消输入", OutputType.USER)
|
|
415
|
-
|
|
416
|
-
print(f"{Fore.GREEN}{tip}{ColoramaStyle.RESET_ALL}")
|
|
417
|
-
|
|
418
|
-
# 自定义按键绑定
|
|
419
|
-
bindings = KeyBindings()
|
|
420
|
-
|
|
421
|
-
@bindings.add('enter')
|
|
422
|
-
def _(event):
|
|
423
|
-
# 当有补全菜单时,回车键确认补全
|
|
424
|
-
if event.current_buffer.complete_state:
|
|
425
|
-
event.current_buffer.apply_completion(event.current_buffer.complete_state.current_completion)
|
|
426
|
-
else:
|
|
427
|
-
# 没有补全菜单时插入换行
|
|
428
|
-
event.current_buffer.insert_text('\n')
|
|
429
|
-
|
|
430
|
-
@bindings.add('c-j') # 修改为支持的按键组合
|
|
431
|
-
def _(event):
|
|
432
|
-
# 使用 Ctrl+J 提交输入
|
|
433
|
-
event.current_buffer.validate_and_handle()
|
|
434
|
-
|
|
435
|
-
style = PromptStyle.from_dict({
|
|
436
|
-
'prompt': 'ansicyan',
|
|
437
|
-
})
|
|
438
|
-
|
|
439
|
-
try:
|
|
440
|
-
session = PromptSession(
|
|
441
|
-
history=None,
|
|
442
|
-
completer=FileCompleter(),
|
|
443
|
-
key_bindings=bindings,
|
|
444
|
-
complete_while_typing=True,
|
|
445
|
-
multiline=True, # 启用原生多行支持
|
|
446
|
-
vi_mode=False,
|
|
447
|
-
mouse_support=False
|
|
448
|
-
)
|
|
449
|
-
|
|
450
|
-
prompt = FormattedText([
|
|
451
|
-
('class:prompt', '>>> ')
|
|
452
|
-
])
|
|
453
|
-
|
|
454
|
-
# 单次获取多行输入
|
|
455
|
-
text = session.prompt(
|
|
456
|
-
prompt,
|
|
457
|
-
style=style,
|
|
458
|
-
).strip()
|
|
459
|
-
|
|
460
|
-
return text
|
|
461
|
-
|
|
462
|
-
except KeyboardInterrupt:
|
|
463
|
-
PrettyOutput.print("输入已取消", OutputType.INFO)
|
|
464
|
-
return ""
|
|
465
|
-
|
|
466
|
-
def init_env():
|
|
467
|
-
"""Load environment variables from ~/.jarvis/env"""
|
|
468
|
-
jarvis_dir = Path.home() / ".jarvis"
|
|
469
|
-
env_file = jarvis_dir / "env"
|
|
470
|
-
|
|
471
|
-
# Check if ~/.jarvis directory exists
|
|
472
|
-
if not jarvis_dir.exists():
|
|
473
|
-
jarvis_dir.mkdir(parents=True)
|
|
474
|
-
|
|
475
|
-
if env_file.exists():
|
|
476
|
-
try:
|
|
477
|
-
with open(env_file, "r", encoding="utf-8") as f:
|
|
478
|
-
for line in f:
|
|
479
|
-
line = line.strip()
|
|
480
|
-
if line and not line.startswith(("#", ";")):
|
|
481
|
-
try:
|
|
482
|
-
key, value = line.split("=", 1)
|
|
483
|
-
os.environ[key.strip()] = value.strip().strip("'").strip('"')
|
|
484
|
-
except ValueError:
|
|
485
|
-
continue
|
|
486
|
-
except Exception as e:
|
|
487
|
-
PrettyOutput.print(f"警告: 读取 {env_file} 失败: {e}", OutputType.WARNING)
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
def while_success(func, sleep_time: float = 0.1):
|
|
491
|
-
while True:
|
|
492
|
-
try:
|
|
493
|
-
return func()
|
|
494
|
-
except Exception as e:
|
|
495
|
-
PrettyOutput.print(f"执行失败: {str(e)}, 等待 {sleep_time}s...", OutputType.ERROR)
|
|
496
|
-
time.sleep(sleep_time)
|
|
497
|
-
continue
|
|
498
|
-
|
|
499
|
-
def while_true(func, sleep_time: float = 0.1):
|
|
500
|
-
"""Loop execution function, until the function returns True"""
|
|
501
|
-
while True:
|
|
502
|
-
ret = func()
|
|
503
|
-
if ret:
|
|
504
|
-
break
|
|
505
|
-
PrettyOutput.print(f"执行失败, 等待 {sleep_time}s...", OutputType.WARNING)
|
|
506
|
-
time.sleep(sleep_time)
|
|
507
|
-
return ret
|
|
508
|
-
|
|
509
|
-
def find_git_root(start_dir="."):
|
|
510
|
-
"""Change to git root directory of the given path"""
|
|
511
|
-
os.chdir(start_dir)
|
|
512
|
-
git_root = os.popen("git rev-parse --show-toplevel").read().strip()
|
|
513
|
-
os.chdir(git_root)
|
|
514
|
-
return git_root
|
|
515
|
-
|
|
516
|
-
def has_uncommitted_changes():
|
|
517
|
-
import subprocess
|
|
518
|
-
# Add all changes silently
|
|
519
|
-
subprocess.run(["git", "add", "."], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
520
|
-
|
|
521
|
-
# Check working directory changes
|
|
522
|
-
working_changes = subprocess.run(["git", "diff", "--exit-code"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode != 0
|
|
523
|
-
|
|
524
|
-
# Check staged changes
|
|
525
|
-
staged_changes = subprocess.run(["git", "diff", "--cached", "--exit-code"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode != 0
|
|
526
|
-
|
|
527
|
-
# Reset changes silently
|
|
528
|
-
subprocess.run(["git", "reset"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
529
|
-
|
|
530
|
-
return working_changes or staged_changes
|
|
531
|
-
def get_commits_between(start_hash: str, end_hash: str) -> List[Tuple[str, str]]:
|
|
532
|
-
"""Get list of commits between two commit hashes
|
|
533
|
-
|
|
534
|
-
Args:
|
|
535
|
-
start_hash: Starting commit hash (exclusive)
|
|
536
|
-
end_hash: Ending commit hash (inclusive)
|
|
537
|
-
|
|
538
|
-
Returns:
|
|
539
|
-
List[Tuple[str, str]]: List of (commit_hash, commit_message) tuples
|
|
540
|
-
"""
|
|
541
|
-
try:
|
|
542
|
-
import subprocess
|
|
543
|
-
# Use git log with pretty format to get hash and message
|
|
544
|
-
result = subprocess.run(
|
|
545
|
-
['git', 'log', f'{start_hash}..{end_hash}', '--pretty=format:%H|%s'],
|
|
546
|
-
stdout=subprocess.PIPE,
|
|
547
|
-
stderr=subprocess.PIPE,
|
|
548
|
-
text=True
|
|
549
|
-
)
|
|
550
|
-
if result.returncode != 0:
|
|
551
|
-
PrettyOutput.print(f"获取commit历史失败: {result.stderr}", OutputType.ERROR)
|
|
552
|
-
return []
|
|
553
|
-
|
|
554
|
-
commits = []
|
|
555
|
-
for line in result.stdout.splitlines():
|
|
556
|
-
if '|' in line:
|
|
557
|
-
commit_hash, message = line.split('|', 1)
|
|
558
|
-
commits.append((commit_hash, message))
|
|
559
|
-
return commits
|
|
560
|
-
|
|
561
|
-
except Exception as e:
|
|
562
|
-
PrettyOutput.print(f"获取commit历史异常: {str(e)}", OutputType.ERROR)
|
|
563
|
-
return []
|
|
564
|
-
def get_latest_commit_hash() -> str:
|
|
565
|
-
"""Get the latest commit hash of the current git repository
|
|
566
|
-
|
|
567
|
-
Returns:
|
|
568
|
-
str: The commit hash, or empty string if not in a git repo or error occurs
|
|
569
|
-
"""
|
|
570
|
-
try:
|
|
571
|
-
import subprocess
|
|
572
|
-
result = subprocess.run(
|
|
573
|
-
['git', 'rev-parse', 'HEAD'],
|
|
574
|
-
stdout=subprocess.PIPE,
|
|
575
|
-
stderr=subprocess.PIPE,
|
|
576
|
-
text=True
|
|
577
|
-
)
|
|
578
|
-
if result.returncode == 0:
|
|
579
|
-
return result.stdout.strip()
|
|
580
|
-
return ""
|
|
581
|
-
except Exception:
|
|
582
|
-
return ""
|
|
583
|
-
|
|
584
|
-
def load_embedding_model():
|
|
585
|
-
model_name = "BAAI/bge-m3"
|
|
586
|
-
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
try:
|
|
590
|
-
# Load model
|
|
591
|
-
embedding_model = SentenceTransformer(
|
|
592
|
-
model_name,
|
|
593
|
-
cache_folder=cache_dir,
|
|
594
|
-
local_files_only=True
|
|
595
|
-
)
|
|
596
|
-
except Exception as e:
|
|
597
|
-
# Load model
|
|
598
|
-
embedding_model = SentenceTransformer(
|
|
599
|
-
model_name,
|
|
600
|
-
cache_folder=cache_dir,
|
|
601
|
-
local_files_only=False
|
|
602
|
-
)
|
|
603
|
-
|
|
604
|
-
return embedding_model
|
|
605
|
-
|
|
606
|
-
def load_tokenizer():
|
|
607
|
-
"""Load tokenizer"""
|
|
608
|
-
model_name = "gpt2"
|
|
609
|
-
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
610
|
-
|
|
611
|
-
try:
|
|
612
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
613
|
-
model_name,
|
|
614
|
-
cache_dir=cache_dir,
|
|
615
|
-
local_files_only=True
|
|
616
|
-
)
|
|
617
|
-
except Exception as e:
|
|
618
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
619
|
-
model_name,
|
|
620
|
-
cache_dir=cache_dir,
|
|
621
|
-
local_files_only=False
|
|
622
|
-
)
|
|
623
|
-
|
|
624
|
-
return tokenizer
|
|
625
|
-
|
|
626
|
-
def load_rerank_model():
|
|
627
|
-
"""Load reranking model"""
|
|
628
|
-
model_name = "BAAI/bge-reranker-v2-m3"
|
|
629
|
-
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
630
|
-
|
|
631
|
-
PrettyOutput.print(f"加载重排序模型: {model_name}...", OutputType.INFO)
|
|
632
|
-
|
|
633
|
-
try:
|
|
634
|
-
# Load model and tokenizer
|
|
635
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
636
|
-
model_name,
|
|
637
|
-
cache_dir=cache_dir,
|
|
638
|
-
local_files_only=True
|
|
639
|
-
)
|
|
640
|
-
model = AutoModelForSequenceClassification.from_pretrained(
|
|
641
|
-
model_name,
|
|
642
|
-
cache_dir=cache_dir,
|
|
643
|
-
local_files_only=True
|
|
644
|
-
)
|
|
645
|
-
except Exception as e:
|
|
646
|
-
# Load model and tokenizer
|
|
647
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
648
|
-
model_name,
|
|
649
|
-
cache_dir=cache_dir,
|
|
650
|
-
local_files_only=False
|
|
651
|
-
)
|
|
652
|
-
model = AutoModelForSequenceClassification.from_pretrained(
|
|
653
|
-
model_name,
|
|
654
|
-
cache_dir=cache_dir,
|
|
655
|
-
local_files_only=False
|
|
656
|
-
)
|
|
657
|
-
|
|
658
|
-
# Use GPU if available
|
|
659
|
-
if torch.cuda.is_available():
|
|
660
|
-
model = model.cuda()
|
|
661
|
-
model.eval()
|
|
662
|
-
|
|
663
|
-
return model, tokenizer
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
def is_long_context(files: list) -> bool:
|
|
668
|
-
"""Check if the file list belongs to a long context (total characters exceed 80% of the maximum context length)"""
|
|
669
|
-
max_token_count = get_max_token_count()
|
|
670
|
-
threshold = max_token_count * 0.8
|
|
671
|
-
total_tokens = 0
|
|
672
|
-
|
|
673
|
-
for file_path in files:
|
|
674
|
-
try:
|
|
675
|
-
with open(file_path, 'r', encoding='utf-8') as f:
|
|
676
|
-
content = f.read()
|
|
677
|
-
total_tokens += get_context_token_count(content)
|
|
678
|
-
|
|
679
|
-
if total_tokens > threshold:
|
|
680
|
-
return True
|
|
681
|
-
except Exception as e:
|
|
682
|
-
PrettyOutput.print(f"读取文件 {file_path} 失败: {e}", OutputType.WARNING)
|
|
683
|
-
continue
|
|
684
|
-
|
|
685
|
-
return total_tokens > threshold
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
def get_file_md5(filepath: str)->str:
|
|
690
|
-
return hashlib.md5(open(filepath, "rb").read(100*1024*1024)).hexdigest()
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
def _create_methodology_embedding(embedding_model: Any, methodology_text: str) -> np.ndarray:
|
|
696
|
-
"""Create embedding vector for methodology text"""
|
|
697
|
-
try:
|
|
698
|
-
# Truncate long text
|
|
699
|
-
max_length = 512
|
|
700
|
-
text = ' '.join(methodology_text.split()[:max_length])
|
|
701
|
-
|
|
702
|
-
# 使用sentence_transformers模型获取嵌入向量
|
|
703
|
-
embedding = embedding_model.encode([text],
|
|
704
|
-
convert_to_tensor=True,
|
|
705
|
-
normalize_embeddings=True)
|
|
706
|
-
vector = np.array(embedding.cpu().numpy(), dtype=np.float32)
|
|
707
|
-
return vector[0] # Return first vector, because we only encoded one text
|
|
708
|
-
except Exception as e:
|
|
709
|
-
PrettyOutput.print(f"创建方法论嵌入向量失败: {str(e)}", OutputType.ERROR)
|
|
710
|
-
return np.zeros(1536, dtype=np.float32)
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
def load_methodology(user_input: str) -> str:
|
|
714
|
-
"""Load methodology and build vector index"""
|
|
715
|
-
PrettyOutput.print("加载方法论...", OutputType.PROGRESS)
|
|
716
|
-
user_jarvis_methodology = os.path.expanduser("~/.jarvis/methodology")
|
|
717
|
-
if not os.path.exists(user_jarvis_methodology):
|
|
718
|
-
return ""
|
|
719
|
-
|
|
720
|
-
def make_methodology_prompt(data: Dict) -> str:
|
|
721
|
-
ret = """This is the standard methodology for handling previous problems, if the current task is similar, you can refer to it, if not,just ignore it:\n"""
|
|
722
|
-
for key, value in data.items():
|
|
723
|
-
ret += f"Problem: {key}\nMethodology: {value}\n"
|
|
724
|
-
return ret
|
|
725
|
-
|
|
726
|
-
try:
|
|
727
|
-
with open(user_jarvis_methodology, "r", encoding="utf-8") as f:
|
|
728
|
-
data = yaml.safe_load(f)
|
|
729
|
-
|
|
730
|
-
if dont_use_local_model():
|
|
731
|
-
return make_methodology_prompt(data)
|
|
732
|
-
|
|
733
|
-
# Reset data structure
|
|
734
|
-
methodology_data = []
|
|
735
|
-
vectors = []
|
|
736
|
-
ids = []
|
|
737
|
-
|
|
738
|
-
# Get embedding model
|
|
739
|
-
embedding_model = load_embedding_model()
|
|
740
|
-
|
|
741
|
-
# Create test embedding to get correct dimension
|
|
742
|
-
test_embedding = _create_methodology_embedding(embedding_model, "test")
|
|
743
|
-
embedding_dimension = len(test_embedding)
|
|
744
|
-
|
|
745
|
-
# Create embedding vector for each methodology
|
|
746
|
-
for i, (key, value) in enumerate(data.items()):
|
|
747
|
-
methodology_text = f"{key}\n{value}"
|
|
748
|
-
embedding = _create_methodology_embedding(embedding_model, methodology_text)
|
|
749
|
-
vectors.append(embedding)
|
|
750
|
-
ids.append(i)
|
|
751
|
-
methodology_data.append({"key": key, "value": value})
|
|
752
|
-
|
|
753
|
-
if vectors:
|
|
754
|
-
vectors_array = np.vstack(vectors)
|
|
755
|
-
# Use correct dimension from test embedding
|
|
756
|
-
hnsw_index = faiss.IndexHNSWFlat(embedding_dimension, 16)
|
|
757
|
-
hnsw_index.hnsw.efConstruction = 40
|
|
758
|
-
hnsw_index.hnsw.efSearch = 16
|
|
759
|
-
methodology_index = faiss.IndexIDMap(hnsw_index)
|
|
760
|
-
methodology_index.add_with_ids(vectors_array, np.array(ids)) # type: ignore
|
|
761
|
-
query_embedding = _create_methodology_embedding(embedding_model, user_input)
|
|
762
|
-
k = min(3, len(methodology_data))
|
|
763
|
-
PrettyOutput.print(f"检索方法论...", OutputType.INFO)
|
|
764
|
-
distances, indices = methodology_index.search(
|
|
765
|
-
query_embedding.reshape(1, -1), k
|
|
766
|
-
) # type: ignore
|
|
767
|
-
|
|
768
|
-
relevant_methodologies = {}
|
|
769
|
-
output_lines = []
|
|
770
|
-
for dist, idx in zip(distances[0], indices[0]):
|
|
771
|
-
if idx >= 0:
|
|
772
|
-
similarity = 1.0 / (1.0 + float(dist))
|
|
773
|
-
methodology = methodology_data[idx]
|
|
774
|
-
output_lines.append(
|
|
775
|
-
f"Methodology '{methodology['key']}' similarity: {similarity:.3f}"
|
|
776
|
-
)
|
|
777
|
-
if similarity >= 0.5:
|
|
778
|
-
relevant_methodologies[methodology["key"]] = methodology["value"]
|
|
779
|
-
|
|
780
|
-
if output_lines:
|
|
781
|
-
PrettyOutput.print("\n".join(output_lines), OutputType.INFO)
|
|
782
|
-
|
|
783
|
-
if relevant_methodologies:
|
|
784
|
-
return make_methodology_prompt(relevant_methodologies)
|
|
785
|
-
return make_methodology_prompt(data)
|
|
786
|
-
|
|
787
|
-
except Exception as e:
|
|
788
|
-
PrettyOutput.print(f"加载方法论失败: {str(e)}", OutputType.ERROR)
|
|
789
|
-
return ""
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
def user_confirm(tip: str, default: bool = True) -> bool:
|
|
793
|
-
"""Prompt the user for confirmation.
|
|
794
|
-
|
|
795
|
-
Args:
|
|
796
|
-
tip: The message to show to the user
|
|
797
|
-
default: The default response if user hits enter
|
|
798
|
-
|
|
799
|
-
Returns:
|
|
800
|
-
bool: True if user confirmed, False otherwise
|
|
801
|
-
"""
|
|
802
|
-
suffix = "[Y/n]" if default else "[y/N]"
|
|
803
|
-
ret = get_single_line_input(f"{tip} {suffix}: ")
|
|
804
|
-
return default if ret == "" else ret.lower() == "y"
|
|
805
|
-
|
|
806
|
-
def get_file_line_count(filename: str) -> int:
|
|
807
|
-
try:
|
|
808
|
-
return len(open(filename, "r", encoding="utf-8").readlines())
|
|
809
|
-
except Exception as e:
|
|
810
|
-
return 0
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
def init_gpu_config() -> Dict:
|
|
814
|
-
"""Initialize GPU configuration based on available hardware
|
|
815
|
-
|
|
816
|
-
Returns:
|
|
817
|
-
Dict: GPU configuration including memory sizes and availability
|
|
818
|
-
"""
|
|
819
|
-
config = {
|
|
820
|
-
"has_gpu": False,
|
|
821
|
-
"shared_memory": 0,
|
|
822
|
-
"device_memory": 0,
|
|
823
|
-
"memory_fraction": 0.8 # 默认使用80%的可用内存
|
|
824
|
-
}
|
|
825
|
-
|
|
826
|
-
try:
|
|
827
|
-
import torch
|
|
828
|
-
if torch.cuda.is_available():
|
|
829
|
-
# 获取GPU信息
|
|
830
|
-
gpu_mem = torch.cuda.get_device_properties(0).total_memory
|
|
831
|
-
config["has_gpu"] = True
|
|
832
|
-
config["device_memory"] = gpu_mem
|
|
833
|
-
|
|
834
|
-
# 估算共享内存 (通常是系统内存的一部分)
|
|
835
|
-
|
|
836
|
-
system_memory = psutil.virtual_memory().total
|
|
837
|
-
config["shared_memory"] = min(system_memory * 0.5, gpu_mem * 2) # 取系统内存的50%或GPU内存的2倍中的较小值
|
|
838
|
-
|
|
839
|
-
# 设置CUDA内存分配
|
|
840
|
-
torch.cuda.set_per_process_memory_fraction(config["memory_fraction"])
|
|
841
|
-
torch.cuda.empty_cache()
|
|
842
|
-
|
|
843
|
-
PrettyOutput.print(
|
|
844
|
-
f"GPU已初始化: {torch.cuda.get_device_name(0)}\n"
|
|
845
|
-
f"设备内存: {gpu_mem / 1024**3:.1f}GB\n"
|
|
846
|
-
f"共享内存: {config['shared_memory'] / 1024**3:.1f}GB",
|
|
847
|
-
output_type=OutputType.SUCCESS
|
|
848
|
-
)
|
|
849
|
-
else:
|
|
850
|
-
PrettyOutput.print("没有GPU可用, 使用CPU模式", output_type=OutputType.WARNING)
|
|
851
|
-
except Exception as e:
|
|
852
|
-
PrettyOutput.print(f"GPU初始化失败: {str(e)}", output_type=OutputType.WARNING)
|
|
853
|
-
|
|
854
|
-
return config
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
def get_embedding(embedding_model: Any, text: str) -> np.ndarray:
|
|
858
|
-
"""Get the vector representation of the text"""
|
|
859
|
-
embedding = embedding_model.encode(text,
|
|
860
|
-
normalize_embeddings=True,
|
|
861
|
-
show_progress_bar=False)
|
|
862
|
-
return np.array(embedding, dtype=np.float32)
|
|
863
|
-
|
|
864
|
-
def get_embedding_batch(embedding_model: Any, texts: List[str]) -> np.ndarray:
|
|
865
|
-
"""Get embeddings for a batch of texts efficiently"""
|
|
866
|
-
try:
|
|
867
|
-
all_vectors = []
|
|
868
|
-
for text in texts:
|
|
869
|
-
vectors = get_embedding_with_chunks(embedding_model, text)
|
|
870
|
-
all_vectors.extend(vectors)
|
|
871
|
-
return np.vstack(all_vectors)
|
|
872
|
-
except Exception as e:
|
|
873
|
-
PrettyOutput.print(f"批量嵌入失败: {str(e)}", OutputType.ERROR)
|
|
874
|
-
return np.zeros((0, embedding_model.get_sentence_embedding_dimension()), dtype=np.float32)
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
def get_max_token_count():
|
|
879
|
-
return int(os.getenv('JARVIS_MAX_TOKEN_COUNT', '131072')) # 默认128k
|
|
880
|
-
|
|
881
|
-
def get_thread_count():
|
|
882
|
-
return int(os.getenv('JARVIS_THREAD_COUNT', '1'))
|
|
883
|
-
|
|
884
|
-
def dont_use_local_model():
|
|
885
|
-
return os.getenv('JARVIS_DONT_USE_LOCAL_MODEL', 'false') == 'true'
|
|
886
|
-
|
|
887
|
-
def is_auto_complete() -> bool:
|
|
888
|
-
return os.getenv('JARVIS_AUTO_COMPLETE', 'false') == 'true'
|
|
889
|
-
|
|
890
|
-
def is_use_methodology() -> bool:
|
|
891
|
-
return os.getenv('JARVIS_USE_METHODOLOGY', 'true') == 'true'
|
|
892
|
-
|
|
893
|
-
def is_record_methodology() -> bool:
|
|
894
|
-
return os.getenv('JARVIS_RECORD_METHODOLOGY', 'true') == 'true'
|
|
895
|
-
|
|
896
|
-
def is_need_summary() -> bool:
|
|
897
|
-
return os.getenv('JARVIS_NEED_SUMMARY', 'true') == 'true'
|
|
898
|
-
|
|
899
|
-
def get_min_paragraph_length() -> int:
|
|
900
|
-
return int(os.getenv('JARVIS_MIN_PARAGRAPH_LENGTH', '50'))
|
|
901
|
-
|
|
902
|
-
def get_max_paragraph_length() -> int:
|
|
903
|
-
return int(os.getenv('JARVIS_MAX_PARAGRAPH_LENGTH', '12800'))
|
|
904
|
-
|
|
905
|
-
def get_shell_name() -> str:
|
|
906
|
-
return os.getenv('SHELL', 'bash')
|
|
907
|
-
|
|
908
|
-
def get_normal_platform_name() -> str:
|
|
909
|
-
return os.getenv('JARVIS_PLATFORM', 'kimi')
|
|
910
|
-
|
|
911
|
-
def get_normal_model_name() -> str:
|
|
912
|
-
return os.getenv('JARVIS_MODEL', 'kimi')
|
|
913
|
-
|
|
914
|
-
def get_codegen_platform_name() -> str:
|
|
915
|
-
return os.getenv('JARVIS_CODEGEN_PLATFORM', os.getenv('JARVIS_PLATFORM', 'kimi'))
|
|
916
|
-
|
|
917
|
-
def get_codegen_model_name() -> str:
|
|
918
|
-
return os.getenv('JARVIS_CODEGEN_MODEL', os.getenv('JARVIS_MODEL', 'kimi'))
|
|
919
|
-
|
|
920
|
-
def get_thinking_platform_name() -> str:
|
|
921
|
-
return os.getenv('JARVIS_THINKING_PLATFORM', os.getenv('JARVIS_PLATFORM', 'kimi'))
|
|
922
|
-
|
|
923
|
-
def get_thinking_model_name() -> str:
|
|
924
|
-
return os.getenv('JARVIS_THINKING_MODEL', os.getenv('JARVIS_MODEL', 'kimi'))
|
|
925
|
-
|
|
926
|
-
def get_cheap_platform_name() -> str:
|
|
927
|
-
return os.getenv('JARVIS_CHEAP_PLATFORM', os.getenv('JARVIS_PLATFORM', 'kimi'))
|
|
928
|
-
|
|
929
|
-
def get_cheap_model_name() -> str:
|
|
930
|
-
return os.getenv('JARVIS_CHEAP_MODEL', os.getenv('JARVIS_MODEL', 'kimi'))
|
|
931
|
-
|
|
932
|
-
def is_execute_tool_confirm() -> bool:
|
|
933
|
-
return os.getenv('JARVIS_EXECUTE_TOOL_CONFIRM', 'false') == 'true'
|
|
934
|
-
|
|
935
|
-
def split_text_into_chunks(text: str, max_length: int = 512) -> List[str]:
|
|
936
|
-
"""Split text into chunks with overlapping windows"""
|
|
937
|
-
chunks = []
|
|
938
|
-
start = 0
|
|
939
|
-
while start < len(text):
|
|
940
|
-
end = start + max_length
|
|
941
|
-
# Find the nearest sentence boundary
|
|
942
|
-
if end < len(text):
|
|
943
|
-
while end > start and text[end] not in {'.', '!', '?', '\n'}:
|
|
944
|
-
end -= 1
|
|
945
|
-
if end == start: # No punctuation found, hard cut
|
|
946
|
-
end = start + max_length
|
|
947
|
-
chunk = text[start:end]
|
|
948
|
-
chunks.append(chunk)
|
|
949
|
-
# Overlap 20% of the window
|
|
950
|
-
start = end - int(max_length * 0.2)
|
|
951
|
-
return chunks
|
|
952
|
-
|
|
953
|
-
def get_embedding_with_chunks(embedding_model: Any, text: str) -> List[np.ndarray]:
|
|
954
|
-
"""Get embeddings for text chunks"""
|
|
955
|
-
chunks = split_text_into_chunks(text, 512)
|
|
956
|
-
if not chunks:
|
|
957
|
-
return []
|
|
958
|
-
|
|
959
|
-
vectors = []
|
|
960
|
-
for chunk in chunks:
|
|
961
|
-
vector = get_embedding(embedding_model, chunk)
|
|
962
|
-
vectors.append(vector)
|
|
963
|
-
return vectors
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
def get_context_token_count(text: str) -> int:
|
|
967
|
-
"""Get the token count of the text using the tokenizer
|
|
968
|
-
|
|
969
|
-
Args:
|
|
970
|
-
text: The input text to count tokens for
|
|
971
|
-
|
|
972
|
-
Returns:
|
|
973
|
-
int: The number of tokens in the text
|
|
974
|
-
"""
|
|
975
|
-
try:
|
|
976
|
-
# Use a fast tokenizer that's good at general text
|
|
977
|
-
tokenizer = load_tokenizer()
|
|
978
|
-
chunks = split_text_into_chunks(text, 512)
|
|
979
|
-
return sum([len(tokenizer.encode(chunk)) for chunk in chunks])
|
|
980
|
-
|
|
981
|
-
except Exception as e:
|
|
982
|
-
PrettyOutput.print(f"计算token失败: {str(e)}", OutputType.WARNING)
|
|
983
|
-
# Fallback to rough character-based estimate
|
|
984
|
-
return len(text) // 4 # Rough estimate of 4 chars per token
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
23
|
+
# Install rich traceback handler for better error messages
|
|
24
|
+
install_rich_traceback()
|