graphiti-core 0.12.0rc5__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/driver/__init__.py +17 -0
- graphiti_core/driver/driver.py +66 -0
- graphiti_core/driver/falkordb_driver.py +131 -0
- graphiti_core/driver/neo4j_driver.py +61 -0
- graphiti_core/edges.py +26 -26
- graphiti_core/embedder/azure_openai.py +64 -0
- graphiti_core/graph_queries.py +149 -0
- graphiti_core/graphiti.py +8 -4
- graphiti_core/graphiti_types.py +2 -2
- graphiti_core/helpers.py +9 -3
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/azure_openai_client.py +73 -0
- graphiti_core/nodes.py +31 -31
- graphiti_core/search/search.py +6 -10
- graphiti_core/search/search_utils.py +243 -187
- graphiti_core/utils/bulk_utils.py +20 -11
- graphiti_core/utils/maintenance/community_operations.py +6 -7
- graphiti_core/utils/maintenance/edge_operations.py +0 -1
- graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
- graphiti_core/utils/maintenance/node_operations.py +0 -1
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.1.dist-info}/METADATA +4 -3
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.1.dist-info}/RECORD +25 -18
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.1.dist-info}/LICENSE +0 -0
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Database query utilities for different graph database backends.
|
|
3
|
+
|
|
4
|
+
This module provides database-agnostic query generation for Neo4j and FalkorDB,
|
|
5
|
+
supporting index creation, fulltext search, and bulk operations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from typing_extensions import LiteralString
|
|
11
|
+
|
|
12
|
+
from graphiti_core.models.edges.edge_db_queries import (
|
|
13
|
+
ENTITY_EDGE_SAVE_BULK,
|
|
14
|
+
)
|
|
15
|
+
from graphiti_core.models.nodes.node_db_queries import (
|
|
16
|
+
ENTITY_NODE_SAVE_BULK,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Mapping from Neo4j fulltext index names to FalkorDB node labels
|
|
20
|
+
NEO4J_TO_FALKORDB_MAPPING = {
|
|
21
|
+
'node_name_and_summary': 'Entity',
|
|
22
|
+
'community_name': 'Community',
|
|
23
|
+
'episode_content': 'Episodic',
|
|
24
|
+
'edge_name_and_fact': 'RELATES_TO',
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
|
29
|
+
if db_type == 'falkordb':
|
|
30
|
+
return [
|
|
31
|
+
# Entity node
|
|
32
|
+
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
|
|
33
|
+
# Episodic node
|
|
34
|
+
'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
|
|
35
|
+
# Community node
|
|
36
|
+
'CREATE INDEX FOR (n:Community) ON (n.uuid)',
|
|
37
|
+
# RELATES_TO edge
|
|
38
|
+
'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
|
|
39
|
+
# MENTIONS edge
|
|
40
|
+
'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
|
|
41
|
+
# HAS_MEMBER edge
|
|
42
|
+
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
43
|
+
]
|
|
44
|
+
else:
|
|
45
|
+
return [
|
|
46
|
+
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
47
|
+
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
48
|
+
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
|
49
|
+
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
|
50
|
+
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
|
51
|
+
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
52
|
+
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
|
53
|
+
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
|
54
|
+
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
|
55
|
+
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
|
56
|
+
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
|
57
|
+
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
|
58
|
+
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
|
59
|
+
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
|
60
|
+
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
|
61
|
+
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
|
62
|
+
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
|
63
|
+
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
|
64
|
+
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
|
69
|
+
if db_type == 'falkordb':
|
|
70
|
+
return [
|
|
71
|
+
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
|
|
72
|
+
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
|
|
73
|
+
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
|
|
74
|
+
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
|
75
|
+
]
|
|
76
|
+
else:
|
|
77
|
+
return [
|
|
78
|
+
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
|
79
|
+
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
|
80
|
+
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
|
81
|
+
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
|
82
|
+
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
|
83
|
+
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
|
84
|
+
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
|
85
|
+
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
|
|
90
|
+
if db_type == 'falkordb':
|
|
91
|
+
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
92
|
+
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
|
93
|
+
else:
|
|
94
|
+
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
|
|
98
|
+
if db_type == 'falkordb':
|
|
99
|
+
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
|
100
|
+
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
|
101
|
+
else:
|
|
102
|
+
return f'vector.similarity.cosine({vec1}, {vec2})'
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
|
|
106
|
+
if db_type == 'falkordb':
|
|
107
|
+
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
108
|
+
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
|
109
|
+
else:
|
|
110
|
+
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
|
|
114
|
+
if db_type == 'falkordb':
|
|
115
|
+
queries = []
|
|
116
|
+
for node in nodes:
|
|
117
|
+
for label in node['labels']:
|
|
118
|
+
queries.append(
|
|
119
|
+
(
|
|
120
|
+
f"""
|
|
121
|
+
UNWIND $nodes AS node
|
|
122
|
+
MERGE (n:Entity {{uuid: node.uuid}})
|
|
123
|
+
SET n:{label}
|
|
124
|
+
SET n = node
|
|
125
|
+
WITH n, node
|
|
126
|
+
SET n.name_embedding = vecf32(node.name_embedding)
|
|
127
|
+
RETURN n.uuid AS uuid
|
|
128
|
+
""",
|
|
129
|
+
{'nodes': [node]},
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
return queries
|
|
133
|
+
else:
|
|
134
|
+
return ENTITY_NODE_SAVE_BULK
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
|
|
138
|
+
if db_type == 'falkordb':
|
|
139
|
+
return """
|
|
140
|
+
UNWIND $entity_edges AS edge
|
|
141
|
+
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
142
|
+
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
143
|
+
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
144
|
+
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
|
145
|
+
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
|
146
|
+
WITH r, edge
|
|
147
|
+
RETURN edge.uuid AS uuid"""
|
|
148
|
+
else:
|
|
149
|
+
return ENTITY_EDGE_SAVE_BULK
|
graphiti_core/graphiti.py
CHANGED
|
@@ -19,12 +19,13 @@ from datetime import datetime
|
|
|
19
19
|
from time import time
|
|
20
20
|
|
|
21
21
|
from dotenv import load_dotenv
|
|
22
|
-
from neo4j import AsyncGraphDatabase
|
|
23
22
|
from pydantic import BaseModel
|
|
24
23
|
from typing_extensions import LiteralString
|
|
25
24
|
|
|
26
25
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
27
26
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
27
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
28
|
+
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|
28
29
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
29
30
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
30
31
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
@@ -94,12 +95,13 @@ class Graphiti:
|
|
|
94
95
|
def __init__(
|
|
95
96
|
self,
|
|
96
97
|
uri: str,
|
|
97
|
-
user: str,
|
|
98
|
-
password: str,
|
|
98
|
+
user: str | None = None,
|
|
99
|
+
password: str | None = None,
|
|
99
100
|
llm_client: LLMClient | None = None,
|
|
100
101
|
embedder: EmbedderClient | None = None,
|
|
101
102
|
cross_encoder: CrossEncoderClient | None = None,
|
|
102
103
|
store_raw_episode_content: bool = True,
|
|
104
|
+
graph_driver: GraphDriver | None = None,
|
|
103
105
|
):
|
|
104
106
|
"""
|
|
105
107
|
Initialize a Graphiti instance.
|
|
@@ -137,7 +139,9 @@ class Graphiti:
|
|
|
137
139
|
Make sure to set the OPENAI_API_KEY environment variable before initializing
|
|
138
140
|
Graphiti if you're using the default OpenAIClient.
|
|
139
141
|
"""
|
|
140
|
-
|
|
142
|
+
|
|
143
|
+
self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
|
|
144
|
+
|
|
141
145
|
self.database = DEFAULT_DATABASE
|
|
142
146
|
self.store_raw_episode_content = store_raw_episode_content
|
|
143
147
|
if llm_client:
|
graphiti_core/graphiti_types.py
CHANGED
|
@@ -14,16 +14,16 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
from neo4j import AsyncDriver
|
|
18
17
|
from pydantic import BaseModel, ConfigDict
|
|
19
18
|
|
|
20
19
|
from graphiti_core.cross_encoder import CrossEncoderClient
|
|
20
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
21
21
|
from graphiti_core.embedder import EmbedderClient
|
|
22
22
|
from graphiti_core.llm_client import LLMClient
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class GraphitiClients(BaseModel):
|
|
26
|
-
driver:
|
|
26
|
+
driver: GraphDriver
|
|
27
27
|
llm_client: LLMClient
|
|
28
28
|
embedder: EmbedderClient
|
|
29
29
|
cross_encoder: CrossEncoderClient
|
graphiti_core/helpers.py
CHANGED
|
@@ -27,7 +27,7 @@ from typing_extensions import LiteralString
|
|
|
27
27
|
|
|
28
28
|
load_dotenv()
|
|
29
29
|
|
|
30
|
-
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE',
|
|
30
|
+
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
|
|
31
31
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
32
32
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
|
33
33
|
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
|
@@ -38,8 +38,14 @@ RUNTIME_QUERY: LiteralString = (
|
|
|
38
38
|
)
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
42
|
-
return
|
|
41
|
+
def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
|
|
42
|
+
return (
|
|
43
|
+
neo_date.to_native()
|
|
44
|
+
if isinstance(neo_date, neo4j_time.DateTime)
|
|
45
|
+
else datetime.fromisoformat(neo_date)
|
|
46
|
+
if neo_date
|
|
47
|
+
else None
|
|
48
|
+
)
|
|
43
49
|
|
|
44
50
|
|
|
45
51
|
def lucene_sanitize(query: str) -> str:
|
|
@@ -1,3 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
1
17
|
from .client import LLMClient
|
|
2
18
|
from .config import LLMConfig
|
|
3
19
|
from .errors import RateLimitError
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from openai import AsyncAzureOpenAI
|
|
22
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
|
|
25
|
+
from ..prompts.models import Message
|
|
26
|
+
from .client import LLMClient
|
|
27
|
+
from .config import LLMConfig, ModelSize
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AzureOpenAILLMClient(LLMClient):
|
|
33
|
+
"""Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
|
|
36
|
+
super().__init__(config, cache=False)
|
|
37
|
+
self.azure_client = azure_client
|
|
38
|
+
|
|
39
|
+
async def _generate_response(
|
|
40
|
+
self,
|
|
41
|
+
messages: list[Message],
|
|
42
|
+
response_model: type[BaseModel] | None = None,
|
|
43
|
+
max_tokens: int = 1024,
|
|
44
|
+
model_size: ModelSize = ModelSize.medium,
|
|
45
|
+
) -> dict[str, Any]:
|
|
46
|
+
"""Generate response using Azure OpenAI client."""
|
|
47
|
+
# Convert messages to OpenAI format
|
|
48
|
+
openai_messages: list[ChatCompletionMessageParam] = []
|
|
49
|
+
for message in messages:
|
|
50
|
+
message.content = self._clean_input(message.content)
|
|
51
|
+
if message.role == 'user':
|
|
52
|
+
openai_messages.append({'role': 'user', 'content': message.content})
|
|
53
|
+
elif message.role == 'system':
|
|
54
|
+
openai_messages.append({'role': 'system', 'content': message.content})
|
|
55
|
+
|
|
56
|
+
# Ensure model is a string
|
|
57
|
+
model_name = self.model if self.model else 'gpt-4o-mini'
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
response = await self.azure_client.chat.completions.create(
|
|
61
|
+
model=model_name,
|
|
62
|
+
messages=openai_messages,
|
|
63
|
+
temperature=float(self.temperature) if self.temperature is not None else 0.7,
|
|
64
|
+
max_tokens=max_tokens,
|
|
65
|
+
response_format={'type': 'json_object'},
|
|
66
|
+
)
|
|
67
|
+
result = response.choices[0].message.content or '{}'
|
|
68
|
+
|
|
69
|
+
# Parse JSON response
|
|
70
|
+
return json.loads(result)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.error(f'Error in Azure OpenAI LLM response: {e}')
|
|
73
|
+
raise
|
graphiti_core/nodes.py
CHANGED
|
@@ -22,13 +22,13 @@ from time import time
|
|
|
22
22
|
from typing import Any
|
|
23
23
|
from uuid import uuid4
|
|
24
24
|
|
|
25
|
-
from neo4j import AsyncDriver
|
|
26
25
|
from pydantic import BaseModel, Field
|
|
27
26
|
from typing_extensions import LiteralString
|
|
28
27
|
|
|
28
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
29
29
|
from graphiti_core.embedder import EmbedderClient
|
|
30
30
|
from graphiti_core.errors import NodeNotFoundError
|
|
31
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
31
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
|
32
32
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
33
33
|
COMMUNITY_NODE_SAVE,
|
|
34
34
|
ENTITY_NODE_SAVE,
|
|
@@ -94,9 +94,9 @@ class Node(BaseModel, ABC):
|
|
|
94
94
|
created_at: datetime = Field(default_factory=lambda: utc_now())
|
|
95
95
|
|
|
96
96
|
@abstractmethod
|
|
97
|
-
async def save(self, driver:
|
|
97
|
+
async def save(self, driver: GraphDriver): ...
|
|
98
98
|
|
|
99
|
-
async def delete(self, driver:
|
|
99
|
+
async def delete(self, driver: GraphDriver):
|
|
100
100
|
result = await driver.execute_query(
|
|
101
101
|
"""
|
|
102
102
|
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
|
@@ -119,7 +119,7 @@ class Node(BaseModel, ABC):
|
|
|
119
119
|
return False
|
|
120
120
|
|
|
121
121
|
@classmethod
|
|
122
|
-
async def delete_by_group_id(cls, driver:
|
|
122
|
+
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
|
|
123
123
|
await driver.execute_query(
|
|
124
124
|
"""
|
|
125
125
|
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
|
@@ -132,10 +132,10 @@ class Node(BaseModel, ABC):
|
|
|
132
132
|
return 'SUCCESS'
|
|
133
133
|
|
|
134
134
|
@classmethod
|
|
135
|
-
async def get_by_uuid(cls, driver:
|
|
135
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
136
136
|
|
|
137
137
|
@classmethod
|
|
138
|
-
async def get_by_uuids(cls, driver:
|
|
138
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
|
|
139
139
|
|
|
140
140
|
|
|
141
141
|
class EpisodicNode(Node):
|
|
@@ -150,7 +150,7 @@ class EpisodicNode(Node):
|
|
|
150
150
|
default_factory=list,
|
|
151
151
|
)
|
|
152
152
|
|
|
153
|
-
async def save(self, driver:
|
|
153
|
+
async def save(self, driver: GraphDriver):
|
|
154
154
|
result = await driver.execute_query(
|
|
155
155
|
EPISODIC_NODE_SAVE,
|
|
156
156
|
uuid=self.uuid,
|
|
@@ -165,12 +165,12 @@ class EpisodicNode(Node):
|
|
|
165
165
|
database_=DEFAULT_DATABASE,
|
|
166
166
|
)
|
|
167
167
|
|
|
168
|
-
logger.debug(f'Saved Node to
|
|
168
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
169
169
|
|
|
170
170
|
return result
|
|
171
171
|
|
|
172
172
|
@classmethod
|
|
173
|
-
async def get_by_uuid(cls, driver:
|
|
173
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
174
174
|
records, _, _ = await driver.execute_query(
|
|
175
175
|
"""
|
|
176
176
|
MATCH (e:Episodic {uuid: $uuid})
|
|
@@ -197,7 +197,7 @@ class EpisodicNode(Node):
|
|
|
197
197
|
return episodes[0]
|
|
198
198
|
|
|
199
199
|
@classmethod
|
|
200
|
-
async def get_by_uuids(cls, driver:
|
|
200
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
201
201
|
records, _, _ = await driver.execute_query(
|
|
202
202
|
"""
|
|
203
203
|
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
|
@@ -224,7 +224,7 @@ class EpisodicNode(Node):
|
|
|
224
224
|
@classmethod
|
|
225
225
|
async def get_by_group_ids(
|
|
226
226
|
cls,
|
|
227
|
-
driver:
|
|
227
|
+
driver: GraphDriver,
|
|
228
228
|
group_ids: list[str],
|
|
229
229
|
limit: int | None = None,
|
|
230
230
|
uuid_cursor: str | None = None,
|
|
@@ -263,7 +263,7 @@ class EpisodicNode(Node):
|
|
|
263
263
|
return episodes
|
|
264
264
|
|
|
265
265
|
@classmethod
|
|
266
|
-
async def get_by_entity_node_uuid(cls, driver:
|
|
266
|
+
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
|
|
267
267
|
records, _, _ = await driver.execute_query(
|
|
268
268
|
"""
|
|
269
269
|
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
|
@@ -304,7 +304,7 @@ class EntityNode(Node):
|
|
|
304
304
|
|
|
305
305
|
return self.name_embedding
|
|
306
306
|
|
|
307
|
-
async def load_name_embedding(self, driver:
|
|
307
|
+
async def load_name_embedding(self, driver: GraphDriver):
|
|
308
308
|
query: LiteralString = """
|
|
309
309
|
MATCH (n:Entity {uuid: $uuid})
|
|
310
310
|
RETURN n.name_embedding AS name_embedding
|
|
@@ -318,7 +318,7 @@ class EntityNode(Node):
|
|
|
318
318
|
|
|
319
319
|
self.name_embedding = records[0]['name_embedding']
|
|
320
320
|
|
|
321
|
-
async def save(self, driver:
|
|
321
|
+
async def save(self, driver: GraphDriver):
|
|
322
322
|
entity_data: dict[str, Any] = {
|
|
323
323
|
'uuid': self.uuid,
|
|
324
324
|
'name': self.name,
|
|
@@ -337,16 +337,16 @@ class EntityNode(Node):
|
|
|
337
337
|
database_=DEFAULT_DATABASE,
|
|
338
338
|
)
|
|
339
339
|
|
|
340
|
-
logger.debug(f'Saved Node to
|
|
340
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
341
341
|
|
|
342
342
|
return result
|
|
343
343
|
|
|
344
344
|
@classmethod
|
|
345
|
-
async def get_by_uuid(cls, driver:
|
|
345
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
346
346
|
query = (
|
|
347
347
|
"""
|
|
348
|
-
|
|
349
|
-
|
|
348
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
349
|
+
"""
|
|
350
350
|
+ ENTITY_NODE_RETURN
|
|
351
351
|
)
|
|
352
352
|
records, _, _ = await driver.execute_query(
|
|
@@ -364,7 +364,7 @@ class EntityNode(Node):
|
|
|
364
364
|
return nodes[0]
|
|
365
365
|
|
|
366
366
|
@classmethod
|
|
367
|
-
async def get_by_uuids(cls, driver:
|
|
367
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
368
368
|
records, _, _ = await driver.execute_query(
|
|
369
369
|
"""
|
|
370
370
|
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
|
@@ -382,7 +382,7 @@ class EntityNode(Node):
|
|
|
382
382
|
@classmethod
|
|
383
383
|
async def get_by_group_ids(
|
|
384
384
|
cls,
|
|
385
|
-
driver:
|
|
385
|
+
driver: GraphDriver,
|
|
386
386
|
group_ids: list[str],
|
|
387
387
|
limit: int | None = None,
|
|
388
388
|
uuid_cursor: str | None = None,
|
|
@@ -416,7 +416,7 @@ class CommunityNode(Node):
|
|
|
416
416
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
417
417
|
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
|
418
418
|
|
|
419
|
-
async def save(self, driver:
|
|
419
|
+
async def save(self, driver: GraphDriver):
|
|
420
420
|
result = await driver.execute_query(
|
|
421
421
|
COMMUNITY_NODE_SAVE,
|
|
422
422
|
uuid=self.uuid,
|
|
@@ -428,7 +428,7 @@ class CommunityNode(Node):
|
|
|
428
428
|
database_=DEFAULT_DATABASE,
|
|
429
429
|
)
|
|
430
430
|
|
|
431
|
-
logger.debug(f'Saved Node to
|
|
431
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
432
432
|
|
|
433
433
|
return result
|
|
434
434
|
|
|
@@ -441,7 +441,7 @@ class CommunityNode(Node):
|
|
|
441
441
|
|
|
442
442
|
return self.name_embedding
|
|
443
443
|
|
|
444
|
-
async def load_name_embedding(self, driver:
|
|
444
|
+
async def load_name_embedding(self, driver: GraphDriver):
|
|
445
445
|
query: LiteralString = """
|
|
446
446
|
MATCH (c:Community {uuid: $uuid})
|
|
447
447
|
RETURN c.name_embedding AS name_embedding
|
|
@@ -456,7 +456,7 @@ class CommunityNode(Node):
|
|
|
456
456
|
self.name_embedding = records[0]['name_embedding']
|
|
457
457
|
|
|
458
458
|
@classmethod
|
|
459
|
-
async def get_by_uuid(cls, driver:
|
|
459
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
460
460
|
records, _, _ = await driver.execute_query(
|
|
461
461
|
"""
|
|
462
462
|
MATCH (n:Community {uuid: $uuid})
|
|
@@ -480,7 +480,7 @@ class CommunityNode(Node):
|
|
|
480
480
|
return nodes[0]
|
|
481
481
|
|
|
482
482
|
@classmethod
|
|
483
|
-
async def get_by_uuids(cls, driver:
|
|
483
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
484
484
|
records, _, _ = await driver.execute_query(
|
|
485
485
|
"""
|
|
486
486
|
MATCH (n:Community) WHERE n.uuid IN $uuids
|
|
@@ -503,7 +503,7 @@ class CommunityNode(Node):
|
|
|
503
503
|
@classmethod
|
|
504
504
|
async def get_by_group_ids(
|
|
505
505
|
cls,
|
|
506
|
-
driver:
|
|
506
|
+
driver: GraphDriver,
|
|
507
507
|
group_ids: list[str],
|
|
508
508
|
limit: int | None = None,
|
|
509
509
|
uuid_cursor: str | None = None,
|
|
@@ -542,8 +542,8 @@ class CommunityNode(Node):
|
|
|
542
542
|
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
543
543
|
return EpisodicNode(
|
|
544
544
|
content=record['content'],
|
|
545
|
-
created_at=record['created_at']
|
|
546
|
-
valid_at=(record['valid_at']
|
|
545
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
546
|
+
valid_at=parse_db_date(record['valid_at']), # type: ignore
|
|
547
547
|
uuid=record['uuid'],
|
|
548
548
|
group_id=record['group_id'],
|
|
549
549
|
source=EpisodeType.from_str(record['source']),
|
|
@@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
|
559
559
|
name=record['name'],
|
|
560
560
|
group_id=record['group_id'],
|
|
561
561
|
labels=record['labels'],
|
|
562
|
-
created_at=record['created_at']
|
|
562
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
563
563
|
summary=record['summary'],
|
|
564
564
|
attributes=record['attributes'],
|
|
565
565
|
)
|
|
@@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|
|
580
580
|
name=record['name'],
|
|
581
581
|
group_id=record['group_id'],
|
|
582
582
|
name_embedding=record['name_embedding'],
|
|
583
|
-
created_at=record['created_at']
|
|
583
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
584
584
|
summary=record['summary'],
|
|
585
585
|
)
|
|
586
586
|
|
graphiti_core/search/search.py
CHANGED
|
@@ -18,9 +18,8 @@ import logging
|
|
|
18
18
|
from collections import defaultdict
|
|
19
19
|
from time import time
|
|
20
20
|
|
|
21
|
-
from neo4j import AsyncDriver
|
|
22
|
-
|
|
23
21
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
22
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
24
23
|
from graphiti_core.edges import EntityEdge
|
|
25
24
|
from graphiti_core.errors import SearchRerankerError
|
|
26
25
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
@@ -94,7 +93,7 @@ async def search(
|
|
|
94
93
|
)
|
|
95
94
|
|
|
96
95
|
# if group_ids is empty, set it to None
|
|
97
|
-
group_ids = group_ids if group_ids else None
|
|
96
|
+
group_ids = group_ids if group_ids and group_ids != [''] else None
|
|
98
97
|
edges, nodes, episodes, communities = await semaphore_gather(
|
|
99
98
|
edge_search(
|
|
100
99
|
driver,
|
|
@@ -160,7 +159,7 @@ async def search(
|
|
|
160
159
|
|
|
161
160
|
|
|
162
161
|
async def edge_search(
|
|
163
|
-
driver:
|
|
162
|
+
driver: GraphDriver,
|
|
164
163
|
cross_encoder: CrossEncoderClient,
|
|
165
164
|
query: str,
|
|
166
165
|
query_vector: list[float],
|
|
@@ -174,7 +173,6 @@ async def edge_search(
|
|
|
174
173
|
) -> list[EntityEdge]:
|
|
175
174
|
if config is None:
|
|
176
175
|
return []
|
|
177
|
-
|
|
178
176
|
search_results: list[list[EntityEdge]] = list(
|
|
179
177
|
await semaphore_gather(
|
|
180
178
|
*[
|
|
@@ -261,7 +259,7 @@ async def edge_search(
|
|
|
261
259
|
|
|
262
260
|
|
|
263
261
|
async def node_search(
|
|
264
|
-
driver:
|
|
262
|
+
driver: GraphDriver,
|
|
265
263
|
cross_encoder: CrossEncoderClient,
|
|
266
264
|
query: str,
|
|
267
265
|
query_vector: list[float],
|
|
@@ -275,7 +273,6 @@ async def node_search(
|
|
|
275
273
|
) -> list[EntityNode]:
|
|
276
274
|
if config is None:
|
|
277
275
|
return []
|
|
278
|
-
|
|
279
276
|
search_results: list[list[EntityNode]] = list(
|
|
280
277
|
await semaphore_gather(
|
|
281
278
|
*[
|
|
@@ -344,7 +341,7 @@ async def node_search(
|
|
|
344
341
|
|
|
345
342
|
|
|
346
343
|
async def episode_search(
|
|
347
|
-
driver:
|
|
344
|
+
driver: GraphDriver,
|
|
348
345
|
cross_encoder: CrossEncoderClient,
|
|
349
346
|
query: str,
|
|
350
347
|
_query_vector: list[float],
|
|
@@ -356,7 +353,6 @@ async def episode_search(
|
|
|
356
353
|
) -> list[EpisodicNode]:
|
|
357
354
|
if config is None:
|
|
358
355
|
return []
|
|
359
|
-
|
|
360
356
|
search_results: list[list[EpisodicNode]] = list(
|
|
361
357
|
await semaphore_gather(
|
|
362
358
|
*[
|
|
@@ -392,7 +388,7 @@ async def episode_search(
|
|
|
392
388
|
|
|
393
389
|
|
|
394
390
|
async def community_search(
|
|
395
|
-
driver:
|
|
391
|
+
driver: GraphDriver,
|
|
396
392
|
cross_encoder: CrossEncoderClient,
|
|
397
393
|
query: str,
|
|
398
394
|
query_vector: list[float],
|