graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +70 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +635 -260
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +37 -15
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +92 -48
  24. graphiti_core/llm_client/openai_client.py +39 -9
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +24 -15
  33. graphiti_core/prompts/extract_nodes.py +76 -35
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +110 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/content_chunking.py +702 -0
  45. graphiti_core/utils/datetime_utils.py +13 -0
  46. graphiti_core/utils/maintenance/community_operations.py +62 -38
  47. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  48. graphiti_core/utils/maintenance/edge_operations.py +306 -156
  49. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  50. graphiti_core/utils/maintenance/node_operations.py +466 -206
  51. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  52. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  53. graphiti_core/utils/text_utils.py +53 -0
  54. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
  55. graphiti_core-0.25.3.dist-info/RECORD +87 -0
  56. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
  57. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  58. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  59. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
@@ -15,22 +15,26 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import logging
18
- from contextlib import suppress
18
+ from collections.abc import Awaitable, Callable
19
19
  from time import time
20
20
  from typing import Any
21
- from uuid import uuid4
22
21
 
23
- import pydantic
24
- from pydantic import BaseModel, Field
22
+ from pydantic import BaseModel
25
23
 
26
24
  from graphiti_core.graphiti_types import GraphitiClients
27
- from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
25
+ from graphiti_core.helpers import semaphore_gather
28
26
  from graphiti_core.llm_client import LLMClient
29
27
  from graphiti_core.llm_client.config import ModelSize
30
- from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
28
+ from graphiti_core.nodes import (
29
+ EntityNode,
30
+ EpisodeType,
31
+ EpisodicNode,
32
+ create_entity_node_embeddings,
33
+ )
31
34
  from graphiti_core.prompts import prompt_library
32
- from graphiti_core.prompts.dedupe_nodes import NodeResolutions
35
+ from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
33
36
  from graphiti_core.prompts.extract_nodes import (
37
+ EntitySummary,
34
38
  ExtractedEntities,
35
39
  ExtractedEntity,
36
40
  MissedEntities,
@@ -39,17 +43,35 @@ from graphiti_core.search.search import search
39
43
  from graphiti_core.search.search_config import SearchResults
40
44
  from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
41
45
  from graphiti_core.search.search_filters import SearchFilters
46
+ from graphiti_core.utils.content_chunking import (
47
+ chunk_json_content,
48
+ chunk_message_content,
49
+ chunk_text_content,
50
+ should_chunk,
51
+ )
42
52
  from graphiti_core.utils.datetime_utils import utc_now
43
- from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges
53
+ from graphiti_core.utils.maintenance.dedup_helpers import (
54
+ DedupCandidateIndexes,
55
+ DedupResolutionState,
56
+ _build_candidate_indexes,
57
+ _resolve_with_similarity,
58
+ )
59
+ from graphiti_core.utils.maintenance.edge_operations import (
60
+ filter_existing_duplicate_of_edges,
61
+ )
62
+ from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
44
63
 
45
64
  logger = logging.getLogger(__name__)
46
65
 
66
+ NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
67
+
47
68
 
48
69
  async def extract_nodes_reflexion(
49
70
  llm_client: LLMClient,
50
71
  episode: EpisodicNode,
51
72
  previous_episodes: list[EpisodicNode],
52
73
  node_names: list[str],
74
+ group_id: str | None = None,
53
75
  ) -> list[str]:
54
76
  # Prepare context for LLM
55
77
  context = {
@@ -59,7 +81,10 @@ async def extract_nodes_reflexion(
59
81
  }
60
82
 
61
83
  llm_response = await llm_client.generate_response(
62
- prompt_library.extract_nodes.reflexion(context), MissedEntities
84
+ prompt_library.extract_nodes.reflexion(context),
85
+ MissedEntities,
86
+ group_id=group_id,
87
+ prompt_name='extract_nodes.reflexion',
63
88
  )
64
89
  missed_entities = llm_response.get('missed_entities', [])
65
90
 
@@ -70,26 +95,69 @@ async def extract_nodes(
70
95
  clients: GraphitiClients,
71
96
  episode: EpisodicNode,
72
97
  previous_episodes: list[EpisodicNode],
73
- entity_types: dict[str, BaseModel] | None = None,
98
+ entity_types: dict[str, type[BaseModel]] | None = None,
74
99
  excluded_entity_types: list[str] | None = None,
100
+ custom_extraction_instructions: str | None = None,
75
101
  ) -> list[EntityNode]:
102
+ """Extract entity nodes from an episode with adaptive chunking.
103
+
104
+ For high-density content (many entities per token), the content is chunked
105
+ and processed in parallel to avoid LLM timeouts and truncation issues.
106
+ """
76
107
  start = time()
77
108
  llm_client = clients.llm_client
78
- llm_response = {}
79
- custom_prompt = ''
80
- entities_missed = True
81
- reflexion_iterations = 0
82
109
 
110
+ # Build entity types context
111
+ entity_types_context = _build_entity_types_context(entity_types)
112
+
113
+ # Build base context
114
+ context = {
115
+ 'episode_content': episode.content,
116
+ 'episode_timestamp': episode.valid_at.isoformat(),
117
+ 'previous_episodes': [ep.content for ep in previous_episodes],
118
+ 'custom_extraction_instructions': custom_extraction_instructions or '',
119
+ 'entity_types': entity_types_context,
120
+ 'source_description': episode.source_description,
121
+ }
122
+
123
+ # Check if chunking is needed (based on entity density)
124
+ if should_chunk(episode.content, episode.source):
125
+ extracted_entities = await _extract_nodes_chunked(llm_client, episode, context)
126
+ else:
127
+ extracted_entities = await _extract_nodes_single(llm_client, episode, context)
128
+
129
+ # Filter empty names
130
+ filtered_entities = [e for e in extracted_entities if e.name.strip()]
131
+
132
+ end = time()
133
+ logger.debug(f'Extracted {len(filtered_entities)} entities in {(end - start) * 1000:.0f} ms')
134
+
135
+ # Convert to EntityNode objects
136
+ extracted_nodes = _create_entity_nodes(
137
+ filtered_entities, entity_types_context, excluded_entity_types, episode
138
+ )
139
+
140
+ logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
141
+ return extracted_nodes
142
+
143
+
144
+ def _build_entity_types_context(
145
+ entity_types: dict[str, type[BaseModel]] | None,
146
+ ) -> list[dict]:
147
+ """Build entity types context with ID mappings."""
83
148
  entity_types_context = [
84
149
  {
85
150
  'entity_type_id': 0,
86
151
  'entity_type_name': 'Entity',
87
- 'entity_type_description': 'Default entity classification. Use this entity type if the entity is not one of the other listed types.',
152
+ 'entity_type_description': (
153
+ 'Default entity classification. Use this entity type '
154
+ 'if the entity is not one of the other listed types.'
155
+ ),
88
156
  }
89
157
  ]
90
158
 
91
- entity_types_context += (
92
- [
159
+ if entity_types is not None:
160
+ entity_types_context += [
93
161
  {
94
162
  'entity_type_id': i + 1,
95
163
  'entity_type_name': type_name,
@@ -97,63 +165,126 @@ async def extract_nodes(
97
165
  }
98
166
  for i, (type_name, type_model) in enumerate(entity_types.items())
99
167
  ]
100
- if entity_types is not None
101
- else []
168
+
169
+ return entity_types_context
170
+
171
+
172
+ async def _extract_nodes_single(
173
+ llm_client: LLMClient,
174
+ episode: EpisodicNode,
175
+ context: dict,
176
+ ) -> list[ExtractedEntity]:
177
+ """Extract entities using a single LLM call."""
178
+ llm_response = await _call_extraction_llm(llm_client, episode, context)
179
+ response_object = ExtractedEntities(**llm_response)
180
+ return response_object.extracted_entities
181
+
182
+
183
+ async def _extract_nodes_chunked(
184
+ llm_client: LLMClient,
185
+ episode: EpisodicNode,
186
+ context: dict,
187
+ ) -> list[ExtractedEntity]:
188
+ """Extract entities from large content using chunking."""
189
+ # Chunk the content based on episode type
190
+ if episode.source == EpisodeType.json:
191
+ chunks = chunk_json_content(episode.content)
192
+ elif episode.source == EpisodeType.message:
193
+ chunks = chunk_message_content(episode.content)
194
+ else:
195
+ chunks = chunk_text_content(episode.content)
196
+
197
+ logger.debug(f'Chunked content into {len(chunks)} chunks for entity extraction')
198
+
199
+ # Extract entities from each chunk in parallel
200
+ chunk_results = await semaphore_gather(
201
+ *[_extract_from_chunk(llm_client, chunk, context, episode) for chunk in chunks]
102
202
  )
103
203
 
104
- context = {
105
- 'episode_content': episode.content,
106
- 'episode_timestamp': episode.valid_at.isoformat(),
107
- 'previous_episodes': [ep.content for ep in previous_episodes],
108
- 'custom_prompt': custom_prompt,
109
- 'entity_types': entity_types_context,
110
- 'source_description': episode.source_description,
111
- }
204
+ # Merge and deduplicate entities across chunks
205
+ merged_entities = _merge_extracted_entities(chunk_results)
206
+ logger.debug(
207
+ f'Merged {sum(len(r) for r in chunk_results)} entities into {len(merged_entities)} unique'
208
+ )
112
209
 
113
- while entities_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
114
- if episode.source == EpisodeType.message:
115
- llm_response = await llm_client.generate_response(
116
- prompt_library.extract_nodes.extract_message(context),
117
- response_model=ExtractedEntities,
118
- )
119
- elif episode.source == EpisodeType.text:
120
- llm_response = await llm_client.generate_response(
121
- prompt_library.extract_nodes.extract_text(context), response_model=ExtractedEntities
122
- )
123
- elif episode.source == EpisodeType.json:
124
- llm_response = await llm_client.generate_response(
125
- prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities
126
- )
210
+ return merged_entities
127
211
 
128
- extracted_entities: list[ExtractedEntity] = [
129
- ExtractedEntity(**entity_types_context)
130
- for entity_types_context in llm_response.get('extracted_entities', [])
131
- ]
132
212
 
133
- reflexion_iterations += 1
134
- if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
135
- missing_entities = await extract_nodes_reflexion(
136
- llm_client,
137
- episode,
138
- previous_episodes,
139
- [entity.name for entity in extracted_entities],
140
- )
213
+ async def _extract_from_chunk(
214
+ llm_client: LLMClient,
215
+ chunk: str,
216
+ base_context: dict,
217
+ episode: EpisodicNode,
218
+ ) -> list[ExtractedEntity]:
219
+ """Extract entities from a single chunk."""
220
+ chunk_context = {**base_context, 'episode_content': chunk}
221
+ llm_response = await _call_extraction_llm(llm_client, episode, chunk_context)
222
+ return ExtractedEntities(**llm_response).extracted_entities
223
+
224
+
225
+ async def _call_extraction_llm(
226
+ llm_client: LLMClient,
227
+ episode: EpisodicNode,
228
+ context: dict,
229
+ ) -> dict:
230
+ """Call the appropriate extraction prompt based on episode type."""
231
+ if episode.source == EpisodeType.message:
232
+ prompt = prompt_library.extract_nodes.extract_message(context)
233
+ prompt_name = 'extract_nodes.extract_message'
234
+ elif episode.source == EpisodeType.text:
235
+ prompt = prompt_library.extract_nodes.extract_text(context)
236
+ prompt_name = 'extract_nodes.extract_text'
237
+ elif episode.source == EpisodeType.json:
238
+ prompt = prompt_library.extract_nodes.extract_json(context)
239
+ prompt_name = 'extract_nodes.extract_json'
240
+ else:
241
+ # Fallback to text extraction
242
+ prompt = prompt_library.extract_nodes.extract_text(context)
243
+ prompt_name = 'extract_nodes.extract_text'
244
+
245
+ return await llm_client.generate_response(
246
+ prompt,
247
+ response_model=ExtractedEntities,
248
+ group_id=episode.group_id,
249
+ prompt_name=prompt_name,
250
+ )
141
251
 
142
- entities_missed = len(missing_entities) != 0
143
252
 
144
- custom_prompt = 'Make sure that the following entities are extracted: '
145
- for entity in missing_entities:
146
- custom_prompt += f'\n{entity},'
253
+ def _merge_extracted_entities(
254
+ chunk_results: list[list[ExtractedEntity]],
255
+ ) -> list[ExtractedEntity]:
256
+ """Merge entities from multiple chunks, deduplicating by normalized name.
147
257
 
148
- filtered_extracted_entities = [entity for entity in extracted_entities if entity.name.strip()]
149
- end = time()
150
- logger.debug(f'Extracted new nodes: {filtered_extracted_entities} in {(end - start) * 1000} ms')
151
- # Convert the extracted data into EntityNode objects
258
+ When duplicates occur, prefer the first occurrence (maintains ordering).
259
+ """
260
+ seen_names: set[str] = set()
261
+ merged: list[ExtractedEntity] = []
262
+
263
+ for entities in chunk_results:
264
+ for entity in entities:
265
+ normalized = entity.name.strip().lower()
266
+ if normalized and normalized not in seen_names:
267
+ seen_names.add(normalized)
268
+ merged.append(entity)
269
+
270
+ return merged
271
+
272
+
273
+ def _create_entity_nodes(
274
+ extracted_entities: list[ExtractedEntity],
275
+ entity_types_context: list[dict],
276
+ excluded_entity_types: list[str] | None,
277
+ episode: EpisodicNode,
278
+ ) -> list[EntityNode]:
279
+ """Convert ExtractedEntity objects to EntityNode objects."""
152
280
  extracted_nodes = []
153
- for extracted_entity in filtered_extracted_entities:
154
- entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
155
- 'entity_type_name'
156
- )
281
+
282
+ for extracted_entity in extracted_entities:
283
+ type_id = extracted_entity.entity_type_id
284
+ if 0 <= type_id < len(entity_types_context):
285
+ entity_type_name = entity_types_context[type_id].get('entity_type_name')
286
+ else:
287
+ entity_type_name = 'Entity'
157
288
 
158
289
  # Check if this entity type should be excluded
159
290
  if excluded_entity_types and entity_type_name in excluded_entity_types:
@@ -172,21 +303,15 @@ async def extract_nodes(
172
303
  extracted_nodes.append(new_node)
173
304
  logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
174
305
 
175
- logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
176
306
  return extracted_nodes
177
307
 
178
308
 
179
- async def resolve_extracted_nodes(
309
+ async def _collect_candidate_nodes(
180
310
  clients: GraphitiClients,
181
311
  extracted_nodes: list[EntityNode],
182
- episode: EpisodicNode | None = None,
183
- previous_episodes: list[EpisodicNode] | None = None,
184
- entity_types: dict[str, BaseModel] | None = None,
185
- existing_nodes_override: list[EntityNode] | None = None,
186
- ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
187
- llm_client = clients.llm_client
188
- driver = clients.driver
189
-
312
+ existing_nodes_override: list[EntityNode] | None,
313
+ ) -> list[EntityNode]:
314
+ """Search per extracted name and return unique candidates with overrides honored in order."""
190
315
  search_results: list[SearchResults] = await semaphore_gather(
191
316
  *[
192
317
  search(
@@ -200,33 +325,43 @@ async def resolve_extracted_nodes(
200
325
  ]
201
326
  )
202
327
 
203
- candidate_nodes: list[EntityNode] = (
204
- [node for result in search_results for node in result.nodes]
205
- if existing_nodes_override is None
206
- else existing_nodes_override
207
- )
328
+ candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]
208
329
 
209
- existing_nodes_dict: dict[str, EntityNode] = {node.uuid: node for node in candidate_nodes}
330
+ if existing_nodes_override is not None:
331
+ candidate_nodes.extend(existing_nodes_override)
210
332
 
211
- existing_nodes: list[EntityNode] = list(existing_nodes_dict.values())
333
+ seen_candidate_uuids: set[str] = set()
334
+ ordered_candidates: list[EntityNode] = []
335
+ for candidate in candidate_nodes:
336
+ if candidate.uuid in seen_candidate_uuids:
337
+ continue
338
+ seen_candidate_uuids.add(candidate.uuid)
339
+ ordered_candidates.append(candidate)
212
340
 
213
- existing_nodes_context = (
214
- [
215
- {
216
- **{
217
- 'idx': i,
218
- 'name': candidate.name,
219
- 'entity_types': candidate.labels,
220
- },
221
- **candidate.attributes,
222
- }
223
- for i, candidate in enumerate(existing_nodes)
224
- ],
225
- )
341
+ return ordered_candidates
226
342
 
227
- entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
228
343
 
229
- # Prepare context for LLM
344
+ async def _resolve_with_llm(
345
+ llm_client: LLMClient,
346
+ extracted_nodes: list[EntityNode],
347
+ indexes: DedupCandidateIndexes,
348
+ state: DedupResolutionState,
349
+ episode: EpisodicNode | None,
350
+ previous_episodes: list[EpisodicNode] | None,
351
+ entity_types: dict[str, type[BaseModel]] | None,
352
+ ) -> None:
353
+ """Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
354
+
355
+ The guardrails below defensively ignore malformed or duplicate LLM responses so the
356
+ ingestion workflow remains deterministic even when the model misbehaves.
357
+ """
358
+ if not state.unresolved_indices:
359
+ return
360
+
361
+ entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
362
+
363
+ llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]
364
+
230
365
  extracted_nodes_context = [
231
366
  {
232
367
  'id': i,
@@ -237,60 +372,180 @@ async def resolve_extracted_nodes(
237
372
  ).__doc__
238
373
  or 'Default Entity Type',
239
374
  }
240
- for i, node in enumerate(extracted_nodes)
375
+ for i, node in enumerate(llm_extracted_nodes)
376
+ ]
377
+
378
+ sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
379
+ logger.debug(
380
+ 'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
381
+ len(llm_extracted_nodes),
382
+ len(llm_extracted_nodes) - 1,
383
+ sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
384
+ )
385
+ if llm_extracted_nodes:
386
+ sample_size = min(3, len(extracted_nodes_context))
387
+ logger.debug(
388
+ 'First %d entities: %s',
389
+ sample_size,
390
+ [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
391
+ )
392
+ if len(extracted_nodes_context) > 3:
393
+ logger.debug(
394
+ 'Last %d entities: %s',
395
+ sample_size,
396
+ [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
397
+ )
398
+
399
+ existing_nodes_context = [
400
+ {
401
+ **{
402
+ 'idx': i,
403
+ 'name': candidate.name,
404
+ 'entity_types': candidate.labels,
405
+ },
406
+ **candidate.attributes,
407
+ }
408
+ for i, candidate in enumerate(indexes.existing_nodes)
241
409
  ]
242
410
 
243
411
  context = {
244
412
  'extracted_nodes': extracted_nodes_context,
245
413
  'existing_nodes': existing_nodes_context,
246
414
  'episode_content': episode.content if episode is not None else '',
247
- 'previous_episodes': [ep.content for ep in previous_episodes]
248
- if previous_episodes is not None
249
- else [],
415
+ 'previous_episodes': (
416
+ [ep.content for ep in previous_episodes] if previous_episodes is not None else []
417
+ ),
250
418
  }
251
419
 
252
420
  llm_response = await llm_client.generate_response(
253
421
  prompt_library.dedupe_nodes.nodes(context),
254
422
  response_model=NodeResolutions,
423
+ prompt_name='dedupe_nodes.nodes',
255
424
  )
256
425
 
257
- node_resolutions: list = llm_response.get('entity_resolutions', [])
426
+ node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
258
427
 
259
- resolved_nodes: list[EntityNode] = []
260
- uuid_map: dict[str, str] = {}
261
- node_duplicates: list[tuple[EntityNode, EntityNode]] = []
262
- for resolution in node_resolutions:
263
- resolution_id: int = resolution.get('id', -1)
264
- duplicate_idx: int = resolution.get('duplicate_idx', -1)
428
+ valid_relative_range = range(len(state.unresolved_indices))
429
+ processed_relative_ids: set[int] = set()
430
+
431
+ received_ids = {r.id for r in node_resolutions}
432
+ expected_ids = set(valid_relative_range)
433
+ missing_ids = expected_ids - received_ids
434
+ extra_ids = received_ids - expected_ids
265
435
 
266
- extracted_node = extracted_nodes[resolution_id]
436
+ logger.debug(
437
+ 'Received %d resolutions for %d entities',
438
+ len(node_resolutions),
439
+ len(state.unresolved_indices),
440
+ )
267
441
 
268
- resolved_node = (
269
- existing_nodes[duplicate_idx]
270
- if 0 <= duplicate_idx < len(existing_nodes)
271
- else extracted_node
442
+ if missing_ids:
443
+ logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))
444
+
445
+ if extra_ids:
446
+ logger.warning(
447
+ 'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
448
+ len(state.unresolved_indices) - 1,
449
+ sorted(extra_ids),
450
+ sorted(received_ids),
272
451
  )
273
452
 
274
- # resolved_node.name = resolution.get('name')
453
+ for resolution in node_resolutions:
454
+ relative_id: int = resolution.id
455
+ duplicate_idx: int = resolution.duplicate_idx
456
+
457
+ if relative_id not in valid_relative_range:
458
+ logger.warning(
459
+ 'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
460
+ relative_id,
461
+ len(state.unresolved_indices) - 1,
462
+ len(node_resolutions),
463
+ )
464
+ continue
465
+
466
+ if relative_id in processed_relative_ids:
467
+ logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
468
+ continue
469
+ processed_relative_ids.add(relative_id)
470
+
471
+ original_index = state.unresolved_indices[relative_id]
472
+ extracted_node = extracted_nodes[original_index]
473
+
474
+ resolved_node: EntityNode
475
+ if duplicate_idx == -1:
476
+ resolved_node = extracted_node
477
+ elif 0 <= duplicate_idx < len(indexes.existing_nodes):
478
+ resolved_node = indexes.existing_nodes[duplicate_idx]
479
+ else:
480
+ logger.warning(
481
+ 'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
482
+ duplicate_idx,
483
+ extracted_node.uuid,
484
+ )
485
+ resolved_node = extracted_node
486
+
487
+ state.resolved_nodes[original_index] = resolved_node
488
+ state.uuid_map[extracted_node.uuid] = resolved_node.uuid
489
+ if resolved_node.uuid != extracted_node.uuid:
490
+ state.duplicate_pairs.append((extracted_node, resolved_node))
491
+
492
+
493
+ async def resolve_extracted_nodes(
494
+ clients: GraphitiClients,
495
+ extracted_nodes: list[EntityNode],
496
+ episode: EpisodicNode | None = None,
497
+ previous_episodes: list[EpisodicNode] | None = None,
498
+ entity_types: dict[str, type[BaseModel]] | None = None,
499
+ existing_nodes_override: list[EntityNode] | None = None,
500
+ ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
501
+ """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
502
+ llm_client = clients.llm_client
503
+ driver = clients.driver
504
+ existing_nodes = await _collect_candidate_nodes(
505
+ clients,
506
+ extracted_nodes,
507
+ existing_nodes_override,
508
+ )
275
509
 
276
- resolved_nodes.append(resolved_node)
277
- uuid_map[extracted_node.uuid] = resolved_node.uuid
510
+ indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)
278
511
 
279
- duplicates: list[int] = resolution.get('duplicates', [])
280
- if duplicate_idx not in duplicates and duplicate_idx > -1:
281
- duplicates.append(duplicate_idx)
282
- for idx in duplicates:
283
- existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
512
+ state = DedupResolutionState(
513
+ resolved_nodes=[None] * len(extracted_nodes),
514
+ uuid_map={},
515
+ unresolved_indices=[],
516
+ )
284
517
 
285
- node_duplicates.append((extracted_node, existing_node))
518
+ _resolve_with_similarity(extracted_nodes, indexes, state)
286
519
 
287
- logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
520
+ await _resolve_with_llm(
521
+ llm_client,
522
+ extracted_nodes,
523
+ indexes,
524
+ state,
525
+ episode,
526
+ previous_episodes,
527
+ entity_types,
528
+ )
529
+
530
+ for idx, node in enumerate(extracted_nodes):
531
+ if state.resolved_nodes[idx] is None:
532
+ state.resolved_nodes[idx] = node
533
+ state.uuid_map[node.uuid] = node.uuid
534
+
535
+ logger.debug(
536
+ 'Resolved nodes: %s',
537
+ [(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
538
+ )
288
539
 
289
540
  new_node_duplicates: list[
290
541
  tuple[EntityNode, EntityNode]
291
- ] = await filter_existing_duplicate_of_edges(driver, node_duplicates)
542
+ ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
292
543
 
293
- return resolved_nodes, uuid_map, new_node_duplicates
544
+ return (
545
+ [node for node in state.resolved_nodes if node is not None],
546
+ state.uuid_map,
547
+ new_node_duplicates,
548
+ )
294
549
 
295
550
 
296
551
  async def extract_attributes_from_nodes(
@@ -298,7 +553,8 @@ async def extract_attributes_from_nodes(
298
553
  nodes: list[EntityNode],
299
554
  episode: EpisodicNode | None = None,
300
555
  previous_episodes: list[EpisodicNode] | None = None,
301
- entity_types: dict[str, BaseModel] | None = None,
556
+ entity_types: dict[str, type[BaseModel]] | None = None,
557
+ should_summarize_node: NodeSummaryFilter | None = None,
302
558
  ) -> list[EntityNode]:
303
559
  llm_client = clients.llm_client
304
560
  embedder = clients.embedder
@@ -309,9 +565,12 @@ async def extract_attributes_from_nodes(
309
565
  node,
310
566
  episode,
311
567
  previous_episodes,
312
- entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
313
- if entity_types is not None
314
- else None,
568
+ (
569
+ entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
570
+ if entity_types is not None
571
+ else None
572
+ ),
573
+ should_summarize_node,
315
574
  )
316
575
  for node in nodes
317
576
  ]
@@ -327,99 +586,100 @@ async def extract_attributes_from_node(
327
586
  node: EntityNode,
328
587
  episode: EpisodicNode | None = None,
329
588
  previous_episodes: list[EpisodicNode] | None = None,
330
- entity_type: BaseModel | None = None,
589
+ entity_type: type[BaseModel] | None = None,
590
+ should_summarize_node: NodeSummaryFilter | None = None,
331
591
  ) -> EntityNode:
332
- node_context: dict[str, Any] = {
333
- 'name': node.name,
334
- 'summary': node.summary,
335
- 'entity_types': node.labels,
336
- 'attributes': node.attributes,
337
- }
592
+ # Extract attributes if entity type is defined and has attributes
593
+ llm_response = await _extract_entity_attributes(
594
+ llm_client, node, episode, previous_episodes, entity_type
595
+ )
338
596
 
339
- attributes_definitions: dict[str, Any] = {
340
- 'summary': (
341
- str,
342
- Field(
343
- description='Summary containing the important information about the entity. Under 250 words',
344
- ),
345
- )
346
- }
597
+ # Extract summary if needed
598
+ await _extract_entity_summary(
599
+ llm_client, node, episode, previous_episodes, should_summarize_node
600
+ )
347
601
 
348
- if entity_type is not None:
349
- for field_name, field_info in entity_type.model_fields.items():
350
- attributes_definitions[field_name] = (
351
- field_info.annotation,
352
- Field(description=field_info.description),
353
- )
602
+ node.attributes.update(llm_response)
354
603
 
355
- unique_model_name = f'EntityAttributes_{uuid4().hex}'
356
- entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions)
604
+ return node
357
605
 
358
- summary_context: dict[str, Any] = {
359
- 'node': node_context,
360
- 'episode_content': episode.content if episode is not None else '',
361
- 'previous_episodes': [ep.content for ep in previous_episodes]
362
- if previous_episodes is not None
363
- else [],
364
- }
606
+
607
+ async def _extract_entity_attributes(
608
+ llm_client: LLMClient,
609
+ node: EntityNode,
610
+ episode: EpisodicNode | None,
611
+ previous_episodes: list[EpisodicNode] | None,
612
+ entity_type: type[BaseModel] | None,
613
+ ) -> dict[str, Any]:
614
+ if entity_type is None or len(entity_type.model_fields) == 0:
615
+ return {}
616
+
617
+ attributes_context = _build_episode_context(
618
+ # should not include summary
619
+ node_data={
620
+ 'name': node.name,
621
+ 'entity_types': node.labels,
622
+ 'attributes': node.attributes,
623
+ },
624
+ episode=episode,
625
+ previous_episodes=previous_episodes,
626
+ )
365
627
 
366
628
  llm_response = await llm_client.generate_response(
367
- prompt_library.extract_nodes.extract_attributes(summary_context),
368
- response_model=entity_attributes_model,
629
+ prompt_library.extract_nodes.extract_attributes(attributes_context),
630
+ response_model=entity_type,
369
631
  model_size=ModelSize.small,
632
+ group_id=node.group_id,
633
+ prompt_name='extract_nodes.extract_attributes',
370
634
  )
371
635
 
372
- node.summary = llm_response.get('summary', node.summary)
373
- node_attributes = {key: value for key, value in llm_response.items()}
374
-
375
- with suppress(KeyError):
376
- del node_attributes['summary']
636
+ # validate response
637
+ entity_type(**llm_response)
377
638
 
378
- node.attributes.update(node_attributes)
379
-
380
- return node
639
+ return llm_response
381
640
 
382
641
 
383
- async def dedupe_node_list(
642
+ async def _extract_entity_summary(
384
643
  llm_client: LLMClient,
385
- nodes: list[EntityNode],
386
- ) -> tuple[list[EntityNode], dict[str, str]]:
387
- start = time()
388
-
389
- # build node map
390
- node_map = {}
391
- for node in nodes:
392
- node_map[node.uuid] = node
393
-
394
- # Prepare context for LLM
395
- nodes_context = [{'uuid': node.uuid, 'name': node.name, **node.attributes} for node in nodes]
396
-
397
- context = {
398
- 'nodes': nodes_context,
399
- }
400
-
401
- llm_response = await llm_client.generate_response(
402
- prompt_library.dedupe_nodes.node_list(context)
644
+ node: EntityNode,
645
+ episode: EpisodicNode | None,
646
+ previous_episodes: list[EpisodicNode] | None,
647
+ should_summarize_node: NodeSummaryFilter | None,
648
+ ) -> None:
649
+ if should_summarize_node is not None and not await should_summarize_node(node):
650
+ return
651
+
652
+ summary_context = _build_episode_context(
653
+ node_data={
654
+ 'name': node.name,
655
+ 'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
656
+ 'entity_types': node.labels,
657
+ 'attributes': node.attributes,
658
+ },
659
+ episode=episode,
660
+ previous_episodes=previous_episodes,
403
661
  )
404
662
 
405
- nodes_data = llm_response.get('nodes', [])
663
+ summary_response = await llm_client.generate_response(
664
+ prompt_library.extract_nodes.extract_summary(summary_context),
665
+ response_model=EntitySummary,
666
+ model_size=ModelSize.small,
667
+ group_id=node.group_id,
668
+ prompt_name='extract_nodes.extract_summary',
669
+ )
406
670
 
407
- end = time()
408
- logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
409
-
410
- # Get full node data
411
- unique_nodes = []
412
- uuid_map: dict[str, str] = {}
413
- for node_data in nodes_data:
414
- node_instance: EntityNode | None = node_map.get(node_data['uuids'][0])
415
- if node_instance is None:
416
- logger.warning(f'Node {node_data["uuids"][0]} not found in node map')
417
- continue
418
- node_instance.summary = node_data['summary']
419
- unique_nodes.append(node_instance)
671
+ node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
420
672
 
421
- for uuid in node_data['uuids'][1:]:
422
- uuid_value = node_map[node_data['uuids'][0]].uuid
423
- uuid_map[uuid] = uuid_value
424
673
 
425
- return unique_nodes, uuid_map
674
+ def _build_episode_context(
675
+ node_data: dict[str, Any],
676
+ episode: EpisodicNode | None,
677
+ previous_episodes: list[EpisodicNode] | None,
678
+ ) -> dict[str, Any]:
679
+ return {
680
+ 'node': node_data,
681
+ 'episode_content': episode.content if episode is not None else '',
682
+ 'previous_episodes': (
683
+ [ep.content for ep in previous_episodes] if previous_episodes is not None else []
684
+ ),
685
+ }