graphiti-core 0.21.0rc13__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (41) hide show
  1. graphiti_core/driver/driver.py +4 -211
  2. graphiti_core/driver/falkordb_driver.py +31 -3
  3. graphiti_core/driver/graph_operations/graph_operations.py +195 -0
  4. graphiti_core/driver/neo4j_driver.py +0 -49
  5. graphiti_core/driver/neptune_driver.py +43 -26
  6. graphiti_core/driver/search_interface/__init__.py +0 -0
  7. graphiti_core/driver/search_interface/search_interface.py +89 -0
  8. graphiti_core/edges.py +11 -34
  9. graphiti_core/graphiti.py +459 -326
  10. graphiti_core/graphiti_types.py +2 -0
  11. graphiti_core/llm_client/anthropic_client.py +64 -45
  12. graphiti_core/llm_client/client.py +67 -19
  13. graphiti_core/llm_client/gemini_client.py +73 -54
  14. graphiti_core/llm_client/openai_base_client.py +65 -43
  15. graphiti_core/llm_client/openai_generic_client.py +65 -43
  16. graphiti_core/models/edges/edge_db_queries.py +1 -0
  17. graphiti_core/models/nodes/node_db_queries.py +1 -0
  18. graphiti_core/nodes.py +26 -99
  19. graphiti_core/prompts/dedupe_edges.py +4 -4
  20. graphiti_core/prompts/dedupe_nodes.py +10 -10
  21. graphiti_core/prompts/extract_edges.py +4 -4
  22. graphiti_core/prompts/extract_nodes.py +26 -28
  23. graphiti_core/prompts/prompt_helpers.py +18 -2
  24. graphiti_core/prompts/snippets.py +29 -0
  25. graphiti_core/prompts/summarize_nodes.py +22 -24
  26. graphiti_core/search/search_filters.py +0 -38
  27. graphiti_core/search/search_helpers.py +4 -4
  28. graphiti_core/search/search_utils.py +84 -220
  29. graphiti_core/tracer.py +193 -0
  30. graphiti_core/utils/bulk_utils.py +16 -28
  31. graphiti_core/utils/maintenance/community_operations.py +4 -1
  32. graphiti_core/utils/maintenance/edge_operations.py +26 -15
  33. graphiti_core/utils/maintenance/graph_data_operations.py +6 -25
  34. graphiti_core/utils/maintenance/node_operations.py +98 -51
  35. graphiti_core/utils/maintenance/temporal_operations.py +4 -1
  36. graphiti_core/utils/text_utils.py +53 -0
  37. {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/METADATA +7 -3
  38. {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/RECORD +41 -35
  39. /graphiti_core/{utils/maintenance/utils.py → driver/graph_operations/__init__.py} +0 -0
  40. {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/WHEEL +0 -0
  41. {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,193 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from abc import ABC, abstractmethod
18
+ from collections.abc import Generator
19
+ from contextlib import AbstractContextManager, contextmanager, suppress
20
+ from typing import TYPE_CHECKING, Any
21
+
22
+ if TYPE_CHECKING:
23
+ from opentelemetry.trace import Span, StatusCode
24
+
25
+ try:
26
+ from opentelemetry.trace import Span, StatusCode
27
+
28
+ OTEL_AVAILABLE = True
29
+ except ImportError:
30
+ OTEL_AVAILABLE = False
31
+
32
+
33
+ class TracerSpan(ABC):
34
+ """Abstract base class for tracer spans."""
35
+
36
+ @abstractmethod
37
+ def add_attributes(self, attributes: dict[str, Any]) -> None:
38
+ """Add attributes to the span."""
39
+ pass
40
+
41
+ @abstractmethod
42
+ def set_status(self, status: str, description: str | None = None) -> None:
43
+ """Set the status of the span."""
44
+ pass
45
+
46
+ @abstractmethod
47
+ def record_exception(self, exception: Exception) -> None:
48
+ """Record an exception in the span."""
49
+ pass
50
+
51
+
52
+ class Tracer(ABC):
53
+ """Abstract base class for tracers."""
54
+
55
+ @abstractmethod
56
+ def start_span(self, name: str) -> AbstractContextManager[TracerSpan]:
57
+ """Start a new span with the given name."""
58
+ pass
59
+
60
+
61
+ class NoOpSpan(TracerSpan):
62
+ """No-op span implementation that does nothing."""
63
+
64
+ def add_attributes(self, attributes: dict[str, Any]) -> None:
65
+ pass
66
+
67
+ def set_status(self, status: str, description: str | None = None) -> None:
68
+ pass
69
+
70
+ def record_exception(self, exception: Exception) -> None:
71
+ pass
72
+
73
+
74
+ class NoOpTracer(Tracer):
75
+ """No-op tracer implementation that does nothing."""
76
+
77
+ @contextmanager
78
+ def start_span(self, name: str) -> Generator[NoOpSpan, None, None]:
79
+ """Return a no-op span."""
80
+ yield NoOpSpan()
81
+
82
+
83
+ class OpenTelemetrySpan(TracerSpan):
84
+ """Wrapper for OpenTelemetry span."""
85
+
86
+ def __init__(self, span: 'Span'):
87
+ self._span = span
88
+
89
+ def add_attributes(self, attributes: dict[str, Any]) -> None:
90
+ """Add attributes to the OpenTelemetry span."""
91
+ try:
92
+ # Filter out None values and convert all values to appropriate types
93
+ filtered_attrs = {}
94
+ for key, value in attributes.items():
95
+ if value is not None:
96
+ # Convert to string if not a primitive type
97
+ if isinstance(value, str | int | float | bool):
98
+ filtered_attrs[key] = value
99
+ else:
100
+ filtered_attrs[key] = str(value)
101
+
102
+ if filtered_attrs:
103
+ self._span.set_attributes(filtered_attrs)
104
+ except Exception:
105
+ # Silently ignore tracing errors
106
+ pass
107
+
108
+ def set_status(self, status: str, description: str | None = None) -> None:
109
+ """Set the status of the OpenTelemetry span."""
110
+ try:
111
+ if OTEL_AVAILABLE:
112
+ if status == 'error':
113
+ self._span.set_status(StatusCode.ERROR, description)
114
+ elif status == 'ok':
115
+ self._span.set_status(StatusCode.OK, description)
116
+ except Exception:
117
+ # Silently ignore tracing errors
118
+ pass
119
+
120
+ def record_exception(self, exception: Exception) -> None:
121
+ """Record an exception in the OpenTelemetry span."""
122
+ with suppress(Exception):
123
+ self._span.record_exception(exception)
124
+
125
+
126
+ class OpenTelemetryTracer(Tracer):
127
+ """Wrapper for OpenTelemetry tracer with configurable span name prefix."""
128
+
129
+ def __init__(self, tracer: Any, span_prefix: str = 'graphiti'):
130
+ """
131
+ Initialize the OpenTelemetry tracer wrapper.
132
+
133
+ Parameters
134
+ ----------
135
+ tracer : opentelemetry.trace.Tracer
136
+ The OpenTelemetry tracer instance.
137
+ span_prefix : str, optional
138
+ Prefix to prepend to all span names. Defaults to 'graphiti'.
139
+ """
140
+ if not OTEL_AVAILABLE:
141
+ raise ImportError(
142
+ 'OpenTelemetry is not installed. Install it with: pip install opentelemetry-api'
143
+ )
144
+ self._tracer = tracer
145
+ self._span_prefix = span_prefix.rstrip('.')
146
+
147
+ @contextmanager
148
+ def start_span(self, name: str) -> Generator[OpenTelemetrySpan | NoOpSpan, None, None]:
149
+ """Start a new OpenTelemetry span with the configured prefix."""
150
+ try:
151
+ full_name = f'{self._span_prefix}.{name}'
152
+ with self._tracer.start_as_current_span(full_name) as span:
153
+ yield OpenTelemetrySpan(span)
154
+ except Exception:
155
+ # If tracing fails, yield a no-op span to prevent breaking the operation
156
+ yield NoOpSpan()
157
+
158
+
159
+ def create_tracer(otel_tracer: Any | None = None, span_prefix: str = 'graphiti') -> Tracer:
160
+ """
161
+ Create a tracer instance.
162
+
163
+ Parameters
164
+ ----------
165
+ otel_tracer : opentelemetry.trace.Tracer | None, optional
166
+ An OpenTelemetry tracer instance. If None, a no-op tracer is returned.
167
+ span_prefix : str, optional
168
+ Prefix to prepend to all span names. Defaults to 'graphiti'.
169
+
170
+ Returns
171
+ -------
172
+ Tracer
173
+ A tracer instance (either OpenTelemetryTracer or NoOpTracer).
174
+
175
+ Examples
176
+ --------
177
+ Using with OpenTelemetry:
178
+
179
+ >>> from opentelemetry import trace
180
+ >>> otel_tracer = trace.get_tracer(__name__)
181
+ >>> tracer = create_tracer(otel_tracer, span_prefix='myapp.graphiti')
182
+
183
+ Using no-op tracer:
184
+
185
+ >>> tracer = create_tracer() # Returns NoOpTracer
186
+ """
187
+ if otel_tracer is None:
188
+ return NoOpTracer()
189
+
190
+ if not OTEL_AVAILABLE:
191
+ return NoOpTracer()
192
+
193
+ return OpenTelemetryTracer(otel_tracer, span_prefix)
@@ -24,9 +24,6 @@ from pydantic import BaseModel, Field
24
24
  from typing_extensions import Any
25
25
 
26
26
  from graphiti_core.driver.driver import (
27
- ENTITY_EDGE_INDEX_NAME,
28
- ENTITY_INDEX_NAME,
29
- EPISODE_INDEX_NAME,
30
27
  GraphDriver,
31
28
  GraphDriverSession,
32
29
  GraphProvider,
@@ -177,12 +174,10 @@ async def add_nodes_and_edges_bulk_tx(
177
174
  'group_id': node.group_id,
178
175
  'summary': node.summary,
179
176
  'created_at': node.created_at,
177
+ 'name_embedding': node.name_embedding,
178
+ 'labels': list(set(node.labels + ['Entity'])),
180
179
  }
181
180
 
182
- if not bool(driver.aoss_client):
183
- entity_data['name_embedding'] = node.name_embedding
184
-
185
- entity_data['labels'] = list(set(node.labels + ['Entity']))
186
181
  if driver.provider == GraphProvider.KUZU:
187
182
  attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
188
183
  entity_data['attributes'] = json.dumps(attributes)
@@ -207,11 +202,9 @@ async def add_nodes_and_edges_bulk_tx(
207
202
  'expired_at': edge.expired_at,
208
203
  'valid_at': edge.valid_at,
209
204
  'invalid_at': edge.invalid_at,
205
+ 'fact_embedding': edge.fact_embedding,
210
206
  }
211
207
 
212
- if not bool(driver.aoss_client):
213
- edge_data['fact_embedding'] = edge.fact_embedding
214
-
215
208
  if driver.provider == GraphProvider.KUZU:
216
209
  attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
217
210
  edge_data['attributes'] = json.dumps(attributes)
@@ -220,7 +213,17 @@ async def add_nodes_and_edges_bulk_tx(
220
213
 
221
214
  edges.append(edge_data)
222
215
 
223
- if driver.provider == GraphProvider.KUZU:
216
+ if driver.graph_operations_interface:
217
+ await driver.graph_operations_interface.episodic_node_save_bulk(
218
+ None, driver, tx, episodic_nodes
219
+ )
220
+ await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
221
+ await driver.graph_operations_interface.episodic_edge_save_bulk(
222
+ None, driver, tx, episodic_edges
223
+ )
224
+ await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
225
+
226
+ elif driver.provider == GraphProvider.KUZU:
224
227
  # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
225
228
  episode_query = get_episode_node_save_bulk_query(driver.provider)
226
229
  for episode in episodes:
@@ -237,9 +240,7 @@ async def add_nodes_and_edges_bulk_tx(
237
240
  else:
238
241
  await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
239
242
  await tx.run(
240
- get_entity_node_save_bulk_query(
241
- driver.provider, nodes, has_aoss=bool(driver.aoss_client)
242
- ),
243
+ get_entity_node_save_bulk_query(driver.provider, nodes),
243
244
  nodes=nodes,
244
245
  )
245
246
  await tx.run(
@@ -247,23 +248,10 @@ async def add_nodes_and_edges_bulk_tx(
247
248
  episodic_edges=[edge.model_dump() for edge in episodic_edges],
248
249
  )
249
250
  await tx.run(
250
- get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)),
251
+ get_entity_edge_save_bulk_query(driver.provider),
251
252
  entity_edges=edges,
252
253
  )
253
254
 
254
- if bool(driver.aoss_client):
255
- for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
256
- if node_data.get('uuid') == entity_node.uuid:
257
- node_data['name_embedding'] = entity_node.name_embedding
258
-
259
- for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
260
- if edge_data.get('uuid') == entity_edge.uuid:
261
- edge_data['fact_embedding'] = entity_edge.fact_embedding
262
-
263
- await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
264
- await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
265
- await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
266
-
267
255
 
268
256
  async def extract_nodes_and_edges_bulk(
269
257
  clients: GraphitiClients,
@@ -138,7 +138,9 @@ async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -
138
138
  }
139
139
 
140
140
  llm_response = await llm_client.generate_response(
141
- prompt_library.summarize_nodes.summarize_pair(context), response_model=Summary
141
+ prompt_library.summarize_nodes.summarize_pair(context),
142
+ response_model=Summary,
143
+ prompt_name='summarize_nodes.summarize_pair',
142
144
  )
143
145
 
144
146
  pair_summary = llm_response.get('summary', '')
@@ -154,6 +156,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
154
156
  llm_response = await llm_client.generate_response(
155
157
  prompt_library.summarize_nodes.summary_description(context),
156
158
  response_model=SummaryDescription,
159
+ prompt_name='summarize_nodes.summary_description',
157
160
  )
158
161
 
159
162
  description = llm_response.get('description', '')
@@ -139,6 +139,8 @@ async def extract_edges(
139
139
  prompt_library.extract_edges.edge(context),
140
140
  response_model=ExtractedEdges,
141
141
  max_tokens=extract_edges_max_tokens,
142
+ group_id=group_id,
143
+ prompt_name='extract_edges.edge',
142
144
  )
143
145
  edges_data = ExtractedEdges(**llm_response).edges
144
146
 
@@ -150,6 +152,8 @@ async def extract_edges(
150
152
  prompt_library.extract_edges.reflexion(context),
151
153
  response_model=MissingFacts,
152
154
  max_tokens=extract_edges_max_tokens,
155
+ group_id=group_id,
156
+ prompt_name='extract_edges.reflexion',
153
157
  )
154
158
 
155
159
  missing_facts = reflexion_response.get('missing_facts', [])
@@ -409,21 +413,26 @@ def resolve_edge_contradictions(
409
413
  invalidated_edges: list[EntityEdge] = []
410
414
  for edge in invalidation_candidates:
411
415
  # (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
416
+ edge_invalid_at_utc = ensure_utc(edge.invalid_at)
417
+ resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
418
+ edge_valid_at_utc = ensure_utc(edge.valid_at)
419
+ resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
420
+
412
421
  if (
413
- edge.invalid_at is not None
414
- and resolved_edge.valid_at is not None
415
- and edge.invalid_at <= resolved_edge.valid_at
422
+ edge_invalid_at_utc is not None
423
+ and resolved_edge_valid_at_utc is not None
424
+ and edge_invalid_at_utc <= resolved_edge_valid_at_utc
416
425
  ) or (
417
- edge.valid_at is not None
418
- and resolved_edge.invalid_at is not None
419
- and resolved_edge.invalid_at <= edge.valid_at
426
+ edge_valid_at_utc is not None
427
+ and resolved_edge_invalid_at_utc is not None
428
+ and resolved_edge_invalid_at_utc <= edge_valid_at_utc
420
429
  ):
421
430
  continue
422
431
  # New edge invalidates edge
423
432
  elif (
424
- edge.valid_at is not None
425
- and resolved_edge.valid_at is not None
426
- and edge.valid_at < resolved_edge.valid_at
433
+ edge_valid_at_utc is not None
434
+ and resolved_edge_valid_at_utc is not None
435
+ and edge_valid_at_utc < resolved_edge_valid_at_utc
427
436
  ):
428
437
  edge.invalid_at = resolved_edge.valid_at
429
438
  edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
@@ -524,6 +533,7 @@ async def resolve_extracted_edge(
524
533
  prompt_library.dedupe_edges.resolve_edge(context),
525
534
  response_model=EdgeDuplicate,
526
535
  model_size=ModelSize.small,
536
+ prompt_name='dedupe_edges.resolve_edge',
527
537
  )
528
538
  response_object = EdgeDuplicate(**llm_response)
529
539
  duplicate_facts = response_object.duplicate_facts
@@ -587,6 +597,7 @@ async def resolve_extracted_edge(
587
597
  prompt_library.extract_edges.extract_attributes(edge_attributes_context),
588
598
  response_model=edge_model, # type: ignore
589
599
  model_size=ModelSize.small,
600
+ prompt_name='extract_edges.extract_attributes',
590
601
  )
591
602
 
592
603
  resolved_edge.attributes = edge_attributes_response
@@ -613,14 +624,14 @@ async def resolve_extracted_edge(
613
624
 
614
625
  # Determine if the new_edge needs to be expired
615
626
  if resolved_edge.expired_at is None:
616
- invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
627
+ invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
617
628
  for candidate in invalidation_candidates:
629
+ candidate_valid_at_utc = ensure_utc(candidate.valid_at)
630
+ resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
618
631
  if (
619
- candidate.valid_at
620
- and resolved_edge.valid_at
621
- and candidate.valid_at.tzinfo
622
- and resolved_edge.valid_at.tzinfo
623
- and candidate.valid_at > resolved_edge.valid_at
632
+ candidate_valid_at_utc is not None
633
+ and resolved_edge_valid_at_utc is not None
634
+ and candidate_valid_at_utc > resolved_edge_valid_at_utc
624
635
  ):
625
636
  # Expire new edge since we have information about more recent events
626
637
  resolved_edge.invalid_at = candidate.valid_at
@@ -34,30 +34,13 @@ logger = logging.getLogger(__name__)
34
34
 
35
35
 
36
36
  async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
37
- if driver.aoss_client:
38
- await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
39
- return
40
37
  if delete_existing:
41
- records, _, _ = await driver.execute_query(
42
- """
43
- SHOW INDEXES YIELD name
44
- """,
45
- )
46
- index_names = [record['name'] for record in records]
47
- await semaphore_gather(
48
- *[
49
- driver.execute_query(
50
- """DROP INDEX $name""",
51
- name=name,
52
- )
53
- for name in index_names
54
- ]
55
- )
38
+ await driver.delete_all_indexes()
56
39
 
57
40
  range_indices: list[LiteralString] = get_range_indices(driver.provider)
58
41
 
59
- # Don't create fulltext indices if OpenSearch is being used
60
- if not driver.aoss_client:
42
+ # Don't create fulltext indices if search_interface is being used
43
+ if not driver.search_interface:
61
44
  fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
62
45
 
63
46
  if driver.provider == GraphProvider.KUZU:
@@ -95,8 +78,6 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
95
78
 
96
79
  async def delete_all(tx):
97
80
  await tx.run('MATCH (n) DETACH DELETE n')
98
- if driver.aoss_client:
99
- await driver.clear_aoss_indices()
100
81
 
101
82
  async def delete_group_ids(tx):
102
83
  labels = ['Entity', 'Episodic', 'Community']
@@ -153,9 +134,9 @@ async def retrieve_episodes(
153
134
 
154
135
  query: LiteralString = (
155
136
  """
156
- MATCH (e:Episodic)
157
- WHERE e.valid_at <= $reference_time
158
- """
137
+ MATCH (e:Episodic)
138
+ WHERE e.valid_at <= $reference_time
139
+ """
159
140
  + query_filter
160
141
  + """
161
142
  RETURN
@@ -53,6 +53,7 @@ from graphiti_core.utils.maintenance.dedup_helpers import (
53
53
  from graphiti_core.utils.maintenance.edge_operations import (
54
54
  filter_existing_duplicate_of_edges,
55
55
  )
56
+ from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
56
57
 
57
58
  logger = logging.getLogger(__name__)
58
59
 
@@ -64,6 +65,7 @@ async def extract_nodes_reflexion(
64
65
  episode: EpisodicNode,
65
66
  previous_episodes: list[EpisodicNode],
66
67
  node_names: list[str],
68
+ group_id: str | None = None,
67
69
  ) -> list[str]:
68
70
  # Prepare context for LLM
69
71
  context = {
@@ -73,7 +75,10 @@ async def extract_nodes_reflexion(
73
75
  }
74
76
 
75
77
  llm_response = await llm_client.generate_response(
76
- prompt_library.extract_nodes.reflexion(context), MissedEntities
78
+ prompt_library.extract_nodes.reflexion(context),
79
+ MissedEntities,
80
+ group_id=group_id,
81
+ prompt_name='extract_nodes.reflexion',
77
82
  )
78
83
  missed_entities = llm_response.get('missed_entities', [])
79
84
 
@@ -129,16 +134,22 @@ async def extract_nodes(
129
134
  llm_response = await llm_client.generate_response(
130
135
  prompt_library.extract_nodes.extract_message(context),
131
136
  response_model=ExtractedEntities,
137
+ group_id=episode.group_id,
138
+ prompt_name='extract_nodes.extract_message',
132
139
  )
133
140
  elif episode.source == EpisodeType.text:
134
141
  llm_response = await llm_client.generate_response(
135
142
  prompt_library.extract_nodes.extract_text(context),
136
143
  response_model=ExtractedEntities,
144
+ group_id=episode.group_id,
145
+ prompt_name='extract_nodes.extract_text',
137
146
  )
138
147
  elif episode.source == EpisodeType.json:
139
148
  llm_response = await llm_client.generate_response(
140
149
  prompt_library.extract_nodes.extract_json(context),
141
150
  response_model=ExtractedEntities,
151
+ group_id=episode.group_id,
152
+ prompt_name='extract_nodes.extract_json',
142
153
  )
143
154
 
144
155
  response_object = ExtractedEntities(**llm_response)
@@ -152,6 +163,7 @@ async def extract_nodes(
152
163
  episode,
153
164
  previous_episodes,
154
165
  [entity.name for entity in extracted_entities],
166
+ episode.group_id,
155
167
  )
156
168
 
157
169
  entities_missed = len(missing_entities) != 0
@@ -310,6 +322,7 @@ async def _resolve_with_llm(
310
322
  llm_response = await llm_client.generate_response(
311
323
  prompt_library.dedupe_nodes.nodes(context),
312
324
  response_model=NodeResolutions,
325
+ prompt_name='dedupe_nodes.nodes',
313
326
  )
314
327
 
315
328
  node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
@@ -478,63 +491,97 @@ async def extract_attributes_from_node(
478
491
  entity_type: type[BaseModel] | None = None,
479
492
  should_summarize_node: NodeSummaryFilter | None = None,
480
493
  ) -> EntityNode:
481
- node_context: dict[str, Any] = {
482
- 'name': node.name,
483
- 'summary': node.summary,
484
- 'entity_types': node.labels,
485
- 'attributes': node.attributes,
486
- }
494
+ # Extract attributes if entity type is defined and has attributes
495
+ llm_response = await _extract_entity_attributes(
496
+ llm_client, node, episode, previous_episodes, entity_type
497
+ )
487
498
 
488
- attributes_context: dict[str, Any] = {
489
- 'node': node_context,
490
- 'episode_content': episode.content if episode is not None else '',
491
- 'previous_episodes': (
492
- [ep.content for ep in previous_episodes] if previous_episodes is not None else []
493
- ),
494
- }
499
+ # Extract summary if needed
500
+ await _extract_entity_summary(
501
+ llm_client, node, episode, previous_episodes, should_summarize_node
502
+ )
495
503
 
496
- summary_context: dict[str, Any] = {
497
- 'node': node_context,
498
- 'episode_content': episode.content if episode is not None else '',
499
- 'previous_episodes': (
500
- [ep.content for ep in previous_episodes] if previous_episodes is not None else []
501
- ),
502
- }
504
+ node.attributes.update(llm_response)
505
+
506
+ return node
503
507
 
504
- has_entity_attributes: bool = bool(
505
- entity_type is not None and len(entity_type.model_fields) != 0
508
+
509
+ async def _extract_entity_attributes(
510
+ llm_client: LLMClient,
511
+ node: EntityNode,
512
+ episode: EpisodicNode | None,
513
+ previous_episodes: list[EpisodicNode] | None,
514
+ entity_type: type[BaseModel] | None,
515
+ ) -> dict[str, Any]:
516
+ if entity_type is None or len(entity_type.model_fields) == 0:
517
+ return {}
518
+
519
+ attributes_context = _build_episode_context(
520
+ # should not include summary
521
+ node_data={
522
+ 'name': node.name,
523
+ 'entity_types': node.labels,
524
+ 'attributes': node.attributes,
525
+ },
526
+ episode=episode,
527
+ previous_episodes=previous_episodes,
506
528
  )
507
529
 
508
- llm_response = (
509
- (
510
- await llm_client.generate_response(
511
- prompt_library.extract_nodes.extract_attributes(attributes_context),
512
- response_model=entity_type,
513
- model_size=ModelSize.small,
514
- )
515
- )
516
- if has_entity_attributes
517
- else {}
530
+ llm_response = await llm_client.generate_response(
531
+ prompt_library.extract_nodes.extract_attributes(attributes_context),
532
+ response_model=entity_type,
533
+ model_size=ModelSize.small,
534
+ group_id=node.group_id,
535
+ prompt_name='extract_nodes.extract_attributes',
518
536
  )
519
537
 
520
- # Determine if summary should be generated
521
- generate_summary = True
522
- if should_summarize_node is not None:
523
- generate_summary = await should_summarize_node(node)
524
-
525
- # Conditionally generate summary
526
- if generate_summary:
527
- summary_response = await llm_client.generate_response(
528
- prompt_library.extract_nodes.extract_summary(summary_context),
529
- response_model=EntitySummary,
530
- model_size=ModelSize.small,
531
- )
532
- node.summary = summary_response.get('summary', '')
538
+ # validate response
539
+ entity_type(**llm_response)
533
540
 
534
- if has_entity_attributes and entity_type is not None:
535
- entity_type(**llm_response)
536
- node_attributes = {key: value for key, value in llm_response.items()}
541
+ return llm_response
537
542
 
538
- node.attributes.update(node_attributes)
539
543
 
540
- return node
544
+ async def _extract_entity_summary(
545
+ llm_client: LLMClient,
546
+ node: EntityNode,
547
+ episode: EpisodicNode | None,
548
+ previous_episodes: list[EpisodicNode] | None,
549
+ should_summarize_node: NodeSummaryFilter | None,
550
+ ) -> None:
551
+ if should_summarize_node is not None and not await should_summarize_node(node):
552
+ return
553
+
554
+ summary_context = _build_episode_context(
555
+ node_data={
556
+ 'name': node.name,
557
+ 'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
558
+ 'entity_types': node.labels,
559
+ 'attributes': node.attributes,
560
+ },
561
+ episode=episode,
562
+ previous_episodes=previous_episodes,
563
+ )
564
+
565
+ summary_response = await llm_client.generate_response(
566
+ prompt_library.extract_nodes.extract_summary(summary_context),
567
+ response_model=EntitySummary,
568
+ model_size=ModelSize.small,
569
+ group_id=node.group_id,
570
+ prompt_name='extract_nodes.extract_summary',
571
+ )
572
+
573
+ node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
574
+
575
+
576
+ def _build_episode_context(
577
+ node_data: dict[str, Any],
578
+ episode: EpisodicNode | None,
579
+ previous_episodes: list[EpisodicNode] | None,
580
+ ) -> dict[str, Any]:
581
+ return {
582
+ 'node': node_data,
583
+ 'episode_content': episode.content if episode is not None else '',
584
+ 'previous_episodes': (
585
+ [ep.content for ep in previous_episodes] if previous_episodes is not None else []
586
+ ),
587
+ }
@@ -43,7 +43,9 @@ async def extract_edge_dates(
43
43
  'reference_timestamp': current_episode.valid_at.isoformat(),
44
44
  }
45
45
  llm_response = await llm_client.generate_response(
46
- prompt_library.extract_edge_dates.v1(context), response_model=EdgeDates
46
+ prompt_library.extract_edge_dates.v1(context),
47
+ response_model=EdgeDates,
48
+ prompt_name='extract_edge_dates.v1',
47
49
  )
48
50
 
49
51
  valid_at = llm_response.get('valid_at')
@@ -90,6 +92,7 @@ async def get_edge_contradictions(
90
92
  prompt_library.invalidate_edges.v2(context),
91
93
  response_model=InvalidatedEdges,
92
94
  model_size=ModelSize.small,
95
+ prompt_name='invalidate_edges.v2',
93
96
  )
94
97
 
95
98
  contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])