graphiti-core 0.12.0rc5__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -0,0 +1,149 @@
1
+ """
2
+ Database query utilities for different graph database backends.
3
+
4
+ This module provides database-agnostic query generation for Neo4j and FalkorDB,
5
+ supporting index creation, fulltext search, and bulk operations.
6
+ """
7
+
8
+ from typing import Any
9
+
10
+ from typing_extensions import LiteralString
11
+
12
+ from graphiti_core.models.edges.edge_db_queries import (
13
+ ENTITY_EDGE_SAVE_BULK,
14
+ )
15
+ from graphiti_core.models.nodes.node_db_queries import (
16
+ ENTITY_NODE_SAVE_BULK,
17
+ )
18
+
19
+ # Mapping from Neo4j fulltext index names to FalkorDB node labels
20
+ NEO4J_TO_FALKORDB_MAPPING = {
21
+ 'node_name_and_summary': 'Entity',
22
+ 'community_name': 'Community',
23
+ 'episode_content': 'Episodic',
24
+ 'edge_name_and_fact': 'RELATES_TO',
25
+ }
26
+
27
+
28
+ def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
29
+ if db_type == 'falkordb':
30
+ return [
31
+ # Entity node
32
+ 'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
33
+ # Episodic node
34
+ 'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
35
+ # Community node
36
+ 'CREATE INDEX FOR (n:Community) ON (n.uuid)',
37
+ # RELATES_TO edge
38
+ 'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
39
+ # MENTIONS edge
40
+ 'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
41
+ # HAS_MEMBER edge
42
+ 'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
43
+ ]
44
+ else:
45
+ return [
46
+ 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
47
+ 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
48
+ 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
49
+ 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
50
+ 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
51
+ 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
52
+ 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
53
+ 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
54
+ 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
55
+ 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
56
+ 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
57
+ 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
58
+ 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
59
+ 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
60
+ 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
61
+ 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
62
+ 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
63
+ 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
64
+ 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
65
+ ]
66
+
67
+
68
+ def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
69
+ if db_type == 'falkordb':
70
+ return [
71
+ """CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
72
+ """CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
73
+ """CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
74
+ """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
75
+ ]
76
+ else:
77
+ return [
78
+ """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
79
+ FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
80
+ """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
81
+ FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
82
+ """CREATE FULLTEXT INDEX community_name IF NOT EXISTS
83
+ FOR (n:Community) ON EACH [n.name, n.group_id]""",
84
+ """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
85
+ FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
86
+ ]
87
+
88
+
89
+ def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
90
+ if db_type == 'falkordb':
91
+ label = NEO4J_TO_FALKORDB_MAPPING[name]
92
+ return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
93
+ else:
94
+ return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
95
+
96
+
97
+ def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
98
+ if db_type == 'falkordb':
99
+ # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
100
+ return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
101
+ else:
102
+ return f'vector.similarity.cosine({vec1}, {vec2})'
103
+
104
+
105
+ def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
106
+ if db_type == 'falkordb':
107
+ label = NEO4J_TO_FALKORDB_MAPPING[name]
108
+ return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
109
+ else:
110
+ return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
111
+
112
+
113
+ def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
114
+ if db_type == 'falkordb':
115
+ queries = []
116
+ for node in nodes:
117
+ for label in node['labels']:
118
+ queries.append(
119
+ (
120
+ f"""
121
+ UNWIND $nodes AS node
122
+ MERGE (n:Entity {{uuid: node.uuid}})
123
+ SET n:{label}
124
+ SET n = node
125
+ WITH n, node
126
+ SET n.name_embedding = vecf32(node.name_embedding)
127
+ RETURN n.uuid AS uuid
128
+ """,
129
+ {'nodes': [node]},
130
+ )
131
+ )
132
+ return queries
133
+ else:
134
+ return ENTITY_NODE_SAVE_BULK
135
+
136
+
137
+ def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
138
+ if db_type == 'falkordb':
139
+ return """
140
+ UNWIND $entity_edges AS edge
141
+ MATCH (source:Entity {uuid: edge.source_node_uuid})
142
+ MATCH (target:Entity {uuid: edge.target_node_uuid})
143
+ MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
144
+ SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
145
+ created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
146
+ WITH r, edge
147
+ RETURN edge.uuid AS uuid"""
148
+ else:
149
+ return ENTITY_EDGE_SAVE_BULK
graphiti_core/graphiti.py CHANGED
@@ -19,12 +19,13 @@ from datetime import datetime
19
19
  from time import time
20
20
 
21
21
  from dotenv import load_dotenv
22
- from neo4j import AsyncGraphDatabase
23
22
  from pydantic import BaseModel
24
23
  from typing_extensions import LiteralString
25
24
 
26
25
  from graphiti_core.cross_encoder.client import CrossEncoderClient
27
26
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
27
+ from graphiti_core.driver.driver import GraphDriver
28
+ from graphiti_core.driver.neo4j_driver import Neo4jDriver
28
29
  from graphiti_core.edges import EntityEdge, EpisodicEdge
29
30
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
31
  from graphiti_core.graphiti_types import GraphitiClients
@@ -94,12 +95,13 @@ class Graphiti:
94
95
  def __init__(
95
96
  self,
96
97
  uri: str,
97
- user: str,
98
- password: str,
98
+ user: str | None = None,
99
+ password: str | None = None,
99
100
  llm_client: LLMClient | None = None,
100
101
  embedder: EmbedderClient | None = None,
101
102
  cross_encoder: CrossEncoderClient | None = None,
102
103
  store_raw_episode_content: bool = True,
104
+ graph_driver: GraphDriver | None = None,
103
105
  ):
104
106
  """
105
107
  Initialize a Graphiti instance.
@@ -137,7 +139,9 @@ class Graphiti:
137
139
  Make sure to set the OPENAI_API_KEY environment variable before initializing
138
140
  Graphiti if you're using the default OpenAIClient.
139
141
  """
140
- self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
142
+
143
+ self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
144
+
141
145
  self.database = DEFAULT_DATABASE
142
146
  self.store_raw_episode_content = store_raw_episode_content
143
147
  if llm_client:
@@ -14,16 +14,16 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from neo4j import AsyncDriver
18
17
  from pydantic import BaseModel, ConfigDict
19
18
 
20
19
  from graphiti_core.cross_encoder import CrossEncoderClient
20
+ from graphiti_core.driver.driver import GraphDriver
21
21
  from graphiti_core.embedder import EmbedderClient
22
22
  from graphiti_core.llm_client import LLMClient
23
23
 
24
24
 
25
25
  class GraphitiClients(BaseModel):
26
- driver: AsyncDriver
26
+ driver: GraphDriver
27
27
  llm_client: LLMClient
28
28
  embedder: EmbedderClient
29
29
  cross_encoder: CrossEncoderClient
graphiti_core/helpers.py CHANGED
@@ -27,7 +27,7 @@ from typing_extensions import LiteralString
27
27
 
28
28
  load_dotenv()
29
29
 
30
- DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
30
+ DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
31
31
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
32
32
  SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
33
33
  MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
@@ -38,8 +38,14 @@ RUNTIME_QUERY: LiteralString = (
38
38
  )
39
39
 
40
40
 
41
- def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
42
- return neo_date.to_native() if neo_date else None
41
+ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
42
+ return (
43
+ neo_date.to_native()
44
+ if isinstance(neo_date, neo4j_time.DateTime)
45
+ else datetime.fromisoformat(neo_date)
46
+ if neo_date
47
+ else None
48
+ )
43
49
 
44
50
 
45
51
  def lucene_sanitize(query: str) -> str:
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  from .client import LLMClient
2
18
  from .config import LLMConfig
3
19
  from .errors import RateLimitError
@@ -0,0 +1,73 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import json
18
+ import logging
19
+ from typing import Any
20
+
21
+ from openai import AsyncAzureOpenAI
22
+ from openai.types.chat import ChatCompletionMessageParam
23
+ from pydantic import BaseModel
24
+
25
+ from ..prompts.models import Message
26
+ from .client import LLMClient
27
+ from .config import LLMConfig, ModelSize
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class AzureOpenAILLMClient(LLMClient):
33
+ """Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
34
+
35
+ def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
36
+ super().__init__(config, cache=False)
37
+ self.azure_client = azure_client
38
+
39
+ async def _generate_response(
40
+ self,
41
+ messages: list[Message],
42
+ response_model: type[BaseModel] | None = None,
43
+ max_tokens: int = 1024,
44
+ model_size: ModelSize = ModelSize.medium,
45
+ ) -> dict[str, Any]:
46
+ """Generate response using Azure OpenAI client."""
47
+ # Convert messages to OpenAI format
48
+ openai_messages: list[ChatCompletionMessageParam] = []
49
+ for message in messages:
50
+ message.content = self._clean_input(message.content)
51
+ if message.role == 'user':
52
+ openai_messages.append({'role': 'user', 'content': message.content})
53
+ elif message.role == 'system':
54
+ openai_messages.append({'role': 'system', 'content': message.content})
55
+
56
+ # Ensure model is a string
57
+ model_name = self.model if self.model else 'gpt-4o-mini'
58
+
59
+ try:
60
+ response = await self.azure_client.chat.completions.create(
61
+ model=model_name,
62
+ messages=openai_messages,
63
+ temperature=float(self.temperature) if self.temperature is not None else 0.7,
64
+ max_tokens=max_tokens,
65
+ response_format={'type': 'json_object'},
66
+ )
67
+ result = response.choices[0].message.content or '{}'
68
+
69
+ # Parse JSON response
70
+ return json.loads(result)
71
+ except Exception as e:
72
+ logger.error(f'Error in Azure OpenAI LLM response: {e}')
73
+ raise
graphiti_core/nodes.py CHANGED
@@ -22,13 +22,13 @@ from time import time
22
22
  from typing import Any
23
23
  from uuid import uuid4
24
24
 
25
- from neo4j import AsyncDriver
26
25
  from pydantic import BaseModel, Field
27
26
  from typing_extensions import LiteralString
28
27
 
28
+ from graphiti_core.driver.driver import GraphDriver
29
29
  from graphiti_core.embedder import EmbedderClient
30
30
  from graphiti_core.errors import NodeNotFoundError
31
- from graphiti_core.helpers import DEFAULT_DATABASE
31
+ from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
32
32
  from graphiti_core.models.nodes.node_db_queries import (
33
33
  COMMUNITY_NODE_SAVE,
34
34
  ENTITY_NODE_SAVE,
@@ -94,9 +94,9 @@ class Node(BaseModel, ABC):
94
94
  created_at: datetime = Field(default_factory=lambda: utc_now())
95
95
 
96
96
  @abstractmethod
97
- async def save(self, driver: AsyncDriver): ...
97
+ async def save(self, driver: GraphDriver): ...
98
98
 
99
- async def delete(self, driver: AsyncDriver):
99
+ async def delete(self, driver: GraphDriver):
100
100
  result = await driver.execute_query(
101
101
  """
102
102
  MATCH (n:Entity|Episodic|Community {uuid: $uuid})
@@ -119,7 +119,7 @@ class Node(BaseModel, ABC):
119
119
  return False
120
120
 
121
121
  @classmethod
122
- async def delete_by_group_id(cls, driver: AsyncDriver, group_id: str):
122
+ async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
123
123
  await driver.execute_query(
124
124
  """
125
125
  MATCH (n:Entity|Episodic|Community {group_id: $group_id})
@@ -132,10 +132,10 @@ class Node(BaseModel, ABC):
132
132
  return 'SUCCESS'
133
133
 
134
134
  @classmethod
135
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
135
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
136
136
 
137
137
  @classmethod
138
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
138
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
139
139
 
140
140
 
141
141
  class EpisodicNode(Node):
@@ -150,7 +150,7 @@ class EpisodicNode(Node):
150
150
  default_factory=list,
151
151
  )
152
152
 
153
- async def save(self, driver: AsyncDriver):
153
+ async def save(self, driver: GraphDriver):
154
154
  result = await driver.execute_query(
155
155
  EPISODIC_NODE_SAVE,
156
156
  uuid=self.uuid,
@@ -165,12 +165,12 @@ class EpisodicNode(Node):
165
165
  database_=DEFAULT_DATABASE,
166
166
  )
167
167
 
168
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
168
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
169
169
 
170
170
  return result
171
171
 
172
172
  @classmethod
173
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
173
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
174
174
  records, _, _ = await driver.execute_query(
175
175
  """
176
176
  MATCH (e:Episodic {uuid: $uuid})
@@ -197,7 +197,7 @@ class EpisodicNode(Node):
197
197
  return episodes[0]
198
198
 
199
199
  @classmethod
200
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
200
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
201
201
  records, _, _ = await driver.execute_query(
202
202
  """
203
203
  MATCH (e:Episodic) WHERE e.uuid IN $uuids
@@ -224,7 +224,7 @@ class EpisodicNode(Node):
224
224
  @classmethod
225
225
  async def get_by_group_ids(
226
226
  cls,
227
- driver: AsyncDriver,
227
+ driver: GraphDriver,
228
228
  group_ids: list[str],
229
229
  limit: int | None = None,
230
230
  uuid_cursor: str | None = None,
@@ -263,7 +263,7 @@ class EpisodicNode(Node):
263
263
  return episodes
264
264
 
265
265
  @classmethod
266
- async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: str):
266
+ async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
267
267
  records, _, _ = await driver.execute_query(
268
268
  """
269
269
  MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
@@ -304,7 +304,7 @@ class EntityNode(Node):
304
304
 
305
305
  return self.name_embedding
306
306
 
307
- async def load_name_embedding(self, driver: AsyncDriver):
307
+ async def load_name_embedding(self, driver: GraphDriver):
308
308
  query: LiteralString = """
309
309
  MATCH (n:Entity {uuid: $uuid})
310
310
  RETURN n.name_embedding AS name_embedding
@@ -318,7 +318,7 @@ class EntityNode(Node):
318
318
 
319
319
  self.name_embedding = records[0]['name_embedding']
320
320
 
321
- async def save(self, driver: AsyncDriver):
321
+ async def save(self, driver: GraphDriver):
322
322
  entity_data: dict[str, Any] = {
323
323
  'uuid': self.uuid,
324
324
  'name': self.name,
@@ -337,16 +337,16 @@ class EntityNode(Node):
337
337
  database_=DEFAULT_DATABASE,
338
338
  )
339
339
 
340
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
340
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
341
341
 
342
342
  return result
343
343
 
344
344
  @classmethod
345
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
345
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
346
346
  query = (
347
347
  """
348
- MATCH (n:Entity {uuid: $uuid})
349
- """
348
+ MATCH (n:Entity {uuid: $uuid})
349
+ """
350
350
  + ENTITY_NODE_RETURN
351
351
  )
352
352
  records, _, _ = await driver.execute_query(
@@ -364,7 +364,7 @@ class EntityNode(Node):
364
364
  return nodes[0]
365
365
 
366
366
  @classmethod
367
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
367
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
368
368
  records, _, _ = await driver.execute_query(
369
369
  """
370
370
  MATCH (n:Entity) WHERE n.uuid IN $uuids
@@ -382,7 +382,7 @@ class EntityNode(Node):
382
382
  @classmethod
383
383
  async def get_by_group_ids(
384
384
  cls,
385
- driver: AsyncDriver,
385
+ driver: GraphDriver,
386
386
  group_ids: list[str],
387
387
  limit: int | None = None,
388
388
  uuid_cursor: str | None = None,
@@ -416,7 +416,7 @@ class CommunityNode(Node):
416
416
  name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
417
417
  summary: str = Field(description='region summary of member nodes', default_factory=str)
418
418
 
419
- async def save(self, driver: AsyncDriver):
419
+ async def save(self, driver: GraphDriver):
420
420
  result = await driver.execute_query(
421
421
  COMMUNITY_NODE_SAVE,
422
422
  uuid=self.uuid,
@@ -428,7 +428,7 @@ class CommunityNode(Node):
428
428
  database_=DEFAULT_DATABASE,
429
429
  )
430
430
 
431
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
431
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
432
432
 
433
433
  return result
434
434
 
@@ -441,7 +441,7 @@ class CommunityNode(Node):
441
441
 
442
442
  return self.name_embedding
443
443
 
444
- async def load_name_embedding(self, driver: AsyncDriver):
444
+ async def load_name_embedding(self, driver: GraphDriver):
445
445
  query: LiteralString = """
446
446
  MATCH (c:Community {uuid: $uuid})
447
447
  RETURN c.name_embedding AS name_embedding
@@ -456,7 +456,7 @@ class CommunityNode(Node):
456
456
  self.name_embedding = records[0]['name_embedding']
457
457
 
458
458
  @classmethod
459
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
459
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
460
460
  records, _, _ = await driver.execute_query(
461
461
  """
462
462
  MATCH (n:Community {uuid: $uuid})
@@ -480,7 +480,7 @@ class CommunityNode(Node):
480
480
  return nodes[0]
481
481
 
482
482
  @classmethod
483
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
483
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
484
484
  records, _, _ = await driver.execute_query(
485
485
  """
486
486
  MATCH (n:Community) WHERE n.uuid IN $uuids
@@ -503,7 +503,7 @@ class CommunityNode(Node):
503
503
  @classmethod
504
504
  async def get_by_group_ids(
505
505
  cls,
506
- driver: AsyncDriver,
506
+ driver: GraphDriver,
507
507
  group_ids: list[str],
508
508
  limit: int | None = None,
509
509
  uuid_cursor: str | None = None,
@@ -542,8 +542,8 @@ class CommunityNode(Node):
542
542
  def get_episodic_node_from_record(record: Any) -> EpisodicNode:
543
543
  return EpisodicNode(
544
544
  content=record['content'],
545
- created_at=record['created_at'].to_native().timestamp(),
546
- valid_at=(record['valid_at'].to_native()),
545
+ created_at=parse_db_date(record['created_at']), # type: ignore
546
+ valid_at=parse_db_date(record['valid_at']), # type: ignore
547
547
  uuid=record['uuid'],
548
548
  group_id=record['group_id'],
549
549
  source=EpisodeType.from_str(record['source']),
@@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
559
559
  name=record['name'],
560
560
  group_id=record['group_id'],
561
561
  labels=record['labels'],
562
- created_at=record['created_at'].to_native(),
562
+ created_at=parse_db_date(record['created_at']), # type: ignore
563
563
  summary=record['summary'],
564
564
  attributes=record['attributes'],
565
565
  )
@@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
580
580
  name=record['name'],
581
581
  group_id=record['group_id'],
582
582
  name_embedding=record['name_embedding'],
583
- created_at=record['created_at'].to_native(),
583
+ created_at=parse_db_date(record['created_at']), # type: ignore
584
584
  summary=record['summary'],
585
585
  )
586
586
 
@@ -18,9 +18,8 @@ import logging
18
18
  from collections import defaultdict
19
19
  from time import time
20
20
 
21
- from neo4j import AsyncDriver
22
-
23
21
  from graphiti_core.cross_encoder.client import CrossEncoderClient
22
+ from graphiti_core.driver.driver import GraphDriver
24
23
  from graphiti_core.edges import EntityEdge
25
24
  from graphiti_core.errors import SearchRerankerError
26
25
  from graphiti_core.graphiti_types import GraphitiClients
@@ -94,7 +93,7 @@ async def search(
94
93
  )
95
94
 
96
95
  # if group_ids is empty, set it to None
97
- group_ids = group_ids if group_ids else None
96
+ group_ids = group_ids if group_ids and group_ids != [''] else None
98
97
  edges, nodes, episodes, communities = await semaphore_gather(
99
98
  edge_search(
100
99
  driver,
@@ -160,7 +159,7 @@ async def search(
160
159
 
161
160
 
162
161
  async def edge_search(
163
- driver: AsyncDriver,
162
+ driver: GraphDriver,
164
163
  cross_encoder: CrossEncoderClient,
165
164
  query: str,
166
165
  query_vector: list[float],
@@ -174,7 +173,6 @@ async def edge_search(
174
173
  ) -> list[EntityEdge]:
175
174
  if config is None:
176
175
  return []
177
-
178
176
  search_results: list[list[EntityEdge]] = list(
179
177
  await semaphore_gather(
180
178
  *[
@@ -261,7 +259,7 @@ async def edge_search(
261
259
 
262
260
 
263
261
  async def node_search(
264
- driver: AsyncDriver,
262
+ driver: GraphDriver,
265
263
  cross_encoder: CrossEncoderClient,
266
264
  query: str,
267
265
  query_vector: list[float],
@@ -275,7 +273,6 @@ async def node_search(
275
273
  ) -> list[EntityNode]:
276
274
  if config is None:
277
275
  return []
278
-
279
276
  search_results: list[list[EntityNode]] = list(
280
277
  await semaphore_gather(
281
278
  *[
@@ -344,7 +341,7 @@ async def node_search(
344
341
 
345
342
 
346
343
  async def episode_search(
347
- driver: AsyncDriver,
344
+ driver: GraphDriver,
348
345
  cross_encoder: CrossEncoderClient,
349
346
  query: str,
350
347
  _query_vector: list[float],
@@ -356,7 +353,6 @@ async def episode_search(
356
353
  ) -> list[EpisodicNode]:
357
354
  if config is None:
358
355
  return []
359
-
360
356
  search_results: list[list[EpisodicNode]] = list(
361
357
  await semaphore_gather(
362
358
  *[
@@ -392,7 +388,7 @@ async def episode_search(
392
388
 
393
389
 
394
390
  async def community_search(
395
- driver: AsyncDriver,
391
+ driver: GraphDriver,
396
392
  cross_encoder: CrossEncoderClient,
397
393
  query: str,
398
394
  query_vector: list[float],