arionxiv 1.0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arionxiv/__init__.py +40 -0
- arionxiv/__main__.py +10 -0
- arionxiv/arxiv_operations/__init__.py +0 -0
- arionxiv/arxiv_operations/client.py +225 -0
- arionxiv/arxiv_operations/fetcher.py +173 -0
- arionxiv/arxiv_operations/searcher.py +122 -0
- arionxiv/arxiv_operations/utils.py +293 -0
- arionxiv/cli/__init__.py +4 -0
- arionxiv/cli/commands/__init__.py +1 -0
- arionxiv/cli/commands/analyze.py +587 -0
- arionxiv/cli/commands/auth.py +365 -0
- arionxiv/cli/commands/chat.py +714 -0
- arionxiv/cli/commands/daily.py +482 -0
- arionxiv/cli/commands/fetch.py +217 -0
- arionxiv/cli/commands/library.py +295 -0
- arionxiv/cli/commands/preferences.py +426 -0
- arionxiv/cli/commands/search.py +254 -0
- arionxiv/cli/commands/settings_unified.py +1407 -0
- arionxiv/cli/commands/trending.py +41 -0
- arionxiv/cli/commands/welcome.py +168 -0
- arionxiv/cli/main.py +407 -0
- arionxiv/cli/ui/__init__.py +1 -0
- arionxiv/cli/ui/global_theme_manager.py +173 -0
- arionxiv/cli/ui/logo.py +127 -0
- arionxiv/cli/ui/splash.py +89 -0
- arionxiv/cli/ui/theme.py +32 -0
- arionxiv/cli/ui/theme_system.py +391 -0
- arionxiv/cli/utils/__init__.py +54 -0
- arionxiv/cli/utils/animations.py +522 -0
- arionxiv/cli/utils/api_client.py +583 -0
- arionxiv/cli/utils/api_config.py +505 -0
- arionxiv/cli/utils/command_suggestions.py +147 -0
- arionxiv/cli/utils/db_config_manager.py +254 -0
- arionxiv/github_actions_runner.py +206 -0
- arionxiv/main.py +23 -0
- arionxiv/prompts/__init__.py +9 -0
- arionxiv/prompts/prompts.py +247 -0
- arionxiv/rag_techniques/__init__.py +8 -0
- arionxiv/rag_techniques/basic_rag.py +1531 -0
- arionxiv/scheduler_daemon.py +139 -0
- arionxiv/server.py +1000 -0
- arionxiv/server_main.py +24 -0
- arionxiv/services/__init__.py +73 -0
- arionxiv/services/llm_client.py +30 -0
- arionxiv/services/llm_inference/__init__.py +58 -0
- arionxiv/services/llm_inference/groq_client.py +469 -0
- arionxiv/services/llm_inference/llm_utils.py +250 -0
- arionxiv/services/llm_inference/openrouter_client.py +564 -0
- arionxiv/services/unified_analysis_service.py +872 -0
- arionxiv/services/unified_auth_service.py +457 -0
- arionxiv/services/unified_config_service.py +456 -0
- arionxiv/services/unified_daily_dose_service.py +823 -0
- arionxiv/services/unified_database_service.py +1633 -0
- arionxiv/services/unified_llm_service.py +366 -0
- arionxiv/services/unified_paper_service.py +604 -0
- arionxiv/services/unified_pdf_service.py +522 -0
- arionxiv/services/unified_prompt_service.py +344 -0
- arionxiv/services/unified_scheduler_service.py +589 -0
- arionxiv/services/unified_user_service.py +954 -0
- arionxiv/utils/__init__.py +51 -0
- arionxiv/utils/api_helpers.py +200 -0
- arionxiv/utils/file_cleanup.py +150 -0
- arionxiv/utils/ip_helper.py +96 -0
- arionxiv-1.0.32.dist-info/METADATA +336 -0
- arionxiv-1.0.32.dist-info/RECORD +69 -0
- arionxiv-1.0.32.dist-info/WHEEL +5 -0
- arionxiv-1.0.32.dist-info/entry_points.txt +4 -0
- arionxiv-1.0.32.dist-info/licenses/LICENSE +21 -0
- arionxiv-1.0.32.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1531 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Basic RAG Implementation for ArionXiv
|
|
3
|
+
Provides standard Retrieval-Augmented Generation with text chunking, embedding generation, and vector search
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
from datetime import datetime, timedelta
|
|
8
|
+
from typing import List, Dict, Any, Optional, Union, Tuple
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
import logging
|
|
11
|
+
from pymongo import IndexModel
|
|
12
|
+
import os
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import numpy as np
|
|
16
|
+
from sentence_transformers import SentenceTransformer
|
|
17
|
+
ML_DEPENDENCIES_AVAILABLE = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
ML_DEPENDENCIES_AVAILABLE = False
|
|
20
|
+
np = None
|
|
21
|
+
SentenceTransformer = None
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from google import genai
|
|
25
|
+
GEMINI_AVAILABLE = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
GEMINI_AVAILABLE = False
|
|
28
|
+
genai = None
|
|
29
|
+
|
|
30
|
+
from rich.console import Console
|
|
31
|
+
from rich.panel import Panel
|
|
32
|
+
from rich.prompt import Prompt
|
|
33
|
+
from rich.markdown import Markdown
|
|
34
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
|
35
|
+
|
|
36
|
+
# Global cache for loaded embedding models to avoid reloading across sessions
|
|
37
|
+
# This persists the model in memory once loaded
|
|
38
|
+
_GLOBAL_MODEL_CACHE: Dict[str, Any] = {}
|
|
39
|
+
|
|
40
|
+
# Import theme system for consistent styling
|
|
41
|
+
try:
|
|
42
|
+
from ..cli.ui.theme import (
|
|
43
|
+
create_themed_console, get_theme_colors, style_text,
|
|
44
|
+
print_success, print_error, print_warning, print_info
|
|
45
|
+
)
|
|
46
|
+
from ..cli.utils.animations import left_to_right_reveal, stream_markdown_response
|
|
47
|
+
from ..cli.utils.command_suggestions import show_command_suggestions
|
|
48
|
+
THEME_AVAILABLE = True
|
|
49
|
+
except ImportError:
|
|
50
|
+
THEME_AVAILABLE = False
|
|
51
|
+
def get_theme_colors(db_service=None):
|
|
52
|
+
return {'primary': 'blue', 'secondary': 'cyan', 'success': 'green',
|
|
53
|
+
'warning': 'yellow', 'error': 'red', 'muted': 'dim'}
|
|
54
|
+
def style_text(text, style='primary', db_service=None):
|
|
55
|
+
colors = get_theme_colors()
|
|
56
|
+
return f"[{colors.get(style, 'white')}]{text}[/{colors.get(style, 'white')}]"
|
|
57
|
+
def create_themed_console(db_service=None):
|
|
58
|
+
return Console()
|
|
59
|
+
def left_to_right_reveal(console, text, style="", duration=1.0):
|
|
60
|
+
console.print(text)
|
|
61
|
+
def stream_markdown_response(console, text, panel_title="", border_style=None, duration=3.0):
|
|
62
|
+
colors = get_theme_colors()
|
|
63
|
+
actual_style = border_style or colors.get('primary', 'blue')
|
|
64
|
+
console.print(Panel(Markdown(text), title=panel_title, border_style=actual_style))
|
|
65
|
+
def show_command_suggestions(console, context="general", **kwargs):
|
|
66
|
+
pass # No-op fallback
|
|
67
|
+
|
|
68
|
+
# Import API config manager to check if Gemini key is available
|
|
69
|
+
try:
|
|
70
|
+
from ..cli.utils.api_config import api_config_manager
|
|
71
|
+
API_CONFIG_AVAILABLE = True
|
|
72
|
+
except ImportError:
|
|
73
|
+
API_CONFIG_AVAILABLE = False
|
|
74
|
+
api_config_manager = None
|
|
75
|
+
|
|
76
|
+
logger = logging.getLogger(__name__)
|
|
77
|
+
|
|
78
|
+
class EmbeddingProvider(ABC):
|
|
79
|
+
"""Abstract base class for embedding providers"""
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
83
|
+
"""Get embeddings for a list of texts"""
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def get_dimension(self) -> int:
|
|
88
|
+
"""Get the dimension of embeddings"""
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
@abstractmethod
|
|
92
|
+
def get_name(self) -> str:
|
|
93
|
+
"""Get provider name"""
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class GeminiEmbeddingProvider(EmbeddingProvider):
|
|
98
|
+
"""Google Gemini embedding provider using gemini-embedding-001 model (FREE!)
|
|
99
|
+
|
|
100
|
+
Uses output_dimensionality=768 for efficient storage (default is 3072).
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(self, api_key: str = None, console: Console = None):
|
|
104
|
+
if not GEMINI_AVAILABLE:
|
|
105
|
+
raise ImportError("google-genai not installed. Install with: pip install google-genai")
|
|
106
|
+
|
|
107
|
+
self.api_key = api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
|
|
108
|
+
if not self.api_key:
|
|
109
|
+
raise ValueError("Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.")
|
|
110
|
+
|
|
111
|
+
# Use new genai.Client() API
|
|
112
|
+
self.client = genai.Client(api_key=self.api_key)
|
|
113
|
+
self.model = "gemini-embedding-001"
|
|
114
|
+
self.dimension = 768 # Using reduced dimensionality for efficiency
|
|
115
|
+
self._console = console or Console()
|
|
116
|
+
|
|
117
|
+
logger.info("Gemini embedding provider initialized with free API")
|
|
118
|
+
|
|
119
|
+
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
120
|
+
"""Get embeddings using Gemini API (FREE!) with rate limit handling"""
|
|
121
|
+
try:
|
|
122
|
+
batch_size = 10
|
|
123
|
+
all_embeddings = []
|
|
124
|
+
max_retries = 3
|
|
125
|
+
|
|
126
|
+
for i in range(0, len(texts), batch_size):
|
|
127
|
+
batch = texts[i:i + batch_size]
|
|
128
|
+
batch_embeddings = []
|
|
129
|
+
|
|
130
|
+
for text in batch:
|
|
131
|
+
retries = 0
|
|
132
|
+
while retries < max_retries:
|
|
133
|
+
try:
|
|
134
|
+
# New API: client.models.embed_content()
|
|
135
|
+
result = self.client.models.embed_content(
|
|
136
|
+
model=self.model,
|
|
137
|
+
contents=text
|
|
138
|
+
)
|
|
139
|
+
# New API returns result.embeddings[0].values
|
|
140
|
+
batch_embeddings.append(list(result.embeddings[0].values))
|
|
141
|
+
await asyncio.sleep(0.1)
|
|
142
|
+
break # Success, exit retry loop
|
|
143
|
+
except Exception as e:
|
|
144
|
+
error_str = str(e).lower()
|
|
145
|
+
# Check for rate limit errors - silently retry with backoff
|
|
146
|
+
if any(term in error_str for term in ['rate limit', 'quota', '429', 'resource exhausted', 'too many']):
|
|
147
|
+
retries += 1
|
|
148
|
+
if retries < max_retries:
|
|
149
|
+
wait_time = (2 ** retries) * 2 # Exponential backoff: 4, 8, 16 seconds
|
|
150
|
+
await asyncio.sleep(wait_time)
|
|
151
|
+
else:
|
|
152
|
+
# Max retries reached, use fallback
|
|
153
|
+
batch_embeddings.append([0.0] * self.dimension)
|
|
154
|
+
else:
|
|
155
|
+
logger.debug(f"Failed to embed text: {str(e)}")
|
|
156
|
+
batch_embeddings.append([0.0] * self.dimension)
|
|
157
|
+
break
|
|
158
|
+
|
|
159
|
+
all_embeddings.extend(batch_embeddings)
|
|
160
|
+
|
|
161
|
+
if i + batch_size < len(texts):
|
|
162
|
+
await asyncio.sleep(0.5)
|
|
163
|
+
|
|
164
|
+
return all_embeddings
|
|
165
|
+
|
|
166
|
+
except Exception as e:
|
|
167
|
+
logger.error(f"Gemini embedding failed: {str(e)}")
|
|
168
|
+
raise
|
|
169
|
+
|
|
170
|
+
def get_dimension(self) -> int:
|
|
171
|
+
return self.dimension
|
|
172
|
+
|
|
173
|
+
def get_name(self) -> str:
|
|
174
|
+
return "Google-Gemini-Embedding-001-FREE"
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class HuggingFaceEmbeddingProvider(EmbeddingProvider):
|
|
178
|
+
"""HuggingFace embedding provider using sentence-transformers (fallback)"""
|
|
179
|
+
|
|
180
|
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
|
181
|
+
if not ML_DEPENDENCIES_AVAILABLE:
|
|
182
|
+
raise ImportError(
|
|
183
|
+
"ML dependencies not installed. Install with: pip install sentence-transformers numpy"
|
|
184
|
+
)
|
|
185
|
+
self.model_name = model_name
|
|
186
|
+
self.model = None
|
|
187
|
+
self._dimension = None
|
|
188
|
+
self._console = Console()
|
|
189
|
+
|
|
190
|
+
def _load_model(self):
|
|
191
|
+
"""Lazy load the model"""
|
|
192
|
+
if self.model is None:
|
|
193
|
+
logger.info(f"Loading HuggingFace model: {self.model_name}")
|
|
194
|
+
colors = get_theme_colors()
|
|
195
|
+
self._console.print(f"[{colors['muted']}]Loading fallback model: {self.model_name}[/{colors['muted']}]")
|
|
196
|
+
|
|
197
|
+
# Suppress HuggingFace's internal progress bars
|
|
198
|
+
import os
|
|
199
|
+
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
|
|
200
|
+
|
|
201
|
+
with Progress(
|
|
202
|
+
SpinnerColumn(),
|
|
203
|
+
TextColumn("[progress.description]{task.description}"),
|
|
204
|
+
console=self._console,
|
|
205
|
+
transient=True
|
|
206
|
+
) as progress:
|
|
207
|
+
task = progress.add_task(
|
|
208
|
+
f"[{colors['primary']}]Loading model...[/{colors['primary']}]",
|
|
209
|
+
total=None
|
|
210
|
+
)
|
|
211
|
+
self.model = SentenceTransformer(self.model_name)
|
|
212
|
+
self._dimension = self.model.get_sentence_embedding_dimension()
|
|
213
|
+
|
|
214
|
+
# Re-enable progress bars for other operations
|
|
215
|
+
os.environ.pop('HF_HUB_DISABLE_PROGRESS_BARS', None)
|
|
216
|
+
|
|
217
|
+
self._console.print(f"[{colors['primary']}][OK][/{colors['primary']}] Fallback model ready")
|
|
218
|
+
|
|
219
|
+
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
220
|
+
"""Get embeddings using HuggingFace model"""
|
|
221
|
+
try:
|
|
222
|
+
self._load_model()
|
|
223
|
+
loop = asyncio.get_event_loop()
|
|
224
|
+
embeddings = await loop.run_in_executor(None, self.model.encode, texts)
|
|
225
|
+
return embeddings.tolist()
|
|
226
|
+
except Exception as e:
|
|
227
|
+
logger.error(f"HuggingFace embedding failed: {str(e)}")
|
|
228
|
+
raise
|
|
229
|
+
|
|
230
|
+
def get_dimension(self) -> int:
|
|
231
|
+
if self._dimension is None:
|
|
232
|
+
self._load_model()
|
|
233
|
+
return self._dimension
|
|
234
|
+
|
|
235
|
+
def get_name(self) -> str:
|
|
236
|
+
return f"HuggingFace-{self.model_name}"
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class GraniteDoclingEmbeddingProvider(EmbeddingProvider):
|
|
240
|
+
"""
|
|
241
|
+
IBM Granite embedding provider - small, fast, and runs locally
|
|
242
|
+
|
|
243
|
+
Downloads the model on first use. Model is kept in memory during
|
|
244
|
+
the session and uses HuggingFace's default cache (~/.cache/huggingface/).
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
# Default model - IBM Granite 30M English (small, ~120MB download)
|
|
248
|
+
DEFAULT_MODEL = "ibm-granite/granite-embedding-30m-english"
|
|
249
|
+
|
|
250
|
+
def __init__(self, model_name: str = None):
|
|
251
|
+
if not ML_DEPENDENCIES_AVAILABLE:
|
|
252
|
+
raise ImportError(
|
|
253
|
+
"ML dependencies not installed. Install with: pip install sentence-transformers numpy"
|
|
254
|
+
)
|
|
255
|
+
self.model_name = model_name or self.DEFAULT_MODEL
|
|
256
|
+
self._dimension = None
|
|
257
|
+
self._console = Console()
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def model(self):
|
|
261
|
+
"""Get model from global cache or None if not loaded"""
|
|
262
|
+
return _GLOBAL_MODEL_CACHE.get(self.model_name)
|
|
263
|
+
|
|
264
|
+
@model.setter
|
|
265
|
+
def model(self, value):
|
|
266
|
+
"""Store model in global cache"""
|
|
267
|
+
if value is not None:
|
|
268
|
+
_GLOBAL_MODEL_CACHE[self.model_name] = value
|
|
269
|
+
|
|
270
|
+
def _load_model(self):
|
|
271
|
+
"""Lazy load the model with progress indicator - uses global cache"""
|
|
272
|
+
# Check global cache first - model persists across sessions
|
|
273
|
+
if self.model_name in _GLOBAL_MODEL_CACHE:
|
|
274
|
+
self._dimension = _GLOBAL_MODEL_CACHE[self.model_name].get_sentence_embedding_dimension()
|
|
275
|
+
return # Model already in memory, no loading needed
|
|
276
|
+
|
|
277
|
+
colors = get_theme_colors()
|
|
278
|
+
logger.info(f"Loading embedding model: {self.model_name}")
|
|
279
|
+
|
|
280
|
+
# Check if model is already cached by HuggingFace on disk
|
|
281
|
+
from pathlib import Path
|
|
282
|
+
cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
|
|
283
|
+
model_cache_name = f"models--{self.model_name.replace('/', '--')}"
|
|
284
|
+
is_cached = (cache_dir / model_cache_name).exists()
|
|
285
|
+
|
|
286
|
+
if not is_cached:
|
|
287
|
+
# First time - show download message
|
|
288
|
+
self._console.print(
|
|
289
|
+
f"[dim {colors['primary']}]Downloading embedding model: {self.model_name}[/dim {colors['primary']}]"
|
|
290
|
+
)
|
|
291
|
+
self._console.print(
|
|
292
|
+
f"[dim {colors['primary']}](First run downloads ~120MB, uses HuggingFace cache)[/{colors['primary']}]"
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
# Suppress HuggingFace's internal progress bars to avoid flickering
|
|
297
|
+
import os
|
|
298
|
+
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
|
|
299
|
+
|
|
300
|
+
if is_cached:
|
|
301
|
+
# Model is on disk - load silently (fast operation, no spinner needed)
|
|
302
|
+
loaded_model = SentenceTransformer(self.model_name, trust_remote_code=True)
|
|
303
|
+
self._dimension = loaded_model.get_sentence_embedding_dimension()
|
|
304
|
+
_GLOBAL_MODEL_CACHE[self.model_name] = loaded_model
|
|
305
|
+
else:
|
|
306
|
+
# First time download - show progress spinner
|
|
307
|
+
with Progress(
|
|
308
|
+
SpinnerColumn(),
|
|
309
|
+
TextColumn("[progress.description]{task.description}"),
|
|
310
|
+
console=self._console,
|
|
311
|
+
transient=True
|
|
312
|
+
) as progress:
|
|
313
|
+
task = progress.add_task(
|
|
314
|
+
f"[bold {colors['primary']}]Downloading and initializing embedding model...[/bold {colors['primary']}]",
|
|
315
|
+
total=None
|
|
316
|
+
)
|
|
317
|
+
loaded_model = SentenceTransformer(self.model_name, trust_remote_code=True)
|
|
318
|
+
self._dimension = loaded_model.get_sentence_embedding_dimension()
|
|
319
|
+
_GLOBAL_MODEL_CACHE[self.model_name] = loaded_model
|
|
320
|
+
|
|
321
|
+
# Re-enable progress bars for other operations
|
|
322
|
+
os.environ.pop('HF_HUB_DISABLE_PROGRESS_BARS', None)
|
|
323
|
+
|
|
324
|
+
# self._console.print(
|
|
325
|
+
# f"[{colors['primary']}][OK][/{colors['primary']}] Embedding model ready "
|
|
326
|
+
# f"(dimension: {self._dimension})"
|
|
327
|
+
# )
|
|
328
|
+
logger.info(f"Embedding model loaded successfully (dimension: {self._dimension})")
|
|
329
|
+
|
|
330
|
+
except Exception as e:
|
|
331
|
+
logger.error(f"Failed to load embedding model {self.model_name}: {str(e)}")
|
|
332
|
+
raise
|
|
333
|
+
|
|
334
|
+
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
335
|
+
"""Get embeddings using the configured embedding model"""
|
|
336
|
+
try:
|
|
337
|
+
self._load_model()
|
|
338
|
+
model = _GLOBAL_MODEL_CACHE.get(self.model_name)
|
|
339
|
+
loop = asyncio.get_event_loop()
|
|
340
|
+
embeddings = await loop.run_in_executor(None, model.encode, texts)
|
|
341
|
+
return embeddings.tolist()
|
|
342
|
+
except Exception as e:
|
|
343
|
+
logger.error(f"Embedding generation failed for {self.model_name}: {str(e)}")
|
|
344
|
+
raise
|
|
345
|
+
|
|
346
|
+
def get_dimension(self) -> int:
|
|
347
|
+
if self._dimension is None:
|
|
348
|
+
self._load_model()
|
|
349
|
+
return self._dimension
|
|
350
|
+
|
|
351
|
+
def get_name(self) -> str:
|
|
352
|
+
return f"Granite-{self.model_name.split('/')[-1]}"
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class BasicRAG:
|
|
356
|
+
"""
|
|
357
|
+
Basic RAG (Retrieval-Augmented Generation) implementation
|
|
358
|
+
Handles text chunking, embedding generation, vector search, and context retrieval
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
def __init__(self, database_service, config_service, llm_client, openrouter_client=None):
|
|
362
|
+
"""
|
|
363
|
+
Initialize BasicRAG with required services
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
database_service: Database service for storing/retrieving embeddings
|
|
367
|
+
config_service: Configuration service for RAG settings
|
|
368
|
+
llm_client: LLM client for generating responses (Groq - fallback)
|
|
369
|
+
openrouter_client: OpenRouter client for primary LLM (Kimi K2)
|
|
370
|
+
"""
|
|
371
|
+
self.db_service = database_service
|
|
372
|
+
self.config_service = config_service
|
|
373
|
+
self.llm_client = llm_client
|
|
374
|
+
self.openrouter_client = openrouter_client
|
|
375
|
+
|
|
376
|
+
# Lazy initialization flags for embedding providers
|
|
377
|
+
self._embedding_providers_initialized = False
|
|
378
|
+
self._embedding_providers = []
|
|
379
|
+
self._current_embedding_provider = None
|
|
380
|
+
|
|
381
|
+
# Use OpenRouter as primary if available, otherwise fall back to Groq
|
|
382
|
+
# Can be overridden with RAG_LLM_PROVIDER env var
|
|
383
|
+
env_provider = os.getenv("RAG_LLM_PROVIDER", "").lower()
|
|
384
|
+
if env_provider:
|
|
385
|
+
self.llm_provider = env_provider
|
|
386
|
+
elif openrouter_client and openrouter_client.is_available:
|
|
387
|
+
self.llm_provider = "openrouter"
|
|
388
|
+
else:
|
|
389
|
+
self.llm_provider = "groq"
|
|
390
|
+
|
|
391
|
+
rag_config = config_service.get_rag_config()
|
|
392
|
+
embedding_config = config_service.get_embedding_config()
|
|
393
|
+
|
|
394
|
+
self.vector_collection = rag_config["vector_collection"]
|
|
395
|
+
self.chat_collection = rag_config["chat_collection"]
|
|
396
|
+
self.chunk_size = rag_config["chunk_size"]
|
|
397
|
+
self.chunk_overlap = rag_config["chunk_overlap"]
|
|
398
|
+
self.top_k_results = rag_config["top_k_results"]
|
|
399
|
+
self.ttl_hours = rag_config["ttl_hours"]
|
|
400
|
+
|
|
401
|
+
self.embedding_batch_size = embedding_config["batch_size"]
|
|
402
|
+
self.embedding_dimension = embedding_config["dimension_default"]
|
|
403
|
+
self._embedding_config = embedding_config
|
|
404
|
+
|
|
405
|
+
# In-memory embedding storage for current chat session
|
|
406
|
+
# Format: {chunk_id: {text, embedding, metadata}}
|
|
407
|
+
self._session_embeddings: Dict[str, Dict[str, Any]] = {}
|
|
408
|
+
self._current_session_id: Optional[str] = None
|
|
409
|
+
|
|
410
|
+
# In-memory session storage (fallback when database unavailable)
|
|
411
|
+
self._in_memory_sessions: Dict[str, Dict[str, Any]] = {}
|
|
412
|
+
|
|
413
|
+
self.console = Console()
|
|
414
|
+
|
|
415
|
+
logger.info("BasicRAG initialized (embedding providers lazy-loaded)")
|
|
416
|
+
|
|
417
|
+
@property
|
|
418
|
+
def embedding_providers(self):
|
|
419
|
+
"""Lazy initialize embedding providers"""
|
|
420
|
+
if not self._embedding_providers_initialized:
|
|
421
|
+
self._embedding_providers_initialized = True
|
|
422
|
+
self._setup_embedding_providers(self._embedding_config)
|
|
423
|
+
return self._embedding_providers
|
|
424
|
+
|
|
425
|
+
@property
|
|
426
|
+
def current_embedding_provider(self):
|
|
427
|
+
"""Get current embedding provider (lazy init if needed)"""
|
|
428
|
+
if not self._embedding_providers_initialized:
|
|
429
|
+
self._embedding_providers_initialized = True
|
|
430
|
+
self._setup_embedding_providers(self._embedding_config)
|
|
431
|
+
return self._current_embedding_provider
|
|
432
|
+
|
|
433
|
+
@current_embedding_provider.setter
|
|
434
|
+
def current_embedding_provider(self, value):
|
|
435
|
+
"""Set current embedding provider"""
|
|
436
|
+
self._current_embedding_provider = value
|
|
437
|
+
|
|
438
|
+
def _setup_embedding_providers(self, embedding_config):
|
|
439
|
+
"""
|
|
440
|
+
Setup embedding providers in order of preference
|
|
441
|
+
|
|
442
|
+
Order:
|
|
443
|
+
1. Gemini (FREE API, if API key is configured)
|
|
444
|
+
2. Granite/HuggingFace fallback models (run locally with 24h cache)
|
|
445
|
+
|
|
446
|
+
If Gemini API key is not available, automatically falls back to
|
|
447
|
+
local Granite model which is cached for 24 hours to avoid
|
|
448
|
+
repeated downloads.
|
|
449
|
+
"""
|
|
450
|
+
primary_model = embedding_config["primary_model"]
|
|
451
|
+
fallback_1 = embedding_config["fallback_1"]
|
|
452
|
+
fallback_2 = embedding_config["fallback_2"]
|
|
453
|
+
enable_gemini = embedding_config["enable_gemini"]
|
|
454
|
+
enable_huggingface = embedding_config["enable_huggingface"]
|
|
455
|
+
|
|
456
|
+
# Check if Gemini API key is actually available
|
|
457
|
+
gemini_key_available = False
|
|
458
|
+
if API_CONFIG_AVAILABLE and api_config_manager:
|
|
459
|
+
gemini_key_available = api_config_manager.is_configured("gemini")
|
|
460
|
+
else:
|
|
461
|
+
# Fallback: check environment variable directly
|
|
462
|
+
gemini_key_available = bool(
|
|
463
|
+
os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Primary: Gemini (if enabled AND API key is available)
|
|
467
|
+
if enable_gemini and gemini_key_available and (primary_model.lower() == "gemini" or primary_model == ""):
|
|
468
|
+
try:
|
|
469
|
+
gemini_provider = GeminiEmbeddingProvider()
|
|
470
|
+
self._embedding_providers.append(gemini_provider)
|
|
471
|
+
logger.info("Gemini embedding provider initialized as PRIMARY (FREE API)")
|
|
472
|
+
except Exception as e:
|
|
473
|
+
logger.warning(f"Gemini embedding provider failed to initialize: {str(e)}")
|
|
474
|
+
elif enable_gemini and not gemini_key_available:
|
|
475
|
+
logger.info("Gemini API key not configured - will use local Granite model as fallback")
|
|
476
|
+
|
|
477
|
+
# If Gemini is not available OR primary is a HuggingFace model, use Granite
|
|
478
|
+
if enable_huggingface:
|
|
479
|
+
# If Gemini failed/unavailable, Granite becomes primary
|
|
480
|
+
if not self._embedding_providers:
|
|
481
|
+
try:
|
|
482
|
+
# Use Granite as primary when Gemini is unavailable
|
|
483
|
+
granite_model = fallback_1 or GraniteDoclingEmbeddingProvider.DEFAULT_MODEL
|
|
484
|
+
granite_provider = GraniteDoclingEmbeddingProvider(model_name=granite_model)
|
|
485
|
+
self._embedding_providers.append(granite_provider)
|
|
486
|
+
logger.info(f"Granite embedding provider initialized as PRIMARY (local): {granite_model}")
|
|
487
|
+
except Exception as e:
|
|
488
|
+
logger.warning(f"Granite embedding provider failed to initialize: {str(e)}")
|
|
489
|
+
|
|
490
|
+
# If primary is explicitly a HuggingFace model (not "gemini"), add it
|
|
491
|
+
elif primary_model.lower() != "gemini" and primary_model != "":
|
|
492
|
+
try:
|
|
493
|
+
primary_provider = GraniteDoclingEmbeddingProvider(model_name=primary_model)
|
|
494
|
+
self._embedding_providers.append(primary_provider)
|
|
495
|
+
logger.info(f"Primary HuggingFace embedding provider initialized: {primary_model}")
|
|
496
|
+
except Exception as e:
|
|
497
|
+
logger.warning(f"Primary embedding provider failed to initialize: {str(e)}")
|
|
498
|
+
|
|
499
|
+
# Add fallback (Granite) if not already primary
|
|
500
|
+
if fallback_1 and not any(
|
|
501
|
+
isinstance(p, GraniteDoclingEmbeddingProvider) and p.model_name == fallback_1
|
|
502
|
+
for p in self._embedding_providers
|
|
503
|
+
):
|
|
504
|
+
try:
|
|
505
|
+
fallback_1_provider = GraniteDoclingEmbeddingProvider(model_name=fallback_1)
|
|
506
|
+
self._embedding_providers.append(fallback_1_provider)
|
|
507
|
+
logger.info(f"Fallback embedding provider initialized: {fallback_1}")
|
|
508
|
+
except Exception as e:
|
|
509
|
+
logger.warning(f"Fallback embedding provider failed: {str(e)}")
|
|
510
|
+
|
|
511
|
+
if self._embedding_providers:
|
|
512
|
+
self._current_embedding_provider = self._embedding_providers[0]
|
|
513
|
+
logger.info(f"Using embedding provider: {self._current_embedding_provider.get_name()}")
|
|
514
|
+
else:
|
|
515
|
+
# No providers available - this will be handled gracefully in chat
|
|
516
|
+
logger.debug("No embedding providers available - chat will show user-friendly message")
|
|
517
|
+
|
|
518
|
+
def is_embedding_available(self) -> bool:
|
|
519
|
+
"""Check if any embedding provider is available for chat"""
|
|
520
|
+
# Trigger lazy initialization
|
|
521
|
+
_ = self.embedding_providers
|
|
522
|
+
return len(self._embedding_providers) > 0
|
|
523
|
+
|
|
524
|
+
def get_embedding_unavailable_message(self) -> str:
|
|
525
|
+
"""Get user-friendly message explaining why embeddings are unavailable"""
|
|
526
|
+
# Check if Gemini API key is configured
|
|
527
|
+
gemini_configured = False
|
|
528
|
+
if API_CONFIG_AVAILABLE and api_config_manager:
|
|
529
|
+
gemini_configured = api_config_manager.is_configured("gemini")
|
|
530
|
+
else:
|
|
531
|
+
gemini_configured = bool(os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"))
|
|
532
|
+
|
|
533
|
+
if not ML_DEPENDENCIES_AVAILABLE and not gemini_configured:
|
|
534
|
+
return (
|
|
535
|
+
"Chat feature is temporarily unavailable.\n\n"
|
|
536
|
+
"To enable this feature, please configure your Gemini API key:\n"
|
|
537
|
+
" arionxiv settings\n\n"
|
|
538
|
+
"If you encounter persistent issues, please report at:\n"
|
|
539
|
+
" https://github.com/Arion-IT/ArionXiv/issues"
|
|
540
|
+
)
|
|
541
|
+
elif not ML_DEPENDENCIES_AVAILABLE:
|
|
542
|
+
return (
|
|
543
|
+
"Chat feature encountered an issue.\n\n"
|
|
544
|
+
"Please try again later or report at:\n"
|
|
545
|
+
" https://github.com/Arion-IT/ArionXiv/issues"
|
|
546
|
+
)
|
|
547
|
+
else:
|
|
548
|
+
return (
|
|
549
|
+
"Chat feature is temporarily unavailable.\n\n"
|
|
550
|
+
"Please try again later or report at:\n"
|
|
551
|
+
" https://github.com/Arion-IT/ArionXiv/issues"
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
async def get_embeddings(self, texts: Union[str, List[str]]) -> List[List[float]]:
|
|
555
|
+
"""Get embeddings with automatic fallback"""
|
|
556
|
+
if isinstance(texts, str):
|
|
557
|
+
texts = [texts]
|
|
558
|
+
|
|
559
|
+
if not texts:
|
|
560
|
+
return []
|
|
561
|
+
|
|
562
|
+
for i, provider in enumerate(self.embedding_providers):
|
|
563
|
+
try:
|
|
564
|
+
embeddings = await provider.get_embeddings(texts)
|
|
565
|
+
|
|
566
|
+
if provider != self.current_embedding_provider:
|
|
567
|
+
self.current_embedding_provider = provider
|
|
568
|
+
logger.info(f"Switched to embedding provider: {provider.get_name()}")
|
|
569
|
+
|
|
570
|
+
return embeddings
|
|
571
|
+
|
|
572
|
+
except Exception as e:
|
|
573
|
+
logger.warning(f"Provider {provider.get_name()} failed: {str(e)}")
|
|
574
|
+
if i == len(self.embedding_providers) - 1:
|
|
575
|
+
raise RuntimeError(f"All embedding providers failed. Last error: {str(e)}")
|
|
576
|
+
continue
|
|
577
|
+
|
|
578
|
+
async def get_single_embedding(self, text: str) -> List[float]:
|
|
579
|
+
"""Get embedding for a single text"""
|
|
580
|
+
embeddings = await self.get_embeddings([text])
|
|
581
|
+
return embeddings[0] if embeddings else []
|
|
582
|
+
|
|
583
|
+
def get_embedding_dimension(self) -> int:
|
|
584
|
+
"""Get embedding dimension"""
|
|
585
|
+
if self.current_embedding_provider:
|
|
586
|
+
return self.current_embedding_provider.get_dimension()
|
|
587
|
+
return self.embedding_dimension
|
|
588
|
+
|
|
589
|
+
def get_embedding_provider_name(self) -> str:
|
|
590
|
+
"""Get current provider name"""
|
|
591
|
+
if self.current_embedding_provider:
|
|
592
|
+
return self.current_embedding_provider.get_name()
|
|
593
|
+
return "None"
|
|
594
|
+
|
|
595
|
+
def ensure_embedding_model_loaded(self):
|
|
596
|
+
"""Ensure the embedding model is loaded before starting batch operations.
|
|
597
|
+
|
|
598
|
+
This prevents the model download progress from interfering with
|
|
599
|
+
the embedding computation progress bar.
|
|
600
|
+
"""
|
|
601
|
+
if self.current_embedding_provider:
|
|
602
|
+
# Trigger model loading by calling get_dimension which internally calls _load_model
|
|
603
|
+
try:
|
|
604
|
+
self.current_embedding_provider.get_dimension()
|
|
605
|
+
except Exception as e:
|
|
606
|
+
logger.warning(f"Failed to pre-load embedding model: {e}")
|
|
607
|
+
|
|
608
|
+
def _chunk_text(self, text: str) -> List[str]:
|
|
609
|
+
"""Split text into overlapping chunks"""
|
|
610
|
+
if len(text) <= self.chunk_size:
|
|
611
|
+
return [text]
|
|
612
|
+
|
|
613
|
+
chunks = []
|
|
614
|
+
start = 0
|
|
615
|
+
|
|
616
|
+
while start < len(text):
|
|
617
|
+
end = start + self.chunk_size
|
|
618
|
+
chunk = text[start:end]
|
|
619
|
+
|
|
620
|
+
if end < len(text):
|
|
621
|
+
last_period = chunk.rfind('.')
|
|
622
|
+
if last_period > self.chunk_size * 0.7:
|
|
623
|
+
chunk = chunk[:last_period + 1]
|
|
624
|
+
end = start + last_period + 1
|
|
625
|
+
|
|
626
|
+
chunks.append(chunk.strip())
|
|
627
|
+
start = end - self.chunk_overlap
|
|
628
|
+
|
|
629
|
+
if start >= len(text):
|
|
630
|
+
break
|
|
631
|
+
|
|
632
|
+
return chunks
|
|
633
|
+
|
|
634
|
+
async def add_document_to_index(self, doc_id: str, text: str, metadata: Dict[str, Any] = None) -> bool:
|
|
635
|
+
"""Add document to in-memory vector index for current session
|
|
636
|
+
|
|
637
|
+
First checks if embeddings are cached in the database (24-hour TTL).
|
|
638
|
+
If cached, loads them directly. Otherwise, computes and caches them.
|
|
639
|
+
"""
|
|
640
|
+
try:
|
|
641
|
+
# Check if embeddings are already cached in the database
|
|
642
|
+
cached_embeddings = await self._get_cached_embeddings(doc_id)
|
|
643
|
+
|
|
644
|
+
if cached_embeddings:
|
|
645
|
+
# Load from cache
|
|
646
|
+
await self._load_embeddings_from_cache(cached_embeddings)
|
|
647
|
+
logger.info(f"Loaded {len(cached_embeddings)} cached embeddings for document {doc_id}")
|
|
648
|
+
return True
|
|
649
|
+
|
|
650
|
+
# No cache - compute embeddings
|
|
651
|
+
chunks = self._chunk_text(text)
|
|
652
|
+
embeddings = await self.get_embeddings(chunks)
|
|
653
|
+
|
|
654
|
+
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
|
655
|
+
chunk_id = f"{doc_id}_chunk_{i}"
|
|
656
|
+
# Store in memory
|
|
657
|
+
self._session_embeddings[chunk_id] = {
|
|
658
|
+
'doc_id': doc_id,
|
|
659
|
+
'chunk_id': chunk_id,
|
|
660
|
+
'text': chunk,
|
|
661
|
+
'embedding': embedding,
|
|
662
|
+
'metadata': metadata or {}
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
# Save to database cache for future use (24-hour TTL)
|
|
666
|
+
await self._save_embeddings_to_cache(doc_id, chunks, embeddings, metadata)
|
|
667
|
+
|
|
668
|
+
logger.info(f"Added {len(chunks)} chunks for document {doc_id} to in-memory index and cache")
|
|
669
|
+
return True
|
|
670
|
+
|
|
671
|
+
except Exception as e:
|
|
672
|
+
logger.error(f"Failed to add document {doc_id} to index: {str(e)}")
|
|
673
|
+
return False
|
|
674
|
+
|
|
675
|
+
async def add_document_to_index_with_progress(self, doc_id: str, text: str, metadata: Dict[str, Any] = None, console: Console = None) -> bool:
|
|
676
|
+
"""Add document to in-memory vector index with progress bar
|
|
677
|
+
|
|
678
|
+
First checks if embeddings are cached in the database (24-hour TTL).
|
|
679
|
+
If cached, loads them directly. Otherwise, computes and caches them.
|
|
680
|
+
"""
|
|
681
|
+
try:
|
|
682
|
+
colors = get_theme_colors()
|
|
683
|
+
console = console or self.console
|
|
684
|
+
|
|
685
|
+
# Check if embeddings are already cached in the database
|
|
686
|
+
cached_embeddings = await self._get_cached_embeddings(doc_id)
|
|
687
|
+
|
|
688
|
+
if cached_embeddings:
|
|
689
|
+
# Load from cache - much faster!
|
|
690
|
+
left_to_right_reveal(console, f"Loading cached embeddings ({len(cached_embeddings)} chunks)...", style=f"bold {colors['primary']}", duration=0.8)
|
|
691
|
+
await self._load_embeddings_from_cache(cached_embeddings)
|
|
692
|
+
|
|
693
|
+
# Note: We intentionally do NOT pre-load the embedding model here.
|
|
694
|
+
# Query embeddings will use Gemini API if available (fast, no download needed).
|
|
695
|
+
# The local Granite model will only be loaded lazily if Gemini fails.
|
|
696
|
+
|
|
697
|
+
logger.info(f"Loaded {len(cached_embeddings)} cached embeddings for document {doc_id}")
|
|
698
|
+
return True
|
|
699
|
+
|
|
700
|
+
# No cache - need to compute embeddings
|
|
701
|
+
# First, chunk the text
|
|
702
|
+
chunks = self._chunk_text(text)
|
|
703
|
+
total_chunks = len(chunks)
|
|
704
|
+
|
|
705
|
+
if total_chunks == 0:
|
|
706
|
+
return False
|
|
707
|
+
|
|
708
|
+
# Show subtle hint for large papers
|
|
709
|
+
if total_chunks > 20:
|
|
710
|
+
console.print(f"[white]Processing [bold {colors['primary']}]{total_chunks} chunks [/bold {colors['primary']}](this may take a moment for large papers)...[/white]")
|
|
711
|
+
|
|
712
|
+
# Ensure embedding model is loaded BEFORE showing the computation progress bar
|
|
713
|
+
# This prevents model download progress from interfering with embedding progress
|
|
714
|
+
self.ensure_embedding_model_loaded()
|
|
715
|
+
|
|
716
|
+
# Create progress bar for embedding computation
|
|
717
|
+
with Progress(
|
|
718
|
+
SpinnerColumn(),
|
|
719
|
+
TextColumn("[progress.description]{task.description}"),
|
|
720
|
+
BarColumn(bar_width=50),
|
|
721
|
+
TaskProgressColumn(),
|
|
722
|
+
TextColumn("-"),
|
|
723
|
+
TimeRemainingColumn(),
|
|
724
|
+
console=console,
|
|
725
|
+
transient=False
|
|
726
|
+
) as progress:
|
|
727
|
+
task = progress.add_task(
|
|
728
|
+
f"[bold {colors['primary']}]Computing embeddings...",
|
|
729
|
+
total=total_chunks
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
# Process chunks in batches for the API
|
|
733
|
+
batch_size = 5
|
|
734
|
+
all_embeddings = []
|
|
735
|
+
|
|
736
|
+
for i in range(0, total_chunks, batch_size):
|
|
737
|
+
batch = chunks[i:i + batch_size]
|
|
738
|
+
batch_embeddings = await self.get_embeddings(batch)
|
|
739
|
+
all_embeddings.extend(batch_embeddings)
|
|
740
|
+
|
|
741
|
+
# Update progress for each chunk in the batch
|
|
742
|
+
progress.update(task, advance=len(batch))
|
|
743
|
+
|
|
744
|
+
# Store embeddings in memory
|
|
745
|
+
for i, (chunk, embedding) in enumerate(zip(chunks, all_embeddings)):
|
|
746
|
+
chunk_id = f"{doc_id}_chunk_{i}"
|
|
747
|
+
self._session_embeddings[chunk_id] = {
|
|
748
|
+
'doc_id': doc_id,
|
|
749
|
+
'chunk_id': chunk_id,
|
|
750
|
+
'text': chunk,
|
|
751
|
+
'embedding': embedding,
|
|
752
|
+
'metadata': metadata or {}
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
# Save to database cache for future use (24-hour TTL)
|
|
756
|
+
await self._save_embeddings_to_cache(doc_id, chunks, all_embeddings, metadata)
|
|
757
|
+
|
|
758
|
+
logger.info(f"Added {total_chunks} chunks for document {doc_id} to in-memory index and cache")
|
|
759
|
+
return True
|
|
760
|
+
|
|
761
|
+
except Exception as e:
|
|
762
|
+
logger.error(f"Failed to add document {doc_id} to index: {str(e)}")
|
|
763
|
+
return False
|
|
764
|
+
|
|
765
|
+
def clear_session_embeddings(self):
|
|
766
|
+
"""Clear in-memory embeddings when chat session ends"""
|
|
767
|
+
count = len(self._session_embeddings)
|
|
768
|
+
self._session_embeddings.clear()
|
|
769
|
+
self._current_session_id = None
|
|
770
|
+
logger.info(f"Cleared {count} embeddings from memory")
|
|
771
|
+
|
|
772
|
+
async def _get_cached_embeddings(self, doc_id: str) -> Optional[List[Dict[str, Any]]]:
|
|
773
|
+
"""Check if embeddings for a document are cached (tries API first, then local DB)"""
|
|
774
|
+
try:
|
|
775
|
+
# First, try to get from API (cloud cache - accessible across devices)
|
|
776
|
+
try:
|
|
777
|
+
from ..cli.utils.api_client import api_client
|
|
778
|
+
api_result = await api_client.get_embeddings(doc_id)
|
|
779
|
+
if api_result.get("success"):
|
|
780
|
+
embeddings = api_result.get("embeddings", [])
|
|
781
|
+
chunks = api_result.get("chunks", [])
|
|
782
|
+
batches = api_result.get("batches", 1)
|
|
783
|
+
|
|
784
|
+
if embeddings and chunks:
|
|
785
|
+
logger.info(f"Found {len(embeddings)} cached embeddings from cloud ({batches} batches) for {doc_id}")
|
|
786
|
+
|
|
787
|
+
# Convert to the format expected by _load_embeddings_from_cache
|
|
788
|
+
cached = []
|
|
789
|
+
for i, (embedding, chunk) in enumerate(zip(embeddings, chunks)):
|
|
790
|
+
cached.append({
|
|
791
|
+
'chunk_id': f"{doc_id}_chunk_{i}",
|
|
792
|
+
'doc_id': doc_id,
|
|
793
|
+
'chunk_text': chunk,
|
|
794
|
+
'embedding': embedding,
|
|
795
|
+
'expires_at': datetime.utcnow() + timedelta(hours=24)
|
|
796
|
+
})
|
|
797
|
+
return cached
|
|
798
|
+
except Exception as api_err:
|
|
799
|
+
logger.debug(f"Cloud cache not available, trying local: {api_err}")
|
|
800
|
+
|
|
801
|
+
# Fall back to local database cache
|
|
802
|
+
cached = await self.db_service.find_many(
|
|
803
|
+
self.vector_collection,
|
|
804
|
+
{
|
|
805
|
+
'doc_id': doc_id,
|
|
806
|
+
'expires_at': {'$gt': datetime.utcnow()}
|
|
807
|
+
},
|
|
808
|
+
limit=10000 # High limit to get all chunks for large papers
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
if cached and len(cached) > 0:
|
|
812
|
+
logger.info(f"Found {len(cached)} cached embeddings from local DB for {doc_id}")
|
|
813
|
+
return cached
|
|
814
|
+
return None
|
|
815
|
+
|
|
816
|
+
except Exception as e:
|
|
817
|
+
logger.warning(f"Failed to check cached embeddings: {str(e)}")
|
|
818
|
+
return None
|
|
819
|
+
|
|
820
|
+
async def _save_embeddings_to_cache(self, doc_id: str, chunks: List[str], embeddings: List[List[float]], metadata: Dict[str, Any] = None):
|
|
821
|
+
"""Save embeddings to API and local database with 24-hour TTL"""
|
|
822
|
+
try:
|
|
823
|
+
# First, try to save to API (cloud storage - accessible across devices)
|
|
824
|
+
api_saved = False
|
|
825
|
+
try:
|
|
826
|
+
from ..cli.utils.api_client import api_client
|
|
827
|
+
api_result = await api_client.save_embeddings(doc_id, embeddings, chunks)
|
|
828
|
+
if api_result.get("success"):
|
|
829
|
+
batches = api_result.get("message", "")
|
|
830
|
+
logger.info(f"✓ Saved {len(embeddings)} embeddings to cloud cache for {doc_id}: {batches}")
|
|
831
|
+
api_saved = True
|
|
832
|
+
else:
|
|
833
|
+
error_msg = api_result.get("message", "Unknown error")
|
|
834
|
+
logger.warning(f"Cloud cache save failed for {doc_id}: {error_msg}")
|
|
835
|
+
except Exception as api_err:
|
|
836
|
+
# Silently fall back to local cache - this is expected when offline or API unavailable
|
|
837
|
+
logger.debug(f"Using local cache only: {api_err}")
|
|
838
|
+
|
|
839
|
+
# Always save to local DB as backup
|
|
840
|
+
expires_at = datetime.utcnow() + timedelta(hours=24)
|
|
841
|
+
|
|
842
|
+
# Delete any existing embeddings for this document first
|
|
843
|
+
await self.db_service.delete_many(
|
|
844
|
+
self.vector_collection,
|
|
845
|
+
{'doc_id': doc_id}
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
# Save new embeddings
|
|
849
|
+
documents = []
|
|
850
|
+
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
|
851
|
+
chunk_id = f"{doc_id}_chunk_{i}"
|
|
852
|
+
documents.append({
|
|
853
|
+
'doc_id': doc_id,
|
|
854
|
+
'chunk_id': chunk_id,
|
|
855
|
+
'text': chunk,
|
|
856
|
+
'embedding': embedding,
|
|
857
|
+
'metadata': metadata or {},
|
|
858
|
+
'created_at': datetime.utcnow(),
|
|
859
|
+
'expires_at': expires_at
|
|
860
|
+
})
|
|
861
|
+
|
|
862
|
+
if documents:
|
|
863
|
+
await self.db_service.insert_many(self.vector_collection, documents)
|
|
864
|
+
logger.info(f"Saved {len(documents)} embeddings to local cache for document {doc_id} (expires in 24h)")
|
|
865
|
+
|
|
866
|
+
except Exception as e:
|
|
867
|
+
logger.warning(f"Failed to save embeddings to local cache: {str(e)}")
|
|
868
|
+
|
|
869
|
+
async def _load_embeddings_from_cache(self, cached_embeddings: List[Dict[str, Any]], cached_chunks: List[str] = None):
|
|
870
|
+
"""Load cached embeddings into session memory
|
|
871
|
+
|
|
872
|
+
Args:
|
|
873
|
+
cached_embeddings: Either a list of raw embedding vectors (from API), or
|
|
874
|
+
a list of dict objects with 'embedding', 'text', etc. (from local DB)
|
|
875
|
+
cached_chunks: Optional list of text chunks (only provided when embeddings are raw vectors from API)
|
|
876
|
+
"""
|
|
877
|
+
# Handle API format (parallel lists of embeddings and chunks)
|
|
878
|
+
if cached_chunks and cached_embeddings and isinstance(cached_embeddings[0], list):
|
|
879
|
+
# API format: embeddings is a list of vectors, chunks is a list of strings
|
|
880
|
+
for i, (embedding, chunk) in enumerate(zip(cached_embeddings, cached_chunks)):
|
|
881
|
+
chunk_id = f"cached_chunk_{i}"
|
|
882
|
+
self._session_embeddings[chunk_id] = {
|
|
883
|
+
'doc_id': 'cached',
|
|
884
|
+
'chunk_id': chunk_id,
|
|
885
|
+
'text': chunk,
|
|
886
|
+
'embedding': embedding,
|
|
887
|
+
'metadata': {}
|
|
888
|
+
}
|
|
889
|
+
logger.info(f"Loaded {len(cached_embeddings)} embeddings from API cache to session memory")
|
|
890
|
+
else:
|
|
891
|
+
# Local DB format: list of dicts with 'embedding', 'text', etc.
|
|
892
|
+
for doc in cached_embeddings:
|
|
893
|
+
chunk_id = doc.get('chunk_id')
|
|
894
|
+
self._session_embeddings[chunk_id] = {
|
|
895
|
+
'doc_id': doc.get('doc_id'),
|
|
896
|
+
'chunk_id': chunk_id,
|
|
897
|
+
'text': doc.get('text'),
|
|
898
|
+
'embedding': doc.get('embedding'),
|
|
899
|
+
'metadata': doc.get('metadata', {})
|
|
900
|
+
}
|
|
901
|
+
logger.info(f"Loaded {len(cached_embeddings)} embeddings from local cache to session memory")
|
|
902
|
+
|
|
903
|
+
async def search_similar_documents(self, query: str, filters: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
|
904
|
+
"""Search for similar documents using cosine similarity (in-memory)"""
|
|
905
|
+
try:
|
|
906
|
+
query_embedding = await self.get_single_embedding(query)
|
|
907
|
+
|
|
908
|
+
# Search in-memory embeddings
|
|
909
|
+
scored_docs = []
|
|
910
|
+
for chunk_id, doc in self._session_embeddings.items():
|
|
911
|
+
# Apply metadata filters if provided
|
|
912
|
+
if filters:
|
|
913
|
+
match = True
|
|
914
|
+
for key, value in filters.items():
|
|
915
|
+
# Handle nested keys like 'metadata.type'
|
|
916
|
+
keys = key.split('.')
|
|
917
|
+
doc_value = doc
|
|
918
|
+
for k in keys:
|
|
919
|
+
doc_value = doc_value.get(k, {}) if isinstance(doc_value, dict) else None
|
|
920
|
+
if doc_value != value:
|
|
921
|
+
match = False
|
|
922
|
+
break
|
|
923
|
+
if not match:
|
|
924
|
+
continue
|
|
925
|
+
|
|
926
|
+
doc_embedding = doc.get('embedding', [])
|
|
927
|
+
if doc_embedding:
|
|
928
|
+
score = await self.compute_similarity(query_embedding, doc_embedding)
|
|
929
|
+
scored_docs.append({
|
|
930
|
+
'doc_id': doc.get('doc_id'),
|
|
931
|
+
'chunk_id': doc.get('chunk_id'),
|
|
932
|
+
'text': doc.get('text'),
|
|
933
|
+
'metadata': doc.get('metadata', {}),
|
|
934
|
+
'score': score
|
|
935
|
+
})
|
|
936
|
+
|
|
937
|
+
# Sort by score descending and take top k
|
|
938
|
+
scored_docs.sort(key=lambda x: x['score'], reverse=True)
|
|
939
|
+
return scored_docs[:self.top_k_results]
|
|
940
|
+
|
|
941
|
+
except Exception as e:
|
|
942
|
+
logger.error(f"Vector search failed: {str(e)}")
|
|
943
|
+
return []
|
|
944
|
+
|
|
945
|
+
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
|
946
|
+
"""Compute cosine similarity between two embeddings"""
|
|
947
|
+
try:
|
|
948
|
+
if not ML_DEPENDENCIES_AVAILABLE:
|
|
949
|
+
return 0.0
|
|
950
|
+
|
|
951
|
+
vec1 = np.array(embedding1)
|
|
952
|
+
vec2 = np.array(embedding2)
|
|
953
|
+
|
|
954
|
+
dot_product = np.dot(vec1, vec2)
|
|
955
|
+
norm1 = np.linalg.norm(vec1)
|
|
956
|
+
norm2 = np.linalg.norm(vec2)
|
|
957
|
+
|
|
958
|
+
if norm1 == 0 or norm2 == 0:
|
|
959
|
+
return 0.0
|
|
960
|
+
|
|
961
|
+
similarity = dot_product / (norm1 * norm2)
|
|
962
|
+
return float(similarity)
|
|
963
|
+
|
|
964
|
+
except Exception as e:
|
|
965
|
+
logger.error(f"Similarity computation failed: {str(e)}")
|
|
966
|
+
return 0.0
|
|
967
|
+
|
|
968
|
+
async def start_chat_session(self, papers: List[Dict[str, Any]], user_id: str = "default"):
|
|
969
|
+
"""Start interactive chat session with a single paper (v1)
|
|
970
|
+
|
|
971
|
+
Embeddings are stored in memory during the session and cleared when done.
|
|
972
|
+
Chat history is persisted to DB with 24-hour TTL for resumption.
|
|
973
|
+
"""
|
|
974
|
+
try:
|
|
975
|
+
if not papers:
|
|
976
|
+
colors = get_theme_colors()
|
|
977
|
+
self.console.print(f"[{colors['error']}]No papers provided for chat session[/{colors['error']}]")
|
|
978
|
+
return
|
|
979
|
+
|
|
980
|
+
# V1: Limit to single paper
|
|
981
|
+
paper = papers[0]
|
|
982
|
+
paper_id = paper.get('arxiv_id') or paper.get('id')
|
|
983
|
+
|
|
984
|
+
if not paper_id:
|
|
985
|
+
colors = get_theme_colors()
|
|
986
|
+
self.console.print(f"[{colors['error']}]Paper has no ID[/{colors['error']}]")
|
|
987
|
+
return
|
|
988
|
+
|
|
989
|
+
colors = get_theme_colors()
|
|
990
|
+
|
|
991
|
+
# Check if cached embeddings are available or if we can generate new ones
|
|
992
|
+
cached_embeddings = paper.get('_cached_embeddings')
|
|
993
|
+
cached_chunks = paper.get('_cached_chunks')
|
|
994
|
+
|
|
995
|
+
# If no cached embeddings, check if embedding providers are available
|
|
996
|
+
if not cached_embeddings and not self.is_embedding_available():
|
|
997
|
+
# Show graceful error message
|
|
998
|
+
self.console.print(Panel(
|
|
999
|
+
f"[{colors['warning']}]{self.get_embedding_unavailable_message()}[/{colors['warning']}]",
|
|
1000
|
+
title=f"[bold {colors['warning']}]Feature Unavailable[/bold {colors['warning']}]",
|
|
1001
|
+
border_style=f"bold {colors['warning']}"
|
|
1002
|
+
))
|
|
1003
|
+
return
|
|
1004
|
+
|
|
1005
|
+
# Clear any previous session embeddings
|
|
1006
|
+
self.clear_session_embeddings()
|
|
1007
|
+
|
|
1008
|
+
# Check if cached embeddings were passed (already fetched from API/DB)
|
|
1009
|
+
if cached_embeddings:
|
|
1010
|
+
# Load cached embeddings directly into session memory
|
|
1011
|
+
await self._load_embeddings_from_cache(cached_embeddings, cached_chunks)
|
|
1012
|
+
logger.info(f"Loaded {len(cached_embeddings)} pre-cached embeddings for paper {paper_id}")
|
|
1013
|
+
else:
|
|
1014
|
+
# Generate embeddings and store in memory - with progress bar
|
|
1015
|
+
paper_text = self._extract_paper_text(paper)
|
|
1016
|
+
if paper_text:
|
|
1017
|
+
success = await self.add_document_to_index_with_progress(
|
|
1018
|
+
paper_id,
|
|
1019
|
+
paper_text,
|
|
1020
|
+
{'type': 'paper', 'title': paper.get('title', '')},
|
|
1021
|
+
console=self.console
|
|
1022
|
+
)
|
|
1023
|
+
if not success:
|
|
1024
|
+
# Embedding failed - show graceful message
|
|
1025
|
+
self.console.print(Panel(
|
|
1026
|
+
f"[{colors['warning']}]{self.get_embedding_unavailable_message()}[/{colors['warning']}]",
|
|
1027
|
+
title=f"[bold {colors['warning']}]Feature Unavailable[/bold {colors['warning']}]",
|
|
1028
|
+
border_style=f"bold {colors['warning']}"
|
|
1029
|
+
))
|
|
1030
|
+
return
|
|
1031
|
+
|
|
1032
|
+
# Create unique session ID
|
|
1033
|
+
import uuid
|
|
1034
|
+
session_id = f"{user_id}_{paper_id}_{uuid.uuid4().hex[:8]}"
|
|
1035
|
+
self._current_session_id = session_id
|
|
1036
|
+
|
|
1037
|
+
# Create session document with 24-hour TTL
|
|
1038
|
+
# Format authors list for display
|
|
1039
|
+
authors = paper.get('authors', [])
|
|
1040
|
+
if isinstance(authors, list):
|
|
1041
|
+
paper_authors = ', '.join(authors[:5]) # Limit to first 5 authors
|
|
1042
|
+
if len(authors) > 5:
|
|
1043
|
+
paper_authors += f' et al. ({len(authors)} authors)'
|
|
1044
|
+
else:
|
|
1045
|
+
paper_authors = str(authors) if authors else 'Unknown'
|
|
1046
|
+
|
|
1047
|
+
session_doc = {
|
|
1048
|
+
'session_id': session_id,
|
|
1049
|
+
'paper_id': paper_id, # Single paper in v1
|
|
1050
|
+
'paper_title': paper.get('title', ''),
|
|
1051
|
+
'paper_authors': paper_authors,
|
|
1052
|
+
'paper_published': paper.get('published', '')[:10] if paper.get('published') else 'Unknown',
|
|
1053
|
+
'user_id': user_id,
|
|
1054
|
+
'created_at': datetime.utcnow(),
|
|
1055
|
+
'last_activity': datetime.utcnow(),
|
|
1056
|
+
'expires_at': datetime.utcnow() + timedelta(hours=24), # 24-hour TTL
|
|
1057
|
+
'messages': []
|
|
1058
|
+
}
|
|
1059
|
+
|
|
1060
|
+
# Store in-memory as fallback (always works)
|
|
1061
|
+
self._in_memory_sessions[session_id] = session_doc
|
|
1062
|
+
|
|
1063
|
+
# Try to persist to Vercel API first (cloud storage)
|
|
1064
|
+
session_saved = False
|
|
1065
|
+
api_error = None
|
|
1066
|
+
try:
|
|
1067
|
+
from ..cli.utils.api_client import api_client
|
|
1068
|
+
api_result = await api_client.create_chat_session(
|
|
1069
|
+
paper_id=paper_id,
|
|
1070
|
+
title=paper.get('title', paper_id)
|
|
1071
|
+
)
|
|
1072
|
+
if api_result.get("success"):
|
|
1073
|
+
# Update in-memory session with API session_id for consistency
|
|
1074
|
+
api_session_id = api_result.get('session_id')
|
|
1075
|
+
# Store the API session ID for later updates
|
|
1076
|
+
self._in_memory_sessions[session_id]['api_session_id'] = api_session_id
|
|
1077
|
+
logger.info(f"Chat session saved to cloud: {api_session_id}")
|
|
1078
|
+
session_saved = True
|
|
1079
|
+
else:
|
|
1080
|
+
api_error = f"API failure: {api_result}"
|
|
1081
|
+
logger.warning(api_error)
|
|
1082
|
+
except Exception as api_err:
|
|
1083
|
+
api_error = f"API error: {api_err}"
|
|
1084
|
+
logger.warning(f"Session not saved to API: {api_err}")
|
|
1085
|
+
|
|
1086
|
+
# Also try local database as backup (regardless of API success)
|
|
1087
|
+
try:
|
|
1088
|
+
await self.db_service.insert_one(self.chat_collection, session_doc)
|
|
1089
|
+
logger.info(f"Chat session saved to local DB: {session_id}")
|
|
1090
|
+
session_saved = True
|
|
1091
|
+
except Exception as db_err:
|
|
1092
|
+
logger.debug(f"Session not saved to local database: {db_err}")
|
|
1093
|
+
|
|
1094
|
+
if not session_saved:
|
|
1095
|
+
logger.warning(f"Chat session only stored in-memory: {session_id}")
|
|
1096
|
+
if api_error:
|
|
1097
|
+
logger.warning(f"API save failed: {api_error}")
|
|
1098
|
+
|
|
1099
|
+
self.console.print(Panel(
|
|
1100
|
+
f"[bold {colors['primary']}]Chat Session Started[/bold {colors['primary']}]\n"
|
|
1101
|
+
f"Paper: [bold {colors['primary']}] {paper.get('title', paper_id)}[/bold {colors['primary']}]\n"
|
|
1102
|
+
f"Chunks indexed: [bold {colors['primary']}] {len(self._session_embeddings)}[/bold {colors['primary']}]\n"
|
|
1103
|
+
f"Type [bold {colors['primary']}]'quit'[/bold {colors['primary']}] or [bold {colors['primary']}]'exit'[/bold {colors['primary']}] to end the chat.",
|
|
1104
|
+
title=f"[bold]ArionXiv Paper Chat[/bold]",
|
|
1105
|
+
border_style=f"bold {colors['primary']}"
|
|
1106
|
+
))
|
|
1107
|
+
|
|
1108
|
+
try:
|
|
1109
|
+
await self._run_chat_loop(session_id)
|
|
1110
|
+
finally:
|
|
1111
|
+
# Always clean up embeddings when session ends
|
|
1112
|
+
self.clear_session_embeddings()
|
|
1113
|
+
|
|
1114
|
+
except Exception as e:
|
|
1115
|
+
logger.error(f"Chat session failed: {str(e)}")
|
|
1116
|
+
colors = get_theme_colors()
|
|
1117
|
+
self.console.print(f"[{colors['error']}]Chat session failed: {str(e)}[/{colors['error']}]")
|
|
1118
|
+
# Clean up on error too
|
|
1119
|
+
self.clear_session_embeddings()
|
|
1120
|
+
|
|
1121
|
+
async def continue_chat_session(self, session: Dict[str, Any], paper_info: Dict[str, Any]):
|
|
1122
|
+
"""Continue an existing chat session
|
|
1123
|
+
|
|
1124
|
+
Reloads the paper embeddings and continues the conversation.
|
|
1125
|
+
Extends the session TTL by 24 hours.
|
|
1126
|
+
"""
|
|
1127
|
+
try:
|
|
1128
|
+
colors = get_theme_colors()
|
|
1129
|
+
session_id = session.get('session_id')
|
|
1130
|
+
paper_title = session.get('paper_title', paper_info.get('title', 'Unknown Paper'))
|
|
1131
|
+
messages = session.get('messages', [])
|
|
1132
|
+
|
|
1133
|
+
if not session_id:
|
|
1134
|
+
self.console.print(f"[{colors['error']}]Invalid session: no session_id[/{colors['error']}]")
|
|
1135
|
+
return
|
|
1136
|
+
|
|
1137
|
+
# Check if cached embeddings are available or if we can generate new ones
|
|
1138
|
+
cached_embeddings = paper_info.get('_cached_embeddings')
|
|
1139
|
+
cached_chunks = paper_info.get('_cached_chunks')
|
|
1140
|
+
|
|
1141
|
+
# If no cached embeddings, check if embedding providers are available
|
|
1142
|
+
if not cached_embeddings and not self.is_embedding_available():
|
|
1143
|
+
# Show graceful error message
|
|
1144
|
+
self.console.print(Panel(
|
|
1145
|
+
f"[{colors['warning']}]{self.get_embedding_unavailable_message()}[/{colors['warning']}]",
|
|
1146
|
+
title=f"[bold {colors['warning']}]Feature Unavailable[/bold {colors['warning']}]",
|
|
1147
|
+
border_style=f"bold {colors['warning']}"
|
|
1148
|
+
))
|
|
1149
|
+
return
|
|
1150
|
+
|
|
1151
|
+
# Extract and format paper metadata for context
|
|
1152
|
+
# Format authors list for display
|
|
1153
|
+
authors = paper_info.get('authors', session.get('paper_authors', []))
|
|
1154
|
+
if isinstance(authors, list):
|
|
1155
|
+
paper_authors = ', '.join(authors) # Limit to first 5 authors
|
|
1156
|
+
if len(authors) > 5:
|
|
1157
|
+
paper_authors += f' et al. ({len(authors)} authors)'
|
|
1158
|
+
else:
|
|
1159
|
+
paper_authors = str(authors) if authors else 'Unknown'
|
|
1160
|
+
|
|
1161
|
+
# Get published date
|
|
1162
|
+
published = paper_info.get('published', session.get('paper_published', ''))
|
|
1163
|
+
paper_published = published[:10] if published else 'Unknown'
|
|
1164
|
+
|
|
1165
|
+
# Update session with paper metadata (for use in _chat_with_session)
|
|
1166
|
+
session['paper_title'] = paper_title
|
|
1167
|
+
session['paper_authors'] = paper_authors
|
|
1168
|
+
session['paper_published'] = paper_published
|
|
1169
|
+
|
|
1170
|
+
# Clear any previous session embeddings
|
|
1171
|
+
self.clear_session_embeddings()
|
|
1172
|
+
|
|
1173
|
+
# Use cached embeddings if available, otherwise generate new ones
|
|
1174
|
+
if cached_embeddings:
|
|
1175
|
+
# Use pre-loaded cached embeddings directly
|
|
1176
|
+
await self._load_embeddings_from_cache(cached_embeddings, cached_chunks)
|
|
1177
|
+
logger.info(f"Loaded {len(cached_embeddings)} cached embeddings for session")
|
|
1178
|
+
else:
|
|
1179
|
+
# Re-index the paper content
|
|
1180
|
+
paper_text = self._extract_paper_text(paper_info)
|
|
1181
|
+
if paper_text:
|
|
1182
|
+
paper_id = paper_info.get('arxiv_id') or paper_info.get('id')
|
|
1183
|
+
success = await self.add_document_to_index_with_progress(
|
|
1184
|
+
paper_id,
|
|
1185
|
+
paper_text,
|
|
1186
|
+
{'type': 'paper', 'title': paper_title},
|
|
1187
|
+
console=self.console
|
|
1188
|
+
)
|
|
1189
|
+
if not success:
|
|
1190
|
+
# Embedding failed - show graceful message
|
|
1191
|
+
self.console.print(Panel(
|
|
1192
|
+
f"[{colors['warning']}]{self.get_embedding_unavailable_message()}[/{colors['warning']}]",
|
|
1193
|
+
title=f"[bold {colors['warning']}]Feature Unavailable[/bold {colors['warning']}]",
|
|
1194
|
+
border_style=f"bold {colors['warning']}"
|
|
1195
|
+
))
|
|
1196
|
+
return
|
|
1197
|
+
|
|
1198
|
+
self._current_session_id = session_id
|
|
1199
|
+
# Store session in memory so _chat_with_session can find it
|
|
1200
|
+
self._in_memory_sessions[session_id] = session
|
|
1201
|
+
|
|
1202
|
+
# Extend the session TTL by 24 hours
|
|
1203
|
+
await self.db_service.extend_chat_session_ttl(session_id, hours=24)
|
|
1204
|
+
|
|
1205
|
+
# Show session info with previous message count
|
|
1206
|
+
self.console.print(Panel(
|
|
1207
|
+
f"[bold {colors['primary']}]Continuing Chat Session[/bold {colors['primary']}]\n"
|
|
1208
|
+
f"Paper: [bold {colors['primary']}]{paper_title}[/bold {colors['primary']}]\n"
|
|
1209
|
+
f"Previous messages: [bold {colors['primary']}]{len(messages)}[/bold {colors['primary']}]\n"
|
|
1210
|
+
f"Chunks indexed: [bold {colors['primary']}]{len(self._session_embeddings)}[/bold {colors['primary']}]\n"
|
|
1211
|
+
f"Session extended by 24 hours.\n"
|
|
1212
|
+
f"Type [bold {colors['primary']}]'quit'[/bold {colors['primary']}] or [bold {colors['primary']}]'exit'[/bold {colors['primary']}] to end the chat.",
|
|
1213
|
+
title=f"[bold]ArionXiv Paper Chat - Resumed[/bold]",
|
|
1214
|
+
border_style=f"bold {colors['primary']}"
|
|
1215
|
+
))
|
|
1216
|
+
|
|
1217
|
+
# Show a summary of recent conversation if there are messages
|
|
1218
|
+
if messages:
|
|
1219
|
+
# Show last 8 Q&A pairs (16 messages total)
|
|
1220
|
+
num_pairs = min(8, len(messages) // 2)
|
|
1221
|
+
if num_pairs > 0:
|
|
1222
|
+
recent = messages[-(num_pairs * 2):]
|
|
1223
|
+
else:
|
|
1224
|
+
recent = messages # Show whatever we have
|
|
1225
|
+
|
|
1226
|
+
left_to_right_reveal(self.console, f"\nRecent conversation ({num_pairs} Q&A):", style=f"bold {colors['primary']}", duration=0.8)
|
|
1227
|
+
for msg in recent:
|
|
1228
|
+
role = "You" if msg.get('type') == 'user' else "Assistant"
|
|
1229
|
+
content = msg.get('content', '')
|
|
1230
|
+
# Truncate long messages for display
|
|
1231
|
+
display_content = content[:150] + "..." if len(content) > 150 else content
|
|
1232
|
+
self.console.print(f"[dim {colors['primary']}]{role}: {display_content}[/dim {colors['primary']}]")
|
|
1233
|
+
|
|
1234
|
+
try:
|
|
1235
|
+
await self._run_chat_loop(session_id)
|
|
1236
|
+
finally:
|
|
1237
|
+
self.clear_session_embeddings()
|
|
1238
|
+
|
|
1239
|
+
except Exception as e:
|
|
1240
|
+
logger.error(f"Continue chat session failed: {str(e)}")
|
|
1241
|
+
colors = get_theme_colors()
|
|
1242
|
+
self.console.print(f"[{colors['error']}]Failed to continue session: {str(e)}[/{colors['error']}]")
|
|
1243
|
+
self.clear_session_embeddings()
|
|
1244
|
+
|
|
1245
|
+
async def _run_chat_loop(self, session_id: str):
|
|
1246
|
+
"""Run the chat interaction loop"""
|
|
1247
|
+
colors = get_theme_colors()
|
|
1248
|
+
while True:
|
|
1249
|
+
message = Prompt.ask(f"\n[bold {colors['primary']}]You[/bold {colors['primary']}]")
|
|
1250
|
+
|
|
1251
|
+
if message.lower() in ['quit', 'exit', 'q']:
|
|
1252
|
+
left_to_right_reveal(self.console, "\nEnding chat session. Goodbye!", style=f"bold {colors['primary']}", duration=1.0)
|
|
1253
|
+
break
|
|
1254
|
+
|
|
1255
|
+
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=self.console) as progress:
|
|
1256
|
+
task = progress.add_task(f"[bold {colors['primary']}]Thinking...", total=None)
|
|
1257
|
+
result = await self._chat_with_session(session_id, message)
|
|
1258
|
+
|
|
1259
|
+
if result['success']:
|
|
1260
|
+
# Stream the response over 2 seconds
|
|
1261
|
+
stream_markdown_response(
|
|
1262
|
+
self.console,
|
|
1263
|
+
result['response'],
|
|
1264
|
+
panel_title=f"[bold {colors['primary']}]ArionXiv Assistant[/bold {colors['primary']}]",
|
|
1265
|
+
border_style=colors['primary'],
|
|
1266
|
+
duration=1.0
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
# Build info line with chunks and model name
|
|
1270
|
+
info_parts = []
|
|
1271
|
+
if result['relevant_chunks'] > 0:
|
|
1272
|
+
info_parts.append(f"Used {result['relevant_chunks']} relevant content chunks")
|
|
1273
|
+
if result.get('model_display'):
|
|
1274
|
+
info_parts.append(f"• Model: {result['model_display']}")
|
|
1275
|
+
|
|
1276
|
+
if info_parts:
|
|
1277
|
+
info_text = " ".join(info_parts)
|
|
1278
|
+
left_to_right_reveal(self.console, info_text, style=f"dim {colors['muted']}", duration=1.0)
|
|
1279
|
+
else:
|
|
1280
|
+
left_to_right_reveal(self.console, f"Error: {result['error']}", style=f"bold {colors['error']}", duration=1.0)
|
|
1281
|
+
|
|
1282
|
+
def _show_post_chat_commands(self):
|
|
1283
|
+
"""Show helpful commands after chat session ends"""
|
|
1284
|
+
colors = get_theme_colors()
|
|
1285
|
+
|
|
1286
|
+
commands = [
|
|
1287
|
+
("arionxiv chat", "Start a new chat session"),
|
|
1288
|
+
("arionxiv search <query>", "Search for more papers"),
|
|
1289
|
+
("arionxiv settings papers", "Manage your saved papers"),
|
|
1290
|
+
("arionxiv trending", "See trending papers"),
|
|
1291
|
+
("arionxiv daily", "Get your daily paper digest"),
|
|
1292
|
+
]
|
|
1293
|
+
|
|
1294
|
+
self.console.print()
|
|
1295
|
+
self.console.print(Panel(
|
|
1296
|
+
"\n".join([
|
|
1297
|
+
f"[bold {colors['primary']}]{cmd}[/bold {colors['primary']}] [white]→ {desc}[/white]"
|
|
1298
|
+
for cmd, desc in commands
|
|
1299
|
+
]),
|
|
1300
|
+
title=f"[bold {colors['primary']}]What's Next?[/bold {colors['primary']}]",
|
|
1301
|
+
border_style=f"bold {colors['primary']}",
|
|
1302
|
+
padding=(1, 2)
|
|
1303
|
+
))
|
|
1304
|
+
|
|
1305
|
+
async def _chat_with_session(self, session_id: str, message: str) -> Dict[str, Any]:
|
|
1306
|
+
"""Process a chat message and generate response"""
|
|
1307
|
+
try:
|
|
1308
|
+
# Try database first, fall back to in-memory
|
|
1309
|
+
session = None
|
|
1310
|
+
try:
|
|
1311
|
+
session = await self.db_service.find_one(self.chat_collection, {'session_id': session_id})
|
|
1312
|
+
except Exception:
|
|
1313
|
+
pass
|
|
1314
|
+
|
|
1315
|
+
# Fall back to in-memory session
|
|
1316
|
+
if not session:
|
|
1317
|
+
session = self._in_memory_sessions.get(session_id)
|
|
1318
|
+
|
|
1319
|
+
if not session:
|
|
1320
|
+
return {'success': False, 'error': 'Session not found'}
|
|
1321
|
+
|
|
1322
|
+
relevant_chunks = await self.search_similar_documents(message, {'metadata.type': 'paper'})
|
|
1323
|
+
context = "\n\n".join([chunk['text'] for chunk in relevant_chunks[:10]]) # Increased from 5 to 10 chunks for richer context
|
|
1324
|
+
|
|
1325
|
+
# Get conversation history for context
|
|
1326
|
+
chat_history = session.get('messages', [])
|
|
1327
|
+
|
|
1328
|
+
# Get paper metadata for context
|
|
1329
|
+
paper_title = session.get('paper_title', session.get('title', 'Unknown Paper'))
|
|
1330
|
+
paper_authors = session.get('paper_authors', 'Unknown')
|
|
1331
|
+
paper_published = session.get('paper_published', 'Unknown')
|
|
1332
|
+
|
|
1333
|
+
# Determine which LLM to use and generate response
|
|
1334
|
+
model_display = ""
|
|
1335
|
+
success = False
|
|
1336
|
+
response_text = ""
|
|
1337
|
+
error_msg = ""
|
|
1338
|
+
|
|
1339
|
+
# Try OpenRouter for chat, fallback to hosted API
|
|
1340
|
+
if self.openrouter_client and self.openrouter_client.is_available:
|
|
1341
|
+
try:
|
|
1342
|
+
result = await self.openrouter_client.chat(
|
|
1343
|
+
message=message,
|
|
1344
|
+
context=context,
|
|
1345
|
+
history=chat_history,
|
|
1346
|
+
paper_title=paper_title,
|
|
1347
|
+
paper_authors=paper_authors,
|
|
1348
|
+
paper_published=paper_published
|
|
1349
|
+
)
|
|
1350
|
+
if result.get('success'):
|
|
1351
|
+
response_text, model_display, success = result['response'], result.get('model_display', 'OpenRouter'), True
|
|
1352
|
+
else:
|
|
1353
|
+
error_msg = result.get('error', 'OpenRouter failed')
|
|
1354
|
+
except Exception as e:
|
|
1355
|
+
logger.debug(f"OpenRouter error: {e}")
|
|
1356
|
+
|
|
1357
|
+
# Hosted API Fallback (using developer keys on backend)
|
|
1358
|
+
if not success:
|
|
1359
|
+
try:
|
|
1360
|
+
from ..cli.utils.api_client import api_client
|
|
1361
|
+
paper_id = session.get('arxiv_id') or session.get('paper_id')
|
|
1362
|
+
paper_title = session.get('title') or session.get('paper_title')
|
|
1363
|
+
# Pass RAG context to API for paper-aware responses
|
|
1364
|
+
result = await api_client.send_chat_message(
|
|
1365
|
+
message=message,
|
|
1366
|
+
paper_id=paper_id,
|
|
1367
|
+
session_id=session_id,
|
|
1368
|
+
context=context, # Send RAG context
|
|
1369
|
+
paper_title=paper_title # Send paper title
|
|
1370
|
+
)
|
|
1371
|
+
if result.get('success'):
|
|
1372
|
+
response_text = result['response']
|
|
1373
|
+
model_display = result.get('model', 'ArionXiv Cloud')
|
|
1374
|
+
success = True
|
|
1375
|
+
else:
|
|
1376
|
+
error_msg = result.get('error', 'Hosted API failed')
|
|
1377
|
+
except Exception as e:
|
|
1378
|
+
# Extract meaningful error message from APIClientError
|
|
1379
|
+
if hasattr(e, 'message') and e.message:
|
|
1380
|
+
# Clean up the error message for user display
|
|
1381
|
+
msg = e.message
|
|
1382
|
+
if "serverless timeout" in msg.lower():
|
|
1383
|
+
error_msg = "Chat service timeout. For reliable chat, run 'arionxiv settings api' to set your own OPENROUTER_API_KEY."
|
|
1384
|
+
elif "503" in str(getattr(e, 'status_code', '')) or "unavailable" in msg.lower():
|
|
1385
|
+
error_msg = "Chat service temporarily unavailable. Set your OPENROUTER_API_KEY via 'arionxiv settings api' for uninterrupted chat."
|
|
1386
|
+
else:
|
|
1387
|
+
error_msg = f"Chat unavailable: {msg}"
|
|
1388
|
+
elif hasattr(e, 'status_code') and e.status_code:
|
|
1389
|
+
if e.status_code == 503:
|
|
1390
|
+
error_msg = "Chat service temporarily unavailable. For reliable chat, set your OPENROUTER_API_KEY via 'arionxiv settings api'."
|
|
1391
|
+
else:
|
|
1392
|
+
error_msg = f"Chat unavailable: API error {e.status_code}"
|
|
1393
|
+
else:
|
|
1394
|
+
error_msg = f"Chat unavailable: {str(e) or 'Unknown error'}"
|
|
1395
|
+
logger.debug(f"Hosted API error: {e}")
|
|
1396
|
+
|
|
1397
|
+
if not success:
|
|
1398
|
+
return {'success': False, 'error': error_msg or 'Failed to generate response'}
|
|
1399
|
+
|
|
1400
|
+
# Update in-memory session
|
|
1401
|
+
if session_id in self._in_memory_sessions:
|
|
1402
|
+
self._in_memory_sessions[session_id]['messages'].extend([
|
|
1403
|
+
{'type': 'user', 'content': message, 'timestamp': datetime.utcnow()},
|
|
1404
|
+
{'type': 'assistant', 'content': response_text, 'timestamp': datetime.utcnow()}
|
|
1405
|
+
])
|
|
1406
|
+
self._in_memory_sessions[session_id]['last_activity'] = datetime.utcnow()
|
|
1407
|
+
|
|
1408
|
+
# Try to persist to Vercel API (cloud storage)
|
|
1409
|
+
try:
|
|
1410
|
+
from ..cli.utils.api_client import api_client
|
|
1411
|
+
# Get full message history from in-memory session
|
|
1412
|
+
if session_id in self._in_memory_sessions:
|
|
1413
|
+
# Use the API session ID (from MongoDB) for updates
|
|
1414
|
+
api_session_id = self._in_memory_sessions[session_id].get('api_session_id')
|
|
1415
|
+
if api_session_id:
|
|
1416
|
+
all_messages = self._in_memory_sessions[session_id].get('messages', [])
|
|
1417
|
+
# Convert datetime objects to ISO strings for JSON serialization
|
|
1418
|
+
serializable_messages = []
|
|
1419
|
+
for msg in all_messages:
|
|
1420
|
+
serializable_messages.append({
|
|
1421
|
+
'type': msg.get('type'),
|
|
1422
|
+
'content': msg.get('content'),
|
|
1423
|
+
'timestamp': msg.get('timestamp').isoformat() if hasattr(msg.get('timestamp'), 'isoformat') else str(msg.get('timestamp'))
|
|
1424
|
+
})
|
|
1425
|
+
await api_client.update_chat_session(api_session_id, serializable_messages)
|
|
1426
|
+
logger.debug(f"Messages saved to API for session {api_session_id}")
|
|
1427
|
+
except Exception as api_err:
|
|
1428
|
+
logger.debug(f"Failed to save messages to API: {api_err}")
|
|
1429
|
+
|
|
1430
|
+
# Try to persist to local database (may fail)
|
|
1431
|
+
try:
|
|
1432
|
+
await self.db_service.update_one(
|
|
1433
|
+
self.chat_collection,
|
|
1434
|
+
{'session_id': session_id},
|
|
1435
|
+
{
|
|
1436
|
+
'$push': {
|
|
1437
|
+
'messages': {
|
|
1438
|
+
'$each': [
|
|
1439
|
+
{'type': 'user', 'content': message, 'timestamp': datetime.utcnow()},
|
|
1440
|
+
{'type': 'assistant', 'content': response_text, 'timestamp': datetime.utcnow()}
|
|
1441
|
+
]
|
|
1442
|
+
}
|
|
1443
|
+
},
|
|
1444
|
+
'$set': {'last_activity': datetime.utcnow()}
|
|
1445
|
+
}
|
|
1446
|
+
)
|
|
1447
|
+
except Exception:
|
|
1448
|
+
pass # In-memory session is already updated
|
|
1449
|
+
|
|
1450
|
+
return {
|
|
1451
|
+
'success': True,
|
|
1452
|
+
'response': response_text,
|
|
1453
|
+
'relevant_chunks': len(relevant_chunks),
|
|
1454
|
+
'session_id': session_id,
|
|
1455
|
+
'model_display': model_display
|
|
1456
|
+
}
|
|
1457
|
+
|
|
1458
|
+
except Exception as e:
|
|
1459
|
+
logger.error(f"Chat failed for session {session_id}: {str(e)}")
|
|
1460
|
+
return {'success': False, 'error': f'Chat failed: {str(e)}'}
|
|
1461
|
+
|
|
1462
|
+
def _extract_paper_text(self, paper: Dict[str, Any]) -> str:
|
|
1463
|
+
"""Extract text content from paper for indexing"""
|
|
1464
|
+
text_parts = []
|
|
1465
|
+
|
|
1466
|
+
if paper.get('title'):
|
|
1467
|
+
text_parts.append(paper['title'])
|
|
1468
|
+
|
|
1469
|
+
if paper.get('abstract'):
|
|
1470
|
+
text_parts.append(paper['abstract'])
|
|
1471
|
+
|
|
1472
|
+
if paper.get('full_text'):
|
|
1473
|
+
text_parts.append(paper['full_text'])
|
|
1474
|
+
|
|
1475
|
+
return '\n\n'.join(text_parts)
|
|
1476
|
+
|
|
1477
|
+
def _build_chat_prompt(self, session: Dict[str, Any], message: str, context: str) -> str:
|
|
1478
|
+
"""Build chat prompt with context"""
|
|
1479
|
+
from ..prompts import format_prompt
|
|
1480
|
+
|
|
1481
|
+
chat_history = session.get('messages', [])
|
|
1482
|
+
|
|
1483
|
+
history_text = ""
|
|
1484
|
+
recent_messages = chat_history[-6:] if len(chat_history) > 6 else chat_history
|
|
1485
|
+
|
|
1486
|
+
for msg in recent_messages:
|
|
1487
|
+
role = "User" if msg['type'] == 'user' else "Assistant"
|
|
1488
|
+
history_text += f"{role}: {msg['content']}\n"
|
|
1489
|
+
|
|
1490
|
+
return format_prompt("rag_chat",
|
|
1491
|
+
context=context,
|
|
1492
|
+
history=history_text,
|
|
1493
|
+
message=message)
|
|
1494
|
+
|
|
1495
|
+
def _parse_llm_response(self, response: Any) -> Tuple[bool, str, str]:
|
|
1496
|
+
"""Normalize LLM responses that may return strings or dictionaries"""
|
|
1497
|
+
if isinstance(response, dict):
|
|
1498
|
+
if response.get('success', True) and isinstance(response.get('content'), str):
|
|
1499
|
+
content = response['content'].strip()
|
|
1500
|
+
if content:
|
|
1501
|
+
return True, content, ""
|
|
1502
|
+
return False, "", response.get('error', 'LLM response missing content')
|
|
1503
|
+
if isinstance(response, str):
|
|
1504
|
+
text = response.strip()
|
|
1505
|
+
if text and not text.startswith('Error'):
|
|
1506
|
+
return True, text, ""
|
|
1507
|
+
return False, "", text or 'LLM returned empty response'
|
|
1508
|
+
if response is None:
|
|
1509
|
+
return False, "", 'LLM returned no response'
|
|
1510
|
+
return False, "", 'Unexpected LLM response type'
|
|
1511
|
+
|
|
1512
|
+
async def cleanup_expired_data(self):
|
|
1513
|
+
"""Clean up expired embeddings and chat sessions"""
|
|
1514
|
+
try:
|
|
1515
|
+
cutoff_time = datetime.utcnow()
|
|
1516
|
+
|
|
1517
|
+
deleted_embeddings = await self.db_service.delete_many(
|
|
1518
|
+
self.vector_collection,
|
|
1519
|
+
{'expires_at': {'$lt': cutoff_time}}
|
|
1520
|
+
)
|
|
1521
|
+
|
|
1522
|
+
chat_cutoff = datetime.utcnow() - timedelta(days=7)
|
|
1523
|
+
deleted_sessions = await self.db_service.delete_many(
|
|
1524
|
+
self.chat_collection,
|
|
1525
|
+
{'last_activity': {'$lt': chat_cutoff}}
|
|
1526
|
+
)
|
|
1527
|
+
|
|
1528
|
+
logger.info(f"RAG cleanup completed: deleted {deleted_embeddings} embeddings, {deleted_sessions} sessions")
|
|
1529
|
+
|
|
1530
|
+
except Exception as e:
|
|
1531
|
+
logger.error(f"RAG cleanup failed: {str(e)}")
|