agent-brain-rag 1.2.0__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/METADATA +55 -18
- agent_brain_rag-3.0.0.dist-info/RECORD +56 -0
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/WHEEL +1 -1
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/entry_points.txt +0 -1
- agent_brain_server/__init__.py +1 -1
- agent_brain_server/api/main.py +146 -45
- agent_brain_server/api/routers/__init__.py +2 -0
- agent_brain_server/api/routers/health.py +85 -21
- agent_brain_server/api/routers/index.py +108 -36
- agent_brain_server/api/routers/jobs.py +111 -0
- agent_brain_server/config/provider_config.py +352 -0
- agent_brain_server/config/settings.py +22 -5
- agent_brain_server/indexing/__init__.py +21 -0
- agent_brain_server/indexing/bm25_index.py +15 -2
- agent_brain_server/indexing/document_loader.py +45 -4
- agent_brain_server/indexing/embedding.py +86 -135
- agent_brain_server/indexing/graph_extractors.py +582 -0
- agent_brain_server/indexing/graph_index.py +536 -0
- agent_brain_server/job_queue/__init__.py +11 -0
- agent_brain_server/job_queue/job_service.py +317 -0
- agent_brain_server/job_queue/job_store.py +427 -0
- agent_brain_server/job_queue/job_worker.py +434 -0
- agent_brain_server/locking.py +101 -8
- agent_brain_server/models/__init__.py +28 -0
- agent_brain_server/models/graph.py +253 -0
- agent_brain_server/models/health.py +30 -3
- agent_brain_server/models/job.py +289 -0
- agent_brain_server/models/query.py +16 -3
- agent_brain_server/project_root.py +1 -1
- agent_brain_server/providers/__init__.py +64 -0
- agent_brain_server/providers/base.py +251 -0
- agent_brain_server/providers/embedding/__init__.py +23 -0
- agent_brain_server/providers/embedding/cohere.py +163 -0
- agent_brain_server/providers/embedding/ollama.py +150 -0
- agent_brain_server/providers/embedding/openai.py +118 -0
- agent_brain_server/providers/exceptions.py +95 -0
- agent_brain_server/providers/factory.py +157 -0
- agent_brain_server/providers/summarization/__init__.py +41 -0
- agent_brain_server/providers/summarization/anthropic.py +87 -0
- agent_brain_server/providers/summarization/gemini.py +96 -0
- agent_brain_server/providers/summarization/grok.py +95 -0
- agent_brain_server/providers/summarization/ollama.py +114 -0
- agent_brain_server/providers/summarization/openai.py +87 -0
- agent_brain_server/runtime.py +2 -2
- agent_brain_server/services/indexing_service.py +39 -0
- agent_brain_server/services/query_service.py +203 -0
- agent_brain_server/storage/__init__.py +18 -2
- agent_brain_server/storage/graph_store.py +519 -0
- agent_brain_server/storage/vector_store.py +35 -0
- agent_brain_server/storage_paths.py +5 -3
- agent_brain_rag-1.2.0.dist-info/RECORD +0 -31
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
"""Graph index manager for GraphRAG (Feature 113).
|
|
2
|
+
|
|
3
|
+
Manages graph index building and querying for the knowledge graph.
|
|
4
|
+
Coordinates between extractors, graph store, and vector store.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from typing import Any, Callable, Optional
|
|
10
|
+
|
|
11
|
+
from agent_brain_server.config import settings
|
|
12
|
+
from agent_brain_server.indexing.graph_extractors import (
|
|
13
|
+
CodeMetadataExtractor,
|
|
14
|
+
LLMEntityExtractor,
|
|
15
|
+
get_code_extractor,
|
|
16
|
+
get_llm_extractor,
|
|
17
|
+
)
|
|
18
|
+
from agent_brain_server.models.graph import (
|
|
19
|
+
GraphIndexStatus,
|
|
20
|
+
GraphQueryContext,
|
|
21
|
+
GraphTriple,
|
|
22
|
+
)
|
|
23
|
+
from agent_brain_server.storage.graph_store import (
|
|
24
|
+
GraphStoreManager,
|
|
25
|
+
get_graph_store_manager,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Type for progress callbacks
|
|
32
|
+
ProgressCallback = Callable[[int, int, str], None]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GraphIndexManager:
|
|
36
|
+
"""Manages graph index building and querying.
|
|
37
|
+
|
|
38
|
+
Coordinates:
|
|
39
|
+
- Entity extraction from documents (LLM and code metadata)
|
|
40
|
+
- Triplet storage in GraphStoreManager
|
|
41
|
+
- Graph-based retrieval for queries
|
|
42
|
+
|
|
43
|
+
All operations are no-ops when ENABLE_GRAPH_INDEX is False.
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
graph_store: The underlying graph store manager.
|
|
47
|
+
llm_extractor: LLM-based entity extractor.
|
|
48
|
+
code_extractor: Code metadata extractor.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
graph_store: Optional[GraphStoreManager] = None,
|
|
54
|
+
llm_extractor: Optional[LLMEntityExtractor] = None,
|
|
55
|
+
code_extractor: Optional[CodeMetadataExtractor] = None,
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Initialize graph index manager.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
graph_store: Graph store manager (defaults to singleton).
|
|
61
|
+
llm_extractor: LLM extractor (defaults to singleton).
|
|
62
|
+
code_extractor: Code extractor (defaults to singleton).
|
|
63
|
+
"""
|
|
64
|
+
self.graph_store = graph_store or get_graph_store_manager()
|
|
65
|
+
self.llm_extractor = llm_extractor or get_llm_extractor()
|
|
66
|
+
self.code_extractor = code_extractor or get_code_extractor()
|
|
67
|
+
self._last_build_time: Optional[datetime] = None
|
|
68
|
+
self._last_triplet_count: int = 0
|
|
69
|
+
|
|
70
|
+
def build_from_documents(
|
|
71
|
+
self,
|
|
72
|
+
documents: list[Any],
|
|
73
|
+
progress_callback: Optional[ProgressCallback] = None,
|
|
74
|
+
) -> int:
|
|
75
|
+
"""Build graph index from documents.
|
|
76
|
+
|
|
77
|
+
Extracts entities and relationships from document chunks
|
|
78
|
+
and stores them in the graph.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
documents: List of document chunks with text and metadata.
|
|
82
|
+
progress_callback: Optional callback(current, total, message).
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Total number of triplets extracted and stored.
|
|
86
|
+
"""
|
|
87
|
+
if not settings.ENABLE_GRAPH_INDEX:
|
|
88
|
+
logger.debug("Graph indexing disabled, skipping build")
|
|
89
|
+
return 0
|
|
90
|
+
|
|
91
|
+
# Ensure graph store is initialized
|
|
92
|
+
if not self.graph_store.is_initialized:
|
|
93
|
+
self.graph_store.initialize()
|
|
94
|
+
|
|
95
|
+
total_triplets = 0
|
|
96
|
+
total_docs = len(documents)
|
|
97
|
+
|
|
98
|
+
logger.info(f"Building graph index from {total_docs} documents")
|
|
99
|
+
|
|
100
|
+
for idx, doc in enumerate(documents):
|
|
101
|
+
if progress_callback:
|
|
102
|
+
progress_callback(
|
|
103
|
+
idx + 1,
|
|
104
|
+
total_docs,
|
|
105
|
+
f"Extracting entities: {idx + 1}/{total_docs}",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
triplets = self._extract_from_document(doc)
|
|
109
|
+
|
|
110
|
+
for triplet in triplets:
|
|
111
|
+
success = self.graph_store.add_triplet(
|
|
112
|
+
subject=triplet.subject,
|
|
113
|
+
predicate=triplet.predicate,
|
|
114
|
+
obj=triplet.object,
|
|
115
|
+
subject_type=triplet.subject_type,
|
|
116
|
+
object_type=triplet.object_type,
|
|
117
|
+
source_chunk_id=triplet.source_chunk_id,
|
|
118
|
+
)
|
|
119
|
+
if success:
|
|
120
|
+
total_triplets += 1
|
|
121
|
+
|
|
122
|
+
# Persist the graph
|
|
123
|
+
self.graph_store.persist()
|
|
124
|
+
self._last_build_time = datetime.now(timezone.utc)
|
|
125
|
+
self._last_triplet_count = total_triplets
|
|
126
|
+
|
|
127
|
+
logger.info(
|
|
128
|
+
f"Graph index built: {total_triplets} triplets from {total_docs} docs"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return total_triplets
|
|
132
|
+
|
|
133
|
+
def _extract_from_document(self, doc: Any) -> list[GraphTriple]:
|
|
134
|
+
"""Extract triplets from a single document.
|
|
135
|
+
|
|
136
|
+
Uses both code metadata extractor and LLM extractor
|
|
137
|
+
depending on document type and settings.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
doc: Document with text content and metadata.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
List of GraphTriple objects.
|
|
144
|
+
"""
|
|
145
|
+
triplets: list[GraphTriple] = []
|
|
146
|
+
|
|
147
|
+
# Get document properties
|
|
148
|
+
text = self._get_document_text(doc)
|
|
149
|
+
metadata = self._get_document_metadata(doc)
|
|
150
|
+
chunk_id = self._get_document_id(doc)
|
|
151
|
+
source_type = metadata.get("source_type", "doc")
|
|
152
|
+
language = metadata.get("language")
|
|
153
|
+
|
|
154
|
+
# 1. Extract from code metadata (fast, deterministic)
|
|
155
|
+
if source_type == "code" and settings.GRAPH_USE_CODE_METADATA:
|
|
156
|
+
code_triplets = self.code_extractor.extract_from_metadata(
|
|
157
|
+
metadata, source_chunk_id=chunk_id
|
|
158
|
+
)
|
|
159
|
+
triplets.extend(code_triplets)
|
|
160
|
+
|
|
161
|
+
# Also try pattern-based extraction from text
|
|
162
|
+
if language:
|
|
163
|
+
text_triplets = self.code_extractor.extract_from_text(
|
|
164
|
+
text, language=language, source_chunk_id=chunk_id
|
|
165
|
+
)
|
|
166
|
+
triplets.extend(text_triplets)
|
|
167
|
+
|
|
168
|
+
# 2. Extract using LLM (slower, more comprehensive)
|
|
169
|
+
if settings.GRAPH_USE_LLM_EXTRACTION and text:
|
|
170
|
+
llm_triplets = self.llm_extractor.extract_triplets(
|
|
171
|
+
text, source_chunk_id=chunk_id
|
|
172
|
+
)
|
|
173
|
+
triplets.extend(llm_triplets)
|
|
174
|
+
|
|
175
|
+
return triplets
|
|
176
|
+
|
|
177
|
+
def _get_document_text(self, doc: Any) -> str:
|
|
178
|
+
"""Get text content from document."""
|
|
179
|
+
if hasattr(doc, "text"):
|
|
180
|
+
return str(doc.text)
|
|
181
|
+
elif hasattr(doc, "get_content"):
|
|
182
|
+
return str(doc.get_content())
|
|
183
|
+
elif hasattr(doc, "page_content"):
|
|
184
|
+
return str(doc.page_content)
|
|
185
|
+
elif isinstance(doc, dict):
|
|
186
|
+
text = doc.get("text", doc.get("content", ""))
|
|
187
|
+
return str(text) if text else ""
|
|
188
|
+
return str(doc)
|
|
189
|
+
|
|
190
|
+
def _get_document_metadata(self, doc: Any) -> dict[str, Any]:
|
|
191
|
+
"""Get metadata from document."""
|
|
192
|
+
if hasattr(doc, "metadata"):
|
|
193
|
+
meta = doc.metadata
|
|
194
|
+
if hasattr(meta, "to_dict"):
|
|
195
|
+
result = meta.to_dict()
|
|
196
|
+
return dict(result) if result else {}
|
|
197
|
+
elif isinstance(meta, dict):
|
|
198
|
+
return dict(meta)
|
|
199
|
+
elif isinstance(doc, dict):
|
|
200
|
+
meta = doc.get("metadata", {})
|
|
201
|
+
return dict(meta) if meta else {}
|
|
202
|
+
return {}
|
|
203
|
+
|
|
204
|
+
def _get_document_id(self, doc: Any) -> Optional[str]:
|
|
205
|
+
"""Get document/chunk ID."""
|
|
206
|
+
if hasattr(doc, "chunk_id"):
|
|
207
|
+
val = doc.chunk_id
|
|
208
|
+
return str(val) if val else None
|
|
209
|
+
elif hasattr(doc, "id_"):
|
|
210
|
+
val = doc.id_
|
|
211
|
+
return str(val) if val else None
|
|
212
|
+
elif hasattr(doc, "node_id"):
|
|
213
|
+
val = doc.node_id
|
|
214
|
+
return str(val) if val else None
|
|
215
|
+
elif isinstance(doc, dict):
|
|
216
|
+
val = doc.get("chunk_id", doc.get("id"))
|
|
217
|
+
return str(val) if val else None
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
def query(
|
|
221
|
+
self,
|
|
222
|
+
query_text: str,
|
|
223
|
+
top_k: int = 10,
|
|
224
|
+
traversal_depth: int = 2,
|
|
225
|
+
) -> list[dict[str, Any]]:
|
|
226
|
+
"""Query the graph for related entities and documents.
|
|
227
|
+
|
|
228
|
+
Performs entity recognition on query, finds matching nodes,
|
|
229
|
+
and traverses relationships to discover related content.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
query_text: Natural language query.
|
|
233
|
+
top_k: Maximum number of results to return.
|
|
234
|
+
traversal_depth: How many hops to traverse in graph.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
List of result dicts with entity info and relationship paths.
|
|
238
|
+
"""
|
|
239
|
+
if not settings.ENABLE_GRAPH_INDEX:
|
|
240
|
+
return []
|
|
241
|
+
|
|
242
|
+
if not self.graph_store.is_initialized:
|
|
243
|
+
logger.debug("Graph store not initialized for query")
|
|
244
|
+
return []
|
|
245
|
+
|
|
246
|
+
# Get graph store for querying
|
|
247
|
+
graph_store = self.graph_store.graph_store
|
|
248
|
+
if graph_store is None:
|
|
249
|
+
return []
|
|
250
|
+
|
|
251
|
+
results: list[dict[str, Any]] = []
|
|
252
|
+
|
|
253
|
+
# Extract potential entity names from query
|
|
254
|
+
query_entities = self._extract_query_entities(query_text)
|
|
255
|
+
|
|
256
|
+
logger.debug(f"Graph query entities: {query_entities}")
|
|
257
|
+
|
|
258
|
+
# Find matching entities and their relationships
|
|
259
|
+
for entity in query_entities:
|
|
260
|
+
entity_results = self._find_entity_relationships(
|
|
261
|
+
entity, traversal_depth, top_k
|
|
262
|
+
)
|
|
263
|
+
results.extend(entity_results)
|
|
264
|
+
|
|
265
|
+
# Deduplicate and sort by relevance
|
|
266
|
+
seen_keys: set[str] = set()
|
|
267
|
+
unique_results: list[dict[str, Any]] = []
|
|
268
|
+
for result in results:
|
|
269
|
+
# Use source_chunk_id if available, otherwise use relationship path
|
|
270
|
+
chunk_id = result.get("source_chunk_id")
|
|
271
|
+
rel_path = result.get("relationship_path", "")
|
|
272
|
+
dedup_key = chunk_id if chunk_id else rel_path
|
|
273
|
+
|
|
274
|
+
if dedup_key and dedup_key not in seen_keys:
|
|
275
|
+
seen_keys.add(dedup_key)
|
|
276
|
+
unique_results.append(result)
|
|
277
|
+
elif not dedup_key:
|
|
278
|
+
# No dedup key available, still include result
|
|
279
|
+
unique_results.append(result)
|
|
280
|
+
|
|
281
|
+
# Limit to top_k
|
|
282
|
+
return unique_results[:top_k]
|
|
283
|
+
|
|
284
|
+
def _extract_query_entities(self, query_text: str) -> list[str]:
|
|
285
|
+
"""Extract potential entity names from query text.
|
|
286
|
+
|
|
287
|
+
Uses simple heuristics to identify entity-like terms.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
query_text: Query text to analyze.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
List of potential entity names.
|
|
294
|
+
"""
|
|
295
|
+
import re
|
|
296
|
+
|
|
297
|
+
entities: list[str] = []
|
|
298
|
+
|
|
299
|
+
# Split into words
|
|
300
|
+
words = query_text.split()
|
|
301
|
+
|
|
302
|
+
# Look for CamelCase or PascalCase words
|
|
303
|
+
for word in words:
|
|
304
|
+
# Remove punctuation
|
|
305
|
+
clean_word = re.sub(r"[^\w]", "", word)
|
|
306
|
+
if not clean_word:
|
|
307
|
+
continue
|
|
308
|
+
|
|
309
|
+
# CamelCase detection
|
|
310
|
+
if re.match(r"^[A-Z][a-z]+[A-Z]", clean_word):
|
|
311
|
+
entities.append(clean_word)
|
|
312
|
+
# ALL_CAPS constants
|
|
313
|
+
elif re.match(r"^[A-Z_]+$", clean_word) and len(clean_word) > 2:
|
|
314
|
+
entities.append(clean_word)
|
|
315
|
+
# Capitalized words (potential class names)
|
|
316
|
+
elif clean_word[0].isupper() and len(clean_word) > 2:
|
|
317
|
+
entities.append(clean_word)
|
|
318
|
+
# snake_case function names
|
|
319
|
+
elif "_" in clean_word and clean_word.islower():
|
|
320
|
+
entities.append(clean_word)
|
|
321
|
+
|
|
322
|
+
# Also include significant lowercase terms
|
|
323
|
+
for word in words:
|
|
324
|
+
clean_word = re.sub(r"[^\w]", "", word).lower()
|
|
325
|
+
if len(clean_word) > 3 and clean_word not in (
|
|
326
|
+
"what",
|
|
327
|
+
"where",
|
|
328
|
+
"when",
|
|
329
|
+
"which",
|
|
330
|
+
"that",
|
|
331
|
+
"this",
|
|
332
|
+
"have",
|
|
333
|
+
"does",
|
|
334
|
+
"with",
|
|
335
|
+
"from",
|
|
336
|
+
"about",
|
|
337
|
+
"into",
|
|
338
|
+
):
|
|
339
|
+
if clean_word not in [e.lower() for e in entities]:
|
|
340
|
+
entities.append(clean_word)
|
|
341
|
+
|
|
342
|
+
return entities[:10] # Limit to prevent query explosion
|
|
343
|
+
|
|
344
|
+
def _find_entity_relationships(
|
|
345
|
+
self,
|
|
346
|
+
entity: str,
|
|
347
|
+
depth: int,
|
|
348
|
+
max_results: int,
|
|
349
|
+
) -> list[dict[str, Any]]:
|
|
350
|
+
"""Find entity relationships in the graph.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
entity: Entity name to search for.
|
|
354
|
+
depth: Traversal depth.
|
|
355
|
+
max_results: Maximum results per entity.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
List of result dictionaries.
|
|
359
|
+
"""
|
|
360
|
+
results: list[dict[str, Any]] = []
|
|
361
|
+
graph_store = self.graph_store.graph_store
|
|
362
|
+
|
|
363
|
+
if graph_store is None:
|
|
364
|
+
return results
|
|
365
|
+
|
|
366
|
+
# Try to get triplets from graph store
|
|
367
|
+
try:
|
|
368
|
+
if hasattr(graph_store, "get_triplets"):
|
|
369
|
+
triplets = graph_store.get_triplets()
|
|
370
|
+
elif hasattr(graph_store, "_relationships"):
|
|
371
|
+
triplets = graph_store._relationships
|
|
372
|
+
else:
|
|
373
|
+
return results
|
|
374
|
+
|
|
375
|
+
# Search for matching entities
|
|
376
|
+
entity_lower = entity.lower()
|
|
377
|
+
matching_triplets: list[Any] = []
|
|
378
|
+
|
|
379
|
+
for triplet in triplets:
|
|
380
|
+
subject = self._get_triplet_field(triplet, "subject", "").lower()
|
|
381
|
+
obj = self._get_triplet_field(triplet, "object", "").lower()
|
|
382
|
+
|
|
383
|
+
if entity_lower in subject or entity_lower in obj:
|
|
384
|
+
matching_triplets.append(triplet)
|
|
385
|
+
|
|
386
|
+
# Build result entries from matching triplets
|
|
387
|
+
for triplet in matching_triplets[:max_results]:
|
|
388
|
+
result = {
|
|
389
|
+
"entity": entity,
|
|
390
|
+
"subject": self._get_triplet_field(triplet, "subject", ""),
|
|
391
|
+
"predicate": self._get_triplet_field(triplet, "predicate", ""),
|
|
392
|
+
"object": self._get_triplet_field(triplet, "object", ""),
|
|
393
|
+
"source_chunk_id": self._get_triplet_field(
|
|
394
|
+
triplet, "source_chunk_id", None
|
|
395
|
+
),
|
|
396
|
+
"relationship_path": self._format_relationship_path(triplet),
|
|
397
|
+
"graph_score": 1.0, # Direct match
|
|
398
|
+
}
|
|
399
|
+
results.append(result)
|
|
400
|
+
|
|
401
|
+
except Exception as e:
|
|
402
|
+
logger.warning(f"Error querying graph store: {e}")
|
|
403
|
+
|
|
404
|
+
return results
|
|
405
|
+
|
|
406
|
+
def _get_triplet_field(self, triplet: Any, field: str, default: Any) -> Any:
|
|
407
|
+
"""Get field from triplet (handles both dict and object forms)."""
|
|
408
|
+
if isinstance(triplet, dict):
|
|
409
|
+
return triplet.get(field, default)
|
|
410
|
+
return getattr(triplet, field, default)
|
|
411
|
+
|
|
412
|
+
def _format_relationship_path(self, triplet: Any) -> str:
|
|
413
|
+
"""Format a triplet as a relationship path string."""
|
|
414
|
+
subject = self._get_triplet_field(triplet, "subject", "?")
|
|
415
|
+
predicate = self._get_triplet_field(triplet, "predicate", "?")
|
|
416
|
+
obj = self._get_triplet_field(triplet, "object", "?")
|
|
417
|
+
return f"{subject} -> {predicate} -> {obj}"
|
|
418
|
+
|
|
419
|
+
def get_graph_context(
|
|
420
|
+
self,
|
|
421
|
+
query_text: str,
|
|
422
|
+
top_k: int = 5,
|
|
423
|
+
traversal_depth: int = 2,
|
|
424
|
+
) -> GraphQueryContext:
|
|
425
|
+
"""Get graph context for a query.
|
|
426
|
+
|
|
427
|
+
Returns structured context information from the knowledge graph
|
|
428
|
+
that can be used to augment retrieval results.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
query_text: Natural language query.
|
|
432
|
+
top_k: Maximum entities to include.
|
|
433
|
+
traversal_depth: Graph traversal depth.
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
GraphQueryContext with related entities and paths.
|
|
437
|
+
"""
|
|
438
|
+
if not settings.ENABLE_GRAPH_INDEX:
|
|
439
|
+
return GraphQueryContext()
|
|
440
|
+
|
|
441
|
+
results = self.query(query_text, top_k=top_k, traversal_depth=traversal_depth)
|
|
442
|
+
|
|
443
|
+
if not results:
|
|
444
|
+
return GraphQueryContext()
|
|
445
|
+
|
|
446
|
+
# Extract unique entities
|
|
447
|
+
related_entities: list[str] = []
|
|
448
|
+
relationship_paths: list[str] = []
|
|
449
|
+
subgraph_triplets: list[GraphTriple] = []
|
|
450
|
+
|
|
451
|
+
seen_entities: set[str] = set()
|
|
452
|
+
for result in results:
|
|
453
|
+
# Add entities
|
|
454
|
+
for entity_field in ["subject", "object"]:
|
|
455
|
+
entity = result.get(entity_field)
|
|
456
|
+
if entity and entity not in seen_entities:
|
|
457
|
+
seen_entities.add(entity)
|
|
458
|
+
related_entities.append(entity)
|
|
459
|
+
|
|
460
|
+
# Add relationship path
|
|
461
|
+
path = result.get("relationship_path")
|
|
462
|
+
if path and path not in relationship_paths:
|
|
463
|
+
relationship_paths.append(path)
|
|
464
|
+
|
|
465
|
+
# Create triplet
|
|
466
|
+
try:
|
|
467
|
+
triplet = GraphTriple(
|
|
468
|
+
subject=result.get("subject", ""),
|
|
469
|
+
predicate=result.get("predicate", ""),
|
|
470
|
+
object=result.get("object", ""),
|
|
471
|
+
source_chunk_id=result.get("source_chunk_id"),
|
|
472
|
+
)
|
|
473
|
+
subgraph_triplets.append(triplet)
|
|
474
|
+
except Exception:
|
|
475
|
+
pass
|
|
476
|
+
|
|
477
|
+
# Calculate average graph score
|
|
478
|
+
scores = [r.get("graph_score", 0.0) for r in results if r.get("graph_score")]
|
|
479
|
+
avg_score = sum(scores) / len(scores) if scores else 0.0
|
|
480
|
+
|
|
481
|
+
return GraphQueryContext(
|
|
482
|
+
related_entities=related_entities[:top_k],
|
|
483
|
+
relationship_paths=relationship_paths[:top_k],
|
|
484
|
+
subgraph_triplets=subgraph_triplets[:top_k],
|
|
485
|
+
graph_score=min(avg_score, 1.0),
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
def get_status(self) -> GraphIndexStatus:
|
|
489
|
+
"""Get current graph index status.
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
GraphIndexStatus with entity/relationship counts.
|
|
493
|
+
"""
|
|
494
|
+
if not settings.ENABLE_GRAPH_INDEX:
|
|
495
|
+
return GraphIndexStatus(
|
|
496
|
+
enabled=False,
|
|
497
|
+
initialized=False,
|
|
498
|
+
entity_count=0,
|
|
499
|
+
relationship_count=0,
|
|
500
|
+
store_type=settings.GRAPH_STORE_TYPE,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
return GraphIndexStatus(
|
|
504
|
+
enabled=True,
|
|
505
|
+
initialized=self.graph_store.is_initialized,
|
|
506
|
+
entity_count=self.graph_store.entity_count,
|
|
507
|
+
relationship_count=self.graph_store.relationship_count,
|
|
508
|
+
last_updated=self.graph_store.last_updated,
|
|
509
|
+
store_type=self.graph_store.store_type,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
def clear(self) -> None:
|
|
513
|
+
"""Clear the graph index."""
|
|
514
|
+
if settings.ENABLE_GRAPH_INDEX and self.graph_store.is_initialized:
|
|
515
|
+
self.graph_store.clear()
|
|
516
|
+
self._last_build_time = None
|
|
517
|
+
self._last_triplet_count = 0
|
|
518
|
+
logger.info("Graph index cleared")
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
# Module-level singleton
|
|
522
|
+
_graph_index_manager: Optional[GraphIndexManager] = None
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def get_graph_index_manager() -> GraphIndexManager:
|
|
526
|
+
"""Get the global graph index manager instance."""
|
|
527
|
+
global _graph_index_manager
|
|
528
|
+
if _graph_index_manager is None:
|
|
529
|
+
_graph_index_manager = GraphIndexManager()
|
|
530
|
+
return _graph_index_manager
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def reset_graph_index_manager() -> None:
|
|
534
|
+
"""Reset the global graph index manager. Used for testing."""
|
|
535
|
+
global _graph_index_manager
|
|
536
|
+
_graph_index_manager = None
|