graphiti-core 0.12.0rc5__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -0,0 +1,149 @@
1
+ """
2
+ Database query utilities for different graph database backends.
3
+
4
+ This module provides database-agnostic query generation for Neo4j and FalkorDB,
5
+ supporting index creation, fulltext search, and bulk operations.
6
+ """
7
+
8
+ from typing import Any
9
+
10
+ from typing_extensions import LiteralString
11
+
12
+ from graphiti_core.models.edges.edge_db_queries import (
13
+ ENTITY_EDGE_SAVE_BULK,
14
+ )
15
+ from graphiti_core.models.nodes.node_db_queries import (
16
+ ENTITY_NODE_SAVE_BULK,
17
+ )
18
+
19
+ # Mapping from Neo4j fulltext index names to FalkorDB node labels
20
+ NEO4J_TO_FALKORDB_MAPPING = {
21
+ 'node_name_and_summary': 'Entity',
22
+ 'community_name': 'Community',
23
+ 'episode_content': 'Episodic',
24
+ 'edge_name_and_fact': 'RELATES_TO',
25
+ }
26
+
27
+
28
+ def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
29
+ if db_type == 'falkordb':
30
+ return [
31
+ # Entity node
32
+ 'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
33
+ # Episodic node
34
+ 'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
35
+ # Community node
36
+ 'CREATE INDEX FOR (n:Community) ON (n.uuid)',
37
+ # RELATES_TO edge
38
+ 'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
39
+ # MENTIONS edge
40
+ 'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
41
+ # HAS_MEMBER edge
42
+ 'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
43
+ ]
44
+ else:
45
+ return [
46
+ 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
47
+ 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
48
+ 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
49
+ 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
50
+ 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
51
+ 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
52
+ 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
53
+ 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
54
+ 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
55
+ 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
56
+ 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
57
+ 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
58
+ 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
59
+ 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
60
+ 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
61
+ 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
62
+ 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
63
+ 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
64
+ 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
65
+ ]
66
+
67
+
68
+ def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
69
+ if db_type == 'falkordb':
70
+ return [
71
+ """CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
72
+ """CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
73
+ """CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
74
+ """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
75
+ ]
76
+ else:
77
+ return [
78
+ """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
79
+ FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
80
+ """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
81
+ FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
82
+ """CREATE FULLTEXT INDEX community_name IF NOT EXISTS
83
+ FOR (n:Community) ON EACH [n.name, n.group_id]""",
84
+ """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
85
+ FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
86
+ ]
87
+
88
+
89
+ def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
90
+ if db_type == 'falkordb':
91
+ label = NEO4J_TO_FALKORDB_MAPPING[name]
92
+ return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
93
+ else:
94
+ return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
95
+
96
+
97
+ def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
98
+ if db_type == 'falkordb':
99
+ # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
100
+ return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
101
+ else:
102
+ return f'vector.similarity.cosine({vec1}, {vec2})'
103
+
104
+
105
+ def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
106
+ if db_type == 'falkordb':
107
+ label = NEO4J_TO_FALKORDB_MAPPING[name]
108
+ return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
109
+ else:
110
+ return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
111
+
112
+
113
+ def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
114
+ if db_type == 'falkordb':
115
+ queries = []
116
+ for node in nodes:
117
+ for label in node['labels']:
118
+ queries.append(
119
+ (
120
+ f"""
121
+ UNWIND $nodes AS node
122
+ MERGE (n:Entity {{uuid: node.uuid}})
123
+ SET n:{label}
124
+ SET n = node
125
+ WITH n, node
126
+ SET n.name_embedding = vecf32(node.name_embedding)
127
+ RETURN n.uuid AS uuid
128
+ """,
129
+ {'nodes': [node]},
130
+ )
131
+ )
132
+ return queries
133
+ else:
134
+ return ENTITY_NODE_SAVE_BULK
135
+
136
+
137
+ def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
138
+ if db_type == 'falkordb':
139
+ return """
140
+ UNWIND $entity_edges AS edge
141
+ MATCH (source:Entity {uuid: edge.source_node_uuid})
142
+ MATCH (target:Entity {uuid: edge.target_node_uuid})
143
+ MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
144
+ SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
145
+ created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
146
+ WITH r, edge
147
+ RETURN edge.uuid AS uuid"""
148
+ else:
149
+ return ENTITY_EDGE_SAVE_BULK
graphiti_core/graphiti.py CHANGED
@@ -19,12 +19,13 @@ from datetime import datetime
19
19
  from time import time
20
20
 
21
21
  from dotenv import load_dotenv
22
- from neo4j import AsyncGraphDatabase
23
22
  from pydantic import BaseModel
24
23
  from typing_extensions import LiteralString
25
24
 
26
25
  from graphiti_core.cross_encoder.client import CrossEncoderClient
27
26
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
27
+ from graphiti_core.driver.driver import GraphDriver
28
+ from graphiti_core.driver.neo4j_driver import Neo4jDriver
28
29
  from graphiti_core.edges import EntityEdge, EpisodicEdge
29
30
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
31
  from graphiti_core.graphiti_types import GraphitiClients
@@ -62,6 +63,7 @@ from graphiti_core.utils.maintenance.community_operations import (
62
63
  update_community,
63
64
  )
64
65
  from graphiti_core.utils.maintenance.edge_operations import (
66
+ build_duplicate_of_edges,
65
67
  build_episodic_edges,
66
68
  extract_edges,
67
69
  resolve_extracted_edge,
@@ -94,12 +96,13 @@ class Graphiti:
94
96
  def __init__(
95
97
  self,
96
98
  uri: str,
97
- user: str,
98
- password: str,
99
+ user: str | None = None,
100
+ password: str | None = None,
99
101
  llm_client: LLMClient | None = None,
100
102
  embedder: EmbedderClient | None = None,
101
103
  cross_encoder: CrossEncoderClient | None = None,
102
104
  store_raw_episode_content: bool = True,
105
+ graph_driver: GraphDriver | None = None,
103
106
  ):
104
107
  """
105
108
  Initialize a Graphiti instance.
@@ -137,7 +140,9 @@ class Graphiti:
137
140
  Make sure to set the OPENAI_API_KEY environment variable before initializing
138
141
  Graphiti if you're using the default OpenAIClient.
139
142
  """
140
- self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
143
+
144
+ self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
145
+
141
146
  self.database = DEFAULT_DATABASE
142
147
  self.store_raw_episode_content = store_raw_episode_content
143
148
  if llm_client:
@@ -371,7 +376,7 @@ class Graphiti:
371
376
  )
372
377
 
373
378
  # Extract edges and resolve nodes
374
- (nodes, uuid_map), extracted_edges = await semaphore_gather(
379
+ (nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather(
375
380
  resolve_extracted_nodes(
376
381
  self.clients,
377
382
  extracted_nodes,
@@ -380,7 +385,13 @@ class Graphiti:
380
385
  entity_types,
381
386
  ),
382
387
  extract_edges(
383
- self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
388
+ self.clients,
389
+ episode,
390
+ extracted_nodes,
391
+ previous_episodes,
392
+ edge_type_map or edge_type_map_default,
393
+ group_id,
394
+ edge_types,
384
395
  ),
385
396
  )
386
397
 
@@ -400,7 +411,9 @@ class Graphiti:
400
411
  ),
401
412
  )
402
413
 
403
- entity_edges = resolved_edges + invalidated_edges
414
+ duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
415
+
416
+ entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
404
417
 
405
418
  episodic_edges = build_episodic_edges(nodes, episode, now)
406
419
 
@@ -687,7 +700,7 @@ class Graphiti:
687
700
  if edge.fact_embedding is None:
688
701
  await edge.generate_embedding(self.embedder)
689
702
 
690
- resolved_nodes, uuid_map = await resolve_extracted_nodes(
703
+ resolved_nodes, uuid_map, _ = await resolve_extracted_nodes(
691
704
  self.clients,
692
705
  [source_node, target_node],
693
706
  )
@@ -14,16 +14,16 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from neo4j import AsyncDriver
18
17
  from pydantic import BaseModel, ConfigDict
19
18
 
20
19
  from graphiti_core.cross_encoder import CrossEncoderClient
20
+ from graphiti_core.driver.driver import GraphDriver
21
21
  from graphiti_core.embedder import EmbedderClient
22
22
  from graphiti_core.llm_client import LLMClient
23
23
 
24
24
 
25
25
  class GraphitiClients(BaseModel):
26
- driver: AsyncDriver
26
+ driver: GraphDriver
27
27
  llm_client: LLMClient
28
28
  embedder: EmbedderClient
29
29
  cross_encoder: CrossEncoderClient
graphiti_core/helpers.py CHANGED
@@ -27,7 +27,7 @@ from typing_extensions import LiteralString
27
27
 
28
28
  load_dotenv()
29
29
 
30
- DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
30
+ DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
31
31
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
32
32
  SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
33
33
  MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
@@ -38,8 +38,14 @@ RUNTIME_QUERY: LiteralString = (
38
38
  )
39
39
 
40
40
 
41
- def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
42
- return neo_date.to_native() if neo_date else None
41
+ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
42
+ return (
43
+ neo_date.to_native()
44
+ if isinstance(neo_date, neo4j_time.DateTime)
45
+ else datetime.fromisoformat(neo_date)
46
+ if neo_date
47
+ else None
48
+ )
43
49
 
44
50
 
45
51
  def lucene_sanitize(query: str) -> str:
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  from .client import LLMClient
2
18
  from .config import LLMConfig
3
19
  from .errors import RateLimitError
@@ -0,0 +1,73 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import json
18
+ import logging
19
+ from typing import Any
20
+
21
+ from openai import AsyncAzureOpenAI
22
+ from openai.types.chat import ChatCompletionMessageParam
23
+ from pydantic import BaseModel
24
+
25
+ from ..prompts.models import Message
26
+ from .client import LLMClient
27
+ from .config import LLMConfig, ModelSize
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class AzureOpenAILLMClient(LLMClient):
33
+ """Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
34
+
35
+ def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
36
+ super().__init__(config, cache=False)
37
+ self.azure_client = azure_client
38
+
39
+ async def _generate_response(
40
+ self,
41
+ messages: list[Message],
42
+ response_model: type[BaseModel] | None = None,
43
+ max_tokens: int = 1024,
44
+ model_size: ModelSize = ModelSize.medium,
45
+ ) -> dict[str, Any]:
46
+ """Generate response using Azure OpenAI client."""
47
+ # Convert messages to OpenAI format
48
+ openai_messages: list[ChatCompletionMessageParam] = []
49
+ for message in messages:
50
+ message.content = self._clean_input(message.content)
51
+ if message.role == 'user':
52
+ openai_messages.append({'role': 'user', 'content': message.content})
53
+ elif message.role == 'system':
54
+ openai_messages.append({'role': 'system', 'content': message.content})
55
+
56
+ # Ensure model is a string
57
+ model_name = self.model if self.model else 'gpt-4o-mini'
58
+
59
+ try:
60
+ response = await self.azure_client.chat.completions.create(
61
+ model=model_name,
62
+ messages=openai_messages,
63
+ temperature=float(self.temperature) if self.temperature is not None else 0.7,
64
+ max_tokens=max_tokens,
65
+ response_format={'type': 'json_object'},
66
+ )
67
+ result = response.choices[0].message.content or '{}'
68
+
69
+ # Parse JSON response
70
+ return json.loads(result)
71
+ except Exception as e:
72
+ logger.error(f'Error in Azure OpenAI LLM response: {e}')
73
+ raise
graphiti_core/nodes.py CHANGED
@@ -22,13 +22,13 @@ from time import time
22
22
  from typing import Any
23
23
  from uuid import uuid4
24
24
 
25
- from neo4j import AsyncDriver
26
25
  from pydantic import BaseModel, Field
27
26
  from typing_extensions import LiteralString
28
27
 
28
+ from graphiti_core.driver.driver import GraphDriver
29
29
  from graphiti_core.embedder import EmbedderClient
30
30
  from graphiti_core.errors import NodeNotFoundError
31
- from graphiti_core.helpers import DEFAULT_DATABASE
31
+ from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
32
32
  from graphiti_core.models.nodes.node_db_queries import (
33
33
  COMMUNITY_NODE_SAVE,
34
34
  ENTITY_NODE_SAVE,
@@ -94,9 +94,9 @@ class Node(BaseModel, ABC):
94
94
  created_at: datetime = Field(default_factory=lambda: utc_now())
95
95
 
96
96
  @abstractmethod
97
- async def save(self, driver: AsyncDriver): ...
97
+ async def save(self, driver: GraphDriver): ...
98
98
 
99
- async def delete(self, driver: AsyncDriver):
99
+ async def delete(self, driver: GraphDriver):
100
100
  result = await driver.execute_query(
101
101
  """
102
102
  MATCH (n:Entity|Episodic|Community {uuid: $uuid})
@@ -119,7 +119,7 @@ class Node(BaseModel, ABC):
119
119
  return False
120
120
 
121
121
  @classmethod
122
- async def delete_by_group_id(cls, driver: AsyncDriver, group_id: str):
122
+ async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
123
123
  await driver.execute_query(
124
124
  """
125
125
  MATCH (n:Entity|Episodic|Community {group_id: $group_id})
@@ -132,10 +132,10 @@ class Node(BaseModel, ABC):
132
132
  return 'SUCCESS'
133
133
 
134
134
  @classmethod
135
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
135
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
136
136
 
137
137
  @classmethod
138
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
138
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
139
139
 
140
140
 
141
141
  class EpisodicNode(Node):
@@ -150,7 +150,7 @@ class EpisodicNode(Node):
150
150
  default_factory=list,
151
151
  )
152
152
 
153
- async def save(self, driver: AsyncDriver):
153
+ async def save(self, driver: GraphDriver):
154
154
  result = await driver.execute_query(
155
155
  EPISODIC_NODE_SAVE,
156
156
  uuid=self.uuid,
@@ -165,12 +165,12 @@ class EpisodicNode(Node):
165
165
  database_=DEFAULT_DATABASE,
166
166
  )
167
167
 
168
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
168
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
169
169
 
170
170
  return result
171
171
 
172
172
  @classmethod
173
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
173
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
174
174
  records, _, _ = await driver.execute_query(
175
175
  """
176
176
  MATCH (e:Episodic {uuid: $uuid})
@@ -197,7 +197,7 @@ class EpisodicNode(Node):
197
197
  return episodes[0]
198
198
 
199
199
  @classmethod
200
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
200
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
201
201
  records, _, _ = await driver.execute_query(
202
202
  """
203
203
  MATCH (e:Episodic) WHERE e.uuid IN $uuids
@@ -224,7 +224,7 @@ class EpisodicNode(Node):
224
224
  @classmethod
225
225
  async def get_by_group_ids(
226
226
  cls,
227
- driver: AsyncDriver,
227
+ driver: GraphDriver,
228
228
  group_ids: list[str],
229
229
  limit: int | None = None,
230
230
  uuid_cursor: str | None = None,
@@ -263,7 +263,7 @@ class EpisodicNode(Node):
263
263
  return episodes
264
264
 
265
265
  @classmethod
266
- async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: str):
266
+ async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
267
267
  records, _, _ = await driver.execute_query(
268
268
  """
269
269
  MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
@@ -304,7 +304,7 @@ class EntityNode(Node):
304
304
 
305
305
  return self.name_embedding
306
306
 
307
- async def load_name_embedding(self, driver: AsyncDriver):
307
+ async def load_name_embedding(self, driver: GraphDriver):
308
308
  query: LiteralString = """
309
309
  MATCH (n:Entity {uuid: $uuid})
310
310
  RETURN n.name_embedding AS name_embedding
@@ -318,7 +318,7 @@ class EntityNode(Node):
318
318
 
319
319
  self.name_embedding = records[0]['name_embedding']
320
320
 
321
- async def save(self, driver: AsyncDriver):
321
+ async def save(self, driver: GraphDriver):
322
322
  entity_data: dict[str, Any] = {
323
323
  'uuid': self.uuid,
324
324
  'name': self.name,
@@ -337,16 +337,16 @@ class EntityNode(Node):
337
337
  database_=DEFAULT_DATABASE,
338
338
  )
339
339
 
340
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
340
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
341
341
 
342
342
  return result
343
343
 
344
344
  @classmethod
345
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
345
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
346
346
  query = (
347
347
  """
348
- MATCH (n:Entity {uuid: $uuid})
349
- """
348
+ MATCH (n:Entity {uuid: $uuid})
349
+ """
350
350
  + ENTITY_NODE_RETURN
351
351
  )
352
352
  records, _, _ = await driver.execute_query(
@@ -364,7 +364,7 @@ class EntityNode(Node):
364
364
  return nodes[0]
365
365
 
366
366
  @classmethod
367
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
367
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
368
368
  records, _, _ = await driver.execute_query(
369
369
  """
370
370
  MATCH (n:Entity) WHERE n.uuid IN $uuids
@@ -382,7 +382,7 @@ class EntityNode(Node):
382
382
  @classmethod
383
383
  async def get_by_group_ids(
384
384
  cls,
385
- driver: AsyncDriver,
385
+ driver: GraphDriver,
386
386
  group_ids: list[str],
387
387
  limit: int | None = None,
388
388
  uuid_cursor: str | None = None,
@@ -416,7 +416,7 @@ class CommunityNode(Node):
416
416
  name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
417
417
  summary: str = Field(description='region summary of member nodes', default_factory=str)
418
418
 
419
- async def save(self, driver: AsyncDriver):
419
+ async def save(self, driver: GraphDriver):
420
420
  result = await driver.execute_query(
421
421
  COMMUNITY_NODE_SAVE,
422
422
  uuid=self.uuid,
@@ -428,7 +428,7 @@ class CommunityNode(Node):
428
428
  database_=DEFAULT_DATABASE,
429
429
  )
430
430
 
431
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
431
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
432
432
 
433
433
  return result
434
434
 
@@ -441,7 +441,7 @@ class CommunityNode(Node):
441
441
 
442
442
  return self.name_embedding
443
443
 
444
- async def load_name_embedding(self, driver: AsyncDriver):
444
+ async def load_name_embedding(self, driver: GraphDriver):
445
445
  query: LiteralString = """
446
446
  MATCH (c:Community {uuid: $uuid})
447
447
  RETURN c.name_embedding AS name_embedding
@@ -456,7 +456,7 @@ class CommunityNode(Node):
456
456
  self.name_embedding = records[0]['name_embedding']
457
457
 
458
458
  @classmethod
459
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
459
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
460
460
  records, _, _ = await driver.execute_query(
461
461
  """
462
462
  MATCH (n:Community {uuid: $uuid})
@@ -480,7 +480,7 @@ class CommunityNode(Node):
480
480
  return nodes[0]
481
481
 
482
482
  @classmethod
483
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
483
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
484
484
  records, _, _ = await driver.execute_query(
485
485
  """
486
486
  MATCH (n:Community) WHERE n.uuid IN $uuids
@@ -503,7 +503,7 @@ class CommunityNode(Node):
503
503
  @classmethod
504
504
  async def get_by_group_ids(
505
505
  cls,
506
- driver: AsyncDriver,
506
+ driver: GraphDriver,
507
507
  group_ids: list[str],
508
508
  limit: int | None = None,
509
509
  uuid_cursor: str | None = None,
@@ -542,8 +542,8 @@ class CommunityNode(Node):
542
542
  def get_episodic_node_from_record(record: Any) -> EpisodicNode:
543
543
  return EpisodicNode(
544
544
  content=record['content'],
545
- created_at=record['created_at'].to_native().timestamp(),
546
- valid_at=(record['valid_at'].to_native()),
545
+ created_at=parse_db_date(record['created_at']), # type: ignore
546
+ valid_at=parse_db_date(record['valid_at']), # type: ignore
547
547
  uuid=record['uuid'],
548
548
  group_id=record['group_id'],
549
549
  source=EpisodeType.from_str(record['source']),
@@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
559
559
  name=record['name'],
560
560
  group_id=record['group_id'],
561
561
  labels=record['labels'],
562
- created_at=record['created_at'].to_native(),
562
+ created_at=parse_db_date(record['created_at']), # type: ignore
563
563
  summary=record['summary'],
564
564
  attributes=record['attributes'],
565
565
  )
@@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
580
580
  name=record['name'],
581
581
  group_id=record['group_id'],
582
582
  name_embedding=record['name_embedding'],
583
- created_at=record['created_at'].to_native(),
583
+ created_at=parse_db_date(record['created_at']), # type: ignore
584
584
  summary=record['summary'],
585
585
  )
586
586
 
@@ -26,12 +26,16 @@ class NodeDuplicate(BaseModel):
26
26
  id: int = Field(..., description='integer id of the entity')
27
27
  duplicate_idx: int = Field(
28
28
  ...,
29
- description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
29
+ description='idx of the duplicate entity. If no duplicate entities are found, default to -1.',
30
30
  )
31
31
  name: str = Field(
32
32
  ...,
33
33
  description='Name of the entity. Should be the most complete and descriptive name possible.',
34
34
  )
35
+ additional_duplicates: list[int] = Field(
36
+ ...,
37
+ description='idx of additional duplicate entities. Use this list if the entity has multiple duplicates among existing entities.',
38
+ )
35
39
 
36
40
 
37
41
  class NodeResolutions(BaseModel):
@@ -97,6 +97,8 @@ Only extract facts that:
97
97
  - The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
98
98
  - The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
99
99
  of the FACT TYPES
100
+ - The FACT TYPES each contain their fact_type_signature which represents the entity types which that fact_type is defined for.
101
+ A Type of Entity in the signature represents any extracted entity (it is a generic universal type for all entities).
100
102
 
101
103
  You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
102
104
 
@@ -90,6 +90,8 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
90
90
  Instructions:
91
91
 
92
92
  You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
93
+ Pronoun references such as he/she/they or this/that/those should be disambiguated to the names of the
94
+ reference entities.
93
95
 
94
96
  1. **Speaker Extraction**: Always extract the speaker (the part before the colon `:` in each dialogue line) as the first entity node.
95
97
  - If the speaker is mentioned again in the message, treat both mentions as a **single entity**.