arxiv-pulse 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,356 @@
1
+ import openai
2
+ import json
3
+ import logging
4
+ from typing import List, Dict, Any, Optional
5
+ from tqdm import tqdm
6
+ import time
7
+ import re
8
+
9
+ from arxiv_pulse.models import Database, Paper
10
+ from arxiv_pulse.config import Config
11
+ from arxiv_pulse.output_manager import output
12
+
13
+ # 使用根日志记录器的配置(保留用于向后兼容)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class PaperSummarizer:
18
+ def __init__(self):
19
+ self.db = Database()
20
+ self.config = Config
21
+
22
+ # Token使用累计统计
23
+ self.total_prompt_tokens = 0
24
+ self.total_completion_tokens = 0
25
+ self.total_tokens = 0
26
+
27
+ # 抑制第三方库的详细日志
28
+ logging.getLogger("httpx").setLevel(logging.WARNING)
29
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
30
+
31
+ if self.config.AI_API_KEY:
32
+ # openai 2.x版本不需要全局设置,将在使用时创建客户端
33
+ pass
34
+ else:
35
+ output.warn("AI API密钥未设置,使用基础总结")
36
+
37
+ def extract_keywords(self, text: str, max_keywords: int = 10) -> List[str]:
38
+ """Extract keywords from text using simple heuristics"""
39
+ # Simple keyword extraction (can be improved)
40
+ words = re.findall(r"\b[A-Za-z][a-z]{3,}\b", text.lower())
41
+ common_words = {
42
+ "this",
43
+ "that",
44
+ "with",
45
+ "from",
46
+ "have",
47
+ "which",
48
+ "there",
49
+ "their",
50
+ "about",
51
+ "using",
52
+ "based",
53
+ "approach",
54
+ "method",
55
+ "study",
56
+ "paper",
57
+ "research",
58
+ "results",
59
+ "show",
60
+ "find",
61
+ "found",
62
+ "propose",
63
+ "proposed",
64
+ }
65
+
66
+ word_freq = {}
67
+ for word in words:
68
+ if word not in common_words and len(word) > 3:
69
+ word_freq[word] = word_freq.get(word, 0) + 1
70
+
71
+ # Sort by frequency
72
+ sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
73
+ keywords = [word for word, _ in sorted_words[:max_keywords]]
74
+
75
+ return keywords
76
+
77
+ def basic_summary(self, paper: Paper) -> str:
78
+ """Generate basic summary without AI"""
79
+ abstract_str = str(paper.abstract) if paper.abstract else ""
80
+ title_str = str(paper.title) if paper.title else ""
81
+
82
+ # Simple extraction of first few sentences
83
+ sentences = re.split(r"[.!?]+", abstract_str)
84
+ if len(sentences) > 3:
85
+ summary = ". ".join(sentences[:3]) + "."
86
+ else:
87
+ summary = abstract_str[:500] + "..." if len(abstract_str) > 500 else abstract_str
88
+
89
+ keywords = self.extract_keywords(f"{title_str} {abstract_str}")
90
+
91
+ return json.dumps(
92
+ {
93
+ "summary": "",
94
+ "keywords": keywords,
95
+ "method": "basic",
96
+ "key_findings": [],
97
+ }
98
+ )
99
+
100
+ def deepseek_summary(self, paper: Paper) -> Optional[str]:
101
+ """Generate summary using DeepSeek"""
102
+ if not self.config.AI_API_KEY:
103
+ return None
104
+
105
+ prompt = f"""
106
+ 请用结构化的格式总结以下研究论文:
107
+
108
+ 标题: {paper.title}
109
+
110
+ 摘要: {paper.abstract}
111
+
112
+ 请提供:
113
+ 1. 关键发现/贡献(要点列表)
114
+ 2. 使用的方法论/方法
115
+ 3. 与凝聚态物理、DFT、机器学习或力场的相关性
116
+ 4. 潜在影响/重要性
117
+
118
+ 请将回答格式化为JSON对象,包含以下字段:
119
+ - key_findings: 字符串数组
120
+ - methodology: 字符串
121
+ - relevance: 字符串
122
+ - impact: 字符串
123
+ - keywords: 相关关键词数组(5-10个)
124
+ """
125
+
126
+ try:
127
+ output.do(f"总结论文: {paper.arxiv_id}")
128
+
129
+ # 创建openai客户端 (openai 2.x版本)
130
+ import openai
131
+
132
+ client = openai.OpenAI(api_key=self.config.AI_API_KEY, base_url=self.config.AI_BASE_URL)
133
+
134
+ response = client.chat.completions.create(
135
+ model=self.config.AI_MODEL,
136
+ messages=[
137
+ {
138
+ "role": "system",
139
+ "content": "你是一个总结物理学和计算科学论文的研究助手。",
140
+ },
141
+ {"role": "user", "content": prompt},
142
+ ],
143
+ max_tokens=self.config.SUMMARY_MAX_TOKENS,
144
+ temperature=0.3,
145
+ )
146
+
147
+ # 记录token使用情况
148
+ if hasattr(response, "usage") and response.usage:
149
+ usage = response.usage
150
+ # 更新累计token统计
151
+ self.total_prompt_tokens += usage.prompt_tokens
152
+ self.total_completion_tokens += usage.completion_tokens
153
+ self.total_tokens += usage.total_tokens
154
+
155
+ output.info(
156
+ f"Token 使用: 本次 提示 {usage.prompt_tokens}, 完成 {usage.completion_tokens}, 总计 {usage.total_tokens} | "
157
+ f"累计 提示 {self.total_prompt_tokens}, 完成 {self.total_completion_tokens}, 总计 {self.total_tokens}"
158
+ )
159
+ else:
160
+ # 估算token使用(约4字符/1token)
161
+ prompt_chars = len(prompt)
162
+ estimated_tokens = prompt_chars // 4 + self.config.SUMMARY_MAX_TOKENS // 2
163
+ # 更新累计token统计(估算)
164
+ self.total_tokens += estimated_tokens
165
+ output.info(f"Token 使用: 估算约 {estimated_tokens} tokens | 累计总计 {self.total_tokens} tokens")
166
+
167
+ result = response.choices[0].message.content
168
+
169
+ def clean_json_response(text):
170
+ """清理AI响应中的JSON代码块标记"""
171
+ import re
172
+
173
+ # 移除 ```json 和 ``` 标记
174
+ text = text.strip()
175
+ # 匹配 ```json ... ``` 模式
176
+ json_match = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL)
177
+ if json_match:
178
+ return json_match.group(1).strip()
179
+ # 匹配 ``` ... ``` 模式(没有json标签)
180
+ code_match = re.search(r"```\s*(.*?)\s*```", text, re.DOTALL)
181
+ if code_match:
182
+ return code_match.group(1).strip()
183
+ # 如果以 ```json 开头但没有闭合
184
+ if text.startswith("```json"):
185
+ text = text[7:].strip()
186
+ if text.startswith("```"):
187
+ text = text[3:].strip()
188
+ if text.endswith("```"):
189
+ text = text[:-3].strip()
190
+ return text
191
+
192
+ # 清理响应
193
+ cleaned_result = clean_json_response(result)
194
+
195
+ # Try to parse as JSON, fallback to text
196
+ try:
197
+ summary_data = json.loads(cleaned_result)
198
+ except json.JSONDecodeError:
199
+ # 如果清理后仍然不是JSON,尝试原始结果
200
+ try:
201
+ summary_data = json.loads(result)
202
+ except json.JSONDecodeError:
203
+ # If not JSON, wrap it
204
+ summary_data = {
205
+ "summary": "",
206
+ "key_findings": [],
207
+ "methodology": "",
208
+ "relevance": "",
209
+ "impact": "",
210
+ "keywords": self.extract_keywords(f"{paper.title} {paper.abstract}"),
211
+ }
212
+
213
+ return json.dumps(summary_data)
214
+
215
+ except Exception as e:
216
+ output.error(f"DeepSeek API 错误: {paper.arxiv_id}", details={"exception": str(e)})
217
+ return None
218
+
219
+ def summarize_paper(self, paper: Paper) -> bool:
220
+ """Summarize a single paper"""
221
+ try:
222
+ summary_json = None
223
+
224
+ # Try DeepSeek summary if API key is available
225
+ if self.config.AI_API_KEY:
226
+ summary_json = self.deepseek_summary(paper)
227
+
228
+ # Fall back to basic summary if OpenAI failed or not available
229
+ if not summary_json:
230
+ summary_json = self.basic_summary(paper)
231
+ # 估算basic summary的token使用(标题+摘要)
232
+ text_length = len(str(paper.title or "")) + len(str(paper.abstract or ""))
233
+ estimated_tokens = text_length // 4
234
+ self.total_tokens += estimated_tokens
235
+ output.info(f"基础总结Token估算: 约 {estimated_tokens} tokens | 累计总计 {self.total_tokens} tokens")
236
+
237
+ if summary_json:
238
+ # Extract keywords for separate storage
239
+ try:
240
+ summary_data = json.loads(summary_json)
241
+ keywords = summary_data.get("keywords", [])
242
+ except:
243
+ keywords = []
244
+
245
+ # Update paper
246
+ success = self.db.update_paper(
247
+ paper.arxiv_id,
248
+ summarized=True,
249
+ summary=summary_json,
250
+ keywords=json.dumps(keywords),
251
+ )
252
+
253
+ if success:
254
+ output.done(f"总结完成: {paper.arxiv_id}")
255
+ return True
256
+
257
+ return False
258
+
259
+ except Exception as e:
260
+ output.error(f"总结论文失败: {paper.arxiv_id}", details={"exception": str(e)})
261
+ return False
262
+
263
+ def summarize_pending_papers(self, limit: int = 20) -> Dict[str, Any]:
264
+ """Summarize papers that need summarization"""
265
+ papers = self.db.get_papers_to_summarize(limit=limit)
266
+ output.do(f"找到 {len(papers)} 篇需要总结的论文")
267
+
268
+ successful = 0
269
+ failed = 0
270
+
271
+ for paper in tqdm(papers, desc="Summarizing papers"):
272
+ if self.summarize_paper(paper):
273
+ successful += 1
274
+ else:
275
+ failed += 1
276
+
277
+ # Rate limiting
278
+ time.sleep(0.5)
279
+
280
+ return {
281
+ "total_processed": len(papers),
282
+ "successful": successful,
283
+ "failed": failed,
284
+ }
285
+
286
+ def get_summary_stats(self) -> Dict[str, Any]:
287
+ """Get summarization statistics"""
288
+ with self.db.get_session() as session:
289
+ total = session.query(Paper).count()
290
+ summarized = session.query(Paper).filter_by(summarized=True).count()
291
+
292
+ # Quality metrics
293
+ papers = session.query(Paper).filter_by(summarized=True).all()
294
+ avg_summary_length = 0
295
+ if papers:
296
+ total_length = sum(len(p.summary or "") for p in papers)
297
+ avg_summary_length = total_length / len(papers)
298
+
299
+ return {
300
+ "total_papers": total,
301
+ "summarized_papers": summarized,
302
+ "summarization_rate": summarized / total if total > 0 else 0,
303
+ "avg_summary_length": avg_summary_length,
304
+ "token_usage": {
305
+ "total_prompt_tokens": self.total_prompt_tokens,
306
+ "total_completion_tokens": self.total_completion_tokens,
307
+ "total_tokens": self.total_tokens,
308
+ },
309
+ }
310
+
311
+
312
+ def main():
313
+ """Test the summarizer"""
314
+ summarizer = PaperSummarizer()
315
+
316
+ print("Testing paper summarizer...")
317
+
318
+ # Get a paper to summarize
319
+ db = Database()
320
+ with db.get_session() as session:
321
+ paper = session.query(Paper).filter_by(summarized=False).first()
322
+
323
+ if paper:
324
+ print(f"\nPaper to summarize:")
325
+ print(f"Title: {paper.title[:100]}...")
326
+ print(f"Abstract preview: {paper.abstract[:200]}...")
327
+
328
+ # Test summarization
329
+ print(f"\nSummarizing...")
330
+ success = summarizer.summarize_paper(paper)
331
+
332
+ if success:
333
+ print("Summary successful!")
334
+
335
+ # Get updated paper
336
+ updated = session.query(Paper).filter_by(arxiv_id=paper.arxiv_id).first()
337
+ if updated and updated.summary:
338
+ try:
339
+ summary_data = json.loads(updated.summary)
340
+ print(f"\nSummary: {summary_data.get('summary', '')[:300]}...")
341
+ print(f"Keywords: {summary_data.get('keywords', [])[:5]}")
342
+ except:
343
+ print(f"Summary: {updated.summary[:300]}...")
344
+ else:
345
+ print("No papers available for summarization")
346
+
347
+ # Get stats
348
+ stats = summarizer.get_summary_stats()
349
+ print(f"\nSummarization stats:")
350
+ print(f"Total papers: {stats['total_papers']}")
351
+ print(f"Summarized: {stats['summarized_papers']}")
352
+ print(f"Rate: {stats['summarization_rate']:.1%}")
353
+
354
+
355
+ if __name__ == "__main__":
356
+ main()