cognee 0.5.0__py3-none-any.whl → 0.5.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. cognee/api/client.py +5 -1
  2. cognee/api/v1/add/add.py +1 -2
  3. cognee/api/v1/cognify/code_graph_pipeline.py +119 -0
  4. cognee/api/v1/cognify/cognify.py +16 -24
  5. cognee/api/v1/cognify/routers/__init__.py +1 -0
  6. cognee/api/v1/cognify/routers/get_code_pipeline_router.py +90 -0
  7. cognee/api/v1/cognify/routers/get_cognify_router.py +1 -3
  8. cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
  9. cognee/api/v1/ontologies/ontologies.py +37 -12
  10. cognee/api/v1/ontologies/routers/get_ontology_router.py +25 -27
  11. cognee/api/v1/search/search.py +0 -4
  12. cognee/api/v1/ui/ui.py +68 -38
  13. cognee/context_global_variables.py +16 -61
  14. cognee/eval_framework/answer_generation/answer_generation_executor.py +0 -10
  15. cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
  16. cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +2 -0
  17. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
  18. cognee/eval_framework/eval_config.py +2 -2
  19. cognee/eval_framework/modal_run_eval.py +28 -16
  20. cognee/infrastructure/databases/graph/config.py +0 -3
  21. cognee/infrastructure/databases/graph/get_graph_engine.py +0 -1
  22. cognee/infrastructure/databases/graph/graph_db_interface.py +0 -15
  23. cognee/infrastructure/databases/graph/kuzu/adapter.py +0 -228
  24. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +1 -80
  25. cognee/infrastructure/databases/utils/__init__.py +0 -3
  26. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +48 -62
  27. cognee/infrastructure/databases/vector/config.py +0 -2
  28. cognee/infrastructure/databases/vector/create_vector_engine.py +0 -1
  29. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -8
  30. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +7 -9
  31. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +10 -11
  32. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +544 -0
  33. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -2
  34. cognee/infrastructure/databases/vector/vector_db_interface.py +0 -35
  35. cognee/infrastructure/files/storage/s3_config.py +0 -2
  36. cognee/infrastructure/llm/LLMGateway.py +2 -5
  37. cognee/infrastructure/llm/config.py +0 -35
  38. cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
  39. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +8 -23
  40. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +16 -17
  41. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +37 -40
  42. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +36 -39
  43. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +1 -19
  44. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +9 -11
  45. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +21 -23
  46. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +34 -42
  47. cognee/modules/cognify/config.py +0 -2
  48. cognee/modules/data/deletion/prune_system.py +2 -52
  49. cognee/modules/data/methods/delete_dataset.py +0 -26
  50. cognee/modules/engine/models/__init__.py +0 -1
  51. cognee/modules/graph/cognee_graph/CogneeGraph.py +37 -85
  52. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +3 -8
  53. cognee/modules/memify/memify.py +7 -1
  54. cognee/modules/pipelines/operations/pipeline.py +2 -18
  55. cognee/modules/retrieval/__init__.py +1 -1
  56. cognee/modules/retrieval/code_retriever.py +232 -0
  57. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -4
  58. cognee/modules/retrieval/graph_completion_cot_retriever.py +0 -4
  59. cognee/modules/retrieval/graph_completion_retriever.py +0 -10
  60. cognee/modules/retrieval/graph_summary_completion_retriever.py +0 -4
  61. cognee/modules/retrieval/temporal_retriever.py +0 -4
  62. cognee/modules/retrieval/utils/brute_force_triplet_search.py +10 -42
  63. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +1 -8
  64. cognee/modules/search/methods/get_search_type_tools.py +8 -54
  65. cognee/modules/search/methods/no_access_control_search.py +0 -4
  66. cognee/modules/search/methods/search.py +0 -21
  67. cognee/modules/search/types/SearchType.py +1 -1
  68. cognee/modules/settings/get_settings.py +0 -19
  69. cognee/modules/users/methods/get_authenticated_user.py +2 -2
  70. cognee/modules/users/models/DatasetDatabase.py +3 -15
  71. cognee/shared/logging_utils.py +0 -4
  72. cognee/tasks/code/enrich_dependency_graph_checker.py +35 -0
  73. cognee/tasks/code/get_local_dependencies_checker.py +20 -0
  74. cognee/tasks/code/get_repo_dependency_graph_checker.py +35 -0
  75. cognee/tasks/documents/__init__.py +1 -0
  76. cognee/tasks/documents/check_permissions_on_dataset.py +26 -0
  77. cognee/tasks/graph/extract_graph_from_data.py +10 -9
  78. cognee/tasks/repo_processor/__init__.py +2 -0
  79. cognee/tasks/repo_processor/get_local_dependencies.py +335 -0
  80. cognee/tasks/repo_processor/get_non_code_files.py +158 -0
  81. cognee/tasks/repo_processor/get_repo_file_dependencies.py +243 -0
  82. cognee/tasks/storage/add_data_points.py +2 -142
  83. cognee/tests/test_cognee_server_start.py +4 -2
  84. cognee/tests/test_conversation_history.py +1 -23
  85. cognee/tests/test_delete_bmw_example.py +60 -0
  86. cognee/tests/test_search_db.py +1 -37
  87. cognee/tests/unit/api/test_ontology_endpoint.py +89 -77
  88. cognee/tests/unit/infrastructure/mock_embedding_engine.py +7 -3
  89. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -0
  90. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
  91. cognee/tests/unit/modules/graph/cognee_graph_test.py +0 -406
  92. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/METADATA +89 -76
  93. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/RECORD +97 -118
  94. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/WHEEL +1 -1
  95. cognee/api/v1/ui/node_setup.py +0 -360
  96. cognee/api/v1/ui/npm_utils.py +0 -50
  97. cognee/eval_framework/Dockerfile +0 -29
  98. cognee/infrastructure/databases/dataset_database_handler/__init__.py +0 -3
  99. cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +0 -80
  100. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +0 -18
  101. cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +0 -10
  102. cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +0 -81
  103. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +0 -168
  104. cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +0 -10
  105. cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +0 -10
  106. cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +0 -30
  107. cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +0 -50
  108. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +0 -5
  109. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +0 -153
  110. cognee/memify_pipelines/create_triplet_embeddings.py +0 -53
  111. cognee/modules/engine/models/Triplet.py +0 -9
  112. cognee/modules/retrieval/register_retriever.py +0 -10
  113. cognee/modules/retrieval/registered_community_retrievers.py +0 -1
  114. cognee/modules/retrieval/triplet_retriever.py +0 -182
  115. cognee/shared/rate_limiting.py +0 -30
  116. cognee/tasks/memify/get_triplet_datapoints.py +0 -289
  117. cognee/tests/integration/retrieval/test_triplet_retriever.py +0 -84
  118. cognee/tests/integration/tasks/test_add_data_points.py +0 -139
  119. cognee/tests/integration/tasks/test_get_triplet_datapoints.py +0 -69
  120. cognee/tests/test_dataset_database_handler.py +0 -137
  121. cognee/tests/test_dataset_delete.py +0 -76
  122. cognee/tests/test_edge_centered_payload.py +0 -170
  123. cognee/tests/test_pipeline_cache.py +0 -164
  124. cognee/tests/unit/infrastructure/llm/test_llm_config.py +0 -46
  125. cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +0 -214
  126. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +0 -608
  127. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +0 -83
  128. cognee/tests/unit/tasks/storage/test_add_data_points.py +0 -288
  129. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/entry_points.txt +0 -0
  130. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/LICENSE +0 -0
  131. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,608 +0,0 @@
1
- import pytest
2
- from unittest.mock import AsyncMock, patch
3
-
4
- from cognee.modules.retrieval.utils.brute_force_triplet_search import (
5
- brute_force_triplet_search,
6
- get_memory_fragment,
7
- )
8
- from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
9
- from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
10
-
11
-
12
- class MockScoredResult:
13
- """Mock class for vector search results."""
14
-
15
- def __init__(self, id, score, payload=None):
16
- self.id = id
17
- self.score = score
18
- self.payload = payload or {}
19
-
20
-
21
- @pytest.mark.asyncio
22
- async def test_brute_force_triplet_search_empty_query():
23
- """Test that empty query raises ValueError."""
24
- with pytest.raises(ValueError, match="The query must be a non-empty string."):
25
- await brute_force_triplet_search(query="")
26
-
27
-
28
- @pytest.mark.asyncio
29
- async def test_brute_force_triplet_search_none_query():
30
- """Test that None query raises ValueError."""
31
- with pytest.raises(ValueError, match="The query must be a non-empty string."):
32
- await brute_force_triplet_search(query=None)
33
-
34
-
35
- @pytest.mark.asyncio
36
- async def test_brute_force_triplet_search_negative_top_k():
37
- """Test that negative top_k raises ValueError."""
38
- with pytest.raises(ValueError, match="top_k must be a positive integer."):
39
- await brute_force_triplet_search(query="test query", top_k=-1)
40
-
41
-
42
- @pytest.mark.asyncio
43
- async def test_brute_force_triplet_search_zero_top_k():
44
- """Test that zero top_k raises ValueError."""
45
- with pytest.raises(ValueError, match="top_k must be a positive integer."):
46
- await brute_force_triplet_search(query="test query", top_k=0)
47
-
48
-
49
- @pytest.mark.asyncio
50
- async def test_brute_force_triplet_search_wide_search_limit_global_search():
51
- """Test that wide_search_limit is applied for global search (node_name=None)."""
52
- mock_vector_engine = AsyncMock()
53
- mock_vector_engine.embedding_engine = AsyncMock()
54
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
55
- mock_vector_engine.search = AsyncMock(return_value=[])
56
-
57
- with patch(
58
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
59
- return_value=mock_vector_engine,
60
- ):
61
- await brute_force_triplet_search(
62
- query="test",
63
- node_name=None, # Global search
64
- wide_search_top_k=75,
65
- )
66
-
67
- for call in mock_vector_engine.search.call_args_list:
68
- assert call[1]["limit"] == 75
69
-
70
-
71
- @pytest.mark.asyncio
72
- async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
73
- """Test that wide_search_limit is None for filtered search (node_name provided)."""
74
- mock_vector_engine = AsyncMock()
75
- mock_vector_engine.embedding_engine = AsyncMock()
76
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
77
- mock_vector_engine.search = AsyncMock(return_value=[])
78
-
79
- with patch(
80
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
81
- return_value=mock_vector_engine,
82
- ):
83
- await brute_force_triplet_search(
84
- query="test",
85
- node_name=["Node1"],
86
- wide_search_top_k=50,
87
- )
88
-
89
- for call in mock_vector_engine.search.call_args_list:
90
- assert call[1]["limit"] is None
91
-
92
-
93
- @pytest.mark.asyncio
94
- async def test_brute_force_triplet_search_wide_search_default():
95
- """Test that wide_search_top_k defaults to 100."""
96
- mock_vector_engine = AsyncMock()
97
- mock_vector_engine.embedding_engine = AsyncMock()
98
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
99
- mock_vector_engine.search = AsyncMock(return_value=[])
100
-
101
- with patch(
102
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
103
- return_value=mock_vector_engine,
104
- ):
105
- await brute_force_triplet_search(query="test", node_name=None)
106
-
107
- for call in mock_vector_engine.search.call_args_list:
108
- assert call[1]["limit"] == 100
109
-
110
-
111
- @pytest.mark.asyncio
112
- async def test_brute_force_triplet_search_default_collections():
113
- """Test that default collections are used when none provided."""
114
- mock_vector_engine = AsyncMock()
115
- mock_vector_engine.embedding_engine = AsyncMock()
116
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
117
- mock_vector_engine.search = AsyncMock(return_value=[])
118
-
119
- with patch(
120
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
121
- return_value=mock_vector_engine,
122
- ):
123
- await brute_force_triplet_search(query="test")
124
-
125
- expected_collections = [
126
- "Entity_name",
127
- "TextSummary_text",
128
- "EntityType_name",
129
- "DocumentChunk_text",
130
- "EdgeType_relationship_name",
131
- ]
132
-
133
- call_collections = [
134
- call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
135
- ]
136
- assert call_collections == expected_collections
137
-
138
-
139
- @pytest.mark.asyncio
140
- async def test_brute_force_triplet_search_custom_collections():
141
- """Test that custom collections are used when provided."""
142
- mock_vector_engine = AsyncMock()
143
- mock_vector_engine.embedding_engine = AsyncMock()
144
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
145
- mock_vector_engine.search = AsyncMock(return_value=[])
146
-
147
- custom_collections = ["CustomCol1", "CustomCol2"]
148
-
149
- with patch(
150
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
151
- return_value=mock_vector_engine,
152
- ):
153
- await brute_force_triplet_search(query="test", collections=custom_collections)
154
-
155
- call_collections = [
156
- call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
157
- ]
158
- assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"}
159
-
160
-
161
- @pytest.mark.asyncio
162
- async def test_brute_force_triplet_search_always_includes_edge_collection():
163
- """Test that EdgeType_relationship_name is always searched even when not in collections."""
164
- mock_vector_engine = AsyncMock()
165
- mock_vector_engine.embedding_engine = AsyncMock()
166
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
167
- mock_vector_engine.search = AsyncMock(return_value=[])
168
-
169
- collections_without_edge = ["Entity_name", "TextSummary_text"]
170
-
171
- with patch(
172
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
173
- return_value=mock_vector_engine,
174
- ):
175
- await brute_force_triplet_search(query="test", collections=collections_without_edge)
176
-
177
- call_collections = [
178
- call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
179
- ]
180
- assert "EdgeType_relationship_name" in call_collections
181
- assert set(call_collections) == set(collections_without_edge) | {
182
- "EdgeType_relationship_name"
183
- }
184
-
185
-
186
- @pytest.mark.asyncio
187
- async def test_brute_force_triplet_search_all_collections_empty():
188
- """Test that empty list is returned when all collections return no results."""
189
- mock_vector_engine = AsyncMock()
190
- mock_vector_engine.embedding_engine = AsyncMock()
191
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
192
- mock_vector_engine.search = AsyncMock(return_value=[])
193
-
194
- with patch(
195
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
196
- return_value=mock_vector_engine,
197
- ):
198
- results = await brute_force_triplet_search(query="test")
199
- assert results == []
200
-
201
-
202
- # Tests for query embedding
203
-
204
-
205
- @pytest.mark.asyncio
206
- async def test_brute_force_triplet_search_embeds_query():
207
- """Test that query is embedded before searching."""
208
- query_text = "test query"
209
- expected_vector = [0.1, 0.2, 0.3]
210
-
211
- mock_vector_engine = AsyncMock()
212
- mock_vector_engine.embedding_engine = AsyncMock()
213
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
214
- mock_vector_engine.search = AsyncMock(return_value=[])
215
-
216
- with patch(
217
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
218
- return_value=mock_vector_engine,
219
- ):
220
- await brute_force_triplet_search(query=query_text)
221
-
222
- mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
223
-
224
- for call in mock_vector_engine.search.call_args_list:
225
- assert call[1]["query_vector"] == expected_vector
226
-
227
-
228
- @pytest.mark.asyncio
229
- async def test_brute_force_triplet_search_extracts_node_ids_global_search():
230
- """Test that node IDs are extracted from search results for global search."""
231
- scored_results = [
232
- MockScoredResult("node1", 0.95),
233
- MockScoredResult("node2", 0.87),
234
- MockScoredResult("node3", 0.92),
235
- ]
236
-
237
- mock_vector_engine = AsyncMock()
238
- mock_vector_engine.embedding_engine = AsyncMock()
239
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
240
- mock_vector_engine.search = AsyncMock(return_value=scored_results)
241
-
242
- mock_fragment = AsyncMock(
243
- map_vector_distances_to_graph_nodes=AsyncMock(),
244
- map_vector_distances_to_graph_edges=AsyncMock(),
245
- calculate_top_triplet_importances=AsyncMock(return_value=[]),
246
- )
247
-
248
- with (
249
- patch(
250
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
251
- return_value=mock_vector_engine,
252
- ),
253
- patch(
254
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
255
- return_value=mock_fragment,
256
- ) as mock_get_fragment_fn,
257
- ):
258
- await brute_force_triplet_search(query="test", node_name=None)
259
-
260
- call_kwargs = mock_get_fragment_fn.call_args[1]
261
- assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
262
-
263
-
264
- @pytest.mark.asyncio
265
- async def test_brute_force_triplet_search_reuses_provided_fragment():
266
- """Test that provided memory fragment is reused instead of creating new one."""
267
- provided_fragment = AsyncMock(
268
- map_vector_distances_to_graph_nodes=AsyncMock(),
269
- map_vector_distances_to_graph_edges=AsyncMock(),
270
- calculate_top_triplet_importances=AsyncMock(return_value=[]),
271
- )
272
-
273
- mock_vector_engine = AsyncMock()
274
- mock_vector_engine.embedding_engine = AsyncMock()
275
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
276
- mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
277
-
278
- with (
279
- patch(
280
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
281
- return_value=mock_vector_engine,
282
- ),
283
- patch(
284
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
285
- ) as mock_get_fragment,
286
- ):
287
- await brute_force_triplet_search(
288
- query="test",
289
- memory_fragment=provided_fragment,
290
- node_name=["node"],
291
- )
292
-
293
- mock_get_fragment.assert_not_called()
294
-
295
-
296
- @pytest.mark.asyncio
297
- async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
298
- """Test that memory fragment is created when not provided."""
299
- mock_vector_engine = AsyncMock()
300
- mock_vector_engine.embedding_engine = AsyncMock()
301
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
302
- mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
303
-
304
- mock_fragment = AsyncMock(
305
- map_vector_distances_to_graph_nodes=AsyncMock(),
306
- map_vector_distances_to_graph_edges=AsyncMock(),
307
- calculate_top_triplet_importances=AsyncMock(return_value=[]),
308
- )
309
-
310
- with (
311
- patch(
312
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
313
- return_value=mock_vector_engine,
314
- ),
315
- patch(
316
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
317
- return_value=mock_fragment,
318
- ) as mock_get_fragment,
319
- ):
320
- await brute_force_triplet_search(query="test", node_name=["node"])
321
-
322
- mock_get_fragment.assert_called_once()
323
-
324
-
325
- @pytest.mark.asyncio
326
- async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
327
- """Test that custom top_k is passed to importance calculation."""
328
- mock_vector_engine = AsyncMock()
329
- mock_vector_engine.embedding_engine = AsyncMock()
330
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
331
- mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
332
-
333
- mock_fragment = AsyncMock(
334
- map_vector_distances_to_graph_nodes=AsyncMock(),
335
- map_vector_distances_to_graph_edges=AsyncMock(),
336
- calculate_top_triplet_importances=AsyncMock(return_value=[]),
337
- )
338
-
339
- with (
340
- patch(
341
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
342
- return_value=mock_vector_engine,
343
- ),
344
- patch(
345
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
346
- return_value=mock_fragment,
347
- ),
348
- ):
349
- custom_top_k = 15
350
- await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
351
-
352
- mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
353
-
354
-
355
- @pytest.mark.asyncio
356
- async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
357
- """Test that get_memory_fragment returns empty graph when entity not found."""
358
- mock_graph_engine = AsyncMock()
359
- mock_graph_engine.project_graph_from_db = AsyncMock(
360
- side_effect=EntityNotFoundError("Entity not found")
361
- )
362
-
363
- with patch(
364
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
365
- return_value=mock_graph_engine,
366
- ):
367
- fragment = await get_memory_fragment()
368
-
369
- assert isinstance(fragment, CogneeGraph)
370
- assert len(fragment.nodes) == 0
371
-
372
-
373
- @pytest.mark.asyncio
374
- async def test_get_memory_fragment_returns_empty_graph_on_error():
375
- """Test that get_memory_fragment returns empty graph on generic error."""
376
- mock_graph_engine = AsyncMock()
377
- mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
378
-
379
- with patch(
380
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
381
- return_value=mock_graph_engine,
382
- ):
383
- fragment = await get_memory_fragment()
384
-
385
- assert isinstance(fragment, CogneeGraph)
386
- assert len(fragment.nodes) == 0
387
-
388
-
389
- @pytest.mark.asyncio
390
- async def test_brute_force_triplet_search_deduplicates_node_ids():
391
- """Test that duplicate node IDs across collections are deduplicated."""
392
-
393
- def search_side_effect(*args, **kwargs):
394
- collection_name = kwargs.get("collection_name")
395
- if collection_name == "Entity_name":
396
- return [
397
- MockScoredResult("node1", 0.95),
398
- MockScoredResult("node2", 0.87),
399
- ]
400
- elif collection_name == "TextSummary_text":
401
- return [
402
- MockScoredResult("node1", 0.90),
403
- MockScoredResult("node3", 0.92),
404
- ]
405
- else:
406
- return []
407
-
408
- mock_vector_engine = AsyncMock()
409
- mock_vector_engine.embedding_engine = AsyncMock()
410
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
411
- mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
412
-
413
- mock_fragment = AsyncMock(
414
- map_vector_distances_to_graph_nodes=AsyncMock(),
415
- map_vector_distances_to_graph_edges=AsyncMock(),
416
- calculate_top_triplet_importances=AsyncMock(return_value=[]),
417
- )
418
-
419
- with (
420
- patch(
421
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
422
- return_value=mock_vector_engine,
423
- ),
424
- patch(
425
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
426
- return_value=mock_fragment,
427
- ) as mock_get_fragment_fn,
428
- ):
429
- await brute_force_triplet_search(query="test", node_name=None)
430
-
431
- call_kwargs = mock_get_fragment_fn.call_args[1]
432
- assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
433
- assert len(call_kwargs["relevant_ids_to_filter"]) == 3
434
-
435
-
436
- @pytest.mark.asyncio
437
- async def test_brute_force_triplet_search_excludes_edge_collection():
438
- """Test that EdgeType_relationship_name collection is excluded from ID extraction."""
439
-
440
- def search_side_effect(*args, **kwargs):
441
- collection_name = kwargs.get("collection_name")
442
- if collection_name == "Entity_name":
443
- return [MockScoredResult("node1", 0.95)]
444
- elif collection_name == "EdgeType_relationship_name":
445
- return [MockScoredResult("edge1", 0.88)]
446
- else:
447
- return []
448
-
449
- mock_vector_engine = AsyncMock()
450
- mock_vector_engine.embedding_engine = AsyncMock()
451
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
452
- mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
453
-
454
- mock_fragment = AsyncMock(
455
- map_vector_distances_to_graph_nodes=AsyncMock(),
456
- map_vector_distances_to_graph_edges=AsyncMock(),
457
- calculate_top_triplet_importances=AsyncMock(return_value=[]),
458
- )
459
-
460
- with (
461
- patch(
462
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
463
- return_value=mock_vector_engine,
464
- ),
465
- patch(
466
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
467
- return_value=mock_fragment,
468
- ) as mock_get_fragment_fn,
469
- ):
470
- await brute_force_triplet_search(
471
- query="test",
472
- node_name=None,
473
- collections=["Entity_name", "EdgeType_relationship_name"],
474
- )
475
-
476
- call_kwargs = mock_get_fragment_fn.call_args[1]
477
- assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
478
-
479
-
480
- @pytest.mark.asyncio
481
- async def test_brute_force_triplet_search_skips_nodes_without_ids():
482
- """Test that nodes without ID attribute are skipped."""
483
-
484
- class ScoredResultNoId:
485
- """Mock result without id attribute."""
486
-
487
- def __init__(self, score):
488
- self.score = score
489
-
490
- def search_side_effect(*args, **kwargs):
491
- collection_name = kwargs.get("collection_name")
492
- if collection_name == "Entity_name":
493
- return [
494
- MockScoredResult("node1", 0.95),
495
- ScoredResultNoId(0.90),
496
- MockScoredResult("node2", 0.87),
497
- ]
498
- else:
499
- return []
500
-
501
- mock_vector_engine = AsyncMock()
502
- mock_vector_engine.embedding_engine = AsyncMock()
503
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
504
- mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
505
-
506
- mock_fragment = AsyncMock(
507
- map_vector_distances_to_graph_nodes=AsyncMock(),
508
- map_vector_distances_to_graph_edges=AsyncMock(),
509
- calculate_top_triplet_importances=AsyncMock(return_value=[]),
510
- )
511
-
512
- with (
513
- patch(
514
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
515
- return_value=mock_vector_engine,
516
- ),
517
- patch(
518
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
519
- return_value=mock_fragment,
520
- ) as mock_get_fragment_fn,
521
- ):
522
- await brute_force_triplet_search(query="test", node_name=None)
523
-
524
- call_kwargs = mock_get_fragment_fn.call_args[1]
525
- assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
526
-
527
-
528
- @pytest.mark.asyncio
529
- async def test_brute_force_triplet_search_handles_tuple_results():
530
- """Test that both list and tuple results are handled correctly."""
531
-
532
- def search_side_effect(*args, **kwargs):
533
- collection_name = kwargs.get("collection_name")
534
- if collection_name == "Entity_name":
535
- return (
536
- MockScoredResult("node1", 0.95),
537
- MockScoredResult("node2", 0.87),
538
- )
539
- else:
540
- return []
541
-
542
- mock_vector_engine = AsyncMock()
543
- mock_vector_engine.embedding_engine = AsyncMock()
544
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
545
- mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
546
-
547
- mock_fragment = AsyncMock(
548
- map_vector_distances_to_graph_nodes=AsyncMock(),
549
- map_vector_distances_to_graph_edges=AsyncMock(),
550
- calculate_top_triplet_importances=AsyncMock(return_value=[]),
551
- )
552
-
553
- with (
554
- patch(
555
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
556
- return_value=mock_vector_engine,
557
- ),
558
- patch(
559
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
560
- return_value=mock_fragment,
561
- ) as mock_get_fragment_fn,
562
- ):
563
- await brute_force_triplet_search(query="test", node_name=None)
564
-
565
- call_kwargs = mock_get_fragment_fn.call_args[1]
566
- assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
567
-
568
-
569
- @pytest.mark.asyncio
570
- async def test_brute_force_triplet_search_mixed_empty_collections():
571
- """Test ID extraction with mixed empty and non-empty collections."""
572
-
573
- def search_side_effect(*args, **kwargs):
574
- collection_name = kwargs.get("collection_name")
575
- if collection_name == "Entity_name":
576
- return [MockScoredResult("node1", 0.95)]
577
- elif collection_name == "TextSummary_text":
578
- return []
579
- elif collection_name == "EntityType_name":
580
- return [MockScoredResult("node2", 0.92)]
581
- else:
582
- return []
583
-
584
- mock_vector_engine = AsyncMock()
585
- mock_vector_engine.embedding_engine = AsyncMock()
586
- mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
587
- mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
588
-
589
- mock_fragment = AsyncMock(
590
- map_vector_distances_to_graph_nodes=AsyncMock(),
591
- map_vector_distances_to_graph_edges=AsyncMock(),
592
- calculate_top_triplet_importances=AsyncMock(return_value=[]),
593
- )
594
-
595
- with (
596
- patch(
597
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
598
- return_value=mock_vector_engine,
599
- ),
600
- patch(
601
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
602
- return_value=mock_fragment,
603
- ) as mock_get_fragment_fn,
604
- ):
605
- await brute_force_triplet_search(query="test", node_name=None)
606
-
607
- call_kwargs = mock_get_fragment_fn.call_args[1]
608
- assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
@@ -1,83 +0,0 @@
1
- import pytest
2
- from unittest.mock import AsyncMock, patch, MagicMock
3
-
4
- from cognee.modules.retrieval.triplet_retriever import TripletRetriever
5
- from cognee.modules.retrieval.exceptions.exceptions import NoDataError
6
- from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
7
-
8
-
9
- @pytest.fixture
10
- def mock_vector_engine():
11
- """Create a mock vector engine."""
12
- engine = AsyncMock()
13
- engine.has_collection = AsyncMock(return_value=True)
14
- engine.search = AsyncMock()
15
- return engine
16
-
17
-
18
- @pytest.mark.asyncio
19
- async def test_get_context_success(mock_vector_engine):
20
- """Test successful retrieval of triplet context."""
21
- mock_result1 = MagicMock()
22
- mock_result1.payload = {"text": "Alice knows Bob"}
23
- mock_result2 = MagicMock()
24
- mock_result2.payload = {"text": "Bob works at Tech Corp"}
25
-
26
- mock_vector_engine.search.return_value = [mock_result1, mock_result2]
27
-
28
- retriever = TripletRetriever(top_k=5)
29
-
30
- with patch(
31
- "cognee.modules.retrieval.triplet_retriever.get_vector_engine",
32
- return_value=mock_vector_engine,
33
- ):
34
- context = await retriever.get_context("test query")
35
-
36
- assert context == "Alice knows Bob\nBob works at Tech Corp"
37
- mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
38
-
39
-
40
- @pytest.mark.asyncio
41
- async def test_get_context_no_collection(mock_vector_engine):
42
- """Test that NoDataError is raised when Triplet_text collection doesn't exist."""
43
- mock_vector_engine.has_collection.return_value = False
44
-
45
- retriever = TripletRetriever()
46
-
47
- with patch(
48
- "cognee.modules.retrieval.triplet_retriever.get_vector_engine",
49
- return_value=mock_vector_engine,
50
- ):
51
- with pytest.raises(NoDataError, match="create_triplet_embeddings"):
52
- await retriever.get_context("test query")
53
-
54
-
55
- @pytest.mark.asyncio
56
- async def test_get_context_empty_results(mock_vector_engine):
57
- """Test that empty string is returned when no triplets are found."""
58
- mock_vector_engine.search.return_value = []
59
-
60
- retriever = TripletRetriever()
61
-
62
- with patch(
63
- "cognee.modules.retrieval.triplet_retriever.get_vector_engine",
64
- return_value=mock_vector_engine,
65
- ):
66
- context = await retriever.get_context("test query")
67
-
68
- assert context == ""
69
-
70
-
71
- @pytest.mark.asyncio
72
- async def test_get_context_collection_not_found_error(mock_vector_engine):
73
- """Test that CollectionNotFoundError is converted to NoDataError."""
74
- mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found")
75
-
76
- retriever = TripletRetriever()
77
-
78
- with patch(
79
- "cognee.modules.retrieval.triplet_retriever.get_vector_engine",
80
- return_value=mock_vector_engine,
81
- ):
82
- with pytest.raises(NoDataError, match="No data found"):
83
- await retriever.get_context("test query")