jarvis-ai-assistant 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,252 @@
1
+ # -*- coding: utf-8 -*-
2
+ import json
3
+ import shutil
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from jarvis.jarvis_utils.config import get_data_dir
8
+ from jarvis.jarvis_utils.output import OutputType, PrettyOutput
9
+ from jarvis.jarvis_utils.globals import (
10
+ clear_short_term_memories,
11
+ get_short_term_memories,
12
+ short_term_memories,
13
+ )
14
+
15
+
16
+ class ClearMemoryTool:
17
+ """清除记忆工具,用于批量清除指定的记忆"""
18
+
19
+ name = "clear_memory"
20
+ description = """批量清除指定的记忆。
21
+
22
+ 支持的清除方式:
23
+ 1. 按记忆类型清除所有记忆
24
+ 2. 按标签清除特定记忆
25
+ 3. 按记忆ID清除单个记忆
26
+
27
+ 支持的记忆类型:
28
+ - project_long_term: 项目长期记忆
29
+ - global_long_term: 全局长期记忆
30
+ - short_term: 短期记忆
31
+ - all: 所有类型的记忆
32
+
33
+ 注意:清除操作不可恢复,请谨慎使用
34
+ """
35
+
36
+ parameters = {
37
+ "type": "object",
38
+ "properties": {
39
+ "memory_types": {
40
+ "type": "array",
41
+ "items": {
42
+ "type": "string",
43
+ "enum": [
44
+ "project_long_term",
45
+ "global_long_term",
46
+ "short_term",
47
+ "all",
48
+ ],
49
+ },
50
+ "description": "要清除的记忆类型列表",
51
+ },
52
+ "tags": {
53
+ "type": "array",
54
+ "items": {"type": "string"},
55
+ "description": "要清除的记忆标签列表(可选,如果指定则只清除带有这些标签的记忆)",
56
+ },
57
+ "memory_ids": {
58
+ "type": "array",
59
+ "items": {"type": "string"},
60
+ "description": "要清除的具体记忆ID列表(可选)",
61
+ },
62
+ "confirm": {
63
+ "type": "boolean",
64
+ "description": "确认清除操作(必须为true才会执行清除)",
65
+ "default": False,
66
+ },
67
+ },
68
+ "required": ["memory_types", "confirm"],
69
+ }
70
+
71
+ def __init__(self):
72
+ """初始化清除记忆工具"""
73
+ self.project_memory_dir = Path(".jarvis/memory")
74
+ self.global_memory_dir = Path(get_data_dir()) / "memory"
75
+
76
+ def _get_memory_dir(self, memory_type: str) -> Path:
77
+ """根据记忆类型获取存储目录"""
78
+ if memory_type == "project_long_term":
79
+ return self.project_memory_dir
80
+ elif memory_type in ["global_long_term", "short_term"]:
81
+ return self.global_memory_dir / memory_type
82
+ else:
83
+ raise ValueError(f"未知的记忆类型: {memory_type}")
84
+
85
+ def _clear_short_term_memories(
86
+ self, tags: Optional[List[str]] = None, memory_ids: Optional[List[str]] = None
87
+ ) -> Dict[str, int]:
88
+ """清除短期记忆"""
89
+ global short_term_memories
90
+
91
+ initial_count = len(short_term_memories)
92
+ removed_count = 0
93
+
94
+ if memory_ids:
95
+ # 按ID清除
96
+ new_memories = []
97
+ for memory in short_term_memories:
98
+ if memory.get("id") not in memory_ids:
99
+ new_memories.append(memory)
100
+ else:
101
+ removed_count += 1
102
+ short_term_memories[:] = new_memories
103
+ elif tags:
104
+ # 按标签清除
105
+ new_memories = []
106
+ for memory in short_term_memories:
107
+ memory_tags = memory.get("tags", [])
108
+ if not any(tag in memory_tags for tag in tags):
109
+ new_memories.append(memory)
110
+ else:
111
+ removed_count += 1
112
+ short_term_memories[:] = new_memories
113
+ else:
114
+ # 清除所有
115
+ clear_short_term_memories()
116
+ removed_count = initial_count
117
+
118
+ return {"total": initial_count, "removed": removed_count}
119
+
120
+ def _clear_long_term_memories(
121
+ self,
122
+ memory_type: str,
123
+ tags: Optional[List[str]] = None,
124
+ memory_ids: Optional[List[str]] = None,
125
+ ) -> Dict[str, int]:
126
+ """清除长期记忆"""
127
+ memory_dir = self._get_memory_dir(memory_type)
128
+
129
+ if not memory_dir.exists():
130
+ return {"total": 0, "removed": 0}
131
+
132
+ total_count = 0
133
+ removed_count = 0
134
+
135
+ # 获取所有记忆文件
136
+ memory_files = list(memory_dir.glob("*.json"))
137
+ total_count = len(memory_files)
138
+
139
+ for memory_file in memory_files:
140
+ try:
141
+ # 读取记忆内容
142
+ with open(memory_file, "r", encoding="utf-8") as f:
143
+ memory_data = json.load(f)
144
+
145
+ should_remove = False
146
+
147
+ if memory_ids:
148
+ # 按ID判断
149
+ if memory_data.get("id") in memory_ids:
150
+ should_remove = True
151
+ elif tags:
152
+ # 按标签判断
153
+ memory_tags = memory_data.get("tags", [])
154
+ if any(tag in memory_tags for tag in tags):
155
+ should_remove = True
156
+ else:
157
+ # 清除所有
158
+ should_remove = True
159
+
160
+ if should_remove:
161
+ memory_file.unlink()
162
+ removed_count += 1
163
+
164
+ except Exception as e:
165
+ PrettyOutput.print(
166
+ f"处理记忆文件 {memory_file} 时出错: {str(e)}", OutputType.WARNING
167
+ )
168
+
169
+ # 如果目录为空,可以删除目录
170
+ if not any(memory_dir.iterdir()) and memory_dir != self.project_memory_dir:
171
+ memory_dir.rmdir()
172
+
173
+ return {"total": total_count, "removed": removed_count}
174
+
175
+ def execute(self, args: Dict[str, Any]) -> Dict[str, Any]:
176
+ """执行清除记忆操作"""
177
+ try:
178
+ memory_types = args.get("memory_types", [])
179
+ tags = args.get("tags", [])
180
+ memory_ids = args.get("memory_ids", [])
181
+ confirm = args.get("confirm", False)
182
+
183
+ if not confirm:
184
+ return {
185
+ "success": False,
186
+ "stdout": "",
187
+ "stderr": "必须设置 confirm=true 才能执行清除操作",
188
+ }
189
+
190
+ # 确定要清除的记忆类型
191
+ if "all" in memory_types:
192
+ types_to_clear = ["project_long_term", "global_long_term", "short_term"]
193
+ else:
194
+ types_to_clear = memory_types
195
+
196
+ # 统计结果
197
+ results = {}
198
+ total_removed = 0
199
+
200
+ # 清除各类型的记忆
201
+ for memory_type in types_to_clear:
202
+ if memory_type == "short_term":
203
+ result = self._clear_short_term_memories(tags, memory_ids)
204
+ else:
205
+ result = self._clear_long_term_memories(
206
+ memory_type, tags, memory_ids
207
+ )
208
+
209
+ results[memory_type] = result
210
+ total_removed += result["removed"]
211
+
212
+ # 生成结果报告
213
+ PrettyOutput.print(
214
+ f"记忆清除完成,共清除 {total_removed} 条记忆", OutputType.SUCCESS
215
+ )
216
+
217
+ # 详细报告
218
+ report = f"# 记忆清除报告\n\n"
219
+ report += f"**总计清除**: {total_removed} 条记忆\n\n"
220
+
221
+ if tags:
222
+ report += f"**使用标签过滤**: {', '.join(tags)}\n\n"
223
+
224
+ if memory_ids:
225
+ report += f"**指定记忆ID**: {', '.join(memory_ids)}\n\n"
226
+
227
+ report += "## 详细结果\n\n"
228
+
229
+ for memory_type, result in results.items():
230
+ report += f"### {memory_type}\n"
231
+ report += f"- 原有记忆: {result['total']} 条\n"
232
+ report += f"- 已清除: {result['removed']} 条\n"
233
+ report += f"- 剩余: {result['total'] - result['removed']} 条\n\n"
234
+
235
+ # 在终端显示摘要
236
+ for memory_type, result in results.items():
237
+ if result["removed"] > 0:
238
+ PrettyOutput.print(
239
+ f"{memory_type}: 清除了 {result['removed']}/{result['total']} 条记忆",
240
+ OutputType.INFO,
241
+ )
242
+
243
+ return {
244
+ "success": True,
245
+ "stdout": report,
246
+ "stderr": "",
247
+ }
248
+
249
+ except Exception as e:
250
+ error_msg = f"清除记忆失败: {str(e)}"
251
+ PrettyOutput.print(error_msg, OutputType.ERROR)
252
+ return {"success": False, "stdout": "", "stderr": error_msg}
@@ -107,17 +107,13 @@ arguments:
107
107
 
108
108
 
109
109
  class OutputHandlerProtocol(Protocol):
110
- def name(self) -> str:
111
- ...
110
+ def name(self) -> str: ...
112
111
 
113
- def can_handle(self, response: str) -> bool:
114
- ...
112
+ def can_handle(self, response: str) -> bool: ...
115
113
 
116
- def prompt(self) -> str:
117
- ...
114
+ def prompt(self) -> str: ...
118
115
 
119
- def handle(self, response: str, agent: Any) -> Tuple[bool, Any]:
120
- ...
116
+ def handle(self, response: str, agent: Any) -> Tuple[bool, Any]: ...
121
117
 
122
118
 
123
119
  class ToolRegistry(OutputHandlerProtocol):
@@ -138,9 +134,7 @@ class ToolRegistry(OutputHandlerProtocol):
138
134
  try:
139
135
  tools_prompt += " <tool>\n"
140
136
  tools_prompt += f" <name>名称: {tool['name']}</name>\n"
141
- tools_prompt += (
142
- f" <description>描述: {tool['description']}</description>\n"
143
- )
137
+ tools_prompt += f" <description>描述: {tool['description']}</description>\n"
144
138
  tools_prompt += " <parameters>\n"
145
139
  tools_prompt += " <yaml>|\n"
146
140
 
@@ -202,15 +196,15 @@ class ToolRegistry(OutputHandlerProtocol):
202
196
  """从数据目录获取工具调用统计"""
203
197
  from jarvis.jarvis_stats.stats import StatsManager
204
198
  from datetime import datetime, timedelta
205
-
199
+
206
200
  # 获取所有工具的统计数据
207
201
  tool_stats = {}
208
202
  tools = self.get_all_tools()
209
-
203
+
210
204
  # 获取所有历史数据(从很早的时间开始)
211
205
  end_time = datetime.now()
212
206
  start_time = datetime(2000, 1, 1) # 使用一个足够早的时间
213
-
207
+
214
208
  for tool in tools:
215
209
  tool_name = tool["name"]
216
210
  # 获取该工具的统计数据
@@ -218,22 +212,22 @@ class ToolRegistry(OutputHandlerProtocol):
218
212
  metric_name=tool_name,
219
213
  start_time=start_time,
220
214
  end_time=end_time,
221
- tags={"group": "tool"}
215
+ tags={"group": "tool"},
222
216
  )
223
-
217
+
224
218
  # 计算总调用次数
225
219
  if stats_data and "records" in stats_data:
226
220
  total_count = sum(record["value"] for record in stats_data["records"])
227
221
  tool_stats[tool_name] = int(total_count)
228
222
  else:
229
223
  tool_stats[tool_name] = 0
230
-
224
+
231
225
  return tool_stats
232
226
 
233
227
  def _update_tool_stats(self, name: str) -> None:
234
228
  """更新工具调用统计"""
235
229
  from jarvis.jarvis_stats.stats import StatsManager
236
-
230
+
237
231
  StatsManager.increment(name, group="tool")
238
232
 
239
233
  def use_tools(self, name: List[str]) -> None:
@@ -292,7 +286,9 @@ class ToolRegistry(OutputHandlerProtocol):
292
286
  config = yaml.safe_load(open(file_path, "r", encoding="utf-8"))
293
287
  self.register_mcp_tool_by_config(config)
294
288
  except Exception as e:
295
- PrettyOutput.print(f"文件 {file_path} 加载失败: {str(e)}", OutputType.WARNING)
289
+ PrettyOutput.print(
290
+ f"文件 {file_path} 加载失败: {str(e)}", OutputType.WARNING
291
+ )
296
292
 
297
293
  def _load_builtin_tools(self) -> None:
298
294
  """从内置工具目录加载工具"""
@@ -308,8 +304,33 @@ class ToolRegistry(OutputHandlerProtocol):
308
304
 
309
305
  def _load_external_tools(self) -> None:
310
306
  """从jarvis_data/tools和配置的目录加载外部工具"""
307
+ from jarvis.jarvis_utils.config import get_central_tool_repo
308
+
311
309
  tool_dirs = [str(Path(get_data_dir()) / "tools")] + get_tool_load_dirs()
312
310
 
311
+ # 如果配置了中心工具仓库,将其添加到加载路径
312
+ central_repo = get_central_tool_repo()
313
+ if central_repo:
314
+ # 中心工具仓库存储在数据目录下的特定位置
315
+ central_repo_path = os.path.join(get_data_dir(), "central_tool_repo")
316
+ tool_dirs.append(central_repo_path)
317
+
318
+ # 确保中心工具仓库被克隆/更新
319
+ if not os.path.exists(central_repo_path):
320
+ try:
321
+ import subprocess
322
+
323
+ PrettyOutput.print(
324
+ f"正在克隆中心工具仓库: {central_repo}", OutputType.INFO
325
+ )
326
+ subprocess.run(
327
+ ["git", "clone", central_repo, central_repo_path], check=True
328
+ )
329
+ except Exception as e:
330
+ PrettyOutput.print(
331
+ f"克隆中心工具仓库失败: {str(e)}", OutputType.ERROR
332
+ )
333
+
313
334
  # --- 全局每日更新检查 ---
314
335
  daily_check_git_updates(tool_dirs, "tools")
315
336
 
@@ -662,6 +683,10 @@ class ToolRegistry(OutputHandlerProtocol):
662
683
  parameters: 工具参数定义
663
684
  func: 工具执行函数
664
685
  """
686
+ if name in self.tools:
687
+ PrettyOutput.print(
688
+ f"警告: 工具 '{name}' 已存在,将被覆盖", OutputType.WARNING
689
+ )
665
690
  self.tools[name] = Tool(name, description, parameters, func)
666
691
 
667
692
  def get_tool(self, name: str) -> Optional[Tool]:
@@ -735,7 +760,9 @@ class ToolRegistry(OutputHandlerProtocol):
735
760
  """
736
761
  if len(output.splitlines()) > 60:
737
762
  lines = output.splitlines()
738
- return "\n".join(lines[:30] + ["\n...内容太长,已截取前后30行...\n"] + lines[-30:])
763
+ return "\n".join(
764
+ lines[:30] + ["\n...内容太长,已截取前后30行...\n"] + lines[-30:]
765
+ )
739
766
  return output
740
767
 
741
768
  def handle_tool_calls(self, tool_call: Dict[str, Any], agent: Any) -> str:
@@ -22,6 +22,7 @@ class RetrieveMemoryTool:
22
22
  - all: 从所有类型中检索
23
23
 
24
24
  可以通过标签过滤检索结果,支持多个标签(满足任一标签即可)
25
+ 注意:标签数量建议不要超过10个,以保证检索效率
25
26
  """
26
27
 
27
28
  parameters = {
@@ -149,35 +150,40 @@ class RetrieveMemoryTool:
149
150
  # 格式化为Markdown输出
150
151
  markdown_output = f"# 记忆检索结果\n\n"
151
152
  markdown_output += f"**检索到 {len(all_memories)} 条记忆**\n\n"
152
-
153
+
153
154
  if tags:
154
155
  markdown_output += f"**使用标签过滤**: {', '.join(tags)}\n\n"
155
-
156
+
156
157
  markdown_output += f"**记忆类型**: {', '.join(types_to_search)}\n\n"
157
-
158
+
158
159
  markdown_output += "---\n\n"
159
-
160
+
160
161
  # 输出所有记忆
161
162
  for i, memory in enumerate(all_memories):
162
163
  markdown_output += f"## {i+1}. {memory.get('id', '未知ID')}\n\n"
163
164
  markdown_output += f"**类型**: {memory.get('type', '未知类型')}\n\n"
164
165
  markdown_output += f"**标签**: {', '.join(memory.get('tags', []))}\n\n"
165
- markdown_output += f"**创建时间**: {memory.get('created_at', '未知时间')}\n\n"
166
-
166
+ markdown_output += (
167
+ f"**创建时间**: {memory.get('created_at', '未知时间')}\n\n"
168
+ )
169
+
167
170
  # 内容部分
168
- content = memory.get('content', '')
171
+ content = memory.get("content", "")
169
172
  if content:
170
173
  markdown_output += f"**内容**:\n\n{content}\n\n"
171
-
174
+
172
175
  # 如果有额外的元数据
173
- metadata = {k: v for k, v in memory.items()
174
- if k not in ['id', 'type', 'tags', 'created_at', 'content']}
176
+ metadata = {
177
+ k: v
178
+ for k, v in memory.items()
179
+ if k not in ["id", "type", "tags", "created_at", "content"]
180
+ }
175
181
  if metadata:
176
182
  markdown_output += f"**其他信息**:\n"
177
183
  for key, value in metadata.items():
178
184
  markdown_output += f"- {key}: {value}\n"
179
185
  markdown_output += "\n"
180
-
186
+
181
187
  markdown_output += "---\n\n"
182
188
 
183
189
  # 如果记忆较多,在终端显示摘要