graphiti-core 0.20.4__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +28 -0
- graphiti_core/driver/falkordb_driver.py +112 -0
- graphiti_core/driver/kuzu_driver.py +1 -0
- graphiti_core/driver/neo4j_driver.py +10 -2
- graphiti_core/driver/neptune_driver.py +4 -6
- graphiti_core/edges.py +67 -7
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +35 -6
- graphiti_core/graphiti.py +27 -23
- graphiti_core/graphiti_types.py +0 -1
- graphiti_core/helpers.py +2 -2
- graphiti_core/llm_client/client.py +19 -4
- graphiti_core/llm_client/gemini_client.py +4 -2
- graphiti_core/llm_client/openai_base_client.py +3 -2
- graphiti_core/llm_client/openai_generic_client.py +3 -2
- graphiti_core/models/edges/edge_db_queries.py +36 -16
- graphiti_core/models/nodes/node_db_queries.py +30 -10
- graphiti_core/nodes.py +126 -25
- graphiti_core/prompts/dedupe_edges.py +40 -29
- graphiti_core/prompts/dedupe_nodes.py +51 -34
- graphiti_core/prompts/eval.py +3 -3
- graphiti_core/prompts/extract_edges.py +17 -9
- graphiti_core/prompts/extract_nodes.py +10 -9
- graphiti_core/prompts/prompt_helpers.py +3 -3
- graphiti_core/prompts/summarize_nodes.py +5 -5
- graphiti_core/search/search_filters.py +53 -0
- graphiti_core/search/search_helpers.py +5 -7
- graphiti_core/search/search_utils.py +227 -57
- graphiti_core/utils/bulk_utils.py +168 -69
- graphiti_core/utils/maintenance/community_operations.py +8 -20
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +187 -50
- graphiti_core/utils/maintenance/graph_data_operations.py +9 -5
- graphiti_core/utils/maintenance/node_operations.py +244 -88
- graphiti_core/utils/maintenance/temporal_operations.py +0 -4
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/METADATA +7 -1
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/RECORD +39 -38
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/driver/driver.py
CHANGED
|
@@ -16,13 +16,25 @@ limitations under the License.
|
|
|
16
16
|
|
|
17
17
|
import copy
|
|
18
18
|
import logging
|
|
19
|
+
import os
|
|
19
20
|
from abc import ABC, abstractmethod
|
|
20
21
|
from collections.abc import Coroutine
|
|
21
22
|
from enum import Enum
|
|
22
23
|
from typing import Any
|
|
23
24
|
|
|
25
|
+
from dotenv import load_dotenv
|
|
26
|
+
|
|
24
27
|
logger = logging.getLogger(__name__)
|
|
25
28
|
|
|
29
|
+
DEFAULT_SIZE = 10
|
|
30
|
+
|
|
31
|
+
load_dotenv()
|
|
32
|
+
|
|
33
|
+
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
|
|
34
|
+
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
|
|
35
|
+
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
|
36
|
+
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
|
|
37
|
+
|
|
26
38
|
|
|
27
39
|
class GraphProvider(Enum):
|
|
28
40
|
NEO4J = 'neo4j'
|
|
@@ -61,6 +73,7 @@ class GraphDriver(ABC):
|
|
|
61
73
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
|
62
74
|
)
|
|
63
75
|
_database: str
|
|
76
|
+
aoss_client: Any # type: ignore
|
|
64
77
|
|
|
65
78
|
@abstractmethod
|
|
66
79
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
|
@@ -87,3 +100,18 @@ class GraphDriver(ABC):
|
|
|
87
100
|
cloned._database = database
|
|
88
101
|
|
|
89
102
|
return cloned
|
|
103
|
+
|
|
104
|
+
def build_fulltext_query(
|
|
105
|
+
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
|
106
|
+
) -> str:
|
|
107
|
+
"""
|
|
108
|
+
Specific fulltext query builder for database providers.
|
|
109
|
+
Only implemented by providers that need custom fulltext query building.
|
|
110
|
+
"""
|
|
111
|
+
raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
|
|
112
|
+
|
|
113
|
+
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
114
|
+
return 0
|
|
115
|
+
|
|
116
|
+
async def clear_aoss_indices(self):
|
|
117
|
+
return 1
|
|
@@ -36,6 +36,42 @@ from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
|
|
36
36
|
|
|
37
37
|
logger = logging.getLogger(__name__)
|
|
38
38
|
|
|
39
|
+
STOPWORDS = [
|
|
40
|
+
'a',
|
|
41
|
+
'is',
|
|
42
|
+
'the',
|
|
43
|
+
'an',
|
|
44
|
+
'and',
|
|
45
|
+
'are',
|
|
46
|
+
'as',
|
|
47
|
+
'at',
|
|
48
|
+
'be',
|
|
49
|
+
'but',
|
|
50
|
+
'by',
|
|
51
|
+
'for',
|
|
52
|
+
'if',
|
|
53
|
+
'in',
|
|
54
|
+
'into',
|
|
55
|
+
'it',
|
|
56
|
+
'no',
|
|
57
|
+
'not',
|
|
58
|
+
'of',
|
|
59
|
+
'on',
|
|
60
|
+
'or',
|
|
61
|
+
'such',
|
|
62
|
+
'that',
|
|
63
|
+
'their',
|
|
64
|
+
'then',
|
|
65
|
+
'there',
|
|
66
|
+
'these',
|
|
67
|
+
'they',
|
|
68
|
+
'this',
|
|
69
|
+
'to',
|
|
70
|
+
'was',
|
|
71
|
+
'will',
|
|
72
|
+
'with',
|
|
73
|
+
]
|
|
74
|
+
|
|
39
75
|
|
|
40
76
|
class FalkorDriverSession(GraphDriverSession):
|
|
41
77
|
provider = GraphProvider.FALKORDB
|
|
@@ -74,6 +110,7 @@ class FalkorDriverSession(GraphDriverSession):
|
|
|
74
110
|
|
|
75
111
|
class FalkorDriver(GraphDriver):
|
|
76
112
|
provider = GraphProvider.FALKORDB
|
|
113
|
+
aoss_client: None = None
|
|
77
114
|
|
|
78
115
|
def __init__(
|
|
79
116
|
self,
|
|
@@ -166,3 +203,78 @@ class FalkorDriver(GraphDriver):
|
|
|
166
203
|
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
|
167
204
|
|
|
168
205
|
return cloned
|
|
206
|
+
|
|
207
|
+
def sanitize(self, query: str) -> str:
|
|
208
|
+
"""
|
|
209
|
+
Replace FalkorDB special characters with whitespace.
|
|
210
|
+
Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
|
|
211
|
+
"""
|
|
212
|
+
# FalkorDB separator characters that break text into tokens
|
|
213
|
+
separator_map = str.maketrans(
|
|
214
|
+
{
|
|
215
|
+
',': ' ',
|
|
216
|
+
'.': ' ',
|
|
217
|
+
'<': ' ',
|
|
218
|
+
'>': ' ',
|
|
219
|
+
'{': ' ',
|
|
220
|
+
'}': ' ',
|
|
221
|
+
'[': ' ',
|
|
222
|
+
']': ' ',
|
|
223
|
+
'"': ' ',
|
|
224
|
+
"'": ' ',
|
|
225
|
+
':': ' ',
|
|
226
|
+
';': ' ',
|
|
227
|
+
'!': ' ',
|
|
228
|
+
'@': ' ',
|
|
229
|
+
'#': ' ',
|
|
230
|
+
'$': ' ',
|
|
231
|
+
'%': ' ',
|
|
232
|
+
'^': ' ',
|
|
233
|
+
'&': ' ',
|
|
234
|
+
'*': ' ',
|
|
235
|
+
'(': ' ',
|
|
236
|
+
')': ' ',
|
|
237
|
+
'-': ' ',
|
|
238
|
+
'+': ' ',
|
|
239
|
+
'=': ' ',
|
|
240
|
+
'~': ' ',
|
|
241
|
+
'?': ' ',
|
|
242
|
+
}
|
|
243
|
+
)
|
|
244
|
+
sanitized = query.translate(separator_map)
|
|
245
|
+
# Clean up multiple spaces
|
|
246
|
+
sanitized = ' '.join(sanitized.split())
|
|
247
|
+
return sanitized
|
|
248
|
+
|
|
249
|
+
def build_fulltext_query(
|
|
250
|
+
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
|
251
|
+
) -> str:
|
|
252
|
+
"""
|
|
253
|
+
Build a fulltext query string for FalkorDB using RedisSearch syntax.
|
|
254
|
+
FalkorDB uses RedisSearch-like syntax where:
|
|
255
|
+
- Field queries use @ prefix: @field:value
|
|
256
|
+
- Multiple values for same field: (@field:value1|value2)
|
|
257
|
+
- Text search doesn't need @ prefix for content fields
|
|
258
|
+
- AND is implicit with space: (@group_id:value) (text)
|
|
259
|
+
- OR uses pipe within parentheses: (@group_id:value1|value2)
|
|
260
|
+
"""
|
|
261
|
+
if group_ids is None or len(group_ids) == 0:
|
|
262
|
+
group_filter = ''
|
|
263
|
+
else:
|
|
264
|
+
group_values = '|'.join(group_ids)
|
|
265
|
+
group_filter = f'(@group_id:{group_values})'
|
|
266
|
+
|
|
267
|
+
sanitized_query = self.sanitize(query)
|
|
268
|
+
|
|
269
|
+
# Remove stopwords from the sanitized query
|
|
270
|
+
query_words = sanitized_query.split()
|
|
271
|
+
filtered_words = [word for word in query_words if word.lower() not in STOPWORDS]
|
|
272
|
+
sanitized_query = ' | '.join(filtered_words)
|
|
273
|
+
|
|
274
|
+
# If the query is too long return no query
|
|
275
|
+
if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length:
|
|
276
|
+
return ''
|
|
277
|
+
|
|
278
|
+
full_query = group_filter + ' (' + sanitized_query + ')'
|
|
279
|
+
|
|
280
|
+
return full_query
|
|
@@ -29,7 +29,13 @@ logger = logging.getLogger(__name__)
|
|
|
29
29
|
class Neo4jDriver(GraphDriver):
|
|
30
30
|
provider = GraphProvider.NEO4J
|
|
31
31
|
|
|
32
|
-
def __init__(
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
uri: str,
|
|
35
|
+
user: str | None,
|
|
36
|
+
password: str | None,
|
|
37
|
+
database: str = 'neo4j',
|
|
38
|
+
):
|
|
33
39
|
super().__init__()
|
|
34
40
|
self.client = AsyncGraphDatabase.driver(
|
|
35
41
|
uri=uri,
|
|
@@ -37,6 +43,8 @@ class Neo4jDriver(GraphDriver):
|
|
|
37
43
|
)
|
|
38
44
|
self._database = database
|
|
39
45
|
|
|
46
|
+
self.aoss_client = None
|
|
47
|
+
|
|
40
48
|
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
|
41
49
|
# Check if database_ is provided in kwargs.
|
|
42
50
|
# If not populated, set the value to retain backwards compatibility
|
|
@@ -60,7 +68,7 @@ class Neo4jDriver(GraphDriver):
|
|
|
60
68
|
async def close(self) -> None:
|
|
61
69
|
return await self.client.close()
|
|
62
70
|
|
|
63
|
-
def delete_all_indexes(self) -> Coroutine
|
|
71
|
+
def delete_all_indexes(self) -> Coroutine:
|
|
64
72
|
return self.client.execute_query(
|
|
65
73
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
66
74
|
)
|
|
@@ -257,15 +257,13 @@ class NeptuneDriver(GraphDriver):
|
|
|
257
257
|
if name.lower() == index['index_name']:
|
|
258
258
|
to_index = []
|
|
259
259
|
for d in data:
|
|
260
|
-
item = {'_index': name}
|
|
260
|
+
item = {'_index': name, '_id': d['uuid']}
|
|
261
261
|
for p in index['body']['mappings']['properties']:
|
|
262
|
-
|
|
262
|
+
if p in d:
|
|
263
|
+
item[p] = d[p]
|
|
263
264
|
to_index.append(item)
|
|
264
265
|
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
|
265
|
-
|
|
266
|
-
return success
|
|
267
|
-
else:
|
|
268
|
-
return 0
|
|
266
|
+
return success
|
|
269
267
|
|
|
270
268
|
return 0
|
|
271
269
|
|
graphiti_core/edges.py
CHANGED
|
@@ -25,7 +25,7 @@ from uuid import uuid4
|
|
|
25
25
|
from pydantic import BaseModel, Field
|
|
26
26
|
from typing_extensions import LiteralString
|
|
27
27
|
|
|
28
|
-
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
28
|
+
from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
|
|
29
29
|
from graphiti_core.embedder import EmbedderClient
|
|
30
30
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
31
31
|
from graphiti_core.helpers import parse_db_date
|
|
@@ -77,6 +77,13 @@ class Edge(BaseModel, ABC):
|
|
|
77
77
|
uuid=self.uuid,
|
|
78
78
|
)
|
|
79
79
|
|
|
80
|
+
if driver.aoss_client:
|
|
81
|
+
await driver.aoss_client.delete(
|
|
82
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
83
|
+
id=self.uuid,
|
|
84
|
+
params={'routing': self.group_id},
|
|
85
|
+
)
|
|
86
|
+
|
|
80
87
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
|
81
88
|
|
|
82
89
|
@classmethod
|
|
@@ -108,6 +115,12 @@ class Edge(BaseModel, ABC):
|
|
|
108
115
|
uuids=uuids,
|
|
109
116
|
)
|
|
110
117
|
|
|
118
|
+
if driver.aoss_client:
|
|
119
|
+
await driver.aoss_client.delete_by_query(
|
|
120
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
121
|
+
body={'query': {'terms': {'uuid': uuids}}},
|
|
122
|
+
)
|
|
123
|
+
|
|
111
124
|
logger.debug(f'Deleted Edges: {uuids}')
|
|
112
125
|
|
|
113
126
|
def __hash__(self):
|
|
@@ -255,6 +268,21 @@ class EntityEdge(Edge):
|
|
|
255
268
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
256
269
|
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
|
257
270
|
"""
|
|
271
|
+
elif driver.aoss_client:
|
|
272
|
+
resp = await driver.aoss_client.search(
|
|
273
|
+
body={
|
|
274
|
+
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
|
275
|
+
'size': 1,
|
|
276
|
+
},
|
|
277
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
278
|
+
params={'routing': self.group_id},
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
if resp['hits']['hits']:
|
|
282
|
+
self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
|
|
283
|
+
return
|
|
284
|
+
else:
|
|
285
|
+
raise EdgeNotFoundError(self.uuid)
|
|
258
286
|
|
|
259
287
|
if driver.provider == GraphProvider.KUZU:
|
|
260
288
|
query = """
|
|
@@ -292,14 +320,14 @@ class EntityEdge(Edge):
|
|
|
292
320
|
if driver.provider == GraphProvider.KUZU:
|
|
293
321
|
edge_data['attributes'] = json.dumps(self.attributes)
|
|
294
322
|
result = await driver.execute_query(
|
|
295
|
-
get_entity_edge_save_query(driver.provider),
|
|
323
|
+
get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
|
296
324
|
**edge_data,
|
|
297
325
|
)
|
|
298
326
|
else:
|
|
299
327
|
edge_data.update(self.attributes or {})
|
|
300
328
|
|
|
301
|
-
if driver.
|
|
302
|
-
driver.save_to_aoss(
|
|
329
|
+
if driver.aoss_client:
|
|
330
|
+
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
|
303
331
|
|
|
304
332
|
result = await driver.execute_query(
|
|
305
333
|
get_entity_edge_save_query(driver.provider),
|
|
@@ -336,6 +364,35 @@ class EntityEdge(Edge):
|
|
|
336
364
|
raise EdgeNotFoundError(uuid)
|
|
337
365
|
return edges[0]
|
|
338
366
|
|
|
367
|
+
@classmethod
|
|
368
|
+
async def get_between_nodes(
|
|
369
|
+
cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
|
|
370
|
+
):
|
|
371
|
+
match_query = """
|
|
372
|
+
MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
|
373
|
+
"""
|
|
374
|
+
if driver.provider == GraphProvider.KUZU:
|
|
375
|
+
match_query = """
|
|
376
|
+
MATCH (n:Entity {uuid: $source_node_uuid})
|
|
377
|
+
-[:RELATES_TO]->(e:RelatesToNode_)
|
|
378
|
+
-[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
records, _, _ = await driver.execute_query(
|
|
382
|
+
match_query
|
|
383
|
+
+ """
|
|
384
|
+
RETURN
|
|
385
|
+
"""
|
|
386
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
387
|
+
source_node_uuid=source_node_uuid,
|
|
388
|
+
target_node_uuid=target_node_uuid,
|
|
389
|
+
routing_='r',
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
393
|
+
|
|
394
|
+
return edges
|
|
395
|
+
|
|
339
396
|
@classmethod
|
|
340
397
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
341
398
|
if len(uuids) == 0:
|
|
@@ -587,8 +644,11 @@ def get_community_edge_from_record(record: Any):
|
|
|
587
644
|
|
|
588
645
|
|
|
589
646
|
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
|
590
|
-
|
|
647
|
+
# filter out falsey values from edges
|
|
648
|
+
filtered_edges = [edge for edge in edges if edge.fact]
|
|
649
|
+
|
|
650
|
+
if len(filtered_edges) == 0:
|
|
591
651
|
return
|
|
592
|
-
fact_embeddings = await embedder.create_batch([edge.fact for edge in
|
|
593
|
-
for edge, fact_embedding in zip(
|
|
652
|
+
fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
|
|
653
|
+
for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
|
|
594
654
|
edge.fact_embedding = fact_embedding
|
graphiti_core/embedder/client.py
CHANGED
|
@@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import os
|
|
17
18
|
from abc import ABC, abstractmethod
|
|
18
19
|
from collections.abc import Iterable
|
|
19
20
|
|
|
20
21
|
from pydantic import BaseModel, Field
|
|
21
22
|
|
|
22
|
-
EMBEDDING_DIM = 1024
|
|
23
|
+
EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024))
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class EmbedderConfig(BaseModel):
|
graphiti_core/graph_queries.py
CHANGED
|
@@ -71,12 +71,41 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
|
71
71
|
|
|
72
72
|
def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
73
73
|
if provider == GraphProvider.FALKORDB:
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
74
|
+
from typing import cast
|
|
75
|
+
|
|
76
|
+
from graphiti_core.driver.falkordb_driver import STOPWORDS
|
|
77
|
+
|
|
78
|
+
# Convert to string representation for embedding in queries
|
|
79
|
+
stopwords_str = str(STOPWORDS)
|
|
80
|
+
|
|
81
|
+
# Use type: ignore to satisfy LiteralString requirement while maintaining single source of truth
|
|
82
|
+
return cast(
|
|
83
|
+
list[LiteralString],
|
|
84
|
+
[
|
|
85
|
+
f"""CALL db.idx.fulltext.createNodeIndex(
|
|
86
|
+
{{
|
|
87
|
+
label: 'Episodic',
|
|
88
|
+
stopwords: {stopwords_str}
|
|
89
|
+
}},
|
|
90
|
+
'content', 'source', 'source_description', 'group_id'
|
|
91
|
+
)""",
|
|
92
|
+
f"""CALL db.idx.fulltext.createNodeIndex(
|
|
93
|
+
{{
|
|
94
|
+
label: 'Entity',
|
|
95
|
+
stopwords: {stopwords_str}
|
|
96
|
+
}},
|
|
97
|
+
'name', 'summary', 'group_id'
|
|
98
|
+
)""",
|
|
99
|
+
f"""CALL db.idx.fulltext.createNodeIndex(
|
|
100
|
+
{{
|
|
101
|
+
label: 'Community',
|
|
102
|
+
stopwords: {stopwords_str}
|
|
103
|
+
}},
|
|
104
|
+
'name', 'group_id'
|
|
105
|
+
)""",
|
|
106
|
+
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
|
107
|
+
],
|
|
108
|
+
)
|
|
80
109
|
|
|
81
110
|
if provider == GraphProvider.KUZU:
|
|
82
111
|
return [
|
graphiti_core/graphiti.py
CHANGED
|
@@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
|
|
|
60
60
|
from graphiti_core.search.search_filters import SearchFilters
|
|
61
61
|
from graphiti_core.search.search_utils import (
|
|
62
62
|
RELEVANT_SCHEMA_LIMIT,
|
|
63
|
-
get_edge_invalidation_candidates,
|
|
64
63
|
get_mentioned_nodes,
|
|
65
|
-
get_relevant_edges,
|
|
66
64
|
)
|
|
67
65
|
from graphiti_core.telemetry import capture_event
|
|
68
66
|
from graphiti_core.utils.bulk_utils import (
|
|
@@ -81,7 +79,6 @@ from graphiti_core.utils.maintenance.community_operations import (
|
|
|
81
79
|
update_community,
|
|
82
80
|
)
|
|
83
81
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
84
|
-
build_duplicate_of_edges,
|
|
85
82
|
build_episodic_edges,
|
|
86
83
|
extract_edges,
|
|
87
84
|
resolve_extracted_edge,
|
|
@@ -139,7 +136,6 @@ class Graphiti:
|
|
|
139
136
|
store_raw_episode_content: bool = True,
|
|
140
137
|
graph_driver: GraphDriver | None = None,
|
|
141
138
|
max_coroutines: int | None = None,
|
|
142
|
-
ensure_ascii: bool = False,
|
|
143
139
|
):
|
|
144
140
|
"""
|
|
145
141
|
Initialize a Graphiti instance.
|
|
@@ -172,10 +168,6 @@ class Graphiti:
|
|
|
172
168
|
max_coroutines : int | None, optional
|
|
173
169
|
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
|
|
174
170
|
If not set, the Graphiti default is used.
|
|
175
|
-
ensure_ascii : bool, optional
|
|
176
|
-
Whether to escape non-ASCII characters in JSON serialization for prompts. Defaults to False.
|
|
177
|
-
Set as False to preserve non-ASCII characters (e.g., Korean, Japanese, Chinese) in their
|
|
178
|
-
original form, making them readable in LLM logs and improving model understanding.
|
|
179
171
|
|
|
180
172
|
Returns
|
|
181
173
|
-------
|
|
@@ -205,7 +197,6 @@ class Graphiti:
|
|
|
205
197
|
|
|
206
198
|
self.store_raw_episode_content = store_raw_episode_content
|
|
207
199
|
self.max_coroutines = max_coroutines
|
|
208
|
-
self.ensure_ascii = ensure_ascii
|
|
209
200
|
if llm_client:
|
|
210
201
|
self.llm_client = llm_client
|
|
211
202
|
else:
|
|
@@ -224,7 +215,6 @@ class Graphiti:
|
|
|
224
215
|
llm_client=self.llm_client,
|
|
225
216
|
embedder=self.embedder,
|
|
226
217
|
cross_encoder=self.cross_encoder,
|
|
227
|
-
ensure_ascii=self.ensure_ascii,
|
|
228
218
|
)
|
|
229
219
|
|
|
230
220
|
# Capture telemetry event
|
|
@@ -458,12 +448,12 @@ class Graphiti:
|
|
|
458
448
|
start = time()
|
|
459
449
|
now = utc_now()
|
|
460
450
|
|
|
461
|
-
# if group_id is None, use the default group id by the provider
|
|
462
|
-
group_id = group_id or get_default_group_id(self.driver.provider)
|
|
463
451
|
validate_entity_types(entity_types)
|
|
464
452
|
|
|
465
453
|
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
|
466
454
|
validate_group_id(group_id)
|
|
455
|
+
# if group_id is None, use the default group id by the provider
|
|
456
|
+
group_id = group_id or get_default_group_id(self.driver.provider)
|
|
467
457
|
|
|
468
458
|
previous_episodes = (
|
|
469
459
|
await self.retrieve_episodes(
|
|
@@ -505,7 +495,7 @@ class Graphiti:
|
|
|
505
495
|
)
|
|
506
496
|
|
|
507
497
|
# Extract edges and resolve nodes
|
|
508
|
-
(nodes, uuid_map,
|
|
498
|
+
(nodes, uuid_map, _), extracted_edges = await semaphore_gather(
|
|
509
499
|
resolve_extracted_nodes(
|
|
510
500
|
self.clients,
|
|
511
501
|
extracted_nodes,
|
|
@@ -542,9 +532,7 @@ class Graphiti:
|
|
|
542
532
|
max_coroutines=self.max_coroutines,
|
|
543
533
|
)
|
|
544
534
|
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
|
|
535
|
+
entity_edges = resolved_edges + invalidated_edges
|
|
548
536
|
|
|
549
537
|
episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
|
|
550
538
|
|
|
@@ -564,9 +552,7 @@ class Graphiti:
|
|
|
564
552
|
if update_communities:
|
|
565
553
|
communities, community_edges = await semaphore_gather(
|
|
566
554
|
*[
|
|
567
|
-
update_community(
|
|
568
|
-
self.driver, self.llm_client, self.embedder, node, self.ensure_ascii
|
|
569
|
-
)
|
|
555
|
+
update_community(self.driver, self.llm_client, self.embedder, node)
|
|
570
556
|
for node in nodes
|
|
571
557
|
],
|
|
572
558
|
max_coroutines=self.max_coroutines,
|
|
@@ -1037,10 +1023,28 @@ class Graphiti:
|
|
|
1037
1023
|
|
|
1038
1024
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
|
1039
1025
|
|
|
1040
|
-
|
|
1026
|
+
valid_edges = await EntityEdge.get_between_nodes(
|
|
1027
|
+
self.driver, edge.source_node_uuid, edge.target_node_uuid
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
related_edges = (
|
|
1031
|
+
await search(
|
|
1032
|
+
self.clients,
|
|
1033
|
+
updated_edge.fact,
|
|
1034
|
+
group_ids=[updated_edge.group_id],
|
|
1035
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1036
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
1037
|
+
)
|
|
1038
|
+
).edges
|
|
1041
1039
|
existing_edges = (
|
|
1042
|
-
await
|
|
1043
|
-
|
|
1040
|
+
await search(
|
|
1041
|
+
self.clients,
|
|
1042
|
+
updated_edge.fact,
|
|
1043
|
+
group_ids=[updated_edge.group_id],
|
|
1044
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1045
|
+
search_filter=SearchFilters(),
|
|
1046
|
+
)
|
|
1047
|
+
).edges
|
|
1044
1048
|
|
|
1045
1049
|
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
|
1046
1050
|
self.llm_client,
|
|
@@ -1057,7 +1061,7 @@ class Graphiti:
|
|
|
1057
1061
|
group_id=edge.group_id,
|
|
1058
1062
|
),
|
|
1059
1063
|
None,
|
|
1060
|
-
|
|
1064
|
+
None,
|
|
1061
1065
|
)
|
|
1062
1066
|
|
|
1063
1067
|
edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
|
graphiti_core/graphiti_types.py
CHANGED
graphiti_core/helpers.py
CHANGED
|
@@ -54,7 +54,7 @@ def get_default_group_id(provider: GraphProvider) -> str:
|
|
|
54
54
|
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
|
55
55
|
"""
|
|
56
56
|
if provider == GraphProvider.FALKORDB:
|
|
57
|
-
return '_'
|
|
57
|
+
return '\\_'
|
|
58
58
|
else:
|
|
59
59
|
return ''
|
|
60
60
|
|
|
@@ -116,7 +116,7 @@ async def semaphore_gather(
|
|
|
116
116
|
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
|
|
117
117
|
|
|
118
118
|
|
|
119
|
-
def validate_group_id(group_id: str) -> bool:
|
|
119
|
+
def validate_group_id(group_id: str | None) -> bool:
|
|
120
120
|
"""
|
|
121
121
|
Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
|
|
122
122
|
|
|
@@ -32,9 +32,23 @@ from .errors import RateLimitError
|
|
|
32
32
|
DEFAULT_TEMPERATURE = 0
|
|
33
33
|
DEFAULT_CACHE_DIR = './llm_cache'
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
35
|
+
|
|
36
|
+
def get_extraction_language_instruction(group_id: str | None = None) -> str:
|
|
37
|
+
"""Returns instruction for language extraction behavior.
|
|
38
|
+
|
|
39
|
+
Override this function to customize language extraction:
|
|
40
|
+
- Return empty string to disable multilingual instructions
|
|
41
|
+
- Return custom instructions for specific language requirements
|
|
42
|
+
- Use group_id to provide different instructions per group/partition
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
group_id: Optional partition identifier for the graph
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
str: Language instruction to append to system messages
|
|
49
|
+
"""
|
|
50
|
+
return '\n\nAny extracted information should be returned in the same language as it was written in.'
|
|
51
|
+
|
|
38
52
|
|
|
39
53
|
logger = logging.getLogger(__name__)
|
|
40
54
|
|
|
@@ -132,6 +146,7 @@ class LLMClient(ABC):
|
|
|
132
146
|
response_model: type[BaseModel] | None = None,
|
|
133
147
|
max_tokens: int | None = None,
|
|
134
148
|
model_size: ModelSize = ModelSize.medium,
|
|
149
|
+
group_id: str | None = None,
|
|
135
150
|
) -> dict[str, typing.Any]:
|
|
136
151
|
if max_tokens is None:
|
|
137
152
|
max_tokens = self.max_tokens
|
|
@@ -145,7 +160,7 @@ class LLMClient(ABC):
|
|
|
145
160
|
)
|
|
146
161
|
|
|
147
162
|
# Add multilingual extraction instructions
|
|
148
|
-
messages[0].content +=
|
|
163
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
149
164
|
|
|
150
165
|
if self.cache_enabled and self.cache_dir is not None:
|
|
151
166
|
cache_key = self._get_cache_key(messages)
|
|
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, ClassVar
|
|
|
23
23
|
from pydantic import BaseModel
|
|
24
24
|
|
|
25
25
|
from ..prompts.models import Message
|
|
26
|
-
from .client import
|
|
26
|
+
from .client import LLMClient, get_extraction_language_instruction
|
|
27
27
|
from .config import LLMConfig, ModelSize
|
|
28
28
|
from .errors import RateLimitError
|
|
29
29
|
|
|
@@ -357,6 +357,7 @@ class GeminiClient(LLMClient):
|
|
|
357
357
|
response_model: type[BaseModel] | None = None,
|
|
358
358
|
max_tokens: int | None = None,
|
|
359
359
|
model_size: ModelSize = ModelSize.medium,
|
|
360
|
+
group_id: str | None = None,
|
|
360
361
|
) -> dict[str, typing.Any]:
|
|
361
362
|
"""
|
|
362
363
|
Generate a response from the Gemini language model with retry logic and error handling.
|
|
@@ -367,6 +368,7 @@ class GeminiClient(LLMClient):
|
|
|
367
368
|
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
|
368
369
|
max_tokens (int | None): The maximum number of tokens to generate in the response.
|
|
369
370
|
model_size (ModelSize): The size of the model to use (small or medium).
|
|
371
|
+
group_id (str | None): Optional partition identifier for the graph.
|
|
370
372
|
|
|
371
373
|
Returns:
|
|
372
374
|
dict[str, typing.Any]: The response from the language model.
|
|
@@ -376,7 +378,7 @@ class GeminiClient(LLMClient):
|
|
|
376
378
|
last_output = None
|
|
377
379
|
|
|
378
380
|
# Add multilingual extraction instructions
|
|
379
|
-
messages[0].content +=
|
|
381
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
380
382
|
|
|
381
383
|
while retry_count < self.MAX_RETRIES:
|
|
382
384
|
try:
|