graphiti-core 0.11.6rc7__py3-none-any.whl → 0.12.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -50,6 +50,9 @@ from graphiti_core.search.search_utils import (
50
50
  edge_similarity_search,
51
51
  episode_fulltext_search,
52
52
  episode_mentions_reranker,
53
+ get_embeddings_for_communities,
54
+ get_embeddings_for_edges,
55
+ get_embeddings_for_nodes,
53
56
  maximal_marginal_relevance,
54
57
  node_bfs_search,
55
58
  node_distance_reranker,
@@ -209,26 +212,17 @@ async def edge_search(
209
212
 
210
213
  reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
211
214
  elif config.reranker == EdgeReranker.mmr:
212
- await semaphore_gather(
213
- *[edge.load_fact_embedding(driver) for result in search_results for edge in result]
215
+ search_result_uuids_and_vectors = await get_embeddings_for_edges(
216
+ driver, list(edge_uuid_map.values())
214
217
  )
215
- search_result_uuids_and_vectors = [
216
- (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
217
- for result in search_results
218
- for edge in result
219
- ]
220
218
  reranked_uuids = maximal_marginal_relevance(
221
219
  query_vector,
222
220
  search_result_uuids_and_vectors,
223
221
  config.mmr_lambda,
222
+ reranker_min_score,
224
223
  )
225
224
  elif config.reranker == EdgeReranker.cross_encoder:
226
- search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
227
-
228
- rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
229
- rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
230
-
231
- fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
225
+ fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
232
226
  reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
233
227
  reranked_uuids = [
234
228
  fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
@@ -311,30 +305,23 @@ async def node_search(
311
305
  if config.reranker == NodeReranker.rrf:
312
306
  reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
313
307
  elif config.reranker == NodeReranker.mmr:
314
- await semaphore_gather(
315
- *[node.load_name_embedding(driver) for result in search_results for node in result]
308
+ search_result_uuids_and_vectors = await get_embeddings_for_nodes(
309
+ driver, list(node_uuid_map.values())
316
310
  )
317
- search_result_uuids_and_vectors = [
318
- (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
319
- for result in search_results
320
- for node in result
321
- ]
311
+
322
312
  reranked_uuids = maximal_marginal_relevance(
323
313
  query_vector,
324
314
  search_result_uuids_and_vectors,
325
315
  config.mmr_lambda,
316
+ reranker_min_score,
326
317
  )
327
318
  elif config.reranker == NodeReranker.cross_encoder:
328
- # use rrf as a preliminary reranker
329
- rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
330
- rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
319
+ name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
331
320
 
332
- summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
333
-
334
- reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
321
+ reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
335
322
  reranked_uuids = [
336
- summary_to_uuid_map[fact]
337
- for fact, score in reranked_summaries
323
+ name_to_uuid_map[name]
324
+ for name, score in reranked_node_names
338
325
  if score >= reranker_min_score
339
326
  ]
340
327
  elif config.reranker == NodeReranker.episode_mentions:
@@ -437,25 +424,12 @@ async def community_search(
437
424
  if config.reranker == CommunityReranker.rrf:
438
425
  reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
439
426
  elif config.reranker == CommunityReranker.mmr:
440
- await semaphore_gather(
441
- *[
442
- community.load_name_embedding(driver)
443
- for result in search_results
444
- for community in result
445
- ]
427
+ search_result_uuids_and_vectors = await get_embeddings_for_communities(
428
+ driver, list(community_uuid_map.values())
446
429
  )
447
- search_result_uuids_and_vectors = [
448
- (
449
- community.uuid,
450
- community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
451
- )
452
- for result in search_results
453
- for community in result
454
- ]
430
+
455
431
  reranked_uuids = maximal_marginal_relevance(
456
- query_vector,
457
- search_result_uuids_and_vectors,
458
- config.mmr_lambda,
432
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
459
433
  )
460
434
  elif config.reranker == CommunityReranker.cross_encoder:
461
435
  name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
@@ -21,6 +21,7 @@ from typing import Any
21
21
 
22
22
  import numpy as np
23
23
  from neo4j import AsyncDriver, Query
24
+ from numpy._typing import NDArray
24
25
  from typing_extensions import LiteralString
25
26
 
26
27
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
@@ -173,7 +174,8 @@ async def edge_fulltext_search(
173
174
  r.episodes AS episodes,
174
175
  r.expired_at AS expired_at,
175
176
  r.valid_at AS valid_at,
176
- r.invalid_at AS invalid_at
177
+ r.invalid_at AS invalid_at,
178
+ properties(r) AS attributes
177
179
  ORDER BY score DESC LIMIT $limit
178
180
  """
179
181
  )
@@ -242,7 +244,8 @@ async def edge_similarity_search(
242
244
  r.episodes AS episodes,
243
245
  r.expired_at AS expired_at,
244
246
  r.valid_at AS valid_at,
245
- r.invalid_at AS invalid_at
247
+ r.invalid_at AS invalid_at,
248
+ properties(r) AS attributes
246
249
  ORDER BY score DESC
247
250
  LIMIT $limit
248
251
  """
@@ -300,7 +303,8 @@ async def edge_bfs_search(
300
303
  r.episodes AS episodes,
301
304
  r.expired_at AS expired_at,
302
305
  r.valid_at AS valid_at,
303
- r.invalid_at AS invalid_at
306
+ r.invalid_at AS invalid_at,
307
+ properties(r) AS attributes
304
308
  LIMIT $limit
305
309
  """
306
310
  )
@@ -336,10 +340,10 @@ async def node_fulltext_search(
336
340
 
337
341
  query = (
338
342
  """
339
- CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
340
- YIELD node AS n, score
341
- WHERE n:Entity
342
- """
343
+ CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
344
+ YIELD node AS n, score
345
+ WHERE n:Entity
346
+ """
343
347
  + filter_query
344
348
  + ENTITY_NODE_RETURN
345
349
  + """
@@ -770,7 +774,8 @@ async def get_relevant_edges(
770
774
  episodes: e.episodes,
771
775
  expired_at: e.expired_at,
772
776
  valid_at: e.valid_at,
773
- invalid_at: e.invalid_at
777
+ invalid_at: e.invalid_at,
778
+ attributes: properties(e)
774
779
  })[..$limit] AS matches
775
780
  """
776
781
  )
@@ -836,7 +841,8 @@ async def get_edge_invalidation_candidates(
836
841
  episodes: e.episodes,
837
842
  expired_at: e.expired_at,
838
843
  valid_at: e.valid_at,
839
- invalid_at: e.invalid_at
844
+ invalid_at: e.invalid_at,
845
+ attributes: properties(e)
840
846
  })[..$limit] AS matches
841
847
  """
842
848
  )
@@ -899,6 +905,7 @@ async def node_distance_reranker(
899
905
  node_uuids=filtered_uuids,
900
906
  center_uuid=center_node_uuid,
901
907
  database_=DEFAULT_DATABASE,
908
+ routing_='r',
902
909
  )
903
910
 
904
911
  for result in path_results:
@@ -939,6 +946,7 @@ async def episode_mentions_reranker(
939
946
  query,
940
947
  node_uuids=sorted_uuids,
941
948
  database_=DEFAULT_DATABASE,
949
+ routing_='r',
942
950
  )
943
951
 
944
952
  for result in results:
@@ -952,15 +960,116 @@ async def episode_mentions_reranker(
952
960
 
953
961
  def maximal_marginal_relevance(
954
962
  query_vector: list[float],
955
- candidates: list[tuple[str, list[float]]],
963
+ candidates: dict[str, list[float]],
956
964
  mmr_lambda: float = DEFAULT_MMR_LAMBDA,
957
- ):
958
- candidates_with_mmr: list[tuple[str, float]] = []
959
- for candidate in candidates:
960
- max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates])
961
- mmr = mmr_lambda * np.dot(candidate[1], query_vector) - (1 - mmr_lambda) * max_sim
962
- candidates_with_mmr.append((candidate[0], mmr))
965
+ min_score: float = -2.0,
966
+ ) -> list[str]:
967
+ start = time()
968
+ query_array = np.array(query_vector)
969
+ candidate_arrays: dict[str, NDArray] = {}
970
+ for uuid, embedding in candidates.items():
971
+ candidate_arrays[uuid] = normalize_l2(embedding)
972
+
973
+ uuids: list[str] = list(candidate_arrays.keys())
974
+
975
+ similarity_matrix = np.zeros((len(uuids), len(uuids)))
976
+
977
+ for i, uuid_1 in enumerate(uuids):
978
+ for j, uuid_2 in enumerate(uuids[:i]):
979
+ u = candidate_arrays[uuid_1]
980
+ v = candidate_arrays[uuid_2]
981
+ similarity = np.dot(u, v)
982
+
983
+ similarity_matrix[i, j] = similarity
984
+ similarity_matrix[j, i] = similarity
985
+
986
+ mmr_scores: dict[str, float] = {}
987
+ for i, uuid in enumerate(uuids):
988
+ max_sim = np.max(similarity_matrix[i, :])
989
+ mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
990
+ mmr_scores[uuid] = mmr
991
+
992
+ uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
993
+
994
+ end = time()
995
+ logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
996
+
997
+ return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
998
+
999
+
1000
+ async def get_embeddings_for_nodes(
1001
+ driver: AsyncDriver, nodes: list[EntityNode]
1002
+ ) -> dict[str, list[float]]:
1003
+ query: LiteralString = """MATCH (n:Entity)
1004
+ WHERE n.uuid IN $node_uuids
1005
+ RETURN DISTINCT
1006
+ n.uuid AS uuid,
1007
+ n.name_embedding AS name_embedding
1008
+ """
1009
+
1010
+ results, _, _ = await driver.execute_query(
1011
+ query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
1012
+ )
1013
+
1014
+ embeddings_dict: dict[str, list[float]] = {}
1015
+ for result in results:
1016
+ uuid: str = result.get('uuid')
1017
+ embedding: list[float] = result.get('name_embedding')
1018
+ if uuid is not None and embedding is not None:
1019
+ embeddings_dict[uuid] = embedding
963
1020
 
964
- candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
1021
+ return embeddings_dict
1022
+
1023
+
1024
+ async def get_embeddings_for_communities(
1025
+ driver: AsyncDriver, communities: list[CommunityNode]
1026
+ ) -> dict[str, list[float]]:
1027
+ query: LiteralString = """MATCH (c:Community)
1028
+ WHERE c.uuid IN $community_uuids
1029
+ RETURN DISTINCT
1030
+ c.uuid AS uuid,
1031
+ c.name_embedding AS name_embedding
1032
+ """
1033
+
1034
+ results, _, _ = await driver.execute_query(
1035
+ query,
1036
+ community_uuids=[community.uuid for community in communities],
1037
+ database_=DEFAULT_DATABASE,
1038
+ routing_='r',
1039
+ )
1040
+
1041
+ embeddings_dict: dict[str, list[float]] = {}
1042
+ for result in results:
1043
+ uuid: str = result.get('uuid')
1044
+ embedding: list[float] = result.get('name_embedding')
1045
+ if uuid is not None and embedding is not None:
1046
+ embeddings_dict[uuid] = embedding
1047
+
1048
+ return embeddings_dict
1049
+
1050
+
1051
+ async def get_embeddings_for_edges(
1052
+ driver: AsyncDriver, edges: list[EntityEdge]
1053
+ ) -> dict[str, list[float]]:
1054
+ query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1055
+ WHERE e.uuid IN $edge_uuids
1056
+ RETURN DISTINCT
1057
+ e.uuid AS uuid,
1058
+ e.fact_embedding AS fact_embedding
1059
+ """
1060
+
1061
+ results, _, _ = await driver.execute_query(
1062
+ query,
1063
+ edge_uuids=[edge.uuid for edge in edges],
1064
+ database_=DEFAULT_DATABASE,
1065
+ routing_='r',
1066
+ )
1067
+
1068
+ embeddings_dict: dict[str, list[float]] = {}
1069
+ for result in results:
1070
+ uuid: str = result.get('uuid')
1071
+ embedding: list[float] = result.get('fact_embedding')
1072
+ if uuid is not None and embedding is not None:
1073
+ embeddings_dict[uuid] = embedding
965
1074
 
966
- return list(set([candidate[0] for candidate in candidates_with_mmr]))
1075
+ return embeddings_dict
@@ -137,16 +137,34 @@ async def add_nodes_and_edges_bulk_tx(
137
137
  entity_data['labels'] = list(set(node.labels + ['Entity']))
138
138
  nodes.append(entity_data)
139
139
 
140
+ edges: list[dict[str, Any]] = []
140
141
  for edge in entity_edges:
141
142
  if edge.fact_embedding is None:
142
143
  await edge.generate_embedding(embedder)
144
+ edge_data: dict[str, Any] = {
145
+ 'uuid': edge.uuid,
146
+ 'source_node_uuid': edge.source_node_uuid,
147
+ 'target_node_uuid': edge.target_node_uuid,
148
+ 'name': edge.name,
149
+ 'fact': edge.fact,
150
+ 'fact_embedding': edge.fact_embedding,
151
+ 'group_id': edge.group_id,
152
+ 'episodes': edge.episodes,
153
+ 'created_at': edge.created_at,
154
+ 'expired_at': edge.expired_at,
155
+ 'valid_at': edge.valid_at,
156
+ 'invalid_at': edge.invalid_at,
157
+ }
158
+
159
+ edge_data.update(edge.attributes or {})
160
+ edges.append(edge_data)
143
161
 
144
162
  await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
145
163
  await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
146
164
  await tx.run(
147
165
  EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
148
166
  )
149
- await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
167
+ await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
150
168
 
151
169
 
152
170
  async def extract_nodes_and_edges_bulk(
@@ -18,6 +18,8 @@ import logging
18
18
  from datetime import datetime
19
19
  from time import time
20
20
 
21
+ from pydantic import BaseModel
22
+
21
23
  from graphiti_core.edges import (
22
24
  CommunityEdge,
23
25
  EntityEdge,
@@ -35,9 +37,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
35
37
  from graphiti_core.search.search_filters import SearchFilters
36
38
  from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
37
39
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
38
- from graphiti_core.utils.maintenance.temporal_operations import (
39
- get_edge_contradictions,
40
- )
41
40
 
42
41
  logger = logging.getLogger(__name__)
43
42
 
@@ -86,6 +85,7 @@ async def extract_edges(
86
85
  nodes: list[EntityNode],
87
86
  previous_episodes: list[EpisodicNode],
88
87
  group_id: str = '',
88
+ edge_types: dict[str, BaseModel] | None = None,
89
89
  ) -> list[EntityEdge]:
90
90
  start = time()
91
91
 
@@ -94,12 +94,25 @@ async def extract_edges(
94
94
 
95
95
  node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
96
96
 
97
+ edge_types_context = (
98
+ [
99
+ {
100
+ 'fact_type_name': type_name,
101
+ 'fact_type_description': type_model.__doc__,
102
+ }
103
+ for type_name, type_model in edge_types.items()
104
+ ]
105
+ if edge_types is not None
106
+ else []
107
+ )
108
+
97
109
  # Prepare context for LLM
98
110
  context = {
99
111
  'episode_content': episode.content,
100
112
  'nodes': [node.name for node in nodes],
101
113
  'previous_episodes': [ep.content for ep in previous_episodes],
102
114
  'reference_time': episode.valid_at,
115
+ 'edge_types': edge_types_context,
103
116
  'custom_prompt': '',
104
117
  }
105
118
 
@@ -236,6 +249,9 @@ async def resolve_extracted_edges(
236
249
  clients: GraphitiClients,
237
250
  extracted_edges: list[EntityEdge],
238
251
  episode: EpisodicNode,
252
+ entities: list[EntityNode],
253
+ edge_types: dict[str, BaseModel],
254
+ edge_type_map: dict[tuple[str, str], list[str]],
239
255
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
240
256
  driver = clients.driver
241
257
  llm_client = clients.llm_client
@@ -245,7 +261,7 @@ async def resolve_extracted_edges(
245
261
 
246
262
  search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
247
263
  get_relevant_edges(driver, extracted_edges, SearchFilters()),
248
- get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
264
+ get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
249
265
  )
250
266
 
251
267
  related_edges_lists, edge_invalidation_candidates = search_results
@@ -254,15 +270,50 @@ async def resolve_extracted_edges(
254
270
  f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
255
271
  )
256
272
 
273
+ # Build entity hash table
274
+ uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
275
+
276
+ # Determine which edge types are relevant for each edge
277
+ edge_types_lst: list[dict[str, BaseModel]] = []
278
+ for extracted_edge in extracted_edges:
279
+ source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels
280
+ target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels
281
+ label_tuples = [
282
+ (source_label, target_label)
283
+ for source_label in source_node_labels
284
+ for target_label in target_node_labels
285
+ ]
286
+
287
+ extracted_edge_types = {}
288
+ for label_tuple in label_tuples:
289
+ type_names = edge_type_map.get(label_tuple, [])
290
+ for type_name in type_names:
291
+ type_model = edge_types.get(type_name)
292
+ if type_model is None:
293
+ continue
294
+
295
+ extracted_edge_types[type_name] = type_model
296
+
297
+ edge_types_lst.append(extracted_edge_types)
298
+
257
299
  # resolve edges with related edges in the graph and find invalidation candidates
258
300
  results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
259
301
  await semaphore_gather(
260
302
  *[
261
303
  resolve_extracted_edge(
262
- llm_client, extracted_edge, related_edges, existing_edges, episode
304
+ llm_client,
305
+ extracted_edge,
306
+ related_edges,
307
+ existing_edges,
308
+ episode,
309
+ extracted_edge_types,
263
310
  )
264
- for extracted_edge, related_edges, existing_edges in zip(
265
- extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True
311
+ for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
312
+ extracted_edges,
313
+ related_edges_lists,
314
+ edge_invalidation_candidates,
315
+ edge_types_lst,
316
+ strict=True,
266
317
  )
267
318
  ]
268
319
  )
@@ -326,10 +377,86 @@ async def resolve_extracted_edge(
326
377
  related_edges: list[EntityEdge],
327
378
  existing_edges: list[EntityEdge],
328
379
  episode: EpisodicNode,
380
+ edge_types: dict[str, BaseModel] | None = None,
329
381
  ) -> tuple[EntityEdge, list[EntityEdge]]:
330
- resolved_edge, invalidation_candidates = await semaphore_gather(
331
- dedupe_extracted_edge(llm_client, extracted_edge, related_edges, episode),
332
- get_edge_contradictions(llm_client, extracted_edge, existing_edges),
382
+ if len(related_edges) == 0 and len(existing_edges) == 0:
383
+ return extracted_edge, []
384
+
385
+ start = time()
386
+
387
+ # Prepare context for LLM
388
+ related_edges_context = [
389
+ {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
390
+ ]
391
+
392
+ invalidation_edge_candidates_context = [
393
+ {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
394
+ ]
395
+
396
+ edge_types_context = (
397
+ [
398
+ {
399
+ 'fact_type_id': i,
400
+ 'fact_type_name': type_name,
401
+ 'fact_type_description': type_model.__doc__,
402
+ }
403
+ for i, (type_name, type_model) in enumerate(edge_types.items())
404
+ ]
405
+ if edge_types is not None
406
+ else []
407
+ )
408
+
409
+ context = {
410
+ 'existing_edges': related_edges_context,
411
+ 'new_edge': extracted_edge.fact,
412
+ 'edge_invalidation_candidates': invalidation_edge_candidates_context,
413
+ 'edge_types': edge_types_context,
414
+ }
415
+
416
+ llm_response = await llm_client.generate_response(
417
+ prompt_library.dedupe_edges.resolve_edge(context),
418
+ response_model=EdgeDuplicate,
419
+ model_size=ModelSize.small,
420
+ )
421
+
422
+ duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
423
+
424
+ resolved_edge = (
425
+ related_edges[duplicate_fact_id]
426
+ if 0 <= duplicate_fact_id < len(related_edges)
427
+ else extracted_edge
428
+ )
429
+
430
+ if duplicate_fact_id >= 0 and episode is not None:
431
+ resolved_edge.episodes.append(episode.uuid)
432
+
433
+ contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
434
+
435
+ invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
436
+
437
+ fact_type: str = str(llm_response.get('fact_type'))
438
+ if fact_type.upper() != 'DEFAULT' and edge_types is not None:
439
+ resolved_edge.name = fact_type
440
+
441
+ edge_attributes_context = {
442
+ 'message': episode.content,
443
+ 'reference_time': episode.valid_at,
444
+ 'fact': resolved_edge.fact,
445
+ }
446
+
447
+ edge_model = edge_types.get(fact_type)
448
+
449
+ edge_attributes_response = await llm_client.generate_response(
450
+ prompt_library.extract_edges.extract_attributes(edge_attributes_context),
451
+ response_model=edge_model, # type: ignore
452
+ model_size=ModelSize.small,
453
+ )
454
+
455
+ resolved_edge.attributes = edge_attributes_response
456
+
457
+ end = time()
458
+ logger.debug(
459
+ f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
333
460
  )
334
461
 
335
462
  now = utc_now()
@@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
29
29
  from graphiti_core.llm_client.config import ModelSize
30
30
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
31
31
  from graphiti_core.prompts import prompt_library
32
- from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
32
+ from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
33
33
  from graphiti_core.prompts.extract_nodes import (
34
34
  ExtractedEntities,
35
35
  ExtractedEntity,
@@ -243,28 +243,65 @@ async def resolve_extracted_nodes(
243
243
 
244
244
  existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
245
245
 
246
- resolved_nodes: list[EntityNode] = await semaphore_gather(
247
- *[
248
- resolve_extracted_node(
249
- llm_client,
250
- extracted_node,
251
- existing_nodes,
252
- episode,
253
- previous_episodes,
254
- entity_types.get(
255
- next((item for item in extracted_node.labels if item != 'Entity'), '')
256
- )
257
- if entity_types is not None
258
- else None,
259
- )
260
- for extracted_node, existing_nodes in zip(
261
- extracted_nodes, existing_nodes_lists, strict=True
262
- )
263
- ]
246
+ entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
247
+
248
+ # Prepare context for LLM
249
+ extracted_nodes_context = [
250
+ {
251
+ 'id': i,
252
+ 'name': node.name,
253
+ 'entity_type': node.labels,
254
+ 'entity_type_description': entity_types_dict.get(
255
+ next((item for item in node.labels if item != 'Entity'), '')
256
+ ).__doc__
257
+ or 'Default Entity Type',
258
+ 'duplication_candidates': [
259
+ {
260
+ **{
261
+ 'idx': j,
262
+ 'name': candidate.name,
263
+ 'entity_types': candidate.labels,
264
+ },
265
+ **candidate.attributes,
266
+ }
267
+ for j, candidate in enumerate(existing_nodes_lists[i])
268
+ ],
269
+ }
270
+ for i, node in enumerate(extracted_nodes)
271
+ ]
272
+
273
+ context = {
274
+ 'extracted_nodes': extracted_nodes_context,
275
+ 'episode_content': episode.content if episode is not None else '',
276
+ 'previous_episodes': [ep.content for ep in previous_episodes]
277
+ if previous_episodes is not None
278
+ else [],
279
+ }
280
+
281
+ llm_response = await llm_client.generate_response(
282
+ prompt_library.dedupe_nodes.nodes(context),
283
+ response_model=NodeResolutions,
264
284
  )
265
285
 
286
+ node_resolutions: list = llm_response.get('entity_resolutions', [])
287
+
288
+ resolved_nodes: list[EntityNode] = []
266
289
  uuid_map: dict[str, str] = {}
267
- for extracted_node, resolved_node in zip(extracted_nodes, resolved_nodes, strict=True):
290
+ for resolution in node_resolutions:
291
+ resolution_id = resolution.get('id', -1)
292
+ duplicate_idx = resolution.get('duplicate_idx', -1)
293
+
294
+ extracted_node = extracted_nodes[resolution_id]
295
+
296
+ resolved_node = (
297
+ existing_nodes_lists[resolution_id][duplicate_idx]
298
+ if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
299
+ else extracted_node
300
+ )
301
+
302
+ resolved_node.name = resolution.get('name')
303
+
304
+ resolved_nodes.append(resolved_node)
268
305
  uuid_map[extracted_node.uuid] = resolved_node.uuid
269
306
 
270
307
  logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
@@ -410,6 +447,7 @@ async def extract_attributes_from_node(
410
447
  llm_response = await llm_client.generate_response(
411
448
  prompt_library.extract_nodes.extract_attributes(summary_context),
412
449
  response_model=entity_attributes_model,
450
+ model_size=ModelSize.small,
413
451
  )
414
452
 
415
453
  node.summary = llm_response.get('summary', node.summary)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: graphiti-core
3
- Version: 0.11.6rc7
3
+ Version: 0.12.0rc1
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk