graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
|
@@ -15,22 +15,26 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
|
-
from
|
|
18
|
+
from collections.abc import Awaitable, Callable
|
|
19
19
|
from time import time
|
|
20
20
|
from typing import Any
|
|
21
|
-
from uuid import uuid4
|
|
22
21
|
|
|
23
|
-
import
|
|
24
|
-
from pydantic import BaseModel, Field
|
|
22
|
+
from pydantic import BaseModel
|
|
25
23
|
|
|
26
24
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
27
25
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
|
28
26
|
from graphiti_core.llm_client import LLMClient
|
|
29
27
|
from graphiti_core.llm_client.config import ModelSize
|
|
30
|
-
from graphiti_core.nodes import
|
|
28
|
+
from graphiti_core.nodes import (
|
|
29
|
+
EntityNode,
|
|
30
|
+
EpisodeType,
|
|
31
|
+
EpisodicNode,
|
|
32
|
+
create_entity_node_embeddings,
|
|
33
|
+
)
|
|
31
34
|
from graphiti_core.prompts import prompt_library
|
|
32
35
|
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
|
|
33
36
|
from graphiti_core.prompts.extract_nodes import (
|
|
37
|
+
EntitySummary,
|
|
34
38
|
ExtractedEntities,
|
|
35
39
|
ExtractedEntity,
|
|
36
40
|
MissedEntities,
|
|
@@ -40,15 +44,28 @@ from graphiti_core.search.search_config import SearchResults
|
|
|
40
44
|
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
|
41
45
|
from graphiti_core.search.search_filters import SearchFilters
|
|
42
46
|
from graphiti_core.utils.datetime_utils import utc_now
|
|
47
|
+
from graphiti_core.utils.maintenance.dedup_helpers import (
|
|
48
|
+
DedupCandidateIndexes,
|
|
49
|
+
DedupResolutionState,
|
|
50
|
+
_build_candidate_indexes,
|
|
51
|
+
_resolve_with_similarity,
|
|
52
|
+
)
|
|
53
|
+
from graphiti_core.utils.maintenance.edge_operations import (
|
|
54
|
+
filter_existing_duplicate_of_edges,
|
|
55
|
+
)
|
|
56
|
+
from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
|
|
43
57
|
|
|
44
58
|
logger = logging.getLogger(__name__)
|
|
45
59
|
|
|
60
|
+
NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
|
|
61
|
+
|
|
46
62
|
|
|
47
63
|
async def extract_nodes_reflexion(
|
|
48
64
|
llm_client: LLMClient,
|
|
49
65
|
episode: EpisodicNode,
|
|
50
66
|
previous_episodes: list[EpisodicNode],
|
|
51
67
|
node_names: list[str],
|
|
68
|
+
group_id: str | None = None,
|
|
52
69
|
) -> list[str]:
|
|
53
70
|
# Prepare context for LLM
|
|
54
71
|
context = {
|
|
@@ -58,7 +75,10 @@ async def extract_nodes_reflexion(
|
|
|
58
75
|
}
|
|
59
76
|
|
|
60
77
|
llm_response = await llm_client.generate_response(
|
|
61
|
-
prompt_library.extract_nodes.reflexion(context),
|
|
78
|
+
prompt_library.extract_nodes.reflexion(context),
|
|
79
|
+
MissedEntities,
|
|
80
|
+
group_id=group_id,
|
|
81
|
+
prompt_name='extract_nodes.reflexion',
|
|
62
82
|
)
|
|
63
83
|
missed_entities = llm_response.get('missed_entities', [])
|
|
64
84
|
|
|
@@ -69,7 +89,8 @@ async def extract_nodes(
|
|
|
69
89
|
clients: GraphitiClients,
|
|
70
90
|
episode: EpisodicNode,
|
|
71
91
|
previous_episodes: list[EpisodicNode],
|
|
72
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
92
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
93
|
+
excluded_entity_types: list[str] | None = None,
|
|
73
94
|
) -> list[EntityNode]:
|
|
74
95
|
start = time()
|
|
75
96
|
llm_client = clients.llm_client
|
|
@@ -113,20 +134,27 @@ async def extract_nodes(
|
|
|
113
134
|
llm_response = await llm_client.generate_response(
|
|
114
135
|
prompt_library.extract_nodes.extract_message(context),
|
|
115
136
|
response_model=ExtractedEntities,
|
|
137
|
+
group_id=episode.group_id,
|
|
138
|
+
prompt_name='extract_nodes.extract_message',
|
|
116
139
|
)
|
|
117
140
|
elif episode.source == EpisodeType.text:
|
|
118
141
|
llm_response = await llm_client.generate_response(
|
|
119
|
-
prompt_library.extract_nodes.extract_text(context),
|
|
142
|
+
prompt_library.extract_nodes.extract_text(context),
|
|
143
|
+
response_model=ExtractedEntities,
|
|
144
|
+
group_id=episode.group_id,
|
|
145
|
+
prompt_name='extract_nodes.extract_text',
|
|
120
146
|
)
|
|
121
147
|
elif episode.source == EpisodeType.json:
|
|
122
148
|
llm_response = await llm_client.generate_response(
|
|
123
|
-
prompt_library.extract_nodes.extract_json(context),
|
|
149
|
+
prompt_library.extract_nodes.extract_json(context),
|
|
150
|
+
response_model=ExtractedEntities,
|
|
151
|
+
group_id=episode.group_id,
|
|
152
|
+
prompt_name='extract_nodes.extract_json',
|
|
124
153
|
)
|
|
125
154
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
]
|
|
155
|
+
response_object = ExtractedEntities(**llm_response)
|
|
156
|
+
|
|
157
|
+
extracted_entities: list[ExtractedEntity] = response_object.extracted_entities
|
|
130
158
|
|
|
131
159
|
reflexion_iterations += 1
|
|
132
160
|
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
|
@@ -135,6 +163,7 @@ async def extract_nodes(
|
|
|
135
163
|
episode,
|
|
136
164
|
previous_episodes,
|
|
137
165
|
[entity.name for entity in extracted_entities],
|
|
166
|
+
episode.group_id,
|
|
138
167
|
)
|
|
139
168
|
|
|
140
169
|
entities_missed = len(missing_entities) != 0
|
|
@@ -149,9 +178,18 @@ async def extract_nodes(
|
|
|
149
178
|
# Convert the extracted data into EntityNode objects
|
|
150
179
|
extracted_nodes = []
|
|
151
180
|
for extracted_entity in filtered_extracted_entities:
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
181
|
+
type_id = extracted_entity.entity_type_id
|
|
182
|
+
if 0 <= type_id < len(entity_types_context):
|
|
183
|
+
entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
|
|
184
|
+
'entity_type_name'
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
entity_type_name = 'Entity'
|
|
188
|
+
|
|
189
|
+
# Check if this entity type should be excluded
|
|
190
|
+
if excluded_entity_types and entity_type_name in excluded_entity_types:
|
|
191
|
+
logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
|
|
192
|
+
continue
|
|
155
193
|
|
|
156
194
|
labels: list[str] = list({'Entity', str(entity_type_name)})
|
|
157
195
|
|
|
@@ -166,68 +204,16 @@ async def extract_nodes(
|
|
|
166
204
|
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
|
167
205
|
|
|
168
206
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
169
|
-
return extracted_nodes
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
async def dedupe_extracted_nodes(
|
|
173
|
-
llm_client: LLMClient,
|
|
174
|
-
extracted_nodes: list[EntityNode],
|
|
175
|
-
existing_nodes: list[EntityNode],
|
|
176
|
-
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
177
|
-
start = time()
|
|
178
|
-
|
|
179
|
-
# build existing node map
|
|
180
|
-
node_map: dict[str, EntityNode] = {}
|
|
181
|
-
for node in existing_nodes:
|
|
182
|
-
node_map[node.uuid] = node
|
|
183
|
-
|
|
184
|
-
# Prepare context for LLM
|
|
185
|
-
existing_nodes_context = [
|
|
186
|
-
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
|
|
187
|
-
]
|
|
188
|
-
|
|
189
|
-
extracted_nodes_context = [
|
|
190
|
-
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in extracted_nodes
|
|
191
|
-
]
|
|
192
|
-
|
|
193
|
-
context = {
|
|
194
|
-
'existing_nodes': existing_nodes_context,
|
|
195
|
-
'extracted_nodes': extracted_nodes_context,
|
|
196
|
-
}
|
|
197
|
-
|
|
198
|
-
llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.node(context))
|
|
199
|
-
|
|
200
|
-
duplicate_data = llm_response.get('duplicates', [])
|
|
201
|
-
|
|
202
|
-
end = time()
|
|
203
|
-
logger.debug(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
|
|
204
|
-
|
|
205
|
-
uuid_map: dict[str, str] = {}
|
|
206
|
-
for duplicate in duplicate_data:
|
|
207
|
-
uuid_value = duplicate['duplicate_of']
|
|
208
|
-
uuid_map[duplicate['uuid']] = uuid_value
|
|
209
|
-
|
|
210
|
-
nodes: list[EntityNode] = []
|
|
211
|
-
for node in extracted_nodes:
|
|
212
|
-
if node.uuid in uuid_map:
|
|
213
|
-
existing_uuid = uuid_map[node.uuid]
|
|
214
|
-
existing_node = node_map[existing_uuid]
|
|
215
|
-
nodes.append(existing_node)
|
|
216
|
-
else:
|
|
217
|
-
nodes.append(node)
|
|
218
207
|
|
|
219
|
-
return
|
|
208
|
+
return extracted_nodes
|
|
220
209
|
|
|
221
210
|
|
|
222
|
-
async def
|
|
211
|
+
async def _collect_candidate_nodes(
|
|
223
212
|
clients: GraphitiClients,
|
|
224
213
|
extracted_nodes: list[EntityNode],
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
229
|
-
llm_client = clients.llm_client
|
|
230
|
-
|
|
214
|
+
existing_nodes_override: list[EntityNode] | None,
|
|
215
|
+
) -> list[EntityNode]:
|
|
216
|
+
"""Search per extracted name and return unique candidates with overrides honored in order."""
|
|
231
217
|
search_results: list[SearchResults] = await semaphore_gather(
|
|
232
218
|
*[
|
|
233
219
|
search(
|
|
@@ -241,11 +227,43 @@ async def resolve_extracted_nodes(
|
|
|
241
227
|
]
|
|
242
228
|
)
|
|
243
229
|
|
|
244
|
-
|
|
230
|
+
candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]
|
|
245
231
|
|
|
246
|
-
|
|
232
|
+
if existing_nodes_override is not None:
|
|
233
|
+
candidate_nodes.extend(existing_nodes_override)
|
|
234
|
+
|
|
235
|
+
seen_candidate_uuids: set[str] = set()
|
|
236
|
+
ordered_candidates: list[EntityNode] = []
|
|
237
|
+
for candidate in candidate_nodes:
|
|
238
|
+
if candidate.uuid in seen_candidate_uuids:
|
|
239
|
+
continue
|
|
240
|
+
seen_candidate_uuids.add(candidate.uuid)
|
|
241
|
+
ordered_candidates.append(candidate)
|
|
242
|
+
|
|
243
|
+
return ordered_candidates
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
async def _resolve_with_llm(
|
|
247
|
+
llm_client: LLMClient,
|
|
248
|
+
extracted_nodes: list[EntityNode],
|
|
249
|
+
indexes: DedupCandidateIndexes,
|
|
250
|
+
state: DedupResolutionState,
|
|
251
|
+
episode: EpisodicNode | None,
|
|
252
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
253
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
254
|
+
) -> None:
|
|
255
|
+
"""Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
|
|
256
|
+
|
|
257
|
+
The guardrails below defensively ignore malformed or duplicate LLM responses so the
|
|
258
|
+
ingestion workflow remains deterministic even when the model misbehaves.
|
|
259
|
+
"""
|
|
260
|
+
if not state.unresolved_indices:
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
|
|
264
|
+
|
|
265
|
+
llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]
|
|
247
266
|
|
|
248
|
-
# Prepare context for LLM
|
|
249
267
|
extracted_nodes_context = [
|
|
250
268
|
{
|
|
251
269
|
'id': i,
|
|
@@ -255,122 +273,181 @@ async def resolve_extracted_nodes(
|
|
|
255
273
|
next((item for item in node.labels if item != 'Entity'), '')
|
|
256
274
|
).__doc__
|
|
257
275
|
or 'Default Entity Type',
|
|
258
|
-
'duplication_candidates': [
|
|
259
|
-
{
|
|
260
|
-
**{
|
|
261
|
-
'idx': j,
|
|
262
|
-
'name': candidate.name,
|
|
263
|
-
'entity_types': candidate.labels,
|
|
264
|
-
},
|
|
265
|
-
**candidate.attributes,
|
|
266
|
-
}
|
|
267
|
-
for j, candidate in enumerate(existing_nodes_lists[i])
|
|
268
|
-
],
|
|
269
276
|
}
|
|
270
|
-
for i, node in enumerate(
|
|
277
|
+
for i, node in enumerate(llm_extracted_nodes)
|
|
278
|
+
]
|
|
279
|
+
|
|
280
|
+
sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
|
|
281
|
+
logger.debug(
|
|
282
|
+
'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
|
|
283
|
+
len(llm_extracted_nodes),
|
|
284
|
+
len(llm_extracted_nodes) - 1,
|
|
285
|
+
sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
|
|
286
|
+
)
|
|
287
|
+
if llm_extracted_nodes:
|
|
288
|
+
sample_size = min(3, len(extracted_nodes_context))
|
|
289
|
+
logger.debug(
|
|
290
|
+
'First %d entities: %s',
|
|
291
|
+
sample_size,
|
|
292
|
+
[(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
|
|
293
|
+
)
|
|
294
|
+
if len(extracted_nodes_context) > 3:
|
|
295
|
+
logger.debug(
|
|
296
|
+
'Last %d entities: %s',
|
|
297
|
+
sample_size,
|
|
298
|
+
[(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
existing_nodes_context = [
|
|
302
|
+
{
|
|
303
|
+
**{
|
|
304
|
+
'idx': i,
|
|
305
|
+
'name': candidate.name,
|
|
306
|
+
'entity_types': candidate.labels,
|
|
307
|
+
},
|
|
308
|
+
**candidate.attributes,
|
|
309
|
+
}
|
|
310
|
+
for i, candidate in enumerate(indexes.existing_nodes)
|
|
271
311
|
]
|
|
272
312
|
|
|
273
313
|
context = {
|
|
274
314
|
'extracted_nodes': extracted_nodes_context,
|
|
315
|
+
'existing_nodes': existing_nodes_context,
|
|
275
316
|
'episode_content': episode.content if episode is not None else '',
|
|
276
|
-
'previous_episodes':
|
|
277
|
-
|
|
278
|
-
|
|
317
|
+
'previous_episodes': (
|
|
318
|
+
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
|
319
|
+
),
|
|
279
320
|
}
|
|
280
321
|
|
|
281
322
|
llm_response = await llm_client.generate_response(
|
|
282
323
|
prompt_library.dedupe_nodes.nodes(context),
|
|
283
324
|
response_model=NodeResolutions,
|
|
325
|
+
prompt_name='dedupe_nodes.nodes',
|
|
284
326
|
)
|
|
285
327
|
|
|
286
|
-
node_resolutions: list = llm_response.
|
|
328
|
+
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
|
|
287
329
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
330
|
+
valid_relative_range = range(len(state.unresolved_indices))
|
|
331
|
+
processed_relative_ids: set[int] = set()
|
|
332
|
+
|
|
333
|
+
received_ids = {r.id for r in node_resolutions}
|
|
334
|
+
expected_ids = set(valid_relative_range)
|
|
335
|
+
missing_ids = expected_ids - received_ids
|
|
336
|
+
extra_ids = received_ids - expected_ids
|
|
337
|
+
|
|
338
|
+
logger.debug(
|
|
339
|
+
'Received %d resolutions for %d entities',
|
|
340
|
+
len(node_resolutions),
|
|
341
|
+
len(state.unresolved_indices),
|
|
342
|
+
)
|
|
293
343
|
|
|
294
|
-
|
|
344
|
+
if missing_ids:
|
|
345
|
+
logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))
|
|
295
346
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
347
|
+
if extra_ids:
|
|
348
|
+
logger.warning(
|
|
349
|
+
'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
|
|
350
|
+
len(state.unresolved_indices) - 1,
|
|
351
|
+
sorted(extra_ids),
|
|
352
|
+
sorted(received_ids),
|
|
300
353
|
)
|
|
301
354
|
|
|
302
|
-
|
|
355
|
+
for resolution in node_resolutions:
|
|
356
|
+
relative_id: int = resolution.id
|
|
357
|
+
duplicate_idx: int = resolution.duplicate_idx
|
|
358
|
+
|
|
359
|
+
if relative_id not in valid_relative_range:
|
|
360
|
+
logger.warning(
|
|
361
|
+
'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
|
|
362
|
+
relative_id,
|
|
363
|
+
len(state.unresolved_indices) - 1,
|
|
364
|
+
len(node_resolutions),
|
|
365
|
+
)
|
|
366
|
+
continue
|
|
303
367
|
|
|
304
|
-
|
|
305
|
-
|
|
368
|
+
if relative_id in processed_relative_ids:
|
|
369
|
+
logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
|
|
370
|
+
continue
|
|
371
|
+
processed_relative_ids.add(relative_id)
|
|
306
372
|
|
|
307
|
-
|
|
373
|
+
original_index = state.unresolved_indices[relative_id]
|
|
374
|
+
extracted_node = extracted_nodes[original_index]
|
|
308
375
|
|
|
309
|
-
|
|
376
|
+
resolved_node: EntityNode
|
|
377
|
+
if duplicate_idx == -1:
|
|
378
|
+
resolved_node = extracted_node
|
|
379
|
+
elif 0 <= duplicate_idx < len(indexes.existing_nodes):
|
|
380
|
+
resolved_node = indexes.existing_nodes[duplicate_idx]
|
|
381
|
+
else:
|
|
382
|
+
logger.warning(
|
|
383
|
+
'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
|
|
384
|
+
duplicate_idx,
|
|
385
|
+
extracted_node.uuid,
|
|
386
|
+
)
|
|
387
|
+
resolved_node = extracted_node
|
|
310
388
|
|
|
389
|
+
state.resolved_nodes[original_index] = resolved_node
|
|
390
|
+
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
|
|
391
|
+
if resolved_node.uuid != extracted_node.uuid:
|
|
392
|
+
state.duplicate_pairs.append((extracted_node, resolved_node))
|
|
311
393
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
394
|
+
|
|
395
|
+
async def resolve_extracted_nodes(
|
|
396
|
+
clients: GraphitiClients,
|
|
397
|
+
extracted_nodes: list[EntityNode],
|
|
316
398
|
episode: EpisodicNode | None = None,
|
|
317
399
|
previous_episodes: list[EpisodicNode] | None = None,
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
'name': node.name,
|
|
330
|
-
'entity_types': node.labels,
|
|
331
|
-
},
|
|
332
|
-
**node.attributes,
|
|
333
|
-
}
|
|
334
|
-
for i, node in enumerate(existing_nodes)
|
|
335
|
-
]
|
|
336
|
-
|
|
337
|
-
extracted_node_context = {
|
|
338
|
-
'name': extracted_node.name,
|
|
339
|
-
'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore
|
|
340
|
-
}
|
|
400
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
401
|
+
existing_nodes_override: list[EntityNode] | None = None,
|
|
402
|
+
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
|
403
|
+
"""Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
|
|
404
|
+
llm_client = clients.llm_client
|
|
405
|
+
driver = clients.driver
|
|
406
|
+
existing_nodes = await _collect_candidate_nodes(
|
|
407
|
+
clients,
|
|
408
|
+
extracted_nodes,
|
|
409
|
+
existing_nodes_override,
|
|
410
|
+
)
|
|
341
411
|
|
|
342
|
-
|
|
343
|
-
'existing_nodes': existing_nodes_context,
|
|
344
|
-
'extracted_node': extracted_node_context,
|
|
345
|
-
'entity_type_description': entity_type.__doc__
|
|
346
|
-
if entity_type is not None
|
|
347
|
-
else 'Default Entity Type',
|
|
348
|
-
'episode_content': episode.content if episode is not None else '',
|
|
349
|
-
'previous_episodes': [ep.content for ep in previous_episodes]
|
|
350
|
-
if previous_episodes is not None
|
|
351
|
-
else [],
|
|
352
|
-
}
|
|
412
|
+
indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)
|
|
353
413
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
414
|
+
state = DedupResolutionState(
|
|
415
|
+
resolved_nodes=[None] * len(extracted_nodes),
|
|
416
|
+
uuid_map={},
|
|
417
|
+
unresolved_indices=[],
|
|
358
418
|
)
|
|
359
419
|
|
|
360
|
-
|
|
420
|
+
_resolve_with_similarity(extracted_nodes, indexes, state)
|
|
361
421
|
|
|
362
|
-
|
|
363
|
-
|
|
422
|
+
await _resolve_with_llm(
|
|
423
|
+
llm_client,
|
|
424
|
+
extracted_nodes,
|
|
425
|
+
indexes,
|
|
426
|
+
state,
|
|
427
|
+
episode,
|
|
428
|
+
previous_episodes,
|
|
429
|
+
entity_types,
|
|
364
430
|
)
|
|
365
431
|
|
|
366
|
-
node
|
|
432
|
+
for idx, node in enumerate(extracted_nodes):
|
|
433
|
+
if state.resolved_nodes[idx] is None:
|
|
434
|
+
state.resolved_nodes[idx] = node
|
|
435
|
+
state.uuid_map[node.uuid] = node.uuid
|
|
367
436
|
|
|
368
|
-
end = time()
|
|
369
437
|
logger.debug(
|
|
370
|
-
|
|
438
|
+
'Resolved nodes: %s',
|
|
439
|
+
[(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
|
|
371
440
|
)
|
|
372
441
|
|
|
373
|
-
|
|
442
|
+
new_node_duplicates: list[
|
|
443
|
+
tuple[EntityNode, EntityNode]
|
|
444
|
+
] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
|
|
445
|
+
|
|
446
|
+
return (
|
|
447
|
+
[node for node in state.resolved_nodes if node is not None],
|
|
448
|
+
state.uuid_map,
|
|
449
|
+
new_node_duplicates,
|
|
450
|
+
)
|
|
374
451
|
|
|
375
452
|
|
|
376
453
|
async def extract_attributes_from_nodes(
|
|
@@ -378,11 +455,11 @@ async def extract_attributes_from_nodes(
|
|
|
378
455
|
nodes: list[EntityNode],
|
|
379
456
|
episode: EpisodicNode | None = None,
|
|
380
457
|
previous_episodes: list[EpisodicNode] | None = None,
|
|
381
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
458
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
459
|
+
should_summarize_node: NodeSummaryFilter | None = None,
|
|
382
460
|
) -> list[EntityNode]:
|
|
383
461
|
llm_client = clients.llm_client
|
|
384
462
|
embedder = clients.embedder
|
|
385
|
-
|
|
386
463
|
updated_nodes: list[EntityNode] = await semaphore_gather(
|
|
387
464
|
*[
|
|
388
465
|
extract_attributes_from_node(
|
|
@@ -390,9 +467,12 @@ async def extract_attributes_from_nodes(
|
|
|
390
467
|
node,
|
|
391
468
|
episode,
|
|
392
469
|
previous_episodes,
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
470
|
+
(
|
|
471
|
+
entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
|
|
472
|
+
if entity_types is not None
|
|
473
|
+
else None
|
|
474
|
+
),
|
|
475
|
+
should_summarize_node,
|
|
396
476
|
)
|
|
397
477
|
for node in nodes
|
|
398
478
|
]
|
|
@@ -408,99 +488,100 @@ async def extract_attributes_from_node(
|
|
|
408
488
|
node: EntityNode,
|
|
409
489
|
episode: EpisodicNode | None = None,
|
|
410
490
|
previous_episodes: list[EpisodicNode] | None = None,
|
|
411
|
-
entity_type: BaseModel | None = None,
|
|
491
|
+
entity_type: type[BaseModel] | None = None,
|
|
492
|
+
should_summarize_node: NodeSummaryFilter | None = None,
|
|
412
493
|
) -> EntityNode:
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
'attributes': node.attributes,
|
|
418
|
-
}
|
|
494
|
+
# Extract attributes if entity type is defined and has attributes
|
|
495
|
+
llm_response = await _extract_entity_attributes(
|
|
496
|
+
llm_client, node, episode, previous_episodes, entity_type
|
|
497
|
+
)
|
|
419
498
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
description='Summary containing the important information about the entity. Under 250 words',
|
|
425
|
-
),
|
|
426
|
-
)
|
|
427
|
-
}
|
|
499
|
+
# Extract summary if needed
|
|
500
|
+
await _extract_entity_summary(
|
|
501
|
+
llm_client, node, episode, previous_episodes, should_summarize_node
|
|
502
|
+
)
|
|
428
503
|
|
|
429
|
-
|
|
430
|
-
for field_name, field_info in entity_type.model_fields.items():
|
|
431
|
-
attributes_definitions[field_name] = (
|
|
432
|
-
field_info.annotation,
|
|
433
|
-
Field(description=field_info.description),
|
|
434
|
-
)
|
|
504
|
+
node.attributes.update(llm_response)
|
|
435
505
|
|
|
436
|
-
|
|
437
|
-
entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions)
|
|
506
|
+
return node
|
|
438
507
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
508
|
+
|
|
509
|
+
async def _extract_entity_attributes(
|
|
510
|
+
llm_client: LLMClient,
|
|
511
|
+
node: EntityNode,
|
|
512
|
+
episode: EpisodicNode | None,
|
|
513
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
514
|
+
entity_type: type[BaseModel] | None,
|
|
515
|
+
) -> dict[str, Any]:
|
|
516
|
+
if entity_type is None or len(entity_type.model_fields) == 0:
|
|
517
|
+
return {}
|
|
518
|
+
|
|
519
|
+
attributes_context = _build_episode_context(
|
|
520
|
+
# should not include summary
|
|
521
|
+
node_data={
|
|
522
|
+
'name': node.name,
|
|
523
|
+
'entity_types': node.labels,
|
|
524
|
+
'attributes': node.attributes,
|
|
525
|
+
},
|
|
526
|
+
episode=episode,
|
|
527
|
+
previous_episodes=previous_episodes,
|
|
528
|
+
)
|
|
446
529
|
|
|
447
530
|
llm_response = await llm_client.generate_response(
|
|
448
|
-
prompt_library.extract_nodes.extract_attributes(
|
|
449
|
-
response_model=
|
|
531
|
+
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
|
532
|
+
response_model=entity_type,
|
|
450
533
|
model_size=ModelSize.small,
|
|
534
|
+
group_id=node.group_id,
|
|
535
|
+
prompt_name='extract_nodes.extract_attributes',
|
|
451
536
|
)
|
|
452
537
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
with suppress(KeyError):
|
|
457
|
-
del node_attributes['summary']
|
|
538
|
+
# validate response
|
|
539
|
+
entity_type(**llm_response)
|
|
458
540
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
return node
|
|
541
|
+
return llm_response
|
|
462
542
|
|
|
463
543
|
|
|
464
|
-
async def
|
|
544
|
+
async def _extract_entity_summary(
|
|
465
545
|
llm_client: LLMClient,
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
prompt_library.dedupe_nodes.node_list(context)
|
|
546
|
+
node: EntityNode,
|
|
547
|
+
episode: EpisodicNode | None,
|
|
548
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
549
|
+
should_summarize_node: NodeSummaryFilter | None,
|
|
550
|
+
) -> None:
|
|
551
|
+
if should_summarize_node is not None and not await should_summarize_node(node):
|
|
552
|
+
return
|
|
553
|
+
|
|
554
|
+
summary_context = _build_episode_context(
|
|
555
|
+
node_data={
|
|
556
|
+
'name': node.name,
|
|
557
|
+
'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
|
|
558
|
+
'entity_types': node.labels,
|
|
559
|
+
'attributes': node.attributes,
|
|
560
|
+
},
|
|
561
|
+
episode=episode,
|
|
562
|
+
previous_episodes=previous_episodes,
|
|
484
563
|
)
|
|
485
564
|
|
|
486
|
-
|
|
565
|
+
summary_response = await llm_client.generate_response(
|
|
566
|
+
prompt_library.extract_nodes.extract_summary(summary_context),
|
|
567
|
+
response_model=EntitySummary,
|
|
568
|
+
model_size=ModelSize.small,
|
|
569
|
+
group_id=node.group_id,
|
|
570
|
+
prompt_name='extract_nodes.extract_summary',
|
|
571
|
+
)
|
|
487
572
|
|
|
488
|
-
|
|
489
|
-
logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
|
|
490
|
-
|
|
491
|
-
# Get full node data
|
|
492
|
-
unique_nodes = []
|
|
493
|
-
uuid_map: dict[str, str] = {}
|
|
494
|
-
for node_data in nodes_data:
|
|
495
|
-
node_instance: EntityNode | None = node_map.get(node_data['uuids'][0])
|
|
496
|
-
if node_instance is None:
|
|
497
|
-
logger.warning(f'Node {node_data["uuids"][0]} not found in node map')
|
|
498
|
-
continue
|
|
499
|
-
node_instance.summary = node_data['summary']
|
|
500
|
-
unique_nodes.append(node_instance)
|
|
573
|
+
node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
|
|
501
574
|
|
|
502
|
-
for uuid in node_data['uuids'][1:]:
|
|
503
|
-
uuid_value = node_map[node_data['uuids'][0]].uuid
|
|
504
|
-
uuid_map[uuid] = uuid_value
|
|
505
575
|
|
|
506
|
-
|
|
576
|
+
def _build_episode_context(
|
|
577
|
+
node_data: dict[str, Any],
|
|
578
|
+
episode: EpisodicNode | None,
|
|
579
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
580
|
+
) -> dict[str, Any]:
|
|
581
|
+
return {
|
|
582
|
+
'node': node_data,
|
|
583
|
+
'episode_content': episode.content if episode is not None else '',
|
|
584
|
+
'previous_episodes': (
|
|
585
|
+
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
|
586
|
+
),
|
|
587
|
+
}
|