graphiti-core 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +28 -0
- graphiti_core/driver/falkordb_driver.py +112 -0
- graphiti_core/driver/kuzu_driver.py +1 -0
- graphiti_core/driver/neo4j_driver.py +10 -2
- graphiti_core/driver/neptune_driver.py +4 -6
- graphiti_core/edges.py +67 -7
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +35 -6
- graphiti_core/graphiti.py +36 -24
- graphiti_core/graphiti_types.py +0 -1
- graphiti_core/helpers.py +2 -2
- graphiti_core/llm_client/client.py +19 -4
- graphiti_core/llm_client/gemini_client.py +4 -2
- graphiti_core/llm_client/openai_base_client.py +3 -2
- graphiti_core/llm_client/openai_generic_client.py +3 -2
- graphiti_core/models/edges/edge_db_queries.py +36 -16
- graphiti_core/models/nodes/node_db_queries.py +30 -10
- graphiti_core/nodes.py +126 -25
- graphiti_core/prompts/dedupe_edges.py +40 -29
- graphiti_core/prompts/dedupe_nodes.py +51 -34
- graphiti_core/prompts/eval.py +3 -3
- graphiti_core/prompts/extract_edges.py +17 -9
- graphiti_core/prompts/extract_nodes.py +10 -9
- graphiti_core/prompts/prompt_helpers.py +3 -3
- graphiti_core/prompts/summarize_nodes.py +5 -5
- graphiti_core/search/search_filters.py +53 -0
- graphiti_core/search/search_helpers.py +5 -7
- graphiti_core/search/search_utils.py +227 -57
- graphiti_core/utils/bulk_utils.py +168 -69
- graphiti_core/utils/maintenance/community_operations.py +8 -20
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +187 -50
- graphiti_core/utils/maintenance/graph_data_operations.py +9 -5
- graphiti_core/utils/maintenance/node_operations.py +244 -88
- graphiti_core/utils/maintenance/temporal_operations.py +0 -4
- {graphiti_core-0.20.3.dist-info → graphiti_core-0.21.0.dist-info}/METADATA +7 -1
- {graphiti_core-0.20.3.dist-info → graphiti_core-0.21.0.dist-info}/RECORD +39 -38
- {graphiti_core-0.20.3.dist-info → graphiti_core-0.21.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.20.3.dist-info → graphiti_core-0.21.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/driver/driver.py
CHANGED
|
@@ -16,13 +16,25 @@ limitations under the License.
|
|
|
16
16
|
|
|
17
17
|
import copy
|
|
18
18
|
import logging
|
|
19
|
+
import os
|
|
19
20
|
from abc import ABC, abstractmethod
|
|
20
21
|
from collections.abc import Coroutine
|
|
21
22
|
from enum import Enum
|
|
22
23
|
from typing import Any
|
|
23
24
|
|
|
25
|
+
from dotenv import load_dotenv
|
|
26
|
+
|
|
24
27
|
logger = logging.getLogger(__name__)
|
|
25
28
|
|
|
29
|
+
DEFAULT_SIZE = 10
|
|
30
|
+
|
|
31
|
+
load_dotenv()
|
|
32
|
+
|
|
33
|
+
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
|
|
34
|
+
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
|
|
35
|
+
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
|
36
|
+
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
|
|
37
|
+
|
|
26
38
|
|
|
27
39
|
class GraphProvider(Enum):
|
|
28
40
|
NEO4J = 'neo4j'
|
|
@@ -61,6 +73,7 @@ class GraphDriver(ABC):
|
|
|
61
73
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
|
62
74
|
)
|
|
63
75
|
_database: str
|
|
76
|
+
aoss_client: Any # type: ignore
|
|
64
77
|
|
|
65
78
|
@abstractmethod
|
|
66
79
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
|
@@ -87,3 +100,18 @@ class GraphDriver(ABC):
|
|
|
87
100
|
cloned._database = database
|
|
88
101
|
|
|
89
102
|
return cloned
|
|
103
|
+
|
|
104
|
+
def build_fulltext_query(
|
|
105
|
+
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
|
106
|
+
) -> str:
|
|
107
|
+
"""
|
|
108
|
+
Specific fulltext query builder for database providers.
|
|
109
|
+
Only implemented by providers that need custom fulltext query building.
|
|
110
|
+
"""
|
|
111
|
+
raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
|
|
112
|
+
|
|
113
|
+
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
114
|
+
return 0
|
|
115
|
+
|
|
116
|
+
async def clear_aoss_indices(self):
|
|
117
|
+
return 1
|
|
@@ -36,6 +36,42 @@ from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
|
|
36
36
|
|
|
37
37
|
logger = logging.getLogger(__name__)
|
|
38
38
|
|
|
39
|
+
STOPWORDS = [
|
|
40
|
+
'a',
|
|
41
|
+
'is',
|
|
42
|
+
'the',
|
|
43
|
+
'an',
|
|
44
|
+
'and',
|
|
45
|
+
'are',
|
|
46
|
+
'as',
|
|
47
|
+
'at',
|
|
48
|
+
'be',
|
|
49
|
+
'but',
|
|
50
|
+
'by',
|
|
51
|
+
'for',
|
|
52
|
+
'if',
|
|
53
|
+
'in',
|
|
54
|
+
'into',
|
|
55
|
+
'it',
|
|
56
|
+
'no',
|
|
57
|
+
'not',
|
|
58
|
+
'of',
|
|
59
|
+
'on',
|
|
60
|
+
'or',
|
|
61
|
+
'such',
|
|
62
|
+
'that',
|
|
63
|
+
'their',
|
|
64
|
+
'then',
|
|
65
|
+
'there',
|
|
66
|
+
'these',
|
|
67
|
+
'they',
|
|
68
|
+
'this',
|
|
69
|
+
'to',
|
|
70
|
+
'was',
|
|
71
|
+
'will',
|
|
72
|
+
'with',
|
|
73
|
+
]
|
|
74
|
+
|
|
39
75
|
|
|
40
76
|
class FalkorDriverSession(GraphDriverSession):
|
|
41
77
|
provider = GraphProvider.FALKORDB
|
|
@@ -74,6 +110,7 @@ class FalkorDriverSession(GraphDriverSession):
|
|
|
74
110
|
|
|
75
111
|
class FalkorDriver(GraphDriver):
|
|
76
112
|
provider = GraphProvider.FALKORDB
|
|
113
|
+
aoss_client: None = None
|
|
77
114
|
|
|
78
115
|
def __init__(
|
|
79
116
|
self,
|
|
@@ -166,3 +203,78 @@ class FalkorDriver(GraphDriver):
|
|
|
166
203
|
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
|
167
204
|
|
|
168
205
|
return cloned
|
|
206
|
+
|
|
207
|
+
def sanitize(self, query: str) -> str:
|
|
208
|
+
"""
|
|
209
|
+
Replace FalkorDB special characters with whitespace.
|
|
210
|
+
Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
|
|
211
|
+
"""
|
|
212
|
+
# FalkorDB separator characters that break text into tokens
|
|
213
|
+
separator_map = str.maketrans(
|
|
214
|
+
{
|
|
215
|
+
',': ' ',
|
|
216
|
+
'.': ' ',
|
|
217
|
+
'<': ' ',
|
|
218
|
+
'>': ' ',
|
|
219
|
+
'{': ' ',
|
|
220
|
+
'}': ' ',
|
|
221
|
+
'[': ' ',
|
|
222
|
+
']': ' ',
|
|
223
|
+
'"': ' ',
|
|
224
|
+
"'": ' ',
|
|
225
|
+
':': ' ',
|
|
226
|
+
';': ' ',
|
|
227
|
+
'!': ' ',
|
|
228
|
+
'@': ' ',
|
|
229
|
+
'#': ' ',
|
|
230
|
+
'$': ' ',
|
|
231
|
+
'%': ' ',
|
|
232
|
+
'^': ' ',
|
|
233
|
+
'&': ' ',
|
|
234
|
+
'*': ' ',
|
|
235
|
+
'(': ' ',
|
|
236
|
+
')': ' ',
|
|
237
|
+
'-': ' ',
|
|
238
|
+
'+': ' ',
|
|
239
|
+
'=': ' ',
|
|
240
|
+
'~': ' ',
|
|
241
|
+
'?': ' ',
|
|
242
|
+
}
|
|
243
|
+
)
|
|
244
|
+
sanitized = query.translate(separator_map)
|
|
245
|
+
# Clean up multiple spaces
|
|
246
|
+
sanitized = ' '.join(sanitized.split())
|
|
247
|
+
return sanitized
|
|
248
|
+
|
|
249
|
+
def build_fulltext_query(
|
|
250
|
+
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
|
251
|
+
) -> str:
|
|
252
|
+
"""
|
|
253
|
+
Build a fulltext query string for FalkorDB using RedisSearch syntax.
|
|
254
|
+
FalkorDB uses RedisSearch-like syntax where:
|
|
255
|
+
- Field queries use @ prefix: @field:value
|
|
256
|
+
- Multiple values for same field: (@field:value1|value2)
|
|
257
|
+
- Text search doesn't need @ prefix for content fields
|
|
258
|
+
- AND is implicit with space: (@group_id:value) (text)
|
|
259
|
+
- OR uses pipe within parentheses: (@group_id:value1|value2)
|
|
260
|
+
"""
|
|
261
|
+
if group_ids is None or len(group_ids) == 0:
|
|
262
|
+
group_filter = ''
|
|
263
|
+
else:
|
|
264
|
+
group_values = '|'.join(group_ids)
|
|
265
|
+
group_filter = f'(@group_id:{group_values})'
|
|
266
|
+
|
|
267
|
+
sanitized_query = self.sanitize(query)
|
|
268
|
+
|
|
269
|
+
# Remove stopwords from the sanitized query
|
|
270
|
+
query_words = sanitized_query.split()
|
|
271
|
+
filtered_words = [word for word in query_words if word.lower() not in STOPWORDS]
|
|
272
|
+
sanitized_query = ' | '.join(filtered_words)
|
|
273
|
+
|
|
274
|
+
# If the query is too long return no query
|
|
275
|
+
if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length:
|
|
276
|
+
return ''
|
|
277
|
+
|
|
278
|
+
full_query = group_filter + ' (' + sanitized_query + ')'
|
|
279
|
+
|
|
280
|
+
return full_query
|
|
@@ -29,7 +29,13 @@ logger = logging.getLogger(__name__)
|
|
|
29
29
|
class Neo4jDriver(GraphDriver):
|
|
30
30
|
provider = GraphProvider.NEO4J
|
|
31
31
|
|
|
32
|
-
def __init__(
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
uri: str,
|
|
35
|
+
user: str | None,
|
|
36
|
+
password: str | None,
|
|
37
|
+
database: str = 'neo4j',
|
|
38
|
+
):
|
|
33
39
|
super().__init__()
|
|
34
40
|
self.client = AsyncGraphDatabase.driver(
|
|
35
41
|
uri=uri,
|
|
@@ -37,6 +43,8 @@ class Neo4jDriver(GraphDriver):
|
|
|
37
43
|
)
|
|
38
44
|
self._database = database
|
|
39
45
|
|
|
46
|
+
self.aoss_client = None
|
|
47
|
+
|
|
40
48
|
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
|
41
49
|
# Check if database_ is provided in kwargs.
|
|
42
50
|
# If not populated, set the value to retain backwards compatibility
|
|
@@ -60,7 +68,7 @@ class Neo4jDriver(GraphDriver):
|
|
|
60
68
|
async def close(self) -> None:
|
|
61
69
|
return await self.client.close()
|
|
62
70
|
|
|
63
|
-
def delete_all_indexes(self) -> Coroutine
|
|
71
|
+
def delete_all_indexes(self) -> Coroutine:
|
|
64
72
|
return self.client.execute_query(
|
|
65
73
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
66
74
|
)
|
|
@@ -257,15 +257,13 @@ class NeptuneDriver(GraphDriver):
|
|
|
257
257
|
if name.lower() == index['index_name']:
|
|
258
258
|
to_index = []
|
|
259
259
|
for d in data:
|
|
260
|
-
item = {'_index': name}
|
|
260
|
+
item = {'_index': name, '_id': d['uuid']}
|
|
261
261
|
for p in index['body']['mappings']['properties']:
|
|
262
|
-
|
|
262
|
+
if p in d:
|
|
263
|
+
item[p] = d[p]
|
|
263
264
|
to_index.append(item)
|
|
264
265
|
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
|
265
|
-
|
|
266
|
-
return success
|
|
267
|
-
else:
|
|
268
|
-
return 0
|
|
266
|
+
return success
|
|
269
267
|
|
|
270
268
|
return 0
|
|
271
269
|
|
graphiti_core/edges.py
CHANGED
|
@@ -25,7 +25,7 @@ from uuid import uuid4
|
|
|
25
25
|
from pydantic import BaseModel, Field
|
|
26
26
|
from typing_extensions import LiteralString
|
|
27
27
|
|
|
28
|
-
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
28
|
+
from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
|
|
29
29
|
from graphiti_core.embedder import EmbedderClient
|
|
30
30
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
31
31
|
from graphiti_core.helpers import parse_db_date
|
|
@@ -77,6 +77,13 @@ class Edge(BaseModel, ABC):
|
|
|
77
77
|
uuid=self.uuid,
|
|
78
78
|
)
|
|
79
79
|
|
|
80
|
+
if driver.aoss_client:
|
|
81
|
+
await driver.aoss_client.delete(
|
|
82
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
83
|
+
id=self.uuid,
|
|
84
|
+
params={'routing': self.group_id},
|
|
85
|
+
)
|
|
86
|
+
|
|
80
87
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
|
81
88
|
|
|
82
89
|
@classmethod
|
|
@@ -108,6 +115,12 @@ class Edge(BaseModel, ABC):
|
|
|
108
115
|
uuids=uuids,
|
|
109
116
|
)
|
|
110
117
|
|
|
118
|
+
if driver.aoss_client:
|
|
119
|
+
await driver.aoss_client.delete_by_query(
|
|
120
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
121
|
+
body={'query': {'terms': {'uuid': uuids}}},
|
|
122
|
+
)
|
|
123
|
+
|
|
111
124
|
logger.debug(f'Deleted Edges: {uuids}')
|
|
112
125
|
|
|
113
126
|
def __hash__(self):
|
|
@@ -255,6 +268,21 @@ class EntityEdge(Edge):
|
|
|
255
268
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
256
269
|
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
|
257
270
|
"""
|
|
271
|
+
elif driver.aoss_client:
|
|
272
|
+
resp = await driver.aoss_client.search(
|
|
273
|
+
body={
|
|
274
|
+
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
|
275
|
+
'size': 1,
|
|
276
|
+
},
|
|
277
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
278
|
+
params={'routing': self.group_id},
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
if resp['hits']['hits']:
|
|
282
|
+
self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
|
|
283
|
+
return
|
|
284
|
+
else:
|
|
285
|
+
raise EdgeNotFoundError(self.uuid)
|
|
258
286
|
|
|
259
287
|
if driver.provider == GraphProvider.KUZU:
|
|
260
288
|
query = """
|
|
@@ -292,14 +320,14 @@ class EntityEdge(Edge):
|
|
|
292
320
|
if driver.provider == GraphProvider.KUZU:
|
|
293
321
|
edge_data['attributes'] = json.dumps(self.attributes)
|
|
294
322
|
result = await driver.execute_query(
|
|
295
|
-
get_entity_edge_save_query(driver.provider),
|
|
323
|
+
get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
|
296
324
|
**edge_data,
|
|
297
325
|
)
|
|
298
326
|
else:
|
|
299
327
|
edge_data.update(self.attributes or {})
|
|
300
328
|
|
|
301
|
-
if driver.
|
|
302
|
-
driver.save_to_aoss(
|
|
329
|
+
if driver.aoss_client:
|
|
330
|
+
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
|
303
331
|
|
|
304
332
|
result = await driver.execute_query(
|
|
305
333
|
get_entity_edge_save_query(driver.provider),
|
|
@@ -336,6 +364,35 @@ class EntityEdge(Edge):
|
|
|
336
364
|
raise EdgeNotFoundError(uuid)
|
|
337
365
|
return edges[0]
|
|
338
366
|
|
|
367
|
+
@classmethod
|
|
368
|
+
async def get_between_nodes(
|
|
369
|
+
cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
|
|
370
|
+
):
|
|
371
|
+
match_query = """
|
|
372
|
+
MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
|
373
|
+
"""
|
|
374
|
+
if driver.provider == GraphProvider.KUZU:
|
|
375
|
+
match_query = """
|
|
376
|
+
MATCH (n:Entity {uuid: $source_node_uuid})
|
|
377
|
+
-[:RELATES_TO]->(e:RelatesToNode_)
|
|
378
|
+
-[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
records, _, _ = await driver.execute_query(
|
|
382
|
+
match_query
|
|
383
|
+
+ """
|
|
384
|
+
RETURN
|
|
385
|
+
"""
|
|
386
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
387
|
+
source_node_uuid=source_node_uuid,
|
|
388
|
+
target_node_uuid=target_node_uuid,
|
|
389
|
+
routing_='r',
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
393
|
+
|
|
394
|
+
return edges
|
|
395
|
+
|
|
339
396
|
@classmethod
|
|
340
397
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
341
398
|
if len(uuids) == 0:
|
|
@@ -587,8 +644,11 @@ def get_community_edge_from_record(record: Any):
|
|
|
587
644
|
|
|
588
645
|
|
|
589
646
|
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
|
590
|
-
|
|
647
|
+
# filter out falsey values from edges
|
|
648
|
+
filtered_edges = [edge for edge in edges if edge.fact]
|
|
649
|
+
|
|
650
|
+
if len(filtered_edges) == 0:
|
|
591
651
|
return
|
|
592
|
-
fact_embeddings = await embedder.create_batch([edge.fact for edge in
|
|
593
|
-
for edge, fact_embedding in zip(
|
|
652
|
+
fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
|
|
653
|
+
for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
|
|
594
654
|
edge.fact_embedding = fact_embedding
|
graphiti_core/embedder/client.py
CHANGED
|
@@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import os
|
|
17
18
|
from abc import ABC, abstractmethod
|
|
18
19
|
from collections.abc import Iterable
|
|
19
20
|
|
|
20
21
|
from pydantic import BaseModel, Field
|
|
21
22
|
|
|
22
|
-
EMBEDDING_DIM = 1024
|
|
23
|
+
EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024))
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class EmbedderConfig(BaseModel):
|
graphiti_core/graph_queries.py
CHANGED
|
@@ -71,12 +71,41 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
|
71
71
|
|
|
72
72
|
def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
73
73
|
if provider == GraphProvider.FALKORDB:
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
74
|
+
from typing import cast
|
|
75
|
+
|
|
76
|
+
from graphiti_core.driver.falkordb_driver import STOPWORDS
|
|
77
|
+
|
|
78
|
+
# Convert to string representation for embedding in queries
|
|
79
|
+
stopwords_str = str(STOPWORDS)
|
|
80
|
+
|
|
81
|
+
# Use type: ignore to satisfy LiteralString requirement while maintaining single source of truth
|
|
82
|
+
return cast(
|
|
83
|
+
list[LiteralString],
|
|
84
|
+
[
|
|
85
|
+
f"""CALL db.idx.fulltext.createNodeIndex(
|
|
86
|
+
{{
|
|
87
|
+
label: 'Episodic',
|
|
88
|
+
stopwords: {stopwords_str}
|
|
89
|
+
}},
|
|
90
|
+
'content', 'source', 'source_description', 'group_id'
|
|
91
|
+
)""",
|
|
92
|
+
f"""CALL db.idx.fulltext.createNodeIndex(
|
|
93
|
+
{{
|
|
94
|
+
label: 'Entity',
|
|
95
|
+
stopwords: {stopwords_str}
|
|
96
|
+
}},
|
|
97
|
+
'name', 'summary', 'group_id'
|
|
98
|
+
)""",
|
|
99
|
+
f"""CALL db.idx.fulltext.createNodeIndex(
|
|
100
|
+
{{
|
|
101
|
+
label: 'Community',
|
|
102
|
+
stopwords: {stopwords_str}
|
|
103
|
+
}},
|
|
104
|
+
'name', 'group_id'
|
|
105
|
+
)""",
|
|
106
|
+
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
|
107
|
+
],
|
|
108
|
+
)
|
|
80
109
|
|
|
81
110
|
if provider == GraphProvider.KUZU:
|
|
82
111
|
return [
|
graphiti_core/graphiti.py
CHANGED
|
@@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
|
|
|
60
60
|
from graphiti_core.search.search_filters import SearchFilters
|
|
61
61
|
from graphiti_core.search.search_utils import (
|
|
62
62
|
RELEVANT_SCHEMA_LIMIT,
|
|
63
|
-
get_edge_invalidation_candidates,
|
|
64
63
|
get_mentioned_nodes,
|
|
65
|
-
get_relevant_edges,
|
|
66
64
|
)
|
|
67
65
|
from graphiti_core.telemetry import capture_event
|
|
68
66
|
from graphiti_core.utils.bulk_utils import (
|
|
@@ -81,7 +79,6 @@ from graphiti_core.utils.maintenance.community_operations import (
|
|
|
81
79
|
update_community,
|
|
82
80
|
)
|
|
83
81
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
84
|
-
build_duplicate_of_edges,
|
|
85
82
|
build_episodic_edges,
|
|
86
83
|
extract_edges,
|
|
87
84
|
resolve_extracted_edge,
|
|
@@ -122,6 +119,11 @@ class AddBulkEpisodeResults(BaseModel):
|
|
|
122
119
|
community_edges: list[CommunityEdge]
|
|
123
120
|
|
|
124
121
|
|
|
122
|
+
class AddTripletResults(BaseModel):
|
|
123
|
+
nodes: list[EntityNode]
|
|
124
|
+
edges: list[EntityEdge]
|
|
125
|
+
|
|
126
|
+
|
|
125
127
|
class Graphiti:
|
|
126
128
|
def __init__(
|
|
127
129
|
self,
|
|
@@ -134,7 +136,6 @@ class Graphiti:
|
|
|
134
136
|
store_raw_episode_content: bool = True,
|
|
135
137
|
graph_driver: GraphDriver | None = None,
|
|
136
138
|
max_coroutines: int | None = None,
|
|
137
|
-
ensure_ascii: bool = False,
|
|
138
139
|
):
|
|
139
140
|
"""
|
|
140
141
|
Initialize a Graphiti instance.
|
|
@@ -167,10 +168,6 @@ class Graphiti:
|
|
|
167
168
|
max_coroutines : int | None, optional
|
|
168
169
|
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
|
|
169
170
|
If not set, the Graphiti default is used.
|
|
170
|
-
ensure_ascii : bool, optional
|
|
171
|
-
Whether to escape non-ASCII characters in JSON serialization for prompts. Defaults to False.
|
|
172
|
-
Set as False to preserve non-ASCII characters (e.g., Korean, Japanese, Chinese) in their
|
|
173
|
-
original form, making them readable in LLM logs and improving model understanding.
|
|
174
171
|
|
|
175
172
|
Returns
|
|
176
173
|
-------
|
|
@@ -200,7 +197,6 @@ class Graphiti:
|
|
|
200
197
|
|
|
201
198
|
self.store_raw_episode_content = store_raw_episode_content
|
|
202
199
|
self.max_coroutines = max_coroutines
|
|
203
|
-
self.ensure_ascii = ensure_ascii
|
|
204
200
|
if llm_client:
|
|
205
201
|
self.llm_client = llm_client
|
|
206
202
|
else:
|
|
@@ -219,7 +215,6 @@ class Graphiti:
|
|
|
219
215
|
llm_client=self.llm_client,
|
|
220
216
|
embedder=self.embedder,
|
|
221
217
|
cross_encoder=self.cross_encoder,
|
|
222
|
-
ensure_ascii=self.ensure_ascii,
|
|
223
218
|
)
|
|
224
219
|
|
|
225
220
|
# Capture telemetry event
|
|
@@ -453,12 +448,12 @@ class Graphiti:
|
|
|
453
448
|
start = time()
|
|
454
449
|
now = utc_now()
|
|
455
450
|
|
|
456
|
-
# if group_id is None, use the default group id by the provider
|
|
457
|
-
group_id = group_id or get_default_group_id(self.driver.provider)
|
|
458
451
|
validate_entity_types(entity_types)
|
|
459
452
|
|
|
460
453
|
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
|
461
454
|
validate_group_id(group_id)
|
|
455
|
+
# if group_id is None, use the default group id by the provider
|
|
456
|
+
group_id = group_id or get_default_group_id(self.driver.provider)
|
|
462
457
|
|
|
463
458
|
previous_episodes = (
|
|
464
459
|
await self.retrieve_episodes(
|
|
@@ -500,7 +495,7 @@ class Graphiti:
|
|
|
500
495
|
)
|
|
501
496
|
|
|
502
497
|
# Extract edges and resolve nodes
|
|
503
|
-
(nodes, uuid_map,
|
|
498
|
+
(nodes, uuid_map, _), extracted_edges = await semaphore_gather(
|
|
504
499
|
resolve_extracted_nodes(
|
|
505
500
|
self.clients,
|
|
506
501
|
extracted_nodes,
|
|
@@ -537,9 +532,7 @@ class Graphiti:
|
|
|
537
532
|
max_coroutines=self.max_coroutines,
|
|
538
533
|
)
|
|
539
534
|
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
|
|
535
|
+
entity_edges = resolved_edges + invalidated_edges
|
|
543
536
|
|
|
544
537
|
episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
|
|
545
538
|
|
|
@@ -559,9 +552,7 @@ class Graphiti:
|
|
|
559
552
|
if update_communities:
|
|
560
553
|
communities, community_edges = await semaphore_gather(
|
|
561
554
|
*[
|
|
562
|
-
update_community(
|
|
563
|
-
self.driver, self.llm_client, self.embedder, node, self.ensure_ascii
|
|
564
|
-
)
|
|
555
|
+
update_community(self.driver, self.llm_client, self.embedder, node)
|
|
565
556
|
for node in nodes
|
|
566
557
|
],
|
|
567
558
|
max_coroutines=self.max_coroutines,
|
|
@@ -1015,7 +1006,9 @@ class Graphiti:
|
|
|
1015
1006
|
|
|
1016
1007
|
return SearchResults(edges=edges, nodes=nodes)
|
|
1017
1008
|
|
|
1018
|
-
async def add_triplet(
|
|
1009
|
+
async def add_triplet(
|
|
1010
|
+
self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
|
|
1011
|
+
) -> AddTripletResults:
|
|
1019
1012
|
if source_node.name_embedding is None:
|
|
1020
1013
|
await source_node.generate_name_embedding(self.embedder)
|
|
1021
1014
|
if target_node.name_embedding is None:
|
|
@@ -1030,10 +1023,28 @@ class Graphiti:
|
|
|
1030
1023
|
|
|
1031
1024
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
|
1032
1025
|
|
|
1033
|
-
|
|
1026
|
+
valid_edges = await EntityEdge.get_between_nodes(
|
|
1027
|
+
self.driver, edge.source_node_uuid, edge.target_node_uuid
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
related_edges = (
|
|
1031
|
+
await search(
|
|
1032
|
+
self.clients,
|
|
1033
|
+
updated_edge.fact,
|
|
1034
|
+
group_ids=[updated_edge.group_id],
|
|
1035
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1036
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
1037
|
+
)
|
|
1038
|
+
).edges
|
|
1034
1039
|
existing_edges = (
|
|
1035
|
-
await
|
|
1036
|
-
|
|
1040
|
+
await search(
|
|
1041
|
+
self.clients,
|
|
1042
|
+
updated_edge.fact,
|
|
1043
|
+
group_ids=[updated_edge.group_id],
|
|
1044
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1045
|
+
search_filter=SearchFilters(),
|
|
1046
|
+
)
|
|
1047
|
+
).edges
|
|
1037
1048
|
|
|
1038
1049
|
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
|
1039
1050
|
self.llm_client,
|
|
@@ -1050,7 +1061,7 @@ class Graphiti:
|
|
|
1050
1061
|
group_id=edge.group_id,
|
|
1051
1062
|
),
|
|
1052
1063
|
None,
|
|
1053
|
-
|
|
1064
|
+
None,
|
|
1054
1065
|
)
|
|
1055
1066
|
|
|
1056
1067
|
edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
|
|
@@ -1059,6 +1070,7 @@ class Graphiti:
|
|
|
1059
1070
|
await create_entity_node_embeddings(self.embedder, nodes)
|
|
1060
1071
|
|
|
1061
1072
|
await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
|
|
1073
|
+
return AddTripletResults(edges=edges, nodes=nodes)
|
|
1062
1074
|
|
|
1063
1075
|
async def remove_episode(self, episode_uuid: str):
|
|
1064
1076
|
# Find the episode to be deleted
|
graphiti_core/graphiti_types.py
CHANGED
graphiti_core/helpers.py
CHANGED
|
@@ -54,7 +54,7 @@ def get_default_group_id(provider: GraphProvider) -> str:
|
|
|
54
54
|
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
|
55
55
|
"""
|
|
56
56
|
if provider == GraphProvider.FALKORDB:
|
|
57
|
-
return '_'
|
|
57
|
+
return '\\_'
|
|
58
58
|
else:
|
|
59
59
|
return ''
|
|
60
60
|
|
|
@@ -116,7 +116,7 @@ async def semaphore_gather(
|
|
|
116
116
|
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
|
|
117
117
|
|
|
118
118
|
|
|
119
|
-
def validate_group_id(group_id: str) -> bool:
|
|
119
|
+
def validate_group_id(group_id: str | None) -> bool:
|
|
120
120
|
"""
|
|
121
121
|
Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
|
|
122
122
|
|
|
@@ -32,9 +32,23 @@ from .errors import RateLimitError
|
|
|
32
32
|
DEFAULT_TEMPERATURE = 0
|
|
33
33
|
DEFAULT_CACHE_DIR = './llm_cache'
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
35
|
+
|
|
36
|
+
def get_extraction_language_instruction(group_id: str | None = None) -> str:
|
|
37
|
+
"""Returns instruction for language extraction behavior.
|
|
38
|
+
|
|
39
|
+
Override this function to customize language extraction:
|
|
40
|
+
- Return empty string to disable multilingual instructions
|
|
41
|
+
- Return custom instructions for specific language requirements
|
|
42
|
+
- Use group_id to provide different instructions per group/partition
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
group_id: Optional partition identifier for the graph
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
str: Language instruction to append to system messages
|
|
49
|
+
"""
|
|
50
|
+
return '\n\nAny extracted information should be returned in the same language as it was written in.'
|
|
51
|
+
|
|
38
52
|
|
|
39
53
|
logger = logging.getLogger(__name__)
|
|
40
54
|
|
|
@@ -132,6 +146,7 @@ class LLMClient(ABC):
|
|
|
132
146
|
response_model: type[BaseModel] | None = None,
|
|
133
147
|
max_tokens: int | None = None,
|
|
134
148
|
model_size: ModelSize = ModelSize.medium,
|
|
149
|
+
group_id: str | None = None,
|
|
135
150
|
) -> dict[str, typing.Any]:
|
|
136
151
|
if max_tokens is None:
|
|
137
152
|
max_tokens = self.max_tokens
|
|
@@ -145,7 +160,7 @@ class LLMClient(ABC):
|
|
|
145
160
|
)
|
|
146
161
|
|
|
147
162
|
# Add multilingual extraction instructions
|
|
148
|
-
messages[0].content +=
|
|
163
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
149
164
|
|
|
150
165
|
if self.cache_enabled and self.cache_dir is not None:
|
|
151
166
|
cache_key = self._get_cache_key(messages)
|