graphiti-core 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/__init__.py +3 -0
- graphiti_core/edges.py +232 -0
- graphiti_core/graphiti.py +618 -0
- graphiti_core/helpers.py +7 -0
- graphiti_core/llm_client/__init__.py +5 -0
- graphiti_core/llm_client/anthropic_client.py +63 -0
- graphiti_core/llm_client/client.py +96 -0
- graphiti_core/llm_client/config.py +58 -0
- graphiti_core/llm_client/groq_client.py +64 -0
- graphiti_core/llm_client/openai_client.py +65 -0
- graphiti_core/llm_client/utils.py +22 -0
- graphiti_core/nodes.py +250 -0
- graphiti_core/prompts/__init__.py +4 -0
- graphiti_core/prompts/dedupe_edges.py +154 -0
- graphiti_core/prompts/dedupe_nodes.py +151 -0
- graphiti_core/prompts/extract_edge_dates.py +60 -0
- graphiti_core/prompts/extract_edges.py +138 -0
- graphiti_core/prompts/extract_nodes.py +145 -0
- graphiti_core/prompts/invalidate_edges.py +74 -0
- graphiti_core/prompts/lib.py +122 -0
- graphiti_core/prompts/models.py +31 -0
- graphiti_core/search/__init__.py +0 -0
- graphiti_core/search/search.py +142 -0
- graphiti_core/search/search_utils.py +454 -0
- graphiti_core/utils/__init__.py +15 -0
- graphiti_core/utils/bulk_utils.py +227 -0
- graphiti_core/utils/maintenance/__init__.py +16 -0
- graphiti_core/utils/maintenance/edge_operations.py +170 -0
- graphiti_core/utils/maintenance/graph_data_operations.py +133 -0
- graphiti_core/utils/maintenance/node_operations.py +199 -0
- graphiti_core/utils/maintenance/temporal_operations.py +184 -0
- graphiti_core/utils/maintenance/utils.py +0 -0
- graphiti_core/utils/utils.py +39 -0
- graphiti_core-0.1.0.dist-info/LICENSE +201 -0
- graphiti_core-0.1.0.dist-info/METADATA +199 -0
- graphiti_core-0.1.0.dist-info/RECORD +37 -0
- graphiti_core-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from datetime import datetime
|
|
19
|
+
from time import time
|
|
20
|
+
from typing import List
|
|
21
|
+
|
|
22
|
+
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
23
|
+
from graphiti_core.llm_client import LLMClient
|
|
24
|
+
from graphiti_core.nodes import EntityNode, EpisodicNode
|
|
25
|
+
from graphiti_core.prompts import prompt_library
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def build_episodic_edges(
|
|
31
|
+
entity_nodes: List[EntityNode],
|
|
32
|
+
episode: EpisodicNode,
|
|
33
|
+
created_at: datetime,
|
|
34
|
+
) -> List[EpisodicEdge]:
|
|
35
|
+
edges: List[EpisodicEdge] = []
|
|
36
|
+
|
|
37
|
+
for node in entity_nodes:
|
|
38
|
+
edge = EpisodicEdge(
|
|
39
|
+
source_node_uuid=episode.uuid,
|
|
40
|
+
target_node_uuid=node.uuid,
|
|
41
|
+
created_at=created_at,
|
|
42
|
+
)
|
|
43
|
+
edges.append(edge)
|
|
44
|
+
|
|
45
|
+
return edges
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
async def extract_edges(
|
|
49
|
+
llm_client: LLMClient,
|
|
50
|
+
episode: EpisodicNode,
|
|
51
|
+
nodes: list[EntityNode],
|
|
52
|
+
previous_episodes: list[EpisodicNode],
|
|
53
|
+
) -> list[EntityEdge]:
|
|
54
|
+
start = time()
|
|
55
|
+
|
|
56
|
+
# Prepare context for LLM
|
|
57
|
+
context = {
|
|
58
|
+
'episode_content': episode.content,
|
|
59
|
+
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
|
|
60
|
+
'nodes': [
|
|
61
|
+
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in nodes
|
|
62
|
+
],
|
|
63
|
+
'previous_episodes': [
|
|
64
|
+
{
|
|
65
|
+
'content': ep.content,
|
|
66
|
+
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
|
|
67
|
+
}
|
|
68
|
+
for ep in previous_episodes
|
|
69
|
+
],
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
llm_response = await llm_client.generate_response(prompt_library.extract_edges.v2(context))
|
|
73
|
+
print(llm_response)
|
|
74
|
+
edges_data = llm_response.get('edges', [])
|
|
75
|
+
|
|
76
|
+
end = time()
|
|
77
|
+
logger.info(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
|
|
78
|
+
|
|
79
|
+
# Convert the extracted data into EntityEdge objects
|
|
80
|
+
edges = []
|
|
81
|
+
for edge_data in edges_data:
|
|
82
|
+
if edge_data['target_node_uuid'] and edge_data['source_node_uuid']:
|
|
83
|
+
edge = EntityEdge(
|
|
84
|
+
source_node_uuid=edge_data['source_node_uuid'],
|
|
85
|
+
target_node_uuid=edge_data['target_node_uuid'],
|
|
86
|
+
name=edge_data['relation_type'],
|
|
87
|
+
fact=edge_data['fact'],
|
|
88
|
+
episodes=[episode.uuid],
|
|
89
|
+
created_at=datetime.now(),
|
|
90
|
+
valid_at=None,
|
|
91
|
+
invalid_at=None,
|
|
92
|
+
)
|
|
93
|
+
edges.append(edge)
|
|
94
|
+
logger.info(
|
|
95
|
+
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return edges
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def create_edge_identifier(
|
|
102
|
+
source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
|
|
103
|
+
) -> str:
|
|
104
|
+
return f'{source_node.name}-{edge.name}-{target_node.name}'
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
async def dedupe_extracted_edges(
|
|
108
|
+
llm_client: LLMClient,
|
|
109
|
+
extracted_edges: list[EntityEdge],
|
|
110
|
+
existing_edges: list[EntityEdge],
|
|
111
|
+
) -> list[EntityEdge]:
|
|
112
|
+
# Create edge map
|
|
113
|
+
edge_map = {}
|
|
114
|
+
for edge in extracted_edges:
|
|
115
|
+
edge_map[edge.uuid] = edge
|
|
116
|
+
|
|
117
|
+
# Prepare context for LLM
|
|
118
|
+
context = {
|
|
119
|
+
'extracted_edges': [
|
|
120
|
+
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
|
|
121
|
+
],
|
|
122
|
+
'existing_edges': [
|
|
123
|
+
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
|
|
124
|
+
],
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
|
|
128
|
+
unique_edge_data = llm_response.get('unique_facts', [])
|
|
129
|
+
logger.info(f'Extracted unique edges: {unique_edge_data}')
|
|
130
|
+
|
|
131
|
+
# Get full edge data
|
|
132
|
+
edges = []
|
|
133
|
+
for unique_edge in unique_edge_data:
|
|
134
|
+
edge = edge_map[unique_edge['uuid']]
|
|
135
|
+
edges.append(edge)
|
|
136
|
+
|
|
137
|
+
return edges
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
async def dedupe_edge_list(
|
|
141
|
+
llm_client: LLMClient,
|
|
142
|
+
edges: list[EntityEdge],
|
|
143
|
+
) -> list[EntityEdge]:
|
|
144
|
+
start = time()
|
|
145
|
+
|
|
146
|
+
# Create edge map
|
|
147
|
+
edge_map = {}
|
|
148
|
+
for edge in edges:
|
|
149
|
+
edge_map[edge.uuid] = edge
|
|
150
|
+
|
|
151
|
+
# Prepare context for LLM
|
|
152
|
+
context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
|
|
153
|
+
|
|
154
|
+
llm_response = await llm_client.generate_response(
|
|
155
|
+
prompt_library.dedupe_edges.edge_list(context)
|
|
156
|
+
)
|
|
157
|
+
unique_edges_data = llm_response.get('unique_facts', [])
|
|
158
|
+
|
|
159
|
+
end = time()
|
|
160
|
+
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
|
161
|
+
|
|
162
|
+
# Get full edge data
|
|
163
|
+
unique_edges = []
|
|
164
|
+
for edge_data in unique_edges_data:
|
|
165
|
+
uuid = edge_data['uuid']
|
|
166
|
+
edge = edge_map[uuid]
|
|
167
|
+
edge.fact = edge_data['fact']
|
|
168
|
+
unique_edges.append(edge)
|
|
169
|
+
|
|
170
|
+
return unique_edges
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import logging
|
|
19
|
+
from datetime import datetime, timezone
|
|
20
|
+
|
|
21
|
+
from neo4j import AsyncDriver
|
|
22
|
+
from typing_extensions import LiteralString
|
|
23
|
+
|
|
24
|
+
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
|
25
|
+
|
|
26
|
+
EPISODE_WINDOW_LEN = 3
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def build_indices_and_constraints(driver: AsyncDriver):
|
|
32
|
+
range_indices: list[LiteralString] = [
|
|
33
|
+
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
34
|
+
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
35
|
+
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
|
36
|
+
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
|
37
|
+
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
|
38
|
+
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
|
39
|
+
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
|
40
|
+
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
|
41
|
+
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
|
42
|
+
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
|
43
|
+
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
|
44
|
+
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
|
45
|
+
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
fulltext_indices: list[LiteralString] = [
|
|
49
|
+
'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]',
|
|
50
|
+
'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]',
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
vector_indices: list[LiteralString] = [
|
|
54
|
+
"""
|
|
55
|
+
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
|
56
|
+
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
|
|
57
|
+
OPTIONS {indexConfig: {
|
|
58
|
+
`vector.dimensions`: 1024,
|
|
59
|
+
`vector.similarity_function`: 'cosine'
|
|
60
|
+
}}
|
|
61
|
+
""",
|
|
62
|
+
"""
|
|
63
|
+
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
|
|
64
|
+
FOR (n:Entity) ON (n.name_embedding)
|
|
65
|
+
OPTIONS {indexConfig: {
|
|
66
|
+
`vector.dimensions`: 1024,
|
|
67
|
+
`vector.similarity_function`: 'cosine'
|
|
68
|
+
}}
|
|
69
|
+
""",
|
|
70
|
+
]
|
|
71
|
+
index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
|
|
72
|
+
|
|
73
|
+
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
async def clear_data(driver: AsyncDriver):
|
|
77
|
+
async with driver.session() as session:
|
|
78
|
+
|
|
79
|
+
async def delete_all(tx):
|
|
80
|
+
await tx.run('MATCH (n) DETACH DELETE n')
|
|
81
|
+
|
|
82
|
+
await session.execute_write(delete_all)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def retrieve_episodes(
|
|
86
|
+
driver: AsyncDriver,
|
|
87
|
+
reference_time: datetime,
|
|
88
|
+
last_n: int = EPISODE_WINDOW_LEN,
|
|
89
|
+
) -> list[EpisodicNode]:
|
|
90
|
+
"""
|
|
91
|
+
Retrieve the last n episodic nodes from the graph.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
driver (AsyncDriver): The Neo4j driver instance.
|
|
95
|
+
reference_time (datetime): The reference time to filter episodes. Only episodes with a valid_at timestamp
|
|
96
|
+
less than or equal to this reference_time will be retrieved. This allows for
|
|
97
|
+
querying the graph's state at a specific point in time.
|
|
98
|
+
last_n (int, optional): The number of most recent episodes to retrieve, relative to the reference_time.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
|
102
|
+
"""
|
|
103
|
+
result = await driver.execute_query(
|
|
104
|
+
"""
|
|
105
|
+
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
|
106
|
+
RETURN e.content as content,
|
|
107
|
+
e.created_at as created_at,
|
|
108
|
+
e.valid_at as valid_at,
|
|
109
|
+
e.uuid as uuid,
|
|
110
|
+
e.name as name,
|
|
111
|
+
e.source_description as source_description,
|
|
112
|
+
e.source as source
|
|
113
|
+
ORDER BY e.created_at DESC
|
|
114
|
+
LIMIT $num_episodes
|
|
115
|
+
""",
|
|
116
|
+
reference_time=reference_time,
|
|
117
|
+
num_episodes=last_n,
|
|
118
|
+
)
|
|
119
|
+
episodes = [
|
|
120
|
+
EpisodicNode(
|
|
121
|
+
content=record['content'],
|
|
122
|
+
created_at=datetime.fromtimestamp(
|
|
123
|
+
record['created_at'].to_native().timestamp(), timezone.utc
|
|
124
|
+
),
|
|
125
|
+
valid_at=(record['valid_at'].to_native()),
|
|
126
|
+
uuid=record['uuid'],
|
|
127
|
+
source=EpisodeType.from_str(record['source']),
|
|
128
|
+
name=record['name'],
|
|
129
|
+
source_description=record['source_description'],
|
|
130
|
+
)
|
|
131
|
+
for record in result.records
|
|
132
|
+
]
|
|
133
|
+
return list(reversed(episodes)) # Return in chronological order
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from datetime import datetime
|
|
19
|
+
from time import time
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
from graphiti_core.llm_client import LLMClient
|
|
23
|
+
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
24
|
+
from graphiti_core.prompts import prompt_library
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
async def extract_message_nodes(
|
|
30
|
+
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
|
31
|
+
) -> list[dict[str, Any]]:
|
|
32
|
+
# Prepare context for LLM
|
|
33
|
+
context = {
|
|
34
|
+
'episode_content': episode.content,
|
|
35
|
+
'episode_timestamp': episode.valid_at.isoformat(),
|
|
36
|
+
'previous_episodes': [
|
|
37
|
+
{
|
|
38
|
+
'content': ep.content,
|
|
39
|
+
'timestamp': ep.valid_at.isoformat(),
|
|
40
|
+
}
|
|
41
|
+
for ep in previous_episodes
|
|
42
|
+
],
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v2(context))
|
|
46
|
+
extracted_node_data = llm_response.get('extracted_nodes', [])
|
|
47
|
+
return extracted_node_data
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
async def extract_json_nodes(
|
|
51
|
+
llm_client: LLMClient,
|
|
52
|
+
episode: EpisodicNode,
|
|
53
|
+
) -> list[dict[str, Any]]:
|
|
54
|
+
# Prepare context for LLM
|
|
55
|
+
context = {
|
|
56
|
+
'episode_content': episode.content,
|
|
57
|
+
'episode_timestamp': episode.valid_at.isoformat(),
|
|
58
|
+
'source_description': episode.source_description,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
llm_response = await llm_client.generate_response(
|
|
62
|
+
prompt_library.extract_nodes.extract_json(context)
|
|
63
|
+
)
|
|
64
|
+
extracted_node_data = llm_response.get('extracted_nodes', [])
|
|
65
|
+
return extracted_node_data
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
async def extract_nodes(
|
|
69
|
+
llm_client: LLMClient,
|
|
70
|
+
episode: EpisodicNode,
|
|
71
|
+
previous_episodes: list[EpisodicNode],
|
|
72
|
+
) -> list[EntityNode]:
|
|
73
|
+
start = time()
|
|
74
|
+
extracted_node_data: list[dict[str, Any]] = []
|
|
75
|
+
if episode.source in [EpisodeType.message, EpisodeType.text]:
|
|
76
|
+
extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
|
|
77
|
+
elif episode.source == EpisodeType.json:
|
|
78
|
+
extracted_node_data = await extract_json_nodes(llm_client, episode)
|
|
79
|
+
|
|
80
|
+
end = time()
|
|
81
|
+
logger.info(f'Extracted new nodes: {extracted_node_data} in {(end - start) * 1000} ms')
|
|
82
|
+
# Convert the extracted data into EntityNode objects
|
|
83
|
+
new_nodes = []
|
|
84
|
+
for node_data in extracted_node_data:
|
|
85
|
+
new_node = EntityNode(
|
|
86
|
+
name=node_data['name'],
|
|
87
|
+
labels=node_data['labels'],
|
|
88
|
+
summary=node_data['summary'],
|
|
89
|
+
created_at=datetime.now(),
|
|
90
|
+
)
|
|
91
|
+
new_nodes.append(new_node)
|
|
92
|
+
logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
|
93
|
+
|
|
94
|
+
return new_nodes
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
async def dedupe_extracted_nodes(
|
|
98
|
+
llm_client: LLMClient,
|
|
99
|
+
extracted_nodes: list[EntityNode],
|
|
100
|
+
existing_nodes: list[EntityNode],
|
|
101
|
+
) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]:
|
|
102
|
+
start = time()
|
|
103
|
+
|
|
104
|
+
# build existing node map
|
|
105
|
+
node_map: dict[str, EntityNode] = {}
|
|
106
|
+
for node in existing_nodes:
|
|
107
|
+
node_map[node.name] = node
|
|
108
|
+
|
|
109
|
+
# Temp hack
|
|
110
|
+
new_nodes_map: dict[str, EntityNode] = {}
|
|
111
|
+
for node in extracted_nodes:
|
|
112
|
+
new_nodes_map[node.name] = node
|
|
113
|
+
|
|
114
|
+
# Prepare context for LLM
|
|
115
|
+
existing_nodes_context = [
|
|
116
|
+
{'name': node.name, 'summary': node.summary} for node in existing_nodes
|
|
117
|
+
]
|
|
118
|
+
|
|
119
|
+
extracted_nodes_context = [
|
|
120
|
+
{'name': node.name, 'summary': node.summary} for node in extracted_nodes
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
context = {
|
|
124
|
+
'existing_nodes': existing_nodes_context,
|
|
125
|
+
'extracted_nodes': extracted_nodes_context,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v2(context))
|
|
129
|
+
|
|
130
|
+
duplicate_data = llm_response.get('duplicates', [])
|
|
131
|
+
|
|
132
|
+
end = time()
|
|
133
|
+
logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
|
|
134
|
+
|
|
135
|
+
uuid_map: dict[str, str] = {}
|
|
136
|
+
for duplicate in duplicate_data:
|
|
137
|
+
uuid = new_nodes_map[duplicate['name']].uuid
|
|
138
|
+
uuid_value = node_map[duplicate['duplicate_of']].uuid
|
|
139
|
+
uuid_map[uuid] = uuid_value
|
|
140
|
+
|
|
141
|
+
nodes: list[EntityNode] = []
|
|
142
|
+
brand_new_nodes: list[EntityNode] = []
|
|
143
|
+
for node in extracted_nodes:
|
|
144
|
+
if node.uuid in uuid_map:
|
|
145
|
+
existing_uuid = uuid_map[node.uuid]
|
|
146
|
+
# TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes,
|
|
147
|
+
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
|
|
148
|
+
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
|
|
149
|
+
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
|
|
150
|
+
if existing_node:
|
|
151
|
+
nodes.append(existing_node)
|
|
152
|
+
|
|
153
|
+
continue
|
|
154
|
+
brand_new_nodes.append(node)
|
|
155
|
+
nodes.append(node)
|
|
156
|
+
|
|
157
|
+
return nodes, uuid_map, brand_new_nodes
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
async def dedupe_node_list(
|
|
161
|
+
llm_client: LLMClient,
|
|
162
|
+
nodes: list[EntityNode],
|
|
163
|
+
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
164
|
+
start = time()
|
|
165
|
+
|
|
166
|
+
# build node map
|
|
167
|
+
node_map = {}
|
|
168
|
+
for node in nodes:
|
|
169
|
+
node_map[node.name] = node
|
|
170
|
+
|
|
171
|
+
# Prepare context for LLM
|
|
172
|
+
nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes]
|
|
173
|
+
|
|
174
|
+
context = {
|
|
175
|
+
'nodes': nodes_context,
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
llm_response = await llm_client.generate_response(
|
|
179
|
+
prompt_library.dedupe_nodes.node_list(context)
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
nodes_data = llm_response.get('nodes', [])
|
|
183
|
+
|
|
184
|
+
end = time()
|
|
185
|
+
logger.info(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
|
|
186
|
+
|
|
187
|
+
# Get full node data
|
|
188
|
+
unique_nodes = []
|
|
189
|
+
uuid_map: dict[str, str] = {}
|
|
190
|
+
for node_data in nodes_data:
|
|
191
|
+
node = node_map[node_data['names'][0]]
|
|
192
|
+
unique_nodes.append(node)
|
|
193
|
+
|
|
194
|
+
for name in node_data['names'][1:]:
|
|
195
|
+
uuid = node_map[name].uuid
|
|
196
|
+
uuid_value = node_map[node_data['names'][0]].uuid
|
|
197
|
+
uuid_map[uuid] = uuid_value
|
|
198
|
+
|
|
199
|
+
return unique_nodes, uuid_map
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from datetime import datetime
|
|
19
|
+
from typing import List
|
|
20
|
+
|
|
21
|
+
from graphiti_core.edges import EntityEdge
|
|
22
|
+
from graphiti_core.llm_client import LLMClient
|
|
23
|
+
from graphiti_core.nodes import EntityNode, EpisodicNode
|
|
24
|
+
from graphiti_core.prompts import prompt_library
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def extract_node_and_edge_triplets(
|
|
32
|
+
edges: list[EntityEdge], nodes: list[EntityNode]
|
|
33
|
+
) -> list[NodeEdgeNodeTriplet]:
|
|
34
|
+
return [extract_node_edge_node_triplet(edge, nodes) for edge in edges]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def extract_node_edge_node_triplet(
|
|
38
|
+
edge: EntityEdge, nodes: list[EntityNode]
|
|
39
|
+
) -> NodeEdgeNodeTriplet:
|
|
40
|
+
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
|
41
|
+
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
|
42
|
+
if not source_node or not target_node:
|
|
43
|
+
raise ValueError(f'Source or target node not found for edge {edge.uuid}')
|
|
44
|
+
return (source_node, edge, target_node)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def prepare_edges_for_invalidation(
|
|
48
|
+
existing_edges: list[EntityEdge],
|
|
49
|
+
new_edges: list[EntityEdge],
|
|
50
|
+
nodes: list[EntityNode],
|
|
51
|
+
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
|
|
52
|
+
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = []
|
|
53
|
+
new_edges_with_nodes: list[NodeEdgeNodeTriplet] = []
|
|
54
|
+
|
|
55
|
+
for edge_list, result_list in [
|
|
56
|
+
(existing_edges, existing_edges_pending_invalidation),
|
|
57
|
+
(new_edges, new_edges_with_nodes),
|
|
58
|
+
]:
|
|
59
|
+
for edge in edge_list:
|
|
60
|
+
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
|
61
|
+
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
|
62
|
+
|
|
63
|
+
if source_node and target_node:
|
|
64
|
+
result_list.append((source_node, edge, target_node))
|
|
65
|
+
|
|
66
|
+
return existing_edges_pending_invalidation, new_edges_with_nodes
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
async def invalidate_edges(
|
|
70
|
+
llm_client: LLMClient,
|
|
71
|
+
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet],
|
|
72
|
+
new_edges: list[NodeEdgeNodeTriplet],
|
|
73
|
+
current_episode: EpisodicNode,
|
|
74
|
+
previous_episodes: list[EpisodicNode],
|
|
75
|
+
) -> list[EntityEdge]:
|
|
76
|
+
invalidated_edges = [] # TODO: this is not yet used?
|
|
77
|
+
|
|
78
|
+
context = prepare_invalidation_context(
|
|
79
|
+
existing_edges_pending_invalidation,
|
|
80
|
+
new_edges,
|
|
81
|
+
current_episode,
|
|
82
|
+
previous_episodes,
|
|
83
|
+
)
|
|
84
|
+
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
|
|
85
|
+
|
|
86
|
+
edges_to_invalidate = llm_response.get('invalidated_edges', [])
|
|
87
|
+
invalidated_edges = process_edge_invalidation_llm_response(
|
|
88
|
+
edges_to_invalidate, existing_edges_pending_invalidation
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return invalidated_edges
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def extract_date_strings_from_edge(edge: EntityEdge) -> str:
|
|
95
|
+
start = edge.valid_at
|
|
96
|
+
end = edge.invalid_at
|
|
97
|
+
date_string = f'Start Date: {start.isoformat()}' if start else ''
|
|
98
|
+
if end:
|
|
99
|
+
date_string += f' (End Date: {end.isoformat()})'
|
|
100
|
+
return date_string
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def prepare_invalidation_context(
|
|
104
|
+
existing_edges: list[NodeEdgeNodeTriplet],
|
|
105
|
+
new_edges: list[NodeEdgeNodeTriplet],
|
|
106
|
+
current_episode: EpisodicNode,
|
|
107
|
+
previous_episodes: list[EpisodicNode],
|
|
108
|
+
) -> dict:
|
|
109
|
+
return {
|
|
110
|
+
'existing_edges': [
|
|
111
|
+
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
|
|
112
|
+
for source_node, edge, target_node in sorted(
|
|
113
|
+
existing_edges, key=lambda x: (x[1].created_at), reverse=True
|
|
114
|
+
)
|
|
115
|
+
],
|
|
116
|
+
'new_edges': [
|
|
117
|
+
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
|
|
118
|
+
for source_node, edge, target_node in sorted(
|
|
119
|
+
new_edges, key=lambda x: (x[1].created_at), reverse=True
|
|
120
|
+
)
|
|
121
|
+
],
|
|
122
|
+
'current_episode': current_episode.content,
|
|
123
|
+
'previous_episodes': [episode.content for episode in previous_episodes],
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def process_edge_invalidation_llm_response(
|
|
128
|
+
edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet]
|
|
129
|
+
) -> List[EntityEdge]:
|
|
130
|
+
invalidated_edges = []
|
|
131
|
+
for edge_to_invalidate in edges_to_invalidate:
|
|
132
|
+
edge_uuid = edge_to_invalidate['edge_uuid']
|
|
133
|
+
edge_to_update = next(
|
|
134
|
+
(edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid),
|
|
135
|
+
None,
|
|
136
|
+
)
|
|
137
|
+
if edge_to_update:
|
|
138
|
+
edge_to_update.expired_at = datetime.now()
|
|
139
|
+
edge_to_update.fact = edge_to_invalidate['fact']
|
|
140
|
+
invalidated_edges.append(edge_to_update)
|
|
141
|
+
logger.info(
|
|
142
|
+
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}"
|
|
143
|
+
)
|
|
144
|
+
return invalidated_edges
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
async def extract_edge_dates(
|
|
148
|
+
llm_client: LLMClient,
|
|
149
|
+
edge: EntityEdge,
|
|
150
|
+
reference_time: datetime,
|
|
151
|
+
current_episode: EpisodicNode,
|
|
152
|
+
previous_episodes: List[EpisodicNode],
|
|
153
|
+
) -> tuple[datetime | None, datetime | None, str]:
|
|
154
|
+
context = {
|
|
155
|
+
'edge_name': edge.name,
|
|
156
|
+
'edge_fact': edge.fact,
|
|
157
|
+
'current_episode': current_episode.content,
|
|
158
|
+
'previous_episodes': [ep.content for ep in previous_episodes],
|
|
159
|
+
'reference_timestamp': reference_time.isoformat(),
|
|
160
|
+
}
|
|
161
|
+
llm_response = await llm_client.generate_response(prompt_library.extract_edge_dates.v1(context))
|
|
162
|
+
|
|
163
|
+
valid_at = llm_response.get('valid_at')
|
|
164
|
+
invalid_at = llm_response.get('invalid_at')
|
|
165
|
+
explanation = llm_response.get('explanation', '')
|
|
166
|
+
|
|
167
|
+
valid_at_datetime = None
|
|
168
|
+
invalid_at_datetime = None
|
|
169
|
+
|
|
170
|
+
if valid_at and valid_at != '':
|
|
171
|
+
try:
|
|
172
|
+
valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
|
|
173
|
+
except ValueError as e:
|
|
174
|
+
logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
|
|
175
|
+
|
|
176
|
+
if invalid_at and invalid_at != '':
|
|
177
|
+
try:
|
|
178
|
+
invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
|
|
179
|
+
except ValueError as e:
|
|
180
|
+
logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
|
|
181
|
+
|
|
182
|
+
logger.info(f'Edge date extraction explanation: {explanation}')
|
|
183
|
+
|
|
184
|
+
return valid_at_datetime, invalid_at_datetime, explanation
|
|
File without changes
|