arionxiv 1.0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. arionxiv/__init__.py +40 -0
  2. arionxiv/__main__.py +10 -0
  3. arionxiv/arxiv_operations/__init__.py +0 -0
  4. arionxiv/arxiv_operations/client.py +225 -0
  5. arionxiv/arxiv_operations/fetcher.py +173 -0
  6. arionxiv/arxiv_operations/searcher.py +122 -0
  7. arionxiv/arxiv_operations/utils.py +293 -0
  8. arionxiv/cli/__init__.py +4 -0
  9. arionxiv/cli/commands/__init__.py +1 -0
  10. arionxiv/cli/commands/analyze.py +587 -0
  11. arionxiv/cli/commands/auth.py +365 -0
  12. arionxiv/cli/commands/chat.py +714 -0
  13. arionxiv/cli/commands/daily.py +482 -0
  14. arionxiv/cli/commands/fetch.py +217 -0
  15. arionxiv/cli/commands/library.py +295 -0
  16. arionxiv/cli/commands/preferences.py +426 -0
  17. arionxiv/cli/commands/search.py +254 -0
  18. arionxiv/cli/commands/settings_unified.py +1407 -0
  19. arionxiv/cli/commands/trending.py +41 -0
  20. arionxiv/cli/commands/welcome.py +168 -0
  21. arionxiv/cli/main.py +407 -0
  22. arionxiv/cli/ui/__init__.py +1 -0
  23. arionxiv/cli/ui/global_theme_manager.py +173 -0
  24. arionxiv/cli/ui/logo.py +127 -0
  25. arionxiv/cli/ui/splash.py +89 -0
  26. arionxiv/cli/ui/theme.py +32 -0
  27. arionxiv/cli/ui/theme_system.py +391 -0
  28. arionxiv/cli/utils/__init__.py +54 -0
  29. arionxiv/cli/utils/animations.py +522 -0
  30. arionxiv/cli/utils/api_client.py +583 -0
  31. arionxiv/cli/utils/api_config.py +505 -0
  32. arionxiv/cli/utils/command_suggestions.py +147 -0
  33. arionxiv/cli/utils/db_config_manager.py +254 -0
  34. arionxiv/github_actions_runner.py +206 -0
  35. arionxiv/main.py +23 -0
  36. arionxiv/prompts/__init__.py +9 -0
  37. arionxiv/prompts/prompts.py +247 -0
  38. arionxiv/rag_techniques/__init__.py +8 -0
  39. arionxiv/rag_techniques/basic_rag.py +1531 -0
  40. arionxiv/scheduler_daemon.py +139 -0
  41. arionxiv/server.py +1000 -0
  42. arionxiv/server_main.py +24 -0
  43. arionxiv/services/__init__.py +73 -0
  44. arionxiv/services/llm_client.py +30 -0
  45. arionxiv/services/llm_inference/__init__.py +58 -0
  46. arionxiv/services/llm_inference/groq_client.py +469 -0
  47. arionxiv/services/llm_inference/llm_utils.py +250 -0
  48. arionxiv/services/llm_inference/openrouter_client.py +564 -0
  49. arionxiv/services/unified_analysis_service.py +872 -0
  50. arionxiv/services/unified_auth_service.py +457 -0
  51. arionxiv/services/unified_config_service.py +456 -0
  52. arionxiv/services/unified_daily_dose_service.py +823 -0
  53. arionxiv/services/unified_database_service.py +1633 -0
  54. arionxiv/services/unified_llm_service.py +366 -0
  55. arionxiv/services/unified_paper_service.py +604 -0
  56. arionxiv/services/unified_pdf_service.py +522 -0
  57. arionxiv/services/unified_prompt_service.py +344 -0
  58. arionxiv/services/unified_scheduler_service.py +589 -0
  59. arionxiv/services/unified_user_service.py +954 -0
  60. arionxiv/utils/__init__.py +51 -0
  61. arionxiv/utils/api_helpers.py +200 -0
  62. arionxiv/utils/file_cleanup.py +150 -0
  63. arionxiv/utils/ip_helper.py +96 -0
  64. arionxiv-1.0.32.dist-info/METADATA +336 -0
  65. arionxiv-1.0.32.dist-info/RECORD +69 -0
  66. arionxiv-1.0.32.dist-info/WHEEL +5 -0
  67. arionxiv-1.0.32.dist-info/entry_points.txt +4 -0
  68. arionxiv-1.0.32.dist-info/licenses/LICENSE +21 -0
  69. arionxiv-1.0.32.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1531 @@
1
+ """
2
+ Basic RAG Implementation for ArionXiv
3
+ Provides standard Retrieval-Augmented Generation with text chunking, embedding generation, and vector search
4
+ """
5
+
6
+ import asyncio
7
+ from datetime import datetime, timedelta
8
+ from typing import List, Dict, Any, Optional, Union, Tuple
9
+ from abc import ABC, abstractmethod
10
+ import logging
11
+ from pymongo import IndexModel
12
+ import os
13
+
14
+ try:
15
+ import numpy as np
16
+ from sentence_transformers import SentenceTransformer
17
+ ML_DEPENDENCIES_AVAILABLE = True
18
+ except ImportError:
19
+ ML_DEPENDENCIES_AVAILABLE = False
20
+ np = None
21
+ SentenceTransformer = None
22
+
23
+ try:
24
+ from google import genai
25
+ GEMINI_AVAILABLE = True
26
+ except ImportError:
27
+ GEMINI_AVAILABLE = False
28
+ genai = None
29
+
30
+ from rich.console import Console
31
+ from rich.panel import Panel
32
+ from rich.prompt import Prompt
33
+ from rich.markdown import Markdown
34
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
35
+
36
+ # Global cache for loaded embedding models to avoid reloading across sessions
37
+ # This persists the model in memory once loaded
38
+ _GLOBAL_MODEL_CACHE: Dict[str, Any] = {}
39
+
40
+ # Import theme system for consistent styling
41
+ try:
42
+ from ..cli.ui.theme import (
43
+ create_themed_console, get_theme_colors, style_text,
44
+ print_success, print_error, print_warning, print_info
45
+ )
46
+ from ..cli.utils.animations import left_to_right_reveal, stream_markdown_response
47
+ from ..cli.utils.command_suggestions import show_command_suggestions
48
+ THEME_AVAILABLE = True
49
+ except ImportError:
50
+ THEME_AVAILABLE = False
51
+ def get_theme_colors(db_service=None):
52
+ return {'primary': 'blue', 'secondary': 'cyan', 'success': 'green',
53
+ 'warning': 'yellow', 'error': 'red', 'muted': 'dim'}
54
+ def style_text(text, style='primary', db_service=None):
55
+ colors = get_theme_colors()
56
+ return f"[{colors.get(style, 'white')}]{text}[/{colors.get(style, 'white')}]"
57
+ def create_themed_console(db_service=None):
58
+ return Console()
59
+ def left_to_right_reveal(console, text, style="", duration=1.0):
60
+ console.print(text)
61
+ def stream_markdown_response(console, text, panel_title="", border_style=None, duration=3.0):
62
+ colors = get_theme_colors()
63
+ actual_style = border_style or colors.get('primary', 'blue')
64
+ console.print(Panel(Markdown(text), title=panel_title, border_style=actual_style))
65
+ def show_command_suggestions(console, context="general", **kwargs):
66
+ pass # No-op fallback
67
+
68
+ # Import API config manager to check if Gemini key is available
69
+ try:
70
+ from ..cli.utils.api_config import api_config_manager
71
+ API_CONFIG_AVAILABLE = True
72
+ except ImportError:
73
+ API_CONFIG_AVAILABLE = False
74
+ api_config_manager = None
75
+
76
+ logger = logging.getLogger(__name__)
77
+
78
+ class EmbeddingProvider(ABC):
79
+ """Abstract base class for embedding providers"""
80
+
81
+ @abstractmethod
82
+ async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
83
+ """Get embeddings for a list of texts"""
84
+ pass
85
+
86
+ @abstractmethod
87
+ def get_dimension(self) -> int:
88
+ """Get the dimension of embeddings"""
89
+ pass
90
+
91
+ @abstractmethod
92
+ def get_name(self) -> str:
93
+ """Get provider name"""
94
+ pass
95
+
96
+
97
+ class GeminiEmbeddingProvider(EmbeddingProvider):
98
+ """Google Gemini embedding provider using gemini-embedding-001 model (FREE!)
99
+
100
+ Uses output_dimensionality=768 for efficient storage (default is 3072).
101
+ """
102
+
103
+ def __init__(self, api_key: str = None, console: Console = None):
104
+ if not GEMINI_AVAILABLE:
105
+ raise ImportError("google-genai not installed. Install with: pip install google-genai")
106
+
107
+ self.api_key = api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
108
+ if not self.api_key:
109
+ raise ValueError("Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.")
110
+
111
+ # Use new genai.Client() API
112
+ self.client = genai.Client(api_key=self.api_key)
113
+ self.model = "gemini-embedding-001"
114
+ self.dimension = 768 # Using reduced dimensionality for efficiency
115
+ self._console = console or Console()
116
+
117
+ logger.info("Gemini embedding provider initialized with free API")
118
+
119
+ async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
120
+ """Get embeddings using Gemini API (FREE!) with rate limit handling"""
121
+ try:
122
+ batch_size = 10
123
+ all_embeddings = []
124
+ max_retries = 3
125
+
126
+ for i in range(0, len(texts), batch_size):
127
+ batch = texts[i:i + batch_size]
128
+ batch_embeddings = []
129
+
130
+ for text in batch:
131
+ retries = 0
132
+ while retries < max_retries:
133
+ try:
134
+ # New API: client.models.embed_content()
135
+ result = self.client.models.embed_content(
136
+ model=self.model,
137
+ contents=text
138
+ )
139
+ # New API returns result.embeddings[0].values
140
+ batch_embeddings.append(list(result.embeddings[0].values))
141
+ await asyncio.sleep(0.1)
142
+ break # Success, exit retry loop
143
+ except Exception as e:
144
+ error_str = str(e).lower()
145
+ # Check for rate limit errors - silently retry with backoff
146
+ if any(term in error_str for term in ['rate limit', 'quota', '429', 'resource exhausted', 'too many']):
147
+ retries += 1
148
+ if retries < max_retries:
149
+ wait_time = (2 ** retries) * 2 # Exponential backoff: 4, 8, 16 seconds
150
+ await asyncio.sleep(wait_time)
151
+ else:
152
+ # Max retries reached, use fallback
153
+ batch_embeddings.append([0.0] * self.dimension)
154
+ else:
155
+ logger.debug(f"Failed to embed text: {str(e)}")
156
+ batch_embeddings.append([0.0] * self.dimension)
157
+ break
158
+
159
+ all_embeddings.extend(batch_embeddings)
160
+
161
+ if i + batch_size < len(texts):
162
+ await asyncio.sleep(0.5)
163
+
164
+ return all_embeddings
165
+
166
+ except Exception as e:
167
+ logger.error(f"Gemini embedding failed: {str(e)}")
168
+ raise
169
+
170
+ def get_dimension(self) -> int:
171
+ return self.dimension
172
+
173
+ def get_name(self) -> str:
174
+ return "Google-Gemini-Embedding-001-FREE"
175
+
176
+
177
+ class HuggingFaceEmbeddingProvider(EmbeddingProvider):
178
+ """HuggingFace embedding provider using sentence-transformers (fallback)"""
179
+
180
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
181
+ if not ML_DEPENDENCIES_AVAILABLE:
182
+ raise ImportError(
183
+ "ML dependencies not installed. Install with: pip install sentence-transformers numpy"
184
+ )
185
+ self.model_name = model_name
186
+ self.model = None
187
+ self._dimension = None
188
+ self._console = Console()
189
+
190
+ def _load_model(self):
191
+ """Lazy load the model"""
192
+ if self.model is None:
193
+ logger.info(f"Loading HuggingFace model: {self.model_name}")
194
+ colors = get_theme_colors()
195
+ self._console.print(f"[{colors['muted']}]Loading fallback model: {self.model_name}[/{colors['muted']}]")
196
+
197
+ # Suppress HuggingFace's internal progress bars
198
+ import os
199
+ os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
200
+
201
+ with Progress(
202
+ SpinnerColumn(),
203
+ TextColumn("[progress.description]{task.description}"),
204
+ console=self._console,
205
+ transient=True
206
+ ) as progress:
207
+ task = progress.add_task(
208
+ f"[{colors['primary']}]Loading model...[/{colors['primary']}]",
209
+ total=None
210
+ )
211
+ self.model = SentenceTransformer(self.model_name)
212
+ self._dimension = self.model.get_sentence_embedding_dimension()
213
+
214
+ # Re-enable progress bars for other operations
215
+ os.environ.pop('HF_HUB_DISABLE_PROGRESS_BARS', None)
216
+
217
+ self._console.print(f"[{colors['primary']}][OK][/{colors['primary']}] Fallback model ready")
218
+
219
+ async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
220
+ """Get embeddings using HuggingFace model"""
221
+ try:
222
+ self._load_model()
223
+ loop = asyncio.get_event_loop()
224
+ embeddings = await loop.run_in_executor(None, self.model.encode, texts)
225
+ return embeddings.tolist()
226
+ except Exception as e:
227
+ logger.error(f"HuggingFace embedding failed: {str(e)}")
228
+ raise
229
+
230
+ def get_dimension(self) -> int:
231
+ if self._dimension is None:
232
+ self._load_model()
233
+ return self._dimension
234
+
235
+ def get_name(self) -> str:
236
+ return f"HuggingFace-{self.model_name}"
237
+
238
+
239
+ class GraniteDoclingEmbeddingProvider(EmbeddingProvider):
240
+ """
241
+ IBM Granite embedding provider - small, fast, and runs locally
242
+
243
+ Downloads the model on first use. Model is kept in memory during
244
+ the session and uses HuggingFace's default cache (~/.cache/huggingface/).
245
+ """
246
+
247
+ # Default model - IBM Granite 30M English (small, ~120MB download)
248
+ DEFAULT_MODEL = "ibm-granite/granite-embedding-30m-english"
249
+
250
+ def __init__(self, model_name: str = None):
251
+ if not ML_DEPENDENCIES_AVAILABLE:
252
+ raise ImportError(
253
+ "ML dependencies not installed. Install with: pip install sentence-transformers numpy"
254
+ )
255
+ self.model_name = model_name or self.DEFAULT_MODEL
256
+ self._dimension = None
257
+ self._console = Console()
258
+
259
+ @property
260
+ def model(self):
261
+ """Get model from global cache or None if not loaded"""
262
+ return _GLOBAL_MODEL_CACHE.get(self.model_name)
263
+
264
+ @model.setter
265
+ def model(self, value):
266
+ """Store model in global cache"""
267
+ if value is not None:
268
+ _GLOBAL_MODEL_CACHE[self.model_name] = value
269
+
270
+ def _load_model(self):
271
+ """Lazy load the model with progress indicator - uses global cache"""
272
+ # Check global cache first - model persists across sessions
273
+ if self.model_name in _GLOBAL_MODEL_CACHE:
274
+ self._dimension = _GLOBAL_MODEL_CACHE[self.model_name].get_sentence_embedding_dimension()
275
+ return # Model already in memory, no loading needed
276
+
277
+ colors = get_theme_colors()
278
+ logger.info(f"Loading embedding model: {self.model_name}")
279
+
280
+ # Check if model is already cached by HuggingFace on disk
281
+ from pathlib import Path
282
+ cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
283
+ model_cache_name = f"models--{self.model_name.replace('/', '--')}"
284
+ is_cached = (cache_dir / model_cache_name).exists()
285
+
286
+ if not is_cached:
287
+ # First time - show download message
288
+ self._console.print(
289
+ f"[dim {colors['primary']}]Downloading embedding model: {self.model_name}[/dim {colors['primary']}]"
290
+ )
291
+ self._console.print(
292
+ f"[dim {colors['primary']}](First run downloads ~120MB, uses HuggingFace cache)[/{colors['primary']}]"
293
+ )
294
+
295
+ try:
296
+ # Suppress HuggingFace's internal progress bars to avoid flickering
297
+ import os
298
+ os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
299
+
300
+ if is_cached:
301
+ # Model is on disk - load silently (fast operation, no spinner needed)
302
+ loaded_model = SentenceTransformer(self.model_name, trust_remote_code=True)
303
+ self._dimension = loaded_model.get_sentence_embedding_dimension()
304
+ _GLOBAL_MODEL_CACHE[self.model_name] = loaded_model
305
+ else:
306
+ # First time download - show progress spinner
307
+ with Progress(
308
+ SpinnerColumn(),
309
+ TextColumn("[progress.description]{task.description}"),
310
+ console=self._console,
311
+ transient=True
312
+ ) as progress:
313
+ task = progress.add_task(
314
+ f"[bold {colors['primary']}]Downloading and initializing embedding model...[/bold {colors['primary']}]",
315
+ total=None
316
+ )
317
+ loaded_model = SentenceTransformer(self.model_name, trust_remote_code=True)
318
+ self._dimension = loaded_model.get_sentence_embedding_dimension()
319
+ _GLOBAL_MODEL_CACHE[self.model_name] = loaded_model
320
+
321
+ # Re-enable progress bars for other operations
322
+ os.environ.pop('HF_HUB_DISABLE_PROGRESS_BARS', None)
323
+
324
+ # self._console.print(
325
+ # f"[{colors['primary']}][OK][/{colors['primary']}] Embedding model ready "
326
+ # f"(dimension: {self._dimension})"
327
+ # )
328
+ logger.info(f"Embedding model loaded successfully (dimension: {self._dimension})")
329
+
330
+ except Exception as e:
331
+ logger.error(f"Failed to load embedding model {self.model_name}: {str(e)}")
332
+ raise
333
+
334
+ async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
335
+ """Get embeddings using the configured embedding model"""
336
+ try:
337
+ self._load_model()
338
+ model = _GLOBAL_MODEL_CACHE.get(self.model_name)
339
+ loop = asyncio.get_event_loop()
340
+ embeddings = await loop.run_in_executor(None, model.encode, texts)
341
+ return embeddings.tolist()
342
+ except Exception as e:
343
+ logger.error(f"Embedding generation failed for {self.model_name}: {str(e)}")
344
+ raise
345
+
346
+ def get_dimension(self) -> int:
347
+ if self._dimension is None:
348
+ self._load_model()
349
+ return self._dimension
350
+
351
+ def get_name(self) -> str:
352
+ return f"Granite-{self.model_name.split('/')[-1]}"
353
+
354
+
355
+ class BasicRAG:
356
+ """
357
+ Basic RAG (Retrieval-Augmented Generation) implementation
358
+ Handles text chunking, embedding generation, vector search, and context retrieval
359
+ """
360
+
361
+ def __init__(self, database_service, config_service, llm_client, openrouter_client=None):
362
+ """
363
+ Initialize BasicRAG with required services
364
+
365
+ Args:
366
+ database_service: Database service for storing/retrieving embeddings
367
+ config_service: Configuration service for RAG settings
368
+ llm_client: LLM client for generating responses (Groq - fallback)
369
+ openrouter_client: OpenRouter client for primary LLM (Kimi K2)
370
+ """
371
+ self.db_service = database_service
372
+ self.config_service = config_service
373
+ self.llm_client = llm_client
374
+ self.openrouter_client = openrouter_client
375
+
376
+ # Lazy initialization flags for embedding providers
377
+ self._embedding_providers_initialized = False
378
+ self._embedding_providers = []
379
+ self._current_embedding_provider = None
380
+
381
+ # Use OpenRouter as primary if available, otherwise fall back to Groq
382
+ # Can be overridden with RAG_LLM_PROVIDER env var
383
+ env_provider = os.getenv("RAG_LLM_PROVIDER", "").lower()
384
+ if env_provider:
385
+ self.llm_provider = env_provider
386
+ elif openrouter_client and openrouter_client.is_available:
387
+ self.llm_provider = "openrouter"
388
+ else:
389
+ self.llm_provider = "groq"
390
+
391
+ rag_config = config_service.get_rag_config()
392
+ embedding_config = config_service.get_embedding_config()
393
+
394
+ self.vector_collection = rag_config["vector_collection"]
395
+ self.chat_collection = rag_config["chat_collection"]
396
+ self.chunk_size = rag_config["chunk_size"]
397
+ self.chunk_overlap = rag_config["chunk_overlap"]
398
+ self.top_k_results = rag_config["top_k_results"]
399
+ self.ttl_hours = rag_config["ttl_hours"]
400
+
401
+ self.embedding_batch_size = embedding_config["batch_size"]
402
+ self.embedding_dimension = embedding_config["dimension_default"]
403
+ self._embedding_config = embedding_config
404
+
405
+ # In-memory embedding storage for current chat session
406
+ # Format: {chunk_id: {text, embedding, metadata}}
407
+ self._session_embeddings: Dict[str, Dict[str, Any]] = {}
408
+ self._current_session_id: Optional[str] = None
409
+
410
+ # In-memory session storage (fallback when database unavailable)
411
+ self._in_memory_sessions: Dict[str, Dict[str, Any]] = {}
412
+
413
+ self.console = Console()
414
+
415
+ logger.info("BasicRAG initialized (embedding providers lazy-loaded)")
416
+
417
+ @property
418
+ def embedding_providers(self):
419
+ """Lazy initialize embedding providers"""
420
+ if not self._embedding_providers_initialized:
421
+ self._embedding_providers_initialized = True
422
+ self._setup_embedding_providers(self._embedding_config)
423
+ return self._embedding_providers
424
+
425
+ @property
426
+ def current_embedding_provider(self):
427
+ """Get current embedding provider (lazy init if needed)"""
428
+ if not self._embedding_providers_initialized:
429
+ self._embedding_providers_initialized = True
430
+ self._setup_embedding_providers(self._embedding_config)
431
+ return self._current_embedding_provider
432
+
433
+ @current_embedding_provider.setter
434
+ def current_embedding_provider(self, value):
435
+ """Set current embedding provider"""
436
+ self._current_embedding_provider = value
437
+
438
+ def _setup_embedding_providers(self, embedding_config):
439
+ """
440
+ Setup embedding providers in order of preference
441
+
442
+ Order:
443
+ 1. Gemini (FREE API, if API key is configured)
444
+ 2. Granite/HuggingFace fallback models (run locally with 24h cache)
445
+
446
+ If Gemini API key is not available, automatically falls back to
447
+ local Granite model which is cached for 24 hours to avoid
448
+ repeated downloads.
449
+ """
450
+ primary_model = embedding_config["primary_model"]
451
+ fallback_1 = embedding_config["fallback_1"]
452
+ fallback_2 = embedding_config["fallback_2"]
453
+ enable_gemini = embedding_config["enable_gemini"]
454
+ enable_huggingface = embedding_config["enable_huggingface"]
455
+
456
+ # Check if Gemini API key is actually available
457
+ gemini_key_available = False
458
+ if API_CONFIG_AVAILABLE and api_config_manager:
459
+ gemini_key_available = api_config_manager.is_configured("gemini")
460
+ else:
461
+ # Fallback: check environment variable directly
462
+ gemini_key_available = bool(
463
+ os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
464
+ )
465
+
466
+ # Primary: Gemini (if enabled AND API key is available)
467
+ if enable_gemini and gemini_key_available and (primary_model.lower() == "gemini" or primary_model == ""):
468
+ try:
469
+ gemini_provider = GeminiEmbeddingProvider()
470
+ self._embedding_providers.append(gemini_provider)
471
+ logger.info("Gemini embedding provider initialized as PRIMARY (FREE API)")
472
+ except Exception as e:
473
+ logger.warning(f"Gemini embedding provider failed to initialize: {str(e)}")
474
+ elif enable_gemini and not gemini_key_available:
475
+ logger.info("Gemini API key not configured - will use local Granite model as fallback")
476
+
477
+ # If Gemini is not available OR primary is a HuggingFace model, use Granite
478
+ if enable_huggingface:
479
+ # If Gemini failed/unavailable, Granite becomes primary
480
+ if not self._embedding_providers:
481
+ try:
482
+ # Use Granite as primary when Gemini is unavailable
483
+ granite_model = fallback_1 or GraniteDoclingEmbeddingProvider.DEFAULT_MODEL
484
+ granite_provider = GraniteDoclingEmbeddingProvider(model_name=granite_model)
485
+ self._embedding_providers.append(granite_provider)
486
+ logger.info(f"Granite embedding provider initialized as PRIMARY (local): {granite_model}")
487
+ except Exception as e:
488
+ logger.warning(f"Granite embedding provider failed to initialize: {str(e)}")
489
+
490
+ # If primary is explicitly a HuggingFace model (not "gemini"), add it
491
+ elif primary_model.lower() != "gemini" and primary_model != "":
492
+ try:
493
+ primary_provider = GraniteDoclingEmbeddingProvider(model_name=primary_model)
494
+ self._embedding_providers.append(primary_provider)
495
+ logger.info(f"Primary HuggingFace embedding provider initialized: {primary_model}")
496
+ except Exception as e:
497
+ logger.warning(f"Primary embedding provider failed to initialize: {str(e)}")
498
+
499
+ # Add fallback (Granite) if not already primary
500
+ if fallback_1 and not any(
501
+ isinstance(p, GraniteDoclingEmbeddingProvider) and p.model_name == fallback_1
502
+ for p in self._embedding_providers
503
+ ):
504
+ try:
505
+ fallback_1_provider = GraniteDoclingEmbeddingProvider(model_name=fallback_1)
506
+ self._embedding_providers.append(fallback_1_provider)
507
+ logger.info(f"Fallback embedding provider initialized: {fallback_1}")
508
+ except Exception as e:
509
+ logger.warning(f"Fallback embedding provider failed: {str(e)}")
510
+
511
+ if self._embedding_providers:
512
+ self._current_embedding_provider = self._embedding_providers[0]
513
+ logger.info(f"Using embedding provider: {self._current_embedding_provider.get_name()}")
514
+ else:
515
+ # No providers available - this will be handled gracefully in chat
516
+ logger.debug("No embedding providers available - chat will show user-friendly message")
517
+
518
+ def is_embedding_available(self) -> bool:
519
+ """Check if any embedding provider is available for chat"""
520
+ # Trigger lazy initialization
521
+ _ = self.embedding_providers
522
+ return len(self._embedding_providers) > 0
523
+
524
+ def get_embedding_unavailable_message(self) -> str:
525
+ """Get user-friendly message explaining why embeddings are unavailable"""
526
+ # Check if Gemini API key is configured
527
+ gemini_configured = False
528
+ if API_CONFIG_AVAILABLE and api_config_manager:
529
+ gemini_configured = api_config_manager.is_configured("gemini")
530
+ else:
531
+ gemini_configured = bool(os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"))
532
+
533
+ if not ML_DEPENDENCIES_AVAILABLE and not gemini_configured:
534
+ return (
535
+ "Chat feature is temporarily unavailable.\n\n"
536
+ "To enable this feature, please configure your Gemini API key:\n"
537
+ " arionxiv settings\n\n"
538
+ "If you encounter persistent issues, please report at:\n"
539
+ " https://github.com/Arion-IT/ArionXiv/issues"
540
+ )
541
+ elif not ML_DEPENDENCIES_AVAILABLE:
542
+ return (
543
+ "Chat feature encountered an issue.\n\n"
544
+ "Please try again later or report at:\n"
545
+ " https://github.com/Arion-IT/ArionXiv/issues"
546
+ )
547
+ else:
548
+ return (
549
+ "Chat feature is temporarily unavailable.\n\n"
550
+ "Please try again later or report at:\n"
551
+ " https://github.com/Arion-IT/ArionXiv/issues"
552
+ )
553
+
554
+ async def get_embeddings(self, texts: Union[str, List[str]]) -> List[List[float]]:
555
+ """Get embeddings with automatic fallback"""
556
+ if isinstance(texts, str):
557
+ texts = [texts]
558
+
559
+ if not texts:
560
+ return []
561
+
562
+ for i, provider in enumerate(self.embedding_providers):
563
+ try:
564
+ embeddings = await provider.get_embeddings(texts)
565
+
566
+ if provider != self.current_embedding_provider:
567
+ self.current_embedding_provider = provider
568
+ logger.info(f"Switched to embedding provider: {provider.get_name()}")
569
+
570
+ return embeddings
571
+
572
+ except Exception as e:
573
+ logger.warning(f"Provider {provider.get_name()} failed: {str(e)}")
574
+ if i == len(self.embedding_providers) - 1:
575
+ raise RuntimeError(f"All embedding providers failed. Last error: {str(e)}")
576
+ continue
577
+
578
+ async def get_single_embedding(self, text: str) -> List[float]:
579
+ """Get embedding for a single text"""
580
+ embeddings = await self.get_embeddings([text])
581
+ return embeddings[0] if embeddings else []
582
+
583
+ def get_embedding_dimension(self) -> int:
584
+ """Get embedding dimension"""
585
+ if self.current_embedding_provider:
586
+ return self.current_embedding_provider.get_dimension()
587
+ return self.embedding_dimension
588
+
589
+ def get_embedding_provider_name(self) -> str:
590
+ """Get current provider name"""
591
+ if self.current_embedding_provider:
592
+ return self.current_embedding_provider.get_name()
593
+ return "None"
594
+
595
+ def ensure_embedding_model_loaded(self):
596
+ """Ensure the embedding model is loaded before starting batch operations.
597
+
598
+ This prevents the model download progress from interfering with
599
+ the embedding computation progress bar.
600
+ """
601
+ if self.current_embedding_provider:
602
+ # Trigger model loading by calling get_dimension which internally calls _load_model
603
+ try:
604
+ self.current_embedding_provider.get_dimension()
605
+ except Exception as e:
606
+ logger.warning(f"Failed to pre-load embedding model: {e}")
607
+
608
+ def _chunk_text(self, text: str) -> List[str]:
609
+ """Split text into overlapping chunks"""
610
+ if len(text) <= self.chunk_size:
611
+ return [text]
612
+
613
+ chunks = []
614
+ start = 0
615
+
616
+ while start < len(text):
617
+ end = start + self.chunk_size
618
+ chunk = text[start:end]
619
+
620
+ if end < len(text):
621
+ last_period = chunk.rfind('.')
622
+ if last_period > self.chunk_size * 0.7:
623
+ chunk = chunk[:last_period + 1]
624
+ end = start + last_period + 1
625
+
626
+ chunks.append(chunk.strip())
627
+ start = end - self.chunk_overlap
628
+
629
+ if start >= len(text):
630
+ break
631
+
632
+ return chunks
633
+
634
+ async def add_document_to_index(self, doc_id: str, text: str, metadata: Dict[str, Any] = None) -> bool:
635
+ """Add document to in-memory vector index for current session
636
+
637
+ First checks if embeddings are cached in the database (24-hour TTL).
638
+ If cached, loads them directly. Otherwise, computes and caches them.
639
+ """
640
+ try:
641
+ # Check if embeddings are already cached in the database
642
+ cached_embeddings = await self._get_cached_embeddings(doc_id)
643
+
644
+ if cached_embeddings:
645
+ # Load from cache
646
+ await self._load_embeddings_from_cache(cached_embeddings)
647
+ logger.info(f"Loaded {len(cached_embeddings)} cached embeddings for document {doc_id}")
648
+ return True
649
+
650
+ # No cache - compute embeddings
651
+ chunks = self._chunk_text(text)
652
+ embeddings = await self.get_embeddings(chunks)
653
+
654
+ for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
655
+ chunk_id = f"{doc_id}_chunk_{i}"
656
+ # Store in memory
657
+ self._session_embeddings[chunk_id] = {
658
+ 'doc_id': doc_id,
659
+ 'chunk_id': chunk_id,
660
+ 'text': chunk,
661
+ 'embedding': embedding,
662
+ 'metadata': metadata or {}
663
+ }
664
+
665
+ # Save to database cache for future use (24-hour TTL)
666
+ await self._save_embeddings_to_cache(doc_id, chunks, embeddings, metadata)
667
+
668
+ logger.info(f"Added {len(chunks)} chunks for document {doc_id} to in-memory index and cache")
669
+ return True
670
+
671
+ except Exception as e:
672
+ logger.error(f"Failed to add document {doc_id} to index: {str(e)}")
673
+ return False
674
+
675
+ async def add_document_to_index_with_progress(self, doc_id: str, text: str, metadata: Dict[str, Any] = None, console: Console = None) -> bool:
676
+ """Add document to in-memory vector index with progress bar
677
+
678
+ First checks if embeddings are cached in the database (24-hour TTL).
679
+ If cached, loads them directly. Otherwise, computes and caches them.
680
+ """
681
+ try:
682
+ colors = get_theme_colors()
683
+ console = console or self.console
684
+
685
+ # Check if embeddings are already cached in the database
686
+ cached_embeddings = await self._get_cached_embeddings(doc_id)
687
+
688
+ if cached_embeddings:
689
+ # Load from cache - much faster!
690
+ left_to_right_reveal(console, f"Loading cached embeddings ({len(cached_embeddings)} chunks)...", style=f"bold {colors['primary']}", duration=0.8)
691
+ await self._load_embeddings_from_cache(cached_embeddings)
692
+
693
+ # Note: We intentionally do NOT pre-load the embedding model here.
694
+ # Query embeddings will use Gemini API if available (fast, no download needed).
695
+ # The local Granite model will only be loaded lazily if Gemini fails.
696
+
697
+ logger.info(f"Loaded {len(cached_embeddings)} cached embeddings for document {doc_id}")
698
+ return True
699
+
700
+ # No cache - need to compute embeddings
701
+ # First, chunk the text
702
+ chunks = self._chunk_text(text)
703
+ total_chunks = len(chunks)
704
+
705
+ if total_chunks == 0:
706
+ return False
707
+
708
+ # Show subtle hint for large papers
709
+ if total_chunks > 20:
710
+ console.print(f"[white]Processing [bold {colors['primary']}]{total_chunks} chunks [/bold {colors['primary']}](this may take a moment for large papers)...[/white]")
711
+
712
+ # Ensure embedding model is loaded BEFORE showing the computation progress bar
713
+ # This prevents model download progress from interfering with embedding progress
714
+ self.ensure_embedding_model_loaded()
715
+
716
+ # Create progress bar for embedding computation
717
+ with Progress(
718
+ SpinnerColumn(),
719
+ TextColumn("[progress.description]{task.description}"),
720
+ BarColumn(bar_width=50),
721
+ TaskProgressColumn(),
722
+ TextColumn("-"),
723
+ TimeRemainingColumn(),
724
+ console=console,
725
+ transient=False
726
+ ) as progress:
727
+ task = progress.add_task(
728
+ f"[bold {colors['primary']}]Computing embeddings...",
729
+ total=total_chunks
730
+ )
731
+
732
+ # Process chunks in batches for the API
733
+ batch_size = 5
734
+ all_embeddings = []
735
+
736
+ for i in range(0, total_chunks, batch_size):
737
+ batch = chunks[i:i + batch_size]
738
+ batch_embeddings = await self.get_embeddings(batch)
739
+ all_embeddings.extend(batch_embeddings)
740
+
741
+ # Update progress for each chunk in the batch
742
+ progress.update(task, advance=len(batch))
743
+
744
+ # Store embeddings in memory
745
+ for i, (chunk, embedding) in enumerate(zip(chunks, all_embeddings)):
746
+ chunk_id = f"{doc_id}_chunk_{i}"
747
+ self._session_embeddings[chunk_id] = {
748
+ 'doc_id': doc_id,
749
+ 'chunk_id': chunk_id,
750
+ 'text': chunk,
751
+ 'embedding': embedding,
752
+ 'metadata': metadata or {}
753
+ }
754
+
755
+ # Save to database cache for future use (24-hour TTL)
756
+ await self._save_embeddings_to_cache(doc_id, chunks, all_embeddings, metadata)
757
+
758
+ logger.info(f"Added {total_chunks} chunks for document {doc_id} to in-memory index and cache")
759
+ return True
760
+
761
+ except Exception as e:
762
+ logger.error(f"Failed to add document {doc_id} to index: {str(e)}")
763
+ return False
764
+
765
+ def clear_session_embeddings(self):
766
+ """Clear in-memory embeddings when chat session ends"""
767
+ count = len(self._session_embeddings)
768
+ self._session_embeddings.clear()
769
+ self._current_session_id = None
770
+ logger.info(f"Cleared {count} embeddings from memory")
771
+
772
+ async def _get_cached_embeddings(self, doc_id: str) -> Optional[List[Dict[str, Any]]]:
773
+ """Check if embeddings for a document are cached (tries API first, then local DB)"""
774
+ try:
775
+ # First, try to get from API (cloud cache - accessible across devices)
776
+ try:
777
+ from ..cli.utils.api_client import api_client
778
+ api_result = await api_client.get_embeddings(doc_id)
779
+ if api_result.get("success"):
780
+ embeddings = api_result.get("embeddings", [])
781
+ chunks = api_result.get("chunks", [])
782
+ batches = api_result.get("batches", 1)
783
+
784
+ if embeddings and chunks:
785
+ logger.info(f"Found {len(embeddings)} cached embeddings from cloud ({batches} batches) for {doc_id}")
786
+
787
+ # Convert to the format expected by _load_embeddings_from_cache
788
+ cached = []
789
+ for i, (embedding, chunk) in enumerate(zip(embeddings, chunks)):
790
+ cached.append({
791
+ 'chunk_id': f"{doc_id}_chunk_{i}",
792
+ 'doc_id': doc_id,
793
+ 'chunk_text': chunk,
794
+ 'embedding': embedding,
795
+ 'expires_at': datetime.utcnow() + timedelta(hours=24)
796
+ })
797
+ return cached
798
+ except Exception as api_err:
799
+ logger.debug(f"Cloud cache not available, trying local: {api_err}")
800
+
801
+ # Fall back to local database cache
802
+ cached = await self.db_service.find_many(
803
+ self.vector_collection,
804
+ {
805
+ 'doc_id': doc_id,
806
+ 'expires_at': {'$gt': datetime.utcnow()}
807
+ },
808
+ limit=10000 # High limit to get all chunks for large papers
809
+ )
810
+
811
+ if cached and len(cached) > 0:
812
+ logger.info(f"Found {len(cached)} cached embeddings from local DB for {doc_id}")
813
+ return cached
814
+ return None
815
+
816
+ except Exception as e:
817
+ logger.warning(f"Failed to check cached embeddings: {str(e)}")
818
+ return None
819
+
820
+ async def _save_embeddings_to_cache(self, doc_id: str, chunks: List[str], embeddings: List[List[float]], metadata: Dict[str, Any] = None):
821
+ """Save embeddings to API and local database with 24-hour TTL"""
822
+ try:
823
+ # First, try to save to API (cloud storage - accessible across devices)
824
+ api_saved = False
825
+ try:
826
+ from ..cli.utils.api_client import api_client
827
+ api_result = await api_client.save_embeddings(doc_id, embeddings, chunks)
828
+ if api_result.get("success"):
829
+ batches = api_result.get("message", "")
830
+ logger.info(f"✓ Saved {len(embeddings)} embeddings to cloud cache for {doc_id}: {batches}")
831
+ api_saved = True
832
+ else:
833
+ error_msg = api_result.get("message", "Unknown error")
834
+ logger.warning(f"Cloud cache save failed for {doc_id}: {error_msg}")
835
+ except Exception as api_err:
836
+ # Silently fall back to local cache - this is expected when offline or API unavailable
837
+ logger.debug(f"Using local cache only: {api_err}")
838
+
839
+ # Always save to local DB as backup
840
+ expires_at = datetime.utcnow() + timedelta(hours=24)
841
+
842
+ # Delete any existing embeddings for this document first
843
+ await self.db_service.delete_many(
844
+ self.vector_collection,
845
+ {'doc_id': doc_id}
846
+ )
847
+
848
+ # Save new embeddings
849
+ documents = []
850
+ for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
851
+ chunk_id = f"{doc_id}_chunk_{i}"
852
+ documents.append({
853
+ 'doc_id': doc_id,
854
+ 'chunk_id': chunk_id,
855
+ 'text': chunk,
856
+ 'embedding': embedding,
857
+ 'metadata': metadata or {},
858
+ 'created_at': datetime.utcnow(),
859
+ 'expires_at': expires_at
860
+ })
861
+
862
+ if documents:
863
+ await self.db_service.insert_many(self.vector_collection, documents)
864
+ logger.info(f"Saved {len(documents)} embeddings to local cache for document {doc_id} (expires in 24h)")
865
+
866
+ except Exception as e:
867
+ logger.warning(f"Failed to save embeddings to local cache: {str(e)}")
868
+
869
+ async def _load_embeddings_from_cache(self, cached_embeddings: List[Dict[str, Any]], cached_chunks: List[str] = None):
870
+ """Load cached embeddings into session memory
871
+
872
+ Args:
873
+ cached_embeddings: Either a list of raw embedding vectors (from API), or
874
+ a list of dict objects with 'embedding', 'text', etc. (from local DB)
875
+ cached_chunks: Optional list of text chunks (only provided when embeddings are raw vectors from API)
876
+ """
877
+ # Handle API format (parallel lists of embeddings and chunks)
878
+ if cached_chunks and cached_embeddings and isinstance(cached_embeddings[0], list):
879
+ # API format: embeddings is a list of vectors, chunks is a list of strings
880
+ for i, (embedding, chunk) in enumerate(zip(cached_embeddings, cached_chunks)):
881
+ chunk_id = f"cached_chunk_{i}"
882
+ self._session_embeddings[chunk_id] = {
883
+ 'doc_id': 'cached',
884
+ 'chunk_id': chunk_id,
885
+ 'text': chunk,
886
+ 'embedding': embedding,
887
+ 'metadata': {}
888
+ }
889
+ logger.info(f"Loaded {len(cached_embeddings)} embeddings from API cache to session memory")
890
+ else:
891
+ # Local DB format: list of dicts with 'embedding', 'text', etc.
892
+ for doc in cached_embeddings:
893
+ chunk_id = doc.get('chunk_id')
894
+ self._session_embeddings[chunk_id] = {
895
+ 'doc_id': doc.get('doc_id'),
896
+ 'chunk_id': chunk_id,
897
+ 'text': doc.get('text'),
898
+ 'embedding': doc.get('embedding'),
899
+ 'metadata': doc.get('metadata', {})
900
+ }
901
+ logger.info(f"Loaded {len(cached_embeddings)} embeddings from local cache to session memory")
902
+
903
+ async def search_similar_documents(self, query: str, filters: Dict[str, Any] = None) -> List[Dict[str, Any]]:
904
+ """Search for similar documents using cosine similarity (in-memory)"""
905
+ try:
906
+ query_embedding = await self.get_single_embedding(query)
907
+
908
+ # Search in-memory embeddings
909
+ scored_docs = []
910
+ for chunk_id, doc in self._session_embeddings.items():
911
+ # Apply metadata filters if provided
912
+ if filters:
913
+ match = True
914
+ for key, value in filters.items():
915
+ # Handle nested keys like 'metadata.type'
916
+ keys = key.split('.')
917
+ doc_value = doc
918
+ for k in keys:
919
+ doc_value = doc_value.get(k, {}) if isinstance(doc_value, dict) else None
920
+ if doc_value != value:
921
+ match = False
922
+ break
923
+ if not match:
924
+ continue
925
+
926
+ doc_embedding = doc.get('embedding', [])
927
+ if doc_embedding:
928
+ score = await self.compute_similarity(query_embedding, doc_embedding)
929
+ scored_docs.append({
930
+ 'doc_id': doc.get('doc_id'),
931
+ 'chunk_id': doc.get('chunk_id'),
932
+ 'text': doc.get('text'),
933
+ 'metadata': doc.get('metadata', {}),
934
+ 'score': score
935
+ })
936
+
937
+ # Sort by score descending and take top k
938
+ scored_docs.sort(key=lambda x: x['score'], reverse=True)
939
+ return scored_docs[:self.top_k_results]
940
+
941
+ except Exception as e:
942
+ logger.error(f"Vector search failed: {str(e)}")
943
+ return []
944
+
945
+ async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
946
+ """Compute cosine similarity between two embeddings"""
947
+ try:
948
+ if not ML_DEPENDENCIES_AVAILABLE:
949
+ return 0.0
950
+
951
+ vec1 = np.array(embedding1)
952
+ vec2 = np.array(embedding2)
953
+
954
+ dot_product = np.dot(vec1, vec2)
955
+ norm1 = np.linalg.norm(vec1)
956
+ norm2 = np.linalg.norm(vec2)
957
+
958
+ if norm1 == 0 or norm2 == 0:
959
+ return 0.0
960
+
961
+ similarity = dot_product / (norm1 * norm2)
962
+ return float(similarity)
963
+
964
+ except Exception as e:
965
+ logger.error(f"Similarity computation failed: {str(e)}")
966
+ return 0.0
967
+
968
+ async def start_chat_session(self, papers: List[Dict[str, Any]], user_id: str = "default"):
969
+ """Start interactive chat session with a single paper (v1)
970
+
971
+ Embeddings are stored in memory during the session and cleared when done.
972
+ Chat history is persisted to DB with 24-hour TTL for resumption.
973
+ """
974
+ try:
975
+ if not papers:
976
+ colors = get_theme_colors()
977
+ self.console.print(f"[{colors['error']}]No papers provided for chat session[/{colors['error']}]")
978
+ return
979
+
980
+ # V1: Limit to single paper
981
+ paper = papers[0]
982
+ paper_id = paper.get('arxiv_id') or paper.get('id')
983
+
984
+ if not paper_id:
985
+ colors = get_theme_colors()
986
+ self.console.print(f"[{colors['error']}]Paper has no ID[/{colors['error']}]")
987
+ return
988
+
989
+ colors = get_theme_colors()
990
+
991
+ # Check if cached embeddings are available or if we can generate new ones
992
+ cached_embeddings = paper.get('_cached_embeddings')
993
+ cached_chunks = paper.get('_cached_chunks')
994
+
995
+ # If no cached embeddings, check if embedding providers are available
996
+ if not cached_embeddings and not self.is_embedding_available():
997
+ # Show graceful error message
998
+ self.console.print(Panel(
999
+ f"[{colors['warning']}]{self.get_embedding_unavailable_message()}[/{colors['warning']}]",
1000
+ title=f"[bold {colors['warning']}]Feature Unavailable[/bold {colors['warning']}]",
1001
+ border_style=f"bold {colors['warning']}"
1002
+ ))
1003
+ return
1004
+
1005
+ # Clear any previous session embeddings
1006
+ self.clear_session_embeddings()
1007
+
1008
+ # Check if cached embeddings were passed (already fetched from API/DB)
1009
+ if cached_embeddings:
1010
+ # Load cached embeddings directly into session memory
1011
+ await self._load_embeddings_from_cache(cached_embeddings, cached_chunks)
1012
+ logger.info(f"Loaded {len(cached_embeddings)} pre-cached embeddings for paper {paper_id}")
1013
+ else:
1014
+ # Generate embeddings and store in memory - with progress bar
1015
+ paper_text = self._extract_paper_text(paper)
1016
+ if paper_text:
1017
+ success = await self.add_document_to_index_with_progress(
1018
+ paper_id,
1019
+ paper_text,
1020
+ {'type': 'paper', 'title': paper.get('title', '')},
1021
+ console=self.console
1022
+ )
1023
+ if not success:
1024
+ # Embedding failed - show graceful message
1025
+ self.console.print(Panel(
1026
+ f"[{colors['warning']}]{self.get_embedding_unavailable_message()}[/{colors['warning']}]",
1027
+ title=f"[bold {colors['warning']}]Feature Unavailable[/bold {colors['warning']}]",
1028
+ border_style=f"bold {colors['warning']}"
1029
+ ))
1030
+ return
1031
+
1032
+ # Create unique session ID
1033
+ import uuid
1034
+ session_id = f"{user_id}_{paper_id}_{uuid.uuid4().hex[:8]}"
1035
+ self._current_session_id = session_id
1036
+
1037
+ # Create session document with 24-hour TTL
1038
+ # Format authors list for display
1039
+ authors = paper.get('authors', [])
1040
+ if isinstance(authors, list):
1041
+ paper_authors = ', '.join(authors[:5]) # Limit to first 5 authors
1042
+ if len(authors) > 5:
1043
+ paper_authors += f' et al. ({len(authors)} authors)'
1044
+ else:
1045
+ paper_authors = str(authors) if authors else 'Unknown'
1046
+
1047
+ session_doc = {
1048
+ 'session_id': session_id,
1049
+ 'paper_id': paper_id, # Single paper in v1
1050
+ 'paper_title': paper.get('title', ''),
1051
+ 'paper_authors': paper_authors,
1052
+ 'paper_published': paper.get('published', '')[:10] if paper.get('published') else 'Unknown',
1053
+ 'user_id': user_id,
1054
+ 'created_at': datetime.utcnow(),
1055
+ 'last_activity': datetime.utcnow(),
1056
+ 'expires_at': datetime.utcnow() + timedelta(hours=24), # 24-hour TTL
1057
+ 'messages': []
1058
+ }
1059
+
1060
+ # Store in-memory as fallback (always works)
1061
+ self._in_memory_sessions[session_id] = session_doc
1062
+
1063
+ # Try to persist to Vercel API first (cloud storage)
1064
+ session_saved = False
1065
+ api_error = None
1066
+ try:
1067
+ from ..cli.utils.api_client import api_client
1068
+ api_result = await api_client.create_chat_session(
1069
+ paper_id=paper_id,
1070
+ title=paper.get('title', paper_id)
1071
+ )
1072
+ if api_result.get("success"):
1073
+ # Update in-memory session with API session_id for consistency
1074
+ api_session_id = api_result.get('session_id')
1075
+ # Store the API session ID for later updates
1076
+ self._in_memory_sessions[session_id]['api_session_id'] = api_session_id
1077
+ logger.info(f"Chat session saved to cloud: {api_session_id}")
1078
+ session_saved = True
1079
+ else:
1080
+ api_error = f"API failure: {api_result}"
1081
+ logger.warning(api_error)
1082
+ except Exception as api_err:
1083
+ api_error = f"API error: {api_err}"
1084
+ logger.warning(f"Session not saved to API: {api_err}")
1085
+
1086
+ # Also try local database as backup (regardless of API success)
1087
+ try:
1088
+ await self.db_service.insert_one(self.chat_collection, session_doc)
1089
+ logger.info(f"Chat session saved to local DB: {session_id}")
1090
+ session_saved = True
1091
+ except Exception as db_err:
1092
+ logger.debug(f"Session not saved to local database: {db_err}")
1093
+
1094
+ if not session_saved:
1095
+ logger.warning(f"Chat session only stored in-memory: {session_id}")
1096
+ if api_error:
1097
+ logger.warning(f"API save failed: {api_error}")
1098
+
1099
+ self.console.print(Panel(
1100
+ f"[bold {colors['primary']}]Chat Session Started[/bold {colors['primary']}]\n"
1101
+ f"Paper: [bold {colors['primary']}] {paper.get('title', paper_id)}[/bold {colors['primary']}]\n"
1102
+ f"Chunks indexed: [bold {colors['primary']}] {len(self._session_embeddings)}[/bold {colors['primary']}]\n"
1103
+ f"Type [bold {colors['primary']}]'quit'[/bold {colors['primary']}] or [bold {colors['primary']}]'exit'[/bold {colors['primary']}] to end the chat.",
1104
+ title=f"[bold]ArionXiv Paper Chat[/bold]",
1105
+ border_style=f"bold {colors['primary']}"
1106
+ ))
1107
+
1108
+ try:
1109
+ await self._run_chat_loop(session_id)
1110
+ finally:
1111
+ # Always clean up embeddings when session ends
1112
+ self.clear_session_embeddings()
1113
+
1114
+ except Exception as e:
1115
+ logger.error(f"Chat session failed: {str(e)}")
1116
+ colors = get_theme_colors()
1117
+ self.console.print(f"[{colors['error']}]Chat session failed: {str(e)}[/{colors['error']}]")
1118
+ # Clean up on error too
1119
+ self.clear_session_embeddings()
1120
+
1121
+ async def continue_chat_session(self, session: Dict[str, Any], paper_info: Dict[str, Any]):
1122
+ """Continue an existing chat session
1123
+
1124
+ Reloads the paper embeddings and continues the conversation.
1125
+ Extends the session TTL by 24 hours.
1126
+ """
1127
+ try:
1128
+ colors = get_theme_colors()
1129
+ session_id = session.get('session_id')
1130
+ paper_title = session.get('paper_title', paper_info.get('title', 'Unknown Paper'))
1131
+ messages = session.get('messages', [])
1132
+
1133
+ if not session_id:
1134
+ self.console.print(f"[{colors['error']}]Invalid session: no session_id[/{colors['error']}]")
1135
+ return
1136
+
1137
+ # Check if cached embeddings are available or if we can generate new ones
1138
+ cached_embeddings = paper_info.get('_cached_embeddings')
1139
+ cached_chunks = paper_info.get('_cached_chunks')
1140
+
1141
+ # If no cached embeddings, check if embedding providers are available
1142
+ if not cached_embeddings and not self.is_embedding_available():
1143
+ # Show graceful error message
1144
+ self.console.print(Panel(
1145
+ f"[{colors['warning']}]{self.get_embedding_unavailable_message()}[/{colors['warning']}]",
1146
+ title=f"[bold {colors['warning']}]Feature Unavailable[/bold {colors['warning']}]",
1147
+ border_style=f"bold {colors['warning']}"
1148
+ ))
1149
+ return
1150
+
1151
+ # Extract and format paper metadata for context
1152
+ # Format authors list for display
1153
+ authors = paper_info.get('authors', session.get('paper_authors', []))
1154
+ if isinstance(authors, list):
1155
+ paper_authors = ', '.join(authors) # Limit to first 5 authors
1156
+ if len(authors) > 5:
1157
+ paper_authors += f' et al. ({len(authors)} authors)'
1158
+ else:
1159
+ paper_authors = str(authors) if authors else 'Unknown'
1160
+
1161
+ # Get published date
1162
+ published = paper_info.get('published', session.get('paper_published', ''))
1163
+ paper_published = published[:10] if published else 'Unknown'
1164
+
1165
+ # Update session with paper metadata (for use in _chat_with_session)
1166
+ session['paper_title'] = paper_title
1167
+ session['paper_authors'] = paper_authors
1168
+ session['paper_published'] = paper_published
1169
+
1170
+ # Clear any previous session embeddings
1171
+ self.clear_session_embeddings()
1172
+
1173
+ # Use cached embeddings if available, otherwise generate new ones
1174
+ if cached_embeddings:
1175
+ # Use pre-loaded cached embeddings directly
1176
+ await self._load_embeddings_from_cache(cached_embeddings, cached_chunks)
1177
+ logger.info(f"Loaded {len(cached_embeddings)} cached embeddings for session")
1178
+ else:
1179
+ # Re-index the paper content
1180
+ paper_text = self._extract_paper_text(paper_info)
1181
+ if paper_text:
1182
+ paper_id = paper_info.get('arxiv_id') or paper_info.get('id')
1183
+ success = await self.add_document_to_index_with_progress(
1184
+ paper_id,
1185
+ paper_text,
1186
+ {'type': 'paper', 'title': paper_title},
1187
+ console=self.console
1188
+ )
1189
+ if not success:
1190
+ # Embedding failed - show graceful message
1191
+ self.console.print(Panel(
1192
+ f"[{colors['warning']}]{self.get_embedding_unavailable_message()}[/{colors['warning']}]",
1193
+ title=f"[bold {colors['warning']}]Feature Unavailable[/bold {colors['warning']}]",
1194
+ border_style=f"bold {colors['warning']}"
1195
+ ))
1196
+ return
1197
+
1198
+ self._current_session_id = session_id
1199
+ # Store session in memory so _chat_with_session can find it
1200
+ self._in_memory_sessions[session_id] = session
1201
+
1202
+ # Extend the session TTL by 24 hours
1203
+ await self.db_service.extend_chat_session_ttl(session_id, hours=24)
1204
+
1205
+ # Show session info with previous message count
1206
+ self.console.print(Panel(
1207
+ f"[bold {colors['primary']}]Continuing Chat Session[/bold {colors['primary']}]\n"
1208
+ f"Paper: [bold {colors['primary']}]{paper_title}[/bold {colors['primary']}]\n"
1209
+ f"Previous messages: [bold {colors['primary']}]{len(messages)}[/bold {colors['primary']}]\n"
1210
+ f"Chunks indexed: [bold {colors['primary']}]{len(self._session_embeddings)}[/bold {colors['primary']}]\n"
1211
+ f"Session extended by 24 hours.\n"
1212
+ f"Type [bold {colors['primary']}]'quit'[/bold {colors['primary']}] or [bold {colors['primary']}]'exit'[/bold {colors['primary']}] to end the chat.",
1213
+ title=f"[bold]ArionXiv Paper Chat - Resumed[/bold]",
1214
+ border_style=f"bold {colors['primary']}"
1215
+ ))
1216
+
1217
+ # Show a summary of recent conversation if there are messages
1218
+ if messages:
1219
+ # Show last 8 Q&A pairs (16 messages total)
1220
+ num_pairs = min(8, len(messages) // 2)
1221
+ if num_pairs > 0:
1222
+ recent = messages[-(num_pairs * 2):]
1223
+ else:
1224
+ recent = messages # Show whatever we have
1225
+
1226
+ left_to_right_reveal(self.console, f"\nRecent conversation ({num_pairs} Q&A):", style=f"bold {colors['primary']}", duration=0.8)
1227
+ for msg in recent:
1228
+ role = "You" if msg.get('type') == 'user' else "Assistant"
1229
+ content = msg.get('content', '')
1230
+ # Truncate long messages for display
1231
+ display_content = content[:150] + "..." if len(content) > 150 else content
1232
+ self.console.print(f"[dim {colors['primary']}]{role}: {display_content}[/dim {colors['primary']}]")
1233
+
1234
+ try:
1235
+ await self._run_chat_loop(session_id)
1236
+ finally:
1237
+ self.clear_session_embeddings()
1238
+
1239
+ except Exception as e:
1240
+ logger.error(f"Continue chat session failed: {str(e)}")
1241
+ colors = get_theme_colors()
1242
+ self.console.print(f"[{colors['error']}]Failed to continue session: {str(e)}[/{colors['error']}]")
1243
+ self.clear_session_embeddings()
1244
+
1245
+ async def _run_chat_loop(self, session_id: str):
1246
+ """Run the chat interaction loop"""
1247
+ colors = get_theme_colors()
1248
+ while True:
1249
+ message = Prompt.ask(f"\n[bold {colors['primary']}]You[/bold {colors['primary']}]")
1250
+
1251
+ if message.lower() in ['quit', 'exit', 'q']:
1252
+ left_to_right_reveal(self.console, "\nEnding chat session. Goodbye!", style=f"bold {colors['primary']}", duration=1.0)
1253
+ break
1254
+
1255
+ with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=self.console) as progress:
1256
+ task = progress.add_task(f"[bold {colors['primary']}]Thinking...", total=None)
1257
+ result = await self._chat_with_session(session_id, message)
1258
+
1259
+ if result['success']:
1260
+ # Stream the response over 2 seconds
1261
+ stream_markdown_response(
1262
+ self.console,
1263
+ result['response'],
1264
+ panel_title=f"[bold {colors['primary']}]ArionXiv Assistant[/bold {colors['primary']}]",
1265
+ border_style=colors['primary'],
1266
+ duration=1.0
1267
+ )
1268
+
1269
+ # Build info line with chunks and model name
1270
+ info_parts = []
1271
+ if result['relevant_chunks'] > 0:
1272
+ info_parts.append(f"Used {result['relevant_chunks']} relevant content chunks")
1273
+ if result.get('model_display'):
1274
+ info_parts.append(f"• Model: {result['model_display']}")
1275
+
1276
+ if info_parts:
1277
+ info_text = " ".join(info_parts)
1278
+ left_to_right_reveal(self.console, info_text, style=f"dim {colors['muted']}", duration=1.0)
1279
+ else:
1280
+ left_to_right_reveal(self.console, f"Error: {result['error']}", style=f"bold {colors['error']}", duration=1.0)
1281
+
1282
+ def _show_post_chat_commands(self):
1283
+ """Show helpful commands after chat session ends"""
1284
+ colors = get_theme_colors()
1285
+
1286
+ commands = [
1287
+ ("arionxiv chat", "Start a new chat session"),
1288
+ ("arionxiv search <query>", "Search for more papers"),
1289
+ ("arionxiv settings papers", "Manage your saved papers"),
1290
+ ("arionxiv trending", "See trending papers"),
1291
+ ("arionxiv daily", "Get your daily paper digest"),
1292
+ ]
1293
+
1294
+ self.console.print()
1295
+ self.console.print(Panel(
1296
+ "\n".join([
1297
+ f"[bold {colors['primary']}]{cmd}[/bold {colors['primary']}] [white]→ {desc}[/white]"
1298
+ for cmd, desc in commands
1299
+ ]),
1300
+ title=f"[bold {colors['primary']}]What's Next?[/bold {colors['primary']}]",
1301
+ border_style=f"bold {colors['primary']}",
1302
+ padding=(1, 2)
1303
+ ))
1304
+
1305
+ async def _chat_with_session(self, session_id: str, message: str) -> Dict[str, Any]:
1306
+ """Process a chat message and generate response"""
1307
+ try:
1308
+ # Try database first, fall back to in-memory
1309
+ session = None
1310
+ try:
1311
+ session = await self.db_service.find_one(self.chat_collection, {'session_id': session_id})
1312
+ except Exception:
1313
+ pass
1314
+
1315
+ # Fall back to in-memory session
1316
+ if not session:
1317
+ session = self._in_memory_sessions.get(session_id)
1318
+
1319
+ if not session:
1320
+ return {'success': False, 'error': 'Session not found'}
1321
+
1322
+ relevant_chunks = await self.search_similar_documents(message, {'metadata.type': 'paper'})
1323
+ context = "\n\n".join([chunk['text'] for chunk in relevant_chunks[:10]]) # Increased from 5 to 10 chunks for richer context
1324
+
1325
+ # Get conversation history for context
1326
+ chat_history = session.get('messages', [])
1327
+
1328
+ # Get paper metadata for context
1329
+ paper_title = session.get('paper_title', session.get('title', 'Unknown Paper'))
1330
+ paper_authors = session.get('paper_authors', 'Unknown')
1331
+ paper_published = session.get('paper_published', 'Unknown')
1332
+
1333
+ # Determine which LLM to use and generate response
1334
+ model_display = ""
1335
+ success = False
1336
+ response_text = ""
1337
+ error_msg = ""
1338
+
1339
+ # Try OpenRouter for chat, fallback to hosted API
1340
+ if self.openrouter_client and self.openrouter_client.is_available:
1341
+ try:
1342
+ result = await self.openrouter_client.chat(
1343
+ message=message,
1344
+ context=context,
1345
+ history=chat_history,
1346
+ paper_title=paper_title,
1347
+ paper_authors=paper_authors,
1348
+ paper_published=paper_published
1349
+ )
1350
+ if result.get('success'):
1351
+ response_text, model_display, success = result['response'], result.get('model_display', 'OpenRouter'), True
1352
+ else:
1353
+ error_msg = result.get('error', 'OpenRouter failed')
1354
+ except Exception as e:
1355
+ logger.debug(f"OpenRouter error: {e}")
1356
+
1357
+ # Hosted API Fallback (using developer keys on backend)
1358
+ if not success:
1359
+ try:
1360
+ from ..cli.utils.api_client import api_client
1361
+ paper_id = session.get('arxiv_id') or session.get('paper_id')
1362
+ paper_title = session.get('title') or session.get('paper_title')
1363
+ # Pass RAG context to API for paper-aware responses
1364
+ result = await api_client.send_chat_message(
1365
+ message=message,
1366
+ paper_id=paper_id,
1367
+ session_id=session_id,
1368
+ context=context, # Send RAG context
1369
+ paper_title=paper_title # Send paper title
1370
+ )
1371
+ if result.get('success'):
1372
+ response_text = result['response']
1373
+ model_display = result.get('model', 'ArionXiv Cloud')
1374
+ success = True
1375
+ else:
1376
+ error_msg = result.get('error', 'Hosted API failed')
1377
+ except Exception as e:
1378
+ # Extract meaningful error message from APIClientError
1379
+ if hasattr(e, 'message') and e.message:
1380
+ # Clean up the error message for user display
1381
+ msg = e.message
1382
+ if "serverless timeout" in msg.lower():
1383
+ error_msg = "Chat service timeout. For reliable chat, run 'arionxiv settings api' to set your own OPENROUTER_API_KEY."
1384
+ elif "503" in str(getattr(e, 'status_code', '')) or "unavailable" in msg.lower():
1385
+ error_msg = "Chat service temporarily unavailable. Set your OPENROUTER_API_KEY via 'arionxiv settings api' for uninterrupted chat."
1386
+ else:
1387
+ error_msg = f"Chat unavailable: {msg}"
1388
+ elif hasattr(e, 'status_code') and e.status_code:
1389
+ if e.status_code == 503:
1390
+ error_msg = "Chat service temporarily unavailable. For reliable chat, set your OPENROUTER_API_KEY via 'arionxiv settings api'."
1391
+ else:
1392
+ error_msg = f"Chat unavailable: API error {e.status_code}"
1393
+ else:
1394
+ error_msg = f"Chat unavailable: {str(e) or 'Unknown error'}"
1395
+ logger.debug(f"Hosted API error: {e}")
1396
+
1397
+ if not success:
1398
+ return {'success': False, 'error': error_msg or 'Failed to generate response'}
1399
+
1400
+ # Update in-memory session
1401
+ if session_id in self._in_memory_sessions:
1402
+ self._in_memory_sessions[session_id]['messages'].extend([
1403
+ {'type': 'user', 'content': message, 'timestamp': datetime.utcnow()},
1404
+ {'type': 'assistant', 'content': response_text, 'timestamp': datetime.utcnow()}
1405
+ ])
1406
+ self._in_memory_sessions[session_id]['last_activity'] = datetime.utcnow()
1407
+
1408
+ # Try to persist to Vercel API (cloud storage)
1409
+ try:
1410
+ from ..cli.utils.api_client import api_client
1411
+ # Get full message history from in-memory session
1412
+ if session_id in self._in_memory_sessions:
1413
+ # Use the API session ID (from MongoDB) for updates
1414
+ api_session_id = self._in_memory_sessions[session_id].get('api_session_id')
1415
+ if api_session_id:
1416
+ all_messages = self._in_memory_sessions[session_id].get('messages', [])
1417
+ # Convert datetime objects to ISO strings for JSON serialization
1418
+ serializable_messages = []
1419
+ for msg in all_messages:
1420
+ serializable_messages.append({
1421
+ 'type': msg.get('type'),
1422
+ 'content': msg.get('content'),
1423
+ 'timestamp': msg.get('timestamp').isoformat() if hasattr(msg.get('timestamp'), 'isoformat') else str(msg.get('timestamp'))
1424
+ })
1425
+ await api_client.update_chat_session(api_session_id, serializable_messages)
1426
+ logger.debug(f"Messages saved to API for session {api_session_id}")
1427
+ except Exception as api_err:
1428
+ logger.debug(f"Failed to save messages to API: {api_err}")
1429
+
1430
+ # Try to persist to local database (may fail)
1431
+ try:
1432
+ await self.db_service.update_one(
1433
+ self.chat_collection,
1434
+ {'session_id': session_id},
1435
+ {
1436
+ '$push': {
1437
+ 'messages': {
1438
+ '$each': [
1439
+ {'type': 'user', 'content': message, 'timestamp': datetime.utcnow()},
1440
+ {'type': 'assistant', 'content': response_text, 'timestamp': datetime.utcnow()}
1441
+ ]
1442
+ }
1443
+ },
1444
+ '$set': {'last_activity': datetime.utcnow()}
1445
+ }
1446
+ )
1447
+ except Exception:
1448
+ pass # In-memory session is already updated
1449
+
1450
+ return {
1451
+ 'success': True,
1452
+ 'response': response_text,
1453
+ 'relevant_chunks': len(relevant_chunks),
1454
+ 'session_id': session_id,
1455
+ 'model_display': model_display
1456
+ }
1457
+
1458
+ except Exception as e:
1459
+ logger.error(f"Chat failed for session {session_id}: {str(e)}")
1460
+ return {'success': False, 'error': f'Chat failed: {str(e)}'}
1461
+
1462
+ def _extract_paper_text(self, paper: Dict[str, Any]) -> str:
1463
+ """Extract text content from paper for indexing"""
1464
+ text_parts = []
1465
+
1466
+ if paper.get('title'):
1467
+ text_parts.append(paper['title'])
1468
+
1469
+ if paper.get('abstract'):
1470
+ text_parts.append(paper['abstract'])
1471
+
1472
+ if paper.get('full_text'):
1473
+ text_parts.append(paper['full_text'])
1474
+
1475
+ return '\n\n'.join(text_parts)
1476
+
1477
+ def _build_chat_prompt(self, session: Dict[str, Any], message: str, context: str) -> str:
1478
+ """Build chat prompt with context"""
1479
+ from ..prompts import format_prompt
1480
+
1481
+ chat_history = session.get('messages', [])
1482
+
1483
+ history_text = ""
1484
+ recent_messages = chat_history[-6:] if len(chat_history) > 6 else chat_history
1485
+
1486
+ for msg in recent_messages:
1487
+ role = "User" if msg['type'] == 'user' else "Assistant"
1488
+ history_text += f"{role}: {msg['content']}\n"
1489
+
1490
+ return format_prompt("rag_chat",
1491
+ context=context,
1492
+ history=history_text,
1493
+ message=message)
1494
+
1495
+ def _parse_llm_response(self, response: Any) -> Tuple[bool, str, str]:
1496
+ """Normalize LLM responses that may return strings or dictionaries"""
1497
+ if isinstance(response, dict):
1498
+ if response.get('success', True) and isinstance(response.get('content'), str):
1499
+ content = response['content'].strip()
1500
+ if content:
1501
+ return True, content, ""
1502
+ return False, "", response.get('error', 'LLM response missing content')
1503
+ if isinstance(response, str):
1504
+ text = response.strip()
1505
+ if text and not text.startswith('Error'):
1506
+ return True, text, ""
1507
+ return False, "", text or 'LLM returned empty response'
1508
+ if response is None:
1509
+ return False, "", 'LLM returned no response'
1510
+ return False, "", 'Unexpected LLM response type'
1511
+
1512
+ async def cleanup_expired_data(self):
1513
+ """Clean up expired embeddings and chat sessions"""
1514
+ try:
1515
+ cutoff_time = datetime.utcnow()
1516
+
1517
+ deleted_embeddings = await self.db_service.delete_many(
1518
+ self.vector_collection,
1519
+ {'expires_at': {'$lt': cutoff_time}}
1520
+ )
1521
+
1522
+ chat_cutoff = datetime.utcnow() - timedelta(days=7)
1523
+ deleted_sessions = await self.db_service.delete_many(
1524
+ self.chat_collection,
1525
+ {'last_activity': {'$lt': chat_cutoff}}
1526
+ )
1527
+
1528
+ logger.info(f"RAG cleanup completed: deleted {deleted_embeddings} embeddings, {deleted_sessions} sessions")
1529
+
1530
+ except Exception as e:
1531
+ logger.error(f"RAG cleanup failed: {str(e)}")