cognee 0.3.7__py3-none-any.whl → 0.3.7.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/add/routers/get_add_router.py +6 -1
- cognee/api/v1/cognify/routers/get_cognify_router.py +2 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +8 -0
- cognee/api/v1/delete/routers/get_delete_router.py +2 -0
- cognee/api/v1/memify/routers/get_memify_router.py +2 -1
- cognee/api/v1/permissions/routers/get_permissions_router.py +6 -0
- cognee/api/v1/search/routers/get_search_router.py +3 -3
- cognee/api/v1/sync/routers/get_sync_router.py +3 -0
- cognee/api/v1/ui/ui.py +2 -4
- cognee/api/v1/update/routers/get_update_router.py +2 -0
- cognee/api/v1/users/routers/get_visualize_router.py +2 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +9 -3
- cognee/infrastructure/llm/prompts/feedback_reaction_prompt.txt +14 -0
- cognee/infrastructure/llm/prompts/feedback_report_prompt.txt +13 -0
- cognee/infrastructure/llm/prompts/feedback_user_context_prompt.txt +5 -0
- cognee/modules/pipelines/operations/run_tasks_base.py +7 -0
- cognee/modules/pipelines/operations/run_tasks_with_telemetry.py +9 -1
- cognee/modules/retrieval/graph_completion_cot_retriever.py +137 -38
- cognee/modules/retrieval/utils/completion.py +25 -4
- cognee/modules/search/methods/search.py +17 -3
- cognee/shared/logging_utils.py +24 -12
- cognee/shared/utils.py +24 -2
- cognee/tasks/feedback/__init__.py +13 -0
- cognee/tasks/feedback/create_enrichments.py +84 -0
- cognee/tasks/feedback/extract_feedback_interactions.py +230 -0
- cognee/tasks/feedback/generate_improved_answers.py +130 -0
- cognee/tasks/feedback/link_enrichments_to_feedback.py +67 -0
- cognee/tasks/feedback/models.py +26 -0
- cognee/tests/test_feedback_enrichment.py +174 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +51 -0
- {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/METADATA +1 -1
- {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/RECORD +36 -26
- {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/WHEEL +0 -0
- {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/entry_points.txt +0 -0
- {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
4
|
+
from uuid import UUID, uuid5, NAMESPACE_OID
|
|
5
|
+
|
|
6
|
+
from cognee.infrastructure.llm import LLMGateway
|
|
7
|
+
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
|
8
|
+
from cognee.shared.logging_utils import get_logger
|
|
9
|
+
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
10
|
+
|
|
11
|
+
from .models import FeedbackEnrichment
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
logger = get_logger("extract_feedback_interactions")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _filter_negative_feedback(feedback_nodes):
|
|
18
|
+
"""Filter for negative sentiment feedback using precise sentiment classification."""
|
|
19
|
+
return [
|
|
20
|
+
(node_id, props)
|
|
21
|
+
for node_id, props in feedback_nodes
|
|
22
|
+
if (props.get("sentiment", "").casefold() == "negative" or props.get("score", 0) < 0)
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _get_normalized_id(node_id, props) -> str:
|
|
27
|
+
"""Return Cognee node id preference: props.id → props.node_id → raw node_id."""
|
|
28
|
+
return str(props.get("id") or props.get("node_id") or node_id)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def _fetch_feedback_and_interaction_graph_data() -> Tuple[List, List]:
|
|
32
|
+
"""Fetch feedback and interaction nodes with edges from graph engine."""
|
|
33
|
+
try:
|
|
34
|
+
graph_engine = await get_graph_engine()
|
|
35
|
+
attribute_filters = [{"type": ["CogneeUserFeedback", "CogneeUserInteraction"]}]
|
|
36
|
+
return await graph_engine.get_filtered_graph_data(attribute_filters)
|
|
37
|
+
except Exception as exc: # noqa: BLE001
|
|
38
|
+
logger.error("Failed to fetch filtered graph data", error=str(exc))
|
|
39
|
+
return [], []
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _separate_feedback_and_interaction_nodes(graph_nodes: List) -> Tuple[List, List]:
|
|
43
|
+
"""Split nodes into feedback and interaction groups by type field."""
|
|
44
|
+
feedback_nodes = [
|
|
45
|
+
(_get_normalized_id(node_id, props), props)
|
|
46
|
+
for node_id, props in graph_nodes
|
|
47
|
+
if props.get("type") == "CogneeUserFeedback"
|
|
48
|
+
]
|
|
49
|
+
interaction_nodes = [
|
|
50
|
+
(_get_normalized_id(node_id, props), props)
|
|
51
|
+
for node_id, props in graph_nodes
|
|
52
|
+
if props.get("type") == "CogneeUserInteraction"
|
|
53
|
+
]
|
|
54
|
+
return feedback_nodes, interaction_nodes
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _match_feedback_nodes_to_interactions_by_edges(
|
|
58
|
+
feedback_nodes: List, interaction_nodes: List, graph_edges: List
|
|
59
|
+
) -> List[Tuple[Tuple, Tuple]]:
|
|
60
|
+
"""Match feedback to interactions using gives_feedback_to edges."""
|
|
61
|
+
interaction_by_id = {node_id: (node_id, props) for node_id, props in interaction_nodes}
|
|
62
|
+
feedback_by_id = {node_id: (node_id, props) for node_id, props in feedback_nodes}
|
|
63
|
+
feedback_edges = [
|
|
64
|
+
(source_id, target_id)
|
|
65
|
+
for source_id, target_id, rel, _ in graph_edges
|
|
66
|
+
if rel == "gives_feedback_to"
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
feedback_interaction_pairs: List[Tuple[Tuple, Tuple]] = []
|
|
70
|
+
for source_id, target_id in feedback_edges:
|
|
71
|
+
source_id_str, target_id_str = str(source_id), str(target_id)
|
|
72
|
+
|
|
73
|
+
feedback_node = feedback_by_id.get(source_id_str)
|
|
74
|
+
interaction_node = interaction_by_id.get(target_id_str)
|
|
75
|
+
|
|
76
|
+
if feedback_node and interaction_node:
|
|
77
|
+
feedback_interaction_pairs.append((feedback_node, interaction_node))
|
|
78
|
+
|
|
79
|
+
return feedback_interaction_pairs
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _sort_pairs_by_recency_and_limit(
|
|
83
|
+
feedback_interaction_pairs: List[Tuple[Tuple, Tuple]], last_n_limit: Optional[int]
|
|
84
|
+
) -> List[Tuple[Tuple, Tuple]]:
|
|
85
|
+
"""Sort by interaction created_at desc with updated_at fallback, then limit."""
|
|
86
|
+
|
|
87
|
+
def _recency_key(pair):
|
|
88
|
+
_, (_, interaction_props) = pair
|
|
89
|
+
created_at = interaction_props.get("created_at") or ""
|
|
90
|
+
updated_at = interaction_props.get("updated_at") or ""
|
|
91
|
+
return (created_at, updated_at)
|
|
92
|
+
|
|
93
|
+
sorted_pairs = sorted(feedback_interaction_pairs, key=_recency_key, reverse=True)
|
|
94
|
+
return sorted_pairs[: last_n_limit or len(sorted_pairs)]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
async def _generate_human_readable_context_summary(
|
|
98
|
+
question_text: str, raw_context_text: str
|
|
99
|
+
) -> str:
|
|
100
|
+
"""Generate a concise human-readable summary for given context."""
|
|
101
|
+
try:
|
|
102
|
+
prompt = read_query_prompt("feedback_user_context_prompt.txt")
|
|
103
|
+
rendered = prompt.format(question=question_text, context=raw_context_text)
|
|
104
|
+
return await LLMGateway.acreate_structured_output(
|
|
105
|
+
text_input=rendered, system_prompt="", response_model=str
|
|
106
|
+
)
|
|
107
|
+
except Exception as exc: # noqa: BLE001
|
|
108
|
+
logger.warning("Failed to summarize context", error=str(exc))
|
|
109
|
+
return raw_context_text or ""
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _has_required_feedback_fields(enrichment: FeedbackEnrichment) -> bool:
|
|
113
|
+
"""Validate required fields exist in the FeedbackEnrichment DataPoint."""
|
|
114
|
+
return (
|
|
115
|
+
enrichment.question is not None
|
|
116
|
+
and enrichment.original_answer is not None
|
|
117
|
+
and enrichment.context is not None
|
|
118
|
+
and enrichment.feedback_text is not None
|
|
119
|
+
and enrichment.feedback_id is not None
|
|
120
|
+
and enrichment.interaction_id is not None
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
async def _build_feedback_interaction_record(
|
|
125
|
+
feedback_node_id: str, feedback_props: Dict, interaction_node_id: str, interaction_props: Dict
|
|
126
|
+
) -> Optional[FeedbackEnrichment]:
|
|
127
|
+
"""Build a single FeedbackEnrichment DataPoint with context summary."""
|
|
128
|
+
try:
|
|
129
|
+
question_text = interaction_props.get("question")
|
|
130
|
+
original_answer_text = interaction_props.get("answer")
|
|
131
|
+
raw_context_text = interaction_props.get("context", "")
|
|
132
|
+
feedback_text = feedback_props.get("feedback") or feedback_props.get("text") or ""
|
|
133
|
+
|
|
134
|
+
context_summary_text = await _generate_human_readable_context_summary(
|
|
135
|
+
question_text or "", raw_context_text
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
enrichment = FeedbackEnrichment(
|
|
139
|
+
id=str(uuid5(NAMESPACE_OID, f"{question_text}_{interaction_node_id}")),
|
|
140
|
+
text="",
|
|
141
|
+
question=question_text,
|
|
142
|
+
original_answer=original_answer_text,
|
|
143
|
+
improved_answer="",
|
|
144
|
+
feedback_id=UUID(str(feedback_node_id)),
|
|
145
|
+
interaction_id=UUID(str(interaction_node_id)),
|
|
146
|
+
belongs_to_set=None,
|
|
147
|
+
context=context_summary_text,
|
|
148
|
+
feedback_text=feedback_text,
|
|
149
|
+
new_context="",
|
|
150
|
+
explanation="",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if _has_required_feedback_fields(enrichment):
|
|
154
|
+
return enrichment
|
|
155
|
+
else:
|
|
156
|
+
logger.warning("Skipping invalid feedback item", interaction=str(interaction_node_id))
|
|
157
|
+
return None
|
|
158
|
+
except Exception as exc: # noqa: BLE001
|
|
159
|
+
logger.error("Failed to process feedback pair", error=str(exc))
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
async def _build_feedback_interaction_records(
|
|
164
|
+
matched_feedback_interaction_pairs: List[Tuple[Tuple, Tuple]],
|
|
165
|
+
) -> List[FeedbackEnrichment]:
|
|
166
|
+
"""Build all FeedbackEnrichment DataPoints from matched pairs."""
|
|
167
|
+
feedback_interaction_records: List[FeedbackEnrichment] = []
|
|
168
|
+
for (feedback_node_id, feedback_props), (
|
|
169
|
+
interaction_node_id,
|
|
170
|
+
interaction_props,
|
|
171
|
+
) in matched_feedback_interaction_pairs:
|
|
172
|
+
record = await _build_feedback_interaction_record(
|
|
173
|
+
feedback_node_id, feedback_props, interaction_node_id, interaction_props
|
|
174
|
+
)
|
|
175
|
+
if record:
|
|
176
|
+
feedback_interaction_records.append(record)
|
|
177
|
+
return feedback_interaction_records
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
async def extract_feedback_interactions(
|
|
181
|
+
data: Any, last_n: Optional[int] = None
|
|
182
|
+
) -> List[FeedbackEnrichment]:
|
|
183
|
+
"""Extract negative feedback-interaction pairs and create FeedbackEnrichment DataPoints."""
|
|
184
|
+
if not data or data == [{}]:
|
|
185
|
+
logger.info(
|
|
186
|
+
"No data passed to the extraction task (extraction task fetches data from graph directly)",
|
|
187
|
+
data=data,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
graph_nodes, graph_edges = await _fetch_feedback_and_interaction_graph_data()
|
|
191
|
+
if not graph_nodes:
|
|
192
|
+
logger.warning("No graph nodes retrieved from database")
|
|
193
|
+
return []
|
|
194
|
+
|
|
195
|
+
feedback_nodes, interaction_nodes = _separate_feedback_and_interaction_nodes(graph_nodes)
|
|
196
|
+
logger.info(
|
|
197
|
+
"Retrieved nodes from graph",
|
|
198
|
+
total_nodes=len(graph_nodes),
|
|
199
|
+
feedback_nodes=len(feedback_nodes),
|
|
200
|
+
interaction_nodes=len(interaction_nodes),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
negative_feedback_nodes = _filter_negative_feedback(feedback_nodes)
|
|
204
|
+
logger.info(
|
|
205
|
+
"Filtered feedback nodes",
|
|
206
|
+
total_feedback=len(feedback_nodes),
|
|
207
|
+
negative_feedback=len(negative_feedback_nodes),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if not negative_feedback_nodes:
|
|
211
|
+
logger.info("No negative feedback found; returning empty list")
|
|
212
|
+
return []
|
|
213
|
+
|
|
214
|
+
matched_feedback_interaction_pairs = _match_feedback_nodes_to_interactions_by_edges(
|
|
215
|
+
negative_feedback_nodes, interaction_nodes, graph_edges
|
|
216
|
+
)
|
|
217
|
+
if not matched_feedback_interaction_pairs:
|
|
218
|
+
logger.info("No feedback-to-interaction matches found; returning empty list")
|
|
219
|
+
return []
|
|
220
|
+
|
|
221
|
+
matched_feedback_interaction_pairs = _sort_pairs_by_recency_and_limit(
|
|
222
|
+
matched_feedback_interaction_pairs, last_n
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
feedback_interaction_records = await _build_feedback_interaction_records(
|
|
226
|
+
matched_feedback_interaction_pairs
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
logger.info("Extracted feedback pairs", count=len(feedback_interaction_records))
|
|
230
|
+
return feedback_interaction_records
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from cognee.infrastructure.llm import LLMGateway
|
|
7
|
+
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
|
8
|
+
from cognee.modules.graph.utils import resolve_edges_to_text
|
|
9
|
+
from cognee.shared.logging_utils import get_logger
|
|
10
|
+
|
|
11
|
+
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
|
12
|
+
from .models import FeedbackEnrichment
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ImprovedAnswerResponse(BaseModel):
|
|
16
|
+
"""Response model for improved answer generation containing answer and explanation."""
|
|
17
|
+
|
|
18
|
+
answer: str
|
|
19
|
+
explanation: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
logger = get_logger("generate_improved_answers")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _validate_input_data(enrichments: List[FeedbackEnrichment]) -> bool:
|
|
26
|
+
"""Validate that input contains required fields for all enrichments."""
|
|
27
|
+
return all(
|
|
28
|
+
enrichment.question is not None
|
|
29
|
+
and enrichment.original_answer is not None
|
|
30
|
+
and enrichment.context is not None
|
|
31
|
+
and enrichment.feedback_text is not None
|
|
32
|
+
and enrichment.feedback_id is not None
|
|
33
|
+
and enrichment.interaction_id is not None
|
|
34
|
+
for enrichment in enrichments
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _render_reaction_prompt(
|
|
39
|
+
question: str, context: str, wrong_answer: str, negative_feedback: str
|
|
40
|
+
) -> str:
|
|
41
|
+
"""Render the feedback reaction prompt with provided variables."""
|
|
42
|
+
prompt_template = read_query_prompt("feedback_reaction_prompt.txt")
|
|
43
|
+
return prompt_template.format(
|
|
44
|
+
question=question,
|
|
45
|
+
context=context,
|
|
46
|
+
wrong_answer=wrong_answer,
|
|
47
|
+
negative_feedback=negative_feedback,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
async def _generate_improved_answer_for_single_interaction(
|
|
52
|
+
enrichment: FeedbackEnrichment, retriever, reaction_prompt_location: str
|
|
53
|
+
) -> Optional[FeedbackEnrichment]:
|
|
54
|
+
"""Generate improved answer for a single enrichment using structured retriever completion."""
|
|
55
|
+
try:
|
|
56
|
+
query_text = _render_reaction_prompt(
|
|
57
|
+
enrichment.question,
|
|
58
|
+
enrichment.context,
|
|
59
|
+
enrichment.original_answer,
|
|
60
|
+
enrichment.feedback_text,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
retrieved_context = await retriever.get_context(query_text)
|
|
64
|
+
completion = await retriever.get_structured_completion(
|
|
65
|
+
query=query_text,
|
|
66
|
+
context=retrieved_context,
|
|
67
|
+
response_model=ImprovedAnswerResponse,
|
|
68
|
+
max_iter=4,
|
|
69
|
+
)
|
|
70
|
+
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
|
|
71
|
+
|
|
72
|
+
if completion:
|
|
73
|
+
enrichment.improved_answer = completion.answer
|
|
74
|
+
enrichment.new_context = new_context_text
|
|
75
|
+
enrichment.explanation = completion.explanation
|
|
76
|
+
return enrichment
|
|
77
|
+
else:
|
|
78
|
+
logger.warning(
|
|
79
|
+
"Failed to get structured completion from retriever", question=enrichment.question
|
|
80
|
+
)
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
except Exception as exc: # noqa: BLE001
|
|
84
|
+
logger.error(
|
|
85
|
+
"Failed to generate improved answer",
|
|
86
|
+
error=str(exc),
|
|
87
|
+
question=enrichment.question,
|
|
88
|
+
)
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
async def generate_improved_answers(
|
|
93
|
+
enrichments: List[FeedbackEnrichment],
|
|
94
|
+
top_k: int = 20,
|
|
95
|
+
reaction_prompt_location: str = "feedback_reaction_prompt.txt",
|
|
96
|
+
) -> List[FeedbackEnrichment]:
|
|
97
|
+
"""Generate improved answers using CoT retriever and LLM."""
|
|
98
|
+
if not enrichments:
|
|
99
|
+
logger.info("No enrichments provided; returning empty list")
|
|
100
|
+
return []
|
|
101
|
+
|
|
102
|
+
if not _validate_input_data(enrichments):
|
|
103
|
+
logger.error("Input data validation failed; missing required fields")
|
|
104
|
+
return []
|
|
105
|
+
|
|
106
|
+
retriever = GraphCompletionCotRetriever(
|
|
107
|
+
top_k=top_k,
|
|
108
|
+
save_interaction=False,
|
|
109
|
+
user_prompt_path="graph_context_for_question.txt",
|
|
110
|
+
system_prompt_path="answer_simple_question.txt",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
improved_answers: List[FeedbackEnrichment] = []
|
|
114
|
+
|
|
115
|
+
for enrichment in enrichments:
|
|
116
|
+
result = await _generate_improved_answer_for_single_interaction(
|
|
117
|
+
enrichment, retriever, reaction_prompt_location
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if result:
|
|
121
|
+
improved_answers.append(result)
|
|
122
|
+
else:
|
|
123
|
+
logger.warning(
|
|
124
|
+
"Failed to generate improved answer",
|
|
125
|
+
question=enrichment.question,
|
|
126
|
+
interaction_id=enrichment.interaction_id,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
logger.info("Generated improved answers", count=len(improved_answers))
|
|
130
|
+
return improved_answers
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
from uuid import UUID
|
|
5
|
+
|
|
6
|
+
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
7
|
+
from cognee.tasks.storage import index_graph_edges
|
|
8
|
+
from cognee.shared.logging_utils import get_logger
|
|
9
|
+
|
|
10
|
+
from .models import FeedbackEnrichment
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
logger = get_logger("link_enrichments_to_feedback")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _create_edge_tuple(
|
|
17
|
+
source_id: UUID, target_id: UUID, relationship_name: str
|
|
18
|
+
) -> Tuple[UUID, UUID, str, dict]:
|
|
19
|
+
"""Create an edge tuple with proper properties structure."""
|
|
20
|
+
return (
|
|
21
|
+
source_id,
|
|
22
|
+
target_id,
|
|
23
|
+
relationship_name,
|
|
24
|
+
{
|
|
25
|
+
"relationship_name": relationship_name,
|
|
26
|
+
"source_node_id": source_id,
|
|
27
|
+
"target_node_id": target_id,
|
|
28
|
+
"ontology_valid": False,
|
|
29
|
+
},
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def link_enrichments_to_feedback(
|
|
34
|
+
enrichments: List[FeedbackEnrichment],
|
|
35
|
+
) -> List[FeedbackEnrichment]:
|
|
36
|
+
"""Manually create edges from enrichments to original feedback/interaction nodes."""
|
|
37
|
+
if not enrichments:
|
|
38
|
+
logger.info("No enrichments provided; returning empty list")
|
|
39
|
+
return []
|
|
40
|
+
|
|
41
|
+
relationships = []
|
|
42
|
+
|
|
43
|
+
for enrichment in enrichments:
|
|
44
|
+
enrichment_id = enrichment.id
|
|
45
|
+
feedback_id = enrichment.feedback_id
|
|
46
|
+
interaction_id = enrichment.interaction_id
|
|
47
|
+
|
|
48
|
+
if enrichment_id and feedback_id:
|
|
49
|
+
enriches_feedback_edge = _create_edge_tuple(
|
|
50
|
+
enrichment_id, feedback_id, "enriches_feedback"
|
|
51
|
+
)
|
|
52
|
+
relationships.append(enriches_feedback_edge)
|
|
53
|
+
|
|
54
|
+
if enrichment_id and interaction_id:
|
|
55
|
+
improves_interaction_edge = _create_edge_tuple(
|
|
56
|
+
enrichment_id, interaction_id, "improves_interaction"
|
|
57
|
+
)
|
|
58
|
+
relationships.append(improves_interaction_edge)
|
|
59
|
+
|
|
60
|
+
if relationships:
|
|
61
|
+
graph_engine = await get_graph_engine()
|
|
62
|
+
await graph_engine.add_edges(relationships)
|
|
63
|
+
await index_graph_edges(relationships)
|
|
64
|
+
logger.info("Linking enrichments to feedback", edge_count=len(relationships))
|
|
65
|
+
|
|
66
|
+
logger.info("Linked enrichments", enrichment_count=len(enrichments))
|
|
67
|
+
return enrichments
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from cognee.infrastructure.engine import DataPoint
|
|
5
|
+
from cognee.modules.engine.models import Entity, NodeSet
|
|
6
|
+
from cognee.tasks.temporal_graph.models import Event
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FeedbackEnrichment(DataPoint):
|
|
10
|
+
"""Minimal DataPoint for feedback enrichment that works with extract_graph_from_data."""
|
|
11
|
+
|
|
12
|
+
text: str
|
|
13
|
+
contains: Optional[List[Union[Entity, Event]]] = None
|
|
14
|
+
metadata: dict = {"index_fields": ["text"]}
|
|
15
|
+
|
|
16
|
+
question: str
|
|
17
|
+
original_answer: str
|
|
18
|
+
improved_answer: str
|
|
19
|
+
feedback_id: UUID
|
|
20
|
+
interaction_id: UUID
|
|
21
|
+
belongs_to_set: Optional[List[NodeSet]] = None
|
|
22
|
+
|
|
23
|
+
context: str = ""
|
|
24
|
+
feedback_text: str = ""
|
|
25
|
+
new_context: str = ""
|
|
26
|
+
explanation: str = ""
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""
|
|
2
|
+
End-to-end integration test for feedback enrichment feature.
|
|
3
|
+
|
|
4
|
+
Tests the complete feedback enrichment pipeline:
|
|
5
|
+
1. Add data and cognify
|
|
6
|
+
2. Run search with save_interaction=True to create CogneeUserInteraction nodes
|
|
7
|
+
3. Submit feedback to create CogneeUserFeedback nodes
|
|
8
|
+
4. Run memify with feedback enrichment tasks to create FeedbackEnrichment nodes
|
|
9
|
+
5. Verify all nodes and edges are properly created and linked in the graph
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
import pathlib
|
|
14
|
+
from collections import Counter
|
|
15
|
+
|
|
16
|
+
import cognee
|
|
17
|
+
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
18
|
+
from cognee.modules.pipelines.tasks.task import Task
|
|
19
|
+
from cognee.modules.search.types import SearchType
|
|
20
|
+
from cognee.shared.data_models import KnowledgeGraph
|
|
21
|
+
from cognee.shared.logging_utils import get_logger
|
|
22
|
+
from cognee.tasks.feedback.create_enrichments import create_enrichments
|
|
23
|
+
from cognee.tasks.feedback.extract_feedback_interactions import (
|
|
24
|
+
extract_feedback_interactions,
|
|
25
|
+
)
|
|
26
|
+
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
|
|
27
|
+
from cognee.tasks.feedback.link_enrichments_to_feedback import (
|
|
28
|
+
link_enrichments_to_feedback,
|
|
29
|
+
)
|
|
30
|
+
from cognee.tasks.graph import extract_graph_from_data
|
|
31
|
+
from cognee.tasks.storage import add_data_points
|
|
32
|
+
|
|
33
|
+
logger = get_logger()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
async def main():
|
|
37
|
+
data_directory_path = str(
|
|
38
|
+
pathlib.Path(
|
|
39
|
+
os.path.join(
|
|
40
|
+
pathlib.Path(__file__).parent,
|
|
41
|
+
".data_storage/test_feedback_enrichment",
|
|
42
|
+
)
|
|
43
|
+
).resolve()
|
|
44
|
+
)
|
|
45
|
+
cognee_directory_path = str(
|
|
46
|
+
pathlib.Path(
|
|
47
|
+
os.path.join(
|
|
48
|
+
pathlib.Path(__file__).parent,
|
|
49
|
+
".cognee_system/test_feedback_enrichment",
|
|
50
|
+
)
|
|
51
|
+
).resolve()
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
cognee.config.data_root_directory(data_directory_path)
|
|
55
|
+
cognee.config.system_root_directory(cognee_directory_path)
|
|
56
|
+
|
|
57
|
+
await cognee.prune.prune_data()
|
|
58
|
+
await cognee.prune.prune_system(metadata=True)
|
|
59
|
+
|
|
60
|
+
dataset_name = "feedback_enrichment_test"
|
|
61
|
+
|
|
62
|
+
await cognee.add("Cognee turns documents into AI memory.", dataset_name)
|
|
63
|
+
await cognee.cognify([dataset_name])
|
|
64
|
+
|
|
65
|
+
question_text = "Say something."
|
|
66
|
+
result = await cognee.search(
|
|
67
|
+
query_type=SearchType.GRAPH_COMPLETION,
|
|
68
|
+
query_text=question_text,
|
|
69
|
+
save_interaction=True,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
assert len(result) > 0, "Search should return non-empty results"
|
|
73
|
+
|
|
74
|
+
feedback_text = "This answer was completely useless, my feedback is definitely negative."
|
|
75
|
+
await cognee.search(
|
|
76
|
+
query_type=SearchType.FEEDBACK,
|
|
77
|
+
query_text=feedback_text,
|
|
78
|
+
last_k=1,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
graph_engine = await get_graph_engine()
|
|
82
|
+
nodes_before, edges_before = await graph_engine.get_graph_data()
|
|
83
|
+
|
|
84
|
+
interaction_nodes_before = [
|
|
85
|
+
(node_id, props)
|
|
86
|
+
for node_id, props in nodes_before
|
|
87
|
+
if props.get("type") == "CogneeUserInteraction"
|
|
88
|
+
]
|
|
89
|
+
feedback_nodes_before = [
|
|
90
|
+
(node_id, props)
|
|
91
|
+
for node_id, props in nodes_before
|
|
92
|
+
if props.get("type") == "CogneeUserFeedback"
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
edge_types_before = Counter(edge[2] for edge in edges_before)
|
|
96
|
+
|
|
97
|
+
assert len(interaction_nodes_before) >= 1, (
|
|
98
|
+
f"Expected at least 1 CogneeUserInteraction node, found {len(interaction_nodes_before)}"
|
|
99
|
+
)
|
|
100
|
+
assert len(feedback_nodes_before) >= 1, (
|
|
101
|
+
f"Expected at least 1 CogneeUserFeedback node, found {len(feedback_nodes_before)}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
for node_id, props in feedback_nodes_before:
|
|
105
|
+
sentiment = props.get("sentiment", "")
|
|
106
|
+
score = props.get("score", 0)
|
|
107
|
+
feedback_text = props.get("feedback", "")
|
|
108
|
+
logger.info(
|
|
109
|
+
"Feedback node created",
|
|
110
|
+
feedback=feedback_text,
|
|
111
|
+
sentiment=sentiment,
|
|
112
|
+
score=score,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
assert edge_types_before.get("gives_feedback_to", 0) >= 1, (
|
|
116
|
+
f"Expected at least 1 'gives_feedback_to' edge, found {edge_types_before.get('gives_feedback_to', 0)}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
extraction_tasks = [Task(extract_feedback_interactions, last_n=5)]
|
|
120
|
+
enrichment_tasks = [
|
|
121
|
+
Task(generate_improved_answers, top_k=20),
|
|
122
|
+
Task(create_enrichments),
|
|
123
|
+
Task(
|
|
124
|
+
extract_graph_from_data,
|
|
125
|
+
graph_model=KnowledgeGraph,
|
|
126
|
+
task_config={"batch_size": 10},
|
|
127
|
+
),
|
|
128
|
+
Task(add_data_points, task_config={"batch_size": 10}),
|
|
129
|
+
Task(link_enrichments_to_feedback),
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
await cognee.memify(
|
|
133
|
+
extraction_tasks=extraction_tasks,
|
|
134
|
+
enrichment_tasks=enrichment_tasks,
|
|
135
|
+
data=[{}],
|
|
136
|
+
dataset="feedback_enrichment_test_memify",
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
nodes_after, edges_after = await graph_engine.get_graph_data()
|
|
140
|
+
|
|
141
|
+
enrichment_nodes = [
|
|
142
|
+
(node_id, props)
|
|
143
|
+
for node_id, props in nodes_after
|
|
144
|
+
if props.get("type") == "FeedbackEnrichment"
|
|
145
|
+
]
|
|
146
|
+
|
|
147
|
+
assert len(enrichment_nodes) >= 1, (
|
|
148
|
+
f"Expected at least 1 FeedbackEnrichment node, found {len(enrichment_nodes)}"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
for node_id, props in enrichment_nodes:
|
|
152
|
+
assert "text" in props, f"FeedbackEnrichment node {node_id} missing 'text' property"
|
|
153
|
+
|
|
154
|
+
enrichment_node_ids = {node_id for node_id, _ in enrichment_nodes}
|
|
155
|
+
edges_with_enrichments = [
|
|
156
|
+
edge
|
|
157
|
+
for edge in edges_after
|
|
158
|
+
if edge[0] in enrichment_node_ids or edge[1] in enrichment_node_ids
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
assert len(edges_with_enrichments) >= 1, (
|
|
162
|
+
f"Expected enrichment nodes to have at least 1 edge, found {len(edges_with_enrichments)}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
await cognee.prune.prune_data()
|
|
166
|
+
await cognee.prune.prune_system(metadata=True)
|
|
167
|
+
|
|
168
|
+
logger.info("All feedback enrichment tests passed successfully")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
if __name__ == "__main__":
|
|
172
|
+
import asyncio
|
|
173
|
+
|
|
174
|
+
asyncio.run(main())
|
|
@@ -2,6 +2,7 @@ import os
|
|
|
2
2
|
import pytest
|
|
3
3
|
import pathlib
|
|
4
4
|
from typing import Optional, Union
|
|
5
|
+
from pydantic import BaseModel
|
|
5
6
|
|
|
6
7
|
import cognee
|
|
7
8
|
from cognee.low_level import setup, DataPoint
|
|
@@ -10,6 +11,11 @@ from cognee.tasks.storage import add_data_points
|
|
|
10
11
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
class TestAnswer(BaseModel):
|
|
15
|
+
answer: str
|
|
16
|
+
explanation: str
|
|
17
|
+
|
|
18
|
+
|
|
13
19
|
class TestGraphCompletionCoTRetriever:
|
|
14
20
|
@pytest.mark.asyncio
|
|
15
21
|
async def test_graph_completion_cot_context_simple(self):
|
|
@@ -168,3 +174,48 @@ class TestGraphCompletionCoTRetriever:
|
|
|
168
174
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
|
169
175
|
"Answer must contain only non-empty strings"
|
|
170
176
|
)
|
|
177
|
+
|
|
178
|
+
@pytest.mark.asyncio
|
|
179
|
+
async def test_get_structured_completion(self):
|
|
180
|
+
system_directory_path = os.path.join(
|
|
181
|
+
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
|
182
|
+
)
|
|
183
|
+
cognee.config.system_root_directory(system_directory_path)
|
|
184
|
+
data_directory_path = os.path.join(
|
|
185
|
+
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
|
186
|
+
)
|
|
187
|
+
cognee.config.data_root_directory(data_directory_path)
|
|
188
|
+
|
|
189
|
+
await cognee.prune.prune_data()
|
|
190
|
+
await cognee.prune.prune_system(metadata=True)
|
|
191
|
+
await setup()
|
|
192
|
+
|
|
193
|
+
class Company(DataPoint):
|
|
194
|
+
name: str
|
|
195
|
+
|
|
196
|
+
class Person(DataPoint):
|
|
197
|
+
name: str
|
|
198
|
+
works_for: Company
|
|
199
|
+
|
|
200
|
+
company1 = Company(name="Figma")
|
|
201
|
+
person1 = Person(name="Steve Rodger", works_for=company1)
|
|
202
|
+
|
|
203
|
+
entities = [company1, person1]
|
|
204
|
+
await add_data_points(entities)
|
|
205
|
+
|
|
206
|
+
retriever = GraphCompletionCotRetriever()
|
|
207
|
+
|
|
208
|
+
# Test with string response model (default)
|
|
209
|
+
string_answer = await retriever.get_structured_completion("Who works at Figma?")
|
|
210
|
+
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
|
|
211
|
+
assert string_answer.strip(), "Answer should not be empty"
|
|
212
|
+
|
|
213
|
+
# Test with structured response model
|
|
214
|
+
structured_answer = await retriever.get_structured_completion(
|
|
215
|
+
"Who works at Figma?", response_model=TestAnswer
|
|
216
|
+
)
|
|
217
|
+
assert isinstance(structured_answer, TestAnswer), (
|
|
218
|
+
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
|
219
|
+
)
|
|
220
|
+
assert structured_answer.answer.strip(), "Answer field should not be empty"
|
|
221
|
+
assert structured_answer.explanation.strip(), "Explanation field should not be empty"
|