cognee 0.5.0.dev1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. cognee/api/v1/add/add.py +2 -1
  2. cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
  3. cognee/api/v1/memify/routers/get_memify_router.py +1 -0
  4. cognee/infrastructure/databases/relational/config.py +16 -1
  5. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  6. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
  7. cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
  8. cognee/infrastructure/llm/LLMGateway.py +0 -13
  9. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  10. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  11. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  12. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
  13. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  14. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  15. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  16. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  17. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  18. cognee/modules/data/models/Data.py +2 -1
  19. cognee/modules/retrieval/triplet_retriever.py +1 -1
  20. cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
  21. cognee/tasks/ingestion/data_item.py +8 -0
  22. cognee/tasks/ingestion/ingest_data.py +12 -1
  23. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  24. cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
  25. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
  26. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
  27. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
  28. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
  29. cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
  30. cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
  31. cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
  32. cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
  33. cognee/tests/test_custom_data_label.py +68 -0
  34. cognee/tests/test_search_db.py +334 -181
  35. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  36. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  37. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  38. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
  39. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  40. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
  41. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
  42. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
  43. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
  44. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
  45. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
  46. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
  47. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  48. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  49. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  50. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
  51. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
  52. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +56 -42
  53. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
  54. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
  55. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
  56. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,170 +1,688 @@
1
- import os
2
1
  import pytest
3
- import pathlib
4
- from typing import Optional, Union
5
-
6
- import cognee
7
- from cognee.low_level import setup, DataPoint
8
- from cognee.modules.graph.utils import resolve_edges_to_text
9
- from cognee.tasks.storage import add_data_points
10
- from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
11
-
12
-
13
- class TestGraphCompletionCoTRetriever:
14
- @pytest.mark.asyncio
15
- async def test_graph_completion_cot_context_simple(self):
16
- system_directory_path = os.path.join(
17
- pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple"
2
+ from unittest.mock import AsyncMock, patch, MagicMock
3
+ from uuid import UUID
4
+
5
+ from cognee.modules.retrieval.graph_completion_cot_retriever import (
6
+ GraphCompletionCotRetriever,
7
+ _as_answer_text,
8
+ )
9
+ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
10
+ from cognee.infrastructure.llm.LLMGateway import LLMGateway
11
+
12
+
13
+ @pytest.fixture
14
+ def mock_edge():
15
+ """Create a mock edge."""
16
+ edge = MagicMock(spec=Edge)
17
+ return edge
18
+
19
+
20
+ @pytest.mark.asyncio
21
+ async def test_get_triplets_inherited(mock_edge):
22
+ """Test that get_triplets is inherited from parent class."""
23
+ retriever = GraphCompletionCotRetriever()
24
+
25
+ with patch(
26
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
27
+ return_value=[mock_edge],
28
+ ):
29
+ triplets = await retriever.get_triplets("test query")
30
+
31
+ assert len(triplets) == 1
32
+ assert triplets[0] == mock_edge
33
+
34
+
35
+ @pytest.mark.asyncio
36
+ async def test_init_custom_params():
37
+ """Test GraphCompletionCotRetriever initialization with custom parameters."""
38
+ retriever = GraphCompletionCotRetriever(
39
+ top_k=10,
40
+ user_prompt_path="custom_user.txt",
41
+ system_prompt_path="custom_system.txt",
42
+ validation_user_prompt_path="custom_validation_user.txt",
43
+ validation_system_prompt_path="custom_validation_system.txt",
44
+ followup_system_prompt_path="custom_followup_system.txt",
45
+ followup_user_prompt_path="custom_followup_user.txt",
46
+ )
47
+
48
+ assert retriever.top_k == 10
49
+ assert retriever.user_prompt_path == "custom_user.txt"
50
+ assert retriever.system_prompt_path == "custom_system.txt"
51
+ assert retriever.validation_user_prompt_path == "custom_validation_user.txt"
52
+ assert retriever.validation_system_prompt_path == "custom_validation_system.txt"
53
+ assert retriever.followup_system_prompt_path == "custom_followup_system.txt"
54
+ assert retriever.followup_user_prompt_path == "custom_followup_user.txt"
55
+
56
+
57
+ @pytest.mark.asyncio
58
+ async def test_init_defaults():
59
+ """Test GraphCompletionCotRetriever initialization with defaults."""
60
+ retriever = GraphCompletionCotRetriever()
61
+
62
+ assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt"
63
+ assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt"
64
+ assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt"
65
+ assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt"
66
+
67
+
68
+ @pytest.mark.asyncio
69
+ async def test_run_cot_completion_round_zero_with_context(mock_edge):
70
+ """Test _run_cot_completion round 0 with provided context."""
71
+ retriever = GraphCompletionCotRetriever()
72
+
73
+ with (
74
+ patch(
75
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
76
+ return_value="Resolved context",
77
+ ),
78
+ patch(
79
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
80
+ return_value="Generated answer",
81
+ ),
82
+ patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
83
+ patch(
84
+ "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
85
+ return_value="Generated answer",
86
+ ),
87
+ patch(
88
+ "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
89
+ return_value="Rendered prompt",
90
+ ),
91
+ patch(
92
+ "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
93
+ return_value="System prompt",
94
+ ),
95
+ patch.object(
96
+ LLMGateway,
97
+ "acreate_structured_output",
98
+ new_callable=AsyncMock,
99
+ side_effect=["validation_result", "followup_question"],
100
+ ),
101
+ ):
102
+ completion, context_text, triplets = await retriever._run_cot_completion(
103
+ query="test query",
104
+ context=[mock_edge],
105
+ max_iter=1,
18
106
  )
19
- cognee.config.system_root_directory(system_directory_path)
20
- data_directory_path = os.path.join(
21
- pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple"
22
- )
23
- cognee.config.data_root_directory(data_directory_path)
24
-
25
- await cognee.prune.prune_data()
26
- await cognee.prune.prune_system(metadata=True)
27
- await setup()
28
-
29
- class Company(DataPoint):
30
- name: str
31
-
32
- class Person(DataPoint):
33
- name: str
34
- works_for: Company
35
-
36
- company1 = Company(name="Figma")
37
- company2 = Company(name="Canva")
38
- person1 = Person(name="Steve Rodger", works_for=company1)
39
- person2 = Person(name="Ike Loma", works_for=company1)
40
- person3 = Person(name="Jason Statham", works_for=company1)
41
- person4 = Person(name="Mike Broski", works_for=company2)
42
- person5 = Person(name="Christina Mayer", works_for=company2)
43
-
44
- entities = [company1, company2, person1, person2, person3, person4, person5]
45
-
46
- await add_data_points(entities)
47
-
48
- retriever = GraphCompletionCotRetriever()
49
107
 
50
- context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
51
-
52
- assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
53
- assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
54
-
55
- answer = await retriever.get_completion("Who works at Canva?")
56
-
57
- assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
58
- assert all(isinstance(item, str) and item.strip() for item in answer), (
59
- "Answer must contain only non-empty strings"
108
+ assert completion == "Generated answer"
109
+ assert context_text == "Resolved context"
110
+ assert len(triplets) >= 1
111
+
112
+
113
+ @pytest.mark.asyncio
114
+ async def test_run_cot_completion_round_zero_without_context(mock_edge):
115
+ """Test _run_cot_completion round 0 without provided context."""
116
+ mock_graph_engine = AsyncMock()
117
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
118
+
119
+ retriever = GraphCompletionCotRetriever()
120
+
121
+ with (
122
+ patch(
123
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
124
+ return_value=mock_graph_engine,
125
+ ),
126
+ patch(
127
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
128
+ return_value=[mock_edge],
129
+ ),
130
+ patch(
131
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
132
+ return_value="Resolved context",
133
+ ),
134
+ patch(
135
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
136
+ return_value="Generated answer",
137
+ ),
138
+ ):
139
+ completion, context_text, triplets = await retriever._run_cot_completion(
140
+ query="test query",
141
+ context=None,
142
+ max_iter=1,
60
143
  )
61
144
 
62
- @pytest.mark.asyncio
63
- async def test_graph_completion_cot_context_complex(self):
64
- system_directory_path = os.path.join(
65
- pathlib.Path(__file__).parent,
66
- ".cognee_system/test_graph_completion_cot_context_complex",
145
+ assert completion == "Generated answer"
146
+ assert context_text == "Resolved context"
147
+ assert len(triplets) >= 1
148
+
149
+
150
+ @pytest.mark.asyncio
151
+ async def test_run_cot_completion_multiple_rounds(mock_edge):
152
+ """Test _run_cot_completion with multiple rounds."""
153
+ retriever = GraphCompletionCotRetriever()
154
+
155
+ mock_edge2 = MagicMock(spec=Edge)
156
+
157
+ with (
158
+ patch(
159
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
160
+ return_value="Resolved context",
161
+ ),
162
+ patch(
163
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
164
+ return_value="Generated answer",
165
+ ),
166
+ patch.object(
167
+ retriever,
168
+ "get_context",
169
+ new_callable=AsyncMock,
170
+ side_effect=[[mock_edge], [mock_edge2]],
171
+ ),
172
+ patch(
173
+ "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
174
+ return_value="Rendered prompt",
175
+ ),
176
+ patch(
177
+ "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
178
+ return_value="System prompt",
179
+ ),
180
+ patch.object(
181
+ LLMGateway,
182
+ "acreate_structured_output",
183
+ new_callable=AsyncMock,
184
+ side_effect=[
185
+ "validation_result",
186
+ "followup_question",
187
+ "validation_result2",
188
+ "followup_question2",
189
+ ],
190
+ ),
191
+ patch(
192
+ "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
193
+ return_value="Generated answer",
194
+ ),
195
+ ):
196
+ completion, context_text, triplets = await retriever._run_cot_completion(
197
+ query="test query",
198
+ context=[mock_edge],
199
+ max_iter=2,
67
200
  )
68
- cognee.config.system_root_directory(system_directory_path)
69
- data_directory_path = os.path.join(
70
- pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex"
71
- )
72
- cognee.config.data_root_directory(data_directory_path)
73
-
74
- await cognee.prune.prune_data()
75
- await cognee.prune.prune_system(metadata=True)
76
- await setup()
77
-
78
- class Company(DataPoint):
79
- name: str
80
- metadata: dict = {"index_fields": ["name"]}
81
-
82
- class Car(DataPoint):
83
- brand: str
84
- model: str
85
- year: int
86
-
87
- class Location(DataPoint):
88
- country: str
89
- city: str
90
-
91
- class Home(DataPoint):
92
- location: Location
93
- rooms: int
94
- sqm: int
95
-
96
- class Person(DataPoint):
97
- name: str
98
- works_for: Company
99
- owns: Optional[list[Union[Car, Home]]] = None
100
201
 
101
- company1 = Company(name="Figma")
102
- company2 = Company(name="Canva")
103
-
104
- person1 = Person(name="Mike Rodger", works_for=company1)
105
- person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
106
-
107
- person2 = Person(name="Ike Loma", works_for=company1)
108
- person2.owns = [
109
- Car(brand="Tesla", model="Model S", year=2021),
110
- Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
111
- ]
112
-
113
- person3 = Person(name="Jason Statham", works_for=company1)
114
-
115
- person4 = Person(name="Mike Broski", works_for=company2)
116
- person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
117
-
118
- person5 = Person(name="Christina Mayer", works_for=company2)
119
- person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
120
-
121
- entities = [company1, company2, person1, person2, person3, person4, person5]
122
-
123
- await add_data_points(entities)
124
-
125
- retriever = GraphCompletionCotRetriever(top_k=20)
126
-
127
- context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
128
-
129
- print(context)
130
-
131
- assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
132
- assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
133
- assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
134
-
135
- answer = await retriever.get_completion("Who works at Figma?")
136
-
137
- assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
138
- assert all(isinstance(item, str) and item.strip() for item in answer), (
139
- "Answer must contain only non-empty strings"
202
+ assert completion == "Generated answer"
203
+ assert context_text == "Resolved context"
204
+ assert len(triplets) >= 1
205
+
206
+
207
+ @pytest.mark.asyncio
208
+ async def test_run_cot_completion_with_conversation_history(mock_edge):
209
+ """Test _run_cot_completion with conversation history."""
210
+ retriever = GraphCompletionCotRetriever()
211
+
212
+ with (
213
+ patch(
214
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
215
+ return_value="Resolved context",
216
+ ),
217
+ patch(
218
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
219
+ return_value="Generated answer",
220
+ ) as mock_generate,
221
+ ):
222
+ completion, context_text, triplets = await retriever._run_cot_completion(
223
+ query="test query",
224
+ context=[mock_edge],
225
+ conversation_history="Previous conversation",
226
+ max_iter=1,
140
227
  )
141
228
 
142
- @pytest.mark.asyncio
143
- async def test_get_graph_completion_cot_context_on_empty_graph(self):
144
- system_directory_path = os.path.join(
145
- pathlib.Path(__file__).parent,
146
- ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph",
229
+ assert completion == "Generated answer"
230
+ call_kwargs = mock_generate.call_args[1]
231
+ assert call_kwargs.get("conversation_history") == "Previous conversation"
232
+
233
+
234
+ @pytest.mark.asyncio
235
+ async def test_run_cot_completion_with_response_model(mock_edge):
236
+ """Test _run_cot_completion with custom response model."""
237
+ from pydantic import BaseModel
238
+
239
+ class TestModel(BaseModel):
240
+ answer: str
241
+
242
+ retriever = GraphCompletionCotRetriever()
243
+
244
+ with (
245
+ patch(
246
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
247
+ return_value="Resolved context",
248
+ ),
249
+ patch(
250
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
251
+ return_value=TestModel(answer="Test answer"),
252
+ ),
253
+ ):
254
+ completion, context_text, triplets = await retriever._run_cot_completion(
255
+ query="test query",
256
+ context=[mock_edge],
257
+ response_model=TestModel,
258
+ max_iter=1,
147
259
  )
148
- cognee.config.system_root_directory(system_directory_path)
149
- data_directory_path = os.path.join(
150
- pathlib.Path(__file__).parent,
151
- ".data_storage/test_get_graph_completion_cot_context_on_empty_graph",
152
- )
153
- cognee.config.data_root_directory(data_directory_path)
154
-
155
- await cognee.prune.prune_data()
156
- await cognee.prune.prune_system(metadata=True)
157
260
 
158
- retriever = GraphCompletionCotRetriever()
159
-
160
- await setup()
161
-
162
- context = await retriever.get_context("Who works at Figma?")
163
- assert context == [], "Context should be empty on an empty graph"
261
+ assert isinstance(completion, TestModel)
262
+ assert completion.answer == "Test answer"
263
+
264
+
265
+ @pytest.mark.asyncio
266
+ async def test_run_cot_completion_empty_conversation_history(mock_edge):
267
+ """Test _run_cot_completion with empty conversation history."""
268
+ retriever = GraphCompletionCotRetriever()
269
+
270
+ with (
271
+ patch(
272
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
273
+ return_value="Resolved context",
274
+ ),
275
+ patch(
276
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
277
+ return_value="Generated answer",
278
+ ) as mock_generate,
279
+ ):
280
+ completion, context_text, triplets = await retriever._run_cot_completion(
281
+ query="test query",
282
+ context=[mock_edge],
283
+ conversation_history="",
284
+ max_iter=1,
285
+ )
164
286
 
165
- answer = await retriever.get_completion("Who works at Figma?")
287
+ assert completion == "Generated answer"
288
+ # Verify conversation_history was passed as None when empty
289
+ call_kwargs = mock_generate.call_args[1]
290
+ assert call_kwargs.get("conversation_history") is None
291
+
292
+
293
+ @pytest.mark.asyncio
294
+ async def test_get_completion_without_context(mock_edge):
295
+ """Test get_completion retrieves context when not provided."""
296
+ mock_graph_engine = AsyncMock()
297
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
298
+
299
+ retriever = GraphCompletionCotRetriever()
300
+
301
+ with (
302
+ patch(
303
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
304
+ return_value=mock_graph_engine,
305
+ ),
306
+ patch(
307
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
308
+ return_value=[mock_edge],
309
+ ),
310
+ patch(
311
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
312
+ return_value="Resolved context",
313
+ ),
314
+ patch(
315
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
316
+ return_value="Generated answer",
317
+ ),
318
+ patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
319
+ patch(
320
+ "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
321
+ return_value="Generated answer",
322
+ ),
323
+ patch(
324
+ "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
325
+ return_value="Rendered prompt",
326
+ ),
327
+ patch(
328
+ "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
329
+ return_value="System prompt",
330
+ ),
331
+ patch.object(
332
+ LLMGateway,
333
+ "acreate_structured_output",
334
+ new_callable=AsyncMock,
335
+ side_effect=["validation_result", "followup_question"],
336
+ ),
337
+ patch(
338
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
339
+ ) as mock_cache_config,
340
+ ):
341
+ mock_config = MagicMock()
342
+ mock_config.caching = False
343
+ mock_cache_config.return_value = mock_config
344
+
345
+ completion = await retriever.get_completion("test query", max_iter=1)
346
+
347
+ assert isinstance(completion, list)
348
+ assert len(completion) == 1
349
+ assert completion[0] == "Generated answer"
350
+
351
+
352
+ @pytest.mark.asyncio
353
+ async def test_get_completion_with_provided_context(mock_edge):
354
+ """Test get_completion uses provided context."""
355
+ retriever = GraphCompletionCotRetriever()
356
+
357
+ with (
358
+ patch(
359
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
360
+ return_value="Resolved context",
361
+ ),
362
+ patch(
363
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
364
+ return_value="Generated answer",
365
+ ),
366
+ patch(
367
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
368
+ ) as mock_cache_config,
369
+ ):
370
+ mock_config = MagicMock()
371
+ mock_config.caching = False
372
+ mock_cache_config.return_value = mock_config
373
+
374
+ completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
375
+
376
+ assert isinstance(completion, list)
377
+ assert len(completion) == 1
378
+ assert completion[0] == "Generated answer"
379
+
380
+
381
+ @pytest.mark.asyncio
382
+ async def test_get_completion_with_session(mock_edge):
383
+ """Test get_completion with session caching enabled."""
384
+ mock_graph_engine = AsyncMock()
385
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
386
+
387
+ retriever = GraphCompletionCotRetriever()
388
+
389
+ mock_user = MagicMock()
390
+ mock_user.id = "test-user-id"
391
+
392
+ with (
393
+ patch(
394
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
395
+ return_value=mock_graph_engine,
396
+ ),
397
+ patch(
398
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
399
+ return_value=[mock_edge],
400
+ ),
401
+ patch(
402
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
403
+ return_value="Resolved context",
404
+ ),
405
+ patch(
406
+ "cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history",
407
+ return_value="Previous conversation",
408
+ ),
409
+ patch(
410
+ "cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text",
411
+ return_value="Context summary",
412
+ ),
413
+ patch(
414
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
415
+ return_value="Generated answer",
416
+ ),
417
+ patch(
418
+ "cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history",
419
+ ) as mock_save,
420
+ patch(
421
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
422
+ ) as mock_cache_config,
423
+ patch(
424
+ "cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
425
+ ) as mock_session_user,
426
+ ):
427
+ mock_config = MagicMock()
428
+ mock_config.caching = True
429
+ mock_cache_config.return_value = mock_config
430
+ mock_session_user.get.return_value = mock_user
431
+
432
+ completion = await retriever.get_completion(
433
+ "test query", session_id="test_session", max_iter=1
434
+ )
166
435
 
167
- assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
168
- assert all(isinstance(item, str) and item.strip() for item in answer), (
169
- "Answer must contain only non-empty strings"
436
+ assert isinstance(completion, list)
437
+ assert len(completion) == 1
438
+ assert completion[0] == "Generated answer"
439
+ mock_save.assert_awaited_once()
440
+
441
+
442
+ @pytest.mark.asyncio
443
+ async def test_get_completion_with_save_interaction(mock_edge):
444
+ """Test get_completion with save_interaction enabled."""
445
+ mock_graph_engine = AsyncMock()
446
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
447
+ mock_graph_engine.add_edges = AsyncMock()
448
+
449
+ retriever = GraphCompletionCotRetriever(save_interaction=True)
450
+
451
+ mock_node1 = MagicMock()
452
+ mock_node2 = MagicMock()
453
+ mock_edge.node1 = mock_node1
454
+ mock_edge.node2 = mock_node2
455
+
456
+ with (
457
+ patch(
458
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
459
+ return_value="Resolved context",
460
+ ),
461
+ patch(
462
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
463
+ return_value="Generated answer",
464
+ ),
465
+ patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
466
+ patch(
467
+ "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
468
+ return_value="Generated answer",
469
+ ),
470
+ patch(
471
+ "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
472
+ return_value="Rendered prompt",
473
+ ),
474
+ patch(
475
+ "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
476
+ return_value="System prompt",
477
+ ),
478
+ patch.object(
479
+ LLMGateway,
480
+ "acreate_structured_output",
481
+ new_callable=AsyncMock,
482
+ side_effect=["validation_result", "followup_question"],
483
+ ),
484
+ patch(
485
+ "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
486
+ side_effect=[
487
+ UUID("550e8400-e29b-41d4-a716-446655440000"),
488
+ UUID("550e8400-e29b-41d4-a716-446655440001"),
489
+ ],
490
+ ),
491
+ patch(
492
+ "cognee.modules.retrieval.graph_completion_retriever.add_data_points",
493
+ ) as mock_add_data,
494
+ patch(
495
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
496
+ ) as mock_cache_config,
497
+ ):
498
+ mock_config = MagicMock()
499
+ mock_config.caching = False
500
+ mock_cache_config.return_value = mock_config
501
+
502
+ # Pass context so save_interaction condition is met
503
+ completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
504
+
505
+ assert isinstance(completion, list)
506
+ assert len(completion) == 1
507
+ mock_add_data.assert_awaited_once()
508
+
509
+
510
+ @pytest.mark.asyncio
511
+ async def test_get_completion_with_response_model(mock_edge):
512
+ """Test get_completion with custom response model."""
513
+ from pydantic import BaseModel
514
+
515
+ class TestModel(BaseModel):
516
+ answer: str
517
+
518
+ mock_graph_engine = AsyncMock()
519
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
520
+
521
+ retriever = GraphCompletionCotRetriever()
522
+
523
+ with (
524
+ patch(
525
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
526
+ return_value=mock_graph_engine,
527
+ ),
528
+ patch(
529
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
530
+ return_value=[mock_edge],
531
+ ),
532
+ patch(
533
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
534
+ return_value="Resolved context",
535
+ ),
536
+ patch(
537
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
538
+ return_value=TestModel(answer="Test answer"),
539
+ ),
540
+ patch(
541
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
542
+ ) as mock_cache_config,
543
+ ):
544
+ mock_config = MagicMock()
545
+ mock_config.caching = False
546
+ mock_cache_config.return_value = mock_config
547
+
548
+ completion = await retriever.get_completion(
549
+ "test query", response_model=TestModel, max_iter=1
170
550
  )
551
+
552
+ assert isinstance(completion, list)
553
+ assert len(completion) == 1
554
+ assert isinstance(completion[0], TestModel)
555
+
556
+
557
+ @pytest.mark.asyncio
558
+ async def test_get_completion_with_session_no_user_id(mock_edge):
559
+ """Test get_completion with session config but no user ID."""
560
+ mock_graph_engine = AsyncMock()
561
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
562
+
563
+ retriever = GraphCompletionCotRetriever()
564
+
565
+ with (
566
+ patch(
567
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
568
+ return_value=mock_graph_engine,
569
+ ),
570
+ patch(
571
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
572
+ return_value=[mock_edge],
573
+ ),
574
+ patch(
575
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
576
+ return_value="Resolved context",
577
+ ),
578
+ patch(
579
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
580
+ return_value="Generated answer",
581
+ ),
582
+ patch(
583
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
584
+ ) as mock_cache_config,
585
+ patch(
586
+ "cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
587
+ ) as mock_session_user,
588
+ ):
589
+ mock_config = MagicMock()
590
+ mock_config.caching = True
591
+ mock_cache_config.return_value = mock_config
592
+ mock_session_user.get.return_value = None # No user
593
+
594
+ completion = await retriever.get_completion("test query", max_iter=1)
595
+
596
+ assert isinstance(completion, list)
597
+ assert len(completion) == 1
598
+
599
+
600
+ @pytest.mark.asyncio
601
+ async def test_get_completion_with_save_interaction_no_context(mock_edge):
602
+ """Test get_completion with save_interaction but no context provided."""
603
+ retriever = GraphCompletionCotRetriever(save_interaction=True)
604
+
605
+ with (
606
+ patch(
607
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
608
+ return_value="Resolved context",
609
+ ),
610
+ patch(
611
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
612
+ return_value="Generated answer",
613
+ ),
614
+ patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
615
+ patch(
616
+ "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
617
+ return_value="Generated answer",
618
+ ),
619
+ patch(
620
+ "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
621
+ return_value="Rendered prompt",
622
+ ),
623
+ patch(
624
+ "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
625
+ return_value="System prompt",
626
+ ),
627
+ patch.object(
628
+ LLMGateway,
629
+ "acreate_structured_output",
630
+ new_callable=AsyncMock,
631
+ side_effect=["validation_result", "followup_question"],
632
+ ),
633
+ patch(
634
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
635
+ ) as mock_cache_config,
636
+ ):
637
+ mock_config = MagicMock()
638
+ mock_config.caching = False
639
+ mock_cache_config.return_value = mock_config
640
+
641
+ completion = await retriever.get_completion("test query", context=None, max_iter=1)
642
+
643
+ assert isinstance(completion, list)
644
+ assert len(completion) == 1
645
+
646
+
647
+ @pytest.mark.asyncio
648
+ async def test_as_answer_text_with_typeerror():
649
+ """Test _as_answer_text handles TypeError when json.dumps fails."""
650
+ non_serializable = {1, 2, 3}
651
+
652
+ result = _as_answer_text(non_serializable)
653
+
654
+ assert isinstance(result, str)
655
+ assert result == str(non_serializable)
656
+
657
+
658
+ @pytest.mark.asyncio
659
+ async def test_as_answer_text_with_string():
660
+ """Test _as_answer_text with string input."""
661
+ result = _as_answer_text("test string")
662
+ assert result == "test string"
663
+
664
+
665
+ @pytest.mark.asyncio
666
+ async def test_as_answer_text_with_dict():
667
+ """Test _as_answer_text with dictionary input."""
668
+ test_dict = {"key": "value", "number": 42}
669
+ result = _as_answer_text(test_dict)
670
+ assert isinstance(result, str)
671
+ assert "key" in result
672
+ assert "value" in result
673
+
674
+
675
+ @pytest.mark.asyncio
676
+ async def test_as_answer_text_with_basemodel():
677
+ """Test _as_answer_text with Pydantic BaseModel input."""
678
+ from pydantic import BaseModel
679
+
680
+ class TestModel(BaseModel):
681
+ answer: str
682
+
683
+ test_model = TestModel(answer="test answer")
684
+ result = _as_answer_text(test_model)
685
+
686
+ assert isinstance(result, str)
687
+ assert "[Structured Response]" in result
688
+ assert "test answer" in result