cognee 0.2.3.dev1__py3-none-any.whl → 0.3.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +1 -0
- cognee/__main__.py +4 -0
- cognee/api/client.py +28 -3
- cognee/api/health.py +10 -13
- cognee/api/v1/add/add.py +20 -6
- cognee/api/v1/add/routers/get_add_router.py +12 -37
- cognee/api/v1/cloud/routers/__init__.py +1 -0
- cognee/api/v1/cloud/routers/get_checks_router.py +23 -0
- cognee/api/v1/cognify/code_graph_pipeline.py +14 -3
- cognee/api/v1/cognify/cognify.py +67 -105
- cognee/api/v1/cognify/routers/get_cognify_router.py +11 -3
- cognee/api/v1/datasets/routers/get_datasets_router.py +16 -5
- cognee/api/v1/memify/routers/__init__.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +100 -0
- cognee/api/v1/notebooks/routers/__init__.py +1 -0
- cognee/api/v1/notebooks/routers/get_notebooks_router.py +96 -0
- cognee/api/v1/responses/default_tools.py +4 -0
- cognee/api/v1/responses/dispatch_function.py +6 -1
- cognee/api/v1/responses/models.py +1 -1
- cognee/api/v1/search/routers/get_search_router.py +20 -1
- cognee/api/v1/search/search.py +17 -4
- cognee/api/v1/sync/__init__.py +17 -0
- cognee/api/v1/sync/routers/__init__.py +3 -0
- cognee/api/v1/sync/routers/get_sync_router.py +241 -0
- cognee/api/v1/sync/sync.py +877 -0
- cognee/api/v1/users/routers/get_auth_router.py +13 -1
- cognee/base_config.py +10 -1
- cognee/cli/__init__.py +10 -0
- cognee/cli/_cognee.py +180 -0
- cognee/cli/commands/__init__.py +1 -0
- cognee/cli/commands/add_command.py +80 -0
- cognee/cli/commands/cognify_command.py +128 -0
- cognee/cli/commands/config_command.py +225 -0
- cognee/cli/commands/delete_command.py +80 -0
- cognee/cli/commands/search_command.py +149 -0
- cognee/cli/config.py +33 -0
- cognee/cli/debug.py +21 -0
- cognee/cli/echo.py +45 -0
- cognee/cli/exceptions.py +23 -0
- cognee/cli/minimal_cli.py +97 -0
- cognee/cli/reference.py +26 -0
- cognee/cli/suppress_logging.py +12 -0
- cognee/eval_framework/corpus_builder/corpus_builder_executor.py +2 -2
- cognee/eval_framework/eval_config.py +1 -1
- cognee/infrastructure/databases/graph/config.py +10 -4
- cognee/infrastructure/databases/graph/get_graph_engine.py +4 -9
- cognee/infrastructure/databases/graph/kuzu/adapter.py +199 -2
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +138 -0
- cognee/infrastructure/databases/relational/__init__.py +2 -0
- cognee/infrastructure/databases/relational/get_async_session.py +15 -0
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +6 -1
- cognee/infrastructure/databases/relational/with_async_session.py +25 -0
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +1 -1
- cognee/infrastructure/databases/vector/config.py +13 -6
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -4
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +16 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +5 -5
- cognee/infrastructure/databases/vector/embeddings/config.py +2 -2
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +2 -6
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +10 -7
- cognee/infrastructure/files/storage/LocalFileStorage.py +9 -0
- cognee/infrastructure/files/storage/S3FileStorage.py +5 -0
- cognee/infrastructure/files/storage/StorageManager.py +7 -1
- cognee/infrastructure/files/storage/storage.py +16 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +14 -9
- cognee/infrastructure/files/utils/get_file_metadata.py +2 -1
- cognee/infrastructure/llm/LLMGateway.py +32 -5
- cognee/infrastructure/llm/config.py +6 -4
- cognee/infrastructure/llm/prompts/extract_query_time.txt +15 -0
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +25 -0
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +30 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +16 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +2 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_event_entities.py +44 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/__init__.py +1 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_content_graph.py +19 -15
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_event_graph.py +46 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +3 -3
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +3 -3
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +14 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +6 -4
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +28 -4
- cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +2 -2
- cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py +3 -3
- cognee/infrastructure/llm/tokenizer/Mistral/adapter.py +3 -3
- cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +6 -6
- cognee/infrastructure/llm/utils.py +7 -7
- cognee/infrastructure/utils/run_sync.py +8 -1
- cognee/modules/chunking/models/DocumentChunk.py +4 -3
- cognee/modules/cloud/exceptions/CloudApiKeyMissingError.py +15 -0
- cognee/modules/cloud/exceptions/CloudConnectionError.py +15 -0
- cognee/modules/cloud/exceptions/__init__.py +2 -0
- cognee/modules/cloud/operations/__init__.py +1 -0
- cognee/modules/cloud/operations/check_api_key.py +25 -0
- cognee/modules/data/deletion/prune_system.py +1 -1
- cognee/modules/data/methods/__init__.py +2 -0
- cognee/modules/data/methods/check_dataset_name.py +1 -1
- cognee/modules/data/methods/create_authorized_dataset.py +19 -0
- cognee/modules/data/methods/get_authorized_dataset.py +11 -5
- cognee/modules/data/methods/get_authorized_dataset_by_name.py +16 -0
- cognee/modules/data/methods/get_dataset_data.py +1 -1
- cognee/modules/data/methods/load_or_create_datasets.py +2 -20
- cognee/modules/engine/models/Event.py +16 -0
- cognee/modules/engine/models/Interval.py +8 -0
- cognee/modules/engine/models/Timestamp.py +13 -0
- cognee/modules/engine/models/__init__.py +3 -0
- cognee/modules/engine/utils/__init__.py +2 -0
- cognee/modules/engine/utils/generate_event_datapoint.py +46 -0
- cognee/modules/engine/utils/generate_timestamp_datapoint.py +51 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +2 -2
- cognee/modules/graph/methods/get_formatted_graph_data.py +3 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/resolve_edges_to_text.py +71 -0
- cognee/modules/memify/__init__.py +1 -0
- cognee/modules/memify/memify.py +118 -0
- cognee/modules/notebooks/methods/__init__.py +5 -0
- cognee/modules/notebooks/methods/create_notebook.py +26 -0
- cognee/modules/notebooks/methods/delete_notebook.py +13 -0
- cognee/modules/notebooks/methods/get_notebook.py +21 -0
- cognee/modules/notebooks/methods/get_notebooks.py +18 -0
- cognee/modules/notebooks/methods/update_notebook.py +17 -0
- cognee/modules/notebooks/models/Notebook.py +53 -0
- cognee/modules/notebooks/models/__init__.py +1 -0
- cognee/modules/notebooks/operations/__init__.py +1 -0
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +55 -0
- cognee/modules/pipelines/__init__.py +1 -1
- cognee/modules/pipelines/exceptions/tasks.py +18 -0
- cognee/modules/pipelines/layers/__init__.py +1 -0
- cognee/modules/pipelines/layers/check_pipeline_run_qualification.py +59 -0
- cognee/modules/pipelines/layers/pipeline_execution_mode.py +127 -0
- cognee/modules/pipelines/layers/reset_dataset_pipeline_run_status.py +28 -0
- cognee/modules/pipelines/layers/resolve_authorized_user_dataset.py +34 -0
- cognee/modules/pipelines/layers/resolve_authorized_user_datasets.py +55 -0
- cognee/modules/pipelines/layers/setup_and_check_environment.py +41 -0
- cognee/modules/pipelines/layers/validate_pipeline_tasks.py +20 -0
- cognee/modules/pipelines/methods/__init__.py +2 -0
- cognee/modules/pipelines/methods/get_pipeline_runs_by_dataset.py +34 -0
- cognee/modules/pipelines/methods/reset_pipeline_run_status.py +16 -0
- cognee/modules/pipelines/operations/__init__.py +0 -1
- cognee/modules/pipelines/operations/log_pipeline_run_initiated.py +1 -1
- cognee/modules/pipelines/operations/pipeline.py +24 -138
- cognee/modules/pipelines/operations/run_tasks.py +17 -41
- cognee/modules/retrieval/base_feedback.py +11 -0
- cognee/modules/retrieval/base_graph_retriever.py +18 -0
- cognee/modules/retrieval/base_retriever.py +1 -1
- cognee/modules/retrieval/code_retriever.py +8 -0
- cognee/modules/retrieval/coding_rules_retriever.py +31 -0
- cognee/modules/retrieval/completion_retriever.py +9 -3
- cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -0
- cognee/modules/retrieval/cypher_search_retriever.py +1 -9
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +29 -13
- cognee/modules/retrieval/graph_completion_cot_retriever.py +30 -13
- cognee/modules/retrieval/graph_completion_retriever.py +107 -56
- cognee/modules/retrieval/graph_summary_completion_retriever.py +5 -1
- cognee/modules/retrieval/insights_retriever.py +14 -3
- cognee/modules/retrieval/natural_language_retriever.py +0 -4
- cognee/modules/retrieval/summaries_retriever.py +1 -1
- cognee/modules/retrieval/temporal_retriever.py +152 -0
- cognee/modules/retrieval/user_qa_feedback.py +83 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +7 -32
- cognee/modules/retrieval/utils/completion.py +10 -3
- cognee/modules/retrieval/utils/extract_uuid_from_node.py +18 -0
- cognee/modules/retrieval/utils/models.py +40 -0
- cognee/modules/search/methods/get_search_type_tools.py +168 -0
- cognee/modules/search/methods/no_access_control_search.py +47 -0
- cognee/modules/search/methods/search.py +239 -118
- cognee/modules/search/types/SearchResult.py +21 -0
- cognee/modules/search/types/SearchType.py +3 -0
- cognee/modules/search/types/__init__.py +1 -0
- cognee/modules/search/utils/__init__.py +2 -0
- cognee/modules/search/utils/prepare_search_result.py +41 -0
- cognee/modules/search/utils/transform_context_to_graph.py +38 -0
- cognee/modules/settings/get_settings.py +2 -2
- cognee/modules/sync/__init__.py +1 -0
- cognee/modules/sync/methods/__init__.py +23 -0
- cognee/modules/sync/methods/create_sync_operation.py +53 -0
- cognee/modules/sync/methods/get_sync_operation.py +107 -0
- cognee/modules/sync/methods/update_sync_operation.py +248 -0
- cognee/modules/sync/models/SyncOperation.py +142 -0
- cognee/modules/sync/models/__init__.py +3 -0
- cognee/modules/users/__init__.py +0 -1
- cognee/modules/users/methods/__init__.py +4 -1
- cognee/modules/users/methods/create_user.py +26 -1
- cognee/modules/users/methods/get_authenticated_user.py +36 -42
- cognee/modules/users/methods/get_default_user.py +3 -1
- cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +2 -1
- cognee/root_dir.py +19 -0
- cognee/shared/CodeGraphEntities.py +1 -0
- cognee/shared/logging_utils.py +143 -32
- cognee/shared/utils.py +0 -1
- cognee/tasks/codingagents/coding_rule_associations.py +127 -0
- cognee/tasks/graph/extract_graph_from_data.py +6 -2
- cognee/tasks/ingestion/save_data_item_to_storage.py +23 -0
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/extract_subgraph.py +7 -0
- cognee/tasks/memify/extract_subgraph_chunks.py +11 -0
- cognee/tasks/repo_processor/get_local_dependencies.py +2 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +144 -47
- cognee/tasks/storage/add_data_points.py +33 -3
- cognee/tasks/temporal_graph/__init__.py +1 -0
- cognee/tasks/temporal_graph/add_entities_to_event.py +85 -0
- cognee/tasks/temporal_graph/enrich_events.py +34 -0
- cognee/tasks/temporal_graph/extract_events_and_entities.py +32 -0
- cognee/tasks/temporal_graph/extract_knowledge_graph_from_events.py +41 -0
- cognee/tasks/temporal_graph/models.py +49 -0
- cognee/tests/integration/cli/__init__.py +3 -0
- cognee/tests/integration/cli/test_cli_integration.py +331 -0
- cognee/tests/integration/documents/PdfDocument_test.py +2 -2
- cognee/tests/integration/documents/TextDocument_test.py +2 -4
- cognee/tests/integration/documents/UnstructuredDocument_test.py +5 -8
- cognee/tests/{test_deletion.py → test_delete_hard.py} +0 -37
- cognee/tests/test_delete_soft.py +85 -0
- cognee/tests/test_kuzu.py +2 -2
- cognee/tests/test_neo4j.py +2 -2
- cognee/tests/test_permissions.py +3 -3
- cognee/tests/test_relational_db_migration.py +7 -5
- cognee/tests/test_search_db.py +136 -23
- cognee/tests/test_temporal_graph.py +167 -0
- cognee/tests/unit/api/__init__.py +1 -0
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +246 -0
- cognee/tests/unit/cli/__init__.py +3 -0
- cognee/tests/unit/cli/test_cli_commands.py +483 -0
- cognee/tests/unit/cli/test_cli_edge_cases.py +625 -0
- cognee/tests/unit/cli/test_cli_main.py +173 -0
- cognee/tests/unit/cli/test_cli_runner.py +62 -0
- cognee/tests/unit/cli/test_cli_utils.py +127 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +12 -15
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +10 -15
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +4 -3
- cognee/tests/unit/modules/retrieval/insights_retriever_test.py +4 -2
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +225 -0
- cognee/tests/unit/modules/users/__init__.py +1 -0
- cognee/tests/unit/modules/users/test_conditional_authentication.py +277 -0
- cognee/tests/unit/processing/utils/utils_test.py +20 -1
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dev0.dist-info}/METADATA +13 -9
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dev0.dist-info}/RECORD +245 -135
- cognee-0.3.0.dev0.dist-info/entry_points.txt +2 -0
- cognee/infrastructure/databases/graph/networkx/adapter.py +0 -1017
- cognee/infrastructure/pipeline/models/Operation.py +0 -60
- cognee/notebooks/github_analysis_step_by_step.ipynb +0 -37
- cognee/tests/tasks/descriptive_metrics/networkx_metrics_test.py +0 -7
- cognee/tests/unit/modules/search/search_methods_test.py +0 -223
- /cognee/{infrastructure/databases/graph/networkx → api/v1/memify}/__init__.py +0 -0
- /cognee/{infrastructure/pipeline/models → tasks/codingagents}/__init__.py +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseGraphRetriever(ABC):
|
|
8
|
+
"""Base class for all graph based retrievers."""
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
async def get_context(self, query: str) -> List[Edge]:
|
|
12
|
+
"""Retrieves triplets based on the query."""
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
async def get_completion(self, query: str, context: Optional[List[Edge]] = None) -> str:
|
|
17
|
+
"""Generates a response using the query and optional context (triplets)."""
|
|
18
|
+
pass
|
|
@@ -94,7 +94,15 @@ class CodeRetriever(BaseRetriever):
|
|
|
94
94
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
95
95
|
)
|
|
96
96
|
|
|
97
|
+
existing_collection = []
|
|
97
98
|
for collection in self.classes_and_functions_collections:
|
|
99
|
+
if await vector_engine.has_collection(collection):
|
|
100
|
+
existing_collection.append(collection)
|
|
101
|
+
|
|
102
|
+
if not existing_collection:
|
|
103
|
+
raise RuntimeError("No collection found for code retriever")
|
|
104
|
+
|
|
105
|
+
for collection in existing_collection:
|
|
98
106
|
logger.debug(f"Searching {collection} collection with general query")
|
|
99
107
|
search_results_code = await vector_engine.search(
|
|
100
108
|
collection, query, limit=self.top_k
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from functools import reduce
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
from cognee.shared.logging_utils import get_logger
|
|
5
|
+
from cognee.tasks.codingagents.coding_rule_associations import get_existing_rules
|
|
6
|
+
|
|
7
|
+
logger = get_logger("CodingRulesRetriever")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CodingRulesRetriever:
|
|
11
|
+
"""Retriever for handling codeing rule based searches."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, rules_nodeset_name: Optional[List[str]] = None):
|
|
14
|
+
if isinstance(rules_nodeset_name, list):
|
|
15
|
+
if not rules_nodeset_name:
|
|
16
|
+
# If there is no provided nodeset set to coding_agent_rules
|
|
17
|
+
rules_nodeset_name = ["coding_agent_rules"]
|
|
18
|
+
|
|
19
|
+
self.rules_nodeset_name = rules_nodeset_name
|
|
20
|
+
"""Initialize retriever with search parameters."""
|
|
21
|
+
|
|
22
|
+
async def get_existing_rules(self, query_text):
|
|
23
|
+
if self.rules_nodeset_name:
|
|
24
|
+
rules_list = await asyncio.gather(
|
|
25
|
+
*[
|
|
26
|
+
get_existing_rules(rules_nodeset_name=nodeset)
|
|
27
|
+
for nodeset in self.rules_nodeset_name
|
|
28
|
+
]
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
return reduce(lambda x, y: x + y, rules_list, [])
|
|
@@ -23,12 +23,14 @@ class CompletionRetriever(BaseRetriever):
|
|
|
23
23
|
self,
|
|
24
24
|
user_prompt_path: str = "context_for_question.txt",
|
|
25
25
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
26
|
+
system_prompt: Optional[str] = None,
|
|
26
27
|
top_k: Optional[int] = 1,
|
|
27
28
|
):
|
|
28
29
|
"""Initialize retriever with optional custom prompt paths."""
|
|
29
30
|
self.user_prompt_path = user_prompt_path
|
|
30
31
|
self.system_prompt_path = system_prompt_path
|
|
31
32
|
self.top_k = top_k if top_k is not None else 1
|
|
33
|
+
self.system_prompt = system_prompt
|
|
32
34
|
|
|
33
35
|
async def get_context(self, query: str) -> str:
|
|
34
36
|
"""
|
|
@@ -65,7 +67,7 @@ class CompletionRetriever(BaseRetriever):
|
|
|
65
67
|
logger.error("DocumentChunk_text collection not found")
|
|
66
68
|
raise NoDataError("No data found in the system, please add data first.") from error
|
|
67
69
|
|
|
68
|
-
async def get_completion(self, query: str, context: Optional[Any] = None) ->
|
|
70
|
+
async def get_completion(self, query: str, context: Optional[Any] = None) -> str:
|
|
69
71
|
"""
|
|
70
72
|
Generates an LLM completion using the context.
|
|
71
73
|
|
|
@@ -88,6 +90,10 @@ class CompletionRetriever(BaseRetriever):
|
|
|
88
90
|
context = await self.get_context(query)
|
|
89
91
|
|
|
90
92
|
completion = await generate_completion(
|
|
91
|
-
query,
|
|
93
|
+
query=query,
|
|
94
|
+
context=context,
|
|
95
|
+
user_prompt_path=self.user_prompt_path,
|
|
96
|
+
system_prompt_path=self.system_prompt_path,
|
|
97
|
+
system_prompt=self.system_prompt,
|
|
92
98
|
)
|
|
93
|
-
return
|
|
99
|
+
return completion
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from typing import Any, Optional
|
|
2
2
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
3
|
-
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
|
4
3
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
5
4
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
6
5
|
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported, CypherSearchError
|
|
@@ -31,8 +30,7 @@ class CypherSearchRetriever(BaseRetriever):
|
|
|
31
30
|
"""
|
|
32
31
|
Retrieves relevant context using a cypher query.
|
|
33
32
|
|
|
34
|
-
If
|
|
35
|
-
any error occurs during execution, logs the error and raises CypherSearchError.
|
|
33
|
+
If any error occurs during execution, logs the error and raises CypherSearchError.
|
|
36
34
|
|
|
37
35
|
Parameters:
|
|
38
36
|
-----------
|
|
@@ -46,12 +44,6 @@ class CypherSearchRetriever(BaseRetriever):
|
|
|
46
44
|
"""
|
|
47
45
|
try:
|
|
48
46
|
graph_engine = await get_graph_engine()
|
|
49
|
-
|
|
50
|
-
if isinstance(graph_engine, NetworkXAdapter):
|
|
51
|
-
raise SearchTypeNotSupported(
|
|
52
|
-
"CYPHER search type not supported for NetworkXAdapter."
|
|
53
|
-
)
|
|
54
|
-
|
|
55
47
|
result = await graph_engine.query(query)
|
|
56
48
|
except Exception as e:
|
|
57
49
|
logger.error("Failed to execture cypher search retrieval: %s", str(e))
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, List, Type
|
|
2
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
2
3
|
from cognee.shared.logging_utils import get_logger
|
|
3
4
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
4
5
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
@@ -26,9 +27,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
26
27
|
self,
|
|
27
28
|
user_prompt_path: str = "graph_context_for_question.txt",
|
|
28
29
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
30
|
+
system_prompt: Optional[str] = None,
|
|
29
31
|
top_k: Optional[int] = 5,
|
|
30
32
|
node_type: Optional[Type] = None,
|
|
31
33
|
node_name: Optional[List[str]] = None,
|
|
34
|
+
save_interaction: bool = False,
|
|
32
35
|
):
|
|
33
36
|
super().__init__(
|
|
34
37
|
user_prompt_path=user_prompt_path,
|
|
@@ -36,11 +39,16 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
36
39
|
top_k=top_k,
|
|
37
40
|
node_type=node_type,
|
|
38
41
|
node_name=node_name,
|
|
42
|
+
save_interaction=save_interaction,
|
|
43
|
+
system_prompt=system_prompt,
|
|
39
44
|
)
|
|
40
45
|
|
|
41
46
|
async def get_completion(
|
|
42
|
-
self,
|
|
43
|
-
|
|
47
|
+
self,
|
|
48
|
+
query: str,
|
|
49
|
+
context: Optional[List[Edge]] = None,
|
|
50
|
+
context_extension_rounds=4,
|
|
51
|
+
) -> str:
|
|
44
52
|
"""
|
|
45
53
|
Extends the context for a given query by retrieving related triplets and generating new
|
|
46
54
|
completions based on them.
|
|
@@ -65,11 +73,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
65
73
|
- List[str]: A list containing the generated answer based on the query and the
|
|
66
74
|
extended context.
|
|
67
75
|
"""
|
|
68
|
-
triplets =
|
|
76
|
+
triplets = context
|
|
77
|
+
|
|
78
|
+
if triplets is None:
|
|
79
|
+
triplets = await self.get_context(query)
|
|
69
80
|
|
|
70
|
-
|
|
71
|
-
triplets += await self.get_triplets(query)
|
|
72
|
-
context = await self.resolve_edges_to_text(triplets)
|
|
81
|
+
context_text = await self.resolve_edges_to_text(triplets)
|
|
73
82
|
|
|
74
83
|
round_idx = 1
|
|
75
84
|
|
|
@@ -81,14 +90,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
81
90
|
)
|
|
82
91
|
completion = await generate_completion(
|
|
83
92
|
query=query,
|
|
84
|
-
context=
|
|
93
|
+
context=context_text,
|
|
85
94
|
user_prompt_path=self.user_prompt_path,
|
|
86
95
|
system_prompt_path=self.system_prompt_path,
|
|
96
|
+
system_prompt=self.system_prompt,
|
|
87
97
|
)
|
|
88
98
|
|
|
89
|
-
triplets += await self.
|
|
99
|
+
triplets += await self.get_context(completion)
|
|
90
100
|
triplets = list(set(triplets))
|
|
91
|
-
|
|
101
|
+
context_text = await self.resolve_edges_to_text(triplets)
|
|
92
102
|
|
|
93
103
|
num_triplets = len(triplets)
|
|
94
104
|
|
|
@@ -105,11 +115,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
105
115
|
|
|
106
116
|
round_idx += 1
|
|
107
117
|
|
|
108
|
-
|
|
118
|
+
completion = await generate_completion(
|
|
109
119
|
query=query,
|
|
110
|
-
context=
|
|
120
|
+
context=context_text,
|
|
111
121
|
user_prompt_path=self.user_prompt_path,
|
|
112
122
|
system_prompt_path=self.system_prompt_path,
|
|
123
|
+
system_prompt=self.system_prompt,
|
|
113
124
|
)
|
|
114
125
|
|
|
115
|
-
|
|
126
|
+
if self.save_interaction and context_text and triplets and completion:
|
|
127
|
+
await self.save_qa(
|
|
128
|
+
question=query, answer=completion, context=context_text, triplets=triplets
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return completion
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, List, Type
|
|
2
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
2
3
|
from cognee.shared.logging_utils import get_logger
|
|
3
4
|
|
|
4
5
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
@@ -32,16 +33,20 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
32
33
|
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
|
33
34
|
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
|
34
35
|
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
|
36
|
+
system_prompt: Optional[str] = None,
|
|
35
37
|
top_k: Optional[int] = 5,
|
|
36
38
|
node_type: Optional[Type] = None,
|
|
37
39
|
node_name: Optional[List[str]] = None,
|
|
40
|
+
save_interaction: bool = False,
|
|
38
41
|
):
|
|
39
42
|
super().__init__(
|
|
40
43
|
user_prompt_path=user_prompt_path,
|
|
41
44
|
system_prompt_path=system_prompt_path,
|
|
45
|
+
system_prompt=system_prompt,
|
|
42
46
|
top_k=top_k,
|
|
43
47
|
node_type=node_type,
|
|
44
48
|
node_name=node_name,
|
|
49
|
+
save_interaction=save_interaction,
|
|
45
50
|
)
|
|
46
51
|
self.validation_system_prompt_path = validation_system_prompt_path
|
|
47
52
|
self.validation_user_prompt_path = validation_user_prompt_path
|
|
@@ -49,8 +54,11 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
49
54
|
self.followup_user_prompt_path = followup_user_prompt_path
|
|
50
55
|
|
|
51
56
|
async def get_completion(
|
|
52
|
-
self,
|
|
53
|
-
|
|
57
|
+
self,
|
|
58
|
+
query: str,
|
|
59
|
+
context: Optional[List[Edge]] = None,
|
|
60
|
+
max_iter=4,
|
|
61
|
+
) -> str:
|
|
54
62
|
"""
|
|
55
63
|
Generate completion responses based on a user query and contextual information.
|
|
56
64
|
|
|
@@ -75,25 +83,29 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
75
83
|
"""
|
|
76
84
|
followup_question = ""
|
|
77
85
|
triplets = []
|
|
78
|
-
|
|
86
|
+
completion = ""
|
|
79
87
|
|
|
80
88
|
for round_idx in range(max_iter + 1):
|
|
81
89
|
if round_idx == 0:
|
|
82
90
|
if context is None:
|
|
83
|
-
|
|
91
|
+
triplets = await self.get_context(query)
|
|
92
|
+
context_text = await self.resolve_edges_to_text(triplets)
|
|
93
|
+
else:
|
|
94
|
+
context_text = await self.resolve_edges_to_text(context)
|
|
84
95
|
else:
|
|
85
|
-
triplets += await self.
|
|
86
|
-
|
|
96
|
+
triplets += await self.get_context(followup_question)
|
|
97
|
+
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
|
87
98
|
|
|
88
|
-
|
|
99
|
+
completion = await generate_completion(
|
|
89
100
|
query=query,
|
|
90
|
-
context=
|
|
101
|
+
context=context_text,
|
|
91
102
|
user_prompt_path=self.user_prompt_path,
|
|
92
103
|
system_prompt_path=self.system_prompt_path,
|
|
104
|
+
system_prompt=self.system_prompt,
|
|
93
105
|
)
|
|
94
|
-
logger.info(f"Chain-of-thought: round {round_idx} - answer: {
|
|
106
|
+
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
|
95
107
|
if round_idx < max_iter:
|
|
96
|
-
valid_args = {"query": query, "answer":
|
|
108
|
+
valid_args = {"query": query, "answer": completion, "context": context_text}
|
|
97
109
|
valid_user_prompt = LLMGateway.render_prompt(
|
|
98
110
|
filename=self.validation_user_prompt_path, context=valid_args
|
|
99
111
|
)
|
|
@@ -106,7 +118,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
106
118
|
system_prompt=valid_system_prompt,
|
|
107
119
|
response_model=str,
|
|
108
120
|
)
|
|
109
|
-
followup_args = {"query": query, "answer":
|
|
121
|
+
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
|
|
110
122
|
followup_prompt = LLMGateway.render_prompt(
|
|
111
123
|
filename=self.followup_user_prompt_path, context=followup_args
|
|
112
124
|
)
|
|
@@ -121,4 +133,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
121
133
|
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
|
122
134
|
)
|
|
123
135
|
|
|
124
|
-
|
|
136
|
+
if self.save_interaction and context and triplets and completion:
|
|
137
|
+
await self.save_qa(
|
|
138
|
+
question=query, answer=completion, context=context_text, triplets=triplets
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return completion
|
|
@@ -1,19 +1,25 @@
|
|
|
1
1
|
from typing import Any, Optional, Type, List
|
|
2
|
-
from
|
|
3
|
-
import string
|
|
2
|
+
from uuid import NAMESPACE_OID, uuid5
|
|
4
3
|
|
|
5
4
|
from cognee.infrastructure.engine import DataPoint
|
|
5
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
6
|
+
from cognee.modules.users.methods import get_default_user
|
|
7
|
+
from cognee.tasks.storage import add_data_points
|
|
8
|
+
from cognee.modules.graph.utils import resolve_edges_to_text
|
|
6
9
|
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
|
7
|
-
from cognee.modules.retrieval.
|
|
10
|
+
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
|
8
11
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
|
9
12
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
10
|
-
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
|
11
13
|
from cognee.shared.logging_utils import get_logger
|
|
14
|
+
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
|
15
|
+
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
|
16
|
+
from cognee.modules.engine.models.node_set import NodeSet
|
|
17
|
+
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
12
18
|
|
|
13
19
|
logger = get_logger("GraphCompletionRetriever")
|
|
14
20
|
|
|
15
21
|
|
|
16
|
-
class GraphCompletionRetriever(
|
|
22
|
+
class GraphCompletionRetriever(BaseGraphRetriever):
|
|
17
23
|
"""
|
|
18
24
|
Retriever for handling graph-based completion searches.
|
|
19
25
|
|
|
@@ -30,33 +36,21 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
30
36
|
self,
|
|
31
37
|
user_prompt_path: str = "graph_context_for_question.txt",
|
|
32
38
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
39
|
+
system_prompt: Optional[str] = None,
|
|
33
40
|
top_k: Optional[int] = 5,
|
|
34
41
|
node_type: Optional[Type] = None,
|
|
35
42
|
node_name: Optional[List[str]] = None,
|
|
43
|
+
save_interaction: bool = False,
|
|
36
44
|
):
|
|
37
45
|
"""Initialize retriever with prompt paths and search parameters."""
|
|
46
|
+
self.save_interaction = save_interaction
|
|
38
47
|
self.user_prompt_path = user_prompt_path
|
|
39
48
|
self.system_prompt_path = system_prompt_path
|
|
49
|
+
self.system_prompt = system_prompt
|
|
40
50
|
self.top_k = top_k if top_k is not None else 5
|
|
41
51
|
self.node_type = node_type
|
|
42
52
|
self.node_name = node_name
|
|
43
53
|
|
|
44
|
-
def _get_nodes(self, retrieved_edges: list) -> dict:
|
|
45
|
-
"""Creates a dictionary of nodes with their names and content."""
|
|
46
|
-
nodes = {}
|
|
47
|
-
for edge in retrieved_edges:
|
|
48
|
-
for node in (edge.node1, edge.node2):
|
|
49
|
-
if node.id not in nodes:
|
|
50
|
-
text = node.attributes.get("text")
|
|
51
|
-
if text:
|
|
52
|
-
name = self._get_title(text)
|
|
53
|
-
content = text
|
|
54
|
-
else:
|
|
55
|
-
name = node.attributes.get("name", "Unnamed Node")
|
|
56
|
-
content = node.attributes.get("description", name)
|
|
57
|
-
nodes[node.id] = {"node": node, "name": name, "content": content}
|
|
58
|
-
return nodes
|
|
59
|
-
|
|
60
54
|
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
|
61
55
|
"""
|
|
62
56
|
Converts retrieved graph edges into a human-readable string format.
|
|
@@ -71,18 +65,9 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
71
65
|
|
|
72
66
|
- str: A formatted string representation of the nodes and their connections.
|
|
73
67
|
"""
|
|
74
|
-
|
|
75
|
-
node_section = "\n".join(
|
|
76
|
-
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
|
77
|
-
for info in nodes.values()
|
|
78
|
-
)
|
|
79
|
-
connection_section = "\n".join(
|
|
80
|
-
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
|
|
81
|
-
for edge in retrieved_edges
|
|
82
|
-
)
|
|
83
|
-
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
|
68
|
+
return await resolve_edges_to_text(retrieved_edges)
|
|
84
69
|
|
|
85
|
-
async def get_triplets(self, query: str) ->
|
|
70
|
+
async def get_triplets(self, query: str) -> List[Edge]:
|
|
86
71
|
"""
|
|
87
72
|
Retrieves relevant graph triplets based on a query string.
|
|
88
73
|
|
|
@@ -97,7 +82,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
97
82
|
- list: A list of found triplets that match the query.
|
|
98
83
|
"""
|
|
99
84
|
subclasses = get_all_subclasses(DataPoint)
|
|
100
|
-
vector_index_collections = []
|
|
85
|
+
vector_index_collections: List[str] = []
|
|
101
86
|
|
|
102
87
|
for subclass in subclasses:
|
|
103
88
|
if "metadata" in subclass.model_fields:
|
|
@@ -108,8 +93,11 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
108
93
|
for field_name in index_fields:
|
|
109
94
|
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
|
110
95
|
|
|
96
|
+
user = await get_default_user()
|
|
97
|
+
|
|
111
98
|
found_triplets = await brute_force_triplet_search(
|
|
112
99
|
query,
|
|
100
|
+
user=user,
|
|
113
101
|
top_k=self.top_k,
|
|
114
102
|
collections=vector_index_collections or None,
|
|
115
103
|
node_type=self.node_type,
|
|
@@ -118,7 +106,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
118
106
|
|
|
119
107
|
return found_triplets
|
|
120
108
|
|
|
121
|
-
async def get_context(self, query: str) ->
|
|
109
|
+
async def get_context(self, query: str) -> List[Edge]:
|
|
122
110
|
"""
|
|
123
111
|
Retrieves and resolves graph triplets into context based on a query.
|
|
124
112
|
|
|
@@ -137,11 +125,17 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
137
125
|
|
|
138
126
|
if len(triplets) == 0:
|
|
139
127
|
logger.warning("Empty context was provided to the completion")
|
|
140
|
-
return
|
|
128
|
+
return []
|
|
129
|
+
|
|
130
|
+
# context = await self.resolve_edges_to_text(triplets)
|
|
141
131
|
|
|
142
|
-
return
|
|
132
|
+
return triplets
|
|
143
133
|
|
|
144
|
-
async def get_completion(
|
|
134
|
+
async def get_completion(
|
|
135
|
+
self,
|
|
136
|
+
query: str,
|
|
137
|
+
context: Optional[List[Edge]] = None,
|
|
138
|
+
) -> Any:
|
|
145
139
|
"""
|
|
146
140
|
Generates a completion using graph connections context based on a query.
|
|
147
141
|
|
|
@@ -157,33 +151,90 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
157
151
|
|
|
158
152
|
- Any: A generated completion based on the query and context provided.
|
|
159
153
|
"""
|
|
160
|
-
|
|
161
|
-
|
|
154
|
+
triplets = context
|
|
155
|
+
|
|
156
|
+
if triplets is None:
|
|
157
|
+
triplets = await self.get_context(query)
|
|
158
|
+
|
|
159
|
+
context_text = await resolve_edges_to_text(triplets)
|
|
162
160
|
|
|
163
161
|
completion = await generate_completion(
|
|
164
162
|
query=query,
|
|
165
|
-
context=
|
|
163
|
+
context=context_text,
|
|
166
164
|
user_prompt_path=self.user_prompt_path,
|
|
167
165
|
system_prompt_path=self.system_prompt_path,
|
|
166
|
+
system_prompt=self.system_prompt,
|
|
168
167
|
)
|
|
169
|
-
return [completion]
|
|
170
168
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
169
|
+
if self.save_interaction and context and triplets and completion:
|
|
170
|
+
await self.save_qa(
|
|
171
|
+
question=query, answer=completion, context=context_text, triplets=triplets
|
|
172
|
+
)
|
|
175
173
|
|
|
176
|
-
|
|
174
|
+
return completion
|
|
177
175
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
176
|
+
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
|
177
|
+
"""
|
|
178
|
+
Saves a question and answer pair for later analysis or storage.
|
|
179
|
+
Parameters:
|
|
180
|
+
-----------
|
|
181
|
+
- question (str): The question text.
|
|
182
|
+
- answer (str): The answer text.
|
|
183
|
+
- context (str): The context text.
|
|
184
|
+
- triplets (List): A list of triples retrieved from the graph.
|
|
185
|
+
"""
|
|
186
|
+
nodeset_name = "Interactions"
|
|
187
|
+
interactions_node_set = NodeSet(
|
|
188
|
+
id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name
|
|
189
|
+
)
|
|
190
|
+
source_id = uuid5(NAMESPACE_OID, name=(question + answer + context))
|
|
182
191
|
|
|
183
|
-
|
|
192
|
+
cognee_user_interaction = CogneeUserInteraction(
|
|
193
|
+
id=source_id,
|
|
194
|
+
question=question,
|
|
195
|
+
answer=answer,
|
|
196
|
+
context=context,
|
|
197
|
+
belongs_to_set=interactions_node_set,
|
|
198
|
+
)
|
|
184
199
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
200
|
+
await add_data_points(data_points=[cognee_user_interaction], update_edge_collection=False)
|
|
201
|
+
|
|
202
|
+
relationships = []
|
|
203
|
+
relationship_name = "used_graph_element_to_answer"
|
|
204
|
+
for triplet in triplets:
|
|
205
|
+
target_id_1 = extract_uuid_from_node(triplet.node1)
|
|
206
|
+
target_id_2 = extract_uuid_from_node(triplet.node2)
|
|
207
|
+
if target_id_1 and target_id_2:
|
|
208
|
+
relationships.append(
|
|
209
|
+
(
|
|
210
|
+
source_id,
|
|
211
|
+
target_id_1,
|
|
212
|
+
relationship_name,
|
|
213
|
+
{
|
|
214
|
+
"relationship_name": relationship_name,
|
|
215
|
+
"source_node_id": source_id,
|
|
216
|
+
"target_node_id": target_id_1,
|
|
217
|
+
"ontology_valid": False,
|
|
218
|
+
"feedback_weight": 0,
|
|
219
|
+
},
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
relationships.append(
|
|
224
|
+
(
|
|
225
|
+
source_id,
|
|
226
|
+
target_id_2,
|
|
227
|
+
relationship_name,
|
|
228
|
+
{
|
|
229
|
+
"relationship_name": relationship_name,
|
|
230
|
+
"source_node_id": source_id,
|
|
231
|
+
"target_node_id": target_id_2,
|
|
232
|
+
"ontology_valid": False,
|
|
233
|
+
"feedback_weight": 0,
|
|
234
|
+
},
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
if len(relationships) > 0:
|
|
239
|
+
graph_engine = await get_graph_engine()
|
|
240
|
+
await graph_engine.add_edges(relationships)
|
|
@@ -21,9 +21,11 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|
|
21
21
|
user_prompt_path: str = "graph_context_for_question.txt",
|
|
22
22
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
23
23
|
summarize_prompt_path: str = "summarize_search_results.txt",
|
|
24
|
+
system_prompt: Optional[str] = None,
|
|
24
25
|
top_k: Optional[int] = 5,
|
|
25
26
|
node_type: Optional[Type] = None,
|
|
26
27
|
node_name: Optional[List[str]] = None,
|
|
28
|
+
save_interaction: bool = False,
|
|
27
29
|
):
|
|
28
30
|
"""Initialize retriever with default prompt paths and search parameters."""
|
|
29
31
|
super().__init__(
|
|
@@ -32,6 +34,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|
|
32
34
|
top_k=top_k,
|
|
33
35
|
node_type=node_type,
|
|
34
36
|
node_name=node_name,
|
|
37
|
+
save_interaction=save_interaction,
|
|
38
|
+
system_prompt=system_prompt,
|
|
35
39
|
)
|
|
36
40
|
self.summarize_prompt_path = summarize_prompt_path
|
|
37
41
|
|
|
@@ -55,4 +59,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|
|
55
59
|
- str: A summary string representing the content of the retrieved edges.
|
|
56
60
|
"""
|
|
57
61
|
direct_text = await super().resolve_edges_to_text(retrieved_edges)
|
|
58
|
-
return await summarize_text(direct_text, self.summarize_prompt_path)
|
|
62
|
+
return await summarize_text(direct_text, self.summarize_prompt_path, self.system_prompt)
|
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import Any, Optional
|
|
3
3
|
|
|
4
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
|
5
|
+
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
|
4
6
|
from cognee.shared.logging_utils import get_logger
|
|
5
7
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
6
8
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
7
|
-
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
8
9
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
9
10
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
|
10
11
|
|
|
11
12
|
logger = get_logger("InsightsRetriever")
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
class InsightsRetriever(
|
|
15
|
+
class InsightsRetriever(BaseGraphRetriever):
|
|
15
16
|
"""
|
|
16
17
|
Retriever for handling graph connection-based insights.
|
|
17
18
|
|
|
@@ -95,7 +96,17 @@ class InsightsRetriever(BaseRetriever):
|
|
|
95
96
|
unique_node_connections_map[unique_id] = True
|
|
96
97
|
unique_node_connections.append(node_connection)
|
|
97
98
|
|
|
98
|
-
return
|
|
99
|
+
return [
|
|
100
|
+
Edge(
|
|
101
|
+
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
|
102
|
+
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
|
103
|
+
attributes={
|
|
104
|
+
**connection[1],
|
|
105
|
+
"relationship_type": connection[1]["relationship_name"],
|
|
106
|
+
},
|
|
107
|
+
)
|
|
108
|
+
for connection in unique_node_connections
|
|
109
|
+
]
|
|
99
110
|
|
|
100
111
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
101
112
|
"""
|