cognee 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/api/client.py +28 -3
- cognee/api/health.py +10 -13
- cognee/api/v1/add/add.py +3 -1
- cognee/api/v1/add/routers/get_add_router.py +12 -37
- cognee/api/v1/cloud/routers/__init__.py +1 -0
- cognee/api/v1/cloud/routers/get_checks_router.py +23 -0
- cognee/api/v1/cognify/code_graph_pipeline.py +9 -4
- cognee/api/v1/cognify/cognify.py +50 -3
- cognee/api/v1/cognify/routers/get_cognify_router.py +1 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +15 -4
- cognee/api/v1/memify/__init__.py +0 -0
- cognee/api/v1/memify/routers/__init__.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +100 -0
- cognee/api/v1/notebooks/routers/__init__.py +1 -0
- cognee/api/v1/notebooks/routers/get_notebooks_router.py +96 -0
- cognee/api/v1/search/routers/get_search_router.py +20 -1
- cognee/api/v1/search/search.py +11 -4
- cognee/api/v1/sync/__init__.py +17 -0
- cognee/api/v1/sync/routers/__init__.py +3 -0
- cognee/api/v1/sync/routers/get_sync_router.py +241 -0
- cognee/api/v1/sync/sync.py +877 -0
- cognee/api/v1/ui/__init__.py +1 -0
- cognee/api/v1/ui/ui.py +529 -0
- cognee/api/v1/users/routers/get_auth_router.py +13 -1
- cognee/base_config.py +10 -1
- cognee/cli/_cognee.py +93 -0
- cognee/infrastructure/databases/graph/config.py +10 -4
- cognee/infrastructure/databases/graph/kuzu/adapter.py +135 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +89 -0
- cognee/infrastructure/databases/relational/__init__.py +2 -0
- cognee/infrastructure/databases/relational/get_async_session.py +15 -0
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +6 -1
- cognee/infrastructure/databases/relational/with_async_session.py +25 -0
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +1 -1
- cognee/infrastructure/databases/vector/config.py +13 -6
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +1 -1
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +2 -6
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +4 -1
- cognee/infrastructure/files/storage/LocalFileStorage.py +9 -0
- cognee/infrastructure/files/storage/S3FileStorage.py +5 -0
- cognee/infrastructure/files/storage/StorageManager.py +7 -1
- cognee/infrastructure/files/storage/storage.py +16 -0
- cognee/infrastructure/llm/LLMGateway.py +18 -0
- cognee/infrastructure/llm/config.py +4 -2
- cognee/infrastructure/llm/prompts/extract_query_time.txt +15 -0
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +25 -0
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +30 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +2 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_event_entities.py +44 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/__init__.py +1 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_event_graph.py +46 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -1
- cognee/infrastructure/utils/run_sync.py +8 -1
- cognee/modules/chunking/models/DocumentChunk.py +4 -3
- cognee/modules/cloud/exceptions/CloudApiKeyMissingError.py +15 -0
- cognee/modules/cloud/exceptions/CloudConnectionError.py +15 -0
- cognee/modules/cloud/exceptions/__init__.py +2 -0
- cognee/modules/cloud/operations/__init__.py +1 -0
- cognee/modules/cloud/operations/check_api_key.py +25 -0
- cognee/modules/data/deletion/prune_system.py +1 -1
- cognee/modules/data/methods/check_dataset_name.py +1 -1
- cognee/modules/data/methods/get_dataset_data.py +1 -1
- cognee/modules/data/methods/load_or_create_datasets.py +1 -1
- cognee/modules/engine/models/Event.py +16 -0
- cognee/modules/engine/models/Interval.py +8 -0
- cognee/modules/engine/models/Timestamp.py +13 -0
- cognee/modules/engine/models/__init__.py +3 -0
- cognee/modules/engine/utils/__init__.py +2 -0
- cognee/modules/engine/utils/generate_event_datapoint.py +46 -0
- cognee/modules/engine/utils/generate_timestamp_datapoint.py +51 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +2 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/resolve_edges_to_text.py +71 -0
- cognee/modules/memify/__init__.py +1 -0
- cognee/modules/memify/memify.py +118 -0
- cognee/modules/notebooks/methods/__init__.py +5 -0
- cognee/modules/notebooks/methods/create_notebook.py +26 -0
- cognee/modules/notebooks/methods/delete_notebook.py +13 -0
- cognee/modules/notebooks/methods/get_notebook.py +21 -0
- cognee/modules/notebooks/methods/get_notebooks.py +18 -0
- cognee/modules/notebooks/methods/update_notebook.py +17 -0
- cognee/modules/notebooks/models/Notebook.py +53 -0
- cognee/modules/notebooks/models/__init__.py +1 -0
- cognee/modules/notebooks/operations/__init__.py +1 -0
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +55 -0
- cognee/modules/pipelines/layers/reset_dataset_pipeline_run_status.py +19 -3
- cognee/modules/pipelines/operations/pipeline.py +1 -0
- cognee/modules/pipelines/operations/run_tasks.py +17 -41
- cognee/modules/retrieval/base_graph_retriever.py +18 -0
- cognee/modules/retrieval/base_retriever.py +1 -1
- cognee/modules/retrieval/code_retriever.py +8 -0
- cognee/modules/retrieval/coding_rules_retriever.py +31 -0
- cognee/modules/retrieval/completion_retriever.py +9 -3
- cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -0
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +23 -14
- cognee/modules/retrieval/graph_completion_cot_retriever.py +21 -11
- cognee/modules/retrieval/graph_completion_retriever.py +32 -65
- cognee/modules/retrieval/graph_summary_completion_retriever.py +3 -1
- cognee/modules/retrieval/insights_retriever.py +14 -3
- cognee/modules/retrieval/summaries_retriever.py +1 -1
- cognee/modules/retrieval/temporal_retriever.py +152 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +7 -32
- cognee/modules/retrieval/utils/completion.py +10 -3
- cognee/modules/search/methods/get_search_type_tools.py +168 -0
- cognee/modules/search/methods/no_access_control_search.py +47 -0
- cognee/modules/search/methods/search.py +219 -139
- cognee/modules/search/types/SearchResult.py +21 -0
- cognee/modules/search/types/SearchType.py +2 -0
- cognee/modules/search/types/__init__.py +1 -0
- cognee/modules/search/utils/__init__.py +2 -0
- cognee/modules/search/utils/prepare_search_result.py +41 -0
- cognee/modules/search/utils/transform_context_to_graph.py +38 -0
- cognee/modules/sync/__init__.py +1 -0
- cognee/modules/sync/methods/__init__.py +23 -0
- cognee/modules/sync/methods/create_sync_operation.py +53 -0
- cognee/modules/sync/methods/get_sync_operation.py +107 -0
- cognee/modules/sync/methods/update_sync_operation.py +248 -0
- cognee/modules/sync/models/SyncOperation.py +142 -0
- cognee/modules/sync/models/__init__.py +3 -0
- cognee/modules/users/__init__.py +0 -1
- cognee/modules/users/methods/__init__.py +4 -1
- cognee/modules/users/methods/create_user.py +26 -1
- cognee/modules/users/methods/get_authenticated_user.py +36 -42
- cognee/modules/users/methods/get_default_user.py +3 -1
- cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +2 -1
- cognee/root_dir.py +19 -0
- cognee/shared/logging_utils.py +1 -1
- cognee/tasks/codingagents/__init__.py +0 -0
- cognee/tasks/codingagents/coding_rule_associations.py +127 -0
- cognee/tasks/ingestion/save_data_item_to_storage.py +23 -0
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/extract_subgraph.py +7 -0
- cognee/tasks/memify/extract_subgraph_chunks.py +11 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +52 -27
- cognee/tasks/temporal_graph/__init__.py +1 -0
- cognee/tasks/temporal_graph/add_entities_to_event.py +85 -0
- cognee/tasks/temporal_graph/enrich_events.py +34 -0
- cognee/tasks/temporal_graph/extract_events_and_entities.py +32 -0
- cognee/tasks/temporal_graph/extract_knowledge_graph_from_events.py +41 -0
- cognee/tasks/temporal_graph/models.py +49 -0
- cognee/tests/test_kuzu.py +4 -4
- cognee/tests/test_neo4j.py +4 -4
- cognee/tests/test_permissions.py +3 -3
- cognee/tests/test_relational_db_migration.py +7 -5
- cognee/tests/test_search_db.py +18 -24
- cognee/tests/test_temporal_graph.py +167 -0
- cognee/tests/unit/api/__init__.py +1 -0
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +246 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +13 -16
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +11 -16
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +5 -4
- cognee/tests/unit/modules/retrieval/insights_retriever_test.py +4 -2
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +225 -0
- cognee/tests/unit/modules/users/__init__.py +1 -0
- cognee/tests/unit/modules/users/test_conditional_authentication.py +277 -0
- cognee/tests/unit/processing/utils/utils_test.py +20 -1
- {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/METADATA +8 -6
- {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/RECORD +165 -90
- cognee/tests/unit/modules/search/search_methods_test.py +0 -225
- {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/WHEEL +0 -0
- {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/entry_points.txt +0 -0
- {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, List, Type
|
|
2
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
2
3
|
from cognee.shared.logging_utils import get_logger
|
|
3
4
|
|
|
4
5
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
@@ -32,6 +33,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
32
33
|
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
|
33
34
|
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
|
34
35
|
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
|
36
|
+
system_prompt: Optional[str] = None,
|
|
35
37
|
top_k: Optional[int] = 5,
|
|
36
38
|
node_type: Optional[Type] = None,
|
|
37
39
|
node_name: Optional[List[str]] = None,
|
|
@@ -40,6 +42,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
40
42
|
super().__init__(
|
|
41
43
|
user_prompt_path=user_prompt_path,
|
|
42
44
|
system_prompt_path=system_prompt_path,
|
|
45
|
+
system_prompt=system_prompt,
|
|
43
46
|
top_k=top_k,
|
|
44
47
|
node_type=node_type,
|
|
45
48
|
node_name=node_name,
|
|
@@ -51,8 +54,11 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
51
54
|
self.followup_user_prompt_path = followup_user_prompt_path
|
|
52
55
|
|
|
53
56
|
async def get_completion(
|
|
54
|
-
self,
|
|
55
|
-
|
|
57
|
+
self,
|
|
58
|
+
query: str,
|
|
59
|
+
context: Optional[List[Edge]] = None,
|
|
60
|
+
max_iter=4,
|
|
61
|
+
) -> str:
|
|
56
62
|
"""
|
|
57
63
|
Generate completion responses based on a user query and contextual information.
|
|
58
64
|
|
|
@@ -77,25 +83,29 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
77
83
|
"""
|
|
78
84
|
followup_question = ""
|
|
79
85
|
triplets = []
|
|
80
|
-
completion =
|
|
86
|
+
completion = ""
|
|
81
87
|
|
|
82
88
|
for round_idx in range(max_iter + 1):
|
|
83
89
|
if round_idx == 0:
|
|
84
90
|
if context is None:
|
|
85
|
-
|
|
91
|
+
triplets = await self.get_context(query)
|
|
92
|
+
context_text = await self.resolve_edges_to_text(triplets)
|
|
93
|
+
else:
|
|
94
|
+
context_text = await self.resolve_edges_to_text(context)
|
|
86
95
|
else:
|
|
87
|
-
triplets += await self.
|
|
88
|
-
|
|
96
|
+
triplets += await self.get_context(followup_question)
|
|
97
|
+
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
|
89
98
|
|
|
90
99
|
completion = await generate_completion(
|
|
91
100
|
query=query,
|
|
92
|
-
context=
|
|
101
|
+
context=context_text,
|
|
93
102
|
user_prompt_path=self.user_prompt_path,
|
|
94
103
|
system_prompt_path=self.system_prompt_path,
|
|
104
|
+
system_prompt=self.system_prompt,
|
|
95
105
|
)
|
|
96
106
|
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
|
97
107
|
if round_idx < max_iter:
|
|
98
|
-
valid_args = {"query": query, "answer": completion, "context":
|
|
108
|
+
valid_args = {"query": query, "answer": completion, "context": context_text}
|
|
99
109
|
valid_user_prompt = LLMGateway.render_prompt(
|
|
100
110
|
filename=self.validation_user_prompt_path, context=valid_args
|
|
101
111
|
)
|
|
@@ -125,7 +135,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
125
135
|
|
|
126
136
|
if self.save_interaction and context and triplets and completion:
|
|
127
137
|
await self.save_qa(
|
|
128
|
-
question=query, answer=completion, context=
|
|
138
|
+
question=query, answer=completion, context=context_text, triplets=triplets
|
|
129
139
|
)
|
|
130
140
|
|
|
131
|
-
return
|
|
141
|
+
return completion
|
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
from typing import Any, Optional, Type, List
|
|
2
|
-
from collections import Counter
|
|
1
|
+
from typing import Any, Optional, Type, List
|
|
3
2
|
from uuid import NAMESPACE_OID, uuid5
|
|
4
|
-
import string
|
|
5
3
|
|
|
6
4
|
from cognee.infrastructure.engine import DataPoint
|
|
5
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
6
|
+
from cognee.modules.users.methods import get_default_user
|
|
7
7
|
from cognee.tasks.storage import add_data_points
|
|
8
|
+
from cognee.modules.graph.utils import resolve_edges_to_text
|
|
8
9
|
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
|
9
|
-
from cognee.modules.retrieval.
|
|
10
|
+
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
|
10
11
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
|
11
12
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
12
|
-
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
|
13
13
|
from cognee.shared.logging_utils import get_logger
|
|
14
14
|
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
|
15
15
|
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
|
@@ -19,7 +19,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|
|
19
19
|
logger = get_logger("GraphCompletionRetriever")
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class GraphCompletionRetriever(
|
|
22
|
+
class GraphCompletionRetriever(BaseGraphRetriever):
|
|
23
23
|
"""
|
|
24
24
|
Retriever for handling graph-based completion searches.
|
|
25
25
|
|
|
@@ -36,6 +36,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
36
36
|
self,
|
|
37
37
|
user_prompt_path: str = "graph_context_for_question.txt",
|
|
38
38
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
39
|
+
system_prompt: Optional[str] = None,
|
|
39
40
|
top_k: Optional[int] = 5,
|
|
40
41
|
node_type: Optional[Type] = None,
|
|
41
42
|
node_name: Optional[List[str]] = None,
|
|
@@ -45,26 +46,11 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
45
46
|
self.save_interaction = save_interaction
|
|
46
47
|
self.user_prompt_path = user_prompt_path
|
|
47
48
|
self.system_prompt_path = system_prompt_path
|
|
49
|
+
self.system_prompt = system_prompt
|
|
48
50
|
self.top_k = top_k if top_k is not None else 5
|
|
49
51
|
self.node_type = node_type
|
|
50
52
|
self.node_name = node_name
|
|
51
53
|
|
|
52
|
-
def _get_nodes(self, retrieved_edges: list) -> dict:
|
|
53
|
-
"""Creates a dictionary of nodes with their names and content."""
|
|
54
|
-
nodes = {}
|
|
55
|
-
for edge in retrieved_edges:
|
|
56
|
-
for node in (edge.node1, edge.node2):
|
|
57
|
-
if node.id not in nodes:
|
|
58
|
-
text = node.attributes.get("text")
|
|
59
|
-
if text:
|
|
60
|
-
name = self._get_title(text)
|
|
61
|
-
content = text
|
|
62
|
-
else:
|
|
63
|
-
name = node.attributes.get("name", "Unnamed Node")
|
|
64
|
-
content = node.attributes.get("description", name)
|
|
65
|
-
nodes[node.id] = {"node": node, "name": name, "content": content}
|
|
66
|
-
return nodes
|
|
67
|
-
|
|
68
54
|
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
|
69
55
|
"""
|
|
70
56
|
Converts retrieved graph edges into a human-readable string format.
|
|
@@ -79,18 +65,9 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
79
65
|
|
|
80
66
|
- str: A formatted string representation of the nodes and their connections.
|
|
81
67
|
"""
|
|
82
|
-
|
|
83
|
-
node_section = "\n".join(
|
|
84
|
-
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
|
85
|
-
for info in nodes.values()
|
|
86
|
-
)
|
|
87
|
-
connection_section = "\n".join(
|
|
88
|
-
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
|
|
89
|
-
for edge in retrieved_edges
|
|
90
|
-
)
|
|
91
|
-
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
|
68
|
+
return await resolve_edges_to_text(retrieved_edges)
|
|
92
69
|
|
|
93
|
-
async def get_triplets(self, query: str) ->
|
|
70
|
+
async def get_triplets(self, query: str) -> List[Edge]:
|
|
94
71
|
"""
|
|
95
72
|
Retrieves relevant graph triplets based on a query string.
|
|
96
73
|
|
|
@@ -105,7 +82,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
105
82
|
- list: A list of found triplets that match the query.
|
|
106
83
|
"""
|
|
107
84
|
subclasses = get_all_subclasses(DataPoint)
|
|
108
|
-
vector_index_collections = []
|
|
85
|
+
vector_index_collections: List[str] = []
|
|
109
86
|
|
|
110
87
|
for subclass in subclasses:
|
|
111
88
|
if "metadata" in subclass.model_fields:
|
|
@@ -116,8 +93,11 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
116
93
|
for field_name in index_fields:
|
|
117
94
|
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
|
118
95
|
|
|
96
|
+
user = await get_default_user()
|
|
97
|
+
|
|
119
98
|
found_triplets = await brute_force_triplet_search(
|
|
120
99
|
query,
|
|
100
|
+
user=user,
|
|
121
101
|
top_k=self.top_k,
|
|
122
102
|
collections=vector_index_collections or None,
|
|
123
103
|
node_type=self.node_type,
|
|
@@ -126,7 +106,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
126
106
|
|
|
127
107
|
return found_triplets
|
|
128
108
|
|
|
129
|
-
async def get_context(self, query: str) ->
|
|
109
|
+
async def get_context(self, query: str) -> List[Edge]:
|
|
130
110
|
"""
|
|
131
111
|
Retrieves and resolves graph triplets into context based on a query.
|
|
132
112
|
|
|
@@ -145,13 +125,17 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
145
125
|
|
|
146
126
|
if len(triplets) == 0:
|
|
147
127
|
logger.warning("Empty context was provided to the completion")
|
|
148
|
-
return
|
|
128
|
+
return []
|
|
149
129
|
|
|
150
|
-
context = await self.resolve_edges_to_text(triplets)
|
|
130
|
+
# context = await self.resolve_edges_to_text(triplets)
|
|
151
131
|
|
|
152
|
-
return
|
|
132
|
+
return triplets
|
|
153
133
|
|
|
154
|
-
async def get_completion(
|
|
134
|
+
async def get_completion(
|
|
135
|
+
self,
|
|
136
|
+
query: str,
|
|
137
|
+
context: Optional[List[Edge]] = None,
|
|
138
|
+
) -> Any:
|
|
155
139
|
"""
|
|
156
140
|
Generates a completion using graph connections context based on a query.
|
|
157
141
|
|
|
@@ -167,44 +151,27 @@ class GraphCompletionRetriever(BaseRetriever):
|
|
|
167
151
|
|
|
168
152
|
- Any: A generated completion based on the query and context provided.
|
|
169
153
|
"""
|
|
170
|
-
triplets =
|
|
154
|
+
triplets = context
|
|
155
|
+
|
|
156
|
+
if triplets is None:
|
|
157
|
+
triplets = await self.get_context(query)
|
|
171
158
|
|
|
172
|
-
|
|
173
|
-
context, triplets = await self.get_context(query)
|
|
159
|
+
context_text = await resolve_edges_to_text(triplets)
|
|
174
160
|
|
|
175
161
|
completion = await generate_completion(
|
|
176
162
|
query=query,
|
|
177
|
-
context=
|
|
163
|
+
context=context_text,
|
|
178
164
|
user_prompt_path=self.user_prompt_path,
|
|
179
165
|
system_prompt_path=self.system_prompt_path,
|
|
166
|
+
system_prompt=self.system_prompt,
|
|
180
167
|
)
|
|
181
168
|
|
|
182
169
|
if self.save_interaction and context and triplets and completion:
|
|
183
170
|
await self.save_qa(
|
|
184
|
-
question=query, answer=completion, context=
|
|
171
|
+
question=query, answer=completion, context=context_text, triplets=triplets
|
|
185
172
|
)
|
|
186
173
|
|
|
187
|
-
return
|
|
188
|
-
|
|
189
|
-
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
|
|
190
|
-
"""Concatenates the top N frequent words in text."""
|
|
191
|
-
if stop_words is None:
|
|
192
|
-
stop_words = DEFAULT_STOP_WORDS
|
|
193
|
-
|
|
194
|
-
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
|
195
|
-
|
|
196
|
-
if stop_words:
|
|
197
|
-
words = [word for word in words if word and word not in stop_words]
|
|
198
|
-
|
|
199
|
-
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
|
200
|
-
|
|
201
|
-
return separator.join(top_words)
|
|
202
|
-
|
|
203
|
-
def _get_title(self, text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
|
204
|
-
"""Creates a title, by combining first words with most frequent words from the text."""
|
|
205
|
-
first_n_words = text.split()[:first_n_words]
|
|
206
|
-
top_n_words = self._top_n_words(text, top_n=top_n_words)
|
|
207
|
-
return f"{' '.join(first_n_words)}... [{top_n_words}]"
|
|
174
|
+
return completion
|
|
208
175
|
|
|
209
176
|
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
|
210
177
|
"""
|
|
@@ -21,6 +21,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|
|
21
21
|
user_prompt_path: str = "graph_context_for_question.txt",
|
|
22
22
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
23
23
|
summarize_prompt_path: str = "summarize_search_results.txt",
|
|
24
|
+
system_prompt: Optional[str] = None,
|
|
24
25
|
top_k: Optional[int] = 5,
|
|
25
26
|
node_type: Optional[Type] = None,
|
|
26
27
|
node_name: Optional[List[str]] = None,
|
|
@@ -34,6 +35,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|
|
34
35
|
node_type=node_type,
|
|
35
36
|
node_name=node_name,
|
|
36
37
|
save_interaction=save_interaction,
|
|
38
|
+
system_prompt=system_prompt,
|
|
37
39
|
)
|
|
38
40
|
self.summarize_prompt_path = summarize_prompt_path
|
|
39
41
|
|
|
@@ -57,4 +59,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|
|
57
59
|
- str: A summary string representing the content of the retrieved edges.
|
|
58
60
|
"""
|
|
59
61
|
direct_text = await super().resolve_edges_to_text(retrieved_edges)
|
|
60
|
-
return await summarize_text(direct_text, self.summarize_prompt_path)
|
|
62
|
+
return await summarize_text(direct_text, self.summarize_prompt_path, self.system_prompt)
|
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import Any, Optional
|
|
3
3
|
|
|
4
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
|
5
|
+
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
|
4
6
|
from cognee.shared.logging_utils import get_logger
|
|
5
7
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
6
8
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
7
|
-
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
8
9
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
9
10
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
|
10
11
|
|
|
11
12
|
logger = get_logger("InsightsRetriever")
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
class InsightsRetriever(
|
|
15
|
+
class InsightsRetriever(BaseGraphRetriever):
|
|
15
16
|
"""
|
|
16
17
|
Retriever for handling graph connection-based insights.
|
|
17
18
|
|
|
@@ -95,7 +96,17 @@ class InsightsRetriever(BaseRetriever):
|
|
|
95
96
|
unique_node_connections_map[unique_id] = True
|
|
96
97
|
unique_node_connections.append(node_connection)
|
|
97
98
|
|
|
98
|
-
return
|
|
99
|
+
return [
|
|
100
|
+
Edge(
|
|
101
|
+
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
|
102
|
+
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
|
103
|
+
attributes={
|
|
104
|
+
**connection[1],
|
|
105
|
+
"relationship_type": connection[1]["relationship_name"],
|
|
106
|
+
},
|
|
107
|
+
)
|
|
108
|
+
for connection in unique_node_connections
|
|
109
|
+
]
|
|
99
110
|
|
|
100
111
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
101
112
|
"""
|
|
@@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever):
|
|
|
62
62
|
logger.info(f"Returning {len(summary_payloads)} summary payloads")
|
|
63
63
|
return summary_payloads
|
|
64
64
|
|
|
65
|
-
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
65
|
+
async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any:
|
|
66
66
|
"""
|
|
67
67
|
Generates a completion using summaries context.
|
|
68
68
|
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, Optional, List, Type
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from operator import itemgetter
|
|
6
|
+
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
7
|
+
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
8
|
+
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
9
|
+
from cognee.infrastructure.llm import LLMGateway
|
|
10
|
+
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
11
|
+
from cognee.shared.logging_utils import get_logger
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from cognee.tasks.temporal_graph.models import QueryInterval
|
|
15
|
+
|
|
16
|
+
logger = get_logger()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TemporalRetriever(GraphCompletionRetriever):
|
|
20
|
+
"""
|
|
21
|
+
Handles graph completion by generating responses based on a series of interactions with
|
|
22
|
+
a language model. This class extends from GraphCompletionRetriever and is designed to
|
|
23
|
+
manage the retrieval and validation process for user queries, integrating follow-up
|
|
24
|
+
questions based on reasoning. The public methods are:
|
|
25
|
+
|
|
26
|
+
- get_completion
|
|
27
|
+
|
|
28
|
+
Instance variables include:
|
|
29
|
+
- validation_system_prompt_path
|
|
30
|
+
- validation_user_prompt_path
|
|
31
|
+
- followup_system_prompt_path
|
|
32
|
+
- followup_user_prompt_path
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
user_prompt_path: str = "graph_context_for_question.txt",
|
|
38
|
+
system_prompt_path: str = "answer_simple_question.txt",
|
|
39
|
+
time_extraction_prompt_path: str = "extract_query_time.txt",
|
|
40
|
+
top_k: Optional[int] = 5,
|
|
41
|
+
node_type: Optional[Type] = None,
|
|
42
|
+
node_name: Optional[List[str]] = None,
|
|
43
|
+
):
|
|
44
|
+
super().__init__(
|
|
45
|
+
user_prompt_path=user_prompt_path,
|
|
46
|
+
system_prompt_path=system_prompt_path,
|
|
47
|
+
top_k=top_k,
|
|
48
|
+
node_type=node_type,
|
|
49
|
+
node_name=node_name,
|
|
50
|
+
)
|
|
51
|
+
self.user_prompt_path = user_prompt_path
|
|
52
|
+
self.system_prompt_path = system_prompt_path
|
|
53
|
+
self.time_extraction_prompt_path = time_extraction_prompt_path
|
|
54
|
+
self.top_k = top_k if top_k is not None else 5
|
|
55
|
+
self.node_type = node_type
|
|
56
|
+
self.node_name = node_name
|
|
57
|
+
|
|
58
|
+
def descriptions_to_string(self, results):
|
|
59
|
+
descs = []
|
|
60
|
+
for entry in results:
|
|
61
|
+
d = entry.get("description")
|
|
62
|
+
if d:
|
|
63
|
+
descs.append(d.strip())
|
|
64
|
+
return "\n#####################\n".join(descs)
|
|
65
|
+
|
|
66
|
+
async def extract_time_from_query(self, query: str):
|
|
67
|
+
prompt_path = self.time_extraction_prompt_path
|
|
68
|
+
|
|
69
|
+
if os.path.isabs(prompt_path):
|
|
70
|
+
base_directory = os.path.dirname(prompt_path)
|
|
71
|
+
prompt_path = os.path.basename(prompt_path)
|
|
72
|
+
else:
|
|
73
|
+
base_directory = None
|
|
74
|
+
|
|
75
|
+
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
|
76
|
+
|
|
77
|
+
interval = await LLMGateway.acreate_structured_output(query, system_prompt, QueryInterval)
|
|
78
|
+
|
|
79
|
+
time_from = interval.starts_at
|
|
80
|
+
time_to = interval.ends_at
|
|
81
|
+
|
|
82
|
+
return time_from, time_to
|
|
83
|
+
|
|
84
|
+
async def filter_top_k_events(self, relevant_events, scored_results):
|
|
85
|
+
# Build a score lookup from vector search results
|
|
86
|
+
score_lookup = {res.payload["id"]: res.score for res in scored_results}
|
|
87
|
+
|
|
88
|
+
events_with_scores = []
|
|
89
|
+
for event in relevant_events[0]["events"]:
|
|
90
|
+
score = score_lookup.get(event["id"], float("inf"))
|
|
91
|
+
events_with_scores.append({**event, "score": score})
|
|
92
|
+
|
|
93
|
+
events_with_scores.sort(key=itemgetter("score"))
|
|
94
|
+
|
|
95
|
+
return events_with_scores[: self.top_k]
|
|
96
|
+
|
|
97
|
+
async def get_context(self, query: str) -> Any:
|
|
98
|
+
"""Retrieves context based on the query."""
|
|
99
|
+
|
|
100
|
+
time_from, time_to = await self.extract_time_from_query(query)
|
|
101
|
+
|
|
102
|
+
graph_engine = await get_graph_engine()
|
|
103
|
+
|
|
104
|
+
triplets = []
|
|
105
|
+
|
|
106
|
+
if time_from and time_to:
|
|
107
|
+
ids = await graph_engine.collect_time_ids(time_from=time_from, time_to=time_to)
|
|
108
|
+
elif time_from:
|
|
109
|
+
ids = await graph_engine.collect_time_ids(time_from=time_from)
|
|
110
|
+
elif time_to:
|
|
111
|
+
ids = await graph_engine.collect_time_ids(time_to=time_to)
|
|
112
|
+
else:
|
|
113
|
+
logger.info(
|
|
114
|
+
"No timestamps identified based on the query, performing retrieval using triplet search on events and entities."
|
|
115
|
+
)
|
|
116
|
+
triplets = await self.get_context(query)
|
|
117
|
+
return await self.resolve_edges_to_text(triplets)
|
|
118
|
+
|
|
119
|
+
if ids:
|
|
120
|
+
relevant_events = await graph_engine.collect_events(ids=ids)
|
|
121
|
+
else:
|
|
122
|
+
logger.info(
|
|
123
|
+
"No events identified based on timestamp filtering, performing retrieval using triplet search on events and entities."
|
|
124
|
+
)
|
|
125
|
+
triplets = await self.get_context(query)
|
|
126
|
+
return await self.resolve_edges_to_text(triplets)
|
|
127
|
+
|
|
128
|
+
vector_engine = get_vector_engine()
|
|
129
|
+
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
|
|
130
|
+
|
|
131
|
+
vector_search_results = await vector_engine.search(
|
|
132
|
+
collection_name="Event_name", query_vector=query_vector, limit=0
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
|
|
136
|
+
|
|
137
|
+
return self.descriptions_to_string(top_k_events)
|
|
138
|
+
|
|
139
|
+
async def get_completion(self, query: str, context: Optional[str] = None) -> str:
|
|
140
|
+
"""Generates a response using the query and optional context."""
|
|
141
|
+
if not context:
|
|
142
|
+
context = await self.get_context(query=query)
|
|
143
|
+
|
|
144
|
+
if context:
|
|
145
|
+
completion = await generate_completion(
|
|
146
|
+
query=query,
|
|
147
|
+
context=context,
|
|
148
|
+
user_prompt_path=self.user_prompt_path,
|
|
149
|
+
system_prompt_path=self.system_prompt_path,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return completion
|
|
@@ -8,7 +8,7 @@ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFound
|
|
|
8
8
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
9
9
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
10
10
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
|
11
|
-
from cognee.modules.
|
|
11
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
12
12
|
from cognee.modules.users.models import User
|
|
13
13
|
from cognee.shared.utils import send_telemetry
|
|
14
14
|
|
|
@@ -63,9 +63,10 @@ async def get_memory_fragment(
|
|
|
63
63
|
if properties_to_project is None:
|
|
64
64
|
properties_to_project = ["id", "description", "name", "type", "text"]
|
|
65
65
|
|
|
66
|
+
memory_fragment = CogneeGraph()
|
|
67
|
+
|
|
66
68
|
try:
|
|
67
69
|
graph_engine = await get_graph_engine()
|
|
68
|
-
memory_fragment = CogneeGraph()
|
|
69
70
|
|
|
70
71
|
await memory_fragment.project_graph_from_db(
|
|
71
72
|
graph_engine,
|
|
@@ -87,41 +88,15 @@ async def get_memory_fragment(
|
|
|
87
88
|
|
|
88
89
|
|
|
89
90
|
async def brute_force_triplet_search(
|
|
90
|
-
query: str,
|
|
91
|
-
user: User = None,
|
|
92
|
-
top_k: int = 5,
|
|
93
|
-
collections: List[str] = None,
|
|
94
|
-
properties_to_project: List[str] = None,
|
|
95
|
-
memory_fragment: Optional[CogneeGraph] = None,
|
|
96
|
-
node_type: Optional[Type] = None,
|
|
97
|
-
node_name: Optional[List[str]] = None,
|
|
98
|
-
) -> list:
|
|
99
|
-
if user is None:
|
|
100
|
-
user = await get_default_user()
|
|
101
|
-
|
|
102
|
-
retrieved_results = await brute_force_search(
|
|
103
|
-
query,
|
|
104
|
-
user,
|
|
105
|
-
top_k,
|
|
106
|
-
collections=collections,
|
|
107
|
-
properties_to_project=properties_to_project,
|
|
108
|
-
memory_fragment=memory_fragment,
|
|
109
|
-
node_type=node_type,
|
|
110
|
-
node_name=node_name,
|
|
111
|
-
)
|
|
112
|
-
return retrieved_results
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
async def brute_force_search(
|
|
116
91
|
query: str,
|
|
117
92
|
user: User,
|
|
118
|
-
top_k: int,
|
|
119
|
-
collections: List[str] = None,
|
|
120
|
-
properties_to_project: List[str] = None,
|
|
93
|
+
top_k: int = 5,
|
|
94
|
+
collections: Optional[List[str]] = None,
|
|
95
|
+
properties_to_project: Optional[List[str]] = None,
|
|
121
96
|
memory_fragment: Optional[CogneeGraph] = None,
|
|
122
97
|
node_type: Optional[Type] = None,
|
|
123
98
|
node_name: Optional[List[str]] = None,
|
|
124
|
-
) ->
|
|
99
|
+
) -> List[Edge]:
|
|
125
100
|
"""
|
|
126
101
|
Performs a brute force search to retrieve the top triplets from the graph.
|
|
127
102
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from typing import Optional
|
|
1
2
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
2
3
|
|
|
3
4
|
|
|
@@ -6,11 +7,14 @@ async def generate_completion(
|
|
|
6
7
|
context: str,
|
|
7
8
|
user_prompt_path: str,
|
|
8
9
|
system_prompt_path: str,
|
|
10
|
+
system_prompt: Optional[str] = None,
|
|
9
11
|
) -> str:
|
|
10
12
|
"""Generates a completion using LLM with given context and prompts."""
|
|
11
13
|
args = {"question": query, "context": context}
|
|
12
14
|
user_prompt = LLMGateway.render_prompt(user_prompt_path, args)
|
|
13
|
-
system_prompt =
|
|
15
|
+
system_prompt = (
|
|
16
|
+
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
|
|
17
|
+
)
|
|
14
18
|
|
|
15
19
|
return await LLMGateway.acreate_structured_output(
|
|
16
20
|
text_input=user_prompt,
|
|
@@ -21,10 +25,13 @@ async def generate_completion(
|
|
|
21
25
|
|
|
22
26
|
async def summarize_text(
|
|
23
27
|
text: str,
|
|
24
|
-
|
|
28
|
+
system_prompt_path: str = "summarize_search_results.txt",
|
|
29
|
+
system_prompt: str = None,
|
|
25
30
|
) -> str:
|
|
26
31
|
"""Summarizes text using LLM with the specified prompt."""
|
|
27
|
-
system_prompt =
|
|
32
|
+
system_prompt = (
|
|
33
|
+
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
|
|
34
|
+
)
|
|
28
35
|
|
|
29
36
|
return await LLMGateway.acreate_structured_output(
|
|
30
37
|
text_input=text,
|