graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +61 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +582 -255
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +21 -14
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +94 -50
  24. graphiti_core/llm_client/openai_client.py +28 -8
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +23 -14
  33. graphiti_core/prompts/extract_nodes.py +73 -32
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +109 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/datetime_utils.py +13 -0
  45. graphiti_core/utils/maintenance/community_operations.py +62 -38
  46. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  47. graphiti_core/utils/maintenance/edge_operations.py +286 -126
  48. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  49. graphiti_core/utils/maintenance/node_operations.py +320 -158
  50. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  51. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  52. graphiti_core/utils/text_utils.py +53 -0
  53. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
  54. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  55. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  56. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  57. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  58. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
@@ -15,22 +15,26 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import logging
18
- from contextlib import suppress
18
+ from collections.abc import Awaitable, Callable
19
19
  from time import time
20
20
  from typing import Any
21
- from uuid import uuid4
22
21
 
23
- import pydantic
24
- from pydantic import BaseModel, Field
22
+ from pydantic import BaseModel
25
23
 
26
24
  from graphiti_core.graphiti_types import GraphitiClients
27
25
  from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
28
26
  from graphiti_core.llm_client import LLMClient
29
27
  from graphiti_core.llm_client.config import ModelSize
30
- from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
28
+ from graphiti_core.nodes import (
29
+ EntityNode,
30
+ EpisodeType,
31
+ EpisodicNode,
32
+ create_entity_node_embeddings,
33
+ )
31
34
  from graphiti_core.prompts import prompt_library
32
- from graphiti_core.prompts.dedupe_nodes import NodeResolutions
35
+ from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
33
36
  from graphiti_core.prompts.extract_nodes import (
37
+ EntitySummary,
34
38
  ExtractedEntities,
35
39
  ExtractedEntity,
36
40
  MissedEntities,
@@ -40,16 +44,28 @@ from graphiti_core.search.search_config import SearchResults
40
44
  from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
41
45
  from graphiti_core.search.search_filters import SearchFilters
42
46
  from graphiti_core.utils.datetime_utils import utc_now
43
- from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges
47
+ from graphiti_core.utils.maintenance.dedup_helpers import (
48
+ DedupCandidateIndexes,
49
+ DedupResolutionState,
50
+ _build_candidate_indexes,
51
+ _resolve_with_similarity,
52
+ )
53
+ from graphiti_core.utils.maintenance.edge_operations import (
54
+ filter_existing_duplicate_of_edges,
55
+ )
56
+ from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
44
57
 
45
58
  logger = logging.getLogger(__name__)
46
59
 
60
+ NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
61
+
47
62
 
48
63
  async def extract_nodes_reflexion(
49
64
  llm_client: LLMClient,
50
65
  episode: EpisodicNode,
51
66
  previous_episodes: list[EpisodicNode],
52
67
  node_names: list[str],
68
+ group_id: str | None = None,
53
69
  ) -> list[str]:
54
70
  # Prepare context for LLM
55
71
  context = {
@@ -59,7 +75,10 @@ async def extract_nodes_reflexion(
59
75
  }
60
76
 
61
77
  llm_response = await llm_client.generate_response(
62
- prompt_library.extract_nodes.reflexion(context), MissedEntities
78
+ prompt_library.extract_nodes.reflexion(context),
79
+ MissedEntities,
80
+ group_id=group_id,
81
+ prompt_name='extract_nodes.reflexion',
63
82
  )
64
83
  missed_entities = llm_response.get('missed_entities', [])
65
84
 
@@ -70,7 +89,7 @@ async def extract_nodes(
70
89
  clients: GraphitiClients,
71
90
  episode: EpisodicNode,
72
91
  previous_episodes: list[EpisodicNode],
73
- entity_types: dict[str, BaseModel] | None = None,
92
+ entity_types: dict[str, type[BaseModel]] | None = None,
74
93
  excluded_entity_types: list[str] | None = None,
75
94
  ) -> list[EntityNode]:
76
95
  start = time()
@@ -115,20 +134,27 @@ async def extract_nodes(
115
134
  llm_response = await llm_client.generate_response(
116
135
  prompt_library.extract_nodes.extract_message(context),
117
136
  response_model=ExtractedEntities,
137
+ group_id=episode.group_id,
138
+ prompt_name='extract_nodes.extract_message',
118
139
  )
119
140
  elif episode.source == EpisodeType.text:
120
141
  llm_response = await llm_client.generate_response(
121
- prompt_library.extract_nodes.extract_text(context), response_model=ExtractedEntities
142
+ prompt_library.extract_nodes.extract_text(context),
143
+ response_model=ExtractedEntities,
144
+ group_id=episode.group_id,
145
+ prompt_name='extract_nodes.extract_text',
122
146
  )
123
147
  elif episode.source == EpisodeType.json:
124
148
  llm_response = await llm_client.generate_response(
125
- prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities
149
+ prompt_library.extract_nodes.extract_json(context),
150
+ response_model=ExtractedEntities,
151
+ group_id=episode.group_id,
152
+ prompt_name='extract_nodes.extract_json',
126
153
  )
127
154
 
128
- extracted_entities: list[ExtractedEntity] = [
129
- ExtractedEntity(**entity_types_context)
130
- for entity_types_context in llm_response.get('extracted_entities', [])
131
- ]
155
+ response_object = ExtractedEntities(**llm_response)
156
+
157
+ extracted_entities: list[ExtractedEntity] = response_object.extracted_entities
132
158
 
133
159
  reflexion_iterations += 1
134
160
  if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
@@ -137,6 +163,7 @@ async def extract_nodes(
137
163
  episode,
138
164
  previous_episodes,
139
165
  [entity.name for entity in extracted_entities],
166
+ episode.group_id,
140
167
  )
141
168
 
142
169
  entities_missed = len(missing_entities) != 0
@@ -151,9 +178,13 @@ async def extract_nodes(
151
178
  # Convert the extracted data into EntityNode objects
152
179
  extracted_nodes = []
153
180
  for extracted_entity in filtered_extracted_entities:
154
- entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
155
- 'entity_type_name'
156
- )
181
+ type_id = extracted_entity.entity_type_id
182
+ if 0 <= type_id < len(entity_types_context):
183
+ entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
184
+ 'entity_type_name'
185
+ )
186
+ else:
187
+ entity_type_name = 'Entity'
157
188
 
158
189
  # Check if this entity type should be excluded
159
190
  if excluded_entity_types and entity_type_name in excluded_entity_types:
@@ -173,20 +204,16 @@ async def extract_nodes(
173
204
  logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
174
205
 
175
206
  logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
207
+
176
208
  return extracted_nodes
177
209
 
178
210
 
179
- async def resolve_extracted_nodes(
211
+ async def _collect_candidate_nodes(
180
212
  clients: GraphitiClients,
181
213
  extracted_nodes: list[EntityNode],
182
- episode: EpisodicNode | None = None,
183
- previous_episodes: list[EpisodicNode] | None = None,
184
- entity_types: dict[str, BaseModel] | None = None,
185
- existing_nodes_override: list[EntityNode] | None = None,
186
- ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
187
- llm_client = clients.llm_client
188
- driver = clients.driver
189
-
214
+ existing_nodes_override: list[EntityNode] | None,
215
+ ) -> list[EntityNode]:
216
+ """Search per extracted name and return unique candidates with overrides honored in order."""
190
217
  search_results: list[SearchResults] = await semaphore_gather(
191
218
  *[
192
219
  search(
@@ -200,33 +227,43 @@ async def resolve_extracted_nodes(
200
227
  ]
201
228
  )
202
229
 
203
- candidate_nodes: list[EntityNode] = (
204
- [node for result in search_results for node in result.nodes]
205
- if existing_nodes_override is None
206
- else existing_nodes_override
207
- )
230
+ candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]
208
231
 
209
- existing_nodes_dict: dict[str, EntityNode] = {node.uuid: node for node in candidate_nodes}
232
+ if existing_nodes_override is not None:
233
+ candidate_nodes.extend(existing_nodes_override)
210
234
 
211
- existing_nodes: list[EntityNode] = list(existing_nodes_dict.values())
235
+ seen_candidate_uuids: set[str] = set()
236
+ ordered_candidates: list[EntityNode] = []
237
+ for candidate in candidate_nodes:
238
+ if candidate.uuid in seen_candidate_uuids:
239
+ continue
240
+ seen_candidate_uuids.add(candidate.uuid)
241
+ ordered_candidates.append(candidate)
212
242
 
213
- existing_nodes_context = (
214
- [
215
- {
216
- **{
217
- 'idx': i,
218
- 'name': candidate.name,
219
- 'entity_types': candidate.labels,
220
- },
221
- **candidate.attributes,
222
- }
223
- for i, candidate in enumerate(existing_nodes)
224
- ],
225
- )
243
+ return ordered_candidates
226
244
 
227
- entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
228
245
 
229
- # Prepare context for LLM
246
+ async def _resolve_with_llm(
247
+ llm_client: LLMClient,
248
+ extracted_nodes: list[EntityNode],
249
+ indexes: DedupCandidateIndexes,
250
+ state: DedupResolutionState,
251
+ episode: EpisodicNode | None,
252
+ previous_episodes: list[EpisodicNode] | None,
253
+ entity_types: dict[str, type[BaseModel]] | None,
254
+ ) -> None:
255
+ """Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
256
+
257
+ The guardrails below defensively ignore malformed or duplicate LLM responses so the
258
+ ingestion workflow remains deterministic even when the model misbehaves.
259
+ """
260
+ if not state.unresolved_indices:
261
+ return
262
+
263
+ entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
264
+
265
+ llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]
266
+
230
267
  extracted_nodes_context = [
231
268
  {
232
269
  'id': i,
@@ -237,60 +274,180 @@ async def resolve_extracted_nodes(
237
274
  ).__doc__
238
275
  or 'Default Entity Type',
239
276
  }
240
- for i, node in enumerate(extracted_nodes)
277
+ for i, node in enumerate(llm_extracted_nodes)
278
+ ]
279
+
280
+ sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
281
+ logger.debug(
282
+ 'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
283
+ len(llm_extracted_nodes),
284
+ len(llm_extracted_nodes) - 1,
285
+ sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
286
+ )
287
+ if llm_extracted_nodes:
288
+ sample_size = min(3, len(extracted_nodes_context))
289
+ logger.debug(
290
+ 'First %d entities: %s',
291
+ sample_size,
292
+ [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
293
+ )
294
+ if len(extracted_nodes_context) > 3:
295
+ logger.debug(
296
+ 'Last %d entities: %s',
297
+ sample_size,
298
+ [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
299
+ )
300
+
301
+ existing_nodes_context = [
302
+ {
303
+ **{
304
+ 'idx': i,
305
+ 'name': candidate.name,
306
+ 'entity_types': candidate.labels,
307
+ },
308
+ **candidate.attributes,
309
+ }
310
+ for i, candidate in enumerate(indexes.existing_nodes)
241
311
  ]
242
312
 
243
313
  context = {
244
314
  'extracted_nodes': extracted_nodes_context,
245
315
  'existing_nodes': existing_nodes_context,
246
316
  'episode_content': episode.content if episode is not None else '',
247
- 'previous_episodes': [ep.content for ep in previous_episodes]
248
- if previous_episodes is not None
249
- else [],
317
+ 'previous_episodes': (
318
+ [ep.content for ep in previous_episodes] if previous_episodes is not None else []
319
+ ),
250
320
  }
251
321
 
252
322
  llm_response = await llm_client.generate_response(
253
323
  prompt_library.dedupe_nodes.nodes(context),
254
324
  response_model=NodeResolutions,
325
+ prompt_name='dedupe_nodes.nodes',
255
326
  )
256
327
 
257
- node_resolutions: list = llm_response.get('entity_resolutions', [])
328
+ node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
258
329
 
259
- resolved_nodes: list[EntityNode] = []
260
- uuid_map: dict[str, str] = {}
261
- node_duplicates: list[tuple[EntityNode, EntityNode]] = []
262
- for resolution in node_resolutions:
263
- resolution_id: int = resolution.get('id', -1)
264
- duplicate_idx: int = resolution.get('duplicate_idx', -1)
330
+ valid_relative_range = range(len(state.unresolved_indices))
331
+ processed_relative_ids: set[int] = set()
265
332
 
266
- extracted_node = extracted_nodes[resolution_id]
333
+ received_ids = {r.id for r in node_resolutions}
334
+ expected_ids = set(valid_relative_range)
335
+ missing_ids = expected_ids - received_ids
336
+ extra_ids = received_ids - expected_ids
337
+
338
+ logger.debug(
339
+ 'Received %d resolutions for %d entities',
340
+ len(node_resolutions),
341
+ len(state.unresolved_indices),
342
+ )
267
343
 
268
- resolved_node = (
269
- existing_nodes[duplicate_idx]
270
- if 0 <= duplicate_idx < len(existing_nodes)
271
- else extracted_node
344
+ if missing_ids:
345
+ logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))
346
+
347
+ if extra_ids:
348
+ logger.warning(
349
+ 'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
350
+ len(state.unresolved_indices) - 1,
351
+ sorted(extra_ids),
352
+ sorted(received_ids),
272
353
  )
273
354
 
274
- # resolved_node.name = resolution.get('name')
355
+ for resolution in node_resolutions:
356
+ relative_id: int = resolution.id
357
+ duplicate_idx: int = resolution.duplicate_idx
358
+
359
+ if relative_id not in valid_relative_range:
360
+ logger.warning(
361
+ 'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
362
+ relative_id,
363
+ len(state.unresolved_indices) - 1,
364
+ len(node_resolutions),
365
+ )
366
+ continue
275
367
 
276
- resolved_nodes.append(resolved_node)
277
- uuid_map[extracted_node.uuid] = resolved_node.uuid
368
+ if relative_id in processed_relative_ids:
369
+ logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
370
+ continue
371
+ processed_relative_ids.add(relative_id)
372
+
373
+ original_index = state.unresolved_indices[relative_id]
374
+ extracted_node = extracted_nodes[original_index]
375
+
376
+ resolved_node: EntityNode
377
+ if duplicate_idx == -1:
378
+ resolved_node = extracted_node
379
+ elif 0 <= duplicate_idx < len(indexes.existing_nodes):
380
+ resolved_node = indexes.existing_nodes[duplicate_idx]
381
+ else:
382
+ logger.warning(
383
+ 'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
384
+ duplicate_idx,
385
+ extracted_node.uuid,
386
+ )
387
+ resolved_node = extracted_node
278
388
 
279
- duplicates: list[int] = resolution.get('duplicates', [])
280
- if duplicate_idx not in duplicates and duplicate_idx > -1:
281
- duplicates.append(duplicate_idx)
282
- for idx in duplicates:
283
- existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
389
+ state.resolved_nodes[original_index] = resolved_node
390
+ state.uuid_map[extracted_node.uuid] = resolved_node.uuid
391
+ if resolved_node.uuid != extracted_node.uuid:
392
+ state.duplicate_pairs.append((extracted_node, resolved_node))
284
393
 
285
- node_duplicates.append((extracted_node, existing_node))
286
394
 
287
- logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
395
+ async def resolve_extracted_nodes(
396
+ clients: GraphitiClients,
397
+ extracted_nodes: list[EntityNode],
398
+ episode: EpisodicNode | None = None,
399
+ previous_episodes: list[EpisodicNode] | None = None,
400
+ entity_types: dict[str, type[BaseModel]] | None = None,
401
+ existing_nodes_override: list[EntityNode] | None = None,
402
+ ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
403
+ """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
404
+ llm_client = clients.llm_client
405
+ driver = clients.driver
406
+ existing_nodes = await _collect_candidate_nodes(
407
+ clients,
408
+ extracted_nodes,
409
+ existing_nodes_override,
410
+ )
411
+
412
+ indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)
413
+
414
+ state = DedupResolutionState(
415
+ resolved_nodes=[None] * len(extracted_nodes),
416
+ uuid_map={},
417
+ unresolved_indices=[],
418
+ )
419
+
420
+ _resolve_with_similarity(extracted_nodes, indexes, state)
421
+
422
+ await _resolve_with_llm(
423
+ llm_client,
424
+ extracted_nodes,
425
+ indexes,
426
+ state,
427
+ episode,
428
+ previous_episodes,
429
+ entity_types,
430
+ )
431
+
432
+ for idx, node in enumerate(extracted_nodes):
433
+ if state.resolved_nodes[idx] is None:
434
+ state.resolved_nodes[idx] = node
435
+ state.uuid_map[node.uuid] = node.uuid
436
+
437
+ logger.debug(
438
+ 'Resolved nodes: %s',
439
+ [(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
440
+ )
288
441
 
289
442
  new_node_duplicates: list[
290
443
  tuple[EntityNode, EntityNode]
291
- ] = await filter_existing_duplicate_of_edges(driver, node_duplicates)
444
+ ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
292
445
 
293
- return resolved_nodes, uuid_map, new_node_duplicates
446
+ return (
447
+ [node for node in state.resolved_nodes if node is not None],
448
+ state.uuid_map,
449
+ new_node_duplicates,
450
+ )
294
451
 
295
452
 
296
453
  async def extract_attributes_from_nodes(
@@ -298,7 +455,8 @@ async def extract_attributes_from_nodes(
298
455
  nodes: list[EntityNode],
299
456
  episode: EpisodicNode | None = None,
300
457
  previous_episodes: list[EpisodicNode] | None = None,
301
- entity_types: dict[str, BaseModel] | None = None,
458
+ entity_types: dict[str, type[BaseModel]] | None = None,
459
+ should_summarize_node: NodeSummaryFilter | None = None,
302
460
  ) -> list[EntityNode]:
303
461
  llm_client = clients.llm_client
304
462
  embedder = clients.embedder
@@ -309,9 +467,12 @@ async def extract_attributes_from_nodes(
309
467
  node,
310
468
  episode,
311
469
  previous_episodes,
312
- entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
313
- if entity_types is not None
314
- else None,
470
+ (
471
+ entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
472
+ if entity_types is not None
473
+ else None
474
+ ),
475
+ should_summarize_node,
315
476
  )
316
477
  for node in nodes
317
478
  ]
@@ -327,99 +488,100 @@ async def extract_attributes_from_node(
327
488
  node: EntityNode,
328
489
  episode: EpisodicNode | None = None,
329
490
  previous_episodes: list[EpisodicNode] | None = None,
330
- entity_type: BaseModel | None = None,
491
+ entity_type: type[BaseModel] | None = None,
492
+ should_summarize_node: NodeSummaryFilter | None = None,
331
493
  ) -> EntityNode:
332
- node_context: dict[str, Any] = {
333
- 'name': node.name,
334
- 'summary': node.summary,
335
- 'entity_types': node.labels,
336
- 'attributes': node.attributes,
337
- }
494
+ # Extract attributes if entity type is defined and has attributes
495
+ llm_response = await _extract_entity_attributes(
496
+ llm_client, node, episode, previous_episodes, entity_type
497
+ )
338
498
 
339
- attributes_definitions: dict[str, Any] = {
340
- 'summary': (
341
- str,
342
- Field(
343
- description='Summary containing the important information about the entity. Under 250 words',
344
- ),
345
- )
346
- }
499
+ # Extract summary if needed
500
+ await _extract_entity_summary(
501
+ llm_client, node, episode, previous_episodes, should_summarize_node
502
+ )
347
503
 
348
- if entity_type is not None:
349
- for field_name, field_info in entity_type.model_fields.items():
350
- attributes_definitions[field_name] = (
351
- field_info.annotation,
352
- Field(description=field_info.description),
353
- )
504
+ node.attributes.update(llm_response)
354
505
 
355
- unique_model_name = f'EntityAttributes_{uuid4().hex}'
356
- entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions)
506
+ return node
357
507
 
358
- summary_context: dict[str, Any] = {
359
- 'node': node_context,
360
- 'episode_content': episode.content if episode is not None else '',
361
- 'previous_episodes': [ep.content for ep in previous_episodes]
362
- if previous_episodes is not None
363
- else [],
364
- }
508
+
509
+ async def _extract_entity_attributes(
510
+ llm_client: LLMClient,
511
+ node: EntityNode,
512
+ episode: EpisodicNode | None,
513
+ previous_episodes: list[EpisodicNode] | None,
514
+ entity_type: type[BaseModel] | None,
515
+ ) -> dict[str, Any]:
516
+ if entity_type is None or len(entity_type.model_fields) == 0:
517
+ return {}
518
+
519
+ attributes_context = _build_episode_context(
520
+ # should not include summary
521
+ node_data={
522
+ 'name': node.name,
523
+ 'entity_types': node.labels,
524
+ 'attributes': node.attributes,
525
+ },
526
+ episode=episode,
527
+ previous_episodes=previous_episodes,
528
+ )
365
529
 
366
530
  llm_response = await llm_client.generate_response(
367
- prompt_library.extract_nodes.extract_attributes(summary_context),
368
- response_model=entity_attributes_model,
531
+ prompt_library.extract_nodes.extract_attributes(attributes_context),
532
+ response_model=entity_type,
369
533
  model_size=ModelSize.small,
534
+ group_id=node.group_id,
535
+ prompt_name='extract_nodes.extract_attributes',
370
536
  )
371
537
 
372
- node.summary = llm_response.get('summary', node.summary)
373
- node_attributes = {key: value for key, value in llm_response.items()}
374
-
375
- with suppress(KeyError):
376
- del node_attributes['summary']
538
+ # validate response
539
+ entity_type(**llm_response)
377
540
 
378
- node.attributes.update(node_attributes)
541
+ return llm_response
379
542
 
380
- return node
381
543
 
382
-
383
- async def dedupe_node_list(
544
+ async def _extract_entity_summary(
384
545
  llm_client: LLMClient,
385
- nodes: list[EntityNode],
386
- ) -> tuple[list[EntityNode], dict[str, str]]:
387
- start = time()
388
-
389
- # build node map
390
- node_map = {}
391
- for node in nodes:
392
- node_map[node.uuid] = node
393
-
394
- # Prepare context for LLM
395
- nodes_context = [{'uuid': node.uuid, 'name': node.name, **node.attributes} for node in nodes]
396
-
397
- context = {
398
- 'nodes': nodes_context,
399
- }
400
-
401
- llm_response = await llm_client.generate_response(
402
- prompt_library.dedupe_nodes.node_list(context)
546
+ node: EntityNode,
547
+ episode: EpisodicNode | None,
548
+ previous_episodes: list[EpisodicNode] | None,
549
+ should_summarize_node: NodeSummaryFilter | None,
550
+ ) -> None:
551
+ if should_summarize_node is not None and not await should_summarize_node(node):
552
+ return
553
+
554
+ summary_context = _build_episode_context(
555
+ node_data={
556
+ 'name': node.name,
557
+ 'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
558
+ 'entity_types': node.labels,
559
+ 'attributes': node.attributes,
560
+ },
561
+ episode=episode,
562
+ previous_episodes=previous_episodes,
403
563
  )
404
564
 
405
- nodes_data = llm_response.get('nodes', [])
565
+ summary_response = await llm_client.generate_response(
566
+ prompt_library.extract_nodes.extract_summary(summary_context),
567
+ response_model=EntitySummary,
568
+ model_size=ModelSize.small,
569
+ group_id=node.group_id,
570
+ prompt_name='extract_nodes.extract_summary',
571
+ )
406
572
 
407
- end = time()
408
- logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
409
-
410
- # Get full node data
411
- unique_nodes = []
412
- uuid_map: dict[str, str] = {}
413
- for node_data in nodes_data:
414
- node_instance: EntityNode | None = node_map.get(node_data['uuids'][0])
415
- if node_instance is None:
416
- logger.warning(f'Node {node_data["uuids"][0]} not found in node map')
417
- continue
418
- node_instance.summary = node_data['summary']
419
- unique_nodes.append(node_instance)
573
+ node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
420
574
 
421
- for uuid in node_data['uuids'][1:]:
422
- uuid_value = node_map[node_data['uuids'][0]].uuid
423
- uuid_map[uuid] = uuid_value
424
575
 
425
- return unique_nodes, uuid_map
576
+ def _build_episode_context(
577
+ node_data: dict[str, Any],
578
+ episode: EpisodicNode | None,
579
+ previous_episodes: list[EpisodicNode] | None,
580
+ ) -> dict[str, Any]:
581
+ return {
582
+ 'node': node_data,
583
+ 'episode_content': episode.content if episode is not None else '',
584
+ 'previous_episodes': (
585
+ [ep.content for ep in previous_episodes] if previous_episodes is not None else []
586
+ ),
587
+ }
@@ -43,7 +43,9 @@ async def extract_edge_dates(
43
43
  'reference_timestamp': current_episode.valid_at.isoformat(),
44
44
  }
45
45
  llm_response = await llm_client.generate_response(
46
- prompt_library.extract_edge_dates.v1(context), response_model=EdgeDates
46
+ prompt_library.extract_edge_dates.v1(context),
47
+ response_model=EdgeDates,
48
+ prompt_name='extract_edge_dates.v1',
47
49
  )
48
50
 
49
51
  valid_at = llm_response.get('valid_at')
@@ -70,7 +72,9 @@ async def extract_edge_dates(
70
72
 
71
73
 
72
74
  async def get_edge_contradictions(
73
- llm_client: LLMClient, new_edge: EntityEdge, existing_edges: list[EntityEdge]
75
+ llm_client: LLMClient,
76
+ new_edge: EntityEdge,
77
+ existing_edges: list[EntityEdge],
74
78
  ) -> list[EntityEdge]:
75
79
  start = time()
76
80
 
@@ -79,12 +83,16 @@ async def get_edge_contradictions(
79
83
  {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
80
84
  ]
81
85
 
82
- context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context}
86
+ context = {
87
+ 'new_edge': new_edge_context,
88
+ 'existing_edges': existing_edge_context,
89
+ }
83
90
 
84
91
  llm_response = await llm_client.generate_response(
85
92
  prompt_library.invalidate_edges.v2(context),
86
93
  response_model=InvalidatedEdges,
87
94
  model_size=ModelSize.small,
95
+ prompt_name='invalidate_edges.v2',
88
96
  )
89
97
 
90
98
  contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])