cite-agent 1.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cite-agent might be problematic. Click here for more details.

Files changed (42) hide show
  1. cite_agent/__init__.py +1 -1
  2. cite_agent/account_client.py +19 -46
  3. cite_agent/agent_backend_only.py +30 -4
  4. cite_agent/cli.py +24 -26
  5. cite_agent/cli_conversational.py +294 -0
  6. cite_agent/enhanced_ai_agent.py +2776 -118
  7. cite_agent/setup_config.py +5 -21
  8. cite_agent/streaming_ui.py +252 -0
  9. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/METADATA +4 -3
  10. cite_agent-1.0.5.dist-info/RECORD +50 -0
  11. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/top_level.txt +1 -0
  12. src/__init__.py +1 -0
  13. src/services/__init__.py +132 -0
  14. src/services/auth_service/__init__.py +3 -0
  15. src/services/auth_service/auth_manager.py +33 -0
  16. src/services/graph/__init__.py +1 -0
  17. src/services/graph/knowledge_graph.py +194 -0
  18. src/services/llm_service/__init__.py +5 -0
  19. src/services/llm_service/llm_manager.py +495 -0
  20. src/services/paper_service/__init__.py +5 -0
  21. src/services/paper_service/openalex.py +231 -0
  22. src/services/performance_service/__init__.py +1 -0
  23. src/services/performance_service/rust_performance.py +395 -0
  24. src/services/research_service/__init__.py +23 -0
  25. src/services/research_service/chatbot.py +2056 -0
  26. src/services/research_service/citation_manager.py +436 -0
  27. src/services/research_service/context_manager.py +1441 -0
  28. src/services/research_service/conversation_manager.py +597 -0
  29. src/services/research_service/critical_paper_detector.py +577 -0
  30. src/services/research_service/enhanced_research.py +121 -0
  31. src/services/research_service/enhanced_synthesizer.py +375 -0
  32. src/services/research_service/query_generator.py +777 -0
  33. src/services/research_service/synthesizer.py +1273 -0
  34. src/services/search_service/__init__.py +5 -0
  35. src/services/search_service/indexer.py +186 -0
  36. src/services/search_service/search_engine.py +342 -0
  37. src/services/simple_enhanced_main.py +287 -0
  38. cite_agent/__distribution__.py +0 -7
  39. cite_agent-1.0.4.dist-info/RECORD +0 -23
  40. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/WHEEL +0 -0
  41. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/entry_points.txt +0 -0
  42. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,597 @@
1
+ # src/services/research_service/conversation_manager.py
2
+
3
+ import logging
4
+ import re
5
+ import asyncio
6
+ from typing import Dict, List, Any, Optional
7
+ import json
8
+ import uuid
9
+ from datetime import datetime, timezone
10
+
11
+ from src.services.llm_service.llm_manager import LLMManager
12
+
13
+ # Configure structured logging
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _utc_timestamp() -> str:
18
+ return datetime.now(timezone.utc).isoformat()
19
+
20
+ class ResearchConversationManager:
21
+ """
22
+ Enhanced research conversation manager with comprehensive error handling, security, and observability.
23
+
24
+ Features:
25
+ - Secure conversation management and storage
26
+ - Input validation and sanitization
27
+ - Comprehensive error handling and retry logic
28
+ - Structured logging and monitoring
29
+ - Protection against injection attacks
30
+ - Research-focused conversation handling
31
+ """
32
+
33
+ def __init__(self, llm_manager: LLMManager, redis_client):
34
+ """
35
+ Initialize conversation manager with enhanced security and error handling.
36
+
37
+ Args:
38
+ llm_manager: LLM manager instance
39
+ redis_client: Redis client instance
40
+
41
+ Raises:
42
+ ValueError: If parameters are invalid
43
+ """
44
+ try:
45
+ if not llm_manager:
46
+ raise ValueError("LLM manager instance is required")
47
+ if not redis_client:
48
+ raise ValueError("Redis client instance is required")
49
+
50
+ #logger.info("Initializing ResearchConversationManager with enhanced security")
51
+ self.llm_manager = llm_manager
52
+ self.redis_client = redis_client
53
+ self.active_conversations = {}
54
+ #logger.info("ResearchConversationManager initialized successfully")
55
+
56
+ except Exception as e:
57
+ logger.error(f"Failed to initialize ResearchConversationManager: {str(e)}")
58
+ raise
59
+
60
+ def _validate_session_id(self, session_id: str) -> None:
61
+ """
62
+ Validate session ID for security and safety.
63
+
64
+ Args:
65
+ session_id: Session ID to validate
66
+
67
+ Raises:
68
+ ValueError: If session ID is invalid
69
+ """
70
+ if not isinstance(session_id, str):
71
+ raise ValueError("Session ID must be a string")
72
+
73
+ if not session_id.strip():
74
+ raise ValueError("Session ID cannot be empty")
75
+
76
+ if len(session_id) > 100: # Reasonable limit
77
+ raise ValueError("Session ID too long (max 100 characters)")
78
+
79
+ # Check for potentially dangerous patterns
80
+ if re.search(r'[<>"\']', session_id):
81
+ raise ValueError("Session ID contains invalid characters")
82
+
83
+ def _validate_papers(self, papers: List[Dict]) -> None:
84
+ """
85
+ Validate papers list for security and safety.
86
+
87
+ Args:
88
+ papers: Papers list to validate
89
+
90
+ Raises:
91
+ ValueError: If papers list is invalid
92
+ """
93
+ if not isinstance(papers, list):
94
+ raise ValueError("Papers must be a list")
95
+
96
+ if len(papers) > 50: # Reasonable limit
97
+ raise ValueError("Too many papers (max 50)")
98
+
99
+ for i, paper in enumerate(papers):
100
+ if not isinstance(paper, dict):
101
+ raise ValueError(f"Paper at index {i} must be a dictionary")
102
+
103
+ # Validate paper ID if present
104
+ if "id" in paper:
105
+ paper_id = str(paper["id"])
106
+ if len(paper_id) > 100:
107
+ raise ValueError(f"Paper ID at index {i} too long (max 100 characters)")
108
+
109
+ def _validate_message_content(self, content: str) -> None:
110
+ """
111
+ Validate message content for security and safety.
112
+
113
+ Args:
114
+ content: Message content to validate
115
+
116
+ Raises:
117
+ ValueError: If content is invalid
118
+ """
119
+ if not isinstance(content, str):
120
+ raise ValueError("Message content must be a string")
121
+
122
+ if not content.strip():
123
+ raise ValueError("Message content cannot be empty")
124
+
125
+ if len(content) > 5000: # Reasonable limit
126
+ raise ValueError("Message content too long (max 5000 characters)")
127
+
128
+ # Check for potentially dangerous content
129
+ dangerous_patterns = [
130
+ r'<script.*?>.*?</script>', # Script tags
131
+ r'javascript:', # JavaScript protocol
132
+ r'data:text/html', # Data URLs
133
+ r'vbscript:', # VBScript
134
+ ]
135
+
136
+ for pattern in dangerous_patterns:
137
+ if re.search(pattern, content, re.IGNORECASE):
138
+ raise ValueError(f"Message content contains potentially dangerous patterns: {pattern}")
139
+
140
+ def _sanitize_text(self, text: str, max_length: int = 5000) -> str:
141
+ """
142
+ Sanitize text to prevent injection attacks.
143
+
144
+ Args:
145
+ text: Text to sanitize
146
+ max_length: Maximum allowed length
147
+
148
+ Returns:
149
+ Sanitized text
150
+ """
151
+ if not isinstance(text, str):
152
+ raise ValueError("Text must be a string")
153
+
154
+ if len(text) > max_length:
155
+ text = text[:max_length]
156
+
157
+ # Basic XSS protection
158
+ sanitized = text.replace('<', '&lt;').replace('>', '&gt;')
159
+
160
+ # Remove null bytes and other control characters
161
+ sanitized = ''.join(char for char in sanitized if ord(char) >= 32 or char in '\n\r\t')
162
+
163
+ return sanitized.strip()
164
+
165
+ async def create_conversation(self,
166
+ session_id: str,
167
+ papers: List[Dict],
168
+ synthesis: Dict) -> str:
169
+ """
170
+ Create a new conversation with enhanced error handling and security.
171
+
172
+ Args:
173
+ session_id: Research session ID
174
+ papers: List of papers to include in conversation
175
+ synthesis: Research synthesis data
176
+
177
+ Returns:
178
+ Conversation ID
179
+
180
+ Raises:
181
+ ValueError: If inputs are invalid
182
+ ConnectionError: If conversation creation fails
183
+ """
184
+ try:
185
+ # Input validation and sanitization
186
+ self._validate_session_id(session_id)
187
+ self._validate_papers(papers)
188
+
189
+ if not isinstance(synthesis, dict):
190
+ raise ValueError("Synthesis must be a dictionary")
191
+
192
+ sanitized_session_id = self._sanitize_text(session_id, max_length=100)
193
+
194
+ #logger.info(f"Creating conversation for session: {sanitized_session_id}")
195
+
196
+ conversation_id = f"conv_{sanitized_session_id}_{str(uuid.uuid4())[:8]}"
197
+
198
+ # Create a research-focused system message
199
+ system_message = self._create_system_message(papers, synthesis)
200
+
201
+ # Initialize conversation history
202
+ conversation = {
203
+ "id": conversation_id,
204
+ "session_id": sanitized_session_id,
205
+ "created_at": _utc_timestamp(),
206
+ "updated_at": _utc_timestamp(),
207
+ "messages": [
208
+ {"role": "system", "content": system_message}
209
+ ],
210
+ "paper_ids": [p.get("id") for p in papers if "id" in p],
211
+ "metadata": {
212
+ "paper_count": len(papers),
213
+ "synthesis_available": bool(synthesis)
214
+ }
215
+ }
216
+
217
+ # Store in Redis and memory with error handling
218
+ await self._store_conversation(conversation)
219
+ self.active_conversations[conversation_id] = conversation
220
+
221
+ #logger.info(f"Successfully created conversation: {conversation_id}")
222
+ return conversation_id
223
+
224
+ except ValueError as e:
225
+ logger.error(f"Invalid input for conversation creation: {str(e)}")
226
+ raise
227
+ except Exception as e:
228
+ logger.error(f"Error creating conversation: {str(e)}")
229
+ raise
230
+
231
+ async def add_message(self,
232
+ conversation_id: str,
233
+ content: str,
234
+ role: str = "user") -> Dict[str, Any]:
235
+ """
236
+ Add a message to the conversation with enhanced error handling and security.
237
+
238
+ Args:
239
+ conversation_id: Conversation ID
240
+ content: Message content
241
+ role: Message role (user/assistant)
242
+
243
+ Returns:
244
+ Response with AI reply
245
+
246
+ Raises:
247
+ ValueError: If inputs are invalid
248
+ ConnectionError: If message processing fails
249
+ """
250
+ try:
251
+ # Input validation and sanitization
252
+ if not isinstance(conversation_id, str) or not conversation_id.strip():
253
+ raise ValueError("Conversation ID must be a non-empty string")
254
+
255
+ self._validate_message_content(content)
256
+
257
+ if role not in ["user", "assistant"]:
258
+ raise ValueError("Role must be 'user' or 'assistant'")
259
+
260
+ sanitized_content = self._sanitize_text(content, max_length=5000)
261
+
262
+ #logger.info(f"Adding message to conversation: {conversation_id[:20]}...")
263
+
264
+ # Get conversation with error handling
265
+ conversation = await self._get_conversation(conversation_id)
266
+ if not conversation:
267
+ return {"error": "Conversation not found"}
268
+
269
+ # Add user message
270
+ conversation["messages"].append({
271
+ "role": role,
272
+ "content": sanitized_content,
273
+ "timestamp": _utc_timestamp()
274
+ })
275
+
276
+ # Format messages for LLM
277
+ messages = conversation["messages"]
278
+
279
+ # Generate response with retry logic
280
+ try:
281
+ response = await self._generate_response_with_retry(messages)
282
+
283
+ # Add assistant response
284
+ conversation["messages"].append({
285
+ "role": "assistant",
286
+ "content": response,
287
+ "timestamp": _utc_timestamp()
288
+ })
289
+
290
+ # Update conversation
291
+ conversation["updated_at"] = _utc_timestamp()
292
+ await self._store_conversation(conversation)
293
+
294
+ #logger.info(f"Successfully processed message for conversation: {conversation_id[:20]}...")
295
+ return {
296
+ "response": response,
297
+ "conversation_id": conversation_id
298
+ }
299
+
300
+ except Exception as e:
301
+ logger.error(f"Error generating response: {str(e)}")
302
+ return {"error": f"Failed to generate response: {str(e)}"}
303
+
304
+ except ValueError as e:
305
+ logger.error(f"Invalid input for message addition: {str(e)}")
306
+ return {"error": str(e)}
307
+ except Exception as e:
308
+ logger.error(f"Error adding message: {str(e)}")
309
+ return {"error": str(e)}
310
+
311
+ async def _generate_response_with_retry(self, messages: List[Dict], max_retries: int = 3) -> str:
312
+ """
313
+ Generate response with retry logic.
314
+
315
+ Args:
316
+ messages: Conversation messages
317
+ max_retries: Maximum retry attempts
318
+
319
+ Returns:
320
+ Generated response
321
+
322
+ Raises:
323
+ ConnectionError: If all retries fail
324
+ """
325
+ last_error = None
326
+
327
+ for attempt in range(max_retries):
328
+ try:
329
+ # Use LLM manager for chat completion
330
+ response = await self.llm_manager.generate_synthesis(
331
+ [{"content": msg["content"]} for msg in messages if msg["role"] != "system"],
332
+ " ".join([msg["content"] for msg in messages if msg["role"] != "system"])
333
+ )
334
+
335
+ if isinstance(response, dict) and "summary" in response:
336
+ return response["summary"]
337
+ else:
338
+ return str(response)
339
+
340
+ except Exception as e:
341
+ last_error = e
342
+ logger.warning(f"Response generation attempt {attempt + 1} failed: {str(e)}")
343
+
344
+ if attempt < max_retries - 1:
345
+ await asyncio.sleep(1) # Short delay between retries
346
+
347
+ # All retries failed
348
+ logger.error(f"All response generation attempts failed")
349
+ raise ConnectionError(f"Failed to generate response after {max_retries} attempts: {str(last_error)}")
350
+
351
+ async def get_conversation_history(self, conversation_id: str) -> Dict[str, Any]:
352
+ """
353
+ Get full conversation history with enhanced error handling and security.
354
+
355
+ Args:
356
+ conversation_id: Conversation ID
357
+
358
+ Returns:
359
+ Conversation history
360
+
361
+ Raises:
362
+ ValueError: If conversation ID is invalid
363
+ """
364
+ try:
365
+ # Input validation
366
+ if not isinstance(conversation_id, str) or not conversation_id.strip():
367
+ raise ValueError("Conversation ID must be a non-empty string")
368
+
369
+ #logger.info(f"Retrieving conversation history: {conversation_id[:20]}...")
370
+
371
+ conversation = await self._get_conversation(conversation_id)
372
+ if not conversation:
373
+ return {"error": "Conversation not found"}
374
+
375
+ # Filter out system messages for display
376
+ user_messages = [
377
+ msg for msg in conversation["messages"]
378
+ if msg["role"] != "system"
379
+ ]
380
+
381
+ # Sanitize messages for security
382
+ sanitized_messages = []
383
+ for msg in user_messages:
384
+ sanitized_msg = msg.copy()
385
+ sanitized_msg["content"] = self._sanitize_text(msg["content"], max_length=5000)
386
+ sanitized_messages.append(sanitized_msg)
387
+
388
+ result = {
389
+ "id": conversation_id,
390
+ "messages": sanitized_messages,
391
+ "created_at": conversation.get("created_at"),
392
+ "updated_at": conversation.get("updated_at"),
393
+ "metadata": conversation.get("metadata", {})
394
+ }
395
+
396
+ #logger.info(f"Successfully retrieved conversation history: {conversation_id[:20]}...")
397
+ return result
398
+
399
+ except ValueError as e:
400
+ logger.error(f"Invalid input for conversation history: {str(e)}")
401
+ return {"error": str(e)}
402
+ except Exception as e:
403
+ logger.error(f"Error getting conversation history: {str(e)}")
404
+ return {"error": str(e)}
405
+
406
+ def _create_system_message(self, papers: List[Dict], synthesis: Dict) -> str:
407
+ """
408
+ Create a detailed system message with research context and enhanced security.
409
+
410
+ Args:
411
+ papers: List of papers
412
+ synthesis: Research synthesis
413
+
414
+ Returns:
415
+ System message
416
+ """
417
+ try:
418
+ # Extract and sanitize paper summaries
419
+ paper_summaries = []
420
+ for i, paper in enumerate(papers, 1):
421
+ if not isinstance(paper, dict):
422
+ continue
423
+
424
+ title = self._sanitize_text(paper.get("title", "Untitled"), max_length=200)
425
+ authors = ", ".join(paper.get("authors", [])) if isinstance(paper.get("authors"), list) else ""
426
+
427
+ summary = f"Paper {i}: {title}"
428
+ if authors:
429
+ summary += f" by {authors}"
430
+
431
+ # Add key findings
432
+ if "main_findings" in paper:
433
+ findings = paper["main_findings"]
434
+ if isinstance(findings, list):
435
+ summary += "\nMain findings:\n"
436
+ for j, finding in enumerate(findings[:5], 1): # Limit to 5 findings
437
+ sanitized_finding = self._sanitize_text(str(finding), max_length=200)
438
+ summary += f" {j}. {sanitized_finding}\n"
439
+ else:
440
+ sanitized_finding = self._sanitize_text(str(findings), max_length=200)
441
+ summary += f"\nMain finding: {sanitized_finding}\n"
442
+
443
+ # Add methodology if available
444
+ if "methodology" in paper:
445
+ methodology = self._sanitize_text(str(paper['methodology']), max_length=200)
446
+ summary += f"\nMethodology: {methodology}\n"
447
+
448
+ paper_summaries.append(summary)
449
+
450
+ # Create system message
451
+ separator = '\n\n'
452
+ synthesis_text = self._sanitize_text(str(synthesis.get("synthesis", "No synthesis available")), max_length=2000)
453
+
454
+ system_message = f"""You are a research assistant discussing a collection of academic papers on a specific topic.
455
+
456
+ You have analyzed these papers:
457
+
458
+ {'='*40}
459
+ {separator.join(paper_summaries)}
460
+ {'='*40}
461
+
462
+ Research synthesis:
463
+ {synthesis_text}
464
+
465
+ When answering questions:
466
+ 1. Reference specific papers by title when drawing from their findings
467
+ 2. Acknowledge contradictions or disagreements between papers when they exist
468
+ 3. Clearly state when information is not covered in the papers analyzed
469
+ 4. Provide nuanced, balanced perspectives that reflect the research literature
470
+ 5. Suggest additional areas to explore when the user's questions go beyond the current papers
471
+
472
+ Maintain a scholarly, informative tone while being conversational and accessible.
473
+ """
474
+ return system_message
475
+
476
+ except Exception as e:
477
+ logger.error(f"Error creating system message: {str(e)}")
478
+ return "You are a research assistant. Please ask me about the research papers."
479
+
480
+ async def _store_conversation(self, conversation: Dict) -> None:
481
+ """
482
+ Store conversation in Redis with enhanced error handling.
483
+
484
+ Args:
485
+ conversation: Conversation data to store
486
+ """
487
+ conversation_id = conversation["id"]
488
+
489
+ try:
490
+ # Store as JSON string with error handling
491
+ conversation_json = json.dumps(conversation)
492
+
493
+ await self.redis_client.set(
494
+ f"conversation:{conversation_id}",
495
+ conversation_json,
496
+ ex=86400 # 24 hour expiration
497
+ )
498
+
499
+ # Update active conversations
500
+ self.active_conversations[conversation_id] = conversation
501
+
502
+ logger.debug(f"Stored conversation: {conversation_id}")
503
+
504
+ except Exception as e:
505
+ logger.error(f"Error storing conversation {conversation_id}: {str(e)}")
506
+ raise
507
+
508
+ async def _get_conversation(self, conversation_id: str) -> Optional[Dict]:
509
+ """
510
+ Get conversation from memory or Redis with enhanced error handling.
511
+
512
+ Args:
513
+ conversation_id: Conversation ID
514
+
515
+ Returns:
516
+ Conversation data or None
517
+ """
518
+ try:
519
+ # Check memory first
520
+ if conversation_id in self.active_conversations:
521
+ return self.active_conversations[conversation_id]
522
+
523
+ # Try Redis
524
+ try:
525
+ conversation_data = await self.redis_client.get(f"conversation:{conversation_id}")
526
+ if conversation_data:
527
+ conversation = json.loads(conversation_data)
528
+ # Cache in memory
529
+ self.active_conversations[conversation_id] = conversation
530
+ return conversation
531
+ except Exception as e:
532
+ logger.warning(f"Error retrieving conversation from Redis: {str(e)}")
533
+
534
+ return None
535
+
536
+ except Exception as e:
537
+ logger.error(f"Error getting conversation {conversation_id}: {str(e)}")
538
+ return None
539
+
540
+ async def health_check(self) -> Dict[str, Any]:
541
+ """
542
+ Perform health check of the conversation manager.
543
+
544
+ Returns:
545
+ Health status
546
+ """
547
+ try:
548
+ health_status = {
549
+ "status": "healthy",
550
+ "timestamp": _utc_timestamp(),
551
+ "components": {}
552
+ }
553
+
554
+ # Check LLM manager
555
+ try:
556
+ llm_health = await self.llm_manager.health_check()
557
+ health_status["components"]["llm_manager"] = llm_health
558
+ if llm_health.get("status") != "healthy":
559
+ health_status["status"] = "degraded"
560
+ except Exception as e:
561
+ health_status["components"]["llm_manager"] = {"status": "error", "error": str(e)}
562
+ health_status["status"] = "degraded"
563
+
564
+ # Check Redis connection
565
+ try:
566
+ await self.redis_client.ping()
567
+ health_status["components"]["redis"] = {"status": "healthy"}
568
+ except Exception as e:
569
+ health_status["components"]["redis"] = {"status": "error", "error": str(e)}
570
+ health_status["status"] = "degraded"
571
+
572
+ # Check active conversations
573
+ active_count = len(self.active_conversations)
574
+ health_status["components"]["active_conversations"] = {
575
+ "status": "healthy",
576
+ "count": active_count
577
+ }
578
+
579
+ #logger.info(f"Health check completed: {health_status['status']}")
580
+ return health_status
581
+
582
+ except Exception as e:
583
+ logger.error(f"Health check failed: {str(e)}")
584
+ return {
585
+ "status": "error",
586
+ "error": str(e),
587
+ "timestamp": _utc_timestamp()
588
+ }
589
+
590
+ async def cleanup(self):
591
+ """Cleanup resources with error handling."""
592
+ try:
593
+ # Clear active conversations
594
+ self.active_conversations.clear()
595
+ #logger.info("ResearchConversationManager cleanup completed")
596
+ except Exception as e:
597
+ logger.error(f"Error during cleanup: {str(e)}")