graphiti-core 0.11.5__tar.gz → 0.11.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (67) hide show
  1. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/PKG-INFO +1 -2
  2. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/edges.py +18 -4
  3. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/graphiti.py +12 -10
  4. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/helpers.py +8 -10
  5. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/llm_client/anthropic_client.py +0 -13
  6. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/nodes.py +30 -7
  7. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/prompts/dedupe_edges.py +44 -1
  8. graphiti_core-0.11.6/graphiti_core/prompts/dedupe_nodes.py +200 -0
  9. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/prompts/extract_nodes.py +1 -1
  10. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/prompts/invalidate_edges.py +1 -1
  11. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/prompts/summarize_nodes.py +4 -4
  12. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/search/search.py +25 -42
  13. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/search/search_utils.py +117 -20
  14. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/utils/bulk_utils.py +15 -1
  15. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/utils/maintenance/community_operations.py +0 -2
  16. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/utils/maintenance/edge_operations.py +63 -15
  17. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/utils/maintenance/node_operations.py +63 -33
  18. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/pyproject.toml +1 -2
  19. graphiti_core-0.11.5/graphiti_core/prompts/dedupe_nodes.py +0 -122
  20. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/LICENSE +0 -0
  21. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/README.md +0 -0
  22. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/__init__.py +0 -0
  23. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/cross_encoder/__init__.py +0 -0
  24. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
  25. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/cross_encoder/client.py +0 -0
  26. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/cross_encoder/openai_reranker_client.py +0 -0
  27. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/embedder/__init__.py +0 -0
  28. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/embedder/client.py +0 -0
  29. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/embedder/gemini.py +0 -0
  30. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/embedder/openai.py +0 -0
  31. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/embedder/voyage.py +0 -0
  32. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/errors.py +0 -0
  33. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/graphiti_types.py +0 -0
  34. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/llm_client/__init__.py +0 -0
  35. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/llm_client/client.py +0 -0
  36. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/llm_client/config.py +0 -0
  37. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/llm_client/errors.py +0 -0
  38. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/llm_client/gemini_client.py +0 -0
  39. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/llm_client/groq_client.py +0 -0
  40. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/llm_client/openai_client.py +0 -0
  41. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/llm_client/openai_generic_client.py +0 -0
  42. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/llm_client/utils.py +0 -0
  43. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/models/__init__.py +0 -0
  44. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/models/edges/__init__.py +0 -0
  45. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/models/edges/edge_db_queries.py +0 -0
  46. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/models/nodes/__init__.py +0 -0
  47. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/models/nodes/node_db_queries.py +0 -0
  48. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/prompts/__init__.py +0 -0
  49. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/prompts/eval.py +0 -0
  50. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  51. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/prompts/extract_edges.py +0 -0
  52. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/prompts/lib.py +0 -0
  53. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/prompts/models.py +0 -0
  54. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/prompts/prompt_helpers.py +0 -0
  55. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/py.typed +0 -0
  56. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/search/__init__.py +0 -0
  57. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/search/search_config.py +0 -0
  58. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/search/search_config_recipes.py +0 -0
  59. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/search/search_filters.py +0 -0
  60. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/search/search_helpers.py +0 -0
  61. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/utils/__init__.py +0 -0
  62. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/utils/datetime_utils.py +0 -0
  63. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/utils/maintenance/__init__.py +0 -0
  64. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
  65. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
  66. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/utils/maintenance/utils.py +0 -0
  67. {graphiti_core-0.11.5 → graphiti_core-0.11.6}/graphiti_core/utils/ontology_utils/entity_types_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: graphiti-core
3
- Version: 0.11.5
3
+ Version: 0.11.6
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -18,7 +18,6 @@ Provides-Extra: groq
18
18
  Requires-Dist: anthropic (>=0.49.0) ; extra == "anthropic"
19
19
  Requires-Dist: diskcache (>=5.6.3)
20
20
  Requires-Dist: google-genai (>=1.8.0) ; extra == "google-genai"
21
- Requires-Dist: graph-service (>=1.0.0.7,<2.0.0.0)
22
21
  Requires-Dist: groq (>=0.2.0) ; extra == "groq"
23
22
  Requires-Dist: neo4j (>=5.23.0)
24
23
  Requires-Dist: numpy (>=1.0.0)
@@ -46,7 +46,6 @@ ENTITY_EDGE_RETURN: LiteralString = """
46
46
  e.name AS name,
47
47
  e.group_id AS group_id,
48
48
  e.fact AS fact,
49
- e.fact_embedding AS fact_embedding,
50
49
  e.episodes AS episodes,
51
50
  e.expired_at AS expired_at,
52
51
  e.valid_at AS valid_at,
@@ -222,6 +221,20 @@ class EntityEdge(Edge):
222
221
 
223
222
  return self.fact_embedding
224
223
 
224
+ async def load_fact_embedding(self, driver: AsyncDriver):
225
+ query: LiteralString = """
226
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
227
+ RETURN e.fact_embedding AS fact_embedding
228
+ """
229
+ records, _, _ = await driver.execute_query(
230
+ query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
231
+ )
232
+
233
+ if len(records) == 0:
234
+ raise EdgeNotFoundError(self.uuid)
235
+
236
+ self.fact_embedding = records[0]['fact_embedding']
237
+
225
238
  async def save(self, driver: AsyncDriver):
226
239
  result = await driver.execute_query(
227
240
  ENTITY_EDGE_SAVE,
@@ -321,8 +334,8 @@ class EntityEdge(Edge):
321
334
  async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
322
335
  query: LiteralString = (
323
336
  """
324
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
325
- """
337
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
338
+ """
326
339
  + ENTITY_EDGE_RETURN
327
340
  )
328
341
  records, _, _ = await driver.execute_query(
@@ -452,7 +465,6 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
452
465
  name=record['name'],
453
466
  group_id=record['group_id'],
454
467
  episodes=record['episodes'],
455
- fact_embedding=record['fact_embedding'],
456
468
  created_at=record['created_at'].to_native(),
457
469
  expired_at=parse_db_date(record['expired_at']),
458
470
  valid_at=parse_db_date(record['valid_at']),
@@ -471,6 +483,8 @@ def get_community_edge_from_record(record: Any):
471
483
 
472
484
 
473
485
  async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
486
+ if len(edges) == 0:
487
+ return
474
488
  fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
475
489
  for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
476
490
  edge.fact_embedding = fact_embedding
@@ -41,6 +41,7 @@ from graphiti_core.search.search_config_recipes import (
41
41
  from graphiti_core.search.search_filters import SearchFilters
42
42
  from graphiti_core.search.search_utils import (
43
43
  RELEVANT_SCHEMA_LIMIT,
44
+ get_edge_invalidation_candidates,
44
45
  get_mentioned_nodes,
45
46
  get_relevant_edges,
46
47
  )
@@ -62,9 +63,8 @@ from graphiti_core.utils.maintenance.community_operations import (
62
63
  )
63
64
  from graphiti_core.utils.maintenance.edge_operations import (
64
65
  build_episodic_edges,
65
- dedupe_extracted_edge,
66
66
  extract_edges,
67
- resolve_edge_contradictions,
67
+ resolve_extracted_edge,
68
68
  resolve_extracted_edges,
69
69
  )
70
70
  from graphiti_core.utils.maintenance.graph_data_operations import (
@@ -77,7 +77,6 @@ from graphiti_core.utils.maintenance.node_operations import (
77
77
  extract_nodes,
78
78
  resolve_extracted_nodes,
79
79
  )
80
- from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
81
80
  from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
82
81
 
83
82
  logger = logging.getLogger(__name__)
@@ -380,6 +379,7 @@ class Graphiti:
380
379
  resolve_extracted_edges(
381
380
  self.clients,
382
381
  edges,
382
+ episode,
383
383
  ),
384
384
  extract_attributes_from_nodes(
385
385
  self.clients, nodes, episode, previous_episodes, entity_types
@@ -396,7 +396,7 @@ class Graphiti:
396
396
  episode.content = ''
397
397
 
398
398
  await add_nodes_and_edges_bulk(
399
- self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges
399
+ self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
400
400
  )
401
401
 
402
402
  # Update any communities
@@ -680,15 +680,17 @@ class Graphiti:
680
680
 
681
681
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
682
682
 
683
- related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
684
-
685
- resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges[0])
683
+ related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
684
+ existing_edges = (
685
+ await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
686
+ )[0]
686
687
 
687
- contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
688
- invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
688
+ resolved_edge, invalidated_edges = await resolve_extracted_edge(
689
+ self.llm_client, updated_edge, related_edges, existing_edges
690
+ )
689
691
 
690
692
  await add_nodes_and_edges_bulk(
691
- self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
693
+ self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
692
694
  )
693
695
 
694
696
  async def remove_episode(self, episode_uuid: str):
@@ -22,6 +22,7 @@ from datetime import datetime
22
22
  import numpy as np
23
23
  from dotenv import load_dotenv
24
24
  from neo4j import time as neo4j_time
25
+ from numpy._typing import NDArray
25
26
  from typing_extensions import LiteralString
26
27
 
27
28
  load_dotenv()
@@ -78,20 +79,17 @@ def lucene_sanitize(query: str) -> str:
78
79
  return sanitized
79
80
 
80
81
 
81
- def normalize_l2(embedding: list[float]):
82
+ def normalize_l2(embedding: list[float]) -> NDArray:
82
83
  embedding_array = np.array(embedding)
83
- if embedding_array.ndim == 1:
84
- norm = np.linalg.norm(embedding_array)
85
- if norm == 0:
86
- return [0.0] * len(embedding)
87
- return (embedding_array / norm).tolist()
88
- else:
89
- norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
90
- return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
84
+ norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True)
85
+ return np.where(norm == 0, embedding_array, embedding_array / norm)
91
86
 
92
87
 
93
88
  # Use this instead of asyncio.gather() to bound coroutines
94
- async def semaphore_gather(*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT):
89
+ async def semaphore_gather(
90
+ *coroutines: Coroutine,
91
+ max_coroutines: int = SEMAPHORE_LIMIT,
92
+ ):
95
93
  semaphore = asyncio.Semaphore(max_coroutines)
96
94
 
97
95
  async def _wrap_coroutine(coroutine):
@@ -139,15 +139,11 @@ class AnthropicClient(LLMClient):
139
139
  A list containing a single tool definition for use with the Anthropic API.
140
140
  """
141
141
  if response_model is not None:
142
- # temporary debug log
143
- logger.info(f'Creating tool for response_model: {response_model}')
144
142
  # Use the response_model to define the tool
145
143
  model_schema = response_model.model_json_schema()
146
144
  tool_name = response_model.__name__
147
145
  description = model_schema.get('description', f'Extract {tool_name} information')
148
146
  else:
149
- # temporary debug log
150
- logger.info('Creating generic JSON output tool')
151
147
  # Create a generic JSON output tool
152
148
  tool_name = 'generic_json_output'
153
149
  description = 'Output data in JSON format'
@@ -205,8 +201,6 @@ class AnthropicClient(LLMClient):
205
201
  try:
206
202
  # Create the appropriate tool based on whether response_model is provided
207
203
  tools, tool_choice = self._create_tool(response_model)
208
- # temporary debug log
209
- logger.info(f'using model: {self.model} with max_tokens: {self.max_tokens}')
210
204
  result = await self.client.messages.create(
211
205
  system=system_message.content,
212
206
  max_tokens=max_creation_tokens,
@@ -227,13 +221,6 @@ class AnthropicClient(LLMClient):
227
221
  return tool_args
228
222
 
229
223
  # If we didn't get a proper tool_use response, try to extract from text
230
- # logger.debug(
231
- # f'Did not get a tool_use response, trying to extract json from text. Result: {result.content}'
232
- # )
233
- # temporary debug log
234
- logger.info(
235
- f'Did not get a tool_use response, trying to extract json from text. Result: {result.content}'
236
- )
237
224
  for content_item in result.content:
238
225
  if content_item.type == 'text':
239
226
  return self._extract_json_from_text(content_item.text)
@@ -42,7 +42,6 @@ ENTITY_NODE_RETURN: LiteralString = """
42
42
  RETURN
43
43
  n.uuid As uuid,
44
44
  n.name AS name,
45
- n.name_embedding AS name_embedding,
46
45
  n.group_id AS group_id,
47
46
  n.created_at AS created_at,
48
47
  n.summary AS summary,
@@ -305,6 +304,20 @@ class EntityNode(Node):
305
304
 
306
305
  return self.name_embedding
307
306
 
307
+ async def load_name_embedding(self, driver: AsyncDriver):
308
+ query: LiteralString = """
309
+ MATCH (n:Entity {uuid: $uuid})
310
+ RETURN n.name_embedding AS name_embedding
311
+ """
312
+ records, _, _ = await driver.execute_query(
313
+ query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
314
+ )
315
+
316
+ if len(records) == 0:
317
+ raise NodeNotFoundError(self.uuid)
318
+
319
+ self.name_embedding = records[0]['name_embedding']
320
+
308
321
  async def save(self, driver: AsyncDriver):
309
322
  entity_data: dict[str, Any] = {
310
323
  'uuid': self.uuid,
@@ -332,8 +345,8 @@ class EntityNode(Node):
332
345
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
333
346
  query = (
334
347
  """
335
- MATCH (n:Entity {uuid: $uuid})
336
- """
348
+ MATCH (n:Entity {uuid: $uuid})
349
+ """
337
350
  + ENTITY_NODE_RETURN
338
351
  )
339
352
  records, _, _ = await driver.execute_query(
@@ -428,6 +441,20 @@ class CommunityNode(Node):
428
441
 
429
442
  return self.name_embedding
430
443
 
444
+ async def load_name_embedding(self, driver: AsyncDriver):
445
+ query: LiteralString = """
446
+ MATCH (c:Community {uuid: $uuid})
447
+ RETURN c.name_embedding AS name_embedding
448
+ """
449
+ records, _, _ = await driver.execute_query(
450
+ query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
451
+ )
452
+
453
+ if len(records) == 0:
454
+ raise NodeNotFoundError(self.uuid)
455
+
456
+ self.name_embedding = records[0]['name_embedding']
457
+
431
458
  @classmethod
432
459
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
433
460
  records, _, _ = await driver.execute_query(
@@ -436,7 +463,6 @@ class CommunityNode(Node):
436
463
  RETURN
437
464
  n.uuid As uuid,
438
465
  n.name AS name,
439
- n.name_embedding AS name_embedding,
440
466
  n.group_id AS group_id,
441
467
  n.created_at AS created_at,
442
468
  n.summary AS summary
@@ -461,7 +487,6 @@ class CommunityNode(Node):
461
487
  RETURN
462
488
  n.uuid As uuid,
463
489
  n.name AS name,
464
- n.name_embedding AS name_embedding,
465
490
  n.group_id AS group_id,
466
491
  n.created_at AS created_at,
467
492
  n.summary AS summary
@@ -495,7 +520,6 @@ class CommunityNode(Node):
495
520
  RETURN
496
521
  n.uuid As uuid,
497
522
  n.name AS name,
498
- n.name_embedding AS name_embedding,
499
523
  n.group_id AS group_id,
500
524
  n.created_at AS created_at,
501
525
  n.summary AS summary
@@ -534,7 +558,6 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
534
558
  uuid=record['uuid'],
535
559
  name=record['name'],
536
560
  group_id=record['group_id'],
537
- name_embedding=record['name_embedding'],
538
561
  labels=record['labels'],
539
562
  created_at=record['created_at'].to_native(),
540
563
  summary=record['summary'],
@@ -27,6 +27,10 @@ class EdgeDuplicate(BaseModel):
27
27
  ...,
28
28
  description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
29
29
  )
30
+ contradicted_facts: list[int] = Field(
31
+ ...,
32
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
33
+ )
30
34
 
31
35
 
32
36
  class UniqueFact(BaseModel):
@@ -41,11 +45,13 @@ class UniqueFacts(BaseModel):
41
45
  class Prompt(Protocol):
42
46
  edge: PromptVersion
43
47
  edge_list: PromptVersion
48
+ resolve_edge: PromptVersion
44
49
 
45
50
 
46
51
  class Versions(TypedDict):
47
52
  edge: PromptFunction
48
53
  edge_list: PromptFunction
54
+ resolve_edge: PromptFunction
49
55
 
50
56
 
51
57
  def edge(context: dict[str, Any]) -> list[Message]:
@@ -106,4 +112,41 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
106
112
  ]
107
113
 
108
114
 
109
- versions: Versions = {'edge': edge, 'edge_list': edge_list}
115
+ def resolve_edge(context: dict[str, Any]) -> list[Message]:
116
+ return [
117
+ Message(
118
+ role='system',
119
+ content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
120
+ 'facts are contradicted by the new fact.',
121
+ ),
122
+ Message(
123
+ role='user',
124
+ content=f"""
125
+ <NEW FACT>
126
+ {context['new_edge']}
127
+ </NEW FACT>
128
+
129
+ <EXISTING FACTS>
130
+ {context['existing_edges']}
131
+ </EXISTING FACTS>
132
+ <FACT INVALIDATION CANDIDATES>
133
+ {context['edge_invalidation_candidates']}
134
+ </FACT INVALIDATION CANDIDATES>
135
+
136
+
137
+ Task:
138
+ If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
139
+ If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
140
+
141
+ Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
142
+ Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
143
+ If there are no contradicted facts, return an empty list.
144
+
145
+ Guidelines:
146
+ 1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
147
+ """,
148
+ ),
149
+ ]
150
+
151
+
152
+ versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}
@@ -0,0 +1,200 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import json
18
+ from typing import Any, Protocol, TypedDict
19
+
20
+ from pydantic import BaseModel, Field
21
+
22
+ from .models import Message, PromptFunction, PromptVersion
23
+
24
+
25
+ class NodeDuplicate(BaseModel):
26
+ id: int = Field(..., description='integer id of the entity')
27
+ duplicate_idx: int = Field(
28
+ ...,
29
+ description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
30
+ )
31
+ name: str = Field(
32
+ ...,
33
+ description='Name of the entity. Should be the most complete and descriptive name possible.',
34
+ )
35
+
36
+
37
+ class NodeResolutions(BaseModel):
38
+ entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
39
+
40
+
41
+ class Prompt(Protocol):
42
+ node: PromptVersion
43
+ node_list: PromptVersion
44
+ nodes: PromptVersion
45
+
46
+
47
+ class Versions(TypedDict):
48
+ node: PromptFunction
49
+ node_list: PromptFunction
50
+ nodes: PromptFunction
51
+
52
+
53
+ def node(context: dict[str, Any]) -> list[Message]:
54
+ return [
55
+ Message(
56
+ role='system',
57
+ content='You are a helpful assistant that determines whether or not a NEW ENTITY is a duplicate of any EXISTING ENTITIES.',
58
+ ),
59
+ Message(
60
+ role='user',
61
+ content=f"""
62
+ <PREVIOUS MESSAGES>
63
+ {json.dumps([ep for ep in context['previous_episodes']], indent=2)}
64
+ </PREVIOUS MESSAGES>
65
+ <CURRENT MESSAGE>
66
+ {context['episode_content']}
67
+ </CURRENT MESSAGE>
68
+ <NEW ENTITY>
69
+ {json.dumps(context['extracted_node'], indent=2)}
70
+ </NEW ENTITY>
71
+ <ENTITY TYPE DESCRIPTION>
72
+ {json.dumps(context['entity_type_description'], indent=2)}
73
+ </ENTITY TYPE DESCRIPTION>
74
+
75
+ <EXISTING ENTITIES>
76
+ {json.dumps(context['existing_nodes'], indent=2)}
77
+ </EXISTING ENTITIES>
78
+
79
+ Given the above EXISTING ENTITIES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW ENTITY extracted from the conversation
80
+ is a duplicate entity of one of the EXISTING ENTITIES.
81
+
82
+ Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
83
+
84
+ Do NOT mark entities as duplicates if:
85
+ - They are related but distinct.
86
+ - They have similar names or purposes but refer to separate instances or concepts.
87
+
88
+ Task:
89
+ If the NEW ENTITY represents a duplicate entity of any entity in EXISTING ENTITIES, set duplicate_entity_id to the
90
+ id of the EXISTING ENTITY that is the duplicate.
91
+
92
+ If the NEW ENTITY is not a duplicate of any of the EXISTING ENTITIES,
93
+ duplicate_entity_id should be set to -1.
94
+
95
+ Also return the name that best describes the NEW ENTITY (whether it is the name of the NEW ENTITY, a node it
96
+ is a duplicate of, or a combination of the two).
97
+ """,
98
+ ),
99
+ ]
100
+
101
+
102
+ def nodes(context: dict[str, Any]) -> list[Message]:
103
+ return [
104
+ Message(
105
+ role='system',
106
+ content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
107
+ 'of existing entities.',
108
+ ),
109
+ Message(
110
+ role='user',
111
+ content=f"""
112
+ <PREVIOUS MESSAGES>
113
+ {json.dumps([ep for ep in context['previous_episodes']], indent=2)}
114
+ </PREVIOUS MESSAGES>
115
+ <CURRENT MESSAGE>
116
+ {context['episode_content']}
117
+ </CURRENT MESSAGE>
118
+
119
+
120
+ Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
121
+ Each entity in ENTITIES is represented as a JSON object with the following structure:
122
+ {{
123
+ id: integer id of the entity,
124
+ name: "name of the entity",
125
+ entity_type: "ontological classification of the entity",
126
+ entity_type_description: "Description of what the entity type represents",
127
+ duplication_candidates: [
128
+ {{
129
+ idx: integer index of the candidate entity,
130
+ name: "name of the candidate entity",
131
+ entity_type: "ontological classification of the candidate entity",
132
+ ...<additional attributes>
133
+ }}
134
+ ]
135
+ }}
136
+
137
+ <ENTITIES>
138
+ {json.dumps(context['extracted_nodes'], indent=2)}
139
+ </ENTITIES>
140
+
141
+ For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
142
+
143
+ Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
144
+
145
+ Do NOT mark entities as duplicates if:
146
+ - They are related but distinct.
147
+ - They have similar names or purposes but refer to separate instances or concepts.
148
+
149
+ Task:
150
+ Your response will be a list called entity_resolutions which contains one entry for each entity.
151
+
152
+ For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
153
+ as an integer.
154
+
155
+ - If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
156
+ duplicate of.
157
+ - If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
158
+ """,
159
+ ),
160
+ ]
161
+
162
+
163
+ def node_list(context: dict[str, Any]) -> list[Message]:
164
+ return [
165
+ Message(
166
+ role='system',
167
+ content='You are a helpful assistant that de-duplicates nodes from node lists.',
168
+ ),
169
+ Message(
170
+ role='user',
171
+ content=f"""
172
+ Given the following context, deduplicate a list of nodes:
173
+
174
+ Nodes:
175
+ {json.dumps(context['nodes'], indent=2)}
176
+
177
+ Task:
178
+ 1. Group nodes together such that all duplicate nodes are in the same list of uuids
179
+ 2. All duplicate uuids should be grouped together in the same list
180
+ 3. Also return a new summary that synthesizes the summary into a new short summary
181
+
182
+ Guidelines:
183
+ 1. Each uuid from the list of nodes should appear EXACTLY once in your response
184
+ 2. If a node has no duplicates, it should appear in the response in a list of only one uuid
185
+
186
+ Respond with a JSON object in the following format:
187
+ {{
188
+ "nodes": [
189
+ {{
190
+ "uuids": ["5d643020624c42fa9de13f97b1b3fa39", "node that is a duplicate of 5d643020624c42fa9de13f97b1b3fa39"],
191
+ "summary": "Brief summary of the node summaries that appear in the list of names."
192
+ }}
193
+ ]
194
+ }}
195
+ """,
196
+ ),
197
+ ]
198
+
199
+
200
+ versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
@@ -256,7 +256,7 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
256
256
  1. Do not hallucinate entity property values if they cannot be found in the current context.
257
257
  2. Only use the provided MESSAGES and ENTITY to set attribute values.
258
258
  3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
259
- Summaries must be no longer than 500 words.
259
+ Summaries must be no longer than 250 words.
260
260
 
261
261
  <ENTITY>
262
262
  {context['node']}
@@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion
24
24
  class InvalidatedEdges(BaseModel):
25
25
  contradicted_facts: list[int] = Field(
26
26
  ...,
27
- description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.',
27
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
28
28
  )
29
29
 
30
30
 
@@ -25,7 +25,7 @@ from .models import Message, PromptFunction, PromptVersion
25
25
  class Summary(BaseModel):
26
26
  summary: str = Field(
27
27
  ...,
28
- description='Summary containing the important information about the entity. Under 500 words',
28
+ description='Summary containing the important information about the entity. Under 250 words',
29
29
  )
30
30
 
31
31
 
@@ -56,7 +56,7 @@ def summarize_pair(context: dict[str, Any]) -> list[Message]:
56
56
  content=f"""
57
57
  Synthesize the information from the following two summaries into a single succinct summary.
58
58
 
59
- Summaries must be under 500 words.
59
+ Summaries must be under 250 words.
60
60
 
61
61
  Summaries:
62
62
  {json.dumps(context['node_summaries'], indent=2)}
@@ -82,7 +82,7 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
82
82
 
83
83
  Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
84
84
  information from the provided MESSAGES. Your summary should also only contain information relevant to the
85
- provided ENTITY. Summaries must be under 500 words.
85
+ provided ENTITY. Summaries must be under 250 words.
86
86
 
87
87
  In addition, extract any values for the provided entity properties based on their descriptions.
88
88
  If the value of the entity property cannot be found in the current context, set the value of the property to the Python value None.
@@ -117,7 +117,7 @@ def summary_description(context: dict[str, Any]) -> list[Message]:
117
117
  role='user',
118
118
  content=f"""
119
119
  Create a short one sentence description of the summary that explains what kind of information is summarized.
120
- Summaries must be under 500 words.
120
+ Summaries must be under 250 words.
121
121
 
122
122
  Summary:
123
123
  {json.dumps(context['summary'], indent=2)}