graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +70 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +635 -260
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +37 -15
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +92 -48
- graphiti_core/llm_client/openai_client.py +39 -9
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +24 -15
- graphiti_core/prompts/extract_nodes.py +76 -35
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +110 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/content_chunking.py +702 -0
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +306 -156
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +466 -206
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
- graphiti_core-0.25.3.dist-info/RECORD +87 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
graphiti_core/tracer.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from collections.abc import Generator
|
|
19
|
+
from contextlib import AbstractContextManager, contextmanager, suppress
|
|
20
|
+
from typing import TYPE_CHECKING, Any
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from opentelemetry.trace import Span, StatusCode
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from opentelemetry.trace import Span, StatusCode
|
|
27
|
+
|
|
28
|
+
OTEL_AVAILABLE = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
OTEL_AVAILABLE = False
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TracerSpan(ABC):
|
|
34
|
+
"""Abstract base class for tracer spans."""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def add_attributes(self, attributes: dict[str, Any]) -> None:
|
|
38
|
+
"""Add attributes to the span."""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def set_status(self, status: str, description: str | None = None) -> None:
|
|
43
|
+
"""Set the status of the span."""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def record_exception(self, exception: Exception) -> None:
|
|
48
|
+
"""Record an exception in the span."""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Tracer(ABC):
|
|
53
|
+
"""Abstract base class for tracers."""
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def start_span(self, name: str) -> AbstractContextManager[TracerSpan]:
|
|
57
|
+
"""Start a new span with the given name."""
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class NoOpSpan(TracerSpan):
|
|
62
|
+
"""No-op span implementation that does nothing."""
|
|
63
|
+
|
|
64
|
+
def add_attributes(self, attributes: dict[str, Any]) -> None:
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
def set_status(self, status: str, description: str | None = None) -> None:
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
def record_exception(self, exception: Exception) -> None:
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class NoOpTracer(Tracer):
|
|
75
|
+
"""No-op tracer implementation that does nothing."""
|
|
76
|
+
|
|
77
|
+
@contextmanager
|
|
78
|
+
def start_span(self, name: str) -> Generator[NoOpSpan, None, None]:
|
|
79
|
+
"""Return a no-op span."""
|
|
80
|
+
yield NoOpSpan()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class OpenTelemetrySpan(TracerSpan):
|
|
84
|
+
"""Wrapper for OpenTelemetry span."""
|
|
85
|
+
|
|
86
|
+
def __init__(self, span: 'Span'):
|
|
87
|
+
self._span = span
|
|
88
|
+
|
|
89
|
+
def add_attributes(self, attributes: dict[str, Any]) -> None:
|
|
90
|
+
"""Add attributes to the OpenTelemetry span."""
|
|
91
|
+
try:
|
|
92
|
+
# Filter out None values and convert all values to appropriate types
|
|
93
|
+
filtered_attrs = {}
|
|
94
|
+
for key, value in attributes.items():
|
|
95
|
+
if value is not None:
|
|
96
|
+
# Convert to string if not a primitive type
|
|
97
|
+
if isinstance(value, str | int | float | bool):
|
|
98
|
+
filtered_attrs[key] = value
|
|
99
|
+
else:
|
|
100
|
+
filtered_attrs[key] = str(value)
|
|
101
|
+
|
|
102
|
+
if filtered_attrs:
|
|
103
|
+
self._span.set_attributes(filtered_attrs)
|
|
104
|
+
except Exception:
|
|
105
|
+
# Silently ignore tracing errors
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
def set_status(self, status: str, description: str | None = None) -> None:
|
|
109
|
+
"""Set the status of the OpenTelemetry span."""
|
|
110
|
+
try:
|
|
111
|
+
if OTEL_AVAILABLE:
|
|
112
|
+
if status == 'error':
|
|
113
|
+
self._span.set_status(StatusCode.ERROR, description)
|
|
114
|
+
elif status == 'ok':
|
|
115
|
+
self._span.set_status(StatusCode.OK, description)
|
|
116
|
+
except Exception:
|
|
117
|
+
# Silently ignore tracing errors
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
def record_exception(self, exception: Exception) -> None:
|
|
121
|
+
"""Record an exception in the OpenTelemetry span."""
|
|
122
|
+
with suppress(Exception):
|
|
123
|
+
self._span.record_exception(exception)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class OpenTelemetryTracer(Tracer):
|
|
127
|
+
"""Wrapper for OpenTelemetry tracer with configurable span name prefix."""
|
|
128
|
+
|
|
129
|
+
def __init__(self, tracer: Any, span_prefix: str = 'graphiti'):
|
|
130
|
+
"""
|
|
131
|
+
Initialize the OpenTelemetry tracer wrapper.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
tracer : opentelemetry.trace.Tracer
|
|
136
|
+
The OpenTelemetry tracer instance.
|
|
137
|
+
span_prefix : str, optional
|
|
138
|
+
Prefix to prepend to all span names. Defaults to 'graphiti'.
|
|
139
|
+
"""
|
|
140
|
+
if not OTEL_AVAILABLE:
|
|
141
|
+
raise ImportError(
|
|
142
|
+
'OpenTelemetry is not installed. Install it with: pip install opentelemetry-api'
|
|
143
|
+
)
|
|
144
|
+
self._tracer = tracer
|
|
145
|
+
self._span_prefix = span_prefix.rstrip('.')
|
|
146
|
+
|
|
147
|
+
@contextmanager
|
|
148
|
+
def start_span(self, name: str) -> Generator[OpenTelemetrySpan | NoOpSpan, None, None]:
|
|
149
|
+
"""Start a new OpenTelemetry span with the configured prefix."""
|
|
150
|
+
try:
|
|
151
|
+
full_name = f'{self._span_prefix}.{name}'
|
|
152
|
+
with self._tracer.start_as_current_span(full_name) as span:
|
|
153
|
+
yield OpenTelemetrySpan(span)
|
|
154
|
+
except Exception:
|
|
155
|
+
# If tracing fails, yield a no-op span to prevent breaking the operation
|
|
156
|
+
yield NoOpSpan()
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def create_tracer(otel_tracer: Any | None = None, span_prefix: str = 'graphiti') -> Tracer:
|
|
160
|
+
"""
|
|
161
|
+
Create a tracer instance.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
otel_tracer : opentelemetry.trace.Tracer | None, optional
|
|
166
|
+
An OpenTelemetry tracer instance. If None, a no-op tracer is returned.
|
|
167
|
+
span_prefix : str, optional
|
|
168
|
+
Prefix to prepend to all span names. Defaults to 'graphiti'.
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
Tracer
|
|
173
|
+
A tracer instance (either OpenTelemetryTracer or NoOpTracer).
|
|
174
|
+
|
|
175
|
+
Examples
|
|
176
|
+
--------
|
|
177
|
+
Using with OpenTelemetry:
|
|
178
|
+
|
|
179
|
+
>>> from opentelemetry import trace
|
|
180
|
+
>>> otel_tracer = trace.get_tracer(__name__)
|
|
181
|
+
>>> tracer = create_tracer(otel_tracer, span_prefix='myapp.graphiti')
|
|
182
|
+
|
|
183
|
+
Using no-op tracer:
|
|
184
|
+
|
|
185
|
+
>>> tracer = create_tracer() # Returns NoOpTracer
|
|
186
|
+
"""
|
|
187
|
+
if otel_tracer is None:
|
|
188
|
+
return NoOpTracer()
|
|
189
|
+
|
|
190
|
+
if not OTEL_AVAILABLE:
|
|
191
|
+
return NoOpTracer()
|
|
192
|
+
|
|
193
|
+
return OpenTelemetryTracer(otel_tracer, span_prefix)
|
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
import typing
|
|
19
20
|
from datetime import datetime
|
|
@@ -22,22 +23,31 @@ import numpy as np
|
|
|
22
23
|
from pydantic import BaseModel, Field
|
|
23
24
|
from typing_extensions import Any
|
|
24
25
|
|
|
25
|
-
from graphiti_core.driver.driver import
|
|
26
|
+
from graphiti_core.driver.driver import (
|
|
27
|
+
GraphDriver,
|
|
28
|
+
GraphDriverSession,
|
|
29
|
+
GraphProvider,
|
|
30
|
+
)
|
|
26
31
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
|
27
32
|
from graphiti_core.embedder import EmbedderClient
|
|
28
|
-
from graphiti_core.graph_queries import (
|
|
29
|
-
get_entity_edge_save_bulk_query,
|
|
30
|
-
get_entity_node_save_bulk_query,
|
|
31
|
-
)
|
|
32
33
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
33
34
|
from graphiti_core.helpers import normalize_l2, semaphore_gather
|
|
34
35
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
35
|
-
|
|
36
|
+
get_entity_edge_save_bulk_query,
|
|
37
|
+
get_episodic_edge_save_bulk_query,
|
|
36
38
|
)
|
|
37
39
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
38
|
-
|
|
40
|
+
get_entity_node_save_bulk_query,
|
|
41
|
+
get_episode_node_save_bulk_query,
|
|
42
|
+
)
|
|
43
|
+
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
44
|
+
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
|
45
|
+
from graphiti_core.utils.maintenance.dedup_helpers import (
|
|
46
|
+
DedupResolutionState,
|
|
47
|
+
_build_candidate_indexes,
|
|
48
|
+
_normalize_string_exact,
|
|
49
|
+
_resolve_with_similarity,
|
|
39
50
|
)
|
|
40
|
-
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
|
41
51
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
42
52
|
extract_edges,
|
|
43
53
|
resolve_extracted_edge,
|
|
@@ -56,6 +66,38 @@ logger = logging.getLogger(__name__)
|
|
|
56
66
|
CHUNK_SIZE = 10
|
|
57
67
|
|
|
58
68
|
|
|
69
|
+
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
|
|
70
|
+
"""Collapse alias -> canonical chains while preserving direction.
|
|
71
|
+
|
|
72
|
+
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
|
|
73
|
+
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
|
|
74
|
+
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
parent: dict[str, str] = {}
|
|
78
|
+
|
|
79
|
+
def find(uuid: str) -> str:
|
|
80
|
+
"""Directed union-find lookup using iterative path compression."""
|
|
81
|
+
parent.setdefault(uuid, uuid)
|
|
82
|
+
root = uuid
|
|
83
|
+
while parent[root] != root:
|
|
84
|
+
root = parent[root]
|
|
85
|
+
|
|
86
|
+
while parent[uuid] != root:
|
|
87
|
+
next_uuid = parent[uuid]
|
|
88
|
+
parent[uuid] = root
|
|
89
|
+
uuid = next_uuid
|
|
90
|
+
|
|
91
|
+
return root
|
|
92
|
+
|
|
93
|
+
for source_uuid, target_uuid in pairs:
|
|
94
|
+
parent.setdefault(source_uuid, source_uuid)
|
|
95
|
+
parent.setdefault(target_uuid, target_uuid)
|
|
96
|
+
parent[find(source_uuid)] = find(target_uuid)
|
|
97
|
+
|
|
98
|
+
return {uuid: find(uuid) for uuid in parent}
|
|
99
|
+
|
|
100
|
+
|
|
59
101
|
class RawEpisode(BaseModel):
|
|
60
102
|
name: str
|
|
61
103
|
uuid: str | None = Field(default=None)
|
|
@@ -118,24 +160,33 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
118
160
|
episodes = [dict(episode) for episode in episodic_nodes]
|
|
119
161
|
for episode in episodes:
|
|
120
162
|
episode['source'] = str(episode['source'].value)
|
|
121
|
-
|
|
163
|
+
episode.pop('labels', None)
|
|
164
|
+
|
|
165
|
+
nodes = []
|
|
166
|
+
|
|
122
167
|
for node in entity_nodes:
|
|
123
168
|
if node.name_embedding is None:
|
|
124
169
|
await node.generate_name_embedding(embedder)
|
|
170
|
+
|
|
125
171
|
entity_data: dict[str, Any] = {
|
|
126
172
|
'uuid': node.uuid,
|
|
127
173
|
'name': node.name,
|
|
128
|
-
'name_embedding': node.name_embedding,
|
|
129
174
|
'group_id': node.group_id,
|
|
130
175
|
'summary': node.summary,
|
|
131
176
|
'created_at': node.created_at,
|
|
177
|
+
'name_embedding': node.name_embedding,
|
|
178
|
+
'labels': list(set(node.labels + ['Entity'])),
|
|
132
179
|
}
|
|
133
180
|
|
|
134
|
-
|
|
135
|
-
|
|
181
|
+
if driver.provider == GraphProvider.KUZU:
|
|
182
|
+
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
|
183
|
+
entity_data['attributes'] = json.dumps(attributes)
|
|
184
|
+
else:
|
|
185
|
+
entity_data.update(node.attributes or {})
|
|
186
|
+
|
|
136
187
|
nodes.append(entity_data)
|
|
137
188
|
|
|
138
|
-
edges
|
|
189
|
+
edges = []
|
|
139
190
|
for edge in entity_edges:
|
|
140
191
|
if edge.fact_embedding is None:
|
|
141
192
|
await edge.generate_embedding(embedder)
|
|
@@ -145,35 +196,68 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
145
196
|
'target_node_uuid': edge.target_node_uuid,
|
|
146
197
|
'name': edge.name,
|
|
147
198
|
'fact': edge.fact,
|
|
148
|
-
'fact_embedding': edge.fact_embedding,
|
|
149
199
|
'group_id': edge.group_id,
|
|
150
200
|
'episodes': edge.episodes,
|
|
151
201
|
'created_at': edge.created_at,
|
|
152
202
|
'expired_at': edge.expired_at,
|
|
153
203
|
'valid_at': edge.valid_at,
|
|
154
204
|
'invalid_at': edge.invalid_at,
|
|
205
|
+
'fact_embedding': edge.fact_embedding,
|
|
155
206
|
}
|
|
156
207
|
|
|
157
|
-
|
|
208
|
+
if driver.provider == GraphProvider.KUZU:
|
|
209
|
+
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
|
210
|
+
edge_data['attributes'] = json.dumps(attributes)
|
|
211
|
+
else:
|
|
212
|
+
edge_data.update(edge.attributes or {})
|
|
213
|
+
|
|
158
214
|
edges.append(edge_data)
|
|
159
215
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
216
|
+
if driver.graph_operations_interface:
|
|
217
|
+
await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
|
|
218
|
+
await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
|
|
219
|
+
await driver.graph_operations_interface.episodic_edge_save_bulk(
|
|
220
|
+
None, driver, tx, [edge.model_dump() for edge in episodic_edges]
|
|
221
|
+
)
|
|
222
|
+
await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
|
|
223
|
+
|
|
224
|
+
elif driver.provider == GraphProvider.KUZU:
|
|
225
|
+
# FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
|
|
226
|
+
episode_query = get_episode_node_save_bulk_query(driver.provider)
|
|
227
|
+
for episode in episodes:
|
|
228
|
+
await tx.run(episode_query, **episode)
|
|
229
|
+
entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
|
|
230
|
+
for node in nodes:
|
|
231
|
+
await tx.run(entity_node_query, **node)
|
|
232
|
+
entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
|
|
233
|
+
for edge in edges:
|
|
234
|
+
await tx.run(entity_edge_query, **edge)
|
|
235
|
+
episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
|
|
236
|
+
for edge in episodic_edges:
|
|
237
|
+
await tx.run(episodic_edge_query, **edge.model_dump())
|
|
238
|
+
else:
|
|
239
|
+
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
|
240
|
+
await tx.run(
|
|
241
|
+
get_entity_node_save_bulk_query(driver.provider, nodes),
|
|
242
|
+
nodes=nodes,
|
|
243
|
+
)
|
|
244
|
+
await tx.run(
|
|
245
|
+
get_episodic_edge_save_bulk_query(driver.provider),
|
|
246
|
+
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
|
247
|
+
)
|
|
248
|
+
await tx.run(
|
|
249
|
+
get_entity_edge_save_bulk_query(driver.provider),
|
|
250
|
+
entity_edges=edges,
|
|
251
|
+
)
|
|
168
252
|
|
|
169
253
|
|
|
170
254
|
async def extract_nodes_and_edges_bulk(
|
|
171
255
|
clients: GraphitiClients,
|
|
172
256
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
173
257
|
edge_type_map: dict[tuple[str, str], list[str]],
|
|
174
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
258
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
175
259
|
excluded_entity_types: list[str] | None = None,
|
|
176
|
-
edge_types: dict[str, BaseModel] | None = None,
|
|
260
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
177
261
|
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
|
|
178
262
|
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
|
|
179
263
|
*[
|
|
@@ -204,85 +288,113 @@ async def dedupe_nodes_bulk(
|
|
|
204
288
|
clients: GraphitiClients,
|
|
205
289
|
extracted_nodes: list[list[EntityNode]],
|
|
206
290
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
207
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
291
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
208
292
|
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
|
209
|
-
|
|
210
|
-
min_score = 0.8
|
|
211
|
-
|
|
212
|
-
# generate embeddings
|
|
213
|
-
await semaphore_gather(
|
|
214
|
-
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
|
|
215
|
-
)
|
|
293
|
+
"""Resolve entity duplicates across an in-memory batch using a two-pass strategy.
|
|
216
294
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
continue
|
|
224
|
-
existing_nodes += nodes_j
|
|
225
|
-
|
|
226
|
-
candidates_i: list[EntityNode] = []
|
|
227
|
-
for node in nodes_i:
|
|
228
|
-
for existing_node in existing_nodes:
|
|
229
|
-
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
|
230
|
-
# This approach will cast a wider net than BM25, which is ideal for this use case
|
|
231
|
-
node_words = set(node.name.lower().split())
|
|
232
|
-
existing_node_words = set(existing_node.name.lower().split())
|
|
233
|
-
has_overlap = not node_words.isdisjoint(existing_node_words)
|
|
234
|
-
if has_overlap:
|
|
235
|
-
candidates_i.append(existing_node)
|
|
236
|
-
continue
|
|
237
|
-
|
|
238
|
-
# Check for semantic similarity even if there is no overlap
|
|
239
|
-
similarity = np.dot(
|
|
240
|
-
normalize_l2(node.name_embedding or []),
|
|
241
|
-
normalize_l2(existing_node.name_embedding or []),
|
|
242
|
-
)
|
|
243
|
-
if similarity >= min_score:
|
|
244
|
-
candidates_i.append(existing_node)
|
|
245
|
-
|
|
246
|
-
dedupe_tuples.append((nodes_i, candidates_i))
|
|
295
|
+
1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
|
|
296
|
+
reconciled against the live graph just like the non-batch flow.
|
|
297
|
+
2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
|
|
298
|
+
duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
|
|
299
|
+
can apply to edges and persistence.
|
|
300
|
+
"""
|
|
247
301
|
|
|
248
|
-
|
|
249
|
-
bulk_node_resolutions: list[
|
|
250
|
-
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
|
|
251
|
-
] = await semaphore_gather(
|
|
302
|
+
first_pass_results = await semaphore_gather(
|
|
252
303
|
*[
|
|
253
304
|
resolve_extracted_nodes(
|
|
254
305
|
clients,
|
|
255
|
-
|
|
306
|
+
nodes,
|
|
256
307
|
episode_tuples[i][0],
|
|
257
308
|
episode_tuples[i][1],
|
|
258
309
|
entity_types,
|
|
259
|
-
existing_nodes_override=dedupe_tuples[i][1],
|
|
260
310
|
)
|
|
261
|
-
for i,
|
|
311
|
+
for i, nodes in enumerate(extracted_nodes)
|
|
262
312
|
]
|
|
263
313
|
)
|
|
264
314
|
|
|
265
|
-
|
|
315
|
+
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
|
|
316
|
+
per_episode_uuid_maps: list[dict[str, str]] = []
|
|
266
317
|
duplicate_pairs: list[tuple[str, str]] = []
|
|
267
|
-
for _, _, duplicates in bulk_node_resolutions:
|
|
268
|
-
for duplicate in duplicates:
|
|
269
|
-
n, m = duplicate
|
|
270
|
-
duplicate_pairs.append((n.uuid, m.uuid))
|
|
271
318
|
|
|
272
|
-
|
|
273
|
-
|
|
319
|
+
for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
|
|
320
|
+
first_pass_results, episode_tuples, strict=True
|
|
321
|
+
):
|
|
322
|
+
episode_resolutions.append((episode.uuid, resolved_nodes))
|
|
323
|
+
per_episode_uuid_maps.append(uuid_map)
|
|
324
|
+
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
|
|
325
|
+
|
|
326
|
+
canonical_nodes: dict[str, EntityNode] = {}
|
|
327
|
+
for _, resolved_nodes in episode_resolutions:
|
|
328
|
+
for node in resolved_nodes:
|
|
329
|
+
# NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
|
|
330
|
+
# the MinHash index for the accumulated canonical pool each time. The LRU-backed
|
|
331
|
+
# shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
|
|
332
|
+
# but if batches grow significantly we should switch to an incremental index or chunked
|
|
333
|
+
# processing.
|
|
334
|
+
if not canonical_nodes:
|
|
335
|
+
canonical_nodes[node.uuid] = node
|
|
336
|
+
continue
|
|
274
337
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
338
|
+
existing_candidates = list(canonical_nodes.values())
|
|
339
|
+
normalized = _normalize_string_exact(node.name)
|
|
340
|
+
exact_match = next(
|
|
341
|
+
(
|
|
342
|
+
candidate
|
|
343
|
+
for candidate in existing_candidates
|
|
344
|
+
if _normalize_string_exact(candidate.name) == normalized
|
|
345
|
+
),
|
|
346
|
+
None,
|
|
347
|
+
)
|
|
348
|
+
if exact_match is not None:
|
|
349
|
+
if exact_match.uuid != node.uuid:
|
|
350
|
+
duplicate_pairs.append((node.uuid, exact_match.uuid))
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
indexes = _build_candidate_indexes(existing_candidates)
|
|
354
|
+
state = DedupResolutionState(
|
|
355
|
+
resolved_nodes=[None],
|
|
356
|
+
uuid_map={},
|
|
357
|
+
unresolved_indices=[],
|
|
358
|
+
)
|
|
359
|
+
_resolve_with_similarity([node], indexes, state)
|
|
360
|
+
|
|
361
|
+
resolved = state.resolved_nodes[0]
|
|
362
|
+
if resolved is None:
|
|
363
|
+
canonical_nodes[node.uuid] = node
|
|
364
|
+
continue
|
|
365
|
+
|
|
366
|
+
canonical_uuid = resolved.uuid
|
|
367
|
+
canonical_nodes.setdefault(canonical_uuid, resolved)
|
|
368
|
+
if canonical_uuid != node.uuid:
|
|
369
|
+
duplicate_pairs.append((node.uuid, canonical_uuid))
|
|
370
|
+
|
|
371
|
+
union_pairs: list[tuple[str, str]] = []
|
|
372
|
+
for uuid_map in per_episode_uuid_maps:
|
|
373
|
+
union_pairs.extend(uuid_map.items())
|
|
374
|
+
union_pairs.extend(duplicate_pairs)
|
|
375
|
+
|
|
376
|
+
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
|
|
278
377
|
|
|
279
378
|
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
|
280
|
-
for
|
|
281
|
-
|
|
379
|
+
for episode_uuid, resolved_nodes in episode_resolutions:
|
|
380
|
+
deduped_nodes: list[EntityNode] = []
|
|
381
|
+
seen: set[str] = set()
|
|
382
|
+
for node in resolved_nodes:
|
|
383
|
+
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
|
|
384
|
+
if canonical_uuid in seen:
|
|
385
|
+
continue
|
|
386
|
+
seen.add(canonical_uuid)
|
|
387
|
+
canonical_node = canonical_nodes.get(canonical_uuid)
|
|
388
|
+
if canonical_node is None:
|
|
389
|
+
logger.error(
|
|
390
|
+
'Canonical node %s missing during batch dedupe; falling back to %s',
|
|
391
|
+
canonical_uuid,
|
|
392
|
+
node.uuid,
|
|
393
|
+
)
|
|
394
|
+
canonical_node = node
|
|
395
|
+
deduped_nodes.append(canonical_node)
|
|
282
396
|
|
|
283
|
-
nodes_by_episode[
|
|
284
|
-
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
|
|
285
|
-
]
|
|
397
|
+
nodes_by_episode[episode_uuid] = deduped_nodes
|
|
286
398
|
|
|
287
399
|
return nodes_by_episode, compressed_map
|
|
288
400
|
|
|
@@ -292,7 +404,7 @@ async def dedupe_edges_bulk(
|
|
|
292
404
|
extracted_edges: list[list[EntityEdge]],
|
|
293
405
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
294
406
|
_entities: list[EntityNode],
|
|
295
|
-
edge_types: dict[str, BaseModel],
|
|
407
|
+
edge_types: dict[str, type[BaseModel]],
|
|
296
408
|
_edge_type_map: dict[tuple[str, str], list[str]],
|
|
297
409
|
) -> dict[str, list[EntityEdge]]:
|
|
298
410
|
embedder = clients.embedder
|
|
@@ -307,16 +419,23 @@ async def dedupe_edges_bulk(
|
|
|
307
419
|
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
|
|
308
420
|
for i, edges_i in enumerate(extracted_edges):
|
|
309
421
|
existing_edges: list[EntityEdge] = []
|
|
310
|
-
for
|
|
311
|
-
if i == j:
|
|
312
|
-
continue
|
|
422
|
+
for edges_j in extracted_edges:
|
|
313
423
|
existing_edges += edges_j
|
|
314
424
|
|
|
315
425
|
for edge in edges_i:
|
|
316
426
|
candidates: list[EntityEdge] = []
|
|
317
427
|
for existing_edge in existing_edges:
|
|
428
|
+
# Skip self-comparison
|
|
429
|
+
if edge.uuid == existing_edge.uuid:
|
|
430
|
+
continue
|
|
318
431
|
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
|
319
432
|
# This approach will cast a wider net than BM25, which is ideal for this use case
|
|
433
|
+
if (
|
|
434
|
+
edge.source_node_uuid != existing_edge.source_node_uuid
|
|
435
|
+
or edge.target_node_uuid != existing_edge.target_node_uuid
|
|
436
|
+
):
|
|
437
|
+
continue
|
|
438
|
+
|
|
320
439
|
edge_words = set(edge.fact.lower().split())
|
|
321
440
|
existing_edge_words = set(existing_edge.fact.lower().split())
|
|
322
441
|
has_overlap = not edge_words.isdisjoint(existing_edge_words)
|
|
@@ -339,12 +458,19 @@ async def dedupe_edges_bulk(
|
|
|
339
458
|
] = await semaphore_gather(
|
|
340
459
|
*[
|
|
341
460
|
resolve_extracted_edge(
|
|
342
|
-
clients.llm_client,
|
|
461
|
+
clients.llm_client,
|
|
462
|
+
edge,
|
|
463
|
+
candidates,
|
|
464
|
+
candidates,
|
|
465
|
+
episode,
|
|
466
|
+
edge_types,
|
|
467
|
+
set(edge_types),
|
|
343
468
|
)
|
|
344
469
|
for episode, edge, candidates in dedupe_tuples
|
|
345
470
|
]
|
|
346
471
|
)
|
|
347
472
|
|
|
473
|
+
# For now we won't track edge invalidation
|
|
348
474
|
duplicate_pairs: list[tuple[str, str]] = []
|
|
349
475
|
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
|
|
350
476
|
episode, edge, candidates = dedupe_tuples[i]
|