cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. cognee/api/v1/add/add.py +2 -1
  2. cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
  3. cognee/api/v1/memify/routers/get_memify_router.py +1 -0
  4. cognee/api/v1/search/search.py +0 -4
  5. cognee/infrastructure/databases/relational/config.py +16 -1
  6. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  7. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
  8. cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
  9. cognee/infrastructure/llm/LLMGateway.py +0 -13
  10. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  11. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  12. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  13. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
  14. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  15. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  16. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  17. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  18. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  19. cognee/modules/data/models/Data.py +2 -1
  20. cognee/modules/retrieval/triplet_retriever.py +1 -1
  21. cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
  22. cognee/modules/search/methods/search.py +18 -25
  23. cognee/tasks/ingestion/data_item.py +8 -0
  24. cognee/tasks/ingestion/ingest_data.py +12 -1
  25. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  26. cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
  27. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
  28. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
  29. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
  30. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
  31. cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
  32. cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
  33. cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
  34. cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
  35. cognee/tests/test_custom_data_label.py +68 -0
  36. cognee/tests/test_search_db.py +334 -181
  37. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  38. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  39. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  40. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
  41. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  42. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
  43. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
  44. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
  45. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
  46. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
  47. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
  48. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
  49. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  50. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  51. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  52. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
  53. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
  54. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
  55. cognee/tests/unit/modules/search/test_search.py +0 -100
  56. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
  57. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
  58. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
  59. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -0,0 +1,157 @@
1
+ import pytest
2
+ from unittest.mock import AsyncMock, patch, MagicMock
3
+
4
+ from cognee.modules.retrieval.graph_summary_completion_retriever import (
5
+ GraphSummaryCompletionRetriever,
6
+ )
7
+ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
8
+
9
+
10
+ @pytest.fixture
11
+ def mock_edge():
12
+ """Create a mock edge."""
13
+ edge = MagicMock(spec=Edge)
14
+ return edge
15
+
16
+
17
+ class TestGraphSummaryCompletionRetriever:
18
+ @pytest.mark.asyncio
19
+ async def test_init_defaults(self):
20
+ """Test GraphSummaryCompletionRetriever initialization with defaults."""
21
+ retriever = GraphSummaryCompletionRetriever()
22
+
23
+ assert retriever.summarize_prompt_path == "summarize_search_results.txt"
24
+ assert retriever.user_prompt_path == "graph_context_for_question.txt"
25
+ assert retriever.system_prompt_path == "answer_simple_question.txt"
26
+ assert retriever.top_k == 5
27
+ assert retriever.save_interaction is False
28
+
29
+ @pytest.mark.asyncio
30
+ async def test_init_custom_params(self):
31
+ """Test GraphSummaryCompletionRetriever initialization with custom parameters."""
32
+ retriever = GraphSummaryCompletionRetriever(
33
+ user_prompt_path="custom_user.txt",
34
+ system_prompt_path="custom_system.txt",
35
+ summarize_prompt_path="custom_summarize.txt",
36
+ system_prompt="Custom system prompt",
37
+ top_k=10,
38
+ save_interaction=True,
39
+ wide_search_top_k=200,
40
+ triplet_distance_penalty=2.5,
41
+ )
42
+
43
+ assert retriever.summarize_prompt_path == "custom_summarize.txt"
44
+ assert retriever.user_prompt_path == "custom_user.txt"
45
+ assert retriever.system_prompt_path == "custom_system.txt"
46
+ assert retriever.top_k == 10
47
+ assert retriever.save_interaction is True
48
+
49
+ @pytest.mark.asyncio
50
+ async def test_resolve_edges_to_text_calls_super_and_summarizes(self, mock_edge):
51
+ """Test resolve_edges_to_text calls super method and then summarizes."""
52
+ retriever = GraphSummaryCompletionRetriever(
53
+ summarize_prompt_path="custom_summarize.txt",
54
+ system_prompt="Custom system prompt",
55
+ )
56
+
57
+ with (
58
+ patch(
59
+ "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
60
+ new_callable=AsyncMock,
61
+ return_value="Resolved edges text",
62
+ ) as mock_super_resolve,
63
+ patch(
64
+ "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
65
+ new_callable=AsyncMock,
66
+ return_value="Summarized text",
67
+ ) as mock_summarize,
68
+ ):
69
+ result = await retriever.resolve_edges_to_text([mock_edge])
70
+
71
+ assert result == "Summarized text"
72
+ mock_super_resolve.assert_awaited_once_with([mock_edge])
73
+ mock_summarize.assert_awaited_once_with(
74
+ "Resolved edges text",
75
+ "custom_summarize.txt",
76
+ "Custom system prompt",
77
+ )
78
+
79
+ @pytest.mark.asyncio
80
+ async def test_resolve_edges_to_text_with_default_system_prompt(self, mock_edge):
81
+ """Test resolve_edges_to_text uses None for system_prompt when not provided."""
82
+ retriever = GraphSummaryCompletionRetriever()
83
+
84
+ with (
85
+ patch(
86
+ "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
87
+ new_callable=AsyncMock,
88
+ return_value="Resolved edges text",
89
+ ),
90
+ patch(
91
+ "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
92
+ new_callable=AsyncMock,
93
+ return_value="Summarized text",
94
+ ) as mock_summarize,
95
+ ):
96
+ await retriever.resolve_edges_to_text([mock_edge])
97
+
98
+ mock_summarize.assert_awaited_once_with(
99
+ "Resolved edges text",
100
+ "summarize_search_results.txt",
101
+ None,
102
+ )
103
+
104
+ @pytest.mark.asyncio
105
+ async def test_resolve_edges_to_text_with_empty_edges(self):
106
+ """Test resolve_edges_to_text handles empty edges list."""
107
+ retriever = GraphSummaryCompletionRetriever()
108
+
109
+ with (
110
+ patch(
111
+ "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
112
+ new_callable=AsyncMock,
113
+ return_value="",
114
+ ),
115
+ patch(
116
+ "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
117
+ new_callable=AsyncMock,
118
+ return_value="Empty summary",
119
+ ) as mock_summarize,
120
+ ):
121
+ result = await retriever.resolve_edges_to_text([])
122
+
123
+ assert result == "Empty summary"
124
+ mock_summarize.assert_awaited_once_with(
125
+ "",
126
+ "summarize_search_results.txt",
127
+ None,
128
+ )
129
+
130
+ @pytest.mark.asyncio
131
+ async def test_resolve_edges_to_text_with_multiple_edges(self, mock_edge):
132
+ """Test resolve_edges_to_text handles multiple edges."""
133
+ retriever = GraphSummaryCompletionRetriever()
134
+
135
+ mock_edge2 = MagicMock(spec=Edge)
136
+ mock_edge3 = MagicMock(spec=Edge)
137
+
138
+ with (
139
+ patch(
140
+ "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
141
+ new_callable=AsyncMock,
142
+ return_value="Multiple edges resolved text",
143
+ ),
144
+ patch(
145
+ "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
146
+ new_callable=AsyncMock,
147
+ return_value="Multiple edges summarized",
148
+ ) as mock_summarize,
149
+ ):
150
+ result = await retriever.resolve_edges_to_text([mock_edge, mock_edge2, mock_edge3])
151
+
152
+ assert result == "Multiple edges summarized"
153
+ mock_summarize.assert_awaited_once_with(
154
+ "Multiple edges resolved text",
155
+ "summarize_search_results.txt",
156
+ None,
157
+ )
@@ -0,0 +1,312 @@
1
+ import pytest
2
+ from unittest.mock import AsyncMock, patch, MagicMock
3
+ from uuid import UUID, NAMESPACE_OID, uuid5
4
+
5
+ from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
6
+ from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment
7
+ from cognee.modules.engine.models import NodeSet
8
+
9
+
10
+ @pytest.fixture
11
+ def mock_feedback_evaluation():
12
+ """Create a mock feedback evaluation."""
13
+ evaluation = MagicMock(spec=UserFeedbackEvaluation)
14
+ evaluation.evaluation = MagicMock()
15
+ evaluation.evaluation.value = "positive"
16
+ evaluation.score = 4.5
17
+ return evaluation
18
+
19
+
20
+ @pytest.fixture
21
+ def mock_graph_engine():
22
+ """Create a mock graph engine."""
23
+ engine = AsyncMock()
24
+ engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
25
+ engine.add_edges = AsyncMock()
26
+ engine.apply_feedback_weight = AsyncMock()
27
+ return engine
28
+
29
+
30
+ class TestUserQAFeedback:
31
+ @pytest.mark.asyncio
32
+ async def test_init_default(self):
33
+ """Test UserQAFeedback initialization with default last_k."""
34
+ retriever = UserQAFeedback()
35
+ assert retriever.last_k == 1
36
+
37
+ @pytest.mark.asyncio
38
+ async def test_init_custom_last_k(self):
39
+ """Test UserQAFeedback initialization with custom last_k."""
40
+ retriever = UserQAFeedback(last_k=5)
41
+ assert retriever.last_k == 5
42
+
43
+ @pytest.mark.asyncio
44
+ async def test_add_feedback_success_with_relationships(
45
+ self, mock_feedback_evaluation, mock_graph_engine
46
+ ):
47
+ """Test add_feedback successfully creates feedback with relationships."""
48
+ interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
49
+ interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001"))
50
+ mock_graph_engine.get_last_user_interaction_ids = AsyncMock(
51
+ return_value=[interaction_id_1, interaction_id_2]
52
+ )
53
+
54
+ feedback_text = "This answer was helpful"
55
+
56
+ with (
57
+ patch(
58
+ "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
59
+ new_callable=AsyncMock,
60
+ return_value=mock_feedback_evaluation,
61
+ ),
62
+ patch(
63
+ "cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
64
+ return_value=mock_graph_engine,
65
+ ),
66
+ patch(
67
+ "cognee.modules.retrieval.user_qa_feedback.add_data_points",
68
+ new_callable=AsyncMock,
69
+ ) as mock_add_data,
70
+ patch(
71
+ "cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
72
+ new_callable=AsyncMock,
73
+ ) as mock_index_edges,
74
+ ):
75
+ retriever = UserQAFeedback(last_k=2)
76
+ result = await retriever.add_feedback(feedback_text)
77
+
78
+ assert result == [feedback_text]
79
+ mock_add_data.assert_awaited_once()
80
+ mock_graph_engine.add_edges.assert_awaited_once()
81
+ mock_index_edges.assert_awaited_once()
82
+ mock_graph_engine.apply_feedback_weight.assert_awaited_once()
83
+
84
+ # Verify add_edges was called with correct relationships
85
+ call_args = mock_graph_engine.add_edges.call_args[0][0]
86
+ assert len(call_args) == 2
87
+ assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text)
88
+ assert call_args[0][1] == UUID(interaction_id_1)
89
+ assert call_args[0][2] == "gives_feedback_to"
90
+ assert call_args[0][3]["relationship_name"] == "gives_feedback_to"
91
+ assert call_args[0][3]["ontology_valid"] is False
92
+
93
+ # Verify apply_feedback_weight was called with correct node IDs
94
+ weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"]
95
+ assert len(weight_call_args) == 2
96
+ assert interaction_id_1 in weight_call_args
97
+ assert interaction_id_2 in weight_call_args
98
+
99
+ @pytest.mark.asyncio
100
+ async def test_add_feedback_success_no_relationships(
101
+ self, mock_feedback_evaluation, mock_graph_engine
102
+ ):
103
+ """Test add_feedback successfully creates feedback without relationships."""
104
+ mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
105
+
106
+ feedback_text = "This answer was helpful"
107
+
108
+ with (
109
+ patch(
110
+ "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
111
+ new_callable=AsyncMock,
112
+ return_value=mock_feedback_evaluation,
113
+ ),
114
+ patch(
115
+ "cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
116
+ return_value=mock_graph_engine,
117
+ ),
118
+ patch(
119
+ "cognee.modules.retrieval.user_qa_feedback.add_data_points",
120
+ new_callable=AsyncMock,
121
+ ) as mock_add_data,
122
+ patch(
123
+ "cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
124
+ new_callable=AsyncMock,
125
+ ) as mock_index_edges,
126
+ ):
127
+ retriever = UserQAFeedback(last_k=1)
128
+ result = await retriever.add_feedback(feedback_text)
129
+
130
+ assert result == [feedback_text]
131
+ mock_add_data.assert_awaited_once()
132
+ # Should not call add_edges or index_graph_edges when no relationships
133
+ mock_graph_engine.add_edges.assert_not_awaited()
134
+ mock_index_edges.assert_not_awaited()
135
+ mock_graph_engine.apply_feedback_weight.assert_not_awaited()
136
+
137
+ @pytest.mark.asyncio
138
+ async def test_add_feedback_creates_correct_feedback_node(
139
+ self, mock_feedback_evaluation, mock_graph_engine
140
+ ):
141
+ """Test add_feedback creates CogneeUserFeedback with correct attributes."""
142
+ mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
143
+
144
+ feedback_text = "This was a negative experience"
145
+ mock_feedback_evaluation.evaluation.value = "negative"
146
+ mock_feedback_evaluation.score = -3.0
147
+
148
+ with (
149
+ patch(
150
+ "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
151
+ new_callable=AsyncMock,
152
+ return_value=mock_feedback_evaluation,
153
+ ),
154
+ patch(
155
+ "cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
156
+ return_value=mock_graph_engine,
157
+ ),
158
+ patch(
159
+ "cognee.modules.retrieval.user_qa_feedback.add_data_points",
160
+ new_callable=AsyncMock,
161
+ ) as mock_add_data,
162
+ ):
163
+ retriever = UserQAFeedback()
164
+ await retriever.add_feedback(feedback_text)
165
+
166
+ # Verify add_data_points was called with correct CogneeUserFeedback
167
+ call_args = mock_add_data.call_args[1]["data_points"]
168
+ assert len(call_args) == 1
169
+ feedback_node = call_args[0]
170
+ assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text)
171
+ assert feedback_node.feedback == feedback_text
172
+ assert feedback_node.sentiment == "negative"
173
+ assert feedback_node.score == -3.0
174
+ assert isinstance(feedback_node.belongs_to_set, NodeSet)
175
+ assert feedback_node.belongs_to_set.name == "UserQAFeedbacks"
176
+
177
+ @pytest.mark.asyncio
178
+ async def test_add_feedback_calls_llm_with_correct_prompt(
179
+ self, mock_feedback_evaluation, mock_graph_engine
180
+ ):
181
+ """Test add_feedback calls LLM with correct sentiment analysis prompt."""
182
+ mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
183
+
184
+ feedback_text = "Great answer!"
185
+
186
+ with (
187
+ patch(
188
+ "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
189
+ new_callable=AsyncMock,
190
+ return_value=mock_feedback_evaluation,
191
+ ) as mock_llm,
192
+ patch(
193
+ "cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
194
+ return_value=mock_graph_engine,
195
+ ),
196
+ patch(
197
+ "cognee.modules.retrieval.user_qa_feedback.add_data_points",
198
+ new_callable=AsyncMock,
199
+ ),
200
+ ):
201
+ retriever = UserQAFeedback()
202
+ await retriever.add_feedback(feedback_text)
203
+
204
+ mock_llm.assert_awaited_once()
205
+ call_kwargs = mock_llm.call_args[1]
206
+ assert call_kwargs["text_input"] == feedback_text
207
+ assert "sentiment analysis assistant" in call_kwargs["system_prompt"]
208
+ assert call_kwargs["response_model"] == UserFeedbackEvaluation
209
+
210
+ @pytest.mark.asyncio
211
+ async def test_add_feedback_uses_last_k_parameter(
212
+ self, mock_feedback_evaluation, mock_graph_engine
213
+ ):
214
+ """Test add_feedback uses last_k parameter when getting interaction IDs."""
215
+ mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
216
+
217
+ feedback_text = "Test feedback"
218
+
219
+ with (
220
+ patch(
221
+ "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
222
+ new_callable=AsyncMock,
223
+ return_value=mock_feedback_evaluation,
224
+ ),
225
+ patch(
226
+ "cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
227
+ return_value=mock_graph_engine,
228
+ ),
229
+ patch(
230
+ "cognee.modules.retrieval.user_qa_feedback.add_data_points",
231
+ new_callable=AsyncMock,
232
+ ),
233
+ ):
234
+ retriever = UserQAFeedback(last_k=5)
235
+ await retriever.add_feedback(feedback_text)
236
+
237
+ mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5)
238
+
239
+ @pytest.mark.asyncio
240
+ async def test_add_feedback_with_single_interaction(
241
+ self, mock_feedback_evaluation, mock_graph_engine
242
+ ):
243
+ """Test add_feedback with single interaction ID."""
244
+ interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
245
+ mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
246
+
247
+ feedback_text = "Test feedback"
248
+
249
+ with (
250
+ patch(
251
+ "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
252
+ new_callable=AsyncMock,
253
+ return_value=mock_feedback_evaluation,
254
+ ),
255
+ patch(
256
+ "cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
257
+ return_value=mock_graph_engine,
258
+ ),
259
+ patch(
260
+ "cognee.modules.retrieval.user_qa_feedback.add_data_points",
261
+ new_callable=AsyncMock,
262
+ ),
263
+ patch(
264
+ "cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
265
+ new_callable=AsyncMock,
266
+ ),
267
+ ):
268
+ retriever = UserQAFeedback()
269
+ result = await retriever.add_feedback(feedback_text)
270
+
271
+ assert result == [feedback_text]
272
+ # Should create relationship for the interaction
273
+ call_args = mock_graph_engine.add_edges.call_args[0][0]
274
+ assert len(call_args) == 1
275
+ assert call_args[0][1] == UUID(interaction_id)
276
+
277
+ @pytest.mark.asyncio
278
+ async def test_add_feedback_applies_weight_correctly(
279
+ self, mock_feedback_evaluation, mock_graph_engine
280
+ ):
281
+ """Test add_feedback applies feedback weight with correct score."""
282
+ interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
283
+ mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
284
+ mock_feedback_evaluation.score = 4.5
285
+
286
+ feedback_text = "Positive feedback"
287
+
288
+ with (
289
+ patch(
290
+ "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
291
+ new_callable=AsyncMock,
292
+ return_value=mock_feedback_evaluation,
293
+ ),
294
+ patch(
295
+ "cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
296
+ return_value=mock_graph_engine,
297
+ ),
298
+ patch(
299
+ "cognee.modules.retrieval.user_qa_feedback.add_data_points",
300
+ new_callable=AsyncMock,
301
+ ),
302
+ patch(
303
+ "cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
304
+ new_callable=AsyncMock,
305
+ ),
306
+ ):
307
+ retriever = UserQAFeedback()
308
+ await retriever.add_feedback(feedback_text)
309
+
310
+ mock_graph_engine.apply_feedback_weight.assert_awaited_once_with(
311
+ node_ids=[interaction_id], weight=4.5
312
+ )