cognee 0.5.0.dev0__py3-none-any.whl → 0.5.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. cognee/api/client.py +1 -5
  2. cognee/api/v1/add/add.py +2 -1
  3. cognee/api/v1/cognify/cognify.py +24 -16
  4. cognee/api/v1/cognify/routers/__init__.py +0 -1
  5. cognee/api/v1/cognify/routers/get_cognify_router.py +3 -1
  6. cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
  7. cognee/api/v1/ontologies/ontologies.py +12 -37
  8. cognee/api/v1/ontologies/routers/get_ontology_router.py +27 -25
  9. cognee/api/v1/search/search.py +4 -0
  10. cognee/api/v1/ui/node_setup.py +360 -0
  11. cognee/api/v1/ui/npm_utils.py +50 -0
  12. cognee/api/v1/ui/ui.py +38 -68
  13. cognee/context_global_variables.py +61 -16
  14. cognee/eval_framework/Dockerfile +29 -0
  15. cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
  16. cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
  17. cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
  18. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
  19. cognee/eval_framework/eval_config.py +2 -2
  20. cognee/eval_framework/modal_run_eval.py +16 -28
  21. cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
  22. cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
  23. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
  24. cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
  25. cognee/infrastructure/databases/graph/config.py +3 -0
  26. cognee/infrastructure/databases/graph/get_graph_engine.py +1 -0
  27. cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
  28. cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
  29. cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
  30. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
  31. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
  32. cognee/infrastructure/databases/utils/__init__.py +3 -0
  33. cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
  34. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +62 -48
  35. cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
  36. cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
  37. cognee/infrastructure/databases/vector/config.py +2 -0
  38. cognee/infrastructure/databases/vector/create_vector_engine.py +1 -0
  39. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
  40. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
  41. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
  42. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
  43. cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
  44. cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
  45. cognee/infrastructure/files/storage/s3_config.py +2 -0
  46. cognee/infrastructure/llm/LLMGateway.py +5 -2
  47. cognee/infrastructure/llm/config.py +35 -0
  48. cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
  49. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
  50. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -16
  51. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
  52. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
  53. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +40 -37
  54. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +39 -36
  55. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +19 -1
  56. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +11 -9
  57. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +23 -21
  58. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +42 -34
  59. cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
  60. cognee/modules/cognify/config.py +2 -0
  61. cognee/modules/data/deletion/prune_system.py +52 -2
  62. cognee/modules/data/methods/delete_dataset.py +26 -0
  63. cognee/modules/engine/models/Triplet.py +9 -0
  64. cognee/modules/engine/models/__init__.py +1 -0
  65. cognee/modules/graph/cognee_graph/CogneeGraph.py +85 -37
  66. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
  67. cognee/modules/memify/memify.py +1 -7
  68. cognee/modules/pipelines/operations/pipeline.py +18 -2
  69. cognee/modules/retrieval/__init__.py +1 -1
  70. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +4 -0
  71. cognee/modules/retrieval/graph_completion_cot_retriever.py +4 -0
  72. cognee/modules/retrieval/graph_completion_retriever.py +10 -0
  73. cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
  74. cognee/modules/retrieval/register_retriever.py +10 -0
  75. cognee/modules/retrieval/registered_community_retrievers.py +1 -0
  76. cognee/modules/retrieval/temporal_retriever.py +4 -0
  77. cognee/modules/retrieval/triplet_retriever.py +182 -0
  78. cognee/modules/retrieval/utils/brute_force_triplet_search.py +42 -10
  79. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +8 -1
  80. cognee/modules/search/methods/get_search_type_tools.py +54 -8
  81. cognee/modules/search/methods/no_access_control_search.py +4 -0
  82. cognee/modules/search/methods/search.py +21 -0
  83. cognee/modules/search/types/SearchType.py +1 -1
  84. cognee/modules/settings/get_settings.py +19 -0
  85. cognee/modules/users/methods/get_authenticated_user.py +2 -2
  86. cognee/modules/users/models/DatasetDatabase.py +15 -3
  87. cognee/shared/logging_utils.py +4 -0
  88. cognee/shared/rate_limiting.py +30 -0
  89. cognee/tasks/documents/__init__.py +0 -1
  90. cognee/tasks/graph/extract_graph_from_data.py +9 -10
  91. cognee/tasks/memify/get_triplet_datapoints.py +289 -0
  92. cognee/tasks/storage/add_data_points.py +142 -2
  93. cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
  94. cognee/tests/integration/tasks/test_add_data_points.py +139 -0
  95. cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
  96. cognee/tests/test_cognee_server_start.py +2 -4
  97. cognee/tests/test_conversation_history.py +23 -1
  98. cognee/tests/test_dataset_database_handler.py +137 -0
  99. cognee/tests/test_dataset_delete.py +76 -0
  100. cognee/tests/test_edge_centered_payload.py +170 -0
  101. cognee/tests/test_pipeline_cache.py +164 -0
  102. cognee/tests/test_search_db.py +37 -1
  103. cognee/tests/unit/api/test_ontology_endpoint.py +77 -89
  104. cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
  105. cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
  106. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
  107. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
  108. cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
  109. cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
  110. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
  111. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
  112. cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
  113. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/METADATA +76 -89
  114. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/RECORD +118 -97
  115. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/WHEEL +1 -1
  116. cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
  117. cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
  118. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
  119. cognee/modules/retrieval/code_retriever.py +0 -232
  120. cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
  121. cognee/tasks/code/get_local_dependencies_checker.py +0 -20
  122. cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
  123. cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
  124. cognee/tasks/repo_processor/__init__.py +0 -2
  125. cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
  126. cognee/tasks/repo_processor/get_non_code_files.py +0 -158
  127. cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
  128. cognee/tests/test_delete_bmw_example.py +0 -60
  129. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/entry_points.txt +0 -0
  130. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/licenses/LICENSE +0 -0
  131. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,4 +1,5 @@
1
1
  import pytest
2
+ from unittest.mock import AsyncMock
2
3
 
3
4
  from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
4
5
  from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
@@ -11,6 +12,30 @@ def setup_graph():
11
12
  return CogneeGraph()
12
13
 
13
14
 
15
+ @pytest.fixture
16
+ def mock_adapter():
17
+ """Fixture to create a mock adapter for database operations."""
18
+ adapter = AsyncMock()
19
+ return adapter
20
+
21
+
22
+ @pytest.fixture
23
+ def mock_vector_engine():
24
+ """Fixture to create a mock vector engine."""
25
+ engine = AsyncMock()
26
+ engine.search = AsyncMock()
27
+ return engine
28
+
29
+
30
+ class MockScoredResult:
31
+ """Mock class for vector search results."""
32
+
33
+ def __init__(self, id, score, payload=None):
34
+ self.id = id
35
+ self.score = score
36
+ self.payload = payload or {}
37
+
38
+
14
39
  def test_add_node_success(setup_graph):
15
40
  """Test successful addition of a node."""
16
41
  graph = setup_graph
@@ -73,3 +98,384 @@ def test_get_edges_nonexistent_node(setup_graph):
73
98
  graph = setup_graph
74
99
  with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
75
100
  graph.get_edges_from_node("nonexistent")
101
+
102
+
103
+ @pytest.mark.asyncio
104
+ async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter):
105
+ """Test projecting a full graph from database."""
106
+ graph = setup_graph
107
+
108
+ nodes_data = [
109
+ ("1", {"name": "Node1", "description": "First node"}),
110
+ ("2", {"name": "Node2", "description": "Second node"}),
111
+ ]
112
+ edges_data = [
113
+ ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
114
+ ]
115
+
116
+ mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
117
+
118
+ await graph.project_graph_from_db(
119
+ adapter=mock_adapter,
120
+ node_properties_to_project=["name", "description"],
121
+ edge_properties_to_project=["relationship_name"],
122
+ )
123
+
124
+ assert len(graph.nodes) == 2
125
+ assert len(graph.edges) == 1
126
+ assert graph.get_node("1") is not None
127
+ assert graph.get_node("2") is not None
128
+ assert graph.edges[0].node1.id == "1"
129
+ assert graph.edges[0].node2.id == "2"
130
+
131
+
132
+ @pytest.mark.asyncio
133
+ async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter):
134
+ """Test projecting an ID-filtered graph from database."""
135
+ graph = setup_graph
136
+
137
+ nodes_data = [
138
+ ("1", {"name": "Node1"}),
139
+ ("2", {"name": "Node2"}),
140
+ ]
141
+ edges_data = [
142
+ ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
143
+ ]
144
+
145
+ mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
146
+
147
+ await graph.project_graph_from_db(
148
+ adapter=mock_adapter,
149
+ node_properties_to_project=["name"],
150
+ edge_properties_to_project=["relationship_name"],
151
+ relevant_ids_to_filter=["1", "2"],
152
+ )
153
+
154
+ assert len(graph.nodes) == 2
155
+ assert len(graph.edges) == 1
156
+ mock_adapter.get_id_filtered_graph_data.assert_called_once()
157
+
158
+
159
+ @pytest.mark.asyncio
160
+ async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter):
161
+ """Test projecting a nodeset subgraph filtered by node type and name."""
162
+ graph = setup_graph
163
+
164
+ nodes_data = [
165
+ ("1", {"name": "Alice", "type": "Person"}),
166
+ ("2", {"name": "Bob", "type": "Person"}),
167
+ ]
168
+ edges_data = [
169
+ ("1", "2", "KNOWS", {"relationship_name": "knows"}),
170
+ ]
171
+
172
+ mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data))
173
+
174
+ await graph.project_graph_from_db(
175
+ adapter=mock_adapter,
176
+ node_properties_to_project=["name", "type"],
177
+ edge_properties_to_project=["relationship_name"],
178
+ node_type="Person",
179
+ node_name=["Alice"],
180
+ )
181
+
182
+ assert len(graph.nodes) == 2
183
+ assert graph.get_node("1") is not None
184
+ assert len(graph.edges) == 1
185
+ mock_adapter.get_nodeset_subgraph.assert_called_once()
186
+
187
+
188
+ @pytest.mark.asyncio
189
+ async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
190
+ """Test projecting empty graph raises EntityNotFoundError."""
191
+ graph = setup_graph
192
+
193
+ mock_adapter.get_graph_data = AsyncMock(return_value=([], []))
194
+
195
+ with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."):
196
+ await graph.project_graph_from_db(
197
+ adapter=mock_adapter,
198
+ node_properties_to_project=["name"],
199
+ edge_properties_to_project=[],
200
+ )
201
+
202
+
203
+ @pytest.mark.asyncio
204
+ async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
205
+ """Test that edges referencing missing nodes raise error."""
206
+ graph = setup_graph
207
+
208
+ nodes_data = [
209
+ ("1", {"name": "Node1"}),
210
+ ]
211
+ edges_data = [
212
+ ("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}),
213
+ ]
214
+
215
+ mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
216
+
217
+ with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"):
218
+ await graph.project_graph_from_db(
219
+ adapter=mock_adapter,
220
+ node_properties_to_project=["name"],
221
+ edge_properties_to_project=["relationship_name"],
222
+ )
223
+
224
+
225
+ @pytest.mark.asyncio
226
+ async def test_map_vector_distances_to_graph_nodes(setup_graph):
227
+ """Test mapping vector distances to graph nodes."""
228
+ graph = setup_graph
229
+
230
+ node1 = Node("1", {"name": "Node1"})
231
+ node2 = Node("2", {"name": "Node2"})
232
+ graph.add_node(node1)
233
+ graph.add_node(node2)
234
+
235
+ node_distances = {
236
+ "Entity_name": [
237
+ MockScoredResult("1", 0.95),
238
+ MockScoredResult("2", 0.87),
239
+ ]
240
+ }
241
+
242
+ await graph.map_vector_distances_to_graph_nodes(node_distances)
243
+
244
+ assert graph.get_node("1").attributes.get("vector_distance") == 0.95
245
+ assert graph.get_node("2").attributes.get("vector_distance") == 0.87
246
+
247
+
248
+ @pytest.mark.asyncio
249
+ async def test_map_vector_distances_partial_node_coverage(setup_graph):
250
+ """Test mapping vector distances when only some nodes have results."""
251
+ graph = setup_graph
252
+
253
+ node1 = Node("1", {"name": "Node1"})
254
+ node2 = Node("2", {"name": "Node2"})
255
+ node3 = Node("3", {"name": "Node3"})
256
+ graph.add_node(node1)
257
+ graph.add_node(node2)
258
+ graph.add_node(node3)
259
+
260
+ node_distances = {
261
+ "Entity_name": [
262
+ MockScoredResult("1", 0.95),
263
+ MockScoredResult("2", 0.87),
264
+ ]
265
+ }
266
+
267
+ await graph.map_vector_distances_to_graph_nodes(node_distances)
268
+
269
+ assert graph.get_node("1").attributes.get("vector_distance") == 0.95
270
+ assert graph.get_node("2").attributes.get("vector_distance") == 0.87
271
+ assert graph.get_node("3").attributes.get("vector_distance") == 3.5
272
+
273
+
274
+ @pytest.mark.asyncio
275
+ async def test_map_vector_distances_multiple_categories(setup_graph):
276
+ """Test mapping vector distances from multiple collection categories."""
277
+ graph = setup_graph
278
+
279
+ # Create nodes
280
+ node1 = Node("1")
281
+ node2 = Node("2")
282
+ node3 = Node("3")
283
+ node4 = Node("4")
284
+ graph.add_node(node1)
285
+ graph.add_node(node2)
286
+ graph.add_node(node3)
287
+ graph.add_node(node4)
288
+
289
+ node_distances = {
290
+ "Entity_name": [
291
+ MockScoredResult("1", 0.95),
292
+ MockScoredResult("2", 0.87),
293
+ ],
294
+ "TextSummary_text": [
295
+ MockScoredResult("3", 0.92),
296
+ ],
297
+ }
298
+
299
+ await graph.map_vector_distances_to_graph_nodes(node_distances)
300
+
301
+ assert graph.get_node("1").attributes.get("vector_distance") == 0.95
302
+ assert graph.get_node("2").attributes.get("vector_distance") == 0.87
303
+ assert graph.get_node("3").attributes.get("vector_distance") == 0.92
304
+ assert graph.get_node("4").attributes.get("vector_distance") == 3.5
305
+
306
+
307
+ @pytest.mark.asyncio
308
+ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
309
+ """Test mapping vector distances to edges when edge_distances provided."""
310
+ graph = setup_graph
311
+
312
+ node1 = Node("1")
313
+ node2 = Node("2")
314
+ graph.add_node(node1)
315
+ graph.add_node(node2)
316
+
317
+ edge = Edge(
318
+ node1,
319
+ node2,
320
+ attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
321
+ )
322
+ graph.add_edge(edge)
323
+
324
+ edge_distances = [
325
+ MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
326
+ ]
327
+
328
+ await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
329
+
330
+ assert graph.edges[0].attributes.get("vector_distance") == 0.92
331
+
332
+
333
+ @pytest.mark.asyncio
334
+ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
335
+ """Test mapping edge distances when only some edges have results."""
336
+ graph = setup_graph
337
+
338
+ node1 = Node("1")
339
+ node2 = Node("2")
340
+ node3 = Node("3")
341
+ graph.add_node(node1)
342
+ graph.add_node(node2)
343
+ graph.add_node(node3)
344
+
345
+ edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"})
346
+ edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"})
347
+ graph.add_edge(edge1)
348
+ graph.add_edge(edge2)
349
+
350
+ edge_distances = [
351
+ MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
352
+ ]
353
+
354
+ await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
355
+
356
+ assert graph.edges[0].attributes.get("vector_distance") == 0.92
357
+ assert graph.edges[1].attributes.get("vector_distance") == 3.5
358
+
359
+
360
+ @pytest.mark.asyncio
361
+ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_graph):
362
+ """Test that edge mapping falls back to relationship_type when edge_text is missing."""
363
+ graph = setup_graph
364
+
365
+ node1 = Node("1")
366
+ node2 = Node("2")
367
+ graph.add_node(node1)
368
+ graph.add_node(node2)
369
+
370
+ edge = Edge(
371
+ node1,
372
+ node2,
373
+ attributes={"relationship_type": "KNOWS"},
374
+ )
375
+ graph.add_edge(edge)
376
+
377
+ edge_distances = [
378
+ MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
379
+ ]
380
+
381
+ await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
382
+
383
+ assert graph.edges[0].attributes.get("vector_distance") == 0.85
384
+
385
+
386
+ @pytest.mark.asyncio
387
+ async def test_map_vector_distances_no_edge_matches(setup_graph):
388
+ """Test edge mapping when no edges match the distance results."""
389
+ graph = setup_graph
390
+
391
+ node1 = Node("1")
392
+ node2 = Node("2")
393
+ graph.add_node(node1)
394
+ graph.add_node(node2)
395
+
396
+ edge = Edge(
397
+ node1,
398
+ node2,
399
+ attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
400
+ )
401
+ graph.add_edge(edge)
402
+
403
+ edge_distances = [
404
+ MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
405
+ ]
406
+
407
+ await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
408
+
409
+ assert graph.edges[0].attributes.get("vector_distance") == 3.5
410
+
411
+
412
+ @pytest.mark.asyncio
413
+ async def test_map_vector_distances_none_returns_early(setup_graph):
414
+ """Test that edge_distances=None returns early without error."""
415
+ graph = setup_graph
416
+ graph.add_node(Node("1"))
417
+ graph.add_node(Node("2"))
418
+ graph.add_edge(Edge(graph.get_node("1"), graph.get_node("2")))
419
+
420
+ await graph.map_vector_distances_to_graph_edges(edge_distances=None)
421
+
422
+ assert graph.edges[0].attributes.get("vector_distance") == 3.5
423
+
424
+
425
+ @pytest.mark.asyncio
426
+ async def test_calculate_top_triplet_importances(setup_graph):
427
+ """Test calculating top triplet importances by score."""
428
+ graph = setup_graph
429
+
430
+ node1 = Node("1")
431
+ node2 = Node("2")
432
+ node3 = Node("3")
433
+ node4 = Node("4")
434
+
435
+ node1.add_attribute("vector_distance", 0.9)
436
+ node2.add_attribute("vector_distance", 0.8)
437
+ node3.add_attribute("vector_distance", 0.7)
438
+ node4.add_attribute("vector_distance", 0.6)
439
+
440
+ graph.add_node(node1)
441
+ graph.add_node(node2)
442
+ graph.add_node(node3)
443
+ graph.add_node(node4)
444
+
445
+ edge1 = Edge(node1, node2)
446
+ edge2 = Edge(node2, node3)
447
+ edge3 = Edge(node3, node4)
448
+
449
+ edge1.add_attribute("vector_distance", 0.85)
450
+ edge2.add_attribute("vector_distance", 0.75)
451
+ edge3.add_attribute("vector_distance", 0.65)
452
+
453
+ graph.add_edge(edge1)
454
+ graph.add_edge(edge2)
455
+ graph.add_edge(edge3)
456
+
457
+ top_triplets = await graph.calculate_top_triplet_importances(k=2)
458
+
459
+ assert len(top_triplets) == 2
460
+
461
+ assert top_triplets[0] == edge3
462
+ assert top_triplets[1] == edge2
463
+
464
+
465
+ @pytest.mark.asyncio
466
+ async def test_calculate_top_triplet_importances_default_distances(setup_graph):
467
+ """Test calculating importances when nodes/edges have no vector distances."""
468
+ graph = setup_graph
469
+
470
+ node1 = Node("1")
471
+ node2 = Node("2")
472
+ graph.add_node(node1)
473
+ graph.add_node(node2)
474
+
475
+ edge = Edge(node1, node2)
476
+ graph.add_edge(edge)
477
+
478
+ top_triplets = await graph.calculate_top_triplet_importances(k=1)
479
+
480
+ assert len(top_triplets) == 1
481
+ assert top_triplets[0] == edge
@@ -0,0 +1,214 @@
1
+ import sys
2
+ import pytest
3
+ from unittest.mock import AsyncMock, patch
4
+
5
+ from cognee.tasks.memify.get_triplet_datapoints import get_triplet_datapoints
6
+ from cognee.modules.engine.models import Triplet
7
+ from cognee.modules.engine.models.Entity import Entity
8
+ from cognee.infrastructure.engine import DataPoint
9
+ from cognee.modules.graph.models.EdgeType import EdgeType
10
+
11
+
12
+ get_triplet_datapoints_module = sys.modules["cognee.tasks.memify.get_triplet_datapoints"]
13
+
14
+
15
+ @pytest.fixture
16
+ def mock_graph_engine():
17
+ """Create a mock graph engine with get_triplets_batch method."""
18
+ engine = AsyncMock()
19
+ engine.get_triplets_batch = AsyncMock()
20
+ return engine
21
+
22
+
23
+ @pytest.mark.asyncio
24
+ async def test_get_triplet_datapoints_success(mock_graph_engine):
25
+ """Test successful extraction of triplet datapoints."""
26
+ mock_triplets_batch = [
27
+ {
28
+ "start_node": {
29
+ "id": "node1",
30
+ "type": "Entity",
31
+ "name": "Alice",
32
+ "description": "A person",
33
+ },
34
+ "end_node": {
35
+ "id": "node2",
36
+ "type": "Entity",
37
+ "name": "Bob",
38
+ "description": "Another person",
39
+ },
40
+ "relationship_properties": {
41
+ "relationship_name": "knows",
42
+ },
43
+ }
44
+ ]
45
+
46
+ mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
47
+
48
+ with (
49
+ patch.object(
50
+ get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
51
+ ),
52
+ patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
53
+ ):
54
+ mock_get_subclasses.return_value = [Triplet, EdgeType, Entity]
55
+
56
+ triplets = []
57
+ async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
58
+ triplets.append(triplet)
59
+
60
+ assert len(triplets) == 1
61
+ assert isinstance(triplets[0], Triplet)
62
+ assert triplets[0].from_node_id == "node1"
63
+ assert triplets[0].to_node_id == "node2"
64
+ assert "Alice" in triplets[0].text
65
+ assert "knows" in triplets[0].text
66
+ assert "Bob" in triplets[0].text
67
+
68
+
69
+ @pytest.mark.asyncio
70
+ async def test_get_triplet_datapoints_edge_text_priority_and_fallback(mock_graph_engine):
71
+ """Test that edge_text is prioritized over relationship_name, and fallback works."""
72
+
73
+ class MockEntity(DataPoint):
74
+ name: str
75
+ metadata: dict = {"index_fields": ["name"]}
76
+
77
+ mock_triplets_batch = [
78
+ {
79
+ "start_node": {"id": "node1", "type": "Entity", "name": "Alice"},
80
+ "end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
81
+ "relationship_properties": {
82
+ "relationship_name": "knows",
83
+ "edge_text": "has a close friendship with",
84
+ },
85
+ },
86
+ {
87
+ "start_node": {"id": "node3", "type": "Entity", "name": "Charlie"},
88
+ "end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
89
+ "relationship_properties": {
90
+ "relationship_name": "works_with",
91
+ },
92
+ },
93
+ ]
94
+
95
+ mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
96
+
97
+ with (
98
+ patch.object(
99
+ get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
100
+ ),
101
+ patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
102
+ ):
103
+ mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
104
+
105
+ triplets = []
106
+ async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
107
+ triplets.append(triplet)
108
+
109
+ assert len(triplets) == 2
110
+ assert "has a close friendship with" in triplets[0].text
111
+ assert "knows" not in triplets[0].text
112
+ assert "works_with" in triplets[1].text
113
+
114
+
115
+ @pytest.mark.asyncio
116
+ async def test_get_triplet_datapoints_skips_missing_node_ids(mock_graph_engine):
117
+ """Test that triplets with missing node IDs are skipped."""
118
+
119
+ class MockEntity(DataPoint):
120
+ name: str
121
+ metadata: dict = {"index_fields": ["name"]}
122
+
123
+ mock_triplets_batch = [
124
+ {
125
+ "start_node": {"id": "", "type": "Entity", "name": "Alice"},
126
+ "end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
127
+ "relationship_properties": {"relationship_name": "knows"},
128
+ },
129
+ {
130
+ "start_node": {"id": "node3", "type": "Entity", "name": "Charlie"},
131
+ "end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
132
+ "relationship_properties": {"relationship_name": "works_with"},
133
+ },
134
+ ]
135
+
136
+ mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
137
+
138
+ with (
139
+ patch.object(
140
+ get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
141
+ ),
142
+ patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
143
+ ):
144
+ mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
145
+
146
+ triplets = []
147
+ async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
148
+ triplets.append(triplet)
149
+
150
+ assert len(triplets) == 1
151
+ assert triplets[0].from_node_id == "node3"
152
+
153
+
154
+ @pytest.mark.asyncio
155
+ async def test_get_triplet_datapoints_error_handling(mock_graph_engine):
156
+ """Test that errors are handled correctly - invalid data is skipped, query errors propagate."""
157
+
158
+ class MockEntity(DataPoint):
159
+ name: str
160
+ metadata: dict = {"index_fields": ["name"]}
161
+
162
+ mock_triplets_batch = [
163
+ {
164
+ "start_node": {"id": "node1", "type": "Entity", "name": "Alice"},
165
+ "end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
166
+ "relationship_properties": {"relationship_name": "knows"},
167
+ },
168
+ {
169
+ "start_node": None,
170
+ "end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
171
+ "relationship_properties": {"relationship_name": "works_with"},
172
+ },
173
+ ]
174
+
175
+ mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
176
+
177
+ with (
178
+ patch.object(
179
+ get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
180
+ ),
181
+ patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
182
+ ):
183
+ mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
184
+
185
+ triplets = []
186
+ async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
187
+ triplets.append(triplet)
188
+
189
+ assert len(triplets) == 1
190
+ assert triplets[0].from_node_id == "node1"
191
+
192
+ mock_graph_engine.get_triplets_batch.side_effect = Exception("Database connection error")
193
+
194
+ with patch.object(
195
+ get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
196
+ ):
197
+ triplets = []
198
+ with pytest.raises(Exception, match="Database connection error"):
199
+ async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
200
+ triplets.append(triplet)
201
+
202
+
203
+ @pytest.mark.asyncio
204
+ async def test_get_triplet_datapoints_no_get_triplets_batch_method(mock_graph_engine):
205
+ """Test that NotImplementedError is raised when graph engine lacks get_triplets_batch."""
206
+ del mock_graph_engine.get_triplets_batch
207
+
208
+ with patch.object(
209
+ get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
210
+ ):
211
+ triplets = []
212
+ with pytest.raises(NotImplementedError, match="does not support get_triplets_batch"):
213
+ async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
214
+ triplets.append(triplet)