graphiti-core 0.15.1__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/gemini_reranker_client.py +4 -1
- graphiti_core/driver/driver.py +2 -4
- graphiti_core/driver/falkordb_driver.py +9 -7
- graphiti_core/driver/neo4j_driver.py +14 -13
- graphiti_core/edges.py +3 -20
- graphiti_core/embedder/gemini.py +17 -5
- graphiti_core/graphiti.py +107 -57
- graphiti_core/helpers.py +0 -1
- graphiti_core/llm_client/gemini_client.py +8 -5
- graphiti_core/nodes.py +3 -22
- graphiti_core/prompts/dedupe_edges.py +5 -4
- graphiti_core/prompts/dedupe_nodes.py +3 -3
- graphiti_core/search/search_utils.py +1 -20
- graphiti_core/utils/bulk_utils.py +212 -256
- graphiti_core/utils/maintenance/community_operations.py +1 -6
- graphiti_core/utils/maintenance/edge_operations.py +35 -122
- graphiti_core/utils/maintenance/graph_data_operations.py +2 -6
- graphiti_core/utils/maintenance/node_operations.py +11 -58
- {graphiti_core-0.15.1.dist-info → graphiti_core-0.17.0.dist-info}/METADATA +19 -1
- {graphiti_core-0.15.1.dist-info → graphiti_core-0.17.0.dist-info}/RECORD +22 -22
- {graphiti_core-0.15.1.dist-info → graphiti_core-0.17.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.15.1.dist-info → graphiti_core-0.17.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,50 +16,40 @@ limitations under the License.
|
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
18
|
import typing
|
|
19
|
-
from collections import defaultdict
|
|
20
19
|
from datetime import datetime
|
|
21
|
-
from math import ceil
|
|
22
20
|
|
|
23
|
-
|
|
24
|
-
from pydantic import BaseModel
|
|
21
|
+
import numpy as np
|
|
22
|
+
from pydantic import BaseModel, Field
|
|
25
23
|
from typing_extensions import Any
|
|
26
24
|
|
|
27
25
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
28
|
-
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
|
26
|
+
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
|
29
27
|
from graphiti_core.embedder import EmbedderClient
|
|
30
28
|
from graphiti_core.graph_queries import (
|
|
31
29
|
get_entity_edge_save_bulk_query,
|
|
32
30
|
get_entity_node_save_bulk_query,
|
|
33
31
|
)
|
|
34
32
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
35
|
-
from graphiti_core.helpers import
|
|
36
|
-
from graphiti_core.llm_client import LLMClient
|
|
33
|
+
from graphiti_core.helpers import normalize_l2, semaphore_gather
|
|
37
34
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
38
35
|
EPISODIC_EDGE_SAVE_BULK,
|
|
39
36
|
)
|
|
40
37
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
41
38
|
EPISODIC_NODE_SAVE_BULK,
|
|
42
39
|
)
|
|
43
|
-
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
44
|
-
from graphiti_core.search.search_filters import SearchFilters
|
|
45
|
-
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
|
46
|
-
from graphiti_core.utils.datetime_utils import utc_now
|
|
40
|
+
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
|
47
41
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
48
|
-
build_episodic_edges,
|
|
49
|
-
dedupe_edge_list,
|
|
50
|
-
dedupe_extracted_edges,
|
|
51
42
|
extract_edges,
|
|
43
|
+
resolve_extracted_edge,
|
|
52
44
|
)
|
|
53
45
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
54
46
|
EPISODE_WINDOW_LEN,
|
|
55
47
|
retrieve_episodes,
|
|
56
48
|
)
|
|
57
49
|
from graphiti_core.utils.maintenance.node_operations import (
|
|
58
|
-
dedupe_extracted_nodes,
|
|
59
|
-
dedupe_node_list,
|
|
60
50
|
extract_nodes,
|
|
51
|
+
resolve_extracted_nodes,
|
|
61
52
|
)
|
|
62
|
-
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
|
63
53
|
|
|
64
54
|
logger = logging.getLogger(__name__)
|
|
65
55
|
|
|
@@ -68,6 +58,7 @@ CHUNK_SIZE = 10
|
|
|
68
58
|
|
|
69
59
|
class RawEpisode(BaseModel):
|
|
70
60
|
name: str
|
|
61
|
+
uuid: str | None = Field(default=None)
|
|
71
62
|
content: str
|
|
72
63
|
source_description: str
|
|
73
64
|
source: EpisodeType
|
|
@@ -100,7 +91,7 @@ async def add_nodes_and_edges_bulk(
|
|
|
100
91
|
entity_edges: list[EntityEdge],
|
|
101
92
|
embedder: EmbedderClient,
|
|
102
93
|
):
|
|
103
|
-
session = driver.session(
|
|
94
|
+
session = driver.session()
|
|
104
95
|
try:
|
|
105
96
|
await session.execute_write(
|
|
106
97
|
add_nodes_and_edges_bulk_tx,
|
|
@@ -179,233 +170,258 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
179
170
|
async def extract_nodes_and_edges_bulk(
|
|
180
171
|
clients: GraphitiClients,
|
|
181
172
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
173
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
182
174
|
entity_types: dict[str, BaseModel] | None = None,
|
|
183
175
|
excluded_entity_types: list[str] | None = None,
|
|
184
|
-
|
|
185
|
-
|
|
176
|
+
edge_types: dict[str, BaseModel] | None = None,
|
|
177
|
+
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
|
|
178
|
+
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
|
|
186
179
|
*[
|
|
187
180
|
extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
|
|
188
181
|
for episode, previous_episodes in episode_tuples
|
|
189
182
|
]
|
|
190
183
|
)
|
|
191
184
|
|
|
192
|
-
|
|
193
|
-
[episode[0] for episode in episode_tuples],
|
|
194
|
-
[episode[1] for episode in episode_tuples],
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
extracted_edges_bulk = await semaphore_gather(
|
|
185
|
+
extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
|
|
198
186
|
*[
|
|
199
187
|
extract_edges(
|
|
200
188
|
clients,
|
|
201
189
|
episode,
|
|
202
190
|
extracted_nodes_bulk[i],
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
episode.group_id,
|
|
191
|
+
previous_episodes,
|
|
192
|
+
edge_type_map=edge_type_map,
|
|
193
|
+
group_id=episode.group_id,
|
|
194
|
+
edge_types=edge_types,
|
|
206
195
|
)
|
|
207
|
-
for i, episode in enumerate(
|
|
196
|
+
for i, (episode, previous_episodes) in enumerate(episode_tuples)
|
|
208
197
|
]
|
|
209
198
|
)
|
|
210
199
|
|
|
211
|
-
|
|
212
|
-
for i, episode in enumerate(episodes):
|
|
213
|
-
episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at)
|
|
214
|
-
|
|
215
|
-
nodes: list[EntityNode] = []
|
|
216
|
-
for extracted_nodes in extracted_nodes_bulk:
|
|
217
|
-
nodes += extracted_nodes
|
|
218
|
-
|
|
219
|
-
edges: list[EntityEdge] = []
|
|
220
|
-
for extracted_edges in extracted_edges_bulk:
|
|
221
|
-
edges += extracted_edges
|
|
222
|
-
|
|
223
|
-
return nodes, edges, episodic_edges
|
|
200
|
+
return extracted_nodes_bulk, extracted_edges_bulk
|
|
224
201
|
|
|
225
202
|
|
|
226
203
|
async def dedupe_nodes_bulk(
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
|
|
235
|
-
|
|
236
|
-
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
|
237
|
-
|
|
238
|
-
existing_nodes_chunks: list[list[EntityNode]] = list(
|
|
239
|
-
await semaphore_gather(
|
|
240
|
-
*[get_relevant_nodes(driver, node_chunk, SearchFilters()) for node_chunk in node_chunks]
|
|
241
|
-
)
|
|
242
|
-
)
|
|
243
|
-
|
|
244
|
-
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
|
|
245
|
-
await semaphore_gather(
|
|
246
|
-
*[
|
|
247
|
-
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
|
248
|
-
for i, node_chunk in enumerate(node_chunks)
|
|
249
|
-
]
|
|
250
|
-
)
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
final_nodes: list[EntityNode] = []
|
|
254
|
-
for result in results:
|
|
255
|
-
final_nodes.extend(result[0])
|
|
256
|
-
partial_uuid_map = result[1]
|
|
257
|
-
compressed_map.update(partial_uuid_map)
|
|
258
|
-
|
|
259
|
-
return final_nodes, compressed_map
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
async def dedupe_edges_bulk(
|
|
263
|
-
driver: GraphDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
|
264
|
-
) -> list[EntityEdge]:
|
|
265
|
-
# First compress edges
|
|
266
|
-
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
|
267
|
-
|
|
268
|
-
edge_chunks = [
|
|
269
|
-
compressed_edges[i : i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
|
|
270
|
-
]
|
|
204
|
+
clients: GraphitiClients,
|
|
205
|
+
extracted_nodes: list[list[EntityNode]],
|
|
206
|
+
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
207
|
+
entity_types: dict[str, BaseModel] | None = None,
|
|
208
|
+
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
|
209
|
+
embedder = clients.embedder
|
|
210
|
+
min_score = 0.8
|
|
271
211
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
)
|
|
212
|
+
# generate embeddings
|
|
213
|
+
await semaphore_gather(
|
|
214
|
+
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
|
|
276
215
|
)
|
|
277
216
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
217
|
+
# Find similar results
|
|
218
|
+
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
|
|
219
|
+
for i, nodes_i in enumerate(extracted_nodes):
|
|
220
|
+
existing_nodes: list[EntityNode] = []
|
|
221
|
+
for j, nodes_j in enumerate(extracted_nodes):
|
|
222
|
+
if i == j:
|
|
223
|
+
continue
|
|
224
|
+
existing_nodes += nodes_j
|
|
225
|
+
|
|
226
|
+
candidates_i: list[EntityNode] = []
|
|
227
|
+
for node in nodes_i:
|
|
228
|
+
for existing_node in existing_nodes:
|
|
229
|
+
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
|
230
|
+
# This approach will cast a wider net than BM25, which is ideal for this use case
|
|
231
|
+
node_words = set(node.name.lower().split())
|
|
232
|
+
existing_node_words = set(existing_node.name.lower().split())
|
|
233
|
+
has_overlap = not node_words.isdisjoint(existing_node_words)
|
|
234
|
+
if has_overlap:
|
|
235
|
+
candidates_i.append(existing_node)
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
# Check for semantic similarity even if there is no overlap
|
|
239
|
+
similarity = np.dot(
|
|
240
|
+
normalize_l2(node.name_embedding or []),
|
|
241
|
+
normalize_l2(existing_node.name_embedding or []),
|
|
242
|
+
)
|
|
243
|
+
if similarity >= min_score:
|
|
244
|
+
candidates_i.append(existing_node)
|
|
245
|
+
|
|
246
|
+
dedupe_tuples.append((nodes_i, candidates_i))
|
|
247
|
+
|
|
248
|
+
# Determine Node Resolutions
|
|
249
|
+
bulk_node_resolutions: list[
|
|
250
|
+
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
|
|
251
|
+
] = await semaphore_gather(
|
|
252
|
+
*[
|
|
253
|
+
resolve_extracted_nodes(
|
|
254
|
+
clients,
|
|
255
|
+
dedupe_tuple[0],
|
|
256
|
+
episode_tuples[i][0],
|
|
257
|
+
episode_tuples[i][1],
|
|
258
|
+
entity_types,
|
|
259
|
+
existing_nodes_override=dedupe_tuples[i][1],
|
|
260
|
+
)
|
|
261
|
+
for i, dedupe_tuple in enumerate(dedupe_tuples)
|
|
262
|
+
]
|
|
285
263
|
)
|
|
286
264
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
265
|
+
# Collect all duplicate pairs sorted by uuid
|
|
266
|
+
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = []
|
|
267
|
+
for _, _, duplicates in bulk_node_resolutions:
|
|
268
|
+
for duplicate in duplicates:
|
|
269
|
+
n, m = duplicate
|
|
270
|
+
if n.uuid < m.uuid:
|
|
271
|
+
duplicate_pairs.append((n, m))
|
|
272
|
+
else:
|
|
273
|
+
duplicate_pairs.append((m, n))
|
|
274
|
+
|
|
275
|
+
# Build full deduplication map
|
|
276
|
+
duplicate_map: dict[str, str] = {}
|
|
277
|
+
for value, key in duplicate_pairs:
|
|
278
|
+
if key.uuid in duplicate_map:
|
|
279
|
+
existing_value = duplicate_map[key.uuid]
|
|
280
|
+
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
|
281
|
+
else:
|
|
282
|
+
duplicate_map[key.uuid] = value.uuid
|
|
303
283
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
307
|
-
# We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
|
|
308
|
-
if len(nodes) == 0:
|
|
309
|
-
return nodes, uuid_map
|
|
284
|
+
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
|
285
|
+
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
|
310
286
|
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
|
|
287
|
+
node_uuid_map: dict[str, EntityNode] = {
|
|
288
|
+
node.uuid: node for nodes in extracted_nodes for node in nodes
|
|
289
|
+
}
|
|
315
290
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
for i, n in enumerate(nodes)
|
|
320
|
-
for j, m in enumerate(nodes[:i])
|
|
321
|
-
]
|
|
291
|
+
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
|
292
|
+
for i, nodes in enumerate(extracted_nodes):
|
|
293
|
+
episode = episode_tuples[i][0]
|
|
322
294
|
|
|
323
|
-
|
|
324
|
-
|
|
295
|
+
nodes_by_episode[episode.uuid] = [
|
|
296
|
+
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
|
|
297
|
+
]
|
|
325
298
|
|
|
326
|
-
|
|
327
|
-
node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))]
|
|
299
|
+
return nodes_by_episode, compressed_map
|
|
328
300
|
|
|
329
|
-
# Draft the most similar nodes into the same chunk
|
|
330
|
-
while len(similarity_scores) > 0:
|
|
331
|
-
i, j, _ = similarity_scores.pop()
|
|
332
|
-
# determine if any of the nodes have already been drafted into a chunk
|
|
333
|
-
n = nodes[i]
|
|
334
|
-
m = nodes[j]
|
|
335
|
-
# make sure the shortest chunks get preference
|
|
336
|
-
node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
|
|
337
301
|
|
|
338
|
-
|
|
339
|
-
|
|
302
|
+
async def dedupe_edges_bulk(
|
|
303
|
+
clients: GraphitiClients,
|
|
304
|
+
extracted_edges: list[list[EntityEdge]],
|
|
305
|
+
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
306
|
+
_entities: list[EntityNode],
|
|
307
|
+
edge_types: dict[str, BaseModel],
|
|
308
|
+
_edge_type_map: dict[tuple[str, str], list[str]],
|
|
309
|
+
) -> dict[str, list[EntityEdge]]:
|
|
310
|
+
embedder = clients.embedder
|
|
311
|
+
min_score = 0.6
|
|
312
|
+
|
|
313
|
+
# generate embeddings
|
|
314
|
+
await semaphore_gather(
|
|
315
|
+
*[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
|
|
316
|
+
)
|
|
340
317
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
318
|
+
# Find similar results
|
|
319
|
+
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
|
|
320
|
+
for i, edges_i in enumerate(extracted_edges):
|
|
321
|
+
existing_edges: list[EntityEdge] = []
|
|
322
|
+
for j, edges_j in enumerate(extracted_edges):
|
|
323
|
+
if i == j:
|
|
324
|
+
continue
|
|
325
|
+
existing_edges += edges_j
|
|
326
|
+
|
|
327
|
+
for edge in edges_i:
|
|
328
|
+
candidates: list[EntityEdge] = []
|
|
329
|
+
for existing_edge in existing_edges:
|
|
330
|
+
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
|
331
|
+
# This approach will cast a wider net than BM25, which is ideal for this use case
|
|
332
|
+
edge_words = set(edge.fact.lower().split())
|
|
333
|
+
existing_edge_words = set(existing_edge.fact.lower().split())
|
|
334
|
+
has_overlap = not edge_words.isdisjoint(existing_edge_words)
|
|
335
|
+
if has_overlap:
|
|
336
|
+
candidates.append(existing_edge)
|
|
337
|
+
continue
|
|
338
|
+
|
|
339
|
+
# Check for semantic similarity even if there is no overlap
|
|
340
|
+
similarity = np.dot(
|
|
341
|
+
normalize_l2(edge.fact_embedding or []),
|
|
342
|
+
normalize_l2(existing_edge.fact_embedding or []),
|
|
343
|
+
)
|
|
344
|
+
if similarity >= min_score:
|
|
345
|
+
candidates.append(existing_edge)
|
|
346
|
+
|
|
347
|
+
dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
|
|
348
|
+
|
|
349
|
+
bulk_edge_resolutions: list[
|
|
350
|
+
tuple[EntityEdge, EntityEdge, list[EntityEdge]]
|
|
351
|
+
] = await semaphore_gather(
|
|
352
|
+
*[
|
|
353
|
+
resolve_extracted_edge(
|
|
354
|
+
clients.llm_client, edge, candidates, candidates, episode, edge_types
|
|
355
|
+
)
|
|
356
|
+
for episode, edge, candidates in dedupe_tuples
|
|
357
|
+
]
|
|
358
|
+
)
|
|
344
359
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
360
|
+
duplicate_pairs: list[tuple[EntityEdge, EntityEdge]] = []
|
|
361
|
+
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
|
|
362
|
+
episode, edge, candidates = dedupe_tuples[i]
|
|
363
|
+
for duplicate in duplicates:
|
|
364
|
+
if edge.uuid < duplicate.uuid:
|
|
365
|
+
duplicate_pairs.append((edge, duplicate))
|
|
366
|
+
else:
|
|
367
|
+
duplicate_pairs.append((duplicate, edge))
|
|
368
|
+
|
|
369
|
+
# Build full deduplication map
|
|
370
|
+
duplicate_map: dict[str, str] = {}
|
|
371
|
+
for value, key in duplicate_pairs:
|
|
372
|
+
if key.uuid in duplicate_map:
|
|
373
|
+
existing_value = duplicate_map[key.uuid]
|
|
374
|
+
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
|
375
|
+
else:
|
|
376
|
+
duplicate_map[key.uuid] = value.uuid
|
|
349
377
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
# put n in the same chunk as m
|
|
353
|
-
node_chunks[m_chunk].append(n)
|
|
378
|
+
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
|
379
|
+
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
|
354
380
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
node_chunks[-1].extend([n, m])
|
|
381
|
+
edge_uuid_map: dict[str, EntityEdge] = {
|
|
382
|
+
edge.uuid: edge for edges in extracted_edges for edge in edges
|
|
383
|
+
}
|
|
359
384
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
385
|
+
edges_by_episode: dict[str, list[EntityEdge]] = {}
|
|
386
|
+
for i, edges in enumerate(extracted_edges):
|
|
387
|
+
episode = episode_tuples[i][0]
|
|
363
388
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
compressed_nodes += node_chunk
|
|
368
|
-
extended_map.update(uuid_map_chunk)
|
|
389
|
+
edges_by_episode[episode.uuid] = [
|
|
390
|
+
edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
|
|
391
|
+
]
|
|
369
392
|
|
|
370
|
-
|
|
371
|
-
if len(compressed_nodes) == len(nodes):
|
|
372
|
-
compressed_uuid_map = compress_uuid_map(extended_map)
|
|
373
|
-
return compressed_nodes, compressed_uuid_map
|
|
393
|
+
return edges_by_episode
|
|
374
394
|
|
|
375
|
-
return await compress_nodes(llm_client, compressed_nodes, extended_map)
|
|
376
395
|
|
|
396
|
+
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
|
397
|
+
compressed_map = {}
|
|
377
398
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
# We build a map of the edges based on their source and target nodes.
|
|
383
|
-
edge_chunks = chunk_edges_by_nodes(edges)
|
|
399
|
+
def find_min_uuid(start: str) -> str:
|
|
400
|
+
path = []
|
|
401
|
+
visited = set()
|
|
402
|
+
curr = start
|
|
384
403
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
404
|
+
while curr in uuid_map and curr not in visited:
|
|
405
|
+
visited.add(curr)
|
|
406
|
+
path.append(curr)
|
|
407
|
+
curr = uuid_map[curr]
|
|
388
408
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
compressed_edges += edge_chunk
|
|
409
|
+
# Also include the last resolved value (could be outside the map)
|
|
410
|
+
path.append(curr)
|
|
392
411
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
return compressed_edges
|
|
412
|
+
# Resolve to lex smallest UUID in the path
|
|
413
|
+
min_uuid = min(path)
|
|
396
414
|
|
|
397
|
-
|
|
415
|
+
# Assign all UUIDs in the path to the min_uuid
|
|
416
|
+
for node in path:
|
|
417
|
+
compressed_map[node] = min_uuid
|
|
398
418
|
|
|
419
|
+
return min_uuid
|
|
399
420
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
for key, uuid in uuid_map.items():
|
|
404
|
-
curr_value = uuid
|
|
405
|
-
while curr_value in uuid_map:
|
|
406
|
-
curr_value = uuid_map[curr_value]
|
|
421
|
+
for key in uuid_map:
|
|
422
|
+
if key not in compressed_map:
|
|
423
|
+
find_min_uuid(key)
|
|
407
424
|
|
|
408
|
-
compressed_map[key] = curr_value
|
|
409
425
|
return compressed_map
|
|
410
426
|
|
|
411
427
|
|
|
@@ -420,63 +436,3 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
|
|
420
436
|
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
|
421
437
|
|
|
422
438
|
return edges
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
async def extract_edge_dates_bulk(
|
|
426
|
-
llm_client: LLMClient,
|
|
427
|
-
extracted_edges: list[EntityEdge],
|
|
428
|
-
episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
429
|
-
) -> list[EntityEdge]:
|
|
430
|
-
edges: list[EntityEdge] = []
|
|
431
|
-
# confirm that all of our edges have at least one episode
|
|
432
|
-
for edge in extracted_edges:
|
|
433
|
-
if edge.episodes is not None and len(edge.episodes) > 0:
|
|
434
|
-
edges.append(edge)
|
|
435
|
-
|
|
436
|
-
episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
|
|
437
|
-
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
|
|
438
|
-
}
|
|
439
|
-
|
|
440
|
-
results = await semaphore_gather(
|
|
441
|
-
*[
|
|
442
|
-
extract_edge_dates(
|
|
443
|
-
llm_client,
|
|
444
|
-
edge,
|
|
445
|
-
episode_uuid_map[edge.episodes[0]][0], # type: ignore
|
|
446
|
-
episode_uuid_map[edge.episodes[0]][1], # type: ignore
|
|
447
|
-
)
|
|
448
|
-
for edge in edges
|
|
449
|
-
]
|
|
450
|
-
)
|
|
451
|
-
|
|
452
|
-
for i, result in enumerate(results):
|
|
453
|
-
valid_at = result[0]
|
|
454
|
-
invalid_at = result[1]
|
|
455
|
-
edge = edges[i]
|
|
456
|
-
|
|
457
|
-
edge.valid_at = valid_at
|
|
458
|
-
edge.invalid_at = invalid_at
|
|
459
|
-
if edge.invalid_at:
|
|
460
|
-
edge.expired_at = utc_now()
|
|
461
|
-
|
|
462
|
-
return edges
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
|
|
466
|
-
# We only want to dedupe edges that are between the same pair of nodes
|
|
467
|
-
# We build a map of the edges based on their source and target nodes.
|
|
468
|
-
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
|
469
|
-
for edge in edges:
|
|
470
|
-
# We drop loop edges
|
|
471
|
-
if edge.source_node_uuid == edge.target_node_uuid:
|
|
472
|
-
continue
|
|
473
|
-
|
|
474
|
-
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
|
475
|
-
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
|
476
|
-
pointers.sort()
|
|
477
|
-
|
|
478
|
-
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
|
479
|
-
|
|
480
|
-
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
|
481
|
-
|
|
482
|
-
return edge_chunks
|
|
@@ -7,7 +7,7 @@ from pydantic import BaseModel
|
|
|
7
7
|
from graphiti_core.driver.driver import GraphDriver
|
|
8
8
|
from graphiti_core.edges import CommunityEdge
|
|
9
9
|
from graphiti_core.embedder import EmbedderClient
|
|
10
|
-
from graphiti_core.helpers import
|
|
10
|
+
from graphiti_core.helpers import semaphore_gather
|
|
11
11
|
from graphiti_core.llm_client import LLMClient
|
|
12
12
|
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
|
|
13
13
|
from graphiti_core.prompts import prompt_library
|
|
@@ -37,7 +37,6 @@ async def get_community_clusters(
|
|
|
37
37
|
RETURN
|
|
38
38
|
collect(DISTINCT n.group_id) AS group_ids
|
|
39
39
|
""",
|
|
40
|
-
database_=DEFAULT_DATABASE,
|
|
41
40
|
)
|
|
42
41
|
|
|
43
42
|
group_ids = group_id_values[0]['group_ids'] if group_id_values else []
|
|
@@ -56,7 +55,6 @@ async def get_community_clusters(
|
|
|
56
55
|
""",
|
|
57
56
|
uuid=node.uuid,
|
|
58
57
|
group_id=group_id,
|
|
59
|
-
database_=DEFAULT_DATABASE,
|
|
60
58
|
)
|
|
61
59
|
|
|
62
60
|
projection[node.uuid] = [
|
|
@@ -224,7 +222,6 @@ async def remove_communities(driver: GraphDriver):
|
|
|
224
222
|
MATCH (c:Community)
|
|
225
223
|
DETACH DELETE c
|
|
226
224
|
""",
|
|
227
|
-
database_=DEFAULT_DATABASE,
|
|
228
225
|
)
|
|
229
226
|
|
|
230
227
|
|
|
@@ -243,7 +240,6 @@ async def determine_entity_community(
|
|
|
243
240
|
c.summary AS summary
|
|
244
241
|
""",
|
|
245
242
|
entity_uuid=entity.uuid,
|
|
246
|
-
database_=DEFAULT_DATABASE,
|
|
247
243
|
)
|
|
248
244
|
|
|
249
245
|
if len(records) > 0:
|
|
@@ -261,7 +257,6 @@ async def determine_entity_community(
|
|
|
261
257
|
c.summary AS summary
|
|
262
258
|
""",
|
|
263
259
|
entity_uuid=entity.uuid,
|
|
264
|
-
database_=DEFAULT_DATABASE,
|
|
265
260
|
)
|
|
266
261
|
|
|
267
262
|
communities: list[CommunityNode] = [
|