graphiti-core 0.11.4__py3-none-any.whl → 0.11.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -21,6 +21,7 @@ from typing import Any
21
21
 
22
22
  import numpy as np
23
23
  from neo4j import AsyncDriver, Query
24
+ from numpy._typing import NDArray
24
25
  from typing_extensions import LiteralString
25
26
 
26
27
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
@@ -101,7 +102,6 @@ async def get_mentioned_nodes(
101
102
  n.uuid As uuid,
102
103
  n.group_id AS group_id,
103
104
  n.name AS name,
104
- n.name_embedding AS name_embedding,
105
105
  n.created_at AS created_at,
106
106
  n.summary AS summary,
107
107
  labels(n) AS labels,
@@ -128,7 +128,6 @@ async def get_communities_by_nodes(
128
128
  c.uuid As uuid,
129
129
  c.group_id AS group_id,
130
130
  c.name AS name,
131
- c.name_embedding AS name_embedding
132
131
  c.created_at AS created_at,
133
132
  c.summary AS summary
134
133
  """,
@@ -172,7 +171,6 @@ async def edge_fulltext_search(
172
171
  r.created_at AS created_at,
173
172
  r.name AS name,
174
173
  r.fact AS fact,
175
- r.fact_embedding AS fact_embedding,
176
174
  r.episodes AS episodes,
177
175
  r.expired_at AS expired_at,
178
176
  r.valid_at AS valid_at,
@@ -242,7 +240,6 @@ async def edge_similarity_search(
242
240
  r.created_at AS created_at,
243
241
  r.name AS name,
244
242
  r.fact AS fact,
245
- r.fact_embedding AS fact_embedding,
246
243
  r.episodes AS episodes,
247
244
  r.expired_at AS expired_at,
248
245
  r.valid_at AS valid_at,
@@ -301,7 +298,6 @@ async def edge_bfs_search(
301
298
  r.created_at AS created_at,
302
299
  r.name AS name,
303
300
  r.fact AS fact,
304
- r.fact_embedding AS fact_embedding,
305
301
  r.episodes AS episodes,
306
302
  r.expired_at AS expired_at,
307
303
  r.valid_at AS valid_at,
@@ -341,10 +337,10 @@ async def node_fulltext_search(
341
337
 
342
338
  query = (
343
339
  """
344
- CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
345
- YIELD node AS n, score
346
- WHERE n:Entity
347
- """
340
+ CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
341
+ YIELD node AS n, score
342
+ WHERE n:Entity
343
+ """
348
344
  + filter_query
349
345
  + ENTITY_NODE_RETURN
350
346
  + """
@@ -510,7 +506,6 @@ async def community_fulltext_search(
510
506
  comm.uuid AS uuid,
511
507
  comm.group_id AS group_id,
512
508
  comm.name AS name,
513
- comm.name_embedding AS name_embedding,
514
509
  comm.created_at AS created_at,
515
510
  comm.summary AS summary
516
511
  ORDER BY score DESC
@@ -555,7 +550,6 @@ async def community_similarity_search(
555
550
  comm.uuid As uuid,
556
551
  comm.group_id AS group_id,
557
552
  comm.name AS name,
558
- comm.name_embedding AS name_embedding,
559
553
  comm.created_at AS created_at,
560
554
  comm.summary AS summary
561
555
  ORDER BY score DESC
@@ -906,6 +900,7 @@ async def node_distance_reranker(
906
900
  node_uuids=filtered_uuids,
907
901
  center_uuid=center_node_uuid,
908
902
  database_=DEFAULT_DATABASE,
903
+ routing_='r',
909
904
  )
910
905
 
911
906
  for result in path_results:
@@ -946,6 +941,7 @@ async def episode_mentions_reranker(
946
941
  query,
947
942
  node_uuids=sorted_uuids,
948
943
  database_=DEFAULT_DATABASE,
944
+ routing_='r',
949
945
  )
950
946
 
951
947
  for result in results:
@@ -959,15 +955,116 @@ async def episode_mentions_reranker(
959
955
 
960
956
  def maximal_marginal_relevance(
961
957
  query_vector: list[float],
962
- candidates: list[tuple[str, list[float]]],
958
+ candidates: dict[str, list[float]],
963
959
  mmr_lambda: float = DEFAULT_MMR_LAMBDA,
964
- ):
965
- candidates_with_mmr: list[tuple[str, float]] = []
966
- for candidate in candidates:
967
- max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates])
968
- mmr = mmr_lambda * np.dot(candidate[1], query_vector) - (1 - mmr_lambda) * max_sim
969
- candidates_with_mmr.append((candidate[0], mmr))
960
+ min_score: float = -2.0,
961
+ ) -> list[str]:
962
+ start = time()
963
+ query_array = np.array(query_vector)
964
+ candidate_arrays: dict[str, NDArray] = {}
965
+ for uuid, embedding in candidates.items():
966
+ candidate_arrays[uuid] = normalize_l2(embedding)
967
+
968
+ uuids: list[str] = list(candidate_arrays.keys())
969
+
970
+ similarity_matrix = np.zeros((len(uuids), len(uuids)))
971
+
972
+ for i, uuid_1 in enumerate(uuids):
973
+ for j, uuid_2 in enumerate(uuids[:i]):
974
+ u = candidate_arrays[uuid_1]
975
+ v = candidate_arrays[uuid_2]
976
+ similarity = np.dot(u, v)
977
+
978
+ similarity_matrix[i, j] = similarity
979
+ similarity_matrix[j, i] = similarity
980
+
981
+ mmr_scores: dict[str, float] = {}
982
+ for i, uuid in enumerate(uuids):
983
+ max_sim = np.max(similarity_matrix[i, :])
984
+ mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
985
+ mmr_scores[uuid] = mmr
986
+
987
+ uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
988
+
989
+ end = time()
990
+ logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
991
+
992
+ return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
993
+
994
+
995
+ async def get_embeddings_for_nodes(
996
+ driver: AsyncDriver, nodes: list[EntityNode]
997
+ ) -> dict[str, list[float]]:
998
+ query: LiteralString = """MATCH (n:Entity)
999
+ WHERE n.uuid IN $node_uuids
1000
+ RETURN DISTINCT
1001
+ n.uuid AS uuid,
1002
+ n.name_embedding AS name_embedding
1003
+ """
1004
+
1005
+ results, _, _ = await driver.execute_query(
1006
+ query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
1007
+ )
1008
+
1009
+ embeddings_dict: dict[str, list[float]] = {}
1010
+ for result in results:
1011
+ uuid: str = result.get('uuid')
1012
+ embedding: list[float] = result.get('name_embedding')
1013
+ if uuid is not None and embedding is not None:
1014
+ embeddings_dict[uuid] = embedding
970
1015
 
971
- candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
1016
+ return embeddings_dict
1017
+
1018
+
1019
+ async def get_embeddings_for_communities(
1020
+ driver: AsyncDriver, communities: list[CommunityNode]
1021
+ ) -> dict[str, list[float]]:
1022
+ query: LiteralString = """MATCH (c:Community)
1023
+ WHERE c.uuid IN $community_uuids
1024
+ RETURN DISTINCT
1025
+ c.uuid AS uuid,
1026
+ c.name_embedding AS name_embedding
1027
+ """
1028
+
1029
+ results, _, _ = await driver.execute_query(
1030
+ query,
1031
+ community_uuids=[community.uuid for community in communities],
1032
+ database_=DEFAULT_DATABASE,
1033
+ routing_='r',
1034
+ )
1035
+
1036
+ embeddings_dict: dict[str, list[float]] = {}
1037
+ for result in results:
1038
+ uuid: str = result.get('uuid')
1039
+ embedding: list[float] = result.get('name_embedding')
1040
+ if uuid is not None and embedding is not None:
1041
+ embeddings_dict[uuid] = embedding
1042
+
1043
+ return embeddings_dict
1044
+
1045
+
1046
+ async def get_embeddings_for_edges(
1047
+ driver: AsyncDriver, edges: list[EntityEdge]
1048
+ ) -> dict[str, list[float]]:
1049
+ query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1050
+ WHERE e.uuid IN $edge_uuids
1051
+ RETURN DISTINCT
1052
+ e.uuid AS uuid,
1053
+ e.fact_embedding AS fact_embedding
1054
+ """
1055
+
1056
+ results, _, _ = await driver.execute_query(
1057
+ query,
1058
+ edge_uuids=[edge.uuid for edge in edges],
1059
+ database_=DEFAULT_DATABASE,
1060
+ routing_='r',
1061
+ )
1062
+
1063
+ embeddings_dict: dict[str, list[float]] = {}
1064
+ for result in results:
1065
+ uuid: str = result.get('uuid')
1066
+ embedding: list[float] = result.get('fact_embedding')
1067
+ if uuid is not None and embedding is not None:
1068
+ embeddings_dict[uuid] = embedding
972
1069
 
973
- return list(set([candidate[0] for candidate in candidates_with_mmr]))
1070
+ return embeddings_dict
@@ -26,6 +26,7 @@ from pydantic import BaseModel
26
26
  from typing_extensions import Any
27
27
 
28
28
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
29
+ from graphiti_core.embedder import EmbedderClient
29
30
  from graphiti_core.graphiti_types import GraphitiClients
30
31
  from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
31
32
  from graphiti_core.llm_client import LLMClient
@@ -95,10 +96,16 @@ async def add_nodes_and_edges_bulk(
95
96
  episodic_edges: list[EpisodicEdge],
96
97
  entity_nodes: list[EntityNode],
97
98
  entity_edges: list[EntityEdge],
99
+ embedder: EmbedderClient,
98
100
  ):
99
101
  async with driver.session(database=DEFAULT_DATABASE) as session:
100
102
  await session.execute_write(
101
- add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges
103
+ add_nodes_and_edges_bulk_tx,
104
+ episodic_nodes,
105
+ episodic_edges,
106
+ entity_nodes,
107
+ entity_edges,
108
+ embedder,
102
109
  )
103
110
 
104
111
 
@@ -108,12 +115,15 @@ async def add_nodes_and_edges_bulk_tx(
108
115
  episodic_edges: list[EpisodicEdge],
109
116
  entity_nodes: list[EntityNode],
110
117
  entity_edges: list[EntityEdge],
118
+ embedder: EmbedderClient,
111
119
  ):
112
120
  episodes = [dict(episode) for episode in episodic_nodes]
113
121
  for episode in episodes:
114
122
  episode['source'] = str(episode['source'].value)
115
123
  nodes: list[dict[str, Any]] = []
116
124
  for node in entity_nodes:
125
+ if node.name_embedding is None:
126
+ await node.generate_name_embedding(embedder)
117
127
  entity_data: dict[str, Any] = {
118
128
  'uuid': node.uuid,
119
129
  'name': node.name,
@@ -127,6 +137,10 @@ async def add_nodes_and_edges_bulk_tx(
127
137
  entity_data['labels'] = list(set(node.labels + ['Entity']))
128
138
  nodes.append(entity_data)
129
139
 
140
+ for edge in entity_edges:
141
+ if edge.fact_embedding is None:
142
+ await edge.generate_embedding(embedder)
143
+
130
144
  await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
131
145
  await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
132
146
  await tx.run(
@@ -239,7 +239,6 @@ async def determine_entity_community(
239
239
  RETURN
240
240
  c.uuid As uuid,
241
241
  c.name AS name,
242
- c.name_embedding AS name_embedding,
243
242
  c.group_id AS group_id,
244
243
  c.created_at AS created_at,
245
244
  c.summary AS summary
@@ -258,7 +257,6 @@ async def determine_entity_community(
258
257
  RETURN
259
258
  c.uuid As uuid,
260
259
  c.name AS name,
261
- c.name_embedding AS name_embedding,
262
260
  c.group_id AS group_id,
263
261
  c.created_at AS created_at,
264
262
  c.summary AS summary
@@ -35,9 +35,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
35
35
  from graphiti_core.search.search_filters import SearchFilters
36
36
  from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
37
37
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
38
- from graphiti_core.utils.maintenance.temporal_operations import (
39
- get_edge_contradictions,
40
- )
41
38
 
42
39
  logger = logging.getLogger(__name__)
43
40
 
@@ -91,7 +88,6 @@ async def extract_edges(
91
88
 
92
89
  extract_edges_max_tokens = 16384
93
90
  llm_client = clients.llm_client
94
- embedder = clients.embedder
95
91
 
96
92
  node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
97
93
 
@@ -184,8 +180,6 @@ async def extract_edges(
184
180
  f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
185
181
  )
186
182
 
187
- await create_entity_edge_embeddings(embedder, edges)
188
-
189
183
  logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
190
184
 
191
185
  return edges
@@ -238,13 +232,17 @@ async def dedupe_extracted_edges(
238
232
  async def resolve_extracted_edges(
239
233
  clients: GraphitiClients,
240
234
  extracted_edges: list[EntityEdge],
235
+ episode: EpisodicNode,
241
236
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
242
237
  driver = clients.driver
243
238
  llm_client = clients.llm_client
239
+ embedder = clients.embedder
240
+
241
+ await create_entity_edge_embeddings(embedder, extracted_edges)
244
242
 
245
243
  search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
246
244
  get_relevant_edges(driver, extracted_edges, SearchFilters()),
247
- get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
245
+ get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
248
246
  )
249
247
 
250
248
  related_edges_lists, edge_invalidation_candidates = search_results
@@ -258,10 +256,7 @@ async def resolve_extracted_edges(
258
256
  await semaphore_gather(
259
257
  *[
260
258
  resolve_extracted_edge(
261
- llm_client,
262
- extracted_edge,
263
- related_edges,
264
- existing_edges,
259
+ llm_client, extracted_edge, related_edges, existing_edges, episode
265
260
  )
266
261
  for extracted_edge, related_edges, existing_edges in zip(
267
262
  extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True
@@ -281,6 +276,11 @@ async def resolve_extracted_edges(
281
276
 
282
277
  logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
283
278
 
279
+ await semaphore_gather(
280
+ create_entity_edge_embeddings(embedder, resolved_edges),
281
+ create_entity_edge_embeddings(embedder, invalidated_edges),
282
+ )
283
+
284
284
  return resolved_edges, invalidated_edges
285
285
 
286
286
 
@@ -322,10 +322,52 @@ async def resolve_extracted_edge(
322
322
  extracted_edge: EntityEdge,
323
323
  related_edges: list[EntityEdge],
324
324
  existing_edges: list[EntityEdge],
325
+ episode: EpisodicNode | None = None,
325
326
  ) -> tuple[EntityEdge, list[EntityEdge]]:
326
- resolved_edge, invalidation_candidates = await semaphore_gather(
327
- dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
328
- get_edge_contradictions(llm_client, extracted_edge, existing_edges),
327
+ if len(related_edges) == 0 and len(existing_edges) == 0:
328
+ return extracted_edge, []
329
+
330
+ start = time()
331
+
332
+ # Prepare context for LLM
333
+ related_edges_context = [
334
+ {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
335
+ ]
336
+
337
+ invalidation_edge_candidates_context = [
338
+ {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
339
+ ]
340
+
341
+ context = {
342
+ 'existing_edges': related_edges_context,
343
+ 'new_edge': extracted_edge.fact,
344
+ 'edge_invalidation_candidates': invalidation_edge_candidates_context,
345
+ }
346
+
347
+ llm_response = await llm_client.generate_response(
348
+ prompt_library.dedupe_edges.resolve_edge(context),
349
+ response_model=EdgeDuplicate,
350
+ model_size=ModelSize.small,
351
+ )
352
+
353
+ duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
354
+
355
+ resolved_edge = (
356
+ related_edges[duplicate_fact_id]
357
+ if 0 <= duplicate_fact_id < len(related_edges)
358
+ else extracted_edge
359
+ )
360
+
361
+ if duplicate_fact_id >= 0 and episode is not None:
362
+ resolved_edge.episodes.append(episode.uuid)
363
+
364
+ contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
365
+
366
+ invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
367
+
368
+ end = time()
369
+ logger.debug(
370
+ f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
329
371
  )
330
372
 
331
373
  now = utc_now()
@@ -356,7 +398,10 @@ async def resolve_extracted_edge(
356
398
 
357
399
 
358
400
  async def dedupe_extracted_edge(
359
- llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge]
401
+ llm_client: LLMClient,
402
+ extracted_edge: EntityEdge,
403
+ related_edges: list[EntityEdge],
404
+ episode: EpisodicNode | None = None,
360
405
  ) -> EntityEdge:
361
406
  if len(related_edges) == 0:
362
407
  return extracted_edge
@@ -391,6 +436,9 @@ async def dedupe_extracted_edge(
391
436
  else extracted_edge
392
437
  )
393
438
 
439
+ if duplicate_fact_id >= 0 and episode is not None:
440
+ edge.episodes.append(episode.uuid)
441
+
394
442
  end = time()
395
443
  logger.debug(
396
444
  f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
@@ -18,6 +18,7 @@ import logging
18
18
  from contextlib import suppress
19
19
  from time import time
20
20
  from typing import Any
21
+ from uuid import uuid4
21
22
 
22
23
  import pydantic
23
24
  from pydantic import BaseModel, Field
@@ -28,14 +29,16 @@ from graphiti_core.llm_client import LLMClient
28
29
  from graphiti_core.llm_client.config import ModelSize
29
30
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
30
31
  from graphiti_core.prompts import prompt_library
31
- from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
32
+ from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
32
33
  from graphiti_core.prompts.extract_nodes import (
33
34
  ExtractedEntities,
34
35
  ExtractedEntity,
35
36
  MissedEntities,
36
37
  )
38
+ from graphiti_core.search.search import search
39
+ from graphiti_core.search.search_config import SearchResults
40
+ from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
37
41
  from graphiti_core.search.search_filters import SearchFilters
38
- from graphiti_core.search.search_utils import get_relevant_nodes
39
42
  from graphiti_core.utils.datetime_utils import utc_now
40
43
 
41
44
  logger = logging.getLogger(__name__)
@@ -70,7 +73,6 @@ async def extract_nodes(
70
73
  ) -> list[EntityNode]:
71
74
  start = time()
72
75
  llm_client = clients.llm_client
73
- embedder = clients.embedder
74
76
  llm_response = {}
75
77
  custom_prompt = ''
76
78
  entities_missed = True
@@ -163,8 +165,6 @@ async def extract_nodes(
163
165
  extracted_nodes.append(new_node)
164
166
  logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
165
167
 
166
- await create_entity_node_embeddings(embedder, extracted_nodes)
167
-
168
168
  logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
169
169
  return extracted_nodes
170
170
 
@@ -227,35 +227,81 @@ async def resolve_extracted_nodes(
227
227
  entity_types: dict[str, BaseModel] | None = None,
228
228
  ) -> tuple[list[EntityNode], dict[str, str]]:
229
229
  llm_client = clients.llm_client
230
- driver = clients.driver
231
230
 
232
- # Find relevant nodes already in the graph
233
- existing_nodes_lists: list[list[EntityNode]] = await get_relevant_nodes(
234
- driver, extracted_nodes, SearchFilters()
235
- )
236
-
237
- resolved_nodes: list[EntityNode] = await semaphore_gather(
231
+ search_results: list[SearchResults] = await semaphore_gather(
238
232
  *[
239
- resolve_extracted_node(
240
- llm_client,
241
- extracted_node,
242
- existing_nodes,
243
- episode,
244
- previous_episodes,
245
- entity_types.get(
246
- next((item for item in extracted_node.labels if item != 'Entity'), '')
247
- )
248
- if entity_types is not None
249
- else None,
250
- )
251
- for extracted_node, existing_nodes in zip(
252
- extracted_nodes, existing_nodes_lists, strict=True
233
+ search(
234
+ clients=clients,
235
+ query=node.name,
236
+ group_ids=[node.group_id],
237
+ search_filter=SearchFilters(),
238
+ config=NODE_HYBRID_SEARCH_RRF,
253
239
  )
240
+ for node in extracted_nodes
254
241
  ]
255
242
  )
256
243
 
244
+ existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
245
+
246
+ entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
247
+
248
+ # Prepare context for LLM
249
+ extracted_nodes_context = [
250
+ {
251
+ 'id': i,
252
+ 'name': node.name,
253
+ 'entity_type': node.labels,
254
+ 'entity_type_description': entity_types_dict.get(
255
+ next((item for item in node.labels if item != 'Entity'), '')
256
+ ).__doc__
257
+ or 'Default Entity Type',
258
+ 'duplication_candidates': [
259
+ {
260
+ **{
261
+ 'idx': j,
262
+ 'name': candidate.name,
263
+ 'entity_types': candidate.labels,
264
+ },
265
+ **candidate.attributes,
266
+ }
267
+ for j, candidate in enumerate(existing_nodes_lists[i])
268
+ ],
269
+ }
270
+ for i, node in enumerate(extracted_nodes)
271
+ ]
272
+
273
+ context = {
274
+ 'extracted_nodes': extracted_nodes_context,
275
+ 'episode_content': episode.content if episode is not None else '',
276
+ 'previous_episodes': [ep.content for ep in previous_episodes]
277
+ if previous_episodes is not None
278
+ else [],
279
+ }
280
+
281
+ llm_response = await llm_client.generate_response(
282
+ prompt_library.dedupe_nodes.nodes(context),
283
+ response_model=NodeResolutions,
284
+ )
285
+
286
+ node_resolutions: list = llm_response.get('entity_resolutions', [])
287
+
288
+ resolved_nodes: list[EntityNode] = []
257
289
  uuid_map: dict[str, str] = {}
258
- for extracted_node, resolved_node in zip(extracted_nodes, resolved_nodes, strict=True):
290
+ for resolution in node_resolutions:
291
+ resolution_id = resolution.get('id', -1)
292
+ duplicate_idx = resolution.get('duplicate_idx', -1)
293
+
294
+ extracted_node = extracted_nodes[resolution_id]
295
+
296
+ resolved_node = (
297
+ existing_nodes_lists[resolution_id][duplicate_idx]
298
+ if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
299
+ else extracted_node
300
+ )
301
+
302
+ resolved_node.name = resolution.get('name')
303
+
304
+ resolved_nodes.append(resolved_node)
259
305
  uuid_map[extracted_node.uuid] = resolved_node.uuid
260
306
 
261
307
  logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
@@ -375,7 +421,7 @@ async def extract_attributes_from_node(
375
421
  'summary': (
376
422
  str,
377
423
  Field(
378
- description='Summary containing the important information about the entity. Under 500 words',
424
+ description='Summary containing the important information about the entity. Under 250 words',
379
425
  ),
380
426
  )
381
427
  }
@@ -387,7 +433,8 @@ async def extract_attributes_from_node(
387
433
  Field(description=field_info.description),
388
434
  )
389
435
 
390
- entity_attributes_model = pydantic.create_model('EntityAttributes', **attributes_definitions)
436
+ unique_model_name = f'EntityAttributes_{uuid4().hex}'
437
+ entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions)
391
438
 
392
439
  summary_context: dict[str, Any] = {
393
440
  'node': node_context,
@@ -400,15 +447,14 @@ async def extract_attributes_from_node(
400
447
  llm_response = await llm_client.generate_response(
401
448
  prompt_library.extract_nodes.extract_attributes(summary_context),
402
449
  response_model=entity_attributes_model,
450
+ model_size=ModelSize.small,
403
451
  )
404
452
 
405
453
  node.summary = llm_response.get('summary', node.summary)
406
- node.name = llm_response.get('name', node.name)
407
454
  node_attributes = {key: value for key, value in llm_response.items()}
408
455
 
409
456
  with suppress(KeyError):
410
457
  del node_attributes['summary']
411
- del node_attributes['name']
412
458
 
413
459
  node.attributes.update(node_attributes)
414
460
 
@@ -427,10 +473,7 @@ async def dedupe_node_list(
427
473
  node_map[node.uuid] = node
428
474
 
429
475
  # Prepare context for LLM
430
- nodes_context = [
431
- {'uuid': node.uuid, 'name': node.name, 'summary': node.summary}.update(node.attributes)
432
- for node in nodes
433
- ]
476
+ nodes_context = [{'uuid': node.uuid, 'name': node.name, **node.attributes} for node in nodes]
434
477
 
435
478
  context = {
436
479
  'nodes': nodes_context,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: graphiti-core
3
- Version: 0.11.4
3
+ Version: 0.11.6
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -18,7 +18,6 @@ Provides-Extra: groq
18
18
  Requires-Dist: anthropic (>=0.49.0) ; extra == "anthropic"
19
19
  Requires-Dist: diskcache (>=5.6.3)
20
20
  Requires-Dist: google-genai (>=1.8.0) ; extra == "google-genai"
21
- Requires-Dist: graph-service (>=1.0.0.7,<2.0.0.0)
22
21
  Requires-Dist: groq (>=0.2.0) ; extra == "groq"
23
22
  Requires-Dist: neo4j (>=5.23.0)
24
23
  Requires-Dist: numpy (>=1.0.0)