graphiti-core 0.11.6rc9__tar.gz → 0.12.0rc1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/PKG-INFO +1 -1
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/edges.py +42 -16
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/gemini.py +14 -3
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/graphiti.py +33 -10
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/helpers.py +0 -1
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/gemini_client.py +4 -1
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/edges/edge_db_queries.py +2 -4
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/dedupe_edges.py +52 -1
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/dedupe_nodes.py +75 -4
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/extract_edges.py +46 -2
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/invalidate_edges.py +1 -1
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_utils.py +14 -9
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/bulk_utils.py +19 -1
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/edge_operations.py +137 -10
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/node_operations.py +58 -20
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/pyproject.toml +1 -1
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/LICENSE +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/README.md +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/__init__.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/client.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/openai_reranker_client.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/__init__.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/client.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/openai.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/voyage.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/errors.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/graphiti_types.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/anthropic_client.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/client.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/errors.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/groq_client.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/openai_client.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/openai_generic_client.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/__init__.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/edges/__init__.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/nodes/__init__.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/nodes/node_db_queries.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/nodes.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/eval.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/extract_edge_dates.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/extract_nodes.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/prompt_helpers.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/summarize_nodes.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/py.typed +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_config.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_config_recipes.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_filters.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_helpers.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/datetime_utils.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/community_operations.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/utils.py +0 -0
- {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/ontology_utils/entity_types_utils.py +0 -0
|
@@ -49,7 +49,9 @@ ENTITY_EDGE_RETURN: LiteralString = """
|
|
|
49
49
|
e.episodes AS episodes,
|
|
50
50
|
e.expired_at AS expired_at,
|
|
51
51
|
e.valid_at AS valid_at,
|
|
52
|
-
e.invalid_at AS invalid_at
|
|
52
|
+
e.invalid_at AS invalid_at,
|
|
53
|
+
properties(e) AS attributes
|
|
54
|
+
"""
|
|
53
55
|
|
|
54
56
|
|
|
55
57
|
class Edge(BaseModel, ABC):
|
|
@@ -209,6 +211,9 @@ class EntityEdge(Edge):
|
|
|
209
211
|
invalid_at: datetime | None = Field(
|
|
210
212
|
default=None, description='datetime of when the fact stopped being true'
|
|
211
213
|
)
|
|
214
|
+
attributes: dict[str, Any] = Field(
|
|
215
|
+
default={}, description='Additional attributes of the edge. Dependent on edge name'
|
|
216
|
+
)
|
|
212
217
|
|
|
213
218
|
async def generate_embedding(self, embedder: EmbedderClient):
|
|
214
219
|
start = time()
|
|
@@ -236,20 +241,26 @@ class EntityEdge(Edge):
|
|
|
236
241
|
self.fact_embedding = records[0]['fact_embedding']
|
|
237
242
|
|
|
238
243
|
async def save(self, driver: AsyncDriver):
|
|
244
|
+
edge_data: dict[str, Any] = {
|
|
245
|
+
'source_uuid': self.source_node_uuid,
|
|
246
|
+
'target_uuid': self.target_node_uuid,
|
|
247
|
+
'uuid': self.uuid,
|
|
248
|
+
'name': self.name,
|
|
249
|
+
'group_id': self.group_id,
|
|
250
|
+
'fact': self.fact,
|
|
251
|
+
'fact_embedding': self.fact_embedding,
|
|
252
|
+
'episodes': self.episodes,
|
|
253
|
+
'created_at': self.created_at,
|
|
254
|
+
'expired_at': self.expired_at,
|
|
255
|
+
'valid_at': self.valid_at,
|
|
256
|
+
'invalid_at': self.invalid_at,
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
edge_data.update(self.attributes or {})
|
|
260
|
+
|
|
239
261
|
result = await driver.execute_query(
|
|
240
262
|
ENTITY_EDGE_SAVE,
|
|
241
|
-
|
|
242
|
-
target_uuid=self.target_node_uuid,
|
|
243
|
-
uuid=self.uuid,
|
|
244
|
-
name=self.name,
|
|
245
|
-
group_id=self.group_id,
|
|
246
|
-
fact=self.fact,
|
|
247
|
-
fact_embedding=self.fact_embedding,
|
|
248
|
-
episodes=self.episodes,
|
|
249
|
-
created_at=self.created_at,
|
|
250
|
-
expired_at=self.expired_at,
|
|
251
|
-
valid_at=self.valid_at,
|
|
252
|
-
invalid_at=self.invalid_at,
|
|
263
|
+
edge_data=edge_data,
|
|
253
264
|
database_=DEFAULT_DATABASE,
|
|
254
265
|
)
|
|
255
266
|
|
|
@@ -334,8 +345,8 @@ class EntityEdge(Edge):
|
|
|
334
345
|
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
|
335
346
|
query: LiteralString = (
|
|
336
347
|
"""
|
|
337
|
-
|
|
338
|
-
|
|
348
|
+
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
|
349
|
+
"""
|
|
339
350
|
+ ENTITY_EDGE_RETURN
|
|
340
351
|
)
|
|
341
352
|
records, _, _ = await driver.execute_query(
|
|
@@ -457,7 +468,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|
|
457
468
|
|
|
458
469
|
|
|
459
470
|
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
460
|
-
|
|
471
|
+
edge = EntityEdge(
|
|
461
472
|
uuid=record['uuid'],
|
|
462
473
|
source_node_uuid=record['source_node_uuid'],
|
|
463
474
|
target_node_uuid=record['target_node_uuid'],
|
|
@@ -469,8 +480,23 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
|
469
480
|
expired_at=parse_db_date(record['expired_at']),
|
|
470
481
|
valid_at=parse_db_date(record['valid_at']),
|
|
471
482
|
invalid_at=parse_db_date(record['invalid_at']),
|
|
483
|
+
attributes=record['attributes'],
|
|
472
484
|
)
|
|
473
485
|
|
|
486
|
+
edge.attributes.pop('uuid', None)
|
|
487
|
+
edge.attributes.pop('source_node_uuid', None)
|
|
488
|
+
edge.attributes.pop('target_node_uuid', None)
|
|
489
|
+
edge.attributes.pop('fact', None)
|
|
490
|
+
edge.attributes.pop('name', None)
|
|
491
|
+
edge.attributes.pop('group_id', None)
|
|
492
|
+
edge.attributes.pop('episodes', None)
|
|
493
|
+
edge.attributes.pop('created_at', None)
|
|
494
|
+
edge.attributes.pop('expired_at', None)
|
|
495
|
+
edge.attributes.pop('valid_at', None)
|
|
496
|
+
edge.attributes.pop('invalid_at', None)
|
|
497
|
+
|
|
498
|
+
return edge
|
|
499
|
+
|
|
474
500
|
|
|
475
501
|
def get_community_edge_from_record(record: Any):
|
|
476
502
|
return CommunityEdge(
|
|
@@ -61,18 +61,29 @@ class GeminiEmbedder(EmbedderClient):
|
|
|
61
61
|
# Generate embeddings
|
|
62
62
|
result = await self.client.aio.models.embed_content(
|
|
63
63
|
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
|
|
64
|
-
contents=[input_data],
|
|
64
|
+
contents=[input_data], # type: ignore[arg-type] # mypy fails on broad union type
|
|
65
65
|
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
|
|
66
66
|
)
|
|
67
67
|
|
|
68
|
+
if not result.embeddings or len(result.embeddings) == 0 or not result.embeddings[0].values:
|
|
69
|
+
raise ValueError('No embeddings returned from Gemini API in create()')
|
|
70
|
+
|
|
68
71
|
return result.embeddings[0].values
|
|
69
72
|
|
|
70
73
|
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
|
71
74
|
# Generate embeddings
|
|
72
75
|
result = await self.client.aio.models.embed_content(
|
|
73
76
|
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
|
|
74
|
-
contents=input_data_list,
|
|
77
|
+
contents=input_data_list, # type: ignore[arg-type] # mypy fails on broad union type
|
|
75
78
|
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
|
|
76
79
|
)
|
|
77
80
|
|
|
78
|
-
|
|
81
|
+
if not result.embeddings or len(result.embeddings) == 0:
|
|
82
|
+
raise Exception('No embeddings returned')
|
|
83
|
+
|
|
84
|
+
embeddings = []
|
|
85
|
+
for embedding in result.embeddings:
|
|
86
|
+
if not embedding.values:
|
|
87
|
+
raise ValueError('Empty embedding values returned')
|
|
88
|
+
embeddings.append(embedding.values)
|
|
89
|
+
return embeddings
|
|
@@ -41,6 +41,7 @@ from graphiti_core.search.search_config_recipes import (
|
|
|
41
41
|
from graphiti_core.search.search_filters import SearchFilters
|
|
42
42
|
from graphiti_core.search.search_utils import (
|
|
43
43
|
RELEVANT_SCHEMA_LIMIT,
|
|
44
|
+
get_edge_invalidation_candidates,
|
|
44
45
|
get_mentioned_nodes,
|
|
45
46
|
get_relevant_edges,
|
|
46
47
|
)
|
|
@@ -62,9 +63,8 @@ from graphiti_core.utils.maintenance.community_operations import (
|
|
|
62
63
|
)
|
|
63
64
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
64
65
|
build_episodic_edges,
|
|
65
|
-
dedupe_extracted_edge,
|
|
66
66
|
extract_edges,
|
|
67
|
-
|
|
67
|
+
resolve_extracted_edge,
|
|
68
68
|
resolve_extracted_edges,
|
|
69
69
|
)
|
|
70
70
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
@@ -77,7 +77,6 @@ from graphiti_core.utils.maintenance.node_operations import (
|
|
|
77
77
|
extract_nodes,
|
|
78
78
|
resolve_extracted_nodes,
|
|
79
79
|
)
|
|
80
|
-
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
|
|
81
80
|
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
|
|
82
81
|
|
|
83
82
|
logger = logging.getLogger(__name__)
|
|
@@ -274,6 +273,8 @@ class Graphiti:
|
|
|
274
273
|
update_communities: bool = False,
|
|
275
274
|
entity_types: dict[str, BaseModel] | None = None,
|
|
276
275
|
previous_episode_uuids: list[str] | None = None,
|
|
276
|
+
edge_types: dict[str, BaseModel] | None = None,
|
|
277
|
+
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
|
277
278
|
) -> AddEpisodeResults:
|
|
278
279
|
"""
|
|
279
280
|
Process an episode and update the graph.
|
|
@@ -356,6 +357,13 @@ class Graphiti:
|
|
|
356
357
|
)
|
|
357
358
|
)
|
|
358
359
|
|
|
360
|
+
# Create default edge type map
|
|
361
|
+
edge_type_map_default = (
|
|
362
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
363
|
+
if edge_types is not None
|
|
364
|
+
else {('Entity', 'Entity'): []}
|
|
365
|
+
)
|
|
366
|
+
|
|
359
367
|
# Extract entities as nodes
|
|
360
368
|
|
|
361
369
|
extracted_nodes = await extract_nodes(
|
|
@@ -371,7 +379,9 @@ class Graphiti:
|
|
|
371
379
|
previous_episodes,
|
|
372
380
|
entity_types,
|
|
373
381
|
),
|
|
374
|
-
extract_edges(
|
|
382
|
+
extract_edges(
|
|
383
|
+
self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
|
|
384
|
+
),
|
|
375
385
|
)
|
|
376
386
|
|
|
377
387
|
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
|
@@ -381,6 +391,9 @@ class Graphiti:
|
|
|
381
391
|
self.clients,
|
|
382
392
|
edges,
|
|
383
393
|
episode,
|
|
394
|
+
nodes,
|
|
395
|
+
edge_types or {},
|
|
396
|
+
edge_type_map or edge_type_map_default,
|
|
384
397
|
),
|
|
385
398
|
extract_attributes_from_nodes(
|
|
386
399
|
self.clients, nodes, episode, previous_episodes, entity_types
|
|
@@ -681,17 +694,27 @@ class Graphiti:
|
|
|
681
694
|
|
|
682
695
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
|
683
696
|
|
|
684
|
-
related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters()
|
|
697
|
+
related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
|
|
698
|
+
existing_edges = (
|
|
699
|
+
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
|
|
700
|
+
)[0]
|
|
685
701
|
|
|
686
|
-
resolved_edge = await
|
|
702
|
+
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
|
687
703
|
self.llm_client,
|
|
688
704
|
updated_edge,
|
|
689
|
-
related_edges
|
|
705
|
+
related_edges,
|
|
706
|
+
existing_edges,
|
|
707
|
+
EpisodicNode(
|
|
708
|
+
name='',
|
|
709
|
+
source=EpisodeType.text,
|
|
710
|
+
source_description='',
|
|
711
|
+
content='',
|
|
712
|
+
valid_at=edge.valid_at or utc_now(),
|
|
713
|
+
entity_edges=[],
|
|
714
|
+
group_id=edge.group_id,
|
|
715
|
+
),
|
|
690
716
|
)
|
|
691
717
|
|
|
692
|
-
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
|
|
693
|
-
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
|
694
|
-
|
|
695
718
|
await add_nodes_and_edges_bulk(
|
|
696
719
|
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
|
|
697
720
|
)
|
{graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/gemini_client.py
RENAMED
|
@@ -139,13 +139,16 @@ class GeminiClient(LLMClient):
|
|
|
139
139
|
# Generate content using the simple string approach
|
|
140
140
|
response = await self.client.aio.models.generate_content(
|
|
141
141
|
model=self.model or DEFAULT_MODEL,
|
|
142
|
-
contents=gemini_messages,
|
|
142
|
+
contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
|
|
143
143
|
config=generation_config,
|
|
144
144
|
)
|
|
145
145
|
|
|
146
146
|
# If this was a structured output request, parse the response into the Pydantic model
|
|
147
147
|
if response_model is not None:
|
|
148
148
|
try:
|
|
149
|
+
if not response.text:
|
|
150
|
+
raise ValueError('No response text')
|
|
151
|
+
|
|
149
152
|
validated_model = response_model.model_validate(json.loads(response.text))
|
|
150
153
|
|
|
151
154
|
# Return as a dictionary for API consistency
|
{graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/edges/edge_db_queries.py
RENAMED
|
@@ -34,8 +34,7 @@ ENTITY_EDGE_SAVE = """
|
|
|
34
34
|
MATCH (source:Entity {uuid: $source_uuid})
|
|
35
35
|
MATCH (target:Entity {uuid: $target_uuid})
|
|
36
36
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
|
37
|
-
SET r =
|
|
38
|
-
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
|
37
|
+
SET r = $edge_data
|
|
39
38
|
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
|
40
39
|
RETURN r.uuid AS uuid"""
|
|
41
40
|
|
|
@@ -44,8 +43,7 @@ ENTITY_EDGE_SAVE_BULK = """
|
|
|
44
43
|
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
45
44
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
46
45
|
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
47
|
-
SET r =
|
|
48
|
-
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
|
|
46
|
+
SET r = edge
|
|
49
47
|
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
|
50
48
|
RETURN edge.uuid AS uuid
|
|
51
49
|
"""
|
|
@@ -27,6 +27,11 @@ class EdgeDuplicate(BaseModel):
|
|
|
27
27
|
...,
|
|
28
28
|
description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
|
|
29
29
|
)
|
|
30
|
+
contradicted_facts: list[int] = Field(
|
|
31
|
+
...,
|
|
32
|
+
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
|
|
33
|
+
)
|
|
34
|
+
fact_type: str = Field(..., description='One of the provided fact types or DEFAULT')
|
|
30
35
|
|
|
31
36
|
|
|
32
37
|
class UniqueFact(BaseModel):
|
|
@@ -41,11 +46,13 @@ class UniqueFacts(BaseModel):
|
|
|
41
46
|
class Prompt(Protocol):
|
|
42
47
|
edge: PromptVersion
|
|
43
48
|
edge_list: PromptVersion
|
|
49
|
+
resolve_edge: PromptVersion
|
|
44
50
|
|
|
45
51
|
|
|
46
52
|
class Versions(TypedDict):
|
|
47
53
|
edge: PromptFunction
|
|
48
54
|
edge_list: PromptFunction
|
|
55
|
+
resolve_edge: PromptFunction
|
|
49
56
|
|
|
50
57
|
|
|
51
58
|
def edge(context: dict[str, Any]) -> list[Message]:
|
|
@@ -106,4 +113,48 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
|
|
|
106
113
|
]
|
|
107
114
|
|
|
108
115
|
|
|
109
|
-
|
|
116
|
+
def resolve_edge(context: dict[str, Any]) -> list[Message]:
|
|
117
|
+
return [
|
|
118
|
+
Message(
|
|
119
|
+
role='system',
|
|
120
|
+
content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
|
|
121
|
+
'facts are contradicted by the new fact.',
|
|
122
|
+
),
|
|
123
|
+
Message(
|
|
124
|
+
role='user',
|
|
125
|
+
content=f"""
|
|
126
|
+
<NEW FACT>
|
|
127
|
+
{context['new_edge']}
|
|
128
|
+
</NEW FACT>
|
|
129
|
+
|
|
130
|
+
<EXISTING FACTS>
|
|
131
|
+
{context['existing_edges']}
|
|
132
|
+
</EXISTING FACTS>
|
|
133
|
+
<FACT INVALIDATION CANDIDATES>
|
|
134
|
+
{context['edge_invalidation_candidates']}
|
|
135
|
+
</FACT INVALIDATION CANDIDATES>
|
|
136
|
+
|
|
137
|
+
<FACT TYPES>
|
|
138
|
+
{context['edge_types']}
|
|
139
|
+
</FACT TYPES>
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
Task:
|
|
143
|
+
If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
|
|
144
|
+
If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
|
|
145
|
+
|
|
146
|
+
Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
|
|
147
|
+
Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
|
|
148
|
+
|
|
149
|
+
Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
|
|
150
|
+
Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
|
|
151
|
+
If there are no contradicted facts, return an empty list.
|
|
152
|
+
|
|
153
|
+
Guidelines:
|
|
154
|
+
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
|
|
155
|
+
""",
|
|
156
|
+
),
|
|
157
|
+
]
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}
|
|
@@ -23,21 +23,31 @@ from .models import Message, PromptFunction, PromptVersion
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class NodeDuplicate(BaseModel):
|
|
26
|
-
|
|
26
|
+
id: int = Field(..., description='integer id of the entity')
|
|
27
|
+
duplicate_idx: int = Field(
|
|
27
28
|
...,
|
|
28
|
-
description='
|
|
29
|
+
description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
|
|
29
30
|
)
|
|
30
|
-
name: str = Field(
|
|
31
|
+
name: str = Field(
|
|
32
|
+
...,
|
|
33
|
+
description='Name of the entity. Should be the most complete and descriptive name possible.',
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class NodeResolutions(BaseModel):
|
|
38
|
+
entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
|
|
31
39
|
|
|
32
40
|
|
|
33
41
|
class Prompt(Protocol):
|
|
34
42
|
node: PromptVersion
|
|
35
43
|
node_list: PromptVersion
|
|
44
|
+
nodes: PromptVersion
|
|
36
45
|
|
|
37
46
|
|
|
38
47
|
class Versions(TypedDict):
|
|
39
48
|
node: PromptFunction
|
|
40
49
|
node_list: PromptFunction
|
|
50
|
+
nodes: PromptFunction
|
|
41
51
|
|
|
42
52
|
|
|
43
53
|
def node(context: dict[str, Any]) -> list[Message]:
|
|
@@ -89,6 +99,67 @@ def node(context: dict[str, Any]) -> list[Message]:
|
|
|
89
99
|
]
|
|
90
100
|
|
|
91
101
|
|
|
102
|
+
def nodes(context: dict[str, Any]) -> list[Message]:
|
|
103
|
+
return [
|
|
104
|
+
Message(
|
|
105
|
+
role='system',
|
|
106
|
+
content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
|
|
107
|
+
'of existing entities.',
|
|
108
|
+
),
|
|
109
|
+
Message(
|
|
110
|
+
role='user',
|
|
111
|
+
content=f"""
|
|
112
|
+
<PREVIOUS MESSAGES>
|
|
113
|
+
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
|
114
|
+
</PREVIOUS MESSAGES>
|
|
115
|
+
<CURRENT MESSAGE>
|
|
116
|
+
{context['episode_content']}
|
|
117
|
+
</CURRENT MESSAGE>
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
|
|
121
|
+
Each entity in ENTITIES is represented as a JSON object with the following structure:
|
|
122
|
+
{{
|
|
123
|
+
id: integer id of the entity,
|
|
124
|
+
name: "name of the entity",
|
|
125
|
+
entity_type: "ontological classification of the entity",
|
|
126
|
+
entity_type_description: "Description of what the entity type represents",
|
|
127
|
+
duplication_candidates: [
|
|
128
|
+
{{
|
|
129
|
+
idx: integer index of the candidate entity,
|
|
130
|
+
name: "name of the candidate entity",
|
|
131
|
+
entity_type: "ontological classification of the candidate entity",
|
|
132
|
+
...<additional attributes>
|
|
133
|
+
}}
|
|
134
|
+
]
|
|
135
|
+
}}
|
|
136
|
+
|
|
137
|
+
<ENTITIES>
|
|
138
|
+
{json.dumps(context['extracted_nodes'], indent=2)}
|
|
139
|
+
</ENTITIES>
|
|
140
|
+
|
|
141
|
+
For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
|
|
142
|
+
|
|
143
|
+
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
|
|
144
|
+
|
|
145
|
+
Do NOT mark entities as duplicates if:
|
|
146
|
+
- They are related but distinct.
|
|
147
|
+
- They have similar names or purposes but refer to separate instances or concepts.
|
|
148
|
+
|
|
149
|
+
Task:
|
|
150
|
+
Your response will be a list called entity_resolutions which contains one entry for each entity.
|
|
151
|
+
|
|
152
|
+
For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
|
|
153
|
+
as an integer.
|
|
154
|
+
|
|
155
|
+
- If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
|
|
156
|
+
duplicate of.
|
|
157
|
+
- If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
|
|
158
|
+
""",
|
|
159
|
+
),
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
|
|
92
163
|
def node_list(context: dict[str, Any]) -> list[Message]:
|
|
93
164
|
return [
|
|
94
165
|
Message(
|
|
@@ -126,4 +197,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
|
|
126
197
|
]
|
|
127
198
|
|
|
128
199
|
|
|
129
|
-
versions: Versions = {'node': node, 'node_list': node_list}
|
|
200
|
+
versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
|
|
@@ -48,11 +48,13 @@ class MissingFacts(BaseModel):
|
|
|
48
48
|
class Prompt(Protocol):
|
|
49
49
|
edge: PromptVersion
|
|
50
50
|
reflexion: PromptVersion
|
|
51
|
+
extract_attributes: PromptVersion
|
|
51
52
|
|
|
52
53
|
|
|
53
54
|
class Versions(TypedDict):
|
|
54
55
|
edge: PromptFunction
|
|
55
56
|
reflexion: PromptFunction
|
|
57
|
+
extract_attributes: PromptFunction
|
|
56
58
|
|
|
57
59
|
|
|
58
60
|
def edge(context: dict[str, Any]) -> list[Message]:
|
|
@@ -82,12 +84,18 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
82
84
|
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
|
|
83
85
|
</REFERENCE_TIME>
|
|
84
86
|
|
|
87
|
+
<FACT TYPES>
|
|
88
|
+
{context['edge_types']}
|
|
89
|
+
</FACT TYPES>
|
|
90
|
+
|
|
85
91
|
# TASK
|
|
86
92
|
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
|
|
87
93
|
Only extract facts that:
|
|
88
94
|
- involve two DISTINCT ENTITIES from the ENTITIES list,
|
|
89
95
|
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
|
|
90
|
-
|
|
96
|
+
and can be represented as edges in a knowledge graph.
|
|
97
|
+
- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
|
|
98
|
+
could be classified into one of the provided fact types
|
|
91
99
|
|
|
92
100
|
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
|
|
93
101
|
|
|
@@ -145,4 +153,40 @@ determine if any facts haven't been extracted.
|
|
|
145
153
|
]
|
|
146
154
|
|
|
147
155
|
|
|
148
|
-
|
|
156
|
+
def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
|
157
|
+
return [
|
|
158
|
+
Message(
|
|
159
|
+
role='system',
|
|
160
|
+
content='You are a helpful assistant that extracts fact properties from the provided text.',
|
|
161
|
+
),
|
|
162
|
+
Message(
|
|
163
|
+
role='user',
|
|
164
|
+
content=f"""
|
|
165
|
+
|
|
166
|
+
<MESSAGE>
|
|
167
|
+
{json.dumps(context['episode_content'], indent=2)}
|
|
168
|
+
</MESSAGE>
|
|
169
|
+
<REFERENCE TIME>
|
|
170
|
+
{context['reference_time']}
|
|
171
|
+
</REFERENCE TIME>
|
|
172
|
+
|
|
173
|
+
Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
|
|
174
|
+
in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
|
|
175
|
+
|
|
176
|
+
Guidelines:
|
|
177
|
+
1. Do not hallucinate entity property values if they cannot be found in the current context.
|
|
178
|
+
2. Only use the provided MESSAGES and FACT to set attribute values.
|
|
179
|
+
|
|
180
|
+
<FACT>
|
|
181
|
+
{context['fact']}
|
|
182
|
+
</FACT>
|
|
183
|
+
""",
|
|
184
|
+
),
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
versions: Versions = {
|
|
189
|
+
'edge': edge,
|
|
190
|
+
'reflexion': reflexion,
|
|
191
|
+
'extract_attributes': extract_attributes,
|
|
192
|
+
}
|
{graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/invalidate_edges.py
RENAMED
|
@@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion
|
|
|
24
24
|
class InvalidatedEdges(BaseModel):
|
|
25
25
|
contradicted_facts: list[int] = Field(
|
|
26
26
|
...,
|
|
27
|
-
description='List of ids of facts that be
|
|
27
|
+
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
|
|
28
28
|
)
|
|
29
29
|
|
|
30
30
|
|
|
@@ -174,7 +174,8 @@ async def edge_fulltext_search(
|
|
|
174
174
|
r.episodes AS episodes,
|
|
175
175
|
r.expired_at AS expired_at,
|
|
176
176
|
r.valid_at AS valid_at,
|
|
177
|
-
r.invalid_at AS invalid_at
|
|
177
|
+
r.invalid_at AS invalid_at,
|
|
178
|
+
properties(r) AS attributes
|
|
178
179
|
ORDER BY score DESC LIMIT $limit
|
|
179
180
|
"""
|
|
180
181
|
)
|
|
@@ -243,7 +244,8 @@ async def edge_similarity_search(
|
|
|
243
244
|
r.episodes AS episodes,
|
|
244
245
|
r.expired_at AS expired_at,
|
|
245
246
|
r.valid_at AS valid_at,
|
|
246
|
-
r.invalid_at AS invalid_at
|
|
247
|
+
r.invalid_at AS invalid_at,
|
|
248
|
+
properties(r) AS attributes
|
|
247
249
|
ORDER BY score DESC
|
|
248
250
|
LIMIT $limit
|
|
249
251
|
"""
|
|
@@ -301,7 +303,8 @@ async def edge_bfs_search(
|
|
|
301
303
|
r.episodes AS episodes,
|
|
302
304
|
r.expired_at AS expired_at,
|
|
303
305
|
r.valid_at AS valid_at,
|
|
304
|
-
r.invalid_at AS invalid_at
|
|
306
|
+
r.invalid_at AS invalid_at,
|
|
307
|
+
properties(r) AS attributes
|
|
305
308
|
LIMIT $limit
|
|
306
309
|
"""
|
|
307
310
|
)
|
|
@@ -337,10 +340,10 @@ async def node_fulltext_search(
|
|
|
337
340
|
|
|
338
341
|
query = (
|
|
339
342
|
"""
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
343
|
+
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
|
344
|
+
YIELD node AS n, score
|
|
345
|
+
WHERE n:Entity
|
|
346
|
+
"""
|
|
344
347
|
+ filter_query
|
|
345
348
|
+ ENTITY_NODE_RETURN
|
|
346
349
|
+ """
|
|
@@ -771,7 +774,8 @@ async def get_relevant_edges(
|
|
|
771
774
|
episodes: e.episodes,
|
|
772
775
|
expired_at: e.expired_at,
|
|
773
776
|
valid_at: e.valid_at,
|
|
774
|
-
invalid_at: e.invalid_at
|
|
777
|
+
invalid_at: e.invalid_at,
|
|
778
|
+
attributes: properties(e)
|
|
775
779
|
})[..$limit] AS matches
|
|
776
780
|
"""
|
|
777
781
|
)
|
|
@@ -837,7 +841,8 @@ async def get_edge_invalidation_candidates(
|
|
|
837
841
|
episodes: e.episodes,
|
|
838
842
|
expired_at: e.expired_at,
|
|
839
843
|
valid_at: e.valid_at,
|
|
840
|
-
invalid_at: e.invalid_at
|
|
844
|
+
invalid_at: e.invalid_at,
|
|
845
|
+
attributes: properties(e)
|
|
841
846
|
})[..$limit] AS matches
|
|
842
847
|
"""
|
|
843
848
|
)
|
|
@@ -137,16 +137,34 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
137
137
|
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
|
138
138
|
nodes.append(entity_data)
|
|
139
139
|
|
|
140
|
+
edges: list[dict[str, Any]] = []
|
|
140
141
|
for edge in entity_edges:
|
|
141
142
|
if edge.fact_embedding is None:
|
|
142
143
|
await edge.generate_embedding(embedder)
|
|
144
|
+
edge_data: dict[str, Any] = {
|
|
145
|
+
'uuid': edge.uuid,
|
|
146
|
+
'source_node_uuid': edge.source_node_uuid,
|
|
147
|
+
'target_node_uuid': edge.target_node_uuid,
|
|
148
|
+
'name': edge.name,
|
|
149
|
+
'fact': edge.fact,
|
|
150
|
+
'fact_embedding': edge.fact_embedding,
|
|
151
|
+
'group_id': edge.group_id,
|
|
152
|
+
'episodes': edge.episodes,
|
|
153
|
+
'created_at': edge.created_at,
|
|
154
|
+
'expired_at': edge.expired_at,
|
|
155
|
+
'valid_at': edge.valid_at,
|
|
156
|
+
'invalid_at': edge.invalid_at,
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
edge_data.update(edge.attributes or {})
|
|
160
|
+
edges.append(edge_data)
|
|
143
161
|
|
|
144
162
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
|
145
163
|
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
|
146
164
|
await tx.run(
|
|
147
165
|
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
|
148
166
|
)
|
|
149
|
-
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=
|
|
167
|
+
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
|
|
150
168
|
|
|
151
169
|
|
|
152
170
|
async def extract_nodes_and_edges_bulk(
|
|
@@ -18,6 +18,8 @@ import logging
|
|
|
18
18
|
from datetime import datetime
|
|
19
19
|
from time import time
|
|
20
20
|
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
21
23
|
from graphiti_core.edges import (
|
|
22
24
|
CommunityEdge,
|
|
23
25
|
EntityEdge,
|
|
@@ -35,9 +37,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
|
|
35
37
|
from graphiti_core.search.search_filters import SearchFilters
|
|
36
38
|
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
|
37
39
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
|
38
|
-
from graphiti_core.utils.maintenance.temporal_operations import (
|
|
39
|
-
get_edge_contradictions,
|
|
40
|
-
)
|
|
41
40
|
|
|
42
41
|
logger = logging.getLogger(__name__)
|
|
43
42
|
|
|
@@ -86,6 +85,7 @@ async def extract_edges(
|
|
|
86
85
|
nodes: list[EntityNode],
|
|
87
86
|
previous_episodes: list[EpisodicNode],
|
|
88
87
|
group_id: str = '',
|
|
88
|
+
edge_types: dict[str, BaseModel] | None = None,
|
|
89
89
|
) -> list[EntityEdge]:
|
|
90
90
|
start = time()
|
|
91
91
|
|
|
@@ -94,12 +94,25 @@ async def extract_edges(
|
|
|
94
94
|
|
|
95
95
|
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
|
96
96
|
|
|
97
|
+
edge_types_context = (
|
|
98
|
+
[
|
|
99
|
+
{
|
|
100
|
+
'fact_type_name': type_name,
|
|
101
|
+
'fact_type_description': type_model.__doc__,
|
|
102
|
+
}
|
|
103
|
+
for type_name, type_model in edge_types.items()
|
|
104
|
+
]
|
|
105
|
+
if edge_types is not None
|
|
106
|
+
else []
|
|
107
|
+
)
|
|
108
|
+
|
|
97
109
|
# Prepare context for LLM
|
|
98
110
|
context = {
|
|
99
111
|
'episode_content': episode.content,
|
|
100
112
|
'nodes': [node.name for node in nodes],
|
|
101
113
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
|
102
114
|
'reference_time': episode.valid_at,
|
|
115
|
+
'edge_types': edge_types_context,
|
|
103
116
|
'custom_prompt': '',
|
|
104
117
|
}
|
|
105
118
|
|
|
@@ -236,6 +249,9 @@ async def resolve_extracted_edges(
|
|
|
236
249
|
clients: GraphitiClients,
|
|
237
250
|
extracted_edges: list[EntityEdge],
|
|
238
251
|
episode: EpisodicNode,
|
|
252
|
+
entities: list[EntityNode],
|
|
253
|
+
edge_types: dict[str, BaseModel],
|
|
254
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
239
255
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
240
256
|
driver = clients.driver
|
|
241
257
|
llm_client = clients.llm_client
|
|
@@ -245,7 +261,7 @@ async def resolve_extracted_edges(
|
|
|
245
261
|
|
|
246
262
|
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
|
247
263
|
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
|
248
|
-
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
|
|
264
|
+
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
|
|
249
265
|
)
|
|
250
266
|
|
|
251
267
|
related_edges_lists, edge_invalidation_candidates = search_results
|
|
@@ -254,15 +270,50 @@ async def resolve_extracted_edges(
|
|
|
254
270
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
|
255
271
|
)
|
|
256
272
|
|
|
273
|
+
# Build entity hash table
|
|
274
|
+
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
|
275
|
+
|
|
276
|
+
# Determine which edge types are relevant for each edge
|
|
277
|
+
edge_types_lst: list[dict[str, BaseModel]] = []
|
|
278
|
+
for extracted_edge in extracted_edges:
|
|
279
|
+
source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels
|
|
280
|
+
target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels
|
|
281
|
+
label_tuples = [
|
|
282
|
+
(source_label, target_label)
|
|
283
|
+
for source_label in source_node_labels
|
|
284
|
+
for target_label in target_node_labels
|
|
285
|
+
]
|
|
286
|
+
|
|
287
|
+
extracted_edge_types = {}
|
|
288
|
+
for label_tuple in label_tuples:
|
|
289
|
+
type_names = edge_type_map.get(label_tuple, [])
|
|
290
|
+
for type_name in type_names:
|
|
291
|
+
type_model = edge_types.get(type_name)
|
|
292
|
+
if type_model is None:
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
extracted_edge_types[type_name] = type_model
|
|
296
|
+
|
|
297
|
+
edge_types_lst.append(extracted_edge_types)
|
|
298
|
+
|
|
257
299
|
# resolve edges with related edges in the graph and find invalidation candidates
|
|
258
300
|
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
|
259
301
|
await semaphore_gather(
|
|
260
302
|
*[
|
|
261
303
|
resolve_extracted_edge(
|
|
262
|
-
llm_client,
|
|
304
|
+
llm_client,
|
|
305
|
+
extracted_edge,
|
|
306
|
+
related_edges,
|
|
307
|
+
existing_edges,
|
|
308
|
+
episode,
|
|
309
|
+
extracted_edge_types,
|
|
263
310
|
)
|
|
264
|
-
for extracted_edge, related_edges, existing_edges in zip(
|
|
265
|
-
extracted_edges,
|
|
311
|
+
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
|
312
|
+
extracted_edges,
|
|
313
|
+
related_edges_lists,
|
|
314
|
+
edge_invalidation_candidates,
|
|
315
|
+
edge_types_lst,
|
|
316
|
+
strict=True,
|
|
266
317
|
)
|
|
267
318
|
]
|
|
268
319
|
)
|
|
@@ -326,10 +377,86 @@ async def resolve_extracted_edge(
|
|
|
326
377
|
related_edges: list[EntityEdge],
|
|
327
378
|
existing_edges: list[EntityEdge],
|
|
328
379
|
episode: EpisodicNode,
|
|
380
|
+
edge_types: dict[str, BaseModel] | None = None,
|
|
329
381
|
) -> tuple[EntityEdge, list[EntityEdge]]:
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
382
|
+
if len(related_edges) == 0 and len(existing_edges) == 0:
|
|
383
|
+
return extracted_edge, []
|
|
384
|
+
|
|
385
|
+
start = time()
|
|
386
|
+
|
|
387
|
+
# Prepare context for LLM
|
|
388
|
+
related_edges_context = [
|
|
389
|
+
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
|
390
|
+
]
|
|
391
|
+
|
|
392
|
+
invalidation_edge_candidates_context = [
|
|
393
|
+
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
|
394
|
+
]
|
|
395
|
+
|
|
396
|
+
edge_types_context = (
|
|
397
|
+
[
|
|
398
|
+
{
|
|
399
|
+
'fact_type_id': i,
|
|
400
|
+
'fact_type_name': type_name,
|
|
401
|
+
'fact_type_description': type_model.__doc__,
|
|
402
|
+
}
|
|
403
|
+
for i, (type_name, type_model) in enumerate(edge_types.items())
|
|
404
|
+
]
|
|
405
|
+
if edge_types is not None
|
|
406
|
+
else []
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
context = {
|
|
410
|
+
'existing_edges': related_edges_context,
|
|
411
|
+
'new_edge': extracted_edge.fact,
|
|
412
|
+
'edge_invalidation_candidates': invalidation_edge_candidates_context,
|
|
413
|
+
'edge_types': edge_types_context,
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
llm_response = await llm_client.generate_response(
|
|
417
|
+
prompt_library.dedupe_edges.resolve_edge(context),
|
|
418
|
+
response_model=EdgeDuplicate,
|
|
419
|
+
model_size=ModelSize.small,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
|
423
|
+
|
|
424
|
+
resolved_edge = (
|
|
425
|
+
related_edges[duplicate_fact_id]
|
|
426
|
+
if 0 <= duplicate_fact_id < len(related_edges)
|
|
427
|
+
else extracted_edge
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
if duplicate_fact_id >= 0 and episode is not None:
|
|
431
|
+
resolved_edge.episodes.append(episode.uuid)
|
|
432
|
+
|
|
433
|
+
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
|
434
|
+
|
|
435
|
+
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
|
|
436
|
+
|
|
437
|
+
fact_type: str = str(llm_response.get('fact_type'))
|
|
438
|
+
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
|
|
439
|
+
resolved_edge.name = fact_type
|
|
440
|
+
|
|
441
|
+
edge_attributes_context = {
|
|
442
|
+
'message': episode.content,
|
|
443
|
+
'reference_time': episode.valid_at,
|
|
444
|
+
'fact': resolved_edge.fact,
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
edge_model = edge_types.get(fact_type)
|
|
448
|
+
|
|
449
|
+
edge_attributes_response = await llm_client.generate_response(
|
|
450
|
+
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
|
451
|
+
response_model=edge_model, # type: ignore
|
|
452
|
+
model_size=ModelSize.small,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
resolved_edge.attributes = edge_attributes_response
|
|
456
|
+
|
|
457
|
+
end = time()
|
|
458
|
+
logger.debug(
|
|
459
|
+
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
|
|
333
460
|
)
|
|
334
461
|
|
|
335
462
|
now = utc_now()
|
|
@@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
|
|
|
29
29
|
from graphiti_core.llm_client.config import ModelSize
|
|
30
30
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
|
31
31
|
from graphiti_core.prompts import prompt_library
|
|
32
|
-
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
|
32
|
+
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
|
|
33
33
|
from graphiti_core.prompts.extract_nodes import (
|
|
34
34
|
ExtractedEntities,
|
|
35
35
|
ExtractedEntity,
|
|
@@ -243,28 +243,65 @@ async def resolve_extracted_nodes(
|
|
|
243
243
|
|
|
244
244
|
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
|
|
245
245
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
246
|
+
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
|
|
247
|
+
|
|
248
|
+
# Prepare context for LLM
|
|
249
|
+
extracted_nodes_context = [
|
|
250
|
+
{
|
|
251
|
+
'id': i,
|
|
252
|
+
'name': node.name,
|
|
253
|
+
'entity_type': node.labels,
|
|
254
|
+
'entity_type_description': entity_types_dict.get(
|
|
255
|
+
next((item for item in node.labels if item != 'Entity'), '')
|
|
256
|
+
).__doc__
|
|
257
|
+
or 'Default Entity Type',
|
|
258
|
+
'duplication_candidates': [
|
|
259
|
+
{
|
|
260
|
+
**{
|
|
261
|
+
'idx': j,
|
|
262
|
+
'name': candidate.name,
|
|
263
|
+
'entity_types': candidate.labels,
|
|
264
|
+
},
|
|
265
|
+
**candidate.attributes,
|
|
266
|
+
}
|
|
267
|
+
for j, candidate in enumerate(existing_nodes_lists[i])
|
|
268
|
+
],
|
|
269
|
+
}
|
|
270
|
+
for i, node in enumerate(extracted_nodes)
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
context = {
|
|
274
|
+
'extracted_nodes': extracted_nodes_context,
|
|
275
|
+
'episode_content': episode.content if episode is not None else '',
|
|
276
|
+
'previous_episodes': [ep.content for ep in previous_episodes]
|
|
277
|
+
if previous_episodes is not None
|
|
278
|
+
else [],
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
llm_response = await llm_client.generate_response(
|
|
282
|
+
prompt_library.dedupe_nodes.nodes(context),
|
|
283
|
+
response_model=NodeResolutions,
|
|
264
284
|
)
|
|
265
285
|
|
|
286
|
+
node_resolutions: list = llm_response.get('entity_resolutions', [])
|
|
287
|
+
|
|
288
|
+
resolved_nodes: list[EntityNode] = []
|
|
266
289
|
uuid_map: dict[str, str] = {}
|
|
267
|
-
for
|
|
290
|
+
for resolution in node_resolutions:
|
|
291
|
+
resolution_id = resolution.get('id', -1)
|
|
292
|
+
duplicate_idx = resolution.get('duplicate_idx', -1)
|
|
293
|
+
|
|
294
|
+
extracted_node = extracted_nodes[resolution_id]
|
|
295
|
+
|
|
296
|
+
resolved_node = (
|
|
297
|
+
existing_nodes_lists[resolution_id][duplicate_idx]
|
|
298
|
+
if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
|
|
299
|
+
else extracted_node
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
resolved_node.name = resolution.get('name')
|
|
303
|
+
|
|
304
|
+
resolved_nodes.append(resolved_node)
|
|
268
305
|
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
|
269
306
|
|
|
270
307
|
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
|
@@ -410,6 +447,7 @@ async def extract_attributes_from_node(
|
|
|
410
447
|
llm_response = await llm_client.generate_response(
|
|
411
448
|
prompt_library.extract_nodes.extract_attributes(summary_context),
|
|
412
449
|
response_model=entity_attributes_model,
|
|
450
|
+
model_size=ModelSize.small,
|
|
413
451
|
)
|
|
414
452
|
|
|
415
453
|
node.summary = llm_response.get('summary', node.summary)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "graphiti-core"
|
|
3
3
|
description = "A temporal graph building library"
|
|
4
|
-
version = "0.
|
|
4
|
+
version = "0.12.0pre1"
|
|
5
5
|
authors = [
|
|
6
6
|
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
|
7
7
|
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/anthropic_client.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/openai_client.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/nodes/node_db_queries.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/extract_edge_dates.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/summarize_nodes.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_config_recipes.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/utils.py
RENAMED
|
File without changes
|
|
File without changes
|