cognee 0.3.7__py3-none-any.whl → 0.3.7.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cognee/api/v1/add/routers/get_add_router.py +6 -1
  2. cognee/api/v1/cognify/routers/get_cognify_router.py +2 -1
  3. cognee/api/v1/datasets/routers/get_datasets_router.py +8 -0
  4. cognee/api/v1/delete/routers/get_delete_router.py +2 -0
  5. cognee/api/v1/memify/routers/get_memify_router.py +2 -1
  6. cognee/api/v1/permissions/routers/get_permissions_router.py +6 -0
  7. cognee/api/v1/search/routers/get_search_router.py +3 -3
  8. cognee/api/v1/sync/routers/get_sync_router.py +3 -0
  9. cognee/api/v1/ui/ui.py +2 -4
  10. cognee/api/v1/update/routers/get_update_router.py +2 -0
  11. cognee/api/v1/users/routers/get_visualize_router.py +2 -0
  12. cognee/infrastructure/databases/graph/kuzu/adapter.py +9 -3
  13. cognee/infrastructure/llm/prompts/feedback_reaction_prompt.txt +14 -0
  14. cognee/infrastructure/llm/prompts/feedback_report_prompt.txt +13 -0
  15. cognee/infrastructure/llm/prompts/feedback_user_context_prompt.txt +5 -0
  16. cognee/modules/pipelines/operations/run_tasks_base.py +7 -0
  17. cognee/modules/pipelines/operations/run_tasks_with_telemetry.py +9 -1
  18. cognee/modules/retrieval/graph_completion_cot_retriever.py +137 -38
  19. cognee/modules/retrieval/utils/completion.py +25 -4
  20. cognee/modules/search/methods/search.py +17 -3
  21. cognee/shared/logging_utils.py +24 -12
  22. cognee/shared/utils.py +24 -2
  23. cognee/tasks/feedback/__init__.py +13 -0
  24. cognee/tasks/feedback/create_enrichments.py +84 -0
  25. cognee/tasks/feedback/extract_feedback_interactions.py +230 -0
  26. cognee/tasks/feedback/generate_improved_answers.py +130 -0
  27. cognee/tasks/feedback/link_enrichments_to_feedback.py +67 -0
  28. cognee/tasks/feedback/models.py +26 -0
  29. cognee/tests/test_feedback_enrichment.py +174 -0
  30. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +51 -0
  31. {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/METADATA +1 -1
  32. {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/RECORD +36 -26
  33. {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/WHEEL +0 -0
  34. {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/entry_points.txt +0 -0
  35. {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/licenses/LICENSE +0 -0
  36. {cognee-0.3.7.dist-info → cognee-0.3.7.dev2.dist-info}/licenses/NOTICE.md +0 -0
@@ -0,0 +1,230 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+ from uuid import UUID, uuid5, NAMESPACE_OID
5
+
6
+ from cognee.infrastructure.llm import LLMGateway
7
+ from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
8
+ from cognee.shared.logging_utils import get_logger
9
+ from cognee.infrastructure.databases.graph import get_graph_engine
10
+
11
+ from .models import FeedbackEnrichment
12
+
13
+
14
+ logger = get_logger("extract_feedback_interactions")
15
+
16
+
17
+ def _filter_negative_feedback(feedback_nodes):
18
+ """Filter for negative sentiment feedback using precise sentiment classification."""
19
+ return [
20
+ (node_id, props)
21
+ for node_id, props in feedback_nodes
22
+ if (props.get("sentiment", "").casefold() == "negative" or props.get("score", 0) < 0)
23
+ ]
24
+
25
+
26
+ def _get_normalized_id(node_id, props) -> str:
27
+ """Return Cognee node id preference: props.id → props.node_id → raw node_id."""
28
+ return str(props.get("id") or props.get("node_id") or node_id)
29
+
30
+
31
+ async def _fetch_feedback_and_interaction_graph_data() -> Tuple[List, List]:
32
+ """Fetch feedback and interaction nodes with edges from graph engine."""
33
+ try:
34
+ graph_engine = await get_graph_engine()
35
+ attribute_filters = [{"type": ["CogneeUserFeedback", "CogneeUserInteraction"]}]
36
+ return await graph_engine.get_filtered_graph_data(attribute_filters)
37
+ except Exception as exc: # noqa: BLE001
38
+ logger.error("Failed to fetch filtered graph data", error=str(exc))
39
+ return [], []
40
+
41
+
42
+ def _separate_feedback_and_interaction_nodes(graph_nodes: List) -> Tuple[List, List]:
43
+ """Split nodes into feedback and interaction groups by type field."""
44
+ feedback_nodes = [
45
+ (_get_normalized_id(node_id, props), props)
46
+ for node_id, props in graph_nodes
47
+ if props.get("type") == "CogneeUserFeedback"
48
+ ]
49
+ interaction_nodes = [
50
+ (_get_normalized_id(node_id, props), props)
51
+ for node_id, props in graph_nodes
52
+ if props.get("type") == "CogneeUserInteraction"
53
+ ]
54
+ return feedback_nodes, interaction_nodes
55
+
56
+
57
+ def _match_feedback_nodes_to_interactions_by_edges(
58
+ feedback_nodes: List, interaction_nodes: List, graph_edges: List
59
+ ) -> List[Tuple[Tuple, Tuple]]:
60
+ """Match feedback to interactions using gives_feedback_to edges."""
61
+ interaction_by_id = {node_id: (node_id, props) for node_id, props in interaction_nodes}
62
+ feedback_by_id = {node_id: (node_id, props) for node_id, props in feedback_nodes}
63
+ feedback_edges = [
64
+ (source_id, target_id)
65
+ for source_id, target_id, rel, _ in graph_edges
66
+ if rel == "gives_feedback_to"
67
+ ]
68
+
69
+ feedback_interaction_pairs: List[Tuple[Tuple, Tuple]] = []
70
+ for source_id, target_id in feedback_edges:
71
+ source_id_str, target_id_str = str(source_id), str(target_id)
72
+
73
+ feedback_node = feedback_by_id.get(source_id_str)
74
+ interaction_node = interaction_by_id.get(target_id_str)
75
+
76
+ if feedback_node and interaction_node:
77
+ feedback_interaction_pairs.append((feedback_node, interaction_node))
78
+
79
+ return feedback_interaction_pairs
80
+
81
+
82
+ def _sort_pairs_by_recency_and_limit(
83
+ feedback_interaction_pairs: List[Tuple[Tuple, Tuple]], last_n_limit: Optional[int]
84
+ ) -> List[Tuple[Tuple, Tuple]]:
85
+ """Sort by interaction created_at desc with updated_at fallback, then limit."""
86
+
87
+ def _recency_key(pair):
88
+ _, (_, interaction_props) = pair
89
+ created_at = interaction_props.get("created_at") or ""
90
+ updated_at = interaction_props.get("updated_at") or ""
91
+ return (created_at, updated_at)
92
+
93
+ sorted_pairs = sorted(feedback_interaction_pairs, key=_recency_key, reverse=True)
94
+ return sorted_pairs[: last_n_limit or len(sorted_pairs)]
95
+
96
+
97
+ async def _generate_human_readable_context_summary(
98
+ question_text: str, raw_context_text: str
99
+ ) -> str:
100
+ """Generate a concise human-readable summary for given context."""
101
+ try:
102
+ prompt = read_query_prompt("feedback_user_context_prompt.txt")
103
+ rendered = prompt.format(question=question_text, context=raw_context_text)
104
+ return await LLMGateway.acreate_structured_output(
105
+ text_input=rendered, system_prompt="", response_model=str
106
+ )
107
+ except Exception as exc: # noqa: BLE001
108
+ logger.warning("Failed to summarize context", error=str(exc))
109
+ return raw_context_text or ""
110
+
111
+
112
+ def _has_required_feedback_fields(enrichment: FeedbackEnrichment) -> bool:
113
+ """Validate required fields exist in the FeedbackEnrichment DataPoint."""
114
+ return (
115
+ enrichment.question is not None
116
+ and enrichment.original_answer is not None
117
+ and enrichment.context is not None
118
+ and enrichment.feedback_text is not None
119
+ and enrichment.feedback_id is not None
120
+ and enrichment.interaction_id is not None
121
+ )
122
+
123
+
124
+ async def _build_feedback_interaction_record(
125
+ feedback_node_id: str, feedback_props: Dict, interaction_node_id: str, interaction_props: Dict
126
+ ) -> Optional[FeedbackEnrichment]:
127
+ """Build a single FeedbackEnrichment DataPoint with context summary."""
128
+ try:
129
+ question_text = interaction_props.get("question")
130
+ original_answer_text = interaction_props.get("answer")
131
+ raw_context_text = interaction_props.get("context", "")
132
+ feedback_text = feedback_props.get("feedback") or feedback_props.get("text") or ""
133
+
134
+ context_summary_text = await _generate_human_readable_context_summary(
135
+ question_text or "", raw_context_text
136
+ )
137
+
138
+ enrichment = FeedbackEnrichment(
139
+ id=str(uuid5(NAMESPACE_OID, f"{question_text}_{interaction_node_id}")),
140
+ text="",
141
+ question=question_text,
142
+ original_answer=original_answer_text,
143
+ improved_answer="",
144
+ feedback_id=UUID(str(feedback_node_id)),
145
+ interaction_id=UUID(str(interaction_node_id)),
146
+ belongs_to_set=None,
147
+ context=context_summary_text,
148
+ feedback_text=feedback_text,
149
+ new_context="",
150
+ explanation="",
151
+ )
152
+
153
+ if _has_required_feedback_fields(enrichment):
154
+ return enrichment
155
+ else:
156
+ logger.warning("Skipping invalid feedback item", interaction=str(interaction_node_id))
157
+ return None
158
+ except Exception as exc: # noqa: BLE001
159
+ logger.error("Failed to process feedback pair", error=str(exc))
160
+ return None
161
+
162
+
163
+ async def _build_feedback_interaction_records(
164
+ matched_feedback_interaction_pairs: List[Tuple[Tuple, Tuple]],
165
+ ) -> List[FeedbackEnrichment]:
166
+ """Build all FeedbackEnrichment DataPoints from matched pairs."""
167
+ feedback_interaction_records: List[FeedbackEnrichment] = []
168
+ for (feedback_node_id, feedback_props), (
169
+ interaction_node_id,
170
+ interaction_props,
171
+ ) in matched_feedback_interaction_pairs:
172
+ record = await _build_feedback_interaction_record(
173
+ feedback_node_id, feedback_props, interaction_node_id, interaction_props
174
+ )
175
+ if record:
176
+ feedback_interaction_records.append(record)
177
+ return feedback_interaction_records
178
+
179
+
180
+ async def extract_feedback_interactions(
181
+ data: Any, last_n: Optional[int] = None
182
+ ) -> List[FeedbackEnrichment]:
183
+ """Extract negative feedback-interaction pairs and create FeedbackEnrichment DataPoints."""
184
+ if not data or data == [{}]:
185
+ logger.info(
186
+ "No data passed to the extraction task (extraction task fetches data from graph directly)",
187
+ data=data,
188
+ )
189
+
190
+ graph_nodes, graph_edges = await _fetch_feedback_and_interaction_graph_data()
191
+ if not graph_nodes:
192
+ logger.warning("No graph nodes retrieved from database")
193
+ return []
194
+
195
+ feedback_nodes, interaction_nodes = _separate_feedback_and_interaction_nodes(graph_nodes)
196
+ logger.info(
197
+ "Retrieved nodes from graph",
198
+ total_nodes=len(graph_nodes),
199
+ feedback_nodes=len(feedback_nodes),
200
+ interaction_nodes=len(interaction_nodes),
201
+ )
202
+
203
+ negative_feedback_nodes = _filter_negative_feedback(feedback_nodes)
204
+ logger.info(
205
+ "Filtered feedback nodes",
206
+ total_feedback=len(feedback_nodes),
207
+ negative_feedback=len(negative_feedback_nodes),
208
+ )
209
+
210
+ if not negative_feedback_nodes:
211
+ logger.info("No negative feedback found; returning empty list")
212
+ return []
213
+
214
+ matched_feedback_interaction_pairs = _match_feedback_nodes_to_interactions_by_edges(
215
+ negative_feedback_nodes, interaction_nodes, graph_edges
216
+ )
217
+ if not matched_feedback_interaction_pairs:
218
+ logger.info("No feedback-to-interaction matches found; returning empty list")
219
+ return []
220
+
221
+ matched_feedback_interaction_pairs = _sort_pairs_by_recency_and_limit(
222
+ matched_feedback_interaction_pairs, last_n
223
+ )
224
+
225
+ feedback_interaction_records = await _build_feedback_interaction_records(
226
+ matched_feedback_interaction_pairs
227
+ )
228
+
229
+ logger.info("Extracted feedback pairs", count=len(feedback_interaction_records))
230
+ return feedback_interaction_records
@@ -0,0 +1,130 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Optional
4
+ from pydantic import BaseModel
5
+
6
+ from cognee.infrastructure.llm import LLMGateway
7
+ from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
8
+ from cognee.modules.graph.utils import resolve_edges_to_text
9
+ from cognee.shared.logging_utils import get_logger
10
+
11
+ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
12
+ from .models import FeedbackEnrichment
13
+
14
+
15
+ class ImprovedAnswerResponse(BaseModel):
16
+ """Response model for improved answer generation containing answer and explanation."""
17
+
18
+ answer: str
19
+ explanation: str
20
+
21
+
22
+ logger = get_logger("generate_improved_answers")
23
+
24
+
25
+ def _validate_input_data(enrichments: List[FeedbackEnrichment]) -> bool:
26
+ """Validate that input contains required fields for all enrichments."""
27
+ return all(
28
+ enrichment.question is not None
29
+ and enrichment.original_answer is not None
30
+ and enrichment.context is not None
31
+ and enrichment.feedback_text is not None
32
+ and enrichment.feedback_id is not None
33
+ and enrichment.interaction_id is not None
34
+ for enrichment in enrichments
35
+ )
36
+
37
+
38
+ def _render_reaction_prompt(
39
+ question: str, context: str, wrong_answer: str, negative_feedback: str
40
+ ) -> str:
41
+ """Render the feedback reaction prompt with provided variables."""
42
+ prompt_template = read_query_prompt("feedback_reaction_prompt.txt")
43
+ return prompt_template.format(
44
+ question=question,
45
+ context=context,
46
+ wrong_answer=wrong_answer,
47
+ negative_feedback=negative_feedback,
48
+ )
49
+
50
+
51
+ async def _generate_improved_answer_for_single_interaction(
52
+ enrichment: FeedbackEnrichment, retriever, reaction_prompt_location: str
53
+ ) -> Optional[FeedbackEnrichment]:
54
+ """Generate improved answer for a single enrichment using structured retriever completion."""
55
+ try:
56
+ query_text = _render_reaction_prompt(
57
+ enrichment.question,
58
+ enrichment.context,
59
+ enrichment.original_answer,
60
+ enrichment.feedback_text,
61
+ )
62
+
63
+ retrieved_context = await retriever.get_context(query_text)
64
+ completion = await retriever.get_structured_completion(
65
+ query=query_text,
66
+ context=retrieved_context,
67
+ response_model=ImprovedAnswerResponse,
68
+ max_iter=4,
69
+ )
70
+ new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
71
+
72
+ if completion:
73
+ enrichment.improved_answer = completion.answer
74
+ enrichment.new_context = new_context_text
75
+ enrichment.explanation = completion.explanation
76
+ return enrichment
77
+ else:
78
+ logger.warning(
79
+ "Failed to get structured completion from retriever", question=enrichment.question
80
+ )
81
+ return None
82
+
83
+ except Exception as exc: # noqa: BLE001
84
+ logger.error(
85
+ "Failed to generate improved answer",
86
+ error=str(exc),
87
+ question=enrichment.question,
88
+ )
89
+ return None
90
+
91
+
92
+ async def generate_improved_answers(
93
+ enrichments: List[FeedbackEnrichment],
94
+ top_k: int = 20,
95
+ reaction_prompt_location: str = "feedback_reaction_prompt.txt",
96
+ ) -> List[FeedbackEnrichment]:
97
+ """Generate improved answers using CoT retriever and LLM."""
98
+ if not enrichments:
99
+ logger.info("No enrichments provided; returning empty list")
100
+ return []
101
+
102
+ if not _validate_input_data(enrichments):
103
+ logger.error("Input data validation failed; missing required fields")
104
+ return []
105
+
106
+ retriever = GraphCompletionCotRetriever(
107
+ top_k=top_k,
108
+ save_interaction=False,
109
+ user_prompt_path="graph_context_for_question.txt",
110
+ system_prompt_path="answer_simple_question.txt",
111
+ )
112
+
113
+ improved_answers: List[FeedbackEnrichment] = []
114
+
115
+ for enrichment in enrichments:
116
+ result = await _generate_improved_answer_for_single_interaction(
117
+ enrichment, retriever, reaction_prompt_location
118
+ )
119
+
120
+ if result:
121
+ improved_answers.append(result)
122
+ else:
123
+ logger.warning(
124
+ "Failed to generate improved answer",
125
+ question=enrichment.question,
126
+ interaction_id=enrichment.interaction_id,
127
+ )
128
+
129
+ logger.info("Generated improved answers", count=len(improved_answers))
130
+ return improved_answers
@@ -0,0 +1,67 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Tuple
4
+ from uuid import UUID
5
+
6
+ from cognee.infrastructure.databases.graph import get_graph_engine
7
+ from cognee.tasks.storage import index_graph_edges
8
+ from cognee.shared.logging_utils import get_logger
9
+
10
+ from .models import FeedbackEnrichment
11
+
12
+
13
+ logger = get_logger("link_enrichments_to_feedback")
14
+
15
+
16
+ def _create_edge_tuple(
17
+ source_id: UUID, target_id: UUID, relationship_name: str
18
+ ) -> Tuple[UUID, UUID, str, dict]:
19
+ """Create an edge tuple with proper properties structure."""
20
+ return (
21
+ source_id,
22
+ target_id,
23
+ relationship_name,
24
+ {
25
+ "relationship_name": relationship_name,
26
+ "source_node_id": source_id,
27
+ "target_node_id": target_id,
28
+ "ontology_valid": False,
29
+ },
30
+ )
31
+
32
+
33
+ async def link_enrichments_to_feedback(
34
+ enrichments: List[FeedbackEnrichment],
35
+ ) -> List[FeedbackEnrichment]:
36
+ """Manually create edges from enrichments to original feedback/interaction nodes."""
37
+ if not enrichments:
38
+ logger.info("No enrichments provided; returning empty list")
39
+ return []
40
+
41
+ relationships = []
42
+
43
+ for enrichment in enrichments:
44
+ enrichment_id = enrichment.id
45
+ feedback_id = enrichment.feedback_id
46
+ interaction_id = enrichment.interaction_id
47
+
48
+ if enrichment_id and feedback_id:
49
+ enriches_feedback_edge = _create_edge_tuple(
50
+ enrichment_id, feedback_id, "enriches_feedback"
51
+ )
52
+ relationships.append(enriches_feedback_edge)
53
+
54
+ if enrichment_id and interaction_id:
55
+ improves_interaction_edge = _create_edge_tuple(
56
+ enrichment_id, interaction_id, "improves_interaction"
57
+ )
58
+ relationships.append(improves_interaction_edge)
59
+
60
+ if relationships:
61
+ graph_engine = await get_graph_engine()
62
+ await graph_engine.add_edges(relationships)
63
+ await index_graph_edges(relationships)
64
+ logger.info("Linking enrichments to feedback", edge_count=len(relationships))
65
+
66
+ logger.info("Linked enrichments", enrichment_count=len(enrichments))
67
+ return enrichments
@@ -0,0 +1,26 @@
1
+ from typing import List, Optional, Union
2
+ from uuid import UUID
3
+
4
+ from cognee.infrastructure.engine import DataPoint
5
+ from cognee.modules.engine.models import Entity, NodeSet
6
+ from cognee.tasks.temporal_graph.models import Event
7
+
8
+
9
+ class FeedbackEnrichment(DataPoint):
10
+ """Minimal DataPoint for feedback enrichment that works with extract_graph_from_data."""
11
+
12
+ text: str
13
+ contains: Optional[List[Union[Entity, Event]]] = None
14
+ metadata: dict = {"index_fields": ["text"]}
15
+
16
+ question: str
17
+ original_answer: str
18
+ improved_answer: str
19
+ feedback_id: UUID
20
+ interaction_id: UUID
21
+ belongs_to_set: Optional[List[NodeSet]] = None
22
+
23
+ context: str = ""
24
+ feedback_text: str = ""
25
+ new_context: str = ""
26
+ explanation: str = ""
@@ -0,0 +1,174 @@
1
+ """
2
+ End-to-end integration test for feedback enrichment feature.
3
+
4
+ Tests the complete feedback enrichment pipeline:
5
+ 1. Add data and cognify
6
+ 2. Run search with save_interaction=True to create CogneeUserInteraction nodes
7
+ 3. Submit feedback to create CogneeUserFeedback nodes
8
+ 4. Run memify with feedback enrichment tasks to create FeedbackEnrichment nodes
9
+ 5. Verify all nodes and edges are properly created and linked in the graph
10
+ """
11
+
12
+ import os
13
+ import pathlib
14
+ from collections import Counter
15
+
16
+ import cognee
17
+ from cognee.infrastructure.databases.graph import get_graph_engine
18
+ from cognee.modules.pipelines.tasks.task import Task
19
+ from cognee.modules.search.types import SearchType
20
+ from cognee.shared.data_models import KnowledgeGraph
21
+ from cognee.shared.logging_utils import get_logger
22
+ from cognee.tasks.feedback.create_enrichments import create_enrichments
23
+ from cognee.tasks.feedback.extract_feedback_interactions import (
24
+ extract_feedback_interactions,
25
+ )
26
+ from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
27
+ from cognee.tasks.feedback.link_enrichments_to_feedback import (
28
+ link_enrichments_to_feedback,
29
+ )
30
+ from cognee.tasks.graph import extract_graph_from_data
31
+ from cognee.tasks.storage import add_data_points
32
+
33
+ logger = get_logger()
34
+
35
+
36
+ async def main():
37
+ data_directory_path = str(
38
+ pathlib.Path(
39
+ os.path.join(
40
+ pathlib.Path(__file__).parent,
41
+ ".data_storage/test_feedback_enrichment",
42
+ )
43
+ ).resolve()
44
+ )
45
+ cognee_directory_path = str(
46
+ pathlib.Path(
47
+ os.path.join(
48
+ pathlib.Path(__file__).parent,
49
+ ".cognee_system/test_feedback_enrichment",
50
+ )
51
+ ).resolve()
52
+ )
53
+
54
+ cognee.config.data_root_directory(data_directory_path)
55
+ cognee.config.system_root_directory(cognee_directory_path)
56
+
57
+ await cognee.prune.prune_data()
58
+ await cognee.prune.prune_system(metadata=True)
59
+
60
+ dataset_name = "feedback_enrichment_test"
61
+
62
+ await cognee.add("Cognee turns documents into AI memory.", dataset_name)
63
+ await cognee.cognify([dataset_name])
64
+
65
+ question_text = "Say something."
66
+ result = await cognee.search(
67
+ query_type=SearchType.GRAPH_COMPLETION,
68
+ query_text=question_text,
69
+ save_interaction=True,
70
+ )
71
+
72
+ assert len(result) > 0, "Search should return non-empty results"
73
+
74
+ feedback_text = "This answer was completely useless, my feedback is definitely negative."
75
+ await cognee.search(
76
+ query_type=SearchType.FEEDBACK,
77
+ query_text=feedback_text,
78
+ last_k=1,
79
+ )
80
+
81
+ graph_engine = await get_graph_engine()
82
+ nodes_before, edges_before = await graph_engine.get_graph_data()
83
+
84
+ interaction_nodes_before = [
85
+ (node_id, props)
86
+ for node_id, props in nodes_before
87
+ if props.get("type") == "CogneeUserInteraction"
88
+ ]
89
+ feedback_nodes_before = [
90
+ (node_id, props)
91
+ for node_id, props in nodes_before
92
+ if props.get("type") == "CogneeUserFeedback"
93
+ ]
94
+
95
+ edge_types_before = Counter(edge[2] for edge in edges_before)
96
+
97
+ assert len(interaction_nodes_before) >= 1, (
98
+ f"Expected at least 1 CogneeUserInteraction node, found {len(interaction_nodes_before)}"
99
+ )
100
+ assert len(feedback_nodes_before) >= 1, (
101
+ f"Expected at least 1 CogneeUserFeedback node, found {len(feedback_nodes_before)}"
102
+ )
103
+
104
+ for node_id, props in feedback_nodes_before:
105
+ sentiment = props.get("sentiment", "")
106
+ score = props.get("score", 0)
107
+ feedback_text = props.get("feedback", "")
108
+ logger.info(
109
+ "Feedback node created",
110
+ feedback=feedback_text,
111
+ sentiment=sentiment,
112
+ score=score,
113
+ )
114
+
115
+ assert edge_types_before.get("gives_feedback_to", 0) >= 1, (
116
+ f"Expected at least 1 'gives_feedback_to' edge, found {edge_types_before.get('gives_feedback_to', 0)}"
117
+ )
118
+
119
+ extraction_tasks = [Task(extract_feedback_interactions, last_n=5)]
120
+ enrichment_tasks = [
121
+ Task(generate_improved_answers, top_k=20),
122
+ Task(create_enrichments),
123
+ Task(
124
+ extract_graph_from_data,
125
+ graph_model=KnowledgeGraph,
126
+ task_config={"batch_size": 10},
127
+ ),
128
+ Task(add_data_points, task_config={"batch_size": 10}),
129
+ Task(link_enrichments_to_feedback),
130
+ ]
131
+
132
+ await cognee.memify(
133
+ extraction_tasks=extraction_tasks,
134
+ enrichment_tasks=enrichment_tasks,
135
+ data=[{}],
136
+ dataset="feedback_enrichment_test_memify",
137
+ )
138
+
139
+ nodes_after, edges_after = await graph_engine.get_graph_data()
140
+
141
+ enrichment_nodes = [
142
+ (node_id, props)
143
+ for node_id, props in nodes_after
144
+ if props.get("type") == "FeedbackEnrichment"
145
+ ]
146
+
147
+ assert len(enrichment_nodes) >= 1, (
148
+ f"Expected at least 1 FeedbackEnrichment node, found {len(enrichment_nodes)}"
149
+ )
150
+
151
+ for node_id, props in enrichment_nodes:
152
+ assert "text" in props, f"FeedbackEnrichment node {node_id} missing 'text' property"
153
+
154
+ enrichment_node_ids = {node_id for node_id, _ in enrichment_nodes}
155
+ edges_with_enrichments = [
156
+ edge
157
+ for edge in edges_after
158
+ if edge[0] in enrichment_node_ids or edge[1] in enrichment_node_ids
159
+ ]
160
+
161
+ assert len(edges_with_enrichments) >= 1, (
162
+ f"Expected enrichment nodes to have at least 1 edge, found {len(edges_with_enrichments)}"
163
+ )
164
+
165
+ await cognee.prune.prune_data()
166
+ await cognee.prune.prune_system(metadata=True)
167
+
168
+ logger.info("All feedback enrichment tests passed successfully")
169
+
170
+
171
+ if __name__ == "__main__":
172
+ import asyncio
173
+
174
+ asyncio.run(main())
@@ -2,6 +2,7 @@ import os
2
2
  import pytest
3
3
  import pathlib
4
4
  from typing import Optional, Union
5
+ from pydantic import BaseModel
5
6
 
6
7
  import cognee
7
8
  from cognee.low_level import setup, DataPoint
@@ -10,6 +11,11 @@ from cognee.tasks.storage import add_data_points
10
11
  from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
11
12
 
12
13
 
14
+ class TestAnswer(BaseModel):
15
+ answer: str
16
+ explanation: str
17
+
18
+
13
19
  class TestGraphCompletionCoTRetriever:
14
20
  @pytest.mark.asyncio
15
21
  async def test_graph_completion_cot_context_simple(self):
@@ -168,3 +174,48 @@ class TestGraphCompletionCoTRetriever:
168
174
  assert all(isinstance(item, str) and item.strip() for item in answer), (
169
175
  "Answer must contain only non-empty strings"
170
176
  )
177
+
178
+ @pytest.mark.asyncio
179
+ async def test_get_structured_completion(self):
180
+ system_directory_path = os.path.join(
181
+ pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
182
+ )
183
+ cognee.config.system_root_directory(system_directory_path)
184
+ data_directory_path = os.path.join(
185
+ pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
186
+ )
187
+ cognee.config.data_root_directory(data_directory_path)
188
+
189
+ await cognee.prune.prune_data()
190
+ await cognee.prune.prune_system(metadata=True)
191
+ await setup()
192
+
193
+ class Company(DataPoint):
194
+ name: str
195
+
196
+ class Person(DataPoint):
197
+ name: str
198
+ works_for: Company
199
+
200
+ company1 = Company(name="Figma")
201
+ person1 = Person(name="Steve Rodger", works_for=company1)
202
+
203
+ entities = [company1, person1]
204
+ await add_data_points(entities)
205
+
206
+ retriever = GraphCompletionCotRetriever()
207
+
208
+ # Test with string response model (default)
209
+ string_answer = await retriever.get_structured_completion("Who works at Figma?")
210
+ assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
211
+ assert string_answer.strip(), "Answer should not be empty"
212
+
213
+ # Test with structured response model
214
+ structured_answer = await retriever.get_structured_completion(
215
+ "Who works at Figma?", response_model=TestAnswer
216
+ )
217
+ assert isinstance(structured_answer, TestAnswer), (
218
+ f"Expected TestAnswer, got {type(structured_answer).__name__}"
219
+ )
220
+ assert structured_answer.answer.strip(), "Answer field should not be empty"
221
+ assert structured_answer.explanation.strip(), "Explanation field should not be empty"