graphiti-core 0.19.0rc3__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +3 -0
- graphiti_core/driver/falkordb_driver.py +3 -14
- graphiti_core/driver/kuzu_driver.py +175 -0
- graphiti_core/driver/neptune_driver.py +2 -0
- graphiti_core/edges.py +148 -83
- graphiti_core/graph_queries.py +31 -2
- graphiti_core/graphiti.py +4 -1
- graphiti_core/helpers.py +7 -12
- graphiti_core/migrations/neo4j_node_group_labels.py +33 -4
- graphiti_core/models/edges/edge_db_queries.py +121 -42
- graphiti_core/models/nodes/node_db_queries.py +102 -23
- graphiti_core/nodes.py +169 -66
- graphiti_core/search/search.py +13 -3
- graphiti_core/search/search_config.py +4 -0
- graphiti_core/search/search_filters.py +35 -22
- graphiti_core/search/search_utils.py +693 -382
- graphiti_core/utils/bulk_utils.py +50 -18
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +39 -32
- graphiti_core/utils/maintenance/edge_operations.py +19 -8
- graphiti_core/utils/maintenance/graph_data_operations.py +77 -47
- {graphiti_core-0.19.0rc3.dist-info → graphiti_core-0.20.0.dist-info}/METADATA +116 -48
- {graphiti_core-0.19.0rc3.dist-info → graphiti_core-0.20.0.dist-info}/RECORD +25 -24
- {graphiti_core-0.19.0rc3.dist-info → graphiti_core-0.20.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.19.0rc3.dist-info → graphiti_core-0.20.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/nodes.py
CHANGED
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
from abc import ABC, abstractmethod
|
|
19
20
|
from datetime import datetime
|
|
@@ -32,10 +33,10 @@ from graphiti_core.helpers import parse_db_date
|
|
|
32
33
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
33
34
|
COMMUNITY_NODE_RETURN,
|
|
34
35
|
COMMUNITY_NODE_RETURN_NEPTUNE,
|
|
35
|
-
ENTITY_NODE_RETURN,
|
|
36
36
|
EPISODIC_NODE_RETURN,
|
|
37
37
|
EPISODIC_NODE_RETURN_NEPTUNE,
|
|
38
38
|
get_community_node_save_query,
|
|
39
|
+
get_entity_node_return_query,
|
|
39
40
|
get_entity_node_save_query,
|
|
40
41
|
get_episode_node_save_query,
|
|
41
42
|
)
|
|
@@ -95,12 +96,37 @@ class Node(BaseModel, ABC):
|
|
|
95
96
|
case GraphProvider.NEO4J:
|
|
96
97
|
await driver.execute_query(
|
|
97
98
|
"""
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
99
|
+
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
|
100
|
+
DETACH DELETE n
|
|
101
|
+
""",
|
|
102
|
+
uuid=self.uuid,
|
|
103
|
+
)
|
|
104
|
+
case GraphProvider.KUZU:
|
|
105
|
+
for label in ['Episodic', 'Community']:
|
|
106
|
+
await driver.execute_query(
|
|
107
|
+
f"""
|
|
108
|
+
MATCH (n:{label} {{uuid: $uuid}})
|
|
109
|
+
DETACH DELETE n
|
|
110
|
+
""",
|
|
111
|
+
uuid=self.uuid,
|
|
112
|
+
)
|
|
113
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
114
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
115
|
+
await driver.execute_query(
|
|
116
|
+
"""
|
|
117
|
+
MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
|
|
118
|
+
DETACH DELETE e
|
|
119
|
+
""",
|
|
101
120
|
uuid=self.uuid,
|
|
102
121
|
)
|
|
103
|
-
|
|
122
|
+
await driver.execute_query(
|
|
123
|
+
"""
|
|
124
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
125
|
+
DETACH DELETE n
|
|
126
|
+
""",
|
|
127
|
+
uuid=self.uuid,
|
|
128
|
+
)
|
|
129
|
+
case _: # FalkorDB, Neptune
|
|
104
130
|
for label in ['Entity', 'Episodic', 'Community']:
|
|
105
131
|
await driver.execute_query(
|
|
106
132
|
f"""
|
|
@@ -136,8 +162,32 @@ class Node(BaseModel, ABC):
|
|
|
136
162
|
group_id=group_id,
|
|
137
163
|
batch_size=batch_size,
|
|
138
164
|
)
|
|
139
|
-
|
|
140
|
-
|
|
165
|
+
case GraphProvider.KUZU:
|
|
166
|
+
for label in ['Episodic', 'Community']:
|
|
167
|
+
await driver.execute_query(
|
|
168
|
+
f"""
|
|
169
|
+
MATCH (n:{label} {{group_id: $group_id}})
|
|
170
|
+
DETACH DELETE n
|
|
171
|
+
""",
|
|
172
|
+
group_id=group_id,
|
|
173
|
+
)
|
|
174
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
175
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
176
|
+
await driver.execute_query(
|
|
177
|
+
"""
|
|
178
|
+
MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
|
|
179
|
+
DETACH DELETE e
|
|
180
|
+
""",
|
|
181
|
+
group_id=group_id,
|
|
182
|
+
)
|
|
183
|
+
await driver.execute_query(
|
|
184
|
+
"""
|
|
185
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
186
|
+
DETACH DELETE n
|
|
187
|
+
""",
|
|
188
|
+
group_id=group_id,
|
|
189
|
+
)
|
|
190
|
+
case _: # FalkorDB, Neptune
|
|
141
191
|
for label in ['Entity', 'Episodic', 'Community']:
|
|
142
192
|
await driver.execute_query(
|
|
143
193
|
f"""
|
|
@@ -149,30 +199,59 @@ class Node(BaseModel, ABC):
|
|
|
149
199
|
|
|
150
200
|
@classmethod
|
|
151
201
|
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
|
152
|
-
|
|
153
|
-
|
|
202
|
+
match driver.provider:
|
|
203
|
+
case GraphProvider.FALKORDB:
|
|
204
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
205
|
+
await driver.execute_query(
|
|
206
|
+
f"""
|
|
207
|
+
MATCH (n:{label})
|
|
208
|
+
WHERE n.uuid IN $uuids
|
|
209
|
+
DETACH DELETE n
|
|
210
|
+
""",
|
|
211
|
+
uuids=uuids,
|
|
212
|
+
)
|
|
213
|
+
case GraphProvider.KUZU:
|
|
214
|
+
for label in ['Episodic', 'Community']:
|
|
215
|
+
await driver.execute_query(
|
|
216
|
+
f"""
|
|
217
|
+
MATCH (n:{label})
|
|
218
|
+
WHERE n.uuid IN $uuids
|
|
219
|
+
DETACH DELETE n
|
|
220
|
+
""",
|
|
221
|
+
uuids=uuids,
|
|
222
|
+
)
|
|
223
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
224
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
154
225
|
await driver.execute_query(
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
226
|
+
"""
|
|
227
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
|
|
228
|
+
WHERE n.uuid IN $uuids
|
|
229
|
+
DETACH DELETE e
|
|
230
|
+
""",
|
|
160
231
|
uuids=uuids,
|
|
161
232
|
)
|
|
162
|
-
|
|
163
|
-
async with driver.session() as session:
|
|
164
|
-
await session.run(
|
|
233
|
+
await driver.execute_query(
|
|
165
234
|
"""
|
|
166
|
-
MATCH (n:Entity
|
|
235
|
+
MATCH (n:Entity)
|
|
167
236
|
WHERE n.uuid IN $uuids
|
|
168
|
-
|
|
169
|
-
WITH n
|
|
170
|
-
DETACH DELETE n
|
|
171
|
-
} IN TRANSACTIONS OF $batch_size ROWS
|
|
237
|
+
DETACH DELETE n
|
|
172
238
|
""",
|
|
173
239
|
uuids=uuids,
|
|
174
|
-
batch_size=batch_size,
|
|
175
240
|
)
|
|
241
|
+
case _: # Neo4J, Neptune
|
|
242
|
+
async with driver.session() as session:
|
|
243
|
+
await session.run(
|
|
244
|
+
"""
|
|
245
|
+
MATCH (n:Entity|Episodic|Community)
|
|
246
|
+
WHERE n.uuid IN $uuids
|
|
247
|
+
CALL {
|
|
248
|
+
WITH n
|
|
249
|
+
DETACH DELETE n
|
|
250
|
+
} IN TRANSACTIONS OF $batch_size ROWS
|
|
251
|
+
""",
|
|
252
|
+
uuids=uuids,
|
|
253
|
+
batch_size=batch_size,
|
|
254
|
+
)
|
|
176
255
|
|
|
177
256
|
@classmethod
|
|
178
257
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
@@ -207,18 +286,24 @@ class EpisodicNode(Node):
|
|
|
207
286
|
}
|
|
208
287
|
],
|
|
209
288
|
)
|
|
289
|
+
|
|
290
|
+
episode_args = {
|
|
291
|
+
'uuid': self.uuid,
|
|
292
|
+
'name': self.name,
|
|
293
|
+
'group_id': self.group_id,
|
|
294
|
+
'source_description': self.source_description,
|
|
295
|
+
'content': self.content,
|
|
296
|
+
'entity_edges': self.entity_edges,
|
|
297
|
+
'created_at': self.created_at,
|
|
298
|
+
'valid_at': self.valid_at,
|
|
299
|
+
'source': self.source.value,
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
if driver.provider == GraphProvider.NEO4J:
|
|
303
|
+
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
|
|
304
|
+
|
|
210
305
|
result = await driver.execute_query(
|
|
211
|
-
get_episode_node_save_query(driver.provider),
|
|
212
|
-
uuid=self.uuid,
|
|
213
|
-
name=self.name,
|
|
214
|
-
group_id=self.group_id,
|
|
215
|
-
group_label='Episodic_' + self.group_id.replace('-', ''),
|
|
216
|
-
source_description=self.source_description,
|
|
217
|
-
content=self.content,
|
|
218
|
-
entity_edges=self.entity_edges,
|
|
219
|
-
created_at=self.created_at,
|
|
220
|
-
valid_at=self.valid_at,
|
|
221
|
-
source=self.source.value,
|
|
306
|
+
get_episode_node_save_query(driver.provider), **episode_args
|
|
222
307
|
)
|
|
223
308
|
|
|
224
309
|
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
@@ -376,17 +461,25 @@ class EntityNode(Node):
|
|
|
376
461
|
'summary': self.summary,
|
|
377
462
|
'created_at': self.created_at,
|
|
378
463
|
}
|
|
379
|
-
entity_data.update(self.attributes or {})
|
|
380
464
|
|
|
381
|
-
if driver.provider == GraphProvider.
|
|
382
|
-
|
|
465
|
+
if driver.provider == GraphProvider.KUZU:
|
|
466
|
+
entity_data['attributes'] = json.dumps(self.attributes)
|
|
467
|
+
entity_data['labels'] = list(set(self.labels + ['Entity']))
|
|
468
|
+
result = await driver.execute_query(
|
|
469
|
+
get_entity_node_save_query(driver.provider, labels=''),
|
|
470
|
+
**entity_data,
|
|
471
|
+
)
|
|
472
|
+
else:
|
|
473
|
+
entity_data.update(self.attributes or {})
|
|
474
|
+
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
|
|
383
475
|
|
|
384
|
-
|
|
476
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
477
|
+
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
|
385
478
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
479
|
+
result = await driver.execute_query(
|
|
480
|
+
get_entity_node_save_query(driver.provider, labels),
|
|
481
|
+
entity_data=entity_data,
|
|
482
|
+
)
|
|
390
483
|
|
|
391
484
|
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
392
485
|
|
|
@@ -399,12 +492,12 @@ class EntityNode(Node):
|
|
|
399
492
|
MATCH (n:Entity {uuid: $uuid})
|
|
400
493
|
RETURN
|
|
401
494
|
"""
|
|
402
|
-
+
|
|
495
|
+
+ get_entity_node_return_query(driver.provider),
|
|
403
496
|
uuid=uuid,
|
|
404
497
|
routing_='r',
|
|
405
498
|
)
|
|
406
499
|
|
|
407
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
500
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
408
501
|
|
|
409
502
|
if len(nodes) == 0:
|
|
410
503
|
raise NodeNotFoundError(uuid)
|
|
@@ -419,12 +512,12 @@ class EntityNode(Node):
|
|
|
419
512
|
WHERE n.uuid IN $uuids
|
|
420
513
|
RETURN
|
|
421
514
|
"""
|
|
422
|
-
+
|
|
515
|
+
+ get_entity_node_return_query(driver.provider),
|
|
423
516
|
uuids=uuids,
|
|
424
517
|
routing_='r',
|
|
425
518
|
)
|
|
426
519
|
|
|
427
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
520
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
428
521
|
|
|
429
522
|
return nodes
|
|
430
523
|
|
|
@@ -456,7 +549,7 @@ class EntityNode(Node):
|
|
|
456
549
|
+ """
|
|
457
550
|
RETURN
|
|
458
551
|
"""
|
|
459
|
-
+
|
|
552
|
+
+ get_entity_node_return_query(driver.provider)
|
|
460
553
|
+ with_embeddings_query
|
|
461
554
|
+ """
|
|
462
555
|
ORDER BY n.uuid DESC
|
|
@@ -468,7 +561,7 @@ class EntityNode(Node):
|
|
|
468
561
|
routing_='r',
|
|
469
562
|
)
|
|
470
563
|
|
|
471
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
564
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
472
565
|
|
|
473
566
|
return nodes
|
|
474
567
|
|
|
@@ -533,7 +626,7 @@ class CommunityNode(Node):
|
|
|
533
626
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
534
627
|
records, _, _ = await driver.execute_query(
|
|
535
628
|
"""
|
|
536
|
-
MATCH (
|
|
629
|
+
MATCH (c:Community {uuid: $uuid})
|
|
537
630
|
RETURN
|
|
538
631
|
"""
|
|
539
632
|
+ (
|
|
@@ -556,8 +649,8 @@ class CommunityNode(Node):
|
|
|
556
649
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
557
650
|
records, _, _ = await driver.execute_query(
|
|
558
651
|
"""
|
|
559
|
-
MATCH (
|
|
560
|
-
WHERE
|
|
652
|
+
MATCH (c:Community)
|
|
653
|
+
WHERE c.uuid IN $uuids
|
|
561
654
|
RETURN
|
|
562
655
|
"""
|
|
563
656
|
+ (
|
|
@@ -581,13 +674,13 @@ class CommunityNode(Node):
|
|
|
581
674
|
limit: int | None = None,
|
|
582
675
|
uuid_cursor: str | None = None,
|
|
583
676
|
):
|
|
584
|
-
cursor_query: LiteralString = 'AND
|
|
677
|
+
cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
|
|
585
678
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
586
679
|
|
|
587
680
|
records, _, _ = await driver.execute_query(
|
|
588
681
|
"""
|
|
589
|
-
MATCH (
|
|
590
|
-
WHERE
|
|
682
|
+
MATCH (c:Community)
|
|
683
|
+
WHERE c.group_id IN $group_ids
|
|
591
684
|
"""
|
|
592
685
|
+ cursor_query
|
|
593
686
|
+ """
|
|
@@ -599,7 +692,7 @@ class CommunityNode(Node):
|
|
|
599
692
|
else COMMUNITY_NODE_RETURN
|
|
600
693
|
)
|
|
601
694
|
+ """
|
|
602
|
-
ORDER BY
|
|
695
|
+
ORDER BY c.uuid DESC
|
|
603
696
|
"""
|
|
604
697
|
+ limit_query,
|
|
605
698
|
group_ids=group_ids,
|
|
@@ -636,25 +729,35 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
|
636
729
|
)
|
|
637
730
|
|
|
638
731
|
|
|
639
|
-
def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
732
|
+
def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
|
|
733
|
+
if provider == GraphProvider.KUZU:
|
|
734
|
+
attributes = json.loads(record['attributes']) if record['attributes'] else {}
|
|
735
|
+
else:
|
|
736
|
+
attributes = record['attributes']
|
|
737
|
+
attributes.pop('uuid', None)
|
|
738
|
+
attributes.pop('name', None)
|
|
739
|
+
attributes.pop('group_id', None)
|
|
740
|
+
attributes.pop('name_embedding', None)
|
|
741
|
+
attributes.pop('summary', None)
|
|
742
|
+
attributes.pop('created_at', None)
|
|
743
|
+
attributes.pop('labels', None)
|
|
744
|
+
|
|
745
|
+
labels = record.get('labels', [])
|
|
746
|
+
group_id = record.get('group_id')
|
|
747
|
+
if 'Entity_' + group_id.replace('-', '') in labels:
|
|
748
|
+
labels.remove('Entity_' + group_id.replace('-', ''))
|
|
749
|
+
|
|
640
750
|
entity_node = EntityNode(
|
|
641
751
|
uuid=record['uuid'],
|
|
642
752
|
name=record['name'],
|
|
643
753
|
name_embedding=record.get('name_embedding'),
|
|
644
|
-
group_id=
|
|
645
|
-
labels=
|
|
754
|
+
group_id=group_id,
|
|
755
|
+
labels=labels,
|
|
646
756
|
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
647
757
|
summary=record['summary'],
|
|
648
|
-
attributes=
|
|
758
|
+
attributes=attributes,
|
|
649
759
|
)
|
|
650
760
|
|
|
651
|
-
entity_node.attributes.pop('uuid', None)
|
|
652
|
-
entity_node.attributes.pop('name', None)
|
|
653
|
-
entity_node.attributes.pop('group_id', None)
|
|
654
|
-
entity_node.attributes.pop('name_embedding', None)
|
|
655
|
-
entity_node.attributes.pop('summary', None)
|
|
656
|
-
entity_node.attributes.pop('created_at', None)
|
|
657
|
-
|
|
658
761
|
return entity_node
|
|
659
762
|
|
|
660
763
|
|
graphiti_core/search/search.py
CHANGED
|
@@ -325,12 +325,20 @@ async def node_search(
|
|
|
325
325
|
search_tasks = []
|
|
326
326
|
if NodeSearchMethod.bm25 in config.search_methods:
|
|
327
327
|
search_tasks.append(
|
|
328
|
-
node_fulltext_search(
|
|
328
|
+
node_fulltext_search(
|
|
329
|
+
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
|
330
|
+
)
|
|
329
331
|
)
|
|
330
332
|
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
|
331
333
|
search_tasks.append(
|
|
332
334
|
node_similarity_search(
|
|
333
|
-
driver,
|
|
335
|
+
driver,
|
|
336
|
+
query_vector,
|
|
337
|
+
search_filter,
|
|
338
|
+
group_ids,
|
|
339
|
+
2 * limit,
|
|
340
|
+
config.sim_min_score,
|
|
341
|
+
config.use_local_indexes,
|
|
334
342
|
)
|
|
335
343
|
)
|
|
336
344
|
if NodeSearchMethod.bfs in config.search_methods:
|
|
@@ -426,7 +434,9 @@ async def episode_search(
|
|
|
426
434
|
search_results: list[list[EpisodicNode]] = list(
|
|
427
435
|
await semaphore_gather(
|
|
428
436
|
*[
|
|
429
|
-
episode_fulltext_search(
|
|
437
|
+
episode_fulltext_search(
|
|
438
|
+
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
|
439
|
+
),
|
|
430
440
|
]
|
|
431
441
|
)
|
|
432
442
|
)
|
|
@@ -24,6 +24,7 @@ from graphiti_core.search.search_utils import (
|
|
|
24
24
|
DEFAULT_MIN_SCORE,
|
|
25
25
|
DEFAULT_MMR_LAMBDA,
|
|
26
26
|
MAX_SEARCH_DEPTH,
|
|
27
|
+
USE_HNSW,
|
|
27
28
|
)
|
|
28
29
|
|
|
29
30
|
DEFAULT_SEARCH_LIMIT = 10
|
|
@@ -91,6 +92,7 @@ class NodeSearchConfig(BaseModel):
|
|
|
91
92
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
92
93
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
93
94
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
95
|
+
use_local_indexes: bool = Field(default=USE_HNSW)
|
|
94
96
|
|
|
95
97
|
|
|
96
98
|
class EpisodeSearchConfig(BaseModel):
|
|
@@ -99,6 +101,7 @@ class EpisodeSearchConfig(BaseModel):
|
|
|
99
101
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
100
102
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
101
103
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
104
|
+
use_local_indexes: bool = Field(default=USE_HNSW)
|
|
102
105
|
|
|
103
106
|
|
|
104
107
|
class CommunitySearchConfig(BaseModel):
|
|
@@ -107,6 +110,7 @@ class CommunitySearchConfig(BaseModel):
|
|
|
107
110
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
108
111
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
109
112
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
113
|
+
use_local_indexes: bool = Field(default=USE_HNSW)
|
|
110
114
|
|
|
111
115
|
|
|
112
116
|
class SearchConfig(BaseModel):
|
|
@@ -20,6 +20,8 @@ from typing import Any
|
|
|
20
20
|
|
|
21
21
|
from pydantic import BaseModel, Field
|
|
22
22
|
|
|
23
|
+
from graphiti_core.driver.driver import GraphProvider
|
|
24
|
+
|
|
23
25
|
|
|
24
26
|
class ComparisonOperator(Enum):
|
|
25
27
|
equals = '='
|
|
@@ -54,16 +56,21 @@ class SearchFilters(BaseModel):
|
|
|
54
56
|
|
|
55
57
|
def node_search_filter_query_constructor(
|
|
56
58
|
filters: SearchFilters,
|
|
57
|
-
|
|
58
|
-
|
|
59
|
+
provider: GraphProvider,
|
|
60
|
+
) -> tuple[list[str], dict[str, Any]]:
|
|
61
|
+
filter_queries: list[str] = []
|
|
59
62
|
filter_params: dict[str, Any] = {}
|
|
60
63
|
|
|
61
64
|
if filters.node_labels is not None:
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
+
if provider == GraphProvider.KUZU:
|
|
66
|
+
node_label_filter = 'list_has_all(n.labels, $labels)'
|
|
67
|
+
filter_params['labels'] = filters.node_labels
|
|
68
|
+
else:
|
|
69
|
+
node_labels = '|'.join(filters.node_labels)
|
|
70
|
+
node_label_filter = 'n:' + node_labels
|
|
71
|
+
filter_queries.append(node_label_filter)
|
|
65
72
|
|
|
66
|
-
return
|
|
73
|
+
return filter_queries, filter_params
|
|
67
74
|
|
|
68
75
|
|
|
69
76
|
def date_filter_query_constructor(
|
|
@@ -81,23 +88,29 @@ def date_filter_query_constructor(
|
|
|
81
88
|
|
|
82
89
|
def edge_search_filter_query_constructor(
|
|
83
90
|
filters: SearchFilters,
|
|
84
|
-
|
|
85
|
-
|
|
91
|
+
provider: GraphProvider,
|
|
92
|
+
) -> tuple[list[str], dict[str, Any]]:
|
|
93
|
+
filter_queries: list[str] = []
|
|
86
94
|
filter_params: dict[str, Any] = {}
|
|
87
95
|
|
|
88
96
|
if filters.edge_types is not None:
|
|
89
97
|
edge_types = filters.edge_types
|
|
90
|
-
|
|
91
|
-
filter_query += edge_types_filter
|
|
98
|
+
filter_queries.append('e.name in $edge_types')
|
|
92
99
|
filter_params['edge_types'] = edge_types
|
|
93
100
|
|
|
94
101
|
if filters.node_labels is not None:
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
102
|
+
if provider == GraphProvider.KUZU:
|
|
103
|
+
node_label_filter = (
|
|
104
|
+
'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
|
|
105
|
+
)
|
|
106
|
+
filter_params['labels'] = filters.node_labels
|
|
107
|
+
else:
|
|
108
|
+
node_labels = '|'.join(filters.node_labels)
|
|
109
|
+
node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
|
|
110
|
+
filter_queries.append(node_label_filter)
|
|
98
111
|
|
|
99
112
|
if filters.valid_at is not None:
|
|
100
|
-
valid_at_filter = '
|
|
113
|
+
valid_at_filter = '('
|
|
101
114
|
for i, or_list in enumerate(filters.valid_at):
|
|
102
115
|
for j, date_filter in enumerate(or_list):
|
|
103
116
|
if date_filter.comparison_operator not in [
|
|
@@ -125,10 +138,10 @@ def edge_search_filter_query_constructor(
|
|
|
125
138
|
else:
|
|
126
139
|
valid_at_filter += ' OR '
|
|
127
140
|
|
|
128
|
-
|
|
141
|
+
filter_queries.append(valid_at_filter)
|
|
129
142
|
|
|
130
143
|
if filters.invalid_at is not None:
|
|
131
|
-
invalid_at_filter = '
|
|
144
|
+
invalid_at_filter = '('
|
|
132
145
|
for i, or_list in enumerate(filters.invalid_at):
|
|
133
146
|
for j, date_filter in enumerate(or_list):
|
|
134
147
|
if date_filter.comparison_operator not in [
|
|
@@ -156,10 +169,10 @@ def edge_search_filter_query_constructor(
|
|
|
156
169
|
else:
|
|
157
170
|
invalid_at_filter += ' OR '
|
|
158
171
|
|
|
159
|
-
|
|
172
|
+
filter_queries.append(invalid_at_filter)
|
|
160
173
|
|
|
161
174
|
if filters.created_at is not None:
|
|
162
|
-
created_at_filter = '
|
|
175
|
+
created_at_filter = '('
|
|
163
176
|
for i, or_list in enumerate(filters.created_at):
|
|
164
177
|
for j, date_filter in enumerate(or_list):
|
|
165
178
|
if date_filter.comparison_operator not in [
|
|
@@ -187,10 +200,10 @@ def edge_search_filter_query_constructor(
|
|
|
187
200
|
else:
|
|
188
201
|
created_at_filter += ' OR '
|
|
189
202
|
|
|
190
|
-
|
|
203
|
+
filter_queries.append(created_at_filter)
|
|
191
204
|
|
|
192
205
|
if filters.expired_at is not None:
|
|
193
|
-
expired_at_filter = '
|
|
206
|
+
expired_at_filter = '('
|
|
194
207
|
for i, or_list in enumerate(filters.expired_at):
|
|
195
208
|
for j, date_filter in enumerate(or_list):
|
|
196
209
|
if date_filter.comparison_operator not in [
|
|
@@ -218,6 +231,6 @@ def edge_search_filter_query_constructor(
|
|
|
218
231
|
else:
|
|
219
232
|
expired_at_filter += ' OR '
|
|
220
233
|
|
|
221
|
-
|
|
234
|
+
filter_queries.append(expired_at_filter)
|
|
222
235
|
|
|
223
|
-
return
|
|
236
|
+
return filter_queries, filter_params
|