graphiti-core 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -96,14 +96,18 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
96
96
 
97
97
 
98
98
  async def edge_similarity_search(
99
- search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
99
+ driver: AsyncDriver,
100
+ search_vector: list[float],
101
+ limit: int = RELEVANT_SCHEMA_LIMIT,
102
+ source_node_uuid: str = '*',
103
+ target_node_uuid: str = '*',
100
104
  ) -> list[EntityEdge]:
101
105
  # vector similarity search over embedded facts
102
106
  records, _, _ = await driver.execute_query(
103
107
  """
104
108
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
105
- YIELD relationship AS r, score
106
- MATCH (n)-[r:RELATES_TO]->(m)
109
+ YIELD relationship AS rel, score
110
+ MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
107
111
  RETURN
108
112
  r.uuid AS uuid,
109
113
  n.uuid AS source_node_uuid,
@@ -119,6 +123,8 @@ async def edge_similarity_search(
119
123
  ORDER BY score DESC
120
124
  """,
121
125
  search_vector=search_vector,
126
+ source_uuid=source_node_uuid,
127
+ target_uuid=target_node_uuid,
122
128
  limit=limit,
123
129
  )
124
130
 
@@ -214,7 +220,11 @@ async def entity_fulltext_search(
214
220
 
215
221
 
216
222
  async def edge_fulltext_search(
217
- query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
223
+ driver: AsyncDriver,
224
+ query: str,
225
+ limit=RELEVANT_SCHEMA_LIMIT,
226
+ source_node_uuid: str = '*',
227
+ target_node_uuid: str = '*',
218
228
  ) -> list[EntityEdge]:
219
229
  # fulltext search over facts
220
230
  fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
@@ -222,8 +232,8 @@ async def edge_fulltext_search(
222
232
  records, _, _ = await driver.execute_query(
223
233
  """
224
234
  CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
225
- YIELD relationship AS r, score
226
- MATCH (n:Entity)-[r]->(m:Entity)
235
+ YIELD relationship AS rel, score
236
+ MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
227
237
  RETURN
228
238
  r.uuid AS uuid,
229
239
  n.uuid AS source_node_uuid,
@@ -239,6 +249,8 @@ async def edge_fulltext_search(
239
249
  ORDER BY score DESC LIMIT $limit
240
250
  """,
241
251
  query=fuzzy_query,
252
+ source_uuid=source_node_uuid,
253
+ target_uuid=target_node_uuid,
242
254
  limit=limit,
243
255
  )
244
256
 
@@ -268,13 +280,13 @@ async def hybrid_node_search(
268
280
  queries: list[str],
269
281
  embeddings: list[list[float]],
270
282
  driver: AsyncDriver,
271
- limit: int | None = None,
283
+ limit: int = RELEVANT_SCHEMA_LIMIT,
272
284
  ) -> list[EntityNode]:
273
285
  """
274
286
  Perform a hybrid search for nodes using both text queries and embeddings.
275
287
 
276
288
  This method combines fulltext search and vector similarity search to find
277
- relevant nodes in the graph database.
289
+ relevant nodes in the graph database. It uses an rrf reranker.
278
290
 
279
291
  Parameters
280
292
  ----------
@@ -307,27 +319,25 @@ async def hybrid_node_search(
307
319
  """
308
320
 
309
321
  start = time()
310
- relevant_nodes: list[EntityNode] = []
311
- relevant_node_uuids = set()
312
322
 
313
- results = await asyncio.gather(
314
- *[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries],
315
- *[
316
- entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT))
317
- for e in embeddings
318
- ],
323
+ results: list[list[EntityNode]] = list(
324
+ await asyncio.gather(
325
+ *[entity_fulltext_search(q, driver, 2 * limit) for q in queries],
326
+ *[entity_similarity_search(e, driver, 2 * limit) for e in embeddings],
327
+ )
319
328
  )
320
329
 
321
- for result in results:
322
- for node in result:
323
- if node.uuid in relevant_node_uuids:
324
- continue
330
+ node_uuid_map: dict[str, EntityNode] = {
331
+ node.uuid: node for result in results for node in result
332
+ }
333
+ result_uuids = [[node.uuid for node in result] for result in results]
325
334
 
326
- relevant_node_uuids.add(node.uuid)
327
- relevant_nodes.append(node)
335
+ ranked_uuids = rrf(result_uuids)
336
+
337
+ relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
328
338
 
329
339
  end = time()
330
- logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
340
+ logger.info(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
331
341
  return relevant_nodes
332
342
 
333
343
 
@@ -371,6 +381,9 @@ async def get_relevant_nodes(
371
381
  async def get_relevant_edges(
372
382
  edges: list[EntityEdge],
373
383
  driver: AsyncDriver,
384
+ limit: int = RELEVANT_SCHEMA_LIMIT,
385
+ source_node_uuid: str = '*',
386
+ target_node_uuid: str = '*',
374
387
  ) -> list[EntityEdge]:
375
388
  start = time()
376
389
  relevant_edges: list[EntityEdge] = []
@@ -378,11 +391,16 @@ async def get_relevant_edges(
378
391
 
379
392
  results = await asyncio.gather(
380
393
  *[
381
- edge_similarity_search(edge.fact_embedding, driver)
394
+ edge_similarity_search(
395
+ driver, edge.fact_embedding, limit, source_node_uuid, target_node_uuid
396
+ )
382
397
  for edge in edges
383
398
  if edge.fact_embedding is not None
384
399
  ],
385
- *[edge_fulltext_search(edge.fact, driver) for edge in edges],
400
+ *[
401
+ edge_fulltext_search(driver, edge.fact, limit, source_node_uuid, target_node_uuid)
402
+ for edge in edges
403
+ ],
386
404
  )
387
405
 
388
406
  for result in results:
@@ -426,7 +444,7 @@ async def node_distance_reranker(
426
444
  records, _, _ = await driver.execute_query(
427
445
  """
428
446
  MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
429
- MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
447
+ MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity)
430
448
  WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
431
449
  RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
432
450
  """,
@@ -15,11 +15,13 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import asyncio
18
+ import logging
18
19
  import typing
19
20
  from datetime import datetime
21
+ from math import ceil
20
22
 
21
23
  from neo4j import AsyncDriver
22
- from numpy import dot
24
+ from numpy import dot, sqrt
23
25
  from pydantic import BaseModel
24
26
 
25
27
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
@@ -39,8 +41,12 @@ from graphiti_core.utils.maintenance.node_operations import (
39
41
  dedupe_node_list,
40
42
  extract_nodes,
41
43
  )
44
+ from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
45
+ from graphiti_core.utils.utils import chunk_edges_by_nodes
42
46
 
43
- CHUNK_SIZE = 15
47
+ logger = logging.getLogger(__name__)
48
+
49
+ CHUNK_SIZE = 10
44
50
 
45
51
 
46
52
  class RawEpisode(BaseModel):
@@ -114,27 +120,58 @@ async def dedupe_nodes_bulk(
114
120
 
115
121
  compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
116
122
 
117
- existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
123
+ node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
118
124
 
119
- nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
120
- llm_client, compressed_nodes, existing_nodes
125
+ existing_nodes_chunks: list[list[EntityNode]] = list(
126
+ await asyncio.gather(
127
+ *[get_relevant_nodes(node_chunk, driver) for node_chunk in node_chunks]
128
+ )
121
129
  )
122
130
 
123
- compressed_map.update(partial_uuid_map)
131
+ results: list[tuple[list[EntityNode], dict[str, str]]] = list(
132
+ await asyncio.gather(
133
+ *[
134
+ dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
135
+ for i, node_chunk in enumerate(node_chunks)
136
+ ]
137
+ )
138
+ )
124
139
 
125
- return nodes, compressed_map
140
+ final_nodes: list[EntityNode] = []
141
+ for result in results:
142
+ final_nodes.extend(result[0])
143
+ partial_uuid_map = result[1]
144
+ compressed_map.update(partial_uuid_map)
145
+
146
+ return final_nodes, compressed_map
126
147
 
127
148
 
128
149
  async def dedupe_edges_bulk(
129
150
  driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
130
151
  ) -> list[EntityEdge]:
131
- # Compress edges
152
+ # First compress edges
132
153
  compressed_edges = await compress_edges(llm_client, extracted_edges)
133
154
 
134
- existing_edges = await get_relevant_edges(compressed_edges, driver)
155
+ edge_chunks = [
156
+ compressed_edges[i : i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
157
+ ]
135
158
 
136
- edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
159
+ relevant_edges_chunks: list[list[EntityEdge]] = list(
160
+ await asyncio.gather(
161
+ *[get_relevant_edges(edge_chunk, driver) for edge_chunk in edge_chunks]
162
+ )
163
+ )
164
+
165
+ resolved_edge_chunks: list[list[EntityEdge]] = list(
166
+ await asyncio.gather(
167
+ *[
168
+ dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
169
+ for i, edge_chunk in enumerate(edge_chunks)
170
+ ]
171
+ )
172
+ )
137
173
 
174
+ edges = [edge for edge_chunk in resolved_edge_chunks for edge in edge_chunk]
138
175
  return edges
139
176
 
140
177
 
@@ -154,13 +191,58 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
154
191
  async def compress_nodes(
155
192
  llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
156
193
  ) -> tuple[list[EntityNode], dict[str, str]]:
194
+ # We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
157
195
  if len(nodes) == 0:
158
196
  return nodes, uuid_map
159
197
 
160
- anchor = nodes[0]
161
- nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or []))
198
+ # Our approach involves us deduplicating chunks of nodes in parallel.
199
+ # We want n chunks of size n so that n ** 2 == len(nodes).
200
+ # We want chunk sizes to be at least 10 for optimizing LLM processing time
201
+ chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
162
202
 
163
- node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
203
+ # First calculate similarity scores between nodes
204
+ similarity_scores: list[tuple[int, int, float]] = [
205
+ (i, j, dot(n.name_embedding or [], m.name_embedding or []))
206
+ for i, n in enumerate(nodes)
207
+ for j, m in enumerate(nodes[:i])
208
+ ]
209
+
210
+ # We now sort by semantic similarity
211
+ similarity_scores.sort(key=lambda score_tuple: score_tuple[2])
212
+
213
+ # initialize our chunks based on chunk size
214
+ node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))]
215
+
216
+ # Draft the most similar nodes into the same chunk
217
+ while len(similarity_scores) > 0:
218
+ i, j, _ = similarity_scores.pop()
219
+ # determine if any of the nodes have already been drafted into a chunk
220
+ n = nodes[i]
221
+ m = nodes[j]
222
+ # make sure the shortest chunks get preference
223
+ node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
224
+
225
+ n_chunk = max([i if n in chunk else -1 for i, chunk in enumerate(node_chunks)])
226
+ m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)])
227
+
228
+ # both nodes already in a chunk
229
+ if n_chunk > -1 and m_chunk > -1:
230
+ continue
231
+
232
+ # n has a chunk and that chunk is not full
233
+ elif n_chunk > -1 and len(node_chunks[n_chunk]) < chunk_size:
234
+ # put m in the same chunk as n
235
+ node_chunks[n_chunk].append(m)
236
+
237
+ # m has a chunk and that chunk is not full
238
+ elif m_chunk > -1 and len(node_chunks[m_chunk]) < chunk_size:
239
+ # put n in the same chunk as m
240
+ node_chunks[m_chunk].append(n)
241
+
242
+ # neither node has a chunk or the chunk is full
243
+ else:
244
+ # add both nodes to the shortest chunk
245
+ node_chunks[-1].extend([n, m])
164
246
 
165
247
  results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
166
248
 
@@ -181,13 +263,9 @@ async def compress_nodes(
181
263
  async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
182
264
  if len(edges) == 0:
183
265
  return edges
184
-
185
- anchor = edges[0]
186
- edges.sort(
187
- key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or [])
188
- )
189
-
190
- edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
266
+ # We only want to dedupe edges that are between the same pair of nodes
267
+ # We build a map of the edges based on their source and target nodes.
268
+ edge_chunks = chunk_edges_by_nodes(edges)
191
269
 
192
270
  results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
193
271
 
@@ -225,3 +303,43 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
225
303
  edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
226
304
 
227
305
  return edges
306
+
307
+
308
+ async def extract_edge_dates_bulk(
309
+ llm_client: LLMClient,
310
+ extracted_edges: list[EntityEdge],
311
+ episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
312
+ ) -> list[EntityEdge]:
313
+ edges: list[EntityEdge] = []
314
+ # confirm that all of our edges have at least one episode
315
+ for edge in extracted_edges:
316
+ if edge.episodes is not None and len(edge.episodes) > 0:
317
+ edges.append(edge)
318
+
319
+ episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
320
+ episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
321
+ }
322
+
323
+ results = await asyncio.gather(
324
+ *[
325
+ extract_edge_dates(
326
+ llm_client,
327
+ edge,
328
+ episode_uuid_map[edge.episodes[0]][0], # type: ignore
329
+ episode_uuid_map[edge.episodes[0]][1], # type: ignore
330
+ )
331
+ for edge in edges
332
+ ]
333
+ )
334
+
335
+ for i, result in enumerate(results):
336
+ valid_at = result[0]
337
+ invalid_at = result[1]
338
+ edge = edges[i]
339
+
340
+ edge.valid_at = valid_at
341
+ edge.invalid_at = invalid_at
342
+ if edge.invalid_at:
343
+ edge.expired_at = datetime.now()
344
+
345
+ return edges
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import asyncio
17
18
  import logging
18
19
  from datetime import datetime
19
20
  from time import time
@@ -70,7 +71,6 @@ async def extract_edges(
70
71
  }
71
72
 
72
73
  llm_response = await llm_client.generate_response(prompt_library.extract_edges.v2(context))
73
- print(llm_response)
74
74
  edges_data = llm_response.get('edges', [])
75
75
 
76
76
  end = time()
@@ -110,8 +110,8 @@ async def dedupe_extracted_edges(
110
110
  existing_edges: list[EntityEdge],
111
111
  ) -> list[EntityEdge]:
112
112
  # Create edge map
113
- edge_map = {}
114
- for edge in extracted_edges:
113
+ edge_map: dict[str, EntityEdge] = {}
114
+ for edge in existing_edges:
115
115
  edge_map[edge.uuid] = edge
116
116
 
117
117
  # Prepare context for LLM
@@ -125,18 +125,85 @@ async def dedupe_extracted_edges(
125
125
  }
126
126
 
127
127
  llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
128
- unique_edge_data = llm_response.get('unique_facts', [])
129
- logger.info(f'Extracted unique edges: {unique_edge_data}')
128
+ duplicate_data = llm_response.get('duplicates', [])
129
+ logger.info(f'Extracted unique edges: {duplicate_data}')
130
+
131
+ duplicate_uuid_map: dict[str, str] = {}
132
+ for duplicate in duplicate_data:
133
+ uuid_value = duplicate['duplicate_of']
134
+ duplicate_uuid_map[duplicate['uuid']] = uuid_value
130
135
 
131
136
  # Get full edge data
132
- edges = []
133
- for unique_edge in unique_edge_data:
134
- edge = edge_map[unique_edge['uuid']]
135
- edges.append(edge)
137
+ edges: list[EntityEdge] = []
138
+ for edge in extracted_edges:
139
+ if edge.uuid in duplicate_uuid_map:
140
+ existing_uuid = duplicate_uuid_map[edge.uuid]
141
+ existing_edge = edge_map[existing_uuid]
142
+ edges.append(existing_edge)
143
+ else:
144
+ edges.append(edge)
136
145
 
137
146
  return edges
138
147
 
139
148
 
149
+ async def resolve_extracted_edges(
150
+ llm_client: LLMClient,
151
+ extracted_edges: list[EntityEdge],
152
+ existing_edges_lists: list[list[EntityEdge]],
153
+ ) -> list[EntityEdge]:
154
+ resolved_edges: list[EntityEdge] = list(
155
+ await asyncio.gather(
156
+ *[
157
+ resolve_extracted_edge(llm_client, extracted_edge, existing_edges)
158
+ for extracted_edge, existing_edges in zip(extracted_edges, existing_edges_lists)
159
+ ]
160
+ )
161
+ )
162
+
163
+ return resolved_edges
164
+
165
+
166
+ async def resolve_extracted_edge(
167
+ llm_client: LLMClient, extracted_edge: EntityEdge, existing_edges: list[EntityEdge]
168
+ ) -> EntityEdge:
169
+ start = time()
170
+
171
+ # Prepare context for LLM
172
+ existing_edges_context = [
173
+ {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
174
+ ]
175
+
176
+ extracted_edge_context = {
177
+ 'uuid': extracted_edge.uuid,
178
+ 'name': extracted_edge.name,
179
+ 'fact': extracted_edge.fact,
180
+ }
181
+
182
+ context = {
183
+ 'existing_edges': existing_edges_context,
184
+ 'extracted_edges': extracted_edge_context,
185
+ }
186
+
187
+ llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v3(context))
188
+
189
+ is_duplicate: bool = llm_response.get('is_duplicate', False)
190
+ uuid: str | None = llm_response.get('uuid', None)
191
+
192
+ edge = extracted_edge
193
+ if is_duplicate:
194
+ for existing_edge in existing_edges:
195
+ if existing_edge.uuid != uuid:
196
+ continue
197
+ edge = existing_edge
198
+
199
+ end = time()
200
+ logger.info(
201
+ f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
202
+ )
203
+
204
+ return edge
205
+
206
+
140
207
  async def dedupe_edge_list(
141
208
  llm_client: LLMClient,
142
209
  edges: list[EntityEdge],