cognee 0.5.0.dev0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +1 -5
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/cognify/cognify.py +24 -16
- cognee/api/v1/cognify/routers/__init__.py +0 -1
- cognee/api/v1/cognify/routers/get_cognify_router.py +3 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/ontologies.py +12 -37
- cognee/api/v1/ontologies/routers/get_ontology_router.py +27 -25
- cognee/api/v1/search/search.py +8 -0
- cognee/api/v1/ui/node_setup.py +360 -0
- cognee/api/v1/ui/npm_utils.py +50 -0
- cognee/api/v1/ui/ui.py +38 -68
- cognee/context_global_variables.py +61 -16
- cognee/eval_framework/Dockerfile +29 -0
- cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +16 -28
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/graph/config.py +3 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +1 -0
- cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
- cognee/infrastructure/databases/utils/__init__.py +3 -0
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +62 -48
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
- cognee/infrastructure/databases/vector/config.py +2 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +1 -0
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
- cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
- cognee/infrastructure/files/storage/s3_config.py +2 -0
- cognee/infrastructure/llm/LLMGateway.py +5 -2
- cognee/infrastructure/llm/config.py +35 -0
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -16
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +40 -37
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +39 -36
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +19 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +11 -9
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +23 -21
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +42 -34
- cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/deletion/prune_system.py +52 -2
- cognee/modules/data/methods/delete_dataset.py +26 -0
- cognee/modules/engine/models/Triplet.py +9 -0
- cognee/modules/engine/models/__init__.py +1 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +85 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
- cognee/modules/memify/memify.py +1 -7
- cognee/modules/pipelines/operations/pipeline.py +18 -2
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +4 -0
- cognee/modules/retrieval/graph_completion_cot_retriever.py +4 -0
- cognee/modules/retrieval/graph_completion_retriever.py +10 -0
- cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
- cognee/modules/retrieval/register_retriever.py +10 -0
- cognee/modules/retrieval/registered_community_retrievers.py +1 -0
- cognee/modules/retrieval/temporal_retriever.py +4 -0
- cognee/modules/retrieval/triplet_retriever.py +182 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +42 -10
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +8 -1
- cognee/modules/search/methods/get_search_type_tools.py +54 -8
- cognee/modules/search/methods/no_access_control_search.py +4 -0
- cognee/modules/search/methods/search.py +46 -18
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +19 -0
- cognee/modules/users/methods/get_authenticated_user.py +2 -2
- cognee/modules/users/models/DatasetDatabase.py +15 -3
- cognee/shared/logging_utils.py +4 -0
- cognee/shared/rate_limiting.py +30 -0
- cognee/tasks/documents/__init__.py +0 -1
- cognee/tasks/graph/extract_graph_from_data.py +9 -10
- cognee/tasks/memify/get_triplet_datapoints.py +289 -0
- cognee/tasks/storage/add_data_points.py +142 -2
- cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
- cognee/tests/integration/tasks/test_add_data_points.py +139 -0
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
- cognee/tests/test_cognee_server_start.py +2 -4
- cognee/tests/test_conversation_history.py +23 -1
- cognee/tests/test_dataset_database_handler.py +137 -0
- cognee/tests/test_dataset_delete.py +76 -0
- cognee/tests/test_edge_centered_payload.py +170 -0
- cognee/tests/test_pipeline_cache.py +164 -0
- cognee/tests/test_search_db.py +37 -1
- cognee/tests/unit/api/test_ontology_endpoint.py +77 -89
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
- cognee/tests/unit/modules/search/test_search.py +100 -0
- cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/METADATA +76 -89
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/RECORD +119 -97
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/WHEEL +1 -1
- cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
- cognee/modules/retrieval/code_retriever.py +0 -232
- cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
- cognee/tasks/code/get_local_dependencies_checker.py +0 -20
- cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
- cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
- cognee/tasks/repo_processor/__init__.py +0 -2
- cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
- cognee/tasks/repo_processor/get_non_code_files.py +0 -158
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
- cognee/tests/test_delete_bmw_example.py +0 -60
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -16,6 +16,7 @@ class ModelName(Enum):
|
|
|
16
16
|
anthropic = "anthropic"
|
|
17
17
|
gemini = "gemini"
|
|
18
18
|
mistral = "mistral"
|
|
19
|
+
bedrock = "bedrock"
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class LLMConfig(BaseModel):
|
|
@@ -77,6 +78,10 @@ def get_settings() -> SettingsDict:
|
|
|
77
78
|
"value": "mistral",
|
|
78
79
|
"label": "Mistral",
|
|
79
80
|
},
|
|
81
|
+
{
|
|
82
|
+
"value": "bedrock",
|
|
83
|
+
"label": "Bedrock",
|
|
84
|
+
},
|
|
80
85
|
]
|
|
81
86
|
|
|
82
87
|
return SettingsDict.model_validate(
|
|
@@ -157,6 +162,20 @@ def get_settings() -> SettingsDict:
|
|
|
157
162
|
"label": "Mistral Large 2.1",
|
|
158
163
|
},
|
|
159
164
|
],
|
|
165
|
+
"bedrock": [
|
|
166
|
+
{
|
|
167
|
+
"value": "eu.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
168
|
+
"label": "Claude 4.5 Sonnet",
|
|
169
|
+
},
|
|
170
|
+
{
|
|
171
|
+
"value": "eu.anthropic.claude-haiku-4-5-20251001-v1:0",
|
|
172
|
+
"label": "Claude 4.5 Haiku",
|
|
173
|
+
},
|
|
174
|
+
{
|
|
175
|
+
"value": "eu.amazon.nova-lite-v1:0",
|
|
176
|
+
"label": "Amazon Nova Lite",
|
|
177
|
+
},
|
|
178
|
+
],
|
|
160
179
|
},
|
|
161
180
|
},
|
|
162
181
|
vector_db={
|
|
@@ -12,8 +12,8 @@ logger = get_logger("get_authenticated_user")
|
|
|
12
12
|
|
|
13
13
|
# Check environment variable to determine authentication requirement
|
|
14
14
|
REQUIRE_AUTHENTICATION = (
|
|
15
|
-
os.getenv("REQUIRE_AUTHENTICATION", "
|
|
16
|
-
or
|
|
15
|
+
os.getenv("REQUIRE_AUTHENTICATION", "true").lower() == "true"
|
|
16
|
+
or os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", "true").lower() == "true"
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
fastapi_users = get_fastapi_users()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from datetime import datetime, timezone
|
|
2
2
|
|
|
3
|
-
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey
|
|
3
|
+
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey, JSON, text
|
|
4
4
|
from cognee.infrastructure.databases.relational import Base
|
|
5
5
|
|
|
6
6
|
|
|
@@ -12,17 +12,29 @@ class DatasetDatabase(Base):
|
|
|
12
12
|
UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
-
vector_database_name = Column(String, unique=
|
|
16
|
-
graph_database_name = Column(String, unique=
|
|
15
|
+
vector_database_name = Column(String, unique=False, nullable=False)
|
|
16
|
+
graph_database_name = Column(String, unique=False, nullable=False)
|
|
17
17
|
|
|
18
18
|
vector_database_provider = Column(String, unique=False, nullable=False)
|
|
19
19
|
graph_database_provider = Column(String, unique=False, nullable=False)
|
|
20
20
|
|
|
21
|
+
graph_dataset_database_handler = Column(String, unique=False, nullable=False)
|
|
22
|
+
vector_dataset_database_handler = Column(String, unique=False, nullable=False)
|
|
23
|
+
|
|
21
24
|
vector_database_url = Column(String, unique=False, nullable=True)
|
|
22
25
|
graph_database_url = Column(String, unique=False, nullable=True)
|
|
23
26
|
|
|
24
27
|
vector_database_key = Column(String, unique=False, nullable=True)
|
|
25
28
|
graph_database_key = Column(String, unique=False, nullable=True)
|
|
26
29
|
|
|
30
|
+
# configuration details for different database types. This would make it more flexible to add new database types
|
|
31
|
+
# without changing the database schema.
|
|
32
|
+
graph_database_connection_info = Column(
|
|
33
|
+
JSON, unique=False, nullable=False, server_default=text("'{}'")
|
|
34
|
+
)
|
|
35
|
+
vector_database_connection_info = Column(
|
|
36
|
+
JSON, unique=False, nullable=False, server_default=text("'{}'")
|
|
37
|
+
)
|
|
38
|
+
|
|
27
39
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
|
28
40
|
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
cognee/shared/logging_utils.py
CHANGED
|
@@ -534,6 +534,10 @@ def setup_logging(log_level=None, name=None):
|
|
|
534
534
|
# Get a configured logger and log system information
|
|
535
535
|
logger = structlog.get_logger(name if name else __name__)
|
|
536
536
|
|
|
537
|
+
logger.warning(
|
|
538
|
+
"From version 0.5.0 onwards, Cognee will run with multi-user access control mode set to on by default. Data isolation between different users and datasets will be enforced and data created before multi-user access control mode was turned on won't be accessible by default. To disable multi-user access control mode and regain access to old data set the environment variable ENABLE_BACKEND_ACCESS_CONTROL to false before starting Cognee. For more information, please refer to the Cognee documentation."
|
|
539
|
+
)
|
|
540
|
+
|
|
537
541
|
if logs_dir is not None:
|
|
538
542
|
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)
|
|
539
543
|
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from aiolimiter import AsyncLimiter
|
|
2
|
+
from contextlib import nullcontext
|
|
3
|
+
from cognee.infrastructure.llm.config import get_llm_config
|
|
4
|
+
|
|
5
|
+
llm_config = get_llm_config()
|
|
6
|
+
|
|
7
|
+
llm_rate_limiter = AsyncLimiter(
|
|
8
|
+
llm_config.llm_rate_limit_requests, llm_config.embedding_rate_limit_interval
|
|
9
|
+
)
|
|
10
|
+
embedding_rate_limiter = AsyncLimiter(
|
|
11
|
+
llm_config.embedding_rate_limit_requests, llm_config.embedding_rate_limit_interval
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def llm_rate_limiter_context_manager():
|
|
16
|
+
global llm_rate_limiter
|
|
17
|
+
if llm_config.llm_rate_limit_enabled:
|
|
18
|
+
return llm_rate_limiter
|
|
19
|
+
else:
|
|
20
|
+
# Return a no-op context manager if rate limiting is disabled
|
|
21
|
+
return nullcontext()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def embedding_rate_limiter_context_manager():
|
|
25
|
+
global embedding_rate_limiter
|
|
26
|
+
if llm_config.embedding_rate_limit_enabled:
|
|
27
|
+
return embedding_rate_limiter
|
|
28
|
+
else:
|
|
29
|
+
# Return a no-op context manager if rate limiting is disabled
|
|
30
|
+
return nullcontext()
|
|
@@ -2,9 +2,7 @@ import asyncio
|
|
|
2
2
|
from typing import Type, List, Optional
|
|
3
3
|
from pydantic import BaseModel
|
|
4
4
|
|
|
5
|
-
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
6
5
|
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
|
7
|
-
from cognee.tasks.storage import index_graph_edges
|
|
8
6
|
from cognee.tasks.storage.add_data_points import add_data_points
|
|
9
7
|
from cognee.modules.ontology.ontology_config import Config
|
|
10
8
|
from cognee.modules.ontology.get_default_ontology_resolver import (
|
|
@@ -25,6 +23,7 @@ from cognee.tasks.graph.exceptions import (
|
|
|
25
23
|
InvalidChunkGraphInputError,
|
|
26
24
|
InvalidOntologyAdapterError,
|
|
27
25
|
)
|
|
26
|
+
from cognee.modules.cognify.config import get_cognify_config
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
async def integrate_chunk_graphs(
|
|
@@ -67,8 +66,6 @@ async def integrate_chunk_graphs(
|
|
|
67
66
|
type(ontology_resolver).__name__ if ontology_resolver else "None"
|
|
68
67
|
)
|
|
69
68
|
|
|
70
|
-
graph_engine = await get_graph_engine()
|
|
71
|
-
|
|
72
69
|
if graph_model is not KnowledgeGraph:
|
|
73
70
|
for chunk_index, chunk_graph in enumerate(chunk_graphs):
|
|
74
71
|
data_chunks[chunk_index].contains = chunk_graph
|
|
@@ -84,12 +81,13 @@ async def integrate_chunk_graphs(
|
|
|
84
81
|
data_chunks, chunk_graphs, ontology_resolver, existing_edges_map
|
|
85
82
|
)
|
|
86
83
|
|
|
87
|
-
|
|
88
|
-
|
|
84
|
+
cognify_config = get_cognify_config()
|
|
85
|
+
embed_triplets = cognify_config.triplet_embedding
|
|
89
86
|
|
|
90
|
-
if len(
|
|
91
|
-
await
|
|
92
|
-
|
|
87
|
+
if len(graph_nodes) > 0:
|
|
88
|
+
await add_data_points(
|
|
89
|
+
data_points=graph_nodes, custom_edges=graph_edges, embed_triplets=embed_triplets
|
|
90
|
+
)
|
|
93
91
|
|
|
94
92
|
return data_chunks
|
|
95
93
|
|
|
@@ -99,6 +97,7 @@ async def extract_graph_from_data(
|
|
|
99
97
|
graph_model: Type[BaseModel],
|
|
100
98
|
config: Config = None,
|
|
101
99
|
custom_prompt: Optional[str] = None,
|
|
100
|
+
**kwargs,
|
|
102
101
|
) -> List[DocumentChunk]:
|
|
103
102
|
"""
|
|
104
103
|
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
|
|
@@ -113,7 +112,7 @@ async def extract_graph_from_data(
|
|
|
113
112
|
|
|
114
113
|
chunk_graphs = await asyncio.gather(
|
|
115
114
|
*[
|
|
116
|
-
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt)
|
|
115
|
+
extract_content_graph(chunk.text, graph_model, custom_prompt=custom_prompt, **kwargs)
|
|
117
116
|
for chunk in data_chunks
|
|
118
117
|
]
|
|
119
118
|
)
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
from typing import AsyncGenerator, Dict, Any, List, Optional
|
|
2
|
+
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
|
3
|
+
from cognee.modules.engine.utils import generate_node_id
|
|
4
|
+
from cognee.shared.logging_utils import get_logger
|
|
5
|
+
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
|
6
|
+
from cognee.infrastructure.engine import DataPoint
|
|
7
|
+
from cognee.modules.engine.models import Triplet
|
|
8
|
+
from cognee.tasks.storage import index_data_points
|
|
9
|
+
|
|
10
|
+
logger = get_logger("get_triplet_datapoints")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _build_datapoint_type_index_mapping() -> Dict[str, List[str]]:
|
|
14
|
+
"""
|
|
15
|
+
Build a mapping of DataPoint type names to their index_fields.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
--------
|
|
19
|
+
- Dict[str, List[str]]: Mapping of type name to list of index field names
|
|
20
|
+
"""
|
|
21
|
+
logger.debug("Building DataPoint type to index_fields mapping")
|
|
22
|
+
subclasses = get_all_subclasses(DataPoint)
|
|
23
|
+
datapoint_type_index_property = {}
|
|
24
|
+
|
|
25
|
+
for subclass in subclasses:
|
|
26
|
+
if "metadata" in subclass.model_fields:
|
|
27
|
+
metadata_field = subclass.model_fields["metadata"]
|
|
28
|
+
default = getattr(metadata_field, "default", None)
|
|
29
|
+
if isinstance(default, dict):
|
|
30
|
+
index_fields = default.get("index_fields", [])
|
|
31
|
+
if index_fields:
|
|
32
|
+
datapoint_type_index_property[subclass.__name__] = index_fields
|
|
33
|
+
logger.debug(
|
|
34
|
+
f"Registered {subclass.__name__} with index_fields: {index_fields}"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
logger.info(
|
|
38
|
+
f"Found {len(datapoint_type_index_property)} DataPoint types with index_fields: "
|
|
39
|
+
f"{list(datapoint_type_index_property.keys())}"
|
|
40
|
+
)
|
|
41
|
+
return datapoint_type_index_property
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _extract_embeddable_text(node_or_edge: Dict[str, Any], index_fields: List[str]) -> str:
|
|
45
|
+
"""
|
|
46
|
+
Extract and concatenate embeddable properties from a node or edge dictionary.
|
|
47
|
+
|
|
48
|
+
Parameters:
|
|
49
|
+
-----------
|
|
50
|
+
- node_or_edge (Dict[str, Any]): Dictionary containing node or edge properties.
|
|
51
|
+
- index_fields (List[str]): List of field names to extract and concatenate.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
--------
|
|
55
|
+
- str: Concatenated string of all embeddable property values, or empty string if none found.
|
|
56
|
+
"""
|
|
57
|
+
if not node_or_edge or not index_fields:
|
|
58
|
+
return ""
|
|
59
|
+
|
|
60
|
+
embeddable_values = []
|
|
61
|
+
for field_name in index_fields:
|
|
62
|
+
field_value = node_or_edge.get(field_name)
|
|
63
|
+
if field_value is not None:
|
|
64
|
+
field_value = str(field_value).strip()
|
|
65
|
+
|
|
66
|
+
if field_value:
|
|
67
|
+
embeddable_values.append(field_value)
|
|
68
|
+
|
|
69
|
+
return " ".join(embeddable_values) if embeddable_values else ""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _extract_relationship_text(
|
|
73
|
+
relationship: Dict[str, Any], datapoint_type_index_property: Dict[str, List[str]]
|
|
74
|
+
) -> str:
|
|
75
|
+
"""
|
|
76
|
+
Extract relationship text from edge properties.
|
|
77
|
+
|
|
78
|
+
Parameters:
|
|
79
|
+
-----------
|
|
80
|
+
- relationship (Dict[str, Any]): Dictionary containing relationship properties
|
|
81
|
+
- datapoint_type_index_property (Dict[str, List[str]]): Mapping of type to index fields
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
--------
|
|
85
|
+
- str: Extracted relationship text or empty string
|
|
86
|
+
"""
|
|
87
|
+
if not relationship:
|
|
88
|
+
return ""
|
|
89
|
+
|
|
90
|
+
edge_text = relationship.get("edge_text")
|
|
91
|
+
if edge_text and isinstance(edge_text, str) and edge_text.strip():
|
|
92
|
+
return edge_text.strip()
|
|
93
|
+
|
|
94
|
+
# Fallback to extracting from EdgeType index_fields
|
|
95
|
+
edge_type_index_fields = datapoint_type_index_property.get("EdgeType", [])
|
|
96
|
+
return _extract_embeddable_text(relationship, edge_type_index_fields)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _process_single_triplet(
|
|
100
|
+
triplet_datapoint: Dict[str, Any],
|
|
101
|
+
datapoint_type_index_property: Dict[str, List[str]],
|
|
102
|
+
offset: int,
|
|
103
|
+
idx: int,
|
|
104
|
+
) -> tuple[Optional[Triplet], Optional[str]]:
|
|
105
|
+
"""
|
|
106
|
+
Process a single triplet and create a Triplet object.
|
|
107
|
+
|
|
108
|
+
Parameters:
|
|
109
|
+
-----------
|
|
110
|
+
- triplet_datapoint (Dict[str, Any]): Raw triplet data from graph engine
|
|
111
|
+
- datapoint_type_index_property (Dict[str, List[str]]): Type to index fields mapping
|
|
112
|
+
- offset (int): Current batch offset
|
|
113
|
+
- idx (int): Index within current batch
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
--------
|
|
117
|
+
- tuple[Optional[Triplet], Optional[str]]: (Triplet object, error message if skipped)
|
|
118
|
+
"""
|
|
119
|
+
start_node = triplet_datapoint.get("start_node", {})
|
|
120
|
+
end_node = triplet_datapoint.get("end_node", {})
|
|
121
|
+
relationship = triplet_datapoint.get("relationship_properties", {})
|
|
122
|
+
|
|
123
|
+
start_node_type = start_node.get("type")
|
|
124
|
+
end_node_type = end_node.get("type")
|
|
125
|
+
|
|
126
|
+
start_index_fields = datapoint_type_index_property.get(start_node_type, [])
|
|
127
|
+
end_index_fields = datapoint_type_index_property.get(end_node_type, [])
|
|
128
|
+
|
|
129
|
+
if not start_index_fields:
|
|
130
|
+
logger.debug(
|
|
131
|
+
f"No index_fields found for start_node type '{start_node_type}' in triplet {offset + idx}"
|
|
132
|
+
)
|
|
133
|
+
if not end_index_fields:
|
|
134
|
+
logger.debug(
|
|
135
|
+
f"No index_fields found for end_node type '{end_node_type}' in triplet {offset + idx}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
start_node_id = start_node.get("id", "")
|
|
139
|
+
end_node_id = end_node.get("id", "")
|
|
140
|
+
|
|
141
|
+
if not start_node_id or not end_node_id:
|
|
142
|
+
return None, (
|
|
143
|
+
f"Skipping triplet at offset {offset + idx}: missing node IDs "
|
|
144
|
+
f"(start: {start_node_id}, end: {end_node_id})"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
relationship_text = _extract_relationship_text(relationship, datapoint_type_index_property)
|
|
148
|
+
start_node_text = _extract_embeddable_text(start_node, start_index_fields)
|
|
149
|
+
end_node_text = _extract_embeddable_text(end_node, end_index_fields)
|
|
150
|
+
|
|
151
|
+
if not start_node_text and not end_node_text and not relationship_text:
|
|
152
|
+
return None, (
|
|
153
|
+
f"Skipping triplet at offset {offset + idx}: empty embeddable text "
|
|
154
|
+
f"(start_node_id: {start_node_id}, end_node_id: {end_node_id})"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
embeddable_text = f"{start_node_text}-›{relationship_text}-›{end_node_text}".strip()
|
|
158
|
+
|
|
159
|
+
relationship_name = relationship.get("relationship_name", "")
|
|
160
|
+
triplet_id = generate_node_id(str(start_node_id) + str(relationship_name) + str(end_node_id))
|
|
161
|
+
|
|
162
|
+
triplet_obj = Triplet(
|
|
163
|
+
id=triplet_id, from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return triplet_obj, None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
async def get_triplet_datapoints(
|
|
170
|
+
data,
|
|
171
|
+
triplets_batch_size: int = 100,
|
|
172
|
+
) -> AsyncGenerator[Triplet, None]:
|
|
173
|
+
"""
|
|
174
|
+
Async generator that yields batches of triplet datapoints with embeddable text extracted.
|
|
175
|
+
|
|
176
|
+
Each triplet in the batch includes:
|
|
177
|
+
- Original triplet structure (start_node, relationship_properties, end_node)
|
|
178
|
+
- Extracted embeddable text for each element based on index_fields
|
|
179
|
+
|
|
180
|
+
Parameters:
|
|
181
|
+
-----------
|
|
182
|
+
- triplets_batch_size (int): Number of triplets to retrieve per batch. Default is 100.
|
|
183
|
+
|
|
184
|
+
Yields:
|
|
185
|
+
-------
|
|
186
|
+
- List[Dict[str, Any]]: A batch of triplets, each enriched with embeddable text.
|
|
187
|
+
"""
|
|
188
|
+
if not data or data == [{}]:
|
|
189
|
+
logger.info("Fetching graph data for current user")
|
|
190
|
+
|
|
191
|
+
logger.info(f"Starting triplet datapoints extraction with batch size: {triplets_batch_size}")
|
|
192
|
+
|
|
193
|
+
graph_engine = await get_graph_engine()
|
|
194
|
+
graph_engine_type = type(graph_engine).__name__
|
|
195
|
+
logger.debug(f"Using graph engine: {graph_engine_type}")
|
|
196
|
+
|
|
197
|
+
if not hasattr(graph_engine, "get_triplets_batch"):
|
|
198
|
+
error_msg = f"Graph adapter {graph_engine_type} does not support get_triplets_batch method"
|
|
199
|
+
logger.error(error_msg)
|
|
200
|
+
raise NotImplementedError(error_msg)
|
|
201
|
+
|
|
202
|
+
datapoint_type_index_property = _build_datapoint_type_index_mapping()
|
|
203
|
+
|
|
204
|
+
offset = 0
|
|
205
|
+
total_triplets_processed = 0
|
|
206
|
+
batch_number = 0
|
|
207
|
+
|
|
208
|
+
while True:
|
|
209
|
+
try:
|
|
210
|
+
batch_number += 1
|
|
211
|
+
logger.debug(
|
|
212
|
+
f"Fetching triplet batch {batch_number} (offset: {offset}, limit: {triplets_batch_size})"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
triplets_batch = await graph_engine.get_triplets_batch(
|
|
216
|
+
offset=offset, limit=triplets_batch_size
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if not triplets_batch:
|
|
220
|
+
logger.info(f"No more triplets found at offset {offset}. Processing complete.")
|
|
221
|
+
break
|
|
222
|
+
|
|
223
|
+
logger.debug(f"Retrieved {len(triplets_batch)} triplets in batch {batch_number}")
|
|
224
|
+
|
|
225
|
+
triplet_datapoints = []
|
|
226
|
+
skipped_count = 0
|
|
227
|
+
|
|
228
|
+
for idx, triplet_datapoint in enumerate(triplets_batch):
|
|
229
|
+
try:
|
|
230
|
+
triplet_obj, error_msg = _process_single_triplet(
|
|
231
|
+
triplet_datapoint, datapoint_type_index_property, offset, idx
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if error_msg:
|
|
235
|
+
logger.warning(error_msg)
|
|
236
|
+
skipped_count += 1
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
if triplet_obj:
|
|
240
|
+
triplet_datapoints.append(triplet_obj)
|
|
241
|
+
yield triplet_obj
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
logger.warning(
|
|
245
|
+
f"Error processing triplet at offset {offset + idx}: {e}. "
|
|
246
|
+
f"Skipping this triplet and continuing."
|
|
247
|
+
)
|
|
248
|
+
skipped_count += 1
|
|
249
|
+
continue
|
|
250
|
+
|
|
251
|
+
if skipped_count > 0:
|
|
252
|
+
logger.warning(
|
|
253
|
+
f"Skipped {skipped_count} out of {len(triplets_batch)} triplets in batch {batch_number}"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
if not triplet_datapoints:
|
|
257
|
+
logger.warning(
|
|
258
|
+
f"No valid triplet datapoints in batch {batch_number} after processing"
|
|
259
|
+
)
|
|
260
|
+
offset += len(triplets_batch)
|
|
261
|
+
if len(triplets_batch) < triplets_batch_size:
|
|
262
|
+
break
|
|
263
|
+
continue
|
|
264
|
+
|
|
265
|
+
total_triplets_processed += len(triplet_datapoints)
|
|
266
|
+
logger.info(
|
|
267
|
+
f"Batch {batch_number} complete: processed {len(triplet_datapoints)} triplets "
|
|
268
|
+
f"(total processed: {total_triplets_processed})"
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
offset += len(triplets_batch)
|
|
272
|
+
if len(triplets_batch) < triplets_batch_size:
|
|
273
|
+
logger.info(
|
|
274
|
+
f"Last batch retrieved (got {len(triplets_batch)} < {triplets_batch_size} triplets). "
|
|
275
|
+
f"Processing complete."
|
|
276
|
+
)
|
|
277
|
+
break
|
|
278
|
+
|
|
279
|
+
except Exception as e:
|
|
280
|
+
logger.error(
|
|
281
|
+
f"Error retrieving triplet batch {batch_number} at offset {offset}: {e}",
|
|
282
|
+
exc_info=True,
|
|
283
|
+
)
|
|
284
|
+
raise
|
|
285
|
+
|
|
286
|
+
logger.info(
|
|
287
|
+
f"Triplet datapoints extraction complete. "
|
|
288
|
+
f"Processed {total_triplets_processed} triplets across {batch_number} batch(es)."
|
|
289
|
+
)
|
|
@@ -1,16 +1,23 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
from typing import List
|
|
2
|
+
from typing import List, Dict, Optional
|
|
3
3
|
from cognee.infrastructure.engine import DataPoint
|
|
4
4
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
5
5
|
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
|
6
6
|
from .index_data_points import index_data_points
|
|
7
7
|
from .index_graph_edges import index_graph_edges
|
|
8
|
+
from cognee.modules.engine.models import Triplet
|
|
9
|
+
from cognee.shared.logging_utils import get_logger
|
|
8
10
|
from cognee.tasks.storage.exceptions import (
|
|
9
11
|
InvalidDataPointsInAddDataPointsError,
|
|
10
12
|
)
|
|
13
|
+
from ...modules.engine.utils import generate_node_id
|
|
11
14
|
|
|
15
|
+
logger = get_logger("add_data_points")
|
|
12
16
|
|
|
13
|
-
|
|
17
|
+
|
|
18
|
+
async def add_data_points(
|
|
19
|
+
data_points: List[DataPoint], custom_edges: Optional[List] = None, embed_triplets: bool = False
|
|
20
|
+
) -> List[DataPoint]:
|
|
14
21
|
"""
|
|
15
22
|
Add a batch of data points to the graph database by extracting nodes and edges,
|
|
16
23
|
deduplicating them, and indexing them for retrieval.
|
|
@@ -23,6 +30,10 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
|
|
23
30
|
Args:
|
|
24
31
|
data_points (List[DataPoint]):
|
|
25
32
|
A list of data points to process and insert into the graph.
|
|
33
|
+
custom_edges (List[tuple]): Custom edges between datapoints.
|
|
34
|
+
embed_triplets (bool):
|
|
35
|
+
If True, creates and indexes triplet embeddings from the graph structure.
|
|
36
|
+
Defaults to False.
|
|
26
37
|
|
|
27
38
|
Returns:
|
|
28
39
|
List[DataPoint]:
|
|
@@ -34,6 +45,7 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
|
|
34
45
|
- Updates the node index via `index_data_points`.
|
|
35
46
|
- Inserts nodes and edges into the graph engine.
|
|
36
47
|
- Optionally updates the edge index via `index_graph_edges`.
|
|
48
|
+
- Optionally creates and indexes triplet embeddings if embed_triplets is True.
|
|
37
49
|
"""
|
|
38
50
|
|
|
39
51
|
if not isinstance(data_points, list):
|
|
@@ -74,4 +86,132 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
|
|
74
86
|
await graph_engine.add_edges(edges)
|
|
75
87
|
await index_graph_edges(edges)
|
|
76
88
|
|
|
89
|
+
if isinstance(custom_edges, list) and custom_edges:
|
|
90
|
+
# This must be handled separately from datapoint edges, created a task in linear to dig deeper but (COG-3488)
|
|
91
|
+
await graph_engine.add_edges(custom_edges)
|
|
92
|
+
await index_graph_edges(custom_edges)
|
|
93
|
+
edges.extend(custom_edges)
|
|
94
|
+
|
|
95
|
+
if embed_triplets:
|
|
96
|
+
triplets = _create_triplets_from_graph(nodes, edges)
|
|
97
|
+
if triplets:
|
|
98
|
+
await index_data_points(triplets)
|
|
99
|
+
logger.info(f"Created and indexed {len(triplets)} triplets from graph structure")
|
|
100
|
+
|
|
77
101
|
return data_points
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _extract_embeddable_text_from_datapoint(data_point: DataPoint) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Extract embeddable text from a DataPoint using its index_fields metadata.
|
|
107
|
+
Uses the same approach as index_data_points.
|
|
108
|
+
|
|
109
|
+
Parameters:
|
|
110
|
+
-----------
|
|
111
|
+
- data_point (DataPoint): The data point to extract text from.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
--------
|
|
115
|
+
- str: Concatenated string of all embeddable property values, or empty string if none found.
|
|
116
|
+
"""
|
|
117
|
+
if not data_point or not hasattr(data_point, "metadata"):
|
|
118
|
+
return ""
|
|
119
|
+
|
|
120
|
+
index_fields = data_point.metadata.get("index_fields", [])
|
|
121
|
+
if not index_fields:
|
|
122
|
+
return ""
|
|
123
|
+
|
|
124
|
+
embeddable_values = []
|
|
125
|
+
for field_name in index_fields:
|
|
126
|
+
field_value = getattr(data_point, field_name, None)
|
|
127
|
+
if field_value is not None:
|
|
128
|
+
field_value = str(field_value).strip()
|
|
129
|
+
|
|
130
|
+
if field_value:
|
|
131
|
+
embeddable_values.append(field_value)
|
|
132
|
+
|
|
133
|
+
return " ".join(embeddable_values) if embeddable_values else ""
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _create_triplets_from_graph(nodes: List[DataPoint], edges: List[tuple]) -> List[Triplet]:
|
|
137
|
+
"""
|
|
138
|
+
Create Triplet objects from graph nodes and edges.
|
|
139
|
+
|
|
140
|
+
This function processes graph edges and their corresponding nodes to create
|
|
141
|
+
triplet datapoints with embeddable text, similar to the triplet embeddings pipeline.
|
|
142
|
+
|
|
143
|
+
Parameters:
|
|
144
|
+
-----------
|
|
145
|
+
- nodes (List[DataPoint]): List of graph nodes extracted from data points
|
|
146
|
+
- edges (List[tuple]): List of edge tuples in format
|
|
147
|
+
(source_node_id, target_node_id, relationship_name, properties_dict)
|
|
148
|
+
Note: All edges including those from DocumentChunk.contains are already extracted
|
|
149
|
+
by get_graph_from_model and included in this list.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
--------
|
|
153
|
+
- List[Triplet]: List of Triplet objects ready for indexing
|
|
154
|
+
"""
|
|
155
|
+
node_map: Dict[str, DataPoint] = {}
|
|
156
|
+
for node in nodes:
|
|
157
|
+
if hasattr(node, "id"):
|
|
158
|
+
node_id = str(node.id)
|
|
159
|
+
if node_id not in node_map:
|
|
160
|
+
node_map[node_id] = node
|
|
161
|
+
|
|
162
|
+
triplets = []
|
|
163
|
+
skipped_count = 0
|
|
164
|
+
seen_ids = set()
|
|
165
|
+
|
|
166
|
+
for edge_tuple in edges:
|
|
167
|
+
if len(edge_tuple) < 4:
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
source_node_id, target_node_id, relationship_name, edge_properties = (
|
|
171
|
+
edge_tuple[0],
|
|
172
|
+
edge_tuple[1],
|
|
173
|
+
edge_tuple[2],
|
|
174
|
+
edge_tuple[3],
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
source_node = node_map.get(str(source_node_id))
|
|
178
|
+
target_node = node_map.get(str(target_node_id))
|
|
179
|
+
|
|
180
|
+
if not source_node or not target_node or relationship_name is None:
|
|
181
|
+
skipped_count += 1
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
source_node_text = _extract_embeddable_text_from_datapoint(source_node)
|
|
185
|
+
target_node_text = _extract_embeddable_text_from_datapoint(target_node)
|
|
186
|
+
|
|
187
|
+
relationship_text = ""
|
|
188
|
+
if isinstance(edge_properties, dict):
|
|
189
|
+
edge_text = edge_properties.get("edge_text")
|
|
190
|
+
if edge_text and isinstance(edge_text, str) and edge_text.strip():
|
|
191
|
+
relationship_text = edge_text.strip()
|
|
192
|
+
|
|
193
|
+
if not relationship_text and relationship_name:
|
|
194
|
+
relationship_text = relationship_name
|
|
195
|
+
|
|
196
|
+
if not source_node_text and not relationship_text and not relationship_name:
|
|
197
|
+
skipped_count += 1
|
|
198
|
+
continue
|
|
199
|
+
|
|
200
|
+
embeddable_text = f"{source_node_text} -› {relationship_text}-›{target_node_text}".strip()
|
|
201
|
+
|
|
202
|
+
triplet_id = generate_node_id(str(source_node_id) + relationship_name + str(target_node_id))
|
|
203
|
+
|
|
204
|
+
if triplet_id in seen_ids:
|
|
205
|
+
continue
|
|
206
|
+
seen_ids.add(triplet_id)
|
|
207
|
+
|
|
208
|
+
triplets.append(
|
|
209
|
+
Triplet(
|
|
210
|
+
id=triplet_id,
|
|
211
|
+
from_node_id=str(source_node_id),
|
|
212
|
+
to_node_id=str(target_node_id),
|
|
213
|
+
text=embeddable_text,
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
return triplets
|