cite-agent 1.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cite-agent might be problematic. Click here for more details.

Files changed (42) hide show
  1. cite_agent/__init__.py +1 -1
  2. cite_agent/account_client.py +19 -46
  3. cite_agent/agent_backend_only.py +30 -4
  4. cite_agent/cli.py +24 -26
  5. cite_agent/cli_conversational.py +294 -0
  6. cite_agent/enhanced_ai_agent.py +2776 -118
  7. cite_agent/setup_config.py +5 -21
  8. cite_agent/streaming_ui.py +252 -0
  9. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/METADATA +4 -3
  10. cite_agent-1.0.5.dist-info/RECORD +50 -0
  11. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/top_level.txt +1 -0
  12. src/__init__.py +1 -0
  13. src/services/__init__.py +132 -0
  14. src/services/auth_service/__init__.py +3 -0
  15. src/services/auth_service/auth_manager.py +33 -0
  16. src/services/graph/__init__.py +1 -0
  17. src/services/graph/knowledge_graph.py +194 -0
  18. src/services/llm_service/__init__.py +5 -0
  19. src/services/llm_service/llm_manager.py +495 -0
  20. src/services/paper_service/__init__.py +5 -0
  21. src/services/paper_service/openalex.py +231 -0
  22. src/services/performance_service/__init__.py +1 -0
  23. src/services/performance_service/rust_performance.py +395 -0
  24. src/services/research_service/__init__.py +23 -0
  25. src/services/research_service/chatbot.py +2056 -0
  26. src/services/research_service/citation_manager.py +436 -0
  27. src/services/research_service/context_manager.py +1441 -0
  28. src/services/research_service/conversation_manager.py +597 -0
  29. src/services/research_service/critical_paper_detector.py +577 -0
  30. src/services/research_service/enhanced_research.py +121 -0
  31. src/services/research_service/enhanced_synthesizer.py +375 -0
  32. src/services/research_service/query_generator.py +777 -0
  33. src/services/research_service/synthesizer.py +1273 -0
  34. src/services/search_service/__init__.py +5 -0
  35. src/services/search_service/indexer.py +186 -0
  36. src/services/search_service/search_engine.py +342 -0
  37. src/services/simple_enhanced_main.py +287 -0
  38. cite_agent/__distribution__.py +0 -7
  39. cite_agent-1.0.4.dist-info/RECORD +0 -23
  40. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/WHEEL +0 -0
  41. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/entry_points.txt +0 -0
  42. {cite_agent-1.0.4.dist-info → cite_agent-1.0.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,375 @@
1
+ """Advanced research synthesiser used by the SophisticatedResearchEngine."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import math
8
+ import statistics
9
+ import uuid
10
+ from dataclasses import dataclass
11
+ from datetime import datetime, timezone
12
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
13
+
14
+ from src.services.graph.knowledge_graph import KnowledgeGraph
15
+ from src.services.llm_service import LLMManager
16
+ from src.services.paper_service import OpenAlexClient
17
+ from src.services.performance_service.rust_performance import HighPerformanceService
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass(slots=True)
23
+ class SynthesizerPaper:
24
+ """Internal representation of a paper ready for synthesis."""
25
+
26
+ paper_id: str
27
+ title: str
28
+ abstract: str
29
+ year: Optional[int]
30
+ doi: str
31
+ authors: List[str]
32
+ keywords: List[str]
33
+ url: str
34
+ fallback: bool = False
35
+
36
+ def to_context_block(self) -> Dict[str, Any]:
37
+ return {
38
+ "title": self.title,
39
+ "content": self.abstract,
40
+ "authors": self.authors,
41
+ "year": self.year,
42
+ "keywords": self.keywords,
43
+ }
44
+
45
+
46
+ class EnhancedSynthesizer:
47
+ """High-fidelity synthesis engine with KG enrichment and telemetry."""
48
+
49
+ def __init__(
50
+ self,
51
+ *,
52
+ llm_manager: Optional[LLMManager] = None,
53
+ openalex_client: Optional[OpenAlexClient] = None,
54
+ knowledge_graph: Optional[KnowledgeGraph] = None,
55
+ performance_service: Optional[HighPerformanceService] = None,
56
+ redis_url: Optional[str] = None,
57
+ ) -> None:
58
+ self.llm = llm_manager or LLMManager(redis_url=redis_url or "redis://localhost:6379")
59
+ self.openalex = openalex_client or OpenAlexClient()
60
+ self.kg = knowledge_graph or KnowledgeGraph()
61
+ self.performance = performance_service or HighPerformanceService()
62
+ self._cache: Dict[str, Tuple[float, Dict[str, Any]]] = {}
63
+ self._cache_lock = asyncio.Lock()
64
+
65
+ # ------------------------------------------------------------------
66
+ async def synthesize_research(
67
+ self,
68
+ *,
69
+ papers: Sequence[Dict[str, Any]],
70
+ max_words: int = 500,
71
+ style: str = "comprehensive",
72
+ context: Optional[Dict[str, Any]] = None,
73
+ include_visualizations: bool = True,
74
+ include_topic_modeling: bool = True,
75
+ include_quality_assessment: bool = True,
76
+ ) -> Dict[str, Any]:
77
+ """Produce an advanced synthesis result.
78
+
79
+ Args:
80
+ papers: Pre-fetched paper payloads (OpenAlex-compatible dictionaries).
81
+ max_words: Target word count for the generated synthesis.
82
+ style: Output style hint ("comprehensive", "executive", ...).
83
+ context: Additional directives (focus areas, custom prompts, query).
84
+ Returns:
85
+ Rich synthesis dictionary consumed by the API layer.
86
+ """
87
+
88
+ context = context or {}
89
+ focus_terms = context.get("focus") or context.get("custom_prompt") or ""
90
+ trace_id = context.get("trace_id") or str(uuid.uuid4())
91
+
92
+ normalized_papers = await self._prepare_papers(papers)
93
+ cache_key = self._make_cache_key(normalized_papers, max_words, style, focus_terms)
94
+ cached = await self._read_cache(cache_key)
95
+ if cached is not None:
96
+ cached["routing_metadata"]["cached"] = True
97
+ cached["trace_id"] = trace_id
98
+ return cached
99
+
100
+ llm_prompt = self._build_prompt(len(normalized_papers), max_words, style, focus_terms)
101
+ llm_context = [paper.to_context_block() for paper in normalized_papers]
102
+ llm_result = await self.llm.generate_synthesis(llm_context, llm_prompt)
103
+ summary_text = llm_result.get("summary", "")
104
+
105
+ key_findings = self._extract_key_findings(summary_text, max_items=6)
106
+ if include_topic_modeling:
107
+ try:
108
+ keywords = await self.performance.extract_keywords(summary_text, max_keywords=12)
109
+ except Exception:
110
+ keywords = self._fallback_keywords(summary_text)
111
+ else:
112
+ keywords = self._fallback_keywords(summary_text)
113
+
114
+ citations = self._build_citations(normalized_papers)
115
+ confidence = self._estimate_confidence(normalized_papers, summary_text, key_findings)
116
+ domain_alignment = self._infer_domain(normalized_papers, focus_terms)
117
+ relevance_score = self._score_relevance(summary_text, focus_terms or context.get("original_query", ""))
118
+
119
+ metadata: Dict[str, Any] = {
120
+ "keywords": keywords,
121
+ "paper_sample_size": len(normalized_papers),
122
+ "domain_alignment": domain_alignment,
123
+ "confidence": confidence,
124
+ "generated_at": datetime.now(timezone.utc).isoformat(),
125
+ }
126
+
127
+ if include_visualizations:
128
+ metadata["visualizations"] = self._visualization_payload(normalized_papers)
129
+ if include_quality_assessment:
130
+ metadata["quality_assessment"] = self._quality_assessment(normalized_papers)
131
+
132
+ await self._upsert_knowledge_graph(normalized_papers, key_findings)
133
+
134
+ routing_metadata = {
135
+ "routing_decision": {
136
+ "model": llm_result.get("model"),
137
+ "provider": llm_result.get("provider"),
138
+ "complexity": "advanced",
139
+ "strategy": "advanced_synthesizer",
140
+ },
141
+ "usage": llm_result.get("usage", {}),
142
+ "latency": llm_result.get("latency"),
143
+ "cached": llm_result.get("cached", False),
144
+ }
145
+
146
+ result = {
147
+ "summary": summary_text,
148
+ "word_count": len(summary_text.split()),
149
+ "key_findings": key_findings,
150
+ "citations_used": citations,
151
+ "metadata": metadata,
152
+ "routing_metadata": routing_metadata,
153
+ "confidence": confidence,
154
+ "relevance_score": relevance_score,
155
+ "trace_id": trace_id,
156
+ }
157
+
158
+ await self._write_cache(cache_key, result)
159
+ return result
160
+
161
+ # ------------------------------------------------------------------
162
+ async def _prepare_papers(self, raw_papers: Sequence[Dict[str, Any]]) -> List[SynthesizerPaper]:
163
+ tasks = [self._normalize_single_paper(paper) for paper in raw_papers]
164
+ results = await asyncio.gather(*tasks, return_exceptions=True)
165
+ normalized: List[SynthesizerPaper] = []
166
+ missing_ids: List[str] = []
167
+ for entry, raw in zip(results, raw_papers):
168
+ if isinstance(entry, SynthesizerPaper):
169
+ normalized.append(entry)
170
+ else:
171
+ paper_id = str(raw.get("id") or raw.get("paper_id") or raw.get("doi") or uuid.uuid4())
172
+ missing_ids.append(paper_id)
173
+ if missing_ids:
174
+ fetched = await self.openalex.get_papers_bulk(missing_ids)
175
+ for payload in fetched:
176
+ normalized.append(await self._normalize_single_paper(payload, allow_fetch=False))
177
+ normalized = [paper for paper in normalized if paper.abstract]
178
+ deduped: Dict[str, SynthesizerPaper] = {}
179
+ for paper in normalized:
180
+ deduped[paper.paper_id] = paper
181
+ return list(deduped.values())
182
+
183
+ async def _normalize_single_paper(self, payload: Dict[str, Any], *, allow_fetch: bool = True) -> SynthesizerPaper:
184
+ paper_id = str(payload.get("id") or payload.get("paper_id") or payload.get("doi") or uuid.uuid4())
185
+ title = payload.get("title") or payload.get("display_name") or f"Paper {paper_id}"
186
+ abstract = self._extract_abstract(payload)
187
+ if not abstract and allow_fetch:
188
+ fetched = await self.openalex.get_paper_by_id(paper_id)
189
+ return await self._normalize_single_paper(fetched or {"id": paper_id}, allow_fetch=False)
190
+ doi = payload.get("doi", "").replace("https://doi.org/", "")
191
+ authors = self._extract_authors(payload)
192
+ keywords = self._extract_keywords(payload)
193
+ url = payload.get("id") or payload.get("url") or f"https://openalex.org/{paper_id}"
194
+ fallback = bool(payload.get("fallback"))
195
+ return SynthesizerPaper(
196
+ paper_id=paper_id,
197
+ title=title,
198
+ abstract=abstract,
199
+ year=payload.get("publication_year") or payload.get("year"),
200
+ doi=doi,
201
+ authors=authors,
202
+ keywords=keywords,
203
+ url=url,
204
+ fallback=fallback,
205
+ )
206
+
207
+ def _build_prompt(self, paper_count: int, max_words: int, style: str, focus: str) -> str:
208
+ focus_clause = f"Focus on: {focus}." if focus else ""
209
+ return (
210
+ "Synthesise the following scholarly papers into a {style} briefing. "
211
+ "Highlight methodological strengths, contradictions, effect sizes, and remaining research gaps. "
212
+ "Return no more than {max_words} words. {focus_clause}"
213
+ ).format(style=style, max_words=max_words, focus_clause=focus_clause)
214
+
215
+ def _extract_key_findings(self, summary: str, *, max_items: int) -> List[str]:
216
+ lines = [line.strip("-• ") for line in summary.splitlines() if line.strip()]
217
+ findings: List[str] = []
218
+ for line in lines:
219
+ if len(findings) >= max_items:
220
+ break
221
+ if len(line.split()) < 6:
222
+ continue
223
+ if any(keyword in line.lower() for keyword in ("finding", "evidence", "increase", "decrease", "%")):
224
+ findings.append(line)
225
+ elif line.endswith("."):
226
+ findings.append(line)
227
+ if not findings:
228
+ sentences = summary.split(". ")
229
+ findings = [sentence.strip() for sentence in sentences[:max_items] if sentence.strip()]
230
+ return findings
231
+
232
+ def _build_citations(self, papers: Iterable[SynthesizerPaper]) -> Dict[str, str]:
233
+ citations = {}
234
+ for idx, paper in enumerate(papers, start=1):
235
+ reference = paper.url or ("https://doi.org/" + paper.doi if paper.doi else paper.title)
236
+ citations[f"[{idx}]"] = reference
237
+ return citations
238
+
239
+ def _estimate_confidence(self, papers: Sequence[SynthesizerPaper], summary: str, findings: Sequence[str]) -> float:
240
+ fallback_penalty = 0.2 if any(paper.fallback for paper in papers) else 0.0
241
+ diversity_bonus = min(0.3, math.log10(len(papers) + 1) / 2)
242
+ finding_bonus = min(0.3, len(findings) / 10)
243
+ length_penalty = 0.1 if len(summary.split()) < 150 else 0.0
244
+ score = 0.5 + diversity_bonus + finding_bonus - fallback_penalty - length_penalty
245
+ return max(0.1, min(0.99, round(score, 3)))
246
+
247
+ def _infer_domain(self, papers: Sequence[SynthesizerPaper], focus: str) -> str:
248
+ keywords = [kw.lower() for paper in papers for kw in paper.keywords]
249
+ if focus:
250
+ keywords.extend(focus.lower().split())
251
+ if not keywords:
252
+ return "general"
253
+ if any(term in keywords for term in ("finance", "market", "equity", "stock")):
254
+ return "quantitative_finance"
255
+ if any(term in keywords for term in ("polymer", "resin", "manufacturing", "materials")):
256
+ return "advanced_materials"
257
+ if any(term in keywords for term in ("nlp", "language", "transformer", "model")):
258
+ return "ai_research"
259
+ return keywords[0][:32]
260
+
261
+ def _score_relevance(self, summary: str, query: str) -> float:
262
+ if not summary or not query:
263
+ return 0.0
264
+ summary_lower = summary.lower()
265
+ tokens = {token for token in query.lower().split() if len(token) > 3}
266
+ if not tokens:
267
+ return 0.0
268
+ matches = sum(1 for token in tokens if token in summary_lower)
269
+ return round(matches / len(tokens), 3)
270
+
271
+ def _visualization_payload(self, papers: Sequence[SynthesizerPaper]) -> Dict[str, Any]:
272
+ years = [paper.year for paper in papers if paper.year]
273
+ histogram: Dict[int, int] = {}
274
+ for year in years:
275
+ histogram[year] = histogram.get(year, 0) + 1
276
+ return {
277
+ "publication_histogram": histogram,
278
+ "paper_urls": [paper.url for paper in papers],
279
+ }
280
+
281
+ def _quality_assessment(self, papers: Sequence[SynthesizerPaper]) -> Dict[str, Any]:
282
+ citations = [paper for paper in papers if paper.doi]
283
+ avg_year = statistics.mean([paper.year for paper in papers if paper.year] or [datetime.now().year])
284
+ return {
285
+ "avg_publication_year": round(avg_year, 1),
286
+ "doi_coverage": round(len(citations) / max(1, len(papers)), 2),
287
+ "sample_size": len(papers),
288
+ }
289
+
290
+ async def _upsert_knowledge_graph(self, papers: Sequence[SynthesizerPaper], findings: Sequence[str]) -> None:
291
+ try:
292
+ for paper in papers:
293
+ entity_id = await self.kg.upsert_entity(
294
+ "Paper",
295
+ {
296
+ "id": paper.paper_id,
297
+ "title": paper.title,
298
+ "year": paper.year,
299
+ "doi": paper.doi,
300
+ "keywords": paper.keywords,
301
+ },
302
+ )
303
+ for author in paper.authors:
304
+ author_id = await self.kg.upsert_entity("Author", {"name": author})
305
+ await self.kg.upsert_relationship("authored", author_id, entity_id)
306
+ for finding in findings:
307
+ finding_id = await self.kg.upsert_entity("Finding", {"summary": finding})
308
+ for paper in papers:
309
+ await self.kg.upsert_relationship("supports", paper.paper_id, finding_id)
310
+ except Exception as exc: # pragma: no cover - KG is best-effort
311
+ logger.info("Knowledge graph enrichment failed", extra={"error": str(exc)})
312
+
313
+ def _fallback_keywords(self, text: str) -> List[str]:
314
+ import re
315
+ from collections import Counter
316
+
317
+ if not text:
318
+ return []
319
+ words = re.findall(r"[a-zA-Z]{4,}", text.lower())
320
+ stop_words = {"this", "that", "with", "from", "into", "their", "which"}
321
+ filtered = [word for word in words if word not in stop_words]
322
+ return [word for word, _ in Counter(filtered).most_common(10)]
323
+
324
+ def _extract_abstract(self, payload: Dict[str, Any]) -> str:
325
+ inverted = payload.get("abstract_inverted_index")
326
+ if isinstance(inverted, dict):
327
+ index_map: Dict[int, str] = {}
328
+ for token, positions in inverted.items():
329
+ for pos in positions:
330
+ index_map[pos] = token
331
+ return " ".join(index_map[idx] for idx in sorted(index_map))
332
+ return payload.get("abstract", "") or payload.get("summary", "") or ""
333
+
334
+ def _extract_authors(self, payload: Dict[str, Any]) -> List[str]:
335
+ authors = []
336
+ for authorship in payload.get("authorships", []):
337
+ if not isinstance(authorship, dict):
338
+ continue
339
+ name = authorship.get("author", {}).get("display_name")
340
+ if name:
341
+ authors.append(name)
342
+ if not authors and payload.get("authors"):
343
+ authors.extend([str(author) for author in payload.get("authors", [])])
344
+ return authors
345
+
346
+ def _extract_keywords(self, payload: Dict[str, Any]) -> List[str]:
347
+ keywords = []
348
+ for concept in payload.get("concepts", []) or []:
349
+ if isinstance(concept, dict) and concept.get("display_name"):
350
+ keywords.append(concept["display_name"])
351
+ if payload.get("keywords"):
352
+ keywords.extend([str(keyword) for keyword in payload.get("keywords", [])])
353
+ return keywords[:10]
354
+
355
+ def _make_cache_key(self, papers: Sequence[SynthesizerPaper], max_words: int, style: str, focus: str) -> str:
356
+ paper_ids = ",".join(sorted(paper.paper_id for paper in papers))
357
+ return f"synth:{paper_ids}:{max_words}:{style}:{focus}".lower()
358
+
359
+ async def _read_cache(self, key: str) -> Optional[Dict[str, Any]]:
360
+ async with self._cache_lock:
361
+ entry = self._cache.get(key)
362
+ if not entry:
363
+ return None
364
+ expires_at, value = entry
365
+ if datetime.now(timezone.utc).timestamp() > expires_at:
366
+ self._cache.pop(key, None)
367
+ return None
368
+ return dict(value)
369
+
370
+ async def _write_cache(self, key: str, value: Dict[str, Any], ttl: int = 900) -> None:
371
+ async with self._cache_lock:
372
+ self._cache[key] = (datetime.now(timezone.utc).timestamp() + ttl, dict(value))
373
+
374
+
375
+ __all__ = ["EnhancedSynthesizer"]