jarvis-ai-assistant 0.1.102__py3-none-any.whl → 0.1.104__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/agent.py +138 -117
- jarvis/jarvis_code_agent/code_agent.py +234 -0
- jarvis/{jarvis_coder → jarvis_code_agent}/file_select.py +19 -22
- jarvis/jarvis_code_agent/patch.py +120 -0
- jarvis/jarvis_code_agent/relevant_files.py +97 -0
- jarvis/jarvis_codebase/main.py +871 -0
- jarvis/jarvis_platform/main.py +5 -3
- jarvis/jarvis_rag/main.py +818 -0
- jarvis/jarvis_smart_shell/main.py +2 -2
- jarvis/models/ai8.py +3 -1
- jarvis/models/kimi.py +36 -30
- jarvis/models/ollama.py +17 -11
- jarvis/models/openai.py +15 -12
- jarvis/models/oyi.py +24 -7
- jarvis/models/registry.py +1 -25
- jarvis/tools/__init__.py +0 -6
- jarvis/tools/ask_codebase.py +96 -0
- jarvis/tools/ask_user.py +1 -9
- jarvis/tools/chdir.py +2 -37
- jarvis/tools/code_review.py +210 -0
- jarvis/tools/create_code_test_agent.py +115 -0
- jarvis/tools/create_ctags_agent.py +164 -0
- jarvis/tools/create_sub_agent.py +2 -2
- jarvis/tools/execute_shell.py +2 -2
- jarvis/tools/file_operation.py +2 -2
- jarvis/tools/find_in_codebase.py +78 -0
- jarvis/tools/git_commiter.py +68 -0
- jarvis/tools/methodology.py +3 -3
- jarvis/tools/rag.py +141 -0
- jarvis/tools/read_code.py +116 -0
- jarvis/tools/read_webpage.py +1 -1
- jarvis/tools/registry.py +47 -31
- jarvis/tools/search.py +8 -6
- jarvis/tools/select_code_files.py +4 -4
- jarvis/utils.py +375 -85
- {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/METADATA +107 -32
- jarvis_ai_assistant-0.1.104.dist-info/RECORD +50 -0
- jarvis_ai_assistant-0.1.104.dist-info/entry_points.txt +11 -0
- jarvis/jarvis_code_agent/main.py +0 -200
- jarvis/jarvis_coder/git_utils.py +0 -123
- jarvis/jarvis_coder/patch_handler.py +0 -340
- jarvis/jarvis_github/main.py +0 -232
- jarvis/tools/create_code_sub_agent.py +0 -56
- jarvis/tools/execute_code_modification.py +0 -70
- jarvis/tools/find_files.py +0 -119
- jarvis/tools/generate_tool.py +0 -174
- jarvis/tools/thinker.py +0 -151
- jarvis_ai_assistant-0.1.102.dist-info/RECORD +0 -46
- jarvis_ai_assistant-0.1.102.dist-info/entry_points.txt +0 -6
- /jarvis/{jarvis_coder → jarvis_codebase}/__init__.py +0 -0
- /jarvis/{jarvis_github → jarvis_rag}/__init__.py +0 -0
- {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/top_level.txt +0 -0
jarvis/utils.py
CHANGED
|
@@ -1,32 +1,68 @@
|
|
|
1
1
|
import hashlib
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
import sys
|
|
4
3
|
import time
|
|
5
4
|
import os
|
|
6
5
|
from enum import Enum
|
|
7
6
|
from datetime import datetime
|
|
8
|
-
from typing import Any
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
9
8
|
import colorama
|
|
10
9
|
from colorama import Fore, Style as ColoramaStyle
|
|
11
10
|
import numpy as np
|
|
12
11
|
from prompt_toolkit import PromptSession
|
|
13
12
|
from prompt_toolkit.styles import Style as PromptStyle
|
|
14
13
|
from prompt_toolkit.formatted_text import FormattedText
|
|
14
|
+
from sentence_transformers import SentenceTransformer
|
|
15
|
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
16
|
+
import torch
|
|
15
17
|
import yaml
|
|
18
|
+
import faiss
|
|
19
|
+
from pygments.lexers import guess_lexer
|
|
20
|
+
from pygments.util import ClassNotFound
|
|
21
|
+
|
|
22
|
+
from rich.console import Console
|
|
23
|
+
from rich.theme import Theme
|
|
24
|
+
from rich.panel import Panel
|
|
25
|
+
from rich.text import Text
|
|
26
|
+
from rich.traceback import install as install_rich_traceback
|
|
27
|
+
from rich.syntax import Syntax
|
|
28
|
+
|
|
29
|
+
from prompt_toolkit.completion import Completer, Completion, PathCompleter
|
|
30
|
+
from prompt_toolkit.document import Document
|
|
31
|
+
from fuzzywuzzy import fuzz
|
|
16
32
|
|
|
17
33
|
# 初始化colorama
|
|
18
34
|
colorama.init()
|
|
19
35
|
|
|
20
36
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
21
|
-
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
22
37
|
|
|
23
38
|
current_agent = []
|
|
24
39
|
|
|
40
|
+
# Install rich traceback handler
|
|
41
|
+
install_rich_traceback()
|
|
42
|
+
|
|
43
|
+
# Create console with custom theme
|
|
44
|
+
custom_theme = Theme({
|
|
45
|
+
"info": "yellow",
|
|
46
|
+
"warning": "yellow",
|
|
47
|
+
"error": "red",
|
|
48
|
+
"success": "green",
|
|
49
|
+
"system": "cyan",
|
|
50
|
+
"code": "green",
|
|
51
|
+
"result": "blue",
|
|
52
|
+
"planning": "magenta",
|
|
53
|
+
"progress": "white",
|
|
54
|
+
"debug": "blue",
|
|
55
|
+
"user": "green",
|
|
56
|
+
"tool": "yellow",
|
|
57
|
+
})
|
|
58
|
+
|
|
59
|
+
console = Console(theme=custom_theme)
|
|
60
|
+
|
|
25
61
|
def add_agent(agent_name: str):
|
|
26
62
|
current_agent.append(agent_name)
|
|
27
63
|
|
|
28
|
-
def
|
|
29
|
-
return current_agent
|
|
64
|
+
def get_agent_list():
|
|
65
|
+
return ']['.join(current_agent) if current_agent else "No Agent"
|
|
30
66
|
|
|
31
67
|
def delete_current_agent():
|
|
32
68
|
current_agent.pop()
|
|
@@ -46,25 +82,9 @@ class OutputType(Enum):
|
|
|
46
82
|
TOOL = "tool" # Tool call
|
|
47
83
|
|
|
48
84
|
class PrettyOutput:
|
|
49
|
-
"""
|
|
85
|
+
"""Pretty output using rich"""
|
|
50
86
|
|
|
51
|
-
#
|
|
52
|
-
COLORS = {
|
|
53
|
-
OutputType.SYSTEM: Fore.CYAN, # Cyan - AI assistant
|
|
54
|
-
OutputType.CODE: Fore.GREEN, # Green - Code
|
|
55
|
-
OutputType.RESULT: Fore.BLUE, # Blue - Result
|
|
56
|
-
OutputType.ERROR: Fore.RED, # Red - Error
|
|
57
|
-
OutputType.INFO: Fore.YELLOW, # Yellow - Prompt
|
|
58
|
-
OutputType.PLANNING: Fore.MAGENTA, # Magenta - Planning
|
|
59
|
-
OutputType.PROGRESS: Fore.WHITE, # White - Progress
|
|
60
|
-
OutputType.SUCCESS: Fore.GREEN, # Green - Success
|
|
61
|
-
OutputType.WARNING: Fore.YELLOW, # Yellow - Warning
|
|
62
|
-
OutputType.DEBUG: Fore.BLUE, # Blue - Debug
|
|
63
|
-
OutputType.USER: Fore.GREEN, # Green - User
|
|
64
|
-
OutputType.TOOL: Fore.YELLOW, # Yellow - Tool
|
|
65
|
-
}
|
|
66
|
-
|
|
67
|
-
# 图标方案
|
|
87
|
+
# Icons for different output types
|
|
68
88
|
ICONS = {
|
|
69
89
|
OutputType.SYSTEM: "🤖", # Robot - AI assistant
|
|
70
90
|
OutputType.CODE: "📝", # Notebook - Code
|
|
@@ -79,67 +99,106 @@ class PrettyOutput:
|
|
|
79
99
|
OutputType.USER: "👤", # User - User
|
|
80
100
|
OutputType.TOOL: "🔧", # Wrench - Tool
|
|
81
101
|
}
|
|
82
|
-
|
|
83
|
-
#
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
102
|
+
|
|
103
|
+
# Common language mapping dictionary
|
|
104
|
+
_lang_map = {
|
|
105
|
+
'Python': 'python',
|
|
106
|
+
'JavaScript': 'javascript',
|
|
107
|
+
'TypeScript': 'typescript',
|
|
108
|
+
'Java': 'java',
|
|
109
|
+
'C++': 'cpp',
|
|
110
|
+
'C#': 'csharp',
|
|
111
|
+
'Ruby': 'ruby',
|
|
112
|
+
'PHP': 'php',
|
|
113
|
+
'Go': 'go',
|
|
114
|
+
'Rust': 'rust',
|
|
115
|
+
'Bash': 'bash',
|
|
116
|
+
'HTML': 'html',
|
|
117
|
+
'CSS': 'css',
|
|
118
|
+
'SQL': 'sql',
|
|
119
|
+
'R': 'r',
|
|
120
|
+
'Kotlin': 'kotlin',
|
|
121
|
+
'Swift': 'swift',
|
|
122
|
+
'Scala': 'scala',
|
|
123
|
+
'Perl': 'perl',
|
|
124
|
+
'Lua': 'lua',
|
|
125
|
+
'YAML': 'yaml',
|
|
126
|
+
'JSON': 'json',
|
|
127
|
+
'XML': 'xml',
|
|
128
|
+
'Markdown': 'markdown',
|
|
129
|
+
'Text': 'text',
|
|
130
|
+
'Shell': 'bash',
|
|
131
|
+
'Dockerfile': 'dockerfile',
|
|
132
|
+
'Makefile': 'makefile',
|
|
133
|
+
'INI': 'ini',
|
|
134
|
+
'TOML': 'toml',
|
|
97
135
|
}
|
|
98
136
|
|
|
99
137
|
@staticmethod
|
|
100
|
-
def
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
138
|
+
def _detect_language(text: str, default_lang: str = 'markdown') -> str:
|
|
139
|
+
"""Helper method to detect language and map it to syntax highlighting name"""
|
|
140
|
+
try:
|
|
141
|
+
lexer = guess_lexer(text)
|
|
142
|
+
detected_lang = lexer.name
|
|
143
|
+
return PrettyOutput._lang_map.get(detected_lang, default_lang)
|
|
144
|
+
except ClassNotFound:
|
|
145
|
+
return default_lang
|
|
146
|
+
except Exception:
|
|
147
|
+
return default_lang
|
|
148
|
+
|
|
149
|
+
@staticmethod
|
|
150
|
+
def format(text: str, output_type: OutputType, timestamp: bool = True) -> Text:
|
|
151
|
+
"""Format output text using rich Text"""
|
|
152
|
+
# Create rich Text object
|
|
153
|
+
formatted = Text()
|
|
108
154
|
|
|
109
|
-
#
|
|
110
|
-
|
|
155
|
+
# Add timestamp and agent info
|
|
156
|
+
if timestamp:
|
|
157
|
+
formatted.append(f"[{datetime.now().strftime('%H:%M:%S')}] ", style="white")
|
|
158
|
+
formatted.append(f"[{get_agent_list()}]", style="blue")
|
|
159
|
+
# Add icon
|
|
160
|
+
icon = PrettyOutput.ICONS.get(output_type, "")
|
|
161
|
+
formatted.append(f"{icon} ", style=output_type.value)
|
|
111
162
|
|
|
112
|
-
return
|
|
163
|
+
return formatted
|
|
113
164
|
|
|
114
165
|
@staticmethod
|
|
115
|
-
def print(text: str, output_type: OutputType, timestamp: bool = True):
|
|
116
|
-
"""Print formatted output"""
|
|
117
|
-
|
|
166
|
+
def print(text: str, output_type: OutputType, timestamp: bool = True, lang: Optional[str] = None):
|
|
167
|
+
"""Print formatted output using rich console"""
|
|
168
|
+
# Get formatted header
|
|
169
|
+
lang = lang if lang is not None else PrettyOutput._detect_language(text, default_lang='markdown')
|
|
170
|
+
header = PrettyOutput.format("", output_type, timestamp)
|
|
171
|
+
|
|
172
|
+
content = Syntax(text, lang, theme="monokai")
|
|
173
|
+
|
|
174
|
+
# Print panel with appropriate border style
|
|
175
|
+
border_style = "red" if output_type == OutputType.ERROR else output_type.value
|
|
176
|
+
console.print(Panel(content, border_style=border_style, title=header, title_align="left", highlight=True))
|
|
177
|
+
|
|
178
|
+
# Print stack trace for errors
|
|
118
179
|
if output_type == OutputType.ERROR:
|
|
119
|
-
|
|
120
|
-
PrettyOutput.print(f"Error trace: {traceback.format_exc()}", OutputType.INFO)
|
|
180
|
+
console.print_exception()
|
|
121
181
|
|
|
122
182
|
@staticmethod
|
|
123
183
|
def section(title: str, output_type: OutputType = OutputType.INFO):
|
|
124
|
-
"""Print
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
print(
|
|
184
|
+
"""Print section title in a panel"""
|
|
185
|
+
panel = Panel(
|
|
186
|
+
Text(title, style=output_type.value, justify="center"),
|
|
187
|
+
border_style=output_type.value
|
|
188
|
+
)
|
|
189
|
+
console.print()
|
|
190
|
+
console.print(panel)
|
|
191
|
+
console.print()
|
|
130
192
|
|
|
131
193
|
@staticmethod
|
|
132
194
|
def print_stream(text: str):
|
|
133
|
-
"""Print stream output
|
|
134
|
-
|
|
135
|
-
sys.stdout.write(f"{color}{text}{ColoramaStyle.RESET_ALL}")
|
|
136
|
-
sys.stdout.flush()
|
|
195
|
+
"""Print stream output without line break"""
|
|
196
|
+
console.print(text, style="system", end="")
|
|
137
197
|
|
|
138
198
|
@staticmethod
|
|
139
199
|
def print_stream_end():
|
|
140
|
-
"""
|
|
141
|
-
|
|
142
|
-
sys.stdout.flush()
|
|
200
|
+
"""End stream output with line break"""
|
|
201
|
+
console.print()
|
|
143
202
|
|
|
144
203
|
def get_single_line_input(tip: str) -> str:
|
|
145
204
|
"""Get single line input, support direction key, history function, etc."""
|
|
@@ -149,14 +208,87 @@ def get_single_line_input(tip: str) -> str:
|
|
|
149
208
|
})
|
|
150
209
|
return session.prompt(f"{tip}", style=style)
|
|
151
210
|
|
|
211
|
+
def make_choice_input(tip: str, choices: list) -> str:
|
|
212
|
+
"""Get choice input, support direction key, history function, etc."""
|
|
213
|
+
session = PromptSession(history=None)
|
|
214
|
+
style = PromptStyle.from_dict({
|
|
215
|
+
'prompt': 'ansicyan',
|
|
216
|
+
})
|
|
217
|
+
return session.prompt(f"{tip}", style=style)
|
|
218
|
+
|
|
219
|
+
class FileCompleter(Completer):
|
|
220
|
+
"""Custom completer for file paths with fuzzy matching."""
|
|
221
|
+
def __init__(self):
|
|
222
|
+
self.path_completer = PathCompleter()
|
|
223
|
+
|
|
224
|
+
def get_completions(self, document: Document, complete_event):
|
|
225
|
+
text = document.text_before_cursor
|
|
226
|
+
cursor_pos = document.cursor_position
|
|
227
|
+
|
|
228
|
+
# Find all @ positions in text
|
|
229
|
+
at_positions = [i for i, char in enumerate(text) if char == '@']
|
|
230
|
+
|
|
231
|
+
if not at_positions:
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
# Get the last @ position
|
|
235
|
+
current_at_pos = at_positions[-1]
|
|
236
|
+
|
|
237
|
+
# If cursor is not after the last @, don't complete
|
|
238
|
+
if cursor_pos <= current_at_pos:
|
|
239
|
+
return
|
|
240
|
+
|
|
241
|
+
# Check if there's a space after @
|
|
242
|
+
text_after_at = text[current_at_pos + 1:cursor_pos]
|
|
243
|
+
if ' ' in text_after_at:
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
# Get the text after the current @
|
|
247
|
+
file_path = text_after_at.strip()
|
|
248
|
+
|
|
249
|
+
# Get all possible files from current directory
|
|
250
|
+
all_files = []
|
|
251
|
+
for root, _, files in os.walk('.'):
|
|
252
|
+
for f in files:
|
|
253
|
+
path = os.path.join(root, f)
|
|
254
|
+
# Remove ./ from the beginning
|
|
255
|
+
path = path[2:] if path.startswith('./') else path
|
|
256
|
+
all_files.append(path)
|
|
257
|
+
|
|
258
|
+
# If no input after @, show all files
|
|
259
|
+
# Otherwise use fuzzy matching
|
|
260
|
+
if not file_path:
|
|
261
|
+
scored_files = [(path, 100) for path in all_files]
|
|
262
|
+
else:
|
|
263
|
+
scored_files = [
|
|
264
|
+
(path, fuzz.ratio(file_path.lower(), path.lower()))
|
|
265
|
+
for path in all_files
|
|
266
|
+
]
|
|
267
|
+
scored_files.sort(key=lambda x: x[1], reverse=True)
|
|
268
|
+
|
|
269
|
+
# Return completions for files
|
|
270
|
+
for path, score in scored_files:
|
|
271
|
+
if not file_path or score > 30: # Show all if no input, otherwise filter by score
|
|
272
|
+
completion = Completion(
|
|
273
|
+
text=path,
|
|
274
|
+
start_position=-len(file_path),
|
|
275
|
+
display=f"{path}" if not file_path else f"{path} ({score}%)",
|
|
276
|
+
display_meta="File"
|
|
277
|
+
)
|
|
278
|
+
yield completion
|
|
279
|
+
|
|
152
280
|
def get_multiline_input(tip: str) -> str:
|
|
153
|
-
"""Get multi-line input, support direction key, history function,
|
|
154
|
-
print(f"{Fore.GREEN}{tip}{ColoramaStyle.RESET_ALL}")
|
|
281
|
+
"""Get multi-line input, support direction key, history function, and file completion.
|
|
155
282
|
|
|
156
|
-
|
|
157
|
-
|
|
283
|
+
Args:
|
|
284
|
+
tip: The prompt tip to display
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
str: The entered text
|
|
288
|
+
"""
|
|
289
|
+
print(f"{Fore.GREEN}{tip}{ColoramaStyle.RESET_ALL}")
|
|
158
290
|
|
|
159
|
-
#
|
|
291
|
+
# Define prompt style
|
|
160
292
|
style = PromptStyle.from_dict({
|
|
161
293
|
'prompt': 'ansicyan',
|
|
162
294
|
})
|
|
@@ -164,28 +296,34 @@ def get_multiline_input(tip: str) -> str:
|
|
|
164
296
|
lines = []
|
|
165
297
|
try:
|
|
166
298
|
while True:
|
|
167
|
-
#
|
|
299
|
+
# Set prompt
|
|
168
300
|
prompt = FormattedText([
|
|
169
301
|
('class:prompt', '... ' if lines else '>>> ')
|
|
170
302
|
])
|
|
171
303
|
|
|
172
|
-
#
|
|
304
|
+
# Create new session with new completer for each line
|
|
305
|
+
session = PromptSession(
|
|
306
|
+
history=None, # Use default history
|
|
307
|
+
completer=FileCompleter() # New completer instance for each line
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Get input with completion support
|
|
173
311
|
line = session.prompt(
|
|
174
312
|
prompt,
|
|
175
313
|
style=style,
|
|
176
314
|
).strip()
|
|
177
315
|
|
|
178
|
-
#
|
|
316
|
+
# Handle empty line
|
|
179
317
|
if not line:
|
|
180
|
-
if not lines: #
|
|
318
|
+
if not lines: # First line is empty
|
|
181
319
|
return ""
|
|
182
|
-
break #
|
|
320
|
+
break # End multi-line input
|
|
183
321
|
|
|
184
322
|
lines.append(line)
|
|
185
323
|
|
|
186
324
|
except KeyboardInterrupt:
|
|
187
|
-
PrettyOutput.print("
|
|
188
|
-
return "
|
|
325
|
+
PrettyOutput.print("Input cancelled", OutputType.INFO)
|
|
326
|
+
return ""
|
|
189
327
|
|
|
190
328
|
return "\n".join(lines)
|
|
191
329
|
|
|
@@ -239,6 +377,73 @@ def find_git_root(dir="."):
|
|
|
239
377
|
os.chdir(curr_dir)
|
|
240
378
|
return ret
|
|
241
379
|
|
|
380
|
+
def has_uncommitted_changes():
|
|
381
|
+
# Check working directory changes
|
|
382
|
+
working_changes = os.popen("git diff --exit-code").read().strip() != ""
|
|
383
|
+
# Check staged changes
|
|
384
|
+
staged_changes = os.popen("git diff --cached --exit-code").read().strip() != ""
|
|
385
|
+
return working_changes or staged_changes
|
|
386
|
+
|
|
387
|
+
def load_embedding_model():
|
|
388
|
+
model_name = "BAAI/bge-m3"
|
|
389
|
+
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
try:
|
|
393
|
+
# Load model
|
|
394
|
+
embedding_model = SentenceTransformer(
|
|
395
|
+
model_name,
|
|
396
|
+
cache_folder=cache_dir,
|
|
397
|
+
local_files_only=True
|
|
398
|
+
)
|
|
399
|
+
except Exception as e:
|
|
400
|
+
# Load model
|
|
401
|
+
embedding_model = SentenceTransformer(
|
|
402
|
+
model_name,
|
|
403
|
+
cache_folder=cache_dir,
|
|
404
|
+
local_files_only=False
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
return embedding_model
|
|
408
|
+
|
|
409
|
+
def load_rerank_model():
|
|
410
|
+
"""Load reranking model"""
|
|
411
|
+
model_name = "BAAI/bge-reranker-v2-m3"
|
|
412
|
+
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
413
|
+
|
|
414
|
+
PrettyOutput.print(f"Loading reranking model: {model_name}...", OutputType.INFO)
|
|
415
|
+
|
|
416
|
+
try:
|
|
417
|
+
# Load model and tokenizer
|
|
418
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
419
|
+
model_name,
|
|
420
|
+
cache_dir=cache_dir,
|
|
421
|
+
local_files_only=True
|
|
422
|
+
)
|
|
423
|
+
model = AutoModelForSequenceClassification.from_pretrained(
|
|
424
|
+
model_name,
|
|
425
|
+
cache_dir=cache_dir,
|
|
426
|
+
local_files_only=True
|
|
427
|
+
)
|
|
428
|
+
except Exception as e:
|
|
429
|
+
# Load model and tokenizer
|
|
430
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
431
|
+
model_name,
|
|
432
|
+
cache_dir=cache_dir,
|
|
433
|
+
local_files_only=False
|
|
434
|
+
)
|
|
435
|
+
model = AutoModelForSequenceClassification.from_pretrained(
|
|
436
|
+
model_name,
|
|
437
|
+
cache_dir=cache_dir,
|
|
438
|
+
local_files_only=False
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# Use GPU if available
|
|
442
|
+
if torch.cuda.is_available():
|
|
443
|
+
model = model.cuda()
|
|
444
|
+
model.eval()
|
|
445
|
+
|
|
446
|
+
return model, tokenizer
|
|
242
447
|
|
|
243
448
|
def get_max_context_length():
|
|
244
449
|
return int(os.getenv('JARVIS_MAX_CONTEXT_LENGTH', '131072')) # 默认128k
|
|
@@ -270,6 +475,10 @@ def get_file_md5(filepath: str)->str:
|
|
|
270
475
|
return hashlib.md5(open(filepath, "rb").read(100*1024*1024)).hexdigest()
|
|
271
476
|
|
|
272
477
|
|
|
478
|
+
def dont_use_local_model():
|
|
479
|
+
return os.getenv('JARVIS_DONT_USE_LOCAL_MODEL', 'false') == 'true'
|
|
480
|
+
|
|
481
|
+
|
|
273
482
|
def _create_methodology_embedding(embedding_model: Any, methodology_text: str) -> np.ndarray:
|
|
274
483
|
"""Create embedding vector for methodology text"""
|
|
275
484
|
try:
|
|
@@ -290,19 +499,77 @@ def _create_methodology_embedding(embedding_model: Any, methodology_text: str) -
|
|
|
290
499
|
|
|
291
500
|
def load_methodology(user_input: str) -> str:
|
|
292
501
|
"""Load methodology and build vector index"""
|
|
502
|
+
PrettyOutput.print("Loading methodology...", OutputType.PROGRESS)
|
|
293
503
|
user_jarvis_methodology = os.path.expanduser("~/.jarvis/methodology")
|
|
294
504
|
if not os.path.exists(user_jarvis_methodology):
|
|
295
505
|
return ""
|
|
506
|
+
|
|
507
|
+
def make_methodology_prompt(data: Dict) -> str:
|
|
508
|
+
ret = """This is the standard methodology for handling previous problems, if the current task is similar, you can refer to it, if not,just ignore it:\n"""
|
|
509
|
+
for key, value in data.items():
|
|
510
|
+
ret += f"Problem: {key}\nMethodology: {value}\n"
|
|
511
|
+
return ret
|
|
296
512
|
|
|
297
513
|
try:
|
|
298
514
|
with open(user_jarvis_methodology, "r", encoding="utf-8") as f:
|
|
299
515
|
data = yaml.safe_load(f)
|
|
300
516
|
|
|
517
|
+
if dont_use_local_model():
|
|
518
|
+
return make_methodology_prompt(data)
|
|
519
|
+
|
|
520
|
+
# Reset data structure
|
|
521
|
+
methodology_data = []
|
|
522
|
+
vectors = []
|
|
523
|
+
ids = []
|
|
524
|
+
|
|
525
|
+
# Get embedding model
|
|
526
|
+
embedding_model = load_embedding_model()
|
|
301
527
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
528
|
+
# Create test embedding to get correct dimension
|
|
529
|
+
test_embedding = _create_methodology_embedding(embedding_model, "test")
|
|
530
|
+
embedding_dimension = len(test_embedding)
|
|
531
|
+
|
|
532
|
+
# Create embedding vector for each methodology
|
|
533
|
+
for i, (key, value) in enumerate(data.items()):
|
|
534
|
+
methodology_text = f"{key}\n{value}"
|
|
535
|
+
embedding = _create_methodology_embedding(embedding_model, methodology_text)
|
|
536
|
+
vectors.append(embedding)
|
|
537
|
+
ids.append(i)
|
|
538
|
+
methodology_data.append({"key": key, "value": value})
|
|
539
|
+
|
|
540
|
+
if vectors:
|
|
541
|
+
vectors_array = np.vstack(vectors)
|
|
542
|
+
# Use correct dimension from test embedding
|
|
543
|
+
hnsw_index = faiss.IndexHNSWFlat(embedding_dimension, 16)
|
|
544
|
+
hnsw_index.hnsw.efConstruction = 40
|
|
545
|
+
hnsw_index.hnsw.efSearch = 16
|
|
546
|
+
methodology_index = faiss.IndexIDMap(hnsw_index)
|
|
547
|
+
methodology_index.add_with_ids(vectors_array, np.array(ids)) # type: ignore
|
|
548
|
+
query_embedding = _create_methodology_embedding(embedding_model, user_input)
|
|
549
|
+
k = min(3, len(methodology_data))
|
|
550
|
+
PrettyOutput.print(f"Retrieving methodology...", OutputType.INFO)
|
|
551
|
+
distances, indices = methodology_index.search(
|
|
552
|
+
query_embedding.reshape(1, -1), k
|
|
553
|
+
) # type: ignore
|
|
554
|
+
|
|
555
|
+
relevant_methodologies = {}
|
|
556
|
+
output_lines = []
|
|
557
|
+
for dist, idx in zip(distances[0], indices[0]):
|
|
558
|
+
if idx >= 0:
|
|
559
|
+
similarity = 1.0 / (1.0 + float(dist))
|
|
560
|
+
methodology = methodology_data[idx]
|
|
561
|
+
output_lines.append(
|
|
562
|
+
f"Methodology '{methodology['key']}' similarity: {similarity:.3f}"
|
|
563
|
+
)
|
|
564
|
+
if similarity >= 0.5:
|
|
565
|
+
relevant_methodologies[methodology["key"]] = methodology["value"]
|
|
566
|
+
|
|
567
|
+
if output_lines:
|
|
568
|
+
PrettyOutput.print("\n".join(output_lines), OutputType.INFO)
|
|
569
|
+
|
|
570
|
+
if relevant_methodologies:
|
|
571
|
+
return make_methodology_prompt(relevant_methodologies)
|
|
572
|
+
return make_methodology_prompt(data)
|
|
306
573
|
|
|
307
574
|
except Exception as e:
|
|
308
575
|
PrettyOutput.print(f"Error loading methodology: {str(e)}", OutputType.ERROR)
|
|
@@ -310,5 +577,28 @@ def load_methodology(user_input: str) -> str:
|
|
|
310
577
|
PrettyOutput.print(f"Error trace: {traceback.format_exc()}", OutputType.INFO)
|
|
311
578
|
return ""
|
|
312
579
|
|
|
313
|
-
def
|
|
314
|
-
return
|
|
580
|
+
def is_auto_complete() -> bool:
|
|
581
|
+
return os.getenv('JARVIS_AUTO_COMPLETE', 'false') == 'true'
|
|
582
|
+
|
|
583
|
+
def is_disable_codebase() -> bool:
|
|
584
|
+
return os.getenv('JARVIS_DISABLE_CODEBASE', 'false') == 'true'
|
|
585
|
+
|
|
586
|
+
def user_confirm(tip: str, default: bool = True) -> bool:
|
|
587
|
+
"""Prompt the user for confirmation.
|
|
588
|
+
|
|
589
|
+
Args:
|
|
590
|
+
tip: The message to show to the user
|
|
591
|
+
default: The default response if user hits enter
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
bool: True if user confirmed, False otherwise
|
|
595
|
+
"""
|
|
596
|
+
suffix = "[Y/n]" if default else "[y/N]"
|
|
597
|
+
ret = get_single_line_input(f"{tip} {suffix}: ")
|
|
598
|
+
return default if ret == "" else ret.lower() == "y"
|
|
599
|
+
|
|
600
|
+
def get_file_line_count(filename: str) -> int:
|
|
601
|
+
try:
|
|
602
|
+
return len(open(filename, "r", encoding="utf-8").readlines())
|
|
603
|
+
except Exception as e:
|
|
604
|
+
return 0
|