graphiti-core 0.10.5__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/edges.py CHANGED
@@ -37,6 +37,21 @@ from graphiti_core.nodes import Node
37
37
 
38
38
  logger = logging.getLogger(__name__)
39
39
 
40
+ ENTITY_EDGE_RETURN: LiteralString = """
41
+ RETURN
42
+ e.uuid AS uuid,
43
+ startNode(e).uuid AS source_node_uuid,
44
+ endNode(e).uuid AS target_node_uuid,
45
+ e.created_at AS created_at,
46
+ e.name AS name,
47
+ e.group_id AS group_id,
48
+ e.fact AS fact,
49
+ e.fact_embedding AS fact_embedding,
50
+ e.episodes AS episodes,
51
+ e.expired_at AS expired_at,
52
+ e.valid_at AS valid_at,
53
+ e.invalid_at AS invalid_at"""
54
+
40
55
 
41
56
  class Edge(BaseModel, ABC):
42
57
  uuid: str = Field(default_factory=lambda: str(uuid4()))
@@ -234,20 +249,8 @@ class EntityEdge(Edge):
234
249
  records, _, _ = await driver.execute_query(
235
250
  """
236
251
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
237
- RETURN
238
- e.uuid AS uuid,
239
- n.uuid AS source_node_uuid,
240
- m.uuid AS target_node_uuid,
241
- e.created_at AS created_at,
242
- e.name AS name,
243
- e.group_id AS group_id,
244
- e.fact AS fact,
245
- e.fact_embedding AS fact_embedding,
246
- e.episodes AS episodes,
247
- e.expired_at AS expired_at,
248
- e.valid_at AS valid_at,
249
- e.invalid_at AS invalid_at
250
- """,
252
+ """
253
+ + ENTITY_EDGE_RETURN,
251
254
  uuid=uuid,
252
255
  database_=DEFAULT_DATABASE,
253
256
  routing_='r',
@@ -268,20 +271,8 @@ class EntityEdge(Edge):
268
271
  """
269
272
  MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
270
273
  WHERE e.uuid IN $uuids
271
- RETURN
272
- e.uuid AS uuid,
273
- n.uuid AS source_node_uuid,
274
- m.uuid AS target_node_uuid,
275
- e.created_at AS created_at,
276
- e.name AS name,
277
- e.group_id AS group_id,
278
- e.fact AS fact,
279
- e.fact_embedding AS fact_embedding,
280
- e.episodes AS episodes,
281
- e.expired_at AS expired_at,
282
- e.valid_at AS valid_at,
283
- e.invalid_at AS invalid_at
284
- """,
274
+ """
275
+ + ENTITY_EDGE_RETURN,
285
276
  uuids=uuids,
286
277
  database_=DEFAULT_DATABASE,
287
278
  routing_='r',
@@ -308,20 +299,8 @@ class EntityEdge(Edge):
308
299
  WHERE e.group_id IN $group_ids
309
300
  """
310
301
  + cursor_query
302
+ + ENTITY_EDGE_RETURN
311
303
  + """
312
- RETURN
313
- e.uuid AS uuid,
314
- n.uuid AS source_node_uuid,
315
- m.uuid AS target_node_uuid,
316
- e.created_at AS created_at,
317
- e.name AS name,
318
- e.group_id AS group_id,
319
- e.fact AS fact,
320
- e.fact_embedding AS fact_embedding,
321
- e.episodes AS episodes,
322
- e.expired_at AS expired_at,
323
- e.valid_at AS valid_at,
324
- e.invalid_at AS invalid_at
325
304
  ORDER BY e.uuid DESC
326
305
  """
327
306
  + limit_query,
@@ -340,22 +319,12 @@ class EntityEdge(Edge):
340
319
 
341
320
  @classmethod
342
321
  async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
343
- query: LiteralString = """
344
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
345
- RETURN DISTINCT
346
- e.uuid AS uuid,
347
- n.uuid AS source_node_uuid,
348
- m.uuid AS target_node_uuid,
349
- e.created_at AS created_at,
350
- e.name AS name,
351
- e.group_id AS group_id,
352
- e.fact AS fact,
353
- e.fact_embedding AS fact_embedding,
354
- e.episodes AS episodes,
355
- e.expired_at AS expired_at,
356
- e.valid_at AS valid_at,
357
- e.invalid_at AS invalid_at
358
- """
322
+ query: LiteralString = (
323
+ """
324
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
325
+ """
326
+ + ENTITY_EDGE_RETURN
327
+ )
359
328
  records, _, _ = await driver.execute_query(
360
329
  query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
361
330
  )
@@ -499,3 +468,9 @@ def get_community_edge_from_record(record: Any):
499
468
  target_node_uuid=record['target_node_uuid'],
500
469
  created_at=record['created_at'].to_native(),
501
470
  )
471
+
472
+
473
+ async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
474
+ fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
475
+ for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
476
+ edge.fact_embedding = fact_embedding
@@ -32,3 +32,6 @@ class EmbedderClient(ABC):
32
32
  self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
33
33
  ) -> list[float]:
34
34
  pass
35
+
36
+ async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
37
+ raise NotImplementedError()
@@ -66,3 +66,13 @@ class GeminiEmbedder(EmbedderClient):
66
66
  )
67
67
 
68
68
  return result.embeddings[0].values
69
+
70
+ async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
71
+ # Generate embeddings
72
+ result = await self.client.aio.models.embed_content(
73
+ model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
74
+ contents=input_data_list,
75
+ config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
76
+ )
77
+
78
+ return [embedding.values for embedding in result.embeddings]
@@ -58,3 +58,9 @@ class OpenAIEmbedder(EmbedderClient):
58
58
  input=input_data, model=self.config.embedding_model
59
59
  )
60
60
  return result.data[0].embedding[: self.config.embedding_dim]
61
+
62
+ async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
63
+ result = await self.client.embeddings.create(
64
+ input=input_data_list, model=self.config.embedding_model
65
+ )
66
+ return [embedding.embedding[: self.config.embedding_dim] for embedding in result.data]
@@ -56,3 +56,10 @@ class VoyageAIEmbedder(EmbedderClient):
56
56
 
57
57
  result = await self.client.embed(input_list, model=self.config.embedding_model)
58
58
  return [float(x) for x in result.embeddings[0][: self.config.embedding_dim]]
59
+
60
+ async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
61
+ result = await self.client.embed(input_data_list, model=self.config.embedding_model)
62
+ return [
63
+ [float(x) for x in embedding[: self.config.embedding_dim]]
64
+ for embedding in result.embeddings
65
+ ]
graphiti_core/graphiti.py CHANGED
@@ -27,6 +27,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
27
27
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
28
28
  from graphiti_core.edges import EntityEdge, EpisodicEdge
29
29
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
+ from graphiti_core.graphiti_types import GraphitiClients
30
31
  from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
31
32
  from graphiti_core.llm_client import LLMClient, OpenAIClient
32
33
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
@@ -42,7 +43,6 @@ from graphiti_core.search.search_utils import (
42
43
  RELEVANT_SCHEMA_LIMIT,
43
44
  get_mentioned_nodes,
44
45
  get_relevant_edges,
45
- get_relevant_nodes,
46
46
  )
47
47
  from graphiti_core.utils.bulk_utils import (
48
48
  RawEpisode,
@@ -72,7 +72,11 @@ from graphiti_core.utils.maintenance.graph_data_operations import (
72
72
  build_indices_and_constraints,
73
73
  retrieve_episodes,
74
74
  )
75
- from graphiti_core.utils.maintenance.node_operations import extract_nodes, resolve_extracted_nodes
75
+ from graphiti_core.utils.maintenance.node_operations import (
76
+ extract_attributes_from_nodes,
77
+ extract_nodes,
78
+ resolve_extracted_nodes,
79
+ )
76
80
  from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
77
81
  from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
78
82
 
@@ -150,6 +154,13 @@ class Graphiti:
150
154
  else:
151
155
  self.cross_encoder = OpenAIRerankerClient()
152
156
 
157
+ self.clients = GraphitiClients(
158
+ driver=self.driver,
159
+ llm_client=self.llm_client,
160
+ embedder=self.embedder,
161
+ cross_encoder=self.cross_encoder,
162
+ )
163
+
153
164
  async def close(self):
154
165
  """
155
166
  Close the connection to the Neo4j database.
@@ -222,6 +233,7 @@ class Graphiti:
222
233
  reference_time: datetime,
223
234
  last_n: int = EPISODE_WINDOW_LEN,
224
235
  group_ids: list[str] | None = None,
236
+ source: EpisodeType | None = None,
225
237
  ) -> list[EpisodicNode]:
226
238
  """
227
239
  Retrieve the last n episodic nodes from the graph.
@@ -248,7 +260,7 @@ class Graphiti:
248
260
  The actual retrieval is performed by the `retrieve_episodes` function
249
261
  from the `graphiti_core.utils` module.
250
262
  """
251
- return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
263
+ return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
252
264
 
253
265
  async def add_episode(
254
266
  self,
@@ -314,15 +326,16 @@ class Graphiti:
314
326
  """
315
327
  try:
316
328
  start = time()
317
-
318
- entity_edges: list[EntityEdge] = []
319
329
  now = utc_now()
320
330
 
321
331
  validate_entity_types(entity_types)
322
332
 
323
333
  previous_episodes = (
324
334
  await self.retrieve_episodes(
325
- reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
335
+ reference_time,
336
+ last_n=RELEVANT_SCHEMA_LIMIT,
337
+ group_ids=[group_id],
338
+ source=source,
326
339
  )
327
340
  if previous_episode_uuids is None
328
341
  else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
@@ -346,132 +359,36 @@ class Graphiti:
346
359
  # Extract entities as nodes
347
360
 
348
361
  extracted_nodes = await extract_nodes(
349
- self.llm_client, episode, previous_episodes, entity_types
362
+ self.clients, episode, previous_episodes, entity_types
350
363
  )
351
- logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
352
-
353
- # Calculate Embeddings
354
364
 
355
- await semaphore_gather(
356
- *[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
357
- )
358
-
359
- # Find relevant nodes already in the graph
360
- existing_nodes_lists: list[list[EntityNode]] = list(
361
- await semaphore_gather(
362
- *[
363
- get_relevant_nodes(self.driver, SearchFilters(), [node])
364
- for node in extracted_nodes
365
- ]
366
- )
367
- )
368
-
369
- # Resolve extracted nodes with nodes already in the graph and extract facts
370
- logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
371
-
372
- (mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
365
+ # Extract edges and resolve nodes
366
+ (nodes, uuid_map), extracted_edges = await semaphore_gather(
373
367
  resolve_extracted_nodes(
374
- self.llm_client,
368
+ self.clients,
375
369
  extracted_nodes,
376
- existing_nodes_lists,
377
370
  episode,
378
371
  previous_episodes,
379
372
  entity_types,
380
373
  ),
381
- extract_edges(
382
- self.llm_client, episode, extracted_nodes, previous_episodes, group_id
383
- ),
384
- )
385
- logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
386
- nodes = mentioned_nodes
387
-
388
- extracted_edges_with_resolved_pointers = resolve_edge_pointers(
389
- extracted_edges, uuid_map
374
+ extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
390
375
  )
391
376
 
392
- # calculate embeddings
393
- await semaphore_gather(
394
- *[
395
- edge.generate_embedding(self.embedder)
396
- for edge in extracted_edges_with_resolved_pointers
397
- ]
398
- )
377
+ edges = resolve_edge_pointers(extracted_edges, uuid_map)
399
378
 
400
- # Resolve extracted edges with related edges already in the graph
401
- related_edges_list: list[list[EntityEdge]] = list(
402
- await semaphore_gather(
403
- *[
404
- get_relevant_edges(
405
- self.driver,
406
- [edge],
407
- edge.source_node_uuid,
408
- edge.target_node_uuid,
409
- RELEVANT_SCHEMA_LIMIT,
410
- )
411
- for edge in extracted_edges_with_resolved_pointers
412
- ]
413
- )
414
- )
415
- logger.debug(
416
- f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
417
- )
418
- logger.debug(
419
- f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
420
- )
421
-
422
- existing_source_edges_list: list[list[EntityEdge]] = list(
423
- await semaphore_gather(
424
- *[
425
- get_relevant_edges(
426
- self.driver,
427
- [edge],
428
- edge.source_node_uuid,
429
- None,
430
- RELEVANT_SCHEMA_LIMIT,
431
- )
432
- for edge in extracted_edges_with_resolved_pointers
433
- ]
434
- )
435
- )
436
-
437
- existing_target_edges_list: list[list[EntityEdge]] = list(
438
- await semaphore_gather(
439
- *[
440
- get_relevant_edges(
441
- self.driver,
442
- [edge],
443
- None,
444
- edge.target_node_uuid,
445
- RELEVANT_SCHEMA_LIMIT,
446
- )
447
- for edge in extracted_edges_with_resolved_pointers
448
- ]
449
- )
450
- )
451
-
452
- existing_edges_list: list[list[EntityEdge]] = [
453
- source_lst + target_lst
454
- for source_lst, target_lst in zip(
455
- existing_source_edges_list, existing_target_edges_list, strict=False
456
- )
457
- ]
458
-
459
- resolved_edges, invalidated_edges = await resolve_extracted_edges(
460
- self.llm_client,
461
- extracted_edges_with_resolved_pointers,
462
- related_edges_list,
463
- existing_edges_list,
464
- episode,
465
- previous_episodes,
379
+ (resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather(
380
+ resolve_extracted_edges(
381
+ self.clients,
382
+ edges,
383
+ ),
384
+ extract_attributes_from_nodes(
385
+ self.clients, nodes, episode, previous_episodes, entity_types
386
+ ),
466
387
  )
467
388
 
468
- entity_edges.extend(resolved_edges + invalidated_edges)
389
+ entity_edges = resolved_edges + invalidated_edges
469
390
 
470
- logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
471
-
472
- episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
473
-
474
- logger.debug(f'Built episodic edges: {episodic_edges}')
391
+ episodic_edges = build_episodic_edges(nodes, episode, now)
475
392
 
476
393
  episode.entity_edges = [edge.uuid for edge in entity_edges]
477
394
 
@@ -565,7 +482,7 @@ class Graphiti:
565
482
  extracted_nodes,
566
483
  extracted_edges,
567
484
  episodic_edges,
568
- ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
485
+ ) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs)
569
486
 
570
487
  # Generate embeddings
571
488
  await semaphore_gather(
@@ -684,9 +601,7 @@ class Graphiti:
684
601
 
685
602
  edges = (
686
603
  await search(
687
- self.driver,
688
- self.embedder,
689
- self.cross_encoder,
604
+ self.clients,
690
605
  query,
691
606
  group_ids,
692
607
  search_config,
@@ -728,9 +643,7 @@ class Graphiti:
728
643
  """
729
644
 
730
645
  return await search(
731
- self.driver,
732
- self.embedder,
733
- self.cross_encoder,
646
+ self.clients,
734
647
  query,
735
648
  group_ids,
736
649
  config,
@@ -761,26 +674,17 @@ class Graphiti:
761
674
  await edge.generate_embedding(self.embedder)
762
675
 
763
676
  resolved_nodes, uuid_map = await resolve_extracted_nodes(
764
- self.llm_client,
677
+ self.clients,
765
678
  [source_node, target_node],
766
- [
767
- await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
768
- await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
769
- ],
770
679
  )
771
680
 
772
681
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
773
682
 
774
- related_edges = await get_relevant_edges(
775
- self.driver,
776
- [updated_edge],
777
- source_node_uuid=resolved_nodes[0].uuid,
778
- target_node_uuid=resolved_nodes[1].uuid,
779
- )
683
+ related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
780
684
 
781
- resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges)
685
+ resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges[0])
782
686
 
783
- contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges)
687
+ contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
784
688
  invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
785
689
 
786
690
  await add_nodes_and_edges_bulk(
@@ -0,0 +1,31 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from neo4j import AsyncDriver
18
+ from pydantic import BaseModel, ConfigDict
19
+
20
+ from graphiti_core.cross_encoder import CrossEncoderClient
21
+ from graphiti_core.embedder import EmbedderClient
22
+ from graphiti_core.llm_client import LLMClient
23
+
24
+
25
+ class GraphitiClients(BaseModel):
26
+ driver: AsyncDriver
27
+ llm_client: LLMClient
28
+ embedder: EmbedderClient
29
+ cross_encoder: CrossEncoderClient
30
+
31
+ model_config = ConfigDict(arbitrary_types_allowed=True)
graphiti_core/helpers.py CHANGED
@@ -22,15 +22,20 @@ from datetime import datetime
22
22
  import numpy as np
23
23
  from dotenv import load_dotenv
24
24
  from neo4j import time as neo4j_time
25
+ from typing_extensions import LiteralString
25
26
 
26
27
  load_dotenv()
27
28
 
28
29
  DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
29
30
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
30
31
  SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
31
- MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 2))
32
+ MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
32
33
  DEFAULT_PAGE_LIMIT = 20
33
34
 
35
+ RUNTIME_QUERY: LiteralString = (
36
+ 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
37
+ )
38
+
34
39
 
35
40
  def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
36
41
  return neo_date.to_native() if neo_date else None
@@ -47,7 +47,7 @@ ENTITY_EDGE_SAVE_BULK = """
47
47
  SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
48
48
  created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
49
49
  WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
50
- RETURN r.uuid AS uuid
50
+ RETURN edge.uuid AS uuid
51
51
  """
52
52
 
53
53
  COMMUNITY_EDGE_SAVE = """
graphiti_core/nodes.py CHANGED
@@ -332,8 +332,8 @@ class EntityNode(Node):
332
332
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
333
333
  query = (
334
334
  """
335
- MATCH (n:Entity {uuid: $uuid})
336
- """
335
+ MATCH (n:Entity {uuid: $uuid})
336
+ """
337
337
  + ENTITY_NODE_RETURN
338
338
  )
339
339
  records, _, _ = await driver.execute_query(
@@ -560,3 +560,9 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
560
560
  created_at=record['created_at'].to_native(),
561
561
  summary=record['summary'],
562
562
  )
563
+
564
+
565
+ async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
566
+ name_embeddings = await embedder.create_batch([node.name for node in nodes])
567
+ for node, name_embedding in zip(nodes, name_embeddings, strict=True):
568
+ node.name_embedding = name_embedding
@@ -23,10 +23,9 @@ from .models import Message, PromptFunction, PromptVersion
23
23
 
24
24
 
25
25
  class EdgeDuplicate(BaseModel):
26
- is_duplicate: bool = Field(..., description='true or false')
27
- uuid: str | None = Field(
28
- None,
29
- description="uuid of the existing edge like '5d643020624c42fa9de13f97b1b3fa39' or null",
26
+ duplicate_fact_id: int = Field(
27
+ ...,
28
+ description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
30
29
  )
31
30
 
32
31
 
@@ -69,9 +68,8 @@ def edge(context: dict[str, Any]) -> list[Message]:
69
68
  </NEW EDGE>
70
69
 
71
70
  Task:
72
- 1. If the New Edges represents the same factual information as any edge in Existing Edges, return 'is_duplicate: true' in the
73
- response. Otherwise, return 'is_duplicate: false'
74
- 2. If is_duplicate is true, also return the uuid of the existing edge in the response
71
+ If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact.
72
+ If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return -1.
75
73
 
76
74
  Guidelines:
77
75
  1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
@@ -23,14 +23,9 @@ from .models import Message, PromptFunction, PromptVersion
23
23
 
24
24
 
25
25
  class NodeDuplicate(BaseModel):
26
- is_duplicate: bool = Field(..., description='true or false')
27
- uuid: str | None = Field(
28
- None,
29
- description="uuid of the existing node like '5d643020624c42fa9de13f97b1b3fa39' or null",
30
- )
31
- name: str = Field(
26
+ duplicate_node_id: int = Field(
32
27
  ...,
33
- description="Updated name of the new node (use the best name between the new node's name, an existing duplicate name, or a combination of both)",
28
+ description='id of the duplicate node. If no duplicate nodes are found, default to -1.',
34
29
  )
35
30
 
36
31
 
@@ -64,28 +59,20 @@ def node(context: dict[str, Any]) -> list[Message]:
64
59
  {json.dumps(context['existing_nodes'], indent=2)}
65
60
  </EXISTING NODES>
66
61
 
67
- Given the above EXISTING NODES and their attributes, MESSAGE, and PREVIOUS MESSAGES. Determine if the NEW NODE extracted from the conversation
62
+ Given the above EXISTING NODES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW NODE extracted from the conversation
68
63
  is a duplicate entity of one of the EXISTING NODES.
69
64
 
70
65
  <NEW NODE>
71
- {json.dumps(context['extracted_nodes'], indent=2)}
66
+ {json.dumps(context['extracted_node'], indent=2)}
72
67
  </NEW NODE>
73
68
  Task:
74
- 1. If the New Node represents the same entity as any node in Existing Nodes, return 'is_duplicate: true' in the
75
- response. Otherwise, return 'is_duplicate: false'
76
- 2. If is_duplicate is true, also return the uuid of the existing node in the response
77
- 3. If is_duplicate is true, return a name for the node that is the most complete full name.
69
+ If the NEW NODE is a duplicate of any node in EXISTING NODES, set duplicate_node_id to the
70
+ id of the EXISTING NODE that is the duplicate. If the NEW NODE is not a duplicate of any of the EXISTING NODES,
71
+ duplicate_node_id should be set to -1.
78
72
 
79
73
  Guidelines:
80
- 1. Use both the name and summary of nodes to determine if the entities are duplicates,
74
+ 1. Use the name, summary, and attributes of nodes to determine if the entities are duplicates,
81
75
  duplicate nodes may have different names
82
-
83
- Respond with a JSON object in the following format:
84
- {{
85
- "is_duplicate": true or false,
86
- "uuid": "uuid of the existing node like 5d643020624c42fa9de13f97b1b3fa39 or null",
87
- "name": "Updated name of the new node (use the best name between the new node's name, an existing duplicate name, or a combination of both)"
88
- }}
89
76
  """,
90
77
  ),
91
78
  ]