arxiv-pulse 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arxiv_pulse/config.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import os
2
- import warnings
3
2
 
4
3
 
5
4
  class Config:
@@ -7,8 +6,8 @@ class Config:
7
6
  DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///data/arxiv_papers.db")
8
7
 
9
8
  # Crawler
10
- MAX_RESULTS_INITIAL = int(os.getenv("MAX_RESULTS_INITIAL", 100))
11
- MAX_RESULTS_DAILY = int(os.getenv("MAX_RESULTS_DAILY", 20))
9
+ MAX_RESULTS_INITIAL = int(os.getenv("MAX_RESULTS_INITIAL", 10000))
10
+ MAX_RESULTS_DAILY = int(os.getenv("MAX_RESULTS_DAILY", 500))
12
11
 
13
12
  # Search queries - use semicolon as separator to allow commas in queries
14
13
  SEARCH_QUERIES_RAW = os.getenv(
@@ -29,7 +28,6 @@ class Config:
29
28
  SUMMARY_MAX_TOKENS = int(os.getenv("SUMMARY_MAX_TOKENS", 2000))
30
29
 
31
30
  # Report generation settings
32
- SUMMARY_SENTENCES_LIMIT = int(os.getenv("SUMMARY_SENTENCES_LIMIT", 3))
33
31
  TOKEN_PRICE_PER_MILLION = float(os.getenv("TOKEN_PRICE_PER_MILLION", 3.0))
34
32
 
35
33
  # Paths
@@ -40,13 +38,13 @@ class Config:
40
38
  REPORT_MAX_PAPERS = int(os.getenv("REPORT_MAX_PAPERS", "50"))
41
39
 
42
40
  # ArXiv API
43
- ARXIV_MAX_RESULTS = 1000
44
- ARXIV_SORT_BY = "submittedDate"
45
- ARXIV_SORT_ORDER = "descending"
41
+ ARXIV_MAX_RESULTS = int(os.getenv("ARXIV_MAX_RESULTS", 30000))
42
+ ARXIV_SORT_BY = os.getenv("ARXIV_SORT_BY", "submittedDate")
43
+ ARXIV_SORT_ORDER = os.getenv("ARXIV_SORT_ORDER", "descending")
46
44
 
47
45
  # Sync configuration
48
46
  YEARS_BACK = int(os.getenv("YEARS_BACK", 3)) # Years to look back for initial sync
49
- IMPORTANT_PAPERS_FILE = os.getenv("IMPORTANT_PAPERS_FILE", "important_papers.txt")
47
+ IMPORTANT_PAPERS_FILE = os.getenv("IMPORTANT_PAPERS_FILE", "data/important_papers.txt")
50
48
 
51
49
  @classmethod
52
50
  def validate(cls):
arxiv_pulse/models.py CHANGED
@@ -11,7 +11,7 @@ from sqlalchemy import (
11
11
  )
12
12
  from sqlalchemy.ext.declarative import declarative_base
13
13
  from sqlalchemy.orm import sessionmaker
14
- from datetime import datetime, timedelta
14
+ from datetime import datetime, timedelta, timezone
15
15
  import json
16
16
  from typing import Optional
17
17
 
@@ -48,8 +48,12 @@ class Paper(Base):
48
48
  summary = Column(Text)
49
49
 
50
50
  # Metadata
51
- created_at = Column(DateTime, default=datetime.utcnow)
52
- updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
51
+ created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None))
52
+ updated_at = Column(
53
+ DateTime,
54
+ default=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
55
+ onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
56
+ )
53
57
 
54
58
  def to_dict(self):
55
59
  """Convert to dictionary"""
@@ -112,8 +116,12 @@ class TranslationCache(Base):
112
116
  source_text_hash = Column(String(64), nullable=False, unique=True, index=True)
113
117
  translated_text = Column(Text, nullable=False)
114
118
  target_language = Column(String(10), default="zh")
115
- created_at = Column(DateTime, default=datetime.utcnow)
116
- updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
119
+ created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None))
120
+ updated_at = Column(
121
+ DateTime,
122
+ default=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
123
+ onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
124
+ )
117
125
 
118
126
  def __repr__(self):
119
127
  return f"<TranslationCache(id={self.id}, hash={self.source_text_hash[:16]}...)>"
@@ -147,7 +155,7 @@ class Database:
147
155
  if paper:
148
156
  for key, value in kwargs.items():
149
157
  setattr(paper, key, value)
150
- paper.updated_at = datetime.utcnow()
158
+ paper.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
151
159
  session.commit()
152
160
  return True
153
161
  return False
@@ -155,7 +163,7 @@ class Database:
155
163
  def get_recent_papers(self, days=7, limit=100):
156
164
  """Get recent papers"""
157
165
  with self.get_session() as session:
158
- cutoff_date = datetime.utcnow() - timedelta(days=days)
166
+ cutoff_date = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(days=days)
159
167
  return (
160
168
  session.query(Paper)
161
169
  .filter(Paper.published >= cutoff_date)
@@ -233,7 +241,7 @@ class Database:
233
241
  if existing:
234
242
  # 更新现有缓存
235
243
  existing.translated_text = translated_text
236
- existing.updated_at = datetime.utcnow()
244
+ existing.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
237
245
  else:
238
246
  # 创建新缓存
239
247
  cache_entry = TranslationCache(
@@ -249,7 +257,7 @@ class Database:
249
257
  def clear_old_translation_cache(self, days_old: int = 30) -> int:
250
258
  """清理旧的翻译缓存"""
251
259
  with self.get_session() as session:
252
- cutoff_date = datetime.utcnow() - timedelta(days=days_old)
260
+ cutoff_date = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(days=days_old)
253
261
  deleted_count = session.query(TranslationCache).filter(TranslationCache.updated_at < cutoff_date).delete()
254
262
  session.commit()
255
263
  return deleted_count
@@ -11,7 +11,7 @@
11
11
  [error] - 错误信息(简洁)
12
12
  [debug] - 调试信息(默认不显示)
13
13
 
14
- 所有详细日志同时写入日志文件,控制台只显示简洁信息。
14
+ 所有输出仅显示在控制台,不写入日志文件。
15
15
  """
16
16
 
17
17
  import sys
@@ -72,34 +72,30 @@ class OutputManager:
72
72
  if not self._initialized:
73
73
  self._initialized = True
74
74
  self._console_enabled = True
75
- self._file_logger = None
76
- self._min_level = OutputLevel.DO # 默认显示DO及以上(包括DONE, TIPS, INFO等)
75
+ # 从环境变量读取日志级别,默认为INFO
76
+ log_level = os.getenv("LOG_LEVEL", "INFO").upper()
77
+ level_map = {
78
+ "DEBUG": OutputLevel.DEBUG,
79
+ "INFO": OutputLevel.INFO,
80
+ "WARNING": OutputLevel.WARN,
81
+ "WARN": OutputLevel.WARN,
82
+ "ERROR": OutputLevel.ERROR,
83
+ "DO": OutputLevel.DO,
84
+ "DONE": OutputLevel.DONE,
85
+ "TIPS": OutputLevel.TIPS,
86
+ }
87
+ self._min_level = level_map.get(log_level, OutputLevel.INFO)
77
88
  self._suppressed_modules = set()
78
- self._setup_file_logger()
89
+ # 创建一个基本的日志记录器(不写入文件)
90
+ self._file_logger = logging.getLogger("arxiv_pulse")
91
+ self._file_logger.setLevel(logging.DEBUG)
92
+ # 添加NullHandler避免"No handlers"警告
93
+ if not self._file_logger.handlers:
94
+ self._file_logger.addHandler(logging.NullHandler())
79
95
 
80
96
  # 抑制第三方库的详细日志
81
97
  self._suppress_third_party_logs()
82
98
 
83
- def _setup_file_logger(self):
84
- """设置文件日志记录器"""
85
- # 创建日志目录
86
- os.makedirs("logs", exist_ok=True)
87
-
88
- # 配置文件日志记录器
89
- self._file_logger = logging.getLogger("arxiv_crawler")
90
- self._file_logger.setLevel(logging.DEBUG)
91
-
92
- # 移除现有处理器
93
- for handler in self._file_logger.handlers[:]:
94
- self._file_logger.removeHandler(handler)
95
-
96
- # 添加文件处理器
97
- file_handler = logging.FileHandler("logs/arxiv_pulse.log", encoding="utf-8")
98
- file_handler.setLevel(logging.DEBUG)
99
- formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
100
- file_handler.setFormatter(formatter)
101
- self._file_logger.addHandler(file_handler)
102
-
103
99
  def _suppress_third_party_logs(self):
104
100
  """抑制第三方库的详细日志"""
105
101
  # 设置第三方库的日志级别为WARNING或更高
@@ -117,16 +113,18 @@ class OutputManager:
117
113
  return False
118
114
 
119
115
  # 检查级别
116
+ # 数字越小表示级别越低(越不重要)
120
117
  level_order = {
121
- OutputLevel.DO: 0,
122
- OutputLevel.DONE: 1,
123
- OutputLevel.TIPS: 2,
124
- OutputLevel.INFO: 3,
125
- OutputLevel.WARN: 4,
126
- OutputLevel.ERROR: 5,
127
- OutputLevel.DEBUG: 6,
118
+ OutputLevel.DEBUG: 0, # 最低级别
119
+ OutputLevel.INFO: 1,
120
+ OutputLevel.WARN: 2,
121
+ OutputLevel.ERROR: 3,
122
+ OutputLevel.DO: 4, # 操作提示,通常显示
123
+ OutputLevel.DONE: 5,
124
+ OutputLevel.TIPS: 6,
128
125
  }
129
126
 
127
+ # 只有级别数字 >= 最小级别数字的才显示
130
128
  return level_order[level] >= level_order[self._min_level]
131
129
 
132
130
  def _output(
@@ -137,28 +135,6 @@ class OutputManager:
137
135
  details: Optional[Dict[str, Any]] = None,
138
136
  ):
139
137
  """统一输出方法"""
140
- # 记录到文件日志
141
- log_level = {
142
- OutputLevel.DO: logging.INFO,
143
- OutputLevel.DONE: logging.INFO,
144
- OutputLevel.TIPS: logging.INFO,
145
- OutputLevel.INFO: logging.INFO,
146
- OutputLevel.WARN: logging.WARNING,
147
- OutputLevel.ERROR: logging.ERROR,
148
- OutputLevel.DEBUG: logging.DEBUG,
149
- }[level]
150
-
151
- # 构建详细日志消息
152
- log_message = message
153
- if module:
154
- log_message = f"[{module}] {message}"
155
- if details:
156
- details_str = " ".join(f"{k}={v}" for k, v in details.items())
157
- log_message = f"{log_message} | {details_str}"
158
-
159
- # 写入文件日志
160
- self._file_logger.log(log_level, log_message)
161
-
162
138
  # 控制台输出
163
139
  if self._console_enabled and self._should_output(level, module):
164
140
  # 获取标签和颜色
@@ -228,7 +204,15 @@ class OutputManager:
228
204
  @classmethod
229
205
  def get_file_logger(cls) -> logging.Logger:
230
206
  """获取文件日志记录器"""
231
- return cls()._file_logger
207
+ instance = cls()
208
+ if instance._file_logger is None:
209
+ # 创建基本的日志记录器作为回退
210
+ instance._file_logger = logging.getLogger("arxiv_pulse_fallback")
211
+ instance._file_logger.setLevel(logging.DEBUG)
212
+ if not instance._file_logger.handlers:
213
+ instance._file_logger.addHandler(logging.NullHandler())
214
+ assert instance._file_logger is not None
215
+ return instance._file_logger
232
216
 
233
217
 
234
218
  # 简化别名
@@ -1,7 +1,6 @@
1
1
  import json
2
2
  import pandas as pd
3
- from datetime import datetime, timedelta
4
- import markdown
3
+ from datetime import datetime, timedelta, timezone
5
4
  from typing import Dict, List, Any, Optional
6
5
  import logging
7
6
  import os
@@ -21,7 +20,6 @@ class ReportGenerator:
21
20
  self.total_tokens_used = 0 # 总token使用量
22
21
  self.total_cost = 0.0 # 总费用(元)
23
22
  self.token_price_per_million = Config.TOKEN_PRICE_PER_MILLION # 每百万token价格,可从配置覆盖
24
- self.summary_sentences_limit = Config.SUMMARY_SENTENCES_LIMIT # 摘要句子数限制
25
23
 
26
24
  # 抑制第三方库的详细日志
27
25
  import logging
@@ -211,47 +209,6 @@ class ReportGenerator:
211
209
  # 确保分数在1-5之间
212
210
  return max(1, min(5, score))
213
211
 
214
- def _truncate_to_sentences(self, text: str, max_sentences: Optional[int] = None) -> str:
215
- """将文本截断为指定数量的句子(支持中英文)"""
216
- if not text:
217
- return ""
218
-
219
- if max_sentences is None:
220
- max_sentences = self.summary_sentences_limit
221
-
222
- import re
223
-
224
- # 支持中英文句子分隔符:句号、问号、感叹号、分号、省略号
225
- # 英文: . ? ! ; ... 中文: 。!?;…
226
- pattern = r"([。!?;…\.\?!;]+|\.{3,})"
227
- parts = re.split(pattern, text)
228
-
229
- sentences = []
230
- current = ""
231
- for i, part in enumerate(parts):
232
- current += part
233
- if i % 2 == 1: # 分隔符部分
234
- sentences.append(current)
235
- current = ""
236
-
237
- # 如果最后还有未结束的句子
238
- if current:
239
- sentences.append(current)
240
-
241
- # 如果分割失败,按长度简单截断
242
- if len(sentences) == 0:
243
- return text[:200] + "..." if len(text) > 200 else text
244
-
245
- # 取前max_sentences句
246
- result = "".join(sentences[:max_sentences])
247
-
248
- # 如果截断后比原文本短很多,添加省略号
249
- if len(result) < len(text) * 0.8:
250
- # 移除末尾的句子分隔符,添加省略号
251
- result = result.rstrip("。!?;….?!;") + "…"
252
-
253
- return result
254
-
255
212
  def translate_text(self, text: str, target_lang: str = "zh") -> str:
256
213
  """使用DeepSeek或OpenAI API翻译文本,优先使用缓存"""
257
214
  if not text or not text.strip():
@@ -364,7 +321,7 @@ class ReportGenerator:
364
321
 
365
322
  with self.db.get_session() as session:
366
323
  # Get papers from last 24 hours
367
- cutoff = datetime.utcnow() - timedelta(hours=24)
324
+ cutoff = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=24)
368
325
  new_papers = (
369
326
  session.query(Paper)
370
327
  .filter(Paper.created_at >= cutoff)
@@ -406,7 +363,7 @@ class ReportGenerator:
406
363
 
407
364
  with self.db.get_session() as session:
408
365
  # Get papers from last 7 days
409
- cutoff = datetime.utcnow() - timedelta(days=7)
366
+ cutoff = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(days=7)
410
367
  recent_papers = (
411
368
  session.query(Paper)
412
369
  .filter(Paper.created_at >= cutoff)
@@ -2,11 +2,9 @@
2
2
  增强搜索引擎 - 提供高级搜索和过滤功能
3
3
  """
4
4
 
5
- import json
6
- from datetime import datetime, timedelta
7
- from typing import List, Dict, Any, Optional, Union
5
+ from datetime import datetime, timedelta, timezone
6
+ from typing import List, Dict, Any, Optional
8
7
  from dataclasses import dataclass, field
9
- from pathlib import Path
10
8
 
11
9
  from sqlalchemy import and_, or_, not_, func, desc, asc
12
10
  from sqlalchemy.orm import Session
@@ -21,7 +19,7 @@ class SearchFilter:
21
19
 
22
20
  # 文本搜索
23
21
  query: Optional[str] = None
24
- search_fields: List[str] = field(default_factory=lambda: ["title", "abstract", "categories", "search_query"])
22
+ search_fields: List[str] = field(default_factory=lambda: ["title", "abstract"])
25
23
 
26
24
  # 分类过滤
27
25
  categories: Optional[List[str]] = None
@@ -64,31 +62,117 @@ class SearchEngine:
64
62
  self.session = db_session
65
63
 
66
64
  def build_text_filter(self, query: str, search_fields: List[str], match_all: bool = False):
67
- """构建文本搜索过滤器"""
65
+ """构建文本搜索过滤器,简单模糊匹配(支持单词拆分)"""
68
66
  if not query or not search_fields:
69
67
  return None
70
68
 
71
- filters = []
69
+ # 将查询转换为小写进行不区分大小写的匹配
70
+ query_lower = query.lower()
71
+
72
+ # 拆分为单词(按非字母数字字符,保留中文)
73
+ import re
74
+
75
+ # 使用正则表达式分割,保留中文字符(支持Unicode)
76
+ words = re.split(r"[^\w]+", query_lower, flags=re.UNICODE)
77
+ # 过滤掉空字符串和过短的单词(长度>1)
78
+ words = [w for w in words if w and len(w) > 1]
79
+
80
+ # 如果没有有效的单词,使用整个查询作为单个单词
81
+ if not words:
82
+ words = [query_lower]
83
+
84
+ # 如果只有一个单词,使用简单的字段间OR逻辑
85
+ if len(words) == 1:
86
+ word = words[0]
87
+ field_filters = []
88
+ for field in search_fields:
89
+ if field == "title":
90
+ field_filters.append(Paper.title.ilike(f"%{word}%"))
91
+ elif field == "abstract":
92
+ field_filters.append(Paper.abstract.ilike(f"%{word}%"))
93
+ elif field == "categories":
94
+ field_filters.append(Paper.categories.ilike(f"%{word}%"))
95
+ elif field == "search_query":
96
+ field_filters.append(Paper.search_query.ilike(f"%{word}%"))
97
+ elif field == "authors":
98
+ field_filters.append(Paper.authors.ilike(f"%{word}%"))
99
+
100
+ if field_filters:
101
+ return or_(*field_filters)
102
+ return None
103
+
104
+ # 多个单词:首先尝试短语匹配(整个查询字符串)
105
+ phrase_filters = []
72
106
  for field in search_fields:
73
107
  if field == "title":
74
- filters.append(Paper.title.contains(query))
108
+ phrase_filters.append(Paper.title.ilike(f"%{query_lower}%"))
75
109
  elif field == "abstract":
76
- filters.append(Paper.abstract.contains(query))
110
+ phrase_filters.append(Paper.abstract.ilike(f"%{query_lower}%"))
77
111
  elif field == "categories":
78
- filters.append(Paper.categories.contains(query))
112
+ phrase_filters.append(Paper.categories.ilike(f"%{query_lower}%"))
79
113
  elif field == "search_query":
80
- filters.append(Paper.search_query.contains(query))
114
+ phrase_filters.append(Paper.search_query.ilike(f"%{query_lower}%"))
81
115
  elif field == "authors":
82
- # 作者字段是JSON字符串,需要特殊处理
83
- filters.append(Paper.authors.contains(query))
84
-
85
- if not filters:
116
+ phrase_filters.append(Paper.authors.ilike(f"%{query_lower}%"))
117
+
118
+ # 尝试顺序匹配(单词按顺序出现,中间可间隔)
119
+ sequence_filters = []
120
+ if len(words) > 1:
121
+ # 构建模式:%word1%word2%word3%
122
+ sequence_pattern = "%" + "%".join(words) + "%"
123
+ for field in search_fields:
124
+ if field == "title":
125
+ sequence_filters.append(Paper.title.ilike(sequence_pattern))
126
+ elif field == "abstract":
127
+ sequence_filters.append(Paper.abstract.ilike(sequence_pattern))
128
+ elif field == "categories":
129
+ sequence_filters.append(Paper.categories.ilike(sequence_pattern))
130
+ elif field == "search_query":
131
+ sequence_filters.append(Paper.search_query.ilike(sequence_pattern))
132
+ elif field == "authors":
133
+ sequence_filters.append(Paper.authors.ilike(sequence_pattern))
134
+
135
+ # 然后添加单词AND匹配(所有单词必须在同一个字段中出现)
136
+ word_and_filters = []
137
+ for field in search_fields:
138
+ if field == "title":
139
+ # 标题必须包含所有单词
140
+ title_filters = [Paper.title.ilike(f"%{word}%") for word in words]
141
+ if title_filters:
142
+ word_and_filters.append(and_(*title_filters))
143
+ elif field == "abstract":
144
+ # 摘要必须包含所有单词
145
+ abstract_filters = [Paper.abstract.ilike(f"%{word}%") for word in words]
146
+ if abstract_filters:
147
+ word_and_filters.append(and_(*abstract_filters))
148
+ elif field == "categories":
149
+ # 分类必须包含所有单词(通常分类搜索是单个词)
150
+ category_filters = [Paper.categories.ilike(f"%{word}%") for word in words]
151
+ if category_filters:
152
+ word_and_filters.append(and_(*category_filters))
153
+ elif field == "search_query":
154
+ search_query_filters = [Paper.search_query.ilike(f"%{word}%") for word in words]
155
+ if search_query_filters:
156
+ word_and_filters.append(and_(*search_query_filters))
157
+ elif field == "authors":
158
+ author_filters = [Paper.authors.ilike(f"%{word}%") for word in words]
159
+ if author_filters:
160
+ word_and_filters.append(and_(*author_filters))
161
+
162
+ # 组合所有过滤器:短语匹配 OR 顺序匹配 OR 单词AND匹配
163
+ all_filters = []
164
+ if phrase_filters:
165
+ all_filters.append(or_(*phrase_filters))
166
+ if sequence_filters:
167
+ all_filters.append(or_(*sequence_filters))
168
+ if word_and_filters:
169
+ all_filters.append(or_(*word_and_filters))
170
+
171
+ if not all_filters:
86
172
  return None
87
173
 
88
- if match_all:
89
- return and_(*filters)
90
- else:
91
- return or_(*filters)
174
+ # 使用OR逻辑连接所有匹配类型
175
+ return or_(*all_filters)
92
176
 
93
177
  def build_category_filter(
94
178
  self,
@@ -147,7 +231,7 @@ class SearchEngine:
147
231
  filters = []
148
232
 
149
233
  if days_back:
150
- cutoff_date = datetime.utcnow() - timedelta(days=days_back)
234
+ cutoff_date = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(days=days_back)
151
235
  filters.append(Paper.published >= cutoff_date)
152
236
 
153
237
  if date_from:
@@ -290,38 +374,6 @@ class SearchEngine:
290
374
  output.error("相似论文搜索失败", details={"exception": str(e)})
291
375
  return []
292
376
 
293
- # 简化的相似性搜索:基于共同关键词或分类
294
- # 在实际应用中,可以使用更复杂的文本相似性算法
295
- all_papers = self.session.query(Paper).filter(Paper.arxiv_id != paper_id).all()
296
-
297
- # 计算简单相似度:分类重叠
298
- similar_papers = []
299
- target_cats = set(target_paper.categories.split()) if target_paper.categories else set()
300
-
301
- for paper in all_papers:
302
- if not paper.categories:
303
- continue
304
-
305
- paper_cats = set(paper.categories.split())
306
- common_cats = target_cats.intersection(paper_cats)
307
-
308
- if common_cats:
309
- # 简单相似度分数:共同分类数 / 总分类数
310
- similarity = len(common_cats) / max(len(target_cats), len(paper_cats))
311
- if similarity >= threshold:
312
- # 临时存储相似度分数
313
- paper.similarity_score = similarity
314
- similar_papers.append(paper)
315
-
316
- # 按相似度排序
317
- similar_papers.sort(key=lambda x: getattr(x, "similarity_score", 0), reverse=True)
318
-
319
- return similar_papers[:limit]
320
-
321
- except Exception as e:
322
- output.error("相似论文搜索失败", details={"exception": str(e)})
323
- return []
324
-
325
377
  def get_search_history(self, limit: int = 10) -> List[Dict[str, Any]]:
326
378
  """获取搜索历史(从数据库中的search_query字段提取)"""
327
379
  try:
@@ -360,7 +412,7 @@ class SearchEngine:
360
412
  output.error("获取搜索历史失败", details={"exception": str(e)})
361
413
  return []
362
414
 
363
- def save_search_query(self, query: str, description: str = None):
415
+ def save_search_query(self, query: str, description: Optional[str] = None):
364
416
  """保存搜索查询到历史(简单实现)"""
365
417
  # 这里可以扩展为保存到单独的搜索历史表
366
418
  # 目前依赖于Paper表中的search_query字段
arxiv_pulse/summarizer.py CHANGED
@@ -1,4 +1,3 @@
1
- import openai
2
1
  import json
3
2
  import logging
4
3
  from typing import List, Dict, Any, Optional