graphiti-core 0.3.2__tar.gz → 0.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/PKG-INFO +1 -1
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/edges.py +78 -3
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/graphiti.py +31 -1
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/nodes.py +3 -2
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/search/search.py +7 -1
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/search/search_config.py +2 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/search/search_config_recipes.py +16 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/search/search_utils.py +57 -1
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/community_operations.py +85 -1
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/edge_operations.py +2 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/pyproject.toml +3 -3
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/LICENSE +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/README.md +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/errors.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/helpers.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/llm_client/anthropic_client.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/llm_client/client.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/llm_client/errors.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/llm_client/groq_client.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/llm_client/openai_client.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/prompts/extract_edge_dates.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/prompts/extract_edges.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/prompts/extract_nodes.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/prompts/summarize_nodes.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/py.typed +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/bulk_utils.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/node_operations.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
- {graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/utils.py +0 -0
|
@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class Edge(BaseModel, ABC):
|
|
36
|
-
uuid: str = Field(default_factory=lambda: uuid4()
|
|
36
|
+
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
|
37
37
|
group_id: str | None = Field(description='partition of the graph')
|
|
38
38
|
source_node_uuid: str
|
|
39
39
|
target_node_uuid: str
|
|
@@ -109,13 +109,36 @@ class EpisodicEdge(Edge):
|
|
|
109
109
|
raise EdgeNotFoundError(uuid)
|
|
110
110
|
return edges[0]
|
|
111
111
|
|
|
112
|
+
@classmethod
|
|
113
|
+
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
|
114
|
+
records, _, _ = await driver.execute_query(
|
|
115
|
+
"""
|
|
116
|
+
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
|
117
|
+
WHERE e.uuid IN $uuids
|
|
118
|
+
RETURN
|
|
119
|
+
e.uuid As uuid,
|
|
120
|
+
e.group_id AS group_id,
|
|
121
|
+
n.uuid AS source_node_uuid,
|
|
122
|
+
m.uuid AS target_node_uuid,
|
|
123
|
+
e.created_at AS created_at
|
|
124
|
+
""",
|
|
125
|
+
uuids=uuids,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
edges = [get_episodic_edge_from_record(record) for record in records]
|
|
129
|
+
|
|
130
|
+
logger.info(f'Found Edges: {uuids}')
|
|
131
|
+
if len(edges) == 0:
|
|
132
|
+
raise EdgeNotFoundError(uuids[0])
|
|
133
|
+
return edges
|
|
134
|
+
|
|
112
135
|
|
|
113
136
|
class EntityEdge(Edge):
|
|
114
137
|
name: str = Field(description='name of the edge, relation name')
|
|
115
138
|
fact: str = Field(description='fact representing the edge and nodes that it connects')
|
|
116
139
|
fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
|
|
117
|
-
episodes: list[str]
|
|
118
|
-
default=
|
|
140
|
+
episodes: list[str] = Field(
|
|
141
|
+
default=[],
|
|
119
142
|
description='list of episode ids that reference these entity edges',
|
|
120
143
|
)
|
|
121
144
|
expired_at: datetime | None = Field(
|
|
@@ -197,6 +220,36 @@ class EntityEdge(Edge):
|
|
|
197
220
|
raise EdgeNotFoundError(uuid)
|
|
198
221
|
return edges[0]
|
|
199
222
|
|
|
223
|
+
@classmethod
|
|
224
|
+
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
|
225
|
+
records, _, _ = await driver.execute_query(
|
|
226
|
+
"""
|
|
227
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
228
|
+
WHERE e.uuid IN $uuids
|
|
229
|
+
RETURN
|
|
230
|
+
e.uuid AS uuid,
|
|
231
|
+
n.uuid AS source_node_uuid,
|
|
232
|
+
m.uuid AS target_node_uuid,
|
|
233
|
+
e.created_at AS created_at,
|
|
234
|
+
e.name AS name,
|
|
235
|
+
e.group_id AS group_id,
|
|
236
|
+
e.fact AS fact,
|
|
237
|
+
e.fact_embedding AS fact_embedding,
|
|
238
|
+
e.episodes AS episodes,
|
|
239
|
+
e.expired_at AS expired_at,
|
|
240
|
+
e.valid_at AS valid_at,
|
|
241
|
+
e.invalid_at AS invalid_at
|
|
242
|
+
""",
|
|
243
|
+
uuids=uuids,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
edges = [get_entity_edge_from_record(record) for record in records]
|
|
247
|
+
|
|
248
|
+
logger.info(f'Found Edges: {uuids}')
|
|
249
|
+
if len(edges) == 0:
|
|
250
|
+
raise EdgeNotFoundError(uuids[0])
|
|
251
|
+
return edges
|
|
252
|
+
|
|
200
253
|
|
|
201
254
|
class CommunityEdge(Edge):
|
|
202
255
|
async def save(self, driver: AsyncDriver):
|
|
@@ -239,6 +292,28 @@ class CommunityEdge(Edge):
|
|
|
239
292
|
|
|
240
293
|
return edges[0]
|
|
241
294
|
|
|
295
|
+
@classmethod
|
|
296
|
+
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
|
297
|
+
records, _, _ = await driver.execute_query(
|
|
298
|
+
"""
|
|
299
|
+
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
|
300
|
+
WHERE e.uuid IN $uuids
|
|
301
|
+
RETURN
|
|
302
|
+
e.uuid As uuid,
|
|
303
|
+
e.group_id AS group_id,
|
|
304
|
+
n.uuid AS source_node_uuid,
|
|
305
|
+
m.uuid AS target_node_uuid,
|
|
306
|
+
e.created_at AS created_at
|
|
307
|
+
""",
|
|
308
|
+
uuids=uuids,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
edges = [get_community_edge_from_record(record) for record in records]
|
|
312
|
+
|
|
313
|
+
logger.info(f'Found Edges: {uuids}')
|
|
314
|
+
|
|
315
|
+
return edges
|
|
316
|
+
|
|
242
317
|
|
|
243
318
|
# Edge helpers
|
|
244
319
|
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|
@@ -35,6 +35,8 @@ from graphiti_core.search.search_config_recipes import (
|
|
|
35
35
|
)
|
|
36
36
|
from graphiti_core.search.search_utils import (
|
|
37
37
|
RELEVANT_SCHEMA_LIMIT,
|
|
38
|
+
get_communities_by_nodes,
|
|
39
|
+
get_mentioned_nodes,
|
|
38
40
|
get_relevant_edges,
|
|
39
41
|
get_relevant_nodes,
|
|
40
42
|
)
|
|
@@ -54,6 +56,7 @@ from graphiti_core.utils.bulk_utils import (
|
|
|
54
56
|
from graphiti_core.utils.maintenance.community_operations import (
|
|
55
57
|
build_communities,
|
|
56
58
|
remove_communities,
|
|
59
|
+
update_community,
|
|
57
60
|
)
|
|
58
61
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
59
62
|
extract_edges,
|
|
@@ -224,6 +227,7 @@ class Graphiti:
|
|
|
224
227
|
source: EpisodeType = EpisodeType.message,
|
|
225
228
|
group_id: str | None = None,
|
|
226
229
|
uuid: str | None = None,
|
|
230
|
+
update_communities: bool = False,
|
|
227
231
|
):
|
|
228
232
|
"""
|
|
229
233
|
Process an episode and update the graph.
|
|
@@ -409,12 +413,22 @@ class Graphiti:
|
|
|
409
413
|
|
|
410
414
|
logger.info(f'Built episodic edges: {episodic_edges}')
|
|
411
415
|
|
|
416
|
+
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
|
417
|
+
|
|
412
418
|
# Future optimization would be using batch operations to save nodes and edges
|
|
413
419
|
await episode.save(self.driver)
|
|
414
420
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
|
415
421
|
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
|
416
422
|
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
|
417
423
|
|
|
424
|
+
# Update any communities
|
|
425
|
+
if update_communities:
|
|
426
|
+
await asyncio.gather(
|
|
427
|
+
*[
|
|
428
|
+
update_community(self.driver, self.llm_client, embedder, node)
|
|
429
|
+
for node in nodes
|
|
430
|
+
]
|
|
431
|
+
)
|
|
418
432
|
end = time()
|
|
419
433
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
420
434
|
|
|
@@ -569,7 +583,7 @@ class Graphiti:
|
|
|
569
583
|
Facts will be reranked based on proximity to this node
|
|
570
584
|
group_ids : list[str | None] | None, optional
|
|
571
585
|
The graph partitions to return data from.
|
|
572
|
-
|
|
586
|
+
num_results : int, optional
|
|
573
587
|
The maximum number of results to return. Defaults to 10.
|
|
574
588
|
|
|
575
589
|
Returns
|
|
@@ -668,3 +682,19 @@ class Graphiti:
|
|
|
668
682
|
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
|
|
669
683
|
).nodes
|
|
670
684
|
return nodes
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
|
688
|
+
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
|
689
|
+
|
|
690
|
+
edges_list = await asyncio.gather(
|
|
691
|
+
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
|
695
|
+
|
|
696
|
+
nodes = await get_mentioned_nodes(self.driver, episodes)
|
|
697
|
+
|
|
698
|
+
communities = await get_communities_by_nodes(self.driver, nodes)
|
|
699
|
+
|
|
700
|
+
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
|
@@ -68,7 +68,7 @@ class EpisodeType(Enum):
|
|
|
68
68
|
|
|
69
69
|
|
|
70
70
|
class Node(BaseModel, ABC):
|
|
71
|
-
uuid: str = Field(default_factory=lambda: uuid4()
|
|
71
|
+
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
|
72
72
|
name: str = Field(description='name of the node')
|
|
73
73
|
group_id: str | None = Field(description='partition of the graph')
|
|
74
74
|
labels: list[str] = Field(default_factory=list)
|
|
@@ -170,7 +170,8 @@ class EpisodicNode(Node):
|
|
|
170
170
|
records, _, _ = await driver.execute_query(
|
|
171
171
|
"""
|
|
172
172
|
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
|
173
|
-
RETURN
|
|
173
|
+
RETURN DISTINCT
|
|
174
|
+
e.content AS content,
|
|
174
175
|
e.created_at AS created_at,
|
|
175
176
|
e.valid_at AS valid_at,
|
|
176
177
|
e.uuid AS uuid,
|
|
@@ -42,6 +42,7 @@ from graphiti_core.search.search_utils import (
|
|
|
42
42
|
community_similarity_search,
|
|
43
43
|
edge_fulltext_search,
|
|
44
44
|
edge_similarity_search,
|
|
45
|
+
episode_mentions_reranker,
|
|
45
46
|
node_distance_reranker,
|
|
46
47
|
node_fulltext_search,
|
|
47
48
|
node_similarity_search,
|
|
@@ -131,7 +132,7 @@ async def edge_search(
|
|
|
131
132
|
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
|
132
133
|
|
|
133
134
|
reranked_uuids: list[str] = []
|
|
134
|
-
if config.reranker == EdgeReranker.rrf:
|
|
135
|
+
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
|
|
135
136
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
136
137
|
|
|
137
138
|
reranked_uuids = rrf(search_result_uuids)
|
|
@@ -150,6 +151,9 @@ async def edge_search(
|
|
|
150
151
|
|
|
151
152
|
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
|
152
153
|
|
|
154
|
+
if config.reranker == EdgeReranker.episode_mentions:
|
|
155
|
+
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
|
|
156
|
+
|
|
153
157
|
return reranked_edges
|
|
154
158
|
|
|
155
159
|
|
|
@@ -189,6 +193,8 @@ async def node_search(
|
|
|
189
193
|
reranked_uuids: list[str] = []
|
|
190
194
|
if config.reranker == NodeReranker.rrf:
|
|
191
195
|
reranked_uuids = rrf(search_result_uuids)
|
|
196
|
+
elif config.reranker == NodeReranker.episode_mentions:
|
|
197
|
+
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
|
192
198
|
elif config.reranker == NodeReranker.node_distance:
|
|
193
199
|
if center_node_uuid is None:
|
|
194
200
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
@@ -42,11 +42,13 @@ class CommunitySearchMethod(Enum):
|
|
|
42
42
|
class EdgeReranker(Enum):
|
|
43
43
|
rrf = 'reciprocal_rank_fusion'
|
|
44
44
|
node_distance = 'node_distance'
|
|
45
|
+
episode_mentions = 'episode_mentions'
|
|
45
46
|
|
|
46
47
|
|
|
47
48
|
class NodeReranker(Enum):
|
|
48
49
|
rrf = 'reciprocal_rank_fusion'
|
|
49
50
|
node_distance = 'node_distance'
|
|
51
|
+
episode_mentions = 'episode_mentions'
|
|
50
52
|
|
|
51
53
|
|
|
52
54
|
class CommunityReranker(Enum):
|
|
@@ -59,6 +59,14 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|
|
59
59
|
)
|
|
60
60
|
)
|
|
61
61
|
|
|
62
|
+
# performs a hybrid search over edges with episode mention reranking
|
|
63
|
+
EDGE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
|
|
64
|
+
edge_config=EdgeSearchConfig(
|
|
65
|
+
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
66
|
+
reranker=EdgeReranker.episode_mentions,
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
|
|
62
70
|
# performs a hybrid search over nodes with rrf reranking
|
|
63
71
|
NODE_HYBRID_SEARCH_RRF = SearchConfig(
|
|
64
72
|
node_config=NodeSearchConfig(
|
|
@@ -75,6 +83,14 @@ NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|
|
75
83
|
)
|
|
76
84
|
)
|
|
77
85
|
|
|
86
|
+
# performs a hybrid search over nodes with episode mentions reranking
|
|
87
|
+
NODE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
|
|
88
|
+
node_config=NodeSearchConfig(
|
|
89
|
+
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
|
90
|
+
reranker=NodeReranker.episode_mentions,
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
|
|
78
94
|
# performs a hybrid search over communities with rrf reranking
|
|
79
95
|
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
|
|
80
96
|
community_config=CommunitySearchConfig(
|
|
@@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
|
|
|
36
36
|
RELEVANT_SCHEMA_LIMIT = 3
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
async def get_mentioned_nodes(
|
|
39
|
+
async def get_mentioned_nodes(
|
|
40
|
+
driver: AsyncDriver, episodes: list[EpisodicNode]
|
|
41
|
+
) -> list[EntityNode]:
|
|
40
42
|
episode_uuids = [episode.uuid for episode in episodes]
|
|
41
43
|
records, _, _ = await driver.execute_query(
|
|
42
44
|
"""
|
|
@@ -57,6 +59,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
|
|
57
59
|
return nodes
|
|
58
60
|
|
|
59
61
|
|
|
62
|
+
async def get_communities_by_nodes(
|
|
63
|
+
driver: AsyncDriver, nodes: list[EntityNode]
|
|
64
|
+
) -> list[CommunityNode]:
|
|
65
|
+
node_uuids = [node.uuid for node in nodes]
|
|
66
|
+
records, _, _ = await driver.execute_query(
|
|
67
|
+
"""
|
|
68
|
+
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
|
69
|
+
RETURN DISTINCT
|
|
70
|
+
c.uuid As uuid,
|
|
71
|
+
c.group_id AS group_id,
|
|
72
|
+
c.name AS name,
|
|
73
|
+
c.name_embedding AS name_embedding
|
|
74
|
+
c.created_at AS created_at,
|
|
75
|
+
c.summary AS summary
|
|
76
|
+
""",
|
|
77
|
+
uuids=node_uuids,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
communities = [get_community_node_from_record(record) for record in records]
|
|
81
|
+
|
|
82
|
+
return communities
|
|
83
|
+
|
|
84
|
+
|
|
60
85
|
async def edge_fulltext_search(
|
|
61
86
|
driver: AsyncDriver,
|
|
62
87
|
query: str,
|
|
@@ -634,3 +659,34 @@ async def node_distance_reranker(
|
|
|
634
659
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
635
660
|
|
|
636
661
|
return sorted_uuids
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
|
|
665
|
+
# use rrf as a preliminary ranker
|
|
666
|
+
sorted_uuids = rrf(node_uuids)
|
|
667
|
+
scores: dict[str, float] = {}
|
|
668
|
+
|
|
669
|
+
# Find the shortest path to center node
|
|
670
|
+
query = Query("""
|
|
671
|
+
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid})
|
|
672
|
+
RETURN count(*) AS score
|
|
673
|
+
""")
|
|
674
|
+
|
|
675
|
+
result_scores = await asyncio.gather(
|
|
676
|
+
*[
|
|
677
|
+
driver.execute_query(
|
|
678
|
+
query,
|
|
679
|
+
node_uuid=uuid,
|
|
680
|
+
)
|
|
681
|
+
for uuid in sorted_uuids
|
|
682
|
+
]
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
for uuid, result in zip(sorted_uuids, result_scores):
|
|
686
|
+
record = result[0][0]
|
|
687
|
+
scores[uuid] = record['score']
|
|
688
|
+
|
|
689
|
+
# rerank on shortest distance
|
|
690
|
+
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
691
|
+
|
|
692
|
+
return sorted_uuids
|
{graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/community_operations.py
RENAMED
|
@@ -7,7 +7,7 @@ from neo4j import AsyncDriver
|
|
|
7
7
|
|
|
8
8
|
from graphiti_core.edges import CommunityEdge
|
|
9
9
|
from graphiti_core.llm_client import LLMClient
|
|
10
|
-
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
10
|
+
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
|
|
11
11
|
from graphiti_core.prompts import prompt_library
|
|
12
12
|
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
|
13
13
|
|
|
@@ -153,3 +153,87 @@ async def remove_communities(driver: AsyncDriver):
|
|
|
153
153
|
MATCH (c:Community)
|
|
154
154
|
DETACH DELETE c
|
|
155
155
|
""")
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
async def determine_entity_community(
|
|
159
|
+
driver: AsyncDriver, entity: EntityNode
|
|
160
|
+
) -> tuple[CommunityNode | None, bool]:
|
|
161
|
+
# Check if the node is already part of a community
|
|
162
|
+
records, _, _ = await driver.execute_query(
|
|
163
|
+
"""
|
|
164
|
+
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
|
165
|
+
RETURN
|
|
166
|
+
c.uuid As uuid,
|
|
167
|
+
c.name AS name,
|
|
168
|
+
c.name_embedding AS name_embedding,
|
|
169
|
+
c.group_id AS group_id,
|
|
170
|
+
c.created_at AS created_at,
|
|
171
|
+
c.summary AS summary
|
|
172
|
+
""",
|
|
173
|
+
entity_uuid=entity.uuid,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if len(records) > 0:
|
|
177
|
+
return get_community_node_from_record(records[0]), False
|
|
178
|
+
|
|
179
|
+
# If the node has no community, add it to the mode community of surrounding entities
|
|
180
|
+
records, _, _ = await driver.execute_query(
|
|
181
|
+
"""
|
|
182
|
+
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
|
183
|
+
RETURN
|
|
184
|
+
c.uuid As uuid,
|
|
185
|
+
c.name AS name,
|
|
186
|
+
c.name_embedding AS name_embedding,
|
|
187
|
+
c.group_id AS group_id,
|
|
188
|
+
c.created_at AS created_at,
|
|
189
|
+
c.summary AS summary
|
|
190
|
+
""",
|
|
191
|
+
entity_uuid=entity.uuid,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
communities: list[CommunityNode] = [
|
|
195
|
+
get_community_node_from_record(record) for record in records
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
community_map: dict[str, int] = defaultdict(int)
|
|
199
|
+
for community in communities:
|
|
200
|
+
community_map[community.uuid] += 1
|
|
201
|
+
|
|
202
|
+
community_uuid = None
|
|
203
|
+
max_count = 0
|
|
204
|
+
for uuid, count in community_map.items():
|
|
205
|
+
if count > max_count:
|
|
206
|
+
community_uuid = uuid
|
|
207
|
+
max_count = count
|
|
208
|
+
|
|
209
|
+
if max_count == 0:
|
|
210
|
+
return None, False
|
|
211
|
+
|
|
212
|
+
for community in communities:
|
|
213
|
+
if community.uuid == community_uuid:
|
|
214
|
+
return community, True
|
|
215
|
+
|
|
216
|
+
return None, False
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
async def update_community(
|
|
220
|
+
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
|
|
221
|
+
):
|
|
222
|
+
community, is_new = await determine_entity_community(driver, entity)
|
|
223
|
+
|
|
224
|
+
if community is None:
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
|
|
228
|
+
new_name = await generate_summary_description(llm_client, new_summary)
|
|
229
|
+
|
|
230
|
+
community.summary = new_summary
|
|
231
|
+
community.name = new_name
|
|
232
|
+
|
|
233
|
+
if is_new:
|
|
234
|
+
community_edge = (build_community_edges([entity], community, datetime.now()))[0]
|
|
235
|
+
await community_edge.save(driver)
|
|
236
|
+
|
|
237
|
+
await community.generate_name_embedding(embedder)
|
|
238
|
+
|
|
239
|
+
await community.save(driver)
|
{graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/edge_operations.py
RENAMED
|
@@ -163,6 +163,8 @@ async def dedupe_extracted_edges(
|
|
|
163
163
|
if edge.uuid in duplicate_uuid_map:
|
|
164
164
|
existing_uuid = duplicate_uuid_map[edge.uuid]
|
|
165
165
|
existing_edge = edge_map[existing_uuid]
|
|
166
|
+
# Add current episode to the episodes list
|
|
167
|
+
existing_edge.episodes += edge.episodes
|
|
166
168
|
edges.append(existing_edge)
|
|
167
169
|
else:
|
|
168
170
|
edges.append(edge)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "graphiti-core"
|
|
3
|
-
version = "0.3.
|
|
3
|
+
version = "0.3.3"
|
|
4
4
|
description = "A temporal graph building library"
|
|
5
5
|
authors = [
|
|
6
6
|
"Paul Paliychuk <paul@getzep.com>",
|
|
@@ -22,11 +22,11 @@ tenacity = "<9.0.0"
|
|
|
22
22
|
numpy = ">=1.0.0"
|
|
23
23
|
|
|
24
24
|
[tool.poetry.dev-dependencies]
|
|
25
|
-
pytest = "^8.3.
|
|
25
|
+
pytest = "^8.3.3"
|
|
26
26
|
python-dotenv = "^1.0.1"
|
|
27
27
|
pytest-asyncio = "^0.24.0"
|
|
28
28
|
pytest-xdist = "^3.6.1"
|
|
29
|
-
ruff = "^0.6.
|
|
29
|
+
ruff = "^0.6.5"
|
|
30
30
|
|
|
31
31
|
[tool.poetry.group.dev.dependencies]
|
|
32
32
|
pydantic = "^2.8.2"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/graph_data_operations.py
RENAMED
|
File without changes
|
{graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/node_operations.py
RENAMED
|
File without changes
|
{graphiti_core-0.3.2 → graphiti_core-0.3.3}/graphiti_core/utils/maintenance/temporal_operations.py
RENAMED
|
File without changes
|
|
File without changes
|