code-graph-builder 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_graph_builder/__init__.py +82 -0
- code_graph_builder/builder.py +366 -0
- code_graph_builder/cgb_cli.py +32 -0
- code_graph_builder/cli.py +564 -0
- code_graph_builder/commands_cli.py +1288 -0
- code_graph_builder/config.py +340 -0
- code_graph_builder/constants.py +708 -0
- code_graph_builder/embeddings/__init__.py +40 -0
- code_graph_builder/embeddings/qwen3_embedder.py +573 -0
- code_graph_builder/embeddings/vector_store.py +584 -0
- code_graph_builder/examples/__init__.py +0 -0
- code_graph_builder/examples/example_configuration.py +276 -0
- code_graph_builder/examples/example_kuzu_usage.py +109 -0
- code_graph_builder/examples/example_semantic_search_full.py +347 -0
- code_graph_builder/examples/generate_wiki.py +915 -0
- code_graph_builder/examples/graph_export_example.py +100 -0
- code_graph_builder/examples/rag_example.py +206 -0
- code_graph_builder/examples/test_cli_demo.py +129 -0
- code_graph_builder/examples/test_embedding_api.py +153 -0
- code_graph_builder/examples/test_kuzu_local.py +190 -0
- code_graph_builder/examples/test_rag_redis.py +390 -0
- code_graph_builder/graph_updater.py +605 -0
- code_graph_builder/guidance/__init__.py +1 -0
- code_graph_builder/guidance/agent.py +123 -0
- code_graph_builder/guidance/prompts.py +74 -0
- code_graph_builder/guidance/toolset.py +264 -0
- code_graph_builder/language_spec.py +536 -0
- code_graph_builder/mcp/__init__.py +21 -0
- code_graph_builder/mcp/api_doc_generator.py +764 -0
- code_graph_builder/mcp/file_editor.py +207 -0
- code_graph_builder/mcp/pipeline.py +777 -0
- code_graph_builder/mcp/server.py +161 -0
- code_graph_builder/mcp/tools.py +1800 -0
- code_graph_builder/models.py +115 -0
- code_graph_builder/parser_loader.py +344 -0
- code_graph_builder/parsers/__init__.py +7 -0
- code_graph_builder/parsers/call_processor.py +306 -0
- code_graph_builder/parsers/call_resolver.py +139 -0
- code_graph_builder/parsers/definition_processor.py +796 -0
- code_graph_builder/parsers/factory.py +119 -0
- code_graph_builder/parsers/import_processor.py +293 -0
- code_graph_builder/parsers/structure_processor.py +145 -0
- code_graph_builder/parsers/type_inference.py +143 -0
- code_graph_builder/parsers/utils.py +134 -0
- code_graph_builder/rag/__init__.py +68 -0
- code_graph_builder/rag/camel_agent.py +429 -0
- code_graph_builder/rag/client.py +298 -0
- code_graph_builder/rag/config.py +239 -0
- code_graph_builder/rag/cypher_generator.py +67 -0
- code_graph_builder/rag/llm_backend.py +210 -0
- code_graph_builder/rag/markdown_generator.py +352 -0
- code_graph_builder/rag/prompt_templates.py +440 -0
- code_graph_builder/rag/rag_engine.py +640 -0
- code_graph_builder/rag/review_report.md +172 -0
- code_graph_builder/rag/tests/__init__.py +3 -0
- code_graph_builder/rag/tests/test_camel_agent.py +313 -0
- code_graph_builder/rag/tests/test_client.py +221 -0
- code_graph_builder/rag/tests/test_config.py +177 -0
- code_graph_builder/rag/tests/test_markdown_generator.py +240 -0
- code_graph_builder/rag/tests/test_prompt_templates.py +160 -0
- code_graph_builder/services/__init__.py +39 -0
- code_graph_builder/services/graph_service.py +465 -0
- code_graph_builder/services/kuzu_service.py +665 -0
- code_graph_builder/services/memory_service.py +171 -0
- code_graph_builder/settings.py +75 -0
- code_graph_builder/tests/ACCEPTANCE_CRITERIA_PHASE2.md +401 -0
- code_graph_builder/tests/__init__.py +1 -0
- code_graph_builder/tests/run_acceptance_check.py +378 -0
- code_graph_builder/tests/test_api_find.py +231 -0
- code_graph_builder/tests/test_api_find_integration.py +226 -0
- code_graph_builder/tests/test_basic.py +78 -0
- code_graph_builder/tests/test_c_api_extraction.py +388 -0
- code_graph_builder/tests/test_call_resolution_scenarios.py +504 -0
- code_graph_builder/tests/test_embedder.py +411 -0
- code_graph_builder/tests/test_integration_semantic.py +434 -0
- code_graph_builder/tests/test_mcp_protocol.py +298 -0
- code_graph_builder/tests/test_mcp_user_flow.py +190 -0
- code_graph_builder/tests/test_rag.py +404 -0
- code_graph_builder/tests/test_settings.py +135 -0
- code_graph_builder/tests/test_step1_graph_build.py +264 -0
- code_graph_builder/tests/test_step2_api_docs.py +323 -0
- code_graph_builder/tests/test_step3_embedding.py +278 -0
- code_graph_builder/tests/test_vector_store.py +552 -0
- code_graph_builder/tools/__init__.py +40 -0
- code_graph_builder/tools/graph_query.py +495 -0
- code_graph_builder/tools/semantic_search.py +387 -0
- code_graph_builder/types.py +333 -0
- code_graph_builder/utils/__init__.py +0 -0
- code_graph_builder/utils/path_utils.py +30 -0
- code_graph_builder-0.2.0.dist-info/METADATA +321 -0
- code_graph_builder-0.2.0.dist-info/RECORD +93 -0
- code_graph_builder-0.2.0.dist-info/WHEEL +4 -0
- code_graph_builder-0.2.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""Tests for prompt templates."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from code_graph_builder.rag.prompt_templates import (
|
|
8
|
+
CodeAnalysisPrompts,
|
|
9
|
+
CodeContext,
|
|
10
|
+
RAGPrompts,
|
|
11
|
+
create_code_context,
|
|
12
|
+
get_default_prompts,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TestCodeContext:
|
|
17
|
+
"""Tests for CodeContext."""
|
|
18
|
+
|
|
19
|
+
def test_basic_creation(self):
|
|
20
|
+
"""Test basic context creation."""
|
|
21
|
+
ctx = CodeContext(
|
|
22
|
+
source_code="def foo(): pass",
|
|
23
|
+
file_path="test.py",
|
|
24
|
+
qualified_name="test.foo",
|
|
25
|
+
entity_type="Function",
|
|
26
|
+
)
|
|
27
|
+
assert ctx.source_code == "def foo(): pass"
|
|
28
|
+
assert ctx.file_path == "test.py"
|
|
29
|
+
assert ctx.qualified_name == "test.foo"
|
|
30
|
+
assert ctx.entity_type == "Function"
|
|
31
|
+
|
|
32
|
+
def test_format_context(self):
|
|
33
|
+
"""Test context formatting."""
|
|
34
|
+
ctx = CodeContext(
|
|
35
|
+
source_code="def foo(): pass",
|
|
36
|
+
file_path="test.py",
|
|
37
|
+
qualified_name="test.foo",
|
|
38
|
+
entity_type="Function",
|
|
39
|
+
docstring="Test function",
|
|
40
|
+
callers=["caller1", "caller2"],
|
|
41
|
+
callees=["callee1"],
|
|
42
|
+
)
|
|
43
|
+
formatted = ctx.format_context()
|
|
44
|
+
assert "Entity: test.foo" in formatted
|
|
45
|
+
assert "Type: Function" in formatted
|
|
46
|
+
assert "File: test.py" in formatted
|
|
47
|
+
assert "Documentation:" in formatted
|
|
48
|
+
assert "def foo(): pass" in formatted
|
|
49
|
+
assert "Called By:" in formatted
|
|
50
|
+
assert "Calls:" in formatted
|
|
51
|
+
|
|
52
|
+
def test_format_context_minimal(self):
|
|
53
|
+
"""Test context formatting with minimal data."""
|
|
54
|
+
ctx = CodeContext(source_code="x = 1")
|
|
55
|
+
formatted = ctx.format_context()
|
|
56
|
+
assert "Source Code:" in formatted
|
|
57
|
+
assert "x = 1" in formatted
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class TestCodeAnalysisPrompts:
|
|
61
|
+
"""Tests for CodeAnalysisPrompts."""
|
|
62
|
+
|
|
63
|
+
def test_get_system_prompt(self):
|
|
64
|
+
"""Test getting system prompt."""
|
|
65
|
+
prompts = CodeAnalysisPrompts()
|
|
66
|
+
system = prompts.get_system_prompt()
|
|
67
|
+
assert "expert code analyst" in system.lower()
|
|
68
|
+
assert len(system) > 0
|
|
69
|
+
|
|
70
|
+
def test_format_explain_prompt(self):
|
|
71
|
+
"""Test formatting explain prompt."""
|
|
72
|
+
prompts = CodeAnalysisPrompts()
|
|
73
|
+
ctx = CodeContext(source_code="def foo(): pass")
|
|
74
|
+
prompt = prompts.format_explain_prompt(ctx)
|
|
75
|
+
assert "explain" in prompt.lower()
|
|
76
|
+
assert "def foo(): pass" in prompt
|
|
77
|
+
|
|
78
|
+
def test_format_query_prompt(self):
|
|
79
|
+
"""Test formatting query prompt."""
|
|
80
|
+
prompts = CodeAnalysisPrompts()
|
|
81
|
+
ctx = CodeContext(source_code="def foo(): pass")
|
|
82
|
+
prompt = prompts.format_query_prompt("What does this do?", ctx)
|
|
83
|
+
assert "What does this do?" in prompt
|
|
84
|
+
assert "def foo(): pass" in prompt
|
|
85
|
+
|
|
86
|
+
def test_format_documentation_prompt(self):
|
|
87
|
+
"""Test formatting documentation prompt."""
|
|
88
|
+
prompts = CodeAnalysisPrompts()
|
|
89
|
+
ctx = CodeContext(source_code="def foo(): pass")
|
|
90
|
+
prompt = prompts.format_documentation_prompt(ctx)
|
|
91
|
+
assert "documentation" in prompt.lower()
|
|
92
|
+
|
|
93
|
+
def test_format_architecture_prompt(self):
|
|
94
|
+
"""Test formatting architecture prompt."""
|
|
95
|
+
prompts = CodeAnalysisPrompts()
|
|
96
|
+
ctx = CodeContext(source_code="class Foo: pass")
|
|
97
|
+
prompt = prompts.format_architecture_prompt(ctx)
|
|
98
|
+
assert "architecture" in prompt.lower()
|
|
99
|
+
|
|
100
|
+
def test_format_summary_prompt(self):
|
|
101
|
+
"""Test formatting summary prompt."""
|
|
102
|
+
prompts = CodeAnalysisPrompts()
|
|
103
|
+
ctx = CodeContext(source_code="def foo(): pass")
|
|
104
|
+
prompt = prompts.format_summary_prompt(ctx)
|
|
105
|
+
assert "summary" in prompt.lower()
|
|
106
|
+
|
|
107
|
+
def test_format_multi_context_prompt(self):
|
|
108
|
+
"""Test formatting multi-context prompt."""
|
|
109
|
+
prompts = CodeAnalysisPrompts()
|
|
110
|
+
contexts = [
|
|
111
|
+
CodeContext(source_code="def foo(): pass"),
|
|
112
|
+
CodeContext(source_code="def bar(): pass"),
|
|
113
|
+
]
|
|
114
|
+
prompt = prompts.format_multi_context_prompt("Compare these", contexts)
|
|
115
|
+
assert "Compare these" in prompt
|
|
116
|
+
assert "Context 1" in prompt
|
|
117
|
+
assert "Context 2" in prompt
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class TestRAGPrompts:
|
|
121
|
+
"""Tests for RAGPrompts."""
|
|
122
|
+
|
|
123
|
+
def test_format_rag_query_with_contexts(self):
|
|
124
|
+
"""Test formatting RAG query with contexts."""
|
|
125
|
+
prompts = RAGPrompts()
|
|
126
|
+
contexts = [
|
|
127
|
+
CodeContext(
|
|
128
|
+
source_code="def foo(): pass",
|
|
129
|
+
qualified_name="test.foo",
|
|
130
|
+
),
|
|
131
|
+
]
|
|
132
|
+
system, user = prompts.format_rag_query("Explain this", contexts)
|
|
133
|
+
assert len(system) > 0
|
|
134
|
+
assert "Explain this" in user
|
|
135
|
+
assert "test.foo" in user
|
|
136
|
+
|
|
137
|
+
def test_format_rag_query_no_contexts(self):
|
|
138
|
+
"""Test formatting RAG query with no contexts."""
|
|
139
|
+
prompts = RAGPrompts()
|
|
140
|
+
system, user = prompts.format_rag_query("Explain this", [])
|
|
141
|
+
assert "No relevant code" in user
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class TestConvenienceFunctions:
|
|
145
|
+
"""Tests for convenience functions."""
|
|
146
|
+
|
|
147
|
+
def test_get_default_prompts(self):
|
|
148
|
+
"""Test getting default prompts."""
|
|
149
|
+
prompts = get_default_prompts()
|
|
150
|
+
assert isinstance(prompts, RAGPrompts)
|
|
151
|
+
|
|
152
|
+
def test_create_code_context(self):
|
|
153
|
+
"""Test creating code context."""
|
|
154
|
+
ctx = create_code_context(
|
|
155
|
+
source_code="def foo(): pass",
|
|
156
|
+
file_path="test.py",
|
|
157
|
+
qualified_name="test.foo",
|
|
158
|
+
)
|
|
159
|
+
assert isinstance(ctx, CodeContext)
|
|
160
|
+
assert ctx.source_code == "def foo(): pass"
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Code Graph Builder - Services."""
|
|
2
|
+
|
|
3
|
+
from typing import Protocol, runtime_checkable
|
|
4
|
+
|
|
5
|
+
from ..types import PropertyDict, PropertyValue, ResultRow
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@runtime_checkable
|
|
9
|
+
class IngestorProtocol(Protocol):
|
|
10
|
+
"""Protocol for graph data ingestors."""
|
|
11
|
+
|
|
12
|
+
def ensure_node_batch(self, label: str, properties: PropertyDict) -> None: ...
|
|
13
|
+
|
|
14
|
+
def ensure_relationship_batch(
|
|
15
|
+
self,
|
|
16
|
+
from_spec: tuple[str, str, PropertyValue],
|
|
17
|
+
rel_type: str,
|
|
18
|
+
to_spec: tuple[str, str, PropertyValue],
|
|
19
|
+
properties: PropertyDict | None = None,
|
|
20
|
+
) -> None: ...
|
|
21
|
+
|
|
22
|
+
def flush_all(self) -> None: ...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@runtime_checkable
|
|
26
|
+
class QueryProtocol(Protocol):
|
|
27
|
+
"""Protocol for graph query operations."""
|
|
28
|
+
|
|
29
|
+
def fetch_all(
|
|
30
|
+
self, query: str, params: PropertyDict | None = None
|
|
31
|
+
) -> list[ResultRow]: ...
|
|
32
|
+
|
|
33
|
+
def execute_write(self, query: str, params: PropertyDict | None = None) -> None: ...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Import implementation
|
|
37
|
+
from .graph_service import MemgraphIngestor
|
|
38
|
+
|
|
39
|
+
__all__ = ["IngestorProtocol", "QueryProtocol", "MemgraphIngestor"]
|
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
"""Graph service for connecting to and interacting with Memgraph."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import types
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from collections.abc import Generator, Sequence
|
|
8
|
+
from contextlib import contextmanager
|
|
9
|
+
from datetime import UTC, datetime
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
from loguru import logger
|
|
13
|
+
|
|
14
|
+
from ..types import (
|
|
15
|
+
BatchParams,
|
|
16
|
+
BatchWrapper,
|
|
17
|
+
ColumnDescriptor,
|
|
18
|
+
CursorProtocol,
|
|
19
|
+
GraphData,
|
|
20
|
+
GraphMetadata,
|
|
21
|
+
GraphNode,
|
|
22
|
+
NodeBatchRow,
|
|
23
|
+
PropertyDict,
|
|
24
|
+
PropertyValue,
|
|
25
|
+
RelBatchRow,
|
|
26
|
+
ResultRow,
|
|
27
|
+
ResultValue,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
import mgclient
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class MemgraphIngestor:
|
|
35
|
+
"""Ingestor for writing code graph data to Memgraph."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, host: str, port: int, batch_size: int = 1000):
|
|
38
|
+
self._host = host
|
|
39
|
+
self._port = port
|
|
40
|
+
if batch_size < 1:
|
|
41
|
+
raise ValueError("batch_size must be at least 1")
|
|
42
|
+
self.batch_size = batch_size
|
|
43
|
+
self.conn: mgclient.Connection | None = None
|
|
44
|
+
self.node_buffer: list[tuple[str, dict[str, PropertyValue]]] = []
|
|
45
|
+
self.relationship_buffer: list[
|
|
46
|
+
tuple[
|
|
47
|
+
tuple[str, str, PropertyValue],
|
|
48
|
+
str,
|
|
49
|
+
tuple[str, str, PropertyValue],
|
|
50
|
+
dict[str, PropertyValue] | None,
|
|
51
|
+
]
|
|
52
|
+
] = []
|
|
53
|
+
|
|
54
|
+
def __enter__(self) -> MemgraphIngestor:
|
|
55
|
+
import mgclient
|
|
56
|
+
|
|
57
|
+
logger.info(f"Connecting to Memgraph at {self._host}:{self._port}")
|
|
58
|
+
self.conn = mgclient.connect(host=self._host, port=self._port)
|
|
59
|
+
self.conn.autocommit = True
|
|
60
|
+
logger.info("Connected to Memgraph")
|
|
61
|
+
return self
|
|
62
|
+
|
|
63
|
+
def __exit__(
|
|
64
|
+
self,
|
|
65
|
+
exc_type: type | None,
|
|
66
|
+
exc_val: Exception | None,
|
|
67
|
+
exc_tb: types.TracebackType | None,
|
|
68
|
+
) -> None:
|
|
69
|
+
if exc_type:
|
|
70
|
+
logger.exception(f"Exception during ingest: {exc_val}")
|
|
71
|
+
try:
|
|
72
|
+
self.flush_all()
|
|
73
|
+
except Exception as flush_err:
|
|
74
|
+
logger.error(f"Flush error during exception handling: {flush_err}")
|
|
75
|
+
else:
|
|
76
|
+
self.flush_all()
|
|
77
|
+
if self.conn:
|
|
78
|
+
self.conn.close()
|
|
79
|
+
logger.info("Disconnected from Memgraph")
|
|
80
|
+
|
|
81
|
+
@contextmanager
|
|
82
|
+
def _get_cursor(self) -> Generator[CursorProtocol, None, None]:
|
|
83
|
+
if not self.conn:
|
|
84
|
+
raise ConnectionError("Not connected to database")
|
|
85
|
+
cursor: CursorProtocol | None = None
|
|
86
|
+
try:
|
|
87
|
+
cursor = self.conn.cursor()
|
|
88
|
+
yield cursor
|
|
89
|
+
finally:
|
|
90
|
+
if cursor:
|
|
91
|
+
cursor.close()
|
|
92
|
+
|
|
93
|
+
def _cursor_to_results(self, cursor: CursorProtocol) -> list[ResultRow]:
|
|
94
|
+
if not cursor.description:
|
|
95
|
+
return []
|
|
96
|
+
column_names = [desc.name for desc in cursor.description]
|
|
97
|
+
return [
|
|
98
|
+
dict[str, ResultValue](zip(column_names, row)) for row in cursor.fetchall()
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
def _execute_query(
|
|
102
|
+
self,
|
|
103
|
+
query: str,
|
|
104
|
+
params: dict[str, PropertyValue] | None = None,
|
|
105
|
+
) -> list[ResultRow]:
|
|
106
|
+
params = params or {}
|
|
107
|
+
with self._get_cursor() as cursor:
|
|
108
|
+
try:
|
|
109
|
+
cursor.execute(query, params)
|
|
110
|
+
return self._cursor_to_results(cursor)
|
|
111
|
+
except Exception as e:
|
|
112
|
+
if "already exists" not in str(e).lower():
|
|
113
|
+
logger.error(f"Query error: {e}")
|
|
114
|
+
logger.error(f"Query: {query}")
|
|
115
|
+
logger.error(f"Params: {params}")
|
|
116
|
+
raise
|
|
117
|
+
|
|
118
|
+
def _execute_batch(self, query: str, params_list: Sequence[BatchParams]) -> None:
|
|
119
|
+
if not self.conn or not params_list:
|
|
120
|
+
return
|
|
121
|
+
cursor = None
|
|
122
|
+
try:
|
|
123
|
+
cursor = self.conn.cursor()
|
|
124
|
+
cursor.execute(
|
|
125
|
+
f"UNWIND $batch AS row\n{query}",
|
|
126
|
+
BatchWrapper(batch=params_list),
|
|
127
|
+
)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
if "already exists" not in str(e).lower():
|
|
130
|
+
logger.error(f"Batch error: {e}")
|
|
131
|
+
logger.error(f"Query: {query}")
|
|
132
|
+
finally:
|
|
133
|
+
if cursor:
|
|
134
|
+
cursor.close()
|
|
135
|
+
|
|
136
|
+
def fetch_all(
|
|
137
|
+
self,
|
|
138
|
+
query: str,
|
|
139
|
+
params: dict[str, PropertyValue] | None = None,
|
|
140
|
+
) -> list[ResultRow]:
|
|
141
|
+
"""Execute a query and return all results."""
|
|
142
|
+
return self._execute_query(query, params)
|
|
143
|
+
|
|
144
|
+
def ensure_node_batch(
|
|
145
|
+
self,
|
|
146
|
+
label: str,
|
|
147
|
+
id_key: str,
|
|
148
|
+
id_val: PropertyValue,
|
|
149
|
+
props: dict[str, PropertyValue],
|
|
150
|
+
) -> None:
|
|
151
|
+
"""Queue a node for batch insertion."""
|
|
152
|
+
unique_id = f"{label}:{id_key}:{id_val}"
|
|
153
|
+
self.node_buffer.append((unique_id, {"label": label, "id_key": id_key, "id_val": id_val, **props}))
|
|
154
|
+
if len(self.node_buffer) >= self.batch_size:
|
|
155
|
+
self.flush_nodes()
|
|
156
|
+
|
|
157
|
+
def ensure_relationship_batch(
|
|
158
|
+
self,
|
|
159
|
+
from_label: str,
|
|
160
|
+
from_key: str,
|
|
161
|
+
from_val: PropertyValue,
|
|
162
|
+
rel_type: str,
|
|
163
|
+
to_label: str,
|
|
164
|
+
to_key: str,
|
|
165
|
+
to_val: PropertyValue,
|
|
166
|
+
props: dict[str, PropertyValue] | None = None,
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Queue a relationship for batch insertion."""
|
|
169
|
+
from_id = (from_label, from_key, from_val)
|
|
170
|
+
to_id = (to_label, to_key, to_val)
|
|
171
|
+
self.relationship_buffer.append((from_id, rel_type, to_id, props))
|
|
172
|
+
if len(self.relationship_buffer) >= self.batch_size:
|
|
173
|
+
self.flush_relationships()
|
|
174
|
+
|
|
175
|
+
def flush_nodes(self) -> None:
|
|
176
|
+
"""Flush buffered nodes to the database."""
|
|
177
|
+
if not self.node_buffer:
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
# Group by label for efficient batching
|
|
181
|
+
by_label: defaultdict[str, list[dict]] = defaultdict(list)
|
|
182
|
+
for _unique_id, node_data in self.node_buffer:
|
|
183
|
+
label = node_data.pop("label")
|
|
184
|
+
by_label[label].append(node_data)
|
|
185
|
+
|
|
186
|
+
for label, nodes in by_label.items():
|
|
187
|
+
query = f"""
|
|
188
|
+
UNWIND $batch AS row
|
|
189
|
+
MERGE (n:{label} {{{nodes[0].get('id_key', 'id')}: row.id_val}})
|
|
190
|
+
SET n += row
|
|
191
|
+
"""
|
|
192
|
+
self._execute_batch(query, nodes)
|
|
193
|
+
|
|
194
|
+
logger.debug(f"Flushed {len(self.node_buffer)} nodes")
|
|
195
|
+
self.node_buffer.clear()
|
|
196
|
+
|
|
197
|
+
def flush_relationships(self) -> None:
|
|
198
|
+
"""Flush buffered relationships to the database."""
|
|
199
|
+
if not self.relationship_buffer:
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
# Group by type for efficient batching
|
|
203
|
+
by_type: defaultdict[
|
|
204
|
+
str,
|
|
205
|
+
list[dict],
|
|
206
|
+
] = defaultdict(list)
|
|
207
|
+
for (from_label, from_key, from_val), rel_type, (to_label, to_key, to_val), props in self.relationship_buffer:
|
|
208
|
+
row: dict = {
|
|
209
|
+
"from_label": from_label,
|
|
210
|
+
"from_key": from_key,
|
|
211
|
+
"from_val": from_val,
|
|
212
|
+
"to_label": to_label,
|
|
213
|
+
"to_key": to_key,
|
|
214
|
+
"to_val": to_val,
|
|
215
|
+
}
|
|
216
|
+
if props:
|
|
217
|
+
row["props"] = props
|
|
218
|
+
by_type[rel_type].append(row)
|
|
219
|
+
|
|
220
|
+
for rel_type, rels in by_type.items():
|
|
221
|
+
query = f"""
|
|
222
|
+
UNWIND $batch AS row
|
|
223
|
+
MATCH (a {{{rels[0].get('from_key', 'id')}: row.from_val}})
|
|
224
|
+
MATCH (b {{{rels[0].get('to_key', 'id')}: row.to_val}})
|
|
225
|
+
MERGE (a)-[r:{rel_type}]->(b)
|
|
226
|
+
SET r += row.props
|
|
227
|
+
"""
|
|
228
|
+
self._execute_batch(query, rels)
|
|
229
|
+
|
|
230
|
+
logger.debug(f"Flushed {len(self.relationship_buffer)} relationships")
|
|
231
|
+
self.relationship_buffer.clear()
|
|
232
|
+
|
|
233
|
+
def flush_all(self) -> None:
|
|
234
|
+
"""Flush all buffered data to the database."""
|
|
235
|
+
self.flush_nodes()
|
|
236
|
+
self.flush_relationships()
|
|
237
|
+
|
|
238
|
+
def clean_database(self) -> None:
|
|
239
|
+
"""Delete all data from the database."""
|
|
240
|
+
logger.warning("Cleaning database - deleting all nodes and relationships")
|
|
241
|
+
self._execute_query("MATCH (n) DETACH DELETE n;")
|
|
242
|
+
logger.info("Database cleaned")
|
|
243
|
+
|
|
244
|
+
def list_projects(self) -> list[str]:
|
|
245
|
+
"""List all projects in the database."""
|
|
246
|
+
results = self._execute_query(
|
|
247
|
+
"MATCH (p:Project) RETURN p.name AS name ORDER BY p.name"
|
|
248
|
+
)
|
|
249
|
+
return [row["name"] for row in results if row.get("name")]
|
|
250
|
+
|
|
251
|
+
def delete_project(self, project_name: str) -> None:
|
|
252
|
+
"""Delete a project and all its related data."""
|
|
253
|
+
logger.info(f"Deleting project: {project_name}")
|
|
254
|
+
query = """
|
|
255
|
+
MATCH (p:Project {name: $project_name})
|
|
256
|
+
OPTIONAL MATCH (p)-[:CONTAINS_PACKAGE|CONTAINS_FOLDER|CONTAINS_FILE|CONTAINS_MODULE*]->(container)
|
|
257
|
+
OPTIONAL MATCH (container)-[:DEFINES|DEFINES_METHOD*]->(defined)
|
|
258
|
+
DETACH DELETE p, container, defined
|
|
259
|
+
"""
|
|
260
|
+
self._execute_query(query, {"project_name": project_name})
|
|
261
|
+
logger.info(f"Project {project_name} deleted")
|
|
262
|
+
|
|
263
|
+
def export_graph_to_dict(self) -> GraphData:
|
|
264
|
+
"""Export the entire graph as a dictionary."""
|
|
265
|
+
logger.info("Exporting graph to dictionary")
|
|
266
|
+
|
|
267
|
+
nodes_query = """
|
|
268
|
+
MATCH (n)
|
|
269
|
+
RETURN id(n) as node_id, labels(n) as labels, properties(n) as properties
|
|
270
|
+
"""
|
|
271
|
+
nodes = self._execute_query(nodes_query)
|
|
272
|
+
|
|
273
|
+
rels_query = """
|
|
274
|
+
MATCH (a)-[r]->(b)
|
|
275
|
+
RETURN id(a) as from_id, id(b) as to_id, type(r) as type, properties(r) as properties
|
|
276
|
+
"""
|
|
277
|
+
relationships = self._execute_query(rels_query)
|
|
278
|
+
|
|
279
|
+
metadata = GraphMetadata(
|
|
280
|
+
total_nodes=len(nodes),
|
|
281
|
+
total_relationships=len(relationships),
|
|
282
|
+
exported_at=datetime.now(UTC).isoformat(),
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
logger.info(
|
|
286
|
+
f"Exported {len(nodes)} nodes and {len(relationships)} relationships"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return GraphData(
|
|
290
|
+
nodes=nodes,
|
|
291
|
+
relationships=relationships,
|
|
292
|
+
metadata=metadata,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
def get_node_by_id(self, node_id: int) -> GraphNode | None:
|
|
296
|
+
"""Get a node by its internal ID.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
node_id: Memgraph internal node ID
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
GraphNode if found, None otherwise
|
|
303
|
+
"""
|
|
304
|
+
query = """
|
|
305
|
+
MATCH (n)
|
|
306
|
+
WHERE id(n) = $node_id
|
|
307
|
+
RETURN id(n) as node_id, labels(n) as labels, properties(n) as props
|
|
308
|
+
"""
|
|
309
|
+
results = self._execute_query(query, {"node_id": node_id})
|
|
310
|
+
|
|
311
|
+
if not results:
|
|
312
|
+
return None
|
|
313
|
+
|
|
314
|
+
row = results[0]
|
|
315
|
+
return self._row_to_graph_node(row)
|
|
316
|
+
|
|
317
|
+
def get_nodes_by_ids(self, node_ids: list[int]) -> list[GraphNode]:
|
|
318
|
+
"""Get multiple nodes by their internal IDs.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
node_ids: List of Memgraph internal node IDs
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
List of GraphNode objects
|
|
325
|
+
"""
|
|
326
|
+
if not node_ids:
|
|
327
|
+
return []
|
|
328
|
+
|
|
329
|
+
query = """
|
|
330
|
+
MATCH (n)
|
|
331
|
+
WHERE id(n) IN $node_ids
|
|
332
|
+
RETURN id(n) as node_id, labels(n) as labels, properties(n) as props
|
|
333
|
+
"""
|
|
334
|
+
results = self._execute_query(query, {"node_ids": node_ids})
|
|
335
|
+
|
|
336
|
+
return [self._row_to_graph_node(row) for row in results if row]
|
|
337
|
+
|
|
338
|
+
def search_nodes(
|
|
339
|
+
self,
|
|
340
|
+
query_str: str,
|
|
341
|
+
label: str | None = None,
|
|
342
|
+
limit: int = 10,
|
|
343
|
+
) -> list[GraphNode]:
|
|
344
|
+
"""Search nodes by name or qualified name.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
query_str: Search query string
|
|
348
|
+
label: Optional node label filter
|
|
349
|
+
limit: Maximum number of results
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
List of matching GraphNode objects
|
|
353
|
+
"""
|
|
354
|
+
if label:
|
|
355
|
+
cypher = """
|
|
356
|
+
MATCH (n:$label)
|
|
357
|
+
WHERE n.name CONTAINS $query OR n.qualified_name CONTAINS $query
|
|
358
|
+
RETURN id(n) as node_id, labels(n) as labels, properties(n) as props
|
|
359
|
+
LIMIT $limit
|
|
360
|
+
"""
|
|
361
|
+
cypher = cypher.replace("$label", label)
|
|
362
|
+
else:
|
|
363
|
+
cypher = """
|
|
364
|
+
MATCH (n)
|
|
365
|
+
WHERE n.name CONTAINS $query OR n.qualified_name CONTAINS $query
|
|
366
|
+
RETURN id(n) as node_id, labels(n) as labels, properties(n) as props
|
|
367
|
+
LIMIT $limit
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
results = self._execute_query(
|
|
371
|
+
cypher, {"query": query_str, "limit": limit}
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
return [self._row_to_graph_node(row) for row in results if row]
|
|
375
|
+
|
|
376
|
+
def get_node_relationships(
|
|
377
|
+
self,
|
|
378
|
+
node_id: int,
|
|
379
|
+
rel_type: str | None = None,
|
|
380
|
+
direction: str = "both",
|
|
381
|
+
) -> list[tuple[GraphNode, str, str]]:
|
|
382
|
+
"""Get relationships for a node.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
node_id: Node identifier
|
|
386
|
+
rel_type: Optional relationship type filter
|
|
387
|
+
direction: "out", "in", or "both"
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
List of (related_node, relationship_type, direction) tuples
|
|
391
|
+
"""
|
|
392
|
+
results = []
|
|
393
|
+
|
|
394
|
+
if direction in ("out", "both"):
|
|
395
|
+
if rel_type:
|
|
396
|
+
query = f"""
|
|
397
|
+
MATCH (n)-[r:{rel_type}]->(m)
|
|
398
|
+
WHERE id(n) = $node_id
|
|
399
|
+
RETURN id(m) as node_id, labels(m) as labels, properties(m) as props,
|
|
400
|
+
type(r) as rel_type, "out" as direction
|
|
401
|
+
"""
|
|
402
|
+
else:
|
|
403
|
+
query = """
|
|
404
|
+
MATCH (n)-[r]->(m)
|
|
405
|
+
WHERE id(n) = $node_id
|
|
406
|
+
RETURN id(m) as node_id, labels(m) as labels, properties(m) as props,
|
|
407
|
+
type(r) as rel_type, "out" as direction
|
|
408
|
+
"""
|
|
409
|
+
rows = self._execute_query(query, {"node_id": node_id})
|
|
410
|
+
for row in rows:
|
|
411
|
+
node = self._row_to_graph_node(row)
|
|
412
|
+
results.append((node, row.get("rel_type", "UNKNOWN"), "out"))
|
|
413
|
+
|
|
414
|
+
if direction in ("in", "both"):
|
|
415
|
+
if rel_type:
|
|
416
|
+
query = f"""
|
|
417
|
+
MATCH (n)<-[r:{rel_type}]-(m)
|
|
418
|
+
WHERE id(n) = $node_id
|
|
419
|
+
RETURN id(m) as node_id, labels(m) as labels, properties(m) as props,
|
|
420
|
+
type(r) as rel_type, "in" as direction
|
|
421
|
+
"""
|
|
422
|
+
else:
|
|
423
|
+
query = """
|
|
424
|
+
MATCH (n)<-[r]-(m)
|
|
425
|
+
WHERE id(n) = $node_id
|
|
426
|
+
RETURN id(m) as node_id, labels(m) as labels, properties(m) as props,
|
|
427
|
+
type(r) as rel_type, "in" as direction
|
|
428
|
+
"""
|
|
429
|
+
rows = self._execute_query(query, {"node_id": node_id})
|
|
430
|
+
for row in rows:
|
|
431
|
+
node = self._row_to_graph_node(row)
|
|
432
|
+
results.append((node, row.get("rel_type", "UNKNOWN"), "in"))
|
|
433
|
+
|
|
434
|
+
return results
|
|
435
|
+
|
|
436
|
+
def _row_to_graph_node(self, row: ResultRow) -> GraphNode:
|
|
437
|
+
"""Convert a query result row to GraphNode.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
row: Query result row
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
GraphNode instance
|
|
444
|
+
"""
|
|
445
|
+
props = row.get("props", {})
|
|
446
|
+
if isinstance(props, dict):
|
|
447
|
+
properties = dict(props)
|
|
448
|
+
else:
|
|
449
|
+
properties = {}
|
|
450
|
+
|
|
451
|
+
labels = row.get("labels", [])
|
|
452
|
+
if not isinstance(labels, list):
|
|
453
|
+
labels = []
|
|
454
|
+
|
|
455
|
+
return GraphNode(
|
|
456
|
+
node_id=row.get("node_id", 0),
|
|
457
|
+
labels=labels,
|
|
458
|
+
qualified_name=properties.get("qualified_name", ""),
|
|
459
|
+
name=properties.get("name", ""),
|
|
460
|
+
path=properties.get("path"),
|
|
461
|
+
start_line=properties.get("start_line"),
|
|
462
|
+
end_line=properties.get("end_line"),
|
|
463
|
+
docstring=properties.get("docstring"),
|
|
464
|
+
properties=properties,
|
|
465
|
+
)
|