cognee 0.5.1.dev0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +45 -3
- cognee/api/v1/memify/routers/get_memify_router.py +2 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +25 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +2 -1
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +69 -22
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +24 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +1 -0
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +31 -32
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -85
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -215
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/chunks/__init__.py +9 -0
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/graph/__init__.py +7 -0
- cognee/tasks/memify/__init__.py +8 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +115 -16
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +13 -5
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +22 -20
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +23 -24
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +70 -5
- cognee/tests/integration/retrieval/test_structured_output.py +62 -18
- cognee/tests/integration/retrieval/test_summaries_retriever.py +20 -9
- cognee/tests/integration/retrieval/test_temporal_retriever.py +38 -8
- cognee/tests/integration/retrieval/test_triplet_retriever.py +13 -4
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +97 -110
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +31 -59
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +70 -33
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +72 -52
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +27 -33
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +28 -15
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +37 -42
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +48 -64
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +263 -24
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +30 -16
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +176 -0
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/RECORD +235 -147
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -8,6 +8,7 @@ from cognee.modules.retrieval.utils.session_cache import (
|
|
|
8
8
|
save_conversation_history,
|
|
9
9
|
get_conversation_history,
|
|
10
10
|
)
|
|
11
|
+
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps
|
|
11
12
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
12
13
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
13
14
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
@@ -20,10 +21,6 @@ logger = get_logger("CompletionRetriever")
|
|
|
20
21
|
class CompletionRetriever(BaseRetriever):
|
|
21
22
|
"""
|
|
22
23
|
Retriever for handling LLM-based completion searches.
|
|
23
|
-
|
|
24
|
-
Public methods:
|
|
25
|
-
- get_context(query: str) -> str
|
|
26
|
-
- get_completion(query: str, context: Optional[Any] = None) -> Any
|
|
27
24
|
"""
|
|
28
25
|
|
|
29
26
|
def __init__(
|
|
@@ -32,14 +29,31 @@ class CompletionRetriever(BaseRetriever):
|
|
|
32
29
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
33
30
|
system_prompt: Optional[str] = None,
|
|
34
31
|
top_k: Optional[int] = 1,
|
|
32
|
+
session_id: Optional[str] = None,
|
|
33
|
+
response_model: Type = str,
|
|
35
34
|
):
|
|
36
35
|
"""Initialize retriever with optional custom prompt paths."""
|
|
37
36
|
self.user_prompt_path = user_prompt_path
|
|
38
37
|
self.system_prompt_path = system_prompt_path
|
|
39
38
|
self.top_k = top_k if top_k is not None else 1
|
|
40
39
|
self.system_prompt = system_prompt
|
|
40
|
+
self.session_id = session_id
|
|
41
|
+
self.response_model = response_model
|
|
42
|
+
|
|
43
|
+
async def get_retrieved_objects(self, query: str) -> Any:
|
|
44
|
+
vector_engine = get_vector_engine()
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
found_chunks = await vector_engine.search(
|
|
48
|
+
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
return found_chunks
|
|
52
|
+
except CollectionNotFoundError as error:
|
|
53
|
+
logger.error("DocumentChunk_text collection not found")
|
|
54
|
+
raise NoDataError("No data found in the system, please add data first.") from error
|
|
41
55
|
|
|
42
|
-
async def
|
|
56
|
+
async def get_context_from_objects(self, query: str, retrieved_objects: Any) -> str:
|
|
43
57
|
"""
|
|
44
58
|
Retrieves relevant document chunks as context.
|
|
45
59
|
|
|
@@ -58,28 +72,18 @@ class CompletionRetriever(BaseRetriever):
|
|
|
58
72
|
- str: A string containing the combined text of the retrieved document chunks, or an
|
|
59
73
|
empty string if none are found.
|
|
60
74
|
"""
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
|
65
|
-
|
|
66
|
-
if len(found_chunks) == 0:
|
|
67
|
-
return ""
|
|
68
|
-
|
|
69
|
-
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
|
70
|
-
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
|
75
|
+
if retrieved_objects:
|
|
76
|
+
# Combine all chunks text returned from vector search (number of chunks is determined by top_k)
|
|
77
|
+
chunks_payload = [found_chunk.payload["text"] for found_chunk in retrieved_objects]
|
|
71
78
|
combined_context = "\n".join(chunks_payload)
|
|
72
79
|
return combined_context
|
|
73
|
-
|
|
74
|
-
logger.error("DocumentChunk_text collection not found")
|
|
75
|
-
raise NoDataError("No data found in the system, please add data first.") from error
|
|
80
|
+
return ""
|
|
76
81
|
|
|
77
|
-
async def
|
|
82
|
+
async def get_completion_from_context(
|
|
78
83
|
self,
|
|
79
84
|
query: str,
|
|
85
|
+
retrieved_objects: Any,
|
|
80
86
|
context: Optional[Any] = None,
|
|
81
|
-
session_id: Optional[str] = None,
|
|
82
|
-
response_model: Type = str,
|
|
83
87
|
) -> List[Any]:
|
|
84
88
|
"""
|
|
85
89
|
Generates an LLM completion using the context.
|
|
@@ -102,9 +106,6 @@ class CompletionRetriever(BaseRetriever):
|
|
|
102
106
|
|
|
103
107
|
- Any: The generated completion based on the provided query and context.
|
|
104
108
|
"""
|
|
105
|
-
if context is None:
|
|
106
|
-
context = await self.get_context(query)
|
|
107
|
-
|
|
108
109
|
# Check if we need to generate context summary for caching
|
|
109
110
|
cache_config = CacheConfig()
|
|
110
111
|
user = session_user.get()
|
|
@@ -112,7 +113,7 @@ class CompletionRetriever(BaseRetriever):
|
|
|
112
113
|
session_save = user_id and cache_config.caching
|
|
113
114
|
|
|
114
115
|
if session_save:
|
|
115
|
-
conversation_history = await get_conversation_history(session_id=session_id)
|
|
116
|
+
conversation_history = await get_conversation_history(session_id=self.session_id)
|
|
116
117
|
|
|
117
118
|
context_summary, completion = await asyncio.gather(
|
|
118
119
|
summarize_text(context),
|
|
@@ -123,7 +124,7 @@ class CompletionRetriever(BaseRetriever):
|
|
|
123
124
|
system_prompt_path=self.system_prompt_path,
|
|
124
125
|
system_prompt=self.system_prompt,
|
|
125
126
|
conversation_history=conversation_history,
|
|
126
|
-
response_model=response_model,
|
|
127
|
+
response_model=self.response_model,
|
|
127
128
|
),
|
|
128
129
|
)
|
|
129
130
|
else:
|
|
@@ -133,7 +134,7 @@ class CompletionRetriever(BaseRetriever):
|
|
|
133
134
|
user_prompt_path=self.user_prompt_path,
|
|
134
135
|
system_prompt_path=self.system_prompt_path,
|
|
135
136
|
system_prompt=self.system_prompt,
|
|
136
|
-
response_model=response_model,
|
|
137
|
+
response_model=self.response_model,
|
|
137
138
|
)
|
|
138
139
|
|
|
139
140
|
if session_save:
|
|
@@ -141,7 +142,7 @@ class CompletionRetriever(BaseRetriever):
|
|
|
141
142
|
query=query,
|
|
142
143
|
context_summary=context_summary,
|
|
143
144
|
answer=completion,
|
|
144
|
-
session_id=session_id,
|
|
145
|
+
session_id=self.session_id,
|
|
145
146
|
)
|
|
146
147
|
|
|
147
148
|
return [completion]
|
|
@@ -23,12 +23,29 @@ class CypherSearchRetriever(BaseRetriever):
|
|
|
23
23
|
self,
|
|
24
24
|
user_prompt_path: str = "context_for_question.txt",
|
|
25
25
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
26
|
+
session_id: Optional[str] = None,
|
|
26
27
|
):
|
|
27
28
|
"""Initialize retriever with optional custom prompt paths."""
|
|
28
29
|
self.user_prompt_path = user_prompt_path
|
|
29
30
|
self.system_prompt_path = system_prompt_path
|
|
31
|
+
self.session_id = session_id
|
|
30
32
|
|
|
31
|
-
async def
|
|
33
|
+
async def get_retrieved_objects(self, query: str) -> Any:
|
|
34
|
+
try:
|
|
35
|
+
graph_engine = await get_graph_engine()
|
|
36
|
+
is_empty = await graph_engine.is_empty()
|
|
37
|
+
|
|
38
|
+
if is_empty:
|
|
39
|
+
logger.warning("Search attempt on an empty knowledge graph")
|
|
40
|
+
return []
|
|
41
|
+
|
|
42
|
+
result = await graph_engine.query(query)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
logger.error("Failed to execture cypher search retrieval: %s", str(e))
|
|
45
|
+
raise CypherSearchError() from e
|
|
46
|
+
return result
|
|
47
|
+
|
|
48
|
+
async def get_context_from_objects(self, query: str, retrieved_objects: Any) -> Any:
|
|
32
49
|
"""
|
|
33
50
|
Retrieves relevant context using a cypher query.
|
|
34
51
|
|
|
@@ -44,22 +61,12 @@ class CypherSearchRetriever(BaseRetriever):
|
|
|
44
61
|
|
|
45
62
|
- Any: The result of the cypher query execution.
|
|
46
63
|
"""
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
if is_empty:
|
|
52
|
-
logger.warning("Search attempt on an empty knowledge graph")
|
|
53
|
-
return []
|
|
54
|
-
|
|
55
|
-
result = jsonable_encoder(await graph_engine.query(query))
|
|
56
|
-
except Exception as e:
|
|
57
|
-
logger.error("Failed to execture cypher search retrieval: %s", str(e))
|
|
58
|
-
raise CypherSearchError() from e
|
|
59
|
-
return result
|
|
64
|
+
# TODO: Do we want to return a string response here?
|
|
65
|
+
# return jsonable_encoder(retrieved_objects)
|
|
66
|
+
return None
|
|
60
67
|
|
|
61
|
-
async def
|
|
62
|
-
self, query: str,
|
|
68
|
+
async def get_completion_from_context(
|
|
69
|
+
self, query: str, retrieved_objects: Any, context: Optional[Any] = None
|
|
63
70
|
) -> Any:
|
|
64
71
|
"""
|
|
65
72
|
Returns the graph connections context.
|
|
@@ -72,7 +79,6 @@ class CypherSearchRetriever(BaseRetriever):
|
|
|
72
79
|
- query (str): The query to retrieve context.
|
|
73
80
|
- context (Optional[Any]): Optional context to use, otherwise fetched using the
|
|
74
81
|
query. (default None)
|
|
75
|
-
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
76
82
|
defaults to 'default_session'. (default None)
|
|
77
83
|
|
|
78
84
|
Returns:
|
|
@@ -80,6 +86,5 @@ class CypherSearchRetriever(BaseRetriever):
|
|
|
80
86
|
|
|
81
87
|
- Any: The context, either provided or retrieved.
|
|
82
88
|
"""
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
return context
|
|
89
|
+
# TODO: Do we want to generate a completion using LLM here?
|
|
90
|
+
return None
|
|
@@ -18,16 +18,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
18
18
|
"""
|
|
19
19
|
Handles graph context completion for question answering tasks, extending context based
|
|
20
20
|
on retrieved triplets.
|
|
21
|
-
|
|
22
|
-
Public methods:
|
|
23
|
-
- get_completion
|
|
24
|
-
|
|
25
|
-
Instance variables:
|
|
26
|
-
- user_prompt_path
|
|
27
|
-
- system_prompt_path
|
|
28
|
-
- top_k
|
|
29
|
-
- node_type
|
|
30
|
-
- node_name
|
|
31
21
|
"""
|
|
32
22
|
|
|
33
23
|
def __init__(
|
|
@@ -41,6 +31,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
41
31
|
save_interaction: bool = False,
|
|
42
32
|
wide_search_top_k: Optional[int] = 100,
|
|
43
33
|
triplet_distance_penalty: Optional[float] = 3.5,
|
|
34
|
+
context_extension_rounds: int = 4,
|
|
35
|
+
session_id: Optional[str] = None,
|
|
36
|
+
response_model: Type = str,
|
|
44
37
|
):
|
|
45
38
|
super().__init__(
|
|
46
39
|
user_prompt_path=user_prompt_path,
|
|
@@ -52,53 +45,38 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
52
45
|
system_prompt=system_prompt,
|
|
53
46
|
wide_search_top_k=wide_search_top_k,
|
|
54
47
|
triplet_distance_penalty=triplet_distance_penalty,
|
|
48
|
+
session_id=session_id,
|
|
49
|
+
response_model=response_model,
|
|
55
50
|
)
|
|
56
51
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
context_extension_rounds=4,
|
|
63
|
-
response_model: Type = str,
|
|
64
|
-
) -> List[Any]:
|
|
52
|
+
# context_extension_rounds: The maximum number of rounds to extend the context with
|
|
53
|
+
# new triplets before halting. (default 4)
|
|
54
|
+
self.context_extension_rounds = context_extension_rounds
|
|
55
|
+
|
|
56
|
+
async def get_retrieved_objects(self, query: str) -> List[Edge]:
|
|
65
57
|
"""
|
|
66
58
|
Extends the context for a given query by retrieving related triplets and generating new
|
|
67
59
|
completions based on them.
|
|
68
60
|
|
|
69
|
-
The method runs for a specified number of rounds to enhance
|
|
61
|
+
The method runs for a specified number of rounds to enhance results until no new
|
|
70
62
|
triplets are found or the maximum rounds are reached. It retrieves triplet suggestions
|
|
71
63
|
based on a generated completion from previous iterations, logging the process of context
|
|
72
64
|
extension.
|
|
73
65
|
|
|
74
66
|
Parameters:
|
|
75
67
|
-----------
|
|
76
|
-
|
|
77
68
|
- query (str): The input query for which the completion is generated.
|
|
78
|
-
- context (Optional[Any]): The existing context to use for enhancing the query; if
|
|
79
|
-
None, it will be initialized from triplets generated for the query. (default None)
|
|
80
|
-
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
81
|
-
defaults to 'default_session'. (default None)
|
|
82
|
-
- context_extension_rounds: The maximum number of rounds to extend the context with
|
|
83
|
-
new triplets before halting. (default 4)
|
|
84
|
-
- response_model (Type): The Pydantic model type for structured output. (default str)
|
|
85
69
|
|
|
86
70
|
Returns:
|
|
87
71
|
--------
|
|
88
|
-
|
|
89
|
-
- List[str]: A list containing the generated answer based on the query and the
|
|
90
|
-
extended context.
|
|
72
|
+
- List[Edge]: A list of retrieved triplet edges relevant to the query.
|
|
91
73
|
"""
|
|
92
|
-
triplets = context
|
|
93
|
-
|
|
94
|
-
if triplets is None:
|
|
95
|
-
triplets = await self.get_context(query)
|
|
96
74
|
|
|
75
|
+
triplets = await self.get_triplets(query)
|
|
97
76
|
context_text = await self.resolve_edges_to_text(triplets)
|
|
98
|
-
|
|
99
77
|
round_idx = 1
|
|
100
78
|
|
|
101
|
-
while round_idx <= context_extension_rounds:
|
|
79
|
+
while round_idx <= self.context_extension_rounds:
|
|
102
80
|
prev_size = len(triplets)
|
|
103
81
|
|
|
104
82
|
logger.info(
|
|
@@ -112,7 +90,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
112
90
|
system_prompt=self.system_prompt,
|
|
113
91
|
)
|
|
114
92
|
|
|
115
|
-
triplets += await self.
|
|
93
|
+
triplets += await self.get_triplets(completion)
|
|
116
94
|
triplets = list(set(triplets))
|
|
117
95
|
context_text = await self.resolve_edges_to_text(triplets)
|
|
118
96
|
|
|
@@ -131,6 +109,24 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
131
109
|
|
|
132
110
|
round_idx += 1
|
|
133
111
|
|
|
112
|
+
return triplets
|
|
113
|
+
|
|
114
|
+
async def get_completion_from_context(
|
|
115
|
+
self,
|
|
116
|
+
query: str,
|
|
117
|
+
retrieved_objects: List[Edge],
|
|
118
|
+
context: str,
|
|
119
|
+
) -> List[Any]:
|
|
120
|
+
"""
|
|
121
|
+
Returns a human readable answer based on the provided query and extended context derived from the retrieved objects.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
--------
|
|
125
|
+
|
|
126
|
+
- List[str]: A list containing the generated answer based on the query and the
|
|
127
|
+
extended context.
|
|
128
|
+
"""
|
|
129
|
+
|
|
134
130
|
# Check if we need to generate context summary for caching
|
|
135
131
|
cache_config = CacheConfig()
|
|
136
132
|
user = session_user.get()
|
|
@@ -138,33 +134,33 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
138
134
|
session_save = user_id and cache_config.caching
|
|
139
135
|
|
|
140
136
|
if session_save:
|
|
141
|
-
conversation_history = await get_conversation_history(session_id=session_id)
|
|
137
|
+
conversation_history = await get_conversation_history(session_id=self.session_id)
|
|
142
138
|
|
|
143
139
|
context_summary, completion = await asyncio.gather(
|
|
144
|
-
summarize_text(
|
|
140
|
+
summarize_text(context),
|
|
145
141
|
generate_completion(
|
|
146
142
|
query=query,
|
|
147
|
-
context=
|
|
143
|
+
context=context,
|
|
148
144
|
user_prompt_path=self.user_prompt_path,
|
|
149
145
|
system_prompt_path=self.system_prompt_path,
|
|
150
146
|
system_prompt=self.system_prompt,
|
|
151
147
|
conversation_history=conversation_history,
|
|
152
|
-
response_model=response_model,
|
|
148
|
+
response_model=self.response_model,
|
|
153
149
|
),
|
|
154
150
|
)
|
|
155
151
|
else:
|
|
156
152
|
completion = await generate_completion(
|
|
157
153
|
query=query,
|
|
158
|
-
context=
|
|
154
|
+
context=context,
|
|
159
155
|
user_prompt_path=self.user_prompt_path,
|
|
160
156
|
system_prompt_path=self.system_prompt_path,
|
|
161
157
|
system_prompt=self.system_prompt,
|
|
162
|
-
response_model=response_model,
|
|
158
|
+
response_model=self.response_model,
|
|
163
159
|
)
|
|
164
160
|
|
|
165
|
-
if self.save_interaction and
|
|
161
|
+
if self.save_interaction and context and retrieved_objects and completion:
|
|
166
162
|
await self.save_qa(
|
|
167
|
-
question=query, answer=completion, context=
|
|
163
|
+
question=query, answer=completion, context=context, triplets=retrieved_objects
|
|
168
164
|
)
|
|
169
165
|
|
|
170
166
|
if session_save:
|
|
@@ -172,7 +168,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
172
168
|
query=query,
|
|
173
169
|
context_summary=context_summary,
|
|
174
170
|
answer=completion,
|
|
175
|
-
session_id=session_id,
|
|
171
|
+
session_id=self.session_id,
|
|
176
172
|
)
|
|
177
173
|
|
|
178
174
|
return [completion]
|
|
@@ -18,6 +18,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
|
18
18
|
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
|
19
19
|
from cognee.context_global_variables import session_user
|
|
20
20
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
|
21
|
+
from cognee.exceptions.exceptions import CogneeValidationError
|
|
21
22
|
|
|
22
23
|
logger = get_logger()
|
|
23
24
|
|
|
@@ -67,6 +68,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
67
68
|
save_interaction: bool = False,
|
|
68
69
|
wide_search_top_k: Optional[int] = 100,
|
|
69
70
|
triplet_distance_penalty: Optional[float] = 3.5,
|
|
71
|
+
max_iter: int = 4,
|
|
72
|
+
session_id: Optional[str] = None,
|
|
73
|
+
response_model: Type = str,
|
|
70
74
|
):
|
|
71
75
|
super().__init__(
|
|
72
76
|
user_prompt_path=user_prompt_path,
|
|
@@ -78,19 +82,68 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
78
82
|
save_interaction=save_interaction,
|
|
79
83
|
wide_search_top_k=wide_search_top_k,
|
|
80
84
|
triplet_distance_penalty=triplet_distance_penalty,
|
|
85
|
+
session_id=session_id,
|
|
86
|
+
response_model=response_model,
|
|
81
87
|
)
|
|
82
88
|
self.validation_system_prompt_path = validation_system_prompt_path
|
|
83
89
|
self.validation_user_prompt_path = validation_user_prompt_path
|
|
84
90
|
self.followup_system_prompt_path = followup_system_prompt_path
|
|
85
91
|
self.followup_user_prompt_path = followup_user_prompt_path
|
|
92
|
+
self.completion = []
|
|
93
|
+
self.max_iter = max_iter
|
|
94
|
+
|
|
95
|
+
async def get_retrieved_objects(self, query: str) -> List[Edge]:
|
|
96
|
+
"""
|
|
97
|
+
Run chain-of-thought completion with optional structured output.
|
|
98
|
+
|
|
99
|
+
Parameters:
|
|
100
|
+
-----------
|
|
101
|
+
- query: User query
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
--------
|
|
105
|
+
- List of retrieved edges
|
|
106
|
+
"""
|
|
107
|
+
# Check if session saving is enabled
|
|
108
|
+
cache_config = CacheConfig()
|
|
109
|
+
user = session_user.get()
|
|
110
|
+
user_id = getattr(user, "id", None)
|
|
111
|
+
session_save = user_id and cache_config.caching
|
|
112
|
+
|
|
113
|
+
# Load conversation history if enabled
|
|
114
|
+
conversation_history = ""
|
|
115
|
+
if session_save:
|
|
116
|
+
conversation_history = await get_conversation_history(session_id=self.session_id)
|
|
117
|
+
|
|
118
|
+
completion, context_text, triplets = await self._run_cot_completion(
|
|
119
|
+
query=query,
|
|
120
|
+
conversation_history=conversation_history,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Note: completion info is stored to reduce the need to call LLM again in get_completion_from_context
|
|
124
|
+
self.completion = completion
|
|
125
|
+
|
|
126
|
+
if self.save_interaction and context_text and triplets and completion:
|
|
127
|
+
await self.save_qa(
|
|
128
|
+
question=query, answer=str(completion), context=context_text, triplets=triplets
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Save to session cache if enabled
|
|
132
|
+
if session_save:
|
|
133
|
+
context_summary = await summarize_text(context_text)
|
|
134
|
+
await save_conversation_history(
|
|
135
|
+
query=query,
|
|
136
|
+
context_summary=context_summary,
|
|
137
|
+
answer=str(completion),
|
|
138
|
+
session_id=self.session_id,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return triplets
|
|
86
142
|
|
|
87
143
|
async def _run_cot_completion(
|
|
88
144
|
self,
|
|
89
145
|
query: str,
|
|
90
|
-
context: Optional[List[Edge]] = None,
|
|
91
146
|
conversation_history: str = "",
|
|
92
|
-
max_iter: int = 4,
|
|
93
|
-
response_model: Type = str,
|
|
94
147
|
) -> tuple[Any, str, List[Edge]]:
|
|
95
148
|
"""
|
|
96
149
|
Run chain-of-thought completion with optional structured output.
|
|
@@ -113,15 +166,12 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
113
166
|
triplets = []
|
|
114
167
|
completion = ""
|
|
115
168
|
|
|
116
|
-
for round_idx in range(max_iter + 1):
|
|
169
|
+
for round_idx in range(self.max_iter + 1):
|
|
117
170
|
if round_idx == 0:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
context_text = await self.resolve_edges_to_text(triplets)
|
|
121
|
-
else:
|
|
122
|
-
context_text = await self.resolve_edges_to_text(context)
|
|
171
|
+
triplets = await self.get_triplets(query)
|
|
172
|
+
context_text = await self.resolve_edges_to_text(triplets)
|
|
123
173
|
else:
|
|
124
|
-
triplets += await self.
|
|
174
|
+
triplets += await self.get_triplets(followup_question)
|
|
125
175
|
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
|
126
176
|
|
|
127
177
|
completion = await generate_completion(
|
|
@@ -131,12 +181,12 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
131
181
|
system_prompt_path=self.system_prompt_path,
|
|
132
182
|
system_prompt=self.system_prompt,
|
|
133
183
|
conversation_history=conversation_history if conversation_history else None,
|
|
134
|
-
response_model=response_model,
|
|
184
|
+
response_model=self.response_model,
|
|
135
185
|
)
|
|
136
186
|
|
|
137
187
|
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
|
138
188
|
|
|
139
|
-
if round_idx < max_iter:
|
|
189
|
+
if round_idx < self.max_iter:
|
|
140
190
|
answer_text = _as_answer_text(completion)
|
|
141
191
|
valid_args = {"query": query, "answer": answer_text, "context": context_text}
|
|
142
192
|
valid_user_prompt = render_prompt(
|
|
@@ -168,13 +218,11 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
168
218
|
|
|
169
219
|
return completion, context_text, triplets
|
|
170
220
|
|
|
171
|
-
async def
|
|
221
|
+
async def get_completion_from_context(
|
|
172
222
|
self,
|
|
173
223
|
query: str,
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
max_iter=4,
|
|
177
|
-
response_model: Type = str,
|
|
224
|
+
retrieved_objects: List[Edge],
|
|
225
|
+
context: str,
|
|
178
226
|
) -> List[Any]:
|
|
179
227
|
"""
|
|
180
228
|
Generate completion responses based on a user query and contextual information.
|
|
@@ -202,38 +250,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
202
250
|
|
|
203
251
|
- List[str]: A list containing the generated answer to the user's query.
|
|
204
252
|
"""
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
user_id = getattr(user, "id", None)
|
|
209
|
-
session_save = user_id and cache_config.caching
|
|
210
|
-
|
|
211
|
-
# Load conversation history if enabled
|
|
212
|
-
conversation_history = ""
|
|
213
|
-
if session_save:
|
|
214
|
-
conversation_history = await get_conversation_history(session_id=session_id)
|
|
215
|
-
|
|
216
|
-
completion, context_text, triplets = await self._run_cot_completion(
|
|
217
|
-
query=query,
|
|
218
|
-
context=context,
|
|
219
|
-
conversation_history=conversation_history,
|
|
220
|
-
max_iter=max_iter,
|
|
221
|
-
response_model=response_model,
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
if self.save_interaction and context and triplets and completion:
|
|
225
|
-
await self.save_qa(
|
|
226
|
-
question=query, answer=str(completion), context=context_text, triplets=triplets
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
# Save to session cache if enabled
|
|
230
|
-
if session_save:
|
|
231
|
-
context_summary = await summarize_text(context_text)
|
|
232
|
-
await save_conversation_history(
|
|
233
|
-
query=query,
|
|
234
|
-
context_summary=context_summary,
|
|
235
|
-
answer=str(completion),
|
|
236
|
-
session_id=session_id,
|
|
237
|
-
)
|
|
238
|
-
|
|
253
|
+
if not retrieved_objects:
|
|
254
|
+
raise CogneeValidationError("No context retrieved to generate completion.")
|
|
255
|
+
completion = self.completion
|
|
239
256
|
return [completion]
|