cognee 0.5.0.dev0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. cognee/api/client.py +1 -5
  2. cognee/api/v1/add/add.py +2 -1
  3. cognee/api/v1/cognify/cognify.py +24 -16
  4. cognee/api/v1/cognify/routers/__init__.py +0 -1
  5. cognee/api/v1/cognify/routers/get_cognify_router.py +3 -1
  6. cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
  7. cognee/api/v1/ontologies/ontologies.py +12 -37
  8. cognee/api/v1/ontologies/routers/get_ontology_router.py +27 -25
  9. cognee/api/v1/search/search.py +8 -0
  10. cognee/api/v1/ui/node_setup.py +360 -0
  11. cognee/api/v1/ui/npm_utils.py +50 -0
  12. cognee/api/v1/ui/ui.py +38 -68
  13. cognee/context_global_variables.py +61 -16
  14. cognee/eval_framework/Dockerfile +29 -0
  15. cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
  16. cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
  17. cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
  18. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
  19. cognee/eval_framework/eval_config.py +2 -2
  20. cognee/eval_framework/modal_run_eval.py +16 -28
  21. cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
  22. cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
  23. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
  24. cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
  25. cognee/infrastructure/databases/graph/config.py +3 -0
  26. cognee/infrastructure/databases/graph/get_graph_engine.py +1 -0
  27. cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
  28. cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
  29. cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
  30. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
  31. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
  32. cognee/infrastructure/databases/utils/__init__.py +3 -0
  33. cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
  34. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +62 -48
  35. cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
  36. cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
  37. cognee/infrastructure/databases/vector/config.py +2 -0
  38. cognee/infrastructure/databases/vector/create_vector_engine.py +1 -0
  39. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
  40. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
  41. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
  42. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
  43. cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
  44. cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
  45. cognee/infrastructure/files/storage/s3_config.py +2 -0
  46. cognee/infrastructure/llm/LLMGateway.py +5 -2
  47. cognee/infrastructure/llm/config.py +35 -0
  48. cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
  49. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
  50. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -16
  51. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
  52. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
  53. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +40 -37
  54. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +39 -36
  55. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +19 -1
  56. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +11 -9
  57. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +23 -21
  58. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +42 -34
  59. cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
  60. cognee/modules/cognify/config.py +2 -0
  61. cognee/modules/data/deletion/prune_system.py +52 -2
  62. cognee/modules/data/methods/delete_dataset.py +26 -0
  63. cognee/modules/engine/models/Triplet.py +9 -0
  64. cognee/modules/engine/models/__init__.py +1 -0
  65. cognee/modules/graph/cognee_graph/CogneeGraph.py +85 -37
  66. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
  67. cognee/modules/memify/memify.py +1 -7
  68. cognee/modules/pipelines/operations/pipeline.py +18 -2
  69. cognee/modules/retrieval/__init__.py +1 -1
  70. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +4 -0
  71. cognee/modules/retrieval/graph_completion_cot_retriever.py +4 -0
  72. cognee/modules/retrieval/graph_completion_retriever.py +10 -0
  73. cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
  74. cognee/modules/retrieval/register_retriever.py +10 -0
  75. cognee/modules/retrieval/registered_community_retrievers.py +1 -0
  76. cognee/modules/retrieval/temporal_retriever.py +4 -0
  77. cognee/modules/retrieval/triplet_retriever.py +182 -0
  78. cognee/modules/retrieval/utils/brute_force_triplet_search.py +42 -10
  79. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +8 -1
  80. cognee/modules/search/methods/get_search_type_tools.py +54 -8
  81. cognee/modules/search/methods/no_access_control_search.py +4 -0
  82. cognee/modules/search/methods/search.py +46 -18
  83. cognee/modules/search/types/SearchType.py +1 -1
  84. cognee/modules/settings/get_settings.py +19 -0
  85. cognee/modules/users/methods/get_authenticated_user.py +2 -2
  86. cognee/modules/users/models/DatasetDatabase.py +15 -3
  87. cognee/shared/logging_utils.py +4 -0
  88. cognee/shared/rate_limiting.py +30 -0
  89. cognee/tasks/documents/__init__.py +0 -1
  90. cognee/tasks/graph/extract_graph_from_data.py +9 -10
  91. cognee/tasks/memify/get_triplet_datapoints.py +289 -0
  92. cognee/tasks/storage/add_data_points.py +142 -2
  93. cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
  94. cognee/tests/integration/tasks/test_add_data_points.py +139 -0
  95. cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
  96. cognee/tests/test_cognee_server_start.py +2 -4
  97. cognee/tests/test_conversation_history.py +23 -1
  98. cognee/tests/test_dataset_database_handler.py +137 -0
  99. cognee/tests/test_dataset_delete.py +76 -0
  100. cognee/tests/test_edge_centered_payload.py +170 -0
  101. cognee/tests/test_pipeline_cache.py +164 -0
  102. cognee/tests/test_search_db.py +37 -1
  103. cognee/tests/unit/api/test_ontology_endpoint.py +77 -89
  104. cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
  105. cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
  106. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
  107. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
  108. cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
  109. cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
  110. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
  111. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
  112. cognee/tests/unit/modules/search/test_search.py +100 -0
  113. cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
  114. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/METADATA +76 -89
  115. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/RECORD +119 -97
  116. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/WHEEL +1 -1
  117. cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
  118. cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
  119. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
  120. cognee/modules/retrieval/code_retriever.py +0 -232
  121. cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
  122. cognee/tasks/code/get_local_dependencies_checker.py +0 -20
  123. cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
  124. cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
  125. cognee/tasks/repo_processor/__init__.py +0 -2
  126. cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
  127. cognee/tasks/repo_processor/get_non_code_files.py +0 -158
  128. cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
  129. cognee/tests/test_delete_bmw_example.py +0 -60
  130. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/entry_points.txt +0 -0
  131. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/LICENSE +0 -0
  132. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/NOTICE.md +0 -0
@@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
12
12
  from concurrent.futures import ThreadPoolExecutor
13
13
  from typing import Dict, Any, List, Union, Optional, Tuple, Type
14
14
 
15
+ from cognee.exceptions import CogneeValidationError
15
16
  from cognee.shared.logging_utils import get_logger
16
17
  from cognee.infrastructure.utils.run_sync import run_sync
17
18
  from cognee.infrastructure.files.storage import get_file_storage
@@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface):
1186
1187
  A tuple with two elements: a list of tuples of (node_id, properties) and a list of
1187
1188
  tuples of (source_id, target_id, relationship_name, properties).
1188
1189
  """
1190
+
1191
+ import time
1192
+
1193
+ start_time = time.time()
1194
+
1189
1195
  try:
1190
1196
  nodes_query = """
1191
1197
  MATCH (n:Node)
@@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface):
1249
1255
  },
1250
1256
  )
1251
1257
  )
1258
+
1259
+ retrieval_time = time.time() - start_time
1260
+ logger.info(
1261
+ f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
1262
+ )
1252
1263
  return formatted_nodes, formatted_edges
1253
1264
  except Exception as e:
1254
1265
  logger.error(f"Failed to get graph data: {e}")
@@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface):
1417
1428
  formatted_edges.append((source_id, target_id, rel_type, props))
1418
1429
  return formatted_nodes, formatted_edges
1419
1430
 
1431
+ async def get_id_filtered_graph_data(self, target_ids: list[str]):
1432
+ """
1433
+ Retrieve graph data filtered by specific node IDs, including their direct neighbors
1434
+ and only edges where one endpoint matches those IDs.
1435
+
1436
+ Returns:
1437
+ nodes: List[dict] -> Each dict includes "id" and all node properties
1438
+ edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
1439
+ """
1440
+ import time
1441
+
1442
+ start_time = time.time()
1443
+
1444
+ try:
1445
+ if not target_ids:
1446
+ logger.warning("No target IDs provided for ID-filtered graph retrieval.")
1447
+ return [], []
1448
+
1449
+ if not all(isinstance(x, str) for x in target_ids):
1450
+ raise CogneeValidationError("target_ids must be a list of strings")
1451
+
1452
+ query = """
1453
+ MATCH (n:Node)-[r]->(m:Node)
1454
+ WHERE n.id IN $target_ids OR m.id IN $target_ids
1455
+ RETURN n.id, {
1456
+ name: n.name,
1457
+ type: n.type,
1458
+ properties: n.properties
1459
+ }, m.id, {
1460
+ name: m.name,
1461
+ type: m.type,
1462
+ properties: m.properties
1463
+ }, r.relationship_name, r.properties
1464
+ """
1465
+
1466
+ result = await self.query(query, {"target_ids": target_ids})
1467
+
1468
+ if not result:
1469
+ logger.info("No data returned for the supplied IDs")
1470
+ return [], []
1471
+
1472
+ nodes_dict = {}
1473
+ edges = []
1474
+
1475
+ for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
1476
+ if n_props.get("properties"):
1477
+ try:
1478
+ additional_props = json.loads(n_props["properties"])
1479
+ n_props.update(additional_props)
1480
+ del n_props["properties"]
1481
+ except json.JSONDecodeError:
1482
+ logger.warning(f"Failed to parse properties JSON for node {n_id}")
1483
+
1484
+ if m_props.get("properties"):
1485
+ try:
1486
+ additional_props = json.loads(m_props["properties"])
1487
+ m_props.update(additional_props)
1488
+ del m_props["properties"]
1489
+ except json.JSONDecodeError:
1490
+ logger.warning(f"Failed to parse properties JSON for node {m_id}")
1491
+
1492
+ nodes_dict[n_id] = (n_id, n_props)
1493
+ nodes_dict[m_id] = (m_id, m_props)
1494
+
1495
+ edge_props = {}
1496
+ if r_props_raw:
1497
+ try:
1498
+ edge_props = json.loads(r_props_raw)
1499
+ except (json.JSONDecodeError, TypeError):
1500
+ logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
1501
+
1502
+ source_id = edge_props.get("source_node_id", n_id)
1503
+ target_id = edge_props.get("target_node_id", m_id)
1504
+ edges.append((source_id, target_id, r_type, edge_props))
1505
+
1506
+ retrieval_time = time.time() - start_time
1507
+ logger.info(
1508
+ f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
1509
+ )
1510
+
1511
+ return list(nodes_dict.values()), edges
1512
+
1513
+ except Exception as e:
1514
+ logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
1515
+ raise
1516
+
1420
1517
  async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
1421
1518
  """
1422
1519
  Get metrics on graph structure and connectivity.
@@ -1908,3 +2005,134 @@ class KuzuAdapter(GraphDBInterface):
1908
2005
  time_ids_list = [item[0] for item in time_nodes]
1909
2006
 
1910
2007
  return ", ".join(f"'{uid}'" for uid in time_ids_list)
2008
+
2009
+ async def get_triplets_batch(self, offset: int, limit: int) -> list[dict[str, Any]]:
2010
+ """
2011
+ Retrieve a batch of triplets (start_node, relationship, end_node) from the graph.
2012
+
2013
+ Parameters:
2014
+ -----------
2015
+ - offset (int): Number of triplets to skip before returning results.
2016
+ - limit (int): Maximum number of triplets to return.
2017
+
2018
+ Returns:
2019
+ --------
2020
+ - list[dict[str, Any]]: A list of triplets, where each triplet is a dictionary
2021
+ with keys: 'start_node', 'relationship_properties', 'end_node'.
2022
+
2023
+ Raises:
2024
+ -------
2025
+ - ValueError: If offset or limit are negative.
2026
+ - Exception: Re-raises any exceptions from query execution.
2027
+ """
2028
+ if offset < 0:
2029
+ raise ValueError(f"Offset must be non-negative, got {offset}")
2030
+ if limit < 0:
2031
+ raise ValueError(f"Limit must be non-negative, got {limit}")
2032
+
2033
+ query = """
2034
+ MATCH (start_node:Node)-[relationship:EDGE]->(end_node:Node)
2035
+ RETURN {
2036
+ start_node: {
2037
+ id: start_node.id,
2038
+ name: start_node.name,
2039
+ type: start_node.type,
2040
+ properties: start_node.properties
2041
+ },
2042
+ relationship_properties: {
2043
+ relationship_name: relationship.relationship_name,
2044
+ properties: relationship.properties
2045
+ },
2046
+ end_node: {
2047
+ id: end_node.id,
2048
+ name: end_node.name,
2049
+ type: end_node.type,
2050
+ properties: end_node.properties
2051
+ }
2052
+ } AS triplet
2053
+ SKIP $offset LIMIT $limit
2054
+ """
2055
+
2056
+ try:
2057
+ results = await self.query(query, {"offset": offset, "limit": limit})
2058
+ except Exception as e:
2059
+ logger.error(f"Failed to execute triplet query: {str(e)}")
2060
+ logger.error(f"Query: {query}")
2061
+ logger.error(f"Parameters: offset={offset}, limit={limit}")
2062
+ raise
2063
+
2064
+ triplets = []
2065
+ for idx, row in enumerate(results):
2066
+ try:
2067
+ if not row or len(row) == 0:
2068
+ logger.warning(f"Skipping empty row at index {idx} in triplet batch")
2069
+ continue
2070
+
2071
+ if not isinstance(row[0], dict):
2072
+ logger.warning(
2073
+ f"Skipping invalid row at index {idx}: expected dict, got {type(row[0])}"
2074
+ )
2075
+ continue
2076
+
2077
+ triplet = row[0]
2078
+
2079
+ if "start_node" not in triplet:
2080
+ logger.warning(f"Skipping triplet at index {idx}: missing 'start_node' key")
2081
+ continue
2082
+
2083
+ if not isinstance(triplet["start_node"], dict):
2084
+ logger.warning(f"Skipping triplet at index {idx}: 'start_node' is not a dict")
2085
+ continue
2086
+
2087
+ triplet["start_node"] = self._parse_node_properties(triplet["start_node"].copy())
2088
+
2089
+ if "relationship_properties" not in triplet:
2090
+ logger.warning(
2091
+ f"Skipping triplet at index {idx}: missing 'relationship_properties' key"
2092
+ )
2093
+ continue
2094
+
2095
+ if not isinstance(triplet["relationship_properties"], dict):
2096
+ logger.warning(
2097
+ f"Skipping triplet at index {idx}: 'relationship_properties' is not a dict"
2098
+ )
2099
+ continue
2100
+
2101
+ rel_props = triplet["relationship_properties"].copy()
2102
+ relationship_name = rel_props.get("relationship_name") or ""
2103
+
2104
+ if rel_props.get("properties"):
2105
+ try:
2106
+ parsed_props = json.loads(rel_props["properties"])
2107
+ if isinstance(parsed_props, dict):
2108
+ rel_props.update(parsed_props)
2109
+ del rel_props["properties"]
2110
+ else:
2111
+ logger.warning(
2112
+ f"Parsed relationship properties is not a dict for triplet at index {idx}"
2113
+ )
2114
+ except (json.JSONDecodeError, TypeError) as e:
2115
+ logger.warning(
2116
+ f"Failed to parse relationship properties JSON for triplet at index {idx}: {e}"
2117
+ )
2118
+
2119
+ rel_props["relationship_name"] = relationship_name
2120
+ triplet["relationship_properties"] = rel_props
2121
+
2122
+ if "end_node" not in triplet:
2123
+ logger.warning(f"Skipping triplet at index {idx}: missing 'end_node' key")
2124
+ continue
2125
+
2126
+ if not isinstance(triplet["end_node"], dict):
2127
+ logger.warning(f"Skipping triplet at index {idx}: 'end_node' is not a dict")
2128
+ continue
2129
+
2130
+ triplet["end_node"] = self._parse_node_properties(triplet["end_node"].copy())
2131
+
2132
+ triplets.append(triplet)
2133
+
2134
+ except Exception as e:
2135
+ logger.error(f"Error processing triplet at index {idx}: {e}", exc_info=True)
2136
+ continue
2137
+
2138
+ return triplets
@@ -0,0 +1,168 @@
1
+ import os
2
+ import asyncio
3
+ import requests
4
+ import base64
5
+ import hashlib
6
+ from uuid import UUID
7
+ from typing import Optional
8
+ from cryptography.fernet import Fernet
9
+
10
+ from cognee.infrastructure.databases.graph import get_graph_config
11
+ from cognee.modules.users.models import User, DatasetDatabase
12
+ from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
13
+
14
+
15
+ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
16
+ """
17
+ Handler for a quick development PoC integration of Cognee multi-user and permission mode with Neo4j Aura databases.
18
+ This handler creates a new Neo4j Aura instance for each Cognee dataset created.
19
+
20
+ Improvements needed to be production ready:
21
+ - Secret management for client credentials, currently secrets are encrypted and stored in the Cognee relational database,
22
+ a secret manager or a similar system should be used instead.
23
+
24
+ Quality of life improvements:
25
+ - Allow configuration of different Neo4j Aura plans and regions.
26
+ - Requests should be made async, currently a blocking requests library is used.
27
+ """
28
+
29
+ @classmethod
30
+ async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
31
+ """
32
+ Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
33
+
34
+ Args:
35
+ dataset_id: Dataset UUID
36
+ user: User object who owns the dataset and is making the request
37
+
38
+ Returns:
39
+ dict: Connection details for the created Neo4j instance
40
+
41
+ """
42
+ graph_config = get_graph_config()
43
+
44
+ if graph_config.graph_database_provider != "neo4j":
45
+ raise ValueError(
46
+ "Neo4jAuraDevDatasetDatabaseHandler can only be used with Neo4j graph database provider."
47
+ )
48
+
49
+ graph_db_name = f"{dataset_id}"
50
+
51
+ # Client credentials and encryption
52
+ client_id = os.environ.get("NEO4J_CLIENT_ID", None)
53
+ client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
54
+ tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
55
+ encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
56
+ encryption_key = base64.urlsafe_b64encode(
57
+ hashlib.sha256(encryption_env_key.encode()).digest()
58
+ )
59
+ cipher = Fernet(encryption_key)
60
+
61
+ if client_id is None or client_secret is None or tenant_id is None:
62
+ raise ValueError(
63
+ "NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
64
+ )
65
+
66
+ # Make the request with HTTP Basic Auth
67
+ def get_aura_token(client_id: str, client_secret: str) -> dict:
68
+ url = "https://api.neo4j.io/oauth/token"
69
+ data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
70
+
71
+ resp = requests.post(url, data=data, auth=(client_id, client_secret))
72
+ resp.raise_for_status() # raises if the request failed
73
+ return resp.json()
74
+
75
+ resp = get_aura_token(client_id, client_secret)
76
+
77
+ url = "https://api.neo4j.io/v1/instances"
78
+
79
+ headers = {
80
+ "accept": "application/json",
81
+ "Authorization": f"Bearer {resp['access_token']}",
82
+ "Content-Type": "application/json",
83
+ }
84
+
85
+ # TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
86
+ # Too allow different configurations between datasets
87
+ payload = {
88
+ "version": "5",
89
+ "region": "europe-west1",
90
+ "memory": "1GB",
91
+ "name": graph_db_name[
92
+ 0:29
93
+ ], # TODO: Find better name to name Neo4j instance within 30 character limit
94
+ "type": "professional-db",
95
+ "tenant_id": tenant_id,
96
+ "cloud_provider": "gcp",
97
+ }
98
+
99
+ response = requests.post(url, headers=headers, json=payload)
100
+
101
+ graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
102
+ graph_db_url = response.json()["data"]["connection_url"]
103
+ graph_db_key = resp["access_token"]
104
+ graph_db_username = response.json()["data"]["username"]
105
+ graph_db_password = response.json()["data"]["password"]
106
+
107
+ async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
108
+ # Poll until the instance is running
109
+ status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
110
+ status = ""
111
+ for attempt in range(30): # Try for up to ~5 minutes
112
+ status_resp = requests.get(
113
+ status_url, headers=headers
114
+ ) # TODO: Use async requests with httpx
115
+ status = status_resp.json()["data"]["status"]
116
+ if status.lower() == "running":
117
+ return
118
+ await asyncio.sleep(10)
119
+ raise TimeoutError(
120
+ f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
121
+ )
122
+
123
+ instance_id = response.json()["data"]["id"]
124
+ await _wait_for_neo4j_instance_provisioning(instance_id, headers)
125
+
126
+ encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
127
+ encrypted_db_password_string = encrypted_db_password_bytes.decode()
128
+
129
+ return {
130
+ "graph_database_name": graph_db_name,
131
+ "graph_database_url": graph_db_url,
132
+ "graph_database_provider": "neo4j",
133
+ "graph_database_key": graph_db_key,
134
+ "graph_dataset_database_handler": "neo4j_aura_dev",
135
+ "graph_database_connection_info": {
136
+ "graph_database_username": graph_db_username,
137
+ "graph_database_password": encrypted_db_password_string,
138
+ },
139
+ }
140
+
141
+ @classmethod
142
+ async def resolve_dataset_connection_info(
143
+ cls, dataset_database: DatasetDatabase
144
+ ) -> DatasetDatabase:
145
+ """
146
+ Resolve and decrypt connection info for the Neo4j dataset database.
147
+ In this case, decrypt the password stored in the database.
148
+
149
+ Args:
150
+ dataset_database: DatasetDatabase instance containing encrypted connection info.
151
+ """
152
+ encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
153
+ encryption_key = base64.urlsafe_b64encode(
154
+ hashlib.sha256(encryption_env_key.encode()).digest()
155
+ )
156
+ cipher = Fernet(encryption_key)
157
+ graph_db_password = cipher.decrypt(
158
+ dataset_database.graph_database_connection_info["graph_database_password"].encode()
159
+ ).decode()
160
+
161
+ dataset_database.graph_database_connection_info["graph_database_password"] = (
162
+ graph_db_password
163
+ )
164
+ return dataset_database
165
+
166
+ @classmethod
167
+ async def delete_dataset(cls, dataset_database: DatasetDatabase):
168
+ pass
@@ -8,7 +8,7 @@ from neo4j import AsyncSession
8
8
  from neo4j import AsyncGraphDatabase
9
9
  from neo4j.exceptions import Neo4jError
10
10
  from contextlib import asynccontextmanager
11
- from typing import Optional, Any, List, Dict, Type, Tuple
11
+ from typing import Optional, Any, List, Dict, Type, Tuple, Coroutine
12
12
 
13
13
  from cognee.infrastructure.engine import DataPoint
14
14
  from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
@@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface):
964
964
  logger.error(f"Error during graph data retrieval: {str(e)}")
965
965
  raise
966
966
 
967
+ async def get_id_filtered_graph_data(self, target_ids: list[str]):
968
+ """
969
+ Retrieve graph data filtered by specific node IDs, including their direct neighbors
970
+ and only edges where one endpoint matches those IDs.
971
+
972
+ This version uses a single Cypher query for efficiency.
973
+ """
974
+ import time
975
+
976
+ start_time = time.time()
977
+
978
+ try:
979
+ if not target_ids:
980
+ logger.warning("No target IDs provided for ID-filtered graph retrieval.")
981
+ return [], []
982
+
983
+ query = """
984
+ MATCH ()-[r]-()
985
+ WHERE startNode(r).id IN $target_ids
986
+ OR endNode(r).id IN $target_ids
987
+ WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
988
+ RETURN
989
+ properties(a) AS n_properties,
990
+ properties(b) AS m_properties,
991
+ type(r) AS type,
992
+ properties(r) AS properties
993
+ """
994
+
995
+ result = await self.query(query, {"target_ids": target_ids})
996
+
997
+ nodes_dict = {}
998
+ edges = []
999
+
1000
+ for record in result:
1001
+ n_props = record["n_properties"]
1002
+ m_props = record["m_properties"]
1003
+ r_props = record["properties"]
1004
+ r_type = record["type"]
1005
+
1006
+ nodes_dict[n_props["id"]] = (n_props["id"], n_props)
1007
+ nodes_dict[m_props["id"]] = (m_props["id"], m_props)
1008
+
1009
+ source_id = r_props.get("source_node_id", n_props["id"])
1010
+ target_id = r_props.get("target_node_id", m_props["id"])
1011
+ edges.append((source_id, target_id, r_type, r_props))
1012
+
1013
+ retrieval_time = time.time() - start_time
1014
+ logger.info(
1015
+ f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
1016
+ )
1017
+
1018
+ return list(nodes_dict.values()), edges
1019
+
1020
+ except Exception as e:
1021
+ logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
1022
+ raise
1023
+
967
1024
  async def get_nodeset_subgraph(
968
1025
  self, node_type: Type[Any], node_name: List[str]
969
1026
  ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
@@ -1470,3 +1527,25 @@ class Neo4jAdapter(GraphDBInterface):
1470
1527
  time_ids_list = [item["id"] for item in time_nodes if "id" in item]
1471
1528
 
1472
1529
  return ", ".join(f"'{uid}'" for uid in time_ids_list)
1530
+
1531
+ async def get_triplets_batch(self, offset: int, limit: int) -> list[dict[str, Any]]:
1532
+ """
1533
+ Retrieve a batch of triplets (start_node, relationship, end_node) from the graph.
1534
+
1535
+ Parameters:
1536
+ -----------
1537
+ - offset (int): Number of triplets to skip before returning results.
1538
+ - limit (int): Maximum number of triplets to return.
1539
+
1540
+ Returns:
1541
+ --------
1542
+ - list[dict[str, Any]]: A list of triplets.
1543
+ """
1544
+ query = f"""
1545
+ MATCH (start_node:`{BASE_LABEL}`)-[relationship]->(end_node:`{BASE_LABEL}`)
1546
+ RETURN start_node, properties(relationship) AS relationship_properties, end_node
1547
+ SKIP $offset LIMIT $limit
1548
+ """
1549
+ results = await self.query(query, {"offset": offset, "limit": limit})
1550
+
1551
+ return results
@@ -1 +1,4 @@
1
1
  from .get_or_create_dataset_database import get_or_create_dataset_database
2
+ from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
3
+ from .get_graph_dataset_database_handler import get_graph_dataset_database_handler
4
+ from .get_vector_dataset_database_handler import get_vector_dataset_database_handler
@@ -0,0 +1,10 @@
1
+ from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
2
+
3
+
4
+ def get_graph_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
5
+ from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
6
+ supported_dataset_database_handlers,
7
+ )
8
+
9
+ handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
10
+ return handler
@@ -1,11 +1,9 @@
1
- import os
2
1
  from uuid import UUID
3
- from typing import Union
2
+ from typing import Union, Optional
4
3
 
5
4
  from sqlalchemy import select
6
5
  from sqlalchemy.exc import IntegrityError
7
6
 
8
- from cognee.base_config import get_base_config
9
7
  from cognee.modules.data.methods import create_dataset
10
8
  from cognee.infrastructure.databases.relational import get_relational_engine
11
9
  from cognee.infrastructure.databases.vector import get_vectordb_config
@@ -15,6 +13,53 @@ from cognee.modules.users.models import DatasetDatabase
15
13
  from cognee.modules.users.models import User
16
14
 
17
15
 
16
+ async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
17
+ vector_config = get_vectordb_config()
18
+
19
+ from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
20
+ supported_dataset_database_handlers,
21
+ )
22
+
23
+ handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
24
+ return await handler["handler_instance"].create_dataset(dataset_id, user)
25
+
26
+
27
+ async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict:
28
+ graph_config = get_graph_config()
29
+
30
+ from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
31
+ supported_dataset_database_handlers,
32
+ )
33
+
34
+ handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
35
+ return await handler["handler_instance"].create_dataset(dataset_id, user)
36
+
37
+
38
+ async def _existing_dataset_database(
39
+ dataset_id: UUID,
40
+ user: User,
41
+ ) -> Optional[DatasetDatabase]:
42
+ """
43
+ Check if a DatasetDatabase row already exists for the given owner + dataset.
44
+ Return None if it doesn't exist, return the row if it does.
45
+ Args:
46
+ dataset_id:
47
+ user:
48
+
49
+ Returns:
50
+ DatasetDatabase or None
51
+ """
52
+ db_engine = get_relational_engine()
53
+
54
+ async with db_engine.get_async_session() as session:
55
+ stmt = select(DatasetDatabase).where(
56
+ DatasetDatabase.owner_id == user.id,
57
+ DatasetDatabase.dataset_id == dataset_id,
58
+ )
59
+ existing: DatasetDatabase = await session.scalar(stmt)
60
+ return existing
61
+
62
+
18
63
  async def get_or_create_dataset_database(
19
64
  dataset: Union[str, UUID],
20
65
  user: User,
@@ -25,6 +70,8 @@ async def get_or_create_dataset_database(
25
70
  • If the row already exists, it is fetched and returned.
26
71
  • Otherwise a new one is created atomically and returned.
27
72
 
73
+ DatasetDatabase row contains connection and provider info for vector and graph databases.
74
+
28
75
  Parameters
29
76
  ----------
30
77
  user : User
@@ -36,59 +83,26 @@ async def get_or_create_dataset_database(
36
83
 
37
84
  dataset_id = await get_unique_dataset_id(dataset, user)
38
85
 
39
- vector_config = get_vectordb_config()
40
- graph_config = get_graph_config()
86
+ # If dataset is given as name make sure the dataset is created first
87
+ if isinstance(dataset, str):
88
+ async with db_engine.get_async_session() as session:
89
+ await create_dataset(dataset, user, session)
41
90
 
42
- # Note: for hybrid databases both graph and vector DB name have to be the same
43
- if graph_config.graph_database_provider == "kuzu":
44
- graph_db_name = f"{dataset_id}.pkl"
45
- else:
46
- graph_db_name = f"{dataset_id}"
91
+ # If dataset database already exists return it
92
+ existing_dataset_database = await _existing_dataset_database(dataset_id, user)
93
+ if existing_dataset_database:
94
+ return existing_dataset_database
47
95
 
48
- if vector_config.vector_db_provider == "lancedb":
49
- vector_db_name = f"{dataset_id}.lance.db"
50
- else:
51
- vector_db_name = f"{dataset_id}"
52
-
53
- base_config = get_base_config()
54
- databases_directory_path = os.path.join(
55
- base_config.system_root_directory, "databases", str(user.id)
56
- )
57
-
58
- # Determine vector database URL
59
- if vector_config.vector_db_provider == "lancedb":
60
- vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
61
- else:
62
- vector_db_url = vector_config.vector_database_url
63
-
64
- # Determine graph database URL
96
+ graph_config_dict = await _get_graph_db_info(dataset_id, user)
97
+ vector_config_dict = await _get_vector_db_info(dataset_id, user)
65
98
 
66
99
  async with db_engine.get_async_session() as session:
67
- # Create dataset if it doesn't exist
68
- if isinstance(dataset, str):
69
- dataset = await create_dataset(dataset, user, session)
70
-
71
- # Try to fetch an existing row first
72
- stmt = select(DatasetDatabase).where(
73
- DatasetDatabase.owner_id == user.id,
74
- DatasetDatabase.dataset_id == dataset_id,
75
- )
76
- existing: DatasetDatabase = await session.scalar(stmt)
77
- if existing:
78
- return existing
79
-
80
100
  # If there are no existing rows build a new row
81
101
  record = DatasetDatabase(
82
102
  owner_id=user.id,
83
103
  dataset_id=dataset_id,
84
- vector_database_name=vector_db_name,
85
- graph_database_name=graph_db_name,
86
- vector_database_provider=vector_config.vector_db_provider,
87
- graph_database_provider=graph_config.graph_database_provider,
88
- vector_database_url=vector_db_url,
89
- graph_database_url=graph_config.graph_database_url,
90
- vector_database_key=vector_config.vector_db_key,
91
- graph_database_key=graph_config.graph_database_key,
104
+ **graph_config_dict, # Unpack graph db config
105
+ **vector_config_dict, # Unpack vector db config
92
106
  )
93
107
 
94
108
  try: