beswarm 0.2.39__py3-none-any.whl → 0.2.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of beswarm might be problematic. Click here for more details.

beswarm/taskmanager.py CHANGED
@@ -17,210 +17,257 @@ class TaskStatus(Enum):
17
17
 
18
18
 
19
19
  class TaskManager:
20
- """一个简单的异步任务管理器"""
21
- def __init__(self):
22
- self.tasks = {} # 使用字典来存储任务,key是task_id, value是task对象
23
- self.results_queue = asyncio.Queue()
20
+ """
21
+ 一个带并发控制的异步任务管理器。
22
+ 它管理任务的生命周期,并通过一个固定大小的工作者池来控制并发执行的任务数量。
23
+ """
24
+ def __init__(self, concurrency_limit=3):
25
+ if concurrency_limit <= 0:
26
+ raise ValueError("并发限制必须大于0")
27
+
28
+ self.concurrency_limit = concurrency_limit
29
+ self.tasks_cache = {} # 存储所有任务的状态和元数据, key: task_id
30
+
31
+ self._pending_queue = asyncio.Queue() # 内部待办任务队列
32
+ self._results_queue = asyncio.Queue() # 内部已完成任务结果队列
33
+ self._workers = [] # 持有工作者任务的引用
34
+ self._is_running = False # 标记工作者池是否在运行
24
35
  self.root_path = None
25
- self.tasks_cache = {}
36
+ self.cache_dir = None
37
+ self.task_cache_file = None
38
+
39
+ print(f"TaskManager 初始化,并发限制为: {self.concurrency_limit}")
26
40
 
27
41
  def set_root_path(self, root_path):
28
- if self.root_path:
42
+ """设置工作根目录并加载持久化的任务状态。"""
43
+ if self.root_path is not None:
29
44
  return
30
45
  self.root_path = Path(root_path)
31
46
  self.cache_dir = self.root_path / ".beswarm"
47
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
32
48
  self.task_cache_file = self.cache_dir / "tasks.json"
33
- self.task_cache_file.touch(exist_ok=True)
34
- self.read_tasks_cache()
49
+
50
+ self._load_tasks_from_cache()
35
51
  self.set_task_cache("root_path", str(self.root_path))
36
- self.resume_all_running_task()
37
52
 
38
- def set_task_cache(self, *keys_and_value):
39
- """
40
- 设置可嵌套的任务缓存。
41
- 接受无限个键和一个值,例如 set_task_cache('a', 'b', 'c', value)
42
- 会转换为 tasks_cache['a']['b']['c'] = value
43
- """
44
- if len(keys_and_value) < 2:
45
- return # 至少需要一个键和一个值
53
+ # 启动工作者池
54
+ self.start()
55
+ # 恢复中断的任务
56
+ self.resume_interrupted_tasks()
57
+
58
+ def start(self):
59
+ """启动并发工作者池。"""
60
+ if self._is_running:
61
+ return
62
+
63
+ self._is_running = True
64
+ for i in range(self.concurrency_limit):
65
+ worker = asyncio.create_task(self._worker_loop(f"Worker-{i+1}"))
66
+ self._workers.append(worker)
67
+ print(f"已启动 {self.concurrency_limit} 个并发工作者。")
68
+
69
+ async def stop(self):
70
+ """优雅地停止所有工作者。"""
71
+ if not self._is_running:
72
+ return
73
+
74
+ print("\n正在停止 TaskManager...")
75
+ await self._pending_queue.join()
76
+
77
+ for worker in self._workers:
78
+ worker.cancel()
79
+
80
+ await asyncio.gather(*self._workers, return_exceptions=True)
81
+
82
+ self._is_running = False
83
+ print("所有工作者已停止。")
84
+
85
+ async def _worker_loop(self, worker_name: str):
86
+ """每个工作者的主循环,从队列中拉取并执行任务。"""
87
+ print(f"[{worker_name}] 已就绪,等待任务...")
88
+ while self._is_running:
89
+ try:
90
+ task_id, coro = await self._pending_queue.get()
91
+
92
+ print(f"[{worker_name}] 领到了任务 <{task_id[:8]}>,开始执行...")
93
+ self._update_task_status(task_id, TaskStatus.RUNNING)
46
94
 
47
- keys = keys_and_value[:-1]
48
- value = keys_and_value[-1]
95
+ try:
96
+ result = await coro
97
+ self._handle_task_completion(task_id, TaskStatus.DONE, result)
98
+ except Exception as e:
99
+ self._handle_task_completion(task_id, TaskStatus.ERROR, e)
100
+ finally:
101
+ self._pending_queue.task_done()
49
102
 
103
+ except asyncio.CancelledError:
104
+ print(f"[{worker_name}] 被取消,正在退出...")
105
+ break
106
+ except Exception as e:
107
+ print(f"[{worker_name}] 循环中遇到严重错误: {e}")
108
+
109
+ def _handle_task_completion(self, task_id, status, result):
110
+ """统一处理任务完成的内部函数。"""
111
+ if status == TaskStatus.DONE:
112
+ print(f"✅ 任务 <{task_id[:8]}> 执行成功。")
113
+ else: # ERROR
114
+ print(f"❌ 任务 <{task_id[:8]}> 执行失败: {result}")
115
+
116
+ self._update_task_status(task_id, status, result=str(result))
117
+ self._results_queue.put_nowait((task_id, status, result))
118
+
119
+ def set_task_cache(self, *keys_and_value):
120
+ """设置可嵌套的任务缓存。"""
121
+ if len(keys_and_value) < 2: return
122
+ keys, value = keys_and_value[:-1], keys_and_value[-1]
50
123
  d = self.tasks_cache
51
- # 遍历到倒数第二个键,确保路径存在
52
124
  for key in keys[:-1]:
53
125
  d = d.setdefault(key, {})
54
-
55
- # 在最后一个键上设置值
56
126
  d[keys[-1]] = value
57
- self.save_tasks_cache()
127
+ self._save_tasks_to_cache()
58
128
 
59
- def save_tasks_cache(self):
60
- self.task_cache_file.write_text(json.dumps(self.tasks_cache, ensure_ascii=False, indent=4), encoding="utf-8")
61
-
62
- def read_tasks_cache(self):
63
- content = self.task_cache_file.read_text(encoding="utf-8")
129
+ def _save_tasks_to_cache(self):
130
+ """将任务缓存持久化到文件。"""
131
+ if not self.task_cache_file: return
64
132
  try:
65
- self.tasks_cache = json.loads(content) if content else {}
66
- except json.JSONDecodeError:
67
- raise ValueError("任务缓存文件格式错误")
133
+ with self.task_cache_file.open('w', encoding='utf-8') as f:
134
+ json.dump(self.tasks_cache, f, indent=4, ensure_ascii=False)
135
+ except Exception as e:
136
+ print(f"警告:无法将任务状态持久化到文件: {e}")
68
137
 
69
- def create_tasks(self, task_coro, tasks_params):
70
- """
71
- 批量创建并注册任务。
138
+ def _load_tasks_from_cache(self):
139
+ """从文件加载任务缓存。"""
140
+ if not self.task_cache_file or not self.task_cache_file.exists():
141
+ self.tasks_cache = {}
142
+ return
143
+ try:
144
+ content = self.task_cache_file.read_text(encoding='utf-8')
145
+ if content:
146
+ self.tasks_cache = json.loads(content)
147
+ else:
148
+ self.tasks_cache = {}
149
+ except (FileNotFoundError, json.JSONDecodeError):
150
+ self.tasks_cache = {}
151
+ print("警告:任务缓存文件不存在或格式错误,将使用空缓存。")
72
152
 
73
- Args:
74
- task_coro: 用于创建任务的协程函数。
75
- tasks_params (list): 包含任务参数的列表。
153
+ async def get_next_result(self):
154
+ """异步获取下一个完成的任务结果。"""
155
+ return await self._results_queue.get()
76
156
 
77
- Returns:
78
- list: 创建的任务ID列表。
157
+ def create_tasks_batch(self, task_coro_func, tasks_params_list):
158
+ """
159
+ 批量创建任务,但不是立即执行,而是将它们放入待处理队列。
79
160
  """
161
+ if not self._is_running:
162
+ raise RuntimeError("TaskManager尚未启动。请先调用 start() 方法。")
163
+
80
164
  task_ids = []
81
- for args in tasks_params:
82
- coro = task_coro(**args)
83
- task_id = self.create_task(coro)
165
+ for params in tasks_params_list:
166
+ task_id = str(uuid.uuid4())
167
+ coro = task_coro_func(**params)
168
+
169
+ # 初始化任务状态为 PENDING
170
+ self._update_task_status(task_id, TaskStatus.PENDING, args=params)
171
+
172
+ # 将任务定义放入队列
173
+ self._pending_queue.put_nowait((task_id, coro))
84
174
  task_ids.append(task_id)
85
- self.set_task_cache(task_id, "args", args)
86
- self.set_task_cache(task_id, "status", TaskStatus.RUNNING.value)
175
+
176
+ print(f"已将 {len(task_ids)} 个新任务加入待处理队列。队列当前大小: {self._pending_queue.qsize()}")
87
177
  return task_ids
88
178
 
89
- def resume_all_running_task(self):
90
- running_task_id_list = [task_id for task_id, task in self.tasks_cache.items() if task_id != "root_path" and task.get("status") == "RUNNING"]
91
- for task_id in running_task_id_list:
92
- tasks_params = self.tasks_cache[task_id]["args"]
93
- task_id = self.resume_task(task_id, registry.tools["worker"], tasks_params)
179
+ def create_tasks(self, task_coro_func, tasks_params_list):
180
+ """批量将任务放入待处理队列。"""
181
+ if not self._is_running:
182
+ raise RuntimeError("TaskManager尚未启动。请先在 set_root_path 后确保其已启动。")
94
183
 
95
- def resume_task(self, task_id, task_coro, args):
96
- """
97
- 恢复一个任务。
98
- """
99
- task = self.tasks_cache.get(task_id)
100
- if not task:
101
- return TaskStatus.NOT_FOUND
184
+ task_ids = []
185
+ for params in tasks_params_list:
186
+ task_id = str(uuid.uuid4())
187
+ coro = task_coro_func(**params)
102
188
 
103
- coro = task_coro(**args)
104
- task_id = self.create_task(coro, task_id)
105
- self.set_task_cache(task_id, "args", args)
106
- self.set_task_cache(task_id, "status", TaskStatus.RUNNING.value)
107
- print(f"任务已恢复: ID={task_id}, Name={task_id}")
108
- print(f"args: {args}")
109
- print(f"self.tasks_cache: {json.dumps(self.tasks_cache, ensure_ascii=False, indent=4)}")
110
- return task_id
189
+ self._update_task_status(task_id, TaskStatus.PENDING, args=params)
190
+ self._pending_queue.put_nowait((task_id, coro))
191
+ task_ids.append(task_id)
111
192
 
112
- def create_task(self, coro, task_id=None):
113
- """
114
- 创建并注册一个新任务。
193
+ print(f"已将 {len(task_ids)} 个新任务加入待处理队列。队列当前大小: {self._pending_queue.qsize()}")
194
+ return task_ids
115
195
 
116
- Args:
117
- coro: 要执行的协程。
118
- name (str, optional): 任务的可读名称。 Defaults to None.
196
+ def resume_interrupted_tasks(self):
197
+ """在启动时,恢复所有处于 PENDING 或 RUNNING 状态的旧任务。"""
198
+ interrupted_tasks = [
199
+ (tid, info) for tid, info in self.tasks_cache.items()
200
+ if tid != "root_path" and info.get("status") in [TaskStatus.PENDING.value, TaskStatus.RUNNING.value]
201
+ ]
119
202
 
120
- Returns:
121
- str: 任务的唯一ID。
122
- """
123
- if task_id == None:
124
- task_id = str(uuid.uuid4())
125
- task_name = f"Task-{task_id[:8]}"
203
+ if not interrupted_tasks:
204
+ return
126
205
 
127
- # 使用 asyncio.create_task() 创建任务
128
- task = asyncio.create_task(coro, name=task_name)
206
+ print(f"检测到 {len(interrupted_tasks)} 个中断的任务,正在恢复...")
207
+ worker_fun = registry.tools["worker"]
129
208
 
130
- # 将任务存储在管理器中
131
- # 当任务完成时,通过回调函数将结果放入队列
132
- task.add_done_callback(
133
- lambda t: self._on_task_done(task_id, t)
134
- )
135
- self.tasks[task_id] = task
136
- print(f"任务已创建: ID={task_id}, Name={task_name}")
137
- return task_id
209
+ for task_id, task_info in interrupted_tasks:
210
+ args = task_info.get("args")
211
+ if not args:
212
+ print(f"警告:任务 <{task_id[:8]}> 缺少参数,无法恢复。")
213
+ self._update_task_status(task_id, TaskStatus.ERROR, result="缺少参数,无法恢复")
214
+ continue
138
215
 
139
- def get_task_status(self, task_id):
140
- """
141
- 查询特定任务的状态。
216
+ coro = worker_fun(**args)
217
+ self._update_task_status(task_id, TaskStatus.PENDING)
218
+ self._pending_queue.put_nowait((task_id, coro))
142
219
 
143
- Args:
144
- task_id (str): 要查询的任务ID。
220
+ print(f"{len(interrupted_tasks)} 个中断的任务已重新加入队列。")
145
221
 
146
- Returns:
147
- TaskStatus: 任务的当前状态。
148
- """
149
- task = self.tasks.get(task_id)
150
- if not task:
151
- return TaskStatus.NOT_FOUND
222
+ def resume_task(self, task_id, goal):
223
+ """恢复一个指定的任务,实质上是创建一个新任务并替换旧的记录,但ID保持不变。"""
224
+ if task_id not in self.tasks_cache:
225
+ return f"任务 {task_id} 不存在"
152
226
 
153
- if task.done():
154
- if task.cancelled():
155
- return TaskStatus.CANCELLED
156
- elif task.exception() is not None:
157
- return TaskStatus.ERROR
158
- else:
159
- return TaskStatus.DONE
227
+ old_task_info = self.tasks_cache.get(task_id, {})
228
+ tasks_params = old_task_info.get("args", {})
229
+ if not tasks_params:
230
+ return f"<tool_error>任务 {task_id} 缺少参数信息,无法恢复。</tool_error>"
160
231
 
161
- # asyncio.Task 没有直接的 'RUNNING' 状态。
162
- # 如果任务还没有完成,它要么是等待执行(PENDING),要么是正在执行(RUNNING)。
163
- # 这里我们简化处理,认为未完成的就是运行中。
164
- return TaskStatus.RUNNING
232
+ tasks_params["goal"] = goal
233
+ tasks_params["cache_messages"] = True # 恢复时强制使用缓存
165
234
 
166
- def get_task_result(self, task_id):
167
- """获取已完成任务的结果,如果任务未完成或出错则返回相应信息。"""
168
- task = self.tasks.get(task_id)
169
- if self.get_task_status(task_id) == TaskStatus.DONE:
170
- return task.result()
171
- elif self.get_task_status(task_id) == TaskStatus.ERROR:
172
- return task.exception()
173
- return None
174
-
175
- def _on_task_done(self, task_id, task):
176
- """私有回调函数,在任务完成时将结果放入队列。"""
177
- try:
178
- # 将元组 (task_id, status, result) 放入队列
179
- self.results_queue.put_nowait(
180
- (task_id, TaskStatus.DONE, task.result())
181
- )
182
- self.set_task_cache(task_id, "status", TaskStatus.DONE.value)
183
- self.set_task_cache(task_id, "result", task.result())
184
- except asyncio.CancelledError:
185
- self.results_queue.put_nowait(
186
- (task_id, TaskStatus.CANCELLED, None)
187
- )
188
- self.set_task_cache(task_id, "status", TaskStatus.CANCELLED.value)
189
- except Exception as e:
190
- self.results_queue.put_nowait(
191
- (task_id, TaskStatus.ERROR, e)
192
- )
193
- self.set_task_cache(task_id, "status", TaskStatus.ERROR.value)
194
- self.set_task_cache(task_id, "result", str(e))
235
+ worker_fun = registry.tools["worker"]
236
+ coro = worker_fun(**tasks_params)
195
237
 
196
- async def get_next_result(self):
197
- """
198
- 等待并返回下一个完成的任务结果。
238
+ self._update_task_status(task_id, TaskStatus.PENDING, args=tasks_params)
239
+ self._pending_queue.put_nowait((task_id, coro))
199
240
 
200
- 如果所有任务都已提交,但没有任务完成,此方法将异步等待。
241
+ print(f"任务 <{task_id[:8]}> 已被重新加入队列等待恢复执行。")
242
+ return f"任务 {task_id} 已恢复"
201
243
 
202
- Returns:
203
- tuple: 一个包含 (task_id, status, result) 的元组。
204
- """
205
- return await self.results_queue.get()
244
+ def _update_task_status(self, task_id, status: TaskStatus, args=None, result=None):
245
+ """统一更新任务状态缓存并持久化。"""
246
+ if task_id not in self.tasks_cache:
247
+ self.tasks_cache[task_id] = {}
206
248
 
207
- def get_task_index(self, task_id):
208
- """
209
- 获取任务在任务字典中的插入顺序索引。
249
+ current_task = self.tasks_cache[task_id]
250
+ current_task['status'] = status.value
251
+ if args is not None:
252
+ current_task['args'] = args
253
+ if result is not None:
254
+ current_task['result'] = result
210
255
 
211
- Args:
212
- task_id (str): 要查询的任务ID。
256
+ self._save_tasks_to_cache()
213
257
 
214
- Returns:
215
- int: 任务的索引(从0开始),如果未找到则返回-1。
216
- """
217
- try:
218
- # 将字典的键转换为列表并查找索引
219
- task_ids_list = list(self.tasks.keys())
220
- return task_ids_list.index(task_id)
221
- except ValueError:
222
- # 如果任务ID不存在,则返回-1
223
- return -1
258
+ def get_task_status(self, task_id):
259
+ """查询特定任务的状态。"""
260
+ task_info = self.tasks_cache.get(task_id)
261
+ if not task_info:
262
+ return TaskStatus.NOT_FOUND
263
+ return TaskStatus(task_info.get("status", "NOT_FOUND"))
264
+
265
+ def get_task_result(self, task_id):
266
+ """获取已完成任务的结果。"""
267
+ task_info = self.tasks_cache.get(task_id)
268
+ if not task_info or task_info.get("status") not in [TaskStatus.DONE.value, TaskStatus.ERROR.value]:
269
+ return None
270
+ return task_info.get("result")
224
271
 
225
272
  async def main():
226
273
  manager = TaskManager()
beswarm/tools/__init__.py CHANGED
@@ -3,11 +3,12 @@ from .search_web import search_web
3
3
  from .completion import task_complete
4
4
  from .search_arxiv import search_arxiv
5
5
  from .repomap import get_code_repo_map
6
+ from .write_csv import append_row_to_csv
6
7
  from .request_input import request_admin_input
7
8
  from .screenshot import save_screenshot_to_file
8
9
  from .worker import worker, worker_gen, chatgroup
9
10
  from .click import find_and_click_element, scroll_screen
10
- from .subtasks import create_task, resume_task, get_all_tasks_status, get_task_result
11
+ from .subtasks import create_task, resume_task, get_all_tasks_status, get_task_result, create_tasks_from_csv
11
12
 
12
13
  #显式导入 aient.plugins 中的所需内容
13
14
  from ..aient.src.aient.plugins import (
@@ -47,12 +48,14 @@ __all__ = [
47
48
  "list_directory",
48
49
  "get_task_result",
49
50
  "get_url_content",
51
+ "append_row_to_csv",
50
52
  "set_readonly_path",
51
53
  "get_code_repo_map",
52
54
  "run_python_script",
53
55
  "get_search_results",
54
56
  "request_admin_input",
55
57
  "get_all_tasks_status",
58
+ "create_tasks_from_csv",
56
59
  "find_and_click_element",
57
60
  "download_read_arxiv_pdf",
58
61
  "save_screenshot_to_file",
@@ -9,7 +9,7 @@ def task_complete(message: str) -> str:
9
9
  它标志着一个任务的成功结束,并将最终的输出传递给用户或调用者。
10
10
 
11
11
  Args:
12
- message (str): 任务完成的信息或最终结果。
12
+ message (str): 任务完成的信息或最终结果。必填字段。
13
13
 
14
14
  Returns:
15
15
  str: 传入的任务完成信息。
@@ -1,4 +1,5 @@
1
1
  import requests
2
+ import csv
2
3
  from datetime import datetime
3
4
  from ..aient.src.aient.plugins import register_tool
4
5
 
@@ -43,10 +44,11 @@ NoProp: Training Neural Networks without Back-propagation or Forward-propagation
43
44
  包含搜索结果的字典列表,每个字典包含论文的标题、作者、摘要、发布日期、最后更新日期、arXiv ID、类别和PDF链接等信息
44
45
  """
45
46
  try:
46
- base_url = "http://export.arxiv.org/api/query"
47
+ base_url = "https://export.arxiv.org/api/query"
47
48
 
48
49
  # 构建查询参数
49
- search_query = f"all:{query}"
50
+ search_query = query
51
+ # search_query = f"all:{query}"
50
52
 
51
53
  # 添加类别过滤
52
54
  if categories:
@@ -64,6 +66,8 @@ NoProp: Training Neural Networks without Back-propagation or Forward-propagation
64
66
  else:
65
67
  search_query += f" AND au:\"{authors}\""
66
68
 
69
+ print(search_query)
70
+
67
71
  # 添加日期过滤
68
72
  # arXiv API不直接支持日期范围过滤,需要在结果中过滤
69
73
 
@@ -130,6 +134,27 @@ NoProp: Training Neural Networks without Back-propagation or Forward-propagation
130
134
  if cat_term not in categories_list:
131
135
  categories_list.append(cat_term)
132
136
 
137
+ # 应用严格的类别过滤,确保论文的所有类别都符合用户的要求
138
+ if categories:
139
+ user_specified_categories = categories if isinstance(categories, list) else [categories]
140
+
141
+ allowed_prefixes = []
142
+ for pattern in user_specified_categories:
143
+ if pattern.endswith('*'):
144
+ allowed_prefixes.append(pattern[:-1])
145
+ else:
146
+ allowed_prefixes.append(pattern)
147
+
148
+ all_paper_categories_match = True
149
+ for paper_cat in categories_list:
150
+ # 检查当前论文的每个分类是否至少匹配一个用户指定的模式前缀
151
+ if not any(paper_cat.startswith(prefix) for prefix in allowed_prefixes):
152
+ all_paper_categories_match = False
153
+ break
154
+
155
+ if not all_paper_categories_match:
156
+ continue # 如果有任何一个分类不匹配,就跳过这篇论文
157
+
133
158
  # 获取摘要
134
159
  abstract = ""
135
160
  if include_abstract:
@@ -167,9 +192,11 @@ if __name__ == '__main__':
167
192
  # python -m beswarm.tools.search_arxiv
168
193
  test_query = "NoProp"
169
194
  test_query = '"Attention Is All You Need"'
195
+ test_query = '(all:"sparse autoencoders" OR all:"sparse autoencoder" OR (all:SAE AND NOT au:SAE))'
196
+
170
197
  print(f"使用关键词 '{test_query}' 测试搜索...")
171
198
 
172
- search_results = search_arxiv(query=test_query, max_results=50, sort_by='lastUpdatedDate')
199
+ search_results = search_arxiv(query=test_query, max_results=1000, categories='cs*', sort_by='lastUpdatedDate')
173
200
 
174
201
  if isinstance(search_results, str):
175
202
  # 如果返回的是错误信息字符串,则打印错误
@@ -183,9 +210,37 @@ if __name__ == '__main__':
183
210
  print(f" 作者: {', '.join(paper['authors'])}")
184
211
  print(f" 发布日期: {paper['published_date']}")
185
212
  print(f" arXiv ID: {paper['arxiv_id']}")
213
+ print(f" 领域: {paper['categories']}")
186
214
  print(f" PDF链接: {paper['pdf_url']}")
187
215
  print(f" 摘要: {paper['abstract'][:150]}...") # 打印摘要前150个字符
188
216
  print("-" * 20)
217
+
218
+ # 将结果保存到CSV文件
219
+ csv_filename = 'arxiv_search_results.csv'
220
+ print(f"\n正在将 {len(search_results)} 条结果保存到 {csv_filename}...")
221
+
222
+ try:
223
+ with open(csv_filename, mode='w', newline='', encoding='utf-8') as csv_file:
224
+ # 使用第一个数据项的键作为CSV文件的标题
225
+ fieldnames = search_results[0].keys()
226
+ writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
227
+
228
+ writer.writeheader()
229
+ for paper in search_results:
230
+ # 转换列表为字符串以便写入CSV
231
+ paper_for_csv = paper.copy()
232
+ if 'authors' in paper_for_csv and isinstance(paper_for_csv['authors'], list):
233
+ paper_for_csv['authors'] = ', '.join(paper_for_csv['authors'])
234
+ if 'categories' in paper_for_csv and isinstance(paper_for_csv['categories'], list):
235
+ paper_for_csv['categories'] = ', '.join(paper_for_csv['categories'])
236
+
237
+ writer.writerow(paper_for_csv)
238
+
239
+ print(f"结果已成功保存到 {csv_filename}")
240
+
241
+ except IOError as e:
242
+ print(f"错误:无法写入文件 {csv_filename}: {e}")
243
+
189
244
  else:
190
245
  print("未找到相关论文。")
191
246
  else: