jarvis-ai-assistant 0.1.124__py3-none-any.whl → 0.1.125__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

Files changed (66) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +18 -20
  3. jarvis/jarvis_code_agent/code_agent.py +195 -45
  4. jarvis/jarvis_code_agent/file_select.py +6 -19
  5. jarvis/jarvis_code_agent/patch.py +189 -310
  6. jarvis/jarvis_codebase/main.py +6 -2
  7. jarvis/jarvis_dev/main.py +6 -4
  8. jarvis/jarvis_git_squash/__init__.py +0 -0
  9. jarvis/jarvis_git_squash/main.py +81 -0
  10. jarvis/jarvis_lsp/cpp.py +1 -1
  11. jarvis/jarvis_lsp/go.py +1 -1
  12. jarvis/jarvis_lsp/registry.py +2 -2
  13. jarvis/jarvis_lsp/rust.py +1 -1
  14. jarvis/jarvis_multi_agent/__init__.py +1 -1
  15. jarvis/jarvis_platform/ai8.py +2 -1
  16. jarvis/jarvis_platform/base.py +19 -24
  17. jarvis/jarvis_platform/kimi.py +2 -3
  18. jarvis/jarvis_platform/ollama.py +3 -1
  19. jarvis/jarvis_platform/openai.py +1 -1
  20. jarvis/jarvis_platform/oyi.py +2 -1
  21. jarvis/jarvis_platform/registry.py +2 -1
  22. jarvis/jarvis_platform_manager/main.py +4 -6
  23. jarvis/jarvis_platform_manager/openai_test.py +0 -1
  24. jarvis/jarvis_rag/main.py +5 -2
  25. jarvis/jarvis_smart_shell/main.py +9 -4
  26. jarvis/jarvis_tools/ask_codebase.py +12 -7
  27. jarvis/jarvis_tools/ask_user.py +3 -2
  28. jarvis/jarvis_tools/base.py +21 -7
  29. jarvis/jarvis_tools/chdir.py +0 -1
  30. jarvis/jarvis_tools/code_review.py +13 -14
  31. jarvis/jarvis_tools/create_code_agent.py +2 -2
  32. jarvis/jarvis_tools/create_sub_agent.py +2 -2
  33. jarvis/jarvis_tools/execute_shell.py +3 -1
  34. jarvis/jarvis_tools/execute_shell_script.py +4 -4
  35. jarvis/jarvis_tools/file_operation.py +3 -2
  36. jarvis/jarvis_tools/git_commiter.py +5 -2
  37. jarvis/jarvis_tools/lsp_find_definition.py +1 -1
  38. jarvis/jarvis_tools/lsp_find_references.py +1 -1
  39. jarvis/jarvis_tools/lsp_get_diagnostics.py +19 -11
  40. jarvis/jarvis_tools/lsp_get_document_symbols.py +1 -1
  41. jarvis/jarvis_tools/lsp_prepare_rename.py +1 -1
  42. jarvis/jarvis_tools/lsp_validate_edit.py +1 -1
  43. jarvis/jarvis_tools/methodology.py +4 -1
  44. jarvis/jarvis_tools/rag.py +22 -15
  45. jarvis/jarvis_tools/read_code.py +4 -3
  46. jarvis/jarvis_tools/read_webpage.py +2 -1
  47. jarvis/jarvis_tools/registry.py +4 -1
  48. jarvis/jarvis_tools/{search.py → search_web.py} +5 -2
  49. jarvis/jarvis_tools/select_code_files.py +1 -1
  50. jarvis/jarvis_utils/__init__.py +19 -982
  51. jarvis/jarvis_utils/config.py +138 -0
  52. jarvis/jarvis_utils/embedding.py +201 -0
  53. jarvis/jarvis_utils/git_utils.py +120 -0
  54. jarvis/jarvis_utils/globals.py +82 -0
  55. jarvis/jarvis_utils/input.py +161 -0
  56. jarvis/jarvis_utils/methodology.py +128 -0
  57. jarvis/jarvis_utils/output.py +235 -0
  58. jarvis/jarvis_utils/utils.py +150 -0
  59. jarvis_ai_assistant-0.1.125.dist-info/METADATA +291 -0
  60. jarvis_ai_assistant-0.1.125.dist-info/RECORD +75 -0
  61. {jarvis_ai_assistant-0.1.124.dist-info → jarvis_ai_assistant-0.1.125.dist-info}/entry_points.txt +1 -0
  62. jarvis_ai_assistant-0.1.124.dist-info/METADATA +0 -460
  63. jarvis_ai_assistant-0.1.124.dist-info/RECORD +0 -65
  64. {jarvis_ai_assistant-0.1.124.dist-info → jarvis_ai_assistant-0.1.125.dist-info}/LICENSE +0 -0
  65. {jarvis_ai_assistant-0.1.124.dist-info → jarvis_ai_assistant-0.1.125.dist-info}/WHEEL +0 -0
  66. {jarvis_ai_assistant-0.1.124.dist-info → jarvis_ai_assistant-0.1.125.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,128 @@
1
+ """
2
+ Methodology Management Module
3
+ This module provides utilities for loading and searching methodologies.
4
+ It includes functions for:
5
+ - Creating methodology embeddings
6
+ - Loading and processing methodology data
7
+ - Building and searching methodology index
8
+ - Generating methodology prompts
9
+ """
10
+ import os
11
+ import yaml
12
+ import numpy as np
13
+ import faiss
14
+ from typing import Dict, Any, List
15
+ from jarvis.jarvis_utils.output import PrettyOutput, OutputType
16
+ from jarvis.jarvis_utils.embedding import load_embedding_model
17
+ from jarvis.jarvis_utils.config import dont_use_local_model
18
+ def _create_methodology_embedding(embedding_model: Any, methodology_text: str) -> np.ndarray:
19
+ """
20
+ Create embedding vector for methodology text.
21
+
22
+ Args:
23
+ embedding_model: The embedding model to use
24
+ methodology_text: The text to create embedding for
25
+
26
+ Returns:
27
+ np.ndarray: The embedding vector
28
+ """
29
+ try:
30
+ # Truncate long text
31
+ max_length = 512
32
+ text = ' '.join(methodology_text.split()[:max_length])
33
+
34
+ # 使用sentence_transformers模型获取嵌入向量
35
+ embedding = embedding_model.encode([text],
36
+ convert_to_tensor=True,
37
+ normalize_embeddings=True)
38
+ vector = np.array(embedding.cpu().numpy(), dtype=np.float32)
39
+ return vector[0] # Return first vector, because we only encoded one text
40
+ except Exception as e:
41
+ PrettyOutput.print(f"创建方法论嵌入向量失败: {str(e)}", OutputType.ERROR)
42
+ return np.zeros(1536, dtype=np.float32)
43
+ def make_methodology_prompt(data: Dict[str, str]) -> str:
44
+ """
45
+ Generate a formatted prompt from methodology data.
46
+
47
+ Args:
48
+ data: Dictionary of methodology data
49
+
50
+ Returns:
51
+ str: Formatted prompt string
52
+ """
53
+ ret = """This is the standard methodology for handling previous problems, if the current task is similar, you can refer to it, if not,just ignore it:\n"""
54
+ for key, value in data.items():
55
+ ret += f"Problem: {key}\nMethodology: {value}\n"
56
+ return ret
57
+ def load_methodology(user_input: str) -> str:
58
+ """
59
+ Load methodology and build vector index for similarity search.
60
+
61
+ Args:
62
+ user_input: The input text to search methodologies for
63
+
64
+ Returns:
65
+ str: Relevant methodology prompt or empty string if no methodology found
66
+ """
67
+ PrettyOutput.print("加载方法论...", OutputType.PROGRESS)
68
+ user_jarvis_methodology = os.path.expanduser("~/.jarvis/methodology")
69
+ if not os.path.exists(user_jarvis_methodology):
70
+ return ""
71
+
72
+ try:
73
+ with open(user_jarvis_methodology, "r", encoding="utf-8") as f:
74
+ data = yaml.safe_load(f)
75
+ if dont_use_local_model():
76
+ return make_methodology_prompt(data)
77
+ # Reset data structure
78
+ methodology_data: List[Dict[str, str]] = []
79
+ vectors: List[np.ndarray] = []
80
+ ids: List[int] = []
81
+ # Get embedding model
82
+ embedding_model = load_embedding_model()
83
+
84
+ # Create test embedding to get correct dimension
85
+ test_embedding = _create_methodology_embedding(embedding_model, "test")
86
+ embedding_dimension = len(test_embedding)
87
+ # Create embedding vector for each methodology
88
+ for i, (key, value) in enumerate(data.items()):
89
+ methodology_text = f"{key}\n{value}"
90
+ embedding = _create_methodology_embedding(embedding_model, methodology_text)
91
+ vectors.append(embedding)
92
+ ids.append(i)
93
+ methodology_data.append({"key": key, "value": value})
94
+ if vectors:
95
+ vectors_array = np.vstack(vectors)
96
+ # Use correct dimension from test embedding
97
+ hnsw_index = faiss.IndexHNSWFlat(embedding_dimension, 16)
98
+ hnsw_index.hnsw.efConstruction = 40
99
+ hnsw_index.hnsw.efSearch = 16
100
+ methodology_index = faiss.IndexIDMap(hnsw_index)
101
+ methodology_index.add_with_ids(vectors_array, np.array(ids)) # type: ignore
102
+ query_embedding = _create_methodology_embedding(embedding_model, user_input)
103
+ k = min(3, len(methodology_data))
104
+ PrettyOutput.print(f"检索方法论...", OutputType.INFO)
105
+ distances, indices = methodology_index.search(
106
+ query_embedding.reshape(1, -1), k
107
+ ) # type: ignore
108
+ relevant_methodologies = {}
109
+ output_lines = []
110
+ for dist, idx in zip(distances[0], indices[0]):
111
+ if idx >= 0:
112
+ similarity = 1.0 / (1.0 + float(dist))
113
+ methodology = methodology_data[idx]
114
+ output_lines.append(
115
+ f"Methodology '{methodology['key']}' similarity: {similarity:.3f}"
116
+ )
117
+ if similarity >= 0.5:
118
+ relevant_methodologies[methodology["key"]] = methodology["value"]
119
+
120
+ if output_lines:
121
+ PrettyOutput.print("\n".join(output_lines), OutputType.INFO)
122
+
123
+ if relevant_methodologies:
124
+ return make_methodology_prompt(relevant_methodologies)
125
+ return make_methodology_prompt(data)
126
+ except Exception as e:
127
+ PrettyOutput.print(f"加载方法论失败: {str(e)}", OutputType.ERROR)
128
+ return ""
@@ -0,0 +1,235 @@
1
+ """
2
+ Output Formatting Module
3
+ This module provides rich text formatting and display utilities for the Jarvis system.
4
+ It includes:
5
+ - OutputType enum for categorizing different types of output
6
+ - PrettyOutput class for formatting and displaying styled output
7
+ - Syntax highlighting support for various programming languages
8
+ - Panel-based display for structured output
9
+ """
10
+ from enum import Enum
11
+ from datetime import datetime
12
+ from typing import Optional
13
+ from rich.panel import Panel
14
+ from rich.box import HEAVY
15
+ from rich.text import Text
16
+ from rich.syntax import Syntax
17
+ from rich.style import Style as RichStyle
18
+ from pygments.lexers import guess_lexer
19
+ from pygments.util import ClassNotFound
20
+ from .globals import console, get_agent_list
21
+ class OutputType(Enum):
22
+ """
23
+ Enumeration of output types for categorizing and styling different types of messages.
24
+
25
+ Attributes:
26
+ SYSTEM: AI assistant message
27
+ CODE: Code related output
28
+ RESULT: Tool execution result
29
+ ERROR: Error information
30
+ INFO: System prompt
31
+ PLANNING: Task planning
32
+ PROGRESS: Execution progress
33
+ SUCCESS: Success information
34
+ WARNING: Warning information
35
+ DEBUG: Debug information
36
+ USER: User input
37
+ TOOL: Tool call
38
+ """
39
+ SYSTEM = "SYSTEM"
40
+ CODE = "CODE"
41
+ RESULT = "RESULT"
42
+ ERROR = "ERROR"
43
+ INFO = "INFO"
44
+ PLANNING = "PLANNING"
45
+ PROGRESS = "PROGRESS"
46
+ SUCCESS = "SUCCESS"
47
+ WARNING = "WARNING"
48
+ DEBUG = "DEBUG"
49
+ USER = "USER"
50
+ TOOL = "TOOL"
51
+ class PrettyOutput:
52
+ """
53
+ Class for formatting and displaying rich text output using the rich library.
54
+
55
+ Provides methods for:
56
+ - Formatting different types of output with appropriate styling
57
+ - Syntax highlighting for code blocks
58
+ - Panel-based display for structured content
59
+ - Stream output for progressive display
60
+ """
61
+ # Icons for different output types
62
+ _ICONS = {
63
+ OutputType.SYSTEM: "🤖",
64
+ OutputType.CODE: "📝",
65
+ OutputType.RESULT: "✨",
66
+ OutputType.ERROR: "❌",
67
+ OutputType.INFO: "ℹ️",
68
+ OutputType.PLANNING: "📋",
69
+ OutputType.PROGRESS: "⏳",
70
+ OutputType.SUCCESS: "✅",
71
+ OutputType.WARNING: "⚠️",
72
+ OutputType.DEBUG: "🔍",
73
+ OutputType.USER: "👤",
74
+ OutputType.TOOL: "🔧",
75
+ }
76
+ # Language mapping for syntax highlighting
77
+ _lang_map = {
78
+ 'Python': 'python',
79
+ 'JavaScript': 'javascript',
80
+ 'TypeScript': 'typescript',
81
+ 'Java': 'java',
82
+ 'C++': 'cpp',
83
+ 'C#': 'csharp',
84
+ 'Ruby': 'ruby',
85
+ 'PHP': 'php',
86
+ 'Go': 'go',
87
+ 'Rust': 'rust',
88
+ 'Bash': 'bash',
89
+ 'HTML': 'html',
90
+ 'CSS': 'css',
91
+ 'SQL': 'sql',
92
+ 'R': 'r',
93
+ 'Kotlin': 'kotlin',
94
+ 'Swift': 'swift',
95
+ 'Scala': 'scala',
96
+ 'Perl': 'perl',
97
+ 'Lua': 'lua',
98
+ 'YAML': 'yaml',
99
+ 'JSON': 'json',
100
+ 'XML': 'xml',
101
+ 'Markdown': 'markdown',
102
+ 'Text': 'text',
103
+ 'Shell': 'bash',
104
+ 'Dockerfile': 'dockerfile',
105
+ 'Makefile': 'makefile',
106
+ 'INI': 'ini',
107
+ 'TOML': 'toml',
108
+ }
109
+ @staticmethod
110
+ def _detect_language(text: str, default_lang: str = 'markdown') -> str:
111
+ """
112
+ Detect the programming language of the given text.
113
+
114
+ Args:
115
+ text: The text to analyze
116
+ default_lang: Default language if detection fails
117
+
118
+ Returns:
119
+ str: Detected language name
120
+ """
121
+ try:
122
+ lexer = guess_lexer(text)
123
+ detected_lang = lexer.name
124
+ return PrettyOutput._lang_map.get(detected_lang, default_lang)
125
+ except (ClassNotFound, Exception):
126
+ return default_lang
127
+ @staticmethod
128
+ def _format(output_type: OutputType, timestamp: bool = True) -> Text:
129
+ """
130
+ Format the output header with timestamp and icon.
131
+
132
+ Args:
133
+ output_type: Type of output
134
+ timestamp: Whether to include timestamp
135
+
136
+ Returns:
137
+ Text: Formatted rich Text object
138
+ """
139
+ formatted = Text()
140
+ if timestamp:
141
+ formatted.append(f"[{datetime.now().strftime('%H:%M:%S')}][{output_type.value}]", style=output_type.value)
142
+ agent_info = get_agent_list()
143
+ if agent_info:
144
+ formatted.append(f"[{agent_info}]", style="blue")
145
+ icon = PrettyOutput._ICONS.get(output_type, "")
146
+ formatted.append(f" {icon} ", style=output_type.value)
147
+ return formatted
148
+ @staticmethod
149
+ def print(text: str, output_type: OutputType, timestamp: bool = True, lang: Optional[str] = None, traceback: bool = False):
150
+ """
151
+ Print formatted output with styling and syntax highlighting.
152
+
153
+ Args:
154
+ text: The text content to print
155
+ output_type: The type of output (affects styling)
156
+ timestamp: Whether to show timestamp
157
+ lang: Language for syntax highlighting
158
+ traceback: Whether to show traceback for errors
159
+ """
160
+ styles = {
161
+ OutputType.SYSTEM: RichStyle(color="bright_cyan", bgcolor="#1a1a1a", frame=True, meta={"icon": "🤖"}),
162
+ OutputType.CODE: RichStyle(color="green", bgcolor="#1a1a1a", frame=True, meta={"icon": "📝"}),
163
+ OutputType.RESULT: RichStyle(color="bright_blue", bgcolor="#1a1a1a", frame=True, meta={"icon": "✨"}),
164
+ OutputType.ERROR: RichStyle(color="red", frame=True, bgcolor="dark_red", meta={"icon": "❌"}),
165
+ OutputType.INFO: RichStyle(color="gold1", frame=True, bgcolor="grey11", meta={"icon": "ℹ️"}),
166
+ OutputType.PLANNING: RichStyle(color="purple", bold=True, frame=True, meta={"icon": "📋"}),
167
+ OutputType.PROGRESS: RichStyle(color="white", encircle=True, frame=True, meta={"icon": "⏳"}),
168
+ OutputType.SUCCESS: RichStyle(color="bright_green", bold=True, strike=False, meta={"icon": "✅"}),
169
+ OutputType.WARNING: RichStyle(color="yellow", bold=True, blink2=True, bgcolor="dark_orange", meta={"icon": "⚠️"}),
170
+ OutputType.DEBUG: RichStyle(color="grey58", dim=True, conceal=True, meta={"icon": "🔍"}),
171
+ OutputType.USER: RichStyle(color="spring_green2", frame=True, meta={"icon": "👤"}),
172
+ OutputType.TOOL: RichStyle(color="dark_sea_green4", bgcolor="grey19", frame=True, meta={"icon": "🔧"}),
173
+ }
174
+ lang = lang if lang is not None else PrettyOutput._detect_language(text, default_lang='markdown')
175
+ header = PrettyOutput._format(output_type, timestamp)
176
+ content = Syntax(text, lang, theme="monokai", word_wrap=True)
177
+ panel = Panel(
178
+ content,
179
+ style=styles[output_type],
180
+ border_style=styles[output_type],
181
+ title=header,
182
+ title_align="left",
183
+ padding=(0, 0),
184
+ highlight=True,
185
+ box=HEAVY,
186
+ )
187
+ console.print(panel)
188
+ if traceback or output_type == OutputType.ERROR:
189
+ console.print_exception()
190
+ @staticmethod
191
+ def section(title: str, output_type: OutputType = OutputType.INFO):
192
+ """
193
+ Print a section title in a styled panel.
194
+
195
+ Args:
196
+ title: The section title text
197
+ output_type: The type of output (affects styling)
198
+ """
199
+ panel = Panel(
200
+ Text(title, style=output_type.value, justify="center"),
201
+ border_style=output_type.value
202
+ )
203
+ console.print()
204
+ console.print(panel)
205
+ console.print()
206
+ @staticmethod
207
+ def print_stream(text: str):
208
+ """
209
+ Print stream output without line break.
210
+
211
+ Args:
212
+ text: The text to print
213
+ """
214
+ style = PrettyOutput._get_style(OutputType.SYSTEM)
215
+ console.print(text, style=style, end="")
216
+ @staticmethod
217
+ def print_stream_end():
218
+ """
219
+ End stream output with line break.
220
+ """
221
+ end_style = PrettyOutput._get_style(OutputType.SUCCESS)
222
+ console.print("\n", style=end_style)
223
+ console.file.flush()
224
+ @staticmethod
225
+ def _get_style(output_type: OutputType) -> RichStyle:
226
+ """
227
+ Get pre-defined RichStyle for output type.
228
+
229
+ Args:
230
+ output_type: The output type to get style for
231
+
232
+ Returns:
233
+ RichStyle: The corresponding style
234
+ """
235
+ return console.get_style(output_type.value)
@@ -0,0 +1,150 @@
1
+ import os
2
+ import time
3
+ import hashlib
4
+ from pathlib import Path
5
+ from typing import Dict, List
6
+ import psutil
7
+ from jarvis.jarvis_utils.config import get_max_token_count
8
+ from jarvis.jarvis_utils.embedding import get_context_token_count
9
+ from jarvis.jarvis_utils.input import get_single_line_input
10
+ from jarvis.jarvis_utils.output import PrettyOutput, OutputType
11
+ def init_env():
12
+ """Initialize environment variables from ~/.jarvis/env file.
13
+
14
+ Creates the .jarvis directory if it doesn't exist and loads environment variables
15
+ from the env file. Handles file reading errors gracefully.
16
+ """
17
+ jarvis_dir = Path.home() / ".jarvis"
18
+ env_file = jarvis_dir / "env"
19
+
20
+ # Check if ~/.jarvis directory exists
21
+ if not jarvis_dir.exists():
22
+ jarvis_dir.mkdir(parents=True)
23
+ if env_file.exists():
24
+ try:
25
+ with open(env_file, "r", encoding="utf-8") as f:
26
+ for line in f:
27
+ line = line.strip()
28
+ if line and not line.startswith(("#", ";")):
29
+ try:
30
+ key, value = line.split("=", 1)
31
+ os.environ[key.strip()] = value.strip().strip("'").strip('"')
32
+ except ValueError:
33
+ continue
34
+ except Exception as e:
35
+ PrettyOutput.print(f"警告: 读取 {env_file} 失败: {e}", OutputType.WARNING)
36
+ def while_success(func, sleep_time: float = 0.1):
37
+ while True:
38
+ try:
39
+ return func()
40
+ except Exception as e:
41
+ PrettyOutput.print(f"执行失败: {str(e)}, 等待 {sleep_time}s...", OutputType.ERROR)
42
+ time.sleep(sleep_time)
43
+ continue
44
+ def while_true(func, sleep_time: float = 0.1):
45
+ """Loop execution function, until the function returns True"""
46
+ while True:
47
+ ret = func()
48
+ if ret:
49
+ break
50
+ PrettyOutput.print(f"执行失败, 等待 {sleep_time}s...", OutputType.WARNING)
51
+ time.sleep(sleep_time)
52
+ return ret
53
+ def get_file_md5(filepath: str)->str:
54
+ """Calculate the MD5 hash of a file's content.
55
+
56
+ Args:
57
+ filepath: Path to the file to hash
58
+
59
+ Returns:
60
+ str: MD5 hash of the file's content
61
+ """
62
+ return hashlib.md5(open(filepath, "rb").read(100*1024*1024)).hexdigest()
63
+ def user_confirm(tip: str, default: bool = True) -> bool:
64
+ """Prompt the user for confirmation with a yes/no question.
65
+
66
+ Args:
67
+ tip: The message to show to the user
68
+ default: The default response if user hits enter
69
+
70
+ Returns:
71
+ bool: True if user confirmed, False otherwise
72
+ """
73
+ suffix = "[Y/n]" if default else "[y/N]"
74
+ ret = get_single_line_input(f"{tip} {suffix}: ")
75
+ return default if ret == "" else ret.lower() == "y"
76
+ def get_file_line_count(filename: str) -> int:
77
+ """Count the number of lines in a file.
78
+
79
+ Args:
80
+ filename: Path to the file to count lines for
81
+
82
+ Returns:
83
+ int: Number of lines in the file, 0 if file cannot be read
84
+ """
85
+ try:
86
+ return len(open(filename, "r", encoding="utf-8").readlines())
87
+ except Exception as e:
88
+ return 0
89
+ def init_gpu_config() -> Dict:
90
+ """Initialize GPU configuration based on available hardware.
91
+
92
+ Returns:
93
+ Dict: GPU configuration including memory sizes and availability
94
+ """
95
+ config = {
96
+ "has_gpu": False,
97
+ "shared_memory": 0,
98
+ "device_memory": 0,
99
+ "memory_fraction": 0.8 # 默认使用80%的可用内存
100
+ }
101
+
102
+ try:
103
+ import torch
104
+ if torch.cuda.is_available():
105
+ # 获取GPU信息
106
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory
107
+ config["has_gpu"] = True
108
+ config["device_memory"] = gpu_mem
109
+
110
+ # 估算共享内存 (通常是系统内存的一部分)
111
+ system_memory = psutil.virtual_memory().total
112
+ config["shared_memory"] = min(system_memory * 0.5, gpu_mem * 2) # 取系统内存的50%或GPU内存的2倍中的较小值
113
+
114
+ # 设置CUDA内存分配
115
+ torch.cuda.set_per_process_memory_fraction(config["memory_fraction"])
116
+ torch.cuda.empty_cache()
117
+
118
+ PrettyOutput.print(
119
+ f"GPU已初始化: {torch.cuda.get_device_name(0)}\n"
120
+ f"设备内存: {gpu_mem / 1024**3:.1f}GB\n"
121
+ f"共享内存: {config['shared_memory'] / 1024**3:.1f}GB",
122
+ output_type=OutputType.SUCCESS
123
+ )
124
+ else:
125
+ PrettyOutput.print("没有GPU可用, 使用CPU模式", output_type=OutputType.WARNING)
126
+ except Exception as e:
127
+ PrettyOutput.print(f"GPU初始化失败: {str(e)}", output_type=OutputType.WARNING)
128
+
129
+ return config
130
+
131
+
132
+ def is_long_context(files: list) -> bool:
133
+ """Check if the file list belongs to a long context (total characters exceed 80% of the maximum context length)"""
134
+ max_token_count = get_max_token_count()
135
+ threshold = max_token_count * 0.8
136
+ total_tokens = 0
137
+
138
+ for file_path in files:
139
+ try:
140
+ with open(file_path, 'r', encoding='utf-8') as f:
141
+ content = f.read()
142
+ total_tokens += get_context_token_count(content)
143
+
144
+ if total_tokens > threshold:
145
+ return True
146
+ except Exception as e:
147
+ PrettyOutput.print(f"读取文件 {file_path} 失败: {e}", OutputType.WARNING)
148
+ continue
149
+
150
+ return total_tokens > threshold