graphiti-core 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/edges.py +39 -32
- graphiti_core/graphiti.py +100 -93
- graphiti_core/nodes.py +45 -39
- graphiti_core/prompts/dedupe_edges.py +1 -1
- graphiti_core/prompts/invalidate_edges.py +37 -1
- graphiti_core/search/search.py +5 -2
- graphiti_core/search/search_utils.py +101 -168
- graphiti_core/utils/bulk_utils.py +31 -3
- graphiti_core/utils/maintenance/edge_operations.py +104 -16
- graphiti_core/utils/maintenance/graph_data_operations.py +17 -8
- graphiti_core/utils/maintenance/node_operations.py +1 -0
- graphiti_core/utils/maintenance/temporal_operations.py +34 -0
- {graphiti_core-0.2.1.dist-info → graphiti_core-0.2.3.dist-info}/METADATA +3 -4
- {graphiti_core-0.2.1.dist-info → graphiti_core-0.2.3.dist-info}/RECORD +16 -17
- graphiti_core/utils/utils.py +0 -60
- {graphiti_core-0.2.1.dist-info → graphiti_core-0.2.3.dist-info}/LICENSE +0 -0
- {graphiti_core-0.2.1.dist-info → graphiti_core-0.2.3.dist-info}/WHEEL +0 -0
graphiti_core/edges.py
CHANGED
|
@@ -18,6 +18,7 @@ import logging
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from time import time
|
|
21
|
+
from typing import Any
|
|
21
22
|
from uuid import uuid4
|
|
22
23
|
|
|
23
24
|
from neo4j import AsyncDriver
|
|
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
|
|
|
32
33
|
|
|
33
34
|
class Edge(BaseModel, ABC):
|
|
34
35
|
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
|
36
|
+
group_id: str | None = Field(description='partition of the graph')
|
|
35
37
|
source_node_uuid: str
|
|
36
38
|
target_node_uuid: str
|
|
37
39
|
created_at: datetime
|
|
@@ -61,11 +63,12 @@ class EpisodicEdge(Edge):
|
|
|
61
63
|
MATCH (episode:Episodic {uuid: $episode_uuid})
|
|
62
64
|
MATCH (node:Entity {uuid: $entity_uuid})
|
|
63
65
|
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
|
64
|
-
SET r = {uuid: $uuid, created_at: $created_at}
|
|
66
|
+
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
|
65
67
|
RETURN r.uuid AS uuid""",
|
|
66
68
|
episode_uuid=self.source_node_uuid,
|
|
67
69
|
entity_uuid=self.target_node_uuid,
|
|
68
70
|
uuid=self.uuid,
|
|
71
|
+
group_id=self.group_id,
|
|
69
72
|
created_at=self.created_at,
|
|
70
73
|
)
|
|
71
74
|
|
|
@@ -92,7 +95,8 @@ class EpisodicEdge(Edge):
|
|
|
92
95
|
"""
|
|
93
96
|
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
94
97
|
RETURN
|
|
95
|
-
e.uuid As uuid,
|
|
98
|
+
e.uuid As uuid,
|
|
99
|
+
e.group_id AS group_id,
|
|
96
100
|
n.uuid AS source_node_uuid,
|
|
97
101
|
m.uuid AS target_node_uuid,
|
|
98
102
|
e.created_at AS created_at
|
|
@@ -100,17 +104,7 @@ class EpisodicEdge(Edge):
|
|
|
100
104
|
uuid=uuid,
|
|
101
105
|
)
|
|
102
106
|
|
|
103
|
-
edges
|
|
104
|
-
|
|
105
|
-
for record in records:
|
|
106
|
-
edges.append(
|
|
107
|
-
EpisodicEdge(
|
|
108
|
-
uuid=record['uuid'],
|
|
109
|
-
source_node_uuid=record['source_node_uuid'],
|
|
110
|
-
target_node_uuid=record['target_node_uuid'],
|
|
111
|
-
created_at=record['created_at'].to_native(),
|
|
112
|
-
)
|
|
113
|
-
)
|
|
107
|
+
edges = [get_episodic_edge_from_record(record) for record in records]
|
|
114
108
|
|
|
115
109
|
logger.info(f'Found Edge: {uuid}')
|
|
116
110
|
|
|
@@ -153,7 +147,7 @@ class EntityEdge(Edge):
|
|
|
153
147
|
MATCH (source:Entity {uuid: $source_uuid})
|
|
154
148
|
MATCH (target:Entity {uuid: $target_uuid})
|
|
155
149
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
|
156
|
-
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
|
|
150
|
+
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
|
|
157
151
|
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
|
|
158
152
|
valid_at: $valid_at, invalid_at: $invalid_at}
|
|
159
153
|
RETURN r.uuid AS uuid""",
|
|
@@ -161,6 +155,7 @@ class EntityEdge(Edge):
|
|
|
161
155
|
target_uuid=self.target_node_uuid,
|
|
162
156
|
uuid=self.uuid,
|
|
163
157
|
name=self.name,
|
|
158
|
+
group_id=self.group_id,
|
|
164
159
|
fact=self.fact,
|
|
165
160
|
fact_embedding=self.fact_embedding,
|
|
166
161
|
episodes=self.episodes,
|
|
@@ -198,6 +193,7 @@ class EntityEdge(Edge):
|
|
|
198
193
|
m.uuid AS target_node_uuid,
|
|
199
194
|
e.created_at AS created_at,
|
|
200
195
|
e.name AS name,
|
|
196
|
+
e.group_id AS group_id,
|
|
201
197
|
e.fact AS fact,
|
|
202
198
|
e.fact_embedding AS fact_embedding,
|
|
203
199
|
e.episodes AS episodes,
|
|
@@ -208,25 +204,36 @@ class EntityEdge(Edge):
|
|
|
208
204
|
uuid=uuid,
|
|
209
205
|
)
|
|
210
206
|
|
|
211
|
-
edges
|
|
212
|
-
|
|
213
|
-
for record in records:
|
|
214
|
-
edges.append(
|
|
215
|
-
EntityEdge(
|
|
216
|
-
uuid=record['uuid'],
|
|
217
|
-
source_node_uuid=record['source_node_uuid'],
|
|
218
|
-
target_node_uuid=record['target_node_uuid'],
|
|
219
|
-
fact=record['fact'],
|
|
220
|
-
name=record['name'],
|
|
221
|
-
episodes=record['episodes'],
|
|
222
|
-
fact_embedding=record['fact_embedding'],
|
|
223
|
-
created_at=record['created_at'].to_native(),
|
|
224
|
-
expired_at=parse_db_date(record['expired_at']),
|
|
225
|
-
valid_at=parse_db_date(record['valid_at']),
|
|
226
|
-
invalid_at=parse_db_date(record['invalid_at']),
|
|
227
|
-
)
|
|
228
|
-
)
|
|
207
|
+
edges = [get_entity_edge_from_record(record) for record in records]
|
|
229
208
|
|
|
230
209
|
logger.info(f'Found Edge: {uuid}')
|
|
231
210
|
|
|
232
211
|
return edges[0]
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# Edge helpers
|
|
215
|
+
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|
216
|
+
return EpisodicEdge(
|
|
217
|
+
uuid=record['uuid'],
|
|
218
|
+
group_id=record['group_id'],
|
|
219
|
+
source_node_uuid=record['source_node_uuid'],
|
|
220
|
+
target_node_uuid=record['target_node_uuid'],
|
|
221
|
+
created_at=record['created_at'].to_native(),
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
226
|
+
return EntityEdge(
|
|
227
|
+
uuid=record['uuid'],
|
|
228
|
+
source_node_uuid=record['source_node_uuid'],
|
|
229
|
+
target_node_uuid=record['target_node_uuid'],
|
|
230
|
+
fact=record['fact'],
|
|
231
|
+
name=record['name'],
|
|
232
|
+
group_id=record['group_id'],
|
|
233
|
+
episodes=record['episodes'],
|
|
234
|
+
fact_embedding=record['fact_embedding'],
|
|
235
|
+
created_at=record['created_at'].to_native(),
|
|
236
|
+
expired_at=parse_db_date(record['expired_at']),
|
|
237
|
+
valid_at=parse_db_date(record['valid_at']),
|
|
238
|
+
invalid_at=parse_db_date(record['invalid_at']),
|
|
239
|
+
)
|
graphiti_core/graphiti.py
CHANGED
|
@@ -18,7 +18,6 @@ import asyncio
|
|
|
18
18
|
import logging
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from time import time
|
|
21
|
-
from typing import Callable
|
|
22
21
|
|
|
23
22
|
from dotenv import load_dotenv
|
|
24
23
|
from neo4j import AsyncGraphDatabase
|
|
@@ -59,11 +58,6 @@ from graphiti_core.utils.maintenance.node_operations import (
|
|
|
59
58
|
extract_nodes,
|
|
60
59
|
resolve_extracted_nodes,
|
|
61
60
|
)
|
|
62
|
-
from graphiti_core.utils.maintenance.temporal_operations import (
|
|
63
|
-
extract_edge_dates,
|
|
64
|
-
invalidate_edges,
|
|
65
|
-
prepare_edges_for_invalidation,
|
|
66
|
-
)
|
|
67
61
|
|
|
68
62
|
logger = logging.getLogger(__name__)
|
|
69
63
|
|
|
@@ -125,7 +119,7 @@ class Graphiti:
|
|
|
125
119
|
|
|
126
120
|
Parameters
|
|
127
121
|
----------
|
|
128
|
-
|
|
122
|
+
self
|
|
129
123
|
|
|
130
124
|
Returns
|
|
131
125
|
-------
|
|
@@ -156,7 +150,7 @@ class Graphiti:
|
|
|
156
150
|
|
|
157
151
|
Parameters
|
|
158
152
|
----------
|
|
159
|
-
|
|
153
|
+
self
|
|
160
154
|
|
|
161
155
|
Returns
|
|
162
156
|
-------
|
|
@@ -183,6 +177,7 @@ class Graphiti:
|
|
|
183
177
|
self,
|
|
184
178
|
reference_time: datetime,
|
|
185
179
|
last_n: int = EPISODE_WINDOW_LEN,
|
|
180
|
+
group_ids: list[str | None] | None = None,
|
|
186
181
|
) -> list[EpisodicNode]:
|
|
187
182
|
"""
|
|
188
183
|
Retrieve the last n episodic nodes from the graph.
|
|
@@ -196,6 +191,8 @@ class Graphiti:
|
|
|
196
191
|
The reference time to retrieve episodes before.
|
|
197
192
|
last_n : int, optional
|
|
198
193
|
The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
|
|
194
|
+
group_ids : list[str | None], optional
|
|
195
|
+
The group ids to return data from.
|
|
199
196
|
|
|
200
197
|
Returns
|
|
201
198
|
-------
|
|
@@ -207,7 +204,7 @@ class Graphiti:
|
|
|
207
204
|
The actual retrieval is performed by the `retrieve_episodes` function
|
|
208
205
|
from the `graphiti_core.utils` module.
|
|
209
206
|
"""
|
|
210
|
-
return await retrieve_episodes(self.driver, reference_time, last_n)
|
|
207
|
+
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
|
|
211
208
|
|
|
212
209
|
async def add_episode(
|
|
213
210
|
self,
|
|
@@ -216,8 +213,8 @@ class Graphiti:
|
|
|
216
213
|
source_description: str,
|
|
217
214
|
reference_time: datetime,
|
|
218
215
|
source: EpisodeType = EpisodeType.message,
|
|
219
|
-
|
|
220
|
-
|
|
216
|
+
group_id: str | None = None,
|
|
217
|
+
uuid: str | None = None,
|
|
221
218
|
):
|
|
222
219
|
"""
|
|
223
220
|
Process an episode and update the graph.
|
|
@@ -237,10 +234,10 @@ class Graphiti:
|
|
|
237
234
|
The reference time for the episode.
|
|
238
235
|
source : EpisodeType, optional
|
|
239
236
|
The type of the episode. Defaults to EpisodeType.message.
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
237
|
+
group_id : str | None
|
|
238
|
+
An id for the graph partition the episode is a part of.
|
|
239
|
+
uuid : str | None
|
|
240
|
+
Optional uuid of the episode.
|
|
244
241
|
|
|
245
242
|
Returns
|
|
246
243
|
-------
|
|
@@ -271,9 +268,12 @@ class Graphiti:
|
|
|
271
268
|
embedder = self.llm_client.get_embedder()
|
|
272
269
|
now = datetime.now()
|
|
273
270
|
|
|
274
|
-
previous_episodes = await self.retrieve_episodes(
|
|
271
|
+
previous_episodes = await self.retrieve_episodes(
|
|
272
|
+
reference_time, last_n=3, group_ids=[group_id]
|
|
273
|
+
)
|
|
275
274
|
episode = EpisodicNode(
|
|
276
275
|
name=name,
|
|
276
|
+
group_id=group_id,
|
|
277
277
|
labels=[],
|
|
278
278
|
source=source,
|
|
279
279
|
content=episode_body,
|
|
@@ -281,6 +281,7 @@ class Graphiti:
|
|
|
281
281
|
created_at=now,
|
|
282
282
|
valid_at=reference_time,
|
|
283
283
|
)
|
|
284
|
+
episode.uuid = uuid if uuid is not None else episode.uuid
|
|
284
285
|
|
|
285
286
|
# Extract entities as nodes
|
|
286
287
|
|
|
@@ -293,7 +294,7 @@ class Graphiti:
|
|
|
293
294
|
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
|
294
295
|
)
|
|
295
296
|
|
|
296
|
-
# Resolve extracted nodes with nodes already in the graph
|
|
297
|
+
# Resolve extracted nodes with nodes already in the graph and extract facts
|
|
297
298
|
existing_nodes_lists: list[list[EntityNode]] = list(
|
|
298
299
|
await asyncio.gather(
|
|
299
300
|
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
|
|
@@ -302,22 +303,29 @@ class Graphiti:
|
|
|
302
303
|
|
|
303
304
|
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
304
305
|
|
|
305
|
-
mentioned_nodes,
|
|
306
|
-
self.llm_client, extracted_nodes, existing_nodes_lists
|
|
306
|
+
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
|
307
|
+
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
|
|
308
|
+
extract_edges(
|
|
309
|
+
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
|
310
|
+
),
|
|
307
311
|
)
|
|
308
312
|
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
|
309
313
|
nodes.extend(mentioned_nodes)
|
|
310
314
|
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
self.llm_client, episode, mentioned_nodes, previous_episodes
|
|
315
|
+
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
|
|
316
|
+
extracted_edges, uuid_map
|
|
314
317
|
)
|
|
315
318
|
|
|
316
319
|
# calculate embeddings
|
|
317
|
-
await asyncio.gather(
|
|
320
|
+
await asyncio.gather(
|
|
321
|
+
*[
|
|
322
|
+
edge.generate_embedding(embedder)
|
|
323
|
+
for edge in extracted_edges_with_resolved_pointers
|
|
324
|
+
]
|
|
325
|
+
)
|
|
318
326
|
|
|
319
|
-
# Resolve extracted edges with edges already in the graph
|
|
320
|
-
|
|
327
|
+
# Resolve extracted edges with related edges already in the graph
|
|
328
|
+
related_edges_list: list[list[EntityEdge]] = list(
|
|
321
329
|
await asyncio.gather(
|
|
322
330
|
*[
|
|
323
331
|
get_relevant_edges(
|
|
@@ -327,80 +335,68 @@ class Graphiti:
|
|
|
327
335
|
edge.target_node_uuid,
|
|
328
336
|
RELEVANT_SCHEMA_LIMIT,
|
|
329
337
|
)
|
|
330
|
-
for edge in
|
|
338
|
+
for edge in extracted_edges_with_resolved_pointers
|
|
331
339
|
]
|
|
332
340
|
)
|
|
333
341
|
)
|
|
334
342
|
logger.info(
|
|
335
|
-
f'
|
|
343
|
+
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
|
|
336
344
|
)
|
|
337
|
-
logger.info(
|
|
338
|
-
|
|
339
|
-
deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
|
|
340
|
-
self.llm_client, extracted_edges, existing_edges_list
|
|
345
|
+
logger.info(
|
|
346
|
+
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
|
|
341
347
|
)
|
|
342
348
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
349
|
+
existing_source_edges_list: list[list[EntityEdge]] = list(
|
|
350
|
+
await asyncio.gather(
|
|
351
|
+
*[
|
|
352
|
+
get_relevant_edges(
|
|
353
|
+
self.driver,
|
|
354
|
+
[edge],
|
|
355
|
+
edge.source_node_uuid,
|
|
356
|
+
None,
|
|
357
|
+
RELEVANT_SCHEMA_LIMIT,
|
|
358
|
+
)
|
|
359
|
+
for edge in extracted_edges_with_resolved_pointers
|
|
360
|
+
]
|
|
361
|
+
)
|
|
354
362
|
)
|
|
355
363
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
364
|
+
existing_target_edges_list: list[list[EntityEdge]] = list(
|
|
365
|
+
await asyncio.gather(
|
|
366
|
+
*[
|
|
367
|
+
get_relevant_edges(
|
|
368
|
+
self.driver,
|
|
369
|
+
[edge],
|
|
370
|
+
None,
|
|
371
|
+
edge.target_node_uuid,
|
|
372
|
+
RELEVANT_SCHEMA_LIMIT,
|
|
373
|
+
)
|
|
374
|
+
for edge in extracted_edges_with_resolved_pointers
|
|
375
|
+
]
|
|
376
|
+
)
|
|
377
|
+
)
|
|
366
378
|
|
|
367
|
-
|
|
368
|
-
|
|
379
|
+
existing_edges_list: list[list[EntityEdge]] = [
|
|
380
|
+
source_lst + target_lst
|
|
381
|
+
for source_lst, target_lst in zip(
|
|
382
|
+
existing_source_edges_list, existing_target_edges_list
|
|
383
|
+
)
|
|
369
384
|
]
|
|
370
385
|
|
|
371
|
-
(
|
|
372
|
-
old_edges_with_nodes_pending_invalidation,
|
|
373
|
-
new_edges_with_nodes,
|
|
374
|
-
) = prepare_edges_for_invalidation(
|
|
375
|
-
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
|
|
376
|
-
)
|
|
377
|
-
|
|
378
|
-
invalidated_edges = await invalidate_edges(
|
|
386
|
+
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
|
379
387
|
self.llm_client,
|
|
380
|
-
|
|
381
|
-
|
|
388
|
+
extracted_edges_with_resolved_pointers,
|
|
389
|
+
related_edges_list,
|
|
390
|
+
existing_edges_list,
|
|
382
391
|
episode,
|
|
383
392
|
previous_episodes,
|
|
384
393
|
)
|
|
385
394
|
|
|
386
|
-
|
|
387
|
-
for existing_edge in existing_edges:
|
|
388
|
-
if existing_edge.uuid == edge.uuid:
|
|
389
|
-
existing_edge.expired_at = edge.expired_at
|
|
390
|
-
for deduped_edge in deduped_edges:
|
|
391
|
-
if deduped_edge.uuid == edge.uuid:
|
|
392
|
-
deduped_edge.expired_at = edge.expired_at
|
|
393
|
-
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
|
|
395
|
+
entity_edges.extend(resolved_edges + invalidated_edges)
|
|
394
396
|
|
|
395
|
-
|
|
397
|
+
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
|
396
398
|
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
|
|
400
|
-
mentioned_nodes,
|
|
401
|
-
episode,
|
|
402
|
-
now,
|
|
403
|
-
)
|
|
399
|
+
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
|
404
400
|
|
|
405
401
|
logger.info(f'Built episodic edges: {episodic_edges}')
|
|
406
402
|
|
|
@@ -413,18 +409,10 @@ class Graphiti:
|
|
|
413
409
|
end = time()
|
|
414
410
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
415
411
|
|
|
416
|
-
if success_callback:
|
|
417
|
-
await success_callback(episode)
|
|
418
412
|
except Exception as e:
|
|
419
|
-
|
|
420
|
-
await error_callback(episode, e)
|
|
421
|
-
else:
|
|
422
|
-
raise e
|
|
413
|
+
raise e
|
|
423
414
|
|
|
424
|
-
async def add_episode_bulk(
|
|
425
|
-
self,
|
|
426
|
-
bulk_episodes: list[RawEpisode],
|
|
427
|
-
):
|
|
415
|
+
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
|
|
428
416
|
"""
|
|
429
417
|
Process multiple episodes in bulk and update the graph.
|
|
430
418
|
|
|
@@ -435,6 +423,8 @@ class Graphiti:
|
|
|
435
423
|
----------
|
|
436
424
|
bulk_episodes : list[RawEpisode]
|
|
437
425
|
A list of RawEpisode objects to be processed and added to the graph.
|
|
426
|
+
group_id : str | None
|
|
427
|
+
An id for the graph partition the episode is a part of.
|
|
438
428
|
|
|
439
429
|
Returns
|
|
440
430
|
-------
|
|
@@ -471,6 +461,7 @@ class Graphiti:
|
|
|
471
461
|
source=episode.source,
|
|
472
462
|
content=episode.content,
|
|
473
463
|
source_description=episode.source_description,
|
|
464
|
+
group_id=group_id,
|
|
474
465
|
created_at=now,
|
|
475
466
|
valid_at=episode.reference_time,
|
|
476
467
|
)
|
|
@@ -535,7 +526,13 @@ class Graphiti:
|
|
|
535
526
|
except Exception as e:
|
|
536
527
|
raise e
|
|
537
528
|
|
|
538
|
-
async def search(
|
|
529
|
+
async def search(
|
|
530
|
+
self,
|
|
531
|
+
query: str,
|
|
532
|
+
center_node_uuid: str | None = None,
|
|
533
|
+
group_ids: list[str | None] | None = None,
|
|
534
|
+
num_results=10,
|
|
535
|
+
):
|
|
539
536
|
"""
|
|
540
537
|
Perform a hybrid search on the knowledge graph.
|
|
541
538
|
|
|
@@ -548,6 +545,8 @@ class Graphiti:
|
|
|
548
545
|
The search query string.
|
|
549
546
|
center_node_uuid: str, optional
|
|
550
547
|
Facts will be reranked based on proximity to this node
|
|
548
|
+
group_ids : list[str | None] | None, optional
|
|
549
|
+
The graph partitions to return data from.
|
|
551
550
|
num_results : int, optional
|
|
552
551
|
The maximum number of results to return. Defaults to 10.
|
|
553
552
|
|
|
@@ -570,6 +569,7 @@ class Graphiti:
|
|
|
570
569
|
num_episodes=0,
|
|
571
570
|
num_edges=num_results,
|
|
572
571
|
num_nodes=0,
|
|
572
|
+
group_ids=group_ids,
|
|
573
573
|
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
|
|
574
574
|
reranker=reranker,
|
|
575
575
|
)
|
|
@@ -598,7 +598,10 @@ class Graphiti:
|
|
|
598
598
|
)
|
|
599
599
|
|
|
600
600
|
async def get_nodes_by_query(
|
|
601
|
-
self,
|
|
601
|
+
self,
|
|
602
|
+
query: str,
|
|
603
|
+
group_ids: list[str | None] | None = None,
|
|
604
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
602
605
|
) -> list[EntityNode]:
|
|
603
606
|
"""
|
|
604
607
|
Retrieve nodes from the graph database based on a text query.
|
|
@@ -610,6 +613,8 @@ class Graphiti:
|
|
|
610
613
|
----------
|
|
611
614
|
query : str
|
|
612
615
|
The text query to search for in the graph.
|
|
616
|
+
group_ids : list[str | None] | None, optional
|
|
617
|
+
The graph partitions to return data from.
|
|
613
618
|
limit : int | None, optional
|
|
614
619
|
The maximum number of results to return per search method.
|
|
615
620
|
If None, a default limit will be applied.
|
|
@@ -634,5 +639,7 @@ class Graphiti:
|
|
|
634
639
|
"""
|
|
635
640
|
embedder = self.llm_client.get_embedder()
|
|
636
641
|
query_embedding = await generate_embedding(embedder, query)
|
|
637
|
-
relevant_nodes = await hybrid_node_search(
|
|
642
|
+
relevant_nodes = await hybrid_node_search(
|
|
643
|
+
[query], [query_embedding], self.driver, group_ids, limit
|
|
644
|
+
)
|
|
638
645
|
return relevant_nodes
|
graphiti_core/nodes.py
CHANGED
|
@@ -19,10 +19,10 @@ from abc import ABC, abstractmethod
|
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from enum import Enum
|
|
21
21
|
from time import time
|
|
22
|
+
from typing import Any
|
|
22
23
|
from uuid import uuid4
|
|
23
24
|
|
|
24
25
|
from neo4j import AsyncDriver
|
|
25
|
-
from openai import OpenAI
|
|
26
26
|
from pydantic import BaseModel, Field
|
|
27
27
|
|
|
28
28
|
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
@@ -69,6 +69,7 @@ class EpisodeType(Enum):
|
|
|
69
69
|
class Node(BaseModel, ABC):
|
|
70
70
|
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
|
71
71
|
name: str = Field(description='name of the node')
|
|
72
|
+
group_id: str | None = Field(description='partition of the graph')
|
|
72
73
|
labels: list[str] = Field(default_factory=list)
|
|
73
74
|
created_at: datetime = Field(default_factory=lambda: datetime.now())
|
|
74
75
|
|
|
@@ -106,11 +107,12 @@ class EpisodicNode(Node):
|
|
|
106
107
|
result = await driver.execute_query(
|
|
107
108
|
"""
|
|
108
109
|
MERGE (n:Episodic {uuid: $uuid})
|
|
109
|
-
SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content,
|
|
110
|
+
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
|
110
111
|
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
|
111
112
|
RETURN n.uuid AS uuid""",
|
|
112
113
|
uuid=self.uuid,
|
|
113
114
|
name=self.name,
|
|
115
|
+
group_id=self.group_id,
|
|
114
116
|
source_description=self.source_description,
|
|
115
117
|
content=self.content,
|
|
116
118
|
entity_edges=self.entity_edges,
|
|
@@ -141,29 +143,19 @@ class EpisodicNode(Node):
|
|
|
141
143
|
records, _, _ = await driver.execute_query(
|
|
142
144
|
"""
|
|
143
145
|
MATCH (e:Episodic {uuid: $uuid})
|
|
144
|
-
RETURN e.content
|
|
145
|
-
e.created_at
|
|
146
|
-
e.valid_at
|
|
147
|
-
e.uuid
|
|
148
|
-
e.name
|
|
149
|
-
e.
|
|
150
|
-
e.
|
|
146
|
+
RETURN e.content AS content,
|
|
147
|
+
e.created_at AS created_at,
|
|
148
|
+
e.valid_at AS valid_at,
|
|
149
|
+
e.uuid AS uuid,
|
|
150
|
+
e.name AS name,
|
|
151
|
+
e.group_id AS group_id
|
|
152
|
+
e.source_description AS source_description,
|
|
153
|
+
e.source AS source
|
|
151
154
|
""",
|
|
152
155
|
uuid=uuid,
|
|
153
156
|
)
|
|
154
157
|
|
|
155
|
-
episodes = [
|
|
156
|
-
EpisodicNode(
|
|
157
|
-
content=record['content'],
|
|
158
|
-
created_at=record['created_at'].to_native().timestamp(),
|
|
159
|
-
valid_at=(record['valid_at'].to_native()),
|
|
160
|
-
uuid=record['uuid'],
|
|
161
|
-
source=EpisodeType.from_str(record['source']),
|
|
162
|
-
name=record['name'],
|
|
163
|
-
source_description=record['source_description'],
|
|
164
|
-
)
|
|
165
|
-
for record in records
|
|
166
|
-
]
|
|
158
|
+
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
167
159
|
|
|
168
160
|
logger.info(f'Found Node: {uuid}')
|
|
169
161
|
|
|
@@ -174,10 +166,6 @@ class EntityNode(Node):
|
|
|
174
166
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
175
167
|
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
|
176
168
|
|
|
177
|
-
async def update_summary(self, driver: AsyncDriver): ...
|
|
178
|
-
|
|
179
|
-
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
|
180
|
-
|
|
181
169
|
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
|
|
182
170
|
start = time()
|
|
183
171
|
text = self.name.replace('\n', ' ')
|
|
@@ -192,10 +180,11 @@ class EntityNode(Node):
|
|
|
192
180
|
result = await driver.execute_query(
|
|
193
181
|
"""
|
|
194
182
|
MERGE (n:Entity {uuid: $uuid})
|
|
195
|
-
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
|
|
183
|
+
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
|
|
196
184
|
RETURN n.uuid AS uuid""",
|
|
197
185
|
uuid=self.uuid,
|
|
198
186
|
name=self.name,
|
|
187
|
+
group_id=self.group_id,
|
|
199
188
|
summary=self.summary,
|
|
200
189
|
name_embedding=self.name_embedding,
|
|
201
190
|
created_at=self.created_at,
|
|
@@ -225,26 +214,43 @@ class EntityNode(Node):
|
|
|
225
214
|
MATCH (n:Entity {uuid: $uuid})
|
|
226
215
|
RETURN
|
|
227
216
|
n.uuid As uuid,
|
|
228
|
-
n.name AS name,
|
|
217
|
+
n.name AS name,
|
|
218
|
+
n.name_embedding AS name_embedding,
|
|
219
|
+
n.group_id AS group_id
|
|
229
220
|
n.created_at AS created_at,
|
|
230
221
|
n.summary AS summary
|
|
231
222
|
""",
|
|
232
223
|
uuid=uuid,
|
|
233
224
|
)
|
|
234
225
|
|
|
235
|
-
nodes
|
|
236
|
-
|
|
237
|
-
for record in records:
|
|
238
|
-
nodes.append(
|
|
239
|
-
EntityNode(
|
|
240
|
-
uuid=record['uuid'],
|
|
241
|
-
name=record['name'],
|
|
242
|
-
labels=['Entity'],
|
|
243
|
-
created_at=record['created_at'].to_native(),
|
|
244
|
-
summary=record['summary'],
|
|
245
|
-
)
|
|
246
|
-
)
|
|
226
|
+
nodes = [get_entity_node_from_record(record) for record in records]
|
|
247
227
|
|
|
248
228
|
logger.info(f'Found Node: {uuid}')
|
|
249
229
|
|
|
250
230
|
return nodes[0]
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
# Node helpers
|
|
234
|
+
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
235
|
+
return EpisodicNode(
|
|
236
|
+
content=record['content'],
|
|
237
|
+
created_at=record['created_at'].to_native().timestamp(),
|
|
238
|
+
valid_at=(record['valid_at'].to_native()),
|
|
239
|
+
uuid=record['uuid'],
|
|
240
|
+
group_id=record['group_id'],
|
|
241
|
+
source=EpisodeType.from_str(record['source']),
|
|
242
|
+
name=record['name'],
|
|
243
|
+
source_description=record['source_description'],
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
248
|
+
return EntityNode(
|
|
249
|
+
uuid=record['uuid'],
|
|
250
|
+
name=record['name'],
|
|
251
|
+
group_id=record['group_id'],
|
|
252
|
+
name_embedding=record['name_embedding'],
|
|
253
|
+
labels=['Entity'],
|
|
254
|
+
created_at=record['created_at'].to_native(),
|
|
255
|
+
summary=record['summary'],
|
|
256
|
+
)
|
|
@@ -129,7 +129,7 @@ def v3(context: dict[str, Any]) -> list[Message]:
|
|
|
129
129
|
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
|
|
130
130
|
|
|
131
131
|
Existing Edges:
|
|
132
|
-
{json.dumps(context['
|
|
132
|
+
{json.dumps(context['related_edges'], indent=2)}
|
|
133
133
|
|
|
134
134
|
New Edge:
|
|
135
135
|
{json.dumps(context['extracted_edges'], indent=2)}
|