graphiti-core 0.18.9__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +4 -0
- graphiti_core/driver/falkordb_driver.py +3 -14
- graphiti_core/driver/kuzu_driver.py +175 -0
- graphiti_core/driver/neptune_driver.py +301 -0
- graphiti_core/edges.py +155 -62
- graphiti_core/graph_queries.py +31 -2
- graphiti_core/graphiti.py +6 -1
- graphiti_core/helpers.py +8 -8
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/openai_base_client.py +12 -2
- graphiti_core/llm_client/openai_client.py +10 -2
- graphiti_core/migrations/__init__.py +0 -0
- graphiti_core/migrations/neo4j_node_group_labels.py +114 -0
- graphiti_core/models/edges/edge_db_queries.py +205 -76
- graphiti_core/models/nodes/node_db_queries.py +253 -74
- graphiti_core/nodes.py +271 -98
- graphiti_core/search/search.py +42 -12
- graphiti_core/search/search_config.py +4 -0
- graphiti_core/search/search_filters.py +35 -22
- graphiti_core/search/search_utils.py +1329 -392
- graphiti_core/utils/bulk_utils.py +50 -15
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +39 -32
- graphiti_core/utils/maintenance/edge_operations.py +47 -13
- graphiti_core/utils/maintenance/graph_data_operations.py +100 -15
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/METADATA +87 -13
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/RECORD +29 -25
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
import typing
|
|
19
20
|
from datetime import datetime
|
|
@@ -22,20 +23,21 @@ import numpy as np
|
|
|
22
23
|
from pydantic import BaseModel, Field
|
|
23
24
|
from typing_extensions import Any
|
|
24
25
|
|
|
25
|
-
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
26
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
26
27
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
|
27
28
|
from graphiti_core.embedder import EmbedderClient
|
|
28
29
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
29
30
|
from graphiti_core.helpers import normalize_l2, semaphore_gather
|
|
30
31
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
31
|
-
EPISODIC_EDGE_SAVE_BULK,
|
|
32
32
|
get_entity_edge_save_bulk_query,
|
|
33
|
+
get_episodic_edge_save_bulk_query,
|
|
33
34
|
)
|
|
34
35
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
35
|
-
EPISODIC_NODE_SAVE_BULK,
|
|
36
36
|
get_entity_node_save_bulk_query,
|
|
37
|
+
get_episode_node_save_bulk_query,
|
|
37
38
|
)
|
|
38
39
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
|
40
|
+
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
|
39
41
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
40
42
|
extract_edges,
|
|
41
43
|
resolve_extracted_edge,
|
|
@@ -116,10 +118,16 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
116
118
|
episodes = [dict(episode) for episode in episodic_nodes]
|
|
117
119
|
for episode in episodes:
|
|
118
120
|
episode['source'] = str(episode['source'].value)
|
|
119
|
-
|
|
121
|
+
episode.pop('labels', None)
|
|
122
|
+
if driver.provider == GraphProvider.NEO4J:
|
|
123
|
+
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
|
|
124
|
+
|
|
125
|
+
nodes = []
|
|
126
|
+
|
|
120
127
|
for node in entity_nodes:
|
|
121
128
|
if node.name_embedding is None:
|
|
122
129
|
await node.generate_name_embedding(embedder)
|
|
130
|
+
|
|
123
131
|
entity_data: dict[str, Any] = {
|
|
124
132
|
'uuid': node.uuid,
|
|
125
133
|
'name': node.name,
|
|
@@ -129,11 +137,19 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
129
137
|
'created_at': node.created_at,
|
|
130
138
|
}
|
|
131
139
|
|
|
132
|
-
entity_data.update(node.attributes or {})
|
|
133
140
|
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
|
141
|
+
if driver.provider == GraphProvider.KUZU:
|
|
142
|
+
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
|
143
|
+
entity_data['attributes'] = json.dumps(attributes)
|
|
144
|
+
else:
|
|
145
|
+
entity_data.update(node.attributes or {})
|
|
146
|
+
entity_data['labels'] = list(
|
|
147
|
+
set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
|
|
148
|
+
)
|
|
149
|
+
|
|
134
150
|
nodes.append(entity_data)
|
|
135
151
|
|
|
136
|
-
edges
|
|
152
|
+
edges = []
|
|
137
153
|
for edge in entity_edges:
|
|
138
154
|
if edge.fact_embedding is None:
|
|
139
155
|
await edge.generate_embedding(embedder)
|
|
@@ -152,17 +168,36 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
152
168
|
'invalid_at': edge.invalid_at,
|
|
153
169
|
}
|
|
154
170
|
|
|
155
|
-
|
|
171
|
+
if driver.provider == GraphProvider.KUZU:
|
|
172
|
+
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
|
173
|
+
edge_data['attributes'] = json.dumps(attributes)
|
|
174
|
+
else:
|
|
175
|
+
edge_data.update(edge.attributes or {})
|
|
176
|
+
|
|
156
177
|
edges.append(edge_data)
|
|
157
178
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
179
|
+
if driver.provider == GraphProvider.KUZU:
|
|
180
|
+
# FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
|
|
181
|
+
episode_query = get_episode_node_save_bulk_query(driver.provider)
|
|
182
|
+
for episode in episodes:
|
|
183
|
+
await tx.run(episode_query, **episode)
|
|
184
|
+
entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
|
|
185
|
+
for node in nodes:
|
|
186
|
+
await tx.run(entity_node_query, **node)
|
|
187
|
+
entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
|
|
188
|
+
for edge in edges:
|
|
189
|
+
await tx.run(entity_edge_query, **edge)
|
|
190
|
+
episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
|
|
191
|
+
for edge in episodic_edges:
|
|
192
|
+
await tx.run(episodic_edge_query, **edge.model_dump())
|
|
193
|
+
else:
|
|
194
|
+
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
|
195
|
+
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
|
|
196
|
+
await tx.run(
|
|
197
|
+
get_episodic_edge_save_bulk_query(driver.provider),
|
|
198
|
+
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
|
199
|
+
)
|
|
200
|
+
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
|
|
166
201
|
|
|
167
202
|
|
|
168
203
|
async def extract_nodes_and_edges_bulk(
|
|
@@ -40,3 +40,16 @@ def ensure_utc(dt: datetime | None) -> datetime | None:
|
|
|
40
40
|
return dt.astimezone(timezone.utc)
|
|
41
41
|
|
|
42
42
|
return dt
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def convert_datetimes_to_strings(obj):
|
|
46
|
+
if isinstance(obj, dict):
|
|
47
|
+
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
|
|
48
|
+
elif isinstance(obj, list):
|
|
49
|
+
return [convert_datetimes_to_strings(item) for item in obj]
|
|
50
|
+
elif isinstance(obj, tuple):
|
|
51
|
+
return tuple(convert_datetimes_to_strings(item) for item in obj)
|
|
52
|
+
elif isinstance(obj, datetime):
|
|
53
|
+
return obj.isoformat()
|
|
54
|
+
else:
|
|
55
|
+
return obj
|
|
@@ -4,11 +4,12 @@ from collections import defaultdict
|
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel
|
|
6
6
|
|
|
7
|
-
from graphiti_core.driver.driver import GraphDriver
|
|
7
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
8
8
|
from graphiti_core.edges import CommunityEdge
|
|
9
9
|
from graphiti_core.embedder import EmbedderClient
|
|
10
10
|
from graphiti_core.helpers import semaphore_gather
|
|
11
11
|
from graphiti_core.llm_client import LLMClient
|
|
12
|
+
from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN
|
|
12
13
|
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
|
|
13
14
|
from graphiti_core.prompts import prompt_library
|
|
14
15
|
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
|
|
@@ -33,11 +34,11 @@ async def get_community_clusters(
|
|
|
33
34
|
if group_ids is None:
|
|
34
35
|
group_id_values, _, _ = await driver.execute_query(
|
|
35
36
|
"""
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
37
|
+
MATCH (n:Entity)
|
|
38
|
+
WHERE n.group_id IS NOT NULL
|
|
39
|
+
RETURN
|
|
40
|
+
collect(DISTINCT n.group_id) AS group_ids
|
|
41
|
+
"""
|
|
41
42
|
)
|
|
42
43
|
|
|
43
44
|
group_ids = group_id_values[0]['group_ids'] if group_id_values else []
|
|
@@ -46,14 +47,21 @@ async def get_community_clusters(
|
|
|
46
47
|
projection: dict[str, list[Neighbor]] = {}
|
|
47
48
|
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
|
|
48
49
|
for node in nodes:
|
|
49
|
-
|
|
50
|
+
match_query = """
|
|
51
|
+
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id})
|
|
52
|
+
"""
|
|
53
|
+
if driver.provider == GraphProvider.KUZU:
|
|
54
|
+
match_query = """
|
|
55
|
+
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id})
|
|
50
56
|
"""
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
uuid
|
|
55
|
-
|
|
56
|
-
|
|
57
|
+
records, _, _ = await driver.execute_query(
|
|
58
|
+
match_query
|
|
59
|
+
+ """
|
|
60
|
+
WITH count(e) AS count, m.uuid AS uuid
|
|
61
|
+
RETURN
|
|
62
|
+
uuid,
|
|
63
|
+
count
|
|
64
|
+
""",
|
|
57
65
|
uuid=node.uuid,
|
|
58
66
|
group_id=group_id,
|
|
59
67
|
)
|
|
@@ -235,9 +243,9 @@ async def build_communities(
|
|
|
235
243
|
async def remove_communities(driver: GraphDriver):
|
|
236
244
|
await driver.execute_query(
|
|
237
245
|
"""
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
246
|
+
MATCH (c:Community)
|
|
247
|
+
DETACH DELETE c
|
|
248
|
+
"""
|
|
241
249
|
)
|
|
242
250
|
|
|
243
251
|
|
|
@@ -247,14 +255,10 @@ async def determine_entity_community(
|
|
|
247
255
|
# Check if the node is already part of a community
|
|
248
256
|
records, _, _ = await driver.execute_query(
|
|
249
257
|
"""
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
c.group_id AS group_id,
|
|
255
|
-
c.created_at AS created_at,
|
|
256
|
-
c.summary AS summary
|
|
257
|
-
""",
|
|
258
|
+
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
|
259
|
+
RETURN
|
|
260
|
+
"""
|
|
261
|
+
+ COMMUNITY_NODE_RETURN,
|
|
258
262
|
entity_uuid=entity.uuid,
|
|
259
263
|
)
|
|
260
264
|
|
|
@@ -262,16 +266,19 @@ async def determine_entity_community(
|
|
|
262
266
|
return get_community_node_from_record(records[0]), False
|
|
263
267
|
|
|
264
268
|
# If the node has no community, add it to the mode community of surrounding entities
|
|
269
|
+
match_query = """
|
|
270
|
+
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
|
271
|
+
"""
|
|
272
|
+
if driver.provider == GraphProvider.KUZU:
|
|
273
|
+
match_query = """
|
|
274
|
+
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
|
275
|
+
"""
|
|
265
276
|
records, _, _ = await driver.execute_query(
|
|
277
|
+
match_query
|
|
278
|
+
+ """
|
|
279
|
+
RETURN
|
|
266
280
|
"""
|
|
267
|
-
|
|
268
|
-
RETURN
|
|
269
|
-
c.uuid AS uuid,
|
|
270
|
-
c.name AS name,
|
|
271
|
-
c.group_id AS group_id,
|
|
272
|
-
c.created_at AS created_at,
|
|
273
|
-
c.summary AS summary
|
|
274
|
-
""",
|
|
281
|
+
+ COMMUNITY_NODE_RETURN,
|
|
275
282
|
entity_uuid=entity.uuid,
|
|
276
283
|
)
|
|
277
284
|
|
|
@@ -21,7 +21,7 @@ from time import time
|
|
|
21
21
|
from pydantic import BaseModel
|
|
22
22
|
from typing_extensions import LiteralString
|
|
23
23
|
|
|
24
|
-
from graphiti_core.driver.driver import GraphDriver
|
|
24
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
25
25
|
from graphiti_core.edges import (
|
|
26
26
|
CommunityEdge,
|
|
27
27
|
EntityEdge,
|
|
@@ -504,23 +504,57 @@ async def resolve_extracted_edge(
|
|
|
504
504
|
async def filter_existing_duplicate_of_edges(
|
|
505
505
|
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
|
506
506
|
) -> list[tuple[EntityNode, EntityNode]]:
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
|
510
|
-
RETURN DISTINCT
|
|
511
|
-
n.uuid AS source_uuid,
|
|
512
|
-
m.uuid AS target_uuid
|
|
513
|
-
"""
|
|
507
|
+
if not duplicates_node_tuples:
|
|
508
|
+
return []
|
|
514
509
|
|
|
515
510
|
duplicate_nodes_map = {
|
|
516
511
|
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
|
|
517
512
|
}
|
|
518
513
|
|
|
519
|
-
|
|
520
|
-
query
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
514
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
515
|
+
query: LiteralString = """
|
|
516
|
+
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
|
517
|
+
MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
|
|
518
|
+
RETURN DISTINCT
|
|
519
|
+
n.uuid AS source_uuid,
|
|
520
|
+
m.uuid AS target_uuid
|
|
521
|
+
"""
|
|
522
|
+
|
|
523
|
+
duplicate_nodes = [
|
|
524
|
+
{'source': source.uuid, 'target': target.uuid}
|
|
525
|
+
for source, target in duplicates_node_tuples
|
|
526
|
+
]
|
|
527
|
+
|
|
528
|
+
records, _, _ = await driver.execute_query(
|
|
529
|
+
query,
|
|
530
|
+
duplicate_node_uuids=duplicate_nodes,
|
|
531
|
+
routing_='r',
|
|
532
|
+
)
|
|
533
|
+
else:
|
|
534
|
+
if driver.provider == GraphProvider.KUZU:
|
|
535
|
+
query = """
|
|
536
|
+
UNWIND $duplicate_node_uuids AS duplicate
|
|
537
|
+
MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
|
|
538
|
+
RETURN DISTINCT
|
|
539
|
+
n.uuid AS source_uuid,
|
|
540
|
+
m.uuid AS target_uuid
|
|
541
|
+
"""
|
|
542
|
+
duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
|
|
543
|
+
else:
|
|
544
|
+
query: LiteralString = """
|
|
545
|
+
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
|
546
|
+
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
|
547
|
+
RETURN DISTINCT
|
|
548
|
+
n.uuid AS source_uuid,
|
|
549
|
+
m.uuid AS target_uuid
|
|
550
|
+
"""
|
|
551
|
+
duplicate_node_uuids = list(duplicate_nodes_map.keys())
|
|
552
|
+
|
|
553
|
+
records, _, _ = await driver.execute_query(
|
|
554
|
+
query,
|
|
555
|
+
duplicate_node_uuids=duplicate_node_uuids,
|
|
556
|
+
routing_='r',
|
|
557
|
+
)
|
|
524
558
|
|
|
525
559
|
# Remove duplicates that already have the IS_DUPLICATE_OF edge
|
|
526
560
|
for record in records:
|
|
@@ -19,10 +19,13 @@ from datetime import datetime
|
|
|
19
19
|
|
|
20
20
|
from typing_extensions import LiteralString
|
|
21
21
|
|
|
22
|
-
from graphiti_core.driver.driver import GraphDriver
|
|
22
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
23
23
|
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
|
24
24
|
from graphiti_core.helpers import semaphore_gather
|
|
25
|
-
from graphiti_core.models.nodes.node_db_queries import
|
|
25
|
+
from graphiti_core.models.nodes.node_db_queries import (
|
|
26
|
+
EPISODIC_NODE_RETURN,
|
|
27
|
+
EPISODIC_NODE_RETURN_NEPTUNE,
|
|
28
|
+
)
|
|
26
29
|
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
|
|
27
30
|
|
|
28
31
|
EPISODE_WINDOW_LEN = 3
|
|
@@ -31,6 +34,9 @@ logger = logging.getLogger(__name__)
|
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
|
37
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
38
|
+
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
|
|
39
|
+
return
|
|
34
40
|
if delete_existing:
|
|
35
41
|
records, _, _ = await driver.execute_query(
|
|
36
42
|
"""
|
|
@@ -47,10 +53,29 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
|
|
47
53
|
for name in index_names
|
|
48
54
|
]
|
|
49
55
|
)
|
|
56
|
+
|
|
50
57
|
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
|
51
58
|
|
|
52
59
|
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
|
53
60
|
|
|
61
|
+
if driver.provider == GraphProvider.KUZU:
|
|
62
|
+
# Skip creating fulltext indices if they already exist. Need to do this manually
|
|
63
|
+
# until Kuzu supports `IF NOT EXISTS` for indices.
|
|
64
|
+
result, _, _ = await driver.execute_query('CALL SHOW_INDEXES() RETURN *;')
|
|
65
|
+
if len(result) > 0:
|
|
66
|
+
fulltext_indices = []
|
|
67
|
+
|
|
68
|
+
# Only load the `fts` extension if it's not already loaded, otherwise throw an error.
|
|
69
|
+
result, _, _ = await driver.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
|
|
70
|
+
if len(result) == 0:
|
|
71
|
+
fulltext_indices.insert(
|
|
72
|
+
0,
|
|
73
|
+
"""
|
|
74
|
+
INSTALL fts;
|
|
75
|
+
LOAD fts;
|
|
76
|
+
""",
|
|
77
|
+
)
|
|
78
|
+
|
|
54
79
|
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
|
55
80
|
|
|
56
81
|
await semaphore_gather(
|
|
@@ -70,10 +95,19 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
|
|
70
95
|
await tx.run('MATCH (n) DETACH DELETE n')
|
|
71
96
|
|
|
72
97
|
async def delete_group_ids(tx):
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
98
|
+
labels = ['Entity', 'Episodic', 'Community']
|
|
99
|
+
if driver.provider == GraphProvider.KUZU:
|
|
100
|
+
labels.append('RelatesToNode_')
|
|
101
|
+
|
|
102
|
+
for label in labels:
|
|
103
|
+
await tx.run(
|
|
104
|
+
f"""
|
|
105
|
+
MATCH (n:{label})
|
|
106
|
+
WHERE n.group_id IN $group_ids
|
|
107
|
+
DETACH DELETE n
|
|
108
|
+
""",
|
|
109
|
+
group_ids=group_ids,
|
|
110
|
+
)
|
|
77
111
|
|
|
78
112
|
if group_ids is None:
|
|
79
113
|
await session.execute_write(delete_all)
|
|
@@ -102,22 +136,31 @@ async def retrieve_episodes(
|
|
|
102
136
|
Returns:
|
|
103
137
|
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
|
104
138
|
"""
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
139
|
+
|
|
140
|
+
query_params: dict = {}
|
|
141
|
+
query_filter = ''
|
|
142
|
+
if group_ids and len(group_ids) > 0:
|
|
143
|
+
query_filter += '\nAND e.group_id IN $group_ids'
|
|
144
|
+
query_params['group_ids'] = group_ids
|
|
145
|
+
|
|
146
|
+
if source is not None:
|
|
147
|
+
query_filter += '\nAND e.source = $source'
|
|
148
|
+
query_params['source'] = source.name
|
|
109
149
|
|
|
110
150
|
query: LiteralString = (
|
|
111
151
|
"""
|
|
112
152
|
MATCH (e:Episodic)
|
|
113
153
|
WHERE e.valid_at <= $reference_time
|
|
114
154
|
"""
|
|
115
|
-
+
|
|
116
|
-
+ source_filter
|
|
155
|
+
+ query_filter
|
|
117
156
|
+ """
|
|
118
157
|
RETURN
|
|
119
158
|
"""
|
|
120
|
-
+
|
|
159
|
+
+ (
|
|
160
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
161
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
162
|
+
else EPISODIC_NODE_RETURN
|
|
163
|
+
)
|
|
121
164
|
+ """
|
|
122
165
|
ORDER BY e.valid_at DESC
|
|
123
166
|
LIMIT $num_episodes
|
|
@@ -126,10 +169,52 @@ async def retrieve_episodes(
|
|
|
126
169
|
result, _, _ = await driver.execute_query(
|
|
127
170
|
query,
|
|
128
171
|
reference_time=reference_time,
|
|
129
|
-
source=source.name if source is not None else None,
|
|
130
172
|
num_episodes=last_n,
|
|
131
|
-
|
|
173
|
+
**query_params,
|
|
132
174
|
)
|
|
133
175
|
|
|
134
176
|
episodes = [get_episodic_node_from_record(record) for record in result]
|
|
135
177
|
return list(reversed(episodes)) # Return in chronological order
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
async def build_dynamic_indexes(driver: GraphDriver, group_id: str):
|
|
181
|
+
# Make sure indices exist for this group_id in Neo4j
|
|
182
|
+
if driver.provider == GraphProvider.NEO4J:
|
|
183
|
+
await semaphore_gather(
|
|
184
|
+
driver.execute_query(
|
|
185
|
+
"""CREATE FULLTEXT INDEX $episode_content IF NOT EXISTS
|
|
186
|
+
FOR (e:"""
|
|
187
|
+
+ 'Episodic_'
|
|
188
|
+
+ group_id.replace('-', '')
|
|
189
|
+
+ """) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
|
190
|
+
episode_content='episode_content_' + group_id.replace('-', ''),
|
|
191
|
+
),
|
|
192
|
+
driver.execute_query(
|
|
193
|
+
"""CREATE FULLTEXT INDEX $node_name_and_summary IF NOT EXISTS FOR (n:"""
|
|
194
|
+
+ 'Entity_'
|
|
195
|
+
+ group_id.replace('-', '')
|
|
196
|
+
+ """) ON EACH [n.name, n.summary, n.group_id]""",
|
|
197
|
+
node_name_and_summary='node_name_and_summary_' + group_id.replace('-', ''),
|
|
198
|
+
),
|
|
199
|
+
driver.execute_query(
|
|
200
|
+
"""CREATE FULLTEXT INDEX $community_name IF NOT EXISTS
|
|
201
|
+
FOR (n:"""
|
|
202
|
+
+ 'Community_'
|
|
203
|
+
+ group_id.replace('-', '')
|
|
204
|
+
+ """) ON EACH [n.name, n.group_id]""",
|
|
205
|
+
community_name='Community_' + group_id.replace('-', ''),
|
|
206
|
+
),
|
|
207
|
+
driver.execute_query(
|
|
208
|
+
"""CREATE VECTOR INDEX $group_entity_vector IF NOT EXISTS
|
|
209
|
+
FOR (n:"""
|
|
210
|
+
+ 'Entity_'
|
|
211
|
+
+ group_id.replace('-', '')
|
|
212
|
+
+ """)
|
|
213
|
+
ON n.embedding
|
|
214
|
+
OPTIONS { indexConfig: {
|
|
215
|
+
`vector.dimensions`: 1024,
|
|
216
|
+
`vector.similarity_function`: 'cosine'
|
|
217
|
+
}}""",
|
|
218
|
+
group_entity_vector='group_entity_vector_' + group_id.replace('-', ''),
|
|
219
|
+
),
|
|
220
|
+
)
|