graphiti-core 0.11.6rc7__tar.gz → 0.12.0rc1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (66) hide show
  1. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/PKG-INFO +1 -1
  2. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/edges.py +42 -16
  3. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/gemini.py +14 -3
  4. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/graphiti.py +33 -10
  5. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/helpers.py +8 -27
  6. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/gemini_client.py +4 -1
  7. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/models/edges/edge_db_queries.py +2 -4
  8. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/dedupe_edges.py +52 -1
  9. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/dedupe_nodes.py +75 -4
  10. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/extract_edges.py +46 -2
  11. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/invalidate_edges.py +1 -1
  12. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/search/search.py +19 -45
  13. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_utils.py +127 -18
  14. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/utils/bulk_utils.py +19 -1
  15. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/edge_operations.py +137 -10
  16. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/node_operations.py +58 -20
  17. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/pyproject.toml +1 -1
  18. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/LICENSE +0 -0
  19. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/README.md +0 -0
  20. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/__init__.py +0 -0
  21. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/__init__.py +0 -0
  22. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
  23. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/client.py +0 -0
  24. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/openai_reranker_client.py +0 -0
  25. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/__init__.py +0 -0
  26. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/client.py +0 -0
  27. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/openai.py +0 -0
  28. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/voyage.py +0 -0
  29. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/errors.py +0 -0
  30. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/graphiti_types.py +0 -0
  31. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/__init__.py +0 -0
  32. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/anthropic_client.py +0 -0
  33. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/client.py +0 -0
  34. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/config.py +0 -0
  35. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/errors.py +0 -0
  36. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/groq_client.py +0 -0
  37. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/openai_client.py +0 -0
  38. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/openai_generic_client.py +0 -0
  39. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/utils.py +0 -0
  40. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/models/__init__.py +0 -0
  41. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/models/edges/__init__.py +0 -0
  42. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/models/nodes/__init__.py +0 -0
  43. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/models/nodes/node_db_queries.py +0 -0
  44. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/nodes.py +0 -0
  45. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/__init__.py +0 -0
  46. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/eval.py +0 -0
  47. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  48. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/extract_nodes.py +0 -0
  49. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/lib.py +0 -0
  50. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/models.py +0 -0
  51. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/prompt_helpers.py +0 -0
  52. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/summarize_nodes.py +0 -0
  53. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/py.typed +0 -0
  54. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/search/__init__.py +0 -0
  55. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_config.py +0 -0
  56. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_config_recipes.py +0 -0
  57. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_filters.py +0 -0
  58. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_helpers.py +0 -0
  59. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/utils/__init__.py +0 -0
  60. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/utils/datetime_utils.py +0 -0
  61. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/__init__.py +0 -0
  62. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/community_operations.py +0 -0
  63. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
  64. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
  65. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/utils.py +0 -0
  66. {graphiti_core-0.11.6rc7 → graphiti_core-0.12.0rc1}/graphiti_core/utils/ontology_utils/entity_types_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: graphiti-core
3
- Version: 0.11.6rc7
3
+ Version: 0.12.0rc1
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -49,7 +49,9 @@ ENTITY_EDGE_RETURN: LiteralString = """
49
49
  e.episodes AS episodes,
50
50
  e.expired_at AS expired_at,
51
51
  e.valid_at AS valid_at,
52
- e.invalid_at AS invalid_at"""
52
+ e.invalid_at AS invalid_at,
53
+ properties(e) AS attributes
54
+ """
53
55
 
54
56
 
55
57
  class Edge(BaseModel, ABC):
@@ -209,6 +211,9 @@ class EntityEdge(Edge):
209
211
  invalid_at: datetime | None = Field(
210
212
  default=None, description='datetime of when the fact stopped being true'
211
213
  )
214
+ attributes: dict[str, Any] = Field(
215
+ default={}, description='Additional attributes of the edge. Dependent on edge name'
216
+ )
212
217
 
213
218
  async def generate_embedding(self, embedder: EmbedderClient):
214
219
  start = time()
@@ -236,20 +241,26 @@ class EntityEdge(Edge):
236
241
  self.fact_embedding = records[0]['fact_embedding']
237
242
 
238
243
  async def save(self, driver: AsyncDriver):
244
+ edge_data: dict[str, Any] = {
245
+ 'source_uuid': self.source_node_uuid,
246
+ 'target_uuid': self.target_node_uuid,
247
+ 'uuid': self.uuid,
248
+ 'name': self.name,
249
+ 'group_id': self.group_id,
250
+ 'fact': self.fact,
251
+ 'fact_embedding': self.fact_embedding,
252
+ 'episodes': self.episodes,
253
+ 'created_at': self.created_at,
254
+ 'expired_at': self.expired_at,
255
+ 'valid_at': self.valid_at,
256
+ 'invalid_at': self.invalid_at,
257
+ }
258
+
259
+ edge_data.update(self.attributes or {})
260
+
239
261
  result = await driver.execute_query(
240
262
  ENTITY_EDGE_SAVE,
241
- source_uuid=self.source_node_uuid,
242
- target_uuid=self.target_node_uuid,
243
- uuid=self.uuid,
244
- name=self.name,
245
- group_id=self.group_id,
246
- fact=self.fact,
247
- fact_embedding=self.fact_embedding,
248
- episodes=self.episodes,
249
- created_at=self.created_at,
250
- expired_at=self.expired_at,
251
- valid_at=self.valid_at,
252
- invalid_at=self.invalid_at,
263
+ edge_data=edge_data,
253
264
  database_=DEFAULT_DATABASE,
254
265
  )
255
266
 
@@ -334,8 +345,8 @@ class EntityEdge(Edge):
334
345
  async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
335
346
  query: LiteralString = (
336
347
  """
337
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
338
- """
348
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
349
+ """
339
350
  + ENTITY_EDGE_RETURN
340
351
  )
341
352
  records, _, _ = await driver.execute_query(
@@ -457,7 +468,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
457
468
 
458
469
 
459
470
  def get_entity_edge_from_record(record: Any) -> EntityEdge:
460
- return EntityEdge(
471
+ edge = EntityEdge(
461
472
  uuid=record['uuid'],
462
473
  source_node_uuid=record['source_node_uuid'],
463
474
  target_node_uuid=record['target_node_uuid'],
@@ -469,8 +480,23 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
469
480
  expired_at=parse_db_date(record['expired_at']),
470
481
  valid_at=parse_db_date(record['valid_at']),
471
482
  invalid_at=parse_db_date(record['invalid_at']),
483
+ attributes=record['attributes'],
472
484
  )
473
485
 
486
+ edge.attributes.pop('uuid', None)
487
+ edge.attributes.pop('source_node_uuid', None)
488
+ edge.attributes.pop('target_node_uuid', None)
489
+ edge.attributes.pop('fact', None)
490
+ edge.attributes.pop('name', None)
491
+ edge.attributes.pop('group_id', None)
492
+ edge.attributes.pop('episodes', None)
493
+ edge.attributes.pop('created_at', None)
494
+ edge.attributes.pop('expired_at', None)
495
+ edge.attributes.pop('valid_at', None)
496
+ edge.attributes.pop('invalid_at', None)
497
+
498
+ return edge
499
+
474
500
 
475
501
  def get_community_edge_from_record(record: Any):
476
502
  return CommunityEdge(
@@ -61,18 +61,29 @@ class GeminiEmbedder(EmbedderClient):
61
61
  # Generate embeddings
62
62
  result = await self.client.aio.models.embed_content(
63
63
  model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
64
- contents=[input_data],
64
+ contents=[input_data], # type: ignore[arg-type] # mypy fails on broad union type
65
65
  config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
66
66
  )
67
67
 
68
+ if not result.embeddings or len(result.embeddings) == 0 or not result.embeddings[0].values:
69
+ raise ValueError('No embeddings returned from Gemini API in create()')
70
+
68
71
  return result.embeddings[0].values
69
72
 
70
73
  async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
71
74
  # Generate embeddings
72
75
  result = await self.client.aio.models.embed_content(
73
76
  model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
74
- contents=input_data_list,
77
+ contents=input_data_list, # type: ignore[arg-type] # mypy fails on broad union type
75
78
  config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
76
79
  )
77
80
 
78
- return [embedding.values for embedding in result.embeddings]
81
+ if not result.embeddings or len(result.embeddings) == 0:
82
+ raise Exception('No embeddings returned')
83
+
84
+ embeddings = []
85
+ for embedding in result.embeddings:
86
+ if not embedding.values:
87
+ raise ValueError('Empty embedding values returned')
88
+ embeddings.append(embedding.values)
89
+ return embeddings
@@ -41,6 +41,7 @@ from graphiti_core.search.search_config_recipes import (
41
41
  from graphiti_core.search.search_filters import SearchFilters
42
42
  from graphiti_core.search.search_utils import (
43
43
  RELEVANT_SCHEMA_LIMIT,
44
+ get_edge_invalidation_candidates,
44
45
  get_mentioned_nodes,
45
46
  get_relevant_edges,
46
47
  )
@@ -62,9 +63,8 @@ from graphiti_core.utils.maintenance.community_operations import (
62
63
  )
63
64
  from graphiti_core.utils.maintenance.edge_operations import (
64
65
  build_episodic_edges,
65
- dedupe_extracted_edge,
66
66
  extract_edges,
67
- resolve_edge_contradictions,
67
+ resolve_extracted_edge,
68
68
  resolve_extracted_edges,
69
69
  )
70
70
  from graphiti_core.utils.maintenance.graph_data_operations import (
@@ -77,7 +77,6 @@ from graphiti_core.utils.maintenance.node_operations import (
77
77
  extract_nodes,
78
78
  resolve_extracted_nodes,
79
79
  )
80
- from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
81
80
  from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
82
81
 
83
82
  logger = logging.getLogger(__name__)
@@ -274,6 +273,8 @@ class Graphiti:
274
273
  update_communities: bool = False,
275
274
  entity_types: dict[str, BaseModel] | None = None,
276
275
  previous_episode_uuids: list[str] | None = None,
276
+ edge_types: dict[str, BaseModel] | None = None,
277
+ edge_type_map: dict[tuple[str, str], list[str]] | None = None,
277
278
  ) -> AddEpisodeResults:
278
279
  """
279
280
  Process an episode and update the graph.
@@ -356,6 +357,13 @@ class Graphiti:
356
357
  )
357
358
  )
358
359
 
360
+ # Create default edge type map
361
+ edge_type_map_default = (
362
+ {('Entity', 'Entity'): list(edge_types.keys())}
363
+ if edge_types is not None
364
+ else {('Entity', 'Entity'): []}
365
+ )
366
+
359
367
  # Extract entities as nodes
360
368
 
361
369
  extracted_nodes = await extract_nodes(
@@ -371,7 +379,9 @@ class Graphiti:
371
379
  previous_episodes,
372
380
  entity_types,
373
381
  ),
374
- extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
382
+ extract_edges(
383
+ self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
384
+ ),
375
385
  )
376
386
 
377
387
  edges = resolve_edge_pointers(extracted_edges, uuid_map)
@@ -381,6 +391,9 @@ class Graphiti:
381
391
  self.clients,
382
392
  edges,
383
393
  episode,
394
+ nodes,
395
+ edge_types or {},
396
+ edge_type_map or edge_type_map_default,
384
397
  ),
385
398
  extract_attributes_from_nodes(
386
399
  self.clients, nodes, episode, previous_episodes, entity_types
@@ -681,17 +694,27 @@ class Graphiti:
681
694
 
682
695
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
683
696
 
684
- related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
697
+ related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
698
+ existing_edges = (
699
+ await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
700
+ )[0]
685
701
 
686
- resolved_edge = await dedupe_extracted_edge(
702
+ resolved_edge, invalidated_edges = await resolve_extracted_edge(
687
703
  self.llm_client,
688
704
  updated_edge,
689
- related_edges[0],
705
+ related_edges,
706
+ existing_edges,
707
+ EpisodicNode(
708
+ name='',
709
+ source=EpisodeType.text,
710
+ source_description='',
711
+ content='',
712
+ valid_at=edge.valid_at or utc_now(),
713
+ entity_edges=[],
714
+ group_id=edge.group_id,
715
+ ),
690
716
  )
691
717
 
692
- contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
693
- invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
694
-
695
718
  await add_nodes_and_edges_bulk(
696
719
  self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
697
720
  )
@@ -18,11 +18,11 @@ import asyncio
18
18
  import os
19
19
  from collections.abc import Coroutine
20
20
  from datetime import datetime
21
- from typing import Any
22
21
 
23
22
  import numpy as np
24
23
  from dotenv import load_dotenv
25
24
  from neo4j import time as neo4j_time
25
+ from numpy._typing import NDArray
26
26
  from typing_extensions import LiteralString
27
27
 
28
28
  load_dotenv()
@@ -79,16 +79,10 @@ def lucene_sanitize(query: str) -> str:
79
79
  return sanitized
80
80
 
81
81
 
82
- def normalize_l2(embedding: list[float]):
82
+ def normalize_l2(embedding: list[float]) -> NDArray:
83
83
  embedding_array = np.array(embedding)
84
- if embedding_array.ndim == 1:
85
- norm = np.linalg.norm(embedding_array)
86
- if norm == 0:
87
- return [0.0] * len(embedding)
88
- return (embedding_array / norm).tolist()
89
- else:
90
- norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
91
- return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
84
+ norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True)
85
+ return np.where(norm == 0, embedding_array, embedding_array / norm)
92
86
 
93
87
 
94
88
  # Use this instead of asyncio.gather() to bound coroutines
@@ -98,21 +92,8 @@ async def semaphore_gather(
98
92
  ):
99
93
  semaphore = asyncio.Semaphore(max_coroutines)
100
94
 
101
- async def _wrap(coro: Coroutine) -> Any:
95
+ async def _wrap_coroutine(coroutine):
102
96
  async with semaphore:
103
- return await coro
104
-
105
- results = []
106
- batch = []
107
- for coroutine in coroutines:
108
- batch.append(_wrap(coroutine))
109
- # once we hit max_coroutines, gather and clear the batch
110
- if len(batch) >= max_coroutines:
111
- results.extend(await asyncio.gather(*batch))
112
- batch.clear()
113
-
114
- # gather any remaining coroutines in the final batch
115
- if batch:
116
- results.extend(await asyncio.gather(*batch))
117
-
118
- return results
97
+ return await coroutine
98
+
99
+ return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
@@ -139,13 +139,16 @@ class GeminiClient(LLMClient):
139
139
  # Generate content using the simple string approach
140
140
  response = await self.client.aio.models.generate_content(
141
141
  model=self.model or DEFAULT_MODEL,
142
- contents=gemini_messages,
142
+ contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
143
143
  config=generation_config,
144
144
  )
145
145
 
146
146
  # If this was a structured output request, parse the response into the Pydantic model
147
147
  if response_model is not None:
148
148
  try:
149
+ if not response.text:
150
+ raise ValueError('No response text')
151
+
149
152
  validated_model = response_model.model_validate(json.loads(response.text))
150
153
 
151
154
  # Return as a dictionary for API consistency
@@ -34,8 +34,7 @@ ENTITY_EDGE_SAVE = """
34
34
  MATCH (source:Entity {uuid: $source_uuid})
35
35
  MATCH (target:Entity {uuid: $target_uuid})
36
36
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
37
- SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
38
- created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
37
+ SET r = $edge_data
39
38
  WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
40
39
  RETURN r.uuid AS uuid"""
41
40
 
@@ -44,8 +43,7 @@ ENTITY_EDGE_SAVE_BULK = """
44
43
  MATCH (source:Entity {uuid: edge.source_node_uuid})
45
44
  MATCH (target:Entity {uuid: edge.target_node_uuid})
46
45
  MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
47
- SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
48
- created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
46
+ SET r = edge
49
47
  WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
50
48
  RETURN edge.uuid AS uuid
51
49
  """
@@ -27,6 +27,11 @@ class EdgeDuplicate(BaseModel):
27
27
  ...,
28
28
  description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
29
29
  )
30
+ contradicted_facts: list[int] = Field(
31
+ ...,
32
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
33
+ )
34
+ fact_type: str = Field(..., description='One of the provided fact types or DEFAULT')
30
35
 
31
36
 
32
37
  class UniqueFact(BaseModel):
@@ -41,11 +46,13 @@ class UniqueFacts(BaseModel):
41
46
  class Prompt(Protocol):
42
47
  edge: PromptVersion
43
48
  edge_list: PromptVersion
49
+ resolve_edge: PromptVersion
44
50
 
45
51
 
46
52
  class Versions(TypedDict):
47
53
  edge: PromptFunction
48
54
  edge_list: PromptFunction
55
+ resolve_edge: PromptFunction
49
56
 
50
57
 
51
58
  def edge(context: dict[str, Any]) -> list[Message]:
@@ -106,4 +113,48 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
106
113
  ]
107
114
 
108
115
 
109
- versions: Versions = {'edge': edge, 'edge_list': edge_list}
116
+ def resolve_edge(context: dict[str, Any]) -> list[Message]:
117
+ return [
118
+ Message(
119
+ role='system',
120
+ content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
121
+ 'facts are contradicted by the new fact.',
122
+ ),
123
+ Message(
124
+ role='user',
125
+ content=f"""
126
+ <NEW FACT>
127
+ {context['new_edge']}
128
+ </NEW FACT>
129
+
130
+ <EXISTING FACTS>
131
+ {context['existing_edges']}
132
+ </EXISTING FACTS>
133
+ <FACT INVALIDATION CANDIDATES>
134
+ {context['edge_invalidation_candidates']}
135
+ </FACT INVALIDATION CANDIDATES>
136
+
137
+ <FACT TYPES>
138
+ {context['edge_types']}
139
+ </FACT TYPES>
140
+
141
+
142
+ Task:
143
+ If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
144
+ If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
145
+
146
+ Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
147
+ Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
148
+
149
+ Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
150
+ Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
151
+ If there are no contradicted facts, return an empty list.
152
+
153
+ Guidelines:
154
+ 1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
155
+ """,
156
+ ),
157
+ ]
158
+
159
+
160
+ versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}
@@ -23,21 +23,31 @@ from .models import Message, PromptFunction, PromptVersion
23
23
 
24
24
 
25
25
  class NodeDuplicate(BaseModel):
26
- duplicate_node_id: int = Field(
26
+ id: int = Field(..., description='integer id of the entity')
27
+ duplicate_idx: int = Field(
27
28
  ...,
28
- description='id of the duplicate node. If no duplicate nodes are found, default to -1.',
29
+ description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
29
30
  )
30
- name: str = Field(..., description='Name of the entity.')
31
+ name: str = Field(
32
+ ...,
33
+ description='Name of the entity. Should be the most complete and descriptive name possible.',
34
+ )
35
+
36
+
37
+ class NodeResolutions(BaseModel):
38
+ entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
31
39
 
32
40
 
33
41
  class Prompt(Protocol):
34
42
  node: PromptVersion
35
43
  node_list: PromptVersion
44
+ nodes: PromptVersion
36
45
 
37
46
 
38
47
  class Versions(TypedDict):
39
48
  node: PromptFunction
40
49
  node_list: PromptFunction
50
+ nodes: PromptFunction
41
51
 
42
52
 
43
53
  def node(context: dict[str, Any]) -> list[Message]:
@@ -89,6 +99,67 @@ def node(context: dict[str, Any]) -> list[Message]:
89
99
  ]
90
100
 
91
101
 
102
+ def nodes(context: dict[str, Any]) -> list[Message]:
103
+ return [
104
+ Message(
105
+ role='system',
106
+ content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
107
+ 'of existing entities.',
108
+ ),
109
+ Message(
110
+ role='user',
111
+ content=f"""
112
+ <PREVIOUS MESSAGES>
113
+ {json.dumps([ep for ep in context['previous_episodes']], indent=2)}
114
+ </PREVIOUS MESSAGES>
115
+ <CURRENT MESSAGE>
116
+ {context['episode_content']}
117
+ </CURRENT MESSAGE>
118
+
119
+
120
+ Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
121
+ Each entity in ENTITIES is represented as a JSON object with the following structure:
122
+ {{
123
+ id: integer id of the entity,
124
+ name: "name of the entity",
125
+ entity_type: "ontological classification of the entity",
126
+ entity_type_description: "Description of what the entity type represents",
127
+ duplication_candidates: [
128
+ {{
129
+ idx: integer index of the candidate entity,
130
+ name: "name of the candidate entity",
131
+ entity_type: "ontological classification of the candidate entity",
132
+ ...<additional attributes>
133
+ }}
134
+ ]
135
+ }}
136
+
137
+ <ENTITIES>
138
+ {json.dumps(context['extracted_nodes'], indent=2)}
139
+ </ENTITIES>
140
+
141
+ For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
142
+
143
+ Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
144
+
145
+ Do NOT mark entities as duplicates if:
146
+ - They are related but distinct.
147
+ - They have similar names or purposes but refer to separate instances or concepts.
148
+
149
+ Task:
150
+ Your response will be a list called entity_resolutions which contains one entry for each entity.
151
+
152
+ For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
153
+ as an integer.
154
+
155
+ - If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
156
+ duplicate of.
157
+ - If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
158
+ """,
159
+ ),
160
+ ]
161
+
162
+
92
163
  def node_list(context: dict[str, Any]) -> list[Message]:
93
164
  return [
94
165
  Message(
@@ -126,4 +197,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
126
197
  ]
127
198
 
128
199
 
129
- versions: Versions = {'node': node, 'node_list': node_list}
200
+ versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
@@ -48,11 +48,13 @@ class MissingFacts(BaseModel):
48
48
  class Prompt(Protocol):
49
49
  edge: PromptVersion
50
50
  reflexion: PromptVersion
51
+ extract_attributes: PromptVersion
51
52
 
52
53
 
53
54
  class Versions(TypedDict):
54
55
  edge: PromptFunction
55
56
  reflexion: PromptFunction
57
+ extract_attributes: PromptFunction
56
58
 
57
59
 
58
60
  def edge(context: dict[str, Any]) -> list[Message]:
@@ -82,12 +84,18 @@ def edge(context: dict[str, Any]) -> list[Message]:
82
84
  {context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
83
85
  </REFERENCE_TIME>
84
86
 
87
+ <FACT TYPES>
88
+ {context['edge_types']}
89
+ </FACT TYPES>
90
+
85
91
  # TASK
86
92
  Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
87
93
  Only extract facts that:
88
94
  - involve two DISTINCT ENTITIES from the ENTITIES list,
89
95
  - are clearly stated or unambiguously implied in the CURRENT MESSAGE,
90
- - and can be represented as edges in a knowledge graph.
96
+ and can be represented as edges in a knowledge graph.
97
+ - The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
98
+ could be classified into one of the provided fact types
91
99
 
92
100
  You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
93
101
 
@@ -145,4 +153,40 @@ determine if any facts haven't been extracted.
145
153
  ]
146
154
 
147
155
 
148
- versions: Versions = {'edge': edge, 'reflexion': reflexion}
156
+ def extract_attributes(context: dict[str, Any]) -> list[Message]:
157
+ return [
158
+ Message(
159
+ role='system',
160
+ content='You are a helpful assistant that extracts fact properties from the provided text.',
161
+ ),
162
+ Message(
163
+ role='user',
164
+ content=f"""
165
+
166
+ <MESSAGE>
167
+ {json.dumps(context['episode_content'], indent=2)}
168
+ </MESSAGE>
169
+ <REFERENCE TIME>
170
+ {context['reference_time']}
171
+ </REFERENCE TIME>
172
+
173
+ Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
174
+ in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
175
+
176
+ Guidelines:
177
+ 1. Do not hallucinate entity property values if they cannot be found in the current context.
178
+ 2. Only use the provided MESSAGES and FACT to set attribute values.
179
+
180
+ <FACT>
181
+ {context['fact']}
182
+ </FACT>
183
+ """,
184
+ ),
185
+ ]
186
+
187
+
188
+ versions: Versions = {
189
+ 'edge': edge,
190
+ 'reflexion': reflexion,
191
+ 'extract_attributes': extract_attributes,
192
+ }
@@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion
24
24
  class InvalidatedEdges(BaseModel):
25
25
  contradicted_facts: list[int] = Field(
26
26
  ...,
27
- description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.',
27
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
28
28
  )
29
29
 
30
30