graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
  2. graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
  3. graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
  4. graphiti_core/decorators.py +110 -0
  5. graphiti_core/driver/__init__.py +19 -0
  6. graphiti_core/driver/driver.py +124 -0
  7. graphiti_core/driver/falkordb_driver.py +362 -0
  8. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  9. graphiti_core/driver/kuzu_driver.py +182 -0
  10. graphiti_core/driver/neo4j_driver.py +117 -0
  11. graphiti_core/driver/neptune_driver.py +305 -0
  12. graphiti_core/driver/search_interface/search_interface.py +89 -0
  13. graphiti_core/edges.py +287 -172
  14. graphiti_core/embedder/azure_openai.py +71 -0
  15. graphiti_core/embedder/client.py +2 -1
  16. graphiti_core/embedder/gemini.py +116 -22
  17. graphiti_core/embedder/voyage.py +13 -2
  18. graphiti_core/errors.py +8 -0
  19. graphiti_core/graph_queries.py +162 -0
  20. graphiti_core/graphiti.py +705 -193
  21. graphiti_core/graphiti_types.py +4 -2
  22. graphiti_core/helpers.py +87 -10
  23. graphiti_core/llm_client/__init__.py +16 -0
  24. graphiti_core/llm_client/anthropic_client.py +159 -56
  25. graphiti_core/llm_client/azure_openai_client.py +115 -0
  26. graphiti_core/llm_client/client.py +98 -21
  27. graphiti_core/llm_client/config.py +1 -1
  28. graphiti_core/llm_client/gemini_client.py +290 -41
  29. graphiti_core/llm_client/groq_client.py +14 -3
  30. graphiti_core/llm_client/openai_base_client.py +261 -0
  31. graphiti_core/llm_client/openai_client.py +56 -132
  32. graphiti_core/llm_client/openai_generic_client.py +91 -56
  33. graphiti_core/models/edges/edge_db_queries.py +259 -35
  34. graphiti_core/models/nodes/node_db_queries.py +311 -32
  35. graphiti_core/nodes.py +420 -205
  36. graphiti_core/prompts/dedupe_edges.py +46 -32
  37. graphiti_core/prompts/dedupe_nodes.py +67 -42
  38. graphiti_core/prompts/eval.py +4 -4
  39. graphiti_core/prompts/extract_edges.py +27 -16
  40. graphiti_core/prompts/extract_nodes.py +74 -31
  41. graphiti_core/prompts/prompt_helpers.py +39 -0
  42. graphiti_core/prompts/snippets.py +29 -0
  43. graphiti_core/prompts/summarize_nodes.py +23 -25
  44. graphiti_core/search/search.py +158 -82
  45. graphiti_core/search/search_config.py +39 -4
  46. graphiti_core/search/search_filters.py +126 -35
  47. graphiti_core/search/search_helpers.py +5 -6
  48. graphiti_core/search/search_utils.py +1405 -485
  49. graphiti_core/telemetry/__init__.py +9 -0
  50. graphiti_core/telemetry/telemetry.py +117 -0
  51. graphiti_core/tracer.py +193 -0
  52. graphiti_core/utils/bulk_utils.py +364 -285
  53. graphiti_core/utils/datetime_utils.py +13 -0
  54. graphiti_core/utils/maintenance/community_operations.py +67 -49
  55. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  56. graphiti_core/utils/maintenance/edge_operations.py +339 -197
  57. graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
  58. graphiti_core/utils/maintenance/node_operations.py +319 -238
  59. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  60. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  61. graphiti_core/utils/text_utils.py +53 -0
  62. graphiti_core-0.24.3.dist-info/METADATA +726 -0
  63. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  64. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  65. graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
  66. graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
  67. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  68. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
@@ -15,22 +15,26 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import logging
18
- from contextlib import suppress
18
+ from collections.abc import Awaitable, Callable
19
19
  from time import time
20
20
  from typing import Any
21
- from uuid import uuid4
22
21
 
23
- import pydantic
24
- from pydantic import BaseModel, Field
22
+ from pydantic import BaseModel
25
23
 
26
24
  from graphiti_core.graphiti_types import GraphitiClients
27
25
  from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
28
26
  from graphiti_core.llm_client import LLMClient
29
27
  from graphiti_core.llm_client.config import ModelSize
30
- from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
28
+ from graphiti_core.nodes import (
29
+ EntityNode,
30
+ EpisodeType,
31
+ EpisodicNode,
32
+ create_entity_node_embeddings,
33
+ )
31
34
  from graphiti_core.prompts import prompt_library
32
35
  from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
33
36
  from graphiti_core.prompts.extract_nodes import (
37
+ EntitySummary,
34
38
  ExtractedEntities,
35
39
  ExtractedEntity,
36
40
  MissedEntities,
@@ -40,15 +44,28 @@ from graphiti_core.search.search_config import SearchResults
40
44
  from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
41
45
  from graphiti_core.search.search_filters import SearchFilters
42
46
  from graphiti_core.utils.datetime_utils import utc_now
47
+ from graphiti_core.utils.maintenance.dedup_helpers import (
48
+ DedupCandidateIndexes,
49
+ DedupResolutionState,
50
+ _build_candidate_indexes,
51
+ _resolve_with_similarity,
52
+ )
53
+ from graphiti_core.utils.maintenance.edge_operations import (
54
+ filter_existing_duplicate_of_edges,
55
+ )
56
+ from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
43
57
 
44
58
  logger = logging.getLogger(__name__)
45
59
 
60
+ NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
61
+
46
62
 
47
63
  async def extract_nodes_reflexion(
48
64
  llm_client: LLMClient,
49
65
  episode: EpisodicNode,
50
66
  previous_episodes: list[EpisodicNode],
51
67
  node_names: list[str],
68
+ group_id: str | None = None,
52
69
  ) -> list[str]:
53
70
  # Prepare context for LLM
54
71
  context = {
@@ -58,7 +75,10 @@ async def extract_nodes_reflexion(
58
75
  }
59
76
 
60
77
  llm_response = await llm_client.generate_response(
61
- prompt_library.extract_nodes.reflexion(context), MissedEntities
78
+ prompt_library.extract_nodes.reflexion(context),
79
+ MissedEntities,
80
+ group_id=group_id,
81
+ prompt_name='extract_nodes.reflexion',
62
82
  )
63
83
  missed_entities = llm_response.get('missed_entities', [])
64
84
 
@@ -69,7 +89,8 @@ async def extract_nodes(
69
89
  clients: GraphitiClients,
70
90
  episode: EpisodicNode,
71
91
  previous_episodes: list[EpisodicNode],
72
- entity_types: dict[str, BaseModel] | None = None,
92
+ entity_types: dict[str, type[BaseModel]] | None = None,
93
+ excluded_entity_types: list[str] | None = None,
73
94
  ) -> list[EntityNode]:
74
95
  start = time()
75
96
  llm_client = clients.llm_client
@@ -113,20 +134,27 @@ async def extract_nodes(
113
134
  llm_response = await llm_client.generate_response(
114
135
  prompt_library.extract_nodes.extract_message(context),
115
136
  response_model=ExtractedEntities,
137
+ group_id=episode.group_id,
138
+ prompt_name='extract_nodes.extract_message',
116
139
  )
117
140
  elif episode.source == EpisodeType.text:
118
141
  llm_response = await llm_client.generate_response(
119
- prompt_library.extract_nodes.extract_text(context), response_model=ExtractedEntities
142
+ prompt_library.extract_nodes.extract_text(context),
143
+ response_model=ExtractedEntities,
144
+ group_id=episode.group_id,
145
+ prompt_name='extract_nodes.extract_text',
120
146
  )
121
147
  elif episode.source == EpisodeType.json:
122
148
  llm_response = await llm_client.generate_response(
123
- prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities
149
+ prompt_library.extract_nodes.extract_json(context),
150
+ response_model=ExtractedEntities,
151
+ group_id=episode.group_id,
152
+ prompt_name='extract_nodes.extract_json',
124
153
  )
125
154
 
126
- extracted_entities: list[ExtractedEntity] = [
127
- ExtractedEntity(**entity_types_context)
128
- for entity_types_context in llm_response.get('extracted_entities', [])
129
- ]
155
+ response_object = ExtractedEntities(**llm_response)
156
+
157
+ extracted_entities: list[ExtractedEntity] = response_object.extracted_entities
130
158
 
131
159
  reflexion_iterations += 1
132
160
  if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
@@ -135,6 +163,7 @@ async def extract_nodes(
135
163
  episode,
136
164
  previous_episodes,
137
165
  [entity.name for entity in extracted_entities],
166
+ episode.group_id,
138
167
  )
139
168
 
140
169
  entities_missed = len(missing_entities) != 0
@@ -149,9 +178,18 @@ async def extract_nodes(
149
178
  # Convert the extracted data into EntityNode objects
150
179
  extracted_nodes = []
151
180
  for extracted_entity in filtered_extracted_entities:
152
- entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
153
- 'entity_type_name'
154
- )
181
+ type_id = extracted_entity.entity_type_id
182
+ if 0 <= type_id < len(entity_types_context):
183
+ entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
184
+ 'entity_type_name'
185
+ )
186
+ else:
187
+ entity_type_name = 'Entity'
188
+
189
+ # Check if this entity type should be excluded
190
+ if excluded_entity_types and entity_type_name in excluded_entity_types:
191
+ logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
192
+ continue
155
193
 
156
194
  labels: list[str] = list({'Entity', str(entity_type_name)})
157
195
 
@@ -166,68 +204,16 @@ async def extract_nodes(
166
204
  logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
167
205
 
168
206
  logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
169
- return extracted_nodes
170
-
171
-
172
- async def dedupe_extracted_nodes(
173
- llm_client: LLMClient,
174
- extracted_nodes: list[EntityNode],
175
- existing_nodes: list[EntityNode],
176
- ) -> tuple[list[EntityNode], dict[str, str]]:
177
- start = time()
178
-
179
- # build existing node map
180
- node_map: dict[str, EntityNode] = {}
181
- for node in existing_nodes:
182
- node_map[node.uuid] = node
183
-
184
- # Prepare context for LLM
185
- existing_nodes_context = [
186
- {'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
187
- ]
188
-
189
- extracted_nodes_context = [
190
- {'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in extracted_nodes
191
- ]
192
-
193
- context = {
194
- 'existing_nodes': existing_nodes_context,
195
- 'extracted_nodes': extracted_nodes_context,
196
- }
197
-
198
- llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.node(context))
199
-
200
- duplicate_data = llm_response.get('duplicates', [])
201
-
202
- end = time()
203
- logger.debug(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
204
-
205
- uuid_map: dict[str, str] = {}
206
- for duplicate in duplicate_data:
207
- uuid_value = duplicate['duplicate_of']
208
- uuid_map[duplicate['uuid']] = uuid_value
209
-
210
- nodes: list[EntityNode] = []
211
- for node in extracted_nodes:
212
- if node.uuid in uuid_map:
213
- existing_uuid = uuid_map[node.uuid]
214
- existing_node = node_map[existing_uuid]
215
- nodes.append(existing_node)
216
- else:
217
- nodes.append(node)
218
207
 
219
- return nodes, uuid_map
208
+ return extracted_nodes
220
209
 
221
210
 
222
- async def resolve_extracted_nodes(
211
+ async def _collect_candidate_nodes(
223
212
  clients: GraphitiClients,
224
213
  extracted_nodes: list[EntityNode],
225
- episode: EpisodicNode | None = None,
226
- previous_episodes: list[EpisodicNode] | None = None,
227
- entity_types: dict[str, BaseModel] | None = None,
228
- ) -> tuple[list[EntityNode], dict[str, str]]:
229
- llm_client = clients.llm_client
230
-
214
+ existing_nodes_override: list[EntityNode] | None,
215
+ ) -> list[EntityNode]:
216
+ """Search per extracted name and return unique candidates with overrides honored in order."""
231
217
  search_results: list[SearchResults] = await semaphore_gather(
232
218
  *[
233
219
  search(
@@ -241,11 +227,43 @@ async def resolve_extracted_nodes(
241
227
  ]
242
228
  )
243
229
 
244
- existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
230
+ candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]
245
231
 
246
- entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
232
+ if existing_nodes_override is not None:
233
+ candidate_nodes.extend(existing_nodes_override)
234
+
235
+ seen_candidate_uuids: set[str] = set()
236
+ ordered_candidates: list[EntityNode] = []
237
+ for candidate in candidate_nodes:
238
+ if candidate.uuid in seen_candidate_uuids:
239
+ continue
240
+ seen_candidate_uuids.add(candidate.uuid)
241
+ ordered_candidates.append(candidate)
242
+
243
+ return ordered_candidates
244
+
245
+
246
+ async def _resolve_with_llm(
247
+ llm_client: LLMClient,
248
+ extracted_nodes: list[EntityNode],
249
+ indexes: DedupCandidateIndexes,
250
+ state: DedupResolutionState,
251
+ episode: EpisodicNode | None,
252
+ previous_episodes: list[EpisodicNode] | None,
253
+ entity_types: dict[str, type[BaseModel]] | None,
254
+ ) -> None:
255
+ """Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
256
+
257
+ The guardrails below defensively ignore malformed or duplicate LLM responses so the
258
+ ingestion workflow remains deterministic even when the model misbehaves.
259
+ """
260
+ if not state.unresolved_indices:
261
+ return
262
+
263
+ entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
264
+
265
+ llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]
247
266
 
248
- # Prepare context for LLM
249
267
  extracted_nodes_context = [
250
268
  {
251
269
  'id': i,
@@ -255,122 +273,181 @@ async def resolve_extracted_nodes(
255
273
  next((item for item in node.labels if item != 'Entity'), '')
256
274
  ).__doc__
257
275
  or 'Default Entity Type',
258
- 'duplication_candidates': [
259
- {
260
- **{
261
- 'idx': j,
262
- 'name': candidate.name,
263
- 'entity_types': candidate.labels,
264
- },
265
- **candidate.attributes,
266
- }
267
- for j, candidate in enumerate(existing_nodes_lists[i])
268
- ],
269
276
  }
270
- for i, node in enumerate(extracted_nodes)
277
+ for i, node in enumerate(llm_extracted_nodes)
278
+ ]
279
+
280
+ sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
281
+ logger.debug(
282
+ 'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
283
+ len(llm_extracted_nodes),
284
+ len(llm_extracted_nodes) - 1,
285
+ sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
286
+ )
287
+ if llm_extracted_nodes:
288
+ sample_size = min(3, len(extracted_nodes_context))
289
+ logger.debug(
290
+ 'First %d entities: %s',
291
+ sample_size,
292
+ [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
293
+ )
294
+ if len(extracted_nodes_context) > 3:
295
+ logger.debug(
296
+ 'Last %d entities: %s',
297
+ sample_size,
298
+ [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
299
+ )
300
+
301
+ existing_nodes_context = [
302
+ {
303
+ **{
304
+ 'idx': i,
305
+ 'name': candidate.name,
306
+ 'entity_types': candidate.labels,
307
+ },
308
+ **candidate.attributes,
309
+ }
310
+ for i, candidate in enumerate(indexes.existing_nodes)
271
311
  ]
272
312
 
273
313
  context = {
274
314
  'extracted_nodes': extracted_nodes_context,
315
+ 'existing_nodes': existing_nodes_context,
275
316
  'episode_content': episode.content if episode is not None else '',
276
- 'previous_episodes': [ep.content for ep in previous_episodes]
277
- if previous_episodes is not None
278
- else [],
317
+ 'previous_episodes': (
318
+ [ep.content for ep in previous_episodes] if previous_episodes is not None else []
319
+ ),
279
320
  }
280
321
 
281
322
  llm_response = await llm_client.generate_response(
282
323
  prompt_library.dedupe_nodes.nodes(context),
283
324
  response_model=NodeResolutions,
325
+ prompt_name='dedupe_nodes.nodes',
284
326
  )
285
327
 
286
- node_resolutions: list = llm_response.get('entity_resolutions', [])
328
+ node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
287
329
 
288
- resolved_nodes: list[EntityNode] = []
289
- uuid_map: dict[str, str] = {}
290
- for resolution in node_resolutions:
291
- resolution_id = resolution.get('id', -1)
292
- duplicate_idx = resolution.get('duplicate_idx', -1)
330
+ valid_relative_range = range(len(state.unresolved_indices))
331
+ processed_relative_ids: set[int] = set()
332
+
333
+ received_ids = {r.id for r in node_resolutions}
334
+ expected_ids = set(valid_relative_range)
335
+ missing_ids = expected_ids - received_ids
336
+ extra_ids = received_ids - expected_ids
337
+
338
+ logger.debug(
339
+ 'Received %d resolutions for %d entities',
340
+ len(node_resolutions),
341
+ len(state.unresolved_indices),
342
+ )
293
343
 
294
- extracted_node = extracted_nodes[resolution_id]
344
+ if missing_ids:
345
+ logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))
295
346
 
296
- resolved_node = (
297
- existing_nodes_lists[resolution_id][duplicate_idx]
298
- if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
299
- else extracted_node
347
+ if extra_ids:
348
+ logger.warning(
349
+ 'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
350
+ len(state.unresolved_indices) - 1,
351
+ sorted(extra_ids),
352
+ sorted(received_ids),
300
353
  )
301
354
 
302
- resolved_node.name = resolution.get('name')
355
+ for resolution in node_resolutions:
356
+ relative_id: int = resolution.id
357
+ duplicate_idx: int = resolution.duplicate_idx
358
+
359
+ if relative_id not in valid_relative_range:
360
+ logger.warning(
361
+ 'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
362
+ relative_id,
363
+ len(state.unresolved_indices) - 1,
364
+ len(node_resolutions),
365
+ )
366
+ continue
303
367
 
304
- resolved_nodes.append(resolved_node)
305
- uuid_map[extracted_node.uuid] = resolved_node.uuid
368
+ if relative_id in processed_relative_ids:
369
+ logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
370
+ continue
371
+ processed_relative_ids.add(relative_id)
306
372
 
307
- logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
373
+ original_index = state.unresolved_indices[relative_id]
374
+ extracted_node = extracted_nodes[original_index]
308
375
 
309
- return resolved_nodes, uuid_map
376
+ resolved_node: EntityNode
377
+ if duplicate_idx == -1:
378
+ resolved_node = extracted_node
379
+ elif 0 <= duplicate_idx < len(indexes.existing_nodes):
380
+ resolved_node = indexes.existing_nodes[duplicate_idx]
381
+ else:
382
+ logger.warning(
383
+ 'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
384
+ duplicate_idx,
385
+ extracted_node.uuid,
386
+ )
387
+ resolved_node = extracted_node
310
388
 
389
+ state.resolved_nodes[original_index] = resolved_node
390
+ state.uuid_map[extracted_node.uuid] = resolved_node.uuid
391
+ if resolved_node.uuid != extracted_node.uuid:
392
+ state.duplicate_pairs.append((extracted_node, resolved_node))
311
393
 
312
- async def resolve_extracted_node(
313
- llm_client: LLMClient,
314
- extracted_node: EntityNode,
315
- existing_nodes: list[EntityNode],
394
+
395
+ async def resolve_extracted_nodes(
396
+ clients: GraphitiClients,
397
+ extracted_nodes: list[EntityNode],
316
398
  episode: EpisodicNode | None = None,
317
399
  previous_episodes: list[EpisodicNode] | None = None,
318
- entity_type: BaseModel | None = None,
319
- ) -> EntityNode:
320
- start = time()
321
- if len(existing_nodes) == 0:
322
- return extracted_node
323
-
324
- # Prepare context for LLM
325
- existing_nodes_context = [
326
- {
327
- **{
328
- 'id': i,
329
- 'name': node.name,
330
- 'entity_types': node.labels,
331
- },
332
- **node.attributes,
333
- }
334
- for i, node in enumerate(existing_nodes)
335
- ]
336
-
337
- extracted_node_context = {
338
- 'name': extracted_node.name,
339
- 'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore
340
- }
400
+ entity_types: dict[str, type[BaseModel]] | None = None,
401
+ existing_nodes_override: list[EntityNode] | None = None,
402
+ ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
403
+ """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
404
+ llm_client = clients.llm_client
405
+ driver = clients.driver
406
+ existing_nodes = await _collect_candidate_nodes(
407
+ clients,
408
+ extracted_nodes,
409
+ existing_nodes_override,
410
+ )
341
411
 
342
- context = {
343
- 'existing_nodes': existing_nodes_context,
344
- 'extracted_node': extracted_node_context,
345
- 'entity_type_description': entity_type.__doc__
346
- if entity_type is not None
347
- else 'Default Entity Type',
348
- 'episode_content': episode.content if episode is not None else '',
349
- 'previous_episodes': [ep.content for ep in previous_episodes]
350
- if previous_episodes is not None
351
- else [],
352
- }
412
+ indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)
353
413
 
354
- llm_response = await llm_client.generate_response(
355
- prompt_library.dedupe_nodes.node(context),
356
- response_model=NodeDuplicate,
357
- model_size=ModelSize.small,
414
+ state = DedupResolutionState(
415
+ resolved_nodes=[None] * len(extracted_nodes),
416
+ uuid_map={},
417
+ unresolved_indices=[],
358
418
  )
359
419
 
360
- duplicate_id: int = llm_response.get('duplicate_node_id', -1)
420
+ _resolve_with_similarity(extracted_nodes, indexes, state)
361
421
 
362
- node = (
363
- existing_nodes[duplicate_id] if 0 <= duplicate_id < len(existing_nodes) else extracted_node
422
+ await _resolve_with_llm(
423
+ llm_client,
424
+ extracted_nodes,
425
+ indexes,
426
+ state,
427
+ episode,
428
+ previous_episodes,
429
+ entity_types,
364
430
  )
365
431
 
366
- node.name = llm_response.get('name', '')
432
+ for idx, node in enumerate(extracted_nodes):
433
+ if state.resolved_nodes[idx] is None:
434
+ state.resolved_nodes[idx] = node
435
+ state.uuid_map[node.uuid] = node.uuid
367
436
 
368
- end = time()
369
437
  logger.debug(
370
- f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
438
+ 'Resolved nodes: %s',
439
+ [(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
371
440
  )
372
441
 
373
- return node
442
+ new_node_duplicates: list[
443
+ tuple[EntityNode, EntityNode]
444
+ ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
445
+
446
+ return (
447
+ [node for node in state.resolved_nodes if node is not None],
448
+ state.uuid_map,
449
+ new_node_duplicates,
450
+ )
374
451
 
375
452
 
376
453
  async def extract_attributes_from_nodes(
@@ -378,11 +455,11 @@ async def extract_attributes_from_nodes(
378
455
  nodes: list[EntityNode],
379
456
  episode: EpisodicNode | None = None,
380
457
  previous_episodes: list[EpisodicNode] | None = None,
381
- entity_types: dict[str, BaseModel] | None = None,
458
+ entity_types: dict[str, type[BaseModel]] | None = None,
459
+ should_summarize_node: NodeSummaryFilter | None = None,
382
460
  ) -> list[EntityNode]:
383
461
  llm_client = clients.llm_client
384
462
  embedder = clients.embedder
385
-
386
463
  updated_nodes: list[EntityNode] = await semaphore_gather(
387
464
  *[
388
465
  extract_attributes_from_node(
@@ -390,9 +467,12 @@ async def extract_attributes_from_nodes(
390
467
  node,
391
468
  episode,
392
469
  previous_episodes,
393
- entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
394
- if entity_types is not None
395
- else None,
470
+ (
471
+ entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
472
+ if entity_types is not None
473
+ else None
474
+ ),
475
+ should_summarize_node,
396
476
  )
397
477
  for node in nodes
398
478
  ]
@@ -408,99 +488,100 @@ async def extract_attributes_from_node(
408
488
  node: EntityNode,
409
489
  episode: EpisodicNode | None = None,
410
490
  previous_episodes: list[EpisodicNode] | None = None,
411
- entity_type: BaseModel | None = None,
491
+ entity_type: type[BaseModel] | None = None,
492
+ should_summarize_node: NodeSummaryFilter | None = None,
412
493
  ) -> EntityNode:
413
- node_context: dict[str, Any] = {
414
- 'name': node.name,
415
- 'summary': node.summary,
416
- 'entity_types': node.labels,
417
- 'attributes': node.attributes,
418
- }
494
+ # Extract attributes if entity type is defined and has attributes
495
+ llm_response = await _extract_entity_attributes(
496
+ llm_client, node, episode, previous_episodes, entity_type
497
+ )
419
498
 
420
- attributes_definitions: dict[str, Any] = {
421
- 'summary': (
422
- str,
423
- Field(
424
- description='Summary containing the important information about the entity. Under 250 words',
425
- ),
426
- )
427
- }
499
+ # Extract summary if needed
500
+ await _extract_entity_summary(
501
+ llm_client, node, episode, previous_episodes, should_summarize_node
502
+ )
428
503
 
429
- if entity_type is not None:
430
- for field_name, field_info in entity_type.model_fields.items():
431
- attributes_definitions[field_name] = (
432
- field_info.annotation,
433
- Field(description=field_info.description),
434
- )
504
+ node.attributes.update(llm_response)
435
505
 
436
- unique_model_name = f'EntityAttributes_{uuid4().hex}'
437
- entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions)
506
+ return node
438
507
 
439
- summary_context: dict[str, Any] = {
440
- 'node': node_context,
441
- 'episode_content': episode.content if episode is not None else '',
442
- 'previous_episodes': [ep.content for ep in previous_episodes]
443
- if previous_episodes is not None
444
- else [],
445
- }
508
+
509
+ async def _extract_entity_attributes(
510
+ llm_client: LLMClient,
511
+ node: EntityNode,
512
+ episode: EpisodicNode | None,
513
+ previous_episodes: list[EpisodicNode] | None,
514
+ entity_type: type[BaseModel] | None,
515
+ ) -> dict[str, Any]:
516
+ if entity_type is None or len(entity_type.model_fields) == 0:
517
+ return {}
518
+
519
+ attributes_context = _build_episode_context(
520
+ # should not include summary
521
+ node_data={
522
+ 'name': node.name,
523
+ 'entity_types': node.labels,
524
+ 'attributes': node.attributes,
525
+ },
526
+ episode=episode,
527
+ previous_episodes=previous_episodes,
528
+ )
446
529
 
447
530
  llm_response = await llm_client.generate_response(
448
- prompt_library.extract_nodes.extract_attributes(summary_context),
449
- response_model=entity_attributes_model,
531
+ prompt_library.extract_nodes.extract_attributes(attributes_context),
532
+ response_model=entity_type,
450
533
  model_size=ModelSize.small,
534
+ group_id=node.group_id,
535
+ prompt_name='extract_nodes.extract_attributes',
451
536
  )
452
537
 
453
- node.summary = llm_response.get('summary', node.summary)
454
- node_attributes = {key: value for key, value in llm_response.items()}
455
-
456
- with suppress(KeyError):
457
- del node_attributes['summary']
538
+ # validate response
539
+ entity_type(**llm_response)
458
540
 
459
- node.attributes.update(node_attributes)
460
-
461
- return node
541
+ return llm_response
462
542
 
463
543
 
464
- async def dedupe_node_list(
544
+ async def _extract_entity_summary(
465
545
  llm_client: LLMClient,
466
- nodes: list[EntityNode],
467
- ) -> tuple[list[EntityNode], dict[str, str]]:
468
- start = time()
469
-
470
- # build node map
471
- node_map = {}
472
- for node in nodes:
473
- node_map[node.uuid] = node
474
-
475
- # Prepare context for LLM
476
- nodes_context = [{'uuid': node.uuid, 'name': node.name, **node.attributes} for node in nodes]
477
-
478
- context = {
479
- 'nodes': nodes_context,
480
- }
481
-
482
- llm_response = await llm_client.generate_response(
483
- prompt_library.dedupe_nodes.node_list(context)
546
+ node: EntityNode,
547
+ episode: EpisodicNode | None,
548
+ previous_episodes: list[EpisodicNode] | None,
549
+ should_summarize_node: NodeSummaryFilter | None,
550
+ ) -> None:
551
+ if should_summarize_node is not None and not await should_summarize_node(node):
552
+ return
553
+
554
+ summary_context = _build_episode_context(
555
+ node_data={
556
+ 'name': node.name,
557
+ 'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
558
+ 'entity_types': node.labels,
559
+ 'attributes': node.attributes,
560
+ },
561
+ episode=episode,
562
+ previous_episodes=previous_episodes,
484
563
  )
485
564
 
486
- nodes_data = llm_response.get('nodes', [])
565
+ summary_response = await llm_client.generate_response(
566
+ prompt_library.extract_nodes.extract_summary(summary_context),
567
+ response_model=EntitySummary,
568
+ model_size=ModelSize.small,
569
+ group_id=node.group_id,
570
+ prompt_name='extract_nodes.extract_summary',
571
+ )
487
572
 
488
- end = time()
489
- logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
490
-
491
- # Get full node data
492
- unique_nodes = []
493
- uuid_map: dict[str, str] = {}
494
- for node_data in nodes_data:
495
- node_instance: EntityNode | None = node_map.get(node_data['uuids'][0])
496
- if node_instance is None:
497
- logger.warning(f'Node {node_data["uuids"][0]} not found in node map')
498
- continue
499
- node_instance.summary = node_data['summary']
500
- unique_nodes.append(node_instance)
573
+ node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
501
574
 
502
- for uuid in node_data['uuids'][1:]:
503
- uuid_value = node_map[node_data['uuids'][0]].uuid
504
- uuid_map[uuid] = uuid_value
505
575
 
506
- return unique_nodes, uuid_map
576
+ def _build_episode_context(
577
+ node_data: dict[str, Any],
578
+ episode: EpisodicNode | None,
579
+ previous_episodes: list[EpisodicNode] | None,
580
+ ) -> dict[str, Any]:
581
+ return {
582
+ 'node': node_data,
583
+ 'episode_content': episode.content if episode is not None else '',
584
+ 'previous_episodes': (
585
+ [ep.content for ep in previous_episodes] if previous_episodes is not None else []
586
+ ),
587
+ }