cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +1 -0
- cognee/api/v1/search/search.py +0 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/data/models/Data.py +2 -1
- cognee/modules/retrieval/triplet_retriever.py +1 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
- cognee/modules/search/methods/search.py +18 -25
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
- cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
- cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_search_db.py +334 -181
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
- cognee/tests/unit/modules/search/test_search.py +0 -100
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
|
+
|
|
4
|
+
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
|
5
|
+
GraphSummaryCompletionRetriever,
|
|
6
|
+
)
|
|
7
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@pytest.fixture
|
|
11
|
+
def mock_edge():
|
|
12
|
+
"""Create a mock edge."""
|
|
13
|
+
edge = MagicMock(spec=Edge)
|
|
14
|
+
return edge
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TestGraphSummaryCompletionRetriever:
|
|
18
|
+
@pytest.mark.asyncio
|
|
19
|
+
async def test_init_defaults(self):
|
|
20
|
+
"""Test GraphSummaryCompletionRetriever initialization with defaults."""
|
|
21
|
+
retriever = GraphSummaryCompletionRetriever()
|
|
22
|
+
|
|
23
|
+
assert retriever.summarize_prompt_path == "summarize_search_results.txt"
|
|
24
|
+
assert retriever.user_prompt_path == "graph_context_for_question.txt"
|
|
25
|
+
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
|
26
|
+
assert retriever.top_k == 5
|
|
27
|
+
assert retriever.save_interaction is False
|
|
28
|
+
|
|
29
|
+
@pytest.mark.asyncio
|
|
30
|
+
async def test_init_custom_params(self):
|
|
31
|
+
"""Test GraphSummaryCompletionRetriever initialization with custom parameters."""
|
|
32
|
+
retriever = GraphSummaryCompletionRetriever(
|
|
33
|
+
user_prompt_path="custom_user.txt",
|
|
34
|
+
system_prompt_path="custom_system.txt",
|
|
35
|
+
summarize_prompt_path="custom_summarize.txt",
|
|
36
|
+
system_prompt="Custom system prompt",
|
|
37
|
+
top_k=10,
|
|
38
|
+
save_interaction=True,
|
|
39
|
+
wide_search_top_k=200,
|
|
40
|
+
triplet_distance_penalty=2.5,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
assert retriever.summarize_prompt_path == "custom_summarize.txt"
|
|
44
|
+
assert retriever.user_prompt_path == "custom_user.txt"
|
|
45
|
+
assert retriever.system_prompt_path == "custom_system.txt"
|
|
46
|
+
assert retriever.top_k == 10
|
|
47
|
+
assert retriever.save_interaction is True
|
|
48
|
+
|
|
49
|
+
@pytest.mark.asyncio
|
|
50
|
+
async def test_resolve_edges_to_text_calls_super_and_summarizes(self, mock_edge):
|
|
51
|
+
"""Test resolve_edges_to_text calls super method and then summarizes."""
|
|
52
|
+
retriever = GraphSummaryCompletionRetriever(
|
|
53
|
+
summarize_prompt_path="custom_summarize.txt",
|
|
54
|
+
system_prompt="Custom system prompt",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
with (
|
|
58
|
+
patch(
|
|
59
|
+
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
|
|
60
|
+
new_callable=AsyncMock,
|
|
61
|
+
return_value="Resolved edges text",
|
|
62
|
+
) as mock_super_resolve,
|
|
63
|
+
patch(
|
|
64
|
+
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
|
|
65
|
+
new_callable=AsyncMock,
|
|
66
|
+
return_value="Summarized text",
|
|
67
|
+
) as mock_summarize,
|
|
68
|
+
):
|
|
69
|
+
result = await retriever.resolve_edges_to_text([mock_edge])
|
|
70
|
+
|
|
71
|
+
assert result == "Summarized text"
|
|
72
|
+
mock_super_resolve.assert_awaited_once_with([mock_edge])
|
|
73
|
+
mock_summarize.assert_awaited_once_with(
|
|
74
|
+
"Resolved edges text",
|
|
75
|
+
"custom_summarize.txt",
|
|
76
|
+
"Custom system prompt",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
@pytest.mark.asyncio
|
|
80
|
+
async def test_resolve_edges_to_text_with_default_system_prompt(self, mock_edge):
|
|
81
|
+
"""Test resolve_edges_to_text uses None for system_prompt when not provided."""
|
|
82
|
+
retriever = GraphSummaryCompletionRetriever()
|
|
83
|
+
|
|
84
|
+
with (
|
|
85
|
+
patch(
|
|
86
|
+
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
|
|
87
|
+
new_callable=AsyncMock,
|
|
88
|
+
return_value="Resolved edges text",
|
|
89
|
+
),
|
|
90
|
+
patch(
|
|
91
|
+
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
|
|
92
|
+
new_callable=AsyncMock,
|
|
93
|
+
return_value="Summarized text",
|
|
94
|
+
) as mock_summarize,
|
|
95
|
+
):
|
|
96
|
+
await retriever.resolve_edges_to_text([mock_edge])
|
|
97
|
+
|
|
98
|
+
mock_summarize.assert_awaited_once_with(
|
|
99
|
+
"Resolved edges text",
|
|
100
|
+
"summarize_search_results.txt",
|
|
101
|
+
None,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
@pytest.mark.asyncio
|
|
105
|
+
async def test_resolve_edges_to_text_with_empty_edges(self):
|
|
106
|
+
"""Test resolve_edges_to_text handles empty edges list."""
|
|
107
|
+
retriever = GraphSummaryCompletionRetriever()
|
|
108
|
+
|
|
109
|
+
with (
|
|
110
|
+
patch(
|
|
111
|
+
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
|
|
112
|
+
new_callable=AsyncMock,
|
|
113
|
+
return_value="",
|
|
114
|
+
),
|
|
115
|
+
patch(
|
|
116
|
+
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
|
|
117
|
+
new_callable=AsyncMock,
|
|
118
|
+
return_value="Empty summary",
|
|
119
|
+
) as mock_summarize,
|
|
120
|
+
):
|
|
121
|
+
result = await retriever.resolve_edges_to_text([])
|
|
122
|
+
|
|
123
|
+
assert result == "Empty summary"
|
|
124
|
+
mock_summarize.assert_awaited_once_with(
|
|
125
|
+
"",
|
|
126
|
+
"summarize_search_results.txt",
|
|
127
|
+
None,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
@pytest.mark.asyncio
|
|
131
|
+
async def test_resolve_edges_to_text_with_multiple_edges(self, mock_edge):
|
|
132
|
+
"""Test resolve_edges_to_text handles multiple edges."""
|
|
133
|
+
retriever = GraphSummaryCompletionRetriever()
|
|
134
|
+
|
|
135
|
+
mock_edge2 = MagicMock(spec=Edge)
|
|
136
|
+
mock_edge3 = MagicMock(spec=Edge)
|
|
137
|
+
|
|
138
|
+
with (
|
|
139
|
+
patch(
|
|
140
|
+
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
|
|
141
|
+
new_callable=AsyncMock,
|
|
142
|
+
return_value="Multiple edges resolved text",
|
|
143
|
+
),
|
|
144
|
+
patch(
|
|
145
|
+
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
|
|
146
|
+
new_callable=AsyncMock,
|
|
147
|
+
return_value="Multiple edges summarized",
|
|
148
|
+
) as mock_summarize,
|
|
149
|
+
):
|
|
150
|
+
result = await retriever.resolve_edges_to_text([mock_edge, mock_edge2, mock_edge3])
|
|
151
|
+
|
|
152
|
+
assert result == "Multiple edges summarized"
|
|
153
|
+
mock_summarize.assert_awaited_once_with(
|
|
154
|
+
"Multiple edges resolved text",
|
|
155
|
+
"summarize_search_results.txt",
|
|
156
|
+
None,
|
|
157
|
+
)
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
|
+
from uuid import UUID, NAMESPACE_OID, uuid5
|
|
4
|
+
|
|
5
|
+
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
|
6
|
+
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment
|
|
7
|
+
from cognee.modules.engine.models import NodeSet
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@pytest.fixture
|
|
11
|
+
def mock_feedback_evaluation():
|
|
12
|
+
"""Create a mock feedback evaluation."""
|
|
13
|
+
evaluation = MagicMock(spec=UserFeedbackEvaluation)
|
|
14
|
+
evaluation.evaluation = MagicMock()
|
|
15
|
+
evaluation.evaluation.value = "positive"
|
|
16
|
+
evaluation.score = 4.5
|
|
17
|
+
return evaluation
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.fixture
|
|
21
|
+
def mock_graph_engine():
|
|
22
|
+
"""Create a mock graph engine."""
|
|
23
|
+
engine = AsyncMock()
|
|
24
|
+
engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
|
25
|
+
engine.add_edges = AsyncMock()
|
|
26
|
+
engine.apply_feedback_weight = AsyncMock()
|
|
27
|
+
return engine
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TestUserQAFeedback:
|
|
31
|
+
@pytest.mark.asyncio
|
|
32
|
+
async def test_init_default(self):
|
|
33
|
+
"""Test UserQAFeedback initialization with default last_k."""
|
|
34
|
+
retriever = UserQAFeedback()
|
|
35
|
+
assert retriever.last_k == 1
|
|
36
|
+
|
|
37
|
+
@pytest.mark.asyncio
|
|
38
|
+
async def test_init_custom_last_k(self):
|
|
39
|
+
"""Test UserQAFeedback initialization with custom last_k."""
|
|
40
|
+
retriever = UserQAFeedback(last_k=5)
|
|
41
|
+
assert retriever.last_k == 5
|
|
42
|
+
|
|
43
|
+
@pytest.mark.asyncio
|
|
44
|
+
async def test_add_feedback_success_with_relationships(
|
|
45
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
46
|
+
):
|
|
47
|
+
"""Test add_feedback successfully creates feedback with relationships."""
|
|
48
|
+
interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
|
|
49
|
+
interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001"))
|
|
50
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(
|
|
51
|
+
return_value=[interaction_id_1, interaction_id_2]
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
feedback_text = "This answer was helpful"
|
|
55
|
+
|
|
56
|
+
with (
|
|
57
|
+
patch(
|
|
58
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
59
|
+
new_callable=AsyncMock,
|
|
60
|
+
return_value=mock_feedback_evaluation,
|
|
61
|
+
),
|
|
62
|
+
patch(
|
|
63
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
64
|
+
return_value=mock_graph_engine,
|
|
65
|
+
),
|
|
66
|
+
patch(
|
|
67
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
68
|
+
new_callable=AsyncMock,
|
|
69
|
+
) as mock_add_data,
|
|
70
|
+
patch(
|
|
71
|
+
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
|
72
|
+
new_callable=AsyncMock,
|
|
73
|
+
) as mock_index_edges,
|
|
74
|
+
):
|
|
75
|
+
retriever = UserQAFeedback(last_k=2)
|
|
76
|
+
result = await retriever.add_feedback(feedback_text)
|
|
77
|
+
|
|
78
|
+
assert result == [feedback_text]
|
|
79
|
+
mock_add_data.assert_awaited_once()
|
|
80
|
+
mock_graph_engine.add_edges.assert_awaited_once()
|
|
81
|
+
mock_index_edges.assert_awaited_once()
|
|
82
|
+
mock_graph_engine.apply_feedback_weight.assert_awaited_once()
|
|
83
|
+
|
|
84
|
+
# Verify add_edges was called with correct relationships
|
|
85
|
+
call_args = mock_graph_engine.add_edges.call_args[0][0]
|
|
86
|
+
assert len(call_args) == 2
|
|
87
|
+
assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text)
|
|
88
|
+
assert call_args[0][1] == UUID(interaction_id_1)
|
|
89
|
+
assert call_args[0][2] == "gives_feedback_to"
|
|
90
|
+
assert call_args[0][3]["relationship_name"] == "gives_feedback_to"
|
|
91
|
+
assert call_args[0][3]["ontology_valid"] is False
|
|
92
|
+
|
|
93
|
+
# Verify apply_feedback_weight was called with correct node IDs
|
|
94
|
+
weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"]
|
|
95
|
+
assert len(weight_call_args) == 2
|
|
96
|
+
assert interaction_id_1 in weight_call_args
|
|
97
|
+
assert interaction_id_2 in weight_call_args
|
|
98
|
+
|
|
99
|
+
@pytest.mark.asyncio
|
|
100
|
+
async def test_add_feedback_success_no_relationships(
|
|
101
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
102
|
+
):
|
|
103
|
+
"""Test add_feedback successfully creates feedback without relationships."""
|
|
104
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
|
105
|
+
|
|
106
|
+
feedback_text = "This answer was helpful"
|
|
107
|
+
|
|
108
|
+
with (
|
|
109
|
+
patch(
|
|
110
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
111
|
+
new_callable=AsyncMock,
|
|
112
|
+
return_value=mock_feedback_evaluation,
|
|
113
|
+
),
|
|
114
|
+
patch(
|
|
115
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
116
|
+
return_value=mock_graph_engine,
|
|
117
|
+
),
|
|
118
|
+
patch(
|
|
119
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
120
|
+
new_callable=AsyncMock,
|
|
121
|
+
) as mock_add_data,
|
|
122
|
+
patch(
|
|
123
|
+
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
|
124
|
+
new_callable=AsyncMock,
|
|
125
|
+
) as mock_index_edges,
|
|
126
|
+
):
|
|
127
|
+
retriever = UserQAFeedback(last_k=1)
|
|
128
|
+
result = await retriever.add_feedback(feedback_text)
|
|
129
|
+
|
|
130
|
+
assert result == [feedback_text]
|
|
131
|
+
mock_add_data.assert_awaited_once()
|
|
132
|
+
# Should not call add_edges or index_graph_edges when no relationships
|
|
133
|
+
mock_graph_engine.add_edges.assert_not_awaited()
|
|
134
|
+
mock_index_edges.assert_not_awaited()
|
|
135
|
+
mock_graph_engine.apply_feedback_weight.assert_not_awaited()
|
|
136
|
+
|
|
137
|
+
@pytest.mark.asyncio
|
|
138
|
+
async def test_add_feedback_creates_correct_feedback_node(
|
|
139
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
140
|
+
):
|
|
141
|
+
"""Test add_feedback creates CogneeUserFeedback with correct attributes."""
|
|
142
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
|
143
|
+
|
|
144
|
+
feedback_text = "This was a negative experience"
|
|
145
|
+
mock_feedback_evaluation.evaluation.value = "negative"
|
|
146
|
+
mock_feedback_evaluation.score = -3.0
|
|
147
|
+
|
|
148
|
+
with (
|
|
149
|
+
patch(
|
|
150
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
151
|
+
new_callable=AsyncMock,
|
|
152
|
+
return_value=mock_feedback_evaluation,
|
|
153
|
+
),
|
|
154
|
+
patch(
|
|
155
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
156
|
+
return_value=mock_graph_engine,
|
|
157
|
+
),
|
|
158
|
+
patch(
|
|
159
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
160
|
+
new_callable=AsyncMock,
|
|
161
|
+
) as mock_add_data,
|
|
162
|
+
):
|
|
163
|
+
retriever = UserQAFeedback()
|
|
164
|
+
await retriever.add_feedback(feedback_text)
|
|
165
|
+
|
|
166
|
+
# Verify add_data_points was called with correct CogneeUserFeedback
|
|
167
|
+
call_args = mock_add_data.call_args[1]["data_points"]
|
|
168
|
+
assert len(call_args) == 1
|
|
169
|
+
feedback_node = call_args[0]
|
|
170
|
+
assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text)
|
|
171
|
+
assert feedback_node.feedback == feedback_text
|
|
172
|
+
assert feedback_node.sentiment == "negative"
|
|
173
|
+
assert feedback_node.score == -3.0
|
|
174
|
+
assert isinstance(feedback_node.belongs_to_set, NodeSet)
|
|
175
|
+
assert feedback_node.belongs_to_set.name == "UserQAFeedbacks"
|
|
176
|
+
|
|
177
|
+
@pytest.mark.asyncio
|
|
178
|
+
async def test_add_feedback_calls_llm_with_correct_prompt(
|
|
179
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
180
|
+
):
|
|
181
|
+
"""Test add_feedback calls LLM with correct sentiment analysis prompt."""
|
|
182
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
|
183
|
+
|
|
184
|
+
feedback_text = "Great answer!"
|
|
185
|
+
|
|
186
|
+
with (
|
|
187
|
+
patch(
|
|
188
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
189
|
+
new_callable=AsyncMock,
|
|
190
|
+
return_value=mock_feedback_evaluation,
|
|
191
|
+
) as mock_llm,
|
|
192
|
+
patch(
|
|
193
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
194
|
+
return_value=mock_graph_engine,
|
|
195
|
+
),
|
|
196
|
+
patch(
|
|
197
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
198
|
+
new_callable=AsyncMock,
|
|
199
|
+
),
|
|
200
|
+
):
|
|
201
|
+
retriever = UserQAFeedback()
|
|
202
|
+
await retriever.add_feedback(feedback_text)
|
|
203
|
+
|
|
204
|
+
mock_llm.assert_awaited_once()
|
|
205
|
+
call_kwargs = mock_llm.call_args[1]
|
|
206
|
+
assert call_kwargs["text_input"] == feedback_text
|
|
207
|
+
assert "sentiment analysis assistant" in call_kwargs["system_prompt"]
|
|
208
|
+
assert call_kwargs["response_model"] == UserFeedbackEvaluation
|
|
209
|
+
|
|
210
|
+
@pytest.mark.asyncio
|
|
211
|
+
async def test_add_feedback_uses_last_k_parameter(
|
|
212
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
213
|
+
):
|
|
214
|
+
"""Test add_feedback uses last_k parameter when getting interaction IDs."""
|
|
215
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
|
216
|
+
|
|
217
|
+
feedback_text = "Test feedback"
|
|
218
|
+
|
|
219
|
+
with (
|
|
220
|
+
patch(
|
|
221
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
222
|
+
new_callable=AsyncMock,
|
|
223
|
+
return_value=mock_feedback_evaluation,
|
|
224
|
+
),
|
|
225
|
+
patch(
|
|
226
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
227
|
+
return_value=mock_graph_engine,
|
|
228
|
+
),
|
|
229
|
+
patch(
|
|
230
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
231
|
+
new_callable=AsyncMock,
|
|
232
|
+
),
|
|
233
|
+
):
|
|
234
|
+
retriever = UserQAFeedback(last_k=5)
|
|
235
|
+
await retriever.add_feedback(feedback_text)
|
|
236
|
+
|
|
237
|
+
mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5)
|
|
238
|
+
|
|
239
|
+
@pytest.mark.asyncio
|
|
240
|
+
async def test_add_feedback_with_single_interaction(
|
|
241
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
242
|
+
):
|
|
243
|
+
"""Test add_feedback with single interaction ID."""
|
|
244
|
+
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
|
|
245
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
|
|
246
|
+
|
|
247
|
+
feedback_text = "Test feedback"
|
|
248
|
+
|
|
249
|
+
with (
|
|
250
|
+
patch(
|
|
251
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
252
|
+
new_callable=AsyncMock,
|
|
253
|
+
return_value=mock_feedback_evaluation,
|
|
254
|
+
),
|
|
255
|
+
patch(
|
|
256
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
257
|
+
return_value=mock_graph_engine,
|
|
258
|
+
),
|
|
259
|
+
patch(
|
|
260
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
261
|
+
new_callable=AsyncMock,
|
|
262
|
+
),
|
|
263
|
+
patch(
|
|
264
|
+
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
|
265
|
+
new_callable=AsyncMock,
|
|
266
|
+
),
|
|
267
|
+
):
|
|
268
|
+
retriever = UserQAFeedback()
|
|
269
|
+
result = await retriever.add_feedback(feedback_text)
|
|
270
|
+
|
|
271
|
+
assert result == [feedback_text]
|
|
272
|
+
# Should create relationship for the interaction
|
|
273
|
+
call_args = mock_graph_engine.add_edges.call_args[0][0]
|
|
274
|
+
assert len(call_args) == 1
|
|
275
|
+
assert call_args[0][1] == UUID(interaction_id)
|
|
276
|
+
|
|
277
|
+
@pytest.mark.asyncio
|
|
278
|
+
async def test_add_feedback_applies_weight_correctly(
|
|
279
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
280
|
+
):
|
|
281
|
+
"""Test add_feedback applies feedback weight with correct score."""
|
|
282
|
+
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
|
|
283
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
|
|
284
|
+
mock_feedback_evaluation.score = 4.5
|
|
285
|
+
|
|
286
|
+
feedback_text = "Positive feedback"
|
|
287
|
+
|
|
288
|
+
with (
|
|
289
|
+
patch(
|
|
290
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
291
|
+
new_callable=AsyncMock,
|
|
292
|
+
return_value=mock_feedback_evaluation,
|
|
293
|
+
),
|
|
294
|
+
patch(
|
|
295
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
296
|
+
return_value=mock_graph_engine,
|
|
297
|
+
),
|
|
298
|
+
patch(
|
|
299
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
300
|
+
new_callable=AsyncMock,
|
|
301
|
+
),
|
|
302
|
+
patch(
|
|
303
|
+
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
|
304
|
+
new_callable=AsyncMock,
|
|
305
|
+
),
|
|
306
|
+
):
|
|
307
|
+
retriever = UserQAFeedback()
|
|
308
|
+
await retriever.add_feedback(feedback_text)
|
|
309
|
+
|
|
310
|
+
mock_graph_engine.apply_feedback_weight.assert_awaited_once_with(
|
|
311
|
+
node_ids=[interaction_id], weight=4.5
|
|
312
|
+
)
|