cognee 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. cognee/__init__.py +1 -0
  2. cognee/api/health.py +2 -12
  3. cognee/api/v1/add/add.py +46 -6
  4. cognee/api/v1/add/routers/get_add_router.py +5 -1
  5. cognee/api/v1/cognify/cognify.py +29 -9
  6. cognee/api/v1/datasets/datasets.py +11 -0
  7. cognee/api/v1/responses/default_tools.py +0 -1
  8. cognee/api/v1/responses/dispatch_function.py +1 -1
  9. cognee/api/v1/responses/routers/default_tools.py +0 -1
  10. cognee/api/v1/search/search.py +11 -9
  11. cognee/api/v1/settings/routers/get_settings_router.py +7 -1
  12. cognee/api/v1/ui/ui.py +47 -16
  13. cognee/api/v1/update/routers/get_update_router.py +1 -1
  14. cognee/api/v1/update/update.py +3 -3
  15. cognee/cli/_cognee.py +61 -10
  16. cognee/cli/commands/add_command.py +3 -3
  17. cognee/cli/commands/cognify_command.py +3 -3
  18. cognee/cli/commands/config_command.py +9 -7
  19. cognee/cli/commands/delete_command.py +3 -3
  20. cognee/cli/commands/search_command.py +3 -7
  21. cognee/cli/config.py +0 -1
  22. cognee/context_global_variables.py +5 -0
  23. cognee/exceptions/exceptions.py +1 -1
  24. cognee/infrastructure/databases/cache/__init__.py +2 -0
  25. cognee/infrastructure/databases/cache/cache_db_interface.py +79 -0
  26. cognee/infrastructure/databases/cache/config.py +44 -0
  27. cognee/infrastructure/databases/cache/get_cache_engine.py +67 -0
  28. cognee/infrastructure/databases/cache/redis/RedisAdapter.py +243 -0
  29. cognee/infrastructure/databases/exceptions/__init__.py +1 -0
  30. cognee/infrastructure/databases/exceptions/exceptions.py +18 -2
  31. cognee/infrastructure/databases/graph/get_graph_engine.py +1 -1
  32. cognee/infrastructure/databases/graph/graph_db_interface.py +5 -0
  33. cognee/infrastructure/databases/graph/kuzu/adapter.py +67 -44
  34. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +13 -3
  35. cognee/infrastructure/databases/graph/neo4j_driver/deadlock_retry.py +1 -1
  36. cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py +1 -1
  37. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +1 -1
  38. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +21 -3
  39. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +17 -10
  40. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +17 -4
  41. cognee/infrastructure/databases/vector/embeddings/config.py +2 -3
  42. cognee/infrastructure/databases/vector/exceptions/exceptions.py +1 -1
  43. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -1
  44. cognee/infrastructure/files/exceptions.py +1 -1
  45. cognee/infrastructure/files/storage/LocalFileStorage.py +9 -9
  46. cognee/infrastructure/files/storage/S3FileStorage.py +11 -11
  47. cognee/infrastructure/files/utils/guess_file_type.py +6 -0
  48. cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +0 -5
  49. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +19 -9
  50. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +17 -5
  51. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +17 -5
  52. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +32 -0
  53. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/__init__.py +0 -0
  54. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +109 -0
  55. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +33 -8
  56. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +40 -18
  57. cognee/infrastructure/loaders/LoaderEngine.py +27 -7
  58. cognee/infrastructure/loaders/external/__init__.py +7 -0
  59. cognee/infrastructure/loaders/external/advanced_pdf_loader.py +2 -8
  60. cognee/infrastructure/loaders/external/beautiful_soup_loader.py +310 -0
  61. cognee/infrastructure/loaders/supported_loaders.py +7 -0
  62. cognee/modules/data/exceptions/exceptions.py +1 -1
  63. cognee/modules/data/methods/__init__.py +3 -0
  64. cognee/modules/data/methods/get_dataset_data.py +4 -1
  65. cognee/modules/data/methods/has_dataset_data.py +21 -0
  66. cognee/modules/engine/models/TableRow.py +0 -1
  67. cognee/modules/ingestion/save_data_to_file.py +9 -2
  68. cognee/modules/pipelines/exceptions/exceptions.py +1 -1
  69. cognee/modules/pipelines/operations/pipeline.py +12 -1
  70. cognee/modules/pipelines/operations/run_tasks.py +25 -197
  71. cognee/modules/pipelines/operations/run_tasks_data_item.py +260 -0
  72. cognee/modules/pipelines/operations/run_tasks_distributed.py +121 -38
  73. cognee/modules/retrieval/EntityCompletionRetriever.py +48 -8
  74. cognee/modules/retrieval/base_graph_retriever.py +3 -1
  75. cognee/modules/retrieval/base_retriever.py +3 -1
  76. cognee/modules/retrieval/chunks_retriever.py +5 -1
  77. cognee/modules/retrieval/code_retriever.py +20 -2
  78. cognee/modules/retrieval/completion_retriever.py +50 -9
  79. cognee/modules/retrieval/cypher_search_retriever.py +11 -1
  80. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +47 -8
  81. cognee/modules/retrieval/graph_completion_cot_retriever.py +32 -1
  82. cognee/modules/retrieval/graph_completion_retriever.py +54 -10
  83. cognee/modules/retrieval/lexical_retriever.py +20 -2
  84. cognee/modules/retrieval/natural_language_retriever.py +10 -1
  85. cognee/modules/retrieval/summaries_retriever.py +5 -1
  86. cognee/modules/retrieval/temporal_retriever.py +62 -10
  87. cognee/modules/retrieval/user_qa_feedback.py +3 -2
  88. cognee/modules/retrieval/utils/completion.py +5 -0
  89. cognee/modules/retrieval/utils/description_to_codepart_search.py +1 -1
  90. cognee/modules/retrieval/utils/session_cache.py +156 -0
  91. cognee/modules/search/methods/get_search_type_tools.py +0 -5
  92. cognee/modules/search/methods/no_access_control_search.py +12 -1
  93. cognee/modules/search/methods/search.py +34 -2
  94. cognee/modules/search/types/SearchType.py +0 -1
  95. cognee/modules/settings/get_settings.py +23 -0
  96. cognee/modules/users/methods/get_authenticated_user.py +3 -1
  97. cognee/modules/users/methods/get_default_user.py +1 -6
  98. cognee/modules/users/roles/methods/create_role.py +2 -2
  99. cognee/modules/users/tenants/methods/create_tenant.py +2 -2
  100. cognee/shared/exceptions/exceptions.py +1 -1
  101. cognee/tasks/codingagents/coding_rule_associations.py +1 -2
  102. cognee/tasks/documents/exceptions/exceptions.py +1 -1
  103. cognee/tasks/graph/extract_graph_from_data.py +2 -0
  104. cognee/tasks/ingestion/data_item_to_text_file.py +3 -3
  105. cognee/tasks/ingestion/ingest_data.py +11 -5
  106. cognee/tasks/ingestion/save_data_item_to_storage.py +12 -1
  107. cognee/tasks/storage/add_data_points.py +3 -10
  108. cognee/tasks/storage/index_data_points.py +19 -14
  109. cognee/tasks/storage/index_graph_edges.py +25 -11
  110. cognee/tasks/web_scraper/__init__.py +34 -0
  111. cognee/tasks/web_scraper/config.py +26 -0
  112. cognee/tasks/web_scraper/default_url_crawler.py +446 -0
  113. cognee/tasks/web_scraper/models.py +46 -0
  114. cognee/tasks/web_scraper/types.py +4 -0
  115. cognee/tasks/web_scraper/utils.py +142 -0
  116. cognee/tasks/web_scraper/web_scraper_task.py +396 -0
  117. cognee/tests/cli_tests/cli_unit_tests/test_cli_utils.py +0 -1
  118. cognee/tests/integration/web_url_crawler/test_default_url_crawler.py +13 -0
  119. cognee/tests/integration/web_url_crawler/test_tavily_crawler.py +19 -0
  120. cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py +344 -0
  121. cognee/tests/subprocesses/reader.py +25 -0
  122. cognee/tests/subprocesses/simple_cognify_1.py +31 -0
  123. cognee/tests/subprocesses/simple_cognify_2.py +31 -0
  124. cognee/tests/subprocesses/writer.py +32 -0
  125. cognee/tests/tasks/descriptive_metrics/metrics_test_utils.py +0 -2
  126. cognee/tests/tasks/descriptive_metrics/neo4j_metrics_test.py +8 -3
  127. cognee/tests/tasks/entity_extraction/entity_extraction_test.py +89 -0
  128. cognee/tests/tasks/web_scraping/web_scraping_test.py +172 -0
  129. cognee/tests/test_add_docling_document.py +56 -0
  130. cognee/tests/test_chromadb.py +7 -11
  131. cognee/tests/test_concurrent_subprocess_access.py +76 -0
  132. cognee/tests/test_conversation_history.py +240 -0
  133. cognee/tests/test_kuzu.py +27 -15
  134. cognee/tests/test_lancedb.py +7 -11
  135. cognee/tests/test_library.py +32 -2
  136. cognee/tests/test_neo4j.py +24 -16
  137. cognee/tests/test_neptune_analytics_vector.py +7 -11
  138. cognee/tests/test_permissions.py +9 -13
  139. cognee/tests/test_pgvector.py +4 -4
  140. cognee/tests/test_remote_kuzu.py +8 -11
  141. cognee/tests/test_s3_file_storage.py +1 -1
  142. cognee/tests/test_search_db.py +6 -8
  143. cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +89 -0
  144. cognee/tests/unit/modules/retrieval/conversation_history_test.py +154 -0
  145. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/METADATA +22 -7
  146. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/RECORD +155 -128
  147. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/entry_points.txt +1 -0
  148. distributed/Dockerfile +0 -3
  149. distributed/entrypoint.py +21 -9
  150. distributed/signal.py +5 -0
  151. distributed/workers/data_point_saving_worker.py +64 -34
  152. distributed/workers/graph_saving_worker.py +71 -47
  153. cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py +0 -1116
  154. cognee/modules/retrieval/insights_retriever.py +0 -133
  155. cognee/tests/test_memgraph.py +0 -109
  156. cognee/tests/unit/modules/retrieval/insights_retriever_test.py +0 -251
  157. distributed/poetry.lock +0 -12238
  158. distributed/pyproject.toml +0 -185
  159. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/WHEEL +0 -0
  160. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/licenses/LICENSE +0 -0
  161. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,1116 +0,0 @@
1
- """Memgraph Adapter for Graph Database"""
2
-
3
- import json
4
- from cognee.shared.logging_utils import get_logger, ERROR
5
- import asyncio
6
- from textwrap import dedent
7
- from typing import Optional, Any, List, Dict, Type, Tuple
8
- from contextlib import asynccontextmanager
9
- from uuid import UUID
10
- from neo4j import AsyncSession
11
- from neo4j import AsyncGraphDatabase
12
- from neo4j.exceptions import Neo4jError
13
- from cognee.infrastructure.engine import DataPoint
14
- from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
15
- from cognee.modules.storage.utils import JSONEncoder
16
- from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
17
-
18
- logger = get_logger("MemgraphAdapter", level=ERROR)
19
-
20
-
21
- class MemgraphAdapter(GraphDBInterface):
22
- """
23
- Handles interaction with a Memgraph database through various graph operations.
24
-
25
- Public methods include:
26
- - get_session
27
- - query
28
- - has_node
29
- - add_node
30
- - add_nodes
31
- - extract_node
32
- - extract_nodes
33
- - delete_node
34
- - delete_nodes
35
- - has_edge
36
- - has_edges
37
- - add_edge
38
- - add_edges
39
- - get_edges
40
- - get_disconnected_nodes
41
- - get_predecessors
42
- - get_successors
43
- - get_neighbours
44
- - get_connections
45
- - remove_connection_to_predecessors_of
46
- - remove_connection_to_successors_of
47
- - delete_graph
48
- - serialize_properties
49
- - get_model_independent_graph_data
50
- - get_graph_data
51
- - get_nodeset_subgraph
52
- - get_filtered_graph_data
53
- - get_node_labels_string
54
- - get_relationship_labels_string
55
- - get_graph_metrics
56
- """
57
-
58
- def __init__(
59
- self,
60
- graph_database_url: str,
61
- graph_database_username: Optional[str] = None,
62
- graph_database_password: Optional[str] = None,
63
- driver: Optional[Any] = None,
64
- ):
65
- # Only use auth if both username and password are provided
66
- auth = None
67
- if graph_database_username and graph_database_password:
68
- auth = (graph_database_username, graph_database_password)
69
-
70
- self.driver = driver or AsyncGraphDatabase.driver(
71
- graph_database_url,
72
- auth=auth,
73
- max_connection_lifetime=120,
74
- )
75
-
76
- @asynccontextmanager
77
- async def get_session(self) -> AsyncSession:
78
- """
79
- Manage a session with the database, yielding the session for use in operations.
80
- """
81
- async with self.driver.session() as session:
82
- yield session
83
-
84
- async def query(
85
- self,
86
- query: str,
87
- params: Optional[Dict[str, Any]] = None,
88
- ) -> List[Dict[str, Any]]:
89
- """
90
- Execute a provided query on the Memgraph database and return the results.
91
-
92
- Parameters:
93
- -----------
94
-
95
- - query (str): The Cypher query to be executed against the database.
96
- - params (Optional[Dict[str, Any]]): Optional parameters to be used in the query.
97
- (default None)
98
-
99
- Returns:
100
- --------
101
-
102
- - List[Dict[str, Any]]: A list of dictionaries representing the result set of the
103
- query.
104
- """
105
- try:
106
- async with self.get_session() as session:
107
- result = await session.run(query, params)
108
- data = await result.data()
109
- return data
110
- except Neo4jError as error:
111
- logger.error("Memgraph query error: %s", error, exc_info=True)
112
- raise error
113
-
114
- async def has_node(self, node_id: str) -> bool:
115
- """
116
- Determine if a node with the given ID exists in the database.
117
-
118
- Parameters:
119
- -----------
120
-
121
- - node_id (str): The ID of the node to check for existence.
122
-
123
- Returns:
124
- --------
125
-
126
- - bool: True if the node exists; otherwise, False.
127
- """
128
- results = await self.query(
129
- """
130
- MATCH (n)
131
- WHERE n.id = $node_id
132
- RETURN COUNT(n) > 0 AS node_exists
133
- """,
134
- {"node_id": node_id},
135
- )
136
- return results[0]["node_exists"] if len(results) > 0 else False
137
-
138
- async def add_node(self, node: DataPoint):
139
- """
140
- Add a new node to the database with specified properties.
141
-
142
- Parameters:
143
- -----------
144
-
145
- - node (DataPoint): The DataPoint object representing the node to add.
146
-
147
- Returns:
148
- --------
149
-
150
- The result of the node addition, including its internal ID and node ID.
151
- """
152
- serialized_properties = self.serialize_properties(node.model_dump())
153
-
154
- query = """
155
- MERGE (node {id: $node_id})
156
- ON CREATE SET node:$node_label, node += $properties, node.updated_at = timestamp()
157
- ON MATCH SET node:$node_label, node += $properties, node.updated_at = timestamp()
158
- RETURN ID(node) AS internal_id, node.id AS nodeId
159
- """
160
-
161
- params = {
162
- "node_id": str(node.id),
163
- "node_label": type(node).__name__,
164
- "properties": serialized_properties,
165
- }
166
- return await self.query(query, params)
167
-
168
- async def add_nodes(self, nodes: list[DataPoint]) -> None:
169
- """
170
- Add multiple nodes to the database in a single operation.
171
-
172
- Parameters:
173
- -----------
174
-
175
- - nodes (list[DataPoint]): A list of DataPoint objects representing the nodes to
176
- add.
177
-
178
- Returns:
179
- --------
180
-
181
- - None: None.
182
- """
183
- query = """
184
- UNWIND $nodes AS node
185
- MERGE (n {id: node.node_id})
186
- ON CREATE SET n:node.label, n += node.properties, n.updated_at = timestamp()
187
- ON MATCH SET n:node.label, n += node.properties, n.updated_at = timestamp()
188
- RETURN ID(n) AS internal_id, n.id AS nodeId
189
- """
190
-
191
- nodes = [
192
- {
193
- "node_id": str(node.id),
194
- "label": type(node).__name__,
195
- "properties": self.serialize_properties(node.model_dump()),
196
- }
197
- for node in nodes
198
- ]
199
-
200
- results = await self.query(query, dict(nodes=nodes))
201
- return results
202
-
203
- async def extract_node(self, node_id: str):
204
- """
205
- Retrieve a single node based on its ID.
206
-
207
- Parameters:
208
- -----------
209
-
210
- - node_id (str): The ID of the node to retrieve.
211
-
212
- Returns:
213
- --------
214
-
215
- The node corresponding to the provided ID, or None if not found.
216
- """
217
- results = await self.extract_nodes([node_id])
218
-
219
- return results[0] if len(results) > 0 else None
220
-
221
- async def extract_nodes(self, node_ids: List[str]):
222
- """
223
- Retrieve multiple nodes based on their IDs.
224
-
225
- Parameters:
226
- -----------
227
-
228
- - node_ids (List[str]): A list of IDs for the nodes to retrieve.
229
-
230
- Returns:
231
- --------
232
-
233
- A list of nodes corresponding to the provided IDs.
234
- """
235
- query = """
236
- UNWIND $node_ids AS id
237
- MATCH (node {id: id})
238
- RETURN node"""
239
-
240
- params = {"node_ids": node_ids}
241
-
242
- results = await self.query(query, params)
243
-
244
- return [result["node"] for result in results]
245
-
246
- async def delete_node(self, node_id: str):
247
- """
248
- Delete a node from the database based on its ID.
249
-
250
- Parameters:
251
- -----------
252
-
253
- - node_id (str): The ID of the node to delete.
254
-
255
- Returns:
256
- --------
257
-
258
- None.
259
- """
260
- sanitized_id = node_id.replace(":", "_")
261
-
262
- query = "MATCH (node: {{id: $node_id}}) DETACH DELETE node"
263
- params = {"node_id": sanitized_id}
264
-
265
- return await self.query(query, params)
266
-
267
- async def delete_nodes(self, node_ids: list[str]) -> None:
268
- """
269
- Delete multiple nodes from the database based on their IDs.
270
-
271
- Parameters:
272
- -----------
273
-
274
- - node_ids (list[str]): A list of IDs for the nodes to delete.
275
-
276
- Returns:
277
- --------
278
-
279
- - None: None.
280
- """
281
- query = """
282
- UNWIND $node_ids AS id
283
- MATCH (node {id: id})
284
- DETACH DELETE node"""
285
-
286
- params = {"node_ids": node_ids}
287
-
288
- return await self.query(query, params)
289
-
290
- async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
291
- """
292
- Check if a directed edge exists between two nodes identified by their IDs.
293
-
294
- Parameters:
295
- -----------
296
-
297
- - from_node (UUID): The ID of the source node.
298
- - to_node (UUID): The ID of the target node.
299
- - edge_label (str): The label of the edge to check.
300
-
301
- Returns:
302
- --------
303
-
304
- - bool: True if the edge exists; otherwise, False.
305
- """
306
- query = """
307
- MATCH (from_node)-[relationship]->(to_node)
308
- WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
309
- RETURN COUNT(relationship) > 0 AS edge_exists
310
- """
311
-
312
- params = {
313
- "from_node_id": str(from_node),
314
- "to_node_id": str(to_node),
315
- "edge_label": edge_label,
316
- }
317
-
318
- records = await self.query(query, params)
319
- return records[0]["edge_exists"] if records else False
320
-
321
- async def has_edges(self, edges):
322
- """
323
- Check for the existence of multiple edges based on provided criteria.
324
-
325
- Parameters:
326
- -----------
327
-
328
- - edges: A list of edges to verify existence for.
329
-
330
- Returns:
331
- --------
332
-
333
- A list of boolean values indicating the existence of each edge.
334
- """
335
- query = """
336
- UNWIND $edges AS edge
337
- MATCH (a)-[r]->(b)
338
- WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
339
- RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
340
- """
341
-
342
- try:
343
- params = {
344
- "edges": [
345
- {
346
- "from_node": str(edge[0]),
347
- "to_node": str(edge[1]),
348
- "relationship_name": edge[2],
349
- }
350
- for edge in edges
351
- ],
352
- }
353
-
354
- results = await self.query(query, params)
355
- return [result["edge_exists"] for result in results]
356
- except Neo4jError as error:
357
- logger.error("Memgraph query error: %s", error, exc_info=True)
358
- raise error
359
-
360
- async def add_edge(
361
- self,
362
- from_node: UUID,
363
- to_node: UUID,
364
- relationship_name: str,
365
- edge_properties: Optional[Dict[str, Any]] = None,
366
- ):
367
- """
368
- Add a directed edge between two nodes with optional properties.
369
-
370
- Parameters:
371
- -----------
372
-
373
- - from_node (UUID): The ID of the source node.
374
- - to_node (UUID): The ID of the target node.
375
- - relationship_name (str): The type/label of the relationship to create.
376
- - edge_properties (Optional[Dict[str, Any]]): Optional properties associated with
377
- the edge. (default None)
378
-
379
- Returns:
380
- --------
381
-
382
- The result of the edge addition operation, including relationship details.
383
- """
384
-
385
- exists = await asyncio.gather(self.has_node(str(from_node)), self.has_node(str(to_node)))
386
-
387
- if not all(exists):
388
- return None
389
-
390
- serialized_properties = self.serialize_properties(edge_properties or {})
391
-
392
- query = dedent(
393
- f"""\
394
- MATCH (from_node {{id: $from_node}}),
395
- (to_node {{id: $to_node}})
396
- WHERE from_node IS NOT NULL AND to_node IS NOT NULL
397
- MERGE (from_node)-[r:{relationship_name}]->(to_node)
398
- ON CREATE SET r += $properties, r.updated_at = timestamp()
399
- ON MATCH SET r += $properties, r.updated_at = timestamp()
400
- RETURN r
401
- """
402
- )
403
-
404
- params = {
405
- "from_node": str(from_node),
406
- "to_node": str(to_node),
407
- "relationship_name": relationship_name,
408
- "properties": serialized_properties,
409
- }
410
-
411
- return await self.query(query, params)
412
-
413
- async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
414
- """
415
- Batch add multiple edges between nodes, enforcing specified relationships.
416
-
417
- Parameters:
418
- -----------
419
-
420
- - edges (list[tuple[str, str, str, dict[str, Any]]): A list of tuples containing
421
- specifications for each edge to add.
422
-
423
- Returns:
424
- --------
425
-
426
- - None: None.
427
- """
428
- query = """
429
- UNWIND $edges AS edge
430
- MATCH (from_node {id: edge.from_node})
431
- MATCH (to_node {id: edge.to_node})
432
- CALL merge.relationship(
433
- from_node,
434
- edge.relationship_name,
435
- {
436
- source_node_id: edge.from_node,
437
- target_node_id: edge.to_node
438
- },
439
- edge.properties,
440
- to_node,
441
- {}
442
- ) YIELD rel
443
- RETURN rel"""
444
-
445
- edges = [
446
- {
447
- "from_node": str(edge[0]),
448
- "to_node": str(edge[1]),
449
- "relationship_name": edge[2],
450
- "properties": {
451
- **(edge[3] if edge[3] else {}),
452
- "source_node_id": str(edge[0]),
453
- "target_node_id": str(edge[1]),
454
- },
455
- }
456
- for edge in edges
457
- ]
458
-
459
- try:
460
- results = await self.query(query, dict(edges=edges))
461
- return results
462
- except Neo4jError as error:
463
- logger.error("Memgraph query error: %s", error, exc_info=True)
464
- raise error
465
-
466
- async def get_edges(self, node_id: str):
467
- """
468
- Retrieve all edges connected to a specific node identified by its ID.
469
-
470
- Parameters:
471
- -----------
472
-
473
- - node_id (str): The ID of the node for which to retrieve connected edges.
474
-
475
- Returns:
476
- --------
477
-
478
- A list of tuples representing the edges connected to the node.
479
- """
480
- query = """
481
- MATCH (n {id: $node_id})-[r]-(m)
482
- RETURN n, r, m
483
- """
484
-
485
- results = await self.query(query, dict(node_id=node_id))
486
-
487
- return [
488
- (result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]})
489
- for result in results
490
- ]
491
-
492
- async def get_disconnected_nodes(self) -> list[str]:
493
- """
494
- Identify nodes in the graph that do not belong to the largest connected component.
495
-
496
- Returns:
497
- --------
498
-
499
- - list[str]: A list of IDs representing the disconnected nodes.
500
- """
501
- query = """
502
- // Step 1: Collect all nodes
503
- MATCH (n)
504
- WITH COLLECT(n) AS nodes
505
-
506
- // Step 2: Find all connected components
507
- WITH nodes
508
- CALL {
509
- WITH nodes
510
- UNWIND nodes AS startNode
511
- MATCH path = (startNode)-[*]-(connectedNode)
512
- WITH startNode, COLLECT(DISTINCT connectedNode) AS component
513
- RETURN component
514
- }
515
-
516
- // Step 3: Aggregate components
517
- WITH COLLECT(component) AS components
518
-
519
- // Step 4: Identify the largest connected component
520
- UNWIND components AS component
521
- WITH component
522
- ORDER BY SIZE(component) DESC
523
- LIMIT 1
524
- WITH component AS largestComponent
525
-
526
- // Step 5: Find nodes not in the largest connected component
527
- MATCH (n)
528
- WHERE NOT n IN largestComponent
529
- RETURN COLLECT(ID(n)) AS ids
530
- """
531
-
532
- results = await self.query(query)
533
- return results[0]["ids"] if len(results) > 0 else []
534
-
535
- async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
536
- """
537
- Retrieve all predecessors of a node based on its ID and optional edge label.
538
-
539
- Parameters:
540
- -----------
541
-
542
- - node_id (str): The ID of the node to find predecessors for.
543
- - edge_label (str): Optional edge label to filter predecessors. (default None)
544
-
545
- Returns:
546
- --------
547
-
548
- - list[str]: A list of predecessor node IDs.
549
- """
550
- if edge_label is not None:
551
- query = """
552
- MATCH (node)<-[r]-(predecessor)
553
- WHERE node.id = $node_id AND type(r) = $edge_label
554
- RETURN predecessor
555
- """
556
-
557
- results = await self.query(
558
- query,
559
- dict(
560
- node_id=node_id,
561
- edge_label=edge_label,
562
- ),
563
- )
564
-
565
- return [result["predecessor"] for result in results]
566
- else:
567
- query = """
568
- MATCH (node)<-[r]-(predecessor)
569
- WHERE node.id = $node_id
570
- RETURN predecessor
571
- """
572
-
573
- results = await self.query(
574
- query,
575
- dict(
576
- node_id=node_id,
577
- ),
578
- )
579
-
580
- return [result["predecessor"] for result in results]
581
-
582
- async def get_successors(self, node_id: str, edge_label: str = None) -> list[str]:
583
- """
584
- Retrieve all successors of a node based on its ID and optional edge label.
585
-
586
- Parameters:
587
- -----------
588
-
589
- - node_id (str): The ID of the node to find successors for.
590
- - edge_label (str): Optional edge label to filter successors. (default None)
591
-
592
- Returns:
593
- --------
594
-
595
- - list[str]: A list of successor node IDs.
596
- """
597
- if edge_label is not None:
598
- query = """
599
- MATCH (node)-[r]->(successor)
600
- WHERE node.id = $node_id AND type(r) = $edge_label
601
- RETURN successor
602
- """
603
-
604
- results = await self.query(
605
- query,
606
- dict(
607
- node_id=node_id,
608
- edge_label=edge_label,
609
- ),
610
- )
611
-
612
- return [result["successor"] for result in results]
613
- else:
614
- query = """
615
- MATCH (node)-[r]->(successor)
616
- WHERE node.id = $node_id
617
- RETURN successor
618
- """
619
-
620
- results = await self.query(
621
- query,
622
- dict(
623
- node_id=node_id,
624
- ),
625
- )
626
-
627
- return [result["successor"] for result in results]
628
-
629
- async def get_neighbors(self, node_id: str) -> List[Dict[str, Any]]:
630
- """
631
- Get both predecessors and successors of a node.
632
-
633
- Parameters:
634
- -----------
635
-
636
- - node_id (str): The ID of the node to find neighbors for.
637
-
638
- Returns:
639
- --------
640
-
641
- - List[Dict[str, Any]]: A combined list of neighbor node IDs.
642
- """
643
- predecessors, successors = await asyncio.gather(
644
- self.get_predecessors(node_id), self.get_successors(node_id)
645
- )
646
-
647
- return predecessors + successors
648
-
649
- async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
650
- """Get a single node by ID."""
651
- query = """
652
- MATCH (node {id: $node_id})
653
- RETURN node
654
- """
655
- results = await self.query(query, {"node_id": node_id})
656
- return results[0]["node"] if results else None
657
-
658
- async def get_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
659
- """Get multiple nodes by their IDs."""
660
- query = """
661
- UNWIND $node_ids AS id
662
- MATCH (node {id: id})
663
- RETURN node
664
- """
665
- results = await self.query(query, {"node_ids": node_ids})
666
- return [result["node"] for result in results]
667
-
668
- async def get_connections(self, node_id: UUID) -> list:
669
- """
670
- Retrieve connections for a given node, including both predecessors and successors.
671
-
672
- Parameters:
673
- -----------
674
-
675
- - node_id (UUID): The ID of the node for which to retrieve connections.
676
-
677
- Returns:
678
- --------
679
-
680
- - list: A list of connections associated with the node.
681
- """
682
- predecessors_query = """
683
- MATCH (node)<-[relation]-(neighbour)
684
- WHERE node.id = $node_id
685
- RETURN neighbour, relation, node
686
- """
687
- successors_query = """
688
- MATCH (node)-[relation]->(neighbour)
689
- WHERE node.id = $node_id
690
- RETURN node, relation, neighbour
691
- """
692
-
693
- predecessors, successors = await asyncio.gather(
694
- self.query(predecessors_query, dict(node_id=str(node_id))),
695
- self.query(successors_query, dict(node_id=str(node_id))),
696
- )
697
-
698
- connections = []
699
-
700
- for neighbour in predecessors:
701
- neighbour = neighbour["relation"]
702
- connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
703
-
704
- for neighbour in successors:
705
- neighbour = neighbour["relation"]
706
- connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
707
-
708
- return connections
709
-
710
- async def remove_connection_to_predecessors_of(
711
- self, node_ids: list[str], edge_label: str
712
- ) -> None:
713
- """
714
- Remove specified connections to the predecessors of the given node IDs.
715
-
716
- Parameters:
717
- -----------
718
-
719
- - node_ids (list[str]): A list of node IDs from which to remove predecessor
720
- connections.
721
- - edge_label (str): The label of the edges to remove.
722
-
723
- Returns:
724
- --------
725
-
726
- - None: None.
727
- """
728
- query = f"""
729
- UNWIND $node_ids AS nid
730
- MATCH (node {id: nid})-[r]->(predecessor)
731
- WHERE type(r) = $edge_label
732
- DELETE r;
733
- """
734
-
735
- params = {"node_ids": node_ids, "edge_label": edge_label}
736
-
737
- return await self.query(query, params)
738
-
739
- async def remove_connection_to_successors_of(
740
- self, node_ids: list[str], edge_label: str
741
- ) -> None:
742
- """
743
- Remove specified connections to the successors of the given node IDs.
744
-
745
- Parameters:
746
- -----------
747
-
748
- - node_ids (list[str]): A list of node IDs from which to remove successor
749
- connections.
750
- - edge_label (str): The label of the edges to remove.
751
-
752
- Returns:
753
- --------
754
-
755
- - None: None.
756
- """
757
- query = f"""
758
- UNWIND $node_ids AS id
759
- MATCH (node:`{id}`)<-[r:{edge_label}]-(successor)
760
- DELETE r;
761
- """
762
-
763
- params = {"node_ids": node_ids}
764
-
765
- return await self.query(query, params)
766
-
767
- async def delete_graph(self):
768
- """
769
- Completely delete the graph from the database, removing all nodes and edges.
770
-
771
- Returns:
772
- --------
773
-
774
- None.
775
- """
776
- query = """MATCH (node)
777
- DETACH DELETE node;"""
778
-
779
- return await self.query(query)
780
-
781
- def serialize_properties(self, properties=dict()):
782
- """
783
- Convert property values to a suitable representation for storage.
784
-
785
- Parameters:
786
- -----------
787
-
788
- - properties: A dictionary of properties to serialize. (default dict())
789
-
790
- Returns:
791
- --------
792
-
793
- A dictionary of serialized properties.
794
- """
795
- serialized_properties = {}
796
-
797
- for property_key, property_value in properties.items():
798
- if isinstance(property_value, UUID):
799
- serialized_properties[property_key] = str(property_value)
800
- continue
801
-
802
- if isinstance(property_value, dict):
803
- serialized_properties[property_key] = json.dumps(property_value, cls=JSONEncoder)
804
- continue
805
-
806
- serialized_properties[property_key] = property_value
807
-
808
- return serialized_properties
809
-
810
- async def get_model_independent_graph_data(self):
811
- """
812
- Fetch nodes and relationships without any specific model filtering.
813
-
814
- Returns:
815
- --------
816
-
817
- A tuple containing nodes and edges as collections.
818
- """
819
- query_nodes = "MATCH (n) RETURN collect(n) AS nodes"
820
- nodes = await self.query(query_nodes)
821
-
822
- query_edges = "MATCH (n)-[r]->(m) RETURN collect([n, r, m]) AS elements"
823
- edges = await self.query(query_edges)
824
-
825
- return (nodes, edges)
826
-
827
- async def get_graph_data(self):
828
- """
829
- Retrieve all nodes and edges from the graph, including their properties.
830
-
831
- Returns:
832
- --------
833
-
834
- A tuple containing lists of nodes and edges.
835
- """
836
- query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
837
-
838
- result = await self.query(query)
839
-
840
- nodes = [
841
- (
842
- record["id"],
843
- record["properties"],
844
- )
845
- for record in result
846
- ]
847
-
848
- query = """
849
- MATCH (n)-[r]->(m)
850
- RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
851
- """
852
- result = await self.query(query)
853
- edges = [
854
- (
855
- record["source"],
856
- record["target"],
857
- record["type"],
858
- record["properties"],
859
- )
860
- for record in result
861
- ]
862
-
863
- return (nodes, edges)
864
-
865
- async def get_nodeset_subgraph(
866
- self, node_type: Type[Any], node_name: List[str]
867
- ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
868
- """
869
- Throw an error indicating that node set filtering is not supported.
870
-
871
- Parameters:
872
- -----------
873
-
874
- - node_type (Type[Any]): The type of nodes to filter.
875
- - node_name (List[str]): A list of node names to filter.
876
- """
877
- raise NodesetFilterNotSupportedError
878
-
879
- async def get_filtered_graph_data(self, attribute_filters):
880
- """
881
- Fetch nodes and relationships based on specified attribute filters.
882
-
883
- Parameters:
884
- -----------
885
-
886
- - attribute_filters: A list of criteria to filter nodes and relationships.
887
-
888
- Returns:
889
- --------
890
-
891
- A tuple containing filtered nodes and edges.
892
- """
893
- where_clauses = []
894
- for attribute, values in attribute_filters[0].items():
895
- values_str = ", ".join(
896
- f"'{value}'" if isinstance(value, str) else str(value) for value in values
897
- )
898
- where_clauses.append(f"n.{attribute} IN [{values_str}]")
899
-
900
- where_clause = " AND ".join(where_clauses)
901
-
902
- query_nodes = f"""
903
- MATCH (n)
904
- WHERE {where_clause}
905
- RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties
906
- """
907
- result_nodes = await self.query(query_nodes)
908
-
909
- nodes = [
910
- (
911
- record["id"],
912
- record["properties"],
913
- )
914
- for record in result_nodes
915
- ]
916
-
917
- query_edges = f"""
918
- MATCH (n)-[r]->(m)
919
- WHERE {where_clause} AND {where_clause.replace("n.", "m.")}
920
- RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
921
- """
922
- result_edges = await self.query(query_edges)
923
-
924
- edges = [
925
- (
926
- record["source"],
927
- record["target"],
928
- record["type"],
929
- record["properties"],
930
- )
931
- for record in result_edges
932
- ]
933
-
934
- return (nodes, edges)
935
-
936
- async def get_node_labels_string(self):
937
- """
938
- Retrieve a string representation of all unique node labels in the graph.
939
-
940
- Returns:
941
- --------
942
-
943
- A string containing unique node labels.
944
- """
945
- node_labels_query = """
946
- MATCH (n)
947
- WITH DISTINCT labels(n) AS labelList
948
- UNWIND labelList AS label
949
- RETURN collect(DISTINCT label) AS labels;
950
- """
951
- node_labels_result = await self.query(node_labels_query)
952
- node_labels = node_labels_result[0]["labels"] if node_labels_result else []
953
-
954
- if not node_labels:
955
- raise ValueError("No node labels found in the database")
956
-
957
- node_labels_str = "[" + ", ".join(f"'{label}'" for label in node_labels) + "]"
958
- return node_labels_str
959
-
960
- async def get_relationship_labels_string(self):
961
- """
962
- Retrieve a string representation of all unique relationship types in the graph.
963
-
964
- Returns:
965
- --------
966
-
967
- A string containing unique relationship types.
968
- """
969
- relationship_types_query = (
970
- "MATCH ()-[r]->() RETURN collect(DISTINCT type(r)) AS relationships;"
971
- )
972
- relationship_types_result = await self.query(relationship_types_query)
973
- relationship_types = (
974
- relationship_types_result[0]["relationships"] if relationship_types_result else []
975
- )
976
-
977
- if not relationship_types:
978
- raise ValueError("No relationship types found in the database.")
979
-
980
- relationship_types_undirected_str = (
981
- "{"
982
- + ", ".join(f"{rel}" + ": {orientation: 'UNDIRECTED'}" for rel in relationship_types)
983
- + "}"
984
- )
985
- return relationship_types_undirected_str
986
-
987
- async def get_graph_metrics(self, include_optional=False):
988
- """
989
- Calculate and return various metrics of the graph, including mandatory and optional
990
- metrics.
991
-
992
- Parameters:
993
- -----------
994
-
995
- - include_optional: Specify whether to include optional metrics in the results.
996
- (default False)
997
-
998
- Returns:
999
- --------
1000
-
1001
- A dictionary containing calculated graph metrics.
1002
- """
1003
-
1004
- try:
1005
- # Basic metrics
1006
- node_count = await self.query("MATCH (n) RETURN count(n)")
1007
- edge_count = await self.query("MATCH ()-[r]->() RETURN count(r)")
1008
- num_nodes = node_count[0][0] if node_count else 0
1009
- num_edges = edge_count[0][0] if edge_count else 0
1010
-
1011
- # Calculate mandatory metrics
1012
- mandatory_metrics = {
1013
- "num_nodes": num_nodes,
1014
- "num_edges": num_edges,
1015
- "mean_degree": (2 * num_edges) / num_nodes if num_nodes > 0 else 0,
1016
- "edge_density": (num_edges) / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0,
1017
- }
1018
-
1019
- # Calculate connected components
1020
- components_query = """
1021
- MATCH (n:Node)
1022
- WITH n.id AS node_id
1023
- MATCH path = (n)-[:EDGE*0..]-()
1024
- WITH COLLECT(DISTINCT node_id) AS component
1025
- RETURN COLLECT(component) AS components
1026
- """
1027
- components_result = await self.query(components_query)
1028
- component_sizes = (
1029
- [len(comp) for comp in components_result[0][0]] if components_result else []
1030
- )
1031
-
1032
- mandatory_metrics.update(
1033
- {
1034
- "num_connected_components": len(component_sizes),
1035
- "sizes_of_connected_components": component_sizes,
1036
- }
1037
- )
1038
-
1039
- if include_optional:
1040
- # Self-loops
1041
- self_loops_query = """
1042
- MATCH (n:Node)-[r:EDGE]->(n)
1043
- RETURN COUNT(r)
1044
- """
1045
- self_loops = await self.query(self_loops_query)
1046
- num_selfloops = self_loops[0][0] if self_loops else 0
1047
-
1048
- # Shortest paths (simplified for Kuzu)
1049
- paths_query = """
1050
- MATCH (n:Node), (m:Node)
1051
- WHERE n.id < m.id
1052
- MATCH path = (n)-[:EDGE*]-(m)
1053
- RETURN MIN(LENGTH(path)) AS length
1054
- """
1055
- paths = await self.query(paths_query)
1056
- path_lengths = [p[0] for p in paths if p[0] is not None]
1057
-
1058
- # Local clustering coefficient
1059
- clustering_query = """
1060
- /// Step 1: Get each node with its neighbors and degree
1061
- MATCH (n:Node)-[:EDGE]-(neighbor)
1062
- WITH n, COLLECT(DISTINCT neighbor) AS neighbors, COUNT(DISTINCT neighbor) AS degree
1063
-
1064
- // Step 2: Pair up neighbors and check if they are connected
1065
- UNWIND neighbors AS n1
1066
- UNWIND neighbors AS n2
1067
- WITH n, degree, n1, n2
1068
- WHERE id(n1) < id(n2) // avoid duplicate pairs
1069
-
1070
- // Step 3: Use OPTIONAL MATCH to see if n1 and n2 are connected
1071
- OPTIONAL MATCH (n1)-[:EDGE]-(n2)
1072
- WITH n, degree, COUNT(n2) AS triangle_count
1073
-
1074
- // Step 4: Compute local clustering coefficient
1075
- WITH n, degree,
1076
- CASE WHEN degree <= 1 THEN 0.0
1077
- ELSE (1.0 * triangle_count) / (degree * (degree - 1) / 2.0)
1078
- END AS local_cc
1079
-
1080
- // Step 5: Compute average
1081
- RETURN AVG(local_cc) AS avg_clustering_coefficient
1082
- """
1083
- clustering = await self.query(clustering_query)
1084
-
1085
- optional_metrics = {
1086
- "num_selfloops": num_selfloops,
1087
- "diameter": max(path_lengths) if path_lengths else -1,
1088
- "avg_shortest_path_length": sum(path_lengths) / len(path_lengths)
1089
- if path_lengths
1090
- else -1,
1091
- "avg_clustering": clustering[0][0] if clustering and clustering[0][0] else -1,
1092
- }
1093
- else:
1094
- optional_metrics = {
1095
- "num_selfloops": -1,
1096
- "diameter": -1,
1097
- "avg_shortest_path_length": -1,
1098
- "avg_clustering": -1,
1099
- }
1100
-
1101
- return {**mandatory_metrics, **optional_metrics}
1102
-
1103
- except Exception as e:
1104
- logger.error(f"Failed to get graph metrics: {e}")
1105
- return {
1106
- "num_nodes": 0,
1107
- "num_edges": 0,
1108
- "mean_degree": 0,
1109
- "edge_density": 0,
1110
- "num_connected_components": 0,
1111
- "sizes_of_connected_components": [],
1112
- "num_selfloops": -1,
1113
- "diameter": -1,
1114
- "avg_shortest_path_length": -1,
1115
- "avg_clustering": -1,
1116
- }