cognee 0.3.6__py3-none-any.whl → 0.3.7.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +1 -0
- cognee/api/health.py +2 -12
- cognee/api/v1/add/add.py +46 -6
- cognee/api/v1/add/routers/get_add_router.py +11 -2
- cognee/api/v1/cognify/cognify.py +29 -9
- cognee/api/v1/cognify/routers/get_cognify_router.py +2 -1
- cognee/api/v1/datasets/datasets.py +11 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +8 -0
- cognee/api/v1/delete/routers/get_delete_router.py +2 -0
- cognee/api/v1/memify/routers/get_memify_router.py +2 -1
- cognee/api/v1/permissions/routers/get_permissions_router.py +6 -0
- cognee/api/v1/responses/default_tools.py +0 -1
- cognee/api/v1/responses/dispatch_function.py +1 -1
- cognee/api/v1/responses/routers/default_tools.py +0 -1
- cognee/api/v1/search/routers/get_search_router.py +3 -3
- cognee/api/v1/search/search.py +11 -9
- cognee/api/v1/settings/routers/get_settings_router.py +7 -1
- cognee/api/v1/sync/routers/get_sync_router.py +3 -0
- cognee/api/v1/ui/ui.py +45 -16
- cognee/api/v1/update/routers/get_update_router.py +3 -1
- cognee/api/v1/update/update.py +3 -3
- cognee/api/v1/users/routers/get_visualize_router.py +2 -0
- cognee/cli/_cognee.py +61 -10
- cognee/cli/commands/add_command.py +3 -3
- cognee/cli/commands/cognify_command.py +3 -3
- cognee/cli/commands/config_command.py +9 -7
- cognee/cli/commands/delete_command.py +3 -3
- cognee/cli/commands/search_command.py +3 -7
- cognee/cli/config.py +0 -1
- cognee/context_global_variables.py +5 -0
- cognee/exceptions/exceptions.py +1 -1
- cognee/infrastructure/databases/cache/__init__.py +2 -0
- cognee/infrastructure/databases/cache/cache_db_interface.py +79 -0
- cognee/infrastructure/databases/cache/config.py +44 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +67 -0
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +243 -0
- cognee/infrastructure/databases/exceptions/__init__.py +1 -0
- cognee/infrastructure/databases/exceptions/exceptions.py +18 -2
- cognee/infrastructure/databases/graph/get_graph_engine.py +1 -1
- cognee/infrastructure/databases/graph/graph_db_interface.py +5 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +76 -47
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +13 -3
- cognee/infrastructure/databases/graph/neo4j_driver/deadlock_retry.py +1 -1
- cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py +1 -1
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +1 -1
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +21 -3
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +17 -10
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +17 -4
- cognee/infrastructure/databases/vector/embeddings/config.py +2 -3
- cognee/infrastructure/databases/vector/exceptions/exceptions.py +1 -1
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -1
- cognee/infrastructure/files/exceptions.py +1 -1
- cognee/infrastructure/files/storage/LocalFileStorage.py +9 -9
- cognee/infrastructure/files/storage/S3FileStorage.py +11 -11
- cognee/infrastructure/files/utils/guess_file_type.py +6 -0
- cognee/infrastructure/llm/prompts/feedback_reaction_prompt.txt +14 -0
- cognee/infrastructure/llm/prompts/feedback_report_prompt.txt +13 -0
- cognee/infrastructure/llm/prompts/feedback_user_context_prompt.txt +5 -0
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +0 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +19 -9
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +17 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +17 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +32 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/__init__.py +0 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +109 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +33 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +40 -18
- cognee/infrastructure/loaders/LoaderEngine.py +27 -7
- cognee/infrastructure/loaders/external/__init__.py +7 -0
- cognee/infrastructure/loaders/external/advanced_pdf_loader.py +2 -8
- cognee/infrastructure/loaders/external/beautiful_soup_loader.py +310 -0
- cognee/infrastructure/loaders/supported_loaders.py +7 -0
- cognee/modules/data/exceptions/exceptions.py +1 -1
- cognee/modules/data/methods/__init__.py +3 -0
- cognee/modules/data/methods/get_dataset_data.py +4 -1
- cognee/modules/data/methods/has_dataset_data.py +21 -0
- cognee/modules/engine/models/TableRow.py +0 -1
- cognee/modules/ingestion/save_data_to_file.py +9 -2
- cognee/modules/pipelines/exceptions/exceptions.py +1 -1
- cognee/modules/pipelines/operations/pipeline.py +12 -1
- cognee/modules/pipelines/operations/run_tasks.py +25 -197
- cognee/modules/pipelines/operations/run_tasks_base.py +7 -0
- cognee/modules/pipelines/operations/run_tasks_data_item.py +260 -0
- cognee/modules/pipelines/operations/run_tasks_distributed.py +121 -38
- cognee/modules/pipelines/operations/run_tasks_with_telemetry.py +9 -1
- cognee/modules/retrieval/EntityCompletionRetriever.py +48 -8
- cognee/modules/retrieval/base_graph_retriever.py +3 -1
- cognee/modules/retrieval/base_retriever.py +3 -1
- cognee/modules/retrieval/chunks_retriever.py +5 -1
- cognee/modules/retrieval/code_retriever.py +20 -2
- cognee/modules/retrieval/completion_retriever.py +50 -9
- cognee/modules/retrieval/cypher_search_retriever.py +11 -1
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +47 -8
- cognee/modules/retrieval/graph_completion_cot_retriever.py +152 -22
- cognee/modules/retrieval/graph_completion_retriever.py +54 -10
- cognee/modules/retrieval/lexical_retriever.py +20 -2
- cognee/modules/retrieval/natural_language_retriever.py +10 -1
- cognee/modules/retrieval/summaries_retriever.py +5 -1
- cognee/modules/retrieval/temporal_retriever.py +62 -10
- cognee/modules/retrieval/user_qa_feedback.py +3 -2
- cognee/modules/retrieval/utils/completion.py +30 -4
- cognee/modules/retrieval/utils/description_to_codepart_search.py +1 -1
- cognee/modules/retrieval/utils/session_cache.py +156 -0
- cognee/modules/search/methods/get_search_type_tools.py +0 -5
- cognee/modules/search/methods/no_access_control_search.py +12 -1
- cognee/modules/search/methods/search.py +51 -5
- cognee/modules/search/types/SearchType.py +0 -1
- cognee/modules/settings/get_settings.py +23 -0
- cognee/modules/users/methods/get_authenticated_user.py +3 -1
- cognee/modules/users/methods/get_default_user.py +1 -6
- cognee/modules/users/roles/methods/create_role.py +2 -2
- cognee/modules/users/tenants/methods/create_tenant.py +2 -2
- cognee/shared/exceptions/exceptions.py +1 -1
- cognee/shared/logging_utils.py +18 -11
- cognee/shared/utils.py +24 -2
- cognee/tasks/codingagents/coding_rule_associations.py +1 -2
- cognee/tasks/documents/exceptions/exceptions.py +1 -1
- cognee/tasks/feedback/__init__.py +13 -0
- cognee/tasks/feedback/create_enrichments.py +84 -0
- cognee/tasks/feedback/extract_feedback_interactions.py +230 -0
- cognee/tasks/feedback/generate_improved_answers.py +130 -0
- cognee/tasks/feedback/link_enrichments_to_feedback.py +67 -0
- cognee/tasks/feedback/models.py +26 -0
- cognee/tasks/graph/extract_graph_from_data.py +2 -0
- cognee/tasks/ingestion/data_item_to_text_file.py +3 -3
- cognee/tasks/ingestion/ingest_data.py +11 -5
- cognee/tasks/ingestion/save_data_item_to_storage.py +12 -1
- cognee/tasks/storage/add_data_points.py +3 -10
- cognee/tasks/storage/index_data_points.py +19 -14
- cognee/tasks/storage/index_graph_edges.py +25 -11
- cognee/tasks/web_scraper/__init__.py +34 -0
- cognee/tasks/web_scraper/config.py +26 -0
- cognee/tasks/web_scraper/default_url_crawler.py +446 -0
- cognee/tasks/web_scraper/models.py +46 -0
- cognee/tasks/web_scraper/types.py +4 -0
- cognee/tasks/web_scraper/utils.py +142 -0
- cognee/tasks/web_scraper/web_scraper_task.py +396 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_utils.py +0 -1
- cognee/tests/integration/web_url_crawler/test_default_url_crawler.py +13 -0
- cognee/tests/integration/web_url_crawler/test_tavily_crawler.py +19 -0
- cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py +344 -0
- cognee/tests/subprocesses/reader.py +25 -0
- cognee/tests/subprocesses/simple_cognify_1.py +31 -0
- cognee/tests/subprocesses/simple_cognify_2.py +31 -0
- cognee/tests/subprocesses/writer.py +32 -0
- cognee/tests/tasks/descriptive_metrics/metrics_test_utils.py +0 -2
- cognee/tests/tasks/descriptive_metrics/neo4j_metrics_test.py +8 -3
- cognee/tests/tasks/entity_extraction/entity_extraction_test.py +89 -0
- cognee/tests/tasks/web_scraping/web_scraping_test.py +172 -0
- cognee/tests/test_add_docling_document.py +56 -0
- cognee/tests/test_chromadb.py +7 -11
- cognee/tests/test_concurrent_subprocess_access.py +76 -0
- cognee/tests/test_conversation_history.py +240 -0
- cognee/tests/test_feedback_enrichment.py +174 -0
- cognee/tests/test_kuzu.py +27 -15
- cognee/tests/test_lancedb.py +7 -11
- cognee/tests/test_library.py +32 -2
- cognee/tests/test_neo4j.py +24 -16
- cognee/tests/test_neptune_analytics_vector.py +7 -11
- cognee/tests/test_permissions.py +9 -13
- cognee/tests/test_pgvector.py +4 -4
- cognee/tests/test_remote_kuzu.py +8 -11
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +6 -8
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +89 -0
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +154 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +51 -0
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/METADATA +21 -6
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/RECORD +178 -139
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/entry_points.txt +1 -0
- distributed/Dockerfile +0 -3
- distributed/entrypoint.py +21 -9
- distributed/signal.py +5 -0
- distributed/workers/data_point_saving_worker.py +64 -34
- distributed/workers/graph_saving_worker.py +71 -47
- cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py +0 -1116
- cognee/modules/retrieval/insights_retriever.py +0 -133
- cognee/tests/test_memgraph.py +0 -109
- cognee/tests/unit/modules/retrieval/insights_retriever_test.py +0 -251
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/WHEEL +0 -0
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,15 +1,41 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
1
3
|
from typing import Optional, List, Type, Any
|
|
4
|
+
from pydantic import BaseModel
|
|
2
5
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
3
6
|
from cognee.shared.logging_utils import get_logger
|
|
4
7
|
|
|
5
8
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
6
|
-
from cognee.modules.retrieval.utils.completion import
|
|
9
|
+
from cognee.modules.retrieval.utils.completion import (
|
|
10
|
+
generate_structured_completion,
|
|
11
|
+
summarize_text,
|
|
12
|
+
)
|
|
13
|
+
from cognee.modules.retrieval.utils.session_cache import (
|
|
14
|
+
save_conversation_history,
|
|
15
|
+
get_conversation_history,
|
|
16
|
+
)
|
|
7
17
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
8
18
|
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
|
19
|
+
from cognee.context_global_variables import session_user
|
|
20
|
+
from cognee.infrastructure.databases.cache.config import CacheConfig
|
|
9
21
|
|
|
10
22
|
logger = get_logger()
|
|
11
23
|
|
|
12
24
|
|
|
25
|
+
def _as_answer_text(completion: Any) -> str:
|
|
26
|
+
"""Convert completion to human-readable text for validation and follow-up prompts."""
|
|
27
|
+
if isinstance(completion, str):
|
|
28
|
+
return completion
|
|
29
|
+
if isinstance(completion, BaseModel):
|
|
30
|
+
# Add notice that this is a structured response
|
|
31
|
+
json_str = completion.model_dump_json(indent=2)
|
|
32
|
+
return f"[Structured Response]\n{json_str}"
|
|
33
|
+
try:
|
|
34
|
+
return json.dumps(completion, indent=2)
|
|
35
|
+
except TypeError:
|
|
36
|
+
return str(completion)
|
|
37
|
+
|
|
38
|
+
|
|
13
39
|
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
14
40
|
"""
|
|
15
41
|
Handles graph completion by generating responses based on a series of interactions with
|
|
@@ -18,6 +44,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
18
44
|
questions based on reasoning. The public methods are:
|
|
19
45
|
|
|
20
46
|
- get_completion
|
|
47
|
+
- get_structured_completion
|
|
21
48
|
|
|
22
49
|
Instance variables include:
|
|
23
50
|
- validation_system_prompt_path
|
|
@@ -54,33 +81,30 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
54
81
|
self.followup_system_prompt_path = followup_system_prompt_path
|
|
55
82
|
self.followup_user_prompt_path = followup_user_prompt_path
|
|
56
83
|
|
|
57
|
-
async def
|
|
84
|
+
async def _run_cot_completion(
|
|
58
85
|
self,
|
|
59
86
|
query: str,
|
|
60
87
|
context: Optional[List[Edge]] = None,
|
|
61
|
-
|
|
62
|
-
|
|
88
|
+
conversation_history: str = "",
|
|
89
|
+
max_iter: int = 4,
|
|
90
|
+
response_model: Type = str,
|
|
91
|
+
) -> tuple[Any, str, List[Edge]]:
|
|
63
92
|
"""
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
This method interacts with a language model client to retrieve a structured response,
|
|
67
|
-
using a series of iterations to refine the answers and generate follow-up questions
|
|
68
|
-
based on reasoning derived from previous outputs. It raises exceptions if the context
|
|
69
|
-
retrieval fails or if the model encounters issues in generating outputs.
|
|
93
|
+
Run chain-of-thought completion with optional structured output.
|
|
70
94
|
|
|
71
95
|
Parameters:
|
|
72
96
|
-----------
|
|
73
|
-
|
|
74
|
-
-
|
|
75
|
-
-
|
|
76
|
-
|
|
77
|
-
-
|
|
78
|
-
follow-up questions. (default 4)
|
|
97
|
+
- query: User query
|
|
98
|
+
- context: Optional pre-fetched context edges
|
|
99
|
+
- conversation_history: Optional conversation history string
|
|
100
|
+
- max_iter: Maximum CoT iterations
|
|
101
|
+
- response_model: Type for structured output (str for plain text)
|
|
79
102
|
|
|
80
103
|
Returns:
|
|
81
104
|
--------
|
|
82
|
-
|
|
83
|
-
-
|
|
105
|
+
- completion_result: The generated completion (string or structured model)
|
|
106
|
+
- context_text: The resolved context text
|
|
107
|
+
- triplets: The list of triplets used
|
|
84
108
|
"""
|
|
85
109
|
followup_question = ""
|
|
86
110
|
triplets = []
|
|
@@ -97,16 +121,21 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
97
121
|
triplets += await self.get_context(followup_question)
|
|
98
122
|
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
|
99
123
|
|
|
100
|
-
completion = await
|
|
124
|
+
completion = await generate_structured_completion(
|
|
101
125
|
query=query,
|
|
102
126
|
context=context_text,
|
|
103
127
|
user_prompt_path=self.user_prompt_path,
|
|
104
128
|
system_prompt_path=self.system_prompt_path,
|
|
105
129
|
system_prompt=self.system_prompt,
|
|
130
|
+
conversation_history=conversation_history if conversation_history else None,
|
|
131
|
+
response_model=response_model,
|
|
106
132
|
)
|
|
133
|
+
|
|
107
134
|
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
|
135
|
+
|
|
108
136
|
if round_idx < max_iter:
|
|
109
|
-
|
|
137
|
+
answer_text = _as_answer_text(completion)
|
|
138
|
+
valid_args = {"query": query, "answer": answer_text, "context": context_text}
|
|
110
139
|
valid_user_prompt = render_prompt(
|
|
111
140
|
filename=self.validation_user_prompt_path, context=valid_args
|
|
112
141
|
)
|
|
@@ -119,7 +148,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
119
148
|
system_prompt=valid_system_prompt,
|
|
120
149
|
response_model=str,
|
|
121
150
|
)
|
|
122
|
-
followup_args = {"query": query, "answer":
|
|
151
|
+
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
|
|
123
152
|
followup_prompt = render_prompt(
|
|
124
153
|
filename=self.followup_user_prompt_path, context=followup_args
|
|
125
154
|
)
|
|
@@ -134,9 +163,110 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
134
163
|
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
|
135
164
|
)
|
|
136
165
|
|
|
166
|
+
return completion, context_text, triplets
|
|
167
|
+
|
|
168
|
+
async def get_structured_completion(
|
|
169
|
+
self,
|
|
170
|
+
query: str,
|
|
171
|
+
context: Optional[List[Edge]] = None,
|
|
172
|
+
session_id: Optional[str] = None,
|
|
173
|
+
max_iter: int = 4,
|
|
174
|
+
response_model: Type = str,
|
|
175
|
+
) -> Any:
|
|
176
|
+
"""
|
|
177
|
+
Generate structured completion responses based on a user query and contextual information.
|
|
178
|
+
|
|
179
|
+
This method applies the same chain-of-thought logic as get_completion but returns
|
|
180
|
+
structured output using the provided response model.
|
|
181
|
+
|
|
182
|
+
Parameters:
|
|
183
|
+
-----------
|
|
184
|
+
- query (str): The user's query to be processed and answered.
|
|
185
|
+
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
|
|
186
|
+
If not provided, it will be fetched based on the query. (default None)
|
|
187
|
+
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
188
|
+
defaults to 'default_session'. (default None)
|
|
189
|
+
- max_iter: The maximum number of iterations to refine the answer and generate
|
|
190
|
+
follow-up questions. (default 4)
|
|
191
|
+
- response_model (Type): The Pydantic model type for structured output. (default str)
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
--------
|
|
195
|
+
- Any: The generated structured completion based on the response model.
|
|
196
|
+
"""
|
|
197
|
+
# Check if session saving is enabled
|
|
198
|
+
cache_config = CacheConfig()
|
|
199
|
+
user = session_user.get()
|
|
200
|
+
user_id = getattr(user, "id", None)
|
|
201
|
+
session_save = user_id and cache_config.caching
|
|
202
|
+
|
|
203
|
+
# Load conversation history if enabled
|
|
204
|
+
conversation_history = ""
|
|
205
|
+
if session_save:
|
|
206
|
+
conversation_history = await get_conversation_history(session_id=session_id)
|
|
207
|
+
|
|
208
|
+
completion, context_text, triplets = await self._run_cot_completion(
|
|
209
|
+
query=query,
|
|
210
|
+
context=context,
|
|
211
|
+
conversation_history=conversation_history,
|
|
212
|
+
max_iter=max_iter,
|
|
213
|
+
response_model=response_model,
|
|
214
|
+
)
|
|
215
|
+
|
|
137
216
|
if self.save_interaction and context and triplets and completion:
|
|
138
217
|
await self.save_qa(
|
|
139
|
-
question=query, answer=completion, context=context_text, triplets=triplets
|
|
218
|
+
question=query, answer=str(completion), context=context_text, triplets=triplets
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Save to session cache if enabled
|
|
222
|
+
if session_save:
|
|
223
|
+
context_summary = await summarize_text(context_text)
|
|
224
|
+
await save_conversation_history(
|
|
225
|
+
query=query,
|
|
226
|
+
context_summary=context_summary,
|
|
227
|
+
answer=str(completion),
|
|
228
|
+
session_id=session_id,
|
|
140
229
|
)
|
|
141
230
|
|
|
231
|
+
return completion
|
|
232
|
+
|
|
233
|
+
async def get_completion(
|
|
234
|
+
self,
|
|
235
|
+
query: str,
|
|
236
|
+
context: Optional[List[Edge]] = None,
|
|
237
|
+
session_id: Optional[str] = None,
|
|
238
|
+
max_iter=4,
|
|
239
|
+
) -> List[str]:
|
|
240
|
+
"""
|
|
241
|
+
Generate completion responses based on a user query and contextual information.
|
|
242
|
+
|
|
243
|
+
This method interacts with a language model client to retrieve a structured response,
|
|
244
|
+
using a series of iterations to refine the answers and generate follow-up questions
|
|
245
|
+
based on reasoning derived from previous outputs. It raises exceptions if the context
|
|
246
|
+
retrieval fails or if the model encounters issues in generating outputs.
|
|
247
|
+
|
|
248
|
+
Parameters:
|
|
249
|
+
-----------
|
|
250
|
+
|
|
251
|
+
- query (str): The user's query to be processed and answered.
|
|
252
|
+
- context (Optional[Any]): Optional context that may assist in answering the query.
|
|
253
|
+
If not provided, it will be fetched based on the query. (default None)
|
|
254
|
+
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
255
|
+
defaults to 'default_session'. (default None)
|
|
256
|
+
- max_iter: The maximum number of iterations to refine the answer and generate
|
|
257
|
+
follow-up questions. (default 4)
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
--------
|
|
261
|
+
|
|
262
|
+
- List[str]: A list containing the generated answer to the user's query.
|
|
263
|
+
"""
|
|
264
|
+
completion = await self.get_structured_completion(
|
|
265
|
+
query=query,
|
|
266
|
+
context=context,
|
|
267
|
+
session_id=session_id,
|
|
268
|
+
max_iter=max_iter,
|
|
269
|
+
response_model=str,
|
|
270
|
+
)
|
|
271
|
+
|
|
142
272
|
return [completion]
|
|
@@ -1,20 +1,26 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
from typing import Any, Optional, Type, List
|
|
2
3
|
from uuid import NAMESPACE_OID, uuid5
|
|
3
4
|
|
|
4
5
|
from cognee.infrastructure.engine import DataPoint
|
|
5
6
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
6
|
-
from cognee.modules.users.methods import get_default_user
|
|
7
7
|
from cognee.tasks.storage import add_data_points
|
|
8
8
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
|
9
9
|
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
|
10
10
|
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
|
11
11
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
|
12
|
-
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
12
|
+
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
|
13
|
+
from cognee.modules.retrieval.utils.session_cache import (
|
|
14
|
+
save_conversation_history,
|
|
15
|
+
get_conversation_history,
|
|
16
|
+
)
|
|
13
17
|
from cognee.shared.logging_utils import get_logger
|
|
14
18
|
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
|
15
19
|
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
|
16
20
|
from cognee.modules.engine.models.node_set import NodeSet
|
|
17
21
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
22
|
+
from cognee.context_global_variables import session_user
|
|
23
|
+
from cognee.infrastructure.databases.cache.config import CacheConfig
|
|
18
24
|
|
|
19
25
|
logger = get_logger("GraphCompletionRetriever")
|
|
20
26
|
|
|
@@ -118,6 +124,13 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|
|
118
124
|
- str: A string representing the resolved context from the retrieved triplets, or an
|
|
119
125
|
empty string if no triplets are found.
|
|
120
126
|
"""
|
|
127
|
+
graph_engine = await get_graph_engine()
|
|
128
|
+
is_empty = await graph_engine.is_empty()
|
|
129
|
+
|
|
130
|
+
if is_empty:
|
|
131
|
+
logger.warning("Search attempt on an empty knowledge graph")
|
|
132
|
+
return []
|
|
133
|
+
|
|
121
134
|
triplets = await self.get_triplets(query)
|
|
122
135
|
|
|
123
136
|
if len(triplets) == 0:
|
|
@@ -132,6 +145,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|
|
132
145
|
self,
|
|
133
146
|
query: str,
|
|
134
147
|
context: Optional[List[Edge]] = None,
|
|
148
|
+
session_id: Optional[str] = None,
|
|
135
149
|
) -> List[str]:
|
|
136
150
|
"""
|
|
137
151
|
Generates a completion using graph connections context based on a query.
|
|
@@ -142,6 +156,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|
|
142
156
|
- query (str): The query string for which a completion is generated.
|
|
143
157
|
- context (Optional[Any]): Optional context to use for generating the completion; if
|
|
144
158
|
not provided, context is retrieved based on the query. (default None)
|
|
159
|
+
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
160
|
+
defaults to 'default_session'. (default None)
|
|
145
161
|
|
|
146
162
|
Returns:
|
|
147
163
|
--------
|
|
@@ -155,19 +171,47 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|
|
155
171
|
|
|
156
172
|
context_text = await resolve_edges_to_text(triplets)
|
|
157
173
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
174
|
+
cache_config = CacheConfig()
|
|
175
|
+
user = session_user.get()
|
|
176
|
+
user_id = getattr(user, "id", None)
|
|
177
|
+
session_save = user_id and cache_config.caching
|
|
178
|
+
|
|
179
|
+
if session_save:
|
|
180
|
+
conversation_history = await get_conversation_history(session_id=session_id)
|
|
181
|
+
|
|
182
|
+
context_summary, completion = await asyncio.gather(
|
|
183
|
+
summarize_text(context_text),
|
|
184
|
+
generate_completion(
|
|
185
|
+
query=query,
|
|
186
|
+
context=context_text,
|
|
187
|
+
user_prompt_path=self.user_prompt_path,
|
|
188
|
+
system_prompt_path=self.system_prompt_path,
|
|
189
|
+
system_prompt=self.system_prompt,
|
|
190
|
+
conversation_history=conversation_history,
|
|
191
|
+
),
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
completion = await generate_completion(
|
|
195
|
+
query=query,
|
|
196
|
+
context=context_text,
|
|
197
|
+
user_prompt_path=self.user_prompt_path,
|
|
198
|
+
system_prompt_path=self.system_prompt_path,
|
|
199
|
+
system_prompt=self.system_prompt,
|
|
200
|
+
)
|
|
165
201
|
|
|
166
202
|
if self.save_interaction and context and triplets and completion:
|
|
167
203
|
await self.save_qa(
|
|
168
204
|
question=query, answer=completion, context=context_text, triplets=triplets
|
|
169
205
|
)
|
|
170
206
|
|
|
207
|
+
if session_save:
|
|
208
|
+
await save_conversation_history(
|
|
209
|
+
query=query,
|
|
210
|
+
context_summary=context_summary,
|
|
211
|
+
answer=completion,
|
|
212
|
+
session_id=session_id,
|
|
213
|
+
)
|
|
214
|
+
|
|
171
215
|
return [completion]
|
|
172
216
|
|
|
173
217
|
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
|
@@ -194,7 +238,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|
|
194
238
|
belongs_to_set=interactions_node_set,
|
|
195
239
|
)
|
|
196
240
|
|
|
197
|
-
await add_data_points(data_points=[cognee_user_interaction]
|
|
241
|
+
await add_data_points(data_points=[cognee_user_interaction])
|
|
198
242
|
|
|
199
243
|
relationships = []
|
|
200
244
|
relationship_name = "used_graph_element_to_answer"
|
|
@@ -116,8 +116,26 @@ class LexicalRetriever(BaseRetriever):
|
|
|
116
116
|
else:
|
|
117
117
|
return [self.payloads[chunk_id] for chunk_id, _ in top_results]
|
|
118
118
|
|
|
119
|
-
async def get_completion(
|
|
120
|
-
|
|
119
|
+
async def get_completion(
|
|
120
|
+
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
|
121
|
+
) -> Any:
|
|
122
|
+
"""
|
|
123
|
+
Returns context for the given query (retrieves if not provided).
|
|
124
|
+
|
|
125
|
+
Parameters:
|
|
126
|
+
-----------
|
|
127
|
+
|
|
128
|
+
- query (str): The query string to retrieve context for.
|
|
129
|
+
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
|
|
130
|
+
the context for the query. (default None)
|
|
131
|
+
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
132
|
+
defaults to 'default_session'. (default None)
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
--------
|
|
136
|
+
|
|
137
|
+
- Any: The context, either provided or retrieved.
|
|
138
|
+
"""
|
|
121
139
|
if context is None:
|
|
122
140
|
context = await self.get_context(query)
|
|
123
141
|
return context
|
|
@@ -122,10 +122,17 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|
|
122
122
|
query.
|
|
123
123
|
"""
|
|
124
124
|
graph_engine = await get_graph_engine()
|
|
125
|
+
is_empty = await graph_engine.is_empty()
|
|
126
|
+
|
|
127
|
+
if is_empty:
|
|
128
|
+
logger.warning("Search attempt on an empty knowledge graph")
|
|
129
|
+
return []
|
|
125
130
|
|
|
126
131
|
return await self._execute_cypher_query(query, graph_engine)
|
|
127
132
|
|
|
128
|
-
async def get_completion(
|
|
133
|
+
async def get_completion(
|
|
134
|
+
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
|
135
|
+
) -> Any:
|
|
129
136
|
"""
|
|
130
137
|
Returns a completion based on the query and context.
|
|
131
138
|
|
|
@@ -139,6 +146,8 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|
|
139
146
|
- query (str): The natural language query to get a completion from.
|
|
140
147
|
- context (Optional[Any]): The context in which to base the completion; if not
|
|
141
148
|
provided, it will be retrieved using the query. (default None)
|
|
149
|
+
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
150
|
+
defaults to 'default_session'. (default None)
|
|
142
151
|
|
|
143
152
|
Returns:
|
|
144
153
|
--------
|
|
@@ -62,7 +62,9 @@ class SummariesRetriever(BaseRetriever):
|
|
|
62
62
|
logger.info(f"Returning {len(summary_payloads)} summary payloads")
|
|
63
63
|
return summary_payloads
|
|
64
64
|
|
|
65
|
-
async def get_completion(
|
|
65
|
+
async def get_completion(
|
|
66
|
+
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs
|
|
67
|
+
) -> Any:
|
|
66
68
|
"""
|
|
67
69
|
Generates a completion using summaries context.
|
|
68
70
|
|
|
@@ -75,6 +77,8 @@ class SummariesRetriever(BaseRetriever):
|
|
|
75
77
|
- query (str): The search query for generating the completion.
|
|
76
78
|
- context (Optional[Any]): Optional context for the completion; if not provided,
|
|
77
79
|
will be retrieved based on the query. (default None)
|
|
80
|
+
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
81
|
+
defaults to 'default_session'. (default None)
|
|
78
82
|
|
|
79
83
|
Returns:
|
|
80
84
|
--------
|
|
@@ -1,16 +1,22 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import asyncio
|
|
2
3
|
from typing import Any, Optional, List, Type
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
from operator import itemgetter
|
|
6
7
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
7
|
-
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
8
|
+
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
|
9
|
+
from cognee.modules.retrieval.utils.session_cache import (
|
|
10
|
+
save_conversation_history,
|
|
11
|
+
get_conversation_history,
|
|
12
|
+
)
|
|
8
13
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
9
14
|
from cognee.infrastructure.llm.prompts import render_prompt
|
|
10
15
|
from cognee.infrastructure.llm import LLMGateway
|
|
11
16
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
12
17
|
from cognee.shared.logging_utils import get_logger
|
|
13
|
-
|
|
18
|
+
from cognee.context_global_variables import session_user
|
|
19
|
+
from cognee.infrastructure.databases.cache.config import CacheConfig
|
|
14
20
|
|
|
15
21
|
from cognee.tasks.temporal_graph.models import QueryInterval
|
|
16
22
|
|
|
@@ -137,17 +143,63 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|
|
137
143
|
|
|
138
144
|
return self.descriptions_to_string(top_k_events)
|
|
139
145
|
|
|
140
|
-
async def get_completion(
|
|
141
|
-
|
|
146
|
+
async def get_completion(
|
|
147
|
+
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
|
|
148
|
+
) -> List[str]:
|
|
149
|
+
"""
|
|
150
|
+
Generates a response using the query and optional context.
|
|
151
|
+
|
|
152
|
+
Parameters:
|
|
153
|
+
-----------
|
|
154
|
+
|
|
155
|
+
- query (str): The query string for which a completion is generated.
|
|
156
|
+
- context (Optional[str]): Optional context to use; if None, it will be
|
|
157
|
+
retrieved based on the query. (default None)
|
|
158
|
+
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
159
|
+
defaults to 'default_session'. (default None)
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
--------
|
|
163
|
+
|
|
164
|
+
- List[str]: A list containing the generated completion.
|
|
165
|
+
"""
|
|
142
166
|
if not context:
|
|
143
167
|
context = await self.get_context(query=query)
|
|
144
168
|
|
|
145
169
|
if context:
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
170
|
+
# Check if we need to generate context summary for caching
|
|
171
|
+
cache_config = CacheConfig()
|
|
172
|
+
user = session_user.get()
|
|
173
|
+
user_id = getattr(user, "id", None)
|
|
174
|
+
session_save = user_id and cache_config.caching
|
|
175
|
+
|
|
176
|
+
if session_save:
|
|
177
|
+
conversation_history = await get_conversation_history(session_id=session_id)
|
|
178
|
+
|
|
179
|
+
context_summary, completion = await asyncio.gather(
|
|
180
|
+
summarize_text(context),
|
|
181
|
+
generate_completion(
|
|
182
|
+
query=query,
|
|
183
|
+
context=context,
|
|
184
|
+
user_prompt_path=self.user_prompt_path,
|
|
185
|
+
system_prompt_path=self.system_prompt_path,
|
|
186
|
+
conversation_history=conversation_history,
|
|
187
|
+
),
|
|
188
|
+
)
|
|
189
|
+
else:
|
|
190
|
+
completion = await generate_completion(
|
|
191
|
+
query=query,
|
|
192
|
+
context=context,
|
|
193
|
+
user_prompt_path=self.user_prompt_path,
|
|
194
|
+
system_prompt_path=self.system_prompt_path,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if session_save:
|
|
198
|
+
await save_conversation_history(
|
|
199
|
+
query=query,
|
|
200
|
+
context_summary=context_summary,
|
|
201
|
+
answer=completion,
|
|
202
|
+
session_id=session_id,
|
|
203
|
+
)
|
|
152
204
|
|
|
153
205
|
return [completion]
|
|
@@ -8,7 +8,7 @@ from cognee.shared.logging_utils import get_logger
|
|
|
8
8
|
from cognee.modules.retrieval.base_feedback import BaseFeedback
|
|
9
9
|
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
|
|
10
10
|
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
|
|
11
|
-
from cognee.tasks.storage import add_data_points
|
|
11
|
+
from cognee.tasks.storage import add_data_points, index_graph_edges
|
|
12
12
|
|
|
13
13
|
logger = get_logger("CompletionRetriever")
|
|
14
14
|
|
|
@@ -47,7 +47,7 @@ class UserQAFeedback(BaseFeedback):
|
|
|
47
47
|
belongs_to_set=feedbacks_node_set,
|
|
48
48
|
)
|
|
49
49
|
|
|
50
|
-
await add_data_points(data_points=[cognee_user_feedback]
|
|
50
|
+
await add_data_points(data_points=[cognee_user_feedback])
|
|
51
51
|
|
|
52
52
|
relationships = []
|
|
53
53
|
relationship_name = "gives_feedback_to"
|
|
@@ -76,6 +76,7 @@ class UserQAFeedback(BaseFeedback):
|
|
|
76
76
|
if len(relationships) > 0:
|
|
77
77
|
graph_engine = await get_graph_engine()
|
|
78
78
|
await graph_engine.add_edges(relationships)
|
|
79
|
+
await index_graph_edges(relationships)
|
|
79
80
|
await graph_engine.apply_feedback_weight(
|
|
80
81
|
node_ids=to_node_ids, weight=feedback_sentiment.score
|
|
81
82
|
)
|
|
@@ -1,23 +1,49 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional, Type, Any
|
|
2
2
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
3
3
|
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
async def
|
|
6
|
+
async def generate_structured_completion(
|
|
7
7
|
query: str,
|
|
8
8
|
context: str,
|
|
9
9
|
user_prompt_path: str,
|
|
10
10
|
system_prompt_path: str,
|
|
11
11
|
system_prompt: Optional[str] = None,
|
|
12
|
-
|
|
13
|
-
|
|
12
|
+
conversation_history: Optional[str] = None,
|
|
13
|
+
response_model: Type = str,
|
|
14
|
+
) -> Any:
|
|
15
|
+
"""Generates a structured completion using LLM with given context and prompts."""
|
|
14
16
|
args = {"question": query, "context": context}
|
|
15
17
|
user_prompt = render_prompt(user_prompt_path, args)
|
|
16
18
|
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
|
17
19
|
|
|
20
|
+
if conversation_history:
|
|
21
|
+
#:TODO: I would separate the history and put it into the system prompt but we have to test what works best with longer convos
|
|
22
|
+
system_prompt = conversation_history + "\nTASK:" + system_prompt
|
|
23
|
+
|
|
18
24
|
return await LLMGateway.acreate_structured_output(
|
|
19
25
|
text_input=user_prompt,
|
|
20
26
|
system_prompt=system_prompt,
|
|
27
|
+
response_model=response_model,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def generate_completion(
|
|
32
|
+
query: str,
|
|
33
|
+
context: str,
|
|
34
|
+
user_prompt_path: str,
|
|
35
|
+
system_prompt_path: str,
|
|
36
|
+
system_prompt: Optional[str] = None,
|
|
37
|
+
conversation_history: Optional[str] = None,
|
|
38
|
+
) -> str:
|
|
39
|
+
"""Generates a completion using LLM with given context and prompts."""
|
|
40
|
+
return await generate_structured_completion(
|
|
41
|
+
query=query,
|
|
42
|
+
context=context,
|
|
43
|
+
user_prompt_path=user_prompt_path,
|
|
44
|
+
system_prompt_path=system_prompt_path,
|
|
45
|
+
system_prompt=system_prompt,
|
|
46
|
+
conversation_history=conversation_history,
|
|
21
47
|
response_model=str,
|
|
22
48
|
)
|
|
23
49
|
|
|
@@ -62,7 +62,7 @@ async def code_description_to_code_part(
|
|
|
62
62
|
|
|
63
63
|
try:
|
|
64
64
|
if include_docs:
|
|
65
|
-
search_results = await search(query_text=query, query_type="
|
|
65
|
+
search_results = await search(query_text=query, query_type="GRAPH_COMPLETION")
|
|
66
66
|
|
|
67
67
|
concatenated_descriptions = " ".join(
|
|
68
68
|
obj["description"]
|