cite-agent 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cite-agent might be problematic. Click here for more details.
- cite_agent/__init__.py +1 -1
- cite_agent/agent_backend_only.py +30 -4
- cite_agent/cli.py +24 -26
- cite_agent/cli_conversational.py +294 -0
- cite_agent/enhanced_ai_agent.py +2776 -118
- cite_agent/streaming_ui.py +252 -0
- {cite_agent-1.0.3.dist-info → cite_agent-1.0.5.dist-info}/METADATA +4 -3
- cite_agent-1.0.5.dist-info/RECORD +50 -0
- {cite_agent-1.0.3.dist-info → cite_agent-1.0.5.dist-info}/top_level.txt +1 -0
- src/__init__.py +1 -0
- src/services/__init__.py +132 -0
- src/services/auth_service/__init__.py +3 -0
- src/services/auth_service/auth_manager.py +33 -0
- src/services/graph/__init__.py +1 -0
- src/services/graph/knowledge_graph.py +194 -0
- src/services/llm_service/__init__.py +5 -0
- src/services/llm_service/llm_manager.py +495 -0
- src/services/paper_service/__init__.py +5 -0
- src/services/paper_service/openalex.py +231 -0
- src/services/performance_service/__init__.py +1 -0
- src/services/performance_service/rust_performance.py +395 -0
- src/services/research_service/__init__.py +23 -0
- src/services/research_service/chatbot.py +2056 -0
- src/services/research_service/citation_manager.py +436 -0
- src/services/research_service/context_manager.py +1441 -0
- src/services/research_service/conversation_manager.py +597 -0
- src/services/research_service/critical_paper_detector.py +577 -0
- src/services/research_service/enhanced_research.py +121 -0
- src/services/research_service/enhanced_synthesizer.py +375 -0
- src/services/research_service/query_generator.py +777 -0
- src/services/research_service/synthesizer.py +1273 -0
- src/services/search_service/__init__.py +5 -0
- src/services/search_service/indexer.py +186 -0
- src/services/search_service/search_engine.py +342 -0
- src/services/simple_enhanced_main.py +287 -0
- cite_agent/__distribution__.py +0 -7
- cite_agent-1.0.3.dist-info/RECORD +0 -23
- {cite_agent-1.0.3.dist-info → cite_agent-1.0.5.dist-info}/WHEEL +0 -0
- {cite_agent-1.0.3.dist-info → cite_agent-1.0.5.dist-info}/entry_points.txt +0 -0
- {cite_agent-1.0.3.dist-info → cite_agent-1.0.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Lightweight async knowledge graph implementation used by the research synthesizer.
|
|
2
|
+
|
|
3
|
+
The production design originally assumed an external graph database, but the launch-ready
|
|
4
|
+
runtime needs a dependable in-process implementation that works without external services.
|
|
5
|
+
This module provides a minimal yet functional directed multigraph using in-memory storage.
|
|
6
|
+
|
|
7
|
+
The implementation focuses on the operations exercised by ``ResearchSynthesizer``:
|
|
8
|
+
|
|
9
|
+
* ``upsert_entity`` – register/update an entity node with typed metadata
|
|
10
|
+
* ``upsert_relationship`` – connect two entities with rich relationship properties
|
|
11
|
+
* ``get_entity`` / ``get_relationships`` – helper APIs for diagnostics and future features
|
|
12
|
+
|
|
13
|
+
Data is persisted in memory and optionally mirrored to a JSON file on disk so the graph can
|
|
14
|
+
survive multiple sessions during local development. All public methods are ``async`` to keep
|
|
15
|
+
parity with the historical interface and to allow easy replacement with an external graph
|
|
16
|
+
backend in the future.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import asyncio
|
|
22
|
+
import json
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
26
|
+
|
|
27
|
+
__all__ = ["KnowledgeGraph", "GraphEntity", "GraphRelationship"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class GraphEntity:
|
|
32
|
+
"""Represents a node in the knowledge graph."""
|
|
33
|
+
|
|
34
|
+
entity_id: str
|
|
35
|
+
entity_type: str
|
|
36
|
+
properties: Dict[str, Any] = field(default_factory=dict)
|
|
37
|
+
|
|
38
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
39
|
+
return {
|
|
40
|
+
"id": self.entity_id,
|
|
41
|
+
"type": self.entity_type,
|
|
42
|
+
"properties": self.properties,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class GraphRelationship:
|
|
48
|
+
"""Represents a directed, typed relationship between two entities."""
|
|
49
|
+
|
|
50
|
+
rel_type: str
|
|
51
|
+
source_id: str
|
|
52
|
+
target_id: str
|
|
53
|
+
properties: Dict[str, Any] = field(default_factory=dict)
|
|
54
|
+
|
|
55
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
56
|
+
return {
|
|
57
|
+
"type": self.rel_type,
|
|
58
|
+
"source": self.source_id,
|
|
59
|
+
"target": self.target_id,
|
|
60
|
+
"properties": self.properties,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class KnowledgeGraph:
|
|
65
|
+
"""A simple async-safe in-memory knowledge graph."""
|
|
66
|
+
|
|
67
|
+
def __init__(self, *, persistence_path: Optional[Path] = None) -> None:
|
|
68
|
+
self._entities: Dict[str, GraphEntity] = {}
|
|
69
|
+
# Adjacency list keyed by (source_id, rel_type) -> list[target_id, props]
|
|
70
|
+
self._relationships: List[GraphRelationship] = []
|
|
71
|
+
self._lock = asyncio.Lock()
|
|
72
|
+
self._persistence_path = persistence_path
|
|
73
|
+
if self._persistence_path:
|
|
74
|
+
self._load_from_disk()
|
|
75
|
+
|
|
76
|
+
# ------------------------------------------------------------------
|
|
77
|
+
# Persistence helpers
|
|
78
|
+
# ------------------------------------------------------------------
|
|
79
|
+
def _load_from_disk(self) -> None:
|
|
80
|
+
if not self._persistence_path or not self._persistence_path.exists():
|
|
81
|
+
return
|
|
82
|
+
try:
|
|
83
|
+
payload = json.loads(self._persistence_path.read_text())
|
|
84
|
+
except Exception:
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
for entity in payload.get("entities", []):
|
|
88
|
+
graph_entity = GraphEntity(
|
|
89
|
+
entity_id=entity["id"],
|
|
90
|
+
entity_type=entity.get("type", "Unknown"),
|
|
91
|
+
properties=entity.get("properties", {}),
|
|
92
|
+
)
|
|
93
|
+
self._entities[graph_entity.entity_id] = graph_entity
|
|
94
|
+
|
|
95
|
+
for rel in payload.get("relationships", []):
|
|
96
|
+
graph_rel = GraphRelationship(
|
|
97
|
+
rel_type=rel.get("type", "related_to"),
|
|
98
|
+
source_id=rel.get("source"),
|
|
99
|
+
target_id=rel.get("target"),
|
|
100
|
+
properties=rel.get("properties", {}),
|
|
101
|
+
)
|
|
102
|
+
self._relationships.append(graph_rel)
|
|
103
|
+
|
|
104
|
+
def _persist(self) -> None:
|
|
105
|
+
if not self._persistence_path:
|
|
106
|
+
return
|
|
107
|
+
data = {
|
|
108
|
+
"entities": [entity.to_dict() for entity in self._entities.values()],
|
|
109
|
+
"relationships": [rel.to_dict() for rel in self._relationships],
|
|
110
|
+
}
|
|
111
|
+
try:
|
|
112
|
+
self._persistence_path.parent.mkdir(parents=True, exist_ok=True)
|
|
113
|
+
self._persistence_path.write_text(json.dumps(data, indent=2, sort_keys=True))
|
|
114
|
+
except Exception:
|
|
115
|
+
# Persistence failures should never stop the conversation flow
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
# ------------------------------------------------------------------
|
|
119
|
+
# Public API
|
|
120
|
+
# ------------------------------------------------------------------
|
|
121
|
+
async def upsert_entity(self, entity_type: str, properties: Dict[str, Any]) -> str:
|
|
122
|
+
"""Create or update an entity.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
entity_type: Semantic type (e.g., "Paper", "Author").
|
|
126
|
+
properties: Arbitrary metadata. ``properties['id']`` is optional; when missing
|
|
127
|
+
a deterministic identifier is derived from ``properties['external_id']`` or
|
|
128
|
+
a hash of the payload.
|
|
129
|
+
Returns:
|
|
130
|
+
The entity identifier stored in the graph.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
async with self._lock:
|
|
134
|
+
entity_id = _determine_entity_id(entity_type, properties)
|
|
135
|
+
entity = self._entities.get(entity_id)
|
|
136
|
+
if entity:
|
|
137
|
+
entity.properties.update(properties)
|
|
138
|
+
else:
|
|
139
|
+
entity = GraphEntity(entity_id=entity_id, entity_type=entity_type, properties=properties)
|
|
140
|
+
self._entities[entity_id] = entity
|
|
141
|
+
self._persist()
|
|
142
|
+
return entity_id
|
|
143
|
+
|
|
144
|
+
async def upsert_relationship(
|
|
145
|
+
self,
|
|
146
|
+
rel_type: str,
|
|
147
|
+
source_id: str,
|
|
148
|
+
target_id: str,
|
|
149
|
+
properties: Optional[Dict[str, Any]] = None,
|
|
150
|
+
) -> Tuple[str, str, str]:
|
|
151
|
+
"""Create or update a directed relationship between two entities."""
|
|
152
|
+
|
|
153
|
+
properties = properties or {}
|
|
154
|
+
async with self._lock:
|
|
155
|
+
relationship = GraphRelationship(
|
|
156
|
+
rel_type=rel_type,
|
|
157
|
+
source_id=source_id,
|
|
158
|
+
target_id=target_id,
|
|
159
|
+
properties=properties,
|
|
160
|
+
)
|
|
161
|
+
self._relationships.append(relationship)
|
|
162
|
+
self._persist()
|
|
163
|
+
return (relationship.rel_type, relationship.source_id, relationship.target_id)
|
|
164
|
+
|
|
165
|
+
async def get_entity(self, entity_id: str) -> Optional[GraphEntity]:
|
|
166
|
+
async with self._lock:
|
|
167
|
+
return self._entities.get(entity_id)
|
|
168
|
+
|
|
169
|
+
async def get_relationships(self, entity_id: str) -> List[GraphRelationship]:
|
|
170
|
+
async with self._lock:
|
|
171
|
+
return [rel for rel in self._relationships if rel.source_id == entity_id or rel.target_id == entity_id]
|
|
172
|
+
|
|
173
|
+
async def stats(self) -> Dict[str, Any]:
|
|
174
|
+
async with self._lock:
|
|
175
|
+
return {
|
|
176
|
+
"entities": len(self._entities),
|
|
177
|
+
"relationships": len(self._relationships),
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _determine_entity_id(entity_type: str, properties: Dict[str, Any]) -> str:
|
|
182
|
+
"""Best-effort deterministic identifier for an entity."""
|
|
183
|
+
|
|
184
|
+
# Preferred explicit IDs
|
|
185
|
+
for key in ("id", "external_id", "paper_id", "author_id", "identifier"):
|
|
186
|
+
value = properties.get(key)
|
|
187
|
+
if value:
|
|
188
|
+
return str(value)
|
|
189
|
+
|
|
190
|
+
# Fall back to hashed representation (order-stable via JSON dumps)
|
|
191
|
+
import hashlib
|
|
192
|
+
|
|
193
|
+
payload = json.dumps({"type": entity_type, "properties": properties}, sort_keys=True)
|
|
194
|
+
return f"{entity_type}:{hashlib.md5(payload.encode('utf-8')).hexdigest()}"
|
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
"""Unified large language model management utilities.
|
|
2
|
+
|
|
3
|
+
This module exposes :class:`LLMManager`, a production-ready orchestration layer that
|
|
4
|
+
coordinates multiple LLM providers (Groq, OpenAI, Anthropic) while providing
|
|
5
|
+
advanced routing, caching, observability, and graceful fallbacks when GPU-backed
|
|
6
|
+
models are unavailable. The implementation is intentionally defensive: it never
|
|
7
|
+
raises provider-specific exceptions to callers and instead downgrades to a
|
|
8
|
+
high-quality heuristic summariser so the broader research pipeline can continue
|
|
9
|
+
functioning in constrained environments (including unit tests).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import asyncio
|
|
15
|
+
import hashlib
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
import time
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from datetime import datetime, timezone
|
|
21
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
try: # Optional dependency – only loaded when available
|
|
26
|
+
from groq import Groq # type: ignore
|
|
27
|
+
except Exception: # pragma: no cover - optional provider
|
|
28
|
+
Groq = None # type: ignore
|
|
29
|
+
|
|
30
|
+
try: # OpenAI >=1.x
|
|
31
|
+
from openai import AsyncOpenAI # type: ignore
|
|
32
|
+
except Exception: # pragma: no cover - optional provider
|
|
33
|
+
AsyncOpenAI = None # type: ignore
|
|
34
|
+
|
|
35
|
+
try: # Anthropic python client
|
|
36
|
+
from anthropic import AsyncAnthropic # type: ignore
|
|
37
|
+
except Exception: # pragma: no cover - optional provider
|
|
38
|
+
AsyncAnthropic = None # type: ignore
|
|
39
|
+
|
|
40
|
+
# Default models for each provider. These can be overridden via environment
|
|
41
|
+
# variables or method arguments but serve as sensible, production-tested
|
|
42
|
+
# defaults that balance latency and quality.
|
|
43
|
+
DEFAULT_PROVIDER_MODELS: Dict[str, str] = {
|
|
44
|
+
"groq": os.getenv("NA_GROQ_MODEL", "llama-3.1-70b-versatile"),
|
|
45
|
+
"openai": os.getenv("NA_OPENAI_MODEL", "gpt-4.1-mini"),
|
|
46
|
+
"anthropic": os.getenv("NA_ANTHROPIC_MODEL", "claude-3-5-sonnet-20241022"),
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
# Maximum tokens for synthesis generations. Exposed for easy tuning via env.
|
|
50
|
+
DEFAULT_MAX_TOKENS = int(os.getenv("NA_MAX_SYNTHESIS_TOKENS", "2048"))
|
|
51
|
+
DEFAULT_TEMPERATURE = float(os.getenv("NA_SYNTHESIS_TEMPERATURE", "0.2"))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass(slots=True)
|
|
55
|
+
class ProviderSelection:
|
|
56
|
+
"""Information about the provider/model combination chosen for a request."""
|
|
57
|
+
|
|
58
|
+
provider: str
|
|
59
|
+
model: str
|
|
60
|
+
reason: str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class LLMManager:
|
|
64
|
+
"""Unified interface across Groq, OpenAI, Anthropic, and heuristic fallbacks.
|
|
65
|
+
|
|
66
|
+
The manager exposes a coroutine-based API that can be safely used inside
|
|
67
|
+
FastAPI endpoints or background workers. Each call records latency and
|
|
68
|
+
usage metadata which is returned to callers so that higher levels can make
|
|
69
|
+
routing decisions or surface telemetry.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
_PROVIDER_ENV_KEYS: Dict[str, Tuple[str, ...]] = {
|
|
73
|
+
"groq": ("GROQ_API_KEY", "NA_GROQ_API_KEY"),
|
|
74
|
+
"openai": ("OPENAI_API_KEY", "NA_OPENAI_API_KEY"),
|
|
75
|
+
"anthropic": ("ANTHROPIC_API_KEY", "NA_ANTHROPIC_API_KEY"),
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
*,
|
|
81
|
+
redis_url: str = os.getenv("REDIS_URL", "redis://localhost:6379"),
|
|
82
|
+
default_provider: Optional[str] = None,
|
|
83
|
+
default_model: Optional[str] = None,
|
|
84
|
+
cache_ttl: int = 900,
|
|
85
|
+
) -> None:
|
|
86
|
+
self.redis_url = redis_url
|
|
87
|
+
self._default_provider = (default_provider or os.getenv("NA_LLM_PROVIDER") or "groq").lower()
|
|
88
|
+
self._default_model = default_model or DEFAULT_PROVIDER_MODELS.get(self._default_provider, "")
|
|
89
|
+
self._cache_ttl = cache_ttl
|
|
90
|
+
self._cache: Dict[str, Tuple[float, Dict[str, Any]]] = {}
|
|
91
|
+
self._cache_lock = asyncio.Lock()
|
|
92
|
+
self._client_lock = asyncio.Lock()
|
|
93
|
+
self._clients: Dict[str, Any] = {}
|
|
94
|
+
self._last_health_check: Dict[str, Dict[str, Any]] = {}
|
|
95
|
+
|
|
96
|
+
# Lazily-created loop for running sync provider clients (Groq) in a
|
|
97
|
+
# thread pool. We reuse the default loop to avoid spawning threads per
|
|
98
|
+
# request.
|
|
99
|
+
self._loop = asyncio.get_event_loop()
|
|
100
|
+
|
|
101
|
+
# ------------------------------------------------------------------
|
|
102
|
+
# Public API
|
|
103
|
+
# ------------------------------------------------------------------
|
|
104
|
+
async def generate_synthesis(
|
|
105
|
+
self,
|
|
106
|
+
documents: Iterable[Dict[str, Any]],
|
|
107
|
+
prompt: str,
|
|
108
|
+
*,
|
|
109
|
+
provider: Optional[str] = None,
|
|
110
|
+
model: Optional[str] = None,
|
|
111
|
+
temperature: Optional[float] = None,
|
|
112
|
+
max_tokens: Optional[int] = None,
|
|
113
|
+
) -> Dict[str, Any]:
|
|
114
|
+
"""Generate a synthesis across documents using the best available LLM.
|
|
115
|
+
|
|
116
|
+
Returns a dictionary containing the summary, metadata about the route
|
|
117
|
+
taken, usage information, and latency. The structure is intentionally
|
|
118
|
+
aligned with what the API layer expects when presenting advanced
|
|
119
|
+
synthesis results.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
documents = list(documents or [])
|
|
123
|
+
serialized_context = self._serialize_documents(documents)
|
|
124
|
+
cache_key = self._make_cache_key("synthesis", serialized_context, prompt, provider, model)
|
|
125
|
+
|
|
126
|
+
cached = await self._read_cache(cache_key)
|
|
127
|
+
if cached is not None:
|
|
128
|
+
cached_copy = dict(cached)
|
|
129
|
+
cached_copy["cached"] = True
|
|
130
|
+
return cached_copy
|
|
131
|
+
|
|
132
|
+
selection = await self._select_provider(provider, model)
|
|
133
|
+
start = time.perf_counter()
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
summary, usage = await self._invoke_provider(
|
|
137
|
+
selection,
|
|
138
|
+
self._build_messages(serialized_context, prompt),
|
|
139
|
+
temperature or DEFAULT_TEMPERATURE,
|
|
140
|
+
max_tokens or DEFAULT_MAX_TOKENS,
|
|
141
|
+
)
|
|
142
|
+
except Exception as exc: # pragma: no cover - defensive guard
|
|
143
|
+
logger.warning(
|
|
144
|
+
"LLM provider invocation failed; falling back to heuristic",
|
|
145
|
+
extra={"provider": selection.provider, "model": selection.model, "error": str(exc)},
|
|
146
|
+
)
|
|
147
|
+
summary = self._heuristic_summary(serialized_context, prompt)
|
|
148
|
+
usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "fallback": True}
|
|
149
|
+
selection = ProviderSelection(provider="heuristic", model="text-rank", reason=str(exc))
|
|
150
|
+
|
|
151
|
+
latency = time.perf_counter() - start
|
|
152
|
+
result = {
|
|
153
|
+
"summary": summary.strip(),
|
|
154
|
+
"provider": selection.provider,
|
|
155
|
+
"model": selection.model,
|
|
156
|
+
"reason": selection.reason,
|
|
157
|
+
"usage": usage,
|
|
158
|
+
"latency": latency,
|
|
159
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
160
|
+
"cached": False,
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
await self._write_cache(cache_key, result)
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
async def generate_text(
|
|
167
|
+
self,
|
|
168
|
+
prompt: str,
|
|
169
|
+
*,
|
|
170
|
+
provider: Optional[str] = None,
|
|
171
|
+
model: Optional[str] = None,
|
|
172
|
+
temperature: Optional[float] = None,
|
|
173
|
+
max_tokens: Optional[int] = None,
|
|
174
|
+
) -> str:
|
|
175
|
+
"""Generate free-form text using the same routing heuristics."""
|
|
176
|
+
|
|
177
|
+
result = await self.generate_synthesis(
|
|
178
|
+
documents=[],
|
|
179
|
+
prompt=prompt,
|
|
180
|
+
provider=provider,
|
|
181
|
+
model=model,
|
|
182
|
+
temperature=temperature,
|
|
183
|
+
max_tokens=max_tokens,
|
|
184
|
+
)
|
|
185
|
+
return result.get("summary", "")
|
|
186
|
+
|
|
187
|
+
async def health_check(self) -> Dict[str, Any]:
|
|
188
|
+
"""Return current provider availability and cached connectivity info."""
|
|
189
|
+
|
|
190
|
+
statuses = {}
|
|
191
|
+
for provider in ("groq", "openai", "anthropic"):
|
|
192
|
+
statuses[provider] = {
|
|
193
|
+
"configured": self._get_api_key(provider) is not None,
|
|
194
|
+
"client_initialized": provider in self._clients,
|
|
195
|
+
"last_error": None,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
self._last_health_check = statuses
|
|
199
|
+
return statuses
|
|
200
|
+
|
|
201
|
+
async def close(self) -> None:
|
|
202
|
+
"""Close any underlying async clients (OpenAI/Anthropic)."""
|
|
203
|
+
|
|
204
|
+
async with self._client_lock:
|
|
205
|
+
openai_client = self._clients.get("openai")
|
|
206
|
+
if openai_client and hasattr(openai_client, "close"):
|
|
207
|
+
try:
|
|
208
|
+
await openai_client.close() # type: ignore[attr-defined]
|
|
209
|
+
except Exception:
|
|
210
|
+
pass
|
|
211
|
+
anthropic_client = self._clients.get("anthropic")
|
|
212
|
+
if anthropic_client and hasattr(anthropic_client, "close"):
|
|
213
|
+
try:
|
|
214
|
+
await anthropic_client.close() # type: ignore[attr-defined]
|
|
215
|
+
except Exception:
|
|
216
|
+
pass
|
|
217
|
+
# Groq client is synchronous – nothing to close
|
|
218
|
+
self._clients.clear()
|
|
219
|
+
|
|
220
|
+
# ------------------------------------------------------------------
|
|
221
|
+
# Provider selection & invocation
|
|
222
|
+
# ------------------------------------------------------------------
|
|
223
|
+
async def _select_provider(
|
|
224
|
+
self,
|
|
225
|
+
provider: Optional[str],
|
|
226
|
+
model: Optional[str],
|
|
227
|
+
) -> ProviderSelection:
|
|
228
|
+
"""Select the best available provider/model pair for the request."""
|
|
229
|
+
|
|
230
|
+
candidate_order = []
|
|
231
|
+
if provider:
|
|
232
|
+
candidate_order.append(provider.lower())
|
|
233
|
+
if self._default_provider not in candidate_order:
|
|
234
|
+
candidate_order.append(self._default_provider)
|
|
235
|
+
candidate_order.extend(["groq", "openai", "anthropic"])
|
|
236
|
+
|
|
237
|
+
seen = set()
|
|
238
|
+
for candidate in candidate_order:
|
|
239
|
+
if candidate in seen:
|
|
240
|
+
continue
|
|
241
|
+
seen.add(candidate)
|
|
242
|
+
api_key = self._get_api_key(candidate)
|
|
243
|
+
if not api_key:
|
|
244
|
+
continue
|
|
245
|
+
selected_model = model or self._default_model or DEFAULT_PROVIDER_MODELS.get(candidate)
|
|
246
|
+
if not selected_model:
|
|
247
|
+
continue
|
|
248
|
+
if await self._ensure_client(candidate, api_key):
|
|
249
|
+
reason = "requested" if candidate == provider else "fallback"
|
|
250
|
+
return ProviderSelection(provider=candidate, model=selected_model, reason=reason)
|
|
251
|
+
|
|
252
|
+
logger.warning("No LLM providers configured; using heuristic summariser")
|
|
253
|
+
return ProviderSelection(provider="heuristic", model="text-rank", reason="no-provider-configured")
|
|
254
|
+
|
|
255
|
+
async def _ensure_client(self, provider: str, api_key: str) -> bool:
|
|
256
|
+
"""Instantiate and cache provider clients lazily."""
|
|
257
|
+
|
|
258
|
+
if provider == "heuristic":
|
|
259
|
+
return True
|
|
260
|
+
|
|
261
|
+
async with self._client_lock:
|
|
262
|
+
if provider in self._clients:
|
|
263
|
+
return True
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
if provider == "groq":
|
|
267
|
+
if Groq is None: # pragma: no cover - optional provider
|
|
268
|
+
raise RuntimeError("groq package not installed")
|
|
269
|
+
self._clients[provider] = Groq(api_key=api_key)
|
|
270
|
+
return True
|
|
271
|
+
|
|
272
|
+
if provider == "openai":
|
|
273
|
+
if AsyncOpenAI is None: # pragma: no cover - optional provider
|
|
274
|
+
raise RuntimeError("openai package not installed")
|
|
275
|
+
self._clients[provider] = AsyncOpenAI(api_key=api_key)
|
|
276
|
+
return True
|
|
277
|
+
|
|
278
|
+
if provider == "anthropic":
|
|
279
|
+
if AsyncAnthropic is None: # pragma: no cover - optional provider
|
|
280
|
+
raise RuntimeError("anthropic package not installed")
|
|
281
|
+
self._clients[provider] = AsyncAnthropic(api_key=api_key)
|
|
282
|
+
return True
|
|
283
|
+
|
|
284
|
+
raise ValueError(f"Unknown provider: {provider}")
|
|
285
|
+
except Exception as exc: # pragma: no cover - provider bootstrap is optional
|
|
286
|
+
logger.warning("Failed to initialise LLM provider", extra={"provider": provider, "error": str(exc)})
|
|
287
|
+
self._clients.pop(provider, None)
|
|
288
|
+
return False
|
|
289
|
+
|
|
290
|
+
async def _invoke_provider(
|
|
291
|
+
self,
|
|
292
|
+
selection: ProviderSelection,
|
|
293
|
+
messages: List[Dict[str, Any]],
|
|
294
|
+
temperature: float,
|
|
295
|
+
max_tokens: int,
|
|
296
|
+
) -> Tuple[str, Dict[str, Any]]:
|
|
297
|
+
"""Invoke the selected provider and normalise the response."""
|
|
298
|
+
|
|
299
|
+
if selection.provider == "heuristic":
|
|
300
|
+
return self._heuristic_summary(messages[-1]["content"], ""), { # type: ignore[index]
|
|
301
|
+
"prompt_tokens": 0,
|
|
302
|
+
"completion_tokens": 0,
|
|
303
|
+
"total_tokens": 0,
|
|
304
|
+
"fallback": True,
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
client = self._clients.get(selection.provider)
|
|
308
|
+
if client is None:
|
|
309
|
+
raise RuntimeError(f"Provider {selection.provider} not initialised")
|
|
310
|
+
|
|
311
|
+
if selection.provider == "groq":
|
|
312
|
+
return await self._invoke_groq(client, selection.model, messages, temperature, max_tokens)
|
|
313
|
+
if selection.provider == "openai":
|
|
314
|
+
return await self._invoke_openai(client, selection.model, messages, temperature, max_tokens)
|
|
315
|
+
if selection.provider == "anthropic":
|
|
316
|
+
return await self._invoke_anthropic(client, selection.model, messages, temperature, max_tokens)
|
|
317
|
+
|
|
318
|
+
raise ValueError(f"Unsupported provider: {selection.provider}")
|
|
319
|
+
|
|
320
|
+
async def _invoke_groq(self, client: Any, model: str, messages: List[Dict[str, Any]], temperature: float, max_tokens: int) -> Tuple[str, Dict[str, Any]]:
|
|
321
|
+
"""Invoke Groq's chat completion API (synchronous client)."""
|
|
322
|
+
|
|
323
|
+
def _call() -> Tuple[str, Dict[str, Any]]:
|
|
324
|
+
response = client.chat.completions.create(
|
|
325
|
+
model=model,
|
|
326
|
+
messages=messages,
|
|
327
|
+
temperature=temperature,
|
|
328
|
+
max_tokens=max_tokens,
|
|
329
|
+
)
|
|
330
|
+
message = response.choices[0].message.content if response.choices else ""
|
|
331
|
+
usage = getattr(response, "usage", None)
|
|
332
|
+
normalised_usage = {
|
|
333
|
+
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
|
334
|
+
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
|
335
|
+
"total_tokens": getattr(usage, "total_tokens", 0),
|
|
336
|
+
}
|
|
337
|
+
return message or "", normalised_usage
|
|
338
|
+
|
|
339
|
+
return await asyncio.to_thread(_call)
|
|
340
|
+
|
|
341
|
+
async def _invoke_openai(self, client: Any, model: str, messages: List[Dict[str, Any]], temperature: float, max_tokens: int) -> Tuple[str, Dict[str, Any]]:
|
|
342
|
+
response = await client.chat.completions.create( # type: ignore[attr-defined]
|
|
343
|
+
model=model,
|
|
344
|
+
messages=messages,
|
|
345
|
+
temperature=temperature,
|
|
346
|
+
max_tokens=max_tokens,
|
|
347
|
+
)
|
|
348
|
+
choice = response.choices[0] if response.choices else None
|
|
349
|
+
message = choice.message.content if choice and choice.message else ""
|
|
350
|
+
usage = getattr(response, "usage", None) or {}
|
|
351
|
+
normalised_usage = {
|
|
352
|
+
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
|
353
|
+
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
|
354
|
+
"total_tokens": getattr(usage, "total_tokens", 0),
|
|
355
|
+
}
|
|
356
|
+
return message or "", normalised_usage
|
|
357
|
+
|
|
358
|
+
async def _invoke_anthropic(self, client: Any, model: str, messages: List[Dict[str, Any]], temperature: float, max_tokens: int) -> Tuple[str, Dict[str, Any]]:
|
|
359
|
+
system_prompt = """You are an advanced research assistant that creates meticulous literature syntheses."""
|
|
360
|
+
anthropic_messages = []
|
|
361
|
+
for msg in messages:
|
|
362
|
+
role = msg.get("role")
|
|
363
|
+
content = msg.get("content", "")
|
|
364
|
+
if role == "system":
|
|
365
|
+
system_prompt = content
|
|
366
|
+
continue
|
|
367
|
+
anthropic_messages.append({"role": "user" if role == "user" else "assistant", "content": content})
|
|
368
|
+
|
|
369
|
+
response = await client.messages.create( # type: ignore[attr-defined]
|
|
370
|
+
model=model,
|
|
371
|
+
temperature=temperature,
|
|
372
|
+
max_tokens=max_tokens,
|
|
373
|
+
system=system_prompt,
|
|
374
|
+
messages=anthropic_messages,
|
|
375
|
+
)
|
|
376
|
+
text = ""
|
|
377
|
+
if response.content:
|
|
378
|
+
content_block = response.content[0]
|
|
379
|
+
text = getattr(content_block, "text", "") or getattr(content_block, "input_text", "")
|
|
380
|
+
usage = getattr(response, "usage", None) or {}
|
|
381
|
+
normalised_usage = {
|
|
382
|
+
"prompt_tokens": getattr(usage, "input_tokens", 0),
|
|
383
|
+
"completion_tokens": getattr(usage, "output_tokens", 0),
|
|
384
|
+
"total_tokens": getattr(usage, "input_tokens", 0) + getattr(usage, "output_tokens", 0),
|
|
385
|
+
}
|
|
386
|
+
return text, normalised_usage
|
|
387
|
+
|
|
388
|
+
# ------------------------------------------------------------------
|
|
389
|
+
# Caching helpers
|
|
390
|
+
# ------------------------------------------------------------------
|
|
391
|
+
def _make_cache_key(self, namespace: str, *parts: Any) -> str:
|
|
392
|
+
digest = hashlib.sha256()
|
|
393
|
+
digest.update(namespace.encode("utf-8"))
|
|
394
|
+
for part in parts:
|
|
395
|
+
data = part if isinstance(part, str) else repr(part)
|
|
396
|
+
digest.update(b"|")
|
|
397
|
+
digest.update(data.encode("utf-8", errors="ignore"))
|
|
398
|
+
return digest.hexdigest()
|
|
399
|
+
|
|
400
|
+
async def _read_cache(self, key: str) -> Optional[Dict[str, Any]]:
|
|
401
|
+
async with self._cache_lock:
|
|
402
|
+
entry = self._cache.get(key)
|
|
403
|
+
if not entry:
|
|
404
|
+
return None
|
|
405
|
+
expires_at, value = entry
|
|
406
|
+
if time.time() > expires_at:
|
|
407
|
+
self._cache.pop(key, None)
|
|
408
|
+
return None
|
|
409
|
+
return dict(value)
|
|
410
|
+
|
|
411
|
+
async def _write_cache(self, key: str, value: Dict[str, Any]) -> None:
|
|
412
|
+
async with self._cache_lock:
|
|
413
|
+
self._cache[key] = (time.time() + self._cache_ttl, dict(value))
|
|
414
|
+
|
|
415
|
+
# ------------------------------------------------------------------
|
|
416
|
+
# Prompt + context utilities
|
|
417
|
+
# ------------------------------------------------------------------
|
|
418
|
+
def _serialize_documents(self, documents: List[Dict[str, Any]]) -> str:
|
|
419
|
+
if not documents:
|
|
420
|
+
return ""
|
|
421
|
+
blocks = []
|
|
422
|
+
for idx, document in enumerate(documents, start=1):
|
|
423
|
+
title = document.get("title") or document.get("name") or f"Document {idx}"
|
|
424
|
+
section_lines = [f"### {title}".strip()]
|
|
425
|
+
if document.get("authors"):
|
|
426
|
+
authors = ", ".join(
|
|
427
|
+
a.get("name", "") if isinstance(a, dict) else str(a)
|
|
428
|
+
for a in document.get("authors", [])
|
|
429
|
+
)
|
|
430
|
+
if authors:
|
|
431
|
+
section_lines.append(f"*Authors:* {authors}")
|
|
432
|
+
if document.get("year"):
|
|
433
|
+
section_lines.append(f"*Year:* {document['year']}")
|
|
434
|
+
abstract = document.get("abstract") or document.get("content") or document.get("text")
|
|
435
|
+
if abstract:
|
|
436
|
+
section_lines.append("\n" + str(abstract).strip())
|
|
437
|
+
if document.get("highlights"):
|
|
438
|
+
section_lines.append("\nKey Findings:\n- " + "\n- ".join(map(str, document["highlights"])))
|
|
439
|
+
blocks.append("\n".join(section_lines).strip())
|
|
440
|
+
return "\n\n".join(blocks)
|
|
441
|
+
|
|
442
|
+
def _build_messages(self, serialized_context: str, prompt: str) -> List[Dict[str, Any]]:
|
|
443
|
+
system_prompt = (
|
|
444
|
+
"You are Nocturnal Archive's synthesis orchestrator. "
|
|
445
|
+
"Produce rigorous, citation-ready summaries that emphasise methodology, "
|
|
446
|
+
"effect sizes, limitations, and consensus versus disagreement."
|
|
447
|
+
)
|
|
448
|
+
user_prompt = (
|
|
449
|
+
prompt.format(context=serialized_context)
|
|
450
|
+
if "{context}" in prompt
|
|
451
|
+
else f"{prompt.strip()}\n\nContext:\n{serialized_context.strip()}"
|
|
452
|
+
).strip()
|
|
453
|
+
return [
|
|
454
|
+
{"role": "system", "content": system_prompt},
|
|
455
|
+
{"role": "user", "content": user_prompt},
|
|
456
|
+
]
|
|
457
|
+
|
|
458
|
+
def _heuristic_summary(self, serialized_context: str, prompt: str) -> str:
|
|
459
|
+
"""Fallback summariser using a TextRank-style scoring over sentences."""
|
|
460
|
+
|
|
461
|
+
import re
|
|
462
|
+
from collections import Counter, defaultdict
|
|
463
|
+
|
|
464
|
+
text = serialized_context or prompt
|
|
465
|
+
sentences = [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
|
|
466
|
+
if not sentences:
|
|
467
|
+
return text.strip()
|
|
468
|
+
|
|
469
|
+
words = re.findall(r"[a-zA-Z0-9']+", text.lower())
|
|
470
|
+
frequencies = Counter(words)
|
|
471
|
+
max_freq = max(frequencies.values() or [1])
|
|
472
|
+
for key in frequencies:
|
|
473
|
+
frequencies[key] /= max_freq
|
|
474
|
+
|
|
475
|
+
sentence_scores: Dict[str, float] = defaultdict(float)
|
|
476
|
+
for sentence in sentences:
|
|
477
|
+
for word in re.findall(r"[a-zA-Z0-9']+", sentence.lower()):
|
|
478
|
+
sentence_scores[sentence] += frequencies.get(word, 0.0)
|
|
479
|
+
|
|
480
|
+
top_sentences = sorted(sentence_scores.items(), key=lambda kv: kv[1], reverse=True)[: min(5, len(sentences))]
|
|
481
|
+
ordered = sorted(top_sentences, key=lambda kv: sentences.index(kv[0]))
|
|
482
|
+
return " ".join(sentence for sentence, _ in ordered).strip()
|
|
483
|
+
|
|
484
|
+
# ------------------------------------------------------------------
|
|
485
|
+
# Misc utilities
|
|
486
|
+
# ------------------------------------------------------------------
|
|
487
|
+
def _get_api_key(self, provider: str) -> Optional[str]:
|
|
488
|
+
for env_key in self._PROVIDER_ENV_KEYS.get(provider, ()): # type: ignore[arg-type]
|
|
489
|
+
value = os.getenv(env_key)
|
|
490
|
+
if value:
|
|
491
|
+
return value
|
|
492
|
+
return None
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
__all__ = ["LLMManager"]
|