graphiti-core 0.11.6rc9__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/driver/__init__.py +17 -0
- graphiti_core/driver/driver.py +66 -0
- graphiti_core/driver/falkordb_driver.py +132 -0
- graphiti_core/driver/neo4j_driver.py +61 -0
- graphiti_core/edges.py +66 -40
- graphiti_core/embedder/azure_openai.py +64 -0
- graphiti_core/embedder/gemini.py +14 -3
- graphiti_core/graph_queries.py +149 -0
- graphiti_core/graphiti.py +41 -14
- graphiti_core/graphiti_types.py +2 -2
- graphiti_core/helpers.py +9 -4
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/azure_openai_client.py +73 -0
- graphiti_core/llm_client/gemini_client.py +4 -1
- graphiti_core/models/edges/edge_db_queries.py +2 -4
- graphiti_core/nodes.py +31 -31
- graphiti_core/prompts/dedupe_edges.py +52 -1
- graphiti_core/prompts/dedupe_nodes.py +79 -4
- graphiti_core/prompts/extract_edges.py +50 -5
- graphiti_core/prompts/invalidate_edges.py +1 -1
- graphiti_core/search/search.py +6 -10
- graphiti_core/search/search_filters.py +23 -9
- graphiti_core/search/search_utils.py +250 -189
- graphiti_core/utils/bulk_utils.py +38 -11
- graphiti_core/utils/maintenance/community_operations.py +6 -7
- graphiti_core/utils/maintenance/edge_operations.py +149 -19
- graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
- graphiti_core/utils/maintenance/node_operations.py +52 -71
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
|
@@ -20,22 +20,24 @@ from collections import defaultdict
|
|
|
20
20
|
from datetime import datetime
|
|
21
21
|
from math import ceil
|
|
22
22
|
|
|
23
|
-
from neo4j import AsyncDriver, AsyncManagedTransaction
|
|
24
23
|
from numpy import dot, sqrt
|
|
25
24
|
from pydantic import BaseModel
|
|
26
25
|
from typing_extensions import Any
|
|
27
26
|
|
|
27
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
28
28
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
|
29
29
|
from graphiti_core.embedder import EmbedderClient
|
|
30
|
+
from graphiti_core.graph_queries import (
|
|
31
|
+
get_entity_edge_save_bulk_query,
|
|
32
|
+
get_entity_node_save_bulk_query,
|
|
33
|
+
)
|
|
30
34
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
31
35
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
|
32
36
|
from graphiti_core.llm_client import LLMClient
|
|
33
37
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
34
|
-
ENTITY_EDGE_SAVE_BULK,
|
|
35
38
|
EPISODIC_EDGE_SAVE_BULK,
|
|
36
39
|
)
|
|
37
40
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
38
|
-
ENTITY_NODE_SAVE_BULK,
|
|
39
41
|
EPISODIC_NODE_SAVE_BULK,
|
|
40
42
|
)
|
|
41
43
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
@@ -73,7 +75,7 @@ class RawEpisode(BaseModel):
|
|
|
73
75
|
|
|
74
76
|
|
|
75
77
|
async def retrieve_previous_episodes_bulk(
|
|
76
|
-
driver:
|
|
78
|
+
driver: GraphDriver, episodes: list[EpisodicNode]
|
|
77
79
|
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
|
78
80
|
previous_episodes_list = await semaphore_gather(
|
|
79
81
|
*[
|
|
@@ -91,14 +93,15 @@ async def retrieve_previous_episodes_bulk(
|
|
|
91
93
|
|
|
92
94
|
|
|
93
95
|
async def add_nodes_and_edges_bulk(
|
|
94
|
-
driver:
|
|
96
|
+
driver: GraphDriver,
|
|
95
97
|
episodic_nodes: list[EpisodicNode],
|
|
96
98
|
episodic_edges: list[EpisodicEdge],
|
|
97
99
|
entity_nodes: list[EntityNode],
|
|
98
100
|
entity_edges: list[EntityEdge],
|
|
99
101
|
embedder: EmbedderClient,
|
|
100
102
|
):
|
|
101
|
-
|
|
103
|
+
session = driver.session(database=DEFAULT_DATABASE)
|
|
104
|
+
try:
|
|
102
105
|
await session.execute_write(
|
|
103
106
|
add_nodes_and_edges_bulk_tx,
|
|
104
107
|
episodic_nodes,
|
|
@@ -106,16 +109,20 @@ async def add_nodes_and_edges_bulk(
|
|
|
106
109
|
entity_nodes,
|
|
107
110
|
entity_edges,
|
|
108
111
|
embedder,
|
|
112
|
+
driver=driver,
|
|
109
113
|
)
|
|
114
|
+
finally:
|
|
115
|
+
await session.close()
|
|
110
116
|
|
|
111
117
|
|
|
112
118
|
async def add_nodes_and_edges_bulk_tx(
|
|
113
|
-
tx:
|
|
119
|
+
tx: GraphDriverSession,
|
|
114
120
|
episodic_nodes: list[EpisodicNode],
|
|
115
121
|
episodic_edges: list[EpisodicEdge],
|
|
116
122
|
entity_nodes: list[EntityNode],
|
|
117
123
|
entity_edges: list[EntityEdge],
|
|
118
124
|
embedder: EmbedderClient,
|
|
125
|
+
driver: GraphDriver,
|
|
119
126
|
):
|
|
120
127
|
episodes = [dict(episode) for episode in episodic_nodes]
|
|
121
128
|
for episode in episodes:
|
|
@@ -137,16 +144,36 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
137
144
|
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
|
138
145
|
nodes.append(entity_data)
|
|
139
146
|
|
|
147
|
+
edges: list[dict[str, Any]] = []
|
|
140
148
|
for edge in entity_edges:
|
|
141
149
|
if edge.fact_embedding is None:
|
|
142
150
|
await edge.generate_embedding(embedder)
|
|
151
|
+
edge_data: dict[str, Any] = {
|
|
152
|
+
'uuid': edge.uuid,
|
|
153
|
+
'source_node_uuid': edge.source_node_uuid,
|
|
154
|
+
'target_node_uuid': edge.target_node_uuid,
|
|
155
|
+
'name': edge.name,
|
|
156
|
+
'fact': edge.fact,
|
|
157
|
+
'fact_embedding': edge.fact_embedding,
|
|
158
|
+
'group_id': edge.group_id,
|
|
159
|
+
'episodes': edge.episodes,
|
|
160
|
+
'created_at': edge.created_at,
|
|
161
|
+
'expired_at': edge.expired_at,
|
|
162
|
+
'valid_at': edge.valid_at,
|
|
163
|
+
'invalid_at': edge.invalid_at,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
edge_data.update(edge.attributes or {})
|
|
167
|
+
edges.append(edge_data)
|
|
143
168
|
|
|
144
169
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
|
145
|
-
|
|
170
|
+
entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider)
|
|
171
|
+
await tx.run(entity_node_save_bulk, nodes=nodes)
|
|
146
172
|
await tx.run(
|
|
147
173
|
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
|
148
174
|
)
|
|
149
|
-
|
|
175
|
+
entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider)
|
|
176
|
+
await tx.run(entity_edge_save_bulk, entity_edges=edges)
|
|
150
177
|
|
|
151
178
|
|
|
152
179
|
async def extract_nodes_and_edges_bulk(
|
|
@@ -193,7 +220,7 @@ async def extract_nodes_and_edges_bulk(
|
|
|
193
220
|
|
|
194
221
|
|
|
195
222
|
async def dedupe_nodes_bulk(
|
|
196
|
-
driver:
|
|
223
|
+
driver: GraphDriver,
|
|
197
224
|
llm_client: LLMClient,
|
|
198
225
|
extracted_nodes: list[EntityNode],
|
|
199
226
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
@@ -229,7 +256,7 @@ async def dedupe_nodes_bulk(
|
|
|
229
256
|
|
|
230
257
|
|
|
231
258
|
async def dedupe_edges_bulk(
|
|
232
|
-
driver:
|
|
259
|
+
driver: GraphDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
|
233
260
|
) -> list[EntityEdge]:
|
|
234
261
|
# First compress edges
|
|
235
262
|
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
|
@@ -2,9 +2,9 @@ import asyncio
|
|
|
2
2
|
import logging
|
|
3
3
|
from collections import defaultdict
|
|
4
4
|
|
|
5
|
-
from neo4j import AsyncDriver
|
|
6
5
|
from pydantic import BaseModel
|
|
7
6
|
|
|
7
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
8
8
|
from graphiti_core.edges import CommunityEdge
|
|
9
9
|
from graphiti_core.embedder import EmbedderClient
|
|
10
10
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
|
@@ -26,7 +26,7 @@ class Neighbor(BaseModel):
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
async def get_community_clusters(
|
|
29
|
-
driver:
|
|
29
|
+
driver: GraphDriver, group_ids: list[str] | None
|
|
30
30
|
) -> list[list[EntityNode]]:
|
|
31
31
|
community_clusters: list[list[EntityNode]] = []
|
|
32
32
|
|
|
@@ -95,7 +95,6 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
|
|
|
95
95
|
community_candidates: dict[int, int] = defaultdict(int)
|
|
96
96
|
for neighbor in neighbors:
|
|
97
97
|
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
|
|
98
|
-
|
|
99
98
|
community_lst = [
|
|
100
99
|
(count, community) for community, count in community_candidates.items()
|
|
101
100
|
]
|
|
@@ -194,7 +193,7 @@ async def build_community(
|
|
|
194
193
|
|
|
195
194
|
|
|
196
195
|
async def build_communities(
|
|
197
|
-
driver:
|
|
196
|
+
driver: GraphDriver, llm_client: LLMClient, group_ids: list[str] | None
|
|
198
197
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
|
199
198
|
community_clusters = await get_community_clusters(driver, group_ids)
|
|
200
199
|
|
|
@@ -219,7 +218,7 @@ async def build_communities(
|
|
|
219
218
|
return community_nodes, community_edges
|
|
220
219
|
|
|
221
220
|
|
|
222
|
-
async def remove_communities(driver:
|
|
221
|
+
async def remove_communities(driver: GraphDriver):
|
|
223
222
|
await driver.execute_query(
|
|
224
223
|
"""
|
|
225
224
|
MATCH (c:Community)
|
|
@@ -230,7 +229,7 @@ async def remove_communities(driver: AsyncDriver):
|
|
|
230
229
|
|
|
231
230
|
|
|
232
231
|
async def determine_entity_community(
|
|
233
|
-
driver:
|
|
232
|
+
driver: GraphDriver, entity: EntityNode
|
|
234
233
|
) -> tuple[CommunityNode | None, bool]:
|
|
235
234
|
# Check if the node is already part of a community
|
|
236
235
|
records, _, _ = await driver.execute_query(
|
|
@@ -291,7 +290,7 @@ async def determine_entity_community(
|
|
|
291
290
|
|
|
292
291
|
|
|
293
292
|
async def update_community(
|
|
294
|
-
driver:
|
|
293
|
+
driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
|
|
295
294
|
):
|
|
296
295
|
community, is_new = await determine_entity_community(driver, entity)
|
|
297
296
|
|
|
@@ -18,6 +18,8 @@ import logging
|
|
|
18
18
|
from datetime import datetime
|
|
19
19
|
from time import time
|
|
20
20
|
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
21
23
|
from graphiti_core.edges import (
|
|
22
24
|
CommunityEdge,
|
|
23
25
|
EntityEdge,
|
|
@@ -35,9 +37,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
|
|
35
37
|
from graphiti_core.search.search_filters import SearchFilters
|
|
36
38
|
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
|
37
39
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
|
38
|
-
from graphiti_core.utils.maintenance.temporal_operations import (
|
|
39
|
-
get_edge_contradictions,
|
|
40
|
-
)
|
|
41
40
|
|
|
42
41
|
logger = logging.getLogger(__name__)
|
|
43
42
|
|
|
@@ -86,20 +85,32 @@ async def extract_edges(
|
|
|
86
85
|
nodes: list[EntityNode],
|
|
87
86
|
previous_episodes: list[EpisodicNode],
|
|
88
87
|
group_id: str = '',
|
|
88
|
+
edge_types: dict[str, BaseModel] | None = None,
|
|
89
89
|
) -> list[EntityEdge]:
|
|
90
90
|
start = time()
|
|
91
91
|
|
|
92
92
|
extract_edges_max_tokens = 16384
|
|
93
93
|
llm_client = clients.llm_client
|
|
94
94
|
|
|
95
|
-
|
|
95
|
+
edge_types_context = (
|
|
96
|
+
[
|
|
97
|
+
{
|
|
98
|
+
'fact_type_name': type_name,
|
|
99
|
+
'fact_type_description': type_model.__doc__,
|
|
100
|
+
}
|
|
101
|
+
for type_name, type_model in edge_types.items()
|
|
102
|
+
]
|
|
103
|
+
if edge_types is not None
|
|
104
|
+
else []
|
|
105
|
+
)
|
|
96
106
|
|
|
97
107
|
# Prepare context for LLM
|
|
98
108
|
context = {
|
|
99
109
|
'episode_content': episode.content,
|
|
100
|
-
'nodes': [node.name for node in nodes],
|
|
110
|
+
'nodes': [{'id': idx, 'name': node.name} for idx, node in enumerate(nodes)],
|
|
101
111
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
|
102
112
|
'reference_time': episode.valid_at,
|
|
113
|
+
'edge_types': edge_types_context,
|
|
103
114
|
'custom_prompt': '',
|
|
104
115
|
}
|
|
105
116
|
|
|
@@ -148,6 +159,16 @@ async def extract_edges(
|
|
|
148
159
|
valid_at_datetime = None
|
|
149
160
|
invalid_at_datetime = None
|
|
150
161
|
|
|
162
|
+
source_node_idx = edge_data.get('source_entity_id', -1)
|
|
163
|
+
target_node_idx = edge_data.get('target_entity_id', -1)
|
|
164
|
+
if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
|
|
165
|
+
logger.warning(
|
|
166
|
+
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
|
|
167
|
+
)
|
|
168
|
+
continue
|
|
169
|
+
source_node_uuid = nodes[source_node_idx].uuid
|
|
170
|
+
target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
|
|
171
|
+
|
|
151
172
|
if valid_at:
|
|
152
173
|
try:
|
|
153
174
|
valid_at_datetime = ensure_utc(
|
|
@@ -164,12 +185,8 @@ async def extract_edges(
|
|
|
164
185
|
except ValueError as e:
|
|
165
186
|
logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
|
|
166
187
|
edge = EntityEdge(
|
|
167
|
-
source_node_uuid=
|
|
168
|
-
|
|
169
|
-
),
|
|
170
|
-
target_node_uuid=node_uuids_by_name_map.get(
|
|
171
|
-
edge_data.get('target_entity_name', ''), ''
|
|
172
|
-
),
|
|
188
|
+
source_node_uuid=source_node_uuid,
|
|
189
|
+
target_node_uuid=target_node_uuid,
|
|
173
190
|
name=edge_data.get('relation_type', ''),
|
|
174
191
|
group_id=group_id,
|
|
175
192
|
fact=edge_data.get('fact', ''),
|
|
@@ -236,16 +253,18 @@ async def resolve_extracted_edges(
|
|
|
236
253
|
clients: GraphitiClients,
|
|
237
254
|
extracted_edges: list[EntityEdge],
|
|
238
255
|
episode: EpisodicNode,
|
|
256
|
+
entities: list[EntityNode],
|
|
257
|
+
edge_types: dict[str, BaseModel],
|
|
258
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
239
259
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
240
260
|
driver = clients.driver
|
|
241
261
|
llm_client = clients.llm_client
|
|
242
262
|
embedder = clients.embedder
|
|
243
|
-
|
|
244
263
|
await create_entity_edge_embeddings(embedder, extracted_edges)
|
|
245
264
|
|
|
246
265
|
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
|
247
266
|
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
|
248
|
-
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
|
|
267
|
+
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
|
|
249
268
|
)
|
|
250
269
|
|
|
251
270
|
related_edges_lists, edge_invalidation_candidates = search_results
|
|
@@ -254,15 +273,50 @@ async def resolve_extracted_edges(
|
|
|
254
273
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
|
255
274
|
)
|
|
256
275
|
|
|
276
|
+
# Build entity hash table
|
|
277
|
+
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
|
278
|
+
|
|
279
|
+
# Determine which edge types are relevant for each edge
|
|
280
|
+
edge_types_lst: list[dict[str, BaseModel]] = []
|
|
281
|
+
for extracted_edge in extracted_edges:
|
|
282
|
+
source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels + ['Entity']
|
|
283
|
+
target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels + ['Entity']
|
|
284
|
+
label_tuples = [
|
|
285
|
+
(source_label, target_label)
|
|
286
|
+
for source_label in source_node_labels
|
|
287
|
+
for target_label in target_node_labels
|
|
288
|
+
]
|
|
289
|
+
|
|
290
|
+
extracted_edge_types = {}
|
|
291
|
+
for label_tuple in label_tuples:
|
|
292
|
+
type_names = edge_type_map.get(label_tuple, [])
|
|
293
|
+
for type_name in type_names:
|
|
294
|
+
type_model = edge_types.get(type_name)
|
|
295
|
+
if type_model is None:
|
|
296
|
+
continue
|
|
297
|
+
|
|
298
|
+
extracted_edge_types[type_name] = type_model
|
|
299
|
+
|
|
300
|
+
edge_types_lst.append(extracted_edge_types)
|
|
301
|
+
|
|
257
302
|
# resolve edges with related edges in the graph and find invalidation candidates
|
|
258
303
|
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
|
259
304
|
await semaphore_gather(
|
|
260
305
|
*[
|
|
261
306
|
resolve_extracted_edge(
|
|
262
|
-
llm_client,
|
|
307
|
+
llm_client,
|
|
308
|
+
extracted_edge,
|
|
309
|
+
related_edges,
|
|
310
|
+
existing_edges,
|
|
311
|
+
episode,
|
|
312
|
+
extracted_edge_types,
|
|
263
313
|
)
|
|
264
|
-
for extracted_edge, related_edges, existing_edges in zip(
|
|
265
|
-
extracted_edges,
|
|
314
|
+
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
|
315
|
+
extracted_edges,
|
|
316
|
+
related_edges_lists,
|
|
317
|
+
edge_invalidation_candidates,
|
|
318
|
+
edge_types_lst,
|
|
319
|
+
strict=True,
|
|
266
320
|
)
|
|
267
321
|
]
|
|
268
322
|
)
|
|
@@ -326,10 +380,86 @@ async def resolve_extracted_edge(
|
|
|
326
380
|
related_edges: list[EntityEdge],
|
|
327
381
|
existing_edges: list[EntityEdge],
|
|
328
382
|
episode: EpisodicNode,
|
|
383
|
+
edge_types: dict[str, BaseModel] | None = None,
|
|
329
384
|
) -> tuple[EntityEdge, list[EntityEdge]]:
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
385
|
+
if len(related_edges) == 0 and len(existing_edges) == 0:
|
|
386
|
+
return extracted_edge, []
|
|
387
|
+
|
|
388
|
+
start = time()
|
|
389
|
+
|
|
390
|
+
# Prepare context for LLM
|
|
391
|
+
related_edges_context = [
|
|
392
|
+
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
|
393
|
+
]
|
|
394
|
+
|
|
395
|
+
invalidation_edge_candidates_context = [
|
|
396
|
+
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
|
397
|
+
]
|
|
398
|
+
|
|
399
|
+
edge_types_context = (
|
|
400
|
+
[
|
|
401
|
+
{
|
|
402
|
+
'fact_type_id': i,
|
|
403
|
+
'fact_type_name': type_name,
|
|
404
|
+
'fact_type_description': type_model.__doc__,
|
|
405
|
+
}
|
|
406
|
+
for i, (type_name, type_model) in enumerate(edge_types.items())
|
|
407
|
+
]
|
|
408
|
+
if edge_types is not None
|
|
409
|
+
else []
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
context = {
|
|
413
|
+
'existing_edges': related_edges_context,
|
|
414
|
+
'new_edge': extracted_edge.fact,
|
|
415
|
+
'edge_invalidation_candidates': invalidation_edge_candidates_context,
|
|
416
|
+
'edge_types': edge_types_context,
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
llm_response = await llm_client.generate_response(
|
|
420
|
+
prompt_library.dedupe_edges.resolve_edge(context),
|
|
421
|
+
response_model=EdgeDuplicate,
|
|
422
|
+
model_size=ModelSize.small,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
|
426
|
+
|
|
427
|
+
resolved_edge = (
|
|
428
|
+
related_edges[duplicate_fact_id]
|
|
429
|
+
if 0 <= duplicate_fact_id < len(related_edges)
|
|
430
|
+
else extracted_edge
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
if duplicate_fact_id >= 0 and episode is not None:
|
|
434
|
+
resolved_edge.episodes.append(episode.uuid)
|
|
435
|
+
|
|
436
|
+
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
|
437
|
+
|
|
438
|
+
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
|
|
439
|
+
|
|
440
|
+
fact_type: str = str(llm_response.get('fact_type'))
|
|
441
|
+
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
|
|
442
|
+
resolved_edge.name = fact_type
|
|
443
|
+
|
|
444
|
+
edge_attributes_context = {
|
|
445
|
+
'episode_content': episode.content,
|
|
446
|
+
'reference_time': episode.valid_at,
|
|
447
|
+
'fact': resolved_edge.fact,
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
edge_model = edge_types.get(fact_type)
|
|
451
|
+
|
|
452
|
+
edge_attributes_response = await llm_client.generate_response(
|
|
453
|
+
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
|
454
|
+
response_model=edge_model, # type: ignore
|
|
455
|
+
model_size=ModelSize.small,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
resolved_edge.attributes = edge_attributes_response
|
|
459
|
+
|
|
460
|
+
end = time()
|
|
461
|
+
logger.debug(
|
|
462
|
+
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
|
|
333
463
|
)
|
|
334
464
|
|
|
335
465
|
now = utc_now()
|
|
@@ -17,9 +17,10 @@ limitations under the License.
|
|
|
17
17
|
import logging
|
|
18
18
|
from datetime import datetime, timezone
|
|
19
19
|
|
|
20
|
-
from neo4j import AsyncDriver
|
|
21
20
|
from typing_extensions import LiteralString
|
|
22
21
|
|
|
22
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
23
|
+
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
|
23
24
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
|
24
25
|
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
|
25
26
|
|
|
@@ -28,7 +29,7 @@ EPISODE_WINDOW_LEN = 3
|
|
|
28
29
|
logger = logging.getLogger(__name__)
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
async def build_indices_and_constraints(driver:
|
|
32
|
+
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
|
32
33
|
if delete_existing:
|
|
33
34
|
records, _, _ = await driver.execute_query(
|
|
34
35
|
"""
|
|
@@ -47,39 +48,9 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
|
|
47
48
|
for name in index_names
|
|
48
49
|
]
|
|
49
50
|
)
|
|
51
|
+
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
|
50
52
|
|
|
51
|
-
|
|
52
|
-
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
53
|
-
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
54
|
-
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
|
55
|
-
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
|
56
|
-
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
|
57
|
-
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
58
|
-
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
|
59
|
-
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
|
60
|
-
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
|
61
|
-
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
|
62
|
-
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
|
63
|
-
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
|
64
|
-
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
|
65
|
-
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
|
66
|
-
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
|
67
|
-
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
|
68
|
-
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
|
69
|
-
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
|
70
|
-
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
|
71
|
-
]
|
|
72
|
-
|
|
73
|
-
fulltext_indices: list[LiteralString] = [
|
|
74
|
-
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
|
75
|
-
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
|
76
|
-
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
|
77
|
-
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
|
78
|
-
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
|
79
|
-
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
|
80
|
-
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
|
81
|
-
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
|
82
|
-
]
|
|
53
|
+
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
|
83
54
|
|
|
84
55
|
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
|
85
56
|
|
|
@@ -94,7 +65,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
|
|
94
65
|
)
|
|
95
66
|
|
|
96
67
|
|
|
97
|
-
async def clear_data(driver:
|
|
68
|
+
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
|
98
69
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
|
99
70
|
|
|
100
71
|
async def delete_all(tx):
|
|
@@ -113,7 +84,7 @@ async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None):
|
|
|
113
84
|
|
|
114
85
|
|
|
115
86
|
async def retrieve_episodes(
|
|
116
|
-
driver:
|
|
87
|
+
driver: GraphDriver,
|
|
117
88
|
reference_time: datetime,
|
|
118
89
|
last_n: int = EPISODE_WINDOW_LEN,
|
|
119
90
|
group_ids: list[str] | None = None,
|
|
@@ -123,7 +94,7 @@ async def retrieve_episodes(
|
|
|
123
94
|
Retrieve the last n episodic nodes from the graph.
|
|
124
95
|
|
|
125
96
|
Args:
|
|
126
|
-
driver (
|
|
97
|
+
driver (Driver): The Neo4j driver instance.
|
|
127
98
|
reference_time (datetime): The reference time to filter episodes. Only episodes with a valid_at timestamp
|
|
128
99
|
less than or equal to this reference_time will be retrieved. This allows for
|
|
129
100
|
querying the graph's state at a specific point in time.
|
|
@@ -140,8 +111,8 @@ async def retrieve_episodes(
|
|
|
140
111
|
|
|
141
112
|
query: LiteralString = (
|
|
142
113
|
"""
|
|
143
|
-
|
|
144
|
-
|
|
114
|
+
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
|
115
|
+
"""
|
|
145
116
|
+ group_id_filter
|
|
146
117
|
+ source_filter
|
|
147
118
|
+ """
|
|
@@ -157,8 +128,7 @@ async def retrieve_episodes(
|
|
|
157
128
|
LIMIT $num_episodes
|
|
158
129
|
"""
|
|
159
130
|
)
|
|
160
|
-
|
|
161
|
-
result = await driver.execute_query(
|
|
131
|
+
result, _, _ = await driver.execute_query(
|
|
162
132
|
query,
|
|
163
133
|
reference_time=reference_time,
|
|
164
134
|
source=source.name if source is not None else None,
|
|
@@ -166,6 +136,7 @@ async def retrieve_episodes(
|
|
|
166
136
|
group_ids=group_ids,
|
|
167
137
|
database_=DEFAULT_DATABASE,
|
|
168
138
|
)
|
|
139
|
+
|
|
169
140
|
episodes = [
|
|
170
141
|
EpisodicNode(
|
|
171
142
|
content=record['content'],
|
|
@@ -179,6 +150,6 @@ async def retrieve_episodes(
|
|
|
179
150
|
name=record['name'],
|
|
180
151
|
source_description=record['source_description'],
|
|
181
152
|
)
|
|
182
|
-
for record in result
|
|
153
|
+
for record in result
|
|
183
154
|
]
|
|
184
155
|
return list(reversed(episodes)) # Return in chronological order
|