graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
|
@@ -14,58 +14,93 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
import typing
|
|
19
|
-
from collections import defaultdict
|
|
20
20
|
from datetime import datetime
|
|
21
|
-
from math import ceil
|
|
22
21
|
|
|
23
|
-
|
|
24
|
-
from
|
|
25
|
-
from pydantic import BaseModel
|
|
22
|
+
import numpy as np
|
|
23
|
+
from pydantic import BaseModel, Field
|
|
26
24
|
from typing_extensions import Any
|
|
27
25
|
|
|
28
|
-
from graphiti_core.
|
|
26
|
+
from graphiti_core.driver.driver import (
|
|
27
|
+
GraphDriver,
|
|
28
|
+
GraphDriverSession,
|
|
29
|
+
GraphProvider,
|
|
30
|
+
)
|
|
31
|
+
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
|
29
32
|
from graphiti_core.embedder import EmbedderClient
|
|
30
33
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
31
|
-
from graphiti_core.helpers import
|
|
32
|
-
from graphiti_core.llm_client import LLMClient
|
|
34
|
+
from graphiti_core.helpers import normalize_l2, semaphore_gather
|
|
33
35
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
34
|
-
|
|
35
|
-
|
|
36
|
+
get_entity_edge_save_bulk_query,
|
|
37
|
+
get_episodic_edge_save_bulk_query,
|
|
36
38
|
)
|
|
37
39
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
38
|
-
|
|
39
|
-
|
|
40
|
+
get_entity_node_save_bulk_query,
|
|
41
|
+
get_episode_node_save_bulk_query,
|
|
40
42
|
)
|
|
41
43
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
42
|
-
from graphiti_core.
|
|
43
|
-
from graphiti_core.
|
|
44
|
-
|
|
44
|
+
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
|
45
|
+
from graphiti_core.utils.maintenance.dedup_helpers import (
|
|
46
|
+
DedupResolutionState,
|
|
47
|
+
_build_candidate_indexes,
|
|
48
|
+
_normalize_string_exact,
|
|
49
|
+
_resolve_with_similarity,
|
|
50
|
+
)
|
|
45
51
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
46
|
-
build_episodic_edges,
|
|
47
|
-
dedupe_edge_list,
|
|
48
|
-
dedupe_extracted_edges,
|
|
49
52
|
extract_edges,
|
|
53
|
+
resolve_extracted_edge,
|
|
50
54
|
)
|
|
51
55
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
52
56
|
EPISODE_WINDOW_LEN,
|
|
53
57
|
retrieve_episodes,
|
|
54
58
|
)
|
|
55
59
|
from graphiti_core.utils.maintenance.node_operations import (
|
|
56
|
-
dedupe_extracted_nodes,
|
|
57
|
-
dedupe_node_list,
|
|
58
60
|
extract_nodes,
|
|
61
|
+
resolve_extracted_nodes,
|
|
59
62
|
)
|
|
60
|
-
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
|
61
63
|
|
|
62
64
|
logger = logging.getLogger(__name__)
|
|
63
65
|
|
|
64
66
|
CHUNK_SIZE = 10
|
|
65
67
|
|
|
66
68
|
|
|
69
|
+
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
|
|
70
|
+
"""Collapse alias -> canonical chains while preserving direction.
|
|
71
|
+
|
|
72
|
+
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
|
|
73
|
+
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
|
|
74
|
+
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
parent: dict[str, str] = {}
|
|
78
|
+
|
|
79
|
+
def find(uuid: str) -> str:
|
|
80
|
+
"""Directed union-find lookup using iterative path compression."""
|
|
81
|
+
parent.setdefault(uuid, uuid)
|
|
82
|
+
root = uuid
|
|
83
|
+
while parent[root] != root:
|
|
84
|
+
root = parent[root]
|
|
85
|
+
|
|
86
|
+
while parent[uuid] != root:
|
|
87
|
+
next_uuid = parent[uuid]
|
|
88
|
+
parent[uuid] = root
|
|
89
|
+
uuid = next_uuid
|
|
90
|
+
|
|
91
|
+
return root
|
|
92
|
+
|
|
93
|
+
for source_uuid, target_uuid in pairs:
|
|
94
|
+
parent.setdefault(source_uuid, source_uuid)
|
|
95
|
+
parent.setdefault(target_uuid, target_uuid)
|
|
96
|
+
parent[find(source_uuid)] = find(target_uuid)
|
|
97
|
+
|
|
98
|
+
return {uuid: find(uuid) for uuid in parent}
|
|
99
|
+
|
|
100
|
+
|
|
67
101
|
class RawEpisode(BaseModel):
|
|
68
102
|
name: str
|
|
103
|
+
uuid: str | None = Field(default=None)
|
|
69
104
|
content: str
|
|
70
105
|
source_description: str
|
|
71
106
|
source: EpisodeType
|
|
@@ -73,7 +108,7 @@ class RawEpisode(BaseModel):
|
|
|
73
108
|
|
|
74
109
|
|
|
75
110
|
async def retrieve_previous_episodes_bulk(
|
|
76
|
-
driver:
|
|
111
|
+
driver: GraphDriver, episodes: list[EpisodicNode]
|
|
77
112
|
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
|
78
113
|
previous_episodes_list = await semaphore_gather(
|
|
79
114
|
*[
|
|
@@ -91,14 +126,15 @@ async def retrieve_previous_episodes_bulk(
|
|
|
91
126
|
|
|
92
127
|
|
|
93
128
|
async def add_nodes_and_edges_bulk(
|
|
94
|
-
driver:
|
|
129
|
+
driver: GraphDriver,
|
|
95
130
|
episodic_nodes: list[EpisodicNode],
|
|
96
131
|
episodic_edges: list[EpisodicEdge],
|
|
97
132
|
entity_nodes: list[EntityNode],
|
|
98
133
|
entity_edges: list[EntityEdge],
|
|
99
134
|
embedder: EmbedderClient,
|
|
100
135
|
):
|
|
101
|
-
|
|
136
|
+
session = driver.session()
|
|
137
|
+
try:
|
|
102
138
|
await session.execute_write(
|
|
103
139
|
add_nodes_and_edges_bulk_tx,
|
|
104
140
|
episodic_nodes,
|
|
@@ -106,38 +142,51 @@ async def add_nodes_and_edges_bulk(
|
|
|
106
142
|
entity_nodes,
|
|
107
143
|
entity_edges,
|
|
108
144
|
embedder,
|
|
145
|
+
driver=driver,
|
|
109
146
|
)
|
|
147
|
+
finally:
|
|
148
|
+
await session.close()
|
|
110
149
|
|
|
111
150
|
|
|
112
151
|
async def add_nodes_and_edges_bulk_tx(
|
|
113
|
-
tx:
|
|
152
|
+
tx: GraphDriverSession,
|
|
114
153
|
episodic_nodes: list[EpisodicNode],
|
|
115
154
|
episodic_edges: list[EpisodicEdge],
|
|
116
155
|
entity_nodes: list[EntityNode],
|
|
117
156
|
entity_edges: list[EntityEdge],
|
|
118
157
|
embedder: EmbedderClient,
|
|
158
|
+
driver: GraphDriver,
|
|
119
159
|
):
|
|
120
160
|
episodes = [dict(episode) for episode in episodic_nodes]
|
|
121
161
|
for episode in episodes:
|
|
122
162
|
episode['source'] = str(episode['source'].value)
|
|
123
|
-
|
|
163
|
+
episode.pop('labels', None)
|
|
164
|
+
|
|
165
|
+
nodes = []
|
|
166
|
+
|
|
124
167
|
for node in entity_nodes:
|
|
125
168
|
if node.name_embedding is None:
|
|
126
169
|
await node.generate_name_embedding(embedder)
|
|
170
|
+
|
|
127
171
|
entity_data: dict[str, Any] = {
|
|
128
172
|
'uuid': node.uuid,
|
|
129
173
|
'name': node.name,
|
|
130
|
-
'name_embedding': node.name_embedding,
|
|
131
174
|
'group_id': node.group_id,
|
|
132
175
|
'summary': node.summary,
|
|
133
176
|
'created_at': node.created_at,
|
|
177
|
+
'name_embedding': node.name_embedding,
|
|
178
|
+
'labels': list(set(node.labels + ['Entity'])),
|
|
134
179
|
}
|
|
135
180
|
|
|
136
|
-
|
|
137
|
-
|
|
181
|
+
if driver.provider == GraphProvider.KUZU:
|
|
182
|
+
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
|
183
|
+
entity_data['attributes'] = json.dumps(attributes)
|
|
184
|
+
else:
|
|
185
|
+
entity_data.update(node.attributes or {})
|
|
186
|
+
|
|
138
187
|
nodes.append(entity_data)
|
|
139
188
|
|
|
140
|
-
edges
|
|
189
|
+
edges = []
|
|
141
190
|
for edge in entity_edges:
|
|
142
191
|
if edge.fact_embedding is None:
|
|
143
192
|
await edge.generate_embedding(embedder)
|
|
@@ -147,253 +196,343 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
147
196
|
'target_node_uuid': edge.target_node_uuid,
|
|
148
197
|
'name': edge.name,
|
|
149
198
|
'fact': edge.fact,
|
|
150
|
-
'fact_embedding': edge.fact_embedding,
|
|
151
199
|
'group_id': edge.group_id,
|
|
152
200
|
'episodes': edge.episodes,
|
|
153
201
|
'created_at': edge.created_at,
|
|
154
202
|
'expired_at': edge.expired_at,
|
|
155
203
|
'valid_at': edge.valid_at,
|
|
156
204
|
'invalid_at': edge.invalid_at,
|
|
205
|
+
'fact_embedding': edge.fact_embedding,
|
|
157
206
|
}
|
|
158
207
|
|
|
159
|
-
|
|
208
|
+
if driver.provider == GraphProvider.KUZU:
|
|
209
|
+
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
|
210
|
+
edge_data['attributes'] = json.dumps(attributes)
|
|
211
|
+
else:
|
|
212
|
+
edge_data.update(edge.attributes or {})
|
|
213
|
+
|
|
160
214
|
edges.append(edge_data)
|
|
161
215
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
216
|
+
if driver.graph_operations_interface:
|
|
217
|
+
await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
|
|
218
|
+
await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
|
|
219
|
+
await driver.graph_operations_interface.episodic_edge_save_bulk(
|
|
220
|
+
None, driver, tx, [edge.model_dump() for edge in episodic_edges]
|
|
221
|
+
)
|
|
222
|
+
await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
|
|
223
|
+
|
|
224
|
+
elif driver.provider == GraphProvider.KUZU:
|
|
225
|
+
# FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
|
|
226
|
+
episode_query = get_episode_node_save_bulk_query(driver.provider)
|
|
227
|
+
for episode in episodes:
|
|
228
|
+
await tx.run(episode_query, **episode)
|
|
229
|
+
entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
|
|
230
|
+
for node in nodes:
|
|
231
|
+
await tx.run(entity_node_query, **node)
|
|
232
|
+
entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
|
|
233
|
+
for edge in edges:
|
|
234
|
+
await tx.run(entity_edge_query, **edge)
|
|
235
|
+
episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
|
|
236
|
+
for edge in episodic_edges:
|
|
237
|
+
await tx.run(episodic_edge_query, **edge.model_dump())
|
|
238
|
+
else:
|
|
239
|
+
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
|
240
|
+
await tx.run(
|
|
241
|
+
get_entity_node_save_bulk_query(driver.provider, nodes),
|
|
242
|
+
nodes=nodes,
|
|
243
|
+
)
|
|
244
|
+
await tx.run(
|
|
245
|
+
get_episodic_edge_save_bulk_query(driver.provider),
|
|
246
|
+
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
|
247
|
+
)
|
|
248
|
+
await tx.run(
|
|
249
|
+
get_entity_edge_save_bulk_query(driver.provider),
|
|
250
|
+
entity_edges=edges,
|
|
251
|
+
)
|
|
168
252
|
|
|
169
253
|
|
|
170
254
|
async def extract_nodes_and_edges_bulk(
|
|
171
|
-
clients: GraphitiClients,
|
|
172
|
-
|
|
173
|
-
|
|
255
|
+
clients: GraphitiClients,
|
|
256
|
+
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
257
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
258
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
259
|
+
excluded_entity_types: list[str] | None = None,
|
|
260
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
261
|
+
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
|
|
262
|
+
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
|
|
174
263
|
*[
|
|
175
|
-
extract_nodes(clients, episode, previous_episodes)
|
|
264
|
+
extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
|
|
176
265
|
for episode, previous_episodes in episode_tuples
|
|
177
266
|
]
|
|
178
267
|
)
|
|
179
268
|
|
|
180
|
-
|
|
181
|
-
[episode[0] for episode in episode_tuples],
|
|
182
|
-
[episode[1] for episode in episode_tuples],
|
|
183
|
-
)
|
|
184
|
-
|
|
185
|
-
extracted_edges_bulk = await semaphore_gather(
|
|
269
|
+
extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
|
|
186
270
|
*[
|
|
187
271
|
extract_edges(
|
|
188
272
|
clients,
|
|
189
273
|
episode,
|
|
190
274
|
extracted_nodes_bulk[i],
|
|
191
|
-
|
|
192
|
-
|
|
275
|
+
previous_episodes,
|
|
276
|
+
edge_type_map=edge_type_map,
|
|
277
|
+
group_id=episode.group_id,
|
|
278
|
+
edge_types=edge_types,
|
|
193
279
|
)
|
|
194
|
-
for i, episode in enumerate(
|
|
280
|
+
for i, (episode, previous_episodes) in enumerate(episode_tuples)
|
|
195
281
|
]
|
|
196
282
|
)
|
|
197
283
|
|
|
198
|
-
|
|
199
|
-
for i, episode in enumerate(episodes):
|
|
200
|
-
episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at)
|
|
201
|
-
|
|
202
|
-
nodes: list[EntityNode] = []
|
|
203
|
-
for extracted_nodes in extracted_nodes_bulk:
|
|
204
|
-
nodes += extracted_nodes
|
|
205
|
-
|
|
206
|
-
edges: list[EntityEdge] = []
|
|
207
|
-
for extracted_edges in extracted_edges_bulk:
|
|
208
|
-
edges += extracted_edges
|
|
209
|
-
|
|
210
|
-
return nodes, edges, episodic_edges
|
|
284
|
+
return extracted_nodes_bulk, extracted_edges_bulk
|
|
211
285
|
|
|
212
286
|
|
|
213
287
|
async def dedupe_nodes_bulk(
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
288
|
+
clients: GraphitiClients,
|
|
289
|
+
extracted_nodes: list[list[EntityNode]],
|
|
290
|
+
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
291
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
292
|
+
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
|
293
|
+
"""Resolve entity duplicates across an in-memory batch using a two-pass strategy.
|
|
294
|
+
|
|
295
|
+
1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
|
|
296
|
+
reconciled against the live graph just like the non-batch flow.
|
|
297
|
+
2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
|
|
298
|
+
duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
|
|
299
|
+
can apply to edges and persistence.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
first_pass_results = await semaphore_gather(
|
|
303
|
+
*[
|
|
304
|
+
resolve_extracted_nodes(
|
|
305
|
+
clients,
|
|
306
|
+
nodes,
|
|
307
|
+
episode_tuples[i][0],
|
|
308
|
+
episode_tuples[i][1],
|
|
309
|
+
entity_types,
|
|
310
|
+
)
|
|
311
|
+
for i, nodes in enumerate(extracted_nodes)
|
|
312
|
+
]
|
|
238
313
|
)
|
|
239
314
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
315
|
+
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
|
|
316
|
+
per_episode_uuid_maps: list[dict[str, str]] = []
|
|
317
|
+
duplicate_pairs: list[tuple[str, str]] = []
|
|
318
|
+
|
|
319
|
+
for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
|
|
320
|
+
first_pass_results, episode_tuples, strict=True
|
|
321
|
+
):
|
|
322
|
+
episode_resolutions.append((episode.uuid, resolved_nodes))
|
|
323
|
+
per_episode_uuid_maps.append(uuid_map)
|
|
324
|
+
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
|
|
325
|
+
|
|
326
|
+
canonical_nodes: dict[str, EntityNode] = {}
|
|
327
|
+
for _, resolved_nodes in episode_resolutions:
|
|
328
|
+
for node in resolved_nodes:
|
|
329
|
+
# NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
|
|
330
|
+
# the MinHash index for the accumulated canonical pool each time. The LRU-backed
|
|
331
|
+
# shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
|
|
332
|
+
# but if batches grow significantly we should switch to an incremental index or chunked
|
|
333
|
+
# processing.
|
|
334
|
+
if not canonical_nodes:
|
|
335
|
+
canonical_nodes[node.uuid] = node
|
|
336
|
+
continue
|
|
337
|
+
|
|
338
|
+
existing_candidates = list(canonical_nodes.values())
|
|
339
|
+
normalized = _normalize_string_exact(node.name)
|
|
340
|
+
exact_match = next(
|
|
341
|
+
(
|
|
342
|
+
candidate
|
|
343
|
+
for candidate in existing_candidates
|
|
344
|
+
if _normalize_string_exact(candidate.name) == normalized
|
|
345
|
+
),
|
|
346
|
+
None,
|
|
347
|
+
)
|
|
348
|
+
if exact_match is not None:
|
|
349
|
+
if exact_match.uuid != node.uuid:
|
|
350
|
+
duplicate_pairs.append((node.uuid, exact_match.uuid))
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
indexes = _build_candidate_indexes(existing_candidates)
|
|
354
|
+
state = DedupResolutionState(
|
|
355
|
+
resolved_nodes=[None],
|
|
356
|
+
uuid_map={},
|
|
357
|
+
unresolved_indices=[],
|
|
358
|
+
)
|
|
359
|
+
_resolve_with_similarity([node], indexes, state)
|
|
360
|
+
|
|
361
|
+
resolved = state.resolved_nodes[0]
|
|
362
|
+
if resolved is None:
|
|
363
|
+
canonical_nodes[node.uuid] = node
|
|
364
|
+
continue
|
|
365
|
+
|
|
366
|
+
canonical_uuid = resolved.uuid
|
|
367
|
+
canonical_nodes.setdefault(canonical_uuid, resolved)
|
|
368
|
+
if canonical_uuid != node.uuid:
|
|
369
|
+
duplicate_pairs.append((node.uuid, canonical_uuid))
|
|
370
|
+
|
|
371
|
+
union_pairs: list[tuple[str, str]] = []
|
|
372
|
+
for uuid_map in per_episode_uuid_maps:
|
|
373
|
+
union_pairs.extend(uuid_map.items())
|
|
374
|
+
union_pairs.extend(duplicate_pairs)
|
|
375
|
+
|
|
376
|
+
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
|
|
377
|
+
|
|
378
|
+
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
|
379
|
+
for episode_uuid, resolved_nodes in episode_resolutions:
|
|
380
|
+
deduped_nodes: list[EntityNode] = []
|
|
381
|
+
seen: set[str] = set()
|
|
382
|
+
for node in resolved_nodes:
|
|
383
|
+
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
|
|
384
|
+
if canonical_uuid in seen:
|
|
385
|
+
continue
|
|
386
|
+
seen.add(canonical_uuid)
|
|
387
|
+
canonical_node = canonical_nodes.get(canonical_uuid)
|
|
388
|
+
if canonical_node is None:
|
|
389
|
+
logger.error(
|
|
390
|
+
'Canonical node %s missing during batch dedupe; falling back to %s',
|
|
391
|
+
canonical_uuid,
|
|
392
|
+
node.uuid,
|
|
393
|
+
)
|
|
394
|
+
canonical_node = node
|
|
395
|
+
deduped_nodes.append(canonical_node)
|
|
396
|
+
|
|
397
|
+
nodes_by_episode[episode_uuid] = deduped_nodes
|
|
398
|
+
|
|
399
|
+
return nodes_by_episode, compressed_map
|
|
247
400
|
|
|
248
401
|
|
|
249
402
|
async def dedupe_edges_bulk(
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
)
|
|
403
|
+
clients: GraphitiClients,
|
|
404
|
+
extracted_edges: list[list[EntityEdge]],
|
|
405
|
+
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
406
|
+
_entities: list[EntityNode],
|
|
407
|
+
edge_types: dict[str, type[BaseModel]],
|
|
408
|
+
_edge_type_map: dict[tuple[str, str], list[str]],
|
|
409
|
+
) -> dict[str, list[EntityEdge]]:
|
|
410
|
+
embedder = clients.embedder
|
|
411
|
+
min_score = 0.6
|
|
412
|
+
|
|
413
|
+
# generate embeddings
|
|
414
|
+
await semaphore_gather(
|
|
415
|
+
*[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
|
|
263
416
|
)
|
|
264
417
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
418
|
+
# Find similar results
|
|
419
|
+
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
|
|
420
|
+
for i, edges_i in enumerate(extracted_edges):
|
|
421
|
+
existing_edges: list[EntityEdge] = []
|
|
422
|
+
for edges_j in extracted_edges:
|
|
423
|
+
existing_edges += edges_j
|
|
424
|
+
|
|
425
|
+
for edge in edges_i:
|
|
426
|
+
candidates: list[EntityEdge] = []
|
|
427
|
+
for existing_edge in existing_edges:
|
|
428
|
+
# Skip self-comparison
|
|
429
|
+
if edge.uuid == existing_edge.uuid:
|
|
430
|
+
continue
|
|
431
|
+
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
|
432
|
+
# This approach will cast a wider net than BM25, which is ideal for this use case
|
|
433
|
+
if (
|
|
434
|
+
edge.source_node_uuid != existing_edge.source_node_uuid
|
|
435
|
+
or edge.target_node_uuid != existing_edge.target_node_uuid
|
|
436
|
+
):
|
|
437
|
+
continue
|
|
438
|
+
|
|
439
|
+
edge_words = set(edge.fact.lower().split())
|
|
440
|
+
existing_edge_words = set(existing_edge.fact.lower().split())
|
|
441
|
+
has_overlap = not edge_words.isdisjoint(existing_edge_words)
|
|
442
|
+
if has_overlap:
|
|
443
|
+
candidates.append(existing_edge)
|
|
444
|
+
continue
|
|
445
|
+
|
|
446
|
+
# Check for semantic similarity even if there is no overlap
|
|
447
|
+
similarity = np.dot(
|
|
448
|
+
normalize_l2(edge.fact_embedding or []),
|
|
449
|
+
normalize_l2(existing_edge.fact_embedding or []),
|
|
450
|
+
)
|
|
451
|
+
if similarity >= min_score:
|
|
452
|
+
candidates.append(existing_edge)
|
|
453
|
+
|
|
454
|
+
dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
|
|
455
|
+
|
|
456
|
+
bulk_edge_resolutions: list[
|
|
457
|
+
tuple[EntityEdge, EntityEdge, list[EntityEdge]]
|
|
458
|
+
] = await semaphore_gather(
|
|
459
|
+
*[
|
|
460
|
+
resolve_extracted_edge(
|
|
461
|
+
clients.llm_client,
|
|
462
|
+
edge,
|
|
463
|
+
candidates,
|
|
464
|
+
candidates,
|
|
465
|
+
episode,
|
|
466
|
+
edge_types,
|
|
467
|
+
set(edge_types),
|
|
468
|
+
)
|
|
469
|
+
for episode, edge, candidates in dedupe_tuples
|
|
470
|
+
]
|
|
272
471
|
)
|
|
273
472
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
name_map: dict[str, EntityNode] = {}
|
|
281
|
-
for node in nodes:
|
|
282
|
-
if node.name in name_map:
|
|
283
|
-
uuid_map[node.uuid] = name_map[node.name].uuid
|
|
284
|
-
continue
|
|
285
|
-
|
|
286
|
-
name_map[node.name] = node
|
|
287
|
-
|
|
288
|
-
return [node for node in name_map.values()], uuid_map
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
async def compress_nodes(
|
|
292
|
-
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
|
293
|
-
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
294
|
-
# We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
|
|
295
|
-
if len(nodes) == 0:
|
|
296
|
-
return nodes, uuid_map
|
|
473
|
+
# For now we won't track edge invalidation
|
|
474
|
+
duplicate_pairs: list[tuple[str, str]] = []
|
|
475
|
+
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
|
|
476
|
+
episode, edge, candidates = dedupe_tuples[i]
|
|
477
|
+
for duplicate in duplicates:
|
|
478
|
+
duplicate_pairs.append((edge.uuid, duplicate.uuid))
|
|
297
479
|
|
|
298
|
-
#
|
|
299
|
-
|
|
300
|
-
# We want chunk sizes to be at least 10 for optimizing LLM processing time
|
|
301
|
-
chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
|
|
480
|
+
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
|
481
|
+
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
|
|
302
482
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
for i, n in enumerate(nodes)
|
|
307
|
-
for j, m in enumerate(nodes[:i])
|
|
308
|
-
]
|
|
309
|
-
|
|
310
|
-
# We now sort by semantic similarity
|
|
311
|
-
similarity_scores.sort(key=lambda score_tuple: score_tuple[2])
|
|
483
|
+
edge_uuid_map: dict[str, EntityEdge] = {
|
|
484
|
+
edge.uuid: edge for edges in extracted_edges for edge in edges
|
|
485
|
+
}
|
|
312
486
|
|
|
313
|
-
|
|
314
|
-
|
|
487
|
+
edges_by_episode: dict[str, list[EntityEdge]] = {}
|
|
488
|
+
for i, edges in enumerate(extracted_edges):
|
|
489
|
+
episode = episode_tuples[i][0]
|
|
315
490
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
# determine if any of the nodes have already been drafted into a chunk
|
|
320
|
-
n = nodes[i]
|
|
321
|
-
m = nodes[j]
|
|
322
|
-
# make sure the shortest chunks get preference
|
|
323
|
-
node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
|
|
491
|
+
edges_by_episode[episode.uuid] = [
|
|
492
|
+
edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
|
|
493
|
+
]
|
|
324
494
|
|
|
325
|
-
|
|
326
|
-
m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)])
|
|
495
|
+
return edges_by_episode
|
|
327
496
|
|
|
328
|
-
# both nodes already in a chunk
|
|
329
|
-
if n_chunk > -1 and m_chunk > -1:
|
|
330
|
-
continue
|
|
331
497
|
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
498
|
+
class UnionFind:
|
|
499
|
+
def __init__(self, elements):
|
|
500
|
+
# start each element in its own set
|
|
501
|
+
self.parent = {e: e for e in elements}
|
|
336
502
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
503
|
+
def find(self, x):
|
|
504
|
+
# path‐compression
|
|
505
|
+
if self.parent[x] != x:
|
|
506
|
+
self.parent[x] = self.find(self.parent[x])
|
|
507
|
+
return self.parent[x]
|
|
341
508
|
|
|
342
|
-
|
|
509
|
+
def union(self, a, b):
|
|
510
|
+
ra, rb = self.find(a), self.find(b)
|
|
511
|
+
if ra == rb:
|
|
512
|
+
return
|
|
513
|
+
# attach the lexicographically larger root under the smaller
|
|
514
|
+
if ra < rb:
|
|
515
|
+
self.parent[rb] = ra
|
|
343
516
|
else:
|
|
344
|
-
|
|
345
|
-
node_chunks[-1].extend([n, m])
|
|
346
|
-
|
|
347
|
-
results = await semaphore_gather(
|
|
348
|
-
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
|
|
349
|
-
)
|
|
350
|
-
|
|
351
|
-
extended_map = dict(uuid_map)
|
|
352
|
-
compressed_nodes: list[EntityNode] = []
|
|
353
|
-
for node_chunk, uuid_map_chunk in results:
|
|
354
|
-
compressed_nodes += node_chunk
|
|
355
|
-
extended_map.update(uuid_map_chunk)
|
|
356
|
-
|
|
357
|
-
# Check if we have removed all duplicates
|
|
358
|
-
if len(compressed_nodes) == len(nodes):
|
|
359
|
-
compressed_uuid_map = compress_uuid_map(extended_map)
|
|
360
|
-
return compressed_nodes, compressed_uuid_map
|
|
361
|
-
|
|
362
|
-
return await compress_nodes(llm_client, compressed_nodes, extended_map)
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
|
|
366
|
-
if len(edges) == 0:
|
|
367
|
-
return edges
|
|
368
|
-
# We only want to dedupe edges that are between the same pair of nodes
|
|
369
|
-
# We build a map of the edges based on their source and target nodes.
|
|
370
|
-
edge_chunks = chunk_edges_by_nodes(edges)
|
|
371
|
-
|
|
372
|
-
results = await semaphore_gather(
|
|
373
|
-
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
compressed_edges: list[EntityEdge] = []
|
|
377
|
-
for edge_chunk in results:
|
|
378
|
-
compressed_edges += edge_chunk
|
|
379
|
-
|
|
380
|
-
# Check if we have removed all duplicates
|
|
381
|
-
if len(compressed_edges) == len(edges):
|
|
382
|
-
return compressed_edges
|
|
383
|
-
|
|
384
|
-
return await compress_edges(llm_client, compressed_edges)
|
|
517
|
+
self.parent[ra] = rb
|
|
385
518
|
|
|
386
519
|
|
|
387
|
-
def compress_uuid_map(
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
520
|
+
def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
|
|
521
|
+
"""
|
|
522
|
+
all_ids: iterable of all entity IDs (strings)
|
|
523
|
+
duplicate_pairs: iterable of (id1, id2) pairs
|
|
524
|
+
returns: dict mapping each id -> lexicographically smallest id in its duplicate set
|
|
525
|
+
"""
|
|
526
|
+
all_uuids = set()
|
|
527
|
+
for pair in duplicate_pairs:
|
|
528
|
+
all_uuids.add(pair[0])
|
|
529
|
+
all_uuids.add(pair[1])
|
|
394
530
|
|
|
395
|
-
|
|
396
|
-
|
|
531
|
+
uf = UnionFind(all_uuids)
|
|
532
|
+
for a, b in duplicate_pairs:
|
|
533
|
+
uf.union(a, b)
|
|
534
|
+
# ensure full path‐compression before mapping
|
|
535
|
+
return {uuid: uf.find(uuid) for uuid in all_uuids}
|
|
397
536
|
|
|
398
537
|
|
|
399
538
|
E = typing.TypeVar('E', bound=Edge)
|
|
@@ -407,63 +546,3 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
|
|
407
546
|
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
|
408
547
|
|
|
409
548
|
return edges
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
async def extract_edge_dates_bulk(
|
|
413
|
-
llm_client: LLMClient,
|
|
414
|
-
extracted_edges: list[EntityEdge],
|
|
415
|
-
episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
416
|
-
) -> list[EntityEdge]:
|
|
417
|
-
edges: list[EntityEdge] = []
|
|
418
|
-
# confirm that all of our edges have at least one episode
|
|
419
|
-
for edge in extracted_edges:
|
|
420
|
-
if edge.episodes is not None and len(edge.episodes) > 0:
|
|
421
|
-
edges.append(edge)
|
|
422
|
-
|
|
423
|
-
episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
|
|
424
|
-
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
|
|
425
|
-
}
|
|
426
|
-
|
|
427
|
-
results = await semaphore_gather(
|
|
428
|
-
*[
|
|
429
|
-
extract_edge_dates(
|
|
430
|
-
llm_client,
|
|
431
|
-
edge,
|
|
432
|
-
episode_uuid_map[edge.episodes[0]][0], # type: ignore
|
|
433
|
-
episode_uuid_map[edge.episodes[0]][1], # type: ignore
|
|
434
|
-
)
|
|
435
|
-
for edge in edges
|
|
436
|
-
]
|
|
437
|
-
)
|
|
438
|
-
|
|
439
|
-
for i, result in enumerate(results):
|
|
440
|
-
valid_at = result[0]
|
|
441
|
-
invalid_at = result[1]
|
|
442
|
-
edge = edges[i]
|
|
443
|
-
|
|
444
|
-
edge.valid_at = valid_at
|
|
445
|
-
edge.invalid_at = invalid_at
|
|
446
|
-
if edge.invalid_at:
|
|
447
|
-
edge.expired_at = utc_now()
|
|
448
|
-
|
|
449
|
-
return edges
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
|
|
453
|
-
# We only want to dedupe edges that are between the same pair of nodes
|
|
454
|
-
# We build a map of the edges based on their source and target nodes.
|
|
455
|
-
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
|
456
|
-
for edge in edges:
|
|
457
|
-
# We drop loop edges
|
|
458
|
-
if edge.source_node_uuid == edge.target_node_uuid:
|
|
459
|
-
continue
|
|
460
|
-
|
|
461
|
-
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
|
462
|
-
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
|
463
|
-
pointers.sort()
|
|
464
|
-
|
|
465
|
-
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
|
466
|
-
|
|
467
|
-
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
|
468
|
-
|
|
469
|
-
return edge_chunks
|