arxiv-pulse 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arxiv_pulse/.ENV.TEMPLATE +93 -41
- arxiv_pulse/__version__.py +2 -2
- arxiv_pulse/arxiv_crawler.py +65 -23
- arxiv_pulse/cli.py +228 -433
- arxiv_pulse/config.py +6 -8
- arxiv_pulse/models.py +17 -9
- arxiv_pulse/output_manager.py +38 -54
- arxiv_pulse/report_generator.py +3 -46
- arxiv_pulse/search_engine.py +105 -53
- arxiv_pulse/summarizer.py +0 -1
- {arxiv_pulse-0.5.0.dist-info → arxiv_pulse-0.6.1.dist-info}/METADATA +61 -124
- arxiv_pulse-0.6.1.dist-info/RECORD +17 -0
- arxiv_pulse-0.5.0.dist-info/RECORD +0 -17
- {arxiv_pulse-0.5.0.dist-info → arxiv_pulse-0.6.1.dist-info}/WHEEL +0 -0
- {arxiv_pulse-0.5.0.dist-info → arxiv_pulse-0.6.1.dist-info}/entry_points.txt +0 -0
- {arxiv_pulse-0.5.0.dist-info → arxiv_pulse-0.6.1.dist-info}/licenses/LICENSE +0 -0
- {arxiv_pulse-0.5.0.dist-info → arxiv_pulse-0.6.1.dist-info}/top_level.txt +0 -0
arxiv_pulse/config.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import warnings
|
|
3
2
|
|
|
4
3
|
|
|
5
4
|
class Config:
|
|
@@ -7,8 +6,8 @@ class Config:
|
|
|
7
6
|
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///data/arxiv_papers.db")
|
|
8
7
|
|
|
9
8
|
# Crawler
|
|
10
|
-
MAX_RESULTS_INITIAL = int(os.getenv("MAX_RESULTS_INITIAL",
|
|
11
|
-
MAX_RESULTS_DAILY = int(os.getenv("MAX_RESULTS_DAILY",
|
|
9
|
+
MAX_RESULTS_INITIAL = int(os.getenv("MAX_RESULTS_INITIAL", 10000))
|
|
10
|
+
MAX_RESULTS_DAILY = int(os.getenv("MAX_RESULTS_DAILY", 500))
|
|
12
11
|
|
|
13
12
|
# Search queries - use semicolon as separator to allow commas in queries
|
|
14
13
|
SEARCH_QUERIES_RAW = os.getenv(
|
|
@@ -29,7 +28,6 @@ class Config:
|
|
|
29
28
|
SUMMARY_MAX_TOKENS = int(os.getenv("SUMMARY_MAX_TOKENS", 2000))
|
|
30
29
|
|
|
31
30
|
# Report generation settings
|
|
32
|
-
SUMMARY_SENTENCES_LIMIT = int(os.getenv("SUMMARY_SENTENCES_LIMIT", 3))
|
|
33
31
|
TOKEN_PRICE_PER_MILLION = float(os.getenv("TOKEN_PRICE_PER_MILLION", 3.0))
|
|
34
32
|
|
|
35
33
|
# Paths
|
|
@@ -40,13 +38,13 @@ class Config:
|
|
|
40
38
|
REPORT_MAX_PAPERS = int(os.getenv("REPORT_MAX_PAPERS", "50"))
|
|
41
39
|
|
|
42
40
|
# ArXiv API
|
|
43
|
-
ARXIV_MAX_RESULTS =
|
|
44
|
-
ARXIV_SORT_BY = "submittedDate"
|
|
45
|
-
ARXIV_SORT_ORDER = "descending"
|
|
41
|
+
ARXIV_MAX_RESULTS = int(os.getenv("ARXIV_MAX_RESULTS", 30000))
|
|
42
|
+
ARXIV_SORT_BY = os.getenv("ARXIV_SORT_BY", "submittedDate")
|
|
43
|
+
ARXIV_SORT_ORDER = os.getenv("ARXIV_SORT_ORDER", "descending")
|
|
46
44
|
|
|
47
45
|
# Sync configuration
|
|
48
46
|
YEARS_BACK = int(os.getenv("YEARS_BACK", 3)) # Years to look back for initial sync
|
|
49
|
-
IMPORTANT_PAPERS_FILE = os.getenv("IMPORTANT_PAPERS_FILE", "important_papers.txt")
|
|
47
|
+
IMPORTANT_PAPERS_FILE = os.getenv("IMPORTANT_PAPERS_FILE", "data/important_papers.txt")
|
|
50
48
|
|
|
51
49
|
@classmethod
|
|
52
50
|
def validate(cls):
|
arxiv_pulse/models.py
CHANGED
|
@@ -11,7 +11,7 @@ from sqlalchemy import (
|
|
|
11
11
|
)
|
|
12
12
|
from sqlalchemy.ext.declarative import declarative_base
|
|
13
13
|
from sqlalchemy.orm import sessionmaker
|
|
14
|
-
from datetime import datetime, timedelta
|
|
14
|
+
from datetime import datetime, timedelta, timezone
|
|
15
15
|
import json
|
|
16
16
|
from typing import Optional
|
|
17
17
|
|
|
@@ -48,8 +48,12 @@ class Paper(Base):
|
|
|
48
48
|
summary = Column(Text)
|
|
49
49
|
|
|
50
50
|
# Metadata
|
|
51
|
-
created_at = Column(DateTime, default=datetime.
|
|
52
|
-
updated_at = Column(
|
|
51
|
+
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None))
|
|
52
|
+
updated_at = Column(
|
|
53
|
+
DateTime,
|
|
54
|
+
default=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
|
|
55
|
+
onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
|
|
56
|
+
)
|
|
53
57
|
|
|
54
58
|
def to_dict(self):
|
|
55
59
|
"""Convert to dictionary"""
|
|
@@ -112,8 +116,12 @@ class TranslationCache(Base):
|
|
|
112
116
|
source_text_hash = Column(String(64), nullable=False, unique=True, index=True)
|
|
113
117
|
translated_text = Column(Text, nullable=False)
|
|
114
118
|
target_language = Column(String(10), default="zh")
|
|
115
|
-
created_at = Column(DateTime, default=datetime.
|
|
116
|
-
updated_at = Column(
|
|
119
|
+
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None))
|
|
120
|
+
updated_at = Column(
|
|
121
|
+
DateTime,
|
|
122
|
+
default=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
|
|
123
|
+
onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
|
|
124
|
+
)
|
|
117
125
|
|
|
118
126
|
def __repr__(self):
|
|
119
127
|
return f"<TranslationCache(id={self.id}, hash={self.source_text_hash[:16]}...)>"
|
|
@@ -147,7 +155,7 @@ class Database:
|
|
|
147
155
|
if paper:
|
|
148
156
|
for key, value in kwargs.items():
|
|
149
157
|
setattr(paper, key, value)
|
|
150
|
-
paper.updated_at = datetime.
|
|
158
|
+
paper.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
151
159
|
session.commit()
|
|
152
160
|
return True
|
|
153
161
|
return False
|
|
@@ -155,7 +163,7 @@ class Database:
|
|
|
155
163
|
def get_recent_papers(self, days=7, limit=100):
|
|
156
164
|
"""Get recent papers"""
|
|
157
165
|
with self.get_session() as session:
|
|
158
|
-
cutoff_date = datetime.
|
|
166
|
+
cutoff_date = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(days=days)
|
|
159
167
|
return (
|
|
160
168
|
session.query(Paper)
|
|
161
169
|
.filter(Paper.published >= cutoff_date)
|
|
@@ -233,7 +241,7 @@ class Database:
|
|
|
233
241
|
if existing:
|
|
234
242
|
# 更新现有缓存
|
|
235
243
|
existing.translated_text = translated_text
|
|
236
|
-
existing.updated_at = datetime.
|
|
244
|
+
existing.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
237
245
|
else:
|
|
238
246
|
# 创建新缓存
|
|
239
247
|
cache_entry = TranslationCache(
|
|
@@ -249,7 +257,7 @@ class Database:
|
|
|
249
257
|
def clear_old_translation_cache(self, days_old: int = 30) -> int:
|
|
250
258
|
"""清理旧的翻译缓存"""
|
|
251
259
|
with self.get_session() as session:
|
|
252
|
-
cutoff_date = datetime.
|
|
260
|
+
cutoff_date = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(days=days_old)
|
|
253
261
|
deleted_count = session.query(TranslationCache).filter(TranslationCache.updated_at < cutoff_date).delete()
|
|
254
262
|
session.commit()
|
|
255
263
|
return deleted_count
|
arxiv_pulse/output_manager.py
CHANGED
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
[error] - 错误信息(简洁)
|
|
12
12
|
[debug] - 调试信息(默认不显示)
|
|
13
13
|
|
|
14
|
-
|
|
14
|
+
所有输出仅显示在控制台,不写入日志文件。
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import sys
|
|
@@ -72,34 +72,30 @@ class OutputManager:
|
|
|
72
72
|
if not self._initialized:
|
|
73
73
|
self._initialized = True
|
|
74
74
|
self._console_enabled = True
|
|
75
|
-
|
|
76
|
-
|
|
75
|
+
# 从环境变量读取日志级别,默认为INFO
|
|
76
|
+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
77
|
+
level_map = {
|
|
78
|
+
"DEBUG": OutputLevel.DEBUG,
|
|
79
|
+
"INFO": OutputLevel.INFO,
|
|
80
|
+
"WARNING": OutputLevel.WARN,
|
|
81
|
+
"WARN": OutputLevel.WARN,
|
|
82
|
+
"ERROR": OutputLevel.ERROR,
|
|
83
|
+
"DO": OutputLevel.DO,
|
|
84
|
+
"DONE": OutputLevel.DONE,
|
|
85
|
+
"TIPS": OutputLevel.TIPS,
|
|
86
|
+
}
|
|
87
|
+
self._min_level = level_map.get(log_level, OutputLevel.INFO)
|
|
77
88
|
self._suppressed_modules = set()
|
|
78
|
-
|
|
89
|
+
# 创建一个基本的日志记录器(不写入文件)
|
|
90
|
+
self._file_logger = logging.getLogger("arxiv_pulse")
|
|
91
|
+
self._file_logger.setLevel(logging.DEBUG)
|
|
92
|
+
# 添加NullHandler避免"No handlers"警告
|
|
93
|
+
if not self._file_logger.handlers:
|
|
94
|
+
self._file_logger.addHandler(logging.NullHandler())
|
|
79
95
|
|
|
80
96
|
# 抑制第三方库的详细日志
|
|
81
97
|
self._suppress_third_party_logs()
|
|
82
98
|
|
|
83
|
-
def _setup_file_logger(self):
|
|
84
|
-
"""设置文件日志记录器"""
|
|
85
|
-
# 创建日志目录
|
|
86
|
-
os.makedirs("logs", exist_ok=True)
|
|
87
|
-
|
|
88
|
-
# 配置文件日志记录器
|
|
89
|
-
self._file_logger = logging.getLogger("arxiv_crawler")
|
|
90
|
-
self._file_logger.setLevel(logging.DEBUG)
|
|
91
|
-
|
|
92
|
-
# 移除现有处理器
|
|
93
|
-
for handler in self._file_logger.handlers[:]:
|
|
94
|
-
self._file_logger.removeHandler(handler)
|
|
95
|
-
|
|
96
|
-
# 添加文件处理器
|
|
97
|
-
file_handler = logging.FileHandler("logs/arxiv_pulse.log", encoding="utf-8")
|
|
98
|
-
file_handler.setLevel(logging.DEBUG)
|
|
99
|
-
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
100
|
-
file_handler.setFormatter(formatter)
|
|
101
|
-
self._file_logger.addHandler(file_handler)
|
|
102
|
-
|
|
103
99
|
def _suppress_third_party_logs(self):
|
|
104
100
|
"""抑制第三方库的详细日志"""
|
|
105
101
|
# 设置第三方库的日志级别为WARNING或更高
|
|
@@ -117,16 +113,18 @@ class OutputManager:
|
|
|
117
113
|
return False
|
|
118
114
|
|
|
119
115
|
# 检查级别
|
|
116
|
+
# 数字越小表示级别越低(越不重要)
|
|
120
117
|
level_order = {
|
|
121
|
-
OutputLevel.
|
|
122
|
-
OutputLevel.
|
|
123
|
-
OutputLevel.
|
|
124
|
-
OutputLevel.
|
|
125
|
-
OutputLevel.
|
|
126
|
-
OutputLevel.
|
|
127
|
-
OutputLevel.
|
|
118
|
+
OutputLevel.DEBUG: 0, # 最低级别
|
|
119
|
+
OutputLevel.INFO: 1,
|
|
120
|
+
OutputLevel.WARN: 2,
|
|
121
|
+
OutputLevel.ERROR: 3,
|
|
122
|
+
OutputLevel.DO: 4, # 操作提示,通常显示
|
|
123
|
+
OutputLevel.DONE: 5,
|
|
124
|
+
OutputLevel.TIPS: 6,
|
|
128
125
|
}
|
|
129
126
|
|
|
127
|
+
# 只有级别数字 >= 最小级别数字的才显示
|
|
130
128
|
return level_order[level] >= level_order[self._min_level]
|
|
131
129
|
|
|
132
130
|
def _output(
|
|
@@ -137,28 +135,6 @@ class OutputManager:
|
|
|
137
135
|
details: Optional[Dict[str, Any]] = None,
|
|
138
136
|
):
|
|
139
137
|
"""统一输出方法"""
|
|
140
|
-
# 记录到文件日志
|
|
141
|
-
log_level = {
|
|
142
|
-
OutputLevel.DO: logging.INFO,
|
|
143
|
-
OutputLevel.DONE: logging.INFO,
|
|
144
|
-
OutputLevel.TIPS: logging.INFO,
|
|
145
|
-
OutputLevel.INFO: logging.INFO,
|
|
146
|
-
OutputLevel.WARN: logging.WARNING,
|
|
147
|
-
OutputLevel.ERROR: logging.ERROR,
|
|
148
|
-
OutputLevel.DEBUG: logging.DEBUG,
|
|
149
|
-
}[level]
|
|
150
|
-
|
|
151
|
-
# 构建详细日志消息
|
|
152
|
-
log_message = message
|
|
153
|
-
if module:
|
|
154
|
-
log_message = f"[{module}] {message}"
|
|
155
|
-
if details:
|
|
156
|
-
details_str = " ".join(f"{k}={v}" for k, v in details.items())
|
|
157
|
-
log_message = f"{log_message} | {details_str}"
|
|
158
|
-
|
|
159
|
-
# 写入文件日志
|
|
160
|
-
self._file_logger.log(log_level, log_message)
|
|
161
|
-
|
|
162
138
|
# 控制台输出
|
|
163
139
|
if self._console_enabled and self._should_output(level, module):
|
|
164
140
|
# 获取标签和颜色
|
|
@@ -228,7 +204,15 @@ class OutputManager:
|
|
|
228
204
|
@classmethod
|
|
229
205
|
def get_file_logger(cls) -> logging.Logger:
|
|
230
206
|
"""获取文件日志记录器"""
|
|
231
|
-
|
|
207
|
+
instance = cls()
|
|
208
|
+
if instance._file_logger is None:
|
|
209
|
+
# 创建基本的日志记录器作为回退
|
|
210
|
+
instance._file_logger = logging.getLogger("arxiv_pulse_fallback")
|
|
211
|
+
instance._file_logger.setLevel(logging.DEBUG)
|
|
212
|
+
if not instance._file_logger.handlers:
|
|
213
|
+
instance._file_logger.addHandler(logging.NullHandler())
|
|
214
|
+
assert instance._file_logger is not None
|
|
215
|
+
return instance._file_logger
|
|
232
216
|
|
|
233
217
|
|
|
234
218
|
# 简化别名
|
arxiv_pulse/report_generator.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import pandas as pd
|
|
3
|
-
from datetime import datetime, timedelta
|
|
4
|
-
import markdown
|
|
3
|
+
from datetime import datetime, timedelta, timezone
|
|
5
4
|
from typing import Dict, List, Any, Optional
|
|
6
5
|
import logging
|
|
7
6
|
import os
|
|
@@ -21,7 +20,6 @@ class ReportGenerator:
|
|
|
21
20
|
self.total_tokens_used = 0 # 总token使用量
|
|
22
21
|
self.total_cost = 0.0 # 总费用(元)
|
|
23
22
|
self.token_price_per_million = Config.TOKEN_PRICE_PER_MILLION # 每百万token价格,可从配置覆盖
|
|
24
|
-
self.summary_sentences_limit = Config.SUMMARY_SENTENCES_LIMIT # 摘要句子数限制
|
|
25
23
|
|
|
26
24
|
# 抑制第三方库的详细日志
|
|
27
25
|
import logging
|
|
@@ -211,47 +209,6 @@ class ReportGenerator:
|
|
|
211
209
|
# 确保分数在1-5之间
|
|
212
210
|
return max(1, min(5, score))
|
|
213
211
|
|
|
214
|
-
def _truncate_to_sentences(self, text: str, max_sentences: Optional[int] = None) -> str:
|
|
215
|
-
"""将文本截断为指定数量的句子(支持中英文)"""
|
|
216
|
-
if not text:
|
|
217
|
-
return ""
|
|
218
|
-
|
|
219
|
-
if max_sentences is None:
|
|
220
|
-
max_sentences = self.summary_sentences_limit
|
|
221
|
-
|
|
222
|
-
import re
|
|
223
|
-
|
|
224
|
-
# 支持中英文句子分隔符:句号、问号、感叹号、分号、省略号
|
|
225
|
-
# 英文: . ? ! ; ... 中文: 。!?;…
|
|
226
|
-
pattern = r"([。!?;…\.\?!;]+|\.{3,})"
|
|
227
|
-
parts = re.split(pattern, text)
|
|
228
|
-
|
|
229
|
-
sentences = []
|
|
230
|
-
current = ""
|
|
231
|
-
for i, part in enumerate(parts):
|
|
232
|
-
current += part
|
|
233
|
-
if i % 2 == 1: # 分隔符部分
|
|
234
|
-
sentences.append(current)
|
|
235
|
-
current = ""
|
|
236
|
-
|
|
237
|
-
# 如果最后还有未结束的句子
|
|
238
|
-
if current:
|
|
239
|
-
sentences.append(current)
|
|
240
|
-
|
|
241
|
-
# 如果分割失败,按长度简单截断
|
|
242
|
-
if len(sentences) == 0:
|
|
243
|
-
return text[:200] + "..." if len(text) > 200 else text
|
|
244
|
-
|
|
245
|
-
# 取前max_sentences句
|
|
246
|
-
result = "".join(sentences[:max_sentences])
|
|
247
|
-
|
|
248
|
-
# 如果截断后比原文本短很多,添加省略号
|
|
249
|
-
if len(result) < len(text) * 0.8:
|
|
250
|
-
# 移除末尾的句子分隔符,添加省略号
|
|
251
|
-
result = result.rstrip("。!?;….?!;") + "…"
|
|
252
|
-
|
|
253
|
-
return result
|
|
254
|
-
|
|
255
212
|
def translate_text(self, text: str, target_lang: str = "zh") -> str:
|
|
256
213
|
"""使用DeepSeek或OpenAI API翻译文本,优先使用缓存"""
|
|
257
214
|
if not text or not text.strip():
|
|
@@ -364,7 +321,7 @@ class ReportGenerator:
|
|
|
364
321
|
|
|
365
322
|
with self.db.get_session() as session:
|
|
366
323
|
# Get papers from last 24 hours
|
|
367
|
-
cutoff = datetime.
|
|
324
|
+
cutoff = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=24)
|
|
368
325
|
new_papers = (
|
|
369
326
|
session.query(Paper)
|
|
370
327
|
.filter(Paper.created_at >= cutoff)
|
|
@@ -406,7 +363,7 @@ class ReportGenerator:
|
|
|
406
363
|
|
|
407
364
|
with self.db.get_session() as session:
|
|
408
365
|
# Get papers from last 7 days
|
|
409
|
-
cutoff = datetime.
|
|
366
|
+
cutoff = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(days=7)
|
|
410
367
|
recent_papers = (
|
|
411
368
|
session.query(Paper)
|
|
412
369
|
.filter(Paper.created_at >= cutoff)
|
arxiv_pulse/search_engine.py
CHANGED
|
@@ -2,11 +2,9 @@
|
|
|
2
2
|
增强搜索引擎 - 提供高级搜索和过滤功能
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
import
|
|
6
|
-
from
|
|
7
|
-
from typing import List, Dict, Any, Optional, Union
|
|
5
|
+
from datetime import datetime, timedelta, timezone
|
|
6
|
+
from typing import List, Dict, Any, Optional
|
|
8
7
|
from dataclasses import dataclass, field
|
|
9
|
-
from pathlib import Path
|
|
10
8
|
|
|
11
9
|
from sqlalchemy import and_, or_, not_, func, desc, asc
|
|
12
10
|
from sqlalchemy.orm import Session
|
|
@@ -21,7 +19,7 @@ class SearchFilter:
|
|
|
21
19
|
|
|
22
20
|
# 文本搜索
|
|
23
21
|
query: Optional[str] = None
|
|
24
|
-
search_fields: List[str] = field(default_factory=lambda: ["title", "abstract"
|
|
22
|
+
search_fields: List[str] = field(default_factory=lambda: ["title", "abstract"])
|
|
25
23
|
|
|
26
24
|
# 分类过滤
|
|
27
25
|
categories: Optional[List[str]] = None
|
|
@@ -64,31 +62,117 @@ class SearchEngine:
|
|
|
64
62
|
self.session = db_session
|
|
65
63
|
|
|
66
64
|
def build_text_filter(self, query: str, search_fields: List[str], match_all: bool = False):
|
|
67
|
-
"""
|
|
65
|
+
"""构建文本搜索过滤器,简单模糊匹配(支持单词拆分)"""
|
|
68
66
|
if not query or not search_fields:
|
|
69
67
|
return None
|
|
70
68
|
|
|
71
|
-
|
|
69
|
+
# 将查询转换为小写进行不区分大小写的匹配
|
|
70
|
+
query_lower = query.lower()
|
|
71
|
+
|
|
72
|
+
# 拆分为单词(按非字母数字字符,保留中文)
|
|
73
|
+
import re
|
|
74
|
+
|
|
75
|
+
# 使用正则表达式分割,保留中文字符(支持Unicode)
|
|
76
|
+
words = re.split(r"[^\w]+", query_lower, flags=re.UNICODE)
|
|
77
|
+
# 过滤掉空字符串和过短的单词(长度>1)
|
|
78
|
+
words = [w for w in words if w and len(w) > 1]
|
|
79
|
+
|
|
80
|
+
# 如果没有有效的单词,使用整个查询作为单个单词
|
|
81
|
+
if not words:
|
|
82
|
+
words = [query_lower]
|
|
83
|
+
|
|
84
|
+
# 如果只有一个单词,使用简单的字段间OR逻辑
|
|
85
|
+
if len(words) == 1:
|
|
86
|
+
word = words[0]
|
|
87
|
+
field_filters = []
|
|
88
|
+
for field in search_fields:
|
|
89
|
+
if field == "title":
|
|
90
|
+
field_filters.append(Paper.title.ilike(f"%{word}%"))
|
|
91
|
+
elif field == "abstract":
|
|
92
|
+
field_filters.append(Paper.abstract.ilike(f"%{word}%"))
|
|
93
|
+
elif field == "categories":
|
|
94
|
+
field_filters.append(Paper.categories.ilike(f"%{word}%"))
|
|
95
|
+
elif field == "search_query":
|
|
96
|
+
field_filters.append(Paper.search_query.ilike(f"%{word}%"))
|
|
97
|
+
elif field == "authors":
|
|
98
|
+
field_filters.append(Paper.authors.ilike(f"%{word}%"))
|
|
99
|
+
|
|
100
|
+
if field_filters:
|
|
101
|
+
return or_(*field_filters)
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
# 多个单词:首先尝试短语匹配(整个查询字符串)
|
|
105
|
+
phrase_filters = []
|
|
72
106
|
for field in search_fields:
|
|
73
107
|
if field == "title":
|
|
74
|
-
|
|
108
|
+
phrase_filters.append(Paper.title.ilike(f"%{query_lower}%"))
|
|
75
109
|
elif field == "abstract":
|
|
76
|
-
|
|
110
|
+
phrase_filters.append(Paper.abstract.ilike(f"%{query_lower}%"))
|
|
77
111
|
elif field == "categories":
|
|
78
|
-
|
|
112
|
+
phrase_filters.append(Paper.categories.ilike(f"%{query_lower}%"))
|
|
79
113
|
elif field == "search_query":
|
|
80
|
-
|
|
114
|
+
phrase_filters.append(Paper.search_query.ilike(f"%{query_lower}%"))
|
|
81
115
|
elif field == "authors":
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
116
|
+
phrase_filters.append(Paper.authors.ilike(f"%{query_lower}%"))
|
|
117
|
+
|
|
118
|
+
# 尝试顺序匹配(单词按顺序出现,中间可间隔)
|
|
119
|
+
sequence_filters = []
|
|
120
|
+
if len(words) > 1:
|
|
121
|
+
# 构建模式:%word1%word2%word3%
|
|
122
|
+
sequence_pattern = "%" + "%".join(words) + "%"
|
|
123
|
+
for field in search_fields:
|
|
124
|
+
if field == "title":
|
|
125
|
+
sequence_filters.append(Paper.title.ilike(sequence_pattern))
|
|
126
|
+
elif field == "abstract":
|
|
127
|
+
sequence_filters.append(Paper.abstract.ilike(sequence_pattern))
|
|
128
|
+
elif field == "categories":
|
|
129
|
+
sequence_filters.append(Paper.categories.ilike(sequence_pattern))
|
|
130
|
+
elif field == "search_query":
|
|
131
|
+
sequence_filters.append(Paper.search_query.ilike(sequence_pattern))
|
|
132
|
+
elif field == "authors":
|
|
133
|
+
sequence_filters.append(Paper.authors.ilike(sequence_pattern))
|
|
134
|
+
|
|
135
|
+
# 然后添加单词AND匹配(所有单词必须在同一个字段中出现)
|
|
136
|
+
word_and_filters = []
|
|
137
|
+
for field in search_fields:
|
|
138
|
+
if field == "title":
|
|
139
|
+
# 标题必须包含所有单词
|
|
140
|
+
title_filters = [Paper.title.ilike(f"%{word}%") for word in words]
|
|
141
|
+
if title_filters:
|
|
142
|
+
word_and_filters.append(and_(*title_filters))
|
|
143
|
+
elif field == "abstract":
|
|
144
|
+
# 摘要必须包含所有单词
|
|
145
|
+
abstract_filters = [Paper.abstract.ilike(f"%{word}%") for word in words]
|
|
146
|
+
if abstract_filters:
|
|
147
|
+
word_and_filters.append(and_(*abstract_filters))
|
|
148
|
+
elif field == "categories":
|
|
149
|
+
# 分类必须包含所有单词(通常分类搜索是单个词)
|
|
150
|
+
category_filters = [Paper.categories.ilike(f"%{word}%") for word in words]
|
|
151
|
+
if category_filters:
|
|
152
|
+
word_and_filters.append(and_(*category_filters))
|
|
153
|
+
elif field == "search_query":
|
|
154
|
+
search_query_filters = [Paper.search_query.ilike(f"%{word}%") for word in words]
|
|
155
|
+
if search_query_filters:
|
|
156
|
+
word_and_filters.append(and_(*search_query_filters))
|
|
157
|
+
elif field == "authors":
|
|
158
|
+
author_filters = [Paper.authors.ilike(f"%{word}%") for word in words]
|
|
159
|
+
if author_filters:
|
|
160
|
+
word_and_filters.append(and_(*author_filters))
|
|
161
|
+
|
|
162
|
+
# 组合所有过滤器:短语匹配 OR 顺序匹配 OR 单词AND匹配
|
|
163
|
+
all_filters = []
|
|
164
|
+
if phrase_filters:
|
|
165
|
+
all_filters.append(or_(*phrase_filters))
|
|
166
|
+
if sequence_filters:
|
|
167
|
+
all_filters.append(or_(*sequence_filters))
|
|
168
|
+
if word_and_filters:
|
|
169
|
+
all_filters.append(or_(*word_and_filters))
|
|
170
|
+
|
|
171
|
+
if not all_filters:
|
|
86
172
|
return None
|
|
87
173
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
else:
|
|
91
|
-
return or_(*filters)
|
|
174
|
+
# 使用OR逻辑连接所有匹配类型
|
|
175
|
+
return or_(*all_filters)
|
|
92
176
|
|
|
93
177
|
def build_category_filter(
|
|
94
178
|
self,
|
|
@@ -147,7 +231,7 @@ class SearchEngine:
|
|
|
147
231
|
filters = []
|
|
148
232
|
|
|
149
233
|
if days_back:
|
|
150
|
-
cutoff_date = datetime.
|
|
234
|
+
cutoff_date = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(days=days_back)
|
|
151
235
|
filters.append(Paper.published >= cutoff_date)
|
|
152
236
|
|
|
153
237
|
if date_from:
|
|
@@ -290,38 +374,6 @@ class SearchEngine:
|
|
|
290
374
|
output.error("相似论文搜索失败", details={"exception": str(e)})
|
|
291
375
|
return []
|
|
292
376
|
|
|
293
|
-
# 简化的相似性搜索:基于共同关键词或分类
|
|
294
|
-
# 在实际应用中,可以使用更复杂的文本相似性算法
|
|
295
|
-
all_papers = self.session.query(Paper).filter(Paper.arxiv_id != paper_id).all()
|
|
296
|
-
|
|
297
|
-
# 计算简单相似度:分类重叠
|
|
298
|
-
similar_papers = []
|
|
299
|
-
target_cats = set(target_paper.categories.split()) if target_paper.categories else set()
|
|
300
|
-
|
|
301
|
-
for paper in all_papers:
|
|
302
|
-
if not paper.categories:
|
|
303
|
-
continue
|
|
304
|
-
|
|
305
|
-
paper_cats = set(paper.categories.split())
|
|
306
|
-
common_cats = target_cats.intersection(paper_cats)
|
|
307
|
-
|
|
308
|
-
if common_cats:
|
|
309
|
-
# 简单相似度分数:共同分类数 / 总分类数
|
|
310
|
-
similarity = len(common_cats) / max(len(target_cats), len(paper_cats))
|
|
311
|
-
if similarity >= threshold:
|
|
312
|
-
# 临时存储相似度分数
|
|
313
|
-
paper.similarity_score = similarity
|
|
314
|
-
similar_papers.append(paper)
|
|
315
|
-
|
|
316
|
-
# 按相似度排序
|
|
317
|
-
similar_papers.sort(key=lambda x: getattr(x, "similarity_score", 0), reverse=True)
|
|
318
|
-
|
|
319
|
-
return similar_papers[:limit]
|
|
320
|
-
|
|
321
|
-
except Exception as e:
|
|
322
|
-
output.error("相似论文搜索失败", details={"exception": str(e)})
|
|
323
|
-
return []
|
|
324
|
-
|
|
325
377
|
def get_search_history(self, limit: int = 10) -> List[Dict[str, Any]]:
|
|
326
378
|
"""获取搜索历史(从数据库中的search_query字段提取)"""
|
|
327
379
|
try:
|
|
@@ -360,7 +412,7 @@ class SearchEngine:
|
|
|
360
412
|
output.error("获取搜索历史失败", details={"exception": str(e)})
|
|
361
413
|
return []
|
|
362
414
|
|
|
363
|
-
def save_search_query(self, query: str, description: str = None):
|
|
415
|
+
def save_search_query(self, query: str, description: Optional[str] = None):
|
|
364
416
|
"""保存搜索查询到历史(简单实现)"""
|
|
365
417
|
# 这里可以扩展为保存到单独的搜索历史表
|
|
366
418
|
# 目前依赖于Paper表中的search_query字段
|
arxiv_pulse/summarizer.py
CHANGED