graphiti-core 0.11.6rc9__tar.gz → 0.12.0rc1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (66) hide show
  1. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/PKG-INFO +1 -1
  2. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/edges.py +42 -16
  3. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/gemini.py +14 -3
  4. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/graphiti.py +33 -10
  5. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/helpers.py +0 -1
  6. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/gemini_client.py +4 -1
  7. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/edges/edge_db_queries.py +2 -4
  8. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/dedupe_edges.py +52 -1
  9. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/dedupe_nodes.py +75 -4
  10. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/extract_edges.py +46 -2
  11. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/invalidate_edges.py +1 -1
  12. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_utils.py +14 -9
  13. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/bulk_utils.py +19 -1
  14. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/edge_operations.py +137 -10
  15. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/node_operations.py +58 -20
  16. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/pyproject.toml +1 -1
  17. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/LICENSE +0 -0
  18. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/README.md +0 -0
  19. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/__init__.py +0 -0
  20. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/__init__.py +0 -0
  21. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
  22. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/client.py +0 -0
  23. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/cross_encoder/openai_reranker_client.py +0 -0
  24. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/__init__.py +0 -0
  25. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/client.py +0 -0
  26. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/openai.py +0 -0
  27. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/embedder/voyage.py +0 -0
  28. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/errors.py +0 -0
  29. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/graphiti_types.py +0 -0
  30. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/__init__.py +0 -0
  31. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/anthropic_client.py +0 -0
  32. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/client.py +0 -0
  33. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/config.py +0 -0
  34. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/errors.py +0 -0
  35. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/groq_client.py +0 -0
  36. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/openai_client.py +0 -0
  37. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/openai_generic_client.py +0 -0
  38. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/llm_client/utils.py +0 -0
  39. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/__init__.py +0 -0
  40. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/edges/__init__.py +0 -0
  41. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/nodes/__init__.py +0 -0
  42. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/models/nodes/node_db_queries.py +0 -0
  43. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/nodes.py +0 -0
  44. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/__init__.py +0 -0
  45. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/eval.py +0 -0
  46. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  47. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/extract_nodes.py +0 -0
  48. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/lib.py +0 -0
  49. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/models.py +0 -0
  50. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/prompt_helpers.py +0 -0
  51. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/prompts/summarize_nodes.py +0 -0
  52. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/py.typed +0 -0
  53. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/__init__.py +0 -0
  54. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search.py +0 -0
  55. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_config.py +0 -0
  56. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_config_recipes.py +0 -0
  57. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_filters.py +0 -0
  58. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/search/search_helpers.py +0 -0
  59. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/__init__.py +0 -0
  60. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/datetime_utils.py +0 -0
  61. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/__init__.py +0 -0
  62. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/community_operations.py +0 -0
  63. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
  64. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
  65. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/maintenance/utils.py +0 -0
  66. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0rc1}/graphiti_core/utils/ontology_utils/entity_types_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: graphiti-core
3
- Version: 0.11.6rc9
3
+ Version: 0.12.0rc1
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -49,7 +49,9 @@ ENTITY_EDGE_RETURN: LiteralString = """
49
49
  e.episodes AS episodes,
50
50
  e.expired_at AS expired_at,
51
51
  e.valid_at AS valid_at,
52
- e.invalid_at AS invalid_at"""
52
+ e.invalid_at AS invalid_at,
53
+ properties(e) AS attributes
54
+ """
53
55
 
54
56
 
55
57
  class Edge(BaseModel, ABC):
@@ -209,6 +211,9 @@ class EntityEdge(Edge):
209
211
  invalid_at: datetime | None = Field(
210
212
  default=None, description='datetime of when the fact stopped being true'
211
213
  )
214
+ attributes: dict[str, Any] = Field(
215
+ default={}, description='Additional attributes of the edge. Dependent on edge name'
216
+ )
212
217
 
213
218
  async def generate_embedding(self, embedder: EmbedderClient):
214
219
  start = time()
@@ -236,20 +241,26 @@ class EntityEdge(Edge):
236
241
  self.fact_embedding = records[0]['fact_embedding']
237
242
 
238
243
  async def save(self, driver: AsyncDriver):
244
+ edge_data: dict[str, Any] = {
245
+ 'source_uuid': self.source_node_uuid,
246
+ 'target_uuid': self.target_node_uuid,
247
+ 'uuid': self.uuid,
248
+ 'name': self.name,
249
+ 'group_id': self.group_id,
250
+ 'fact': self.fact,
251
+ 'fact_embedding': self.fact_embedding,
252
+ 'episodes': self.episodes,
253
+ 'created_at': self.created_at,
254
+ 'expired_at': self.expired_at,
255
+ 'valid_at': self.valid_at,
256
+ 'invalid_at': self.invalid_at,
257
+ }
258
+
259
+ edge_data.update(self.attributes or {})
260
+
239
261
  result = await driver.execute_query(
240
262
  ENTITY_EDGE_SAVE,
241
- source_uuid=self.source_node_uuid,
242
- target_uuid=self.target_node_uuid,
243
- uuid=self.uuid,
244
- name=self.name,
245
- group_id=self.group_id,
246
- fact=self.fact,
247
- fact_embedding=self.fact_embedding,
248
- episodes=self.episodes,
249
- created_at=self.created_at,
250
- expired_at=self.expired_at,
251
- valid_at=self.valid_at,
252
- invalid_at=self.invalid_at,
263
+ edge_data=edge_data,
253
264
  database_=DEFAULT_DATABASE,
254
265
  )
255
266
 
@@ -334,8 +345,8 @@ class EntityEdge(Edge):
334
345
  async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
335
346
  query: LiteralString = (
336
347
  """
337
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
338
- """
348
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
349
+ """
339
350
  + ENTITY_EDGE_RETURN
340
351
  )
341
352
  records, _, _ = await driver.execute_query(
@@ -457,7 +468,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
457
468
 
458
469
 
459
470
  def get_entity_edge_from_record(record: Any) -> EntityEdge:
460
- return EntityEdge(
471
+ edge = EntityEdge(
461
472
  uuid=record['uuid'],
462
473
  source_node_uuid=record['source_node_uuid'],
463
474
  target_node_uuid=record['target_node_uuid'],
@@ -469,8 +480,23 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
469
480
  expired_at=parse_db_date(record['expired_at']),
470
481
  valid_at=parse_db_date(record['valid_at']),
471
482
  invalid_at=parse_db_date(record['invalid_at']),
483
+ attributes=record['attributes'],
472
484
  )
473
485
 
486
+ edge.attributes.pop('uuid', None)
487
+ edge.attributes.pop('source_node_uuid', None)
488
+ edge.attributes.pop('target_node_uuid', None)
489
+ edge.attributes.pop('fact', None)
490
+ edge.attributes.pop('name', None)
491
+ edge.attributes.pop('group_id', None)
492
+ edge.attributes.pop('episodes', None)
493
+ edge.attributes.pop('created_at', None)
494
+ edge.attributes.pop('expired_at', None)
495
+ edge.attributes.pop('valid_at', None)
496
+ edge.attributes.pop('invalid_at', None)
497
+
498
+ return edge
499
+
474
500
 
475
501
  def get_community_edge_from_record(record: Any):
476
502
  return CommunityEdge(
@@ -61,18 +61,29 @@ class GeminiEmbedder(EmbedderClient):
61
61
  # Generate embeddings
62
62
  result = await self.client.aio.models.embed_content(
63
63
  model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
64
- contents=[input_data],
64
+ contents=[input_data], # type: ignore[arg-type] # mypy fails on broad union type
65
65
  config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
66
66
  )
67
67
 
68
+ if not result.embeddings or len(result.embeddings) == 0 or not result.embeddings[0].values:
69
+ raise ValueError('No embeddings returned from Gemini API in create()')
70
+
68
71
  return result.embeddings[0].values
69
72
 
70
73
  async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
71
74
  # Generate embeddings
72
75
  result = await self.client.aio.models.embed_content(
73
76
  model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
74
- contents=input_data_list,
77
+ contents=input_data_list, # type: ignore[arg-type] # mypy fails on broad union type
75
78
  config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
76
79
  )
77
80
 
78
- return [embedding.values for embedding in result.embeddings]
81
+ if not result.embeddings or len(result.embeddings) == 0:
82
+ raise Exception('No embeddings returned')
83
+
84
+ embeddings = []
85
+ for embedding in result.embeddings:
86
+ if not embedding.values:
87
+ raise ValueError('Empty embedding values returned')
88
+ embeddings.append(embedding.values)
89
+ return embeddings
@@ -41,6 +41,7 @@ from graphiti_core.search.search_config_recipes import (
41
41
  from graphiti_core.search.search_filters import SearchFilters
42
42
  from graphiti_core.search.search_utils import (
43
43
  RELEVANT_SCHEMA_LIMIT,
44
+ get_edge_invalidation_candidates,
44
45
  get_mentioned_nodes,
45
46
  get_relevant_edges,
46
47
  )
@@ -62,9 +63,8 @@ from graphiti_core.utils.maintenance.community_operations import (
62
63
  )
63
64
  from graphiti_core.utils.maintenance.edge_operations import (
64
65
  build_episodic_edges,
65
- dedupe_extracted_edge,
66
66
  extract_edges,
67
- resolve_edge_contradictions,
67
+ resolve_extracted_edge,
68
68
  resolve_extracted_edges,
69
69
  )
70
70
  from graphiti_core.utils.maintenance.graph_data_operations import (
@@ -77,7 +77,6 @@ from graphiti_core.utils.maintenance.node_operations import (
77
77
  extract_nodes,
78
78
  resolve_extracted_nodes,
79
79
  )
80
- from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
81
80
  from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
82
81
 
83
82
  logger = logging.getLogger(__name__)
@@ -274,6 +273,8 @@ class Graphiti:
274
273
  update_communities: bool = False,
275
274
  entity_types: dict[str, BaseModel] | None = None,
276
275
  previous_episode_uuids: list[str] | None = None,
276
+ edge_types: dict[str, BaseModel] | None = None,
277
+ edge_type_map: dict[tuple[str, str], list[str]] | None = None,
277
278
  ) -> AddEpisodeResults:
278
279
  """
279
280
  Process an episode and update the graph.
@@ -356,6 +357,13 @@ class Graphiti:
356
357
  )
357
358
  )
358
359
 
360
+ # Create default edge type map
361
+ edge_type_map_default = (
362
+ {('Entity', 'Entity'): list(edge_types.keys())}
363
+ if edge_types is not None
364
+ else {('Entity', 'Entity'): []}
365
+ )
366
+
359
367
  # Extract entities as nodes
360
368
 
361
369
  extracted_nodes = await extract_nodes(
@@ -371,7 +379,9 @@ class Graphiti:
371
379
  previous_episodes,
372
380
  entity_types,
373
381
  ),
374
- extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
382
+ extract_edges(
383
+ self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
384
+ ),
375
385
  )
376
386
 
377
387
  edges = resolve_edge_pointers(extracted_edges, uuid_map)
@@ -381,6 +391,9 @@ class Graphiti:
381
391
  self.clients,
382
392
  edges,
383
393
  episode,
394
+ nodes,
395
+ edge_types or {},
396
+ edge_type_map or edge_type_map_default,
384
397
  ),
385
398
  extract_attributes_from_nodes(
386
399
  self.clients, nodes, episode, previous_episodes, entity_types
@@ -681,17 +694,27 @@ class Graphiti:
681
694
 
682
695
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
683
696
 
684
- related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
697
+ related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
698
+ existing_edges = (
699
+ await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
700
+ )[0]
685
701
 
686
- resolved_edge = await dedupe_extracted_edge(
702
+ resolved_edge, invalidated_edges = await resolve_extracted_edge(
687
703
  self.llm_client,
688
704
  updated_edge,
689
- related_edges[0],
705
+ related_edges,
706
+ existing_edges,
707
+ EpisodicNode(
708
+ name='',
709
+ source=EpisodeType.text,
710
+ source_description='',
711
+ content='',
712
+ valid_at=edge.valid_at or utc_now(),
713
+ entity_edges=[],
714
+ group_id=edge.group_id,
715
+ ),
690
716
  )
691
717
 
692
- contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
693
- invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
694
-
695
718
  await add_nodes_and_edges_bulk(
696
719
  self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
697
720
  )
@@ -18,7 +18,6 @@ import asyncio
18
18
  import os
19
19
  from collections.abc import Coroutine
20
20
  from datetime import datetime
21
- from typing import Any
22
21
 
23
22
  import numpy as np
24
23
  from dotenv import load_dotenv
@@ -139,13 +139,16 @@ class GeminiClient(LLMClient):
139
139
  # Generate content using the simple string approach
140
140
  response = await self.client.aio.models.generate_content(
141
141
  model=self.model or DEFAULT_MODEL,
142
- contents=gemini_messages,
142
+ contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
143
143
  config=generation_config,
144
144
  )
145
145
 
146
146
  # If this was a structured output request, parse the response into the Pydantic model
147
147
  if response_model is not None:
148
148
  try:
149
+ if not response.text:
150
+ raise ValueError('No response text')
151
+
149
152
  validated_model = response_model.model_validate(json.loads(response.text))
150
153
 
151
154
  # Return as a dictionary for API consistency
@@ -34,8 +34,7 @@ ENTITY_EDGE_SAVE = """
34
34
  MATCH (source:Entity {uuid: $source_uuid})
35
35
  MATCH (target:Entity {uuid: $target_uuid})
36
36
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
37
- SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
38
- created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
37
+ SET r = $edge_data
39
38
  WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
40
39
  RETURN r.uuid AS uuid"""
41
40
 
@@ -44,8 +43,7 @@ ENTITY_EDGE_SAVE_BULK = """
44
43
  MATCH (source:Entity {uuid: edge.source_node_uuid})
45
44
  MATCH (target:Entity {uuid: edge.target_node_uuid})
46
45
  MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
47
- SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
48
- created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
46
+ SET r = edge
49
47
  WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
50
48
  RETURN edge.uuid AS uuid
51
49
  """
@@ -27,6 +27,11 @@ class EdgeDuplicate(BaseModel):
27
27
  ...,
28
28
  description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
29
29
  )
30
+ contradicted_facts: list[int] = Field(
31
+ ...,
32
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
33
+ )
34
+ fact_type: str = Field(..., description='One of the provided fact types or DEFAULT')
30
35
 
31
36
 
32
37
  class UniqueFact(BaseModel):
@@ -41,11 +46,13 @@ class UniqueFacts(BaseModel):
41
46
  class Prompt(Protocol):
42
47
  edge: PromptVersion
43
48
  edge_list: PromptVersion
49
+ resolve_edge: PromptVersion
44
50
 
45
51
 
46
52
  class Versions(TypedDict):
47
53
  edge: PromptFunction
48
54
  edge_list: PromptFunction
55
+ resolve_edge: PromptFunction
49
56
 
50
57
 
51
58
  def edge(context: dict[str, Any]) -> list[Message]:
@@ -106,4 +113,48 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
106
113
  ]
107
114
 
108
115
 
109
- versions: Versions = {'edge': edge, 'edge_list': edge_list}
116
+ def resolve_edge(context: dict[str, Any]) -> list[Message]:
117
+ return [
118
+ Message(
119
+ role='system',
120
+ content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
121
+ 'facts are contradicted by the new fact.',
122
+ ),
123
+ Message(
124
+ role='user',
125
+ content=f"""
126
+ <NEW FACT>
127
+ {context['new_edge']}
128
+ </NEW FACT>
129
+
130
+ <EXISTING FACTS>
131
+ {context['existing_edges']}
132
+ </EXISTING FACTS>
133
+ <FACT INVALIDATION CANDIDATES>
134
+ {context['edge_invalidation_candidates']}
135
+ </FACT INVALIDATION CANDIDATES>
136
+
137
+ <FACT TYPES>
138
+ {context['edge_types']}
139
+ </FACT TYPES>
140
+
141
+
142
+ Task:
143
+ If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
144
+ If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
145
+
146
+ Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
147
+ Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
148
+
149
+ Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
150
+ Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
151
+ If there are no contradicted facts, return an empty list.
152
+
153
+ Guidelines:
154
+ 1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
155
+ """,
156
+ ),
157
+ ]
158
+
159
+
160
+ versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}
@@ -23,21 +23,31 @@ from .models import Message, PromptFunction, PromptVersion
23
23
 
24
24
 
25
25
  class NodeDuplicate(BaseModel):
26
- duplicate_node_id: int = Field(
26
+ id: int = Field(..., description='integer id of the entity')
27
+ duplicate_idx: int = Field(
27
28
  ...,
28
- description='id of the duplicate node. If no duplicate nodes are found, default to -1.',
29
+ description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
29
30
  )
30
- name: str = Field(..., description='Name of the entity.')
31
+ name: str = Field(
32
+ ...,
33
+ description='Name of the entity. Should be the most complete and descriptive name possible.',
34
+ )
35
+
36
+
37
+ class NodeResolutions(BaseModel):
38
+ entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
31
39
 
32
40
 
33
41
  class Prompt(Protocol):
34
42
  node: PromptVersion
35
43
  node_list: PromptVersion
44
+ nodes: PromptVersion
36
45
 
37
46
 
38
47
  class Versions(TypedDict):
39
48
  node: PromptFunction
40
49
  node_list: PromptFunction
50
+ nodes: PromptFunction
41
51
 
42
52
 
43
53
  def node(context: dict[str, Any]) -> list[Message]:
@@ -89,6 +99,67 @@ def node(context: dict[str, Any]) -> list[Message]:
89
99
  ]
90
100
 
91
101
 
102
+ def nodes(context: dict[str, Any]) -> list[Message]:
103
+ return [
104
+ Message(
105
+ role='system',
106
+ content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
107
+ 'of existing entities.',
108
+ ),
109
+ Message(
110
+ role='user',
111
+ content=f"""
112
+ <PREVIOUS MESSAGES>
113
+ {json.dumps([ep for ep in context['previous_episodes']], indent=2)}
114
+ </PREVIOUS MESSAGES>
115
+ <CURRENT MESSAGE>
116
+ {context['episode_content']}
117
+ </CURRENT MESSAGE>
118
+
119
+
120
+ Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
121
+ Each entity in ENTITIES is represented as a JSON object with the following structure:
122
+ {{
123
+ id: integer id of the entity,
124
+ name: "name of the entity",
125
+ entity_type: "ontological classification of the entity",
126
+ entity_type_description: "Description of what the entity type represents",
127
+ duplication_candidates: [
128
+ {{
129
+ idx: integer index of the candidate entity,
130
+ name: "name of the candidate entity",
131
+ entity_type: "ontological classification of the candidate entity",
132
+ ...<additional attributes>
133
+ }}
134
+ ]
135
+ }}
136
+
137
+ <ENTITIES>
138
+ {json.dumps(context['extracted_nodes'], indent=2)}
139
+ </ENTITIES>
140
+
141
+ For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
142
+
143
+ Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
144
+
145
+ Do NOT mark entities as duplicates if:
146
+ - They are related but distinct.
147
+ - They have similar names or purposes but refer to separate instances or concepts.
148
+
149
+ Task:
150
+ Your response will be a list called entity_resolutions which contains one entry for each entity.
151
+
152
+ For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
153
+ as an integer.
154
+
155
+ - If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
156
+ duplicate of.
157
+ - If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
158
+ """,
159
+ ),
160
+ ]
161
+
162
+
92
163
  def node_list(context: dict[str, Any]) -> list[Message]:
93
164
  return [
94
165
  Message(
@@ -126,4 +197,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
126
197
  ]
127
198
 
128
199
 
129
- versions: Versions = {'node': node, 'node_list': node_list}
200
+ versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
@@ -48,11 +48,13 @@ class MissingFacts(BaseModel):
48
48
  class Prompt(Protocol):
49
49
  edge: PromptVersion
50
50
  reflexion: PromptVersion
51
+ extract_attributes: PromptVersion
51
52
 
52
53
 
53
54
  class Versions(TypedDict):
54
55
  edge: PromptFunction
55
56
  reflexion: PromptFunction
57
+ extract_attributes: PromptFunction
56
58
 
57
59
 
58
60
  def edge(context: dict[str, Any]) -> list[Message]:
@@ -82,12 +84,18 @@ def edge(context: dict[str, Any]) -> list[Message]:
82
84
  {context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
83
85
  </REFERENCE_TIME>
84
86
 
87
+ <FACT TYPES>
88
+ {context['edge_types']}
89
+ </FACT TYPES>
90
+
85
91
  # TASK
86
92
  Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
87
93
  Only extract facts that:
88
94
  - involve two DISTINCT ENTITIES from the ENTITIES list,
89
95
  - are clearly stated or unambiguously implied in the CURRENT MESSAGE,
90
- - and can be represented as edges in a knowledge graph.
96
+ and can be represented as edges in a knowledge graph.
97
+ - The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
98
+ could be classified into one of the provided fact types
91
99
 
92
100
  You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
93
101
 
@@ -145,4 +153,40 @@ determine if any facts haven't been extracted.
145
153
  ]
146
154
 
147
155
 
148
- versions: Versions = {'edge': edge, 'reflexion': reflexion}
156
+ def extract_attributes(context: dict[str, Any]) -> list[Message]:
157
+ return [
158
+ Message(
159
+ role='system',
160
+ content='You are a helpful assistant that extracts fact properties from the provided text.',
161
+ ),
162
+ Message(
163
+ role='user',
164
+ content=f"""
165
+
166
+ <MESSAGE>
167
+ {json.dumps(context['episode_content'], indent=2)}
168
+ </MESSAGE>
169
+ <REFERENCE TIME>
170
+ {context['reference_time']}
171
+ </REFERENCE TIME>
172
+
173
+ Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
174
+ in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
175
+
176
+ Guidelines:
177
+ 1. Do not hallucinate entity property values if they cannot be found in the current context.
178
+ 2. Only use the provided MESSAGES and FACT to set attribute values.
179
+
180
+ <FACT>
181
+ {context['fact']}
182
+ </FACT>
183
+ """,
184
+ ),
185
+ ]
186
+
187
+
188
+ versions: Versions = {
189
+ 'edge': edge,
190
+ 'reflexion': reflexion,
191
+ 'extract_attributes': extract_attributes,
192
+ }
@@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion
24
24
  class InvalidatedEdges(BaseModel):
25
25
  contradicted_facts: list[int] = Field(
26
26
  ...,
27
- description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.',
27
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
28
28
  )
29
29
 
30
30
 
@@ -174,7 +174,8 @@ async def edge_fulltext_search(
174
174
  r.episodes AS episodes,
175
175
  r.expired_at AS expired_at,
176
176
  r.valid_at AS valid_at,
177
- r.invalid_at AS invalid_at
177
+ r.invalid_at AS invalid_at,
178
+ properties(r) AS attributes
178
179
  ORDER BY score DESC LIMIT $limit
179
180
  """
180
181
  )
@@ -243,7 +244,8 @@ async def edge_similarity_search(
243
244
  r.episodes AS episodes,
244
245
  r.expired_at AS expired_at,
245
246
  r.valid_at AS valid_at,
246
- r.invalid_at AS invalid_at
247
+ r.invalid_at AS invalid_at,
248
+ properties(r) AS attributes
247
249
  ORDER BY score DESC
248
250
  LIMIT $limit
249
251
  """
@@ -301,7 +303,8 @@ async def edge_bfs_search(
301
303
  r.episodes AS episodes,
302
304
  r.expired_at AS expired_at,
303
305
  r.valid_at AS valid_at,
304
- r.invalid_at AS invalid_at
306
+ r.invalid_at AS invalid_at,
307
+ properties(r) AS attributes
305
308
  LIMIT $limit
306
309
  """
307
310
  )
@@ -337,10 +340,10 @@ async def node_fulltext_search(
337
340
 
338
341
  query = (
339
342
  """
340
- CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
341
- YIELD node AS n, score
342
- WHERE n:Entity
343
- """
343
+ CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
344
+ YIELD node AS n, score
345
+ WHERE n:Entity
346
+ """
344
347
  + filter_query
345
348
  + ENTITY_NODE_RETURN
346
349
  + """
@@ -771,7 +774,8 @@ async def get_relevant_edges(
771
774
  episodes: e.episodes,
772
775
  expired_at: e.expired_at,
773
776
  valid_at: e.valid_at,
774
- invalid_at: e.invalid_at
777
+ invalid_at: e.invalid_at,
778
+ attributes: properties(e)
775
779
  })[..$limit] AS matches
776
780
  """
777
781
  )
@@ -837,7 +841,8 @@ async def get_edge_invalidation_candidates(
837
841
  episodes: e.episodes,
838
842
  expired_at: e.expired_at,
839
843
  valid_at: e.valid_at,
840
- invalid_at: e.invalid_at
844
+ invalid_at: e.invalid_at,
845
+ attributes: properties(e)
841
846
  })[..$limit] AS matches
842
847
  """
843
848
  )
@@ -137,16 +137,34 @@ async def add_nodes_and_edges_bulk_tx(
137
137
  entity_data['labels'] = list(set(node.labels + ['Entity']))
138
138
  nodes.append(entity_data)
139
139
 
140
+ edges: list[dict[str, Any]] = []
140
141
  for edge in entity_edges:
141
142
  if edge.fact_embedding is None:
142
143
  await edge.generate_embedding(embedder)
144
+ edge_data: dict[str, Any] = {
145
+ 'uuid': edge.uuid,
146
+ 'source_node_uuid': edge.source_node_uuid,
147
+ 'target_node_uuid': edge.target_node_uuid,
148
+ 'name': edge.name,
149
+ 'fact': edge.fact,
150
+ 'fact_embedding': edge.fact_embedding,
151
+ 'group_id': edge.group_id,
152
+ 'episodes': edge.episodes,
153
+ 'created_at': edge.created_at,
154
+ 'expired_at': edge.expired_at,
155
+ 'valid_at': edge.valid_at,
156
+ 'invalid_at': edge.invalid_at,
157
+ }
158
+
159
+ edge_data.update(edge.attributes or {})
160
+ edges.append(edge_data)
143
161
 
144
162
  await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
145
163
  await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
146
164
  await tx.run(
147
165
  EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
148
166
  )
149
- await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
167
+ await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
150
168
 
151
169
 
152
170
  async def extract_nodes_and_edges_bulk(
@@ -18,6 +18,8 @@ import logging
18
18
  from datetime import datetime
19
19
  from time import time
20
20
 
21
+ from pydantic import BaseModel
22
+
21
23
  from graphiti_core.edges import (
22
24
  CommunityEdge,
23
25
  EntityEdge,
@@ -35,9 +37,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
35
37
  from graphiti_core.search.search_filters import SearchFilters
36
38
  from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
37
39
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
38
- from graphiti_core.utils.maintenance.temporal_operations import (
39
- get_edge_contradictions,
40
- )
41
40
 
42
41
  logger = logging.getLogger(__name__)
43
42
 
@@ -86,6 +85,7 @@ async def extract_edges(
86
85
  nodes: list[EntityNode],
87
86
  previous_episodes: list[EpisodicNode],
88
87
  group_id: str = '',
88
+ edge_types: dict[str, BaseModel] | None = None,
89
89
  ) -> list[EntityEdge]:
90
90
  start = time()
91
91
 
@@ -94,12 +94,25 @@ async def extract_edges(
94
94
 
95
95
  node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
96
96
 
97
+ edge_types_context = (
98
+ [
99
+ {
100
+ 'fact_type_name': type_name,
101
+ 'fact_type_description': type_model.__doc__,
102
+ }
103
+ for type_name, type_model in edge_types.items()
104
+ ]
105
+ if edge_types is not None
106
+ else []
107
+ )
108
+
97
109
  # Prepare context for LLM
98
110
  context = {
99
111
  'episode_content': episode.content,
100
112
  'nodes': [node.name for node in nodes],
101
113
  'previous_episodes': [ep.content for ep in previous_episodes],
102
114
  'reference_time': episode.valid_at,
115
+ 'edge_types': edge_types_context,
103
116
  'custom_prompt': '',
104
117
  }
105
118
 
@@ -236,6 +249,9 @@ async def resolve_extracted_edges(
236
249
  clients: GraphitiClients,
237
250
  extracted_edges: list[EntityEdge],
238
251
  episode: EpisodicNode,
252
+ entities: list[EntityNode],
253
+ edge_types: dict[str, BaseModel],
254
+ edge_type_map: dict[tuple[str, str], list[str]],
239
255
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
240
256
  driver = clients.driver
241
257
  llm_client = clients.llm_client
@@ -245,7 +261,7 @@ async def resolve_extracted_edges(
245
261
 
246
262
  search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
247
263
  get_relevant_edges(driver, extracted_edges, SearchFilters()),
248
- get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
264
+ get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
249
265
  )
250
266
 
251
267
  related_edges_lists, edge_invalidation_candidates = search_results
@@ -254,15 +270,50 @@ async def resolve_extracted_edges(
254
270
  f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
255
271
  )
256
272
 
273
+ # Build entity hash table
274
+ uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
275
+
276
+ # Determine which edge types are relevant for each edge
277
+ edge_types_lst: list[dict[str, BaseModel]] = []
278
+ for extracted_edge in extracted_edges:
279
+ source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels
280
+ target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels
281
+ label_tuples = [
282
+ (source_label, target_label)
283
+ for source_label in source_node_labels
284
+ for target_label in target_node_labels
285
+ ]
286
+
287
+ extracted_edge_types = {}
288
+ for label_tuple in label_tuples:
289
+ type_names = edge_type_map.get(label_tuple, [])
290
+ for type_name in type_names:
291
+ type_model = edge_types.get(type_name)
292
+ if type_model is None:
293
+ continue
294
+
295
+ extracted_edge_types[type_name] = type_model
296
+
297
+ edge_types_lst.append(extracted_edge_types)
298
+
257
299
  # resolve edges with related edges in the graph and find invalidation candidates
258
300
  results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
259
301
  await semaphore_gather(
260
302
  *[
261
303
  resolve_extracted_edge(
262
- llm_client, extracted_edge, related_edges, existing_edges, episode
304
+ llm_client,
305
+ extracted_edge,
306
+ related_edges,
307
+ existing_edges,
308
+ episode,
309
+ extracted_edge_types,
263
310
  )
264
- for extracted_edge, related_edges, existing_edges in zip(
265
- extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True
311
+ for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
312
+ extracted_edges,
313
+ related_edges_lists,
314
+ edge_invalidation_candidates,
315
+ edge_types_lst,
316
+ strict=True,
266
317
  )
267
318
  ]
268
319
  )
@@ -326,10 +377,86 @@ async def resolve_extracted_edge(
326
377
  related_edges: list[EntityEdge],
327
378
  existing_edges: list[EntityEdge],
328
379
  episode: EpisodicNode,
380
+ edge_types: dict[str, BaseModel] | None = None,
329
381
  ) -> tuple[EntityEdge, list[EntityEdge]]:
330
- resolved_edge, invalidation_candidates = await semaphore_gather(
331
- dedupe_extracted_edge(llm_client, extracted_edge, related_edges, episode),
332
- get_edge_contradictions(llm_client, extracted_edge, existing_edges),
382
+ if len(related_edges) == 0 and len(existing_edges) == 0:
383
+ return extracted_edge, []
384
+
385
+ start = time()
386
+
387
+ # Prepare context for LLM
388
+ related_edges_context = [
389
+ {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
390
+ ]
391
+
392
+ invalidation_edge_candidates_context = [
393
+ {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
394
+ ]
395
+
396
+ edge_types_context = (
397
+ [
398
+ {
399
+ 'fact_type_id': i,
400
+ 'fact_type_name': type_name,
401
+ 'fact_type_description': type_model.__doc__,
402
+ }
403
+ for i, (type_name, type_model) in enumerate(edge_types.items())
404
+ ]
405
+ if edge_types is not None
406
+ else []
407
+ )
408
+
409
+ context = {
410
+ 'existing_edges': related_edges_context,
411
+ 'new_edge': extracted_edge.fact,
412
+ 'edge_invalidation_candidates': invalidation_edge_candidates_context,
413
+ 'edge_types': edge_types_context,
414
+ }
415
+
416
+ llm_response = await llm_client.generate_response(
417
+ prompt_library.dedupe_edges.resolve_edge(context),
418
+ response_model=EdgeDuplicate,
419
+ model_size=ModelSize.small,
420
+ )
421
+
422
+ duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
423
+
424
+ resolved_edge = (
425
+ related_edges[duplicate_fact_id]
426
+ if 0 <= duplicate_fact_id < len(related_edges)
427
+ else extracted_edge
428
+ )
429
+
430
+ if duplicate_fact_id >= 0 and episode is not None:
431
+ resolved_edge.episodes.append(episode.uuid)
432
+
433
+ contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
434
+
435
+ invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
436
+
437
+ fact_type: str = str(llm_response.get('fact_type'))
438
+ if fact_type.upper() != 'DEFAULT' and edge_types is not None:
439
+ resolved_edge.name = fact_type
440
+
441
+ edge_attributes_context = {
442
+ 'message': episode.content,
443
+ 'reference_time': episode.valid_at,
444
+ 'fact': resolved_edge.fact,
445
+ }
446
+
447
+ edge_model = edge_types.get(fact_type)
448
+
449
+ edge_attributes_response = await llm_client.generate_response(
450
+ prompt_library.extract_edges.extract_attributes(edge_attributes_context),
451
+ response_model=edge_model, # type: ignore
452
+ model_size=ModelSize.small,
453
+ )
454
+
455
+ resolved_edge.attributes = edge_attributes_response
456
+
457
+ end = time()
458
+ logger.debug(
459
+ f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
333
460
  )
334
461
 
335
462
  now = utc_now()
@@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
29
29
  from graphiti_core.llm_client.config import ModelSize
30
30
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
31
31
  from graphiti_core.prompts import prompt_library
32
- from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
32
+ from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
33
33
  from graphiti_core.prompts.extract_nodes import (
34
34
  ExtractedEntities,
35
35
  ExtractedEntity,
@@ -243,28 +243,65 @@ async def resolve_extracted_nodes(
243
243
 
244
244
  existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
245
245
 
246
- resolved_nodes: list[EntityNode] = await semaphore_gather(
247
- *[
248
- resolve_extracted_node(
249
- llm_client,
250
- extracted_node,
251
- existing_nodes,
252
- episode,
253
- previous_episodes,
254
- entity_types.get(
255
- next((item for item in extracted_node.labels if item != 'Entity'), '')
256
- )
257
- if entity_types is not None
258
- else None,
259
- )
260
- for extracted_node, existing_nodes in zip(
261
- extracted_nodes, existing_nodes_lists, strict=True
262
- )
263
- ]
246
+ entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
247
+
248
+ # Prepare context for LLM
249
+ extracted_nodes_context = [
250
+ {
251
+ 'id': i,
252
+ 'name': node.name,
253
+ 'entity_type': node.labels,
254
+ 'entity_type_description': entity_types_dict.get(
255
+ next((item for item in node.labels if item != 'Entity'), '')
256
+ ).__doc__
257
+ or 'Default Entity Type',
258
+ 'duplication_candidates': [
259
+ {
260
+ **{
261
+ 'idx': j,
262
+ 'name': candidate.name,
263
+ 'entity_types': candidate.labels,
264
+ },
265
+ **candidate.attributes,
266
+ }
267
+ for j, candidate in enumerate(existing_nodes_lists[i])
268
+ ],
269
+ }
270
+ for i, node in enumerate(extracted_nodes)
271
+ ]
272
+
273
+ context = {
274
+ 'extracted_nodes': extracted_nodes_context,
275
+ 'episode_content': episode.content if episode is not None else '',
276
+ 'previous_episodes': [ep.content for ep in previous_episodes]
277
+ if previous_episodes is not None
278
+ else [],
279
+ }
280
+
281
+ llm_response = await llm_client.generate_response(
282
+ prompt_library.dedupe_nodes.nodes(context),
283
+ response_model=NodeResolutions,
264
284
  )
265
285
 
286
+ node_resolutions: list = llm_response.get('entity_resolutions', [])
287
+
288
+ resolved_nodes: list[EntityNode] = []
266
289
  uuid_map: dict[str, str] = {}
267
- for extracted_node, resolved_node in zip(extracted_nodes, resolved_nodes, strict=True):
290
+ for resolution in node_resolutions:
291
+ resolution_id = resolution.get('id', -1)
292
+ duplicate_idx = resolution.get('duplicate_idx', -1)
293
+
294
+ extracted_node = extracted_nodes[resolution_id]
295
+
296
+ resolved_node = (
297
+ existing_nodes_lists[resolution_id][duplicate_idx]
298
+ if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
299
+ else extracted_node
300
+ )
301
+
302
+ resolved_node.name = resolution.get('name')
303
+
304
+ resolved_nodes.append(resolved_node)
268
305
  uuid_map[extracted_node.uuid] = resolved_node.uuid
269
306
 
270
307
  logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
@@ -410,6 +447,7 @@ async def extract_attributes_from_node(
410
447
  llm_response = await llm_client.generate_response(
411
448
  prompt_library.extract_nodes.extract_attributes(summary_context),
412
449
  response_model=entity_attributes_model,
450
+ model_size=ModelSize.small,
413
451
  )
414
452
 
415
453
  node.summary = llm_response.get('summary', node.summary)
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
  name = "graphiti-core"
3
3
  description = "A temporal graph building library"
4
- version = "0.11.6pre9"
4
+ version = "0.12.0pre1"
5
5
  authors = [
6
6
  { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
7
7
  { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },