graphiti-core 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -1,11 +1,11 @@
1
1
  import asyncio
2
2
  import logging
3
3
  import re
4
- import typing
5
4
  from collections import defaultdict
6
5
  from time import time
6
+ from typing import Any
7
7
 
8
- from neo4j import AsyncDriver
8
+ from neo4j import AsyncDriver, Query
9
9
 
10
10
  from graphiti_core.edges import EntityEdge
11
11
  from graphiti_core.helpers import parse_db_date
@@ -66,12 +66,12 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
66
66
  r.expired_at AS expired_at,
67
67
  r.valid_at AS valid_at,
68
68
  r.invalid_at AS invalid_at
69
-
69
+
70
70
  """,
71
71
  node_ids=node_ids,
72
72
  )
73
73
 
74
- context: dict[str, typing.Any] = {}
74
+ context: dict[str, Any] = {}
75
75
 
76
76
  for record in records:
77
77
  n_uuid = record['source_node_uuid']
@@ -96,14 +96,17 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
96
96
 
97
97
 
98
98
  async def edge_similarity_search(
99
- search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
99
+ driver: AsyncDriver,
100
+ search_vector: list[float],
101
+ source_node_uuid: str | None,
102
+ target_node_uuid: str | None,
103
+ limit: int = RELEVANT_SCHEMA_LIMIT,
100
104
  ) -> list[EntityEdge]:
101
105
  # vector similarity search over embedded facts
102
- records, _, _ = await driver.execute_query(
103
- """
106
+ query = Query("""
104
107
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
105
- YIELD relationship AS r, score
106
- MATCH (n)-[r:RELATES_TO]->(m)
108
+ YIELD relationship AS rel, score
109
+ MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
107
110
  RETURN
108
111
  r.uuid AS uuid,
109
112
  n.uuid AS source_node_uuid,
@@ -117,8 +120,71 @@ async def edge_similarity_search(
117
120
  r.valid_at AS valid_at,
118
121
  r.invalid_at AS invalid_at
119
122
  ORDER BY score DESC
120
- """,
123
+ """)
124
+
125
+ if source_node_uuid is None and target_node_uuid is None:
126
+ query = Query("""
127
+ CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
128
+ YIELD relationship AS rel, score
129
+ MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
130
+ RETURN
131
+ r.uuid AS uuid,
132
+ n.uuid AS source_node_uuid,
133
+ m.uuid AS target_node_uuid,
134
+ r.created_at AS created_at,
135
+ r.name AS name,
136
+ r.fact AS fact,
137
+ r.fact_embedding AS fact_embedding,
138
+ r.episodes AS episodes,
139
+ r.expired_at AS expired_at,
140
+ r.valid_at AS valid_at,
141
+ r.invalid_at AS invalid_at
142
+ ORDER BY score DESC
143
+ """)
144
+ elif source_node_uuid is None:
145
+ query = Query("""
146
+ CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
147
+ YIELD relationship AS rel, score
148
+ MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
149
+ RETURN
150
+ r.uuid AS uuid,
151
+ n.uuid AS source_node_uuid,
152
+ m.uuid AS target_node_uuid,
153
+ r.created_at AS created_at,
154
+ r.name AS name,
155
+ r.fact AS fact,
156
+ r.fact_embedding AS fact_embedding,
157
+ r.episodes AS episodes,
158
+ r.expired_at AS expired_at,
159
+ r.valid_at AS valid_at,
160
+ r.invalid_at AS invalid_at
161
+ ORDER BY score DESC
162
+ """)
163
+ elif target_node_uuid is None:
164
+ query = Query("""
165
+ CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
166
+ YIELD relationship AS rel, score
167
+ MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
168
+ RETURN
169
+ r.uuid AS uuid,
170
+ n.uuid AS source_node_uuid,
171
+ m.uuid AS target_node_uuid,
172
+ r.created_at AS created_at,
173
+ r.name AS name,
174
+ r.fact AS fact,
175
+ r.fact_embedding AS fact_embedding,
176
+ r.episodes AS episodes,
177
+ r.expired_at AS expired_at,
178
+ r.valid_at AS valid_at,
179
+ r.invalid_at AS invalid_at
180
+ ORDER BY score DESC
181
+ """)
182
+
183
+ records, _, _ = await driver.execute_query(
184
+ query,
121
185
  search_vector=search_vector,
186
+ source_uuid=source_node_uuid,
187
+ target_uuid=target_node_uuid,
122
188
  limit=limit,
123
189
  )
124
190
 
@@ -145,7 +211,7 @@ async def edge_similarity_search(
145
211
 
146
212
 
147
213
  async def entity_similarity_search(
148
- search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
214
+ search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
149
215
  ) -> list[EntityNode]:
150
216
  # vector similarity search over entity names
151
217
  records, _, _ = await driver.execute_query(
@@ -155,6 +221,7 @@ async def entity_similarity_search(
155
221
  RETURN
156
222
  n.uuid As uuid,
157
223
  n.name AS name,
224
+ n.name_embeddings AS name_embedding,
158
225
  n.created_at AS created_at,
159
226
  n.summary AS summary
160
227
  ORDER BY score DESC
@@ -169,6 +236,7 @@ async def entity_similarity_search(
169
236
  EntityNode(
170
237
  uuid=record['uuid'],
171
238
  name=record['name'],
239
+ name_embedding=record['name_embedding'],
172
240
  labels=['Entity'],
173
241
  created_at=record['created_at'].to_native(),
174
242
  summary=record['summary'],
@@ -179,7 +247,7 @@ async def entity_similarity_search(
179
247
 
180
248
 
181
249
  async def entity_fulltext_search(
182
- query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
250
+ query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
183
251
  ) -> list[EntityNode]:
184
252
  # BM25 search to get top nodes
185
253
  fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
@@ -187,8 +255,9 @@ async def entity_fulltext_search(
187
255
  """
188
256
  CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
189
257
  RETURN
190
- node.uuid As uuid,
258
+ node.uuid AS uuid,
191
259
  node.name AS name,
260
+ node.name_embeddings AS name_embedding,
192
261
  node.created_at AS created_at,
193
262
  node.summary AS summary
194
263
  ORDER BY score DESC
@@ -204,6 +273,7 @@ async def entity_fulltext_search(
204
273
  EntityNode(
205
274
  uuid=record['uuid'],
206
275
  name=record['name'],
276
+ name_embedding=record['name_embedding'],
207
277
  labels=['Entity'],
208
278
  created_at=record['created_at'].to_native(),
209
279
  summary=record['summary'],
@@ -214,17 +284,18 @@ async def entity_fulltext_search(
214
284
 
215
285
 
216
286
  async def edge_fulltext_search(
217
- query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
287
+ driver: AsyncDriver,
288
+ query: str,
289
+ source_node_uuid: str | None,
290
+ target_node_uuid: str | None,
291
+ limit=RELEVANT_SCHEMA_LIMIT,
218
292
  ) -> list[EntityEdge]:
219
293
  # fulltext search over facts
220
- fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
221
-
222
- records, _, _ = await driver.execute_query(
223
- """
224
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
225
- YIELD relationship AS r, score
226
- MATCH (n:Entity)-[r]->(m:Entity)
227
- RETURN
294
+ cypher_query = Query("""
295
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
296
+ YIELD relationship AS rel, score
297
+ MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
298
+ RETURN
228
299
  r.uuid AS uuid,
229
300
  n.uuid AS source_node_uuid,
230
301
  m.uuid AS target_node_uuid,
@@ -237,8 +308,73 @@ async def edge_fulltext_search(
237
308
  r.valid_at AS valid_at,
238
309
  r.invalid_at AS invalid_at
239
310
  ORDER BY score DESC LIMIT $limit
240
- """,
311
+ """)
312
+
313
+ if source_node_uuid is None and target_node_uuid is None:
314
+ cypher_query = Query("""
315
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
316
+ YIELD relationship AS rel, score
317
+ MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
318
+ RETURN
319
+ r.uuid AS uuid,
320
+ n.uuid AS source_node_uuid,
321
+ m.uuid AS target_node_uuid,
322
+ r.created_at AS created_at,
323
+ r.name AS name,
324
+ r.fact AS fact,
325
+ r.fact_embedding AS fact_embedding,
326
+ r.episodes AS episodes,
327
+ r.expired_at AS expired_at,
328
+ r.valid_at AS valid_at,
329
+ r.invalid_at AS invalid_at
330
+ ORDER BY score DESC LIMIT $limit
331
+ """)
332
+ elif source_node_uuid is None:
333
+ cypher_query = Query("""
334
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
335
+ YIELD relationship AS rel, score
336
+ MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
337
+ RETURN
338
+ r.uuid AS uuid,
339
+ n.uuid AS source_node_uuid,
340
+ m.uuid AS target_node_uuid,
341
+ r.created_at AS created_at,
342
+ r.name AS name,
343
+ r.fact AS fact,
344
+ r.fact_embedding AS fact_embedding,
345
+ r.episodes AS episodes,
346
+ r.expired_at AS expired_at,
347
+ r.valid_at AS valid_at,
348
+ r.invalid_at AS invalid_at
349
+ ORDER BY score DESC LIMIT $limit
350
+ """)
351
+ elif target_node_uuid is None:
352
+ cypher_query = Query("""
353
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
354
+ YIELD relationship AS rel, score
355
+ MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
356
+ RETURN
357
+ r.uuid AS uuid,
358
+ n.uuid AS source_node_uuid,
359
+ m.uuid AS target_node_uuid,
360
+ r.created_at AS created_at,
361
+ r.name AS name,
362
+ r.fact AS fact,
363
+ r.fact_embedding AS fact_embedding,
364
+ r.episodes AS episodes,
365
+ r.expired_at AS expired_at,
366
+ r.valid_at AS valid_at,
367
+ r.invalid_at AS invalid_at
368
+ ORDER BY score DESC LIMIT $limit
369
+ """)
370
+
371
+ fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
372
+
373
+ records, _, _ = await driver.execute_query(
374
+ cypher_query,
241
375
  query=fuzzy_query,
376
+ source_uuid=source_node_uuid,
377
+ target_uuid=target_node_uuid,
242
378
  limit=limit,
243
379
  )
244
380
 
@@ -265,16 +401,16 @@ async def edge_fulltext_search(
265
401
 
266
402
 
267
403
  async def hybrid_node_search(
268
- queries: list[str],
269
- embeddings: list[list[float]],
270
- driver: AsyncDriver,
271
- limit: int | None = None,
404
+ queries: list[str],
405
+ embeddings: list[list[float]],
406
+ driver: AsyncDriver,
407
+ limit: int = RELEVANT_SCHEMA_LIMIT,
272
408
  ) -> list[EntityNode]:
273
409
  """
274
410
  Perform a hybrid search for nodes using both text queries and embeddings.
275
411
 
276
412
  This method combines fulltext search and vector similarity search to find
277
- relevant nodes in the graph database.
413
+ relevant nodes in the graph database. It uses a rrf reranker.
278
414
 
279
415
  Parameters
280
416
  ----------
@@ -307,33 +443,31 @@ async def hybrid_node_search(
307
443
  """
308
444
 
309
445
  start = time()
310
- relevant_nodes: list[EntityNode] = []
311
- relevant_node_uuids = set()
312
446
 
313
- results = await asyncio.gather(
314
- *[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries],
315
- *[
316
- entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT))
317
- for e in embeddings
318
- ],
447
+ results: list[list[EntityNode]] = list(
448
+ await asyncio.gather(
449
+ *[entity_fulltext_search(q, driver, 2 * limit) for q in queries],
450
+ *[entity_similarity_search(e, driver, 2 * limit) for e in embeddings],
451
+ )
319
452
  )
320
453
 
321
- for result in results:
322
- for node in result:
323
- if node.uuid in relevant_node_uuids:
324
- continue
454
+ node_uuid_map: dict[str, EntityNode] = {
455
+ node.uuid: node for result in results for node in result
456
+ }
457
+ result_uuids = [[node.uuid for node in result] for result in results]
458
+
459
+ ranked_uuids = rrf(result_uuids)
325
460
 
326
- relevant_node_uuids.add(node.uuid)
327
- relevant_nodes.append(node)
461
+ relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
328
462
 
329
463
  end = time()
330
- logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
464
+ logger.info(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
331
465
  return relevant_nodes
332
466
 
333
467
 
334
468
  async def get_relevant_nodes(
335
- nodes: list[EntityNode],
336
- driver: AsyncDriver,
469
+ nodes: list[EntityNode],
470
+ driver: AsyncDriver,
337
471
  ) -> list[EntityNode]:
338
472
  """
339
473
  Retrieve relevant nodes based on the provided list of EntityNodes.
@@ -369,8 +503,11 @@ async def get_relevant_nodes(
369
503
 
370
504
 
371
505
  async def get_relevant_edges(
372
- edges: list[EntityEdge],
373
- driver: AsyncDriver,
506
+ driver: AsyncDriver,
507
+ edges: list[EntityEdge],
508
+ source_node_uuid: str | None,
509
+ target_node_uuid: str | None,
510
+ limit: int = RELEVANT_SCHEMA_LIMIT,
374
511
  ) -> list[EntityEdge]:
375
512
  start = time()
376
513
  relevant_edges: list[EntityEdge] = []
@@ -378,11 +515,16 @@ async def get_relevant_edges(
378
515
 
379
516
  results = await asyncio.gather(
380
517
  *[
381
- edge_similarity_search(edge.fact_embedding, driver)
518
+ edge_similarity_search(
519
+ driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit
520
+ )
382
521
  for edge in edges
383
522
  if edge.fact_embedding is not None
384
523
  ],
385
- *[edge_fulltext_search(edge.fact, driver) for edge in edges],
524
+ *[
525
+ edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit)
526
+ for edge in edges
527
+ ],
386
528
  )
387
529
 
388
530
  for result in results:
@@ -415,18 +557,18 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
415
557
 
416
558
 
417
559
  async def node_distance_reranker(
418
- driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
560
+ driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
419
561
  ) -> list[str]:
420
562
  # use rrf as a preliminary ranker
421
563
  sorted_uuids = rrf(results)
422
564
  scores: dict[str, float] = {}
423
565
 
424
566
  for uuid in sorted_uuids:
425
- # Find shortest path to center node
567
+ # Find the shortest path to center node
426
568
  records, _, _ = await driver.execute_query(
427
569
  """
428
570
  MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
429
- MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
571
+ MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity)
430
572
  WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
431
573
  RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
432
574
  """,
@@ -437,8 +579,8 @@ async def node_distance_reranker(
437
579
 
438
580
  for record in records:
439
581
  if (
440
- record['source_uuid'] == center_node_uuid
441
- or record['target_uuid'] == center_node_uuid
582
+ record['source_uuid'] == center_node_uuid
583
+ or record['target_uuid'] == center_node_uuid
442
584
  ):
443
585
  continue
444
586
  distance = record['score']
@@ -15,11 +15,13 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import asyncio
18
+ import logging
18
19
  import typing
19
20
  from datetime import datetime
21
+ from math import ceil
20
22
 
21
23
  from neo4j import AsyncDriver
22
- from numpy import dot
24
+ from numpy import dot, sqrt
23
25
  from pydantic import BaseModel
24
26
 
25
27
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
@@ -39,8 +41,12 @@ from graphiti_core.utils.maintenance.node_operations import (
39
41
  dedupe_node_list,
40
42
  extract_nodes,
41
43
  )
44
+ from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
45
+ from graphiti_core.utils.utils import chunk_edges_by_nodes
42
46
 
43
- CHUNK_SIZE = 15
47
+ logger = logging.getLogger(__name__)
48
+
49
+ CHUNK_SIZE = 10
44
50
 
45
51
 
46
52
  class RawEpisode(BaseModel):
@@ -114,27 +120,58 @@ async def dedupe_nodes_bulk(
114
120
 
115
121
  compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
116
122
 
117
- existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
123
+ node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
118
124
 
119
- nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
120
- llm_client, compressed_nodes, existing_nodes
125
+ existing_nodes_chunks: list[list[EntityNode]] = list(
126
+ await asyncio.gather(
127
+ *[get_relevant_nodes(node_chunk, driver) for node_chunk in node_chunks]
128
+ )
121
129
  )
122
130
 
123
- compressed_map.update(partial_uuid_map)
131
+ results: list[tuple[list[EntityNode], dict[str, str]]] = list(
132
+ await asyncio.gather(
133
+ *[
134
+ dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
135
+ for i, node_chunk in enumerate(node_chunks)
136
+ ]
137
+ )
138
+ )
124
139
 
125
- return nodes, compressed_map
140
+ final_nodes: list[EntityNode] = []
141
+ for result in results:
142
+ final_nodes.extend(result[0])
143
+ partial_uuid_map = result[1]
144
+ compressed_map.update(partial_uuid_map)
145
+
146
+ return final_nodes, compressed_map
126
147
 
127
148
 
128
149
  async def dedupe_edges_bulk(
129
150
  driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
130
151
  ) -> list[EntityEdge]:
131
- # Compress edges
152
+ # First compress edges
132
153
  compressed_edges = await compress_edges(llm_client, extracted_edges)
133
154
 
134
- existing_edges = await get_relevant_edges(compressed_edges, driver)
155
+ edge_chunks = [
156
+ compressed_edges[i : i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
157
+ ]
135
158
 
136
- edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
159
+ relevant_edges_chunks: list[list[EntityEdge]] = list(
160
+ await asyncio.gather(
161
+ *[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
162
+ )
163
+ )
164
+
165
+ resolved_edge_chunks: list[list[EntityEdge]] = list(
166
+ await asyncio.gather(
167
+ *[
168
+ dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
169
+ for i, edge_chunk in enumerate(edge_chunks)
170
+ ]
171
+ )
172
+ )
137
173
 
174
+ edges = [edge for edge_chunk in resolved_edge_chunks for edge in edge_chunk]
138
175
  return edges
139
176
 
140
177
 
@@ -154,13 +191,58 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
154
191
  async def compress_nodes(
155
192
  llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
156
193
  ) -> tuple[list[EntityNode], dict[str, str]]:
194
+ # We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
157
195
  if len(nodes) == 0:
158
196
  return nodes, uuid_map
159
197
 
160
- anchor = nodes[0]
161
- nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or []))
198
+ # Our approach involves us deduplicating chunks of nodes in parallel.
199
+ # We want n chunks of size n so that n ** 2 == len(nodes).
200
+ # We want chunk sizes to be at least 10 for optimizing LLM processing time
201
+ chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
162
202
 
163
- node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
203
+ # First calculate similarity scores between nodes
204
+ similarity_scores: list[tuple[int, int, float]] = [
205
+ (i, j, dot(n.name_embedding or [], m.name_embedding or []))
206
+ for i, n in enumerate(nodes)
207
+ for j, m in enumerate(nodes[:i])
208
+ ]
209
+
210
+ # We now sort by semantic similarity
211
+ similarity_scores.sort(key=lambda score_tuple: score_tuple[2])
212
+
213
+ # initialize our chunks based on chunk size
214
+ node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))]
215
+
216
+ # Draft the most similar nodes into the same chunk
217
+ while len(similarity_scores) > 0:
218
+ i, j, _ = similarity_scores.pop()
219
+ # determine if any of the nodes have already been drafted into a chunk
220
+ n = nodes[i]
221
+ m = nodes[j]
222
+ # make sure the shortest chunks get preference
223
+ node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
224
+
225
+ n_chunk = max([i if n in chunk else -1 for i, chunk in enumerate(node_chunks)])
226
+ m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)])
227
+
228
+ # both nodes already in a chunk
229
+ if n_chunk > -1 and m_chunk > -1:
230
+ continue
231
+
232
+ # n has a chunk and that chunk is not full
233
+ elif n_chunk > -1 and len(node_chunks[n_chunk]) < chunk_size:
234
+ # put m in the same chunk as n
235
+ node_chunks[n_chunk].append(m)
236
+
237
+ # m has a chunk and that chunk is not full
238
+ elif m_chunk > -1 and len(node_chunks[m_chunk]) < chunk_size:
239
+ # put n in the same chunk as m
240
+ node_chunks[m_chunk].append(n)
241
+
242
+ # neither node has a chunk or the chunk is full
243
+ else:
244
+ # add both nodes to the shortest chunk
245
+ node_chunks[-1].extend([n, m])
164
246
 
165
247
  results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
166
248
 
@@ -181,13 +263,9 @@ async def compress_nodes(
181
263
  async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
182
264
  if len(edges) == 0:
183
265
  return edges
184
-
185
- anchor = edges[0]
186
- edges.sort(
187
- key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or [])
188
- )
189
-
190
- edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
266
+ # We only want to dedupe edges that are between the same pair of nodes
267
+ # We build a map of the edges based on their source and target nodes.
268
+ edge_chunks = chunk_edges_by_nodes(edges)
191
269
 
192
270
  results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
193
271
 
@@ -225,3 +303,43 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
225
303
  edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
226
304
 
227
305
  return edges
306
+
307
+
308
+ async def extract_edge_dates_bulk(
309
+ llm_client: LLMClient,
310
+ extracted_edges: list[EntityEdge],
311
+ episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
312
+ ) -> list[EntityEdge]:
313
+ edges: list[EntityEdge] = []
314
+ # confirm that all of our edges have at least one episode
315
+ for edge in extracted_edges:
316
+ if edge.episodes is not None and len(edge.episodes) > 0:
317
+ edges.append(edge)
318
+
319
+ episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
320
+ episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
321
+ }
322
+
323
+ results = await asyncio.gather(
324
+ *[
325
+ extract_edge_dates(
326
+ llm_client,
327
+ edge,
328
+ episode_uuid_map[edge.episodes[0]][0], # type: ignore
329
+ episode_uuid_map[edge.episodes[0]][1], # type: ignore
330
+ )
331
+ for edge in edges
332
+ ]
333
+ )
334
+
335
+ for i, result in enumerate(results):
336
+ valid_at = result[0]
337
+ invalid_at = result[1]
338
+ edge = edges[i]
339
+
340
+ edge.valid_at = valid_at
341
+ edge.invalid_at = invalid_at
342
+ if edge.invalid_at:
343
+ edge.expired_at = datetime.now()
344
+
345
+ return edges