cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. cognee/api/v1/add/add.py +2 -1
  2. cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
  3. cognee/api/v1/memify/routers/get_memify_router.py +1 -0
  4. cognee/api/v1/search/search.py +0 -4
  5. cognee/infrastructure/databases/relational/config.py +16 -1
  6. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  7. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
  8. cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
  9. cognee/infrastructure/llm/LLMGateway.py +0 -13
  10. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  11. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  12. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  13. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
  14. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  15. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  16. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  17. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  18. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  19. cognee/modules/data/models/Data.py +2 -1
  20. cognee/modules/retrieval/triplet_retriever.py +1 -1
  21. cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
  22. cognee/modules/search/methods/search.py +18 -25
  23. cognee/tasks/ingestion/data_item.py +8 -0
  24. cognee/tasks/ingestion/ingest_data.py +12 -1
  25. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  26. cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
  27. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
  28. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
  29. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
  30. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
  31. cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
  32. cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
  33. cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
  34. cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
  35. cognee/tests/test_custom_data_label.py +68 -0
  36. cognee/tests/test_search_db.py +334 -181
  37. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  38. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  39. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  40. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
  41. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  42. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
  43. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
  44. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
  45. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
  46. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
  47. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
  48. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
  49. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  50. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  51. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  52. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
  53. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
  54. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
  55. cognee/tests/unit/modules/search/test_search.py +0 -100
  56. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
  57. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
  58. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
  59. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,177 +1,469 @@
1
- import os
2
1
  import pytest
3
- import pathlib
4
- from typing import Optional, Union
2
+ from unittest.mock import AsyncMock, patch, MagicMock
3
+ from uuid import UUID
5
4
 
6
- import cognee
7
- from cognee.low_level import setup, DataPoint
8
- from cognee.tasks.storage import add_data_points
9
- from cognee.modules.graph.utils import resolve_edges_to_text
10
5
  from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
11
6
  GraphCompletionContextExtensionRetriever,
12
7
  )
13
-
14
-
15
- class TestGraphCompletionWithContextExtensionRetriever:
16
- @pytest.mark.asyncio
17
- async def test_graph_completion_extension_context_simple(self):
18
- system_directory_path = os.path.join(
19
- pathlib.Path(__file__).parent,
20
- ".cognee_system/test_graph_completion_extension_context_simple",
21
- )
22
- cognee.config.system_root_directory(system_directory_path)
23
- data_directory_path = os.path.join(
24
- pathlib.Path(__file__).parent,
25
- ".data_storage/test_graph_completion_extension_context_simple",
8
+ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
9
+
10
+
11
+ @pytest.fixture
12
+ def mock_edge():
13
+ """Create a mock edge."""
14
+ edge = MagicMock(spec=Edge)
15
+ return edge
16
+
17
+
18
+ @pytest.mark.asyncio
19
+ async def test_get_triplets_inherited(mock_edge):
20
+ """Test that get_triplets is inherited from parent class."""
21
+ retriever = GraphCompletionContextExtensionRetriever()
22
+
23
+ with patch(
24
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
25
+ return_value=[mock_edge],
26
+ ):
27
+ triplets = await retriever.get_triplets("test query")
28
+
29
+ assert len(triplets) == 1
30
+ assert triplets[0] == mock_edge
31
+
32
+
33
+ @pytest.mark.asyncio
34
+ async def test_init_defaults():
35
+ """Test GraphCompletionContextExtensionRetriever initialization with defaults."""
36
+ retriever = GraphCompletionContextExtensionRetriever()
37
+
38
+ assert retriever.top_k == 5
39
+ assert retriever.user_prompt_path == "graph_context_for_question.txt"
40
+ assert retriever.system_prompt_path == "answer_simple_question.txt"
41
+
42
+
43
+ @pytest.mark.asyncio
44
+ async def test_init_custom_params():
45
+ """Test GraphCompletionContextExtensionRetriever initialization with custom parameters."""
46
+ retriever = GraphCompletionContextExtensionRetriever(
47
+ top_k=10,
48
+ user_prompt_path="custom_user.txt",
49
+ system_prompt_path="custom_system.txt",
50
+ system_prompt="Custom prompt",
51
+ node_type=str,
52
+ node_name=["node1"],
53
+ save_interaction=True,
54
+ wide_search_top_k=200,
55
+ triplet_distance_penalty=5.0,
56
+ )
57
+
58
+ assert retriever.top_k == 10
59
+ assert retriever.user_prompt_path == "custom_user.txt"
60
+ assert retriever.system_prompt_path == "custom_system.txt"
61
+ assert retriever.system_prompt == "Custom prompt"
62
+ assert retriever.node_type is str
63
+ assert retriever.node_name == ["node1"]
64
+ assert retriever.save_interaction is True
65
+ assert retriever.wide_search_top_k == 200
66
+ assert retriever.triplet_distance_penalty == 5.0
67
+
68
+
69
+ @pytest.mark.asyncio
70
+ async def test_get_completion_without_context(mock_edge):
71
+ """Test get_completion retrieves context when not provided."""
72
+ mock_graph_engine = AsyncMock()
73
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
74
+
75
+ retriever = GraphCompletionContextExtensionRetriever()
76
+
77
+ with (
78
+ patch(
79
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
80
+ return_value=mock_graph_engine,
81
+ ),
82
+ patch(
83
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
84
+ return_value=[mock_edge],
85
+ ),
86
+ patch(
87
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
88
+ return_value="Resolved context",
89
+ ),
90
+ patch(
91
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
92
+ return_value="Generated answer",
93
+ ),
94
+ patch(
95
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
96
+ ) as mock_cache_config,
97
+ ):
98
+ mock_config = MagicMock()
99
+ mock_config.caching = False
100
+ mock_cache_config.return_value = mock_config
101
+
102
+ completion = await retriever.get_completion("test query", context_extension_rounds=1)
103
+
104
+ assert isinstance(completion, list)
105
+ assert len(completion) == 1
106
+ assert completion[0] == "Generated answer"
107
+
108
+
109
+ @pytest.mark.asyncio
110
+ async def test_get_completion_with_provided_context(mock_edge):
111
+ """Test get_completion uses provided context."""
112
+ retriever = GraphCompletionContextExtensionRetriever()
113
+
114
+ with (
115
+ patch(
116
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
117
+ return_value="Resolved context",
118
+ ),
119
+ patch(
120
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
121
+ return_value="Generated answer",
122
+ ),
123
+ patch(
124
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
125
+ ) as mock_cache_config,
126
+ ):
127
+ mock_config = MagicMock()
128
+ mock_config.caching = False
129
+ mock_cache_config.return_value = mock_config
130
+
131
+ completion = await retriever.get_completion(
132
+ "test query", context=[mock_edge], context_extension_rounds=1
26
133
  )
27
- cognee.config.data_root_directory(data_directory_path)
28
-
29
- await cognee.prune.prune_data()
30
- await cognee.prune.prune_system(metadata=True)
31
- await setup()
32
-
33
- class Company(DataPoint):
34
- name: str
35
-
36
- class Person(DataPoint):
37
- name: str
38
- works_for: Company
39
-
40
- company1 = Company(name="Figma")
41
- company2 = Company(name="Canva")
42
- person1 = Person(name="Steve Rodger", works_for=company1)
43
- person2 = Person(name="Ike Loma", works_for=company1)
44
- person3 = Person(name="Jason Statham", works_for=company1)
45
- person4 = Person(name="Mike Broski", works_for=company2)
46
- person5 = Person(name="Christina Mayer", works_for=company2)
47
-
48
- entities = [company1, company2, person1, person2, person3, person4, person5]
49
-
50
- await add_data_points(entities)
51
-
52
- retriever = GraphCompletionContextExtensionRetriever()
53
-
54
- context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
55
134
 
56
- assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
57
- assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
58
-
59
- answer = await retriever.get_completion("Who works at Canva?")
60
-
61
- assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
62
- assert all(isinstance(item, str) and item.strip() for item in answer), (
63
- "Answer must contain only non-empty strings"
135
+ assert isinstance(completion, list)
136
+ assert len(completion) == 1
137
+ assert completion[0] == "Generated answer"
138
+
139
+
140
+ @pytest.mark.asyncio
141
+ async def test_get_completion_context_extension_rounds(mock_edge):
142
+ """Test get_completion with multiple context extension rounds."""
143
+ mock_graph_engine = AsyncMock()
144
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
145
+
146
+ retriever = GraphCompletionContextExtensionRetriever()
147
+
148
+ # Create a second edge for extension rounds
149
+ mock_edge2 = MagicMock(spec=Edge)
150
+
151
+ with (
152
+ patch(
153
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
154
+ return_value=mock_graph_engine,
155
+ ),
156
+ patch.object(
157
+ retriever,
158
+ "get_context",
159
+ new_callable=AsyncMock,
160
+ side_effect=[[mock_edge], [mock_edge2]],
161
+ ),
162
+ patch(
163
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
164
+ side_effect=["Resolved context", "Extended context"], # Different contexts
165
+ ),
166
+ patch(
167
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
168
+ side_effect=[
169
+ "Extension query",
170
+ "Generated answer",
171
+ ], # Query for extension, then final answer
172
+ ),
173
+ patch(
174
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
175
+ ) as mock_cache_config,
176
+ ):
177
+ mock_config = MagicMock()
178
+ mock_config.caching = False
179
+ mock_cache_config.return_value = mock_config
180
+
181
+ completion = await retriever.get_completion("test query", context_extension_rounds=1)
182
+
183
+ assert isinstance(completion, list)
184
+ assert len(completion) == 1
185
+ assert completion[0] == "Generated answer"
186
+
187
+
188
+ @pytest.mark.asyncio
189
+ async def test_get_completion_context_extension_stops_early(mock_edge):
190
+ """Test get_completion stops early when no new triplets found."""
191
+ mock_graph_engine = AsyncMock()
192
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
193
+
194
+ retriever = GraphCompletionContextExtensionRetriever()
195
+
196
+ with (
197
+ patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
198
+ patch(
199
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
200
+ return_value="Resolved context",
201
+ ),
202
+ patch(
203
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
204
+ side_effect=[
205
+ "Extension query",
206
+ "Generated answer",
207
+ ],
208
+ ),
209
+ patch(
210
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
211
+ ) as mock_cache_config,
212
+ ):
213
+ mock_config = MagicMock()
214
+ mock_config.caching = False
215
+ mock_cache_config.return_value = mock_config
216
+
217
+ # When get_context returns same triplets, the loop should stop early
218
+ completion = await retriever.get_completion(
219
+ "test query", context=[mock_edge], context_extension_rounds=4
64
220
  )
65
221
 
66
- @pytest.mark.asyncio
67
- async def test_graph_completion_extension_context_complex(self):
68
- system_directory_path = os.path.join(
69
- pathlib.Path(__file__).parent,
70
- ".cognee_system/test_graph_completion_extension_context_complex",
71
- )
72
- cognee.config.system_root_directory(system_directory_path)
73
- data_directory_path = os.path.join(
74
- pathlib.Path(__file__).parent,
75
- ".data_storage/test_graph_completion_extension_context_complex",
222
+ assert isinstance(completion, list)
223
+ assert len(completion) == 1
224
+ assert completion[0] == "Generated answer"
225
+
226
+
227
+ @pytest.mark.asyncio
228
+ async def test_get_completion_with_session(mock_edge):
229
+ """Test get_completion with session caching enabled."""
230
+ mock_graph_engine = AsyncMock()
231
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
232
+
233
+ retriever = GraphCompletionContextExtensionRetriever()
234
+
235
+ mock_user = MagicMock()
236
+ mock_user.id = "test-user-id"
237
+
238
+ with (
239
+ patch(
240
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
241
+ return_value=mock_graph_engine,
242
+ ),
243
+ patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
244
+ patch(
245
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
246
+ return_value="Resolved context",
247
+ ),
248
+ patch(
249
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.get_conversation_history",
250
+ return_value="Previous conversation",
251
+ ),
252
+ patch(
253
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.summarize_text",
254
+ return_value="Context summary",
255
+ ),
256
+ patch(
257
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
258
+ side_effect=[
259
+ "Extension query",
260
+ "Generated answer",
261
+ ], # Extension query, then final answer
262
+ ),
263
+ patch(
264
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.save_conversation_history",
265
+ ) as mock_save,
266
+ patch(
267
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
268
+ ) as mock_cache_config,
269
+ patch(
270
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
271
+ ) as mock_session_user,
272
+ ):
273
+ mock_config = MagicMock()
274
+ mock_config.caching = True
275
+ mock_cache_config.return_value = mock_config
276
+ mock_session_user.get.return_value = mock_user
277
+
278
+ completion = await retriever.get_completion(
279
+ "test query", session_id="test_session", context_extension_rounds=1
76
280
  )
77
- cognee.config.data_root_directory(data_directory_path)
78
-
79
- await cognee.prune.prune_data()
80
- await cognee.prune.prune_system(metadata=True)
81
- await setup()
82
-
83
- class Company(DataPoint):
84
- name: str
85
- metadata: dict = {"index_fields": ["name"]}
86
-
87
- class Car(DataPoint):
88
- brand: str
89
- model: str
90
- year: int
91
-
92
- class Location(DataPoint):
93
- country: str
94
- city: str
95
-
96
- class Home(DataPoint):
97
- location: Location
98
- rooms: int
99
- sqm: int
100
-
101
- class Person(DataPoint):
102
- name: str
103
- works_for: Company
104
- owns: Optional[list[Union[Car, Home]]] = None
105
-
106
- company1 = Company(name="Figma")
107
- company2 = Company(name="Canva")
108
281
 
109
- person1 = Person(name="Mike Rodger", works_for=company1)
110
- person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
111
-
112
- person2 = Person(name="Ike Loma", works_for=company1)
113
- person2.owns = [
114
- Car(brand="Tesla", model="Model S", year=2021),
115
- Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
116
- ]
117
-
118
- person3 = Person(name="Jason Statham", works_for=company1)
119
-
120
- person4 = Person(name="Mike Broski", works_for=company2)
121
- person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
122
-
123
- person5 = Person(name="Christina Mayer", works_for=company2)
124
- person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
125
-
126
- entities = [company1, company2, person1, person2, person3, person4, person5]
127
-
128
- await add_data_points(entities)
129
-
130
- retriever = GraphCompletionContextExtensionRetriever(top_k=20)
131
-
132
- context = await resolve_edges_to_text(
133
- await retriever.get_context("Who works at Figma and drives Tesla?")
134
- )
135
-
136
- print(context)
137
-
138
- assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
139
- assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
140
- assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
141
-
142
- answer = await retriever.get_completion("Who works at Figma?")
143
-
144
- assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
145
- assert all(isinstance(item, str) and item.strip() for item in answer), (
146
- "Answer must contain only non-empty strings"
282
+ assert isinstance(completion, list)
283
+ assert len(completion) == 1
284
+ assert completion[0] == "Generated answer"
285
+ mock_save.assert_awaited_once()
286
+
287
+
288
+ @pytest.mark.asyncio
289
+ async def test_get_completion_with_save_interaction(mock_edge):
290
+ """Test get_completion with save_interaction enabled."""
291
+ mock_graph_engine = AsyncMock()
292
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
293
+ mock_graph_engine.add_edges = AsyncMock()
294
+
295
+ retriever = GraphCompletionContextExtensionRetriever(save_interaction=True)
296
+
297
+ mock_node1 = MagicMock()
298
+ mock_node2 = MagicMock()
299
+ mock_edge.node1 = mock_node1
300
+ mock_edge.node2 = mock_node2
301
+
302
+ with (
303
+ patch(
304
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
305
+ return_value=mock_graph_engine,
306
+ ),
307
+ patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
308
+ patch(
309
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
310
+ return_value="Resolved context",
311
+ ),
312
+ patch(
313
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
314
+ side_effect=[
315
+ "Extension query",
316
+ "Generated answer",
317
+ ], # Extension query, then final answer
318
+ ),
319
+ patch(
320
+ "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
321
+ side_effect=[
322
+ UUID("550e8400-e29b-41d4-a716-446655440000"),
323
+ UUID("550e8400-e29b-41d4-a716-446655440001"),
324
+ ],
325
+ ),
326
+ patch(
327
+ "cognee.modules.retrieval.graph_completion_retriever.add_data_points",
328
+ ) as mock_add_data,
329
+ patch(
330
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
331
+ ) as mock_cache_config,
332
+ ):
333
+ mock_config = MagicMock()
334
+ mock_config.caching = False
335
+ mock_cache_config.return_value = mock_config
336
+
337
+ completion = await retriever.get_completion(
338
+ "test query", context=[mock_edge], context_extension_rounds=1
147
339
  )
148
340
 
149
- @pytest.mark.asyncio
150
- async def test_get_graph_completion_extension_context_on_empty_graph(self):
151
- system_directory_path = os.path.join(
152
- pathlib.Path(__file__).parent,
153
- ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph",
341
+ assert isinstance(completion, list)
342
+ assert len(completion) == 1
343
+ mock_add_data.assert_awaited_once()
344
+
345
+
346
+ @pytest.mark.asyncio
347
+ async def test_get_completion_with_response_model(mock_edge):
348
+ """Test get_completion with custom response model."""
349
+ from pydantic import BaseModel
350
+
351
+ class TestModel(BaseModel):
352
+ answer: str
353
+
354
+ mock_graph_engine = AsyncMock()
355
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
356
+
357
+ retriever = GraphCompletionContextExtensionRetriever()
358
+
359
+ with (
360
+ patch(
361
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
362
+ return_value=mock_graph_engine,
363
+ ),
364
+ patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
365
+ patch(
366
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
367
+ return_value="Resolved context",
368
+ ),
369
+ patch(
370
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
371
+ side_effect=[
372
+ "Extension query",
373
+ TestModel(answer="Test answer"),
374
+ ], # Extension query, then final answer
375
+ ),
376
+ patch(
377
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
378
+ ) as mock_cache_config,
379
+ ):
380
+ mock_config = MagicMock()
381
+ mock_config.caching = False
382
+ mock_cache_config.return_value = mock_config
383
+
384
+ completion = await retriever.get_completion(
385
+ "test query", response_model=TestModel, context_extension_rounds=1
154
386
  )
155
- cognee.config.system_root_directory(system_directory_path)
156
- data_directory_path = os.path.join(
157
- pathlib.Path(__file__).parent,
158
- ".data_storage/test_get_graph_completion_extension_context_on_empty_graph",
159
- )
160
- cognee.config.data_root_directory(data_directory_path)
161
-
162
- await cognee.prune.prune_data()
163
- await cognee.prune.prune_system(metadata=True)
164
387
 
165
- retriever = GraphCompletionContextExtensionRetriever()
166
-
167
- await setup()
168
-
169
- context = await retriever.get_context("Who works at Figma?")
170
- assert context == [], "Context should be empty on an empty graph"
171
-
172
- answer = await retriever.get_completion("Who works at Figma?")
173
-
174
- assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
175
- assert all(isinstance(item, str) and item.strip() for item in answer), (
176
- "Answer must contain only non-empty strings"
177
- )
388
+ assert isinstance(completion, list)
389
+ assert len(completion) == 1
390
+ assert isinstance(completion[0], TestModel)
391
+
392
+
393
+ @pytest.mark.asyncio
394
+ async def test_get_completion_with_session_no_user_id(mock_edge):
395
+ """Test get_completion with session config but no user ID."""
396
+ mock_graph_engine = AsyncMock()
397
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
398
+
399
+ retriever = GraphCompletionContextExtensionRetriever()
400
+
401
+ with (
402
+ patch(
403
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
404
+ return_value=mock_graph_engine,
405
+ ),
406
+ patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
407
+ patch(
408
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
409
+ return_value="Resolved context",
410
+ ),
411
+ patch(
412
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
413
+ side_effect=[
414
+ "Extension query",
415
+ "Generated answer",
416
+ ], # Extension query, then final answer
417
+ ),
418
+ patch(
419
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
420
+ ) as mock_cache_config,
421
+ patch(
422
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
423
+ ) as mock_session_user,
424
+ ):
425
+ mock_config = MagicMock()
426
+ mock_config.caching = True
427
+ mock_cache_config.return_value = mock_config
428
+ mock_session_user.get.return_value = None # No user
429
+
430
+ completion = await retriever.get_completion("test query", context_extension_rounds=1)
431
+
432
+ assert isinstance(completion, list)
433
+ assert len(completion) == 1
434
+
435
+
436
+ @pytest.mark.asyncio
437
+ async def test_get_completion_zero_extension_rounds(mock_edge):
438
+ """Test get_completion with zero context extension rounds."""
439
+ mock_graph_engine = AsyncMock()
440
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
441
+
442
+ retriever = GraphCompletionContextExtensionRetriever()
443
+
444
+ with (
445
+ patch(
446
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
447
+ return_value=mock_graph_engine,
448
+ ),
449
+ patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
450
+ patch(
451
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
452
+ return_value="Resolved context",
453
+ ),
454
+ patch(
455
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
456
+ return_value="Generated answer",
457
+ ),
458
+ patch(
459
+ "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
460
+ ) as mock_cache_config,
461
+ ):
462
+ mock_config = MagicMock()
463
+ mock_config.caching = False
464
+ mock_cache_config.return_value = mock_config
465
+
466
+ completion = await retriever.get_completion("test query", context_extension_rounds=0)
467
+
468
+ assert isinstance(completion, list)
469
+ assert len(completion) == 1