graphiti-core 0.20.4__py3-none-any.whl → 0.21.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +170 -0
- graphiti_core/driver/falkordb_driver.py +1 -0
- graphiti_core/driver/kuzu_driver.py +1 -0
- graphiti_core/driver/neo4j_driver.py +48 -2
- graphiti_core/driver/neptune_driver.py +25 -44
- graphiti_core/edges.py +18 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/models/edges/edge_db_queries.py +36 -16
- graphiti_core/models/nodes/node_db_queries.py +30 -10
- graphiti_core/nodes.py +26 -18
- graphiti_core/search/search_filters.py +45 -0
- graphiti_core/search/search_utils.py +200 -56
- graphiti_core/utils/bulk_utils.py +15 -2
- graphiti_core/utils/maintenance/graph_data_operations.py +7 -5
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc1.dist-info}/METADATA +4 -1
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc1.dist-info}/RECORD +18 -18
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc1.dist-info}/WHEEL +0 -0
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc1.dist-info}/licenses/LICENSE +0 -0
graphiti_core/nodes.py
CHANGED
|
@@ -273,20 +273,6 @@ class EpisodicNode(Node):
|
|
|
273
273
|
)
|
|
274
274
|
|
|
275
275
|
async def save(self, driver: GraphDriver):
|
|
276
|
-
if driver.provider == GraphProvider.NEPTUNE:
|
|
277
|
-
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
278
|
-
'episode_content',
|
|
279
|
-
[
|
|
280
|
-
{
|
|
281
|
-
'uuid': self.uuid,
|
|
282
|
-
'group_id': self.group_id,
|
|
283
|
-
'source': self.source.value,
|
|
284
|
-
'content': self.content,
|
|
285
|
-
'source_description': self.source_description,
|
|
286
|
-
}
|
|
287
|
-
],
|
|
288
|
-
)
|
|
289
|
-
|
|
290
276
|
episode_args = {
|
|
291
277
|
'uuid': self.uuid,
|
|
292
278
|
'name': self.name,
|
|
@@ -299,6 +285,12 @@ class EpisodicNode(Node):
|
|
|
299
285
|
'source': self.source.value,
|
|
300
286
|
}
|
|
301
287
|
|
|
288
|
+
if driver.aoss_client:
|
|
289
|
+
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
290
|
+
'episodes',
|
|
291
|
+
[episode_args],
|
|
292
|
+
)
|
|
293
|
+
|
|
302
294
|
result = await driver.execute_query(
|
|
303
295
|
get_episode_node_save_query(driver.provider), **episode_args
|
|
304
296
|
)
|
|
@@ -433,6 +425,22 @@ class EntityNode(Node):
|
|
|
433
425
|
MATCH (n:Entity {uuid: $uuid})
|
|
434
426
|
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
435
427
|
"""
|
|
428
|
+
elif driver.aoss_client:
|
|
429
|
+
resp = driver.aoss_client.search(
|
|
430
|
+
body={
|
|
431
|
+
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
|
432
|
+
'size': 1,
|
|
433
|
+
},
|
|
434
|
+
index='entities',
|
|
435
|
+
routing=self.group_id,
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
if resp['hits']['hits']:
|
|
439
|
+
self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
|
|
440
|
+
return
|
|
441
|
+
else:
|
|
442
|
+
raise NodeNotFoundError(self.uuid)
|
|
443
|
+
|
|
436
444
|
else:
|
|
437
445
|
query: LiteralString = """
|
|
438
446
|
MATCH (n:Entity {uuid: $uuid})
|
|
@@ -470,11 +478,11 @@ class EntityNode(Node):
|
|
|
470
478
|
entity_data.update(self.attributes or {})
|
|
471
479
|
labels = ':'.join(self.labels + ['Entity'])
|
|
472
480
|
|
|
473
|
-
if driver.
|
|
474
|
-
driver.save_to_aoss('
|
|
481
|
+
if driver.aoss_client:
|
|
482
|
+
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
|
475
483
|
|
|
476
484
|
result = await driver.execute_query(
|
|
477
|
-
get_entity_node_save_query(driver.provider, labels),
|
|
485
|
+
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
|
478
486
|
entity_data=entity_data,
|
|
479
487
|
)
|
|
480
488
|
|
|
@@ -570,7 +578,7 @@ class CommunityNode(Node):
|
|
|
570
578
|
async def save(self, driver: GraphDriver):
|
|
571
579
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
572
580
|
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
573
|
-
'
|
|
581
|
+
'communities',
|
|
574
582
|
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
|
575
583
|
)
|
|
576
584
|
result = await driver.execute_query(
|
|
@@ -54,6 +54,16 @@ class SearchFilters(BaseModel):
|
|
|
54
54
|
expired_at: list[list[DateFilter]] | None = Field(default=None)
|
|
55
55
|
|
|
56
56
|
|
|
57
|
+
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
|
|
58
|
+
mapping = {
|
|
59
|
+
ComparisonOperator.greater_than: 'gt',
|
|
60
|
+
ComparisonOperator.less_than: 'lt',
|
|
61
|
+
ComparisonOperator.greater_than_equal: 'gte',
|
|
62
|
+
ComparisonOperator.less_than_equal: 'lte',
|
|
63
|
+
}
|
|
64
|
+
return mapping.get(op, op.value)
|
|
65
|
+
|
|
66
|
+
|
|
57
67
|
def node_search_filter_query_constructor(
|
|
58
68
|
filters: SearchFilters,
|
|
59
69
|
provider: GraphProvider,
|
|
@@ -234,3 +244,38 @@ def edge_search_filter_query_constructor(
|
|
|
234
244
|
filter_queries.append(expired_at_filter)
|
|
235
245
|
|
|
236
246
|
return filter_queries, filter_params
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
|
250
|
+
filters = [{'terms': {'group_id': group_ids}}]
|
|
251
|
+
|
|
252
|
+
if search_filters.node_labels:
|
|
253
|
+
filters.append({'terms': {'node_labels': search_filters.node_labels}})
|
|
254
|
+
|
|
255
|
+
return filters
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
|
259
|
+
filters: list[dict] = [{'terms': {'group_id': group_ids}}]
|
|
260
|
+
|
|
261
|
+
if search_filters.edge_types:
|
|
262
|
+
filters.append({'terms': {'edge_types': search_filters.edge_types}})
|
|
263
|
+
|
|
264
|
+
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
|
|
265
|
+
ranges = getattr(search_filters, field)
|
|
266
|
+
if ranges:
|
|
267
|
+
# OR of ANDs
|
|
268
|
+
should_clauses = []
|
|
269
|
+
for and_group in ranges:
|
|
270
|
+
and_filters = []
|
|
271
|
+
for df in and_group: # df is a DateFilter
|
|
272
|
+
range_query = {
|
|
273
|
+
'range': {
|
|
274
|
+
field: {cypher_to_opensearch_operator(df.comparison_operator): df.date}
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
and_filters.append(range_query)
|
|
278
|
+
should_clauses.append({'bool': {'filter': and_filters}})
|
|
279
|
+
filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
|
|
280
|
+
|
|
281
|
+
return filters
|
|
@@ -51,6 +51,8 @@ from graphiti_core.nodes import (
|
|
|
51
51
|
)
|
|
52
52
|
from graphiti_core.search.search_filters import (
|
|
53
53
|
SearchFilters,
|
|
54
|
+
build_aoss_edge_filters,
|
|
55
|
+
build_aoss_node_filters,
|
|
54
56
|
edge_search_filter_query_constructor,
|
|
55
57
|
node_search_filter_query_constructor,
|
|
56
58
|
)
|
|
@@ -200,7 +202,6 @@ async def edge_fulltext_search(
|
|
|
200
202
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
201
203
|
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
|
|
202
204
|
if res['hits']['total']['value'] > 0:
|
|
203
|
-
# Calculate Cosine similarity then return the edge ids
|
|
204
205
|
input_ids = []
|
|
205
206
|
for r in res['hits']['hits']:
|
|
206
207
|
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
@@ -208,11 +209,11 @@ async def edge_fulltext_search(
|
|
|
208
209
|
# Match the edge ids and return the values
|
|
209
210
|
query = (
|
|
210
211
|
"""
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
212
|
+
UNWIND $ids as id
|
|
213
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
214
|
+
WHERE e.group_id IN $group_ids
|
|
215
|
+
AND id(e)=id
|
|
216
|
+
"""
|
|
216
217
|
+ filter_query
|
|
217
218
|
+ """
|
|
218
219
|
AND id(e)=id
|
|
@@ -244,6 +245,31 @@ async def edge_fulltext_search(
|
|
|
244
245
|
)
|
|
245
246
|
else:
|
|
246
247
|
return []
|
|
248
|
+
elif driver.aoss_client:
|
|
249
|
+
route = group_ids[0] if group_ids else None
|
|
250
|
+
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
|
251
|
+
res = driver.aoss_client.search(
|
|
252
|
+
index='entity_edges',
|
|
253
|
+
routing=route,
|
|
254
|
+
_source=['uuid'],
|
|
255
|
+
query={
|
|
256
|
+
'bool': {
|
|
257
|
+
'filter': filters,
|
|
258
|
+
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
|
|
259
|
+
}
|
|
260
|
+
},
|
|
261
|
+
)
|
|
262
|
+
if res['hits']['total']['value'] > 0:
|
|
263
|
+
input_uuids = {}
|
|
264
|
+
for r in res['hits']['hits']:
|
|
265
|
+
input_uuids[r['_source']['uuid']] = r['_score']
|
|
266
|
+
|
|
267
|
+
# Get edges
|
|
268
|
+
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
|
269
|
+
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
270
|
+
return entity_edges
|
|
271
|
+
else:
|
|
272
|
+
return []
|
|
247
273
|
else:
|
|
248
274
|
query = (
|
|
249
275
|
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
|
|
@@ -318,8 +344,8 @@ async def edge_similarity_search(
|
|
|
318
344
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
319
345
|
query = (
|
|
320
346
|
"""
|
|
321
|
-
|
|
322
|
-
|
|
347
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
348
|
+
"""
|
|
323
349
|
+ filter_query
|
|
324
350
|
+ """
|
|
325
351
|
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
|
@@ -377,6 +403,32 @@ async def edge_similarity_search(
|
|
|
377
403
|
)
|
|
378
404
|
else:
|
|
379
405
|
return []
|
|
406
|
+
elif driver.aoss_client:
|
|
407
|
+
route = group_ids[0] if group_ids else None
|
|
408
|
+
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
|
409
|
+
res = driver.aoss_client.search(
|
|
410
|
+
index='entity_edges',
|
|
411
|
+
routing=route,
|
|
412
|
+
_source=['uuid'],
|
|
413
|
+
knn={
|
|
414
|
+
'field': 'fact_embedding',
|
|
415
|
+
'query_vector': search_vector,
|
|
416
|
+
'k': limit,
|
|
417
|
+
'num_candidates': 1000,
|
|
418
|
+
},
|
|
419
|
+
query={'bool': {'filter': filters}},
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if res['hits']['total']['value'] > 0:
|
|
423
|
+
input_uuids = {}
|
|
424
|
+
for r in res['hits']['hits']:
|
|
425
|
+
input_uuids[r['_source']['uuid']] = r['_score']
|
|
426
|
+
|
|
427
|
+
# Get edges
|
|
428
|
+
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
|
429
|
+
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
430
|
+
return entity_edges
|
|
431
|
+
|
|
380
432
|
else:
|
|
381
433
|
query = (
|
|
382
434
|
match_query
|
|
@@ -563,7 +615,6 @@ async def node_fulltext_search(
|
|
|
563
615
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
564
616
|
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
565
617
|
if res['hits']['total']['value'] > 0:
|
|
566
|
-
# Calculate Cosine similarity then return the edge ids
|
|
567
618
|
input_ids = []
|
|
568
619
|
for r in res['hits']['hits']:
|
|
569
620
|
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
@@ -571,11 +622,11 @@ async def node_fulltext_search(
|
|
|
571
622
|
# Match the edge ides and return the values
|
|
572
623
|
query = (
|
|
573
624
|
"""
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
625
|
+
UNWIND $ids as i
|
|
626
|
+
MATCH (n:Entity)
|
|
627
|
+
WHERE n.uuid=i.id
|
|
628
|
+
RETURN
|
|
629
|
+
"""
|
|
579
630
|
+ get_entity_node_return_query(driver.provider)
|
|
580
631
|
+ """
|
|
581
632
|
ORDER BY i.score DESC
|
|
@@ -592,6 +643,41 @@ async def node_fulltext_search(
|
|
|
592
643
|
)
|
|
593
644
|
else:
|
|
594
645
|
return []
|
|
646
|
+
elif driver.aoss_client:
|
|
647
|
+
route = group_ids[0] if group_ids else None
|
|
648
|
+
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
|
649
|
+
res = driver.aoss_client.search(
|
|
650
|
+
'entities',
|
|
651
|
+
routing=route,
|
|
652
|
+
_source=['uuid'],
|
|
653
|
+
query={
|
|
654
|
+
'bool': {
|
|
655
|
+
'filter': filters,
|
|
656
|
+
'must': [
|
|
657
|
+
{
|
|
658
|
+
'multi_match': {
|
|
659
|
+
'query': query,
|
|
660
|
+
'field': ['name', 'summary'],
|
|
661
|
+
'operator': 'or',
|
|
662
|
+
}
|
|
663
|
+
}
|
|
664
|
+
],
|
|
665
|
+
}
|
|
666
|
+
},
|
|
667
|
+
limit=limit,
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
if res['hits']['total']['value'] > 0:
|
|
671
|
+
input_uuids = {}
|
|
672
|
+
for r in res['hits']['hits']:
|
|
673
|
+
input_uuids[r['_source']['uuid']] = r['_score']
|
|
674
|
+
|
|
675
|
+
# Get nodes
|
|
676
|
+
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
|
677
|
+
entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
678
|
+
return entities
|
|
679
|
+
else:
|
|
680
|
+
return []
|
|
595
681
|
else:
|
|
596
682
|
query = (
|
|
597
683
|
get_nodes_query(
|
|
@@ -648,8 +734,8 @@ async def node_similarity_search(
|
|
|
648
734
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
649
735
|
query = (
|
|
650
736
|
"""
|
|
651
|
-
|
|
652
|
-
|
|
737
|
+
MATCH (n:Entity)
|
|
738
|
+
"""
|
|
653
739
|
+ filter_query
|
|
654
740
|
+ """
|
|
655
741
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
@@ -678,11 +764,11 @@ async def node_similarity_search(
|
|
|
678
764
|
# Match the edge ides and return the values
|
|
679
765
|
query = (
|
|
680
766
|
"""
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
767
|
+
UNWIND $ids as i
|
|
768
|
+
MATCH (n:Entity)
|
|
769
|
+
WHERE id(n)=i.id
|
|
770
|
+
RETURN
|
|
771
|
+
"""
|
|
686
772
|
+ get_entity_node_return_query(driver.provider)
|
|
687
773
|
+ """
|
|
688
774
|
ORDER BY i.score DESC
|
|
@@ -700,11 +786,36 @@ async def node_similarity_search(
|
|
|
700
786
|
)
|
|
701
787
|
else:
|
|
702
788
|
return []
|
|
789
|
+
elif driver.aoss_client:
|
|
790
|
+
route = group_ids[0] if group_ids else None
|
|
791
|
+
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
|
792
|
+
res = driver.aoss_client.search(
|
|
793
|
+
index='entities',
|
|
794
|
+
routing=route,
|
|
795
|
+
_source=['uuid'],
|
|
796
|
+
knn={
|
|
797
|
+
'field': 'fact_embedding',
|
|
798
|
+
'query_vector': search_vector,
|
|
799
|
+
'k': limit,
|
|
800
|
+
'num_candidates': 1000,
|
|
801
|
+
},
|
|
802
|
+
query={'bool': {'filter': filters}},
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
if res['hits']['total']['value'] > 0:
|
|
806
|
+
input_uuids = {}
|
|
807
|
+
for r in res['hits']['hits']:
|
|
808
|
+
input_uuids[r['_source']['uuid']] = r['_score']
|
|
809
|
+
|
|
810
|
+
# Get edges
|
|
811
|
+
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
|
812
|
+
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
813
|
+
return entity_nodes
|
|
703
814
|
else:
|
|
704
815
|
query = (
|
|
705
816
|
"""
|
|
706
|
-
|
|
707
|
-
|
|
817
|
+
MATCH (n:Entity)
|
|
818
|
+
"""
|
|
708
819
|
+ filter_query
|
|
709
820
|
+ """
|
|
710
821
|
WITH n, """
|
|
@@ -843,7 +954,6 @@ async def episode_fulltext_search(
|
|
|
843
954
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
844
955
|
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
845
956
|
if res['hits']['total']['value'] > 0:
|
|
846
|
-
# Calculate Cosine similarity then return the edge ids
|
|
847
957
|
input_ids = []
|
|
848
958
|
for r in res['hits']['hits']:
|
|
849
959
|
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
@@ -852,7 +962,7 @@ async def episode_fulltext_search(
|
|
|
852
962
|
query = """
|
|
853
963
|
UNWIND $ids as i
|
|
854
964
|
MATCH (e:Episodic)
|
|
855
|
-
WHERE e.uuid=i.
|
|
965
|
+
WHERE e.uuid=i.uuid
|
|
856
966
|
RETURN
|
|
857
967
|
e.content AS content,
|
|
858
968
|
e.created_at AS created_at,
|
|
@@ -876,6 +986,40 @@ async def episode_fulltext_search(
|
|
|
876
986
|
)
|
|
877
987
|
else:
|
|
878
988
|
return []
|
|
989
|
+
elif driver.aoss_client:
|
|
990
|
+
route = group_ids[0] if group_ids else None
|
|
991
|
+
res = driver.aoss_client.search(
|
|
992
|
+
'episodes',
|
|
993
|
+
routing=route,
|
|
994
|
+
_source=['uuid'],
|
|
995
|
+
query={
|
|
996
|
+
'bool': {
|
|
997
|
+
'filter': {'terms': group_ids},
|
|
998
|
+
'must': [
|
|
999
|
+
{
|
|
1000
|
+
'multi_match': {
|
|
1001
|
+
'query': query,
|
|
1002
|
+
'field': ['name', 'content'],
|
|
1003
|
+
'operator': 'or',
|
|
1004
|
+
}
|
|
1005
|
+
}
|
|
1006
|
+
],
|
|
1007
|
+
}
|
|
1008
|
+
},
|
|
1009
|
+
limit=limit,
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
if res['hits']['total']['value'] > 0:
|
|
1013
|
+
input_uuids = {}
|
|
1014
|
+
for r in res['hits']['hits']:
|
|
1015
|
+
input_uuids[r['_source']['uuid']] = r['_score']
|
|
1016
|
+
|
|
1017
|
+
# Get nodes
|
|
1018
|
+
episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
|
|
1019
|
+
episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
1020
|
+
return episodes
|
|
1021
|
+
else:
|
|
1022
|
+
return []
|
|
879
1023
|
else:
|
|
880
1024
|
query = (
|
|
881
1025
|
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
|
|
@@ -1003,8 +1147,8 @@ async def community_similarity_search(
|
|
|
1003
1147
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
1004
1148
|
query = (
|
|
1005
1149
|
"""
|
|
1006
|
-
|
|
1007
|
-
|
|
1150
|
+
MATCH (n:Community)
|
|
1151
|
+
"""
|
|
1008
1152
|
+ group_filter_query
|
|
1009
1153
|
+ """
|
|
1010
1154
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
@@ -1063,8 +1207,8 @@ async def community_similarity_search(
|
|
|
1063
1207
|
|
|
1064
1208
|
query = (
|
|
1065
1209
|
"""
|
|
1066
|
-
|
|
1067
|
-
|
|
1210
|
+
MATCH (c:Community)
|
|
1211
|
+
"""
|
|
1068
1212
|
+ group_filter_query
|
|
1069
1213
|
+ """
|
|
1070
1214
|
WITH c,
|
|
@@ -1206,9 +1350,9 @@ async def get_relevant_nodes(
|
|
|
1206
1350
|
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
|
1207
1351
|
query = (
|
|
1208
1352
|
"""
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1353
|
+
UNWIND $nodes AS node
|
|
1354
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1355
|
+
"""
|
|
1212
1356
|
+ filter_query
|
|
1213
1357
|
+ """
|
|
1214
1358
|
WITH node, n, """
|
|
@@ -1253,9 +1397,9 @@ async def get_relevant_nodes(
|
|
|
1253
1397
|
else:
|
|
1254
1398
|
query = (
|
|
1255
1399
|
"""
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1400
|
+
UNWIND $nodes AS node
|
|
1401
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1402
|
+
"""
|
|
1259
1403
|
+ filter_query
|
|
1260
1404
|
+ """
|
|
1261
1405
|
WITH node, n, """
|
|
@@ -1344,9 +1488,9 @@ async def get_relevant_edges(
|
|
|
1344
1488
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
1345
1489
|
query = (
|
|
1346
1490
|
"""
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1491
|
+
UNWIND $edges AS edge
|
|
1492
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1493
|
+
"""
|
|
1350
1494
|
+ filter_query
|
|
1351
1495
|
+ """
|
|
1352
1496
|
WITH e, edge
|
|
@@ -1416,9 +1560,9 @@ async def get_relevant_edges(
|
|
|
1416
1560
|
|
|
1417
1561
|
query = (
|
|
1418
1562
|
"""
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1563
|
+
UNWIND $edges AS edge
|
|
1564
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1565
|
+
"""
|
|
1422
1566
|
+ filter_query
|
|
1423
1567
|
+ """
|
|
1424
1568
|
WITH e, edge, n, m, """
|
|
@@ -1454,9 +1598,9 @@ async def get_relevant_edges(
|
|
|
1454
1598
|
else:
|
|
1455
1599
|
query = (
|
|
1456
1600
|
"""
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1601
|
+
UNWIND $edges AS edge
|
|
1602
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1603
|
+
"""
|
|
1460
1604
|
+ filter_query
|
|
1461
1605
|
+ """
|
|
1462
1606
|
WITH e, edge, """
|
|
@@ -1529,10 +1673,10 @@ async def get_edge_invalidation_candidates(
|
|
|
1529
1673
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
1530
1674
|
query = (
|
|
1531
1675
|
"""
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1676
|
+
UNWIND $edges AS edge
|
|
1677
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1678
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1679
|
+
"""
|
|
1536
1680
|
+ filter_query
|
|
1537
1681
|
+ """
|
|
1538
1682
|
WITH e, edge
|
|
@@ -1602,10 +1746,10 @@ async def get_edge_invalidation_candidates(
|
|
|
1602
1746
|
|
|
1603
1747
|
query = (
|
|
1604
1748
|
"""
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
1749
|
+
UNWIND $edges AS edge
|
|
1750
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
|
1751
|
+
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
|
1752
|
+
"""
|
|
1609
1753
|
+ filter_query
|
|
1610
1754
|
+ """
|
|
1611
1755
|
WITH edge, e, n, m, """
|
|
@@ -1641,10 +1785,10 @@ async def get_edge_invalidation_candidates(
|
|
|
1641
1785
|
else:
|
|
1642
1786
|
query = (
|
|
1643
1787
|
"""
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1788
|
+
UNWIND $edges AS edge
|
|
1789
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1790
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1791
|
+
"""
|
|
1648
1792
|
+ filter_query
|
|
1649
1793
|
+ """
|
|
1650
1794
|
WITH edge, e, """
|
|
@@ -187,12 +187,25 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
187
187
|
await tx.run(episodic_edge_query, **edge.model_dump())
|
|
188
188
|
else:
|
|
189
189
|
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
|
190
|
-
await tx.run(
|
|
190
|
+
await tx.run(
|
|
191
|
+
get_entity_node_save_bulk_query(driver.provider, nodes),
|
|
192
|
+
nodes=nodes,
|
|
193
|
+
has_aoss=bool(driver.aoss_client),
|
|
194
|
+
)
|
|
191
195
|
await tx.run(
|
|
192
196
|
get_episodic_edge_save_bulk_query(driver.provider),
|
|
193
197
|
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
|
194
198
|
)
|
|
195
|
-
await tx.run(
|
|
199
|
+
await tx.run(
|
|
200
|
+
get_entity_edge_save_bulk_query(driver.provider),
|
|
201
|
+
entity_edges=edges,
|
|
202
|
+
has_aoss=bool(driver.aoss_client),
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if driver.aoss_client:
|
|
206
|
+
driver.save_to_aoss('episodes', episodes)
|
|
207
|
+
driver.save_to_aoss('entities', nodes)
|
|
208
|
+
driver.save_to_aoss('entity_edges', edges)
|
|
196
209
|
|
|
197
210
|
|
|
198
211
|
async def extract_nodes_and_edges_bulk(
|
|
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
|
37
|
-
if driver.
|
|
37
|
+
if driver.aoss_client:
|
|
38
38
|
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
|
|
39
39
|
return
|
|
40
40
|
if delete_existing:
|
|
@@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
|
|
56
56
|
|
|
57
57
|
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
|
58
58
|
|
|
59
|
-
|
|
59
|
+
# Don't create fulltext indices if OpenSearch is being used
|
|
60
|
+
if not driver.aoss_client:
|
|
61
|
+
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
|
60
62
|
|
|
61
63
|
if driver.provider == GraphProvider.KUZU:
|
|
62
64
|
# Skip creating fulltext indices if they already exist. Need to do this manually
|
|
@@ -149,9 +151,9 @@ async def retrieve_episodes(
|
|
|
149
151
|
|
|
150
152
|
query: LiteralString = (
|
|
151
153
|
"""
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
154
|
+
MATCH (e:Episodic)
|
|
155
|
+
WHERE e.valid_at <= $reference_time
|
|
156
|
+
"""
|
|
155
157
|
+ query_filter
|
|
156
158
|
+ """
|
|
157
159
|
RETURN
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.21.0rc1
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
Project-URL: Homepage, https://help.getzep.com/graphiti/graphiti/overview
|
|
6
6
|
Project-URL: Repository, https://github.com/getzep/graphiti
|
|
@@ -47,6 +47,9 @@ Provides-Extra: groq
|
|
|
47
47
|
Requires-Dist: groq>=0.2.0; extra == 'groq'
|
|
48
48
|
Provides-Extra: kuzu
|
|
49
49
|
Requires-Dist: kuzu>=0.11.2; extra == 'kuzu'
|
|
50
|
+
Provides-Extra: neo4j-opensearch
|
|
51
|
+
Requires-Dist: boto3>=1.39.16; extra == 'neo4j-opensearch'
|
|
52
|
+
Requires-Dist: opensearch-py>=3.0.0; extra == 'neo4j-opensearch'
|
|
50
53
|
Provides-Extra: neptune
|
|
51
54
|
Requires-Dist: boto3>=1.39.16; extra == 'neptune'
|
|
52
55
|
Requires-Dist: langchain-aws>=0.2.29; extra == 'neptune'
|