graphiti-core 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/__init__.py +3 -0
- graphiti_core/edges.py +232 -0
- graphiti_core/graphiti.py +618 -0
- graphiti_core/helpers.py +7 -0
- graphiti_core/llm_client/__init__.py +5 -0
- graphiti_core/llm_client/anthropic_client.py +63 -0
- graphiti_core/llm_client/client.py +96 -0
- graphiti_core/llm_client/config.py +58 -0
- graphiti_core/llm_client/groq_client.py +64 -0
- graphiti_core/llm_client/openai_client.py +65 -0
- graphiti_core/llm_client/utils.py +22 -0
- graphiti_core/nodes.py +250 -0
- graphiti_core/prompts/__init__.py +4 -0
- graphiti_core/prompts/dedupe_edges.py +154 -0
- graphiti_core/prompts/dedupe_nodes.py +151 -0
- graphiti_core/prompts/extract_edge_dates.py +60 -0
- graphiti_core/prompts/extract_edges.py +138 -0
- graphiti_core/prompts/extract_nodes.py +145 -0
- graphiti_core/prompts/invalidate_edges.py +74 -0
- graphiti_core/prompts/lib.py +122 -0
- graphiti_core/prompts/models.py +31 -0
- graphiti_core/search/__init__.py +0 -0
- graphiti_core/search/search.py +142 -0
- graphiti_core/search/search_utils.py +454 -0
- graphiti_core/utils/__init__.py +15 -0
- graphiti_core/utils/bulk_utils.py +227 -0
- graphiti_core/utils/maintenance/__init__.py +16 -0
- graphiti_core/utils/maintenance/edge_operations.py +170 -0
- graphiti_core/utils/maintenance/graph_data_operations.py +133 -0
- graphiti_core/utils/maintenance/node_operations.py +199 -0
- graphiti_core/utils/maintenance/temporal_operations.py +184 -0
- graphiti_core/utils/maintenance/utils.py +0 -0
- graphiti_core/utils/utils.py +39 -0
- graphiti_core-0.1.0.dist-info/LICENSE +201 -0
- graphiti_core-0.1.0.dist-info/METADATA +199 -0
- graphiti_core-0.1.0.dist-info/RECORD +37 -0
- graphiti_core-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
import typing
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from time import time
|
|
7
|
+
|
|
8
|
+
from neo4j import AsyncDriver
|
|
9
|
+
|
|
10
|
+
from graphiti_core.edges import EntityEdge
|
|
11
|
+
from graphiti_core.helpers import parse_db_date
|
|
12
|
+
from graphiti_core.nodes import EntityNode, EpisodicNode
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
RELEVANT_SCHEMA_LIMIT = 3
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
|
|
20
|
+
episode_uuids = [episode.uuid for episode in episodes]
|
|
21
|
+
records, _, _ = await driver.execute_query(
|
|
22
|
+
"""
|
|
23
|
+
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
|
24
|
+
RETURN DISTINCT
|
|
25
|
+
n.uuid As uuid,
|
|
26
|
+
n.name AS name,
|
|
27
|
+
n.created_at AS created_at,
|
|
28
|
+
n.summary AS summary
|
|
29
|
+
""",
|
|
30
|
+
uuids=episode_uuids,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
nodes: list[EntityNode] = []
|
|
34
|
+
|
|
35
|
+
for record in records:
|
|
36
|
+
nodes.append(
|
|
37
|
+
EntityNode(
|
|
38
|
+
uuid=record['uuid'],
|
|
39
|
+
name=record['name'],
|
|
40
|
+
labels=['Entity'],
|
|
41
|
+
created_at=record['created_at'].to_native(),
|
|
42
|
+
summary=record['summary'],
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
return nodes
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
|
50
|
+
records, _, _ = await driver.execute_query(
|
|
51
|
+
"""
|
|
52
|
+
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
|
|
53
|
+
RETURN DISTINCT
|
|
54
|
+
n.uuid AS source_node_uuid,
|
|
55
|
+
n.name AS source_name,
|
|
56
|
+
n.summary AS source_summary,
|
|
57
|
+
m.uuid AS target_node_uuid,
|
|
58
|
+
m.name AS target_name,
|
|
59
|
+
m.summary AS target_summary,
|
|
60
|
+
r.uuid AS uuid,
|
|
61
|
+
r.created_at AS created_at,
|
|
62
|
+
r.name AS name,
|
|
63
|
+
r.fact AS fact,
|
|
64
|
+
r.fact_embedding AS fact_embedding,
|
|
65
|
+
r.episodes AS episodes,
|
|
66
|
+
r.expired_at AS expired_at,
|
|
67
|
+
r.valid_at AS valid_at,
|
|
68
|
+
r.invalid_at AS invalid_at
|
|
69
|
+
|
|
70
|
+
""",
|
|
71
|
+
node_ids=node_ids,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
context: dict[str, typing.Any] = {}
|
|
75
|
+
|
|
76
|
+
for record in records:
|
|
77
|
+
n_uuid = record['source_node_uuid']
|
|
78
|
+
if n_uuid in context:
|
|
79
|
+
context[n_uuid]['facts'].append(record['fact'])
|
|
80
|
+
else:
|
|
81
|
+
context[n_uuid] = {
|
|
82
|
+
'name': record['source_name'],
|
|
83
|
+
'summary': record['source_summary'],
|
|
84
|
+
'facts': [record['fact']],
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
m_uuid = record['target_node_uuid']
|
|
88
|
+
if m_uuid not in context:
|
|
89
|
+
context[m_uuid] = {
|
|
90
|
+
'name': record['target_name'],
|
|
91
|
+
'summary': record['target_summary'],
|
|
92
|
+
'facts': [],
|
|
93
|
+
}
|
|
94
|
+
logger.info(f'bfs search returned context: {context}')
|
|
95
|
+
return context
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
async def edge_similarity_search(
|
|
99
|
+
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
|
100
|
+
) -> list[EntityEdge]:
|
|
101
|
+
# vector similarity search over embedded facts
|
|
102
|
+
records, _, _ = await driver.execute_query(
|
|
103
|
+
"""
|
|
104
|
+
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
105
|
+
YIELD relationship AS r, score
|
|
106
|
+
MATCH (n)-[r:RELATES_TO]->(m)
|
|
107
|
+
RETURN
|
|
108
|
+
r.uuid AS uuid,
|
|
109
|
+
n.uuid AS source_node_uuid,
|
|
110
|
+
m.uuid AS target_node_uuid,
|
|
111
|
+
r.created_at AS created_at,
|
|
112
|
+
r.name AS name,
|
|
113
|
+
r.fact AS fact,
|
|
114
|
+
r.fact_embedding AS fact_embedding,
|
|
115
|
+
r.episodes AS episodes,
|
|
116
|
+
r.expired_at AS expired_at,
|
|
117
|
+
r.valid_at AS valid_at,
|
|
118
|
+
r.invalid_at AS invalid_at
|
|
119
|
+
ORDER BY score DESC
|
|
120
|
+
""",
|
|
121
|
+
search_vector=search_vector,
|
|
122
|
+
limit=limit,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
edges: list[EntityEdge] = []
|
|
126
|
+
|
|
127
|
+
for record in records:
|
|
128
|
+
edge = EntityEdge(
|
|
129
|
+
uuid=record['uuid'],
|
|
130
|
+
source_node_uuid=record['source_node_uuid'],
|
|
131
|
+
target_node_uuid=record['target_node_uuid'],
|
|
132
|
+
fact=record['fact'],
|
|
133
|
+
name=record['name'],
|
|
134
|
+
episodes=record['episodes'],
|
|
135
|
+
fact_embedding=record['fact_embedding'],
|
|
136
|
+
created_at=record['created_at'].to_native(),
|
|
137
|
+
expired_at=parse_db_date(record['expired_at']),
|
|
138
|
+
valid_at=parse_db_date(record['valid_at']),
|
|
139
|
+
invalid_at=parse_db_date(record['invalid_at']),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
edges.append(edge)
|
|
143
|
+
|
|
144
|
+
return edges
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
async def entity_similarity_search(
|
|
148
|
+
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
|
149
|
+
) -> list[EntityNode]:
|
|
150
|
+
# vector similarity search over entity names
|
|
151
|
+
records, _, _ = await driver.execute_query(
|
|
152
|
+
"""
|
|
153
|
+
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
|
154
|
+
YIELD node AS n, score
|
|
155
|
+
RETURN
|
|
156
|
+
n.uuid As uuid,
|
|
157
|
+
n.name AS name,
|
|
158
|
+
n.created_at AS created_at,
|
|
159
|
+
n.summary AS summary
|
|
160
|
+
ORDER BY score DESC
|
|
161
|
+
""",
|
|
162
|
+
search_vector=search_vector,
|
|
163
|
+
limit=limit,
|
|
164
|
+
)
|
|
165
|
+
nodes: list[EntityNode] = []
|
|
166
|
+
|
|
167
|
+
for record in records:
|
|
168
|
+
nodes.append(
|
|
169
|
+
EntityNode(
|
|
170
|
+
uuid=record['uuid'],
|
|
171
|
+
name=record['name'],
|
|
172
|
+
labels=['Entity'],
|
|
173
|
+
created_at=record['created_at'].to_native(),
|
|
174
|
+
summary=record['summary'],
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return nodes
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
async def entity_fulltext_search(
|
|
182
|
+
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
|
183
|
+
) -> list[EntityNode]:
|
|
184
|
+
# BM25 search to get top nodes
|
|
185
|
+
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
186
|
+
records, _, _ = await driver.execute_query(
|
|
187
|
+
"""
|
|
188
|
+
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
|
189
|
+
RETURN
|
|
190
|
+
node.uuid As uuid,
|
|
191
|
+
node.name AS name,
|
|
192
|
+
node.created_at AS created_at,
|
|
193
|
+
node.summary AS summary
|
|
194
|
+
ORDER BY score DESC
|
|
195
|
+
LIMIT $limit
|
|
196
|
+
""",
|
|
197
|
+
query=fuzzy_query,
|
|
198
|
+
limit=limit,
|
|
199
|
+
)
|
|
200
|
+
nodes: list[EntityNode] = []
|
|
201
|
+
|
|
202
|
+
for record in records:
|
|
203
|
+
nodes.append(
|
|
204
|
+
EntityNode(
|
|
205
|
+
uuid=record['uuid'],
|
|
206
|
+
name=record['name'],
|
|
207
|
+
labels=['Entity'],
|
|
208
|
+
created_at=record['created_at'].to_native(),
|
|
209
|
+
summary=record['summary'],
|
|
210
|
+
)
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
return nodes
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
async def edge_fulltext_search(
|
|
217
|
+
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
|
218
|
+
) -> list[EntityEdge]:
|
|
219
|
+
# fulltext search over facts
|
|
220
|
+
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
221
|
+
|
|
222
|
+
records, _, _ = await driver.execute_query(
|
|
223
|
+
"""
|
|
224
|
+
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
225
|
+
YIELD relationship AS r, score
|
|
226
|
+
MATCH (n:Entity)-[r]->(m:Entity)
|
|
227
|
+
RETURN
|
|
228
|
+
r.uuid AS uuid,
|
|
229
|
+
n.uuid AS source_node_uuid,
|
|
230
|
+
m.uuid AS target_node_uuid,
|
|
231
|
+
r.created_at AS created_at,
|
|
232
|
+
r.name AS name,
|
|
233
|
+
r.fact AS fact,
|
|
234
|
+
r.fact_embedding AS fact_embedding,
|
|
235
|
+
r.episodes AS episodes,
|
|
236
|
+
r.expired_at AS expired_at,
|
|
237
|
+
r.valid_at AS valid_at,
|
|
238
|
+
r.invalid_at AS invalid_at
|
|
239
|
+
ORDER BY score DESC LIMIT $limit
|
|
240
|
+
""",
|
|
241
|
+
query=fuzzy_query,
|
|
242
|
+
limit=limit,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
edges: list[EntityEdge] = []
|
|
246
|
+
|
|
247
|
+
for record in records:
|
|
248
|
+
edge = EntityEdge(
|
|
249
|
+
uuid=record['uuid'],
|
|
250
|
+
source_node_uuid=record['source_node_uuid'],
|
|
251
|
+
target_node_uuid=record['target_node_uuid'],
|
|
252
|
+
fact=record['fact'],
|
|
253
|
+
name=record['name'],
|
|
254
|
+
episodes=record['episodes'],
|
|
255
|
+
fact_embedding=record['fact_embedding'],
|
|
256
|
+
created_at=record['created_at'].to_native(),
|
|
257
|
+
expired_at=parse_db_date(record['expired_at']),
|
|
258
|
+
valid_at=parse_db_date(record['valid_at']),
|
|
259
|
+
invalid_at=parse_db_date(record['invalid_at']),
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
edges.append(edge)
|
|
263
|
+
|
|
264
|
+
return edges
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
async def hybrid_node_search(
|
|
268
|
+
queries: list[str],
|
|
269
|
+
embeddings: list[list[float]],
|
|
270
|
+
driver: AsyncDriver,
|
|
271
|
+
limit: int | None = None,
|
|
272
|
+
) -> list[EntityNode]:
|
|
273
|
+
"""
|
|
274
|
+
Perform a hybrid search for nodes using both text queries and embeddings.
|
|
275
|
+
|
|
276
|
+
This method combines fulltext search and vector similarity search to find
|
|
277
|
+
relevant nodes in the graph database.
|
|
278
|
+
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
queries : list[str]
|
|
282
|
+
A list of text queries to search for.
|
|
283
|
+
embeddings : list[list[float]]
|
|
284
|
+
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
|
|
285
|
+
driver : AsyncDriver
|
|
286
|
+
The Neo4j driver instance for database operations.
|
|
287
|
+
limit : int | None, optional
|
|
288
|
+
The maximum number of results to return per search method. If None, a default limit will be applied.
|
|
289
|
+
|
|
290
|
+
Returns
|
|
291
|
+
-------
|
|
292
|
+
list[EntityNode]
|
|
293
|
+
A list of unique EntityNode objects that match the search criteria.
|
|
294
|
+
|
|
295
|
+
Notes
|
|
296
|
+
-----
|
|
297
|
+
This method performs the following steps:
|
|
298
|
+
1. Executes fulltext searches for each query.
|
|
299
|
+
2. Executes vector similarity searches for each embedding.
|
|
300
|
+
3. Combines and deduplicates the results from both search types.
|
|
301
|
+
4. Logs the performance metrics of the search operation.
|
|
302
|
+
|
|
303
|
+
The search results are deduplicated based on the node UUIDs to ensure
|
|
304
|
+
uniqueness in the returned list. The 'limit' parameter is applied to each
|
|
305
|
+
individual search method before deduplication. If not specified, a default
|
|
306
|
+
limit (defined in the individual search functions) will be used.
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
start = time()
|
|
310
|
+
relevant_nodes: list[EntityNode] = []
|
|
311
|
+
relevant_node_uuids = set()
|
|
312
|
+
|
|
313
|
+
results = await asyncio.gather(
|
|
314
|
+
*[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries],
|
|
315
|
+
*[
|
|
316
|
+
entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT))
|
|
317
|
+
for e in embeddings
|
|
318
|
+
],
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
for result in results:
|
|
322
|
+
for node in result:
|
|
323
|
+
if node.uuid in relevant_node_uuids:
|
|
324
|
+
continue
|
|
325
|
+
|
|
326
|
+
relevant_node_uuids.add(node.uuid)
|
|
327
|
+
relevant_nodes.append(node)
|
|
328
|
+
|
|
329
|
+
end = time()
|
|
330
|
+
logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
|
|
331
|
+
return relevant_nodes
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
async def get_relevant_nodes(
|
|
335
|
+
nodes: list[EntityNode],
|
|
336
|
+
driver: AsyncDriver,
|
|
337
|
+
) -> list[EntityNode]:
|
|
338
|
+
"""
|
|
339
|
+
Retrieve relevant nodes based on the provided list of EntityNodes.
|
|
340
|
+
|
|
341
|
+
This method performs a hybrid search using both the names and embeddings
|
|
342
|
+
of the input nodes to find relevant nodes in the graph database.
|
|
343
|
+
|
|
344
|
+
Parameters
|
|
345
|
+
----------
|
|
346
|
+
nodes : list[EntityNode]
|
|
347
|
+
A list of EntityNode objects to use as the basis for the search.
|
|
348
|
+
driver : AsyncDriver
|
|
349
|
+
The Neo4j driver instance for database operations.
|
|
350
|
+
|
|
351
|
+
Returns
|
|
352
|
+
-------
|
|
353
|
+
list[EntityNode]
|
|
354
|
+
A list of EntityNode objects that are deemed relevant based on the input nodes.
|
|
355
|
+
|
|
356
|
+
Notes
|
|
357
|
+
-----
|
|
358
|
+
This method uses the hybrid_node_search function to perform the search,
|
|
359
|
+
which combines fulltext search and vector similarity search.
|
|
360
|
+
It extracts the names and name embeddings (if available) from the input nodes
|
|
361
|
+
to use as search criteria.
|
|
362
|
+
"""
|
|
363
|
+
relevant_nodes = await hybrid_node_search(
|
|
364
|
+
[node.name for node in nodes],
|
|
365
|
+
[node.name_embedding for node in nodes if node.name_embedding is not None],
|
|
366
|
+
driver,
|
|
367
|
+
)
|
|
368
|
+
return relevant_nodes
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
async def get_relevant_edges(
|
|
372
|
+
edges: list[EntityEdge],
|
|
373
|
+
driver: AsyncDriver,
|
|
374
|
+
) -> list[EntityEdge]:
|
|
375
|
+
start = time()
|
|
376
|
+
relevant_edges: list[EntityEdge] = []
|
|
377
|
+
relevant_edge_uuids = set()
|
|
378
|
+
|
|
379
|
+
results = await asyncio.gather(
|
|
380
|
+
*[
|
|
381
|
+
edge_similarity_search(edge.fact_embedding, driver)
|
|
382
|
+
for edge in edges
|
|
383
|
+
if edge.fact_embedding is not None
|
|
384
|
+
],
|
|
385
|
+
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
for result in results:
|
|
389
|
+
for edge in result:
|
|
390
|
+
if edge.uuid in relevant_edge_uuids:
|
|
391
|
+
continue
|
|
392
|
+
|
|
393
|
+
relevant_edge_uuids.add(edge.uuid)
|
|
394
|
+
relevant_edges.append(edge)
|
|
395
|
+
|
|
396
|
+
end = time()
|
|
397
|
+
logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
|
|
398
|
+
|
|
399
|
+
return relevant_edges
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
# takes in a list of rankings of uuids
|
|
403
|
+
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
|
404
|
+
scores: dict[str, float] = defaultdict(float)
|
|
405
|
+
for result in results:
|
|
406
|
+
for i, uuid in enumerate(result):
|
|
407
|
+
scores[uuid] += 1 / (i + rank_const)
|
|
408
|
+
|
|
409
|
+
scored_uuids = [term for term in scores.items()]
|
|
410
|
+
scored_uuids.sort(reverse=True, key=lambda term: term[1])
|
|
411
|
+
|
|
412
|
+
sorted_uuids = [term[0] for term in scored_uuids]
|
|
413
|
+
|
|
414
|
+
return sorted_uuids
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
async def node_distance_reranker(
|
|
418
|
+
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
|
419
|
+
) -> list[str]:
|
|
420
|
+
# use rrf as a preliminary ranker
|
|
421
|
+
sorted_uuids = rrf(results)
|
|
422
|
+
scores: dict[str, float] = {}
|
|
423
|
+
|
|
424
|
+
for uuid in sorted_uuids:
|
|
425
|
+
# Find shortest path to center node
|
|
426
|
+
records, _, _ = await driver.execute_query(
|
|
427
|
+
"""
|
|
428
|
+
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
|
429
|
+
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
|
|
430
|
+
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
|
|
431
|
+
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
|
432
|
+
""",
|
|
433
|
+
edge_uuid=uuid,
|
|
434
|
+
center_uuid=center_node_uuid,
|
|
435
|
+
)
|
|
436
|
+
distance = 0.01
|
|
437
|
+
|
|
438
|
+
for record in records:
|
|
439
|
+
if (
|
|
440
|
+
record['source_uuid'] == center_node_uuid
|
|
441
|
+
or record['target_uuid'] == center_node_uuid
|
|
442
|
+
):
|
|
443
|
+
continue
|
|
444
|
+
distance = record['score']
|
|
445
|
+
|
|
446
|
+
if uuid in scores:
|
|
447
|
+
scores[uuid] = min(1 / distance, scores[uuid])
|
|
448
|
+
else:
|
|
449
|
+
scores[uuid] = 1 / distance
|
|
450
|
+
|
|
451
|
+
# rerank on shortest distance
|
|
452
|
+
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])
|
|
453
|
+
|
|
454
|
+
return sorted_uuids
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .maintenance import (
|
|
2
|
+
build_episodic_edges,
|
|
3
|
+
clear_data,
|
|
4
|
+
extract_edges,
|
|
5
|
+
extract_nodes,
|
|
6
|
+
retrieve_episodes,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
'extract_edges',
|
|
11
|
+
'build_episodic_edges',
|
|
12
|
+
'extract_nodes',
|
|
13
|
+
'clear_data',
|
|
14
|
+
'retrieve_episodes',
|
|
15
|
+
]
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import typing
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
|
|
21
|
+
from neo4j import AsyncDriver
|
|
22
|
+
from numpy import dot
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
|
|
25
|
+
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
|
26
|
+
from graphiti_core.llm_client import LLMClient
|
|
27
|
+
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
28
|
+
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
|
29
|
+
from graphiti_core.utils import retrieve_episodes
|
|
30
|
+
from graphiti_core.utils.maintenance.edge_operations import (
|
|
31
|
+
build_episodic_edges,
|
|
32
|
+
dedupe_edge_list,
|
|
33
|
+
dedupe_extracted_edges,
|
|
34
|
+
extract_edges,
|
|
35
|
+
)
|
|
36
|
+
from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
|
37
|
+
from graphiti_core.utils.maintenance.node_operations import (
|
|
38
|
+
dedupe_extracted_nodes,
|
|
39
|
+
dedupe_node_list,
|
|
40
|
+
extract_nodes,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
CHUNK_SIZE = 15
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class RawEpisode(BaseModel):
|
|
47
|
+
name: str
|
|
48
|
+
content: str
|
|
49
|
+
source_description: str
|
|
50
|
+
source: EpisodeType
|
|
51
|
+
reference_time: datetime
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def retrieve_previous_episodes_bulk(
|
|
55
|
+
driver: AsyncDriver, episodes: list[EpisodicNode]
|
|
56
|
+
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
|
57
|
+
previous_episodes_list = await asyncio.gather(
|
|
58
|
+
*[
|
|
59
|
+
retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN)
|
|
60
|
+
for episode in episodes
|
|
61
|
+
]
|
|
62
|
+
)
|
|
63
|
+
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
|
|
64
|
+
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
return episode_tuples
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
async def extract_nodes_and_edges_bulk(
|
|
71
|
+
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
|
72
|
+
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
|
73
|
+
extracted_nodes_bulk = await asyncio.gather(
|
|
74
|
+
*[
|
|
75
|
+
extract_nodes(llm_client, episode, previous_episodes)
|
|
76
|
+
for episode, previous_episodes in episode_tuples
|
|
77
|
+
]
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
episodes, previous_episodes_list = (
|
|
81
|
+
[episode[0] for episode in episode_tuples],
|
|
82
|
+
[episode[1] for episode in episode_tuples],
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
extracted_edges_bulk = await asyncio.gather(
|
|
86
|
+
*[
|
|
87
|
+
extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i])
|
|
88
|
+
for i, episode in enumerate(episodes)
|
|
89
|
+
]
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
episodic_edges: list[EpisodicEdge] = []
|
|
93
|
+
for i, episode in enumerate(episodes):
|
|
94
|
+
episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at)
|
|
95
|
+
|
|
96
|
+
nodes: list[EntityNode] = []
|
|
97
|
+
for extracted_nodes in extracted_nodes_bulk:
|
|
98
|
+
nodes += extracted_nodes
|
|
99
|
+
|
|
100
|
+
edges: list[EntityEdge] = []
|
|
101
|
+
for extracted_edges in extracted_edges_bulk:
|
|
102
|
+
edges += extracted_edges
|
|
103
|
+
|
|
104
|
+
return nodes, edges, episodic_edges
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
async def dedupe_nodes_bulk(
|
|
108
|
+
driver: AsyncDriver,
|
|
109
|
+
llm_client: LLMClient,
|
|
110
|
+
extracted_nodes: list[EntityNode],
|
|
111
|
+
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
112
|
+
# Compress nodes
|
|
113
|
+
nodes, uuid_map = node_name_match(extracted_nodes)
|
|
114
|
+
|
|
115
|
+
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
|
|
116
|
+
|
|
117
|
+
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
|
|
118
|
+
|
|
119
|
+
nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
|
|
120
|
+
llm_client, compressed_nodes, existing_nodes
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
compressed_map.update(partial_uuid_map)
|
|
124
|
+
|
|
125
|
+
return nodes, compressed_map
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
async def dedupe_edges_bulk(
|
|
129
|
+
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
|
130
|
+
) -> list[EntityEdge]:
|
|
131
|
+
# Compress edges
|
|
132
|
+
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
|
133
|
+
|
|
134
|
+
existing_edges = await get_relevant_edges(compressed_edges, driver)
|
|
135
|
+
|
|
136
|
+
edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
|
|
137
|
+
|
|
138
|
+
return edges
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
|
|
142
|
+
uuid_map: dict[str, str] = {}
|
|
143
|
+
name_map: dict[str, EntityNode] = {}
|
|
144
|
+
for node in nodes:
|
|
145
|
+
if node.name in name_map:
|
|
146
|
+
uuid_map[node.uuid] = name_map[node.name].uuid
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
name_map[node.name] = node
|
|
150
|
+
|
|
151
|
+
return [node for node in name_map.values()], uuid_map
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
async def compress_nodes(
|
|
155
|
+
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
|
156
|
+
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
157
|
+
if len(nodes) == 0:
|
|
158
|
+
return nodes, uuid_map
|
|
159
|
+
|
|
160
|
+
anchor = nodes[0]
|
|
161
|
+
nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or []))
|
|
162
|
+
|
|
163
|
+
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
|
164
|
+
|
|
165
|
+
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
|
|
166
|
+
|
|
167
|
+
extended_map = dict(uuid_map)
|
|
168
|
+
compressed_nodes: list[EntityNode] = []
|
|
169
|
+
for node_chunk, uuid_map_chunk in results:
|
|
170
|
+
compressed_nodes += node_chunk
|
|
171
|
+
extended_map.update(uuid_map_chunk)
|
|
172
|
+
|
|
173
|
+
# Check if we have removed all duplicates
|
|
174
|
+
if len(compressed_nodes) == len(nodes):
|
|
175
|
+
compressed_uuid_map = compress_uuid_map(extended_map)
|
|
176
|
+
return compressed_nodes, compressed_uuid_map
|
|
177
|
+
|
|
178
|
+
return await compress_nodes(llm_client, compressed_nodes, extended_map)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
|
|
182
|
+
if len(edges) == 0:
|
|
183
|
+
return edges
|
|
184
|
+
|
|
185
|
+
anchor = edges[0]
|
|
186
|
+
edges.sort(
|
|
187
|
+
key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or [])
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
|
|
191
|
+
|
|
192
|
+
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
|
|
193
|
+
|
|
194
|
+
compressed_edges: list[EntityEdge] = []
|
|
195
|
+
for edge_chunk in results:
|
|
196
|
+
compressed_edges += edge_chunk
|
|
197
|
+
|
|
198
|
+
# Check if we have removed all duplicates
|
|
199
|
+
if len(compressed_edges) == len(edges):
|
|
200
|
+
return compressed_edges
|
|
201
|
+
|
|
202
|
+
return await compress_edges(llm_client, compressed_edges)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
|
206
|
+
# make sure all uuid values aren't mapped to other uuids
|
|
207
|
+
compressed_map = {}
|
|
208
|
+
for key, uuid in uuid_map.items():
|
|
209
|
+
curr_value = uuid
|
|
210
|
+
while curr_value in uuid_map:
|
|
211
|
+
curr_value = uuid_map[curr_value]
|
|
212
|
+
|
|
213
|
+
compressed_map[key] = curr_value
|
|
214
|
+
return compressed_map
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
E = typing.TypeVar('E', bound=Edge)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
|
221
|
+
for edge in edges:
|
|
222
|
+
source_uuid = edge.source_node_uuid
|
|
223
|
+
target_uuid = edge.target_node_uuid
|
|
224
|
+
edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
|
|
225
|
+
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
|
226
|
+
|
|
227
|
+
return edges
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .edge_operations import build_episodic_edges, extract_edges
|
|
2
|
+
from .graph_data_operations import (
|
|
3
|
+
clear_data,
|
|
4
|
+
retrieve_episodes,
|
|
5
|
+
)
|
|
6
|
+
from .node_operations import extract_nodes
|
|
7
|
+
from .temporal_operations import invalidate_edges
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
'extract_edges',
|
|
11
|
+
'build_episodic_edges',
|
|
12
|
+
'extract_nodes',
|
|
13
|
+
'clear_data',
|
|
14
|
+
'retrieve_episodes',
|
|
15
|
+
'invalidate_edges',
|
|
16
|
+
]
|