graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +61 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +582 -255
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +21 -14
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +94 -50
- graphiti_core/llm_client/openai_client.py +28 -8
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +23 -14
- graphiti_core/prompts/extract_nodes.py +73 -32
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +109 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +286 -126
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +320 -158
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,22 +15,26 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
|
-
from
|
|
18
|
+
from collections.abc import Awaitable, Callable
|
|
19
19
|
from time import time
|
|
20
20
|
from typing import Any
|
|
21
|
-
from uuid import uuid4
|
|
22
21
|
|
|
23
|
-
import
|
|
24
|
-
from pydantic import BaseModel, Field
|
|
22
|
+
from pydantic import BaseModel
|
|
25
23
|
|
|
26
24
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
27
25
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
|
28
26
|
from graphiti_core.llm_client import LLMClient
|
|
29
27
|
from graphiti_core.llm_client.config import ModelSize
|
|
30
|
-
from graphiti_core.nodes import
|
|
28
|
+
from graphiti_core.nodes import (
|
|
29
|
+
EntityNode,
|
|
30
|
+
EpisodeType,
|
|
31
|
+
EpisodicNode,
|
|
32
|
+
create_entity_node_embeddings,
|
|
33
|
+
)
|
|
31
34
|
from graphiti_core.prompts import prompt_library
|
|
32
|
-
from graphiti_core.prompts.dedupe_nodes import NodeResolutions
|
|
35
|
+
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
|
|
33
36
|
from graphiti_core.prompts.extract_nodes import (
|
|
37
|
+
EntitySummary,
|
|
34
38
|
ExtractedEntities,
|
|
35
39
|
ExtractedEntity,
|
|
36
40
|
MissedEntities,
|
|
@@ -40,16 +44,28 @@ from graphiti_core.search.search_config import SearchResults
|
|
|
40
44
|
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
|
41
45
|
from graphiti_core.search.search_filters import SearchFilters
|
|
42
46
|
from graphiti_core.utils.datetime_utils import utc_now
|
|
43
|
-
from graphiti_core.utils.maintenance.
|
|
47
|
+
from graphiti_core.utils.maintenance.dedup_helpers import (
|
|
48
|
+
DedupCandidateIndexes,
|
|
49
|
+
DedupResolutionState,
|
|
50
|
+
_build_candidate_indexes,
|
|
51
|
+
_resolve_with_similarity,
|
|
52
|
+
)
|
|
53
|
+
from graphiti_core.utils.maintenance.edge_operations import (
|
|
54
|
+
filter_existing_duplicate_of_edges,
|
|
55
|
+
)
|
|
56
|
+
from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
|
|
44
57
|
|
|
45
58
|
logger = logging.getLogger(__name__)
|
|
46
59
|
|
|
60
|
+
NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
|
|
61
|
+
|
|
47
62
|
|
|
48
63
|
async def extract_nodes_reflexion(
|
|
49
64
|
llm_client: LLMClient,
|
|
50
65
|
episode: EpisodicNode,
|
|
51
66
|
previous_episodes: list[EpisodicNode],
|
|
52
67
|
node_names: list[str],
|
|
68
|
+
group_id: str | None = None,
|
|
53
69
|
) -> list[str]:
|
|
54
70
|
# Prepare context for LLM
|
|
55
71
|
context = {
|
|
@@ -59,7 +75,10 @@ async def extract_nodes_reflexion(
|
|
|
59
75
|
}
|
|
60
76
|
|
|
61
77
|
llm_response = await llm_client.generate_response(
|
|
62
|
-
prompt_library.extract_nodes.reflexion(context),
|
|
78
|
+
prompt_library.extract_nodes.reflexion(context),
|
|
79
|
+
MissedEntities,
|
|
80
|
+
group_id=group_id,
|
|
81
|
+
prompt_name='extract_nodes.reflexion',
|
|
63
82
|
)
|
|
64
83
|
missed_entities = llm_response.get('missed_entities', [])
|
|
65
84
|
|
|
@@ -70,7 +89,7 @@ async def extract_nodes(
|
|
|
70
89
|
clients: GraphitiClients,
|
|
71
90
|
episode: EpisodicNode,
|
|
72
91
|
previous_episodes: list[EpisodicNode],
|
|
73
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
92
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
74
93
|
excluded_entity_types: list[str] | None = None,
|
|
75
94
|
) -> list[EntityNode]:
|
|
76
95
|
start = time()
|
|
@@ -115,20 +134,27 @@ async def extract_nodes(
|
|
|
115
134
|
llm_response = await llm_client.generate_response(
|
|
116
135
|
prompt_library.extract_nodes.extract_message(context),
|
|
117
136
|
response_model=ExtractedEntities,
|
|
137
|
+
group_id=episode.group_id,
|
|
138
|
+
prompt_name='extract_nodes.extract_message',
|
|
118
139
|
)
|
|
119
140
|
elif episode.source == EpisodeType.text:
|
|
120
141
|
llm_response = await llm_client.generate_response(
|
|
121
|
-
prompt_library.extract_nodes.extract_text(context),
|
|
142
|
+
prompt_library.extract_nodes.extract_text(context),
|
|
143
|
+
response_model=ExtractedEntities,
|
|
144
|
+
group_id=episode.group_id,
|
|
145
|
+
prompt_name='extract_nodes.extract_text',
|
|
122
146
|
)
|
|
123
147
|
elif episode.source == EpisodeType.json:
|
|
124
148
|
llm_response = await llm_client.generate_response(
|
|
125
|
-
prompt_library.extract_nodes.extract_json(context),
|
|
149
|
+
prompt_library.extract_nodes.extract_json(context),
|
|
150
|
+
response_model=ExtractedEntities,
|
|
151
|
+
group_id=episode.group_id,
|
|
152
|
+
prompt_name='extract_nodes.extract_json',
|
|
126
153
|
)
|
|
127
154
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
]
|
|
155
|
+
response_object = ExtractedEntities(**llm_response)
|
|
156
|
+
|
|
157
|
+
extracted_entities: list[ExtractedEntity] = response_object.extracted_entities
|
|
132
158
|
|
|
133
159
|
reflexion_iterations += 1
|
|
134
160
|
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
|
@@ -137,6 +163,7 @@ async def extract_nodes(
|
|
|
137
163
|
episode,
|
|
138
164
|
previous_episodes,
|
|
139
165
|
[entity.name for entity in extracted_entities],
|
|
166
|
+
episode.group_id,
|
|
140
167
|
)
|
|
141
168
|
|
|
142
169
|
entities_missed = len(missing_entities) != 0
|
|
@@ -151,9 +178,13 @@ async def extract_nodes(
|
|
|
151
178
|
# Convert the extracted data into EntityNode objects
|
|
152
179
|
extracted_nodes = []
|
|
153
180
|
for extracted_entity in filtered_extracted_entities:
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
181
|
+
type_id = extracted_entity.entity_type_id
|
|
182
|
+
if 0 <= type_id < len(entity_types_context):
|
|
183
|
+
entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
|
|
184
|
+
'entity_type_name'
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
entity_type_name = 'Entity'
|
|
157
188
|
|
|
158
189
|
# Check if this entity type should be excluded
|
|
159
190
|
if excluded_entity_types and entity_type_name in excluded_entity_types:
|
|
@@ -173,20 +204,16 @@ async def extract_nodes(
|
|
|
173
204
|
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
|
174
205
|
|
|
175
206
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
207
|
+
|
|
176
208
|
return extracted_nodes
|
|
177
209
|
|
|
178
210
|
|
|
179
|
-
async def
|
|
211
|
+
async def _collect_candidate_nodes(
|
|
180
212
|
clients: GraphitiClients,
|
|
181
213
|
extracted_nodes: list[EntityNode],
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
existing_nodes_override: list[EntityNode] | None = None,
|
|
186
|
-
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
|
187
|
-
llm_client = clients.llm_client
|
|
188
|
-
driver = clients.driver
|
|
189
|
-
|
|
214
|
+
existing_nodes_override: list[EntityNode] | None,
|
|
215
|
+
) -> list[EntityNode]:
|
|
216
|
+
"""Search per extracted name and return unique candidates with overrides honored in order."""
|
|
190
217
|
search_results: list[SearchResults] = await semaphore_gather(
|
|
191
218
|
*[
|
|
192
219
|
search(
|
|
@@ -200,33 +227,43 @@ async def resolve_extracted_nodes(
|
|
|
200
227
|
]
|
|
201
228
|
)
|
|
202
229
|
|
|
203
|
-
candidate_nodes: list[EntityNode] =
|
|
204
|
-
[node for result in search_results for node in result.nodes]
|
|
205
|
-
if existing_nodes_override is None
|
|
206
|
-
else existing_nodes_override
|
|
207
|
-
)
|
|
230
|
+
candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]
|
|
208
231
|
|
|
209
|
-
|
|
232
|
+
if existing_nodes_override is not None:
|
|
233
|
+
candidate_nodes.extend(existing_nodes_override)
|
|
210
234
|
|
|
211
|
-
|
|
235
|
+
seen_candidate_uuids: set[str] = set()
|
|
236
|
+
ordered_candidates: list[EntityNode] = []
|
|
237
|
+
for candidate in candidate_nodes:
|
|
238
|
+
if candidate.uuid in seen_candidate_uuids:
|
|
239
|
+
continue
|
|
240
|
+
seen_candidate_uuids.add(candidate.uuid)
|
|
241
|
+
ordered_candidates.append(candidate)
|
|
212
242
|
|
|
213
|
-
|
|
214
|
-
[
|
|
215
|
-
{
|
|
216
|
-
**{
|
|
217
|
-
'idx': i,
|
|
218
|
-
'name': candidate.name,
|
|
219
|
-
'entity_types': candidate.labels,
|
|
220
|
-
},
|
|
221
|
-
**candidate.attributes,
|
|
222
|
-
}
|
|
223
|
-
for i, candidate in enumerate(existing_nodes)
|
|
224
|
-
],
|
|
225
|
-
)
|
|
243
|
+
return ordered_candidates
|
|
226
244
|
|
|
227
|
-
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
|
|
228
245
|
|
|
229
|
-
|
|
246
|
+
async def _resolve_with_llm(
|
|
247
|
+
llm_client: LLMClient,
|
|
248
|
+
extracted_nodes: list[EntityNode],
|
|
249
|
+
indexes: DedupCandidateIndexes,
|
|
250
|
+
state: DedupResolutionState,
|
|
251
|
+
episode: EpisodicNode | None,
|
|
252
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
253
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
254
|
+
) -> None:
|
|
255
|
+
"""Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
|
|
256
|
+
|
|
257
|
+
The guardrails below defensively ignore malformed or duplicate LLM responses so the
|
|
258
|
+
ingestion workflow remains deterministic even when the model misbehaves.
|
|
259
|
+
"""
|
|
260
|
+
if not state.unresolved_indices:
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
|
|
264
|
+
|
|
265
|
+
llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]
|
|
266
|
+
|
|
230
267
|
extracted_nodes_context = [
|
|
231
268
|
{
|
|
232
269
|
'id': i,
|
|
@@ -237,60 +274,180 @@ async def resolve_extracted_nodes(
|
|
|
237
274
|
).__doc__
|
|
238
275
|
or 'Default Entity Type',
|
|
239
276
|
}
|
|
240
|
-
for i, node in enumerate(
|
|
277
|
+
for i, node in enumerate(llm_extracted_nodes)
|
|
278
|
+
]
|
|
279
|
+
|
|
280
|
+
sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
|
|
281
|
+
logger.debug(
|
|
282
|
+
'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
|
|
283
|
+
len(llm_extracted_nodes),
|
|
284
|
+
len(llm_extracted_nodes) - 1,
|
|
285
|
+
sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
|
|
286
|
+
)
|
|
287
|
+
if llm_extracted_nodes:
|
|
288
|
+
sample_size = min(3, len(extracted_nodes_context))
|
|
289
|
+
logger.debug(
|
|
290
|
+
'First %d entities: %s',
|
|
291
|
+
sample_size,
|
|
292
|
+
[(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
|
|
293
|
+
)
|
|
294
|
+
if len(extracted_nodes_context) > 3:
|
|
295
|
+
logger.debug(
|
|
296
|
+
'Last %d entities: %s',
|
|
297
|
+
sample_size,
|
|
298
|
+
[(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
existing_nodes_context = [
|
|
302
|
+
{
|
|
303
|
+
**{
|
|
304
|
+
'idx': i,
|
|
305
|
+
'name': candidate.name,
|
|
306
|
+
'entity_types': candidate.labels,
|
|
307
|
+
},
|
|
308
|
+
**candidate.attributes,
|
|
309
|
+
}
|
|
310
|
+
for i, candidate in enumerate(indexes.existing_nodes)
|
|
241
311
|
]
|
|
242
312
|
|
|
243
313
|
context = {
|
|
244
314
|
'extracted_nodes': extracted_nodes_context,
|
|
245
315
|
'existing_nodes': existing_nodes_context,
|
|
246
316
|
'episode_content': episode.content if episode is not None else '',
|
|
247
|
-
'previous_episodes':
|
|
248
|
-
|
|
249
|
-
|
|
317
|
+
'previous_episodes': (
|
|
318
|
+
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
|
319
|
+
),
|
|
250
320
|
}
|
|
251
321
|
|
|
252
322
|
llm_response = await llm_client.generate_response(
|
|
253
323
|
prompt_library.dedupe_nodes.nodes(context),
|
|
254
324
|
response_model=NodeResolutions,
|
|
325
|
+
prompt_name='dedupe_nodes.nodes',
|
|
255
326
|
)
|
|
256
327
|
|
|
257
|
-
node_resolutions: list = llm_response.
|
|
328
|
+
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
|
|
258
329
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
|
|
262
|
-
for resolution in node_resolutions:
|
|
263
|
-
resolution_id: int = resolution.get('id', -1)
|
|
264
|
-
duplicate_idx: int = resolution.get('duplicate_idx', -1)
|
|
330
|
+
valid_relative_range = range(len(state.unresolved_indices))
|
|
331
|
+
processed_relative_ids: set[int] = set()
|
|
265
332
|
|
|
266
|
-
|
|
333
|
+
received_ids = {r.id for r in node_resolutions}
|
|
334
|
+
expected_ids = set(valid_relative_range)
|
|
335
|
+
missing_ids = expected_ids - received_ids
|
|
336
|
+
extra_ids = received_ids - expected_ids
|
|
337
|
+
|
|
338
|
+
logger.debug(
|
|
339
|
+
'Received %d resolutions for %d entities',
|
|
340
|
+
len(node_resolutions),
|
|
341
|
+
len(state.unresolved_indices),
|
|
342
|
+
)
|
|
267
343
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
344
|
+
if missing_ids:
|
|
345
|
+
logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))
|
|
346
|
+
|
|
347
|
+
if extra_ids:
|
|
348
|
+
logger.warning(
|
|
349
|
+
'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
|
|
350
|
+
len(state.unresolved_indices) - 1,
|
|
351
|
+
sorted(extra_ids),
|
|
352
|
+
sorted(received_ids),
|
|
272
353
|
)
|
|
273
354
|
|
|
274
|
-
|
|
355
|
+
for resolution in node_resolutions:
|
|
356
|
+
relative_id: int = resolution.id
|
|
357
|
+
duplicate_idx: int = resolution.duplicate_idx
|
|
358
|
+
|
|
359
|
+
if relative_id not in valid_relative_range:
|
|
360
|
+
logger.warning(
|
|
361
|
+
'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
|
|
362
|
+
relative_id,
|
|
363
|
+
len(state.unresolved_indices) - 1,
|
|
364
|
+
len(node_resolutions),
|
|
365
|
+
)
|
|
366
|
+
continue
|
|
275
367
|
|
|
276
|
-
|
|
277
|
-
|
|
368
|
+
if relative_id in processed_relative_ids:
|
|
369
|
+
logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
|
|
370
|
+
continue
|
|
371
|
+
processed_relative_ids.add(relative_id)
|
|
372
|
+
|
|
373
|
+
original_index = state.unresolved_indices[relative_id]
|
|
374
|
+
extracted_node = extracted_nodes[original_index]
|
|
375
|
+
|
|
376
|
+
resolved_node: EntityNode
|
|
377
|
+
if duplicate_idx == -1:
|
|
378
|
+
resolved_node = extracted_node
|
|
379
|
+
elif 0 <= duplicate_idx < len(indexes.existing_nodes):
|
|
380
|
+
resolved_node = indexes.existing_nodes[duplicate_idx]
|
|
381
|
+
else:
|
|
382
|
+
logger.warning(
|
|
383
|
+
'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
|
|
384
|
+
duplicate_idx,
|
|
385
|
+
extracted_node.uuid,
|
|
386
|
+
)
|
|
387
|
+
resolved_node = extracted_node
|
|
278
388
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
|
|
389
|
+
state.resolved_nodes[original_index] = resolved_node
|
|
390
|
+
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
|
|
391
|
+
if resolved_node.uuid != extracted_node.uuid:
|
|
392
|
+
state.duplicate_pairs.append((extracted_node, resolved_node))
|
|
284
393
|
|
|
285
|
-
node_duplicates.append((extracted_node, existing_node))
|
|
286
394
|
|
|
287
|
-
|
|
395
|
+
async def resolve_extracted_nodes(
|
|
396
|
+
clients: GraphitiClients,
|
|
397
|
+
extracted_nodes: list[EntityNode],
|
|
398
|
+
episode: EpisodicNode | None = None,
|
|
399
|
+
previous_episodes: list[EpisodicNode] | None = None,
|
|
400
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
401
|
+
existing_nodes_override: list[EntityNode] | None = None,
|
|
402
|
+
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
|
403
|
+
"""Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
|
|
404
|
+
llm_client = clients.llm_client
|
|
405
|
+
driver = clients.driver
|
|
406
|
+
existing_nodes = await _collect_candidate_nodes(
|
|
407
|
+
clients,
|
|
408
|
+
extracted_nodes,
|
|
409
|
+
existing_nodes_override,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)
|
|
413
|
+
|
|
414
|
+
state = DedupResolutionState(
|
|
415
|
+
resolved_nodes=[None] * len(extracted_nodes),
|
|
416
|
+
uuid_map={},
|
|
417
|
+
unresolved_indices=[],
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
_resolve_with_similarity(extracted_nodes, indexes, state)
|
|
421
|
+
|
|
422
|
+
await _resolve_with_llm(
|
|
423
|
+
llm_client,
|
|
424
|
+
extracted_nodes,
|
|
425
|
+
indexes,
|
|
426
|
+
state,
|
|
427
|
+
episode,
|
|
428
|
+
previous_episodes,
|
|
429
|
+
entity_types,
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
for idx, node in enumerate(extracted_nodes):
|
|
433
|
+
if state.resolved_nodes[idx] is None:
|
|
434
|
+
state.resolved_nodes[idx] = node
|
|
435
|
+
state.uuid_map[node.uuid] = node.uuid
|
|
436
|
+
|
|
437
|
+
logger.debug(
|
|
438
|
+
'Resolved nodes: %s',
|
|
439
|
+
[(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
|
|
440
|
+
)
|
|
288
441
|
|
|
289
442
|
new_node_duplicates: list[
|
|
290
443
|
tuple[EntityNode, EntityNode]
|
|
291
|
-
] = await filter_existing_duplicate_of_edges(driver,
|
|
444
|
+
] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
|
|
292
445
|
|
|
293
|
-
return
|
|
446
|
+
return (
|
|
447
|
+
[node for node in state.resolved_nodes if node is not None],
|
|
448
|
+
state.uuid_map,
|
|
449
|
+
new_node_duplicates,
|
|
450
|
+
)
|
|
294
451
|
|
|
295
452
|
|
|
296
453
|
async def extract_attributes_from_nodes(
|
|
@@ -298,7 +455,8 @@ async def extract_attributes_from_nodes(
|
|
|
298
455
|
nodes: list[EntityNode],
|
|
299
456
|
episode: EpisodicNode | None = None,
|
|
300
457
|
previous_episodes: list[EpisodicNode] | None = None,
|
|
301
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
458
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
459
|
+
should_summarize_node: NodeSummaryFilter | None = None,
|
|
302
460
|
) -> list[EntityNode]:
|
|
303
461
|
llm_client = clients.llm_client
|
|
304
462
|
embedder = clients.embedder
|
|
@@ -309,9 +467,12 @@ async def extract_attributes_from_nodes(
|
|
|
309
467
|
node,
|
|
310
468
|
episode,
|
|
311
469
|
previous_episodes,
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
470
|
+
(
|
|
471
|
+
entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
|
|
472
|
+
if entity_types is not None
|
|
473
|
+
else None
|
|
474
|
+
),
|
|
475
|
+
should_summarize_node,
|
|
315
476
|
)
|
|
316
477
|
for node in nodes
|
|
317
478
|
]
|
|
@@ -327,99 +488,100 @@ async def extract_attributes_from_node(
|
|
|
327
488
|
node: EntityNode,
|
|
328
489
|
episode: EpisodicNode | None = None,
|
|
329
490
|
previous_episodes: list[EpisodicNode] | None = None,
|
|
330
|
-
entity_type: BaseModel | None = None,
|
|
491
|
+
entity_type: type[BaseModel] | None = None,
|
|
492
|
+
should_summarize_node: NodeSummaryFilter | None = None,
|
|
331
493
|
) -> EntityNode:
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
'attributes': node.attributes,
|
|
337
|
-
}
|
|
494
|
+
# Extract attributes if entity type is defined and has attributes
|
|
495
|
+
llm_response = await _extract_entity_attributes(
|
|
496
|
+
llm_client, node, episode, previous_episodes, entity_type
|
|
497
|
+
)
|
|
338
498
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
description='Summary containing the important information about the entity. Under 250 words',
|
|
344
|
-
),
|
|
345
|
-
)
|
|
346
|
-
}
|
|
499
|
+
# Extract summary if needed
|
|
500
|
+
await _extract_entity_summary(
|
|
501
|
+
llm_client, node, episode, previous_episodes, should_summarize_node
|
|
502
|
+
)
|
|
347
503
|
|
|
348
|
-
|
|
349
|
-
for field_name, field_info in entity_type.model_fields.items():
|
|
350
|
-
attributes_definitions[field_name] = (
|
|
351
|
-
field_info.annotation,
|
|
352
|
-
Field(description=field_info.description),
|
|
353
|
-
)
|
|
504
|
+
node.attributes.update(llm_response)
|
|
354
505
|
|
|
355
|
-
|
|
356
|
-
entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions)
|
|
506
|
+
return node
|
|
357
507
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
508
|
+
|
|
509
|
+
async def _extract_entity_attributes(
|
|
510
|
+
llm_client: LLMClient,
|
|
511
|
+
node: EntityNode,
|
|
512
|
+
episode: EpisodicNode | None,
|
|
513
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
514
|
+
entity_type: type[BaseModel] | None,
|
|
515
|
+
) -> dict[str, Any]:
|
|
516
|
+
if entity_type is None or len(entity_type.model_fields) == 0:
|
|
517
|
+
return {}
|
|
518
|
+
|
|
519
|
+
attributes_context = _build_episode_context(
|
|
520
|
+
# should not include summary
|
|
521
|
+
node_data={
|
|
522
|
+
'name': node.name,
|
|
523
|
+
'entity_types': node.labels,
|
|
524
|
+
'attributes': node.attributes,
|
|
525
|
+
},
|
|
526
|
+
episode=episode,
|
|
527
|
+
previous_episodes=previous_episodes,
|
|
528
|
+
)
|
|
365
529
|
|
|
366
530
|
llm_response = await llm_client.generate_response(
|
|
367
|
-
prompt_library.extract_nodes.extract_attributes(
|
|
368
|
-
response_model=
|
|
531
|
+
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
|
532
|
+
response_model=entity_type,
|
|
369
533
|
model_size=ModelSize.small,
|
|
534
|
+
group_id=node.group_id,
|
|
535
|
+
prompt_name='extract_nodes.extract_attributes',
|
|
370
536
|
)
|
|
371
537
|
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
with suppress(KeyError):
|
|
376
|
-
del node_attributes['summary']
|
|
538
|
+
# validate response
|
|
539
|
+
entity_type(**llm_response)
|
|
377
540
|
|
|
378
|
-
|
|
541
|
+
return llm_response
|
|
379
542
|
|
|
380
|
-
return node
|
|
381
543
|
|
|
382
|
-
|
|
383
|
-
async def dedupe_node_list(
|
|
544
|
+
async def _extract_entity_summary(
|
|
384
545
|
llm_client: LLMClient,
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
prompt_library.dedupe_nodes.node_list(context)
|
|
546
|
+
node: EntityNode,
|
|
547
|
+
episode: EpisodicNode | None,
|
|
548
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
549
|
+
should_summarize_node: NodeSummaryFilter | None,
|
|
550
|
+
) -> None:
|
|
551
|
+
if should_summarize_node is not None and not await should_summarize_node(node):
|
|
552
|
+
return
|
|
553
|
+
|
|
554
|
+
summary_context = _build_episode_context(
|
|
555
|
+
node_data={
|
|
556
|
+
'name': node.name,
|
|
557
|
+
'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
|
|
558
|
+
'entity_types': node.labels,
|
|
559
|
+
'attributes': node.attributes,
|
|
560
|
+
},
|
|
561
|
+
episode=episode,
|
|
562
|
+
previous_episodes=previous_episodes,
|
|
403
563
|
)
|
|
404
564
|
|
|
405
|
-
|
|
565
|
+
summary_response = await llm_client.generate_response(
|
|
566
|
+
prompt_library.extract_nodes.extract_summary(summary_context),
|
|
567
|
+
response_model=EntitySummary,
|
|
568
|
+
model_size=ModelSize.small,
|
|
569
|
+
group_id=node.group_id,
|
|
570
|
+
prompt_name='extract_nodes.extract_summary',
|
|
571
|
+
)
|
|
406
572
|
|
|
407
|
-
|
|
408
|
-
logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
|
|
409
|
-
|
|
410
|
-
# Get full node data
|
|
411
|
-
unique_nodes = []
|
|
412
|
-
uuid_map: dict[str, str] = {}
|
|
413
|
-
for node_data in nodes_data:
|
|
414
|
-
node_instance: EntityNode | None = node_map.get(node_data['uuids'][0])
|
|
415
|
-
if node_instance is None:
|
|
416
|
-
logger.warning(f'Node {node_data["uuids"][0]} not found in node map')
|
|
417
|
-
continue
|
|
418
|
-
node_instance.summary = node_data['summary']
|
|
419
|
-
unique_nodes.append(node_instance)
|
|
573
|
+
node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
|
|
420
574
|
|
|
421
|
-
for uuid in node_data['uuids'][1:]:
|
|
422
|
-
uuid_value = node_map[node_data['uuids'][0]].uuid
|
|
423
|
-
uuid_map[uuid] = uuid_value
|
|
424
575
|
|
|
425
|
-
|
|
576
|
+
def _build_episode_context(
|
|
577
|
+
node_data: dict[str, Any],
|
|
578
|
+
episode: EpisodicNode | None,
|
|
579
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
580
|
+
) -> dict[str, Any]:
|
|
581
|
+
return {
|
|
582
|
+
'node': node_data,
|
|
583
|
+
'episode_content': episode.content if episode is not None else '',
|
|
584
|
+
'previous_episodes': (
|
|
585
|
+
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
|
586
|
+
),
|
|
587
|
+
}
|
|
@@ -43,7 +43,9 @@ async def extract_edge_dates(
|
|
|
43
43
|
'reference_timestamp': current_episode.valid_at.isoformat(),
|
|
44
44
|
}
|
|
45
45
|
llm_response = await llm_client.generate_response(
|
|
46
|
-
prompt_library.extract_edge_dates.v1(context),
|
|
46
|
+
prompt_library.extract_edge_dates.v1(context),
|
|
47
|
+
response_model=EdgeDates,
|
|
48
|
+
prompt_name='extract_edge_dates.v1',
|
|
47
49
|
)
|
|
48
50
|
|
|
49
51
|
valid_at = llm_response.get('valid_at')
|
|
@@ -70,7 +72,9 @@ async def extract_edge_dates(
|
|
|
70
72
|
|
|
71
73
|
|
|
72
74
|
async def get_edge_contradictions(
|
|
73
|
-
llm_client: LLMClient,
|
|
75
|
+
llm_client: LLMClient,
|
|
76
|
+
new_edge: EntityEdge,
|
|
77
|
+
existing_edges: list[EntityEdge],
|
|
74
78
|
) -> list[EntityEdge]:
|
|
75
79
|
start = time()
|
|
76
80
|
|
|
@@ -79,12 +83,16 @@ async def get_edge_contradictions(
|
|
|
79
83
|
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
|
80
84
|
]
|
|
81
85
|
|
|
82
|
-
context = {
|
|
86
|
+
context = {
|
|
87
|
+
'new_edge': new_edge_context,
|
|
88
|
+
'existing_edges': existing_edge_context,
|
|
89
|
+
}
|
|
83
90
|
|
|
84
91
|
llm_response = await llm_client.generate_response(
|
|
85
92
|
prompt_library.invalidate_edges.v2(context),
|
|
86
93
|
response_model=InvalidatedEdges,
|
|
87
94
|
model_size=ModelSize.small,
|
|
95
|
+
prompt_name='invalidate_edges.v2',
|
|
88
96
|
)
|
|
89
97
|
|
|
90
98
|
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|