graphiti-core 0.11.6rc7__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/driver/__init__.py +17 -0
- graphiti_core/driver/driver.py +66 -0
- graphiti_core/driver/falkordb_driver.py +132 -0
- graphiti_core/driver/neo4j_driver.py +61 -0
- graphiti_core/edges.py +66 -40
- graphiti_core/embedder/azure_openai.py +64 -0
- graphiti_core/embedder/gemini.py +14 -3
- graphiti_core/graph_queries.py +149 -0
- graphiti_core/graphiti.py +41 -14
- graphiti_core/graphiti_types.py +2 -2
- graphiti_core/helpers.py +17 -30
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/azure_openai_client.py +73 -0
- graphiti_core/llm_client/gemini_client.py +4 -1
- graphiti_core/models/edges/edge_db_queries.py +2 -4
- graphiti_core/nodes.py +31 -31
- graphiti_core/prompts/dedupe_edges.py +52 -1
- graphiti_core/prompts/dedupe_nodes.py +79 -4
- graphiti_core/prompts/extract_edges.py +50 -5
- graphiti_core/prompts/invalidate_edges.py +1 -1
- graphiti_core/search/search.py +25 -55
- graphiti_core/search/search_filters.py +23 -9
- graphiti_core/search/search_utils.py +360 -195
- graphiti_core/utils/bulk_utils.py +38 -11
- graphiti_core/utils/maintenance/community_operations.py +6 -7
- graphiti_core/utils/maintenance/edge_operations.py +149 -19
- graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
- graphiti_core/utils/maintenance/node_operations.py +52 -71
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
graphiti_core/nodes.py
CHANGED
|
@@ -22,13 +22,13 @@ from time import time
|
|
|
22
22
|
from typing import Any
|
|
23
23
|
from uuid import uuid4
|
|
24
24
|
|
|
25
|
-
from neo4j import AsyncDriver
|
|
26
25
|
from pydantic import BaseModel, Field
|
|
27
26
|
from typing_extensions import LiteralString
|
|
28
27
|
|
|
28
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
29
29
|
from graphiti_core.embedder import EmbedderClient
|
|
30
30
|
from graphiti_core.errors import NodeNotFoundError
|
|
31
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
31
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
|
32
32
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
33
33
|
COMMUNITY_NODE_SAVE,
|
|
34
34
|
ENTITY_NODE_SAVE,
|
|
@@ -94,9 +94,9 @@ class Node(BaseModel, ABC):
|
|
|
94
94
|
created_at: datetime = Field(default_factory=lambda: utc_now())
|
|
95
95
|
|
|
96
96
|
@abstractmethod
|
|
97
|
-
async def save(self, driver:
|
|
97
|
+
async def save(self, driver: GraphDriver): ...
|
|
98
98
|
|
|
99
|
-
async def delete(self, driver:
|
|
99
|
+
async def delete(self, driver: GraphDriver):
|
|
100
100
|
result = await driver.execute_query(
|
|
101
101
|
"""
|
|
102
102
|
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
|
@@ -119,7 +119,7 @@ class Node(BaseModel, ABC):
|
|
|
119
119
|
return False
|
|
120
120
|
|
|
121
121
|
@classmethod
|
|
122
|
-
async def delete_by_group_id(cls, driver:
|
|
122
|
+
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
|
|
123
123
|
await driver.execute_query(
|
|
124
124
|
"""
|
|
125
125
|
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
|
@@ -132,10 +132,10 @@ class Node(BaseModel, ABC):
|
|
|
132
132
|
return 'SUCCESS'
|
|
133
133
|
|
|
134
134
|
@classmethod
|
|
135
|
-
async def get_by_uuid(cls, driver:
|
|
135
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
136
136
|
|
|
137
137
|
@classmethod
|
|
138
|
-
async def get_by_uuids(cls, driver:
|
|
138
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
|
|
139
139
|
|
|
140
140
|
|
|
141
141
|
class EpisodicNode(Node):
|
|
@@ -150,7 +150,7 @@ class EpisodicNode(Node):
|
|
|
150
150
|
default_factory=list,
|
|
151
151
|
)
|
|
152
152
|
|
|
153
|
-
async def save(self, driver:
|
|
153
|
+
async def save(self, driver: GraphDriver):
|
|
154
154
|
result = await driver.execute_query(
|
|
155
155
|
EPISODIC_NODE_SAVE,
|
|
156
156
|
uuid=self.uuid,
|
|
@@ -165,12 +165,12 @@ class EpisodicNode(Node):
|
|
|
165
165
|
database_=DEFAULT_DATABASE,
|
|
166
166
|
)
|
|
167
167
|
|
|
168
|
-
logger.debug(f'Saved Node to
|
|
168
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
169
169
|
|
|
170
170
|
return result
|
|
171
171
|
|
|
172
172
|
@classmethod
|
|
173
|
-
async def get_by_uuid(cls, driver:
|
|
173
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
174
174
|
records, _, _ = await driver.execute_query(
|
|
175
175
|
"""
|
|
176
176
|
MATCH (e:Episodic {uuid: $uuid})
|
|
@@ -197,7 +197,7 @@ class EpisodicNode(Node):
|
|
|
197
197
|
return episodes[0]
|
|
198
198
|
|
|
199
199
|
@classmethod
|
|
200
|
-
async def get_by_uuids(cls, driver:
|
|
200
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
201
201
|
records, _, _ = await driver.execute_query(
|
|
202
202
|
"""
|
|
203
203
|
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
|
@@ -224,7 +224,7 @@ class EpisodicNode(Node):
|
|
|
224
224
|
@classmethod
|
|
225
225
|
async def get_by_group_ids(
|
|
226
226
|
cls,
|
|
227
|
-
driver:
|
|
227
|
+
driver: GraphDriver,
|
|
228
228
|
group_ids: list[str],
|
|
229
229
|
limit: int | None = None,
|
|
230
230
|
uuid_cursor: str | None = None,
|
|
@@ -263,7 +263,7 @@ class EpisodicNode(Node):
|
|
|
263
263
|
return episodes
|
|
264
264
|
|
|
265
265
|
@classmethod
|
|
266
|
-
async def get_by_entity_node_uuid(cls, driver:
|
|
266
|
+
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
|
|
267
267
|
records, _, _ = await driver.execute_query(
|
|
268
268
|
"""
|
|
269
269
|
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
|
@@ -304,7 +304,7 @@ class EntityNode(Node):
|
|
|
304
304
|
|
|
305
305
|
return self.name_embedding
|
|
306
306
|
|
|
307
|
-
async def load_name_embedding(self, driver:
|
|
307
|
+
async def load_name_embedding(self, driver: GraphDriver):
|
|
308
308
|
query: LiteralString = """
|
|
309
309
|
MATCH (n:Entity {uuid: $uuid})
|
|
310
310
|
RETURN n.name_embedding AS name_embedding
|
|
@@ -318,7 +318,7 @@ class EntityNode(Node):
|
|
|
318
318
|
|
|
319
319
|
self.name_embedding = records[0]['name_embedding']
|
|
320
320
|
|
|
321
|
-
async def save(self, driver:
|
|
321
|
+
async def save(self, driver: GraphDriver):
|
|
322
322
|
entity_data: dict[str, Any] = {
|
|
323
323
|
'uuid': self.uuid,
|
|
324
324
|
'name': self.name,
|
|
@@ -337,16 +337,16 @@ class EntityNode(Node):
|
|
|
337
337
|
database_=DEFAULT_DATABASE,
|
|
338
338
|
)
|
|
339
339
|
|
|
340
|
-
logger.debug(f'Saved Node to
|
|
340
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
341
341
|
|
|
342
342
|
return result
|
|
343
343
|
|
|
344
344
|
@classmethod
|
|
345
|
-
async def get_by_uuid(cls, driver:
|
|
345
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
346
346
|
query = (
|
|
347
347
|
"""
|
|
348
|
-
|
|
349
|
-
|
|
348
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
349
|
+
"""
|
|
350
350
|
+ ENTITY_NODE_RETURN
|
|
351
351
|
)
|
|
352
352
|
records, _, _ = await driver.execute_query(
|
|
@@ -364,7 +364,7 @@ class EntityNode(Node):
|
|
|
364
364
|
return nodes[0]
|
|
365
365
|
|
|
366
366
|
@classmethod
|
|
367
|
-
async def get_by_uuids(cls, driver:
|
|
367
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
368
368
|
records, _, _ = await driver.execute_query(
|
|
369
369
|
"""
|
|
370
370
|
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
|
@@ -382,7 +382,7 @@ class EntityNode(Node):
|
|
|
382
382
|
@classmethod
|
|
383
383
|
async def get_by_group_ids(
|
|
384
384
|
cls,
|
|
385
|
-
driver:
|
|
385
|
+
driver: GraphDriver,
|
|
386
386
|
group_ids: list[str],
|
|
387
387
|
limit: int | None = None,
|
|
388
388
|
uuid_cursor: str | None = None,
|
|
@@ -416,7 +416,7 @@ class CommunityNode(Node):
|
|
|
416
416
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
417
417
|
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
|
418
418
|
|
|
419
|
-
async def save(self, driver:
|
|
419
|
+
async def save(self, driver: GraphDriver):
|
|
420
420
|
result = await driver.execute_query(
|
|
421
421
|
COMMUNITY_NODE_SAVE,
|
|
422
422
|
uuid=self.uuid,
|
|
@@ -428,7 +428,7 @@ class CommunityNode(Node):
|
|
|
428
428
|
database_=DEFAULT_DATABASE,
|
|
429
429
|
)
|
|
430
430
|
|
|
431
|
-
logger.debug(f'Saved Node to
|
|
431
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
432
432
|
|
|
433
433
|
return result
|
|
434
434
|
|
|
@@ -441,7 +441,7 @@ class CommunityNode(Node):
|
|
|
441
441
|
|
|
442
442
|
return self.name_embedding
|
|
443
443
|
|
|
444
|
-
async def load_name_embedding(self, driver:
|
|
444
|
+
async def load_name_embedding(self, driver: GraphDriver):
|
|
445
445
|
query: LiteralString = """
|
|
446
446
|
MATCH (c:Community {uuid: $uuid})
|
|
447
447
|
RETURN c.name_embedding AS name_embedding
|
|
@@ -456,7 +456,7 @@ class CommunityNode(Node):
|
|
|
456
456
|
self.name_embedding = records[0]['name_embedding']
|
|
457
457
|
|
|
458
458
|
@classmethod
|
|
459
|
-
async def get_by_uuid(cls, driver:
|
|
459
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
460
460
|
records, _, _ = await driver.execute_query(
|
|
461
461
|
"""
|
|
462
462
|
MATCH (n:Community {uuid: $uuid})
|
|
@@ -480,7 +480,7 @@ class CommunityNode(Node):
|
|
|
480
480
|
return nodes[0]
|
|
481
481
|
|
|
482
482
|
@classmethod
|
|
483
|
-
async def get_by_uuids(cls, driver:
|
|
483
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
484
484
|
records, _, _ = await driver.execute_query(
|
|
485
485
|
"""
|
|
486
486
|
MATCH (n:Community) WHERE n.uuid IN $uuids
|
|
@@ -503,7 +503,7 @@ class CommunityNode(Node):
|
|
|
503
503
|
@classmethod
|
|
504
504
|
async def get_by_group_ids(
|
|
505
505
|
cls,
|
|
506
|
-
driver:
|
|
506
|
+
driver: GraphDriver,
|
|
507
507
|
group_ids: list[str],
|
|
508
508
|
limit: int | None = None,
|
|
509
509
|
uuid_cursor: str | None = None,
|
|
@@ -542,8 +542,8 @@ class CommunityNode(Node):
|
|
|
542
542
|
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
543
543
|
return EpisodicNode(
|
|
544
544
|
content=record['content'],
|
|
545
|
-
created_at=record['created_at']
|
|
546
|
-
valid_at=(record['valid_at']
|
|
545
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
546
|
+
valid_at=parse_db_date(record['valid_at']), # type: ignore
|
|
547
547
|
uuid=record['uuid'],
|
|
548
548
|
group_id=record['group_id'],
|
|
549
549
|
source=EpisodeType.from_str(record['source']),
|
|
@@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
|
559
559
|
name=record['name'],
|
|
560
560
|
group_id=record['group_id'],
|
|
561
561
|
labels=record['labels'],
|
|
562
|
-
created_at=record['created_at']
|
|
562
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
563
563
|
summary=record['summary'],
|
|
564
564
|
attributes=record['attributes'],
|
|
565
565
|
)
|
|
@@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|
|
580
580
|
name=record['name'],
|
|
581
581
|
group_id=record['group_id'],
|
|
582
582
|
name_embedding=record['name_embedding'],
|
|
583
|
-
created_at=record['created_at']
|
|
583
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
584
584
|
summary=record['summary'],
|
|
585
585
|
)
|
|
586
586
|
|
|
@@ -27,6 +27,11 @@ class EdgeDuplicate(BaseModel):
|
|
|
27
27
|
...,
|
|
28
28
|
description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
|
|
29
29
|
)
|
|
30
|
+
contradicted_facts: list[int] = Field(
|
|
31
|
+
...,
|
|
32
|
+
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
|
|
33
|
+
)
|
|
34
|
+
fact_type: str = Field(..., description='One of the provided fact types or DEFAULT')
|
|
30
35
|
|
|
31
36
|
|
|
32
37
|
class UniqueFact(BaseModel):
|
|
@@ -41,11 +46,13 @@ class UniqueFacts(BaseModel):
|
|
|
41
46
|
class Prompt(Protocol):
|
|
42
47
|
edge: PromptVersion
|
|
43
48
|
edge_list: PromptVersion
|
|
49
|
+
resolve_edge: PromptVersion
|
|
44
50
|
|
|
45
51
|
|
|
46
52
|
class Versions(TypedDict):
|
|
47
53
|
edge: PromptFunction
|
|
48
54
|
edge_list: PromptFunction
|
|
55
|
+
resolve_edge: PromptFunction
|
|
49
56
|
|
|
50
57
|
|
|
51
58
|
def edge(context: dict[str, Any]) -> list[Message]:
|
|
@@ -106,4 +113,48 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
|
|
|
106
113
|
]
|
|
107
114
|
|
|
108
115
|
|
|
109
|
-
|
|
116
|
+
def resolve_edge(context: dict[str, Any]) -> list[Message]:
|
|
117
|
+
return [
|
|
118
|
+
Message(
|
|
119
|
+
role='system',
|
|
120
|
+
content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
|
|
121
|
+
'facts are contradicted by the new fact.',
|
|
122
|
+
),
|
|
123
|
+
Message(
|
|
124
|
+
role='user',
|
|
125
|
+
content=f"""
|
|
126
|
+
<NEW FACT>
|
|
127
|
+
{context['new_edge']}
|
|
128
|
+
</NEW FACT>
|
|
129
|
+
|
|
130
|
+
<EXISTING FACTS>
|
|
131
|
+
{context['existing_edges']}
|
|
132
|
+
</EXISTING FACTS>
|
|
133
|
+
<FACT INVALIDATION CANDIDATES>
|
|
134
|
+
{context['edge_invalidation_candidates']}
|
|
135
|
+
</FACT INVALIDATION CANDIDATES>
|
|
136
|
+
|
|
137
|
+
<FACT TYPES>
|
|
138
|
+
{context['edge_types']}
|
|
139
|
+
</FACT TYPES>
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
Task:
|
|
143
|
+
If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
|
|
144
|
+
If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
|
|
145
|
+
|
|
146
|
+
Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
|
|
147
|
+
Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
|
|
148
|
+
|
|
149
|
+
Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
|
|
150
|
+
Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
|
|
151
|
+
If there are no contradicted facts, return an empty list.
|
|
152
|
+
|
|
153
|
+
Guidelines:
|
|
154
|
+
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
|
|
155
|
+
""",
|
|
156
|
+
),
|
|
157
|
+
]
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}
|
|
@@ -23,21 +23,31 @@ from .models import Message, PromptFunction, PromptVersion
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class NodeDuplicate(BaseModel):
|
|
26
|
-
|
|
26
|
+
id: int = Field(..., description='integer id of the entity')
|
|
27
|
+
duplicate_idx: int = Field(
|
|
27
28
|
...,
|
|
28
|
-
description='
|
|
29
|
+
description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
|
|
29
30
|
)
|
|
30
|
-
name: str = Field(
|
|
31
|
+
name: str = Field(
|
|
32
|
+
...,
|
|
33
|
+
description='Name of the entity. Should be the most complete and descriptive name possible.',
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class NodeResolutions(BaseModel):
|
|
38
|
+
entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
|
|
31
39
|
|
|
32
40
|
|
|
33
41
|
class Prompt(Protocol):
|
|
34
42
|
node: PromptVersion
|
|
35
43
|
node_list: PromptVersion
|
|
44
|
+
nodes: PromptVersion
|
|
36
45
|
|
|
37
46
|
|
|
38
47
|
class Versions(TypedDict):
|
|
39
48
|
node: PromptFunction
|
|
40
49
|
node_list: PromptFunction
|
|
50
|
+
nodes: PromptFunction
|
|
41
51
|
|
|
42
52
|
|
|
43
53
|
def node(context: dict[str, Any]) -> list[Message]:
|
|
@@ -89,6 +99,71 @@ def node(context: dict[str, Any]) -> list[Message]:
|
|
|
89
99
|
]
|
|
90
100
|
|
|
91
101
|
|
|
102
|
+
def nodes(context: dict[str, Any]) -> list[Message]:
|
|
103
|
+
return [
|
|
104
|
+
Message(
|
|
105
|
+
role='system',
|
|
106
|
+
content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
|
|
107
|
+
'of existing entities.',
|
|
108
|
+
),
|
|
109
|
+
Message(
|
|
110
|
+
role='user',
|
|
111
|
+
content=f"""
|
|
112
|
+
<PREVIOUS MESSAGES>
|
|
113
|
+
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
|
114
|
+
</PREVIOUS MESSAGES>
|
|
115
|
+
<CURRENT MESSAGE>
|
|
116
|
+
{context['episode_content']}
|
|
117
|
+
</CURRENT MESSAGE>
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
|
|
121
|
+
Each entity in ENTITIES is represented as a JSON object with the following structure:
|
|
122
|
+
{{
|
|
123
|
+
id: integer id of the entity,
|
|
124
|
+
name: "name of the entity",
|
|
125
|
+
entity_type: "ontological classification of the entity",
|
|
126
|
+
entity_type_description: "Description of what the entity type represents",
|
|
127
|
+
duplication_candidates: [
|
|
128
|
+
{{
|
|
129
|
+
idx: integer index of the candidate entity,
|
|
130
|
+
name: "name of the candidate entity",
|
|
131
|
+
entity_type: "ontological classification of the candidate entity",
|
|
132
|
+
...<additional attributes>
|
|
133
|
+
}}
|
|
134
|
+
]
|
|
135
|
+
}}
|
|
136
|
+
|
|
137
|
+
<ENTITIES>
|
|
138
|
+
{json.dumps(context['extracted_nodes'], indent=2)}
|
|
139
|
+
</ENTITIES>
|
|
140
|
+
|
|
141
|
+
<EXISTING ENTITIES>
|
|
142
|
+
{json.dumps(context['existing_nodes'], indent=2)}
|
|
143
|
+
</EXISTING ENTITIES>
|
|
144
|
+
|
|
145
|
+
For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
|
|
146
|
+
|
|
147
|
+
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
|
|
148
|
+
|
|
149
|
+
Do NOT mark entities as duplicates if:
|
|
150
|
+
- They are related but distinct.
|
|
151
|
+
- They have similar names or purposes but refer to separate instances or concepts.
|
|
152
|
+
|
|
153
|
+
Task:
|
|
154
|
+
Your response will be a list called entity_resolutions which contains one entry for each entity.
|
|
155
|
+
|
|
156
|
+
For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
|
|
157
|
+
as an integer.
|
|
158
|
+
|
|
159
|
+
- If an entity is a duplicate of one of the EXISTING ENTITIES, return the idx of the candidate it is a
|
|
160
|
+
duplicate of.
|
|
161
|
+
- If an entity is not a duplicate of one of the EXISTING ENTITIES, return the -1 as the duplication_idx
|
|
162
|
+
""",
|
|
163
|
+
),
|
|
164
|
+
]
|
|
165
|
+
|
|
166
|
+
|
|
92
167
|
def node_list(context: dict[str, Any]) -> list[Message]:
|
|
93
168
|
return [
|
|
94
169
|
Message(
|
|
@@ -126,4 +201,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
|
|
126
201
|
]
|
|
127
202
|
|
|
128
203
|
|
|
129
|
-
versions: Versions = {'node': node, 'node_list': node_list}
|
|
204
|
+
versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
|
|
@@ -24,8 +24,8 @@ from .models import Message, PromptFunction, PromptVersion
|
|
|
24
24
|
|
|
25
25
|
class Edge(BaseModel):
|
|
26
26
|
relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
|
|
27
|
-
|
|
28
|
-
|
|
27
|
+
source_entity_id: int = Field(..., description='The id of the source entity of the fact.')
|
|
28
|
+
target_entity_id: int = Field(..., description='The id of the target entity of the fact.')
|
|
29
29
|
fact: str = Field(..., description='')
|
|
30
30
|
valid_at: str | None = Field(
|
|
31
31
|
None,
|
|
@@ -48,11 +48,13 @@ class MissingFacts(BaseModel):
|
|
|
48
48
|
class Prompt(Protocol):
|
|
49
49
|
edge: PromptVersion
|
|
50
50
|
reflexion: PromptVersion
|
|
51
|
+
extract_attributes: PromptVersion
|
|
51
52
|
|
|
52
53
|
|
|
53
54
|
class Versions(TypedDict):
|
|
54
55
|
edge: PromptFunction
|
|
55
56
|
reflexion: PromptFunction
|
|
57
|
+
extract_attributes: PromptFunction
|
|
56
58
|
|
|
57
59
|
|
|
58
60
|
def edge(context: dict[str, Any]) -> list[Message]:
|
|
@@ -75,19 +77,26 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
75
77
|
</CURRENT_MESSAGE>
|
|
76
78
|
|
|
77
79
|
<ENTITIES>
|
|
78
|
-
{context['nodes']}
|
|
80
|
+
{context['nodes']}
|
|
79
81
|
</ENTITIES>
|
|
80
82
|
|
|
81
83
|
<REFERENCE_TIME>
|
|
82
84
|
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
|
|
83
85
|
</REFERENCE_TIME>
|
|
84
86
|
|
|
87
|
+
<FACT TYPES>
|
|
88
|
+
{context['edge_types']}
|
|
89
|
+
</FACT TYPES>
|
|
90
|
+
|
|
85
91
|
# TASK
|
|
86
92
|
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
|
|
87
93
|
Only extract facts that:
|
|
88
94
|
- involve two DISTINCT ENTITIES from the ENTITIES list,
|
|
89
95
|
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
|
|
90
|
-
|
|
96
|
+
and can be represented as edges in a knowledge graph.
|
|
97
|
+
- The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
|
|
98
|
+
- The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
|
|
99
|
+
of the FACT TYPES
|
|
91
100
|
|
|
92
101
|
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
|
|
93
102
|
|
|
@@ -145,4 +154,40 @@ determine if any facts haven't been extracted.
|
|
|
145
154
|
]
|
|
146
155
|
|
|
147
156
|
|
|
148
|
-
|
|
157
|
+
def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
|
158
|
+
return [
|
|
159
|
+
Message(
|
|
160
|
+
role='system',
|
|
161
|
+
content='You are a helpful assistant that extracts fact properties from the provided text.',
|
|
162
|
+
),
|
|
163
|
+
Message(
|
|
164
|
+
role='user',
|
|
165
|
+
content=f"""
|
|
166
|
+
|
|
167
|
+
<MESSAGE>
|
|
168
|
+
{json.dumps(context['episode_content'], indent=2)}
|
|
169
|
+
</MESSAGE>
|
|
170
|
+
<REFERENCE TIME>
|
|
171
|
+
{context['reference_time']}
|
|
172
|
+
</REFERENCE TIME>
|
|
173
|
+
|
|
174
|
+
Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
|
|
175
|
+
in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
|
|
176
|
+
|
|
177
|
+
Guidelines:
|
|
178
|
+
1. Do not hallucinate entity property values if they cannot be found in the current context.
|
|
179
|
+
2. Only use the provided MESSAGES and FACT to set attribute values.
|
|
180
|
+
|
|
181
|
+
<FACT>
|
|
182
|
+
{context['fact']}
|
|
183
|
+
</FACT>
|
|
184
|
+
""",
|
|
185
|
+
),
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
versions: Versions = {
|
|
190
|
+
'edge': edge,
|
|
191
|
+
'reflexion': reflexion,
|
|
192
|
+
'extract_attributes': extract_attributes,
|
|
193
|
+
}
|
|
@@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion
|
|
|
24
24
|
class InvalidatedEdges(BaseModel):
|
|
25
25
|
contradicted_facts: list[int] = Field(
|
|
26
26
|
...,
|
|
27
|
-
description='List of ids of facts that be
|
|
27
|
+
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
|
|
28
28
|
)
|
|
29
29
|
|
|
30
30
|
|