graphiti-core 0.12.0rc5__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/driver/__init__.py +17 -0
- graphiti_core/driver/driver.py +66 -0
- graphiti_core/driver/falkordb_driver.py +131 -0
- graphiti_core/driver/neo4j_driver.py +61 -0
- graphiti_core/edges.py +26 -26
- graphiti_core/embedder/azure_openai.py +64 -0
- graphiti_core/graph_queries.py +149 -0
- graphiti_core/graphiti.py +21 -8
- graphiti_core/graphiti_types.py +2 -2
- graphiti_core/helpers.py +9 -3
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/azure_openai_client.py +73 -0
- graphiti_core/nodes.py +31 -31
- graphiti_core/prompts/dedupe_nodes.py +5 -1
- graphiti_core/prompts/extract_edges.py +2 -0
- graphiti_core/prompts/extract_nodes.py +2 -0
- graphiti_core/search/search.py +6 -10
- graphiti_core/search/search_utils.py +243 -187
- graphiti_core/utils/bulk_utils.py +21 -11
- graphiti_core/utils/maintenance/community_operations.py +6 -7
- graphiti_core/utils/maintenance/edge_operations.py +68 -3
- graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
- graphiti_core/utils/maintenance/node_operations.py +19 -5
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.2.dist-info}/METADATA +4 -3
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.2.dist-info}/RECORD +28 -21
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.2.dist-info}/LICENSE +0 -0
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.2.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Database query utilities for different graph database backends.
|
|
3
|
+
|
|
4
|
+
This module provides database-agnostic query generation for Neo4j and FalkorDB,
|
|
5
|
+
supporting index creation, fulltext search, and bulk operations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from typing_extensions import LiteralString
|
|
11
|
+
|
|
12
|
+
from graphiti_core.models.edges.edge_db_queries import (
|
|
13
|
+
ENTITY_EDGE_SAVE_BULK,
|
|
14
|
+
)
|
|
15
|
+
from graphiti_core.models.nodes.node_db_queries import (
|
|
16
|
+
ENTITY_NODE_SAVE_BULK,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Mapping from Neo4j fulltext index names to FalkorDB node labels
|
|
20
|
+
NEO4J_TO_FALKORDB_MAPPING = {
|
|
21
|
+
'node_name_and_summary': 'Entity',
|
|
22
|
+
'community_name': 'Community',
|
|
23
|
+
'episode_content': 'Episodic',
|
|
24
|
+
'edge_name_and_fact': 'RELATES_TO',
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
|
29
|
+
if db_type == 'falkordb':
|
|
30
|
+
return [
|
|
31
|
+
# Entity node
|
|
32
|
+
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
|
|
33
|
+
# Episodic node
|
|
34
|
+
'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
|
|
35
|
+
# Community node
|
|
36
|
+
'CREATE INDEX FOR (n:Community) ON (n.uuid)',
|
|
37
|
+
# RELATES_TO edge
|
|
38
|
+
'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
|
|
39
|
+
# MENTIONS edge
|
|
40
|
+
'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
|
|
41
|
+
# HAS_MEMBER edge
|
|
42
|
+
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
43
|
+
]
|
|
44
|
+
else:
|
|
45
|
+
return [
|
|
46
|
+
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
47
|
+
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
48
|
+
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
|
49
|
+
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
|
50
|
+
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
|
51
|
+
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
52
|
+
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
|
53
|
+
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
|
54
|
+
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
|
55
|
+
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
|
56
|
+
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
|
57
|
+
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
|
58
|
+
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
|
59
|
+
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
|
60
|
+
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
|
61
|
+
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
|
62
|
+
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
|
63
|
+
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
|
64
|
+
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
|
69
|
+
if db_type == 'falkordb':
|
|
70
|
+
return [
|
|
71
|
+
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
|
|
72
|
+
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
|
|
73
|
+
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
|
|
74
|
+
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
|
75
|
+
]
|
|
76
|
+
else:
|
|
77
|
+
return [
|
|
78
|
+
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
|
79
|
+
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
|
80
|
+
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
|
81
|
+
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
|
82
|
+
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
|
83
|
+
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
|
84
|
+
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
|
85
|
+
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
|
|
90
|
+
if db_type == 'falkordb':
|
|
91
|
+
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
92
|
+
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
|
93
|
+
else:
|
|
94
|
+
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
|
|
98
|
+
if db_type == 'falkordb':
|
|
99
|
+
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
|
100
|
+
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
|
101
|
+
else:
|
|
102
|
+
return f'vector.similarity.cosine({vec1}, {vec2})'
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
|
|
106
|
+
if db_type == 'falkordb':
|
|
107
|
+
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
108
|
+
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
|
109
|
+
else:
|
|
110
|
+
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
|
|
114
|
+
if db_type == 'falkordb':
|
|
115
|
+
queries = []
|
|
116
|
+
for node in nodes:
|
|
117
|
+
for label in node['labels']:
|
|
118
|
+
queries.append(
|
|
119
|
+
(
|
|
120
|
+
f"""
|
|
121
|
+
UNWIND $nodes AS node
|
|
122
|
+
MERGE (n:Entity {{uuid: node.uuid}})
|
|
123
|
+
SET n:{label}
|
|
124
|
+
SET n = node
|
|
125
|
+
WITH n, node
|
|
126
|
+
SET n.name_embedding = vecf32(node.name_embedding)
|
|
127
|
+
RETURN n.uuid AS uuid
|
|
128
|
+
""",
|
|
129
|
+
{'nodes': [node]},
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
return queries
|
|
133
|
+
else:
|
|
134
|
+
return ENTITY_NODE_SAVE_BULK
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
|
|
138
|
+
if db_type == 'falkordb':
|
|
139
|
+
return """
|
|
140
|
+
UNWIND $entity_edges AS edge
|
|
141
|
+
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
142
|
+
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
143
|
+
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
144
|
+
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
|
145
|
+
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
|
146
|
+
WITH r, edge
|
|
147
|
+
RETURN edge.uuid AS uuid"""
|
|
148
|
+
else:
|
|
149
|
+
return ENTITY_EDGE_SAVE_BULK
|
graphiti_core/graphiti.py
CHANGED
|
@@ -19,12 +19,13 @@ from datetime import datetime
|
|
|
19
19
|
from time import time
|
|
20
20
|
|
|
21
21
|
from dotenv import load_dotenv
|
|
22
|
-
from neo4j import AsyncGraphDatabase
|
|
23
22
|
from pydantic import BaseModel
|
|
24
23
|
from typing_extensions import LiteralString
|
|
25
24
|
|
|
26
25
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
27
26
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
27
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
28
|
+
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|
28
29
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
29
30
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
30
31
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
@@ -62,6 +63,7 @@ from graphiti_core.utils.maintenance.community_operations import (
|
|
|
62
63
|
update_community,
|
|
63
64
|
)
|
|
64
65
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
66
|
+
build_duplicate_of_edges,
|
|
65
67
|
build_episodic_edges,
|
|
66
68
|
extract_edges,
|
|
67
69
|
resolve_extracted_edge,
|
|
@@ -94,12 +96,13 @@ class Graphiti:
|
|
|
94
96
|
def __init__(
|
|
95
97
|
self,
|
|
96
98
|
uri: str,
|
|
97
|
-
user: str,
|
|
98
|
-
password: str,
|
|
99
|
+
user: str | None = None,
|
|
100
|
+
password: str | None = None,
|
|
99
101
|
llm_client: LLMClient | None = None,
|
|
100
102
|
embedder: EmbedderClient | None = None,
|
|
101
103
|
cross_encoder: CrossEncoderClient | None = None,
|
|
102
104
|
store_raw_episode_content: bool = True,
|
|
105
|
+
graph_driver: GraphDriver | None = None,
|
|
103
106
|
):
|
|
104
107
|
"""
|
|
105
108
|
Initialize a Graphiti instance.
|
|
@@ -137,7 +140,9 @@ class Graphiti:
|
|
|
137
140
|
Make sure to set the OPENAI_API_KEY environment variable before initializing
|
|
138
141
|
Graphiti if you're using the default OpenAIClient.
|
|
139
142
|
"""
|
|
140
|
-
|
|
143
|
+
|
|
144
|
+
self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
|
|
145
|
+
|
|
141
146
|
self.database = DEFAULT_DATABASE
|
|
142
147
|
self.store_raw_episode_content = store_raw_episode_content
|
|
143
148
|
if llm_client:
|
|
@@ -371,7 +376,7 @@ class Graphiti:
|
|
|
371
376
|
)
|
|
372
377
|
|
|
373
378
|
# Extract edges and resolve nodes
|
|
374
|
-
(nodes, uuid_map), extracted_edges = await semaphore_gather(
|
|
379
|
+
(nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather(
|
|
375
380
|
resolve_extracted_nodes(
|
|
376
381
|
self.clients,
|
|
377
382
|
extracted_nodes,
|
|
@@ -380,7 +385,13 @@ class Graphiti:
|
|
|
380
385
|
entity_types,
|
|
381
386
|
),
|
|
382
387
|
extract_edges(
|
|
383
|
-
self.clients,
|
|
388
|
+
self.clients,
|
|
389
|
+
episode,
|
|
390
|
+
extracted_nodes,
|
|
391
|
+
previous_episodes,
|
|
392
|
+
edge_type_map or edge_type_map_default,
|
|
393
|
+
group_id,
|
|
394
|
+
edge_types,
|
|
384
395
|
),
|
|
385
396
|
)
|
|
386
397
|
|
|
@@ -400,7 +411,9 @@ class Graphiti:
|
|
|
400
411
|
),
|
|
401
412
|
)
|
|
402
413
|
|
|
403
|
-
|
|
414
|
+
duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
|
|
415
|
+
|
|
416
|
+
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
|
|
404
417
|
|
|
405
418
|
episodic_edges = build_episodic_edges(nodes, episode, now)
|
|
406
419
|
|
|
@@ -687,7 +700,7 @@ class Graphiti:
|
|
|
687
700
|
if edge.fact_embedding is None:
|
|
688
701
|
await edge.generate_embedding(self.embedder)
|
|
689
702
|
|
|
690
|
-
resolved_nodes, uuid_map = await resolve_extracted_nodes(
|
|
703
|
+
resolved_nodes, uuid_map, _ = await resolve_extracted_nodes(
|
|
691
704
|
self.clients,
|
|
692
705
|
[source_node, target_node],
|
|
693
706
|
)
|
graphiti_core/graphiti_types.py
CHANGED
|
@@ -14,16 +14,16 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
from neo4j import AsyncDriver
|
|
18
17
|
from pydantic import BaseModel, ConfigDict
|
|
19
18
|
|
|
20
19
|
from graphiti_core.cross_encoder import CrossEncoderClient
|
|
20
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
21
21
|
from graphiti_core.embedder import EmbedderClient
|
|
22
22
|
from graphiti_core.llm_client import LLMClient
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class GraphitiClients(BaseModel):
|
|
26
|
-
driver:
|
|
26
|
+
driver: GraphDriver
|
|
27
27
|
llm_client: LLMClient
|
|
28
28
|
embedder: EmbedderClient
|
|
29
29
|
cross_encoder: CrossEncoderClient
|
graphiti_core/helpers.py
CHANGED
|
@@ -27,7 +27,7 @@ from typing_extensions import LiteralString
|
|
|
27
27
|
|
|
28
28
|
load_dotenv()
|
|
29
29
|
|
|
30
|
-
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE',
|
|
30
|
+
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
|
|
31
31
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
32
32
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
|
33
33
|
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
|
@@ -38,8 +38,14 @@ RUNTIME_QUERY: LiteralString = (
|
|
|
38
38
|
)
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
42
|
-
return
|
|
41
|
+
def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
|
|
42
|
+
return (
|
|
43
|
+
neo_date.to_native()
|
|
44
|
+
if isinstance(neo_date, neo4j_time.DateTime)
|
|
45
|
+
else datetime.fromisoformat(neo_date)
|
|
46
|
+
if neo_date
|
|
47
|
+
else None
|
|
48
|
+
)
|
|
43
49
|
|
|
44
50
|
|
|
45
51
|
def lucene_sanitize(query: str) -> str:
|
|
@@ -1,3 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
1
17
|
from .client import LLMClient
|
|
2
18
|
from .config import LLMConfig
|
|
3
19
|
from .errors import RateLimitError
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from openai import AsyncAzureOpenAI
|
|
22
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
|
|
25
|
+
from ..prompts.models import Message
|
|
26
|
+
from .client import LLMClient
|
|
27
|
+
from .config import LLMConfig, ModelSize
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AzureOpenAILLMClient(LLMClient):
|
|
33
|
+
"""Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
|
|
36
|
+
super().__init__(config, cache=False)
|
|
37
|
+
self.azure_client = azure_client
|
|
38
|
+
|
|
39
|
+
async def _generate_response(
|
|
40
|
+
self,
|
|
41
|
+
messages: list[Message],
|
|
42
|
+
response_model: type[BaseModel] | None = None,
|
|
43
|
+
max_tokens: int = 1024,
|
|
44
|
+
model_size: ModelSize = ModelSize.medium,
|
|
45
|
+
) -> dict[str, Any]:
|
|
46
|
+
"""Generate response using Azure OpenAI client."""
|
|
47
|
+
# Convert messages to OpenAI format
|
|
48
|
+
openai_messages: list[ChatCompletionMessageParam] = []
|
|
49
|
+
for message in messages:
|
|
50
|
+
message.content = self._clean_input(message.content)
|
|
51
|
+
if message.role == 'user':
|
|
52
|
+
openai_messages.append({'role': 'user', 'content': message.content})
|
|
53
|
+
elif message.role == 'system':
|
|
54
|
+
openai_messages.append({'role': 'system', 'content': message.content})
|
|
55
|
+
|
|
56
|
+
# Ensure model is a string
|
|
57
|
+
model_name = self.model if self.model else 'gpt-4o-mini'
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
response = await self.azure_client.chat.completions.create(
|
|
61
|
+
model=model_name,
|
|
62
|
+
messages=openai_messages,
|
|
63
|
+
temperature=float(self.temperature) if self.temperature is not None else 0.7,
|
|
64
|
+
max_tokens=max_tokens,
|
|
65
|
+
response_format={'type': 'json_object'},
|
|
66
|
+
)
|
|
67
|
+
result = response.choices[0].message.content or '{}'
|
|
68
|
+
|
|
69
|
+
# Parse JSON response
|
|
70
|
+
return json.loads(result)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.error(f'Error in Azure OpenAI LLM response: {e}')
|
|
73
|
+
raise
|
graphiti_core/nodes.py
CHANGED
|
@@ -22,13 +22,13 @@ from time import time
|
|
|
22
22
|
from typing import Any
|
|
23
23
|
from uuid import uuid4
|
|
24
24
|
|
|
25
|
-
from neo4j import AsyncDriver
|
|
26
25
|
from pydantic import BaseModel, Field
|
|
27
26
|
from typing_extensions import LiteralString
|
|
28
27
|
|
|
28
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
29
29
|
from graphiti_core.embedder import EmbedderClient
|
|
30
30
|
from graphiti_core.errors import NodeNotFoundError
|
|
31
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
31
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
|
32
32
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
33
33
|
COMMUNITY_NODE_SAVE,
|
|
34
34
|
ENTITY_NODE_SAVE,
|
|
@@ -94,9 +94,9 @@ class Node(BaseModel, ABC):
|
|
|
94
94
|
created_at: datetime = Field(default_factory=lambda: utc_now())
|
|
95
95
|
|
|
96
96
|
@abstractmethod
|
|
97
|
-
async def save(self, driver:
|
|
97
|
+
async def save(self, driver: GraphDriver): ...
|
|
98
98
|
|
|
99
|
-
async def delete(self, driver:
|
|
99
|
+
async def delete(self, driver: GraphDriver):
|
|
100
100
|
result = await driver.execute_query(
|
|
101
101
|
"""
|
|
102
102
|
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
|
@@ -119,7 +119,7 @@ class Node(BaseModel, ABC):
|
|
|
119
119
|
return False
|
|
120
120
|
|
|
121
121
|
@classmethod
|
|
122
|
-
async def delete_by_group_id(cls, driver:
|
|
122
|
+
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
|
|
123
123
|
await driver.execute_query(
|
|
124
124
|
"""
|
|
125
125
|
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
|
@@ -132,10 +132,10 @@ class Node(BaseModel, ABC):
|
|
|
132
132
|
return 'SUCCESS'
|
|
133
133
|
|
|
134
134
|
@classmethod
|
|
135
|
-
async def get_by_uuid(cls, driver:
|
|
135
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
136
136
|
|
|
137
137
|
@classmethod
|
|
138
|
-
async def get_by_uuids(cls, driver:
|
|
138
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
|
|
139
139
|
|
|
140
140
|
|
|
141
141
|
class EpisodicNode(Node):
|
|
@@ -150,7 +150,7 @@ class EpisodicNode(Node):
|
|
|
150
150
|
default_factory=list,
|
|
151
151
|
)
|
|
152
152
|
|
|
153
|
-
async def save(self, driver:
|
|
153
|
+
async def save(self, driver: GraphDriver):
|
|
154
154
|
result = await driver.execute_query(
|
|
155
155
|
EPISODIC_NODE_SAVE,
|
|
156
156
|
uuid=self.uuid,
|
|
@@ -165,12 +165,12 @@ class EpisodicNode(Node):
|
|
|
165
165
|
database_=DEFAULT_DATABASE,
|
|
166
166
|
)
|
|
167
167
|
|
|
168
|
-
logger.debug(f'Saved Node to
|
|
168
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
169
169
|
|
|
170
170
|
return result
|
|
171
171
|
|
|
172
172
|
@classmethod
|
|
173
|
-
async def get_by_uuid(cls, driver:
|
|
173
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
174
174
|
records, _, _ = await driver.execute_query(
|
|
175
175
|
"""
|
|
176
176
|
MATCH (e:Episodic {uuid: $uuid})
|
|
@@ -197,7 +197,7 @@ class EpisodicNode(Node):
|
|
|
197
197
|
return episodes[0]
|
|
198
198
|
|
|
199
199
|
@classmethod
|
|
200
|
-
async def get_by_uuids(cls, driver:
|
|
200
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
201
201
|
records, _, _ = await driver.execute_query(
|
|
202
202
|
"""
|
|
203
203
|
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
|
@@ -224,7 +224,7 @@ class EpisodicNode(Node):
|
|
|
224
224
|
@classmethod
|
|
225
225
|
async def get_by_group_ids(
|
|
226
226
|
cls,
|
|
227
|
-
driver:
|
|
227
|
+
driver: GraphDriver,
|
|
228
228
|
group_ids: list[str],
|
|
229
229
|
limit: int | None = None,
|
|
230
230
|
uuid_cursor: str | None = None,
|
|
@@ -263,7 +263,7 @@ class EpisodicNode(Node):
|
|
|
263
263
|
return episodes
|
|
264
264
|
|
|
265
265
|
@classmethod
|
|
266
|
-
async def get_by_entity_node_uuid(cls, driver:
|
|
266
|
+
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
|
|
267
267
|
records, _, _ = await driver.execute_query(
|
|
268
268
|
"""
|
|
269
269
|
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
|
@@ -304,7 +304,7 @@ class EntityNode(Node):
|
|
|
304
304
|
|
|
305
305
|
return self.name_embedding
|
|
306
306
|
|
|
307
|
-
async def load_name_embedding(self, driver:
|
|
307
|
+
async def load_name_embedding(self, driver: GraphDriver):
|
|
308
308
|
query: LiteralString = """
|
|
309
309
|
MATCH (n:Entity {uuid: $uuid})
|
|
310
310
|
RETURN n.name_embedding AS name_embedding
|
|
@@ -318,7 +318,7 @@ class EntityNode(Node):
|
|
|
318
318
|
|
|
319
319
|
self.name_embedding = records[0]['name_embedding']
|
|
320
320
|
|
|
321
|
-
async def save(self, driver:
|
|
321
|
+
async def save(self, driver: GraphDriver):
|
|
322
322
|
entity_data: dict[str, Any] = {
|
|
323
323
|
'uuid': self.uuid,
|
|
324
324
|
'name': self.name,
|
|
@@ -337,16 +337,16 @@ class EntityNode(Node):
|
|
|
337
337
|
database_=DEFAULT_DATABASE,
|
|
338
338
|
)
|
|
339
339
|
|
|
340
|
-
logger.debug(f'Saved Node to
|
|
340
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
341
341
|
|
|
342
342
|
return result
|
|
343
343
|
|
|
344
344
|
@classmethod
|
|
345
|
-
async def get_by_uuid(cls, driver:
|
|
345
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
346
346
|
query = (
|
|
347
347
|
"""
|
|
348
|
-
|
|
349
|
-
|
|
348
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
349
|
+
"""
|
|
350
350
|
+ ENTITY_NODE_RETURN
|
|
351
351
|
)
|
|
352
352
|
records, _, _ = await driver.execute_query(
|
|
@@ -364,7 +364,7 @@ class EntityNode(Node):
|
|
|
364
364
|
return nodes[0]
|
|
365
365
|
|
|
366
366
|
@classmethod
|
|
367
|
-
async def get_by_uuids(cls, driver:
|
|
367
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
368
368
|
records, _, _ = await driver.execute_query(
|
|
369
369
|
"""
|
|
370
370
|
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
|
@@ -382,7 +382,7 @@ class EntityNode(Node):
|
|
|
382
382
|
@classmethod
|
|
383
383
|
async def get_by_group_ids(
|
|
384
384
|
cls,
|
|
385
|
-
driver:
|
|
385
|
+
driver: GraphDriver,
|
|
386
386
|
group_ids: list[str],
|
|
387
387
|
limit: int | None = None,
|
|
388
388
|
uuid_cursor: str | None = None,
|
|
@@ -416,7 +416,7 @@ class CommunityNode(Node):
|
|
|
416
416
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
417
417
|
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
|
418
418
|
|
|
419
|
-
async def save(self, driver:
|
|
419
|
+
async def save(self, driver: GraphDriver):
|
|
420
420
|
result = await driver.execute_query(
|
|
421
421
|
COMMUNITY_NODE_SAVE,
|
|
422
422
|
uuid=self.uuid,
|
|
@@ -428,7 +428,7 @@ class CommunityNode(Node):
|
|
|
428
428
|
database_=DEFAULT_DATABASE,
|
|
429
429
|
)
|
|
430
430
|
|
|
431
|
-
logger.debug(f'Saved Node to
|
|
431
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
432
432
|
|
|
433
433
|
return result
|
|
434
434
|
|
|
@@ -441,7 +441,7 @@ class CommunityNode(Node):
|
|
|
441
441
|
|
|
442
442
|
return self.name_embedding
|
|
443
443
|
|
|
444
|
-
async def load_name_embedding(self, driver:
|
|
444
|
+
async def load_name_embedding(self, driver: GraphDriver):
|
|
445
445
|
query: LiteralString = """
|
|
446
446
|
MATCH (c:Community {uuid: $uuid})
|
|
447
447
|
RETURN c.name_embedding AS name_embedding
|
|
@@ -456,7 +456,7 @@ class CommunityNode(Node):
|
|
|
456
456
|
self.name_embedding = records[0]['name_embedding']
|
|
457
457
|
|
|
458
458
|
@classmethod
|
|
459
|
-
async def get_by_uuid(cls, driver:
|
|
459
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
460
460
|
records, _, _ = await driver.execute_query(
|
|
461
461
|
"""
|
|
462
462
|
MATCH (n:Community {uuid: $uuid})
|
|
@@ -480,7 +480,7 @@ class CommunityNode(Node):
|
|
|
480
480
|
return nodes[0]
|
|
481
481
|
|
|
482
482
|
@classmethod
|
|
483
|
-
async def get_by_uuids(cls, driver:
|
|
483
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
484
484
|
records, _, _ = await driver.execute_query(
|
|
485
485
|
"""
|
|
486
486
|
MATCH (n:Community) WHERE n.uuid IN $uuids
|
|
@@ -503,7 +503,7 @@ class CommunityNode(Node):
|
|
|
503
503
|
@classmethod
|
|
504
504
|
async def get_by_group_ids(
|
|
505
505
|
cls,
|
|
506
|
-
driver:
|
|
506
|
+
driver: GraphDriver,
|
|
507
507
|
group_ids: list[str],
|
|
508
508
|
limit: int | None = None,
|
|
509
509
|
uuid_cursor: str | None = None,
|
|
@@ -542,8 +542,8 @@ class CommunityNode(Node):
|
|
|
542
542
|
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
543
543
|
return EpisodicNode(
|
|
544
544
|
content=record['content'],
|
|
545
|
-
created_at=record['created_at']
|
|
546
|
-
valid_at=(record['valid_at']
|
|
545
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
546
|
+
valid_at=parse_db_date(record['valid_at']), # type: ignore
|
|
547
547
|
uuid=record['uuid'],
|
|
548
548
|
group_id=record['group_id'],
|
|
549
549
|
source=EpisodeType.from_str(record['source']),
|
|
@@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
|
559
559
|
name=record['name'],
|
|
560
560
|
group_id=record['group_id'],
|
|
561
561
|
labels=record['labels'],
|
|
562
|
-
created_at=record['created_at']
|
|
562
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
563
563
|
summary=record['summary'],
|
|
564
564
|
attributes=record['attributes'],
|
|
565
565
|
)
|
|
@@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|
|
580
580
|
name=record['name'],
|
|
581
581
|
group_id=record['group_id'],
|
|
582
582
|
name_embedding=record['name_embedding'],
|
|
583
|
-
created_at=record['created_at']
|
|
583
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
584
584
|
summary=record['summary'],
|
|
585
585
|
)
|
|
586
586
|
|
|
@@ -26,12 +26,16 @@ class NodeDuplicate(BaseModel):
|
|
|
26
26
|
id: int = Field(..., description='integer id of the entity')
|
|
27
27
|
duplicate_idx: int = Field(
|
|
28
28
|
...,
|
|
29
|
-
description='idx of the duplicate
|
|
29
|
+
description='idx of the duplicate entity. If no duplicate entities are found, default to -1.',
|
|
30
30
|
)
|
|
31
31
|
name: str = Field(
|
|
32
32
|
...,
|
|
33
33
|
description='Name of the entity. Should be the most complete and descriptive name possible.',
|
|
34
34
|
)
|
|
35
|
+
additional_duplicates: list[int] = Field(
|
|
36
|
+
...,
|
|
37
|
+
description='idx of additional duplicate entities. Use this list if the entity has multiple duplicates among existing entities.',
|
|
38
|
+
)
|
|
35
39
|
|
|
36
40
|
|
|
37
41
|
class NodeResolutions(BaseModel):
|
|
@@ -97,6 +97,8 @@ Only extract facts that:
|
|
|
97
97
|
- The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
|
|
98
98
|
- The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
|
|
99
99
|
of the FACT TYPES
|
|
100
|
+
- The FACT TYPES each contain their fact_type_signature which represents the entity types which that fact_type is defined for.
|
|
101
|
+
A Type of Entity in the signature represents any extracted entity (it is a generic universal type for all entities).
|
|
100
102
|
|
|
101
103
|
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
|
|
102
104
|
|
|
@@ -90,6 +90,8 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
|
|
|
90
90
|
Instructions:
|
|
91
91
|
|
|
92
92
|
You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
|
|
93
|
+
Pronoun references such as he/she/they or this/that/those should be disambiguated to the names of the
|
|
94
|
+
reference entities.
|
|
93
95
|
|
|
94
96
|
1. **Speaker Extraction**: Always extract the speaker (the part before the colon `:` in each dialogue line) as the first entity node.
|
|
95
97
|
- If the speaker is mentioned again in the message, treat both mentions as a **single entity**.
|