graphiti-core 0.2.2__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/PKG-INFO +2 -3
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/edges.py +39 -32
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/graphiti.py +45 -30
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/nodes.py +40 -39
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/search/search.py +5 -2
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/search/search_utils.py +74 -143
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/bulk_utils.py +31 -3
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/edge_operations.py +7 -5
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/graph_data_operations.py +17 -8
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/node_operations.py +1 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/pyproject.toml +3 -6
- graphiti_core-0.2.2/graphiti_core/utils/utils.py +0 -60
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/LICENSE +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/README.md +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/helpers.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/anthropic_client.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/client.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/groq_client.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/openai_client.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/extract_edge_dates.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/extract_edges.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/extract_nodes.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
- {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -12,11 +12,10 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
14
|
Requires-Dist: diskcache (>=5.6.3,<6.0.0)
|
|
15
|
-
Requires-Dist: fastapi (>=0.112.0,<0.113.0)
|
|
16
15
|
Requires-Dist: neo4j (>=5.23.0,<6.0.0)
|
|
16
|
+
Requires-Dist: numpy (>=2.1.1,<3.0.0)
|
|
17
17
|
Requires-Dist: openai (>=1.38.0,<2.0.0)
|
|
18
18
|
Requires-Dist: pydantic (>=2.8.2,<3.0.0)
|
|
19
|
-
Requires-Dist: sentence-transformers (>=3.0.1,<4.0.0)
|
|
20
19
|
Requires-Dist: tenacity (<9.0.0)
|
|
21
20
|
Description-Content-Type: text/markdown
|
|
22
21
|
|
|
@@ -18,6 +18,7 @@ import logging
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from time import time
|
|
21
|
+
from typing import Any
|
|
21
22
|
from uuid import uuid4
|
|
22
23
|
|
|
23
24
|
from neo4j import AsyncDriver
|
|
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
|
|
|
32
33
|
|
|
33
34
|
class Edge(BaseModel, ABC):
|
|
34
35
|
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
|
36
|
+
group_id: str | None = Field(description='partition of the graph')
|
|
35
37
|
source_node_uuid: str
|
|
36
38
|
target_node_uuid: str
|
|
37
39
|
created_at: datetime
|
|
@@ -61,11 +63,12 @@ class EpisodicEdge(Edge):
|
|
|
61
63
|
MATCH (episode:Episodic {uuid: $episode_uuid})
|
|
62
64
|
MATCH (node:Entity {uuid: $entity_uuid})
|
|
63
65
|
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
|
64
|
-
SET r = {uuid: $uuid, created_at: $created_at}
|
|
66
|
+
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
|
65
67
|
RETURN r.uuid AS uuid""",
|
|
66
68
|
episode_uuid=self.source_node_uuid,
|
|
67
69
|
entity_uuid=self.target_node_uuid,
|
|
68
70
|
uuid=self.uuid,
|
|
71
|
+
group_id=self.group_id,
|
|
69
72
|
created_at=self.created_at,
|
|
70
73
|
)
|
|
71
74
|
|
|
@@ -92,7 +95,8 @@ class EpisodicEdge(Edge):
|
|
|
92
95
|
"""
|
|
93
96
|
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
94
97
|
RETURN
|
|
95
|
-
e.uuid As uuid,
|
|
98
|
+
e.uuid As uuid,
|
|
99
|
+
e.group_id AS group_id,
|
|
96
100
|
n.uuid AS source_node_uuid,
|
|
97
101
|
m.uuid AS target_node_uuid,
|
|
98
102
|
e.created_at AS created_at
|
|
@@ -100,17 +104,7 @@ class EpisodicEdge(Edge):
|
|
|
100
104
|
uuid=uuid,
|
|
101
105
|
)
|
|
102
106
|
|
|
103
|
-
edges
|
|
104
|
-
|
|
105
|
-
for record in records:
|
|
106
|
-
edges.append(
|
|
107
|
-
EpisodicEdge(
|
|
108
|
-
uuid=record['uuid'],
|
|
109
|
-
source_node_uuid=record['source_node_uuid'],
|
|
110
|
-
target_node_uuid=record['target_node_uuid'],
|
|
111
|
-
created_at=record['created_at'].to_native(),
|
|
112
|
-
)
|
|
113
|
-
)
|
|
107
|
+
edges = [get_episodic_edge_from_record(record) for record in records]
|
|
114
108
|
|
|
115
109
|
logger.info(f'Found Edge: {uuid}')
|
|
116
110
|
|
|
@@ -153,7 +147,7 @@ class EntityEdge(Edge):
|
|
|
153
147
|
MATCH (source:Entity {uuid: $source_uuid})
|
|
154
148
|
MATCH (target:Entity {uuid: $target_uuid})
|
|
155
149
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
|
156
|
-
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
|
|
150
|
+
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
|
|
157
151
|
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
|
|
158
152
|
valid_at: $valid_at, invalid_at: $invalid_at}
|
|
159
153
|
RETURN r.uuid AS uuid""",
|
|
@@ -161,6 +155,7 @@ class EntityEdge(Edge):
|
|
|
161
155
|
target_uuid=self.target_node_uuid,
|
|
162
156
|
uuid=self.uuid,
|
|
163
157
|
name=self.name,
|
|
158
|
+
group_id=self.group_id,
|
|
164
159
|
fact=self.fact,
|
|
165
160
|
fact_embedding=self.fact_embedding,
|
|
166
161
|
episodes=self.episodes,
|
|
@@ -198,6 +193,7 @@ class EntityEdge(Edge):
|
|
|
198
193
|
m.uuid AS target_node_uuid,
|
|
199
194
|
e.created_at AS created_at,
|
|
200
195
|
e.name AS name,
|
|
196
|
+
e.group_id AS group_id,
|
|
201
197
|
e.fact AS fact,
|
|
202
198
|
e.fact_embedding AS fact_embedding,
|
|
203
199
|
e.episodes AS episodes,
|
|
@@ -208,25 +204,36 @@ class EntityEdge(Edge):
|
|
|
208
204
|
uuid=uuid,
|
|
209
205
|
)
|
|
210
206
|
|
|
211
|
-
edges
|
|
212
|
-
|
|
213
|
-
for record in records:
|
|
214
|
-
edges.append(
|
|
215
|
-
EntityEdge(
|
|
216
|
-
uuid=record['uuid'],
|
|
217
|
-
source_node_uuid=record['source_node_uuid'],
|
|
218
|
-
target_node_uuid=record['target_node_uuid'],
|
|
219
|
-
fact=record['fact'],
|
|
220
|
-
name=record['name'],
|
|
221
|
-
episodes=record['episodes'],
|
|
222
|
-
fact_embedding=record['fact_embedding'],
|
|
223
|
-
created_at=record['created_at'].to_native(),
|
|
224
|
-
expired_at=parse_db_date(record['expired_at']),
|
|
225
|
-
valid_at=parse_db_date(record['valid_at']),
|
|
226
|
-
invalid_at=parse_db_date(record['invalid_at']),
|
|
227
|
-
)
|
|
228
|
-
)
|
|
207
|
+
edges = [get_entity_edge_from_record(record) for record in records]
|
|
229
208
|
|
|
230
209
|
logger.info(f'Found Edge: {uuid}')
|
|
231
210
|
|
|
232
211
|
return edges[0]
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# Edge helpers
|
|
215
|
+
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|
216
|
+
return EpisodicEdge(
|
|
217
|
+
uuid=record['uuid'],
|
|
218
|
+
group_id=record['group_id'],
|
|
219
|
+
source_node_uuid=record['source_node_uuid'],
|
|
220
|
+
target_node_uuid=record['target_node_uuid'],
|
|
221
|
+
created_at=record['created_at'].to_native(),
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
226
|
+
return EntityEdge(
|
|
227
|
+
uuid=record['uuid'],
|
|
228
|
+
source_node_uuid=record['source_node_uuid'],
|
|
229
|
+
target_node_uuid=record['target_node_uuid'],
|
|
230
|
+
fact=record['fact'],
|
|
231
|
+
name=record['name'],
|
|
232
|
+
group_id=record['group_id'],
|
|
233
|
+
episodes=record['episodes'],
|
|
234
|
+
fact_embedding=record['fact_embedding'],
|
|
235
|
+
created_at=record['created_at'].to_native(),
|
|
236
|
+
expired_at=parse_db_date(record['expired_at']),
|
|
237
|
+
valid_at=parse_db_date(record['valid_at']),
|
|
238
|
+
invalid_at=parse_db_date(record['invalid_at']),
|
|
239
|
+
)
|
|
@@ -18,7 +18,6 @@ import asyncio
|
|
|
18
18
|
import logging
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from time import time
|
|
21
|
-
from typing import Callable
|
|
22
21
|
|
|
23
22
|
from dotenv import load_dotenv
|
|
24
23
|
from neo4j import AsyncGraphDatabase
|
|
@@ -120,7 +119,7 @@ class Graphiti:
|
|
|
120
119
|
|
|
121
120
|
Parameters
|
|
122
121
|
----------
|
|
123
|
-
|
|
122
|
+
self
|
|
124
123
|
|
|
125
124
|
Returns
|
|
126
125
|
-------
|
|
@@ -151,7 +150,7 @@ class Graphiti:
|
|
|
151
150
|
|
|
152
151
|
Parameters
|
|
153
152
|
----------
|
|
154
|
-
|
|
153
|
+
self
|
|
155
154
|
|
|
156
155
|
Returns
|
|
157
156
|
-------
|
|
@@ -178,6 +177,7 @@ class Graphiti:
|
|
|
178
177
|
self,
|
|
179
178
|
reference_time: datetime,
|
|
180
179
|
last_n: int = EPISODE_WINDOW_LEN,
|
|
180
|
+
group_ids: list[str | None] | None = None,
|
|
181
181
|
) -> list[EpisodicNode]:
|
|
182
182
|
"""
|
|
183
183
|
Retrieve the last n episodic nodes from the graph.
|
|
@@ -191,6 +191,8 @@ class Graphiti:
|
|
|
191
191
|
The reference time to retrieve episodes before.
|
|
192
192
|
last_n : int, optional
|
|
193
193
|
The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
|
|
194
|
+
group_ids : list[str | None], optional
|
|
195
|
+
The group ids to return data from.
|
|
194
196
|
|
|
195
197
|
Returns
|
|
196
198
|
-------
|
|
@@ -202,7 +204,7 @@ class Graphiti:
|
|
|
202
204
|
The actual retrieval is performed by the `retrieve_episodes` function
|
|
203
205
|
from the `graphiti_core.utils` module.
|
|
204
206
|
"""
|
|
205
|
-
return await retrieve_episodes(self.driver, reference_time, last_n)
|
|
207
|
+
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
|
|
206
208
|
|
|
207
209
|
async def add_episode(
|
|
208
210
|
self,
|
|
@@ -211,8 +213,8 @@ class Graphiti:
|
|
|
211
213
|
source_description: str,
|
|
212
214
|
reference_time: datetime,
|
|
213
215
|
source: EpisodeType = EpisodeType.message,
|
|
214
|
-
|
|
215
|
-
|
|
216
|
+
group_id: str | None = None,
|
|
217
|
+
uuid: str | None = None,
|
|
216
218
|
):
|
|
217
219
|
"""
|
|
218
220
|
Process an episode and update the graph.
|
|
@@ -232,10 +234,10 @@ class Graphiti:
|
|
|
232
234
|
The reference time for the episode.
|
|
233
235
|
source : EpisodeType, optional
|
|
234
236
|
The type of the episode. Defaults to EpisodeType.message.
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
237
|
+
group_id : str | None
|
|
238
|
+
An id for the graph partition the episode is a part of.
|
|
239
|
+
uuid : str | None
|
|
240
|
+
Optional uuid of the episode.
|
|
239
241
|
|
|
240
242
|
Returns
|
|
241
243
|
-------
|
|
@@ -266,9 +268,12 @@ class Graphiti:
|
|
|
266
268
|
embedder = self.llm_client.get_embedder()
|
|
267
269
|
now = datetime.now()
|
|
268
270
|
|
|
269
|
-
previous_episodes = await self.retrieve_episodes(
|
|
271
|
+
previous_episodes = await self.retrieve_episodes(
|
|
272
|
+
reference_time, last_n=3, group_ids=[group_id]
|
|
273
|
+
)
|
|
270
274
|
episode = EpisodicNode(
|
|
271
275
|
name=name,
|
|
276
|
+
group_id=group_id,
|
|
272
277
|
labels=[],
|
|
273
278
|
source=source,
|
|
274
279
|
content=episode_body,
|
|
@@ -276,6 +281,7 @@ class Graphiti:
|
|
|
276
281
|
created_at=now,
|
|
277
282
|
valid_at=reference_time,
|
|
278
283
|
)
|
|
284
|
+
episode.uuid = uuid if uuid is not None else episode.uuid
|
|
279
285
|
|
|
280
286
|
# Extract entities as nodes
|
|
281
287
|
|
|
@@ -299,7 +305,9 @@ class Graphiti:
|
|
|
299
305
|
|
|
300
306
|
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
|
301
307
|
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
|
|
302
|
-
extract_edges(
|
|
308
|
+
extract_edges(
|
|
309
|
+
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
|
310
|
+
),
|
|
303
311
|
)
|
|
304
312
|
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
|
305
313
|
nodes.extend(mentioned_nodes)
|
|
@@ -388,11 +396,7 @@ class Graphiti:
|
|
|
388
396
|
|
|
389
397
|
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
|
390
398
|
|
|
391
|
-
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
|
|
392
|
-
mentioned_nodes,
|
|
393
|
-
episode,
|
|
394
|
-
now,
|
|
395
|
-
)
|
|
399
|
+
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
|
396
400
|
|
|
397
401
|
logger.info(f'Built episodic edges: {episodic_edges}')
|
|
398
402
|
|
|
@@ -405,18 +409,10 @@ class Graphiti:
|
|
|
405
409
|
end = time()
|
|
406
410
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
407
411
|
|
|
408
|
-
if success_callback:
|
|
409
|
-
await success_callback(episode)
|
|
410
412
|
except Exception as e:
|
|
411
|
-
|
|
412
|
-
await error_callback(episode, e)
|
|
413
|
-
else:
|
|
414
|
-
raise e
|
|
413
|
+
raise e
|
|
415
414
|
|
|
416
|
-
async def add_episode_bulk(
|
|
417
|
-
self,
|
|
418
|
-
bulk_episodes: list[RawEpisode],
|
|
419
|
-
):
|
|
415
|
+
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
|
|
420
416
|
"""
|
|
421
417
|
Process multiple episodes in bulk and update the graph.
|
|
422
418
|
|
|
@@ -427,6 +423,8 @@ class Graphiti:
|
|
|
427
423
|
----------
|
|
428
424
|
bulk_episodes : list[RawEpisode]
|
|
429
425
|
A list of RawEpisode objects to be processed and added to the graph.
|
|
426
|
+
group_id : str | None
|
|
427
|
+
An id for the graph partition the episode is a part of.
|
|
430
428
|
|
|
431
429
|
Returns
|
|
432
430
|
-------
|
|
@@ -463,6 +461,7 @@ class Graphiti:
|
|
|
463
461
|
source=episode.source,
|
|
464
462
|
content=episode.content,
|
|
465
463
|
source_description=episode.source_description,
|
|
464
|
+
group_id=group_id,
|
|
466
465
|
created_at=now,
|
|
467
466
|
valid_at=episode.reference_time,
|
|
468
467
|
)
|
|
@@ -527,7 +526,13 @@ class Graphiti:
|
|
|
527
526
|
except Exception as e:
|
|
528
527
|
raise e
|
|
529
528
|
|
|
530
|
-
async def search(
|
|
529
|
+
async def search(
|
|
530
|
+
self,
|
|
531
|
+
query: str,
|
|
532
|
+
center_node_uuid: str | None = None,
|
|
533
|
+
group_ids: list[str | None] | None = None,
|
|
534
|
+
num_results=10,
|
|
535
|
+
):
|
|
531
536
|
"""
|
|
532
537
|
Perform a hybrid search on the knowledge graph.
|
|
533
538
|
|
|
@@ -540,6 +545,8 @@ class Graphiti:
|
|
|
540
545
|
The search query string.
|
|
541
546
|
center_node_uuid: str, optional
|
|
542
547
|
Facts will be reranked based on proximity to this node
|
|
548
|
+
group_ids : list[str | None] | None, optional
|
|
549
|
+
The graph partitions to return data from.
|
|
543
550
|
num_results : int, optional
|
|
544
551
|
The maximum number of results to return. Defaults to 10.
|
|
545
552
|
|
|
@@ -562,6 +569,7 @@ class Graphiti:
|
|
|
562
569
|
num_episodes=0,
|
|
563
570
|
num_edges=num_results,
|
|
564
571
|
num_nodes=0,
|
|
572
|
+
group_ids=group_ids,
|
|
565
573
|
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
|
|
566
574
|
reranker=reranker,
|
|
567
575
|
)
|
|
@@ -590,7 +598,10 @@ class Graphiti:
|
|
|
590
598
|
)
|
|
591
599
|
|
|
592
600
|
async def get_nodes_by_query(
|
|
593
|
-
self,
|
|
601
|
+
self,
|
|
602
|
+
query: str,
|
|
603
|
+
group_ids: list[str | None] | None = None,
|
|
604
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
594
605
|
) -> list[EntityNode]:
|
|
595
606
|
"""
|
|
596
607
|
Retrieve nodes from the graph database based on a text query.
|
|
@@ -602,6 +613,8 @@ class Graphiti:
|
|
|
602
613
|
----------
|
|
603
614
|
query : str
|
|
604
615
|
The text query to search for in the graph.
|
|
616
|
+
group_ids : list[str | None] | None, optional
|
|
617
|
+
The graph partitions to return data from.
|
|
605
618
|
limit : int | None, optional
|
|
606
619
|
The maximum number of results to return per search method.
|
|
607
620
|
If None, a default limit will be applied.
|
|
@@ -626,5 +639,7 @@ class Graphiti:
|
|
|
626
639
|
"""
|
|
627
640
|
embedder = self.llm_client.get_embedder()
|
|
628
641
|
query_embedding = await generate_embedding(embedder, query)
|
|
629
|
-
relevant_nodes = await hybrid_node_search(
|
|
642
|
+
relevant_nodes = await hybrid_node_search(
|
|
643
|
+
[query], [query_embedding], self.driver, group_ids, limit
|
|
644
|
+
)
|
|
630
645
|
return relevant_nodes
|
|
@@ -19,10 +19,10 @@ from abc import ABC, abstractmethod
|
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from enum import Enum
|
|
21
21
|
from time import time
|
|
22
|
+
from typing import Any
|
|
22
23
|
from uuid import uuid4
|
|
23
24
|
|
|
24
25
|
from neo4j import AsyncDriver
|
|
25
|
-
from openai import OpenAI
|
|
26
26
|
from pydantic import BaseModel, Field
|
|
27
27
|
|
|
28
28
|
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
@@ -69,6 +69,7 @@ class EpisodeType(Enum):
|
|
|
69
69
|
class Node(BaseModel, ABC):
|
|
70
70
|
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
|
71
71
|
name: str = Field(description='name of the node')
|
|
72
|
+
group_id: str | None = Field(description='partition of the graph')
|
|
72
73
|
labels: list[str] = Field(default_factory=list)
|
|
73
74
|
created_at: datetime = Field(default_factory=lambda: datetime.now())
|
|
74
75
|
|
|
@@ -106,11 +107,12 @@ class EpisodicNode(Node):
|
|
|
106
107
|
result = await driver.execute_query(
|
|
107
108
|
"""
|
|
108
109
|
MERGE (n:Episodic {uuid: $uuid})
|
|
109
|
-
SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content,
|
|
110
|
+
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
|
110
111
|
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
|
111
112
|
RETURN n.uuid AS uuid""",
|
|
112
113
|
uuid=self.uuid,
|
|
113
114
|
name=self.name,
|
|
115
|
+
group_id=self.group_id,
|
|
114
116
|
source_description=self.source_description,
|
|
115
117
|
content=self.content,
|
|
116
118
|
entity_edges=self.entity_edges,
|
|
@@ -141,29 +143,19 @@ class EpisodicNode(Node):
|
|
|
141
143
|
records, _, _ = await driver.execute_query(
|
|
142
144
|
"""
|
|
143
145
|
MATCH (e:Episodic {uuid: $uuid})
|
|
144
|
-
RETURN e.content
|
|
145
|
-
e.created_at
|
|
146
|
-
e.valid_at
|
|
147
|
-
e.uuid
|
|
148
|
-
e.name
|
|
149
|
-
e.
|
|
150
|
-
e.
|
|
146
|
+
RETURN e.content AS content,
|
|
147
|
+
e.created_at AS created_at,
|
|
148
|
+
e.valid_at AS valid_at,
|
|
149
|
+
e.uuid AS uuid,
|
|
150
|
+
e.name AS name,
|
|
151
|
+
e.group_id AS group_id
|
|
152
|
+
e.source_description AS source_description,
|
|
153
|
+
e.source AS source
|
|
151
154
|
""",
|
|
152
155
|
uuid=uuid,
|
|
153
156
|
)
|
|
154
157
|
|
|
155
|
-
episodes = [
|
|
156
|
-
EpisodicNode(
|
|
157
|
-
content=record['content'],
|
|
158
|
-
created_at=record['created_at'].to_native().timestamp(),
|
|
159
|
-
valid_at=(record['valid_at'].to_native()),
|
|
160
|
-
uuid=record['uuid'],
|
|
161
|
-
source=EpisodeType.from_str(record['source']),
|
|
162
|
-
name=record['name'],
|
|
163
|
-
source_description=record['source_description'],
|
|
164
|
-
)
|
|
165
|
-
for record in records
|
|
166
|
-
]
|
|
158
|
+
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
167
159
|
|
|
168
160
|
logger.info(f'Found Node: {uuid}')
|
|
169
161
|
|
|
@@ -174,10 +166,6 @@ class EntityNode(Node):
|
|
|
174
166
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
175
167
|
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
|
176
168
|
|
|
177
|
-
async def update_summary(self, driver: AsyncDriver): ...
|
|
178
|
-
|
|
179
|
-
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
|
180
|
-
|
|
181
169
|
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
|
|
182
170
|
start = time()
|
|
183
171
|
text = self.name.replace('\n', ' ')
|
|
@@ -192,10 +180,11 @@ class EntityNode(Node):
|
|
|
192
180
|
result = await driver.execute_query(
|
|
193
181
|
"""
|
|
194
182
|
MERGE (n:Entity {uuid: $uuid})
|
|
195
|
-
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
|
|
183
|
+
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
|
|
196
184
|
RETURN n.uuid AS uuid""",
|
|
197
185
|
uuid=self.uuid,
|
|
198
186
|
name=self.name,
|
|
187
|
+
group_id=self.group_id,
|
|
199
188
|
summary=self.summary,
|
|
200
189
|
name_embedding=self.name_embedding,
|
|
201
190
|
created_at=self.created_at,
|
|
@@ -227,25 +216,14 @@ class EntityNode(Node):
|
|
|
227
216
|
n.uuid As uuid,
|
|
228
217
|
n.name AS name,
|
|
229
218
|
n.name_embedding AS name_embedding,
|
|
219
|
+
n.group_id AS group_id
|
|
230
220
|
n.created_at AS created_at,
|
|
231
221
|
n.summary AS summary
|
|
232
222
|
""",
|
|
233
223
|
uuid=uuid,
|
|
234
224
|
)
|
|
235
225
|
|
|
236
|
-
nodes
|
|
237
|
-
|
|
238
|
-
for record in records:
|
|
239
|
-
nodes.append(
|
|
240
|
-
EntityNode(
|
|
241
|
-
uuid=record['uuid'],
|
|
242
|
-
name=record['name'],
|
|
243
|
-
name_embedding=record['name_embedding'],
|
|
244
|
-
labels=['Entity'],
|
|
245
|
-
created_at=record['created_at'].to_native(),
|
|
246
|
-
summary=record['summary'],
|
|
247
|
-
)
|
|
248
|
-
)
|
|
226
|
+
nodes = [get_entity_node_from_record(record) for record in records]
|
|
249
227
|
|
|
250
228
|
logger.info(f'Found Node: {uuid}')
|
|
251
229
|
|
|
@@ -253,3 +231,26 @@ class EntityNode(Node):
|
|
|
253
231
|
|
|
254
232
|
|
|
255
233
|
# Node helpers
|
|
234
|
+
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
235
|
+
return EpisodicNode(
|
|
236
|
+
content=record['content'],
|
|
237
|
+
created_at=record['created_at'].to_native().timestamp(),
|
|
238
|
+
valid_at=(record['valid_at'].to_native()),
|
|
239
|
+
uuid=record['uuid'],
|
|
240
|
+
group_id=record['group_id'],
|
|
241
|
+
source=EpisodeType.from_str(record['source']),
|
|
242
|
+
name=record['name'],
|
|
243
|
+
source_description=record['source_description'],
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
248
|
+
return EntityNode(
|
|
249
|
+
uuid=record['uuid'],
|
|
250
|
+
name=record['name'],
|
|
251
|
+
group_id=record['group_id'],
|
|
252
|
+
name_embedding=record['name_embedding'],
|
|
253
|
+
labels=['Entity'],
|
|
254
|
+
created_at=record['created_at'].to_native(),
|
|
255
|
+
summary=record['summary'],
|
|
256
|
+
)
|
|
@@ -52,6 +52,7 @@ class SearchConfig(BaseModel):
|
|
|
52
52
|
num_edges: int = Field(default=10)
|
|
53
53
|
num_nodes: int = Field(default=10)
|
|
54
54
|
num_episodes: int = EPISODE_WINDOW_LEN
|
|
55
|
+
group_ids: list[str | None] | None
|
|
55
56
|
search_methods: list[SearchMethod]
|
|
56
57
|
reranker: Reranker | None
|
|
57
58
|
|
|
@@ -83,7 +84,9 @@ async def hybrid_search(
|
|
|
83
84
|
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
|
84
85
|
|
|
85
86
|
if SearchMethod.bm25 in config.search_methods:
|
|
86
|
-
text_search = await edge_fulltext_search(
|
|
87
|
+
text_search = await edge_fulltext_search(
|
|
88
|
+
driver, query, None, None, config.group_ids, 2 * config.num_edges
|
|
89
|
+
)
|
|
87
90
|
search_results.append(text_search)
|
|
88
91
|
|
|
89
92
|
if SearchMethod.cosine_similarity in config.search_methods:
|
|
@@ -95,7 +98,7 @@ async def hybrid_search(
|
|
|
95
98
|
)
|
|
96
99
|
|
|
97
100
|
similarity_search = await edge_similarity_search(
|
|
98
|
-
driver, search_vector, None, None, 2 * config.num_edges
|
|
101
|
+
driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
|
|
99
102
|
)
|
|
100
103
|
search_results.append(similarity_search)
|
|
101
104
|
|