graphiti-core 0.21.0rc12__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +4 -211
- graphiti_core/driver/falkordb_driver.py +31 -3
- graphiti_core/driver/graph_operations/graph_operations.py +195 -0
- graphiti_core/driver/neo4j_driver.py +0 -49
- graphiti_core/driver/neptune_driver.py +43 -26
- graphiti_core/driver/search_interface/__init__.py +0 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +11 -34
- graphiti_core/graphiti.py +459 -326
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/llm_client/anthropic_client.py +64 -45
- graphiti_core/llm_client/client.py +67 -19
- graphiti_core/llm_client/gemini_client.py +73 -54
- graphiti_core/llm_client/openai_base_client.py +65 -43
- graphiti_core/llm_client/openai_generic_client.py +65 -43
- graphiti_core/models/edges/edge_db_queries.py +1 -0
- graphiti_core/models/nodes/node_db_queries.py +1 -0
- graphiti_core/nodes.py +26 -99
- graphiti_core/prompts/dedupe_edges.py +4 -4
- graphiti_core/prompts/dedupe_nodes.py +10 -10
- graphiti_core/prompts/extract_edges.py +4 -4
- graphiti_core/prompts/extract_nodes.py +26 -28
- graphiti_core/prompts/prompt_helpers.py +18 -2
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +22 -24
- graphiti_core/search/search_filters.py +0 -38
- graphiti_core/search/search_helpers.py +4 -4
- graphiti_core/search/search_utils.py +84 -220
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +16 -28
- graphiti_core/utils/maintenance/community_operations.py +4 -1
- graphiti_core/utils/maintenance/edge_operations.py +30 -15
- graphiti_core/utils/maintenance/graph_data_operations.py +6 -25
- graphiti_core/utils/maintenance/node_operations.py +99 -51
- graphiti_core/utils/maintenance/temporal_operations.py +4 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/METADATA +7 -3
- {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/RECORD +41 -35
- /graphiti_core/{utils/maintenance/utils.py → driver/graph_operations/__init__.py} +0 -0
- {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/tracer.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from collections.abc import Generator
|
|
19
|
+
from contextlib import AbstractContextManager, contextmanager, suppress
|
|
20
|
+
from typing import TYPE_CHECKING, Any
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from opentelemetry.trace import Span, StatusCode
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from opentelemetry.trace import Span, StatusCode
|
|
27
|
+
|
|
28
|
+
OTEL_AVAILABLE = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
OTEL_AVAILABLE = False
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TracerSpan(ABC):
|
|
34
|
+
"""Abstract base class for tracer spans."""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def add_attributes(self, attributes: dict[str, Any]) -> None:
|
|
38
|
+
"""Add attributes to the span."""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def set_status(self, status: str, description: str | None = None) -> None:
|
|
43
|
+
"""Set the status of the span."""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def record_exception(self, exception: Exception) -> None:
|
|
48
|
+
"""Record an exception in the span."""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Tracer(ABC):
|
|
53
|
+
"""Abstract base class for tracers."""
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def start_span(self, name: str) -> AbstractContextManager[TracerSpan]:
|
|
57
|
+
"""Start a new span with the given name."""
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class NoOpSpan(TracerSpan):
|
|
62
|
+
"""No-op span implementation that does nothing."""
|
|
63
|
+
|
|
64
|
+
def add_attributes(self, attributes: dict[str, Any]) -> None:
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
def set_status(self, status: str, description: str | None = None) -> None:
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
def record_exception(self, exception: Exception) -> None:
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class NoOpTracer(Tracer):
|
|
75
|
+
"""No-op tracer implementation that does nothing."""
|
|
76
|
+
|
|
77
|
+
@contextmanager
|
|
78
|
+
def start_span(self, name: str) -> Generator[NoOpSpan, None, None]:
|
|
79
|
+
"""Return a no-op span."""
|
|
80
|
+
yield NoOpSpan()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class OpenTelemetrySpan(TracerSpan):
|
|
84
|
+
"""Wrapper for OpenTelemetry span."""
|
|
85
|
+
|
|
86
|
+
def __init__(self, span: 'Span'):
|
|
87
|
+
self._span = span
|
|
88
|
+
|
|
89
|
+
def add_attributes(self, attributes: dict[str, Any]) -> None:
|
|
90
|
+
"""Add attributes to the OpenTelemetry span."""
|
|
91
|
+
try:
|
|
92
|
+
# Filter out None values and convert all values to appropriate types
|
|
93
|
+
filtered_attrs = {}
|
|
94
|
+
for key, value in attributes.items():
|
|
95
|
+
if value is not None:
|
|
96
|
+
# Convert to string if not a primitive type
|
|
97
|
+
if isinstance(value, str | int | float | bool):
|
|
98
|
+
filtered_attrs[key] = value
|
|
99
|
+
else:
|
|
100
|
+
filtered_attrs[key] = str(value)
|
|
101
|
+
|
|
102
|
+
if filtered_attrs:
|
|
103
|
+
self._span.set_attributes(filtered_attrs)
|
|
104
|
+
except Exception:
|
|
105
|
+
# Silently ignore tracing errors
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
def set_status(self, status: str, description: str | None = None) -> None:
|
|
109
|
+
"""Set the status of the OpenTelemetry span."""
|
|
110
|
+
try:
|
|
111
|
+
if OTEL_AVAILABLE:
|
|
112
|
+
if status == 'error':
|
|
113
|
+
self._span.set_status(StatusCode.ERROR, description)
|
|
114
|
+
elif status == 'ok':
|
|
115
|
+
self._span.set_status(StatusCode.OK, description)
|
|
116
|
+
except Exception:
|
|
117
|
+
# Silently ignore tracing errors
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
def record_exception(self, exception: Exception) -> None:
|
|
121
|
+
"""Record an exception in the OpenTelemetry span."""
|
|
122
|
+
with suppress(Exception):
|
|
123
|
+
self._span.record_exception(exception)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class OpenTelemetryTracer(Tracer):
|
|
127
|
+
"""Wrapper for OpenTelemetry tracer with configurable span name prefix."""
|
|
128
|
+
|
|
129
|
+
def __init__(self, tracer: Any, span_prefix: str = 'graphiti'):
|
|
130
|
+
"""
|
|
131
|
+
Initialize the OpenTelemetry tracer wrapper.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
tracer : opentelemetry.trace.Tracer
|
|
136
|
+
The OpenTelemetry tracer instance.
|
|
137
|
+
span_prefix : str, optional
|
|
138
|
+
Prefix to prepend to all span names. Defaults to 'graphiti'.
|
|
139
|
+
"""
|
|
140
|
+
if not OTEL_AVAILABLE:
|
|
141
|
+
raise ImportError(
|
|
142
|
+
'OpenTelemetry is not installed. Install it with: pip install opentelemetry-api'
|
|
143
|
+
)
|
|
144
|
+
self._tracer = tracer
|
|
145
|
+
self._span_prefix = span_prefix.rstrip('.')
|
|
146
|
+
|
|
147
|
+
@contextmanager
|
|
148
|
+
def start_span(self, name: str) -> Generator[OpenTelemetrySpan | NoOpSpan, None, None]:
|
|
149
|
+
"""Start a new OpenTelemetry span with the configured prefix."""
|
|
150
|
+
try:
|
|
151
|
+
full_name = f'{self._span_prefix}.{name}'
|
|
152
|
+
with self._tracer.start_as_current_span(full_name) as span:
|
|
153
|
+
yield OpenTelemetrySpan(span)
|
|
154
|
+
except Exception:
|
|
155
|
+
# If tracing fails, yield a no-op span to prevent breaking the operation
|
|
156
|
+
yield NoOpSpan()
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def create_tracer(otel_tracer: Any | None = None, span_prefix: str = 'graphiti') -> Tracer:
|
|
160
|
+
"""
|
|
161
|
+
Create a tracer instance.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
otel_tracer : opentelemetry.trace.Tracer | None, optional
|
|
166
|
+
An OpenTelemetry tracer instance. If None, a no-op tracer is returned.
|
|
167
|
+
span_prefix : str, optional
|
|
168
|
+
Prefix to prepend to all span names. Defaults to 'graphiti'.
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
Tracer
|
|
173
|
+
A tracer instance (either OpenTelemetryTracer or NoOpTracer).
|
|
174
|
+
|
|
175
|
+
Examples
|
|
176
|
+
--------
|
|
177
|
+
Using with OpenTelemetry:
|
|
178
|
+
|
|
179
|
+
>>> from opentelemetry import trace
|
|
180
|
+
>>> otel_tracer = trace.get_tracer(__name__)
|
|
181
|
+
>>> tracer = create_tracer(otel_tracer, span_prefix='myapp.graphiti')
|
|
182
|
+
|
|
183
|
+
Using no-op tracer:
|
|
184
|
+
|
|
185
|
+
>>> tracer = create_tracer() # Returns NoOpTracer
|
|
186
|
+
"""
|
|
187
|
+
if otel_tracer is None:
|
|
188
|
+
return NoOpTracer()
|
|
189
|
+
|
|
190
|
+
if not OTEL_AVAILABLE:
|
|
191
|
+
return NoOpTracer()
|
|
192
|
+
|
|
193
|
+
return OpenTelemetryTracer(otel_tracer, span_prefix)
|
|
@@ -24,9 +24,6 @@ from pydantic import BaseModel, Field
|
|
|
24
24
|
from typing_extensions import Any
|
|
25
25
|
|
|
26
26
|
from graphiti_core.driver.driver import (
|
|
27
|
-
ENTITY_EDGE_INDEX_NAME,
|
|
28
|
-
ENTITY_INDEX_NAME,
|
|
29
|
-
EPISODE_INDEX_NAME,
|
|
30
27
|
GraphDriver,
|
|
31
28
|
GraphDriverSession,
|
|
32
29
|
GraphProvider,
|
|
@@ -177,12 +174,10 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
177
174
|
'group_id': node.group_id,
|
|
178
175
|
'summary': node.summary,
|
|
179
176
|
'created_at': node.created_at,
|
|
177
|
+
'name_embedding': node.name_embedding,
|
|
178
|
+
'labels': list(set(node.labels + ['Entity'])),
|
|
180
179
|
}
|
|
181
180
|
|
|
182
|
-
if not bool(driver.aoss_client):
|
|
183
|
-
entity_data['name_embedding'] = node.name_embedding
|
|
184
|
-
|
|
185
|
-
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
|
186
181
|
if driver.provider == GraphProvider.KUZU:
|
|
187
182
|
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
|
188
183
|
entity_data['attributes'] = json.dumps(attributes)
|
|
@@ -207,11 +202,9 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
207
202
|
'expired_at': edge.expired_at,
|
|
208
203
|
'valid_at': edge.valid_at,
|
|
209
204
|
'invalid_at': edge.invalid_at,
|
|
205
|
+
'fact_embedding': edge.fact_embedding,
|
|
210
206
|
}
|
|
211
207
|
|
|
212
|
-
if not bool(driver.aoss_client):
|
|
213
|
-
edge_data['fact_embedding'] = edge.fact_embedding
|
|
214
|
-
|
|
215
208
|
if driver.provider == GraphProvider.KUZU:
|
|
216
209
|
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
|
217
210
|
edge_data['attributes'] = json.dumps(attributes)
|
|
@@ -220,7 +213,17 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
220
213
|
|
|
221
214
|
edges.append(edge_data)
|
|
222
215
|
|
|
223
|
-
if driver.
|
|
216
|
+
if driver.graph_operations_interface:
|
|
217
|
+
await driver.graph_operations_interface.episodic_node_save_bulk(
|
|
218
|
+
None, driver, tx, episodic_nodes
|
|
219
|
+
)
|
|
220
|
+
await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
|
|
221
|
+
await driver.graph_operations_interface.episodic_edge_save_bulk(
|
|
222
|
+
None, driver, tx, episodic_edges
|
|
223
|
+
)
|
|
224
|
+
await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
|
|
225
|
+
|
|
226
|
+
elif driver.provider == GraphProvider.KUZU:
|
|
224
227
|
# FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
|
|
225
228
|
episode_query = get_episode_node_save_bulk_query(driver.provider)
|
|
226
229
|
for episode in episodes:
|
|
@@ -237,9 +240,7 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
237
240
|
else:
|
|
238
241
|
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
|
239
242
|
await tx.run(
|
|
240
|
-
get_entity_node_save_bulk_query(
|
|
241
|
-
driver.provider, nodes, has_aoss=bool(driver.aoss_client)
|
|
242
|
-
),
|
|
243
|
+
get_entity_node_save_bulk_query(driver.provider, nodes),
|
|
243
244
|
nodes=nodes,
|
|
244
245
|
)
|
|
245
246
|
await tx.run(
|
|
@@ -247,23 +248,10 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
247
248
|
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
|
248
249
|
)
|
|
249
250
|
await tx.run(
|
|
250
|
-
get_entity_edge_save_bulk_query(driver.provider
|
|
251
|
+
get_entity_edge_save_bulk_query(driver.provider),
|
|
251
252
|
entity_edges=edges,
|
|
252
253
|
)
|
|
253
254
|
|
|
254
|
-
if bool(driver.aoss_client):
|
|
255
|
-
for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
|
|
256
|
-
if node_data.get('uuid') == entity_node.uuid:
|
|
257
|
-
node_data['name_embedding'] = entity_node.name_embedding
|
|
258
|
-
|
|
259
|
-
for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
|
|
260
|
-
if edge_data.get('uuid') == entity_edge.uuid:
|
|
261
|
-
edge_data['fact_embedding'] = entity_edge.fact_embedding
|
|
262
|
-
|
|
263
|
-
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
|
264
|
-
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
|
265
|
-
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
|
266
|
-
|
|
267
255
|
|
|
268
256
|
async def extract_nodes_and_edges_bulk(
|
|
269
257
|
clients: GraphitiClients,
|
|
@@ -138,7 +138,9 @@ async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -
|
|
|
138
138
|
}
|
|
139
139
|
|
|
140
140
|
llm_response = await llm_client.generate_response(
|
|
141
|
-
prompt_library.summarize_nodes.summarize_pair(context),
|
|
141
|
+
prompt_library.summarize_nodes.summarize_pair(context),
|
|
142
|
+
response_model=Summary,
|
|
143
|
+
prompt_name='summarize_nodes.summarize_pair',
|
|
142
144
|
)
|
|
143
145
|
|
|
144
146
|
pair_summary = llm_response.get('summary', '')
|
|
@@ -154,6 +156,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
|
|
|
154
156
|
llm_response = await llm_client.generate_response(
|
|
155
157
|
prompt_library.summarize_nodes.summary_description(context),
|
|
156
158
|
response_model=SummaryDescription,
|
|
159
|
+
prompt_name='summarize_nodes.summary_description',
|
|
157
160
|
)
|
|
158
161
|
|
|
159
162
|
description = llm_response.get('description', '')
|
|
@@ -139,6 +139,8 @@ async def extract_edges(
|
|
|
139
139
|
prompt_library.extract_edges.edge(context),
|
|
140
140
|
response_model=ExtractedEdges,
|
|
141
141
|
max_tokens=extract_edges_max_tokens,
|
|
142
|
+
group_id=group_id,
|
|
143
|
+
prompt_name='extract_edges.edge',
|
|
142
144
|
)
|
|
143
145
|
edges_data = ExtractedEdges(**llm_response).edges
|
|
144
146
|
|
|
@@ -150,6 +152,8 @@ async def extract_edges(
|
|
|
150
152
|
prompt_library.extract_edges.reflexion(context),
|
|
151
153
|
response_model=MissingFacts,
|
|
152
154
|
max_tokens=extract_edges_max_tokens,
|
|
155
|
+
group_id=group_id,
|
|
156
|
+
prompt_name='extract_edges.reflexion',
|
|
153
157
|
)
|
|
154
158
|
|
|
155
159
|
missing_facts = reflexion_response.get('missing_facts', [])
|
|
@@ -177,6 +181,10 @@ async def extract_edges(
|
|
|
177
181
|
valid_at_datetime = None
|
|
178
182
|
invalid_at_datetime = None
|
|
179
183
|
|
|
184
|
+
# Filter out empty edges
|
|
185
|
+
if not edge_data.fact.strip():
|
|
186
|
+
continue
|
|
187
|
+
|
|
180
188
|
source_node_idx = edge_data.source_entity_id
|
|
181
189
|
target_node_idx = edge_data.target_entity_id
|
|
182
190
|
|
|
@@ -405,21 +413,26 @@ def resolve_edge_contradictions(
|
|
|
405
413
|
invalidated_edges: list[EntityEdge] = []
|
|
406
414
|
for edge in invalidation_candidates:
|
|
407
415
|
# (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
|
|
416
|
+
edge_invalid_at_utc = ensure_utc(edge.invalid_at)
|
|
417
|
+
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
|
|
418
|
+
edge_valid_at_utc = ensure_utc(edge.valid_at)
|
|
419
|
+
resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
|
|
420
|
+
|
|
408
421
|
if (
|
|
409
|
-
|
|
410
|
-
and
|
|
411
|
-
and
|
|
422
|
+
edge_invalid_at_utc is not None
|
|
423
|
+
and resolved_edge_valid_at_utc is not None
|
|
424
|
+
and edge_invalid_at_utc <= resolved_edge_valid_at_utc
|
|
412
425
|
) or (
|
|
413
|
-
|
|
414
|
-
and
|
|
415
|
-
and
|
|
426
|
+
edge_valid_at_utc is not None
|
|
427
|
+
and resolved_edge_invalid_at_utc is not None
|
|
428
|
+
and resolved_edge_invalid_at_utc <= edge_valid_at_utc
|
|
416
429
|
):
|
|
417
430
|
continue
|
|
418
431
|
# New edge invalidates edge
|
|
419
432
|
elif (
|
|
420
|
-
|
|
421
|
-
and
|
|
422
|
-
and
|
|
433
|
+
edge_valid_at_utc is not None
|
|
434
|
+
and resolved_edge_valid_at_utc is not None
|
|
435
|
+
and edge_valid_at_utc < resolved_edge_valid_at_utc
|
|
423
436
|
):
|
|
424
437
|
edge.invalid_at = resolved_edge.valid_at
|
|
425
438
|
edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
|
|
@@ -520,6 +533,7 @@ async def resolve_extracted_edge(
|
|
|
520
533
|
prompt_library.dedupe_edges.resolve_edge(context),
|
|
521
534
|
response_model=EdgeDuplicate,
|
|
522
535
|
model_size=ModelSize.small,
|
|
536
|
+
prompt_name='dedupe_edges.resolve_edge',
|
|
523
537
|
)
|
|
524
538
|
response_object = EdgeDuplicate(**llm_response)
|
|
525
539
|
duplicate_facts = response_object.duplicate_facts
|
|
@@ -583,6 +597,7 @@ async def resolve_extracted_edge(
|
|
|
583
597
|
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
|
584
598
|
response_model=edge_model, # type: ignore
|
|
585
599
|
model_size=ModelSize.small,
|
|
600
|
+
prompt_name='extract_edges.extract_attributes',
|
|
586
601
|
)
|
|
587
602
|
|
|
588
603
|
resolved_edge.attributes = edge_attributes_response
|
|
@@ -609,14 +624,14 @@ async def resolve_extracted_edge(
|
|
|
609
624
|
|
|
610
625
|
# Determine if the new_edge needs to be expired
|
|
611
626
|
if resolved_edge.expired_at is None:
|
|
612
|
-
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
|
627
|
+
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
|
|
613
628
|
for candidate in invalidation_candidates:
|
|
629
|
+
candidate_valid_at_utc = ensure_utc(candidate.valid_at)
|
|
630
|
+
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
|
|
614
631
|
if (
|
|
615
|
-
|
|
616
|
-
and
|
|
617
|
-
and
|
|
618
|
-
and resolved_edge.valid_at.tzinfo
|
|
619
|
-
and candidate.valid_at > resolved_edge.valid_at
|
|
632
|
+
candidate_valid_at_utc is not None
|
|
633
|
+
and resolved_edge_valid_at_utc is not None
|
|
634
|
+
and candidate_valid_at_utc > resolved_edge_valid_at_utc
|
|
620
635
|
):
|
|
621
636
|
# Expire new edge since we have information about more recent events
|
|
622
637
|
resolved_edge.invalid_at = candidate.valid_at
|
|
@@ -34,30 +34,13 @@ logger = logging.getLogger(__name__)
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
|
37
|
-
if driver.aoss_client:
|
|
38
|
-
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
|
|
39
|
-
return
|
|
40
37
|
if delete_existing:
|
|
41
|
-
|
|
42
|
-
"""
|
|
43
|
-
SHOW INDEXES YIELD name
|
|
44
|
-
""",
|
|
45
|
-
)
|
|
46
|
-
index_names = [record['name'] for record in records]
|
|
47
|
-
await semaphore_gather(
|
|
48
|
-
*[
|
|
49
|
-
driver.execute_query(
|
|
50
|
-
"""DROP INDEX $name""",
|
|
51
|
-
name=name,
|
|
52
|
-
)
|
|
53
|
-
for name in index_names
|
|
54
|
-
]
|
|
55
|
-
)
|
|
38
|
+
await driver.delete_all_indexes()
|
|
56
39
|
|
|
57
40
|
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
|
58
41
|
|
|
59
|
-
# Don't create fulltext indices if
|
|
60
|
-
if not driver.
|
|
42
|
+
# Don't create fulltext indices if search_interface is being used
|
|
43
|
+
if not driver.search_interface:
|
|
61
44
|
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
|
62
45
|
|
|
63
46
|
if driver.provider == GraphProvider.KUZU:
|
|
@@ -95,8 +78,6 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
|
|
95
78
|
|
|
96
79
|
async def delete_all(tx):
|
|
97
80
|
await tx.run('MATCH (n) DETACH DELETE n')
|
|
98
|
-
if driver.aoss_client:
|
|
99
|
-
await driver.clear_aoss_indices()
|
|
100
81
|
|
|
101
82
|
async def delete_group_ids(tx):
|
|
102
83
|
labels = ['Entity', 'Episodic', 'Community']
|
|
@@ -153,9 +134,9 @@ async def retrieve_episodes(
|
|
|
153
134
|
|
|
154
135
|
query: LiteralString = (
|
|
155
136
|
"""
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
137
|
+
MATCH (e:Episodic)
|
|
138
|
+
WHERE e.valid_at <= $reference_time
|
|
139
|
+
"""
|
|
159
140
|
+ query_filter
|
|
160
141
|
+ """
|
|
161
142
|
RETURN
|
|
@@ -53,6 +53,7 @@ from graphiti_core.utils.maintenance.dedup_helpers import (
|
|
|
53
53
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
54
54
|
filter_existing_duplicate_of_edges,
|
|
55
55
|
)
|
|
56
|
+
from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
|
|
56
57
|
|
|
57
58
|
logger = logging.getLogger(__name__)
|
|
58
59
|
|
|
@@ -64,6 +65,7 @@ async def extract_nodes_reflexion(
|
|
|
64
65
|
episode: EpisodicNode,
|
|
65
66
|
previous_episodes: list[EpisodicNode],
|
|
66
67
|
node_names: list[str],
|
|
68
|
+
group_id: str | None = None,
|
|
67
69
|
) -> list[str]:
|
|
68
70
|
# Prepare context for LLM
|
|
69
71
|
context = {
|
|
@@ -73,7 +75,10 @@ async def extract_nodes_reflexion(
|
|
|
73
75
|
}
|
|
74
76
|
|
|
75
77
|
llm_response = await llm_client.generate_response(
|
|
76
|
-
prompt_library.extract_nodes.reflexion(context),
|
|
78
|
+
prompt_library.extract_nodes.reflexion(context),
|
|
79
|
+
MissedEntities,
|
|
80
|
+
group_id=group_id,
|
|
81
|
+
prompt_name='extract_nodes.reflexion',
|
|
77
82
|
)
|
|
78
83
|
missed_entities = llm_response.get('missed_entities', [])
|
|
79
84
|
|
|
@@ -129,16 +134,22 @@ async def extract_nodes(
|
|
|
129
134
|
llm_response = await llm_client.generate_response(
|
|
130
135
|
prompt_library.extract_nodes.extract_message(context),
|
|
131
136
|
response_model=ExtractedEntities,
|
|
137
|
+
group_id=episode.group_id,
|
|
138
|
+
prompt_name='extract_nodes.extract_message',
|
|
132
139
|
)
|
|
133
140
|
elif episode.source == EpisodeType.text:
|
|
134
141
|
llm_response = await llm_client.generate_response(
|
|
135
142
|
prompt_library.extract_nodes.extract_text(context),
|
|
136
143
|
response_model=ExtractedEntities,
|
|
144
|
+
group_id=episode.group_id,
|
|
145
|
+
prompt_name='extract_nodes.extract_text',
|
|
137
146
|
)
|
|
138
147
|
elif episode.source == EpisodeType.json:
|
|
139
148
|
llm_response = await llm_client.generate_response(
|
|
140
149
|
prompt_library.extract_nodes.extract_json(context),
|
|
141
150
|
response_model=ExtractedEntities,
|
|
151
|
+
group_id=episode.group_id,
|
|
152
|
+
prompt_name='extract_nodes.extract_json',
|
|
142
153
|
)
|
|
143
154
|
|
|
144
155
|
response_object = ExtractedEntities(**llm_response)
|
|
@@ -152,6 +163,7 @@ async def extract_nodes(
|
|
|
152
163
|
episode,
|
|
153
164
|
previous_episodes,
|
|
154
165
|
[entity.name for entity in extracted_entities],
|
|
166
|
+
episode.group_id,
|
|
155
167
|
)
|
|
156
168
|
|
|
157
169
|
entities_missed = len(missing_entities) != 0
|
|
@@ -192,6 +204,7 @@ async def extract_nodes(
|
|
|
192
204
|
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
|
193
205
|
|
|
194
206
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
207
|
+
|
|
195
208
|
return extracted_nodes
|
|
196
209
|
|
|
197
210
|
|
|
@@ -309,6 +322,7 @@ async def _resolve_with_llm(
|
|
|
309
322
|
llm_response = await llm_client.generate_response(
|
|
310
323
|
prompt_library.dedupe_nodes.nodes(context),
|
|
311
324
|
response_model=NodeResolutions,
|
|
325
|
+
prompt_name='dedupe_nodes.nodes',
|
|
312
326
|
)
|
|
313
327
|
|
|
314
328
|
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
|
|
@@ -477,63 +491,97 @@ async def extract_attributes_from_node(
|
|
|
477
491
|
entity_type: type[BaseModel] | None = None,
|
|
478
492
|
should_summarize_node: NodeSummaryFilter | None = None,
|
|
479
493
|
) -> EntityNode:
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
'attributes': node.attributes,
|
|
485
|
-
}
|
|
494
|
+
# Extract attributes if entity type is defined and has attributes
|
|
495
|
+
llm_response = await _extract_entity_attributes(
|
|
496
|
+
llm_client, node, episode, previous_episodes, entity_type
|
|
497
|
+
)
|
|
486
498
|
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
499
|
+
# Extract summary if needed
|
|
500
|
+
await _extract_entity_summary(
|
|
501
|
+
llm_client, node, episode, previous_episodes, should_summarize_node
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
node.attributes.update(llm_response)
|
|
505
|
+
|
|
506
|
+
return node
|
|
494
507
|
|
|
495
|
-
summary_context: dict[str, Any] = {
|
|
496
|
-
'node': node_context,
|
|
497
|
-
'episode_content': episode.content if episode is not None else '',
|
|
498
|
-
'previous_episodes': (
|
|
499
|
-
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
|
500
|
-
),
|
|
501
|
-
}
|
|
502
508
|
|
|
503
|
-
|
|
504
|
-
|
|
509
|
+
async def _extract_entity_attributes(
|
|
510
|
+
llm_client: LLMClient,
|
|
511
|
+
node: EntityNode,
|
|
512
|
+
episode: EpisodicNode | None,
|
|
513
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
514
|
+
entity_type: type[BaseModel] | None,
|
|
515
|
+
) -> dict[str, Any]:
|
|
516
|
+
if entity_type is None or len(entity_type.model_fields) == 0:
|
|
517
|
+
return {}
|
|
518
|
+
|
|
519
|
+
attributes_context = _build_episode_context(
|
|
520
|
+
# should not include summary
|
|
521
|
+
node_data={
|
|
522
|
+
'name': node.name,
|
|
523
|
+
'entity_types': node.labels,
|
|
524
|
+
'attributes': node.attributes,
|
|
525
|
+
},
|
|
526
|
+
episode=episode,
|
|
527
|
+
previous_episodes=previous_episodes,
|
|
505
528
|
)
|
|
506
529
|
|
|
507
|
-
llm_response = (
|
|
508
|
-
(
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
)
|
|
514
|
-
)
|
|
515
|
-
if has_entity_attributes
|
|
516
|
-
else {}
|
|
530
|
+
llm_response = await llm_client.generate_response(
|
|
531
|
+
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
|
532
|
+
response_model=entity_type,
|
|
533
|
+
model_size=ModelSize.small,
|
|
534
|
+
group_id=node.group_id,
|
|
535
|
+
prompt_name='extract_nodes.extract_attributes',
|
|
517
536
|
)
|
|
518
537
|
|
|
519
|
-
#
|
|
520
|
-
|
|
521
|
-
if should_summarize_node is not None:
|
|
522
|
-
generate_summary = await should_summarize_node(node)
|
|
523
|
-
|
|
524
|
-
# Conditionally generate summary
|
|
525
|
-
if generate_summary:
|
|
526
|
-
summary_response = await llm_client.generate_response(
|
|
527
|
-
prompt_library.extract_nodes.extract_summary(summary_context),
|
|
528
|
-
response_model=EntitySummary,
|
|
529
|
-
model_size=ModelSize.small,
|
|
530
|
-
)
|
|
531
|
-
node.summary = summary_response.get('summary', '')
|
|
538
|
+
# validate response
|
|
539
|
+
entity_type(**llm_response)
|
|
532
540
|
|
|
533
|
-
|
|
534
|
-
entity_type(**llm_response)
|
|
535
|
-
node_attributes = {key: value for key, value in llm_response.items()}
|
|
541
|
+
return llm_response
|
|
536
542
|
|
|
537
|
-
node.attributes.update(node_attributes)
|
|
538
543
|
|
|
539
|
-
|
|
544
|
+
async def _extract_entity_summary(
|
|
545
|
+
llm_client: LLMClient,
|
|
546
|
+
node: EntityNode,
|
|
547
|
+
episode: EpisodicNode | None,
|
|
548
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
549
|
+
should_summarize_node: NodeSummaryFilter | None,
|
|
550
|
+
) -> None:
|
|
551
|
+
if should_summarize_node is not None and not await should_summarize_node(node):
|
|
552
|
+
return
|
|
553
|
+
|
|
554
|
+
summary_context = _build_episode_context(
|
|
555
|
+
node_data={
|
|
556
|
+
'name': node.name,
|
|
557
|
+
'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
|
|
558
|
+
'entity_types': node.labels,
|
|
559
|
+
'attributes': node.attributes,
|
|
560
|
+
},
|
|
561
|
+
episode=episode,
|
|
562
|
+
previous_episodes=previous_episodes,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
summary_response = await llm_client.generate_response(
|
|
566
|
+
prompt_library.extract_nodes.extract_summary(summary_context),
|
|
567
|
+
response_model=EntitySummary,
|
|
568
|
+
model_size=ModelSize.small,
|
|
569
|
+
group_id=node.group_id,
|
|
570
|
+
prompt_name='extract_nodes.extract_summary',
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def _build_episode_context(
|
|
577
|
+
node_data: dict[str, Any],
|
|
578
|
+
episode: EpisodicNode | None,
|
|
579
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
580
|
+
) -> dict[str, Any]:
|
|
581
|
+
return {
|
|
582
|
+
'node': node_data,
|
|
583
|
+
'episode_content': episode.content if episode is not None else '',
|
|
584
|
+
'previous_episodes': (
|
|
585
|
+
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
|
586
|
+
),
|
|
587
|
+
}
|