graphiti-core 0.11.6rc9__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (33) hide show
  1. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  2. graphiti_core/driver/__init__.py +17 -0
  3. graphiti_core/driver/driver.py +66 -0
  4. graphiti_core/driver/falkordb_driver.py +132 -0
  5. graphiti_core/driver/neo4j_driver.py +61 -0
  6. graphiti_core/edges.py +66 -40
  7. graphiti_core/embedder/azure_openai.py +64 -0
  8. graphiti_core/embedder/gemini.py +14 -3
  9. graphiti_core/graph_queries.py +149 -0
  10. graphiti_core/graphiti.py +41 -14
  11. graphiti_core/graphiti_types.py +2 -2
  12. graphiti_core/helpers.py +9 -4
  13. graphiti_core/llm_client/__init__.py +16 -0
  14. graphiti_core/llm_client/azure_openai_client.py +73 -0
  15. graphiti_core/llm_client/gemini_client.py +4 -1
  16. graphiti_core/models/edges/edge_db_queries.py +2 -4
  17. graphiti_core/nodes.py +31 -31
  18. graphiti_core/prompts/dedupe_edges.py +52 -1
  19. graphiti_core/prompts/dedupe_nodes.py +79 -4
  20. graphiti_core/prompts/extract_edges.py +50 -5
  21. graphiti_core/prompts/invalidate_edges.py +1 -1
  22. graphiti_core/search/search.py +6 -10
  23. graphiti_core/search/search_filters.py +23 -9
  24. graphiti_core/search/search_utils.py +250 -189
  25. graphiti_core/utils/bulk_utils.py +38 -11
  26. graphiti_core/utils/maintenance/community_operations.py +6 -7
  27. graphiti_core/utils/maintenance/edge_operations.py +149 -19
  28. graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
  29. graphiti_core/utils/maintenance/node_operations.py +52 -71
  30. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
  31. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
  32. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
  33. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
graphiti_core/nodes.py CHANGED
@@ -22,13 +22,13 @@ from time import time
22
22
  from typing import Any
23
23
  from uuid import uuid4
24
24
 
25
- from neo4j import AsyncDriver
26
25
  from pydantic import BaseModel, Field
27
26
  from typing_extensions import LiteralString
28
27
 
28
+ from graphiti_core.driver.driver import GraphDriver
29
29
  from graphiti_core.embedder import EmbedderClient
30
30
  from graphiti_core.errors import NodeNotFoundError
31
- from graphiti_core.helpers import DEFAULT_DATABASE
31
+ from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
32
32
  from graphiti_core.models.nodes.node_db_queries import (
33
33
  COMMUNITY_NODE_SAVE,
34
34
  ENTITY_NODE_SAVE,
@@ -94,9 +94,9 @@ class Node(BaseModel, ABC):
94
94
  created_at: datetime = Field(default_factory=lambda: utc_now())
95
95
 
96
96
  @abstractmethod
97
- async def save(self, driver: AsyncDriver): ...
97
+ async def save(self, driver: GraphDriver): ...
98
98
 
99
- async def delete(self, driver: AsyncDriver):
99
+ async def delete(self, driver: GraphDriver):
100
100
  result = await driver.execute_query(
101
101
  """
102
102
  MATCH (n:Entity|Episodic|Community {uuid: $uuid})
@@ -119,7 +119,7 @@ class Node(BaseModel, ABC):
119
119
  return False
120
120
 
121
121
  @classmethod
122
- async def delete_by_group_id(cls, driver: AsyncDriver, group_id: str):
122
+ async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
123
123
  await driver.execute_query(
124
124
  """
125
125
  MATCH (n:Entity|Episodic|Community {group_id: $group_id})
@@ -132,10 +132,10 @@ class Node(BaseModel, ABC):
132
132
  return 'SUCCESS'
133
133
 
134
134
  @classmethod
135
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
135
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
136
136
 
137
137
  @classmethod
138
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
138
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
139
139
 
140
140
 
141
141
  class EpisodicNode(Node):
@@ -150,7 +150,7 @@ class EpisodicNode(Node):
150
150
  default_factory=list,
151
151
  )
152
152
 
153
- async def save(self, driver: AsyncDriver):
153
+ async def save(self, driver: GraphDriver):
154
154
  result = await driver.execute_query(
155
155
  EPISODIC_NODE_SAVE,
156
156
  uuid=self.uuid,
@@ -165,12 +165,12 @@ class EpisodicNode(Node):
165
165
  database_=DEFAULT_DATABASE,
166
166
  )
167
167
 
168
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
168
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
169
169
 
170
170
  return result
171
171
 
172
172
  @classmethod
173
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
173
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
174
174
  records, _, _ = await driver.execute_query(
175
175
  """
176
176
  MATCH (e:Episodic {uuid: $uuid})
@@ -197,7 +197,7 @@ class EpisodicNode(Node):
197
197
  return episodes[0]
198
198
 
199
199
  @classmethod
200
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
200
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
201
201
  records, _, _ = await driver.execute_query(
202
202
  """
203
203
  MATCH (e:Episodic) WHERE e.uuid IN $uuids
@@ -224,7 +224,7 @@ class EpisodicNode(Node):
224
224
  @classmethod
225
225
  async def get_by_group_ids(
226
226
  cls,
227
- driver: AsyncDriver,
227
+ driver: GraphDriver,
228
228
  group_ids: list[str],
229
229
  limit: int | None = None,
230
230
  uuid_cursor: str | None = None,
@@ -263,7 +263,7 @@ class EpisodicNode(Node):
263
263
  return episodes
264
264
 
265
265
  @classmethod
266
- async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: str):
266
+ async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
267
267
  records, _, _ = await driver.execute_query(
268
268
  """
269
269
  MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
@@ -304,7 +304,7 @@ class EntityNode(Node):
304
304
 
305
305
  return self.name_embedding
306
306
 
307
- async def load_name_embedding(self, driver: AsyncDriver):
307
+ async def load_name_embedding(self, driver: GraphDriver):
308
308
  query: LiteralString = """
309
309
  MATCH (n:Entity {uuid: $uuid})
310
310
  RETURN n.name_embedding AS name_embedding
@@ -318,7 +318,7 @@ class EntityNode(Node):
318
318
 
319
319
  self.name_embedding = records[0]['name_embedding']
320
320
 
321
- async def save(self, driver: AsyncDriver):
321
+ async def save(self, driver: GraphDriver):
322
322
  entity_data: dict[str, Any] = {
323
323
  'uuid': self.uuid,
324
324
  'name': self.name,
@@ -337,16 +337,16 @@ class EntityNode(Node):
337
337
  database_=DEFAULT_DATABASE,
338
338
  )
339
339
 
340
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
340
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
341
341
 
342
342
  return result
343
343
 
344
344
  @classmethod
345
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
345
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
346
346
  query = (
347
347
  """
348
- MATCH (n:Entity {uuid: $uuid})
349
- """
348
+ MATCH (n:Entity {uuid: $uuid})
349
+ """
350
350
  + ENTITY_NODE_RETURN
351
351
  )
352
352
  records, _, _ = await driver.execute_query(
@@ -364,7 +364,7 @@ class EntityNode(Node):
364
364
  return nodes[0]
365
365
 
366
366
  @classmethod
367
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
367
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
368
368
  records, _, _ = await driver.execute_query(
369
369
  """
370
370
  MATCH (n:Entity) WHERE n.uuid IN $uuids
@@ -382,7 +382,7 @@ class EntityNode(Node):
382
382
  @classmethod
383
383
  async def get_by_group_ids(
384
384
  cls,
385
- driver: AsyncDriver,
385
+ driver: GraphDriver,
386
386
  group_ids: list[str],
387
387
  limit: int | None = None,
388
388
  uuid_cursor: str | None = None,
@@ -416,7 +416,7 @@ class CommunityNode(Node):
416
416
  name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
417
417
  summary: str = Field(description='region summary of member nodes', default_factory=str)
418
418
 
419
- async def save(self, driver: AsyncDriver):
419
+ async def save(self, driver: GraphDriver):
420
420
  result = await driver.execute_query(
421
421
  COMMUNITY_NODE_SAVE,
422
422
  uuid=self.uuid,
@@ -428,7 +428,7 @@ class CommunityNode(Node):
428
428
  database_=DEFAULT_DATABASE,
429
429
  )
430
430
 
431
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
431
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
432
432
 
433
433
  return result
434
434
 
@@ -441,7 +441,7 @@ class CommunityNode(Node):
441
441
 
442
442
  return self.name_embedding
443
443
 
444
- async def load_name_embedding(self, driver: AsyncDriver):
444
+ async def load_name_embedding(self, driver: GraphDriver):
445
445
  query: LiteralString = """
446
446
  MATCH (c:Community {uuid: $uuid})
447
447
  RETURN c.name_embedding AS name_embedding
@@ -456,7 +456,7 @@ class CommunityNode(Node):
456
456
  self.name_embedding = records[0]['name_embedding']
457
457
 
458
458
  @classmethod
459
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
459
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
460
460
  records, _, _ = await driver.execute_query(
461
461
  """
462
462
  MATCH (n:Community {uuid: $uuid})
@@ -480,7 +480,7 @@ class CommunityNode(Node):
480
480
  return nodes[0]
481
481
 
482
482
  @classmethod
483
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
483
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
484
484
  records, _, _ = await driver.execute_query(
485
485
  """
486
486
  MATCH (n:Community) WHERE n.uuid IN $uuids
@@ -503,7 +503,7 @@ class CommunityNode(Node):
503
503
  @classmethod
504
504
  async def get_by_group_ids(
505
505
  cls,
506
- driver: AsyncDriver,
506
+ driver: GraphDriver,
507
507
  group_ids: list[str],
508
508
  limit: int | None = None,
509
509
  uuid_cursor: str | None = None,
@@ -542,8 +542,8 @@ class CommunityNode(Node):
542
542
  def get_episodic_node_from_record(record: Any) -> EpisodicNode:
543
543
  return EpisodicNode(
544
544
  content=record['content'],
545
- created_at=record['created_at'].to_native().timestamp(),
546
- valid_at=(record['valid_at'].to_native()),
545
+ created_at=parse_db_date(record['created_at']), # type: ignore
546
+ valid_at=parse_db_date(record['valid_at']), # type: ignore
547
547
  uuid=record['uuid'],
548
548
  group_id=record['group_id'],
549
549
  source=EpisodeType.from_str(record['source']),
@@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
559
559
  name=record['name'],
560
560
  group_id=record['group_id'],
561
561
  labels=record['labels'],
562
- created_at=record['created_at'].to_native(),
562
+ created_at=parse_db_date(record['created_at']), # type: ignore
563
563
  summary=record['summary'],
564
564
  attributes=record['attributes'],
565
565
  )
@@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
580
580
  name=record['name'],
581
581
  group_id=record['group_id'],
582
582
  name_embedding=record['name_embedding'],
583
- created_at=record['created_at'].to_native(),
583
+ created_at=parse_db_date(record['created_at']), # type: ignore
584
584
  summary=record['summary'],
585
585
  )
586
586
 
@@ -27,6 +27,11 @@ class EdgeDuplicate(BaseModel):
27
27
  ...,
28
28
  description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
29
29
  )
30
+ contradicted_facts: list[int] = Field(
31
+ ...,
32
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
33
+ )
34
+ fact_type: str = Field(..., description='One of the provided fact types or DEFAULT')
30
35
 
31
36
 
32
37
  class UniqueFact(BaseModel):
@@ -41,11 +46,13 @@ class UniqueFacts(BaseModel):
41
46
  class Prompt(Protocol):
42
47
  edge: PromptVersion
43
48
  edge_list: PromptVersion
49
+ resolve_edge: PromptVersion
44
50
 
45
51
 
46
52
  class Versions(TypedDict):
47
53
  edge: PromptFunction
48
54
  edge_list: PromptFunction
55
+ resolve_edge: PromptFunction
49
56
 
50
57
 
51
58
  def edge(context: dict[str, Any]) -> list[Message]:
@@ -106,4 +113,48 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
106
113
  ]
107
114
 
108
115
 
109
- versions: Versions = {'edge': edge, 'edge_list': edge_list}
116
+ def resolve_edge(context: dict[str, Any]) -> list[Message]:
117
+ return [
118
+ Message(
119
+ role='system',
120
+ content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
121
+ 'facts are contradicted by the new fact.',
122
+ ),
123
+ Message(
124
+ role='user',
125
+ content=f"""
126
+ <NEW FACT>
127
+ {context['new_edge']}
128
+ </NEW FACT>
129
+
130
+ <EXISTING FACTS>
131
+ {context['existing_edges']}
132
+ </EXISTING FACTS>
133
+ <FACT INVALIDATION CANDIDATES>
134
+ {context['edge_invalidation_candidates']}
135
+ </FACT INVALIDATION CANDIDATES>
136
+
137
+ <FACT TYPES>
138
+ {context['edge_types']}
139
+ </FACT TYPES>
140
+
141
+
142
+ Task:
143
+ If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
144
+ If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
145
+
146
+ Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
147
+ Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
148
+
149
+ Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
150
+ Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
151
+ If there are no contradicted facts, return an empty list.
152
+
153
+ Guidelines:
154
+ 1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
155
+ """,
156
+ ),
157
+ ]
158
+
159
+
160
+ versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}
@@ -23,21 +23,31 @@ from .models import Message, PromptFunction, PromptVersion
23
23
 
24
24
 
25
25
  class NodeDuplicate(BaseModel):
26
- duplicate_node_id: int = Field(
26
+ id: int = Field(..., description='integer id of the entity')
27
+ duplicate_idx: int = Field(
27
28
  ...,
28
- description='id of the duplicate node. If no duplicate nodes are found, default to -1.',
29
+ description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
29
30
  )
30
- name: str = Field(..., description='Name of the entity.')
31
+ name: str = Field(
32
+ ...,
33
+ description='Name of the entity. Should be the most complete and descriptive name possible.',
34
+ )
35
+
36
+
37
+ class NodeResolutions(BaseModel):
38
+ entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
31
39
 
32
40
 
33
41
  class Prompt(Protocol):
34
42
  node: PromptVersion
35
43
  node_list: PromptVersion
44
+ nodes: PromptVersion
36
45
 
37
46
 
38
47
  class Versions(TypedDict):
39
48
  node: PromptFunction
40
49
  node_list: PromptFunction
50
+ nodes: PromptFunction
41
51
 
42
52
 
43
53
  def node(context: dict[str, Any]) -> list[Message]:
@@ -89,6 +99,71 @@ def node(context: dict[str, Any]) -> list[Message]:
89
99
  ]
90
100
 
91
101
 
102
+ def nodes(context: dict[str, Any]) -> list[Message]:
103
+ return [
104
+ Message(
105
+ role='system',
106
+ content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
107
+ 'of existing entities.',
108
+ ),
109
+ Message(
110
+ role='user',
111
+ content=f"""
112
+ <PREVIOUS MESSAGES>
113
+ {json.dumps([ep for ep in context['previous_episodes']], indent=2)}
114
+ </PREVIOUS MESSAGES>
115
+ <CURRENT MESSAGE>
116
+ {context['episode_content']}
117
+ </CURRENT MESSAGE>
118
+
119
+
120
+ Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
121
+ Each entity in ENTITIES is represented as a JSON object with the following structure:
122
+ {{
123
+ id: integer id of the entity,
124
+ name: "name of the entity",
125
+ entity_type: "ontological classification of the entity",
126
+ entity_type_description: "Description of what the entity type represents",
127
+ duplication_candidates: [
128
+ {{
129
+ idx: integer index of the candidate entity,
130
+ name: "name of the candidate entity",
131
+ entity_type: "ontological classification of the candidate entity",
132
+ ...<additional attributes>
133
+ }}
134
+ ]
135
+ }}
136
+
137
+ <ENTITIES>
138
+ {json.dumps(context['extracted_nodes'], indent=2)}
139
+ </ENTITIES>
140
+
141
+ <EXISTING ENTITIES>
142
+ {json.dumps(context['existing_nodes'], indent=2)}
143
+ </EXISTING ENTITIES>
144
+
145
+ For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
146
+
147
+ Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
148
+
149
+ Do NOT mark entities as duplicates if:
150
+ - They are related but distinct.
151
+ - They have similar names or purposes but refer to separate instances or concepts.
152
+
153
+ Task:
154
+ Your response will be a list called entity_resolutions which contains one entry for each entity.
155
+
156
+ For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
157
+ as an integer.
158
+
159
+ - If an entity is a duplicate of one of the EXISTING ENTITIES, return the idx of the candidate it is a
160
+ duplicate of.
161
+ - If an entity is not a duplicate of one of the EXISTING ENTITIES, return the -1 as the duplication_idx
162
+ """,
163
+ ),
164
+ ]
165
+
166
+
92
167
  def node_list(context: dict[str, Any]) -> list[Message]:
93
168
  return [
94
169
  Message(
@@ -126,4 +201,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
126
201
  ]
127
202
 
128
203
 
129
- versions: Versions = {'node': node, 'node_list': node_list}
204
+ versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
@@ -24,8 +24,8 @@ from .models import Message, PromptFunction, PromptVersion
24
24
 
25
25
  class Edge(BaseModel):
26
26
  relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
27
- source_entity_name: str = Field(..., description='The name of the source entity of the fact.')
28
- target_entity_name: str = Field(..., description='The name of the target entity of the fact.')
27
+ source_entity_id: int = Field(..., description='The id of the source entity of the fact.')
28
+ target_entity_id: int = Field(..., description='The id of the target entity of the fact.')
29
29
  fact: str = Field(..., description='')
30
30
  valid_at: str | None = Field(
31
31
  None,
@@ -48,11 +48,13 @@ class MissingFacts(BaseModel):
48
48
  class Prompt(Protocol):
49
49
  edge: PromptVersion
50
50
  reflexion: PromptVersion
51
+ extract_attributes: PromptVersion
51
52
 
52
53
 
53
54
  class Versions(TypedDict):
54
55
  edge: PromptFunction
55
56
  reflexion: PromptFunction
57
+ extract_attributes: PromptFunction
56
58
 
57
59
 
58
60
  def edge(context: dict[str, Any]) -> list[Message]:
@@ -75,19 +77,26 @@ def edge(context: dict[str, Any]) -> list[Message]:
75
77
  </CURRENT_MESSAGE>
76
78
 
77
79
  <ENTITIES>
78
- {context['nodes']} # Each has: id, label (e.g., Person, Org), name, aliases
80
+ {context['nodes']}
79
81
  </ENTITIES>
80
82
 
81
83
  <REFERENCE_TIME>
82
84
  {context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
83
85
  </REFERENCE_TIME>
84
86
 
87
+ <FACT TYPES>
88
+ {context['edge_types']}
89
+ </FACT TYPES>
90
+
85
91
  # TASK
86
92
  Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
87
93
  Only extract facts that:
88
94
  - involve two DISTINCT ENTITIES from the ENTITIES list,
89
95
  - are clearly stated or unambiguously implied in the CURRENT MESSAGE,
90
- - and can be represented as edges in a knowledge graph.
96
+ and can be represented as edges in a knowledge graph.
97
+ - The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
98
+ - The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
99
+ of the FACT TYPES
91
100
 
92
101
  You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
93
102
 
@@ -145,4 +154,40 @@ determine if any facts haven't been extracted.
145
154
  ]
146
155
 
147
156
 
148
- versions: Versions = {'edge': edge, 'reflexion': reflexion}
157
+ def extract_attributes(context: dict[str, Any]) -> list[Message]:
158
+ return [
159
+ Message(
160
+ role='system',
161
+ content='You are a helpful assistant that extracts fact properties from the provided text.',
162
+ ),
163
+ Message(
164
+ role='user',
165
+ content=f"""
166
+
167
+ <MESSAGE>
168
+ {json.dumps(context['episode_content'], indent=2)}
169
+ </MESSAGE>
170
+ <REFERENCE TIME>
171
+ {context['reference_time']}
172
+ </REFERENCE TIME>
173
+
174
+ Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
175
+ in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
176
+
177
+ Guidelines:
178
+ 1. Do not hallucinate entity property values if they cannot be found in the current context.
179
+ 2. Only use the provided MESSAGES and FACT to set attribute values.
180
+
181
+ <FACT>
182
+ {context['fact']}
183
+ </FACT>
184
+ """,
185
+ ),
186
+ ]
187
+
188
+
189
+ versions: Versions = {
190
+ 'edge': edge,
191
+ 'reflexion': reflexion,
192
+ 'extract_attributes': extract_attributes,
193
+ }
@@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion
24
24
  class InvalidatedEdges(BaseModel):
25
25
  contradicted_facts: list[int] = Field(
26
26
  ...,
27
- description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.',
27
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
28
28
  )
29
29
 
30
30
 
@@ -18,9 +18,8 @@ import logging
18
18
  from collections import defaultdict
19
19
  from time import time
20
20
 
21
- from neo4j import AsyncDriver
22
-
23
21
  from graphiti_core.cross_encoder.client import CrossEncoderClient
22
+ from graphiti_core.driver.driver import GraphDriver
24
23
  from graphiti_core.edges import EntityEdge
25
24
  from graphiti_core.errors import SearchRerankerError
26
25
  from graphiti_core.graphiti_types import GraphitiClients
@@ -94,7 +93,7 @@ async def search(
94
93
  )
95
94
 
96
95
  # if group_ids is empty, set it to None
97
- group_ids = group_ids if group_ids else None
96
+ group_ids = group_ids if group_ids and group_ids != [''] else None
98
97
  edges, nodes, episodes, communities = await semaphore_gather(
99
98
  edge_search(
100
99
  driver,
@@ -160,7 +159,7 @@ async def search(
160
159
 
161
160
 
162
161
  async def edge_search(
163
- driver: AsyncDriver,
162
+ driver: GraphDriver,
164
163
  cross_encoder: CrossEncoderClient,
165
164
  query: str,
166
165
  query_vector: list[float],
@@ -174,7 +173,6 @@ async def edge_search(
174
173
  ) -> list[EntityEdge]:
175
174
  if config is None:
176
175
  return []
177
-
178
176
  search_results: list[list[EntityEdge]] = list(
179
177
  await semaphore_gather(
180
178
  *[
@@ -261,7 +259,7 @@ async def edge_search(
261
259
 
262
260
 
263
261
  async def node_search(
264
- driver: AsyncDriver,
262
+ driver: GraphDriver,
265
263
  cross_encoder: CrossEncoderClient,
266
264
  query: str,
267
265
  query_vector: list[float],
@@ -275,7 +273,6 @@ async def node_search(
275
273
  ) -> list[EntityNode]:
276
274
  if config is None:
277
275
  return []
278
-
279
276
  search_results: list[list[EntityNode]] = list(
280
277
  await semaphore_gather(
281
278
  *[
@@ -344,7 +341,7 @@ async def node_search(
344
341
 
345
342
 
346
343
  async def episode_search(
347
- driver: AsyncDriver,
344
+ driver: GraphDriver,
348
345
  cross_encoder: CrossEncoderClient,
349
346
  query: str,
350
347
  _query_vector: list[float],
@@ -356,7 +353,6 @@ async def episode_search(
356
353
  ) -> list[EpisodicNode]:
357
354
  if config is None:
358
355
  return []
359
-
360
356
  search_results: list[list[EpisodicNode]] = list(
361
357
  await semaphore_gather(
362
358
  *[
@@ -392,7 +388,7 @@ async def episode_search(
392
388
 
393
389
 
394
390
  async def community_search(
395
- driver: AsyncDriver,
391
+ driver: GraphDriver,
396
392
  cross_encoder: CrossEncoderClient,
397
393
  query: str,
398
394
  query_vector: list[float],