cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. cognee/api/v1/add/add.py +2 -1
  2. cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
  3. cognee/api/v1/memify/routers/get_memify_router.py +1 -0
  4. cognee/api/v1/search/search.py +0 -4
  5. cognee/infrastructure/databases/relational/config.py +16 -1
  6. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  7. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
  8. cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
  9. cognee/infrastructure/llm/LLMGateway.py +0 -13
  10. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  11. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  12. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  13. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
  14. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  15. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  16. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  17. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  18. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  19. cognee/modules/data/models/Data.py +2 -1
  20. cognee/modules/retrieval/triplet_retriever.py +1 -1
  21. cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
  22. cognee/modules/search/methods/search.py +18 -25
  23. cognee/tasks/ingestion/data_item.py +8 -0
  24. cognee/tasks/ingestion/ingest_data.py +12 -1
  25. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  26. cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
  27. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
  28. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
  29. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
  30. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
  31. cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
  32. cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
  33. cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
  34. cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
  35. cognee/tests/test_custom_data_label.py +68 -0
  36. cognee/tests/test_search_db.py +334 -181
  37. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  38. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  39. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  40. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
  41. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  42. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
  43. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
  44. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
  45. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
  46. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
  47. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
  48. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
  49. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  50. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  51. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  52. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
  53. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
  54. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
  55. cognee/tests/unit/modules/search/test_search.py +0 -100
  56. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
  57. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
  58. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
  59. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,223 +1,648 @@
1
- import os
2
1
  import pytest
3
- import pathlib
4
- from typing import Optional, Union
2
+ from unittest.mock import AsyncMock, patch, MagicMock
3
+ from uuid import UUID
5
4
 
6
- import cognee
7
- from cognee.low_level import setup, DataPoint
8
- from cognee.modules.graph.utils import resolve_edges_to_text
9
- from cognee.tasks.storage import add_data_points
10
5
  from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
6
+ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
11
7
 
12
8
 
13
- class TestGraphCompletionRetriever:
14
- @pytest.mark.asyncio
15
- async def test_graph_completion_context_simple(self):
16
- system_directory_path = os.path.join(
17
- pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple"
18
- )
19
- cognee.config.system_root_directory(system_directory_path)
20
- data_directory_path = os.path.join(
21
- pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple"
22
- )
23
- cognee.config.data_root_directory(data_directory_path)
24
-
25
- await cognee.prune.prune_data()
26
- await cognee.prune.prune_system(metadata=True)
27
- await setup()
28
-
29
- class Company(DataPoint):
30
- name: str
31
- description: str
32
-
33
- class Person(DataPoint):
34
- name: str
35
- description: str
36
- works_for: Company
37
-
38
- company1 = Company(name="Figma", description="Figma is a company")
39
- company2 = Company(name="Canva", description="Canvas is a company")
40
- person1 = Person(
41
- name="Steve Rodger",
42
- description="This is description about Steve Rodger",
43
- works_for=company1,
44
- )
45
- person2 = Person(
46
- name="Ike Loma", description="This is description about Ike Loma", works_for=company1
47
- )
48
- person3 = Person(
49
- name="Jason Statham",
50
- description="This is description about Jason Statham",
51
- works_for=company1,
52
- )
53
- person4 = Person(
54
- name="Mike Broski",
55
- description="This is description about Mike Broski",
56
- works_for=company2,
57
- )
58
- person5 = Person(
59
- name="Christina Mayer",
60
- description="This is description about Christina Mayer",
61
- works_for=company2,
62
- )
9
+ @pytest.fixture
10
+ def mock_edge():
11
+ """Create a mock edge."""
12
+ edge = MagicMock(spec=Edge)
13
+ return edge
63
14
 
64
- entities = [company1, company2, person1, person2, person3, person4, person5]
65
15
 
66
- await add_data_points(entities)
16
+ @pytest.mark.asyncio
17
+ async def test_get_triplets_success(mock_edge):
18
+ """Test successful retrieval of triplets."""
19
+ retriever = GraphCompletionRetriever(top_k=5)
67
20
 
68
- retriever = GraphCompletionRetriever()
21
+ with patch(
22
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
23
+ return_value=[mock_edge],
24
+ ) as mock_search:
25
+ triplets = await retriever.get_triplets("test query")
69
26
 
70
- context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
27
+ assert len(triplets) == 1
28
+ assert triplets[0] == mock_edge
29
+ mock_search.assert_awaited_once()
71
30
 
72
- # Ensure the top-level sections are present
73
- assert "Nodes:" in context, "Missing 'Nodes:' section in context"
74
- assert "Connections:" in context, "Missing 'Connections:' section in context"
75
31
 
76
- # --- Nodes headers ---
77
- assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
78
- assert "Node: Figma" in context, "Missing node header for Figma"
79
- assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
80
- assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
81
- assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
82
- assert "Node: Canva" in context, "Missing node header for Canva"
83
- assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
32
+ @pytest.mark.asyncio
33
+ async def test_get_triplets_empty_results():
34
+ """Test that empty list is returned when no triplets are found."""
35
+ retriever = GraphCompletionRetriever()
84
36
 
85
- # --- Node contents ---
86
- assert (
87
- "__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
88
- in context
89
- ), "Description block for Steve Rodger altered"
90
- assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
91
- "Description block for Figma altered"
92
- )
93
- assert (
94
- "__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
95
- in context
96
- ), "Description block for Ike Loma altered"
97
- assert (
98
- "__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
99
- in context
100
- ), "Description block for Jason Statham altered"
101
- assert (
102
- "__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
103
- in context
104
- ), "Description block for Mike Broski altered"
105
- assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
106
- "Description block for Canva altered"
107
- )
108
- assert (
109
- "__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
110
- in context
111
- ), "Description block for Christina Mayer altered"
112
-
113
- # --- Connections ---
114
- assert "Steve Rodger --[works_for]--> Figma" in context, (
115
- "Connection Steve Rodger→Figma missing or changed"
116
- )
117
- assert "Ike Loma --[works_for]--> Figma" in context, (
118
- "Connection Ike Loma→Figma missing or changed"
119
- )
120
- assert "Jason Statham --[works_for]--> Figma" in context, (
121
- "Connection Jason Statham→Figma missing or changed"
122
- )
123
- assert "Mike Broski --[works_for]--> Canva" in context, (
124
- "Connection Mike Broski→Canva missing or changed"
125
- )
126
- assert "Christina Mayer --[works_for]--> Canva" in context, (
127
- "Connection Christina Mayer→Canva missing or changed"
128
- )
129
-
130
- @pytest.mark.asyncio
131
- async def test_graph_completion_context_complex(self):
132
- system_directory_path = os.path.join(
133
- pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex"
134
- )
135
- cognee.config.system_root_directory(system_directory_path)
136
- data_directory_path = os.path.join(
137
- pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex"
138
- )
139
- cognee.config.data_root_directory(data_directory_path)
37
+ with patch(
38
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
39
+ return_value=[],
40
+ ):
41
+ triplets = await retriever.get_triplets("test query")
140
42
 
141
- await cognee.prune.prune_data()
142
- await cognee.prune.prune_system(metadata=True)
143
- await setup()
43
+ assert triplets == []
144
44
 
145
- class Company(DataPoint):
146
- name: str
147
- metadata: dict = {"index_fields": ["name"]}
148
45
 
149
- class Car(DataPoint):
150
- brand: str
151
- model: str
152
- year: int
46
+ @pytest.mark.asyncio
47
+ async def test_get_triplets_top_k_parameter():
48
+ """Test that top_k parameter is passed to brute_force_triplet_search."""
49
+ retriever = GraphCompletionRetriever(top_k=10)
153
50
 
154
- class Location(DataPoint):
155
- country: str
156
- city: str
51
+ with patch(
52
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
53
+ return_value=[],
54
+ ) as mock_search:
55
+ await retriever.get_triplets("test query")
157
56
 
158
- class Home(DataPoint):
159
- location: Location
160
- rooms: int
161
- sqm: int
57
+ call_kwargs = mock_search.call_args[1]
58
+ assert call_kwargs["top_k"] == 10
162
59
 
163
- class Person(DataPoint):
164
- name: str
165
- works_for: Company
166
- owns: Optional[list[Union[Car, Home]]] = None
167
60
 
168
- company1 = Company(name="Figma")
169
- company2 = Company(name="Canva")
170
-
171
- person1 = Person(name="Mike Rodger", works_for=company1)
172
- person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
173
-
174
- person2 = Person(name="Ike Loma", works_for=company1)
175
- person2.owns = [
176
- Car(brand="Tesla", model="Model S", year=2021),
177
- Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
178
- ]
179
-
180
- person3 = Person(name="Jason Statham", works_for=company1)
181
-
182
- person4 = Person(name="Mike Broski", works_for=company2)
183
- person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
184
-
185
- person5 = Person(name="Christina Mayer", works_for=company2)
186
- person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
187
-
188
- entities = [company1, company2, person1, person2, person3, person4, person5]
189
-
190
- await add_data_points(entities)
191
-
192
- retriever = GraphCompletionRetriever(top_k=20)
193
-
194
- context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
195
-
196
- print(context)
197
-
198
- assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
199
- assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
200
- assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
201
-
202
- @pytest.mark.asyncio
203
- async def test_get_graph_completion_context_on_empty_graph(self):
204
- system_directory_path = os.path.join(
205
- pathlib.Path(__file__).parent,
206
- ".cognee_system/test_get_graph_completion_context_on_empty_graph",
61
+ @pytest.mark.asyncio
62
+ async def test_get_context_success(mock_edge):
63
+ """Test successful retrieval of context."""
64
+ retriever = GraphCompletionRetriever()
65
+
66
+ mock_graph_engine = AsyncMock()
67
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
68
+
69
+ with (
70
+ patch(
71
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
72
+ return_value=mock_graph_engine,
73
+ ),
74
+ patch(
75
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
76
+ return_value=[mock_edge],
77
+ ),
78
+ ):
79
+ context = await retriever.get_context("test query")
80
+
81
+ assert isinstance(context, list)
82
+ assert len(context) == 1
83
+ assert context[0] == mock_edge
84
+
85
+
86
+ @pytest.mark.asyncio
87
+ async def test_get_context_empty_results():
88
+ """Test that empty list is returned when no context is found."""
89
+ retriever = GraphCompletionRetriever()
90
+
91
+ mock_graph_engine = AsyncMock()
92
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
93
+
94
+ with (
95
+ patch(
96
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
97
+ return_value=mock_graph_engine,
98
+ ),
99
+ patch(
100
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
101
+ return_value=[],
102
+ ),
103
+ ):
104
+ context = await retriever.get_context("test query")
105
+
106
+ assert context == []
107
+
108
+
109
+ @pytest.mark.asyncio
110
+ async def test_get_context_empty_graph():
111
+ """Test that empty list is returned when graph is empty."""
112
+ retriever = GraphCompletionRetriever()
113
+
114
+ mock_graph_engine = AsyncMock()
115
+ mock_graph_engine.is_empty = AsyncMock(return_value=True)
116
+
117
+ with patch(
118
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
119
+ return_value=mock_graph_engine,
120
+ ):
121
+ context = await retriever.get_context("test query")
122
+
123
+ assert context == []
124
+
125
+
126
+ @pytest.mark.asyncio
127
+ async def test_resolve_edges_to_text(mock_edge):
128
+ """Test resolve_edges_to_text method."""
129
+ retriever = GraphCompletionRetriever()
130
+
131
+ with patch(
132
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
133
+ return_value="Resolved text",
134
+ ) as mock_resolve:
135
+ result = await retriever.resolve_edges_to_text([mock_edge])
136
+
137
+ assert result == "Resolved text"
138
+ mock_resolve.assert_awaited_once_with([mock_edge])
139
+
140
+
141
+ @pytest.mark.asyncio
142
+ async def test_init_defaults():
143
+ """Test GraphCompletionRetriever initialization with defaults."""
144
+ retriever = GraphCompletionRetriever()
145
+
146
+ assert retriever.top_k == 5
147
+ assert retriever.user_prompt_path == "graph_context_for_question.txt"
148
+ assert retriever.system_prompt_path == "answer_simple_question.txt"
149
+ assert retriever.node_type is None
150
+ assert retriever.node_name is None
151
+
152
+
153
+ @pytest.mark.asyncio
154
+ async def test_init_custom_params():
155
+ """Test GraphCompletionRetriever initialization with custom parameters."""
156
+ retriever = GraphCompletionRetriever(
157
+ top_k=10,
158
+ user_prompt_path="custom_user.txt",
159
+ system_prompt_path="custom_system.txt",
160
+ system_prompt="Custom prompt",
161
+ node_type=str,
162
+ node_name=["node1"],
163
+ save_interaction=True,
164
+ wide_search_top_k=200,
165
+ triplet_distance_penalty=5.0,
166
+ )
167
+
168
+ assert retriever.top_k == 10
169
+ assert retriever.user_prompt_path == "custom_user.txt"
170
+ assert retriever.system_prompt_path == "custom_system.txt"
171
+ assert retriever.system_prompt == "Custom prompt"
172
+ assert retriever.node_type is str
173
+ assert retriever.node_name == ["node1"]
174
+ assert retriever.save_interaction is True
175
+ assert retriever.wide_search_top_k == 200
176
+ assert retriever.triplet_distance_penalty == 5.0
177
+
178
+
179
+ @pytest.mark.asyncio
180
+ async def test_init_none_top_k():
181
+ """Test GraphCompletionRetriever initialization with None top_k."""
182
+ retriever = GraphCompletionRetriever(top_k=None)
183
+
184
+ assert retriever.top_k == 5 # None defaults to 5
185
+
186
+
187
+ @pytest.mark.asyncio
188
+ async def test_convert_retrieved_objects_to_context(mock_edge):
189
+ """Test convert_retrieved_objects_to_context method."""
190
+ retriever = GraphCompletionRetriever()
191
+
192
+ with patch(
193
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
194
+ return_value="Resolved text",
195
+ ) as mock_resolve:
196
+ result = await retriever.convert_retrieved_objects_to_context([mock_edge])
197
+
198
+ assert result == "Resolved text"
199
+ mock_resolve.assert_awaited_once_with([mock_edge])
200
+
201
+
202
+ @pytest.mark.asyncio
203
+ async def test_get_completion_without_context(mock_edge):
204
+ """Test get_completion retrieves context when not provided."""
205
+ mock_graph_engine = AsyncMock()
206
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
207
+
208
+ retriever = GraphCompletionRetriever()
209
+
210
+ with (
211
+ patch(
212
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
213
+ return_value=mock_graph_engine,
214
+ ),
215
+ patch(
216
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
217
+ return_value=[mock_edge],
218
+ ),
219
+ patch(
220
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
221
+ return_value="Resolved context",
222
+ ),
223
+ patch(
224
+ "cognee.modules.retrieval.graph_completion_retriever.generate_completion",
225
+ return_value="Generated answer",
226
+ ),
227
+ patch(
228
+ "cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
229
+ ) as mock_cache_config,
230
+ ):
231
+ mock_config = MagicMock()
232
+ mock_config.caching = False
233
+ mock_cache_config.return_value = mock_config
234
+
235
+ completion = await retriever.get_completion("test query")
236
+
237
+ assert isinstance(completion, list)
238
+ assert len(completion) == 1
239
+ assert completion[0] == "Generated answer"
240
+
241
+
242
+ @pytest.mark.asyncio
243
+ async def test_get_completion_with_provided_context(mock_edge):
244
+ """Test get_completion uses provided context."""
245
+ retriever = GraphCompletionRetriever()
246
+
247
+ with (
248
+ patch(
249
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
250
+ return_value="Resolved context",
251
+ ),
252
+ patch(
253
+ "cognee.modules.retrieval.graph_completion_retriever.generate_completion",
254
+ return_value="Generated answer",
255
+ ),
256
+ patch(
257
+ "cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
258
+ ) as mock_cache_config,
259
+ ):
260
+ mock_config = MagicMock()
261
+ mock_config.caching = False
262
+ mock_cache_config.return_value = mock_config
263
+
264
+ completion = await retriever.get_completion("test query", context=[mock_edge])
265
+
266
+ assert isinstance(completion, list)
267
+ assert len(completion) == 1
268
+ assert completion[0] == "Generated answer"
269
+
270
+
271
+ @pytest.mark.asyncio
272
+ async def test_get_completion_with_session(mock_edge):
273
+ """Test get_completion with session caching enabled."""
274
+ mock_graph_engine = AsyncMock()
275
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
276
+
277
+ retriever = GraphCompletionRetriever()
278
+
279
+ mock_user = MagicMock()
280
+ mock_user.id = "test-user-id"
281
+
282
+ with (
283
+ patch(
284
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
285
+ return_value=mock_graph_engine,
286
+ ),
287
+ patch(
288
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
289
+ return_value=[mock_edge],
290
+ ),
291
+ patch(
292
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
293
+ return_value="Resolved context",
294
+ ),
295
+ patch(
296
+ "cognee.modules.retrieval.graph_completion_retriever.get_conversation_history",
297
+ return_value="Previous conversation",
298
+ ),
299
+ patch(
300
+ "cognee.modules.retrieval.graph_completion_retriever.summarize_text",
301
+ return_value="Context summary",
302
+ ),
303
+ patch(
304
+ "cognee.modules.retrieval.graph_completion_retriever.generate_completion",
305
+ return_value="Generated answer",
306
+ ),
307
+ patch(
308
+ "cognee.modules.retrieval.graph_completion_retriever.save_conversation_history",
309
+ ) as mock_save,
310
+ patch(
311
+ "cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
312
+ ) as mock_cache_config,
313
+ patch(
314
+ "cognee.modules.retrieval.graph_completion_retriever.session_user"
315
+ ) as mock_session_user,
316
+ ):
317
+ mock_config = MagicMock()
318
+ mock_config.caching = True
319
+ mock_cache_config.return_value = mock_config
320
+ mock_session_user.get.return_value = mock_user
321
+
322
+ completion = await retriever.get_completion("test query", session_id="test_session")
323
+
324
+ assert isinstance(completion, list)
325
+ assert len(completion) == 1
326
+ assert completion[0] == "Generated answer"
327
+ mock_save.assert_awaited_once()
328
+
329
+
330
+ @pytest.mark.asyncio
331
+ async def test_get_completion_with_response_model(mock_edge):
332
+ """Test get_completion with custom response model."""
333
+ from pydantic import BaseModel
334
+
335
+ class TestModel(BaseModel):
336
+ answer: str
337
+
338
+ mock_graph_engine = AsyncMock()
339
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
340
+
341
+ retriever = GraphCompletionRetriever()
342
+
343
+ with (
344
+ patch(
345
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
346
+ return_value=mock_graph_engine,
347
+ ),
348
+ patch(
349
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
350
+ return_value=[mock_edge],
351
+ ),
352
+ patch(
353
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
354
+ return_value="Resolved context",
355
+ ),
356
+ patch(
357
+ "cognee.modules.retrieval.graph_completion_retriever.generate_completion",
358
+ return_value=TestModel(answer="Test answer"),
359
+ ),
360
+ patch(
361
+ "cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
362
+ ) as mock_cache_config,
363
+ ):
364
+ mock_config = MagicMock()
365
+ mock_config.caching = False
366
+ mock_cache_config.return_value = mock_config
367
+
368
+ completion = await retriever.get_completion("test query", response_model=TestModel)
369
+
370
+ assert isinstance(completion, list)
371
+ assert len(completion) == 1
372
+ assert isinstance(completion[0], TestModel)
373
+
374
+
375
+ @pytest.mark.asyncio
376
+ async def test_get_completion_empty_context(mock_edge):
377
+ """Test get_completion with empty context."""
378
+ mock_graph_engine = AsyncMock()
379
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
380
+
381
+ retriever = GraphCompletionRetriever()
382
+
383
+ with (
384
+ patch(
385
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
386
+ return_value=mock_graph_engine,
387
+ ),
388
+ patch(
389
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
390
+ return_value=[],
391
+ ),
392
+ patch(
393
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
394
+ return_value="",
395
+ ),
396
+ patch(
397
+ "cognee.modules.retrieval.graph_completion_retriever.generate_completion",
398
+ return_value="Generated answer",
399
+ ),
400
+ patch(
401
+ "cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
402
+ ) as mock_cache_config,
403
+ ):
404
+ mock_config = MagicMock()
405
+ mock_config.caching = False
406
+ mock_cache_config.return_value = mock_config
407
+
408
+ completion = await retriever.get_completion("test query")
409
+
410
+ assert isinstance(completion, list)
411
+ assert len(completion) == 1
412
+
413
+
414
+ @pytest.mark.asyncio
415
+ async def test_save_qa(mock_edge):
416
+ """Test save_qa method."""
417
+ mock_graph_engine = AsyncMock()
418
+ mock_graph_engine.add_edges = AsyncMock()
419
+
420
+ retriever = GraphCompletionRetriever()
421
+
422
+ mock_node1 = MagicMock()
423
+ mock_node2 = MagicMock()
424
+ mock_edge.node1 = mock_node1
425
+ mock_edge.node2 = mock_node2
426
+
427
+ with (
428
+ patch(
429
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
430
+ return_value=mock_graph_engine,
431
+ ),
432
+ patch(
433
+ "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
434
+ side_effect=["uuid1", "uuid2"],
435
+ ),
436
+ patch(
437
+ "cognee.modules.retrieval.graph_completion_retriever.add_data_points",
438
+ ) as mock_add_data,
439
+ ):
440
+ await retriever.save_qa(
441
+ question="Test question",
442
+ answer="Test answer",
443
+ context="Test context",
444
+ triplets=[mock_edge],
207
445
  )
208
- cognee.config.system_root_directory(system_directory_path)
209
- data_directory_path = os.path.join(
210
- pathlib.Path(__file__).parent,
211
- ".data_storage/test_get_graph_completion_context_on_empty_graph",
212
- )
213
- cognee.config.data_root_directory(data_directory_path)
214
-
215
- await cognee.prune.prune_data()
216
- await cognee.prune.prune_system(metadata=True)
217
446
 
218
- retriever = GraphCompletionRetriever()
447
+ mock_add_data.assert_awaited_once()
448
+ mock_graph_engine.add_edges.assert_awaited_once()
449
+
450
+
451
+ @pytest.mark.asyncio
452
+ async def test_save_qa_no_triplet_ids(mock_edge):
453
+ """Test save_qa when triplets have no extractable IDs."""
454
+ mock_graph_engine = AsyncMock()
455
+ mock_graph_engine.add_edges = AsyncMock()
456
+
457
+ retriever = GraphCompletionRetriever()
458
+
459
+ mock_node1 = MagicMock()
460
+ mock_node2 = MagicMock()
461
+ mock_edge.node1 = mock_node1
462
+ mock_edge.node2 = mock_node2
463
+
464
+ with (
465
+ patch(
466
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
467
+ return_value=mock_graph_engine,
468
+ ),
469
+ patch(
470
+ "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
471
+ return_value=None,
472
+ ),
473
+ patch(
474
+ "cognee.modules.retrieval.graph_completion_retriever.add_data_points",
475
+ ) as mock_add_data,
476
+ ):
477
+ await retriever.save_qa(
478
+ question="Test question",
479
+ answer="Test answer",
480
+ context="Test context",
481
+ triplets=[mock_edge],
482
+ )
219
483
 
220
- await setup()
484
+ mock_add_data.assert_awaited_once()
485
+ mock_graph_engine.add_edges.assert_not_called()
486
+
487
+
488
+ @pytest.mark.asyncio
489
+ async def test_save_qa_empty_triplets():
490
+ """Test save_qa with empty triplets list."""
491
+ mock_graph_engine = AsyncMock()
492
+ mock_graph_engine.add_edges = AsyncMock()
493
+
494
+ retriever = GraphCompletionRetriever()
495
+
496
+ with (
497
+ patch(
498
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
499
+ return_value=mock_graph_engine,
500
+ ),
501
+ patch(
502
+ "cognee.modules.retrieval.graph_completion_retriever.add_data_points",
503
+ ) as mock_add_data,
504
+ ):
505
+ await retriever.save_qa(
506
+ question="Test question",
507
+ answer="Test answer",
508
+ context="Test context",
509
+ triplets=[],
510
+ )
221
511
 
222
- context = await retriever.get_context("Who works at Figma?")
223
- assert context == [], "Context should be empty on an empty graph"
512
+ mock_add_data.assert_awaited_once()
513
+ mock_graph_engine.add_edges.assert_not_called()
514
+
515
+
516
+ @pytest.mark.asyncio
517
+ async def test_get_completion_with_save_interaction_no_completion(mock_edge):
518
+ """Test get_completion with save_interaction but no completion."""
519
+ mock_graph_engine = AsyncMock()
520
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
521
+
522
+ retriever = GraphCompletionRetriever(save_interaction=True)
523
+
524
+ with (
525
+ patch(
526
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
527
+ return_value=mock_graph_engine,
528
+ ),
529
+ patch(
530
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
531
+ return_value=[mock_edge],
532
+ ),
533
+ patch(
534
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
535
+ return_value="Resolved context",
536
+ ),
537
+ patch(
538
+ "cognee.modules.retrieval.graph_completion_retriever.generate_completion",
539
+ return_value=None, # No completion
540
+ ),
541
+ patch(
542
+ "cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
543
+ ) as mock_cache_config,
544
+ ):
545
+ mock_config = MagicMock()
546
+ mock_config.caching = False
547
+ mock_cache_config.return_value = mock_config
548
+
549
+ completion = await retriever.get_completion("test query")
550
+
551
+ assert isinstance(completion, list)
552
+ assert len(completion) == 1
553
+ assert completion[0] is None
554
+
555
+
556
+ @pytest.mark.asyncio
557
+ async def test_get_completion_with_save_interaction_no_context(mock_edge):
558
+ """Test get_completion with save_interaction but no context provided."""
559
+ mock_graph_engine = AsyncMock()
560
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
561
+
562
+ retriever = GraphCompletionRetriever(save_interaction=True)
563
+
564
+ with (
565
+ patch(
566
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
567
+ return_value=mock_graph_engine,
568
+ ),
569
+ patch(
570
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
571
+ return_value=[mock_edge],
572
+ ),
573
+ patch(
574
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
575
+ return_value="Resolved context",
576
+ ),
577
+ patch(
578
+ "cognee.modules.retrieval.graph_completion_retriever.generate_completion",
579
+ return_value="Generated answer",
580
+ ),
581
+ patch(
582
+ "cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
583
+ ) as mock_cache_config,
584
+ ):
585
+ mock_config = MagicMock()
586
+ mock_config.caching = False
587
+ mock_cache_config.return_value = mock_config
588
+
589
+ completion = await retriever.get_completion("test query", context=None)
590
+
591
+ assert isinstance(completion, list)
592
+ assert len(completion) == 1
593
+
594
+
595
+ @pytest.mark.asyncio
596
+ async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge):
597
+ """Test get_completion with save_interaction when all conditions are met (line 216)."""
598
+ mock_graph_engine = AsyncMock()
599
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
600
+
601
+ retriever = GraphCompletionRetriever(save_interaction=True)
602
+
603
+ mock_node1 = MagicMock()
604
+ mock_node2 = MagicMock()
605
+ mock_edge.node1 = mock_node1
606
+ mock_edge.node2 = mock_node2
607
+
608
+ with (
609
+ patch(
610
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
611
+ return_value=mock_graph_engine,
612
+ ),
613
+ patch(
614
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
615
+ return_value=[mock_edge],
616
+ ),
617
+ patch(
618
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
619
+ return_value="Resolved context",
620
+ ),
621
+ patch(
622
+ "cognee.modules.retrieval.graph_completion_retriever.generate_completion",
623
+ return_value="Generated answer",
624
+ ),
625
+ patch(
626
+ "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
627
+ side_effect=[
628
+ UUID("550e8400-e29b-41d4-a716-446655440000"),
629
+ UUID("550e8400-e29b-41d4-a716-446655440001"),
630
+ ],
631
+ ),
632
+ patch(
633
+ "cognee.modules.retrieval.graph_completion_retriever.add_data_points",
634
+ ) as mock_add_data,
635
+ patch(
636
+ "cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
637
+ ) as mock_cache_config,
638
+ ):
639
+ mock_config = MagicMock()
640
+ mock_config.caching = False
641
+ mock_cache_config.return_value = mock_config
642
+
643
+ completion = await retriever.get_completion("test query", context=[mock_edge])
644
+
645
+ assert isinstance(completion, list)
646
+ assert len(completion) == 1
647
+ assert completion[0] == "Generated answer"
648
+ mock_add_data.assert_awaited_once()