graphiti-core 0.10.4__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/edges.py +32 -57
- graphiti_core/embedder/client.py +3 -0
- graphiti_core/embedder/gemini.py +10 -0
- graphiti_core/embedder/openai.py +6 -0
- graphiti_core/embedder/voyage.py +7 -0
- graphiti_core/graphiti.py +42 -138
- graphiti_core/graphiti_types.py +31 -0
- graphiti_core/helpers.py +6 -1
- graphiti_core/llm_client/anthropic_client.py +4 -1
- graphiti_core/llm_client/client.py +4 -1
- graphiti_core/llm_client/gemini_client.py +4 -1
- graphiti_core/llm_client/openai_client.py +4 -1
- graphiti_core/llm_client/openai_generic_client.py +4 -1
- graphiti_core/models/edges/edge_db_queries.py +1 -1
- graphiti_core/nodes.py +10 -10
- graphiti_core/prompts/dedupe_edges.py +5 -7
- graphiti_core/prompts/dedupe_nodes.py +8 -21
- graphiti_core/prompts/extract_edges.py +61 -26
- graphiti_core/prompts/extract_nodes.py +89 -18
- graphiti_core/prompts/invalidate_edges.py +11 -11
- graphiti_core/search/search.py +13 -5
- graphiti_core/search/search_utils.py +206 -98
- graphiti_core/utils/bulk_utils.py +10 -7
- graphiti_core/utils/maintenance/edge_operations.py +88 -40
- graphiti_core/utils/maintenance/graph_data_operations.py +20 -6
- graphiti_core/utils/maintenance/node_operations.py +216 -223
- graphiti_core/utils/maintenance/temporal_operations.py +4 -11
- {graphiti_core-0.10.4.dist-info → graphiti_core-0.11.0.dist-info}/METADATA +25 -11
- {graphiti_core-0.10.4.dist-info → graphiti_core-0.11.0.dist-info}/RECORD +31 -30
- {graphiti_core-0.10.4.dist-info → graphiti_core-0.11.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.10.4.dist-info → graphiti_core-0.11.0.dist-info}/WHEEL +0 -0
graphiti_core/edges.py
CHANGED
|
@@ -37,6 +37,21 @@ from graphiti_core.nodes import Node
|
|
|
37
37
|
|
|
38
38
|
logger = logging.getLogger(__name__)
|
|
39
39
|
|
|
40
|
+
ENTITY_EDGE_RETURN: LiteralString = """
|
|
41
|
+
RETURN
|
|
42
|
+
e.uuid AS uuid,
|
|
43
|
+
startNode(e).uuid AS source_node_uuid,
|
|
44
|
+
endNode(e).uuid AS target_node_uuid,
|
|
45
|
+
e.created_at AS created_at,
|
|
46
|
+
e.name AS name,
|
|
47
|
+
e.group_id AS group_id,
|
|
48
|
+
e.fact AS fact,
|
|
49
|
+
e.fact_embedding AS fact_embedding,
|
|
50
|
+
e.episodes AS episodes,
|
|
51
|
+
e.expired_at AS expired_at,
|
|
52
|
+
e.valid_at AS valid_at,
|
|
53
|
+
e.invalid_at AS invalid_at"""
|
|
54
|
+
|
|
40
55
|
|
|
41
56
|
class Edge(BaseModel, ABC):
|
|
42
57
|
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
|
@@ -234,20 +249,8 @@ class EntityEdge(Edge):
|
|
|
234
249
|
records, _, _ = await driver.execute_query(
|
|
235
250
|
"""
|
|
236
251
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
n.uuid AS source_node_uuid,
|
|
240
|
-
m.uuid AS target_node_uuid,
|
|
241
|
-
e.created_at AS created_at,
|
|
242
|
-
e.name AS name,
|
|
243
|
-
e.group_id AS group_id,
|
|
244
|
-
e.fact AS fact,
|
|
245
|
-
e.fact_embedding AS fact_embedding,
|
|
246
|
-
e.episodes AS episodes,
|
|
247
|
-
e.expired_at AS expired_at,
|
|
248
|
-
e.valid_at AS valid_at,
|
|
249
|
-
e.invalid_at AS invalid_at
|
|
250
|
-
""",
|
|
252
|
+
"""
|
|
253
|
+
+ ENTITY_EDGE_RETURN,
|
|
251
254
|
uuid=uuid,
|
|
252
255
|
database_=DEFAULT_DATABASE,
|
|
253
256
|
routing_='r',
|
|
@@ -268,20 +271,8 @@ class EntityEdge(Edge):
|
|
|
268
271
|
"""
|
|
269
272
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
270
273
|
WHERE e.uuid IN $uuids
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
n.uuid AS source_node_uuid,
|
|
274
|
-
m.uuid AS target_node_uuid,
|
|
275
|
-
e.created_at AS created_at,
|
|
276
|
-
e.name AS name,
|
|
277
|
-
e.group_id AS group_id,
|
|
278
|
-
e.fact AS fact,
|
|
279
|
-
e.fact_embedding AS fact_embedding,
|
|
280
|
-
e.episodes AS episodes,
|
|
281
|
-
e.expired_at AS expired_at,
|
|
282
|
-
e.valid_at AS valid_at,
|
|
283
|
-
e.invalid_at AS invalid_at
|
|
284
|
-
""",
|
|
274
|
+
"""
|
|
275
|
+
+ ENTITY_EDGE_RETURN,
|
|
285
276
|
uuids=uuids,
|
|
286
277
|
database_=DEFAULT_DATABASE,
|
|
287
278
|
routing_='r',
|
|
@@ -308,20 +299,8 @@ class EntityEdge(Edge):
|
|
|
308
299
|
WHERE e.group_id IN $group_ids
|
|
309
300
|
"""
|
|
310
301
|
+ cursor_query
|
|
302
|
+
+ ENTITY_EDGE_RETURN
|
|
311
303
|
+ """
|
|
312
|
-
RETURN
|
|
313
|
-
e.uuid AS uuid,
|
|
314
|
-
n.uuid AS source_node_uuid,
|
|
315
|
-
m.uuid AS target_node_uuid,
|
|
316
|
-
e.created_at AS created_at,
|
|
317
|
-
e.name AS name,
|
|
318
|
-
e.group_id AS group_id,
|
|
319
|
-
e.fact AS fact,
|
|
320
|
-
e.fact_embedding AS fact_embedding,
|
|
321
|
-
e.episodes AS episodes,
|
|
322
|
-
e.expired_at AS expired_at,
|
|
323
|
-
e.valid_at AS valid_at,
|
|
324
|
-
e.invalid_at AS invalid_at
|
|
325
304
|
ORDER BY e.uuid DESC
|
|
326
305
|
"""
|
|
327
306
|
+ limit_query,
|
|
@@ -340,22 +319,12 @@ class EntityEdge(Edge):
|
|
|
340
319
|
|
|
341
320
|
@classmethod
|
|
342
321
|
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
|
343
|
-
query: LiteralString =
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
e.created_at AS created_at,
|
|
350
|
-
e.name AS name,
|
|
351
|
-
e.group_id AS group_id,
|
|
352
|
-
e.fact AS fact,
|
|
353
|
-
e.fact_embedding AS fact_embedding,
|
|
354
|
-
e.episodes AS episodes,
|
|
355
|
-
e.expired_at AS expired_at,
|
|
356
|
-
e.valid_at AS valid_at,
|
|
357
|
-
e.invalid_at AS invalid_at
|
|
358
|
-
"""
|
|
322
|
+
query: LiteralString = (
|
|
323
|
+
"""
|
|
324
|
+
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
|
325
|
+
"""
|
|
326
|
+
+ ENTITY_EDGE_RETURN
|
|
327
|
+
)
|
|
359
328
|
records, _, _ = await driver.execute_query(
|
|
360
329
|
query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
|
|
361
330
|
)
|
|
@@ -499,3 +468,9 @@ def get_community_edge_from_record(record: Any):
|
|
|
499
468
|
target_node_uuid=record['target_node_uuid'],
|
|
500
469
|
created_at=record['created_at'].to_native(),
|
|
501
470
|
)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
|
474
|
+
fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
|
|
475
|
+
for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
|
|
476
|
+
edge.fact_embedding = fact_embedding
|
graphiti_core/embedder/client.py
CHANGED
graphiti_core/embedder/gemini.py
CHANGED
|
@@ -66,3 +66,13 @@ class GeminiEmbedder(EmbedderClient):
|
|
|
66
66
|
)
|
|
67
67
|
|
|
68
68
|
return result.embeddings[0].values
|
|
69
|
+
|
|
70
|
+
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
|
71
|
+
# Generate embeddings
|
|
72
|
+
result = await self.client.aio.models.embed_content(
|
|
73
|
+
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
|
|
74
|
+
contents=input_data_list,
|
|
75
|
+
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return [embedding.values for embedding in result.embeddings]
|
graphiti_core/embedder/openai.py
CHANGED
|
@@ -58,3 +58,9 @@ class OpenAIEmbedder(EmbedderClient):
|
|
|
58
58
|
input=input_data, model=self.config.embedding_model
|
|
59
59
|
)
|
|
60
60
|
return result.data[0].embedding[: self.config.embedding_dim]
|
|
61
|
+
|
|
62
|
+
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
|
63
|
+
result = await self.client.embeddings.create(
|
|
64
|
+
input=input_data_list, model=self.config.embedding_model
|
|
65
|
+
)
|
|
66
|
+
return [embedding.embedding[: self.config.embedding_dim] for embedding in result.data]
|
graphiti_core/embedder/voyage.py
CHANGED
|
@@ -56,3 +56,10 @@ class VoyageAIEmbedder(EmbedderClient):
|
|
|
56
56
|
|
|
57
57
|
result = await self.client.embed(input_list, model=self.config.embedding_model)
|
|
58
58
|
return [float(x) for x in result.embeddings[0][: self.config.embedding_dim]]
|
|
59
|
+
|
|
60
|
+
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
|
61
|
+
result = await self.client.embed(input_data_list, model=self.config.embedding_model)
|
|
62
|
+
return [
|
|
63
|
+
[float(x) for x in embedding[: self.config.embedding_dim]]
|
|
64
|
+
for embedding in result.embeddings
|
|
65
|
+
]
|
graphiti_core/graphiti.py
CHANGED
|
@@ -27,6 +27,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
|
27
27
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
28
28
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
29
29
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
30
|
+
from graphiti_core.graphiti_types import GraphitiClients
|
|
30
31
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
|
31
32
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
32
33
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
|
@@ -42,7 +43,6 @@ from graphiti_core.search.search_utils import (
|
|
|
42
43
|
RELEVANT_SCHEMA_LIMIT,
|
|
43
44
|
get_mentioned_nodes,
|
|
44
45
|
get_relevant_edges,
|
|
45
|
-
get_relevant_nodes,
|
|
46
46
|
)
|
|
47
47
|
from graphiti_core.utils.bulk_utils import (
|
|
48
48
|
RawEpisode,
|
|
@@ -72,7 +72,11 @@ from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
|
72
72
|
build_indices_and_constraints,
|
|
73
73
|
retrieve_episodes,
|
|
74
74
|
)
|
|
75
|
-
from graphiti_core.utils.maintenance.node_operations import
|
|
75
|
+
from graphiti_core.utils.maintenance.node_operations import (
|
|
76
|
+
extract_attributes_from_nodes,
|
|
77
|
+
extract_nodes,
|
|
78
|
+
resolve_extracted_nodes,
|
|
79
|
+
)
|
|
76
80
|
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
|
|
77
81
|
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
|
|
78
82
|
|
|
@@ -150,6 +154,13 @@ class Graphiti:
|
|
|
150
154
|
else:
|
|
151
155
|
self.cross_encoder = OpenAIRerankerClient()
|
|
152
156
|
|
|
157
|
+
self.clients = GraphitiClients(
|
|
158
|
+
driver=self.driver,
|
|
159
|
+
llm_client=self.llm_client,
|
|
160
|
+
embedder=self.embedder,
|
|
161
|
+
cross_encoder=self.cross_encoder,
|
|
162
|
+
)
|
|
163
|
+
|
|
153
164
|
async def close(self):
|
|
154
165
|
"""
|
|
155
166
|
Close the connection to the Neo4j database.
|
|
@@ -222,6 +233,7 @@ class Graphiti:
|
|
|
222
233
|
reference_time: datetime,
|
|
223
234
|
last_n: int = EPISODE_WINDOW_LEN,
|
|
224
235
|
group_ids: list[str] | None = None,
|
|
236
|
+
source: EpisodeType | None = None,
|
|
225
237
|
) -> list[EpisodicNode]:
|
|
226
238
|
"""
|
|
227
239
|
Retrieve the last n episodic nodes from the graph.
|
|
@@ -248,7 +260,7 @@ class Graphiti:
|
|
|
248
260
|
The actual retrieval is performed by the `retrieve_episodes` function
|
|
249
261
|
from the `graphiti_core.utils` module.
|
|
250
262
|
"""
|
|
251
|
-
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
|
|
263
|
+
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
|
|
252
264
|
|
|
253
265
|
async def add_episode(
|
|
254
266
|
self,
|
|
@@ -314,15 +326,16 @@ class Graphiti:
|
|
|
314
326
|
"""
|
|
315
327
|
try:
|
|
316
328
|
start = time()
|
|
317
|
-
|
|
318
|
-
entity_edges: list[EntityEdge] = []
|
|
319
329
|
now = utc_now()
|
|
320
330
|
|
|
321
331
|
validate_entity_types(entity_types)
|
|
322
332
|
|
|
323
333
|
previous_episodes = (
|
|
324
334
|
await self.retrieve_episodes(
|
|
325
|
-
reference_time,
|
|
335
|
+
reference_time,
|
|
336
|
+
last_n=RELEVANT_SCHEMA_LIMIT,
|
|
337
|
+
group_ids=[group_id],
|
|
338
|
+
source=source,
|
|
326
339
|
)
|
|
327
340
|
if previous_episode_uuids is None
|
|
328
341
|
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
|
|
@@ -346,132 +359,36 @@ class Graphiti:
|
|
|
346
359
|
# Extract entities as nodes
|
|
347
360
|
|
|
348
361
|
extracted_nodes = await extract_nodes(
|
|
349
|
-
self.
|
|
362
|
+
self.clients, episode, previous_episodes, entity_types
|
|
350
363
|
)
|
|
351
|
-
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
352
|
-
|
|
353
|
-
# Calculate Embeddings
|
|
354
364
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
)
|
|
358
|
-
|
|
359
|
-
# Find relevant nodes already in the graph
|
|
360
|
-
existing_nodes_lists: list[list[EntityNode]] = list(
|
|
361
|
-
await semaphore_gather(
|
|
362
|
-
*[
|
|
363
|
-
get_relevant_nodes(self.driver, SearchFilters(), [node])
|
|
364
|
-
for node in extracted_nodes
|
|
365
|
-
]
|
|
366
|
-
)
|
|
367
|
-
)
|
|
368
|
-
|
|
369
|
-
# Resolve extracted nodes with nodes already in the graph and extract facts
|
|
370
|
-
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
371
|
-
|
|
372
|
-
(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
|
|
365
|
+
# Extract edges and resolve nodes
|
|
366
|
+
(nodes, uuid_map), extracted_edges = await semaphore_gather(
|
|
373
367
|
resolve_extracted_nodes(
|
|
374
|
-
self.
|
|
368
|
+
self.clients,
|
|
375
369
|
extracted_nodes,
|
|
376
|
-
existing_nodes_lists,
|
|
377
370
|
episode,
|
|
378
371
|
previous_episodes,
|
|
379
372
|
entity_types,
|
|
380
373
|
),
|
|
381
|
-
extract_edges(
|
|
382
|
-
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
|
383
|
-
),
|
|
384
|
-
)
|
|
385
|
-
logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
|
386
|
-
nodes = mentioned_nodes
|
|
387
|
-
|
|
388
|
-
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
|
|
389
|
-
extracted_edges, uuid_map
|
|
374
|
+
extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
|
|
390
375
|
)
|
|
391
376
|
|
|
392
|
-
|
|
393
|
-
await semaphore_gather(
|
|
394
|
-
*[
|
|
395
|
-
edge.generate_embedding(self.embedder)
|
|
396
|
-
for edge in extracted_edges_with_resolved_pointers
|
|
397
|
-
]
|
|
398
|
-
)
|
|
377
|
+
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
|
399
378
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
edge.target_node_uuid,
|
|
409
|
-
RELEVANT_SCHEMA_LIMIT,
|
|
410
|
-
)
|
|
411
|
-
for edge in extracted_edges_with_resolved_pointers
|
|
412
|
-
]
|
|
413
|
-
)
|
|
414
|
-
)
|
|
415
|
-
logger.debug(
|
|
416
|
-
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
|
|
417
|
-
)
|
|
418
|
-
logger.debug(
|
|
419
|
-
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
|
|
420
|
-
)
|
|
421
|
-
|
|
422
|
-
existing_source_edges_list: list[list[EntityEdge]] = list(
|
|
423
|
-
await semaphore_gather(
|
|
424
|
-
*[
|
|
425
|
-
get_relevant_edges(
|
|
426
|
-
self.driver,
|
|
427
|
-
[edge],
|
|
428
|
-
edge.source_node_uuid,
|
|
429
|
-
None,
|
|
430
|
-
RELEVANT_SCHEMA_LIMIT,
|
|
431
|
-
)
|
|
432
|
-
for edge in extracted_edges_with_resolved_pointers
|
|
433
|
-
]
|
|
434
|
-
)
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
existing_target_edges_list: list[list[EntityEdge]] = list(
|
|
438
|
-
await semaphore_gather(
|
|
439
|
-
*[
|
|
440
|
-
get_relevant_edges(
|
|
441
|
-
self.driver,
|
|
442
|
-
[edge],
|
|
443
|
-
None,
|
|
444
|
-
edge.target_node_uuid,
|
|
445
|
-
RELEVANT_SCHEMA_LIMIT,
|
|
446
|
-
)
|
|
447
|
-
for edge in extracted_edges_with_resolved_pointers
|
|
448
|
-
]
|
|
449
|
-
)
|
|
450
|
-
)
|
|
451
|
-
|
|
452
|
-
existing_edges_list: list[list[EntityEdge]] = [
|
|
453
|
-
source_lst + target_lst
|
|
454
|
-
for source_lst, target_lst in zip(
|
|
455
|
-
existing_source_edges_list, existing_target_edges_list, strict=False
|
|
456
|
-
)
|
|
457
|
-
]
|
|
458
|
-
|
|
459
|
-
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
|
460
|
-
self.llm_client,
|
|
461
|
-
extracted_edges_with_resolved_pointers,
|
|
462
|
-
related_edges_list,
|
|
463
|
-
existing_edges_list,
|
|
464
|
-
episode,
|
|
465
|
-
previous_episodes,
|
|
379
|
+
(resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather(
|
|
380
|
+
resolve_extracted_edges(
|
|
381
|
+
self.clients,
|
|
382
|
+
edges,
|
|
383
|
+
),
|
|
384
|
+
extract_attributes_from_nodes(
|
|
385
|
+
self.clients, nodes, episode, previous_episodes, entity_types
|
|
386
|
+
),
|
|
466
387
|
)
|
|
467
388
|
|
|
468
|
-
entity_edges
|
|
389
|
+
entity_edges = resolved_edges + invalidated_edges
|
|
469
390
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
|
473
|
-
|
|
474
|
-
logger.debug(f'Built episodic edges: {episodic_edges}')
|
|
391
|
+
episodic_edges = build_episodic_edges(nodes, episode, now)
|
|
475
392
|
|
|
476
393
|
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
|
477
394
|
|
|
@@ -565,7 +482,7 @@ class Graphiti:
|
|
|
565
482
|
extracted_nodes,
|
|
566
483
|
extracted_edges,
|
|
567
484
|
episodic_edges,
|
|
568
|
-
) = await extract_nodes_and_edges_bulk(self.
|
|
485
|
+
) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs)
|
|
569
486
|
|
|
570
487
|
# Generate embeddings
|
|
571
488
|
await semaphore_gather(
|
|
@@ -684,9 +601,7 @@ class Graphiti:
|
|
|
684
601
|
|
|
685
602
|
edges = (
|
|
686
603
|
await search(
|
|
687
|
-
self.
|
|
688
|
-
self.embedder,
|
|
689
|
-
self.cross_encoder,
|
|
604
|
+
self.clients,
|
|
690
605
|
query,
|
|
691
606
|
group_ids,
|
|
692
607
|
search_config,
|
|
@@ -728,9 +643,7 @@ class Graphiti:
|
|
|
728
643
|
"""
|
|
729
644
|
|
|
730
645
|
return await search(
|
|
731
|
-
self.
|
|
732
|
-
self.embedder,
|
|
733
|
-
self.cross_encoder,
|
|
646
|
+
self.clients,
|
|
734
647
|
query,
|
|
735
648
|
group_ids,
|
|
736
649
|
config,
|
|
@@ -761,26 +674,17 @@ class Graphiti:
|
|
|
761
674
|
await edge.generate_embedding(self.embedder)
|
|
762
675
|
|
|
763
676
|
resolved_nodes, uuid_map = await resolve_extracted_nodes(
|
|
764
|
-
self.
|
|
677
|
+
self.clients,
|
|
765
678
|
[source_node, target_node],
|
|
766
|
-
[
|
|
767
|
-
await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
|
|
768
|
-
await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
|
|
769
|
-
],
|
|
770
679
|
)
|
|
771
680
|
|
|
772
681
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
|
773
682
|
|
|
774
|
-
related_edges = await get_relevant_edges(
|
|
775
|
-
self.driver,
|
|
776
|
-
[updated_edge],
|
|
777
|
-
source_node_uuid=resolved_nodes[0].uuid,
|
|
778
|
-
target_node_uuid=resolved_nodes[1].uuid,
|
|
779
|
-
)
|
|
683
|
+
related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
|
|
780
684
|
|
|
781
|
-
resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges)
|
|
685
|
+
resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges[0])
|
|
782
686
|
|
|
783
|
-
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges)
|
|
687
|
+
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
|
|
784
688
|
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
|
785
689
|
|
|
786
690
|
await add_nodes_and_edges_bulk(
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from neo4j import AsyncDriver
|
|
18
|
+
from pydantic import BaseModel, ConfigDict
|
|
19
|
+
|
|
20
|
+
from graphiti_core.cross_encoder import CrossEncoderClient
|
|
21
|
+
from graphiti_core.embedder import EmbedderClient
|
|
22
|
+
from graphiti_core.llm_client import LLMClient
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GraphitiClients(BaseModel):
|
|
26
|
+
driver: AsyncDriver
|
|
27
|
+
llm_client: LLMClient
|
|
28
|
+
embedder: EmbedderClient
|
|
29
|
+
cross_encoder: CrossEncoderClient
|
|
30
|
+
|
|
31
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
graphiti_core/helpers.py
CHANGED
|
@@ -22,15 +22,20 @@ from datetime import datetime
|
|
|
22
22
|
import numpy as np
|
|
23
23
|
from dotenv import load_dotenv
|
|
24
24
|
from neo4j import time as neo4j_time
|
|
25
|
+
from typing_extensions import LiteralString
|
|
25
26
|
|
|
26
27
|
load_dotenv()
|
|
27
28
|
|
|
28
29
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
|
29
30
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
30
31
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
|
31
|
-
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS',
|
|
32
|
+
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
|
32
33
|
DEFAULT_PAGE_LIMIT = 20
|
|
33
34
|
|
|
35
|
+
RUNTIME_QUERY: LiteralString = (
|
|
36
|
+
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
37
|
+
)
|
|
38
|
+
|
|
34
39
|
|
|
35
40
|
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
36
41
|
return neo_date.to_native() if neo_date else None
|
|
@@ -262,7 +262,7 @@ class AnthropicClient(LLMClient):
|
|
|
262
262
|
self,
|
|
263
263
|
messages: list[Message],
|
|
264
264
|
response_model: type[BaseModel] | None = None,
|
|
265
|
-
max_tokens: int =
|
|
265
|
+
max_tokens: int | None = None,
|
|
266
266
|
) -> dict[str, typing.Any]:
|
|
267
267
|
"""
|
|
268
268
|
Generate a response from the LLM.
|
|
@@ -280,6 +280,9 @@ class AnthropicClient(LLMClient):
|
|
|
280
280
|
RefusalError: If the LLM refuses to respond.
|
|
281
281
|
Exception: If an error occurs during the generation process.
|
|
282
282
|
"""
|
|
283
|
+
if max_tokens is None:
|
|
284
|
+
max_tokens = self.max_tokens
|
|
285
|
+
|
|
283
286
|
retry_count = 0
|
|
284
287
|
max_retries = 2
|
|
285
288
|
last_error: Exception | None = None
|
|
@@ -127,8 +127,11 @@ class LLMClient(ABC):
|
|
|
127
127
|
self,
|
|
128
128
|
messages: list[Message],
|
|
129
129
|
response_model: type[BaseModel] | None = None,
|
|
130
|
-
max_tokens: int =
|
|
130
|
+
max_tokens: int | None = None,
|
|
131
131
|
) -> dict[str, typing.Any]:
|
|
132
|
+
if max_tokens is None:
|
|
133
|
+
max_tokens = self.max_tokens
|
|
134
|
+
|
|
132
135
|
if response_model is not None:
|
|
133
136
|
serialized_model = json.dumps(response_model.model_json_schema())
|
|
134
137
|
messages[
|
|
@@ -166,7 +166,7 @@ class GeminiClient(LLMClient):
|
|
|
166
166
|
self,
|
|
167
167
|
messages: list[Message],
|
|
168
168
|
response_model: type[BaseModel] | None = None,
|
|
169
|
-
max_tokens: int =
|
|
169
|
+
max_tokens: int | None = None,
|
|
170
170
|
) -> dict[str, typing.Any]:
|
|
171
171
|
"""
|
|
172
172
|
Generate a response from the Gemini language model.
|
|
@@ -180,6 +180,9 @@ class GeminiClient(LLMClient):
|
|
|
180
180
|
Returns:
|
|
181
181
|
dict[str, typing.Any]: The response from the language model.
|
|
182
182
|
"""
|
|
183
|
+
if max_tokens is None:
|
|
184
|
+
max_tokens = self.max_tokens
|
|
185
|
+
|
|
183
186
|
# Call the internal _generate_response method
|
|
184
187
|
return await self._generate_response(
|
|
185
188
|
messages=messages, response_model=response_model, max_tokens=max_tokens
|
|
@@ -131,8 +131,11 @@ class OpenAIClient(LLMClient):
|
|
|
131
131
|
self,
|
|
132
132
|
messages: list[Message],
|
|
133
133
|
response_model: type[BaseModel] | None = None,
|
|
134
|
-
max_tokens: int =
|
|
134
|
+
max_tokens: int | None = None,
|
|
135
135
|
) -> dict[str, typing.Any]:
|
|
136
|
+
if max_tokens is None:
|
|
137
|
+
max_tokens = self.max_tokens
|
|
138
|
+
|
|
136
139
|
retry_count = 0
|
|
137
140
|
last_error = None
|
|
138
141
|
|
|
@@ -117,8 +117,11 @@ class OpenAIGenericClient(LLMClient):
|
|
|
117
117
|
self,
|
|
118
118
|
messages: list[Message],
|
|
119
119
|
response_model: type[BaseModel] | None = None,
|
|
120
|
-
max_tokens: int =
|
|
120
|
+
max_tokens: int | None = None,
|
|
121
121
|
) -> dict[str, typing.Any]:
|
|
122
|
+
if max_tokens is None:
|
|
123
|
+
max_tokens = self.max_tokens
|
|
124
|
+
|
|
122
125
|
retry_count = 0
|
|
123
126
|
last_error = None
|
|
124
127
|
|
|
@@ -47,7 +47,7 @@ ENTITY_EDGE_SAVE_BULK = """
|
|
|
47
47
|
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
|
48
48
|
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
|
|
49
49
|
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
|
50
|
-
RETURN
|
|
50
|
+
RETURN edge.uuid AS uuid
|
|
51
51
|
"""
|
|
52
52
|
|
|
53
53
|
COMMUNITY_EDGE_SAVE = """
|
graphiti_core/nodes.py
CHANGED
|
@@ -39,8 +39,6 @@ from graphiti_core.utils.datetime_utils import utc_now
|
|
|
39
39
|
logger = logging.getLogger(__name__)
|
|
40
40
|
|
|
41
41
|
ENTITY_NODE_RETURN: LiteralString = """
|
|
42
|
-
OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
|
|
43
|
-
WITH n, collect(e.uuid) AS episodes
|
|
44
42
|
RETURN
|
|
45
43
|
n.uuid As uuid,
|
|
46
44
|
n.name AS name,
|
|
@@ -49,8 +47,8 @@ ENTITY_NODE_RETURN: LiteralString = """
|
|
|
49
47
|
n.created_at AS created_at,
|
|
50
48
|
n.summary AS summary,
|
|
51
49
|
labels(n) AS labels,
|
|
52
|
-
properties(n) AS attributes
|
|
53
|
-
|
|
50
|
+
properties(n) AS attributes
|
|
51
|
+
"""
|
|
54
52
|
|
|
55
53
|
|
|
56
54
|
class EpisodeType(Enum):
|
|
@@ -294,9 +292,6 @@ class EpisodicNode(Node):
|
|
|
294
292
|
class EntityNode(Node):
|
|
295
293
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
296
294
|
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
|
297
|
-
episodes: list[str] | None = Field(
|
|
298
|
-
default=None, description='List of episode uuids that mention this node.'
|
|
299
|
-
)
|
|
300
295
|
attributes: dict[str, Any] = Field(
|
|
301
296
|
default={}, description='Additional attributes of the node. Dependent on node labels'
|
|
302
297
|
)
|
|
@@ -337,8 +332,8 @@ class EntityNode(Node):
|
|
|
337
332
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
338
333
|
query = (
|
|
339
334
|
"""
|
|
340
|
-
|
|
341
|
-
|
|
335
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
336
|
+
"""
|
|
342
337
|
+ ENTITY_NODE_RETURN
|
|
343
338
|
)
|
|
344
339
|
records, _, _ = await driver.execute_query(
|
|
@@ -544,7 +539,6 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
|
544
539
|
created_at=record['created_at'].to_native(),
|
|
545
540
|
summary=record['summary'],
|
|
546
541
|
attributes=record['attributes'],
|
|
547
|
-
episodes=record['episodes'],
|
|
548
542
|
)
|
|
549
543
|
|
|
550
544
|
entity_node.attributes.pop('uuid', None)
|
|
@@ -566,3 +560,9 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|
|
566
560
|
created_at=record['created_at'].to_native(),
|
|
567
561
|
summary=record['summary'],
|
|
568
562
|
)
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
|
|
566
|
+
name_embeddings = await embedder.create_batch([node.name for node in nodes])
|
|
567
|
+
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
|
|
568
|
+
node.name_embedding = name_embedding
|