cite-agent 1.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cite-agent might be problematic. Click here for more details.

Files changed (42) hide show
  1. cite_agent/__init__.py +1 -1
  2. cite_agent/account_client.py +19 -46
  3. cite_agent/agent_backend_only.py +30 -4
  4. cite_agent/cli.py +24 -26
  5. cite_agent/cli_conversational.py +294 -0
  6. cite_agent/enhanced_ai_agent.py +2776 -118
  7. cite_agent/setup_config.py +5 -21
  8. cite_agent/streaming_ui.py +252 -0
  9. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/METADATA +4 -3
  10. cite_agent-1.0.5.dist-info/RECORD +50 -0
  11. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/top_level.txt +1 -0
  12. src/__init__.py +1 -0
  13. src/services/__init__.py +132 -0
  14. src/services/auth_service/__init__.py +3 -0
  15. src/services/auth_service/auth_manager.py +33 -0
  16. src/services/graph/__init__.py +1 -0
  17. src/services/graph/knowledge_graph.py +194 -0
  18. src/services/llm_service/__init__.py +5 -0
  19. src/services/llm_service/llm_manager.py +495 -0
  20. src/services/paper_service/__init__.py +5 -0
  21. src/services/paper_service/openalex.py +231 -0
  22. src/services/performance_service/__init__.py +1 -0
  23. src/services/performance_service/rust_performance.py +395 -0
  24. src/services/research_service/__init__.py +23 -0
  25. src/services/research_service/chatbot.py +2056 -0
  26. src/services/research_service/citation_manager.py +436 -0
  27. src/services/research_service/context_manager.py +1441 -0
  28. src/services/research_service/conversation_manager.py +597 -0
  29. src/services/research_service/critical_paper_detector.py +577 -0
  30. src/services/research_service/enhanced_research.py +121 -0
  31. src/services/research_service/enhanced_synthesizer.py +375 -0
  32. src/services/research_service/query_generator.py +777 -0
  33. src/services/research_service/synthesizer.py +1273 -0
  34. src/services/search_service/__init__.py +5 -0
  35. src/services/search_service/indexer.py +186 -0
  36. src/services/search_service/search_engine.py +342 -0
  37. src/services/simple_enhanced_main.py +287 -0
  38. cite_agent/__distribution__.py +0 -7
  39. cite_agent-1.0.4.dist-info/RECORD +0 -23
  40. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/WHEEL +0 -0
  41. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/entry_points.txt +0 -0
  42. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,5 @@
1
+ """Search service package exposing the high-level search engine."""
2
+
3
+ from .search_engine import SearchEngine
4
+
5
+ __all__ = ["SearchEngine"]
@@ -0,0 +1,186 @@
1
+ #indexer.py
2
+ from typing import List, Dict, Optional
3
+ import asyncio
4
+ from datetime import datetime, timezone
5
+ import redis.asyncio as redis
6
+ import json
7
+
8
+ from ...utils.logger import logger, log_operation
9
+ from ...storage.db.operations import DatabaseOperations
10
+ from .vector_search import VectorSearchEngine
11
+
12
+
13
+ def _utc_timestamp() -> str:
14
+ return datetime.now(timezone.utc).isoformat()
15
+
16
+ class DocumentIndexer:
17
+ def __init__(self, db_ops: DatabaseOperations, vector_search: VectorSearchEngine, redis_url: str):
18
+ #logger.info("Initializing DocumentIndexer")
19
+ self.db = db_ops
20
+ self.vector_search = vector_search
21
+ self.redis_client = redis.from_url(redis_url)
22
+ self.indexing_queue = asyncio.Queue()
23
+ self.active_sessions: Dict[str, Dict] = {}
24
+ self._running = False
25
+ #logger.info("DocumentIndexer initialized")
26
+
27
+ @log_operation("start_indexing")
28
+ async def start(self):
29
+ """Start the indexing process."""
30
+ #logger.info("Starting indexing service")
31
+ self._running = True
32
+ try:
33
+ await asyncio.gather(
34
+ self._process_queue(),
35
+ self._monitor_research_sessions()
36
+ )
37
+ except Exception as e:
38
+ logger.error(f"Error in indexing service: {str(e)}")
39
+ self._running = False
40
+ raise
41
+
42
+ @log_operation("monitor_sessions")
43
+ async def _monitor_research_sessions(self):
44
+ """Monitor active research sessions for new documents."""
45
+ while self._running:
46
+ try:
47
+ # Subscribe to research session updates
48
+ async for message in self.redis_client.subscribe("research_updates"):
49
+ update = json.loads(message["data"])
50
+ session_id = update.get("session_id")
51
+
52
+ if update.get("type") == "new_papers":
53
+ await self._handle_new_papers(session_id, update.get("papers", []))
54
+
55
+ except Exception as e:
56
+ logger.error(f"Error monitoring sessions: {str(e)}")
57
+ await asyncio.sleep(1)
58
+
59
+ async def _handle_new_papers(self, session_id: str, papers: List[str]):
60
+ """Handle new papers added to a research session."""
61
+ #logger.info(f"Processing new papers for session {session_id}")
62
+ for paper_id in papers:
63
+ await self.queue_document(paper_id, session_id)
64
+
65
+ @log_operation("queue_document")
66
+ async def queue_document(self, doc_id: str, session_id: Optional[str] = None):
67
+ """Queue a document for indexing."""
68
+ #logger.info(f"Queuing document for indexing: {doc_id}")
69
+ await self.indexing_queue.put({
70
+ "doc_id": doc_id,
71
+ "session_id": session_id,
72
+ "queued_at": _utc_timestamp()
73
+ })
74
+
75
+ if session_id:
76
+ await self._update_session_progress(session_id, "queued", doc_id)
77
+
78
+ async def _process_queue(self):
79
+ """Process documents in the indexing queue."""
80
+ #logger.info("Starting queue processing")
81
+ while self._running:
82
+ try:
83
+ item = await self.indexing_queue.get()
84
+ doc_id = item["doc_id"]
85
+ session_id = item.get("session_id")
86
+
87
+ #logger.info(f"Processing document: {doc_id}")
88
+
89
+ if session_id:
90
+ await self._update_session_progress(session_id, "processing", doc_id)
91
+
92
+ processed_doc = await self.db.get_processed_paper(doc_id)
93
+ if not processed_doc:
94
+ #logger.warning(f"Processed content not found: {doc_id}")
95
+ continue
96
+
97
+ success = await self._index_document(doc_id, processed_doc.content)
98
+
99
+ if success:
100
+ #logger.info(f"Successfully indexed: {doc_id}")
101
+ await self.db.update_paper_status(doc_id, "indexed")
102
+ if session_id:
103
+ await self._update_session_progress(session_id, "completed", doc_id)
104
+ else:
105
+ #logger.error(f"Failed to index: {doc_id}")
106
+ if session_id:
107
+ await self._update_session_progress(session_id, "failed", doc_id)
108
+
109
+ except asyncio.CancelledError:
110
+ break
111
+ except Exception as e:
112
+ logger.error(f"Error processing document: {str(e)}")
113
+ finally:
114
+ self.indexing_queue.task_done()
115
+
116
+ async def _update_session_progress(self, session_id: str, status: str, doc_id: str):
117
+ """Update indexing progress for research session."""
118
+ try:
119
+ progress_key = f"indexing_progress:{session_id}"
120
+ await self.redis_client.hset(
121
+ progress_key,
122
+ mapping={
123
+ doc_id: json.dumps({
124
+ "status": status,
125
+ "updated_at": _utc_timestamp()
126
+ })
127
+ }
128
+ )
129
+
130
+ # Publish update
131
+ await self.redis_client.publish(
132
+ "indexing_updates",
133
+ json.dumps({
134
+ "session_id": session_id,
135
+ "doc_id": doc_id,
136
+ "status": status
137
+ })
138
+ )
139
+ except Exception as e:
140
+ logger.error(f"Error updating session progress: {str(e)}")
141
+
142
+ @log_operation("batch_index")
143
+ async def batch_index(self, doc_ids: List[str], session_id: Optional[str] = None) -> Dict[str, bool]:
144
+ """Batch index multiple documents with session tracking."""
145
+ #logger.info(f"Starting batch indexing: {len(doc_ids)} documents")
146
+ results = {}
147
+
148
+ for doc_id in doc_ids:
149
+ await self.queue_document(doc_id, session_id)
150
+ results[doc_id] = True
151
+
152
+ await self.indexing_queue.join()
153
+ return results
154
+
155
+ @log_operation("reindex_all")
156
+ async def reindex_all(self):
157
+ """Reindex all documents with progress tracking."""
158
+ #logger.info("Starting full reindexing")
159
+ try:
160
+ papers = await self.db.search_papers({"status": "processed"})
161
+ doc_ids = [paper.id for paper in papers]
162
+
163
+ if not doc_ids:
164
+ logger.info("No documents found for reindexing")
165
+ return
166
+
167
+ total = len(doc_ids)
168
+ progress_key = "reindex_progress"
169
+
170
+ await self.redis_client.set(progress_key, "0")
171
+ results = await self.batch_index(doc_ids)
172
+
173
+ success_count = sum(1 for success in results.values() if success)
174
+ await self.redis_client.set(progress_key, str(success_count / total * 100))
175
+
176
+ #logger.info(f"Reindexing completed: {success_count}/{total} successful")
177
+
178
+ except Exception as e:
179
+ logger.error(f"Error during reindexing: {str(e)}")
180
+ raise
181
+
182
+ async def cleanup(self):
183
+ """Cleanup indexing resources."""
184
+ #logger.info("Cleaning up indexer resources")
185
+ self._running = False
186
+ await self.redis_client.close()
@@ -0,0 +1,342 @@
1
+ """High-level scholarly search orchestration service.
2
+
3
+ The engine aggregates OpenAlex, optional PubMed, and lightweight web results to
4
+ provide the advanced research surface required by the beta release. Emphasis is
5
+ placed on resilience: network failures or missing API keys degrade gracefully
6
+ instead of crashing the request path.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import asyncio
12
+ import logging
13
+ import os
14
+ from dataclasses import dataclass
15
+ from datetime import datetime
16
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
17
+
18
+ import httpx
19
+
20
+ from ..paper_service import OpenAlexClient
21
+ from ..performance_service.rust_performance import HighPerformanceService
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ _PUBMED_BASE = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
26
+ _DDG_PROXY = os.getenv("DDG_PROXY_URL", "https://ddg-webapp-aagd.vercel.app/search")
27
+
28
+
29
+ @dataclass(slots=True)
30
+ class SearchResult:
31
+ """Canonical representation of a scholarly item."""
32
+
33
+ id: str
34
+ title: str
35
+ abstract: str
36
+ source: str
37
+ authors: List[str]
38
+ year: Optional[int]
39
+ doi: str
40
+ url: str
41
+ relevance: float
42
+ citations: int
43
+ keywords: List[str]
44
+ metadata: Dict[str, Any]
45
+
46
+
47
+ class SearchEngine:
48
+ """Advanced scholarly search with optional enrichment."""
49
+
50
+ def __init__(
51
+ self,
52
+ *,
53
+ openalex_client: Optional[OpenAlexClient] = None,
54
+ performance_service: Optional[HighPerformanceService] = None,
55
+ redis_url: str = os.getenv("REDIS_URL", "redis://localhost:6379"),
56
+ timeout: float = float(os.getenv("SEARCH_TIMEOUT", "12.0")),
57
+ ) -> None:
58
+ self.openalex = openalex_client or OpenAlexClient()
59
+ self.performance = performance_service or HighPerformanceService()
60
+ self.redis_url = redis_url
61
+ self.timeout = timeout
62
+ self._session = httpx.AsyncClient(timeout=timeout)
63
+ self._lock = asyncio.Lock()
64
+
65
+ # ------------------------------------------------------------------
66
+ async def search_papers(
67
+ self,
68
+ query: str,
69
+ *,
70
+ limit: int = 10,
71
+ sources: Optional[Sequence[str]] = None,
72
+ include_metadata: bool = True,
73
+ include_abstracts: bool = True,
74
+ include_citations: bool = True,
75
+ ) -> Dict[str, Any]:
76
+ """Search across configured scholarly sources."""
77
+
78
+ if not query or not query.strip():
79
+ raise ValueError("Query must be a non-empty string")
80
+
81
+ limit = max(1, min(limit, 100))
82
+ sources = sources or ("openalex",)
83
+ gathered: List[SearchResult] = []
84
+
85
+ if "openalex" in sources:
86
+ gathered.extend(await self._search_openalex(query, limit))
87
+
88
+ if "pubmed" in sources:
89
+ try:
90
+ gathered.extend(await self._search_pubmed(query, limit))
91
+ except Exception as exc: # pragma: no cover - optional remote dependency
92
+ logger.info("PubMed search failed", extra={"error": str(exc)})
93
+
94
+ deduped = self._deduplicate_results(gathered)
95
+ sorted_results = sorted(deduped, key=lambda item: item.relevance, reverse=True)[:limit]
96
+
97
+ payload = {
98
+ "query": query,
99
+ "count": len(sorted_results),
100
+ "sources_used": list(dict.fromkeys([res.source for res in sorted_results])),
101
+ "papers": [self._result_to_payload(res, include_metadata, include_abstracts, include_citations) for res in sorted_results],
102
+ "timestamp": datetime.utcnow().isoformat() + "Z",
103
+ }
104
+ return payload
105
+
106
+ async def web_search(self, query: str, *, num_results: int = 5) -> List[Dict[str, Any]]:
107
+ """Perform a lightweight DuckDuckGo-backed web search."""
108
+
109
+ params = {"q": query, "max_results": max(1, min(num_results, 10))}
110
+ try:
111
+ response = await self._session.get(_DDG_PROXY, params=params)
112
+ response.raise_for_status()
113
+ data = response.json()
114
+ results = data.get("results", []) if isinstance(data, dict) else []
115
+ formatted = [
116
+ {
117
+ "title": item.get("title", ""),
118
+ "url": item.get("href") or item.get("link") or "",
119
+ "snippet": item.get("body") or item.get("snippet") or "",
120
+ "source": item.get("source", "duckduckgo"),
121
+ }
122
+ for item in results[:num_results]
123
+ ]
124
+ return formatted
125
+ except Exception as exc: # pragma: no cover - network optional
126
+ logger.info("Web search fallback", extra={"error": str(exc)})
127
+ return []
128
+
129
+ async def fetch_paper_bundle(self, paper_ids: Iterable[str]) -> List[Dict[str, Any]]:
130
+ """Convenience helper for fetching OpenAlex metadata for multiple IDs."""
131
+
132
+ papers = await self.openalex.get_papers_bulk(paper_ids)
133
+ formatted: List[Dict[str, Any]] = []
134
+ for paper in papers:
135
+ formatted.append(self._format_openalex_work(paper))
136
+ return formatted
137
+
138
+ async def close(self) -> None:
139
+ await self.openalex.close()
140
+ try:
141
+ await self._session.aclose()
142
+ except Exception:
143
+ pass
144
+
145
+ # ------------------------------------------------------------------
146
+ async def _search_openalex(self, query: str, limit: int) -> List[SearchResult]:
147
+ payload = await self.openalex.search_works(
148
+ query,
149
+ limit=limit,
150
+ filters={"type": "journal-article"},
151
+ )
152
+ results = payload.get("results", []) if isinstance(payload, dict) else []
153
+ formatted = [self._format_openalex_work(item) for item in results]
154
+
155
+ if formatted:
156
+ try:
157
+ combined = "\n".join(res.abstract for res in formatted if res.abstract)
158
+ if combined:
159
+ keywords = await self.performance.extract_keywords(combined, max_keywords=10)
160
+ for res in formatted:
161
+ res.metadata.setdefault("query_keywords", keywords)
162
+ except Exception:
163
+ # Keyword enrichment is best-effort
164
+ pass
165
+
166
+ return formatted
167
+
168
+ async def _search_pubmed(self, query: str, limit: int) -> List[SearchResult]:
169
+ params = {
170
+ "db": "pubmed",
171
+ "term": query,
172
+ "retmax": max(1, min(limit, 50)),
173
+ "retmode": "json",
174
+ "sort": "relevance",
175
+ }
176
+ search_url = f"{_PUBMED_BASE}/esearch.fcgi"
177
+ response = await self._session.get(search_url, params=params)
178
+ response.raise_for_status()
179
+ id_list = response.json().get("esearchresult", {}).get("idlist", [])
180
+ if not id_list:
181
+ return []
182
+
183
+ summary_params = {
184
+ "db": "pubmed",
185
+ "id": ",".join(id_list[:limit]),
186
+ "retmode": "json",
187
+ }
188
+ summary_url = f"{_PUBMED_BASE}/esummary.fcgi"
189
+ summary_resp = await self._session.get(summary_url, params=summary_params)
190
+ summary_resp.raise_for_status()
191
+ summaries = summary_resp.json().get("result", {})
192
+
193
+ results: List[SearchResult] = []
194
+ for pmid in id_list[:limit]:
195
+ raw = summaries.get(pmid)
196
+ if not isinstance(raw, dict):
197
+ continue
198
+ title = raw.get("title", "")
199
+ abstract = raw.get("elocationid", "") or raw.get("source", "")
200
+ authors = [author.get("name", "") for author in raw.get("authors", []) if isinstance(author, dict)]
201
+ results.append(
202
+ SearchResult(
203
+ id=f"PMID:{pmid}",
204
+ title=title,
205
+ abstract=abstract,
206
+ source="pubmed",
207
+ authors=[a for a in authors if a],
208
+ year=self._safe_int(raw.get("pubdate", "")[:4]),
209
+ doi=raw.get("elocationid", ""),
210
+ url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/",
211
+ relevance=0.6,
212
+ citations=raw.get("pmc", 0) or 0,
213
+ keywords=[],
214
+ metadata={"pmid": pmid},
215
+ )
216
+ )
217
+ return results
218
+
219
+ def _deduplicate_results(self, results: List[SearchResult]) -> List[SearchResult]:
220
+ seen: Dict[str, SearchResult] = {}
221
+ for result in results:
222
+ key = result.doi.lower() if result.doi else result.title.lower()
223
+ if key in seen:
224
+ existing = seen[key]
225
+ if result.relevance > existing.relevance:
226
+ seen[key] = result
227
+ else:
228
+ seen[key] = result
229
+ return list(seen.values())
230
+
231
+ def _format_openalex_work(self, work: Dict[str, Any]) -> SearchResult:
232
+ title = work.get("title", "")
233
+ abstract = self._extract_openalex_abstract(work)
234
+ authors = [
235
+ auth.get("author", {}).get("display_name", "")
236
+ for auth in work.get("authorships", [])
237
+ if isinstance(auth, dict)
238
+ ]
239
+ doi = work.get("doi", "") or ""
240
+ url = work.get("id", "")
241
+ concepts = [concept.get("display_name", "") for concept in work.get("concepts", []) if isinstance(concept, dict)]
242
+ citations = work.get("cited_by_count", 0) or 0
243
+ relevance = work.get("relevance_score", 0.0) or 0.5 + min(0.4, citations / 1000)
244
+
245
+ keywords = concepts[:5] or self._quick_keywords(abstract, limit=5)
246
+
247
+ metadata = {
248
+ "openalex_id": url,
249
+ "open_access": work.get("open_access", {}).get("is_oa", False),
250
+ "primary_location": work.get("primary_location"),
251
+ }
252
+
253
+ return SearchResult(
254
+ id=url.split("/")[-1] if url else doi or title,
255
+ title=title,
256
+ abstract=abstract,
257
+ source="openalex",
258
+ authors=[a for a in authors if a],
259
+ year=work.get("publication_year"),
260
+ doi=doi.replace("https://doi.org/", ""),
261
+ url=url or f"https://openalex.org/{title[:50].replace(' ', '_')}",
262
+ relevance=float(relevance),
263
+ citations=citations,
264
+ keywords=keywords,
265
+ metadata=metadata,
266
+ )
267
+
268
+ def _result_to_payload(
269
+ self,
270
+ result: SearchResult,
271
+ include_metadata: bool,
272
+ include_abstracts: bool,
273
+ include_citations: bool,
274
+ ) -> Dict[str, Any]:
275
+ payload = {
276
+ "id": result.id,
277
+ "title": result.title,
278
+ "source": result.source,
279
+ "authors": result.authors,
280
+ "year": result.year,
281
+ "doi": result.doi,
282
+ "url": result.url,
283
+ "keywords": result.keywords,
284
+ "relevance": result.relevance,
285
+ }
286
+ if include_abstracts:
287
+ payload["abstract"] = result.abstract
288
+ if include_citations:
289
+ payload["citations_count"] = result.citations
290
+ if include_metadata:
291
+ payload["metadata"] = result.metadata
292
+ return payload
293
+
294
+ def _extract_openalex_abstract(self, work: Dict[str, Any]) -> str:
295
+ inverted = work.get("abstract_inverted_index")
296
+ if isinstance(inverted, dict) and inverted:
297
+ # Convert inverted index back to human-readable abstract
298
+ index_map: Dict[int, str] = {}
299
+ for token, positions in inverted.items():
300
+ for position in positions:
301
+ index_map[position] = token
302
+ abstract_tokens = [token for _, token in sorted(index_map.items())]
303
+ return " ".join(abstract_tokens)
304
+ return work.get("abstract", "") or ""
305
+
306
+ def _safe_int(self, value: Any) -> Optional[int]:
307
+ try:
308
+ return int(value)
309
+ except Exception:
310
+ return None
311
+
312
+ def _quick_keywords(self, text: str, limit: int = 5) -> List[str]:
313
+ import re
314
+ from collections import Counter
315
+
316
+ if not text:
317
+ return []
318
+
319
+ words = re.findall(r"[a-zA-Z]{3,}", text.lower())
320
+ stop_words = {
321
+ "the",
322
+ "and",
323
+ "for",
324
+ "with",
325
+ "from",
326
+ "that",
327
+ "have",
328
+ "this",
329
+ "were",
330
+ "also",
331
+ "into",
332
+ "which",
333
+ "their",
334
+ "between",
335
+ "within",
336
+ }
337
+ filtered = [word for word in words if word not in stop_words]
338
+ most_common = Counter(filtered).most_common(limit)
339
+ return [word for word, _ in most_common]
340
+
341
+
342
+ __all__ = ["SearchEngine"]