graphiti-core 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/edges.py +39 -32
- graphiti_core/graphiti.py +45 -30
- graphiti_core/nodes.py +40 -39
- graphiti_core/search/search.py +5 -2
- graphiti_core/search/search_utils.py +74 -143
- graphiti_core/utils/bulk_utils.py +31 -3
- graphiti_core/utils/maintenance/edge_operations.py +7 -5
- graphiti_core/utils/maintenance/graph_data_operations.py +17 -8
- graphiti_core/utils/maintenance/node_operations.py +1 -0
- {graphiti_core-0.2.2.dist-info → graphiti_core-0.2.3.dist-info}/METADATA +2 -3
- {graphiti_core-0.2.2.dist-info → graphiti_core-0.2.3.dist-info}/RECORD +13 -14
- graphiti_core/utils/utils.py +0 -60
- {graphiti_core-0.2.2.dist-info → graphiti_core-0.2.3.dist-info}/LICENSE +0 -0
- {graphiti_core-0.2.2.dist-info → graphiti_core-0.2.3.dist-info}/WHEEL +0 -0
graphiti_core/edges.py
CHANGED
|
@@ -18,6 +18,7 @@ import logging
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from time import time
|
|
21
|
+
from typing import Any
|
|
21
22
|
from uuid import uuid4
|
|
22
23
|
|
|
23
24
|
from neo4j import AsyncDriver
|
|
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
|
|
|
32
33
|
|
|
33
34
|
class Edge(BaseModel, ABC):
|
|
34
35
|
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
|
36
|
+
group_id: str | None = Field(description='partition of the graph')
|
|
35
37
|
source_node_uuid: str
|
|
36
38
|
target_node_uuid: str
|
|
37
39
|
created_at: datetime
|
|
@@ -61,11 +63,12 @@ class EpisodicEdge(Edge):
|
|
|
61
63
|
MATCH (episode:Episodic {uuid: $episode_uuid})
|
|
62
64
|
MATCH (node:Entity {uuid: $entity_uuid})
|
|
63
65
|
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
|
64
|
-
SET r = {uuid: $uuid, created_at: $created_at}
|
|
66
|
+
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
|
65
67
|
RETURN r.uuid AS uuid""",
|
|
66
68
|
episode_uuid=self.source_node_uuid,
|
|
67
69
|
entity_uuid=self.target_node_uuid,
|
|
68
70
|
uuid=self.uuid,
|
|
71
|
+
group_id=self.group_id,
|
|
69
72
|
created_at=self.created_at,
|
|
70
73
|
)
|
|
71
74
|
|
|
@@ -92,7 +95,8 @@ class EpisodicEdge(Edge):
|
|
|
92
95
|
"""
|
|
93
96
|
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
94
97
|
RETURN
|
|
95
|
-
e.uuid As uuid,
|
|
98
|
+
e.uuid As uuid,
|
|
99
|
+
e.group_id AS group_id,
|
|
96
100
|
n.uuid AS source_node_uuid,
|
|
97
101
|
m.uuid AS target_node_uuid,
|
|
98
102
|
e.created_at AS created_at
|
|
@@ -100,17 +104,7 @@ class EpisodicEdge(Edge):
|
|
|
100
104
|
uuid=uuid,
|
|
101
105
|
)
|
|
102
106
|
|
|
103
|
-
edges
|
|
104
|
-
|
|
105
|
-
for record in records:
|
|
106
|
-
edges.append(
|
|
107
|
-
EpisodicEdge(
|
|
108
|
-
uuid=record['uuid'],
|
|
109
|
-
source_node_uuid=record['source_node_uuid'],
|
|
110
|
-
target_node_uuid=record['target_node_uuid'],
|
|
111
|
-
created_at=record['created_at'].to_native(),
|
|
112
|
-
)
|
|
113
|
-
)
|
|
107
|
+
edges = [get_episodic_edge_from_record(record) for record in records]
|
|
114
108
|
|
|
115
109
|
logger.info(f'Found Edge: {uuid}')
|
|
116
110
|
|
|
@@ -153,7 +147,7 @@ class EntityEdge(Edge):
|
|
|
153
147
|
MATCH (source:Entity {uuid: $source_uuid})
|
|
154
148
|
MATCH (target:Entity {uuid: $target_uuid})
|
|
155
149
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
|
156
|
-
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
|
|
150
|
+
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
|
|
157
151
|
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
|
|
158
152
|
valid_at: $valid_at, invalid_at: $invalid_at}
|
|
159
153
|
RETURN r.uuid AS uuid""",
|
|
@@ -161,6 +155,7 @@ class EntityEdge(Edge):
|
|
|
161
155
|
target_uuid=self.target_node_uuid,
|
|
162
156
|
uuid=self.uuid,
|
|
163
157
|
name=self.name,
|
|
158
|
+
group_id=self.group_id,
|
|
164
159
|
fact=self.fact,
|
|
165
160
|
fact_embedding=self.fact_embedding,
|
|
166
161
|
episodes=self.episodes,
|
|
@@ -198,6 +193,7 @@ class EntityEdge(Edge):
|
|
|
198
193
|
m.uuid AS target_node_uuid,
|
|
199
194
|
e.created_at AS created_at,
|
|
200
195
|
e.name AS name,
|
|
196
|
+
e.group_id AS group_id,
|
|
201
197
|
e.fact AS fact,
|
|
202
198
|
e.fact_embedding AS fact_embedding,
|
|
203
199
|
e.episodes AS episodes,
|
|
@@ -208,25 +204,36 @@ class EntityEdge(Edge):
|
|
|
208
204
|
uuid=uuid,
|
|
209
205
|
)
|
|
210
206
|
|
|
211
|
-
edges
|
|
212
|
-
|
|
213
|
-
for record in records:
|
|
214
|
-
edges.append(
|
|
215
|
-
EntityEdge(
|
|
216
|
-
uuid=record['uuid'],
|
|
217
|
-
source_node_uuid=record['source_node_uuid'],
|
|
218
|
-
target_node_uuid=record['target_node_uuid'],
|
|
219
|
-
fact=record['fact'],
|
|
220
|
-
name=record['name'],
|
|
221
|
-
episodes=record['episodes'],
|
|
222
|
-
fact_embedding=record['fact_embedding'],
|
|
223
|
-
created_at=record['created_at'].to_native(),
|
|
224
|
-
expired_at=parse_db_date(record['expired_at']),
|
|
225
|
-
valid_at=parse_db_date(record['valid_at']),
|
|
226
|
-
invalid_at=parse_db_date(record['invalid_at']),
|
|
227
|
-
)
|
|
228
|
-
)
|
|
207
|
+
edges = [get_entity_edge_from_record(record) for record in records]
|
|
229
208
|
|
|
230
209
|
logger.info(f'Found Edge: {uuid}')
|
|
231
210
|
|
|
232
211
|
return edges[0]
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# Edge helpers
|
|
215
|
+
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|
216
|
+
return EpisodicEdge(
|
|
217
|
+
uuid=record['uuid'],
|
|
218
|
+
group_id=record['group_id'],
|
|
219
|
+
source_node_uuid=record['source_node_uuid'],
|
|
220
|
+
target_node_uuid=record['target_node_uuid'],
|
|
221
|
+
created_at=record['created_at'].to_native(),
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
226
|
+
return EntityEdge(
|
|
227
|
+
uuid=record['uuid'],
|
|
228
|
+
source_node_uuid=record['source_node_uuid'],
|
|
229
|
+
target_node_uuid=record['target_node_uuid'],
|
|
230
|
+
fact=record['fact'],
|
|
231
|
+
name=record['name'],
|
|
232
|
+
group_id=record['group_id'],
|
|
233
|
+
episodes=record['episodes'],
|
|
234
|
+
fact_embedding=record['fact_embedding'],
|
|
235
|
+
created_at=record['created_at'].to_native(),
|
|
236
|
+
expired_at=parse_db_date(record['expired_at']),
|
|
237
|
+
valid_at=parse_db_date(record['valid_at']),
|
|
238
|
+
invalid_at=parse_db_date(record['invalid_at']),
|
|
239
|
+
)
|
graphiti_core/graphiti.py
CHANGED
|
@@ -18,7 +18,6 @@ import asyncio
|
|
|
18
18
|
import logging
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from time import time
|
|
21
|
-
from typing import Callable
|
|
22
21
|
|
|
23
22
|
from dotenv import load_dotenv
|
|
24
23
|
from neo4j import AsyncGraphDatabase
|
|
@@ -120,7 +119,7 @@ class Graphiti:
|
|
|
120
119
|
|
|
121
120
|
Parameters
|
|
122
121
|
----------
|
|
123
|
-
|
|
122
|
+
self
|
|
124
123
|
|
|
125
124
|
Returns
|
|
126
125
|
-------
|
|
@@ -151,7 +150,7 @@ class Graphiti:
|
|
|
151
150
|
|
|
152
151
|
Parameters
|
|
153
152
|
----------
|
|
154
|
-
|
|
153
|
+
self
|
|
155
154
|
|
|
156
155
|
Returns
|
|
157
156
|
-------
|
|
@@ -178,6 +177,7 @@ class Graphiti:
|
|
|
178
177
|
self,
|
|
179
178
|
reference_time: datetime,
|
|
180
179
|
last_n: int = EPISODE_WINDOW_LEN,
|
|
180
|
+
group_ids: list[str | None] | None = None,
|
|
181
181
|
) -> list[EpisodicNode]:
|
|
182
182
|
"""
|
|
183
183
|
Retrieve the last n episodic nodes from the graph.
|
|
@@ -191,6 +191,8 @@ class Graphiti:
|
|
|
191
191
|
The reference time to retrieve episodes before.
|
|
192
192
|
last_n : int, optional
|
|
193
193
|
The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
|
|
194
|
+
group_ids : list[str | None], optional
|
|
195
|
+
The group ids to return data from.
|
|
194
196
|
|
|
195
197
|
Returns
|
|
196
198
|
-------
|
|
@@ -202,7 +204,7 @@ class Graphiti:
|
|
|
202
204
|
The actual retrieval is performed by the `retrieve_episodes` function
|
|
203
205
|
from the `graphiti_core.utils` module.
|
|
204
206
|
"""
|
|
205
|
-
return await retrieve_episodes(self.driver, reference_time, last_n)
|
|
207
|
+
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
|
|
206
208
|
|
|
207
209
|
async def add_episode(
|
|
208
210
|
self,
|
|
@@ -211,8 +213,8 @@ class Graphiti:
|
|
|
211
213
|
source_description: str,
|
|
212
214
|
reference_time: datetime,
|
|
213
215
|
source: EpisodeType = EpisodeType.message,
|
|
214
|
-
|
|
215
|
-
|
|
216
|
+
group_id: str | None = None,
|
|
217
|
+
uuid: str | None = None,
|
|
216
218
|
):
|
|
217
219
|
"""
|
|
218
220
|
Process an episode and update the graph.
|
|
@@ -232,10 +234,10 @@ class Graphiti:
|
|
|
232
234
|
The reference time for the episode.
|
|
233
235
|
source : EpisodeType, optional
|
|
234
236
|
The type of the episode. Defaults to EpisodeType.message.
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
237
|
+
group_id : str | None
|
|
238
|
+
An id for the graph partition the episode is a part of.
|
|
239
|
+
uuid : str | None
|
|
240
|
+
Optional uuid of the episode.
|
|
239
241
|
|
|
240
242
|
Returns
|
|
241
243
|
-------
|
|
@@ -266,9 +268,12 @@ class Graphiti:
|
|
|
266
268
|
embedder = self.llm_client.get_embedder()
|
|
267
269
|
now = datetime.now()
|
|
268
270
|
|
|
269
|
-
previous_episodes = await self.retrieve_episodes(
|
|
271
|
+
previous_episodes = await self.retrieve_episodes(
|
|
272
|
+
reference_time, last_n=3, group_ids=[group_id]
|
|
273
|
+
)
|
|
270
274
|
episode = EpisodicNode(
|
|
271
275
|
name=name,
|
|
276
|
+
group_id=group_id,
|
|
272
277
|
labels=[],
|
|
273
278
|
source=source,
|
|
274
279
|
content=episode_body,
|
|
@@ -276,6 +281,7 @@ class Graphiti:
|
|
|
276
281
|
created_at=now,
|
|
277
282
|
valid_at=reference_time,
|
|
278
283
|
)
|
|
284
|
+
episode.uuid = uuid if uuid is not None else episode.uuid
|
|
279
285
|
|
|
280
286
|
# Extract entities as nodes
|
|
281
287
|
|
|
@@ -299,7 +305,9 @@ class Graphiti:
|
|
|
299
305
|
|
|
300
306
|
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
|
301
307
|
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
|
|
302
|
-
extract_edges(
|
|
308
|
+
extract_edges(
|
|
309
|
+
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
|
310
|
+
),
|
|
303
311
|
)
|
|
304
312
|
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
|
305
313
|
nodes.extend(mentioned_nodes)
|
|
@@ -388,11 +396,7 @@ class Graphiti:
|
|
|
388
396
|
|
|
389
397
|
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
|
390
398
|
|
|
391
|
-
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
|
|
392
|
-
mentioned_nodes,
|
|
393
|
-
episode,
|
|
394
|
-
now,
|
|
395
|
-
)
|
|
399
|
+
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
|
396
400
|
|
|
397
401
|
logger.info(f'Built episodic edges: {episodic_edges}')
|
|
398
402
|
|
|
@@ -405,18 +409,10 @@ class Graphiti:
|
|
|
405
409
|
end = time()
|
|
406
410
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
407
411
|
|
|
408
|
-
if success_callback:
|
|
409
|
-
await success_callback(episode)
|
|
410
412
|
except Exception as e:
|
|
411
|
-
|
|
412
|
-
await error_callback(episode, e)
|
|
413
|
-
else:
|
|
414
|
-
raise e
|
|
413
|
+
raise e
|
|
415
414
|
|
|
416
|
-
async def add_episode_bulk(
|
|
417
|
-
self,
|
|
418
|
-
bulk_episodes: list[RawEpisode],
|
|
419
|
-
):
|
|
415
|
+
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
|
|
420
416
|
"""
|
|
421
417
|
Process multiple episodes in bulk and update the graph.
|
|
422
418
|
|
|
@@ -427,6 +423,8 @@ class Graphiti:
|
|
|
427
423
|
----------
|
|
428
424
|
bulk_episodes : list[RawEpisode]
|
|
429
425
|
A list of RawEpisode objects to be processed and added to the graph.
|
|
426
|
+
group_id : str | None
|
|
427
|
+
An id for the graph partition the episode is a part of.
|
|
430
428
|
|
|
431
429
|
Returns
|
|
432
430
|
-------
|
|
@@ -463,6 +461,7 @@ class Graphiti:
|
|
|
463
461
|
source=episode.source,
|
|
464
462
|
content=episode.content,
|
|
465
463
|
source_description=episode.source_description,
|
|
464
|
+
group_id=group_id,
|
|
466
465
|
created_at=now,
|
|
467
466
|
valid_at=episode.reference_time,
|
|
468
467
|
)
|
|
@@ -527,7 +526,13 @@ class Graphiti:
|
|
|
527
526
|
except Exception as e:
|
|
528
527
|
raise e
|
|
529
528
|
|
|
530
|
-
async def search(
|
|
529
|
+
async def search(
|
|
530
|
+
self,
|
|
531
|
+
query: str,
|
|
532
|
+
center_node_uuid: str | None = None,
|
|
533
|
+
group_ids: list[str | None] | None = None,
|
|
534
|
+
num_results=10,
|
|
535
|
+
):
|
|
531
536
|
"""
|
|
532
537
|
Perform a hybrid search on the knowledge graph.
|
|
533
538
|
|
|
@@ -540,6 +545,8 @@ class Graphiti:
|
|
|
540
545
|
The search query string.
|
|
541
546
|
center_node_uuid: str, optional
|
|
542
547
|
Facts will be reranked based on proximity to this node
|
|
548
|
+
group_ids : list[str | None] | None, optional
|
|
549
|
+
The graph partitions to return data from.
|
|
543
550
|
num_results : int, optional
|
|
544
551
|
The maximum number of results to return. Defaults to 10.
|
|
545
552
|
|
|
@@ -562,6 +569,7 @@ class Graphiti:
|
|
|
562
569
|
num_episodes=0,
|
|
563
570
|
num_edges=num_results,
|
|
564
571
|
num_nodes=0,
|
|
572
|
+
group_ids=group_ids,
|
|
565
573
|
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
|
|
566
574
|
reranker=reranker,
|
|
567
575
|
)
|
|
@@ -590,7 +598,10 @@ class Graphiti:
|
|
|
590
598
|
)
|
|
591
599
|
|
|
592
600
|
async def get_nodes_by_query(
|
|
593
|
-
self,
|
|
601
|
+
self,
|
|
602
|
+
query: str,
|
|
603
|
+
group_ids: list[str | None] | None = None,
|
|
604
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
594
605
|
) -> list[EntityNode]:
|
|
595
606
|
"""
|
|
596
607
|
Retrieve nodes from the graph database based on a text query.
|
|
@@ -602,6 +613,8 @@ class Graphiti:
|
|
|
602
613
|
----------
|
|
603
614
|
query : str
|
|
604
615
|
The text query to search for in the graph.
|
|
616
|
+
group_ids : list[str | None] | None, optional
|
|
617
|
+
The graph partitions to return data from.
|
|
605
618
|
limit : int | None, optional
|
|
606
619
|
The maximum number of results to return per search method.
|
|
607
620
|
If None, a default limit will be applied.
|
|
@@ -626,5 +639,7 @@ class Graphiti:
|
|
|
626
639
|
"""
|
|
627
640
|
embedder = self.llm_client.get_embedder()
|
|
628
641
|
query_embedding = await generate_embedding(embedder, query)
|
|
629
|
-
relevant_nodes = await hybrid_node_search(
|
|
642
|
+
relevant_nodes = await hybrid_node_search(
|
|
643
|
+
[query], [query_embedding], self.driver, group_ids, limit
|
|
644
|
+
)
|
|
630
645
|
return relevant_nodes
|
graphiti_core/nodes.py
CHANGED
|
@@ -19,10 +19,10 @@ from abc import ABC, abstractmethod
|
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from enum import Enum
|
|
21
21
|
from time import time
|
|
22
|
+
from typing import Any
|
|
22
23
|
from uuid import uuid4
|
|
23
24
|
|
|
24
25
|
from neo4j import AsyncDriver
|
|
25
|
-
from openai import OpenAI
|
|
26
26
|
from pydantic import BaseModel, Field
|
|
27
27
|
|
|
28
28
|
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
@@ -69,6 +69,7 @@ class EpisodeType(Enum):
|
|
|
69
69
|
class Node(BaseModel, ABC):
|
|
70
70
|
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
|
71
71
|
name: str = Field(description='name of the node')
|
|
72
|
+
group_id: str | None = Field(description='partition of the graph')
|
|
72
73
|
labels: list[str] = Field(default_factory=list)
|
|
73
74
|
created_at: datetime = Field(default_factory=lambda: datetime.now())
|
|
74
75
|
|
|
@@ -106,11 +107,12 @@ class EpisodicNode(Node):
|
|
|
106
107
|
result = await driver.execute_query(
|
|
107
108
|
"""
|
|
108
109
|
MERGE (n:Episodic {uuid: $uuid})
|
|
109
|
-
SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content,
|
|
110
|
+
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
|
110
111
|
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
|
111
112
|
RETURN n.uuid AS uuid""",
|
|
112
113
|
uuid=self.uuid,
|
|
113
114
|
name=self.name,
|
|
115
|
+
group_id=self.group_id,
|
|
114
116
|
source_description=self.source_description,
|
|
115
117
|
content=self.content,
|
|
116
118
|
entity_edges=self.entity_edges,
|
|
@@ -141,29 +143,19 @@ class EpisodicNode(Node):
|
|
|
141
143
|
records, _, _ = await driver.execute_query(
|
|
142
144
|
"""
|
|
143
145
|
MATCH (e:Episodic {uuid: $uuid})
|
|
144
|
-
RETURN e.content
|
|
145
|
-
e.created_at
|
|
146
|
-
e.valid_at
|
|
147
|
-
e.uuid
|
|
148
|
-
e.name
|
|
149
|
-
e.
|
|
150
|
-
e.
|
|
146
|
+
RETURN e.content AS content,
|
|
147
|
+
e.created_at AS created_at,
|
|
148
|
+
e.valid_at AS valid_at,
|
|
149
|
+
e.uuid AS uuid,
|
|
150
|
+
e.name AS name,
|
|
151
|
+
e.group_id AS group_id
|
|
152
|
+
e.source_description AS source_description,
|
|
153
|
+
e.source AS source
|
|
151
154
|
""",
|
|
152
155
|
uuid=uuid,
|
|
153
156
|
)
|
|
154
157
|
|
|
155
|
-
episodes = [
|
|
156
|
-
EpisodicNode(
|
|
157
|
-
content=record['content'],
|
|
158
|
-
created_at=record['created_at'].to_native().timestamp(),
|
|
159
|
-
valid_at=(record['valid_at'].to_native()),
|
|
160
|
-
uuid=record['uuid'],
|
|
161
|
-
source=EpisodeType.from_str(record['source']),
|
|
162
|
-
name=record['name'],
|
|
163
|
-
source_description=record['source_description'],
|
|
164
|
-
)
|
|
165
|
-
for record in records
|
|
166
|
-
]
|
|
158
|
+
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
167
159
|
|
|
168
160
|
logger.info(f'Found Node: {uuid}')
|
|
169
161
|
|
|
@@ -174,10 +166,6 @@ class EntityNode(Node):
|
|
|
174
166
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
175
167
|
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
|
176
168
|
|
|
177
|
-
async def update_summary(self, driver: AsyncDriver): ...
|
|
178
|
-
|
|
179
|
-
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
|
180
|
-
|
|
181
169
|
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
|
|
182
170
|
start = time()
|
|
183
171
|
text = self.name.replace('\n', ' ')
|
|
@@ -192,10 +180,11 @@ class EntityNode(Node):
|
|
|
192
180
|
result = await driver.execute_query(
|
|
193
181
|
"""
|
|
194
182
|
MERGE (n:Entity {uuid: $uuid})
|
|
195
|
-
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
|
|
183
|
+
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
|
|
196
184
|
RETURN n.uuid AS uuid""",
|
|
197
185
|
uuid=self.uuid,
|
|
198
186
|
name=self.name,
|
|
187
|
+
group_id=self.group_id,
|
|
199
188
|
summary=self.summary,
|
|
200
189
|
name_embedding=self.name_embedding,
|
|
201
190
|
created_at=self.created_at,
|
|
@@ -227,25 +216,14 @@ class EntityNode(Node):
|
|
|
227
216
|
n.uuid As uuid,
|
|
228
217
|
n.name AS name,
|
|
229
218
|
n.name_embedding AS name_embedding,
|
|
219
|
+
n.group_id AS group_id
|
|
230
220
|
n.created_at AS created_at,
|
|
231
221
|
n.summary AS summary
|
|
232
222
|
""",
|
|
233
223
|
uuid=uuid,
|
|
234
224
|
)
|
|
235
225
|
|
|
236
|
-
nodes
|
|
237
|
-
|
|
238
|
-
for record in records:
|
|
239
|
-
nodes.append(
|
|
240
|
-
EntityNode(
|
|
241
|
-
uuid=record['uuid'],
|
|
242
|
-
name=record['name'],
|
|
243
|
-
name_embedding=record['name_embedding'],
|
|
244
|
-
labels=['Entity'],
|
|
245
|
-
created_at=record['created_at'].to_native(),
|
|
246
|
-
summary=record['summary'],
|
|
247
|
-
)
|
|
248
|
-
)
|
|
226
|
+
nodes = [get_entity_node_from_record(record) for record in records]
|
|
249
227
|
|
|
250
228
|
logger.info(f'Found Node: {uuid}')
|
|
251
229
|
|
|
@@ -253,3 +231,26 @@ class EntityNode(Node):
|
|
|
253
231
|
|
|
254
232
|
|
|
255
233
|
# Node helpers
|
|
234
|
+
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
235
|
+
return EpisodicNode(
|
|
236
|
+
content=record['content'],
|
|
237
|
+
created_at=record['created_at'].to_native().timestamp(),
|
|
238
|
+
valid_at=(record['valid_at'].to_native()),
|
|
239
|
+
uuid=record['uuid'],
|
|
240
|
+
group_id=record['group_id'],
|
|
241
|
+
source=EpisodeType.from_str(record['source']),
|
|
242
|
+
name=record['name'],
|
|
243
|
+
source_description=record['source_description'],
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
248
|
+
return EntityNode(
|
|
249
|
+
uuid=record['uuid'],
|
|
250
|
+
name=record['name'],
|
|
251
|
+
group_id=record['group_id'],
|
|
252
|
+
name_embedding=record['name_embedding'],
|
|
253
|
+
labels=['Entity'],
|
|
254
|
+
created_at=record['created_at'].to_native(),
|
|
255
|
+
summary=record['summary'],
|
|
256
|
+
)
|
graphiti_core/search/search.py
CHANGED
|
@@ -52,6 +52,7 @@ class SearchConfig(BaseModel):
|
|
|
52
52
|
num_edges: int = Field(default=10)
|
|
53
53
|
num_nodes: int = Field(default=10)
|
|
54
54
|
num_episodes: int = EPISODE_WINDOW_LEN
|
|
55
|
+
group_ids: list[str | None] | None
|
|
55
56
|
search_methods: list[SearchMethod]
|
|
56
57
|
reranker: Reranker | None
|
|
57
58
|
|
|
@@ -83,7 +84,9 @@ async def hybrid_search(
|
|
|
83
84
|
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
|
84
85
|
|
|
85
86
|
if SearchMethod.bm25 in config.search_methods:
|
|
86
|
-
text_search = await edge_fulltext_search(
|
|
87
|
+
text_search = await edge_fulltext_search(
|
|
88
|
+
driver, query, None, None, config.group_ids, 2 * config.num_edges
|
|
89
|
+
)
|
|
87
90
|
search_results.append(text_search)
|
|
88
91
|
|
|
89
92
|
if SearchMethod.cosine_similarity in config.search_methods:
|
|
@@ -95,7 +98,7 @@ async def hybrid_search(
|
|
|
95
98
|
)
|
|
96
99
|
|
|
97
100
|
similarity_search = await edge_similarity_search(
|
|
98
|
-
driver, search_vector, None, None, 2 * config.num_edges
|
|
101
|
+
driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
|
|
99
102
|
)
|
|
100
103
|
search_results.append(similarity_search)
|
|
101
104
|
|
|
@@ -3,13 +3,11 @@ import logging
|
|
|
3
3
|
import re
|
|
4
4
|
from collections import defaultdict
|
|
5
5
|
from time import time
|
|
6
|
-
from typing import Any
|
|
7
6
|
|
|
8
7
|
from neo4j import AsyncDriver, Query
|
|
9
8
|
|
|
10
|
-
from graphiti_core.edges import EntityEdge
|
|
11
|
-
from graphiti_core.
|
|
12
|
-
from graphiti_core.nodes import EntityNode, EpisodicNode
|
|
9
|
+
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
10
|
+
from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record
|
|
13
11
|
|
|
14
12
|
logger = logging.getLogger(__name__)
|
|
15
13
|
|
|
@@ -23,6 +21,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
|
|
23
21
|
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
|
24
22
|
RETURN DISTINCT
|
|
25
23
|
n.uuid As uuid,
|
|
24
|
+
n.group_id AS group_id,
|
|
26
25
|
n.name AS name,
|
|
27
26
|
n.name_embedding AS name_embedding
|
|
28
27
|
n.created_at AS created_at,
|
|
@@ -31,86 +30,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
|
|
31
30
|
uuids=episode_uuids,
|
|
32
31
|
)
|
|
33
32
|
|
|
34
|
-
nodes
|
|
35
|
-
|
|
36
|
-
for record in records:
|
|
37
|
-
nodes.append(
|
|
38
|
-
EntityNode(
|
|
39
|
-
uuid=record['uuid'],
|
|
40
|
-
name=record['name'],
|
|
41
|
-
name_embedding=record['name_embedding'],
|
|
42
|
-
labels=['Entity'],
|
|
43
|
-
created_at=record['created_at'].to_native(),
|
|
44
|
-
summary=record['summary'],
|
|
45
|
-
)
|
|
46
|
-
)
|
|
33
|
+
nodes = [get_entity_node_from_record(record) for record in records]
|
|
47
34
|
|
|
48
35
|
return nodes
|
|
49
36
|
|
|
50
37
|
|
|
51
|
-
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
|
52
|
-
records, _, _ = await driver.execute_query(
|
|
53
|
-
"""
|
|
54
|
-
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
|
|
55
|
-
RETURN DISTINCT
|
|
56
|
-
n.uuid AS source_node_uuid,
|
|
57
|
-
n.name AS source_name,
|
|
58
|
-
n.summary AS source_summary,
|
|
59
|
-
m.uuid AS target_node_uuid,
|
|
60
|
-
m.name AS target_name,
|
|
61
|
-
m.summary AS target_summary,
|
|
62
|
-
r.uuid AS uuid,
|
|
63
|
-
r.created_at AS created_at,
|
|
64
|
-
r.name AS name,
|
|
65
|
-
r.fact AS fact,
|
|
66
|
-
r.fact_embedding AS fact_embedding,
|
|
67
|
-
r.episodes AS episodes,
|
|
68
|
-
r.expired_at AS expired_at,
|
|
69
|
-
r.valid_at AS valid_at,
|
|
70
|
-
r.invalid_at AS invalid_at
|
|
71
|
-
|
|
72
|
-
""",
|
|
73
|
-
node_ids=node_ids,
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
context: dict[str, Any] = {}
|
|
77
|
-
|
|
78
|
-
for record in records:
|
|
79
|
-
n_uuid = record['source_node_uuid']
|
|
80
|
-
if n_uuid in context:
|
|
81
|
-
context[n_uuid]['facts'].append(record['fact'])
|
|
82
|
-
else:
|
|
83
|
-
context[n_uuid] = {
|
|
84
|
-
'name': record['source_name'],
|
|
85
|
-
'summary': record['source_summary'],
|
|
86
|
-
'facts': [record['fact']],
|
|
87
|
-
}
|
|
88
|
-
|
|
89
|
-
m_uuid = record['target_node_uuid']
|
|
90
|
-
if m_uuid not in context:
|
|
91
|
-
context[m_uuid] = {
|
|
92
|
-
'name': record['target_name'],
|
|
93
|
-
'summary': record['target_summary'],
|
|
94
|
-
'facts': [],
|
|
95
|
-
}
|
|
96
|
-
logger.info(f'bfs search returned context: {context}')
|
|
97
|
-
return context
|
|
98
|
-
|
|
99
|
-
|
|
100
38
|
async def edge_similarity_search(
|
|
101
39
|
driver: AsyncDriver,
|
|
102
40
|
search_vector: list[float],
|
|
103
41
|
source_node_uuid: str | None,
|
|
104
42
|
target_node_uuid: str | None,
|
|
43
|
+
group_ids: list[str | None] | None = None,
|
|
105
44
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
106
45
|
) -> list[EntityEdge]:
|
|
46
|
+
group_ids = group_ids if group_ids is not None else [None]
|
|
107
47
|
# vector similarity search over embedded facts
|
|
108
48
|
query = Query("""
|
|
109
49
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
110
50
|
YIELD relationship AS rel, score
|
|
111
51
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
52
|
+
WHERE r.group_id IN $group_ids
|
|
112
53
|
RETURN
|
|
113
54
|
r.uuid AS uuid,
|
|
55
|
+
r.group_id AS group_id,
|
|
114
56
|
n.uuid AS source_node_uuid,
|
|
115
57
|
m.uuid AS target_node_uuid,
|
|
116
58
|
r.created_at AS created_at,
|
|
@@ -129,8 +71,10 @@ async def edge_similarity_search(
|
|
|
129
71
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
130
72
|
YIELD relationship AS rel, score
|
|
131
73
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
|
74
|
+
WHERE r.group_id IN $group_ids
|
|
132
75
|
RETURN
|
|
133
76
|
r.uuid AS uuid,
|
|
77
|
+
r.group_id AS group_id,
|
|
134
78
|
n.uuid AS source_node_uuid,
|
|
135
79
|
m.uuid AS target_node_uuid,
|
|
136
80
|
r.created_at AS created_at,
|
|
@@ -148,8 +92,10 @@ async def edge_similarity_search(
|
|
|
148
92
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
149
93
|
YIELD relationship AS rel, score
|
|
150
94
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
95
|
+
WHERE r.group_id IN $group_ids
|
|
151
96
|
RETURN
|
|
152
97
|
r.uuid AS uuid,
|
|
98
|
+
r.group_id AS group_id,
|
|
153
99
|
n.uuid AS source_node_uuid,
|
|
154
100
|
m.uuid AS target_node_uuid,
|
|
155
101
|
r.created_at AS created_at,
|
|
@@ -167,8 +113,10 @@ async def edge_similarity_search(
|
|
|
167
113
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
168
114
|
YIELD relationship AS rel, score
|
|
169
115
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
|
116
|
+
WHERE r.group_id IN $group_ids
|
|
170
117
|
RETURN
|
|
171
118
|
r.uuid AS uuid,
|
|
119
|
+
r.group_id AS group_id,
|
|
172
120
|
n.uuid AS source_node_uuid,
|
|
173
121
|
m.uuid AS target_node_uuid,
|
|
174
122
|
r.created_at AS created_at,
|
|
@@ -187,41 +135,32 @@ async def edge_similarity_search(
|
|
|
187
135
|
search_vector=search_vector,
|
|
188
136
|
source_uuid=source_node_uuid,
|
|
189
137
|
target_uuid=target_node_uuid,
|
|
138
|
+
group_ids=group_ids,
|
|
190
139
|
limit=limit,
|
|
191
140
|
)
|
|
192
141
|
|
|
193
|
-
edges
|
|
194
|
-
|
|
195
|
-
for record in records:
|
|
196
|
-
edge = EntityEdge(
|
|
197
|
-
uuid=record['uuid'],
|
|
198
|
-
source_node_uuid=record['source_node_uuid'],
|
|
199
|
-
target_node_uuid=record['target_node_uuid'],
|
|
200
|
-
fact=record['fact'],
|
|
201
|
-
name=record['name'],
|
|
202
|
-
episodes=record['episodes'],
|
|
203
|
-
fact_embedding=record['fact_embedding'],
|
|
204
|
-
created_at=record['created_at'].to_native(),
|
|
205
|
-
expired_at=parse_db_date(record['expired_at']),
|
|
206
|
-
valid_at=parse_db_date(record['valid_at']),
|
|
207
|
-
invalid_at=parse_db_date(record['invalid_at']),
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
edges.append(edge)
|
|
142
|
+
edges = [get_entity_edge_from_record(record) for record in records]
|
|
211
143
|
|
|
212
144
|
return edges
|
|
213
145
|
|
|
214
146
|
|
|
215
147
|
async def entity_similarity_search(
|
|
216
|
-
search_vector: list[float],
|
|
148
|
+
search_vector: list[float],
|
|
149
|
+
driver: AsyncDriver,
|
|
150
|
+
group_ids: list[str | None] | None = None,
|
|
151
|
+
limit=RELEVANT_SCHEMA_LIMIT,
|
|
217
152
|
) -> list[EntityNode]:
|
|
153
|
+
group_ids = group_ids if group_ids is not None else [None]
|
|
154
|
+
|
|
218
155
|
# vector similarity search over entity names
|
|
219
156
|
records, _, _ = await driver.execute_query(
|
|
220
157
|
"""
|
|
221
158
|
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
|
222
159
|
YIELD node AS n, score
|
|
160
|
+
MATCH (n WHERE n.group_id IN $group_ids)
|
|
223
161
|
RETURN
|
|
224
|
-
n.uuid As uuid,
|
|
162
|
+
n.uuid As uuid,
|
|
163
|
+
n.group_id AS group_id,
|
|
225
164
|
n.name AS name,
|
|
226
165
|
n.name_embedding AS name_embedding,
|
|
227
166
|
n.created_at AS created_at,
|
|
@@ -229,58 +168,44 @@ async def entity_similarity_search(
|
|
|
229
168
|
ORDER BY score DESC
|
|
230
169
|
""",
|
|
231
170
|
search_vector=search_vector,
|
|
171
|
+
group_ids=group_ids,
|
|
232
172
|
limit=limit,
|
|
233
173
|
)
|
|
234
|
-
nodes
|
|
235
|
-
|
|
236
|
-
for record in records:
|
|
237
|
-
nodes.append(
|
|
238
|
-
EntityNode(
|
|
239
|
-
uuid=record['uuid'],
|
|
240
|
-
name=record['name'],
|
|
241
|
-
name_embedding=record['name_embedding'],
|
|
242
|
-
labels=['Entity'],
|
|
243
|
-
created_at=record['created_at'].to_native(),
|
|
244
|
-
summary=record['summary'],
|
|
245
|
-
)
|
|
246
|
-
)
|
|
174
|
+
nodes = [get_entity_node_from_record(record) for record in records]
|
|
247
175
|
|
|
248
176
|
return nodes
|
|
249
177
|
|
|
250
178
|
|
|
251
179
|
async def entity_fulltext_search(
|
|
252
|
-
query: str,
|
|
180
|
+
query: str,
|
|
181
|
+
driver: AsyncDriver,
|
|
182
|
+
group_ids: list[str | None] | None = None,
|
|
183
|
+
limit=RELEVANT_SCHEMA_LIMIT,
|
|
253
184
|
) -> list[EntityNode]:
|
|
185
|
+
group_ids = group_ids if group_ids is not None else [None]
|
|
186
|
+
|
|
254
187
|
# BM25 search to get top nodes
|
|
255
188
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
256
189
|
records, _, _ = await driver.execute_query(
|
|
257
190
|
"""
|
|
258
|
-
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
|
191
|
+
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
|
192
|
+
YIELD node AS n, score
|
|
193
|
+
MATCH (n WHERE n.group_id in $group_ids)
|
|
259
194
|
RETURN
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
195
|
+
n.uuid AS uuid,
|
|
196
|
+
n.group_id AS group_id,
|
|
197
|
+
n.name AS name,
|
|
198
|
+
n.name_embedding AS name_embedding,
|
|
199
|
+
n.created_at AS created_at,
|
|
200
|
+
n.summary AS summary
|
|
265
201
|
ORDER BY score DESC
|
|
266
202
|
LIMIT $limit
|
|
267
203
|
""",
|
|
268
204
|
query=fuzzy_query,
|
|
205
|
+
group_ids=group_ids,
|
|
269
206
|
limit=limit,
|
|
270
207
|
)
|
|
271
|
-
nodes
|
|
272
|
-
|
|
273
|
-
for record in records:
|
|
274
|
-
nodes.append(
|
|
275
|
-
EntityNode(
|
|
276
|
-
uuid=record['uuid'],
|
|
277
|
-
name=record['name'],
|
|
278
|
-
name_embedding=record['name_embedding'],
|
|
279
|
-
labels=['Entity'],
|
|
280
|
-
created_at=record['created_at'].to_native(),
|
|
281
|
-
summary=record['summary'],
|
|
282
|
-
)
|
|
283
|
-
)
|
|
208
|
+
nodes = [get_entity_node_from_record(record) for record in records]
|
|
284
209
|
|
|
285
210
|
return nodes
|
|
286
211
|
|
|
@@ -290,15 +215,20 @@ async def edge_fulltext_search(
|
|
|
290
215
|
query: str,
|
|
291
216
|
source_node_uuid: str | None,
|
|
292
217
|
target_node_uuid: str | None,
|
|
218
|
+
group_ids: list[str | None] | None = None,
|
|
293
219
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
294
220
|
) -> list[EntityEdge]:
|
|
221
|
+
group_ids = group_ids if group_ids is not None else [None]
|
|
222
|
+
|
|
295
223
|
# fulltext search over facts
|
|
296
224
|
cypher_query = Query("""
|
|
297
225
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
298
226
|
YIELD relationship AS rel, score
|
|
299
227
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
228
|
+
WHERE r.group_id IN $group_ids
|
|
300
229
|
RETURN
|
|
301
230
|
r.uuid AS uuid,
|
|
231
|
+
r.group_id AS group_id,
|
|
302
232
|
n.uuid AS source_node_uuid,
|
|
303
233
|
m.uuid AS target_node_uuid,
|
|
304
234
|
r.created_at AS created_at,
|
|
@@ -317,8 +247,10 @@ async def edge_fulltext_search(
|
|
|
317
247
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
318
248
|
YIELD relationship AS rel, score
|
|
319
249
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
|
250
|
+
WHERE r.group_id IN $group_ids
|
|
320
251
|
RETURN
|
|
321
252
|
r.uuid AS uuid,
|
|
253
|
+
r.group_id AS group_id,
|
|
322
254
|
n.uuid AS source_node_uuid,
|
|
323
255
|
m.uuid AS target_node_uuid,
|
|
324
256
|
r.created_at AS created_at,
|
|
@@ -335,9 +267,11 @@ async def edge_fulltext_search(
|
|
|
335
267
|
cypher_query = Query("""
|
|
336
268
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
337
269
|
YIELD relationship AS rel, score
|
|
338
|
-
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
270
|
+
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
271
|
+
WHERE r.group_id IN $group_ids
|
|
339
272
|
RETURN
|
|
340
273
|
r.uuid AS uuid,
|
|
274
|
+
r.group_id AS group_id,
|
|
341
275
|
n.uuid AS source_node_uuid,
|
|
342
276
|
m.uuid AS target_node_uuid,
|
|
343
277
|
r.created_at AS created_at,
|
|
@@ -354,9 +288,11 @@ async def edge_fulltext_search(
|
|
|
354
288
|
cypher_query = Query("""
|
|
355
289
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
356
290
|
YIELD relationship AS rel, score
|
|
357
|
-
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
|
291
|
+
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
|
292
|
+
WHERE r.group_id IN $group_ids
|
|
358
293
|
RETURN
|
|
359
294
|
r.uuid AS uuid,
|
|
295
|
+
r.group_id AS group_id,
|
|
360
296
|
n.uuid AS source_node_uuid,
|
|
361
297
|
m.uuid AS target_node_uuid,
|
|
362
298
|
r.created_at AS created_at,
|
|
@@ -377,27 +313,11 @@ async def edge_fulltext_search(
|
|
|
377
313
|
query=fuzzy_query,
|
|
378
314
|
source_uuid=source_node_uuid,
|
|
379
315
|
target_uuid=target_node_uuid,
|
|
316
|
+
group_ids=group_ids,
|
|
380
317
|
limit=limit,
|
|
381
318
|
)
|
|
382
319
|
|
|
383
|
-
edges
|
|
384
|
-
|
|
385
|
-
for record in records:
|
|
386
|
-
edge = EntityEdge(
|
|
387
|
-
uuid=record['uuid'],
|
|
388
|
-
source_node_uuid=record['source_node_uuid'],
|
|
389
|
-
target_node_uuid=record['target_node_uuid'],
|
|
390
|
-
fact=record['fact'],
|
|
391
|
-
name=record['name'],
|
|
392
|
-
episodes=record['episodes'],
|
|
393
|
-
fact_embedding=record['fact_embedding'],
|
|
394
|
-
created_at=record['created_at'].to_native(),
|
|
395
|
-
expired_at=parse_db_date(record['expired_at']),
|
|
396
|
-
valid_at=parse_db_date(record['valid_at']),
|
|
397
|
-
invalid_at=parse_db_date(record['invalid_at']),
|
|
398
|
-
)
|
|
399
|
-
|
|
400
|
-
edges.append(edge)
|
|
320
|
+
edges = [get_entity_edge_from_record(record) for record in records]
|
|
401
321
|
|
|
402
322
|
return edges
|
|
403
323
|
|
|
@@ -406,6 +326,7 @@ async def hybrid_node_search(
|
|
|
406
326
|
queries: list[str],
|
|
407
327
|
embeddings: list[list[float]],
|
|
408
328
|
driver: AsyncDriver,
|
|
329
|
+
group_ids: list[str | None] | None = None,
|
|
409
330
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
410
331
|
) -> list[EntityNode]:
|
|
411
332
|
"""
|
|
@@ -422,6 +343,8 @@ async def hybrid_node_search(
|
|
|
422
343
|
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
|
|
423
344
|
driver : AsyncDriver
|
|
424
345
|
The Neo4j driver instance for database operations.
|
|
346
|
+
group_ids : list[str] | None, optional
|
|
347
|
+
The list of group ids to retrieve nodes from.
|
|
425
348
|
limit : int | None, optional
|
|
426
349
|
The maximum number of results to return per search method. If None, a default limit will be applied.
|
|
427
350
|
|
|
@@ -448,8 +371,8 @@ async def hybrid_node_search(
|
|
|
448
371
|
|
|
449
372
|
results: list[list[EntityNode]] = list(
|
|
450
373
|
await asyncio.gather(
|
|
451
|
-
*[entity_fulltext_search(q, driver, 2 * limit) for q in queries],
|
|
452
|
-
*[entity_similarity_search(e, driver, 2 * limit) for e in embeddings],
|
|
374
|
+
*[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries],
|
|
375
|
+
*[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings],
|
|
453
376
|
)
|
|
454
377
|
)
|
|
455
378
|
|
|
@@ -500,6 +423,7 @@ async def get_relevant_nodes(
|
|
|
500
423
|
[node.name for node in nodes],
|
|
501
424
|
[node.name_embedding for node in nodes if node.name_embedding is not None],
|
|
502
425
|
driver,
|
|
426
|
+
[node.group_id for node in nodes],
|
|
503
427
|
)
|
|
504
428
|
return relevant_nodes
|
|
505
429
|
|
|
@@ -518,13 +442,20 @@ async def get_relevant_edges(
|
|
|
518
442
|
results = await asyncio.gather(
|
|
519
443
|
*[
|
|
520
444
|
edge_similarity_search(
|
|
521
|
-
driver,
|
|
445
|
+
driver,
|
|
446
|
+
edge.fact_embedding,
|
|
447
|
+
source_node_uuid,
|
|
448
|
+
target_node_uuid,
|
|
449
|
+
[edge.group_id],
|
|
450
|
+
limit,
|
|
522
451
|
)
|
|
523
452
|
for edge in edges
|
|
524
453
|
if edge.fact_embedding is not None
|
|
525
454
|
],
|
|
526
455
|
*[
|
|
527
|
-
edge_fulltext_search(
|
|
456
|
+
edge_fulltext_search(
|
|
457
|
+
driver, edge.fact, source_node_uuid, target_node_uuid, [edge.group_id], limit
|
|
458
|
+
)
|
|
528
459
|
for edge in edges
|
|
529
460
|
],
|
|
530
461
|
)
|
|
@@ -17,6 +17,7 @@ limitations under the License.
|
|
|
17
17
|
import asyncio
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
|
+
from collections import defaultdict
|
|
20
21
|
from datetime import datetime
|
|
21
22
|
from math import ceil
|
|
22
23
|
|
|
@@ -42,7 +43,6 @@ from graphiti_core.utils.maintenance.node_operations import (
|
|
|
42
43
|
extract_nodes,
|
|
43
44
|
)
|
|
44
45
|
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
|
45
|
-
from graphiti_core.utils.utils import chunk_edges_by_nodes
|
|
46
46
|
|
|
47
47
|
logger = logging.getLogger(__name__)
|
|
48
48
|
|
|
@@ -62,7 +62,9 @@ async def retrieve_previous_episodes_bulk(
|
|
|
62
62
|
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
|
63
63
|
previous_episodes_list = await asyncio.gather(
|
|
64
64
|
*[
|
|
65
|
-
retrieve_episodes(
|
|
65
|
+
retrieve_episodes(
|
|
66
|
+
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
|
|
67
|
+
)
|
|
66
68
|
for episode in episodes
|
|
67
69
|
]
|
|
68
70
|
)
|
|
@@ -90,7 +92,13 @@ async def extract_nodes_and_edges_bulk(
|
|
|
90
92
|
|
|
91
93
|
extracted_edges_bulk = await asyncio.gather(
|
|
92
94
|
*[
|
|
93
|
-
extract_edges(
|
|
95
|
+
extract_edges(
|
|
96
|
+
llm_client,
|
|
97
|
+
episode,
|
|
98
|
+
extracted_nodes_bulk[i],
|
|
99
|
+
previous_episodes_list[i],
|
|
100
|
+
episode.group_id,
|
|
101
|
+
)
|
|
94
102
|
for i, episode in enumerate(episodes)
|
|
95
103
|
]
|
|
96
104
|
)
|
|
@@ -343,3 +351,23 @@ async def extract_edge_dates_bulk(
|
|
|
343
351
|
edge.expired_at = datetime.now()
|
|
344
352
|
|
|
345
353
|
return edges
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
|
|
357
|
+
# We only want to dedupe edges that are between the same pair of nodes
|
|
358
|
+
# We build a map of the edges based on their source and target nodes.
|
|
359
|
+
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
|
360
|
+
for edge in edges:
|
|
361
|
+
# We drop loop edges
|
|
362
|
+
if edge.source_node_uuid == edge.target_node_uuid:
|
|
363
|
+
continue
|
|
364
|
+
|
|
365
|
+
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
|
366
|
+
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
|
367
|
+
pointers.sort()
|
|
368
|
+
|
|
369
|
+
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
|
370
|
+
|
|
371
|
+
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
|
372
|
+
|
|
373
|
+
return edge_chunks
|
|
@@ -37,15 +37,15 @@ def build_episodic_edges(
|
|
|
37
37
|
episode: EpisodicNode,
|
|
38
38
|
created_at: datetime,
|
|
39
39
|
) -> List[EpisodicEdge]:
|
|
40
|
-
edges: List[EpisodicEdge] = [
|
|
41
|
-
|
|
42
|
-
for node in entity_nodes:
|
|
43
|
-
edge = EpisodicEdge(
|
|
40
|
+
edges: List[EpisodicEdge] = [
|
|
41
|
+
EpisodicEdge(
|
|
44
42
|
source_node_uuid=episode.uuid,
|
|
45
43
|
target_node_uuid=node.uuid,
|
|
46
44
|
created_at=created_at,
|
|
45
|
+
group_id=episode.group_id,
|
|
47
46
|
)
|
|
48
|
-
|
|
47
|
+
for node in entity_nodes
|
|
48
|
+
]
|
|
49
49
|
|
|
50
50
|
return edges
|
|
51
51
|
|
|
@@ -55,6 +55,7 @@ async def extract_edges(
|
|
|
55
55
|
episode: EpisodicNode,
|
|
56
56
|
nodes: list[EntityNode],
|
|
57
57
|
previous_episodes: list[EpisodicNode],
|
|
58
|
+
group_id: str | None,
|
|
58
59
|
) -> list[EntityEdge]:
|
|
59
60
|
start = time()
|
|
60
61
|
|
|
@@ -88,6 +89,7 @@ async def extract_edges(
|
|
|
88
89
|
source_node_uuid=edge_data['source_node_uuid'],
|
|
89
90
|
target_node_uuid=edge_data['target_node_uuid'],
|
|
90
91
|
name=edge_data['relation_type'],
|
|
92
|
+
group_id=group_id,
|
|
91
93
|
fact=edge_data['fact'],
|
|
92
94
|
episodes=[episode.uuid],
|
|
93
95
|
created_at=datetime.now(),
|
|
@@ -34,6 +34,10 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
|
|
34
34
|
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
35
35
|
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
|
36
36
|
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
|
37
|
+
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
|
38
|
+
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
|
39
|
+
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
|
40
|
+
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
|
37
41
|
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
|
38
42
|
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
|
39
43
|
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
|
@@ -86,6 +90,7 @@ async def retrieve_episodes(
|
|
|
86
90
|
driver: AsyncDriver,
|
|
87
91
|
reference_time: datetime,
|
|
88
92
|
last_n: int = EPISODE_WINDOW_LEN,
|
|
93
|
+
group_ids: list[str | None] | None = None,
|
|
89
94
|
) -> list[EpisodicNode]:
|
|
90
95
|
"""
|
|
91
96
|
Retrieve the last n episodic nodes from the graph.
|
|
@@ -96,25 +101,28 @@ async def retrieve_episodes(
|
|
|
96
101
|
less than or equal to this reference_time will be retrieved. This allows for
|
|
97
102
|
querying the graph's state at a specific point in time.
|
|
98
103
|
last_n (int, optional): The number of most recent episodes to retrieve, relative to the reference_time.
|
|
104
|
+
group_ids (list[str], optional): The list of group ids to return data from.
|
|
99
105
|
|
|
100
106
|
Returns:
|
|
101
107
|
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
|
102
108
|
"""
|
|
103
109
|
result = await driver.execute_query(
|
|
104
110
|
"""
|
|
105
|
-
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
|
106
|
-
RETURN e.content
|
|
107
|
-
e.created_at
|
|
108
|
-
e.valid_at
|
|
109
|
-
e.uuid
|
|
110
|
-
e.
|
|
111
|
-
e.
|
|
112
|
-
e.
|
|
111
|
+
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids
|
|
112
|
+
RETURN e.content AS content,
|
|
113
|
+
e.created_at AS created_at,
|
|
114
|
+
e.valid_at AS valid_at,
|
|
115
|
+
e.uuid AS uuid,
|
|
116
|
+
e.group_id AS group_id,
|
|
117
|
+
e.name AS name,
|
|
118
|
+
e.source_description AS source_description,
|
|
119
|
+
e.source AS source
|
|
113
120
|
ORDER BY e.created_at DESC
|
|
114
121
|
LIMIT $num_episodes
|
|
115
122
|
""",
|
|
116
123
|
reference_time=reference_time,
|
|
117
124
|
num_episodes=last_n,
|
|
125
|
+
group_ids=group_ids,
|
|
118
126
|
)
|
|
119
127
|
episodes = [
|
|
120
128
|
EpisodicNode(
|
|
@@ -124,6 +132,7 @@ async def retrieve_episodes(
|
|
|
124
132
|
),
|
|
125
133
|
valid_at=(record['valid_at'].to_native()),
|
|
126
134
|
uuid=record['uuid'],
|
|
135
|
+
group_id=record['group_id'],
|
|
127
136
|
source=EpisodeType.from_str(record['source']),
|
|
128
137
|
name=record['name'],
|
|
129
138
|
source_description=record['source_description'],
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -12,11 +12,10 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
14
|
Requires-Dist: diskcache (>=5.6.3,<6.0.0)
|
|
15
|
-
Requires-Dist: fastapi (>=0.112.0,<0.113.0)
|
|
16
15
|
Requires-Dist: neo4j (>=5.23.0,<6.0.0)
|
|
16
|
+
Requires-Dist: numpy (>=2.1.1,<3.0.0)
|
|
17
17
|
Requires-Dist: openai (>=1.38.0,<2.0.0)
|
|
18
18
|
Requires-Dist: pydantic (>=2.8.2,<3.0.0)
|
|
19
|
-
Requires-Dist: sentence-transformers (>=3.0.1,<4.0.0)
|
|
20
19
|
Requires-Dist: tenacity (<9.0.0)
|
|
21
20
|
Description-Content-Type: text/markdown
|
|
22
21
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
graphiti_core/__init__.py,sha256=e5SWFkRiaUwfprYIeIgVIh7JDedNiloZvd3roU-0aDY,55
|
|
2
|
-
graphiti_core/edges.py,sha256=
|
|
3
|
-
graphiti_core/graphiti.py,sha256=
|
|
2
|
+
graphiti_core/edges.py,sha256=oy_tK9YWE7_g4aQMGutymVdreiC-SsWP6ZtayEYGCFQ,7700
|
|
3
|
+
graphiti_core/graphiti.py,sha256=tUEtyBb8hQXTn_eMmVSsFVBV7AKWE22SPQihCMZtcZU,23647
|
|
4
4
|
graphiti_core/helpers.py,sha256=EAeC3RrcecjiTGN2vxergN5RHTy2_jhFXA5PQVT3toU,200
|
|
5
5
|
graphiti_core/llm_client/__init__.py,sha256=f4OSk82jJ70wZ2HOuQu6-RQWkkf7HIB0FCT6xOuxZkQ,154
|
|
6
6
|
graphiti_core/llm_client/anthropic_client.py,sha256=C8lOLm7in_eNfOP7s8gjMM0Y99-TzKWlGaPuVGceX68,2180
|
|
@@ -9,7 +9,7 @@ graphiti_core/llm_client/config.py,sha256=d1oZ9tt7QBQlbph7v-0HjItb6otK9_-IwF8kkR
|
|
|
9
9
|
graphiti_core/llm_client/groq_client.py,sha256=qscr5-190wBTUCBL31EAjQTLytK9AF75-y9GsVRvGJU,2206
|
|
10
10
|
graphiti_core/llm_client/openai_client.py,sha256=Bkrp_mKzAxK6kgPzv1UtVUgr1ZvvJhE2H39hgAwWrsI,2211
|
|
11
11
|
graphiti_core/llm_client/utils.py,sha256=H8-Kwa5SyvIYDNIas8O4bHJ6jsOL49li44VoDEMyauY,555
|
|
12
|
-
graphiti_core/nodes.py,sha256=
|
|
12
|
+
graphiti_core/nodes.py,sha256=RZnIKyu9ZzWVlbodae3Rkzlg00fQIqp5o3iGB4Ffm-M,8140
|
|
13
13
|
graphiti_core/prompts/__init__.py,sha256=EA-x9xUki9l8wnu2l8ek_oNf75-do5tq5hVq7Zbv8Kw,101
|
|
14
14
|
graphiti_core/prompts/dedupe_edges.py,sha256=DUNHdIudj50FAjkla4nc68tSFSD2yjmYHBw-Bb7ph20,6529
|
|
15
15
|
graphiti_core/prompts/dedupe_nodes.py,sha256=BZ9S-PB9SSGjc5Oo8ivdgA6rZx3OGOFhKtwrBlQ0bm0,7269
|
|
@@ -20,18 +20,17 @@ graphiti_core/prompts/invalidate_edges.py,sha256=8SHt3iPTdmqk8A52LxgdMtI39w4USKq
|
|
|
20
20
|
graphiti_core/prompts/lib.py,sha256=RR8f8DQfioUK5bJonMzn02pKLxJlaENv1VocpvRJ488,3532
|
|
21
21
|
graphiti_core/prompts/models.py,sha256=cvx_Bv5RMFUD_5IUawYrbpOKLPHogai7_bm7YXrSz84,867
|
|
22
22
|
graphiti_core/search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
|
-
graphiti_core/search/search.py,sha256=
|
|
24
|
-
graphiti_core/search/search_utils.py,sha256=
|
|
23
|
+
graphiti_core/search/search.py,sha256=cr1-syRlRdijnLtbuQYWy_2G1CtAeIaz6BQ2kl_6FrY,4535
|
|
24
|
+
graphiti_core/search/search_utils.py,sha256=YeJ-M67HXPQySruwZmau3jvilFlcwf8OwfuflnSdf1Q,19355
|
|
25
25
|
graphiti_core/utils/__init__.py,sha256=cJAcMnBZdHBQmWrZdU1PQ1YmaL75bhVUkyVpIPuOyns,260
|
|
26
|
-
graphiti_core/utils/bulk_utils.py,sha256=
|
|
26
|
+
graphiti_core/utils/bulk_utils.py,sha256=JtoYTZPCigPa3n2E43Oe7QhFZRTA_QKNGy1jVgklHag,12614
|
|
27
27
|
graphiti_core/utils/maintenance/__init__.py,sha256=4b9sfxqyFZMLwxxS2lnQ6_wBr3xrJRIqfAWOidK8EK0,388
|
|
28
|
-
graphiti_core/utils/maintenance/edge_operations.py,sha256=
|
|
29
|
-
graphiti_core/utils/maintenance/graph_data_operations.py,sha256
|
|
30
|
-
graphiti_core/utils/maintenance/node_operations.py,sha256=
|
|
28
|
+
graphiti_core/utils/maintenance/edge_operations.py,sha256=Xq60YlOGQKzD5qN6eahUMOiLQJiBaDNOeIiGkS8EdB0,10855
|
|
29
|
+
graphiti_core/utils/maintenance/graph_data_operations.py,sha256=-A4fPYtXIjoBBX6IDPoaU9pDcSjZGeRbRPj23W1C-l4,5951
|
|
30
|
+
graphiti_core/utils/maintenance/node_operations.py,sha256=ecBOp_reQynENFN0M69IzRPgEuBYOuPpDBwFZq5e-I4,7995
|
|
31
31
|
graphiti_core/utils/maintenance/temporal_operations.py,sha256=BzfGDm96w4HcUEsaWTHUBt5S8dNmDQL1eX6AuBL-XFM,8135
|
|
32
32
|
graphiti_core/utils/maintenance/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
33
|
-
graphiti_core/
|
|
34
|
-
graphiti_core-0.2.
|
|
35
|
-
graphiti_core-0.2.
|
|
36
|
-
graphiti_core-0.2.
|
|
37
|
-
graphiti_core-0.2.2.dist-info/RECORD,,
|
|
33
|
+
graphiti_core-0.2.3.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
|
|
34
|
+
graphiti_core-0.2.3.dist-info/METADATA,sha256=o81BUoLGtzm0AhnO9MBW8yeG1_UPFN6m-0PmnFOJKis,9124
|
|
35
|
+
graphiti_core-0.2.3.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
36
|
+
graphiti_core-0.2.3.dist-info/RECORD,,
|
graphiti_core/utils/utils.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Copyright 2024, Zep Software, Inc.
|
|
3
|
-
|
|
4
|
-
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
you may not use this file except in compliance with the License.
|
|
6
|
-
You may obtain a copy of the License at
|
|
7
|
-
|
|
8
|
-
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
|
|
10
|
-
Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
See the License for the specific language governing permissions and
|
|
14
|
-
limitations under the License.
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
import logging
|
|
18
|
-
from collections import defaultdict
|
|
19
|
-
|
|
20
|
-
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
21
|
-
from graphiti_core.nodes import EntityNode, EpisodicNode
|
|
22
|
-
|
|
23
|
-
logger = logging.getLogger(__name__)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def build_episodic_edges(
|
|
27
|
-
entity_nodes: list[EntityNode], episode: EpisodicNode
|
|
28
|
-
) -> list[EpisodicEdge]:
|
|
29
|
-
edges: list[EpisodicEdge] = []
|
|
30
|
-
|
|
31
|
-
for node in entity_nodes:
|
|
32
|
-
edges.append(
|
|
33
|
-
EpisodicEdge(
|
|
34
|
-
source_node_uuid=episode.uuid,
|
|
35
|
-
target_node_uuid=node.uuid,
|
|
36
|
-
created_at=episode.created_at,
|
|
37
|
-
)
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
return edges
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
|
|
44
|
-
# We only want to dedupe edges that are between the same pair of nodes
|
|
45
|
-
# We build a map of the edges based on their source and target nodes.
|
|
46
|
-
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
|
47
|
-
for edge in edges:
|
|
48
|
-
# We drop loop edges
|
|
49
|
-
if edge.source_node_uuid == edge.target_node_uuid:
|
|
50
|
-
continue
|
|
51
|
-
|
|
52
|
-
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
|
53
|
-
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
|
54
|
-
pointers.sort()
|
|
55
|
-
|
|
56
|
-
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
|
57
|
-
|
|
58
|
-
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
|
59
|
-
|
|
60
|
-
return edge_chunks
|
|
File without changes
|
|
File without changes
|