cognee 0.2.1.dev7__py3-none-any.whl → 0.2.2.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +44 -4
- cognee/api/health.py +332 -0
- cognee/api/v1/add/add.py +5 -2
- cognee/api/v1/add/routers/get_add_router.py +3 -0
- cognee/api/v1/cognify/code_graph_pipeline.py +3 -1
- cognee/api/v1/cognify/cognify.py +8 -0
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -1
- cognee/api/v1/config/config.py +3 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +2 -8
- cognee/api/v1/delete/delete.py +16 -12
- cognee/api/v1/responses/routers/get_responses_router.py +3 -1
- cognee/api/v1/search/search.py +10 -0
- cognee/api/v1/settings/routers/get_settings_router.py +0 -2
- cognee/base_config.py +1 -0
- cognee/eval_framework/evaluation/direct_llm_eval_adapter.py +5 -6
- cognee/infrastructure/databases/graph/config.py +2 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +58 -12
- cognee/infrastructure/databases/graph/graph_db_interface.py +15 -10
- cognee/infrastructure/databases/graph/kuzu/adapter.py +43 -16
- cognee/infrastructure/databases/graph/kuzu/kuzu_migrate.py +281 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +151 -77
- cognee/infrastructure/databases/graph/neptune_driver/__init__.py +15 -0
- cognee/infrastructure/databases/graph/neptune_driver/adapter.py +1427 -0
- cognee/infrastructure/databases/graph/neptune_driver/exceptions.py +115 -0
- cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py +224 -0
- cognee/infrastructure/databases/graph/networkx/adapter.py +3 -3
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +449 -0
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +11 -3
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +8 -3
- cognee/infrastructure/databases/vector/create_vector_engine.py +31 -23
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +3 -1
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +21 -6
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +4 -3
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +3 -1
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +22 -16
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +36 -34
- cognee/infrastructure/databases/vector/vector_db_interface.py +78 -7
- cognee/infrastructure/files/utils/get_data_file_path.py +39 -0
- cognee/infrastructure/files/utils/guess_file_type.py +2 -2
- cognee/infrastructure/files/utils/open_data_file.py +4 -23
- cognee/infrastructure/llm/LLMGateway.py +137 -0
- cognee/infrastructure/llm/__init__.py +14 -4
- cognee/infrastructure/llm/config.py +29 -1
- cognee/infrastructure/llm/prompts/answer_hotpot_question.txt +1 -1
- cognee/infrastructure/llm/prompts/answer_hotpot_using_cognee_search.txt +1 -1
- cognee/infrastructure/llm/prompts/answer_simple_question.txt +1 -1
- cognee/infrastructure/llm/prompts/answer_simple_question_restricted.txt +1 -1
- cognee/infrastructure/llm/prompts/categorize_categories.txt +1 -1
- cognee/infrastructure/llm/prompts/classify_content.txt +1 -1
- cognee/infrastructure/llm/prompts/context_for_question.txt +1 -1
- cognee/infrastructure/llm/prompts/graph_context_for_question.txt +1 -1
- cognee/infrastructure/llm/prompts/natural_language_retriever_system.txt +1 -1
- cognee/infrastructure/llm/prompts/patch_gen_instructions.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +130 -0
- cognee/infrastructure/llm/prompts/summarize_code.txt +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/__init__.py +57 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/async_client.py +533 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/config.py +94 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/globals.py +37 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/inlinedbaml.py +21 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/parser.py +131 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/runtime.py +266 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/stream_types.py +137 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/sync_client.py +550 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/tracing.py +26 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_builder.py +962 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_map.py +52 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/types.py +166 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_categories.baml +109 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_content_graph.baml +343 -0
- cognee/{modules/data → infrastructure/llm/structured_output_framework/baml/baml_src}/extraction/__init__.py +1 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_summary.py +89 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +33 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/generators.baml +18 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +3 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_categories.py +12 -0
- cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/extract_summary.py +16 -7
- cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/knowledge_graph/extract_content_graph.py +7 -6
- cognee/infrastructure/llm/{anthropic → structured_output_framework/litellm_instructor/llm/anthropic}/adapter.py +10 -4
- cognee/infrastructure/llm/{gemini → structured_output_framework/litellm_instructor/llm/gemini}/adapter.py +6 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/__init__.py +0 -0
- cognee/infrastructure/llm/{generic_llm_api → structured_output_framework/litellm_instructor/llm/generic_llm_api}/adapter.py +7 -3
- cognee/infrastructure/llm/{get_llm_client.py → structured_output_framework/litellm_instructor/llm/get_llm_client.py} +18 -6
- cognee/infrastructure/llm/{llm_interface.py → structured_output_framework/litellm_instructor/llm/llm_interface.py} +2 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/__init__.py +0 -0
- cognee/infrastructure/llm/{ollama → structured_output_framework/litellm_instructor/llm/ollama}/adapter.py +4 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/__init__.py +0 -0
- cognee/infrastructure/llm/{openai → structured_output_framework/litellm_instructor/llm/openai}/adapter.py +6 -4
- cognee/infrastructure/llm/{rate_limiter.py → structured_output_framework/litellm_instructor/llm/rate_limiter.py} +0 -5
- cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +4 -2
- cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +7 -3
- cognee/infrastructure/llm/tokenizer/__init__.py +4 -0
- cognee/infrastructure/llm/utils.py +3 -1
- cognee/infrastructure/loaders/LoaderEngine.py +156 -0
- cognee/infrastructure/loaders/LoaderInterface.py +73 -0
- cognee/infrastructure/loaders/__init__.py +18 -0
- cognee/infrastructure/loaders/core/__init__.py +7 -0
- cognee/infrastructure/loaders/core/audio_loader.py +98 -0
- cognee/infrastructure/loaders/core/image_loader.py +114 -0
- cognee/infrastructure/loaders/core/text_loader.py +90 -0
- cognee/infrastructure/loaders/create_loader_engine.py +32 -0
- cognee/infrastructure/loaders/external/__init__.py +22 -0
- cognee/infrastructure/loaders/external/pypdf_loader.py +96 -0
- cognee/infrastructure/loaders/external/unstructured_loader.py +127 -0
- cognee/infrastructure/loaders/get_loader_engine.py +18 -0
- cognee/infrastructure/loaders/supported_loaders.py +18 -0
- cognee/infrastructure/loaders/use_loader.py +21 -0
- cognee/infrastructure/loaders/utils/__init__.py +0 -0
- cognee/modules/data/methods/__init__.py +1 -0
- cognee/modules/data/methods/get_authorized_dataset.py +23 -0
- cognee/modules/data/models/Data.py +13 -3
- cognee/modules/data/processing/document_types/AudioDocument.py +2 -2
- cognee/modules/data/processing/document_types/ImageDocument.py +2 -2
- cognee/modules/data/processing/document_types/PdfDocument.py +4 -11
- cognee/modules/data/processing/document_types/UnstructuredDocument.py +2 -5
- cognee/modules/engine/utils/generate_edge_id.py +5 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +45 -35
- cognee/modules/graph/methods/get_formatted_graph_data.py +8 -2
- cognee/modules/graph/utils/get_graph_from_model.py +93 -101
- cognee/modules/ingestion/data_types/TextData.py +8 -2
- cognee/modules/ingestion/save_data_to_file.py +1 -1
- cognee/modules/pipelines/exceptions/__init__.py +1 -0
- cognee/modules/pipelines/exceptions/exceptions.py +12 -0
- cognee/modules/pipelines/models/DataItemStatus.py +5 -0
- cognee/modules/pipelines/models/PipelineRunInfo.py +6 -0
- cognee/modules/pipelines/models/__init__.py +1 -0
- cognee/modules/pipelines/operations/pipeline.py +10 -2
- cognee/modules/pipelines/operations/run_tasks.py +252 -20
- cognee/modules/pipelines/operations/run_tasks_distributed.py +1 -1
- cognee/modules/retrieval/chunks_retriever.py +23 -1
- cognee/modules/retrieval/code_retriever.py +66 -9
- cognee/modules/retrieval/completion_retriever.py +11 -9
- cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +0 -2
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -2
- cognee/modules/retrieval/graph_completion_cot_retriever.py +8 -9
- cognee/modules/retrieval/graph_completion_retriever.py +1 -1
- cognee/modules/retrieval/insights_retriever.py +4 -0
- cognee/modules/retrieval/natural_language_retriever.py +9 -15
- cognee/modules/retrieval/summaries_retriever.py +23 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +23 -4
- cognee/modules/retrieval/utils/completion.py +6 -9
- cognee/modules/retrieval/utils/description_to_codepart_search.py +2 -3
- cognee/modules/search/methods/search.py +5 -1
- cognee/modules/search/operations/__init__.py +1 -0
- cognee/modules/search/operations/select_search_type.py +42 -0
- cognee/modules/search/types/SearchType.py +1 -0
- cognee/modules/settings/get_settings.py +0 -8
- cognee/modules/settings/save_vector_db_config.py +1 -1
- cognee/shared/data_models.py +3 -1
- cognee/shared/logging_utils.py +0 -5
- cognee/tasks/chunk_naive_llm_classifier/chunk_naive_llm_classifier.py +2 -2
- cognee/tasks/documents/extract_chunks_from_documents.py +10 -12
- cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py +4 -6
- cognee/tasks/graph/cascade_extract/utils/extract_content_nodes_and_relationship_names.py +4 -6
- cognee/tasks/graph/cascade_extract/utils/extract_edge_triplets.py +6 -7
- cognee/tasks/graph/cascade_extract/utils/extract_nodes.py +4 -7
- cognee/tasks/graph/extract_graph_from_code.py +3 -2
- cognee/tasks/graph/extract_graph_from_data.py +4 -3
- cognee/tasks/graph/infer_data_ontology.py +5 -6
- cognee/tasks/ingestion/data_item_to_text_file.py +79 -0
- cognee/tasks/ingestion/ingest_data.py +91 -61
- cognee/tasks/ingestion/resolve_data_directories.py +3 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +3 -0
- cognee/tasks/storage/index_data_points.py +1 -1
- cognee/tasks/storage/index_graph_edges.py +4 -1
- cognee/tasks/summarization/summarize_code.py +2 -3
- cognee/tasks/summarization/summarize_text.py +3 -2
- cognee/tests/test_cognee_server_start.py +12 -7
- cognee/tests/test_deduplication.py +2 -2
- cognee/tests/test_deletion.py +58 -17
- cognee/tests/test_graph_visualization_permissions.py +161 -0
- cognee/tests/test_neptune_analytics_graph.py +309 -0
- cognee/tests/test_neptune_analytics_hybrid.py +176 -0
- cognee/tests/{test_weaviate.py → test_neptune_analytics_vector.py} +86 -11
- cognee/tests/test_pgvector.py +5 -5
- cognee/tests/test_s3.py +1 -6
- cognee/tests/unit/infrastructure/databases/test_rate_limiter.py +11 -10
- cognee/tests/unit/infrastructure/databases/vector/__init__.py +0 -0
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +1 -1
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -5
- cognee/tests/unit/infrastructure/test_rate_limiting_realistic.py +6 -4
- cognee/tests/unit/infrastructure/test_rate_limiting_retry.py +1 -1
- cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_test.py +61 -3
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +84 -9
- cognee/tests/unit/modules/search/search_methods_test.py +55 -0
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/METADATA +13 -9
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/RECORD +203 -164
- cognee/infrastructure/databases/vector/pinecone/adapter.py +0 -8
- cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +0 -514
- cognee/infrastructure/databases/vector/qdrant/__init__.py +0 -2
- cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +0 -527
- cognee/infrastructure/databases/vector/weaviate_db/__init__.py +0 -1
- cognee/modules/data/extraction/extract_categories.py +0 -14
- cognee/tests/test_qdrant.py +0 -99
- distributed/Dockerfile +0 -34
- distributed/app.py +0 -4
- distributed/entrypoint.py +0 -71
- distributed/entrypoint.sh +0 -5
- distributed/modal_image.py +0 -11
- distributed/queues.py +0 -5
- distributed/tasks/queued_add_data_points.py +0 -13
- distributed/tasks/queued_add_edges.py +0 -13
- distributed/tasks/queued_add_nodes.py +0 -13
- distributed/test.py +0 -28
- distributed/utils.py +0 -19
- distributed/workers/data_point_saving_worker.py +0 -93
- distributed/workers/graph_saving_worker.py +0 -104
- /cognee/infrastructure/databases/{graph/memgraph → hybrid/neptune_analytics}/__init__.py +0 -0
- /cognee/infrastructure/{llm → databases/vector/embeddings}/embedding_rate_limiter.py +0 -0
- /cognee/infrastructure/{databases/vector/pinecone → llm/structured_output_framework}/__init__.py +0 -0
- /cognee/infrastructure/llm/{anthropic → structured_output_framework/baml/baml_src}/__init__.py +0 -0
- /cognee/infrastructure/llm/{gemini/__init__.py → structured_output_framework/baml/baml_src/extraction/extract_categories.py} +0 -0
- /cognee/infrastructure/llm/{generic_llm_api → structured_output_framework/baml/baml_src/extraction/knowledge_graph}/__init__.py +0 -0
- /cognee/infrastructure/llm/{ollama → structured_output_framework/litellm_instructor}/__init__.py +0 -0
- /cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/knowledge_graph/__init__.py +0 -0
- /cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/texts.json +0 -0
- /cognee/infrastructure/llm/{openai → structured_output_framework/litellm_instructor/llm}/__init__.py +0 -0
- {distributed → cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic}/__init__.py +0 -0
- {distributed/tasks → cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini}/__init__.py +0 -0
- /cognee/modules/data/{extraction/knowledge_graph → methods}/add_model_class_to_graph.py +0 -0
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/WHEEL +0 -0
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,1427 @@
|
|
|
1
|
+
"""Neptune Analytics Adapter for Graph Database"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Optional, Any, List, Dict, Type, Tuple
|
|
5
|
+
from uuid import UUID
|
|
6
|
+
from cognee.shared.logging_utils import get_logger
|
|
7
|
+
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
|
8
|
+
GraphDBInterface,
|
|
9
|
+
record_graph_changes,
|
|
10
|
+
NodeData,
|
|
11
|
+
EdgeData,
|
|
12
|
+
Node,
|
|
13
|
+
)
|
|
14
|
+
from cognee.modules.storage.utils import JSONEncoder
|
|
15
|
+
from cognee.infrastructure.engine import DataPoint
|
|
16
|
+
from botocore.config import Config
|
|
17
|
+
|
|
18
|
+
from .exceptions import (
|
|
19
|
+
NeptuneAnalyticsConfigurationError,
|
|
20
|
+
)
|
|
21
|
+
from .neptune_utils import (
|
|
22
|
+
validate_graph_id,
|
|
23
|
+
validate_aws_region,
|
|
24
|
+
build_neptune_config,
|
|
25
|
+
format_neptune_error,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
logger = get_logger("NeptuneGraphDB")
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
from langchain_aws import NeptuneAnalyticsGraph
|
|
32
|
+
|
|
33
|
+
LANGCHAIN_AWS_AVAILABLE = True
|
|
34
|
+
except ImportError:
|
|
35
|
+
logger.warning("langchain_aws not available. Neptune Analytics functionality will be limited.")
|
|
36
|
+
LANGCHAIN_AWS_AVAILABLE = False
|
|
37
|
+
|
|
38
|
+
NEPTUNE_ENDPOINT_URL = "neptune-graph://"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class NeptuneGraphDB(GraphDBInterface):
|
|
42
|
+
"""
|
|
43
|
+
Adapter for interacting with Amazon Neptune Analytics graph store.
|
|
44
|
+
This class provides methods for querying, adding, deleting nodes and edges using the aws_langchain library.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
_GRAPH_NODE_LABEL = "COGNEE_NODE"
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
graph_id: str,
|
|
52
|
+
region: Optional[str] = None,
|
|
53
|
+
aws_access_key_id: Optional[str] = None,
|
|
54
|
+
aws_secret_access_key: Optional[str] = None,
|
|
55
|
+
aws_session_token: Optional[str] = None,
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
Initialize the Neptune Analytics adapter.
|
|
59
|
+
|
|
60
|
+
Parameters:
|
|
61
|
+
-----------
|
|
62
|
+
- graph_id (str): The Neptune Analytics graph identifier
|
|
63
|
+
- region (Optional[str]): AWS region where the graph is located (default: us-east-1)
|
|
64
|
+
- aws_access_key_id (Optional[str]): AWS access key ID
|
|
65
|
+
- aws_secret_access_key (Optional[str]): AWS secret access key
|
|
66
|
+
- aws_session_token (Optional[str]): AWS session token for temporary credentials
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
-------
|
|
70
|
+
- NeptuneAnalyticsConfigurationError: If configuration parameters are invalid
|
|
71
|
+
"""
|
|
72
|
+
# validate import
|
|
73
|
+
if not LANGCHAIN_AWS_AVAILABLE:
|
|
74
|
+
raise ImportError(
|
|
75
|
+
"langchain_aws is not available. Please install it to use Neptune Analytics."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Validate configuration
|
|
79
|
+
if not validate_graph_id(graph_id):
|
|
80
|
+
raise NeptuneAnalyticsConfigurationError(message=f'Invalid graph ID: "{graph_id}"')
|
|
81
|
+
|
|
82
|
+
if region and not validate_aws_region(region):
|
|
83
|
+
raise NeptuneAnalyticsConfigurationError(message=f'Invalid AWS region: "{region}"')
|
|
84
|
+
|
|
85
|
+
self.graph_id = graph_id
|
|
86
|
+
self.region = region
|
|
87
|
+
self.aws_access_key_id = aws_access_key_id
|
|
88
|
+
self.aws_secret_access_key = aws_secret_access_key
|
|
89
|
+
self.aws_session_token = aws_session_token
|
|
90
|
+
|
|
91
|
+
# Build configuration
|
|
92
|
+
self.config = build_neptune_config(
|
|
93
|
+
graph_id=self.graph_id,
|
|
94
|
+
region=self.region,
|
|
95
|
+
aws_access_key_id=self.aws_access_key_id,
|
|
96
|
+
aws_secret_access_key=self.aws_secret_access_key,
|
|
97
|
+
aws_session_token=self.aws_session_token,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Initialize Neptune Analytics client using langchain_aws
|
|
101
|
+
self._client: NeptuneAnalyticsGraph = self._initialize_client()
|
|
102
|
+
logger.info(
|
|
103
|
+
f'Initialized Neptune Analytics adapter for graph: "{graph_id}" in region: "{self.region}"'
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def _initialize_client(self) -> Optional[NeptuneAnalyticsGraph]:
|
|
107
|
+
"""
|
|
108
|
+
Initialize the Neptune Analytics client using langchain_aws.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
--------
|
|
112
|
+
- Optional[Any]: The Neptune Analytics client or None if not available
|
|
113
|
+
"""
|
|
114
|
+
try:
|
|
115
|
+
# Initialize the Neptune Analytics Graph client
|
|
116
|
+
client_config = {
|
|
117
|
+
"graph_identifier": self.graph_id,
|
|
118
|
+
"config": Config(user_agent_appid="Cognee"),
|
|
119
|
+
}
|
|
120
|
+
# Add AWS credentials if provided
|
|
121
|
+
if self.region:
|
|
122
|
+
client_config["region_name"] = self.region
|
|
123
|
+
if self.aws_access_key_id:
|
|
124
|
+
client_config["aws_access_key_id"] = self.aws_access_key_id
|
|
125
|
+
if self.aws_secret_access_key:
|
|
126
|
+
client_config["aws_secret_access_key"] = self.aws_secret_access_key
|
|
127
|
+
if self.aws_session_token:
|
|
128
|
+
client_config["aws_session_token"] = self.aws_session_token
|
|
129
|
+
|
|
130
|
+
client = NeptuneAnalyticsGraph(**client_config)
|
|
131
|
+
logger.info("Successfully initialized Neptune Analytics client")
|
|
132
|
+
return client
|
|
133
|
+
|
|
134
|
+
except Exception as e:
|
|
135
|
+
raise NeptuneAnalyticsConfigurationError(
|
|
136
|
+
message=f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}"
|
|
137
|
+
) from e
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _serialize_properties(properties: Dict[str, Any]) -> Dict[str, Any]:
|
|
141
|
+
"""
|
|
142
|
+
Serialize properties for Neptune Analytics storage.
|
|
143
|
+
Parameters:
|
|
144
|
+
-----------
|
|
145
|
+
- properties (Dict[str, Any]): Properties to serialize.
|
|
146
|
+
Returns:
|
|
147
|
+
--------
|
|
148
|
+
- Dict[str, Any]: Serialized properties.
|
|
149
|
+
"""
|
|
150
|
+
serialized_properties = {}
|
|
151
|
+
|
|
152
|
+
for property_key, property_value in properties.items():
|
|
153
|
+
if isinstance(property_value, UUID):
|
|
154
|
+
serialized_properties[property_key] = str(property_value)
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
if isinstance(property_value, dict) or isinstance(property_value, list):
|
|
158
|
+
serialized_properties[property_key] = json.dumps(property_value, cls=JSONEncoder)
|
|
159
|
+
continue
|
|
160
|
+
|
|
161
|
+
serialized_properties[property_key] = property_value
|
|
162
|
+
|
|
163
|
+
return serialized_properties
|
|
164
|
+
|
|
165
|
+
async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Any]:
|
|
166
|
+
"""
|
|
167
|
+
Execute a query against the Neptune Analytics database and return the results.
|
|
168
|
+
|
|
169
|
+
Parameters:
|
|
170
|
+
-----------
|
|
171
|
+
- query (str): The query string to execute against the database.
|
|
172
|
+
- params (Optional[Dict[str, Any]]): A dictionary of parameters to be used in the query.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
--------
|
|
176
|
+
- List[Any]: A list of results from the query execution.
|
|
177
|
+
"""
|
|
178
|
+
try:
|
|
179
|
+
# Execute the query using the Neptune Analytics client
|
|
180
|
+
# The langchain_aws NeptuneAnalyticsGraph supports openCypher queries
|
|
181
|
+
if params is None:
|
|
182
|
+
params = {}
|
|
183
|
+
logger.debug(f"executing na query:\nquery={query}\n")
|
|
184
|
+
result = self._client.query(query, params)
|
|
185
|
+
|
|
186
|
+
# Convert the result to list format expected by the interface
|
|
187
|
+
if isinstance(result, list):
|
|
188
|
+
return result
|
|
189
|
+
elif isinstance(result, dict):
|
|
190
|
+
return [result]
|
|
191
|
+
else:
|
|
192
|
+
return [{"result": result}]
|
|
193
|
+
|
|
194
|
+
except Exception as e:
|
|
195
|
+
error_msg = format_neptune_error(e)
|
|
196
|
+
logger.error(f"Neptune Analytics query failed: {error_msg}")
|
|
197
|
+
raise Exception(f"Query execution failed: {error_msg}") from e
|
|
198
|
+
|
|
199
|
+
async def add_node(self, node: DataPoint) -> None:
|
|
200
|
+
"""
|
|
201
|
+
Add a single node with specified properties to the graph.
|
|
202
|
+
|
|
203
|
+
Parameters:
|
|
204
|
+
-----------
|
|
205
|
+
- node (DataPoint): The DataPoint object to be added to the graph.
|
|
206
|
+
"""
|
|
207
|
+
try:
|
|
208
|
+
# Prepare node properties with the ID and graph type
|
|
209
|
+
serialized_properties = self._serialize_properties(node.model_dump())
|
|
210
|
+
|
|
211
|
+
query = f"""
|
|
212
|
+
MERGE (n:{self._GRAPH_NODE_LABEL} {{`~id`: $node_id}})
|
|
213
|
+
ON CREATE SET n = $properties, n.updated_at = timestamp()
|
|
214
|
+
ON MATCH SET n += $properties, n.updated_at = timestamp()
|
|
215
|
+
RETURN n
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
params = {
|
|
219
|
+
"node_id": str(node.id),
|
|
220
|
+
"properties": serialized_properties,
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
result = await self.query(query, params)
|
|
224
|
+
logger.debug(f"Successfully added/updated node: {node.id}")
|
|
225
|
+
logger.debug(f"Successfully added/updated node: {str(result)}")
|
|
226
|
+
|
|
227
|
+
except Exception as e:
|
|
228
|
+
error_msg = format_neptune_error(e)
|
|
229
|
+
logger.error(f"Failed to add node {node.id}: {error_msg}")
|
|
230
|
+
raise Exception(f"Failed to add node: {error_msg}") from e
|
|
231
|
+
|
|
232
|
+
@record_graph_changes
|
|
233
|
+
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
|
234
|
+
"""
|
|
235
|
+
Add multiple nodes to the graph in a single operation.
|
|
236
|
+
|
|
237
|
+
Parameters:
|
|
238
|
+
-----------
|
|
239
|
+
- nodes (List[DataPoint]): A list of DataPoint objects to be added to the graph.
|
|
240
|
+
"""
|
|
241
|
+
if not nodes:
|
|
242
|
+
logger.debug("No nodes to add")
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
# Build bulk node creation query using UNWIND
|
|
247
|
+
query = f"""
|
|
248
|
+
UNWIND $nodes AS node
|
|
249
|
+
MERGE (n:{self._GRAPH_NODE_LABEL} {{`~id`: node.node_id}})
|
|
250
|
+
ON CREATE SET n = node.properties, n.updated_at = timestamp()
|
|
251
|
+
ON MATCH SET n += node.properties, n.updated_at = timestamp()
|
|
252
|
+
RETURN count(n) AS nodes_processed
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
# Prepare node data for bulk operation
|
|
256
|
+
params = {
|
|
257
|
+
"nodes": [
|
|
258
|
+
{
|
|
259
|
+
"node_id": str(node.id),
|
|
260
|
+
"properties": self._serialize_properties(node.model_dump()),
|
|
261
|
+
}
|
|
262
|
+
for node in nodes
|
|
263
|
+
]
|
|
264
|
+
}
|
|
265
|
+
result = await self.query(query, params)
|
|
266
|
+
|
|
267
|
+
processed_count = result[0].get("nodes_processed", 0) if result else 0
|
|
268
|
+
logger.debug(f"Successfully processed {processed_count} nodes in bulk operation")
|
|
269
|
+
|
|
270
|
+
except Exception as e:
|
|
271
|
+
error_msg = format_neptune_error(e)
|
|
272
|
+
logger.error(f"Failed to add nodes in bulk: {error_msg}")
|
|
273
|
+
# Fallback to individual node creation
|
|
274
|
+
logger.info("Falling back to individual node creation")
|
|
275
|
+
for node in nodes:
|
|
276
|
+
try:
|
|
277
|
+
await self.add_node(node)
|
|
278
|
+
except Exception as node_error:
|
|
279
|
+
logger.error(
|
|
280
|
+
f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}"
|
|
281
|
+
)
|
|
282
|
+
continue
|
|
283
|
+
|
|
284
|
+
async def delete_node(self, node_id: str) -> None:
|
|
285
|
+
"""
|
|
286
|
+
Delete a specified node from the graph by its ID.
|
|
287
|
+
|
|
288
|
+
Parameters:
|
|
289
|
+
-----------
|
|
290
|
+
- node_id (str): Unique identifier for the node to delete.
|
|
291
|
+
"""
|
|
292
|
+
try:
|
|
293
|
+
# Build openCypher query to delete the node and all its relationships
|
|
294
|
+
query = f"""
|
|
295
|
+
MATCH (n:{self._GRAPH_NODE_LABEL})
|
|
296
|
+
WHERE id(n) = $node_id
|
|
297
|
+
DETACH DELETE n
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
params = {"node_id": node_id}
|
|
301
|
+
|
|
302
|
+
await self.query(query, params)
|
|
303
|
+
logger.debug(f"Successfully deleted node: {node_id}")
|
|
304
|
+
|
|
305
|
+
except Exception as e:
|
|
306
|
+
error_msg = format_neptune_error(e)
|
|
307
|
+
logger.error(f"Failed to delete node {node_id}: {error_msg}")
|
|
308
|
+
raise Exception(f"Failed to delete node: {error_msg}") from e
|
|
309
|
+
|
|
310
|
+
async def delete_nodes(self, node_ids: List[str]) -> None:
|
|
311
|
+
"""
|
|
312
|
+
Delete multiple nodes from the graph by their identifiers.
|
|
313
|
+
|
|
314
|
+
Parameters:
|
|
315
|
+
-----------
|
|
316
|
+
- node_ids (List[str]): A list of unique identifiers for the nodes to delete.
|
|
317
|
+
"""
|
|
318
|
+
if not node_ids:
|
|
319
|
+
logger.debug("No nodes to delete")
|
|
320
|
+
return
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
# Build bulk node deletion query using UNWIND
|
|
324
|
+
query = f"""
|
|
325
|
+
UNWIND $node_ids AS node_id
|
|
326
|
+
MATCH (n:{self._GRAPH_NODE_LABEL})
|
|
327
|
+
WHERE id(n) = node_id
|
|
328
|
+
DETACH DELETE n
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
params = {"node_ids": node_ids}
|
|
332
|
+
await self.query(query, params)
|
|
333
|
+
logger.debug(f"Successfully deleted {len(node_ids)} nodes in bulk operation")
|
|
334
|
+
|
|
335
|
+
except Exception as e:
|
|
336
|
+
error_msg = format_neptune_error(e)
|
|
337
|
+
logger.error(f"Failed to delete nodes in bulk: {error_msg}")
|
|
338
|
+
# Fallback to individual node deletion
|
|
339
|
+
logger.info("Falling back to individual node deletion")
|
|
340
|
+
for node_id in node_ids:
|
|
341
|
+
try:
|
|
342
|
+
await self.delete_node(node_id)
|
|
343
|
+
except Exception as node_error:
|
|
344
|
+
logger.error(
|
|
345
|
+
f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}"
|
|
346
|
+
)
|
|
347
|
+
continue
|
|
348
|
+
|
|
349
|
+
async def get_node(self, node_id: str) -> Optional[NodeData]:
|
|
350
|
+
"""
|
|
351
|
+
Retrieve a single node from the graph using its ID.
|
|
352
|
+
|
|
353
|
+
Parameters:
|
|
354
|
+
-----------
|
|
355
|
+
- node_id (str): Unique identifier of the node to retrieve.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
--------
|
|
359
|
+
- Optional[NodeData]: The node data if found, None otherwise.
|
|
360
|
+
"""
|
|
361
|
+
try:
|
|
362
|
+
# Build openCypher query to retrieve the node
|
|
363
|
+
query = f"""
|
|
364
|
+
MATCH (n:{self._GRAPH_NODE_LABEL})
|
|
365
|
+
WHERE id(n) = $node_id
|
|
366
|
+
RETURN n
|
|
367
|
+
"""
|
|
368
|
+
params = {"node_id": node_id}
|
|
369
|
+
|
|
370
|
+
result = await self.query(query, params)
|
|
371
|
+
|
|
372
|
+
if result and len(result) == 1:
|
|
373
|
+
# Extract node properties from the result
|
|
374
|
+
logger.debug(f"Successfully retrieved node: {node_id}")
|
|
375
|
+
return result[0]["n"]
|
|
376
|
+
else:
|
|
377
|
+
if not result:
|
|
378
|
+
logger.debug(f"Node not found: {node_id}")
|
|
379
|
+
elif len(result) > 1:
|
|
380
|
+
logger.debug(f"Only one node expected, multiple returned: {node_id}")
|
|
381
|
+
return None
|
|
382
|
+
|
|
383
|
+
except Exception as e:
|
|
384
|
+
error_msg = format_neptune_error(e)
|
|
385
|
+
logger.error(f"Failed to get node {node_id}: {error_msg}")
|
|
386
|
+
raise Exception(f"Failed to get node: {error_msg}") from e
|
|
387
|
+
|
|
388
|
+
async def get_nodes(self, node_ids: List[str]) -> List[NodeData]:
|
|
389
|
+
"""
|
|
390
|
+
Retrieve multiple nodes from the graph using their IDs.
|
|
391
|
+
|
|
392
|
+
Parameters:
|
|
393
|
+
-----------
|
|
394
|
+
- node_ids (List[str]): A list of unique identifiers for the nodes to retrieve.
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
--------
|
|
398
|
+
- List[NodeData]: A list of node data for the found nodes.
|
|
399
|
+
"""
|
|
400
|
+
if not node_ids:
|
|
401
|
+
logger.debug("No node IDs provided")
|
|
402
|
+
return []
|
|
403
|
+
|
|
404
|
+
try:
|
|
405
|
+
# Build bulk node-retrieval OpenCypher query using UNWIND
|
|
406
|
+
query = f"""
|
|
407
|
+
UNWIND $node_ids AS node_id
|
|
408
|
+
MATCH (n:{self._GRAPH_NODE_LABEL})
|
|
409
|
+
WHERE id(n) = node_id
|
|
410
|
+
RETURN n
|
|
411
|
+
"""
|
|
412
|
+
|
|
413
|
+
params = {"node_ids": node_ids}
|
|
414
|
+
result = await self.query(query, params)
|
|
415
|
+
|
|
416
|
+
# Extract node data from results
|
|
417
|
+
nodes = [record["n"] for record in result]
|
|
418
|
+
|
|
419
|
+
logger.debug(
|
|
420
|
+
f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested"
|
|
421
|
+
)
|
|
422
|
+
return nodes
|
|
423
|
+
|
|
424
|
+
except Exception as e:
|
|
425
|
+
error_msg = format_neptune_error(e)
|
|
426
|
+
logger.error(f"Failed to get nodes in bulk: {error_msg}")
|
|
427
|
+
# Fallback to individual node retrieval
|
|
428
|
+
logger.info("Falling back to individual node retrieval")
|
|
429
|
+
nodes = []
|
|
430
|
+
for node_id in node_ids:
|
|
431
|
+
try:
|
|
432
|
+
node_data = await self.get_node(node_id)
|
|
433
|
+
if node_data:
|
|
434
|
+
nodes.append(node_data)
|
|
435
|
+
except Exception as node_error:
|
|
436
|
+
logger.error(
|
|
437
|
+
f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}"
|
|
438
|
+
)
|
|
439
|
+
continue
|
|
440
|
+
return nodes
|
|
441
|
+
|
|
442
|
+
async def extract_node(self, node_id: str):
|
|
443
|
+
"""
|
|
444
|
+
Retrieve a single node based on its ID.
|
|
445
|
+
|
|
446
|
+
Parameters:
|
|
447
|
+
-----------
|
|
448
|
+
|
|
449
|
+
- node_id (str): The ID of the node to retrieve.
|
|
450
|
+
|
|
451
|
+
Returns:
|
|
452
|
+
--------
|
|
453
|
+
|
|
454
|
+
- Optional[Dict[str, Any]]: The requested node as a dictionary, or None if it does
|
|
455
|
+
not exist.
|
|
456
|
+
"""
|
|
457
|
+
results = await self.extract_nodes([node_id])
|
|
458
|
+
|
|
459
|
+
return results[0] if len(results) > 0 else None
|
|
460
|
+
|
|
461
|
+
async def extract_nodes(self, node_ids: List[str]):
|
|
462
|
+
"""
|
|
463
|
+
Retrieve multiple nodes from the database by their IDs.
|
|
464
|
+
|
|
465
|
+
Parameters:
|
|
466
|
+
-----------
|
|
467
|
+
|
|
468
|
+
- node_ids (List[str]): A list of IDs for the nodes to retrieve.
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
--------
|
|
472
|
+
|
|
473
|
+
A list of nodes represented as dictionaries.
|
|
474
|
+
"""
|
|
475
|
+
query = f"""
|
|
476
|
+
UNWIND $node_ids AS id
|
|
477
|
+
MATCH (node :{self._GRAPH_NODE_LABEL}) WHERE id(node) = id
|
|
478
|
+
RETURN node"""
|
|
479
|
+
|
|
480
|
+
params = {"node_ids": node_ids}
|
|
481
|
+
|
|
482
|
+
results = await self.query(query, params)
|
|
483
|
+
|
|
484
|
+
return [result["node"] for result in results]
|
|
485
|
+
|
|
486
|
+
async def add_edge(
|
|
487
|
+
self,
|
|
488
|
+
source_id: str,
|
|
489
|
+
target_id: str,
|
|
490
|
+
relationship_name: str,
|
|
491
|
+
properties: Optional[Dict[str, Any]] = None,
|
|
492
|
+
) -> None:
|
|
493
|
+
"""
|
|
494
|
+
Create a new edge between two nodes in the graph.
|
|
495
|
+
|
|
496
|
+
Parameters:
|
|
497
|
+
-----------
|
|
498
|
+
- source_id (str): The unique identifier of the source node.
|
|
499
|
+
- target_id (str): The unique identifier of the target node.
|
|
500
|
+
- relationship_name (str): The name of the relationship to be established by the edge.
|
|
501
|
+
- properties (Optional[Dict[str, Any]]): Optional dictionary of properties associated with the edge.
|
|
502
|
+
"""
|
|
503
|
+
try:
|
|
504
|
+
# Build openCypher query to create the edge
|
|
505
|
+
# First ensure both nodes exist, then create the relationship
|
|
506
|
+
|
|
507
|
+
# Prepare edge properties
|
|
508
|
+
edge_props = properties or {}
|
|
509
|
+
serialized_properties = self._serialize_properties(edge_props)
|
|
510
|
+
|
|
511
|
+
query = f"""
|
|
512
|
+
MATCH (source:{self._GRAPH_NODE_LABEL})
|
|
513
|
+
WHERE id(source) = $source_id
|
|
514
|
+
MATCH (target:{self._GRAPH_NODE_LABEL})
|
|
515
|
+
WHERE id(target) = $target_id
|
|
516
|
+
MERGE (source)-[r:{relationship_name}]->(target)
|
|
517
|
+
ON CREATE SET r = $properties, r.updated_at = timestamp()
|
|
518
|
+
ON MATCH SET r = $properties, r.updated_at = timestamp()
|
|
519
|
+
RETURN r
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
params = {
|
|
523
|
+
"source_id": source_id,
|
|
524
|
+
"target_id": target_id,
|
|
525
|
+
"properties": serialized_properties,
|
|
526
|
+
}
|
|
527
|
+
await self.query(query, params)
|
|
528
|
+
logger.debug(
|
|
529
|
+
f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}"
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
except Exception as e:
|
|
533
|
+
error_msg = format_neptune_error(e)
|
|
534
|
+
logger.error(f"Failed to add edge {source_id} -> {target_id}: {error_msg}")
|
|
535
|
+
raise Exception(f"Failed to add edge: {error_msg}") from e
|
|
536
|
+
|
|
537
|
+
@record_graph_changes
|
|
538
|
+
async def add_edges(self, edges: List[Tuple[str, str, str, Optional[Dict[str, Any]]]]) -> None:
|
|
539
|
+
"""
|
|
540
|
+
Add multiple edges to the graph in a single operation.
|
|
541
|
+
|
|
542
|
+
Parameters:
|
|
543
|
+
-----------
|
|
544
|
+
- edges (List[EdgeData]): A list of EdgeData objects representing edges to be added.
|
|
545
|
+
"""
|
|
546
|
+
if not edges:
|
|
547
|
+
logger.debug("No edges to add")
|
|
548
|
+
return
|
|
549
|
+
|
|
550
|
+
edges_by_relationship: dict[str, list] = {}
|
|
551
|
+
for edge in edges:
|
|
552
|
+
relationship_name = edge[2]
|
|
553
|
+
if edges_by_relationship.get(relationship_name, None):
|
|
554
|
+
edges_by_relationship[relationship_name].append(edge)
|
|
555
|
+
else:
|
|
556
|
+
edges_by_relationship[relationship_name] = [edge]
|
|
557
|
+
|
|
558
|
+
results = {}
|
|
559
|
+
for relationship_name, edges_for_relationship in edges_by_relationship.items():
|
|
560
|
+
try:
|
|
561
|
+
# Create the bulk-edge OpenCypher query using UNWIND
|
|
562
|
+
query = f"""
|
|
563
|
+
UNWIND $edges AS edge
|
|
564
|
+
MATCH (source:{self._GRAPH_NODE_LABEL})
|
|
565
|
+
WHERE id(source) = edge.from_node
|
|
566
|
+
MATCH (target:{self._GRAPH_NODE_LABEL})
|
|
567
|
+
WHERE id(target) = edge.to_node
|
|
568
|
+
MERGE (source)-[r:{relationship_name}]->(target)
|
|
569
|
+
ON CREATE SET r = edge.properties, r.updated_at = timestamp()
|
|
570
|
+
ON MATCH SET r = edge.properties, r.updated_at = timestamp()
|
|
571
|
+
RETURN count(*) AS edges_processed
|
|
572
|
+
"""
|
|
573
|
+
|
|
574
|
+
# Prepare edges data for bulk operation
|
|
575
|
+
params = {
|
|
576
|
+
"edges": [
|
|
577
|
+
{
|
|
578
|
+
"from_node": str(edge[0]),
|
|
579
|
+
"to_node": str(edge[1]),
|
|
580
|
+
"relationship_name": relationship_name,
|
|
581
|
+
"properties": self._serialize_properties(
|
|
582
|
+
edge[3] if len(edge) > 3 and edge[3] else {}
|
|
583
|
+
),
|
|
584
|
+
}
|
|
585
|
+
for edge in edges_for_relationship
|
|
586
|
+
]
|
|
587
|
+
}
|
|
588
|
+
results[relationship_name] = await self.query(query, params)
|
|
589
|
+
except Exception as e:
|
|
590
|
+
logger.error(
|
|
591
|
+
f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}"
|
|
592
|
+
)
|
|
593
|
+
logger.info("Falling back to individual edge creation")
|
|
594
|
+
for edge in edges_by_relationship:
|
|
595
|
+
try:
|
|
596
|
+
source_id, target_id, relationship_name = edge[0], edge[1], edge[2]
|
|
597
|
+
properties = edge[3] if len(edge) > 3 else {}
|
|
598
|
+
await self.add_edge(source_id, target_id, relationship_name, properties)
|
|
599
|
+
except Exception as edge_error:
|
|
600
|
+
logger.error(
|
|
601
|
+
f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}"
|
|
602
|
+
)
|
|
603
|
+
continue
|
|
604
|
+
|
|
605
|
+
processed_count = 0
|
|
606
|
+
for result in results.values():
|
|
607
|
+
processed_count += result[0].get("edges_processed", 0) if result else 0
|
|
608
|
+
logger.debug(f"Successfully processed {processed_count} edges in bulk operation")
|
|
609
|
+
|
|
610
|
+
async def delete_graph(self) -> None:
|
|
611
|
+
"""
|
|
612
|
+
Delete all nodes and edges from the graph database.
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
--------
|
|
616
|
+
The result of the query execution, typically indicating success or failure.
|
|
617
|
+
"""
|
|
618
|
+
try:
|
|
619
|
+
# Build openCypher query to delete the graph
|
|
620
|
+
query = f"MATCH (n:{self._GRAPH_NODE_LABEL}) DETACH DELETE n"
|
|
621
|
+
await self.query(query)
|
|
622
|
+
|
|
623
|
+
except Exception as e:
|
|
624
|
+
error_msg = format_neptune_error(e)
|
|
625
|
+
logger.error(f"Failed to delete graph: {error_msg}")
|
|
626
|
+
raise Exception(f"Failed to delete graph: {error_msg}") from e
|
|
627
|
+
|
|
628
|
+
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
|
|
629
|
+
"""
|
|
630
|
+
Retrieve all nodes and edges within the graph.
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
--------
|
|
634
|
+
- Tuple[List[Node], List[EdgeData]]: A tuple containing all nodes and edges in the graph.
|
|
635
|
+
"""
|
|
636
|
+
try:
|
|
637
|
+
# Query to get all nodes
|
|
638
|
+
nodes_query = f"""
|
|
639
|
+
MATCH (n:{self._GRAPH_NODE_LABEL})
|
|
640
|
+
RETURN id(n) AS node_id, properties(n) AS properties
|
|
641
|
+
"""
|
|
642
|
+
|
|
643
|
+
# Query to get all edges
|
|
644
|
+
edges_query = f"""
|
|
645
|
+
MATCH (source:{self._GRAPH_NODE_LABEL})-[r]->(target:{self._GRAPH_NODE_LABEL})
|
|
646
|
+
RETURN id(source) AS source_id, id(target) AS target_id, type(r) AS relationship_name, properties(r) AS properties
|
|
647
|
+
"""
|
|
648
|
+
|
|
649
|
+
# Execute both queries
|
|
650
|
+
nodes_result = await self.query(nodes_query)
|
|
651
|
+
edges_result = await self.query(edges_query)
|
|
652
|
+
|
|
653
|
+
# Format nodes as (node_id, properties) tuples
|
|
654
|
+
nodes = [(result["node_id"], result["properties"]) for result in nodes_result]
|
|
655
|
+
|
|
656
|
+
# Format edges as (source_id, target_id, relationship_name, properties) tuples
|
|
657
|
+
edges = [
|
|
658
|
+
(
|
|
659
|
+
result["source_id"],
|
|
660
|
+
result["target_id"],
|
|
661
|
+
result["relationship_name"],
|
|
662
|
+
result["properties"],
|
|
663
|
+
)
|
|
664
|
+
for result in edges_result
|
|
665
|
+
]
|
|
666
|
+
|
|
667
|
+
logger.debug(f"Retrieved {len(nodes)} nodes and {len(edges)} edges from graph")
|
|
668
|
+
return (nodes, edges)
|
|
669
|
+
|
|
670
|
+
except Exception as e:
|
|
671
|
+
error_msg = format_neptune_error(e)
|
|
672
|
+
logger.error(f"Failed to get graph data: {error_msg}")
|
|
673
|
+
raise Exception(f"Failed to get graph data: {error_msg}") from e
|
|
674
|
+
|
|
675
|
+
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
|
|
676
|
+
"""
|
|
677
|
+
Fetch metrics and statistics of the graph, possibly including optional details.
|
|
678
|
+
|
|
679
|
+
Parameters:
|
|
680
|
+
-----------
|
|
681
|
+
- include_optional (bool): Flag indicating whether to include optional metrics or not.
|
|
682
|
+
|
|
683
|
+
Returns:
|
|
684
|
+
--------
|
|
685
|
+
- Dict[str, Any]: A dictionary containing graph metrics and statistics.
|
|
686
|
+
"""
|
|
687
|
+
num_nodes, num_edges = await self._get_model_independent_graph_data()
|
|
688
|
+
num_cluster, list_clsuter_size = await self._get_connected_components_stat()
|
|
689
|
+
|
|
690
|
+
mandatory_metrics = {
|
|
691
|
+
"num_nodes": num_nodes,
|
|
692
|
+
"num_edges": num_edges,
|
|
693
|
+
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
|
|
694
|
+
"edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1))
|
|
695
|
+
if num_nodes != 0
|
|
696
|
+
else None,
|
|
697
|
+
"num_connected_components": num_cluster,
|
|
698
|
+
"sizes_of_connected_components": list_clsuter_size,
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
optional_metrics = {
|
|
702
|
+
"num_selfloops": -1,
|
|
703
|
+
"diameter": -1,
|
|
704
|
+
"avg_shortest_path_length": -1,
|
|
705
|
+
"avg_clustering": -1,
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
if include_optional:
|
|
709
|
+
optional_metrics["num_selfloops"] = await self._count_self_loops()
|
|
710
|
+
# Unsupported due to long-running queries when computing the shortest path for each node in the graph:
|
|
711
|
+
# optional_metrics['diameter']
|
|
712
|
+
# optional_metrics['avg_shortest_path_length']
|
|
713
|
+
#
|
|
714
|
+
# Unsupported due to incompatible algorithm: localClusteringCoefficient
|
|
715
|
+
# optional_metrics['avg_clustering']
|
|
716
|
+
|
|
717
|
+
return mandatory_metrics | optional_metrics
|
|
718
|
+
|
|
719
|
+
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
|
|
720
|
+
"""
|
|
721
|
+
Verify if an edge exists between two specified nodes.
|
|
722
|
+
|
|
723
|
+
Parameters:
|
|
724
|
+
-----------
|
|
725
|
+
- source_id (str): Unique identifier of the source node.
|
|
726
|
+
- target_id (str): Unique identifier of the target node.
|
|
727
|
+
- relationship_name (str): Name of the relationship to verify.
|
|
728
|
+
|
|
729
|
+
Returns:
|
|
730
|
+
--------
|
|
731
|
+
- bool: True if the edge exists, False otherwise.
|
|
732
|
+
"""
|
|
733
|
+
try:
|
|
734
|
+
# Build openCypher query to check if the edge exists
|
|
735
|
+
query = f"""
|
|
736
|
+
MATCH (source:{self._GRAPH_NODE_LABEL})-[r:{relationship_name}]->(target:{self._GRAPH_NODE_LABEL})
|
|
737
|
+
WHERE id(source) = $source_id AND id(target) = $target_id
|
|
738
|
+
RETURN COUNT(r) > 0 AS edge_exists
|
|
739
|
+
"""
|
|
740
|
+
|
|
741
|
+
params = {
|
|
742
|
+
"source_id": source_id,
|
|
743
|
+
"target_id": target_id,
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
result = await self.query(query, params)
|
|
747
|
+
|
|
748
|
+
if result and len(result) > 0:
|
|
749
|
+
edge_exists = result.pop().get("edge_exists", False)
|
|
750
|
+
logger.debug(
|
|
751
|
+
f"Edge existence check for "
|
|
752
|
+
f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}"
|
|
753
|
+
)
|
|
754
|
+
return edge_exists
|
|
755
|
+
else:
|
|
756
|
+
return False
|
|
757
|
+
|
|
758
|
+
except Exception as e:
|
|
759
|
+
error_msg = format_neptune_error(e)
|
|
760
|
+
logger.error(f"Failed to check edge existence {source_id} -> {target_id}: {error_msg}")
|
|
761
|
+
return False
|
|
762
|
+
|
|
763
|
+
async def has_edges(self, edges: List[EdgeData]) -> List[EdgeData]:
|
|
764
|
+
"""
|
|
765
|
+
Determine the existence of multiple edges in the graph.
|
|
766
|
+
|
|
767
|
+
Parameters:
|
|
768
|
+
-----------
|
|
769
|
+
- edges (List[EdgeData]): A list of EdgeData objects to check for existence in the graph.
|
|
770
|
+
|
|
771
|
+
Returns:
|
|
772
|
+
--------
|
|
773
|
+
- List[EdgeData]: A list of EdgeData objects that exist in the graph.
|
|
774
|
+
"""
|
|
775
|
+
query = f"""
|
|
776
|
+
UNWIND $edges AS edge
|
|
777
|
+
MATCH (a:{self._GRAPH_NODE_LABEL})-[r]->(b:{self._GRAPH_NODE_LABEL})
|
|
778
|
+
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
|
|
779
|
+
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
|
|
780
|
+
"""
|
|
781
|
+
|
|
782
|
+
try:
|
|
783
|
+
params = {
|
|
784
|
+
"edges": [
|
|
785
|
+
{
|
|
786
|
+
"from_node": str(edge[0]),
|
|
787
|
+
"to_node": str(edge[1]),
|
|
788
|
+
"relationship_name": edge[2],
|
|
789
|
+
}
|
|
790
|
+
for edge in edges
|
|
791
|
+
],
|
|
792
|
+
}
|
|
793
|
+
|
|
794
|
+
results = await self.query(query, params)
|
|
795
|
+
logger.debug(f"Found {len(results)} existing edges out of {len(edges)} checked")
|
|
796
|
+
return [result["edge_exists"] for result in results]
|
|
797
|
+
|
|
798
|
+
except Exception as e:
|
|
799
|
+
error_msg = format_neptune_error(e)
|
|
800
|
+
logger.error(f"Failed to check edges existence: {error_msg}")
|
|
801
|
+
return []
|
|
802
|
+
|
|
803
|
+
async def get_edges(self, node_id: str) -> List[EdgeData]:
|
|
804
|
+
"""
|
|
805
|
+
Retrieve all edges that are connected to the specified node.
|
|
806
|
+
|
|
807
|
+
Parameters:
|
|
808
|
+
-----------
|
|
809
|
+
- node_id (str): Unique identifier of the node whose edges are to be retrieved.
|
|
810
|
+
|
|
811
|
+
Returns:
|
|
812
|
+
--------
|
|
813
|
+
- List[EdgeData]: A list of EdgeData objects representing edges connected to the node.
|
|
814
|
+
"""
|
|
815
|
+
try:
|
|
816
|
+
# Query to get all edges connected to the node (both incoming and outgoing)
|
|
817
|
+
query = f"""
|
|
818
|
+
MATCH (n:{self._GRAPH_NODE_LABEL})-[r]-(m:{self._GRAPH_NODE_LABEL})
|
|
819
|
+
WHERE id(n) = $node_id
|
|
820
|
+
RETURN
|
|
821
|
+
id(n) AS source_id,
|
|
822
|
+
id(m) AS target_id,
|
|
823
|
+
type(r) AS relationship_name,
|
|
824
|
+
properties(r) AS properties
|
|
825
|
+
"""
|
|
826
|
+
|
|
827
|
+
params = {"node_id": node_id}
|
|
828
|
+
result = await self.query(query, params)
|
|
829
|
+
|
|
830
|
+
# Format edges as EdgeData tuples: (source_id, target_id, relationship_name, properties)
|
|
831
|
+
edges = [self._convert_relationship_to_edge(record) for record in result]
|
|
832
|
+
|
|
833
|
+
logger.debug(f"Retrieved {len(edges)} edges for node: {node_id}")
|
|
834
|
+
return edges
|
|
835
|
+
|
|
836
|
+
except Exception as e:
|
|
837
|
+
error_msg = format_neptune_error(e)
|
|
838
|
+
logger.error(f"Failed to get edges for node {node_id}: {error_msg}")
|
|
839
|
+
raise Exception(f"Failed to get edges: {error_msg}") from e
|
|
840
|
+
|
|
841
|
+
async def get_disconnected_nodes(self) -> list[str]:
|
|
842
|
+
"""
|
|
843
|
+
Find and return nodes that are not connected to any other nodes in the graph.
|
|
844
|
+
|
|
845
|
+
Returns:
|
|
846
|
+
--------
|
|
847
|
+
|
|
848
|
+
- list[str]: A list of IDs of disconnected nodes.
|
|
849
|
+
"""
|
|
850
|
+
query = f"""
|
|
851
|
+
MATCH(n :{self._GRAPH_NODE_LABEL})
|
|
852
|
+
WHERE NOT (n)--()
|
|
853
|
+
RETURN COLLECT(ID(n)) as ids
|
|
854
|
+
"""
|
|
855
|
+
|
|
856
|
+
results = await self.query(query)
|
|
857
|
+
return results[0]["ids"] if len(results) > 0 else []
|
|
858
|
+
|
|
859
|
+
async def get_predecessors(self, node_id: str, edge_label: str = "") -> list[str]:
|
|
860
|
+
"""
|
|
861
|
+
Retrieve the predecessor nodes of a specified node based on an optional edge label.
|
|
862
|
+
|
|
863
|
+
Parameters:
|
|
864
|
+
-----------
|
|
865
|
+
|
|
866
|
+
- node_id (str): The ID of the node whose predecessors are to be retrieved.
|
|
867
|
+
- edge_label (str): Optional edge label to filter predecessors. (default None)
|
|
868
|
+
|
|
869
|
+
Returns:
|
|
870
|
+
--------
|
|
871
|
+
|
|
872
|
+
- list[str]: A list of predecessor node IDs.
|
|
873
|
+
"""
|
|
874
|
+
|
|
875
|
+
edge_label = f" :{edge_label}" if edge_label is not None else ""
|
|
876
|
+
query = f"""
|
|
877
|
+
MATCH (node)<-[r{edge_label}]-(predecessor)
|
|
878
|
+
WHERE node.id = $node_id
|
|
879
|
+
RETURN predecessor
|
|
880
|
+
"""
|
|
881
|
+
|
|
882
|
+
results = await self.query(query, {"node_id": node_id})
|
|
883
|
+
|
|
884
|
+
return [result["predecessor"] for result in results]
|
|
885
|
+
|
|
886
|
+
async def get_successors(self, node_id: str, edge_label: str = "") -> list[str]:
|
|
887
|
+
"""
|
|
888
|
+
Retrieve the successor nodes of a specified node based on an optional edge label.
|
|
889
|
+
|
|
890
|
+
Parameters:
|
|
891
|
+
-----------
|
|
892
|
+
|
|
893
|
+
- node_id (str): The ID of the node whose successors are to be retrieved.
|
|
894
|
+
- edge_label (str): Optional edge label to filter successors. (default None)
|
|
895
|
+
|
|
896
|
+
Returns:
|
|
897
|
+
--------
|
|
898
|
+
|
|
899
|
+
- list[str]: A list of successor node IDs.
|
|
900
|
+
"""
|
|
901
|
+
|
|
902
|
+
edge_label = f" :{edge_label}" if edge_label is not None else ""
|
|
903
|
+
query = f"""
|
|
904
|
+
MATCH (node)-[r {edge_label}]->(successor)
|
|
905
|
+
WHERE node.id = $node_id
|
|
906
|
+
RETURN successor
|
|
907
|
+
"""
|
|
908
|
+
|
|
909
|
+
results = await self.query(query, {"node_id": node_id})
|
|
910
|
+
|
|
911
|
+
return [result["successor"] for result in results]
|
|
912
|
+
|
|
913
|
+
async def get_neighbors(self, node_id: str) -> List[NodeData]:
|
|
914
|
+
"""
|
|
915
|
+
Get all neighboring nodes connected to the specified node.
|
|
916
|
+
|
|
917
|
+
Parameters:
|
|
918
|
+
-----------
|
|
919
|
+
- node_id (str): Unique identifier of the node for which to retrieve neighbors.
|
|
920
|
+
|
|
921
|
+
Returns:
|
|
922
|
+
--------
|
|
923
|
+
- List[NodeData]: A list of NodeData objects representing neighboring nodes.
|
|
924
|
+
"""
|
|
925
|
+
try:
|
|
926
|
+
# Query to get all neighboring nodes (both incoming and outgoing connections)
|
|
927
|
+
query = f"""
|
|
928
|
+
MATCH (n:{self._GRAPH_NODE_LABEL})-[r]-(neighbor:{self._GRAPH_NODE_LABEL})
|
|
929
|
+
WHERE id(n) = $node_id
|
|
930
|
+
RETURN DISTINCT id(neighbor) AS neighbor_id, properties(neighbor) AS properties
|
|
931
|
+
"""
|
|
932
|
+
|
|
933
|
+
params = {"node_id": node_id}
|
|
934
|
+
result = await self.query(query, params)
|
|
935
|
+
|
|
936
|
+
# Format neighbors as NodeData objects
|
|
937
|
+
neighbors = [
|
|
938
|
+
{"id": neighbor["neighbor_id"], **neighbor["properties"]} for neighbor in result
|
|
939
|
+
]
|
|
940
|
+
|
|
941
|
+
logger.debug(f"Retrieved {len(neighbors)} neighbors for node: {node_id}")
|
|
942
|
+
return neighbors
|
|
943
|
+
|
|
944
|
+
except Exception as e:
|
|
945
|
+
error_msg = format_neptune_error(e)
|
|
946
|
+
logger.error(f"Failed to get neighbors for node {node_id}: {error_msg}")
|
|
947
|
+
raise Exception(f"Failed to get neighbors: {error_msg}") from e
|
|
948
|
+
|
|
949
|
+
async def get_nodeset_subgraph(
|
|
950
|
+
self, node_type: Type[Any], node_name: List[str]
|
|
951
|
+
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
|
952
|
+
"""
|
|
953
|
+
Fetch a subgraph consisting of a specific set of nodes and their relationships.
|
|
954
|
+
|
|
955
|
+
Parameters:
|
|
956
|
+
-----------
|
|
957
|
+
- node_type (Type[Any]): The type of nodes to include in the subgraph.
|
|
958
|
+
- node_name (List[str]): A list of names of the nodes to include in the subgraph.
|
|
959
|
+
|
|
960
|
+
Returns:
|
|
961
|
+
--------
|
|
962
|
+
- Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: A tuple containing nodes and edges of the subgraph.
|
|
963
|
+
"""
|
|
964
|
+
try:
|
|
965
|
+
# Query to get nodes by name and their connected subgraph
|
|
966
|
+
query = f"""
|
|
967
|
+
UNWIND $names AS wantedName
|
|
968
|
+
MATCH (n:{self._GRAPH_NODE_LABEL})
|
|
969
|
+
WHERE n.name = wantedName AND n.type = $type
|
|
970
|
+
WITH collect(DISTINCT n) AS primary
|
|
971
|
+
UNWIND primary AS p
|
|
972
|
+
OPTIONAL MATCH (p)-[r]-(nbr:{self._GRAPH_NODE_LABEL})
|
|
973
|
+
WITH primary, collect(DISTINCT nbr) AS nbrs, collect(DISTINCT r) AS rels
|
|
974
|
+
WITH primary + nbrs AS nodelist, rels
|
|
975
|
+
UNWIND nodelist AS node
|
|
976
|
+
WITH collect(DISTINCT node) AS nodes, rels
|
|
977
|
+
MATCH (a:{self._GRAPH_NODE_LABEL})-[r]-(b:{self._GRAPH_NODE_LABEL})
|
|
978
|
+
WHERE a IN nodes AND b IN nodes
|
|
979
|
+
WITH nodes, collect(DISTINCT r) AS all_rels
|
|
980
|
+
RETURN
|
|
981
|
+
[n IN nodes | {{
|
|
982
|
+
id: id(n),
|
|
983
|
+
properties: properties(n)
|
|
984
|
+
}}] AS rawNodes,
|
|
985
|
+
[r IN all_rels | {{
|
|
986
|
+
source_id: id(startNode(r)),
|
|
987
|
+
target_id: id(endNode(r)),
|
|
988
|
+
type: type(r),
|
|
989
|
+
properties: properties(r)
|
|
990
|
+
}}] AS rawRels
|
|
991
|
+
"""
|
|
992
|
+
|
|
993
|
+
params = {"names": node_name, "type": node_type.__name__}
|
|
994
|
+
|
|
995
|
+
result = await self.query(query, params)
|
|
996
|
+
|
|
997
|
+
if not result:
|
|
998
|
+
logger.debug(f"No subgraph found for node type {node_type} with names {node_name}")
|
|
999
|
+
return ([], [])
|
|
1000
|
+
|
|
1001
|
+
raw_nodes = result[0]["rawNodes"]
|
|
1002
|
+
raw_rels = result[0]["rawRels"]
|
|
1003
|
+
|
|
1004
|
+
# Format nodes as (node_id, properties) tuples
|
|
1005
|
+
nodes = [(n["id"], n["properties"]) for n in raw_nodes]
|
|
1006
|
+
|
|
1007
|
+
# Format edges as (source_id, target_id, relationship_name, properties) tuples
|
|
1008
|
+
edges = [(r["source_id"], r["target_id"], r["type"], r["properties"]) for r in raw_rels]
|
|
1009
|
+
|
|
1010
|
+
logger.debug(
|
|
1011
|
+
f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}"
|
|
1012
|
+
)
|
|
1013
|
+
return (nodes, edges)
|
|
1014
|
+
|
|
1015
|
+
except Exception as e:
|
|
1016
|
+
error_msg = format_neptune_error(e)
|
|
1017
|
+
logger.error(f"Failed to get nodeset subgraph for type {node_type}: {error_msg}")
|
|
1018
|
+
raise Exception(f"Failed to get nodeset subgraph: {error_msg}") from e
|
|
1019
|
+
|
|
1020
|
+
async def get_connections(self, node_id: UUID) -> list:
|
|
1021
|
+
"""
|
|
1022
|
+
Get all nodes connected to a specified node and their relationship details.
|
|
1023
|
+
|
|
1024
|
+
Parameters:
|
|
1025
|
+
-----------
|
|
1026
|
+
- node_id (str): Unique identifier of the node for which to retrieve connections.
|
|
1027
|
+
|
|
1028
|
+
Returns:
|
|
1029
|
+
--------
|
|
1030
|
+
- List[Tuple[NodeData, Dict[str, Any], NodeData]]: A list of tuples containing connected nodes and relationship details.
|
|
1031
|
+
"""
|
|
1032
|
+
try:
|
|
1033
|
+
# Query to get all connections (both incoming and outgoing)
|
|
1034
|
+
query = f"""
|
|
1035
|
+
MATCH (source:{self._GRAPH_NODE_LABEL})-[r]->(target:{self._GRAPH_NODE_LABEL})
|
|
1036
|
+
WHERE id(source) = $node_id OR id(target) = $node_id
|
|
1037
|
+
RETURN
|
|
1038
|
+
id(source) AS source_id,
|
|
1039
|
+
properties(source) AS source_props,
|
|
1040
|
+
id(target) AS target_id,
|
|
1041
|
+
properties(target) AS target_props,
|
|
1042
|
+
type(r) AS relationship_name,
|
|
1043
|
+
properties(r) AS relationship_props
|
|
1044
|
+
"""
|
|
1045
|
+
|
|
1046
|
+
params = {"node_id": str(node_id)}
|
|
1047
|
+
result = await self.query(query, params)
|
|
1048
|
+
|
|
1049
|
+
connections = []
|
|
1050
|
+
for record in result:
|
|
1051
|
+
# Return as (source_node, relationship, target_node)
|
|
1052
|
+
connections.append(
|
|
1053
|
+
(
|
|
1054
|
+
{"id": record["source_id"], **record["source_props"]},
|
|
1055
|
+
{
|
|
1056
|
+
"relationship_name": record["relationship_name"],
|
|
1057
|
+
**record["relationship_props"],
|
|
1058
|
+
},
|
|
1059
|
+
{"id": record["target_id"], **record["target_props"]},
|
|
1060
|
+
)
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
logger.debug(f"Retrieved {len(connections)} connections for node: {node_id}")
|
|
1064
|
+
return connections
|
|
1065
|
+
|
|
1066
|
+
except Exception as e:
|
|
1067
|
+
error_msg = format_neptune_error(e)
|
|
1068
|
+
logger.error(f"Failed to get connections for node {node_id}: {error_msg}")
|
|
1069
|
+
raise Exception(f"Failed to get connections: {error_msg}") from e
|
|
1070
|
+
|
|
1071
|
+
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str):
|
|
1072
|
+
"""
|
|
1073
|
+
Remove connections (edges) to all predecessors of specified nodes based on edge label.
|
|
1074
|
+
|
|
1075
|
+
Parameters:
|
|
1076
|
+
-----------
|
|
1077
|
+
|
|
1078
|
+
- node_ids (list[str]): A list of IDs of nodes from which connections are to be
|
|
1079
|
+
removed.
|
|
1080
|
+
- edge_label (str): The label of the edges to remove.
|
|
1081
|
+
|
|
1082
|
+
"""
|
|
1083
|
+
query = f"""
|
|
1084
|
+
UNWIND $node_ids AS node_id
|
|
1085
|
+
MATCH ({{`~id`: node_id}})-[r:{edge_label}]->(predecessor)
|
|
1086
|
+
DELETE r;
|
|
1087
|
+
"""
|
|
1088
|
+
params = {"node_ids": node_ids}
|
|
1089
|
+
await self.query(query, params)
|
|
1090
|
+
|
|
1091
|
+
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str):
|
|
1092
|
+
"""
|
|
1093
|
+
Remove connections (edges) to all successors of specified nodes based on edge label.
|
|
1094
|
+
|
|
1095
|
+
Parameters:
|
|
1096
|
+
-----------
|
|
1097
|
+
|
|
1098
|
+
- node_ids (list[str]): A list of IDs of nodes from which connections are to be
|
|
1099
|
+
removed.
|
|
1100
|
+
- edge_label (str): The label of the edges to remove.
|
|
1101
|
+
|
|
1102
|
+
"""
|
|
1103
|
+
query = f"""
|
|
1104
|
+
UNWIND $node_ids AS node_id
|
|
1105
|
+
MATCH ({{`~id`: node_id}})<-[r:{edge_label}]-(successor)
|
|
1106
|
+
DELETE r;
|
|
1107
|
+
"""
|
|
1108
|
+
params = {"node_ids": node_ids}
|
|
1109
|
+
await self.query(query, params)
|
|
1110
|
+
|
|
1111
|
+
async def get_node_labels_string(self):
|
|
1112
|
+
"""
|
|
1113
|
+
Fetch all node labels from the database and return them as a formatted string.
|
|
1114
|
+
|
|
1115
|
+
Returns:
|
|
1116
|
+
--------
|
|
1117
|
+
|
|
1118
|
+
A formatted string of node labels.
|
|
1119
|
+
|
|
1120
|
+
Raises:
|
|
1121
|
+
-------
|
|
1122
|
+
ValueError: If no node labels are found in the database.
|
|
1123
|
+
"""
|
|
1124
|
+
node_labels_query = (
|
|
1125
|
+
"CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels "
|
|
1126
|
+
)
|
|
1127
|
+
node_labels_result = await self.query(node_labels_query)
|
|
1128
|
+
node_labels = node_labels_result[0]["labels"] if node_labels_result else []
|
|
1129
|
+
|
|
1130
|
+
if not node_labels:
|
|
1131
|
+
raise ValueError("No node labels found in the database")
|
|
1132
|
+
|
|
1133
|
+
return str(node_labels)
|
|
1134
|
+
|
|
1135
|
+
async def get_relationship_labels_string(self):
|
|
1136
|
+
"""
|
|
1137
|
+
Fetch all relationship types from the database and return them as a formatted string.
|
|
1138
|
+
|
|
1139
|
+
Returns:
|
|
1140
|
+
--------
|
|
1141
|
+
|
|
1142
|
+
A formatted string of relationship types.
|
|
1143
|
+
"""
|
|
1144
|
+
relationship_types_query = (
|
|
1145
|
+
"CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships "
|
|
1146
|
+
)
|
|
1147
|
+
relationship_types_result = await self.query(relationship_types_query)
|
|
1148
|
+
relationship_types = (
|
|
1149
|
+
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
|
1150
|
+
)
|
|
1151
|
+
|
|
1152
|
+
if not relationship_types:
|
|
1153
|
+
raise ValueError("No relationship types found in the database.")
|
|
1154
|
+
|
|
1155
|
+
relationship_types_undirected_str = (
|
|
1156
|
+
"{"
|
|
1157
|
+
+ ", ".join(f"{rel}" + ": {orientation: 'UNDIRECTED'}" for rel in relationship_types)
|
|
1158
|
+
+ "}"
|
|
1159
|
+
)
|
|
1160
|
+
return relationship_types_undirected_str
|
|
1161
|
+
|
|
1162
|
+
async def drop_graph(self, graph_name="myGraph"):
|
|
1163
|
+
"""
|
|
1164
|
+
Drop an existing graph from the database based on its name.
|
|
1165
|
+
|
|
1166
|
+
Note: This method is currently a placeholder because GDS (Graph Data Science)
|
|
1167
|
+
projection is not supported in Neptune Analytics.
|
|
1168
|
+
|
|
1169
|
+
Parameters:
|
|
1170
|
+
-----------
|
|
1171
|
+
|
|
1172
|
+
- graph_name: The name of the graph to drop, defaults to 'myGraph'. (default
|
|
1173
|
+
'myGraph')
|
|
1174
|
+
"""
|
|
1175
|
+
pass
|
|
1176
|
+
|
|
1177
|
+
async def graph_exists(self, graph_name="myGraph"):
|
|
1178
|
+
"""
|
|
1179
|
+
Check if a graph with a given name exists in the database.
|
|
1180
|
+
|
|
1181
|
+
Note: This method is currently a placeholder because GDS (Graph Data Science)
|
|
1182
|
+
projection is not supported in Neptune Analytics.
|
|
1183
|
+
|
|
1184
|
+
Parameters:
|
|
1185
|
+
-----------
|
|
1186
|
+
|
|
1187
|
+
- graph_name: The name of the graph to check for existence, defaults to 'myGraph'.
|
|
1188
|
+
(default 'myGraph')
|
|
1189
|
+
|
|
1190
|
+
Returns:
|
|
1191
|
+
--------
|
|
1192
|
+
|
|
1193
|
+
True if the graph exists, otherwise False.
|
|
1194
|
+
"""
|
|
1195
|
+
pass
|
|
1196
|
+
|
|
1197
|
+
async def project_entire_graph(self, graph_name="myGraph"):
|
|
1198
|
+
"""
|
|
1199
|
+
Project all node labels and relationship types into an in-memory graph using GDS.
|
|
1200
|
+
|
|
1201
|
+
Note: This method is currently a placeholder because GDS (Graph Data Science)
|
|
1202
|
+
projection is not supported in Neptune Anlaytics.
|
|
1203
|
+
"""
|
|
1204
|
+
pass
|
|
1205
|
+
|
|
1206
|
+
async def get_filtered_graph_data(self, attribute_filters: list[dict[str, list]]):
|
|
1207
|
+
"""
|
|
1208
|
+
Fetch nodes and edges filtered by specific attribute criteria.
|
|
1209
|
+
|
|
1210
|
+
Parameters:
|
|
1211
|
+
-----------
|
|
1212
|
+
|
|
1213
|
+
- attribute_filters: A list of dictionaries representing attributes and associated
|
|
1214
|
+
values for filtering.
|
|
1215
|
+
|
|
1216
|
+
Returns:
|
|
1217
|
+
--------
|
|
1218
|
+
|
|
1219
|
+
A tuple containing filtered nodes and edges based on the specified criteria.
|
|
1220
|
+
"""
|
|
1221
|
+
where_clauses_n = []
|
|
1222
|
+
where_clauses_m = []
|
|
1223
|
+
for attribute, values in attribute_filters[0].items():
|
|
1224
|
+
values_str = ", ".join(
|
|
1225
|
+
f"'{value}'" if isinstance(value, str) else str(value) for value in values
|
|
1226
|
+
)
|
|
1227
|
+
where_clauses_n.append(f"n.{attribute} IN [{values_str}]")
|
|
1228
|
+
where_clauses_m.append(f"m.{attribute} IN [{values_str}]")
|
|
1229
|
+
|
|
1230
|
+
node_where_clauses_n_str = " AND ".join(where_clauses_n)
|
|
1231
|
+
node_where_clauses_m_str = " AND ".join(where_clauses_m)
|
|
1232
|
+
edge_where_clause = f"{node_where_clauses_n_str} AND {node_where_clauses_m_str}"
|
|
1233
|
+
|
|
1234
|
+
query_nodes = f"""
|
|
1235
|
+
MATCH (n :{self._GRAPH_NODE_LABEL})
|
|
1236
|
+
WHERE {node_where_clauses_n_str}
|
|
1237
|
+
RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties
|
|
1238
|
+
"""
|
|
1239
|
+
result_nodes = await self.query(query_nodes)
|
|
1240
|
+
|
|
1241
|
+
nodes = [
|
|
1242
|
+
(
|
|
1243
|
+
record["id"],
|
|
1244
|
+
record["properties"],
|
|
1245
|
+
)
|
|
1246
|
+
for record in result_nodes
|
|
1247
|
+
]
|
|
1248
|
+
|
|
1249
|
+
query_edges = f"""
|
|
1250
|
+
MATCH (n :{self._GRAPH_NODE_LABEL})-[r]->(m :{self._GRAPH_NODE_LABEL})
|
|
1251
|
+
WHERE {edge_where_clause}
|
|
1252
|
+
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
|
1253
|
+
"""
|
|
1254
|
+
result_edges = await self.query(query_edges)
|
|
1255
|
+
|
|
1256
|
+
edges = [
|
|
1257
|
+
(
|
|
1258
|
+
record["source"],
|
|
1259
|
+
record["target"],
|
|
1260
|
+
record["type"],
|
|
1261
|
+
record["properties"],
|
|
1262
|
+
)
|
|
1263
|
+
for record in result_edges
|
|
1264
|
+
]
|
|
1265
|
+
|
|
1266
|
+
return (nodes, edges)
|
|
1267
|
+
|
|
1268
|
+
async def get_degree_one_nodes(self, node_type: str):
|
|
1269
|
+
"""
|
|
1270
|
+
Fetch nodes of a specified type that have exactly one connection.
|
|
1271
|
+
|
|
1272
|
+
Parameters:
|
|
1273
|
+
-----------
|
|
1274
|
+
|
|
1275
|
+
- node_type (str): The type of nodes to retrieve, must be 'Entity' or 'EntityType'.
|
|
1276
|
+
|
|
1277
|
+
Returns:
|
|
1278
|
+
--------
|
|
1279
|
+
|
|
1280
|
+
A list of nodes with exactly one connection of the specified type.
|
|
1281
|
+
"""
|
|
1282
|
+
if not node_type or node_type not in ["Entity", "EntityType"]:
|
|
1283
|
+
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
|
|
1284
|
+
|
|
1285
|
+
query = f"""
|
|
1286
|
+
MATCH (n :{self._GRAPH_NODE_LABEL})
|
|
1287
|
+
WHERE size((n)--()) = 1
|
|
1288
|
+
AND n.type = $node_type
|
|
1289
|
+
RETURN n
|
|
1290
|
+
"""
|
|
1291
|
+
result = await self.query(query, {"node_type": node_type})
|
|
1292
|
+
return [record["n"] for record in result] if result else []
|
|
1293
|
+
|
|
1294
|
+
async def get_document_subgraph(self, data_id: str):
|
|
1295
|
+
"""
|
|
1296
|
+
Retrieve a subgraph related to a document identified by its content hash, including
|
|
1297
|
+
related entities and chunks.
|
|
1298
|
+
|
|
1299
|
+
Parameters:
|
|
1300
|
+
-----------
|
|
1301
|
+
|
|
1302
|
+
- data_id (str): The document_id identifying the document whose subgraph should be
|
|
1303
|
+
retrieved.
|
|
1304
|
+
|
|
1305
|
+
Returns:
|
|
1306
|
+
--------
|
|
1307
|
+
|
|
1308
|
+
The subgraph data as a dictionary, or None if not found.
|
|
1309
|
+
"""
|
|
1310
|
+
query = f"""
|
|
1311
|
+
|
|
1312
|
+
MATCH (doc)
|
|
1313
|
+
WHERE (doc:{self._GRAPH_NODE_LABEL})
|
|
1314
|
+
AND doc.type in ['TextDocument', 'PdfDocument']
|
|
1315
|
+
AND doc.id = $data_id
|
|
1316
|
+
|
|
1317
|
+
OPTIONAL MATCH (doc)<-[:is_part_of]-(chunk {{type: 'DocumentChunk'}})
|
|
1318
|
+
|
|
1319
|
+
// Alternative to WHERE NOT EXISTS
|
|
1320
|
+
OPTIONAL MATCH (chunk)-[:contains]->(entity {{type: 'Entity'}})
|
|
1321
|
+
OPTIONAL MATCH (entity)<-[:contains]-(otherChunk {{type: 'DocumentChunk'}})-[:is_part_of]->(otherDoc)
|
|
1322
|
+
WHERE otherDoc.type in ['TextDocument', 'PdfDocument']
|
|
1323
|
+
AND otherDoc.id <> doc.id
|
|
1324
|
+
OPTIONAL MATCH (chunk)<-[:made_from]-(made_node {{type: 'TextSummary'}})
|
|
1325
|
+
|
|
1326
|
+
OPTIONAL MATCH (chunk)<-[:made_from]-(made_node {{type: 'TextSummary'}})
|
|
1327
|
+
|
|
1328
|
+
// Alternative to WHERE NOT EXISTS
|
|
1329
|
+
OPTIONAL MATCH (entity)-[:is_a]->(type {{type: 'EntityType'}})
|
|
1330
|
+
OPTIONAL MATCH (type)<-[:is_a]-(otherEntity {{type: 'Entity'}})<-[:contains]-(otherChunk {{type: 'DocumentChunk'}})-[:is_part_of]->(otherDoc)
|
|
1331
|
+
WHERE otherDoc.type in ['TextDocument', 'PdfDocument']
|
|
1332
|
+
AND otherDoc.id <> doc.id
|
|
1333
|
+
|
|
1334
|
+
// Alternative to WHERE NOT EXISTS
|
|
1335
|
+
WITH doc, entity, chunk, made_node, type, otherDoc
|
|
1336
|
+
WHERE otherDoc IS NULL
|
|
1337
|
+
|
|
1338
|
+
RETURN
|
|
1339
|
+
collect(DISTINCT doc) as document,
|
|
1340
|
+
collect(DISTINCT chunk) as chunks,
|
|
1341
|
+
collect(DISTINCT entity) as orphan_entities,
|
|
1342
|
+
collect(DISTINCT made_node) as made_from_nodes,
|
|
1343
|
+
collect(DISTINCT type) as orphan_types
|
|
1344
|
+
"""
|
|
1345
|
+
result = await self.query(query, {"data_id": data_id})
|
|
1346
|
+
return result[0] if result else None
|
|
1347
|
+
|
|
1348
|
+
async def _get_model_independent_graph_data(self):
|
|
1349
|
+
"""
|
|
1350
|
+
Retrieve the basic graph data without considering the model specifics, returning nodes
|
|
1351
|
+
and edges.
|
|
1352
|
+
|
|
1353
|
+
Returns:
|
|
1354
|
+
--------
|
|
1355
|
+
|
|
1356
|
+
A tuple of nodes and edges data.
|
|
1357
|
+
"""
|
|
1358
|
+
query_string = f"""
|
|
1359
|
+
MATCH (n :{self._GRAPH_NODE_LABEL})
|
|
1360
|
+
WITH count(n) AS nodeCount
|
|
1361
|
+
MATCH (a :{self._GRAPH_NODE_LABEL})-[r]->(b :{self._GRAPH_NODE_LABEL})
|
|
1362
|
+
RETURN nodeCount AS numVertices, count(r) AS numEdges
|
|
1363
|
+
"""
|
|
1364
|
+
query_response = await self.query(query_string)
|
|
1365
|
+
num_nodes = query_response[0].get("numVertices")
|
|
1366
|
+
num_edges = query_response[0].get("numEdges")
|
|
1367
|
+
|
|
1368
|
+
return (num_nodes, num_edges)
|
|
1369
|
+
|
|
1370
|
+
async def _get_connected_components_stat(self):
|
|
1371
|
+
"""
|
|
1372
|
+
Retrieve statistics about connected components in the graph.
|
|
1373
|
+
|
|
1374
|
+
This method analyzes the graph to find all connected components
|
|
1375
|
+
and returns both the sizes of each component and the total number of components.
|
|
1376
|
+
|
|
1377
|
+
|
|
1378
|
+
Returns:
|
|
1379
|
+
--------
|
|
1380
|
+
tuple[list[int], int]
|
|
1381
|
+
A tuple containing:
|
|
1382
|
+
- A list of sizes for each connected component (descending order).
|
|
1383
|
+
- The total number of connected components.
|
|
1384
|
+
Returns ([], 0) if no connected components are found.
|
|
1385
|
+
"""
|
|
1386
|
+
query = f"""
|
|
1387
|
+
MATCH(n :{self._GRAPH_NODE_LABEL})
|
|
1388
|
+
CALL neptune.algo.wcc(n,{{}})
|
|
1389
|
+
YIELD node, component
|
|
1390
|
+
RETURN component, count(*) AS size
|
|
1391
|
+
ORDER BY size DESC
|
|
1392
|
+
"""
|
|
1393
|
+
|
|
1394
|
+
result = await self.query(query)
|
|
1395
|
+
size_connected_components = [record["size"] for record in result] if result else []
|
|
1396
|
+
num_connected_components = len(result)
|
|
1397
|
+
|
|
1398
|
+
return (size_connected_components, num_connected_components)
|
|
1399
|
+
|
|
1400
|
+
async def _count_self_loops(self):
|
|
1401
|
+
"""
|
|
1402
|
+
Count the number of self-loop relationships in the Neptune Anlaytics graph backend.
|
|
1403
|
+
|
|
1404
|
+
This function executes a OpenCypher query to find and count all edge relationships that
|
|
1405
|
+
begin and end at the same node (self-loops). It returns the count of such relationships
|
|
1406
|
+
or 0 if no results are found.
|
|
1407
|
+
|
|
1408
|
+
Returns:
|
|
1409
|
+
--------
|
|
1410
|
+
|
|
1411
|
+
The count of self-loop relationships found in the database, or 0 if none were found.
|
|
1412
|
+
"""
|
|
1413
|
+
query = f"""
|
|
1414
|
+
MATCH (n :{self._GRAPH_NODE_LABEL})-[r]->(n :{self._GRAPH_NODE_LABEL})
|
|
1415
|
+
RETURN count(r) AS adapter_loop_count;
|
|
1416
|
+
"""
|
|
1417
|
+
result = await self.query(query)
|
|
1418
|
+
return result[0]["adapter_loop_count"] if result else 0
|
|
1419
|
+
|
|
1420
|
+
@staticmethod
|
|
1421
|
+
def _convert_relationship_to_edge(relationship: dict) -> EdgeData:
|
|
1422
|
+
return (
|
|
1423
|
+
relationship["source_id"],
|
|
1424
|
+
relationship["target_id"],
|
|
1425
|
+
relationship["relationship_name"],
|
|
1426
|
+
relationship["properties"],
|
|
1427
|
+
)
|