graphiti-core 0.21.0rc12__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (41) hide show
  1. graphiti_core/driver/driver.py +4 -211
  2. graphiti_core/driver/falkordb_driver.py +31 -3
  3. graphiti_core/driver/graph_operations/graph_operations.py +195 -0
  4. graphiti_core/driver/neo4j_driver.py +0 -49
  5. graphiti_core/driver/neptune_driver.py +43 -26
  6. graphiti_core/driver/search_interface/__init__.py +0 -0
  7. graphiti_core/driver/search_interface/search_interface.py +89 -0
  8. graphiti_core/edges.py +11 -34
  9. graphiti_core/graphiti.py +459 -326
  10. graphiti_core/graphiti_types.py +2 -0
  11. graphiti_core/llm_client/anthropic_client.py +64 -45
  12. graphiti_core/llm_client/client.py +67 -19
  13. graphiti_core/llm_client/gemini_client.py +73 -54
  14. graphiti_core/llm_client/openai_base_client.py +65 -43
  15. graphiti_core/llm_client/openai_generic_client.py +65 -43
  16. graphiti_core/models/edges/edge_db_queries.py +1 -0
  17. graphiti_core/models/nodes/node_db_queries.py +1 -0
  18. graphiti_core/nodes.py +26 -99
  19. graphiti_core/prompts/dedupe_edges.py +4 -4
  20. graphiti_core/prompts/dedupe_nodes.py +10 -10
  21. graphiti_core/prompts/extract_edges.py +4 -4
  22. graphiti_core/prompts/extract_nodes.py +26 -28
  23. graphiti_core/prompts/prompt_helpers.py +18 -2
  24. graphiti_core/prompts/snippets.py +29 -0
  25. graphiti_core/prompts/summarize_nodes.py +22 -24
  26. graphiti_core/search/search_filters.py +0 -38
  27. graphiti_core/search/search_helpers.py +4 -4
  28. graphiti_core/search/search_utils.py +84 -220
  29. graphiti_core/tracer.py +193 -0
  30. graphiti_core/utils/bulk_utils.py +16 -28
  31. graphiti_core/utils/maintenance/community_operations.py +4 -1
  32. graphiti_core/utils/maintenance/edge_operations.py +30 -15
  33. graphiti_core/utils/maintenance/graph_data_operations.py +6 -25
  34. graphiti_core/utils/maintenance/node_operations.py +99 -51
  35. graphiti_core/utils/maintenance/temporal_operations.py +4 -1
  36. graphiti_core/utils/text_utils.py +53 -0
  37. {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/METADATA +7 -3
  38. {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/RECORD +41 -35
  39. /graphiti_core/{utils/maintenance/utils.py → driver/graph_operations/__init__.py} +0 -0
  40. {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/WHEEL +0 -0
  41. {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,193 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from abc import ABC, abstractmethod
18
+ from collections.abc import Generator
19
+ from contextlib import AbstractContextManager, contextmanager, suppress
20
+ from typing import TYPE_CHECKING, Any
21
+
22
+ if TYPE_CHECKING:
23
+ from opentelemetry.trace import Span, StatusCode
24
+
25
+ try:
26
+ from opentelemetry.trace import Span, StatusCode
27
+
28
+ OTEL_AVAILABLE = True
29
+ except ImportError:
30
+ OTEL_AVAILABLE = False
31
+
32
+
33
+ class TracerSpan(ABC):
34
+ """Abstract base class for tracer spans."""
35
+
36
+ @abstractmethod
37
+ def add_attributes(self, attributes: dict[str, Any]) -> None:
38
+ """Add attributes to the span."""
39
+ pass
40
+
41
+ @abstractmethod
42
+ def set_status(self, status: str, description: str | None = None) -> None:
43
+ """Set the status of the span."""
44
+ pass
45
+
46
+ @abstractmethod
47
+ def record_exception(self, exception: Exception) -> None:
48
+ """Record an exception in the span."""
49
+ pass
50
+
51
+
52
+ class Tracer(ABC):
53
+ """Abstract base class for tracers."""
54
+
55
+ @abstractmethod
56
+ def start_span(self, name: str) -> AbstractContextManager[TracerSpan]:
57
+ """Start a new span with the given name."""
58
+ pass
59
+
60
+
61
+ class NoOpSpan(TracerSpan):
62
+ """No-op span implementation that does nothing."""
63
+
64
+ def add_attributes(self, attributes: dict[str, Any]) -> None:
65
+ pass
66
+
67
+ def set_status(self, status: str, description: str | None = None) -> None:
68
+ pass
69
+
70
+ def record_exception(self, exception: Exception) -> None:
71
+ pass
72
+
73
+
74
+ class NoOpTracer(Tracer):
75
+ """No-op tracer implementation that does nothing."""
76
+
77
+ @contextmanager
78
+ def start_span(self, name: str) -> Generator[NoOpSpan, None, None]:
79
+ """Return a no-op span."""
80
+ yield NoOpSpan()
81
+
82
+
83
+ class OpenTelemetrySpan(TracerSpan):
84
+ """Wrapper for OpenTelemetry span."""
85
+
86
+ def __init__(self, span: 'Span'):
87
+ self._span = span
88
+
89
+ def add_attributes(self, attributes: dict[str, Any]) -> None:
90
+ """Add attributes to the OpenTelemetry span."""
91
+ try:
92
+ # Filter out None values and convert all values to appropriate types
93
+ filtered_attrs = {}
94
+ for key, value in attributes.items():
95
+ if value is not None:
96
+ # Convert to string if not a primitive type
97
+ if isinstance(value, str | int | float | bool):
98
+ filtered_attrs[key] = value
99
+ else:
100
+ filtered_attrs[key] = str(value)
101
+
102
+ if filtered_attrs:
103
+ self._span.set_attributes(filtered_attrs)
104
+ except Exception:
105
+ # Silently ignore tracing errors
106
+ pass
107
+
108
+ def set_status(self, status: str, description: str | None = None) -> None:
109
+ """Set the status of the OpenTelemetry span."""
110
+ try:
111
+ if OTEL_AVAILABLE:
112
+ if status == 'error':
113
+ self._span.set_status(StatusCode.ERROR, description)
114
+ elif status == 'ok':
115
+ self._span.set_status(StatusCode.OK, description)
116
+ except Exception:
117
+ # Silently ignore tracing errors
118
+ pass
119
+
120
+ def record_exception(self, exception: Exception) -> None:
121
+ """Record an exception in the OpenTelemetry span."""
122
+ with suppress(Exception):
123
+ self._span.record_exception(exception)
124
+
125
+
126
+ class OpenTelemetryTracer(Tracer):
127
+ """Wrapper for OpenTelemetry tracer with configurable span name prefix."""
128
+
129
+ def __init__(self, tracer: Any, span_prefix: str = 'graphiti'):
130
+ """
131
+ Initialize the OpenTelemetry tracer wrapper.
132
+
133
+ Parameters
134
+ ----------
135
+ tracer : opentelemetry.trace.Tracer
136
+ The OpenTelemetry tracer instance.
137
+ span_prefix : str, optional
138
+ Prefix to prepend to all span names. Defaults to 'graphiti'.
139
+ """
140
+ if not OTEL_AVAILABLE:
141
+ raise ImportError(
142
+ 'OpenTelemetry is not installed. Install it with: pip install opentelemetry-api'
143
+ )
144
+ self._tracer = tracer
145
+ self._span_prefix = span_prefix.rstrip('.')
146
+
147
+ @contextmanager
148
+ def start_span(self, name: str) -> Generator[OpenTelemetrySpan | NoOpSpan, None, None]:
149
+ """Start a new OpenTelemetry span with the configured prefix."""
150
+ try:
151
+ full_name = f'{self._span_prefix}.{name}'
152
+ with self._tracer.start_as_current_span(full_name) as span:
153
+ yield OpenTelemetrySpan(span)
154
+ except Exception:
155
+ # If tracing fails, yield a no-op span to prevent breaking the operation
156
+ yield NoOpSpan()
157
+
158
+
159
+ def create_tracer(otel_tracer: Any | None = None, span_prefix: str = 'graphiti') -> Tracer:
160
+ """
161
+ Create a tracer instance.
162
+
163
+ Parameters
164
+ ----------
165
+ otel_tracer : opentelemetry.trace.Tracer | None, optional
166
+ An OpenTelemetry tracer instance. If None, a no-op tracer is returned.
167
+ span_prefix : str, optional
168
+ Prefix to prepend to all span names. Defaults to 'graphiti'.
169
+
170
+ Returns
171
+ -------
172
+ Tracer
173
+ A tracer instance (either OpenTelemetryTracer or NoOpTracer).
174
+
175
+ Examples
176
+ --------
177
+ Using with OpenTelemetry:
178
+
179
+ >>> from opentelemetry import trace
180
+ >>> otel_tracer = trace.get_tracer(__name__)
181
+ >>> tracer = create_tracer(otel_tracer, span_prefix='myapp.graphiti')
182
+
183
+ Using no-op tracer:
184
+
185
+ >>> tracer = create_tracer() # Returns NoOpTracer
186
+ """
187
+ if otel_tracer is None:
188
+ return NoOpTracer()
189
+
190
+ if not OTEL_AVAILABLE:
191
+ return NoOpTracer()
192
+
193
+ return OpenTelemetryTracer(otel_tracer, span_prefix)
@@ -24,9 +24,6 @@ from pydantic import BaseModel, Field
24
24
  from typing_extensions import Any
25
25
 
26
26
  from graphiti_core.driver.driver import (
27
- ENTITY_EDGE_INDEX_NAME,
28
- ENTITY_INDEX_NAME,
29
- EPISODE_INDEX_NAME,
30
27
  GraphDriver,
31
28
  GraphDriverSession,
32
29
  GraphProvider,
@@ -177,12 +174,10 @@ async def add_nodes_and_edges_bulk_tx(
177
174
  'group_id': node.group_id,
178
175
  'summary': node.summary,
179
176
  'created_at': node.created_at,
177
+ 'name_embedding': node.name_embedding,
178
+ 'labels': list(set(node.labels + ['Entity'])),
180
179
  }
181
180
 
182
- if not bool(driver.aoss_client):
183
- entity_data['name_embedding'] = node.name_embedding
184
-
185
- entity_data['labels'] = list(set(node.labels + ['Entity']))
186
181
  if driver.provider == GraphProvider.KUZU:
187
182
  attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
188
183
  entity_data['attributes'] = json.dumps(attributes)
@@ -207,11 +202,9 @@ async def add_nodes_and_edges_bulk_tx(
207
202
  'expired_at': edge.expired_at,
208
203
  'valid_at': edge.valid_at,
209
204
  'invalid_at': edge.invalid_at,
205
+ 'fact_embedding': edge.fact_embedding,
210
206
  }
211
207
 
212
- if not bool(driver.aoss_client):
213
- edge_data['fact_embedding'] = edge.fact_embedding
214
-
215
208
  if driver.provider == GraphProvider.KUZU:
216
209
  attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
217
210
  edge_data['attributes'] = json.dumps(attributes)
@@ -220,7 +213,17 @@ async def add_nodes_and_edges_bulk_tx(
220
213
 
221
214
  edges.append(edge_data)
222
215
 
223
- if driver.provider == GraphProvider.KUZU:
216
+ if driver.graph_operations_interface:
217
+ await driver.graph_operations_interface.episodic_node_save_bulk(
218
+ None, driver, tx, episodic_nodes
219
+ )
220
+ await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
221
+ await driver.graph_operations_interface.episodic_edge_save_bulk(
222
+ None, driver, tx, episodic_edges
223
+ )
224
+ await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
225
+
226
+ elif driver.provider == GraphProvider.KUZU:
224
227
  # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
225
228
  episode_query = get_episode_node_save_bulk_query(driver.provider)
226
229
  for episode in episodes:
@@ -237,9 +240,7 @@ async def add_nodes_and_edges_bulk_tx(
237
240
  else:
238
241
  await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
239
242
  await tx.run(
240
- get_entity_node_save_bulk_query(
241
- driver.provider, nodes, has_aoss=bool(driver.aoss_client)
242
- ),
243
+ get_entity_node_save_bulk_query(driver.provider, nodes),
243
244
  nodes=nodes,
244
245
  )
245
246
  await tx.run(
@@ -247,23 +248,10 @@ async def add_nodes_and_edges_bulk_tx(
247
248
  episodic_edges=[edge.model_dump() for edge in episodic_edges],
248
249
  )
249
250
  await tx.run(
250
- get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)),
251
+ get_entity_edge_save_bulk_query(driver.provider),
251
252
  entity_edges=edges,
252
253
  )
253
254
 
254
- if bool(driver.aoss_client):
255
- for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
256
- if node_data.get('uuid') == entity_node.uuid:
257
- node_data['name_embedding'] = entity_node.name_embedding
258
-
259
- for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
260
- if edge_data.get('uuid') == entity_edge.uuid:
261
- edge_data['fact_embedding'] = entity_edge.fact_embedding
262
-
263
- await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
264
- await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
265
- await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
266
-
267
255
 
268
256
  async def extract_nodes_and_edges_bulk(
269
257
  clients: GraphitiClients,
@@ -138,7 +138,9 @@ async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -
138
138
  }
139
139
 
140
140
  llm_response = await llm_client.generate_response(
141
- prompt_library.summarize_nodes.summarize_pair(context), response_model=Summary
141
+ prompt_library.summarize_nodes.summarize_pair(context),
142
+ response_model=Summary,
143
+ prompt_name='summarize_nodes.summarize_pair',
142
144
  )
143
145
 
144
146
  pair_summary = llm_response.get('summary', '')
@@ -154,6 +156,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
154
156
  llm_response = await llm_client.generate_response(
155
157
  prompt_library.summarize_nodes.summary_description(context),
156
158
  response_model=SummaryDescription,
159
+ prompt_name='summarize_nodes.summary_description',
157
160
  )
158
161
 
159
162
  description = llm_response.get('description', '')
@@ -139,6 +139,8 @@ async def extract_edges(
139
139
  prompt_library.extract_edges.edge(context),
140
140
  response_model=ExtractedEdges,
141
141
  max_tokens=extract_edges_max_tokens,
142
+ group_id=group_id,
143
+ prompt_name='extract_edges.edge',
142
144
  )
143
145
  edges_data = ExtractedEdges(**llm_response).edges
144
146
 
@@ -150,6 +152,8 @@ async def extract_edges(
150
152
  prompt_library.extract_edges.reflexion(context),
151
153
  response_model=MissingFacts,
152
154
  max_tokens=extract_edges_max_tokens,
155
+ group_id=group_id,
156
+ prompt_name='extract_edges.reflexion',
153
157
  )
154
158
 
155
159
  missing_facts = reflexion_response.get('missing_facts', [])
@@ -177,6 +181,10 @@ async def extract_edges(
177
181
  valid_at_datetime = None
178
182
  invalid_at_datetime = None
179
183
 
184
+ # Filter out empty edges
185
+ if not edge_data.fact.strip():
186
+ continue
187
+
180
188
  source_node_idx = edge_data.source_entity_id
181
189
  target_node_idx = edge_data.target_entity_id
182
190
 
@@ -405,21 +413,26 @@ def resolve_edge_contradictions(
405
413
  invalidated_edges: list[EntityEdge] = []
406
414
  for edge in invalidation_candidates:
407
415
  # (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
416
+ edge_invalid_at_utc = ensure_utc(edge.invalid_at)
417
+ resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
418
+ edge_valid_at_utc = ensure_utc(edge.valid_at)
419
+ resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
420
+
408
421
  if (
409
- edge.invalid_at is not None
410
- and resolved_edge.valid_at is not None
411
- and edge.invalid_at <= resolved_edge.valid_at
422
+ edge_invalid_at_utc is not None
423
+ and resolved_edge_valid_at_utc is not None
424
+ and edge_invalid_at_utc <= resolved_edge_valid_at_utc
412
425
  ) or (
413
- edge.valid_at is not None
414
- and resolved_edge.invalid_at is not None
415
- and resolved_edge.invalid_at <= edge.valid_at
426
+ edge_valid_at_utc is not None
427
+ and resolved_edge_invalid_at_utc is not None
428
+ and resolved_edge_invalid_at_utc <= edge_valid_at_utc
416
429
  ):
417
430
  continue
418
431
  # New edge invalidates edge
419
432
  elif (
420
- edge.valid_at is not None
421
- and resolved_edge.valid_at is not None
422
- and edge.valid_at < resolved_edge.valid_at
433
+ edge_valid_at_utc is not None
434
+ and resolved_edge_valid_at_utc is not None
435
+ and edge_valid_at_utc < resolved_edge_valid_at_utc
423
436
  ):
424
437
  edge.invalid_at = resolved_edge.valid_at
425
438
  edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
@@ -520,6 +533,7 @@ async def resolve_extracted_edge(
520
533
  prompt_library.dedupe_edges.resolve_edge(context),
521
534
  response_model=EdgeDuplicate,
522
535
  model_size=ModelSize.small,
536
+ prompt_name='dedupe_edges.resolve_edge',
523
537
  )
524
538
  response_object = EdgeDuplicate(**llm_response)
525
539
  duplicate_facts = response_object.duplicate_facts
@@ -583,6 +597,7 @@ async def resolve_extracted_edge(
583
597
  prompt_library.extract_edges.extract_attributes(edge_attributes_context),
584
598
  response_model=edge_model, # type: ignore
585
599
  model_size=ModelSize.small,
600
+ prompt_name='extract_edges.extract_attributes',
586
601
  )
587
602
 
588
603
  resolved_edge.attributes = edge_attributes_response
@@ -609,14 +624,14 @@ async def resolve_extracted_edge(
609
624
 
610
625
  # Determine if the new_edge needs to be expired
611
626
  if resolved_edge.expired_at is None:
612
- invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
627
+ invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
613
628
  for candidate in invalidation_candidates:
629
+ candidate_valid_at_utc = ensure_utc(candidate.valid_at)
630
+ resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
614
631
  if (
615
- candidate.valid_at
616
- and resolved_edge.valid_at
617
- and candidate.valid_at.tzinfo
618
- and resolved_edge.valid_at.tzinfo
619
- and candidate.valid_at > resolved_edge.valid_at
632
+ candidate_valid_at_utc is not None
633
+ and resolved_edge_valid_at_utc is not None
634
+ and candidate_valid_at_utc > resolved_edge_valid_at_utc
620
635
  ):
621
636
  # Expire new edge since we have information about more recent events
622
637
  resolved_edge.invalid_at = candidate.valid_at
@@ -34,30 +34,13 @@ logger = logging.getLogger(__name__)
34
34
 
35
35
 
36
36
  async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
37
- if driver.aoss_client:
38
- await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
39
- return
40
37
  if delete_existing:
41
- records, _, _ = await driver.execute_query(
42
- """
43
- SHOW INDEXES YIELD name
44
- """,
45
- )
46
- index_names = [record['name'] for record in records]
47
- await semaphore_gather(
48
- *[
49
- driver.execute_query(
50
- """DROP INDEX $name""",
51
- name=name,
52
- )
53
- for name in index_names
54
- ]
55
- )
38
+ await driver.delete_all_indexes()
56
39
 
57
40
  range_indices: list[LiteralString] = get_range_indices(driver.provider)
58
41
 
59
- # Don't create fulltext indices if OpenSearch is being used
60
- if not driver.aoss_client:
42
+ # Don't create fulltext indices if search_interface is being used
43
+ if not driver.search_interface:
61
44
  fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
62
45
 
63
46
  if driver.provider == GraphProvider.KUZU:
@@ -95,8 +78,6 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
95
78
 
96
79
  async def delete_all(tx):
97
80
  await tx.run('MATCH (n) DETACH DELETE n')
98
- if driver.aoss_client:
99
- await driver.clear_aoss_indices()
100
81
 
101
82
  async def delete_group_ids(tx):
102
83
  labels = ['Entity', 'Episodic', 'Community']
@@ -153,9 +134,9 @@ async def retrieve_episodes(
153
134
 
154
135
  query: LiteralString = (
155
136
  """
156
- MATCH (e:Episodic)
157
- WHERE e.valid_at <= $reference_time
158
- """
137
+ MATCH (e:Episodic)
138
+ WHERE e.valid_at <= $reference_time
139
+ """
159
140
  + query_filter
160
141
  + """
161
142
  RETURN
@@ -53,6 +53,7 @@ from graphiti_core.utils.maintenance.dedup_helpers import (
53
53
  from graphiti_core.utils.maintenance.edge_operations import (
54
54
  filter_existing_duplicate_of_edges,
55
55
  )
56
+ from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
56
57
 
57
58
  logger = logging.getLogger(__name__)
58
59
 
@@ -64,6 +65,7 @@ async def extract_nodes_reflexion(
64
65
  episode: EpisodicNode,
65
66
  previous_episodes: list[EpisodicNode],
66
67
  node_names: list[str],
68
+ group_id: str | None = None,
67
69
  ) -> list[str]:
68
70
  # Prepare context for LLM
69
71
  context = {
@@ -73,7 +75,10 @@ async def extract_nodes_reflexion(
73
75
  }
74
76
 
75
77
  llm_response = await llm_client.generate_response(
76
- prompt_library.extract_nodes.reflexion(context), MissedEntities
78
+ prompt_library.extract_nodes.reflexion(context),
79
+ MissedEntities,
80
+ group_id=group_id,
81
+ prompt_name='extract_nodes.reflexion',
77
82
  )
78
83
  missed_entities = llm_response.get('missed_entities', [])
79
84
 
@@ -129,16 +134,22 @@ async def extract_nodes(
129
134
  llm_response = await llm_client.generate_response(
130
135
  prompt_library.extract_nodes.extract_message(context),
131
136
  response_model=ExtractedEntities,
137
+ group_id=episode.group_id,
138
+ prompt_name='extract_nodes.extract_message',
132
139
  )
133
140
  elif episode.source == EpisodeType.text:
134
141
  llm_response = await llm_client.generate_response(
135
142
  prompt_library.extract_nodes.extract_text(context),
136
143
  response_model=ExtractedEntities,
144
+ group_id=episode.group_id,
145
+ prompt_name='extract_nodes.extract_text',
137
146
  )
138
147
  elif episode.source == EpisodeType.json:
139
148
  llm_response = await llm_client.generate_response(
140
149
  prompt_library.extract_nodes.extract_json(context),
141
150
  response_model=ExtractedEntities,
151
+ group_id=episode.group_id,
152
+ prompt_name='extract_nodes.extract_json',
142
153
  )
143
154
 
144
155
  response_object = ExtractedEntities(**llm_response)
@@ -152,6 +163,7 @@ async def extract_nodes(
152
163
  episode,
153
164
  previous_episodes,
154
165
  [entity.name for entity in extracted_entities],
166
+ episode.group_id,
155
167
  )
156
168
 
157
169
  entities_missed = len(missing_entities) != 0
@@ -192,6 +204,7 @@ async def extract_nodes(
192
204
  logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
193
205
 
194
206
  logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
207
+
195
208
  return extracted_nodes
196
209
 
197
210
 
@@ -309,6 +322,7 @@ async def _resolve_with_llm(
309
322
  llm_response = await llm_client.generate_response(
310
323
  prompt_library.dedupe_nodes.nodes(context),
311
324
  response_model=NodeResolutions,
325
+ prompt_name='dedupe_nodes.nodes',
312
326
  )
313
327
 
314
328
  node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
@@ -477,63 +491,97 @@ async def extract_attributes_from_node(
477
491
  entity_type: type[BaseModel] | None = None,
478
492
  should_summarize_node: NodeSummaryFilter | None = None,
479
493
  ) -> EntityNode:
480
- node_context: dict[str, Any] = {
481
- 'name': node.name,
482
- 'summary': node.summary,
483
- 'entity_types': node.labels,
484
- 'attributes': node.attributes,
485
- }
494
+ # Extract attributes if entity type is defined and has attributes
495
+ llm_response = await _extract_entity_attributes(
496
+ llm_client, node, episode, previous_episodes, entity_type
497
+ )
486
498
 
487
- attributes_context: dict[str, Any] = {
488
- 'node': node_context,
489
- 'episode_content': episode.content if episode is not None else '',
490
- 'previous_episodes': (
491
- [ep.content for ep in previous_episodes] if previous_episodes is not None else []
492
- ),
493
- }
499
+ # Extract summary if needed
500
+ await _extract_entity_summary(
501
+ llm_client, node, episode, previous_episodes, should_summarize_node
502
+ )
503
+
504
+ node.attributes.update(llm_response)
505
+
506
+ return node
494
507
 
495
- summary_context: dict[str, Any] = {
496
- 'node': node_context,
497
- 'episode_content': episode.content if episode is not None else '',
498
- 'previous_episodes': (
499
- [ep.content for ep in previous_episodes] if previous_episodes is not None else []
500
- ),
501
- }
502
508
 
503
- has_entity_attributes: bool = bool(
504
- entity_type is not None and len(entity_type.model_fields) != 0
509
+ async def _extract_entity_attributes(
510
+ llm_client: LLMClient,
511
+ node: EntityNode,
512
+ episode: EpisodicNode | None,
513
+ previous_episodes: list[EpisodicNode] | None,
514
+ entity_type: type[BaseModel] | None,
515
+ ) -> dict[str, Any]:
516
+ if entity_type is None or len(entity_type.model_fields) == 0:
517
+ return {}
518
+
519
+ attributes_context = _build_episode_context(
520
+ # should not include summary
521
+ node_data={
522
+ 'name': node.name,
523
+ 'entity_types': node.labels,
524
+ 'attributes': node.attributes,
525
+ },
526
+ episode=episode,
527
+ previous_episodes=previous_episodes,
505
528
  )
506
529
 
507
- llm_response = (
508
- (
509
- await llm_client.generate_response(
510
- prompt_library.extract_nodes.extract_attributes(attributes_context),
511
- response_model=entity_type,
512
- model_size=ModelSize.small,
513
- )
514
- )
515
- if has_entity_attributes
516
- else {}
530
+ llm_response = await llm_client.generate_response(
531
+ prompt_library.extract_nodes.extract_attributes(attributes_context),
532
+ response_model=entity_type,
533
+ model_size=ModelSize.small,
534
+ group_id=node.group_id,
535
+ prompt_name='extract_nodes.extract_attributes',
517
536
  )
518
537
 
519
- # Determine if summary should be generated
520
- generate_summary = True
521
- if should_summarize_node is not None:
522
- generate_summary = await should_summarize_node(node)
523
-
524
- # Conditionally generate summary
525
- if generate_summary:
526
- summary_response = await llm_client.generate_response(
527
- prompt_library.extract_nodes.extract_summary(summary_context),
528
- response_model=EntitySummary,
529
- model_size=ModelSize.small,
530
- )
531
- node.summary = summary_response.get('summary', '')
538
+ # validate response
539
+ entity_type(**llm_response)
532
540
 
533
- if has_entity_attributes and entity_type is not None:
534
- entity_type(**llm_response)
535
- node_attributes = {key: value for key, value in llm_response.items()}
541
+ return llm_response
536
542
 
537
- node.attributes.update(node_attributes)
538
543
 
539
- return node
544
+ async def _extract_entity_summary(
545
+ llm_client: LLMClient,
546
+ node: EntityNode,
547
+ episode: EpisodicNode | None,
548
+ previous_episodes: list[EpisodicNode] | None,
549
+ should_summarize_node: NodeSummaryFilter | None,
550
+ ) -> None:
551
+ if should_summarize_node is not None and not await should_summarize_node(node):
552
+ return
553
+
554
+ summary_context = _build_episode_context(
555
+ node_data={
556
+ 'name': node.name,
557
+ 'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
558
+ 'entity_types': node.labels,
559
+ 'attributes': node.attributes,
560
+ },
561
+ episode=episode,
562
+ previous_episodes=previous_episodes,
563
+ )
564
+
565
+ summary_response = await llm_client.generate_response(
566
+ prompt_library.extract_nodes.extract_summary(summary_context),
567
+ response_model=EntitySummary,
568
+ model_size=ModelSize.small,
569
+ group_id=node.group_id,
570
+ prompt_name='extract_nodes.extract_summary',
571
+ )
572
+
573
+ node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
574
+
575
+
576
+ def _build_episode_context(
577
+ node_data: dict[str, Any],
578
+ episode: EpisodicNode | None,
579
+ previous_episodes: list[EpisodicNode] | None,
580
+ ) -> dict[str, Any]:
581
+ return {
582
+ 'node': node_data,
583
+ 'episode_content': episode.content if episode is not None else '',
584
+ 'previous_episodes': (
585
+ [ep.content for ep in previous_episodes] if previous_episodes is not None else []
586
+ ),
587
+ }