code-graph-builder 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. code_graph_builder/__init__.py +82 -0
  2. code_graph_builder/builder.py +366 -0
  3. code_graph_builder/cgb_cli.py +32 -0
  4. code_graph_builder/cli.py +564 -0
  5. code_graph_builder/commands_cli.py +1288 -0
  6. code_graph_builder/config.py +340 -0
  7. code_graph_builder/constants.py +708 -0
  8. code_graph_builder/embeddings/__init__.py +40 -0
  9. code_graph_builder/embeddings/qwen3_embedder.py +573 -0
  10. code_graph_builder/embeddings/vector_store.py +584 -0
  11. code_graph_builder/examples/__init__.py +0 -0
  12. code_graph_builder/examples/example_configuration.py +276 -0
  13. code_graph_builder/examples/example_kuzu_usage.py +109 -0
  14. code_graph_builder/examples/example_semantic_search_full.py +347 -0
  15. code_graph_builder/examples/generate_wiki.py +915 -0
  16. code_graph_builder/examples/graph_export_example.py +100 -0
  17. code_graph_builder/examples/rag_example.py +206 -0
  18. code_graph_builder/examples/test_cli_demo.py +129 -0
  19. code_graph_builder/examples/test_embedding_api.py +153 -0
  20. code_graph_builder/examples/test_kuzu_local.py +190 -0
  21. code_graph_builder/examples/test_rag_redis.py +390 -0
  22. code_graph_builder/graph_updater.py +605 -0
  23. code_graph_builder/guidance/__init__.py +1 -0
  24. code_graph_builder/guidance/agent.py +123 -0
  25. code_graph_builder/guidance/prompts.py +74 -0
  26. code_graph_builder/guidance/toolset.py +264 -0
  27. code_graph_builder/language_spec.py +536 -0
  28. code_graph_builder/mcp/__init__.py +21 -0
  29. code_graph_builder/mcp/api_doc_generator.py +764 -0
  30. code_graph_builder/mcp/file_editor.py +207 -0
  31. code_graph_builder/mcp/pipeline.py +777 -0
  32. code_graph_builder/mcp/server.py +161 -0
  33. code_graph_builder/mcp/tools.py +1800 -0
  34. code_graph_builder/models.py +115 -0
  35. code_graph_builder/parser_loader.py +344 -0
  36. code_graph_builder/parsers/__init__.py +7 -0
  37. code_graph_builder/parsers/call_processor.py +306 -0
  38. code_graph_builder/parsers/call_resolver.py +139 -0
  39. code_graph_builder/parsers/definition_processor.py +796 -0
  40. code_graph_builder/parsers/factory.py +119 -0
  41. code_graph_builder/parsers/import_processor.py +293 -0
  42. code_graph_builder/parsers/structure_processor.py +145 -0
  43. code_graph_builder/parsers/type_inference.py +143 -0
  44. code_graph_builder/parsers/utils.py +134 -0
  45. code_graph_builder/rag/__init__.py +68 -0
  46. code_graph_builder/rag/camel_agent.py +429 -0
  47. code_graph_builder/rag/client.py +298 -0
  48. code_graph_builder/rag/config.py +239 -0
  49. code_graph_builder/rag/cypher_generator.py +67 -0
  50. code_graph_builder/rag/llm_backend.py +210 -0
  51. code_graph_builder/rag/markdown_generator.py +352 -0
  52. code_graph_builder/rag/prompt_templates.py +440 -0
  53. code_graph_builder/rag/rag_engine.py +640 -0
  54. code_graph_builder/rag/review_report.md +172 -0
  55. code_graph_builder/rag/tests/__init__.py +3 -0
  56. code_graph_builder/rag/tests/test_camel_agent.py +313 -0
  57. code_graph_builder/rag/tests/test_client.py +221 -0
  58. code_graph_builder/rag/tests/test_config.py +177 -0
  59. code_graph_builder/rag/tests/test_markdown_generator.py +240 -0
  60. code_graph_builder/rag/tests/test_prompt_templates.py +160 -0
  61. code_graph_builder/services/__init__.py +39 -0
  62. code_graph_builder/services/graph_service.py +465 -0
  63. code_graph_builder/services/kuzu_service.py +665 -0
  64. code_graph_builder/services/memory_service.py +171 -0
  65. code_graph_builder/settings.py +75 -0
  66. code_graph_builder/tests/ACCEPTANCE_CRITERIA_PHASE2.md +401 -0
  67. code_graph_builder/tests/__init__.py +1 -0
  68. code_graph_builder/tests/run_acceptance_check.py +378 -0
  69. code_graph_builder/tests/test_api_find.py +231 -0
  70. code_graph_builder/tests/test_api_find_integration.py +226 -0
  71. code_graph_builder/tests/test_basic.py +78 -0
  72. code_graph_builder/tests/test_c_api_extraction.py +388 -0
  73. code_graph_builder/tests/test_call_resolution_scenarios.py +504 -0
  74. code_graph_builder/tests/test_embedder.py +411 -0
  75. code_graph_builder/tests/test_integration_semantic.py +434 -0
  76. code_graph_builder/tests/test_mcp_protocol.py +298 -0
  77. code_graph_builder/tests/test_mcp_user_flow.py +190 -0
  78. code_graph_builder/tests/test_rag.py +404 -0
  79. code_graph_builder/tests/test_settings.py +135 -0
  80. code_graph_builder/tests/test_step1_graph_build.py +264 -0
  81. code_graph_builder/tests/test_step2_api_docs.py +323 -0
  82. code_graph_builder/tests/test_step3_embedding.py +278 -0
  83. code_graph_builder/tests/test_vector_store.py +552 -0
  84. code_graph_builder/tools/__init__.py +40 -0
  85. code_graph_builder/tools/graph_query.py +495 -0
  86. code_graph_builder/tools/semantic_search.py +387 -0
  87. code_graph_builder/types.py +333 -0
  88. code_graph_builder/utils/__init__.py +0 -0
  89. code_graph_builder/utils/path_utils.py +30 -0
  90. code_graph_builder-0.2.0.dist-info/METADATA +321 -0
  91. code_graph_builder-0.2.0.dist-info/RECORD +93 -0
  92. code_graph_builder-0.2.0.dist-info/WHEEL +4 -0
  93. code_graph_builder-0.2.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,160 @@
1
+ """Tests for prompt templates."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from code_graph_builder.rag.prompt_templates import (
8
+ CodeAnalysisPrompts,
9
+ CodeContext,
10
+ RAGPrompts,
11
+ create_code_context,
12
+ get_default_prompts,
13
+ )
14
+
15
+
16
+ class TestCodeContext:
17
+ """Tests for CodeContext."""
18
+
19
+ def test_basic_creation(self):
20
+ """Test basic context creation."""
21
+ ctx = CodeContext(
22
+ source_code="def foo(): pass",
23
+ file_path="test.py",
24
+ qualified_name="test.foo",
25
+ entity_type="Function",
26
+ )
27
+ assert ctx.source_code == "def foo(): pass"
28
+ assert ctx.file_path == "test.py"
29
+ assert ctx.qualified_name == "test.foo"
30
+ assert ctx.entity_type == "Function"
31
+
32
+ def test_format_context(self):
33
+ """Test context formatting."""
34
+ ctx = CodeContext(
35
+ source_code="def foo(): pass",
36
+ file_path="test.py",
37
+ qualified_name="test.foo",
38
+ entity_type="Function",
39
+ docstring="Test function",
40
+ callers=["caller1", "caller2"],
41
+ callees=["callee1"],
42
+ )
43
+ formatted = ctx.format_context()
44
+ assert "Entity: test.foo" in formatted
45
+ assert "Type: Function" in formatted
46
+ assert "File: test.py" in formatted
47
+ assert "Documentation:" in formatted
48
+ assert "def foo(): pass" in formatted
49
+ assert "Called By:" in formatted
50
+ assert "Calls:" in formatted
51
+
52
+ def test_format_context_minimal(self):
53
+ """Test context formatting with minimal data."""
54
+ ctx = CodeContext(source_code="x = 1")
55
+ formatted = ctx.format_context()
56
+ assert "Source Code:" in formatted
57
+ assert "x = 1" in formatted
58
+
59
+
60
+ class TestCodeAnalysisPrompts:
61
+ """Tests for CodeAnalysisPrompts."""
62
+
63
+ def test_get_system_prompt(self):
64
+ """Test getting system prompt."""
65
+ prompts = CodeAnalysisPrompts()
66
+ system = prompts.get_system_prompt()
67
+ assert "expert code analyst" in system.lower()
68
+ assert len(system) > 0
69
+
70
+ def test_format_explain_prompt(self):
71
+ """Test formatting explain prompt."""
72
+ prompts = CodeAnalysisPrompts()
73
+ ctx = CodeContext(source_code="def foo(): pass")
74
+ prompt = prompts.format_explain_prompt(ctx)
75
+ assert "explain" in prompt.lower()
76
+ assert "def foo(): pass" in prompt
77
+
78
+ def test_format_query_prompt(self):
79
+ """Test formatting query prompt."""
80
+ prompts = CodeAnalysisPrompts()
81
+ ctx = CodeContext(source_code="def foo(): pass")
82
+ prompt = prompts.format_query_prompt("What does this do?", ctx)
83
+ assert "What does this do?" in prompt
84
+ assert "def foo(): pass" in prompt
85
+
86
+ def test_format_documentation_prompt(self):
87
+ """Test formatting documentation prompt."""
88
+ prompts = CodeAnalysisPrompts()
89
+ ctx = CodeContext(source_code="def foo(): pass")
90
+ prompt = prompts.format_documentation_prompt(ctx)
91
+ assert "documentation" in prompt.lower()
92
+
93
+ def test_format_architecture_prompt(self):
94
+ """Test formatting architecture prompt."""
95
+ prompts = CodeAnalysisPrompts()
96
+ ctx = CodeContext(source_code="class Foo: pass")
97
+ prompt = prompts.format_architecture_prompt(ctx)
98
+ assert "architecture" in prompt.lower()
99
+
100
+ def test_format_summary_prompt(self):
101
+ """Test formatting summary prompt."""
102
+ prompts = CodeAnalysisPrompts()
103
+ ctx = CodeContext(source_code="def foo(): pass")
104
+ prompt = prompts.format_summary_prompt(ctx)
105
+ assert "summary" in prompt.lower()
106
+
107
+ def test_format_multi_context_prompt(self):
108
+ """Test formatting multi-context prompt."""
109
+ prompts = CodeAnalysisPrompts()
110
+ contexts = [
111
+ CodeContext(source_code="def foo(): pass"),
112
+ CodeContext(source_code="def bar(): pass"),
113
+ ]
114
+ prompt = prompts.format_multi_context_prompt("Compare these", contexts)
115
+ assert "Compare these" in prompt
116
+ assert "Context 1" in prompt
117
+ assert "Context 2" in prompt
118
+
119
+
120
+ class TestRAGPrompts:
121
+ """Tests for RAGPrompts."""
122
+
123
+ def test_format_rag_query_with_contexts(self):
124
+ """Test formatting RAG query with contexts."""
125
+ prompts = RAGPrompts()
126
+ contexts = [
127
+ CodeContext(
128
+ source_code="def foo(): pass",
129
+ qualified_name="test.foo",
130
+ ),
131
+ ]
132
+ system, user = prompts.format_rag_query("Explain this", contexts)
133
+ assert len(system) > 0
134
+ assert "Explain this" in user
135
+ assert "test.foo" in user
136
+
137
+ def test_format_rag_query_no_contexts(self):
138
+ """Test formatting RAG query with no contexts."""
139
+ prompts = RAGPrompts()
140
+ system, user = prompts.format_rag_query("Explain this", [])
141
+ assert "No relevant code" in user
142
+
143
+
144
+ class TestConvenienceFunctions:
145
+ """Tests for convenience functions."""
146
+
147
+ def test_get_default_prompts(self):
148
+ """Test getting default prompts."""
149
+ prompts = get_default_prompts()
150
+ assert isinstance(prompts, RAGPrompts)
151
+
152
+ def test_create_code_context(self):
153
+ """Test creating code context."""
154
+ ctx = create_code_context(
155
+ source_code="def foo(): pass",
156
+ file_path="test.py",
157
+ qualified_name="test.foo",
158
+ )
159
+ assert isinstance(ctx, CodeContext)
160
+ assert ctx.source_code == "def foo(): pass"
@@ -0,0 +1,39 @@
1
+ """Code Graph Builder - Services."""
2
+
3
+ from typing import Protocol, runtime_checkable
4
+
5
+ from ..types import PropertyDict, PropertyValue, ResultRow
6
+
7
+
8
+ @runtime_checkable
9
+ class IngestorProtocol(Protocol):
10
+ """Protocol for graph data ingestors."""
11
+
12
+ def ensure_node_batch(self, label: str, properties: PropertyDict) -> None: ...
13
+
14
+ def ensure_relationship_batch(
15
+ self,
16
+ from_spec: tuple[str, str, PropertyValue],
17
+ rel_type: str,
18
+ to_spec: tuple[str, str, PropertyValue],
19
+ properties: PropertyDict | None = None,
20
+ ) -> None: ...
21
+
22
+ def flush_all(self) -> None: ...
23
+
24
+
25
+ @runtime_checkable
26
+ class QueryProtocol(Protocol):
27
+ """Protocol for graph query operations."""
28
+
29
+ def fetch_all(
30
+ self, query: str, params: PropertyDict | None = None
31
+ ) -> list[ResultRow]: ...
32
+
33
+ def execute_write(self, query: str, params: PropertyDict | None = None) -> None: ...
34
+
35
+
36
+ # Import implementation
37
+ from .graph_service import MemgraphIngestor
38
+
39
+ __all__ = ["IngestorProtocol", "QueryProtocol", "MemgraphIngestor"]
@@ -0,0 +1,465 @@
1
+ """Graph service for connecting to and interacting with Memgraph."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import types
6
+ from collections import defaultdict
7
+ from collections.abc import Generator, Sequence
8
+ from contextlib import contextmanager
9
+ from datetime import UTC, datetime
10
+ from typing import TYPE_CHECKING
11
+
12
+ from loguru import logger
13
+
14
+ from ..types import (
15
+ BatchParams,
16
+ BatchWrapper,
17
+ ColumnDescriptor,
18
+ CursorProtocol,
19
+ GraphData,
20
+ GraphMetadata,
21
+ GraphNode,
22
+ NodeBatchRow,
23
+ PropertyDict,
24
+ PropertyValue,
25
+ RelBatchRow,
26
+ ResultRow,
27
+ ResultValue,
28
+ )
29
+
30
+ if TYPE_CHECKING:
31
+ import mgclient
32
+
33
+
34
+ class MemgraphIngestor:
35
+ """Ingestor for writing code graph data to Memgraph."""
36
+
37
+ def __init__(self, host: str, port: int, batch_size: int = 1000):
38
+ self._host = host
39
+ self._port = port
40
+ if batch_size < 1:
41
+ raise ValueError("batch_size must be at least 1")
42
+ self.batch_size = batch_size
43
+ self.conn: mgclient.Connection | None = None
44
+ self.node_buffer: list[tuple[str, dict[str, PropertyValue]]] = []
45
+ self.relationship_buffer: list[
46
+ tuple[
47
+ tuple[str, str, PropertyValue],
48
+ str,
49
+ tuple[str, str, PropertyValue],
50
+ dict[str, PropertyValue] | None,
51
+ ]
52
+ ] = []
53
+
54
+ def __enter__(self) -> MemgraphIngestor:
55
+ import mgclient
56
+
57
+ logger.info(f"Connecting to Memgraph at {self._host}:{self._port}")
58
+ self.conn = mgclient.connect(host=self._host, port=self._port)
59
+ self.conn.autocommit = True
60
+ logger.info("Connected to Memgraph")
61
+ return self
62
+
63
+ def __exit__(
64
+ self,
65
+ exc_type: type | None,
66
+ exc_val: Exception | None,
67
+ exc_tb: types.TracebackType | None,
68
+ ) -> None:
69
+ if exc_type:
70
+ logger.exception(f"Exception during ingest: {exc_val}")
71
+ try:
72
+ self.flush_all()
73
+ except Exception as flush_err:
74
+ logger.error(f"Flush error during exception handling: {flush_err}")
75
+ else:
76
+ self.flush_all()
77
+ if self.conn:
78
+ self.conn.close()
79
+ logger.info("Disconnected from Memgraph")
80
+
81
+ @contextmanager
82
+ def _get_cursor(self) -> Generator[CursorProtocol, None, None]:
83
+ if not self.conn:
84
+ raise ConnectionError("Not connected to database")
85
+ cursor: CursorProtocol | None = None
86
+ try:
87
+ cursor = self.conn.cursor()
88
+ yield cursor
89
+ finally:
90
+ if cursor:
91
+ cursor.close()
92
+
93
+ def _cursor_to_results(self, cursor: CursorProtocol) -> list[ResultRow]:
94
+ if not cursor.description:
95
+ return []
96
+ column_names = [desc.name for desc in cursor.description]
97
+ return [
98
+ dict[str, ResultValue](zip(column_names, row)) for row in cursor.fetchall()
99
+ ]
100
+
101
+ def _execute_query(
102
+ self,
103
+ query: str,
104
+ params: dict[str, PropertyValue] | None = None,
105
+ ) -> list[ResultRow]:
106
+ params = params or {}
107
+ with self._get_cursor() as cursor:
108
+ try:
109
+ cursor.execute(query, params)
110
+ return self._cursor_to_results(cursor)
111
+ except Exception as e:
112
+ if "already exists" not in str(e).lower():
113
+ logger.error(f"Query error: {e}")
114
+ logger.error(f"Query: {query}")
115
+ logger.error(f"Params: {params}")
116
+ raise
117
+
118
+ def _execute_batch(self, query: str, params_list: Sequence[BatchParams]) -> None:
119
+ if not self.conn or not params_list:
120
+ return
121
+ cursor = None
122
+ try:
123
+ cursor = self.conn.cursor()
124
+ cursor.execute(
125
+ f"UNWIND $batch AS row\n{query}",
126
+ BatchWrapper(batch=params_list),
127
+ )
128
+ except Exception as e:
129
+ if "already exists" not in str(e).lower():
130
+ logger.error(f"Batch error: {e}")
131
+ logger.error(f"Query: {query}")
132
+ finally:
133
+ if cursor:
134
+ cursor.close()
135
+
136
+ def fetch_all(
137
+ self,
138
+ query: str,
139
+ params: dict[str, PropertyValue] | None = None,
140
+ ) -> list[ResultRow]:
141
+ """Execute a query and return all results."""
142
+ return self._execute_query(query, params)
143
+
144
+ def ensure_node_batch(
145
+ self,
146
+ label: str,
147
+ id_key: str,
148
+ id_val: PropertyValue,
149
+ props: dict[str, PropertyValue],
150
+ ) -> None:
151
+ """Queue a node for batch insertion."""
152
+ unique_id = f"{label}:{id_key}:{id_val}"
153
+ self.node_buffer.append((unique_id, {"label": label, "id_key": id_key, "id_val": id_val, **props}))
154
+ if len(self.node_buffer) >= self.batch_size:
155
+ self.flush_nodes()
156
+
157
+ def ensure_relationship_batch(
158
+ self,
159
+ from_label: str,
160
+ from_key: str,
161
+ from_val: PropertyValue,
162
+ rel_type: str,
163
+ to_label: str,
164
+ to_key: str,
165
+ to_val: PropertyValue,
166
+ props: dict[str, PropertyValue] | None = None,
167
+ ) -> None:
168
+ """Queue a relationship for batch insertion."""
169
+ from_id = (from_label, from_key, from_val)
170
+ to_id = (to_label, to_key, to_val)
171
+ self.relationship_buffer.append((from_id, rel_type, to_id, props))
172
+ if len(self.relationship_buffer) >= self.batch_size:
173
+ self.flush_relationships()
174
+
175
+ def flush_nodes(self) -> None:
176
+ """Flush buffered nodes to the database."""
177
+ if not self.node_buffer:
178
+ return
179
+
180
+ # Group by label for efficient batching
181
+ by_label: defaultdict[str, list[dict]] = defaultdict(list)
182
+ for _unique_id, node_data in self.node_buffer:
183
+ label = node_data.pop("label")
184
+ by_label[label].append(node_data)
185
+
186
+ for label, nodes in by_label.items():
187
+ query = f"""
188
+ UNWIND $batch AS row
189
+ MERGE (n:{label} {{{nodes[0].get('id_key', 'id')}: row.id_val}})
190
+ SET n += row
191
+ """
192
+ self._execute_batch(query, nodes)
193
+
194
+ logger.debug(f"Flushed {len(self.node_buffer)} nodes")
195
+ self.node_buffer.clear()
196
+
197
+ def flush_relationships(self) -> None:
198
+ """Flush buffered relationships to the database."""
199
+ if not self.relationship_buffer:
200
+ return
201
+
202
+ # Group by type for efficient batching
203
+ by_type: defaultdict[
204
+ str,
205
+ list[dict],
206
+ ] = defaultdict(list)
207
+ for (from_label, from_key, from_val), rel_type, (to_label, to_key, to_val), props in self.relationship_buffer:
208
+ row: dict = {
209
+ "from_label": from_label,
210
+ "from_key": from_key,
211
+ "from_val": from_val,
212
+ "to_label": to_label,
213
+ "to_key": to_key,
214
+ "to_val": to_val,
215
+ }
216
+ if props:
217
+ row["props"] = props
218
+ by_type[rel_type].append(row)
219
+
220
+ for rel_type, rels in by_type.items():
221
+ query = f"""
222
+ UNWIND $batch AS row
223
+ MATCH (a {{{rels[0].get('from_key', 'id')}: row.from_val}})
224
+ MATCH (b {{{rels[0].get('to_key', 'id')}: row.to_val}})
225
+ MERGE (a)-[r:{rel_type}]->(b)
226
+ SET r += row.props
227
+ """
228
+ self._execute_batch(query, rels)
229
+
230
+ logger.debug(f"Flushed {len(self.relationship_buffer)} relationships")
231
+ self.relationship_buffer.clear()
232
+
233
+ def flush_all(self) -> None:
234
+ """Flush all buffered data to the database."""
235
+ self.flush_nodes()
236
+ self.flush_relationships()
237
+
238
+ def clean_database(self) -> None:
239
+ """Delete all data from the database."""
240
+ logger.warning("Cleaning database - deleting all nodes and relationships")
241
+ self._execute_query("MATCH (n) DETACH DELETE n;")
242
+ logger.info("Database cleaned")
243
+
244
+ def list_projects(self) -> list[str]:
245
+ """List all projects in the database."""
246
+ results = self._execute_query(
247
+ "MATCH (p:Project) RETURN p.name AS name ORDER BY p.name"
248
+ )
249
+ return [row["name"] for row in results if row.get("name")]
250
+
251
+ def delete_project(self, project_name: str) -> None:
252
+ """Delete a project and all its related data."""
253
+ logger.info(f"Deleting project: {project_name}")
254
+ query = """
255
+ MATCH (p:Project {name: $project_name})
256
+ OPTIONAL MATCH (p)-[:CONTAINS_PACKAGE|CONTAINS_FOLDER|CONTAINS_FILE|CONTAINS_MODULE*]->(container)
257
+ OPTIONAL MATCH (container)-[:DEFINES|DEFINES_METHOD*]->(defined)
258
+ DETACH DELETE p, container, defined
259
+ """
260
+ self._execute_query(query, {"project_name": project_name})
261
+ logger.info(f"Project {project_name} deleted")
262
+
263
+ def export_graph_to_dict(self) -> GraphData:
264
+ """Export the entire graph as a dictionary."""
265
+ logger.info("Exporting graph to dictionary")
266
+
267
+ nodes_query = """
268
+ MATCH (n)
269
+ RETURN id(n) as node_id, labels(n) as labels, properties(n) as properties
270
+ """
271
+ nodes = self._execute_query(nodes_query)
272
+
273
+ rels_query = """
274
+ MATCH (a)-[r]->(b)
275
+ RETURN id(a) as from_id, id(b) as to_id, type(r) as type, properties(r) as properties
276
+ """
277
+ relationships = self._execute_query(rels_query)
278
+
279
+ metadata = GraphMetadata(
280
+ total_nodes=len(nodes),
281
+ total_relationships=len(relationships),
282
+ exported_at=datetime.now(UTC).isoformat(),
283
+ )
284
+
285
+ logger.info(
286
+ f"Exported {len(nodes)} nodes and {len(relationships)} relationships"
287
+ )
288
+
289
+ return GraphData(
290
+ nodes=nodes,
291
+ relationships=relationships,
292
+ metadata=metadata,
293
+ )
294
+
295
+ def get_node_by_id(self, node_id: int) -> GraphNode | None:
296
+ """Get a node by its internal ID.
297
+
298
+ Args:
299
+ node_id: Memgraph internal node ID
300
+
301
+ Returns:
302
+ GraphNode if found, None otherwise
303
+ """
304
+ query = """
305
+ MATCH (n)
306
+ WHERE id(n) = $node_id
307
+ RETURN id(n) as node_id, labels(n) as labels, properties(n) as props
308
+ """
309
+ results = self._execute_query(query, {"node_id": node_id})
310
+
311
+ if not results:
312
+ return None
313
+
314
+ row = results[0]
315
+ return self._row_to_graph_node(row)
316
+
317
+ def get_nodes_by_ids(self, node_ids: list[int]) -> list[GraphNode]:
318
+ """Get multiple nodes by their internal IDs.
319
+
320
+ Args:
321
+ node_ids: List of Memgraph internal node IDs
322
+
323
+ Returns:
324
+ List of GraphNode objects
325
+ """
326
+ if not node_ids:
327
+ return []
328
+
329
+ query = """
330
+ MATCH (n)
331
+ WHERE id(n) IN $node_ids
332
+ RETURN id(n) as node_id, labels(n) as labels, properties(n) as props
333
+ """
334
+ results = self._execute_query(query, {"node_ids": node_ids})
335
+
336
+ return [self._row_to_graph_node(row) for row in results if row]
337
+
338
+ def search_nodes(
339
+ self,
340
+ query_str: str,
341
+ label: str | None = None,
342
+ limit: int = 10,
343
+ ) -> list[GraphNode]:
344
+ """Search nodes by name or qualified name.
345
+
346
+ Args:
347
+ query_str: Search query string
348
+ label: Optional node label filter
349
+ limit: Maximum number of results
350
+
351
+ Returns:
352
+ List of matching GraphNode objects
353
+ """
354
+ if label:
355
+ cypher = """
356
+ MATCH (n:$label)
357
+ WHERE n.name CONTAINS $query OR n.qualified_name CONTAINS $query
358
+ RETURN id(n) as node_id, labels(n) as labels, properties(n) as props
359
+ LIMIT $limit
360
+ """
361
+ cypher = cypher.replace("$label", label)
362
+ else:
363
+ cypher = """
364
+ MATCH (n)
365
+ WHERE n.name CONTAINS $query OR n.qualified_name CONTAINS $query
366
+ RETURN id(n) as node_id, labels(n) as labels, properties(n) as props
367
+ LIMIT $limit
368
+ """
369
+
370
+ results = self._execute_query(
371
+ cypher, {"query": query_str, "limit": limit}
372
+ )
373
+
374
+ return [self._row_to_graph_node(row) for row in results if row]
375
+
376
+ def get_node_relationships(
377
+ self,
378
+ node_id: int,
379
+ rel_type: str | None = None,
380
+ direction: str = "both",
381
+ ) -> list[tuple[GraphNode, str, str]]:
382
+ """Get relationships for a node.
383
+
384
+ Args:
385
+ node_id: Node identifier
386
+ rel_type: Optional relationship type filter
387
+ direction: "out", "in", or "both"
388
+
389
+ Returns:
390
+ List of (related_node, relationship_type, direction) tuples
391
+ """
392
+ results = []
393
+
394
+ if direction in ("out", "both"):
395
+ if rel_type:
396
+ query = f"""
397
+ MATCH (n)-[r:{rel_type}]->(m)
398
+ WHERE id(n) = $node_id
399
+ RETURN id(m) as node_id, labels(m) as labels, properties(m) as props,
400
+ type(r) as rel_type, "out" as direction
401
+ """
402
+ else:
403
+ query = """
404
+ MATCH (n)-[r]->(m)
405
+ WHERE id(n) = $node_id
406
+ RETURN id(m) as node_id, labels(m) as labels, properties(m) as props,
407
+ type(r) as rel_type, "out" as direction
408
+ """
409
+ rows = self._execute_query(query, {"node_id": node_id})
410
+ for row in rows:
411
+ node = self._row_to_graph_node(row)
412
+ results.append((node, row.get("rel_type", "UNKNOWN"), "out"))
413
+
414
+ if direction in ("in", "both"):
415
+ if rel_type:
416
+ query = f"""
417
+ MATCH (n)<-[r:{rel_type}]-(m)
418
+ WHERE id(n) = $node_id
419
+ RETURN id(m) as node_id, labels(m) as labels, properties(m) as props,
420
+ type(r) as rel_type, "in" as direction
421
+ """
422
+ else:
423
+ query = """
424
+ MATCH (n)<-[r]-(m)
425
+ WHERE id(n) = $node_id
426
+ RETURN id(m) as node_id, labels(m) as labels, properties(m) as props,
427
+ type(r) as rel_type, "in" as direction
428
+ """
429
+ rows = self._execute_query(query, {"node_id": node_id})
430
+ for row in rows:
431
+ node = self._row_to_graph_node(row)
432
+ results.append((node, row.get("rel_type", "UNKNOWN"), "in"))
433
+
434
+ return results
435
+
436
+ def _row_to_graph_node(self, row: ResultRow) -> GraphNode:
437
+ """Convert a query result row to GraphNode.
438
+
439
+ Args:
440
+ row: Query result row
441
+
442
+ Returns:
443
+ GraphNode instance
444
+ """
445
+ props = row.get("props", {})
446
+ if isinstance(props, dict):
447
+ properties = dict(props)
448
+ else:
449
+ properties = {}
450
+
451
+ labels = row.get("labels", [])
452
+ if not isinstance(labels, list):
453
+ labels = []
454
+
455
+ return GraphNode(
456
+ node_id=row.get("node_id", 0),
457
+ labels=labels,
458
+ qualified_name=properties.get("qualified_name", ""),
459
+ name=properties.get("name", ""),
460
+ path=properties.get("path"),
461
+ start_line=properties.get("start_line"),
462
+ end_line=properties.get("end_line"),
463
+ docstring=properties.get("docstring"),
464
+ properties=properties,
465
+ )