beswarm 0.2.39__py3-none-any.whl → 0.2.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of beswarm might be problematic. Click here for more details.
- beswarm/agents/planact.py +13 -30
- beswarm/aient/setup.py +1 -1
- beswarm/aient/src/aient/core/request.py +4 -2
- beswarm/aient/src/aient/core/response.py +14 -5
- beswarm/aient/src/aient/models/chatgpt.py +154 -49
- beswarm/aient/src/aient/plugins/write_file.py +6 -1
- beswarm/aient/test/test_API.py +1 -1
- beswarm/taskmanager.py +207 -160
- beswarm/tools/__init__.py +4 -1
- beswarm/tools/completion.py +1 -1
- beswarm/tools/search_arxiv.py +58 -3
- beswarm/tools/subtasks.py +100 -7
- beswarm/tools/write_csv.py +35 -0
- beswarm/utils.py +46 -0
- {beswarm-0.2.39.dist-info → beswarm-0.2.41.dist-info}/METADATA +1 -1
- {beswarm-0.2.39.dist-info → beswarm-0.2.41.dist-info}/RECORD +18 -17
- {beswarm-0.2.39.dist-info → beswarm-0.2.41.dist-info}/WHEEL +0 -0
- {beswarm-0.2.39.dist-info → beswarm-0.2.41.dist-info}/top_level.txt +0 -0
beswarm/taskmanager.py
CHANGED
|
@@ -17,210 +17,257 @@ class TaskStatus(Enum):
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class TaskManager:
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
20
|
+
"""
|
|
21
|
+
一个带并发控制的异步任务管理器。
|
|
22
|
+
它管理任务的生命周期,并通过一个固定大小的工作者池来控制并发执行的任务数量。
|
|
23
|
+
"""
|
|
24
|
+
def __init__(self, concurrency_limit=3):
|
|
25
|
+
if concurrency_limit <= 0:
|
|
26
|
+
raise ValueError("并发限制必须大于0")
|
|
27
|
+
|
|
28
|
+
self.concurrency_limit = concurrency_limit
|
|
29
|
+
self.tasks_cache = {} # 存储所有任务的状态和元数据, key: task_id
|
|
30
|
+
|
|
31
|
+
self._pending_queue = asyncio.Queue() # 内部待办任务队列
|
|
32
|
+
self._results_queue = asyncio.Queue() # 内部已完成任务结果队列
|
|
33
|
+
self._workers = [] # 持有工作者任务的引用
|
|
34
|
+
self._is_running = False # 标记工作者池是否在运行
|
|
24
35
|
self.root_path = None
|
|
25
|
-
self.
|
|
36
|
+
self.cache_dir = None
|
|
37
|
+
self.task_cache_file = None
|
|
38
|
+
|
|
39
|
+
print(f"TaskManager 初始化,并发限制为: {self.concurrency_limit}")
|
|
26
40
|
|
|
27
41
|
def set_root_path(self, root_path):
|
|
28
|
-
|
|
42
|
+
"""设置工作根目录并加载持久化的任务状态。"""
|
|
43
|
+
if self.root_path is not None:
|
|
29
44
|
return
|
|
30
45
|
self.root_path = Path(root_path)
|
|
31
46
|
self.cache_dir = self.root_path / ".beswarm"
|
|
47
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
32
48
|
self.task_cache_file = self.cache_dir / "tasks.json"
|
|
33
|
-
|
|
34
|
-
self.
|
|
49
|
+
|
|
50
|
+
self._load_tasks_from_cache()
|
|
35
51
|
self.set_task_cache("root_path", str(self.root_path))
|
|
36
|
-
self.resume_all_running_task()
|
|
37
52
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
53
|
+
# 启动工作者池
|
|
54
|
+
self.start()
|
|
55
|
+
# 恢复中断的任务
|
|
56
|
+
self.resume_interrupted_tasks()
|
|
57
|
+
|
|
58
|
+
def start(self):
|
|
59
|
+
"""启动并发工作者池。"""
|
|
60
|
+
if self._is_running:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
self._is_running = True
|
|
64
|
+
for i in range(self.concurrency_limit):
|
|
65
|
+
worker = asyncio.create_task(self._worker_loop(f"Worker-{i+1}"))
|
|
66
|
+
self._workers.append(worker)
|
|
67
|
+
print(f"已启动 {self.concurrency_limit} 个并发工作者。")
|
|
68
|
+
|
|
69
|
+
async def stop(self):
|
|
70
|
+
"""优雅地停止所有工作者。"""
|
|
71
|
+
if not self._is_running:
|
|
72
|
+
return
|
|
73
|
+
|
|
74
|
+
print("\n正在停止 TaskManager...")
|
|
75
|
+
await self._pending_queue.join()
|
|
76
|
+
|
|
77
|
+
for worker in self._workers:
|
|
78
|
+
worker.cancel()
|
|
79
|
+
|
|
80
|
+
await asyncio.gather(*self._workers, return_exceptions=True)
|
|
81
|
+
|
|
82
|
+
self._is_running = False
|
|
83
|
+
print("所有工作者已停止。")
|
|
84
|
+
|
|
85
|
+
async def _worker_loop(self, worker_name: str):
|
|
86
|
+
"""每个工作者的主循环,从队列中拉取并执行任务。"""
|
|
87
|
+
print(f"[{worker_name}] 已就绪,等待任务...")
|
|
88
|
+
while self._is_running:
|
|
89
|
+
try:
|
|
90
|
+
task_id, coro = await self._pending_queue.get()
|
|
91
|
+
|
|
92
|
+
print(f"[{worker_name}] 领到了任务 <{task_id[:8]}>,开始执行...")
|
|
93
|
+
self._update_task_status(task_id, TaskStatus.RUNNING)
|
|
46
94
|
|
|
47
|
-
|
|
48
|
-
|
|
95
|
+
try:
|
|
96
|
+
result = await coro
|
|
97
|
+
self._handle_task_completion(task_id, TaskStatus.DONE, result)
|
|
98
|
+
except Exception as e:
|
|
99
|
+
self._handle_task_completion(task_id, TaskStatus.ERROR, e)
|
|
100
|
+
finally:
|
|
101
|
+
self._pending_queue.task_done()
|
|
49
102
|
|
|
103
|
+
except asyncio.CancelledError:
|
|
104
|
+
print(f"[{worker_name}] 被取消,正在退出...")
|
|
105
|
+
break
|
|
106
|
+
except Exception as e:
|
|
107
|
+
print(f"[{worker_name}] 循环中遇到严重错误: {e}")
|
|
108
|
+
|
|
109
|
+
def _handle_task_completion(self, task_id, status, result):
|
|
110
|
+
"""统一处理任务完成的内部函数。"""
|
|
111
|
+
if status == TaskStatus.DONE:
|
|
112
|
+
print(f"✅ 任务 <{task_id[:8]}> 执行成功。")
|
|
113
|
+
else: # ERROR
|
|
114
|
+
print(f"❌ 任务 <{task_id[:8]}> 执行失败: {result}")
|
|
115
|
+
|
|
116
|
+
self._update_task_status(task_id, status, result=str(result))
|
|
117
|
+
self._results_queue.put_nowait((task_id, status, result))
|
|
118
|
+
|
|
119
|
+
def set_task_cache(self, *keys_and_value):
|
|
120
|
+
"""设置可嵌套的任务缓存。"""
|
|
121
|
+
if len(keys_and_value) < 2: return
|
|
122
|
+
keys, value = keys_and_value[:-1], keys_and_value[-1]
|
|
50
123
|
d = self.tasks_cache
|
|
51
|
-
# 遍历到倒数第二个键,确保路径存在
|
|
52
124
|
for key in keys[:-1]:
|
|
53
125
|
d = d.setdefault(key, {})
|
|
54
|
-
|
|
55
|
-
# 在最后一个键上设置值
|
|
56
126
|
d[keys[-1]] = value
|
|
57
|
-
self.
|
|
127
|
+
self._save_tasks_to_cache()
|
|
58
128
|
|
|
59
|
-
def
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def read_tasks_cache(self):
|
|
63
|
-
content = self.task_cache_file.read_text(encoding="utf-8")
|
|
129
|
+
def _save_tasks_to_cache(self):
|
|
130
|
+
"""将任务缓存持久化到文件。"""
|
|
131
|
+
if not self.task_cache_file: return
|
|
64
132
|
try:
|
|
65
|
-
self.
|
|
66
|
-
|
|
67
|
-
|
|
133
|
+
with self.task_cache_file.open('w', encoding='utf-8') as f:
|
|
134
|
+
json.dump(self.tasks_cache, f, indent=4, ensure_ascii=False)
|
|
135
|
+
except Exception as e:
|
|
136
|
+
print(f"警告:无法将任务状态持久化到文件: {e}")
|
|
68
137
|
|
|
69
|
-
def
|
|
70
|
-
"""
|
|
71
|
-
|
|
138
|
+
def _load_tasks_from_cache(self):
|
|
139
|
+
"""从文件加载任务缓存。"""
|
|
140
|
+
if not self.task_cache_file or not self.task_cache_file.exists():
|
|
141
|
+
self.tasks_cache = {}
|
|
142
|
+
return
|
|
143
|
+
try:
|
|
144
|
+
content = self.task_cache_file.read_text(encoding='utf-8')
|
|
145
|
+
if content:
|
|
146
|
+
self.tasks_cache = json.loads(content)
|
|
147
|
+
else:
|
|
148
|
+
self.tasks_cache = {}
|
|
149
|
+
except (FileNotFoundError, json.JSONDecodeError):
|
|
150
|
+
self.tasks_cache = {}
|
|
151
|
+
print("警告:任务缓存文件不存在或格式错误,将使用空缓存。")
|
|
72
152
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
153
|
+
async def get_next_result(self):
|
|
154
|
+
"""异步获取下一个完成的任务结果。"""
|
|
155
|
+
return await self._results_queue.get()
|
|
76
156
|
|
|
77
|
-
|
|
78
|
-
|
|
157
|
+
def create_tasks_batch(self, task_coro_func, tasks_params_list):
|
|
158
|
+
"""
|
|
159
|
+
批量创建任务,但不是立即执行,而是将它们放入待处理队列。
|
|
79
160
|
"""
|
|
161
|
+
if not self._is_running:
|
|
162
|
+
raise RuntimeError("TaskManager尚未启动。请先调用 start() 方法。")
|
|
163
|
+
|
|
80
164
|
task_ids = []
|
|
81
|
-
for
|
|
82
|
-
|
|
83
|
-
|
|
165
|
+
for params in tasks_params_list:
|
|
166
|
+
task_id = str(uuid.uuid4())
|
|
167
|
+
coro = task_coro_func(**params)
|
|
168
|
+
|
|
169
|
+
# 初始化任务状态为 PENDING
|
|
170
|
+
self._update_task_status(task_id, TaskStatus.PENDING, args=params)
|
|
171
|
+
|
|
172
|
+
# 将任务定义放入队列
|
|
173
|
+
self._pending_queue.put_nowait((task_id, coro))
|
|
84
174
|
task_ids.append(task_id)
|
|
85
|
-
|
|
86
|
-
|
|
175
|
+
|
|
176
|
+
print(f"已将 {len(task_ids)} 个新任务加入待处理队列。队列当前大小: {self._pending_queue.qsize()}")
|
|
87
177
|
return task_ids
|
|
88
178
|
|
|
89
|
-
def
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
task_id = self.resume_task(task_id, registry.tools["worker"], tasks_params)
|
|
179
|
+
def create_tasks(self, task_coro_func, tasks_params_list):
|
|
180
|
+
"""批量将任务放入待处理队列。"""
|
|
181
|
+
if not self._is_running:
|
|
182
|
+
raise RuntimeError("TaskManager尚未启动。请先在 set_root_path 后确保其已启动。")
|
|
94
183
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
task = self.tasks_cache.get(task_id)
|
|
100
|
-
if not task:
|
|
101
|
-
return TaskStatus.NOT_FOUND
|
|
184
|
+
task_ids = []
|
|
185
|
+
for params in tasks_params_list:
|
|
186
|
+
task_id = str(uuid.uuid4())
|
|
187
|
+
coro = task_coro_func(**params)
|
|
102
188
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
self.set_task_cache(task_id, "status", TaskStatus.RUNNING.value)
|
|
107
|
-
print(f"任务已恢复: ID={task_id}, Name={task_id}")
|
|
108
|
-
print(f"args: {args}")
|
|
109
|
-
print(f"self.tasks_cache: {json.dumps(self.tasks_cache, ensure_ascii=False, indent=4)}")
|
|
110
|
-
return task_id
|
|
189
|
+
self._update_task_status(task_id, TaskStatus.PENDING, args=params)
|
|
190
|
+
self._pending_queue.put_nowait((task_id, coro))
|
|
191
|
+
task_ids.append(task_id)
|
|
111
192
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
创建并注册一个新任务。
|
|
193
|
+
print(f"已将 {len(task_ids)} 个新任务加入待处理队列。队列当前大小: {self._pending_queue.qsize()}")
|
|
194
|
+
return task_ids
|
|
115
195
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
196
|
+
def resume_interrupted_tasks(self):
|
|
197
|
+
"""在启动时,恢复所有处于 PENDING 或 RUNNING 状态的旧任务。"""
|
|
198
|
+
interrupted_tasks = [
|
|
199
|
+
(tid, info) for tid, info in self.tasks_cache.items()
|
|
200
|
+
if tid != "root_path" and info.get("status") in [TaskStatus.PENDING.value, TaskStatus.RUNNING.value]
|
|
201
|
+
]
|
|
119
202
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
"""
|
|
123
|
-
if task_id == None:
|
|
124
|
-
task_id = str(uuid.uuid4())
|
|
125
|
-
task_name = f"Task-{task_id[:8]}"
|
|
203
|
+
if not interrupted_tasks:
|
|
204
|
+
return
|
|
126
205
|
|
|
127
|
-
|
|
128
|
-
|
|
206
|
+
print(f"检测到 {len(interrupted_tasks)} 个中断的任务,正在恢复...")
|
|
207
|
+
worker_fun = registry.tools["worker"]
|
|
129
208
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
print(f"任务已创建: ID={task_id}, Name={task_name}")
|
|
137
|
-
return task_id
|
|
209
|
+
for task_id, task_info in interrupted_tasks:
|
|
210
|
+
args = task_info.get("args")
|
|
211
|
+
if not args:
|
|
212
|
+
print(f"警告:任务 <{task_id[:8]}> 缺少参数,无法恢复。")
|
|
213
|
+
self._update_task_status(task_id, TaskStatus.ERROR, result="缺少参数,无法恢复")
|
|
214
|
+
continue
|
|
138
215
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
216
|
+
coro = worker_fun(**args)
|
|
217
|
+
self._update_task_status(task_id, TaskStatus.PENDING)
|
|
218
|
+
self._pending_queue.put_nowait((task_id, coro))
|
|
142
219
|
|
|
143
|
-
|
|
144
|
-
task_id (str): 要查询的任务ID。
|
|
220
|
+
print(f"{len(interrupted_tasks)} 个中断的任务已重新加入队列。")
|
|
145
221
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
if not task:
|
|
151
|
-
return TaskStatus.NOT_FOUND
|
|
222
|
+
def resume_task(self, task_id, goal):
|
|
223
|
+
"""恢复一个指定的任务,实质上是创建一个新任务并替换旧的记录,但ID保持不变。"""
|
|
224
|
+
if task_id not in self.tasks_cache:
|
|
225
|
+
return f"任务 {task_id} 不存在"
|
|
152
226
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
return TaskStatus.ERROR
|
|
158
|
-
else:
|
|
159
|
-
return TaskStatus.DONE
|
|
227
|
+
old_task_info = self.tasks_cache.get(task_id, {})
|
|
228
|
+
tasks_params = old_task_info.get("args", {})
|
|
229
|
+
if not tasks_params:
|
|
230
|
+
return f"<tool_error>任务 {task_id} 缺少参数信息,无法恢复。</tool_error>"
|
|
160
231
|
|
|
161
|
-
|
|
162
|
-
#
|
|
163
|
-
# 这里我们简化处理,认为未完成的就是运行中。
|
|
164
|
-
return TaskStatus.RUNNING
|
|
232
|
+
tasks_params["goal"] = goal
|
|
233
|
+
tasks_params["cache_messages"] = True # 恢复时强制使用缓存
|
|
165
234
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
task = self.tasks.get(task_id)
|
|
169
|
-
if self.get_task_status(task_id) == TaskStatus.DONE:
|
|
170
|
-
return task.result()
|
|
171
|
-
elif self.get_task_status(task_id) == TaskStatus.ERROR:
|
|
172
|
-
return task.exception()
|
|
173
|
-
return None
|
|
174
|
-
|
|
175
|
-
def _on_task_done(self, task_id, task):
|
|
176
|
-
"""私有回调函数,在任务完成时将结果放入队列。"""
|
|
177
|
-
try:
|
|
178
|
-
# 将元组 (task_id, status, result) 放入队列
|
|
179
|
-
self.results_queue.put_nowait(
|
|
180
|
-
(task_id, TaskStatus.DONE, task.result())
|
|
181
|
-
)
|
|
182
|
-
self.set_task_cache(task_id, "status", TaskStatus.DONE.value)
|
|
183
|
-
self.set_task_cache(task_id, "result", task.result())
|
|
184
|
-
except asyncio.CancelledError:
|
|
185
|
-
self.results_queue.put_nowait(
|
|
186
|
-
(task_id, TaskStatus.CANCELLED, None)
|
|
187
|
-
)
|
|
188
|
-
self.set_task_cache(task_id, "status", TaskStatus.CANCELLED.value)
|
|
189
|
-
except Exception as e:
|
|
190
|
-
self.results_queue.put_nowait(
|
|
191
|
-
(task_id, TaskStatus.ERROR, e)
|
|
192
|
-
)
|
|
193
|
-
self.set_task_cache(task_id, "status", TaskStatus.ERROR.value)
|
|
194
|
-
self.set_task_cache(task_id, "result", str(e))
|
|
235
|
+
worker_fun = registry.tools["worker"]
|
|
236
|
+
coro = worker_fun(**tasks_params)
|
|
195
237
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
等待并返回下一个完成的任务结果。
|
|
238
|
+
self._update_task_status(task_id, TaskStatus.PENDING, args=tasks_params)
|
|
239
|
+
self._pending_queue.put_nowait((task_id, coro))
|
|
199
240
|
|
|
200
|
-
|
|
241
|
+
print(f"任务 <{task_id[:8]}> 已被重新加入队列等待恢复执行。")
|
|
242
|
+
return f"任务 {task_id} 已恢复"
|
|
201
243
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
244
|
+
def _update_task_status(self, task_id, status: TaskStatus, args=None, result=None):
|
|
245
|
+
"""统一更新任务状态缓存并持久化。"""
|
|
246
|
+
if task_id not in self.tasks_cache:
|
|
247
|
+
self.tasks_cache[task_id] = {}
|
|
206
248
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
249
|
+
current_task = self.tasks_cache[task_id]
|
|
250
|
+
current_task['status'] = status.value
|
|
251
|
+
if args is not None:
|
|
252
|
+
current_task['args'] = args
|
|
253
|
+
if result is not None:
|
|
254
|
+
current_task['result'] = result
|
|
210
255
|
|
|
211
|
-
|
|
212
|
-
task_id (str): 要查询的任务ID。
|
|
256
|
+
self._save_tasks_to_cache()
|
|
213
257
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
258
|
+
def get_task_status(self, task_id):
|
|
259
|
+
"""查询特定任务的状态。"""
|
|
260
|
+
task_info = self.tasks_cache.get(task_id)
|
|
261
|
+
if not task_info:
|
|
262
|
+
return TaskStatus.NOT_FOUND
|
|
263
|
+
return TaskStatus(task_info.get("status", "NOT_FOUND"))
|
|
264
|
+
|
|
265
|
+
def get_task_result(self, task_id):
|
|
266
|
+
"""获取已完成任务的结果。"""
|
|
267
|
+
task_info = self.tasks_cache.get(task_id)
|
|
268
|
+
if not task_info or task_info.get("status") not in [TaskStatus.DONE.value, TaskStatus.ERROR.value]:
|
|
269
|
+
return None
|
|
270
|
+
return task_info.get("result")
|
|
224
271
|
|
|
225
272
|
async def main():
|
|
226
273
|
manager = TaskManager()
|
beswarm/tools/__init__.py
CHANGED
|
@@ -3,11 +3,12 @@ from .search_web import search_web
|
|
|
3
3
|
from .completion import task_complete
|
|
4
4
|
from .search_arxiv import search_arxiv
|
|
5
5
|
from .repomap import get_code_repo_map
|
|
6
|
+
from .write_csv import append_row_to_csv
|
|
6
7
|
from .request_input import request_admin_input
|
|
7
8
|
from .screenshot import save_screenshot_to_file
|
|
8
9
|
from .worker import worker, worker_gen, chatgroup
|
|
9
10
|
from .click import find_and_click_element, scroll_screen
|
|
10
|
-
from .subtasks import create_task, resume_task, get_all_tasks_status, get_task_result
|
|
11
|
+
from .subtasks import create_task, resume_task, get_all_tasks_status, get_task_result, create_tasks_from_csv
|
|
11
12
|
|
|
12
13
|
#显式导入 aient.plugins 中的所需内容
|
|
13
14
|
from ..aient.src.aient.plugins import (
|
|
@@ -47,12 +48,14 @@ __all__ = [
|
|
|
47
48
|
"list_directory",
|
|
48
49
|
"get_task_result",
|
|
49
50
|
"get_url_content",
|
|
51
|
+
"append_row_to_csv",
|
|
50
52
|
"set_readonly_path",
|
|
51
53
|
"get_code_repo_map",
|
|
52
54
|
"run_python_script",
|
|
53
55
|
"get_search_results",
|
|
54
56
|
"request_admin_input",
|
|
55
57
|
"get_all_tasks_status",
|
|
58
|
+
"create_tasks_from_csv",
|
|
56
59
|
"find_and_click_element",
|
|
57
60
|
"download_read_arxiv_pdf",
|
|
58
61
|
"save_screenshot_to_file",
|
beswarm/tools/completion.py
CHANGED
beswarm/tools/search_arxiv.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import requests
|
|
2
|
+
import csv
|
|
2
3
|
from datetime import datetime
|
|
3
4
|
from ..aient.src.aient.plugins import register_tool
|
|
4
5
|
|
|
@@ -43,10 +44,11 @@ NoProp: Training Neural Networks without Back-propagation or Forward-propagation
|
|
|
43
44
|
包含搜索结果的字典列表,每个字典包含论文的标题、作者、摘要、发布日期、最后更新日期、arXiv ID、类别和PDF链接等信息
|
|
44
45
|
"""
|
|
45
46
|
try:
|
|
46
|
-
base_url = "
|
|
47
|
+
base_url = "https://export.arxiv.org/api/query"
|
|
47
48
|
|
|
48
49
|
# 构建查询参数
|
|
49
|
-
search_query =
|
|
50
|
+
search_query = query
|
|
51
|
+
# search_query = f"all:{query}"
|
|
50
52
|
|
|
51
53
|
# 添加类别过滤
|
|
52
54
|
if categories:
|
|
@@ -64,6 +66,8 @@ NoProp: Training Neural Networks without Back-propagation or Forward-propagation
|
|
|
64
66
|
else:
|
|
65
67
|
search_query += f" AND au:\"{authors}\""
|
|
66
68
|
|
|
69
|
+
print(search_query)
|
|
70
|
+
|
|
67
71
|
# 添加日期过滤
|
|
68
72
|
# arXiv API不直接支持日期范围过滤,需要在结果中过滤
|
|
69
73
|
|
|
@@ -130,6 +134,27 @@ NoProp: Training Neural Networks without Back-propagation or Forward-propagation
|
|
|
130
134
|
if cat_term not in categories_list:
|
|
131
135
|
categories_list.append(cat_term)
|
|
132
136
|
|
|
137
|
+
# 应用严格的类别过滤,确保论文的所有类别都符合用户的要求
|
|
138
|
+
if categories:
|
|
139
|
+
user_specified_categories = categories if isinstance(categories, list) else [categories]
|
|
140
|
+
|
|
141
|
+
allowed_prefixes = []
|
|
142
|
+
for pattern in user_specified_categories:
|
|
143
|
+
if pattern.endswith('*'):
|
|
144
|
+
allowed_prefixes.append(pattern[:-1])
|
|
145
|
+
else:
|
|
146
|
+
allowed_prefixes.append(pattern)
|
|
147
|
+
|
|
148
|
+
all_paper_categories_match = True
|
|
149
|
+
for paper_cat in categories_list:
|
|
150
|
+
# 检查当前论文的每个分类是否至少匹配一个用户指定的模式前缀
|
|
151
|
+
if not any(paper_cat.startswith(prefix) for prefix in allowed_prefixes):
|
|
152
|
+
all_paper_categories_match = False
|
|
153
|
+
break
|
|
154
|
+
|
|
155
|
+
if not all_paper_categories_match:
|
|
156
|
+
continue # 如果有任何一个分类不匹配,就跳过这篇论文
|
|
157
|
+
|
|
133
158
|
# 获取摘要
|
|
134
159
|
abstract = ""
|
|
135
160
|
if include_abstract:
|
|
@@ -167,9 +192,11 @@ if __name__ == '__main__':
|
|
|
167
192
|
# python -m beswarm.tools.search_arxiv
|
|
168
193
|
test_query = "NoProp"
|
|
169
194
|
test_query = '"Attention Is All You Need"'
|
|
195
|
+
test_query = '(all:"sparse autoencoders" OR all:"sparse autoencoder" OR (all:SAE AND NOT au:SAE))'
|
|
196
|
+
|
|
170
197
|
print(f"使用关键词 '{test_query}' 测试搜索...")
|
|
171
198
|
|
|
172
|
-
search_results = search_arxiv(query=test_query, max_results=
|
|
199
|
+
search_results = search_arxiv(query=test_query, max_results=1000, categories='cs*', sort_by='lastUpdatedDate')
|
|
173
200
|
|
|
174
201
|
if isinstance(search_results, str):
|
|
175
202
|
# 如果返回的是错误信息字符串,则打印错误
|
|
@@ -183,9 +210,37 @@ if __name__ == '__main__':
|
|
|
183
210
|
print(f" 作者: {', '.join(paper['authors'])}")
|
|
184
211
|
print(f" 发布日期: {paper['published_date']}")
|
|
185
212
|
print(f" arXiv ID: {paper['arxiv_id']}")
|
|
213
|
+
print(f" 领域: {paper['categories']}")
|
|
186
214
|
print(f" PDF链接: {paper['pdf_url']}")
|
|
187
215
|
print(f" 摘要: {paper['abstract'][:150]}...") # 打印摘要前150个字符
|
|
188
216
|
print("-" * 20)
|
|
217
|
+
|
|
218
|
+
# 将结果保存到CSV文件
|
|
219
|
+
csv_filename = 'arxiv_search_results.csv'
|
|
220
|
+
print(f"\n正在将 {len(search_results)} 条结果保存到 {csv_filename}...")
|
|
221
|
+
|
|
222
|
+
try:
|
|
223
|
+
with open(csv_filename, mode='w', newline='', encoding='utf-8') as csv_file:
|
|
224
|
+
# 使用第一个数据项的键作为CSV文件的标题
|
|
225
|
+
fieldnames = search_results[0].keys()
|
|
226
|
+
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
|
|
227
|
+
|
|
228
|
+
writer.writeheader()
|
|
229
|
+
for paper in search_results:
|
|
230
|
+
# 转换列表为字符串以便写入CSV
|
|
231
|
+
paper_for_csv = paper.copy()
|
|
232
|
+
if 'authors' in paper_for_csv and isinstance(paper_for_csv['authors'], list):
|
|
233
|
+
paper_for_csv['authors'] = ', '.join(paper_for_csv['authors'])
|
|
234
|
+
if 'categories' in paper_for_csv and isinstance(paper_for_csv['categories'], list):
|
|
235
|
+
paper_for_csv['categories'] = ', '.join(paper_for_csv['categories'])
|
|
236
|
+
|
|
237
|
+
writer.writerow(paper_for_csv)
|
|
238
|
+
|
|
239
|
+
print(f"结果已成功保存到 {csv_filename}")
|
|
240
|
+
|
|
241
|
+
except IOError as e:
|
|
242
|
+
print(f"错误:无法写入文件 {csv_filename}: {e}")
|
|
243
|
+
|
|
189
244
|
else:
|
|
190
245
|
print("未找到相关论文。")
|
|
191
246
|
else:
|