arionxiv 1.0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arionxiv/__init__.py +40 -0
- arionxiv/__main__.py +10 -0
- arionxiv/arxiv_operations/__init__.py +0 -0
- arionxiv/arxiv_operations/client.py +225 -0
- arionxiv/arxiv_operations/fetcher.py +173 -0
- arionxiv/arxiv_operations/searcher.py +122 -0
- arionxiv/arxiv_operations/utils.py +293 -0
- arionxiv/cli/__init__.py +4 -0
- arionxiv/cli/commands/__init__.py +1 -0
- arionxiv/cli/commands/analyze.py +587 -0
- arionxiv/cli/commands/auth.py +365 -0
- arionxiv/cli/commands/chat.py +714 -0
- arionxiv/cli/commands/daily.py +482 -0
- arionxiv/cli/commands/fetch.py +217 -0
- arionxiv/cli/commands/library.py +295 -0
- arionxiv/cli/commands/preferences.py +426 -0
- arionxiv/cli/commands/search.py +254 -0
- arionxiv/cli/commands/settings_unified.py +1407 -0
- arionxiv/cli/commands/trending.py +41 -0
- arionxiv/cli/commands/welcome.py +168 -0
- arionxiv/cli/main.py +407 -0
- arionxiv/cli/ui/__init__.py +1 -0
- arionxiv/cli/ui/global_theme_manager.py +173 -0
- arionxiv/cli/ui/logo.py +127 -0
- arionxiv/cli/ui/splash.py +89 -0
- arionxiv/cli/ui/theme.py +32 -0
- arionxiv/cli/ui/theme_system.py +391 -0
- arionxiv/cli/utils/__init__.py +54 -0
- arionxiv/cli/utils/animations.py +522 -0
- arionxiv/cli/utils/api_client.py +583 -0
- arionxiv/cli/utils/api_config.py +505 -0
- arionxiv/cli/utils/command_suggestions.py +147 -0
- arionxiv/cli/utils/db_config_manager.py +254 -0
- arionxiv/github_actions_runner.py +206 -0
- arionxiv/main.py +23 -0
- arionxiv/prompts/__init__.py +9 -0
- arionxiv/prompts/prompts.py +247 -0
- arionxiv/rag_techniques/__init__.py +8 -0
- arionxiv/rag_techniques/basic_rag.py +1531 -0
- arionxiv/scheduler_daemon.py +139 -0
- arionxiv/server.py +1000 -0
- arionxiv/server_main.py +24 -0
- arionxiv/services/__init__.py +73 -0
- arionxiv/services/llm_client.py +30 -0
- arionxiv/services/llm_inference/__init__.py +58 -0
- arionxiv/services/llm_inference/groq_client.py +469 -0
- arionxiv/services/llm_inference/llm_utils.py +250 -0
- arionxiv/services/llm_inference/openrouter_client.py +564 -0
- arionxiv/services/unified_analysis_service.py +872 -0
- arionxiv/services/unified_auth_service.py +457 -0
- arionxiv/services/unified_config_service.py +456 -0
- arionxiv/services/unified_daily_dose_service.py +823 -0
- arionxiv/services/unified_database_service.py +1633 -0
- arionxiv/services/unified_llm_service.py +366 -0
- arionxiv/services/unified_paper_service.py +604 -0
- arionxiv/services/unified_pdf_service.py +522 -0
- arionxiv/services/unified_prompt_service.py +344 -0
- arionxiv/services/unified_scheduler_service.py +589 -0
- arionxiv/services/unified_user_service.py +954 -0
- arionxiv/utils/__init__.py +51 -0
- arionxiv/utils/api_helpers.py +200 -0
- arionxiv/utils/file_cleanup.py +150 -0
- arionxiv/utils/ip_helper.py +96 -0
- arionxiv-1.0.32.dist-info/METADATA +336 -0
- arionxiv-1.0.32.dist-info/RECORD +69 -0
- arionxiv-1.0.32.dist-info/WHEEL +5 -0
- arionxiv-1.0.32.dist-info/entry_points.txt +4 -0
- arionxiv-1.0.32.dist-info/licenses/LICENSE +21 -0
- arionxiv-1.0.32.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1633 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified Database Service for ArionXiv
|
|
3
|
+
Consolidates database_client.py and sync_db_wrapper.py
|
|
4
|
+
Provides comprehensive database management with both async and sync interfaces
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import asyncio
|
|
9
|
+
import threading
|
|
10
|
+
from motor.motor_asyncio import AsyncIOMotorClient
|
|
11
|
+
from typing import Optional, Dict, Any, List
|
|
12
|
+
import logging
|
|
13
|
+
from datetime import datetime, timedelta
|
|
14
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
15
|
+
from dotenv import load_dotenv
|
|
16
|
+
|
|
17
|
+
load_dotenv()
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# Import IP helper for better error messages
|
|
22
|
+
try:
|
|
23
|
+
from ..utils.ip_helper import check_mongodb_connection_error, get_public_ip
|
|
24
|
+
IP_HELPER_AVAILABLE = True
|
|
25
|
+
except ImportError:
|
|
26
|
+
IP_HELPER_AVAILABLE = False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class UnifiedDatabaseService:
|
|
30
|
+
"""
|
|
31
|
+
Comprehensive database service that handles:
|
|
32
|
+
1. MongoDB connections with TTL-based caching
|
|
33
|
+
2. Synchronous wrapper for CLI operations (sync_db_wrapper.py functionality)
|
|
34
|
+
3. All CRUD operations for papers, users, sessions, and analysis
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
# MongoDB URI must be provided via environment variable
|
|
38
|
+
# Set MONGODB_URI in .env file or as environment variable
|
|
39
|
+
DEFAULT_MONGODB_URI = None
|
|
40
|
+
|
|
41
|
+
def __init__(self):
|
|
42
|
+
# Async database clients
|
|
43
|
+
self.mongodb_client: Optional[AsyncIOMotorClient] = None
|
|
44
|
+
self.db = None
|
|
45
|
+
|
|
46
|
+
# Sync operations support
|
|
47
|
+
self.executor = ThreadPoolExecutor(max_workers=2)
|
|
48
|
+
|
|
49
|
+
# MongoDB connection string - uses default production URI, can be overridden by env vars
|
|
50
|
+
self._db_url = None
|
|
51
|
+
self.database_name = os.getenv("DATABASE_NAME", "arionxiv")
|
|
52
|
+
|
|
53
|
+
logger.info("UnifiedDatabaseService initialized")
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def db_url(self) -> Optional[str]:
|
|
57
|
+
"""Get MongoDB URL - uses environment variable if set, otherwise returns None."""
|
|
58
|
+
if self._db_url is None:
|
|
59
|
+
# Check environment variables first (for development/testing override)
|
|
60
|
+
self._db_url = os.getenv('MONGODB_URI') or os.getenv('MONGODB_URL')
|
|
61
|
+
# Fall back to default production URI if available
|
|
62
|
+
if not self._db_url and self.DEFAULT_MONGODB_URI:
|
|
63
|
+
self._db_url = self.DEFAULT_MONGODB_URI
|
|
64
|
+
return self._db_url
|
|
65
|
+
|
|
66
|
+
# ============================================================
|
|
67
|
+
# CONNECTION MANAGEMENT (from database_client.py)
|
|
68
|
+
# ============================================================
|
|
69
|
+
|
|
70
|
+
async def connect_mongodb(self):
|
|
71
|
+
"""
|
|
72
|
+
Connect to MongoDB Atlas with proper SSL configuration
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
# Check if MongoDB URI is configured
|
|
77
|
+
if not self.db_url:
|
|
78
|
+
# Silent debug - end users use hosted Vercel API, local MongoDB is optional
|
|
79
|
+
logger.debug("No local MongoDB URI configured - this is normal for end users")
|
|
80
|
+
raise ValueError("Local MongoDB not configured")
|
|
81
|
+
|
|
82
|
+
logger.info("Attempting to connect to MongoDB Atlas...")
|
|
83
|
+
|
|
84
|
+
# Import config service for connection parameters
|
|
85
|
+
from .unified_config_service import unified_config_service
|
|
86
|
+
connection_params = unified_config_service.get_mongodb_connection_config()
|
|
87
|
+
|
|
88
|
+
logger.info(f"Connection timeout: {connection_params['connectTimeoutMS']}ms")
|
|
89
|
+
|
|
90
|
+
# For MongoDB Atlas (mongodb+srv://) - use property to trigger lazy load
|
|
91
|
+
if 'mongodb+srv://' in self.db_url:
|
|
92
|
+
logger.info("Connecting to MongoDB Atlas cluster...")
|
|
93
|
+
self.mongodb_client = AsyncIOMotorClient(self.db_url, **connection_params)
|
|
94
|
+
else:
|
|
95
|
+
logger.info("Connecting to MongoDB instance...")
|
|
96
|
+
self.mongodb_client = AsyncIOMotorClient(self.db_url, **connection_params)
|
|
97
|
+
|
|
98
|
+
# Set database
|
|
99
|
+
self.db = self.mongodb_client[self.database_name]
|
|
100
|
+
|
|
101
|
+
# Test the connection with a ping
|
|
102
|
+
logger.info("Testing MongoDB connection...")
|
|
103
|
+
await asyncio.wait_for(
|
|
104
|
+
self.mongodb_client.admin.command('ping'),
|
|
105
|
+
timeout=30.0
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
logger.info(f"Successfully connected to MongoDB: {self.database_name}")
|
|
109
|
+
|
|
110
|
+
# Create indexes
|
|
111
|
+
await self.create_indexes()
|
|
112
|
+
logger.info("Database indexes created/verified")
|
|
113
|
+
|
|
114
|
+
except asyncio.TimeoutError:
|
|
115
|
+
logger.debug("MongoDB connection timeout")
|
|
116
|
+
if IP_HELPER_AVAILABLE:
|
|
117
|
+
check_mongodb_connection_error("connection timeout")
|
|
118
|
+
raise Exception("MongoDB connection timeout")
|
|
119
|
+
except Exception as e:
|
|
120
|
+
error_msg = str(e)
|
|
121
|
+
|
|
122
|
+
# Check for common IP whitelisting issues
|
|
123
|
+
if IP_HELPER_AVAILABLE:
|
|
124
|
+
check_mongodb_connection_error(error_msg)
|
|
125
|
+
|
|
126
|
+
if "SSL handshake failed" in error_msg and "TLSV1_ALERT_INTERNAL_ERROR" in error_msg:
|
|
127
|
+
logger.debug("MongoDB Atlas SSL handshake failed - check credentials")
|
|
128
|
+
raise Exception("Local MongoDB connection failed")
|
|
129
|
+
else:
|
|
130
|
+
logger.debug(f"MongoDB connection issue: {str(e)}")
|
|
131
|
+
raise Exception(f"Failed to connect to MongoDB: {str(e)}")
|
|
132
|
+
|
|
133
|
+
def _enable_offline_mode(self):
|
|
134
|
+
"""Enable offline mode with in-memory storage"""
|
|
135
|
+
logger.info("ArionXiv running in offline mode - using in-memory storage")
|
|
136
|
+
# Create a simple in-memory storage
|
|
137
|
+
self._offline_storage = {
|
|
138
|
+
'papers': {},
|
|
139
|
+
'users': {},
|
|
140
|
+
'sessions': {},
|
|
141
|
+
'analyses': {}
|
|
142
|
+
}
|
|
143
|
+
# Mark as offline mode
|
|
144
|
+
self._offline_mode = True
|
|
145
|
+
|
|
146
|
+
def is_offline(self) -> bool:
|
|
147
|
+
"""
|
|
148
|
+
Check if the service is in offline mode, i.e. not connected to MongoDB
|
|
149
|
+
"""
|
|
150
|
+
return getattr(self, '_offline_mode', False)
|
|
151
|
+
|
|
152
|
+
async def create_indexes(self):
|
|
153
|
+
"""
|
|
154
|
+
Create necessary indexes for collections
|
|
155
|
+
|
|
156
|
+
Indices:
|
|
157
|
+
- Papers: arxiv_id (unique), title, authors, categories, published, text search on title+abstract
|
|
158
|
+
- Users: user_name (unique), email, user_id, created_at, updated_at, etc.
|
|
159
|
+
- User Papers: user_name, arxiv_id, category, added_at, TTL index for cleanup
|
|
160
|
+
- Cache: key (unique), expires_at (TTL)
|
|
161
|
+
- Chat Sessions: user_name, created_at, TTL index for cleanup
|
|
162
|
+
- Daily Analysis: user_id, analysis_type, analyzed_at, paper_id, created_at, generated_at
|
|
163
|
+
- Cron Jobs: user_id, job_type (unique), status, updated_at
|
|
164
|
+
- Paper Texts and Embeddings: paper_id (unique), chunk_id (unique), expires_at (TTL)
|
|
165
|
+
"""
|
|
166
|
+
if self.db is None:
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
# Papers collection indexes
|
|
171
|
+
await self.db.papers.create_index("arxiv_id", unique=True)
|
|
172
|
+
await self.db.papers.create_index("title")
|
|
173
|
+
await self.db.papers.create_index("authors")
|
|
174
|
+
await self.db.papers.create_index("categories")
|
|
175
|
+
await self.db.papers.create_index("published")
|
|
176
|
+
await self.db.papers.create_index([("title", "text"), ("abstract", "text")])
|
|
177
|
+
|
|
178
|
+
# Users collection indexes
|
|
179
|
+
await self.db.users.create_index("user_name", unique=True)
|
|
180
|
+
await self.db.users.create_index("email")
|
|
181
|
+
|
|
182
|
+
# User papers collection indexes
|
|
183
|
+
# Drop old stale indexes if they exist (migration from old field names)
|
|
184
|
+
try:
|
|
185
|
+
await self.db.user_papers.drop_index("user_id_1_paper_id_1")
|
|
186
|
+
except Exception:
|
|
187
|
+
pass # Index doesn't exist, that's fine
|
|
188
|
+
|
|
189
|
+
# Clean up any documents with null/missing key fields from old schema
|
|
190
|
+
try:
|
|
191
|
+
await self.db.user_papers.delete_many({
|
|
192
|
+
"$or": [
|
|
193
|
+
{"user_name": {"$exists": False}},
|
|
194
|
+
{"user_name": None},
|
|
195
|
+
{"arxiv_id": {"$exists": False}},
|
|
196
|
+
{"arxiv_id": None}
|
|
197
|
+
]
|
|
198
|
+
})
|
|
199
|
+
except Exception:
|
|
200
|
+
pass
|
|
201
|
+
|
|
202
|
+
await self.db.user_papers.create_index([("user_name", 1), ("arxiv_id", 1)], unique=True, sparse=True)
|
|
203
|
+
await self.db.user_papers.create_index("user_name")
|
|
204
|
+
await self.db.user_papers.create_index("category")
|
|
205
|
+
await self.db.user_papers.create_index("added_at")
|
|
206
|
+
|
|
207
|
+
# TTL index for cleanup (papers older than 30 days will be removed)
|
|
208
|
+
try:
|
|
209
|
+
await self.db.user_papers.create_index("added_at", name="user_papers_ttl", expireAfterSeconds=30*24*60*60)
|
|
210
|
+
except Exception:
|
|
211
|
+
pass # Index might already exist
|
|
212
|
+
|
|
213
|
+
# MongoDB TTL-based cache collection indexes
|
|
214
|
+
await self.db.cache.create_index("key", unique=True)
|
|
215
|
+
try:
|
|
216
|
+
await self.db.cache.create_index("expires_at", name="cache_ttl", expireAfterSeconds=0)
|
|
217
|
+
except Exception:
|
|
218
|
+
pass # Index might already exist
|
|
219
|
+
|
|
220
|
+
# Chat sessions collection indexes
|
|
221
|
+
await self.db.chat_sessions.create_index("user_name")
|
|
222
|
+
await self.db.chat_sessions.create_index("session_id", unique=True)
|
|
223
|
+
await self.db.chat_sessions.create_index("created_at")
|
|
224
|
+
try:
|
|
225
|
+
# 24-hour TTL based on expires_at field
|
|
226
|
+
await self.db.chat_sessions.create_index("expires_at", name="chat_sessions_ttl", expireAfterSeconds=0)
|
|
227
|
+
except Exception:
|
|
228
|
+
pass # Index might already exist
|
|
229
|
+
|
|
230
|
+
# Daily analysis collection indexes
|
|
231
|
+
await self.db.daily_analysis.create_index([("user_id", 1), ("analysis_type", 1)])
|
|
232
|
+
await self.db.daily_analysis.create_index([("user_id", 1), ("analyzed_at", 1)])
|
|
233
|
+
await self.db.daily_analysis.create_index([("user_id", 1), ("paper_id", 1)])
|
|
234
|
+
await self.db.daily_analysis.create_index("created_at")
|
|
235
|
+
await self.db.daily_analysis.create_index("generated_at")
|
|
236
|
+
|
|
237
|
+
# Cron jobs collection indexes
|
|
238
|
+
await self.db.cron_jobs.create_index([("user_id", 1), ("job_type", 1)], unique=True)
|
|
239
|
+
await self.db.cron_jobs.create_index("status")
|
|
240
|
+
await self.db.cron_jobs.create_index("updated_at")
|
|
241
|
+
|
|
242
|
+
# Paper texts and embeddings indexes
|
|
243
|
+
await self.db.paper_texts.create_index("paper_id", unique=True)
|
|
244
|
+
try:
|
|
245
|
+
await self.db.paper_texts.create_index("expires_at", name="paper_texts_ttl", expireAfterSeconds=0)
|
|
246
|
+
except Exception:
|
|
247
|
+
pass # Index might already exist
|
|
248
|
+
await self.db.paper_embeddings.create_index("paper_id")
|
|
249
|
+
await self.db.paper_embeddings.create_index("chunk_id", unique=True)
|
|
250
|
+
try:
|
|
251
|
+
await self.db.paper_embeddings.create_index("expires_at", name="paper_embeddings_ttl", expireAfterSeconds=0)
|
|
252
|
+
except Exception:
|
|
253
|
+
pass # Index might already exist
|
|
254
|
+
|
|
255
|
+
# Prompts collection indexes
|
|
256
|
+
await self.db.prompts.create_index("prompt_name", unique=True)
|
|
257
|
+
await self.db.prompts.create_index("updated_at")
|
|
258
|
+
|
|
259
|
+
# Daily dose collection indexes
|
|
260
|
+
await self.db.daily_dose.create_index("user_id")
|
|
261
|
+
await self.db.daily_dose.create_index([("user_id", 1), ("generated_at", -1)])
|
|
262
|
+
await self.db.daily_dose.create_index("date")
|
|
263
|
+
await self.db.daily_dose.create_index("created_at")
|
|
264
|
+
|
|
265
|
+
logger.debug("Database indexes created successfully")
|
|
266
|
+
|
|
267
|
+
except Exception as e:
|
|
268
|
+
# Suppress index creation warnings - often caused by existing data inconsistencies
|
|
269
|
+
logger.debug(f"Index creation skipped: {str(e)}")
|
|
270
|
+
|
|
271
|
+
async def connect(self):
|
|
272
|
+
"""
|
|
273
|
+
Connect to MongoDB with fallback to offline mode
|
|
274
|
+
|
|
275
|
+
Purpose: Establish database connections and initialize in offline mode if connection fails
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
logger.info("Initializing database connections...")
|
|
279
|
+
|
|
280
|
+
# Try to connect to MongoDB with fallback to offline mode
|
|
281
|
+
try:
|
|
282
|
+
await self.connect_mongodb()
|
|
283
|
+
logger.info(" MongoDB connection successful - running in ONLINE mode")
|
|
284
|
+
self._offline_mode = False
|
|
285
|
+
except Exception as e:
|
|
286
|
+
# Silent - end users use hosted API, local DB is optional
|
|
287
|
+
logger.debug(f"Local MongoDB not available: {str(e)}")
|
|
288
|
+
self._offline_mode = True
|
|
289
|
+
self.offline_papers = {}
|
|
290
|
+
self.offline_users = {}
|
|
291
|
+
self.offline_auth_sessions = {}
|
|
292
|
+
|
|
293
|
+
logger.info("Database service ready!")
|
|
294
|
+
|
|
295
|
+
async def disconnect(self):
|
|
296
|
+
"""Disconnect from MongoDB"""
|
|
297
|
+
try:
|
|
298
|
+
if self.mongodb_client:
|
|
299
|
+
self.mongodb_client.close()
|
|
300
|
+
|
|
301
|
+
logger.debug("Disconnected from databases")
|
|
302
|
+
|
|
303
|
+
except Exception as e:
|
|
304
|
+
logger.error(f"Error disconnecting from databases: {str(e)}")
|
|
305
|
+
|
|
306
|
+
async def health_check(self) -> Dict[str, bool]:
|
|
307
|
+
"""Check health of database connections"""
|
|
308
|
+
health = {"mongodb": False}
|
|
309
|
+
|
|
310
|
+
# MongoDB health check
|
|
311
|
+
try:
|
|
312
|
+
if self.db:
|
|
313
|
+
await self.mongodb_client.admin.command('ping')
|
|
314
|
+
health["mongodb"] = True
|
|
315
|
+
except Exception as e:
|
|
316
|
+
logger.debug(f"MongoDB health check failed: {str(e)}")
|
|
317
|
+
|
|
318
|
+
return health
|
|
319
|
+
|
|
320
|
+
# ============================================================
|
|
321
|
+
# BASIC CRUD OPERATIONS
|
|
322
|
+
# ============================================================
|
|
323
|
+
|
|
324
|
+
async def find_one(self, collection: str, filter_dict: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
325
|
+
"""
|
|
326
|
+
Find a single document in a collection
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
collection (str): The name of the collection to search
|
|
330
|
+
filter_dict (Dict[str, Any]): The filter criteria for the document
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Optional[Dict[str, Any]]: The found document or None if not found
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
if self.is_offline():
|
|
337
|
+
# Offline mode - search in memory
|
|
338
|
+
logger.info("Running in offline mode - searching in in-memory storage")
|
|
339
|
+
storage = getattr(self, '_offline_storage', {})
|
|
340
|
+
for doc_id, doc in storage.get(collection, {}).items():
|
|
341
|
+
if all(doc.get(k) == v for k, v in filter_dict.items()):
|
|
342
|
+
return doc
|
|
343
|
+
return None
|
|
344
|
+
|
|
345
|
+
if self.db is None:
|
|
346
|
+
logger.debug("No database connection available")
|
|
347
|
+
return None
|
|
348
|
+
|
|
349
|
+
try:
|
|
350
|
+
logger.debug(f"Querying collection '{collection}' with filter: {filter_dict}")
|
|
351
|
+
result = await self.db[collection].find_one(filter_dict)
|
|
352
|
+
return result
|
|
353
|
+
except RuntimeError as e:
|
|
354
|
+
# Handle closed event loop gracefully
|
|
355
|
+
if "Event loop is closed" in str(e):
|
|
356
|
+
logger.debug("Query skipped - event loop closed during shutdown")
|
|
357
|
+
else:
|
|
358
|
+
logger.debug(f"Query error: {str(e)}")
|
|
359
|
+
return None
|
|
360
|
+
except Exception as e:
|
|
361
|
+
logger.debug(f"Query error: {str(e)}")
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
async def update_one(self, collection: str, filter_dict: Dict[str, Any], update_dict: Dict[str, Any], upsert: bool = False):
|
|
365
|
+
"""
|
|
366
|
+
Update a single document in a collection
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
collection (str): The name of the collection to update
|
|
370
|
+
filter_dict (Dict[str, Any]): The filter criteria for the document to update
|
|
371
|
+
update_dict (Dict[str, Any]): The update operations to apply
|
|
372
|
+
upsert (bool): If True, insert document if not found
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
The result of the update operation
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
if self.is_offline():
|
|
379
|
+
logger.error("Cannot update document in offline mode - Database connection required")
|
|
380
|
+
raise ConnectionError("Database connection is not available. Please check your MongoDB connection string in the .env file.")
|
|
381
|
+
|
|
382
|
+
if self.db is None:
|
|
383
|
+
logger.debug("No database connection available")
|
|
384
|
+
raise ConnectionError("Database connection is not available. Please check your MongoDB connection string in the .env file.")
|
|
385
|
+
|
|
386
|
+
try:
|
|
387
|
+
result = await self.db[collection].update_one(filter_dict, update_dict, upsert=upsert)
|
|
388
|
+
return result
|
|
389
|
+
except Exception as e:
|
|
390
|
+
logger.error(f"Failed to update document: {str(e)}")
|
|
391
|
+
return None
|
|
392
|
+
|
|
393
|
+
async def insert_one(self, collection: str, document: Dict[str, Any]):
|
|
394
|
+
"""
|
|
395
|
+
Insert a single document into a collection
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
collection (str): The name of the collection to insert into
|
|
399
|
+
document (Dict[str, Any]): The document to insert
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
The result of the insert operation
|
|
403
|
+
"""
|
|
404
|
+
|
|
405
|
+
if self.db is None:
|
|
406
|
+
logger.debug("No database connection available")
|
|
407
|
+
return None
|
|
408
|
+
|
|
409
|
+
try:
|
|
410
|
+
result = await self.db[collection].insert_one(document)
|
|
411
|
+
return result
|
|
412
|
+
except Exception as e:
|
|
413
|
+
logger.error(f"Failed to insert document: {str(e)}")
|
|
414
|
+
return None
|
|
415
|
+
|
|
416
|
+
async def delete_one(self, collection: str, filter_dict: Dict[str, Any]):
|
|
417
|
+
"""Delete a single document from a collection"""
|
|
418
|
+
if self.db is None:
|
|
419
|
+
logger.debug("No database connection available")
|
|
420
|
+
return None
|
|
421
|
+
|
|
422
|
+
try:
|
|
423
|
+
result = await self.db[collection].delete_one(filter_dict)
|
|
424
|
+
return result
|
|
425
|
+
except Exception as e:
|
|
426
|
+
logger.error(f"Failed to delete document: {str(e)}")
|
|
427
|
+
return None
|
|
428
|
+
|
|
429
|
+
async def aggregate(self, collection: str, pipeline: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
430
|
+
"""
|
|
431
|
+
Run an aggregation pipeline on a collection
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
collection (str): The name of the collection
|
|
435
|
+
pipeline (List[Dict[str, Any]]): The aggregation pipeline stages
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
List[Dict[str, Any]]: Results from the aggregation
|
|
439
|
+
"""
|
|
440
|
+
if self.db is None:
|
|
441
|
+
logger.debug("No database connection available")
|
|
442
|
+
return []
|
|
443
|
+
|
|
444
|
+
try:
|
|
445
|
+
cursor = self.db[collection].aggregate(pipeline)
|
|
446
|
+
results = await cursor.to_list(length=100)
|
|
447
|
+
return results
|
|
448
|
+
except Exception as e:
|
|
449
|
+
logger.error(f"Aggregation failed: {str(e)}")
|
|
450
|
+
return []
|
|
451
|
+
|
|
452
|
+
async def find_many(self, collection: str, filter_dict: Dict[str, Any] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
|
453
|
+
"""
|
|
454
|
+
Find multiple documents in a collection
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
collection (str): The name of the collection to search
|
|
458
|
+
filter_dict (Dict[str, Any]): The filter criteria (optional)
|
|
459
|
+
limit (int): Maximum number of documents to return
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
List[Dict[str, Any]]: List of found documents
|
|
463
|
+
"""
|
|
464
|
+
if self.db is None:
|
|
465
|
+
logger.debug("No database connection available")
|
|
466
|
+
return []
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
filter_dict = filter_dict or {}
|
|
470
|
+
cursor = self.db[collection].find(filter_dict).limit(limit)
|
|
471
|
+
results = await cursor.to_list(length=limit)
|
|
472
|
+
return results
|
|
473
|
+
except Exception as e:
|
|
474
|
+
logger.error(f"Find many failed: {str(e)}")
|
|
475
|
+
return []
|
|
476
|
+
|
|
477
|
+
async def delete_many(self, collection: str, filter_dict: Dict[str, Any]) -> int:
|
|
478
|
+
"""
|
|
479
|
+
Delete multiple documents from a collection
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
collection (str): The name of the collection
|
|
483
|
+
filter_dict (Dict[str, Any]): The filter criteria for deletion
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
int: Number of documents deleted
|
|
487
|
+
"""
|
|
488
|
+
if self.db is None:
|
|
489
|
+
logger.debug("No database connection available")
|
|
490
|
+
return 0
|
|
491
|
+
|
|
492
|
+
try:
|
|
493
|
+
result = await self.db[collection].delete_many(filter_dict)
|
|
494
|
+
return result.deleted_count
|
|
495
|
+
except Exception as e:
|
|
496
|
+
logger.error(f"Delete many failed: {str(e)}")
|
|
497
|
+
return 0
|
|
498
|
+
|
|
499
|
+
async def insert_many(self, collection: str, documents: List[Dict[str, Any]]) -> List[Any]:
|
|
500
|
+
"""
|
|
501
|
+
Insert multiple documents into a collection
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
collection (str): The name of the collection
|
|
505
|
+
documents (List[Dict[str, Any]]): List of documents to insert
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
List[Any]: List of inserted document IDs
|
|
509
|
+
"""
|
|
510
|
+
if self.db is None:
|
|
511
|
+
logger.debug("No database connection available")
|
|
512
|
+
return []
|
|
513
|
+
|
|
514
|
+
if not documents:
|
|
515
|
+
return []
|
|
516
|
+
|
|
517
|
+
try:
|
|
518
|
+
result = await self.db[collection].insert_many(documents)
|
|
519
|
+
return result.inserted_ids
|
|
520
|
+
except Exception as e:
|
|
521
|
+
logger.error(f"Insert many failed: {str(e)}")
|
|
522
|
+
return []
|
|
523
|
+
|
|
524
|
+
# ============================================================
|
|
525
|
+
# PAPER MANAGEMENT
|
|
526
|
+
# ============================================================
|
|
527
|
+
|
|
528
|
+
async def save_paper(self, paper_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
529
|
+
"""
|
|
530
|
+
Save a paper to the database
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
paper_data (Dict[str, Any]): The paper data to save
|
|
534
|
+
|
|
535
|
+
Returns:
|
|
536
|
+
Dict[str, Any]: Result of the save operation
|
|
537
|
+
"""
|
|
538
|
+
|
|
539
|
+
if self.db is None:
|
|
540
|
+
logger.debug("No database connection available")
|
|
541
|
+
return {"success": False, "message": "Database not connected"}
|
|
542
|
+
|
|
543
|
+
try:
|
|
544
|
+
# Add timestamp
|
|
545
|
+
paper_data["saved_at"] = datetime.utcnow()
|
|
546
|
+
|
|
547
|
+
# Upsert the paper
|
|
548
|
+
result = await self.db.papers.update_one(
|
|
549
|
+
{"arxiv_id": paper_data["arxiv_id"]},
|
|
550
|
+
{"$set": paper_data},
|
|
551
|
+
upsert=True
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
logger.debug(f"Paper saved: {paper_data['arxiv_id']}")
|
|
555
|
+
return {"success": True, "message": "Paper saved successfully"}
|
|
556
|
+
|
|
557
|
+
except Exception as e:
|
|
558
|
+
logger.error(f"Failed to save paper: {str(e)}")
|
|
559
|
+
return {"success": False, "message": str(e)}
|
|
560
|
+
|
|
561
|
+
async def get_paper(self, arxiv_id: str) -> Optional[Dict[str, Any]]:
|
|
562
|
+
"""
|
|
563
|
+
Get a paper by its arXiv ID
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
arxiv_id (str): The arXiv ID of the paper to retrieve
|
|
567
|
+
"""
|
|
568
|
+
|
|
569
|
+
if self.db is None:
|
|
570
|
+
return None
|
|
571
|
+
|
|
572
|
+
try:
|
|
573
|
+
paper = await self.db.papers.find_one({"arxiv_id": arxiv_id})
|
|
574
|
+
return paper
|
|
575
|
+
|
|
576
|
+
except Exception as e:
|
|
577
|
+
logger.error(f"Failed to get paper: {str(e)}")
|
|
578
|
+
return None
|
|
579
|
+
|
|
580
|
+
async def search_papers(self, query: str, limit: int = 50) -> List[Dict[str, Any]]:
|
|
581
|
+
"""
|
|
582
|
+
Search papers using text search
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
query (str): The search query
|
|
586
|
+
limit (int): Maximum number of papers to return
|
|
587
|
+
|
|
588
|
+
Returns:
|
|
589
|
+
List[Dict[str, Any]]: List of papers matching the query
|
|
590
|
+
"""
|
|
591
|
+
|
|
592
|
+
if self.db is None:
|
|
593
|
+
return []
|
|
594
|
+
|
|
595
|
+
try:
|
|
596
|
+
cursor = self.db.papers.find(
|
|
597
|
+
{"$text": {"$search": query}},
|
|
598
|
+
{"score": {"$meta": "textScore"}}
|
|
599
|
+
).sort([("score", {"$meta": "textScore"})]).limit(limit)
|
|
600
|
+
|
|
601
|
+
papers = await cursor.to_list(length=limit)
|
|
602
|
+
return papers
|
|
603
|
+
|
|
604
|
+
except Exception as e:
|
|
605
|
+
logger.error(f"Failed to search papers: {str(e)}")
|
|
606
|
+
return []
|
|
607
|
+
|
|
608
|
+
async def get_papers_by_category(self, category: str, limit: int = 50) -> List[Dict[str, Any]]:
|
|
609
|
+
"""
|
|
610
|
+
Get papers by category
|
|
611
|
+
|
|
612
|
+
Args:
|
|
613
|
+
category (str): The category to filter papers by
|
|
614
|
+
limit (int): Maximum number of papers to return
|
|
615
|
+
|
|
616
|
+
Returns:
|
|
617
|
+
List[Dict[str, Any]]: List of papers in the specified category
|
|
618
|
+
"""
|
|
619
|
+
if self.db is None:
|
|
620
|
+
return []
|
|
621
|
+
|
|
622
|
+
try:
|
|
623
|
+
cursor = self.db.papers.find(
|
|
624
|
+
{"categories": {"$in": [category]}}
|
|
625
|
+
).sort("published", -1).limit(limit)
|
|
626
|
+
|
|
627
|
+
papers = await cursor.to_list(length=limit)
|
|
628
|
+
return papers
|
|
629
|
+
|
|
630
|
+
except Exception as e:
|
|
631
|
+
logger.error(f"Failed to get papers by category: {str(e)}")
|
|
632
|
+
return []
|
|
633
|
+
|
|
634
|
+
async def get_recent_papers(self, days: int = 7, limit: int = 50) -> List[Dict[str, Any]]:
|
|
635
|
+
"""
|
|
636
|
+
Get papers published in the last 'days' days
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
days (int): Number of days to look back
|
|
640
|
+
limit (int): Maximum number of papers to return
|
|
641
|
+
|
|
642
|
+
Returns:
|
|
643
|
+
List[Dict[str, Any]]: List of recent papers
|
|
644
|
+
"""
|
|
645
|
+
|
|
646
|
+
if self.db is None:
|
|
647
|
+
return []
|
|
648
|
+
|
|
649
|
+
try:
|
|
650
|
+
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
|
651
|
+
|
|
652
|
+
cursor = self.db.papers.find(
|
|
653
|
+
{"published": {"$gte": cutoff_date.isoformat()}}
|
|
654
|
+
).sort("published", -1).limit(limit)
|
|
655
|
+
|
|
656
|
+
papers = await cursor.to_list(length=limit)
|
|
657
|
+
return papers
|
|
658
|
+
|
|
659
|
+
except Exception as e:
|
|
660
|
+
logger.error(f"Failed to get recent papers: {str(e)}")
|
|
661
|
+
return []
|
|
662
|
+
|
|
663
|
+
# ============================================================
|
|
664
|
+
# USER MANAGEMENT
|
|
665
|
+
# ============================================================
|
|
666
|
+
|
|
667
|
+
async def create_user(self, user_data: Dict[str, Any]) -> bool:
|
|
668
|
+
"""
|
|
669
|
+
Create a new user
|
|
670
|
+
|
|
671
|
+
Args:
|
|
672
|
+
user_data (Dict[str, Any]): The user data to create an account for
|
|
673
|
+
"""
|
|
674
|
+
|
|
675
|
+
if self.db is None:
|
|
676
|
+
return False
|
|
677
|
+
|
|
678
|
+
try:
|
|
679
|
+
user_data["created_at"] = datetime.utcnow()
|
|
680
|
+
user_data["updated_at"] = datetime.utcnow()
|
|
681
|
+
|
|
682
|
+
await self.db.users.insert_one(user_data)
|
|
683
|
+
logger.debug(f"User created: {user_data.get('email')}")
|
|
684
|
+
return True
|
|
685
|
+
|
|
686
|
+
except Exception as e:
|
|
687
|
+
logger.error(f"Failed to create user: {str(e)}")
|
|
688
|
+
return False
|
|
689
|
+
|
|
690
|
+
async def get_user(self, user_name: str) -> Optional[Dict[str, Any]]:
|
|
691
|
+
"""Get user by primary key (user_name) with legacy fallback"""
|
|
692
|
+
if self.db is None:
|
|
693
|
+
return None
|
|
694
|
+
|
|
695
|
+
if not user_name:
|
|
696
|
+
return None
|
|
697
|
+
|
|
698
|
+
try:
|
|
699
|
+
query = {
|
|
700
|
+
"$or": [
|
|
701
|
+
{"user_name": user_name},
|
|
702
|
+
{"username": user_name}
|
|
703
|
+
]
|
|
704
|
+
}
|
|
705
|
+
user = await self.db.users.find_one(query)
|
|
706
|
+
return user
|
|
707
|
+
|
|
708
|
+
except Exception as e:
|
|
709
|
+
logger.error(f"Failed to get user by name: {str(e)}")
|
|
710
|
+
return None
|
|
711
|
+
|
|
712
|
+
async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]:
|
|
713
|
+
"""Get user by email (non-unique helper)"""
|
|
714
|
+
if self.db is None or not email:
|
|
715
|
+
return None
|
|
716
|
+
|
|
717
|
+
try:
|
|
718
|
+
return await self.db.users.find_one({"email": email})
|
|
719
|
+
except Exception as e:
|
|
720
|
+
logger.error(f"Failed to get user by email: {str(e)}")
|
|
721
|
+
return None
|
|
722
|
+
|
|
723
|
+
async def update_user(self, user_name: str, update_data: Dict[str, Any]) -> bool:
|
|
724
|
+
"""Update user data keyed by user_name"""
|
|
725
|
+
if self.db is None or not user_name:
|
|
726
|
+
return False
|
|
727
|
+
|
|
728
|
+
try:
|
|
729
|
+
update_data["updated_at"] = datetime.utcnow()
|
|
730
|
+
|
|
731
|
+
result = await self.db.users.update_one(
|
|
732
|
+
{"user_name": user_name},
|
|
733
|
+
{"$set": update_data}
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
return result.modified_count > 0
|
|
737
|
+
|
|
738
|
+
except Exception as e:
|
|
739
|
+
logger.error(f"Failed to update user: {str(e)}")
|
|
740
|
+
return False
|
|
741
|
+
|
|
742
|
+
# ====================
|
|
743
|
+
# USER PAPERS MANAGEMENT
|
|
744
|
+
# ====================
|
|
745
|
+
|
|
746
|
+
async def add_user_paper(self, user_name: str, arxiv_id: str, category: str = None, _retry_count: int = 0) -> bool:
|
|
747
|
+
"""Add a paper to user's library"""
|
|
748
|
+
import asyncio
|
|
749
|
+
|
|
750
|
+
if self.db is None:
|
|
751
|
+
return False
|
|
752
|
+
|
|
753
|
+
MAX_RETRIES = 3
|
|
754
|
+
|
|
755
|
+
try:
|
|
756
|
+
paper_data = {
|
|
757
|
+
"user_name": user_name,
|
|
758
|
+
"arxiv_id": arxiv_id,
|
|
759
|
+
"category": category,
|
|
760
|
+
"added_at": datetime.utcnow()
|
|
761
|
+
}
|
|
762
|
+
|
|
763
|
+
await self.db.user_papers.update_one(
|
|
764
|
+
{"user_name": user_name, "arxiv_id": arxiv_id},
|
|
765
|
+
{"$set": paper_data},
|
|
766
|
+
upsert=True
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
logger.debug(f"Paper added to user library: {arxiv_id}")
|
|
770
|
+
return True
|
|
771
|
+
|
|
772
|
+
except Exception as e:
|
|
773
|
+
error_str = str(e).lower()
|
|
774
|
+
|
|
775
|
+
# Handle rate limit errors - retry with backoff
|
|
776
|
+
if any(term in error_str for term in ['rate limit', 'too many requests', '429', 'throttl']):
|
|
777
|
+
if _retry_count < MAX_RETRIES:
|
|
778
|
+
wait_time = (2 ** _retry_count) * 2 # 2, 4, 8 seconds
|
|
779
|
+
logger.info(f"Rate limit hit, waiting {wait_time}s before retry {_retry_count + 1}/{MAX_RETRIES}")
|
|
780
|
+
await asyncio.sleep(wait_time)
|
|
781
|
+
return await self.add_user_paper(user_name, arxiv_id, category, _retry_count + 1)
|
|
782
|
+
else:
|
|
783
|
+
logger.error("Max retries reached for rate limit")
|
|
784
|
+
return False
|
|
785
|
+
|
|
786
|
+
# Handle stale index error - drop the old index and retry
|
|
787
|
+
if "user_id_1_paper_id_1" in str(e):
|
|
788
|
+
try:
|
|
789
|
+
logger.info("Dropping stale index user_id_1_paper_id_1 and retrying...")
|
|
790
|
+
await self.db.user_papers.drop_index("user_id_1_paper_id_1")
|
|
791
|
+
# Also clean up any orphaned documents with old schema
|
|
792
|
+
await self.db.user_papers.delete_many({
|
|
793
|
+
"$or": [
|
|
794
|
+
{"user_id": {"$exists": True}},
|
|
795
|
+
{"paper_id": {"$exists": True}},
|
|
796
|
+
{"user_name": None},
|
|
797
|
+
{"arxiv_id": None}
|
|
798
|
+
]
|
|
799
|
+
})
|
|
800
|
+
# Retry the insert
|
|
801
|
+
paper_data = {
|
|
802
|
+
"user_name": user_name,
|
|
803
|
+
"arxiv_id": arxiv_id,
|
|
804
|
+
"category": category,
|
|
805
|
+
"added_at": datetime.utcnow()
|
|
806
|
+
}
|
|
807
|
+
await self.db.user_papers.update_one(
|
|
808
|
+
{"user_name": user_name, "arxiv_id": arxiv_id},
|
|
809
|
+
{"$set": paper_data},
|
|
810
|
+
upsert=True
|
|
811
|
+
)
|
|
812
|
+
logger.info("Successfully added paper after dropping stale index")
|
|
813
|
+
return True
|
|
814
|
+
except Exception as retry_error:
|
|
815
|
+
logger.error(f"Failed to add user paper after index cleanup: {str(retry_error)}")
|
|
816
|
+
return False
|
|
817
|
+
|
|
818
|
+
logger.error(f"Failed to add user paper: {str(e)}")
|
|
819
|
+
return False
|
|
820
|
+
|
|
821
|
+
async def remove_user_paper(self, user_name: str, arxiv_id: str) -> bool:
|
|
822
|
+
"""Remove a paper from user's library"""
|
|
823
|
+
if self.db is None:
|
|
824
|
+
return False
|
|
825
|
+
|
|
826
|
+
try:
|
|
827
|
+
result = await self.db.user_papers.delete_one(
|
|
828
|
+
{"user_name": user_name, "arxiv_id": arxiv_id}
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
return result.deleted_count > 0
|
|
832
|
+
|
|
833
|
+
except Exception as e:
|
|
834
|
+
logger.error(f"Failed to remove user paper: {str(e)}")
|
|
835
|
+
return False
|
|
836
|
+
|
|
837
|
+
async def get_user_papers(self, user_name: str, category: str = None) -> List[Dict[str, Any]]:
|
|
838
|
+
"""Get user's papers, optionally filtered by category"""
|
|
839
|
+
if self.db is None:
|
|
840
|
+
return []
|
|
841
|
+
|
|
842
|
+
try:
|
|
843
|
+
query = {"user_name": user_name}
|
|
844
|
+
if category:
|
|
845
|
+
query["category"] = category
|
|
846
|
+
|
|
847
|
+
# Get user paper records
|
|
848
|
+
cursor = self.db.user_papers.find(query).sort("added_at", -1)
|
|
849
|
+
user_papers = await cursor.to_list(length=None)
|
|
850
|
+
|
|
851
|
+
# Get full paper details
|
|
852
|
+
if user_papers:
|
|
853
|
+
arxiv_ids = [up["arxiv_id"] for up in user_papers]
|
|
854
|
+
papers_cursor = self.db.papers.find({"arxiv_id": {"$in": arxiv_ids}})
|
|
855
|
+
papers = await papers_cursor.to_list(length=None)
|
|
856
|
+
|
|
857
|
+
# Create a mapping for quick lookup
|
|
858
|
+
papers_dict = {p["arxiv_id"]: p for p in papers}
|
|
859
|
+
|
|
860
|
+
# Combine user paper info with full paper details
|
|
861
|
+
result = []
|
|
862
|
+
for user_paper in user_papers:
|
|
863
|
+
arxiv_id = user_paper["arxiv_id"]
|
|
864
|
+
if arxiv_id in papers_dict:
|
|
865
|
+
paper = papers_dict[arxiv_id].copy()
|
|
866
|
+
paper["user_category"] = user_paper.get("category")
|
|
867
|
+
paper["added_at"] = user_paper["added_at"]
|
|
868
|
+
result.append(paper)
|
|
869
|
+
|
|
870
|
+
return result
|
|
871
|
+
|
|
872
|
+
return []
|
|
873
|
+
|
|
874
|
+
except Exception as e:
|
|
875
|
+
logger.error(f"Failed to get user papers: {str(e)}")
|
|
876
|
+
return []
|
|
877
|
+
|
|
878
|
+
async def get_user_paper_categories(self, user_name: str) -> List[str]:
|
|
879
|
+
"""Get unique categories from user's papers"""
|
|
880
|
+
if self.db is None:
|
|
881
|
+
return []
|
|
882
|
+
|
|
883
|
+
try:
|
|
884
|
+
pipeline = [
|
|
885
|
+
{"$match": {"user_name": user_name, "category": {"$ne": None}}},
|
|
886
|
+
{"$group": {"_id": "$category"}},
|
|
887
|
+
{"$sort": {"_id": 1}}
|
|
888
|
+
]
|
|
889
|
+
|
|
890
|
+
cursor = self.db.user_papers.aggregate(pipeline)
|
|
891
|
+
categories = await cursor.to_list(length=None)
|
|
892
|
+
|
|
893
|
+
return [cat["_id"] for cat in categories if cat["_id"]]
|
|
894
|
+
|
|
895
|
+
except Exception as e:
|
|
896
|
+
logger.error(f"Failed to get user paper categories: {str(e)}")
|
|
897
|
+
return []
|
|
898
|
+
|
|
899
|
+
# ====================
|
|
900
|
+
# CHAT SESSIONS MANAGEMENT
|
|
901
|
+
# ====================
|
|
902
|
+
|
|
903
|
+
async def create_chat_session(self, user_name: str, session_data: Dict[str, Any]) -> str:
|
|
904
|
+
"""Create a new chat session"""
|
|
905
|
+
if self.db is None:
|
|
906
|
+
return ""
|
|
907
|
+
|
|
908
|
+
try:
|
|
909
|
+
session_data.update({
|
|
910
|
+
"user_name": user_name,
|
|
911
|
+
"created_at": datetime.utcnow(),
|
|
912
|
+
"updated_at": datetime.utcnow()
|
|
913
|
+
})
|
|
914
|
+
|
|
915
|
+
result = await self.db.chat_sessions.insert_one(session_data)
|
|
916
|
+
session_id = str(result.inserted_id)
|
|
917
|
+
|
|
918
|
+
logger.debug(f"Chat session created: {session_id}")
|
|
919
|
+
return session_id
|
|
920
|
+
|
|
921
|
+
except Exception as e:
|
|
922
|
+
logger.error(f"Failed to create chat session: {str(e)}")
|
|
923
|
+
return ""
|
|
924
|
+
|
|
925
|
+
async def get_chat_session(self, session_id: str, user_name: str = None) -> Optional[Dict[str, Any]]:
|
|
926
|
+
"""Get a chat session"""
|
|
927
|
+
if self.db is None:
|
|
928
|
+
return None
|
|
929
|
+
|
|
930
|
+
try:
|
|
931
|
+
from bson import ObjectId
|
|
932
|
+
|
|
933
|
+
query = {"_id": ObjectId(session_id)}
|
|
934
|
+
if user_name:
|
|
935
|
+
query["user_name"] = user_name
|
|
936
|
+
|
|
937
|
+
session = await self.db.chat_sessions.find_one(query)
|
|
938
|
+
if session:
|
|
939
|
+
session["_id"] = str(session["_id"])
|
|
940
|
+
|
|
941
|
+
return session
|
|
942
|
+
|
|
943
|
+
except Exception as e:
|
|
944
|
+
logger.error(f"Failed to get chat session: {str(e)}")
|
|
945
|
+
return None
|
|
946
|
+
|
|
947
|
+
async def update_chat_session(self, session_id: str, update_data: Dict[str, Any], user_name: str = None) -> bool:
|
|
948
|
+
"""Update a chat session"""
|
|
949
|
+
if self.db is None:
|
|
950
|
+
return False
|
|
951
|
+
|
|
952
|
+
try:
|
|
953
|
+
from bson import ObjectId
|
|
954
|
+
|
|
955
|
+
query = {"_id": ObjectId(session_id)}
|
|
956
|
+
if user_name:
|
|
957
|
+
query["user_name"] = user_name
|
|
958
|
+
|
|
959
|
+
update_data["updated_at"] = datetime.utcnow()
|
|
960
|
+
|
|
961
|
+
result = await self.db.chat_sessions.update_one(
|
|
962
|
+
query,
|
|
963
|
+
{"$set": update_data}
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
return result.modified_count > 0
|
|
967
|
+
|
|
968
|
+
except Exception as e:
|
|
969
|
+
logger.error(f"Failed to update chat session: {str(e)}")
|
|
970
|
+
return False
|
|
971
|
+
|
|
972
|
+
async def get_user_chat_sessions(self, user_name: str, limit: int = 10) -> List[Dict[str, Any]]:
|
|
973
|
+
"""Get user's recent chat sessions"""
|
|
974
|
+
if self.db is None:
|
|
975
|
+
return []
|
|
976
|
+
|
|
977
|
+
try:
|
|
978
|
+
cursor = self.db.chat_sessions.find(
|
|
979
|
+
{"user_name": user_name}
|
|
980
|
+
).sort("updated_at", -1).limit(limit)
|
|
981
|
+
|
|
982
|
+
sessions = await cursor.to_list(length=limit)
|
|
983
|
+
|
|
984
|
+
# Convert ObjectId to string
|
|
985
|
+
for session in sessions:
|
|
986
|
+
session["_id"] = str(session["_id"])
|
|
987
|
+
|
|
988
|
+
return sessions
|
|
989
|
+
|
|
990
|
+
except Exception as e:
|
|
991
|
+
logger.error(f"Failed to get user chat sessions: {str(e)}")
|
|
992
|
+
return []
|
|
993
|
+
|
|
994
|
+
async def get_active_chat_sessions(self, user_name: str, limit: int = 20) -> List[Dict[str, Any]]:
|
|
995
|
+
"""Get user's active chat sessions within the last 24 hours (not expired)"""
|
|
996
|
+
if self.db is None:
|
|
997
|
+
return []
|
|
998
|
+
|
|
999
|
+
try:
|
|
1000
|
+
now = datetime.utcnow()
|
|
1001
|
+
|
|
1002
|
+
# Find sessions that haven't expired yet
|
|
1003
|
+
cursor = self.db.chat_sessions.find({
|
|
1004
|
+
"user_id": user_name,
|
|
1005
|
+
"expires_at": {"$gt": now}
|
|
1006
|
+
}).sort("last_activity", -1).limit(limit)
|
|
1007
|
+
|
|
1008
|
+
sessions = await cursor.to_list(length=limit)
|
|
1009
|
+
|
|
1010
|
+
# Convert ObjectId to string and add message count
|
|
1011
|
+
for session in sessions:
|
|
1012
|
+
session["_id"] = str(session["_id"])
|
|
1013
|
+
session["message_count"] = len(session.get("messages", []))
|
|
1014
|
+
|
|
1015
|
+
return sessions
|
|
1016
|
+
|
|
1017
|
+
except Exception as e:
|
|
1018
|
+
logger.error(f"Failed to get active chat sessions: {str(e)}")
|
|
1019
|
+
return []
|
|
1020
|
+
|
|
1021
|
+
async def extend_chat_session_ttl(self, session_id: str, hours: int = 24) -> bool:
|
|
1022
|
+
"""Extend the TTL of a chat session by the specified hours"""
|
|
1023
|
+
if self.db is None:
|
|
1024
|
+
return False
|
|
1025
|
+
|
|
1026
|
+
try:
|
|
1027
|
+
new_expiry = datetime.utcnow() + timedelta(hours=hours)
|
|
1028
|
+
|
|
1029
|
+
result = await self.db.chat_sessions.update_one(
|
|
1030
|
+
{"session_id": session_id},
|
|
1031
|
+
{"$set": {
|
|
1032
|
+
"expires_at": new_expiry,
|
|
1033
|
+
"last_activity": datetime.utcnow()
|
|
1034
|
+
}}
|
|
1035
|
+
)
|
|
1036
|
+
|
|
1037
|
+
return result.modified_count > 0
|
|
1038
|
+
|
|
1039
|
+
except Exception as e:
|
|
1040
|
+
logger.error(f"Failed to extend chat session TTL: {str(e)}")
|
|
1041
|
+
return False
|
|
1042
|
+
|
|
1043
|
+
# ====================
|
|
1044
|
+
# DAILY ANALYSIS MANAGEMENT
|
|
1045
|
+
# ====================
|
|
1046
|
+
|
|
1047
|
+
async def save_daily_analysis(self, user_id: str, analysis_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
1048
|
+
"""Save daily analysis for a user, replacing any existing analysis for today"""
|
|
1049
|
+
if self.db is None:
|
|
1050
|
+
return {"success": False, "message": "Database not connected"}
|
|
1051
|
+
|
|
1052
|
+
try:
|
|
1053
|
+
# Get today's date boundaries
|
|
1054
|
+
today = datetime.utcnow().date()
|
|
1055
|
+
start_of_day = datetime.combine(today, datetime.min.time())
|
|
1056
|
+
end_of_day = datetime.combine(today, datetime.max.time())
|
|
1057
|
+
|
|
1058
|
+
# Delete any existing analysis for today (enforce one per day rule)
|
|
1059
|
+
await self.db.daily_analysis.delete_many({
|
|
1060
|
+
"user_id": user_id,
|
|
1061
|
+
"generated_at": {
|
|
1062
|
+
"$gte": start_of_day.isoformat(),
|
|
1063
|
+
"$lte": end_of_day.isoformat()
|
|
1064
|
+
}
|
|
1065
|
+
})
|
|
1066
|
+
|
|
1067
|
+
# Delete yesterday's analysis to keep only current analysis
|
|
1068
|
+
yesterday = today - timedelta(days=1)
|
|
1069
|
+
start_of_yesterday = datetime.combine(yesterday, datetime.min.time())
|
|
1070
|
+
end_of_yesterday = datetime.combine(yesterday, datetime.max.time())
|
|
1071
|
+
|
|
1072
|
+
deleted_result = await self.db.daily_analysis.delete_many({
|
|
1073
|
+
"user_id": user_id,
|
|
1074
|
+
"generated_at": {
|
|
1075
|
+
"$gte": start_of_yesterday.isoformat(),
|
|
1076
|
+
"$lte": end_of_yesterday.isoformat()
|
|
1077
|
+
}
|
|
1078
|
+
})
|
|
1079
|
+
|
|
1080
|
+
if deleted_result.deleted_count > 0:
|
|
1081
|
+
logger.info(f"Deleted {deleted_result.deleted_count} previous daily analysis records for user {user_id}")
|
|
1082
|
+
|
|
1083
|
+
# Add metadata
|
|
1084
|
+
analysis_data.update({
|
|
1085
|
+
"user_id": user_id,
|
|
1086
|
+
"analysis_type": "daily_analysis",
|
|
1087
|
+
"created_at": datetime.utcnow(),
|
|
1088
|
+
"generated_at": datetime.utcnow().isoformat()
|
|
1089
|
+
})
|
|
1090
|
+
|
|
1091
|
+
# Insert new analysis
|
|
1092
|
+
result = await self.db.daily_analysis.insert_one(analysis_data)
|
|
1093
|
+
analysis_id = str(result.inserted_id)
|
|
1094
|
+
|
|
1095
|
+
logger.info(f"Saved daily analysis for user {user_id}, analysis_id: {analysis_id}")
|
|
1096
|
+
return {
|
|
1097
|
+
"success": True,
|
|
1098
|
+
"analysis_id": analysis_id,
|
|
1099
|
+
"message": "Daily analysis saved successfully"
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
except Exception as e:
|
|
1103
|
+
logger.error(f"Failed to save daily analysis for user {user_id}", error=str(e))
|
|
1104
|
+
return {"success": False, "message": str(e)}
|
|
1105
|
+
|
|
1106
|
+
async def get_latest_daily_analysis(self, user_id: str) -> Dict[str, Any]:
|
|
1107
|
+
"""Get the latest daily analysis for a user"""
|
|
1108
|
+
if self.db is None:
|
|
1109
|
+
return {"success": False, "message": "Database not connected"}
|
|
1110
|
+
|
|
1111
|
+
try:
|
|
1112
|
+
# Get the most recent daily analysis
|
|
1113
|
+
analysis = await self.db.daily_analysis.find_one(
|
|
1114
|
+
{"user_id": user_id, "analysis_type": "daily_analysis"},
|
|
1115
|
+
sort=[("created_at", -1)]
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
if analysis:
|
|
1119
|
+
# Convert ObjectId to string
|
|
1120
|
+
analysis["_id"] = str(analysis["_id"])
|
|
1121
|
+
|
|
1122
|
+
return {
|
|
1123
|
+
"success": True,
|
|
1124
|
+
"analysis": {
|
|
1125
|
+
"id": analysis["_id"],
|
|
1126
|
+
"data": analysis,
|
|
1127
|
+
"created_at": analysis["created_at"]
|
|
1128
|
+
}
|
|
1129
|
+
}
|
|
1130
|
+
else:
|
|
1131
|
+
return {
|
|
1132
|
+
"success": False,
|
|
1133
|
+
"message": "No daily analysis found for user"
|
|
1134
|
+
}
|
|
1135
|
+
|
|
1136
|
+
except Exception as e:
|
|
1137
|
+
logger.error(f"Failed to get latest daily analysis for user {user_id}", error=str(e))
|
|
1138
|
+
return {"success": False, "message": str(e)}
|
|
1139
|
+
|
|
1140
|
+
async def schedule_daily_job(self, user_id: str, job_config: Dict[str, Any]) -> Dict[str, Any]:
|
|
1141
|
+
"""Schedule a daily job for a user"""
|
|
1142
|
+
if self.db is None:
|
|
1143
|
+
return {"success": False, "message": "Database not connected"}
|
|
1144
|
+
|
|
1145
|
+
try:
|
|
1146
|
+
job_data = {
|
|
1147
|
+
"user_id": user_id,
|
|
1148
|
+
"job_type": "daily_dose",
|
|
1149
|
+
"config": job_config,
|
|
1150
|
+
"status": "active",
|
|
1151
|
+
"created_at": datetime.utcnow(),
|
|
1152
|
+
"updated_at": datetime.utcnow()
|
|
1153
|
+
}
|
|
1154
|
+
|
|
1155
|
+
# Upsert the job configuration
|
|
1156
|
+
await self.db.cron_jobs.update_one(
|
|
1157
|
+
{"user_id": user_id, "job_type": "daily_dose"},
|
|
1158
|
+
{"$set": job_data},
|
|
1159
|
+
upsert=True
|
|
1160
|
+
)
|
|
1161
|
+
|
|
1162
|
+
logger.info(f"Scheduled daily job for user {user_id}")
|
|
1163
|
+
return {"success": True, "message": "Daily job scheduled successfully"}
|
|
1164
|
+
|
|
1165
|
+
except Exception as e:
|
|
1166
|
+
logger.error(f"Failed to schedule daily job for user {user_id}", error=str(e))
|
|
1167
|
+
return {"success": False, "message": str(e)}
|
|
1168
|
+
|
|
1169
|
+
# ====================
|
|
1170
|
+
# CACHING METHODS (MongoDB TTL)
|
|
1171
|
+
# ====================
|
|
1172
|
+
|
|
1173
|
+
async def cache_set(self, key: str, value: str, ttl: int = 3600) -> bool:
|
|
1174
|
+
"""Set a cache value using MongoDB TTL index"""
|
|
1175
|
+
if self.db is None:
|
|
1176
|
+
return False
|
|
1177
|
+
|
|
1178
|
+
try:
|
|
1179
|
+
expires_at = datetime.utcnow() + timedelta(seconds=ttl)
|
|
1180
|
+
|
|
1181
|
+
await self.db.cache.update_one(
|
|
1182
|
+
{"key": key},
|
|
1183
|
+
{
|
|
1184
|
+
"$set": {
|
|
1185
|
+
"value": value,
|
|
1186
|
+
"expires_at": expires_at,
|
|
1187
|
+
"created_at": datetime.utcnow()
|
|
1188
|
+
}
|
|
1189
|
+
},
|
|
1190
|
+
upsert=True
|
|
1191
|
+
)
|
|
1192
|
+
return True
|
|
1193
|
+
|
|
1194
|
+
except Exception as e:
|
|
1195
|
+
logger.debug(f"Failed to set cache: {str(e)}")
|
|
1196
|
+
return False
|
|
1197
|
+
|
|
1198
|
+
async def cache_get(self, key: str) -> Optional[str]:
|
|
1199
|
+
"""Get a cache value from MongoDB"""
|
|
1200
|
+
if self.db is None:
|
|
1201
|
+
return None
|
|
1202
|
+
|
|
1203
|
+
try:
|
|
1204
|
+
result = await self.db.cache.find_one({"key": key})
|
|
1205
|
+
|
|
1206
|
+
# Check if expired (extra safety, TTL index should handle this)
|
|
1207
|
+
if result and result.get("expires_at"):
|
|
1208
|
+
if result["expires_at"] < datetime.utcnow():
|
|
1209
|
+
# Already expired, delete it
|
|
1210
|
+
await self.db.cache.delete_one({"key": key})
|
|
1211
|
+
return None
|
|
1212
|
+
|
|
1213
|
+
return result.get("value") if result else None
|
|
1214
|
+
|
|
1215
|
+
except Exception as e:
|
|
1216
|
+
logger.debug(f"Failed to get cache: {str(e)}")
|
|
1217
|
+
return None
|
|
1218
|
+
|
|
1219
|
+
async def cache_delete(self, key: str) -> bool:
|
|
1220
|
+
"""Delete a cache value from MongoDB"""
|
|
1221
|
+
if self.db is None:
|
|
1222
|
+
return False
|
|
1223
|
+
|
|
1224
|
+
try:
|
|
1225
|
+
result = await self.db.cache.delete_one({"key": key})
|
|
1226
|
+
return result.deleted_count > 0
|
|
1227
|
+
|
|
1228
|
+
except Exception as e:
|
|
1229
|
+
logger.debug(f"Failed to delete cache: {str(e)}")
|
|
1230
|
+
return False
|
|
1231
|
+
|
|
1232
|
+
# ====================
|
|
1233
|
+
# SYNCHRONOUS WRAPPERS (from sync_db_wrapper.py)
|
|
1234
|
+
# ====================
|
|
1235
|
+
|
|
1236
|
+
def _run_async_in_thread(self, coro):
|
|
1237
|
+
"""Run async coroutine in a separate thread with its own event loop"""
|
|
1238
|
+
def run_in_thread():
|
|
1239
|
+
loop = asyncio.new_event_loop()
|
|
1240
|
+
asyncio.set_event_loop(loop)
|
|
1241
|
+
try:
|
|
1242
|
+
result = loop.run_until_complete(coro)
|
|
1243
|
+
return result
|
|
1244
|
+
finally:
|
|
1245
|
+
loop.close()
|
|
1246
|
+
|
|
1247
|
+
future = self.executor.submit(run_in_thread)
|
|
1248
|
+
return future.result(timeout=30)
|
|
1249
|
+
|
|
1250
|
+
def save_paper_text_sync(self, paper_id: str, title: str, content: str) -> bool:
|
|
1251
|
+
"""Synchronous wrapper for save_paper_text"""
|
|
1252
|
+
async def save_operation():
|
|
1253
|
+
if self.db is None:
|
|
1254
|
+
# Create separate client for sync operations
|
|
1255
|
+
client = AsyncIOMotorClient(self.db_url)
|
|
1256
|
+
db = client[self.database_name]
|
|
1257
|
+
else:
|
|
1258
|
+
db = self.db
|
|
1259
|
+
|
|
1260
|
+
try:
|
|
1261
|
+
expires_at = datetime.utcnow() + timedelta(hours=2)
|
|
1262
|
+
|
|
1263
|
+
paper_doc = {
|
|
1264
|
+
"paper_id": paper_id,
|
|
1265
|
+
"title": title,
|
|
1266
|
+
"content": content,
|
|
1267
|
+
"created_at": datetime.utcnow(),
|
|
1268
|
+
"expires_at": expires_at
|
|
1269
|
+
}
|
|
1270
|
+
|
|
1271
|
+
await db.paper_texts.replace_one(
|
|
1272
|
+
{"paper_id": paper_id},
|
|
1273
|
+
paper_doc,
|
|
1274
|
+
upsert=True
|
|
1275
|
+
)
|
|
1276
|
+
|
|
1277
|
+
if self.db is None:
|
|
1278
|
+
client.close()
|
|
1279
|
+
return True
|
|
1280
|
+
|
|
1281
|
+
except Exception as e:
|
|
1282
|
+
logger.error("Failed to save paper text (sync)", error=str(e))
|
|
1283
|
+
if self.db is None:
|
|
1284
|
+
client.close()
|
|
1285
|
+
return False
|
|
1286
|
+
|
|
1287
|
+
try:
|
|
1288
|
+
return self._run_async_in_thread(save_operation())
|
|
1289
|
+
except Exception as e:
|
|
1290
|
+
logger.error("Failed to save paper text (sync)", error=str(e))
|
|
1291
|
+
return False
|
|
1292
|
+
|
|
1293
|
+
def save_paper_embeddings_sync(self, paper_id: str, embeddings_data: list) -> bool:
|
|
1294
|
+
"""Synchronous wrapper for save_paper_embeddings"""
|
|
1295
|
+
async def save_operation():
|
|
1296
|
+
if self.db is None:
|
|
1297
|
+
client = AsyncIOMotorClient(self.db_url)
|
|
1298
|
+
db = client[self.database_name]
|
|
1299
|
+
else:
|
|
1300
|
+
db = self.db
|
|
1301
|
+
|
|
1302
|
+
try:
|
|
1303
|
+
expires_at = datetime.utcnow() + timedelta(hours=2)
|
|
1304
|
+
|
|
1305
|
+
# Delete existing embeddings for this paper
|
|
1306
|
+
await db.paper_embeddings.delete_many({"paper_id": paper_id})
|
|
1307
|
+
|
|
1308
|
+
# Prepare embedding documents
|
|
1309
|
+
embedding_docs = []
|
|
1310
|
+
for i, embedding_data in enumerate(embeddings_data):
|
|
1311
|
+
doc = {
|
|
1312
|
+
"paper_id": paper_id,
|
|
1313
|
+
"chunk_id": f"{paper_id}_chunk_{i}",
|
|
1314
|
+
"text": embedding_data["text"],
|
|
1315
|
+
"embedding": embedding_data["embedding"],
|
|
1316
|
+
"metadata": embedding_data.get("metadata", {}),
|
|
1317
|
+
"created_at": datetime.utcnow(),
|
|
1318
|
+
"expires_at": expires_at
|
|
1319
|
+
}
|
|
1320
|
+
embedding_docs.append(doc)
|
|
1321
|
+
|
|
1322
|
+
if embedding_docs:
|
|
1323
|
+
await db.paper_embeddings.insert_many(embedding_docs)
|
|
1324
|
+
|
|
1325
|
+
if self.db is None:
|
|
1326
|
+
client.close()
|
|
1327
|
+
return True
|
|
1328
|
+
|
|
1329
|
+
except Exception as e:
|
|
1330
|
+
logger.error("Failed to save paper embeddings (sync)", error=str(e))
|
|
1331
|
+
if self.db is None:
|
|
1332
|
+
client.close()
|
|
1333
|
+
return False
|
|
1334
|
+
|
|
1335
|
+
try:
|
|
1336
|
+
return self._run_async_in_thread(save_operation())
|
|
1337
|
+
except Exception as e:
|
|
1338
|
+
logger.error("Failed to save paper embeddings (sync)", error=str(e))
|
|
1339
|
+
return False
|
|
1340
|
+
|
|
1341
|
+
def save_chat_session_sync(self, session_id: str, session_data: Dict[str, Any]) -> bool:
|
|
1342
|
+
"""Synchronous wrapper for save_chat_session"""
|
|
1343
|
+
async def save_operation():
|
|
1344
|
+
if self.db is None:
|
|
1345
|
+
client = AsyncIOMotorClient(self.db_url)
|
|
1346
|
+
db = client[self.database_name]
|
|
1347
|
+
else:
|
|
1348
|
+
db = self.db
|
|
1349
|
+
|
|
1350
|
+
try:
|
|
1351
|
+
expires_at = datetime.utcnow() + timedelta(hours=2)
|
|
1352
|
+
|
|
1353
|
+
session_doc = {
|
|
1354
|
+
"session_id": session_id,
|
|
1355
|
+
"user_id": session_data.get("user_id", "default"),
|
|
1356
|
+
"conversation_history": session_data.get("conversation_history", []),
|
|
1357
|
+
"paper_id": session_data.get("paper_id"),
|
|
1358
|
+
"created_at": datetime.utcnow(),
|
|
1359
|
+
"expires_at": expires_at
|
|
1360
|
+
}
|
|
1361
|
+
|
|
1362
|
+
await db.chat_sessions.replace_one(
|
|
1363
|
+
{"session_id": session_id},
|
|
1364
|
+
session_doc,
|
|
1365
|
+
upsert=True
|
|
1366
|
+
)
|
|
1367
|
+
|
|
1368
|
+
if self.db is None:
|
|
1369
|
+
client.close()
|
|
1370
|
+
return True
|
|
1371
|
+
|
|
1372
|
+
except Exception as e:
|
|
1373
|
+
logger.error("Failed to save chat session (sync)", error=str(e))
|
|
1374
|
+
if self.db is None:
|
|
1375
|
+
client.close()
|
|
1376
|
+
return False
|
|
1377
|
+
|
|
1378
|
+
try:
|
|
1379
|
+
return self._run_async_in_thread(save_operation())
|
|
1380
|
+
except Exception as e:
|
|
1381
|
+
logger.error("Failed to save chat session (sync)", error=str(e))
|
|
1382
|
+
return False
|
|
1383
|
+
|
|
1384
|
+
def load_chat_session_sync(self, session_id: str) -> Optional[Dict[str, Any]]:
|
|
1385
|
+
"""Synchronous wrapper for load_chat_session"""
|
|
1386
|
+
async def load_operation():
|
|
1387
|
+
if self.db is None:
|
|
1388
|
+
client = AsyncIOMotorClient(self.db_url)
|
|
1389
|
+
db = client[self.database_name]
|
|
1390
|
+
else:
|
|
1391
|
+
db = self.db
|
|
1392
|
+
|
|
1393
|
+
try:
|
|
1394
|
+
session = await db.chat_sessions.find_one({"session_id": session_id})
|
|
1395
|
+
|
|
1396
|
+
if session and session.get("expires_at", datetime.utcnow()) > datetime.utcnow():
|
|
1397
|
+
if self.db is None:
|
|
1398
|
+
client.close()
|
|
1399
|
+
return session
|
|
1400
|
+
else:
|
|
1401
|
+
if session:
|
|
1402
|
+
await db.chat_sessions.delete_one({"session_id": session_id})
|
|
1403
|
+
if self.db is None:
|
|
1404
|
+
client.close()
|
|
1405
|
+
return None
|
|
1406
|
+
|
|
1407
|
+
except Exception as e:
|
|
1408
|
+
logger.error("Failed to load chat session (sync)", error=str(e))
|
|
1409
|
+
if self.db is None:
|
|
1410
|
+
client.close()
|
|
1411
|
+
return None
|
|
1412
|
+
|
|
1413
|
+
try:
|
|
1414
|
+
return self._run_async_in_thread(load_operation())
|
|
1415
|
+
except Exception as e:
|
|
1416
|
+
logger.error("Failed to load chat session (sync)", error=str(e))
|
|
1417
|
+
return None
|
|
1418
|
+
|
|
1419
|
+
def clear_chat_session_sync(self, session_id: str) -> bool:
|
|
1420
|
+
"""Synchronous wrapper for clear_chat_session"""
|
|
1421
|
+
async def clear_operation():
|
|
1422
|
+
if self.db is None:
|
|
1423
|
+
client = AsyncIOMotorClient(self.db_url)
|
|
1424
|
+
db = client[self.database_name]
|
|
1425
|
+
else:
|
|
1426
|
+
db = self.db
|
|
1427
|
+
|
|
1428
|
+
try:
|
|
1429
|
+
await db.chat_sessions.delete_one({"session_id": session_id})
|
|
1430
|
+
if self.db is None:
|
|
1431
|
+
client.close()
|
|
1432
|
+
return True
|
|
1433
|
+
|
|
1434
|
+
except Exception as e:
|
|
1435
|
+
logger.error("Failed to clear chat session (sync)", error=str(e))
|
|
1436
|
+
if self.db is None:
|
|
1437
|
+
client.close()
|
|
1438
|
+
return False
|
|
1439
|
+
|
|
1440
|
+
try:
|
|
1441
|
+
return self._run_async_in_thread(clear_operation())
|
|
1442
|
+
except Exception as e:
|
|
1443
|
+
logger.error("Failed to clear chat session (sync)", error=str(e))
|
|
1444
|
+
return False
|
|
1445
|
+
|
|
1446
|
+
def get_paper_embeddings_sync(self, paper_id: str) -> List[Dict[str, Any]]:
|
|
1447
|
+
"""Synchronous wrapper for get_paper_embeddings"""
|
|
1448
|
+
async def get_operation():
|
|
1449
|
+
if self.db is None:
|
|
1450
|
+
client = AsyncIOMotorClient(self.db_url)
|
|
1451
|
+
db = client[self.database_name]
|
|
1452
|
+
else:
|
|
1453
|
+
db = self.db
|
|
1454
|
+
|
|
1455
|
+
try:
|
|
1456
|
+
current_time = datetime.utcnow()
|
|
1457
|
+
embeddings = []
|
|
1458
|
+
|
|
1459
|
+
async for doc in db.paper_embeddings.find({
|
|
1460
|
+
"paper_id": paper_id,
|
|
1461
|
+
"expires_at": {"$gt": current_time}
|
|
1462
|
+
}):
|
|
1463
|
+
embeddings.append(doc)
|
|
1464
|
+
|
|
1465
|
+
if self.db is None:
|
|
1466
|
+
client.close()
|
|
1467
|
+
return embeddings
|
|
1468
|
+
|
|
1469
|
+
except Exception as e:
|
|
1470
|
+
logger.error("Failed to get paper embeddings (sync)", error=str(e))
|
|
1471
|
+
if self.db is None:
|
|
1472
|
+
client.close()
|
|
1473
|
+
return []
|
|
1474
|
+
|
|
1475
|
+
try:
|
|
1476
|
+
return self._run_async_in_thread(get_operation())
|
|
1477
|
+
except Exception as e:
|
|
1478
|
+
logger.error("Failed to get paper embeddings (sync)", error=str(e))
|
|
1479
|
+
return []
|
|
1480
|
+
|
|
1481
|
+
def save_user_paper_sync(self, user_id: str, paper_data: Dict[str, Any]) -> bool:
|
|
1482
|
+
"""Synchronous wrapper for save_user_paper"""
|
|
1483
|
+
async def save_operation():
|
|
1484
|
+
if self.db is None:
|
|
1485
|
+
client = AsyncIOMotorClient(self.db_url)
|
|
1486
|
+
db = client[self.database_name]
|
|
1487
|
+
else:
|
|
1488
|
+
db = self.db
|
|
1489
|
+
|
|
1490
|
+
try:
|
|
1491
|
+
expires_at = datetime.utcnow() + timedelta(hours=24)
|
|
1492
|
+
|
|
1493
|
+
paper_doc = {
|
|
1494
|
+
"user_id": user_id,
|
|
1495
|
+
"paper_id": paper_data.get("paper_id"),
|
|
1496
|
+
"title": paper_data.get("title"),
|
|
1497
|
+
"authors": paper_data.get("authors", []),
|
|
1498
|
+
"categories": paper_data.get("categories", []),
|
|
1499
|
+
"abstract": paper_data.get("abstract"),
|
|
1500
|
+
"published": paper_data.get("published"),
|
|
1501
|
+
"url": paper_data.get("url"),
|
|
1502
|
+
"added_at": datetime.utcnow(),
|
|
1503
|
+
"expires_at": expires_at
|
|
1504
|
+
}
|
|
1505
|
+
|
|
1506
|
+
await db.user_papers.replace_one(
|
|
1507
|
+
{"user_id": user_id, "paper_id": paper_data.get("paper_id")},
|
|
1508
|
+
paper_doc,
|
|
1509
|
+
upsert=True
|
|
1510
|
+
)
|
|
1511
|
+
|
|
1512
|
+
if self.db is None:
|
|
1513
|
+
client.close()
|
|
1514
|
+
return True
|
|
1515
|
+
|
|
1516
|
+
except Exception as e:
|
|
1517
|
+
logger.error("Failed to save user paper (sync)", error=str(e))
|
|
1518
|
+
if self.db is None:
|
|
1519
|
+
client.close()
|
|
1520
|
+
return False
|
|
1521
|
+
|
|
1522
|
+
try:
|
|
1523
|
+
return self._run_async_in_thread(save_operation())
|
|
1524
|
+
except Exception as e:
|
|
1525
|
+
logger.error("Failed to save user paper (sync)", error=str(e))
|
|
1526
|
+
return False
|
|
1527
|
+
|
|
1528
|
+
def get_user_papers_sync(self, user_id: str, category: str = None) -> List[Dict[str, Any]]:
|
|
1529
|
+
"""Synchronous wrapper for get_user_papers"""
|
|
1530
|
+
async def get_operation():
|
|
1531
|
+
if self.db is None:
|
|
1532
|
+
client = AsyncIOMotorClient(self.db_url)
|
|
1533
|
+
db = client[self.database_name]
|
|
1534
|
+
else:
|
|
1535
|
+
db = self.db
|
|
1536
|
+
|
|
1537
|
+
try:
|
|
1538
|
+
query = {"user_id": user_id, "expires_at": {"$gt": datetime.utcnow()}}
|
|
1539
|
+
|
|
1540
|
+
if category:
|
|
1541
|
+
query["categories"] = {"$in": [category]}
|
|
1542
|
+
|
|
1543
|
+
papers = []
|
|
1544
|
+
async for paper in db.user_papers.find(query).sort("added_at", -1):
|
|
1545
|
+
papers.append(paper)
|
|
1546
|
+
|
|
1547
|
+
if self.db is None:
|
|
1548
|
+
client.close()
|
|
1549
|
+
return papers
|
|
1550
|
+
|
|
1551
|
+
except Exception as e:
|
|
1552
|
+
logger.error("Failed to get user papers (sync)", error=str(e))
|
|
1553
|
+
if self.db is None:
|
|
1554
|
+
client.close()
|
|
1555
|
+
return []
|
|
1556
|
+
|
|
1557
|
+
try:
|
|
1558
|
+
return self._run_async_in_thread(get_operation())
|
|
1559
|
+
except Exception as e:
|
|
1560
|
+
logger.error("Failed to get user papers (sync)", error=str(e))
|
|
1561
|
+
return []
|
|
1562
|
+
|
|
1563
|
+
def get_user_paper_categories_sync(self, user_id: str) -> List[str]:
|
|
1564
|
+
"""Synchronous wrapper for get_user_paper_categories"""
|
|
1565
|
+
async def get_operation():
|
|
1566
|
+
if self.db is None:
|
|
1567
|
+
client = AsyncIOMotorClient(self.db_url)
|
|
1568
|
+
db = client[self.database_name]
|
|
1569
|
+
else:
|
|
1570
|
+
db = self.db
|
|
1571
|
+
|
|
1572
|
+
try:
|
|
1573
|
+
pipeline = [
|
|
1574
|
+
{"$match": {"user_id": user_id, "expires_at": {"$gt": datetime.utcnow()}}},
|
|
1575
|
+
{"$unwind": "$categories"},
|
|
1576
|
+
{"$group": {"_id": "$categories"}},
|
|
1577
|
+
{"$sort": {"_id": 1}}
|
|
1578
|
+
]
|
|
1579
|
+
|
|
1580
|
+
categories = []
|
|
1581
|
+
async for doc in db.user_papers.aggregate(pipeline):
|
|
1582
|
+
categories.append(doc["_id"])
|
|
1583
|
+
|
|
1584
|
+
if self.db is None:
|
|
1585
|
+
client.close()
|
|
1586
|
+
return categories
|
|
1587
|
+
|
|
1588
|
+
except Exception as e:
|
|
1589
|
+
logger.error("Failed to get user paper categories (sync)", error=str(e))
|
|
1590
|
+
if self.db is None:
|
|
1591
|
+
client.close()
|
|
1592
|
+
return []
|
|
1593
|
+
|
|
1594
|
+
try:
|
|
1595
|
+
return self._run_async_in_thread(get_operation())
|
|
1596
|
+
except Exception as e:
|
|
1597
|
+
logger.error("Failed to get user paper categories (sync)", error=str(e))
|
|
1598
|
+
return []
|
|
1599
|
+
|
|
1600
|
+
|
|
1601
|
+
# Global instances
|
|
1602
|
+
unified_database_service = UnifiedDatabaseService()
|
|
1603
|
+
|
|
1604
|
+
# Backwards compatibility
|
|
1605
|
+
database_client = unified_database_service
|
|
1606
|
+
sync_db = unified_database_service
|
|
1607
|
+
|
|
1608
|
+
# Export commonly used methods and aliases
|
|
1609
|
+
database_service = unified_database_service
|
|
1610
|
+
save_paper = unified_database_service.save_paper
|
|
1611
|
+
get_paper = unified_database_service.get_paper
|
|
1612
|
+
search_papers = unified_database_service.search_papers
|
|
1613
|
+
create_user = unified_database_service.create_user
|
|
1614
|
+
get_user = unified_database_service.get_user
|
|
1615
|
+
get_user_by_email = unified_database_service.get_user_by_email
|
|
1616
|
+
create_chat_session = unified_database_service.create_chat_session
|
|
1617
|
+
get_chat_session = unified_database_service.get_chat_session
|
|
1618
|
+
|
|
1619
|
+
__all__ = [
|
|
1620
|
+
'UnifiedDatabaseService',
|
|
1621
|
+
'unified_database_service',
|
|
1622
|
+
'database_service',
|
|
1623
|
+
'database_client',
|
|
1624
|
+
'sync_db',
|
|
1625
|
+
'save_paper',
|
|
1626
|
+
'get_paper',
|
|
1627
|
+
'search_papers',
|
|
1628
|
+
'create_user',
|
|
1629
|
+
'get_user',
|
|
1630
|
+
'get_user_by_email',
|
|
1631
|
+
'create_chat_session',
|
|
1632
|
+
'get_chat_session'
|
|
1633
|
+
]
|