graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +70 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +635 -260
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +37 -15
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +92 -48
- graphiti_core/llm_client/openai_client.py +39 -9
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +24 -15
- graphiti_core/prompts/extract_nodes.py +76 -35
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +110 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/content_chunking.py +702 -0
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +306 -156
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +466 -206
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
- graphiti_core-0.25.3.dist-info/RECORD +87 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -84,7 +84,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
84
84
|
responses = await semaphore_gather(
|
|
85
85
|
*[
|
|
86
86
|
self.client.chat.completions.create(
|
|
87
|
-
model=DEFAULT_MODEL,
|
|
87
|
+
model=self.config.model or DEFAULT_MODEL,
|
|
88
88
|
messages=openai_messages,
|
|
89
89
|
temperature=0,
|
|
90
90
|
max_tokens=1,
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import functools
|
|
18
|
+
import inspect
|
|
19
|
+
from collections.abc import Awaitable, Callable
|
|
20
|
+
from typing import Any, TypeVar
|
|
21
|
+
|
|
22
|
+
from graphiti_core.driver.driver import GraphProvider
|
|
23
|
+
from graphiti_core.helpers import semaphore_gather
|
|
24
|
+
from graphiti_core.search.search_config import SearchResults
|
|
25
|
+
|
|
26
|
+
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def handle_multiple_group_ids(func: F) -> F:
|
|
30
|
+
"""
|
|
31
|
+
Decorator for FalkorDB methods that need to handle multiple group_ids.
|
|
32
|
+
Runs the function for each group_id separately and merges results.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@functools.wraps(func)
|
|
36
|
+
async def wrapper(self, *args, **kwargs):
|
|
37
|
+
group_ids_func_pos = get_parameter_position(func, 'group_ids')
|
|
38
|
+
group_ids_pos = (
|
|
39
|
+
group_ids_func_pos - 1 if group_ids_func_pos is not None else None
|
|
40
|
+
) # Adjust for zero-based index
|
|
41
|
+
group_ids = kwargs.get('group_ids')
|
|
42
|
+
|
|
43
|
+
# If not in kwargs and position exists, get from args
|
|
44
|
+
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
|
|
45
|
+
group_ids = args[group_ids_pos]
|
|
46
|
+
|
|
47
|
+
# Only handle FalkorDB with multiple group_ids
|
|
48
|
+
if (
|
|
49
|
+
hasattr(self, 'clients')
|
|
50
|
+
and hasattr(self.clients, 'driver')
|
|
51
|
+
and self.clients.driver.provider == GraphProvider.FALKORDB
|
|
52
|
+
and group_ids
|
|
53
|
+
and len(group_ids) > 1
|
|
54
|
+
):
|
|
55
|
+
# Execute for each group_id concurrently
|
|
56
|
+
driver = self.clients.driver
|
|
57
|
+
|
|
58
|
+
async def execute_for_group(gid: str):
|
|
59
|
+
# Remove group_ids from args if it was passed positionally
|
|
60
|
+
filtered_args = list(args)
|
|
61
|
+
if group_ids_pos is not None and len(args) > group_ids_pos:
|
|
62
|
+
filtered_args.pop(group_ids_pos)
|
|
63
|
+
|
|
64
|
+
return await func(
|
|
65
|
+
self,
|
|
66
|
+
*filtered_args,
|
|
67
|
+
**{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)},
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
results = await semaphore_gather(
|
|
71
|
+
*[execute_for_group(gid) for gid in group_ids],
|
|
72
|
+
max_coroutines=getattr(self, 'max_coroutines', None),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Merge results based on type
|
|
76
|
+
if isinstance(results[0], SearchResults):
|
|
77
|
+
return SearchResults.merge(results)
|
|
78
|
+
elif isinstance(results[0], list):
|
|
79
|
+
return [item for result in results for item in result]
|
|
80
|
+
elif isinstance(results[0], tuple):
|
|
81
|
+
# Handle tuple outputs (like build_communities returning (nodes, edges))
|
|
82
|
+
merged_tuple = []
|
|
83
|
+
for i in range(len(results[0])):
|
|
84
|
+
component_results = [result[i] for result in results]
|
|
85
|
+
if isinstance(component_results[0], list):
|
|
86
|
+
merged_tuple.append(
|
|
87
|
+
[item for component in component_results for item in component]
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
merged_tuple.append(component_results)
|
|
91
|
+
return tuple(merged_tuple)
|
|
92
|
+
else:
|
|
93
|
+
return results
|
|
94
|
+
|
|
95
|
+
# Normal execution
|
|
96
|
+
return await func(self, *args, **kwargs)
|
|
97
|
+
|
|
98
|
+
return wrapper # type: ignore
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_parameter_position(func: Callable, param_name: str) -> int | None:
|
|
102
|
+
"""
|
|
103
|
+
Returns the positional index of a parameter in the function signature.
|
|
104
|
+
If the parameter is not found, returns None.
|
|
105
|
+
"""
|
|
106
|
+
sig = inspect.signature(func)
|
|
107
|
+
for idx, (name, _param) in enumerate(sig.parameters.items()):
|
|
108
|
+
if name == param_name:
|
|
109
|
+
return idx
|
|
110
|
+
return None
|
graphiti_core/driver/driver.py
CHANGED
|
@@ -14,15 +14,41 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import copy
|
|
17
18
|
import logging
|
|
19
|
+
import os
|
|
18
20
|
from abc import ABC, abstractmethod
|
|
19
21
|
from collections.abc import Coroutine
|
|
22
|
+
from enum import Enum
|
|
20
23
|
from typing import Any
|
|
21
24
|
|
|
25
|
+
from dotenv import load_dotenv
|
|
26
|
+
|
|
27
|
+
from graphiti_core.driver.graph_operations.graph_operations import GraphOperationsInterface
|
|
28
|
+
from graphiti_core.driver.search_interface.search_interface import SearchInterface
|
|
29
|
+
|
|
22
30
|
logger = logging.getLogger(__name__)
|
|
23
31
|
|
|
32
|
+
DEFAULT_SIZE = 10
|
|
33
|
+
|
|
34
|
+
load_dotenv()
|
|
35
|
+
|
|
36
|
+
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
|
|
37
|
+
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
|
|
38
|
+
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
|
39
|
+
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class GraphProvider(Enum):
|
|
43
|
+
NEO4J = 'neo4j'
|
|
44
|
+
FALKORDB = 'falkordb'
|
|
45
|
+
KUZU = 'kuzu'
|
|
46
|
+
NEPTUNE = 'neptune'
|
|
47
|
+
|
|
24
48
|
|
|
25
49
|
class GraphDriverSession(ABC):
|
|
50
|
+
provider: GraphProvider
|
|
51
|
+
|
|
26
52
|
async def __aenter__(self):
|
|
27
53
|
return self
|
|
28
54
|
|
|
@@ -45,7 +71,14 @@ class GraphDriverSession(ABC):
|
|
|
45
71
|
|
|
46
72
|
|
|
47
73
|
class GraphDriver(ABC):
|
|
48
|
-
provider:
|
|
74
|
+
provider: GraphProvider
|
|
75
|
+
fulltext_syntax: str = (
|
|
76
|
+
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
|
77
|
+
)
|
|
78
|
+
_database: str
|
|
79
|
+
default_group_id: str = ''
|
|
80
|
+
search_interface: SearchInterface | None = None
|
|
81
|
+
graph_operations_interface: GraphOperationsInterface | None = None
|
|
49
82
|
|
|
50
83
|
@abstractmethod
|
|
51
84
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
|
@@ -60,5 +93,32 @@ class GraphDriver(ABC):
|
|
|
60
93
|
raise NotImplementedError()
|
|
61
94
|
|
|
62
95
|
@abstractmethod
|
|
63
|
-
def delete_all_indexes(self
|
|
96
|
+
def delete_all_indexes(self) -> Coroutine:
|
|
64
97
|
raise NotImplementedError()
|
|
98
|
+
|
|
99
|
+
def with_database(self, database: str) -> 'GraphDriver':
|
|
100
|
+
"""
|
|
101
|
+
Returns a shallow copy of this driver with a different default database.
|
|
102
|
+
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
|
103
|
+
"""
|
|
104
|
+
cloned = copy.copy(self)
|
|
105
|
+
cloned._database = database
|
|
106
|
+
|
|
107
|
+
return cloned
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
|
111
|
+
raise NotImplementedError()
|
|
112
|
+
|
|
113
|
+
def clone(self, database: str) -> 'GraphDriver':
|
|
114
|
+
"""Clone the driver with a different database or graph name."""
|
|
115
|
+
return self
|
|
116
|
+
|
|
117
|
+
def build_fulltext_query(
|
|
118
|
+
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
|
119
|
+
) -> str:
|
|
120
|
+
"""
|
|
121
|
+
Specific fulltext query builder for database providers.
|
|
122
|
+
Only implemented by providers that need custom fulltext query building.
|
|
123
|
+
"""
|
|
124
|
+
raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
|
|
@@ -14,8 +14,9 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import asyncio
|
|
18
|
+
import datetime
|
|
17
19
|
import logging
|
|
18
|
-
from datetime import datetime
|
|
19
20
|
from typing import TYPE_CHECKING, Any
|
|
20
21
|
|
|
21
22
|
if TYPE_CHECKING:
|
|
@@ -32,12 +33,52 @@ else:
|
|
|
32
33
|
'Install it with: pip install graphiti-core[falkordb]'
|
|
33
34
|
) from None
|
|
34
35
|
|
|
35
|
-
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
36
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
37
|
+
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
|
38
|
+
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
|
36
39
|
|
|
37
40
|
logger = logging.getLogger(__name__)
|
|
38
41
|
|
|
42
|
+
STOPWORDS = [
|
|
43
|
+
'a',
|
|
44
|
+
'is',
|
|
45
|
+
'the',
|
|
46
|
+
'an',
|
|
47
|
+
'and',
|
|
48
|
+
'are',
|
|
49
|
+
'as',
|
|
50
|
+
'at',
|
|
51
|
+
'be',
|
|
52
|
+
'but',
|
|
53
|
+
'by',
|
|
54
|
+
'for',
|
|
55
|
+
'if',
|
|
56
|
+
'in',
|
|
57
|
+
'into',
|
|
58
|
+
'it',
|
|
59
|
+
'no',
|
|
60
|
+
'not',
|
|
61
|
+
'of',
|
|
62
|
+
'on',
|
|
63
|
+
'or',
|
|
64
|
+
'such',
|
|
65
|
+
'that',
|
|
66
|
+
'their',
|
|
67
|
+
'then',
|
|
68
|
+
'there',
|
|
69
|
+
'these',
|
|
70
|
+
'they',
|
|
71
|
+
'this',
|
|
72
|
+
'to',
|
|
73
|
+
'was',
|
|
74
|
+
'will',
|
|
75
|
+
'with',
|
|
76
|
+
]
|
|
77
|
+
|
|
39
78
|
|
|
40
79
|
class FalkorDriverSession(GraphDriverSession):
|
|
80
|
+
provider = GraphProvider.FALKORDB
|
|
81
|
+
|
|
41
82
|
def __init__(self, graph: FalkorGraph):
|
|
42
83
|
self.graph = graph
|
|
43
84
|
|
|
@@ -71,7 +112,10 @@ class FalkorDriverSession(GraphDriverSession):
|
|
|
71
112
|
|
|
72
113
|
|
|
73
114
|
class FalkorDriver(GraphDriver):
|
|
74
|
-
provider
|
|
115
|
+
provider = GraphProvider.FALKORDB
|
|
116
|
+
default_group_id: str = '\\_'
|
|
117
|
+
fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries
|
|
118
|
+
aoss_client: None = None
|
|
75
119
|
|
|
76
120
|
def __init__(
|
|
77
121
|
self,
|
|
@@ -88,14 +132,32 @@ class FalkorDriver(GraphDriver):
|
|
|
88
132
|
FalkorDB is a multi-tenant graph database.
|
|
89
133
|
To connect, provide the host and port.
|
|
90
134
|
The default parameters assume a local (on-premises) FalkorDB instance.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
host (str): The host where FalkorDB is running.
|
|
138
|
+
port (int): The port on which FalkorDB is listening.
|
|
139
|
+
username (str | None): The username for authentication (if required).
|
|
140
|
+
password (str | None): The password for authentication (if required).
|
|
141
|
+
falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
|
|
142
|
+
database (str): The name of the database to connect to. Defaults to 'default_db'.
|
|
91
143
|
"""
|
|
92
144
|
super().__init__()
|
|
145
|
+
self._database = database
|
|
93
146
|
if falkor_db is not None:
|
|
94
147
|
# If a FalkorDB instance is provided, use it directly
|
|
95
148
|
self.client = falkor_db
|
|
96
149
|
else:
|
|
97
150
|
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
|
98
|
-
|
|
151
|
+
|
|
152
|
+
# Schedule the indices and constraints to be built
|
|
153
|
+
try:
|
|
154
|
+
# Try to get the current event loop
|
|
155
|
+
loop = asyncio.get_running_loop()
|
|
156
|
+
# Schedule the build_indices_and_constraints to run
|
|
157
|
+
loop.create_task(self.build_indices_and_constraints())
|
|
158
|
+
except RuntimeError:
|
|
159
|
+
# No event loop running, this will be handled later
|
|
160
|
+
pass
|
|
99
161
|
|
|
100
162
|
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
|
101
163
|
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
|
|
@@ -104,8 +166,7 @@ class FalkorDriver(GraphDriver):
|
|
|
104
166
|
return self.client.select_graph(graph_name)
|
|
105
167
|
|
|
106
168
|
async def execute_query(self, cypher_query_, **kwargs: Any):
|
|
107
|
-
|
|
108
|
-
graph = self._get_graph(graph_name)
|
|
169
|
+
graph = self._get_graph(self._database)
|
|
109
170
|
|
|
110
171
|
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
|
|
111
172
|
params = convert_datetimes_to_strings(dict(kwargs))
|
|
@@ -117,7 +178,7 @@ class FalkorDriver(GraphDriver):
|
|
|
117
178
|
# check if index already exists
|
|
118
179
|
logger.info(f'Index already exists: {e}')
|
|
119
180
|
return None
|
|
120
|
-
logger.error(f'Error executing FalkorDB query: {e}')
|
|
181
|
+
logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}')
|
|
121
182
|
raise
|
|
122
183
|
|
|
123
184
|
# Convert the result header to a list of strings
|
|
@@ -149,22 +210,153 @@ class FalkorDriver(GraphDriver):
|
|
|
149
210
|
elif hasattr(self.client.connection, 'close'):
|
|
150
211
|
await self.client.connection.close()
|
|
151
212
|
|
|
152
|
-
async def delete_all_indexes(self
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
213
|
+
async def delete_all_indexes(self) -> None:
|
|
214
|
+
result = await self.execute_query('CALL db.indexes()')
|
|
215
|
+
if not result:
|
|
216
|
+
return
|
|
217
|
+
|
|
218
|
+
records, _, _ = result
|
|
219
|
+
drop_tasks = []
|
|
220
|
+
|
|
221
|
+
for record in records:
|
|
222
|
+
label = record['label']
|
|
223
|
+
entity_type = record['entitytype']
|
|
224
|
+
|
|
225
|
+
for field_name, index_type in record['types'].items():
|
|
226
|
+
if 'RANGE' in index_type:
|
|
227
|
+
drop_tasks.append(self.execute_query(f'DROP INDEX ON :{label}({field_name})'))
|
|
228
|
+
elif 'FULLTEXT' in index_type:
|
|
229
|
+
if entity_type == 'NODE':
|
|
230
|
+
drop_tasks.append(
|
|
231
|
+
self.execute_query(
|
|
232
|
+
f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})'
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
elif entity_type == 'RELATIONSHIP':
|
|
236
|
+
drop_tasks.append(
|
|
237
|
+
self.execute_query(
|
|
238
|
+
f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})'
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if drop_tasks:
|
|
243
|
+
await asyncio.gather(*drop_tasks)
|
|
244
|
+
|
|
245
|
+
async def build_indices_and_constraints(self, delete_existing=False):
|
|
246
|
+
if delete_existing:
|
|
247
|
+
await self.delete_all_indexes()
|
|
248
|
+
index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
|
|
249
|
+
for query in index_queries:
|
|
250
|
+
await self.execute_query(query)
|
|
251
|
+
|
|
252
|
+
def clone(self, database: str) -> 'GraphDriver':
|
|
253
|
+
"""
|
|
254
|
+
Returns a shallow copy of this driver with a different default database.
|
|
255
|
+
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
|
256
|
+
"""
|
|
257
|
+
if database == self._database:
|
|
258
|
+
cloned = self
|
|
259
|
+
elif database == self.default_group_id:
|
|
260
|
+
cloned = FalkorDriver(falkor_db=self.client)
|
|
261
|
+
else:
|
|
262
|
+
# Create a new instance of FalkorDriver with the same connection but a different database
|
|
263
|
+
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
|
264
|
+
|
|
265
|
+
return cloned
|
|
266
|
+
|
|
267
|
+
async def health_check(self) -> None:
|
|
268
|
+
"""Check FalkorDB connectivity by running a simple query."""
|
|
269
|
+
try:
|
|
270
|
+
await self.execute_query('MATCH (n) RETURN 1 LIMIT 1')
|
|
271
|
+
return None
|
|
272
|
+
except Exception as e:
|
|
273
|
+
print(f'FalkorDB health check failed: {e}')
|
|
274
|
+
raise
|
|
275
|
+
|
|
276
|
+
@staticmethod
|
|
277
|
+
def convert_datetimes_to_strings(obj):
|
|
278
|
+
if isinstance(obj, dict):
|
|
279
|
+
return {k: FalkorDriver.convert_datetimes_to_strings(v) for k, v in obj.items()}
|
|
280
|
+
elif isinstance(obj, list):
|
|
281
|
+
return [FalkorDriver.convert_datetimes_to_strings(item) for item in obj]
|
|
282
|
+
elif isinstance(obj, tuple):
|
|
283
|
+
return tuple(FalkorDriver.convert_datetimes_to_strings(item) for item in obj)
|
|
284
|
+
elif isinstance(obj, datetime):
|
|
285
|
+
return obj.isoformat()
|
|
286
|
+
else:
|
|
287
|
+
return obj
|
|
288
|
+
|
|
289
|
+
def sanitize(self, query: str) -> str:
|
|
290
|
+
"""
|
|
291
|
+
Replace FalkorDB special characters with whitespace.
|
|
292
|
+
Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
|
|
293
|
+
"""
|
|
294
|
+
# FalkorDB separator characters that break text into tokens
|
|
295
|
+
separator_map = str.maketrans(
|
|
296
|
+
{
|
|
297
|
+
',': ' ',
|
|
298
|
+
'.': ' ',
|
|
299
|
+
'<': ' ',
|
|
300
|
+
'>': ' ',
|
|
301
|
+
'{': ' ',
|
|
302
|
+
'}': ' ',
|
|
303
|
+
'[': ' ',
|
|
304
|
+
']': ' ',
|
|
305
|
+
'"': ' ',
|
|
306
|
+
"'": ' ',
|
|
307
|
+
':': ' ',
|
|
308
|
+
';': ' ',
|
|
309
|
+
'!': ' ',
|
|
310
|
+
'@': ' ',
|
|
311
|
+
'#': ' ',
|
|
312
|
+
'$': ' ',
|
|
313
|
+
'%': ' ',
|
|
314
|
+
'^': ' ',
|
|
315
|
+
'&': ' ',
|
|
316
|
+
'*': ' ',
|
|
317
|
+
'(': ' ',
|
|
318
|
+
')': ' ',
|
|
319
|
+
'-': ' ',
|
|
320
|
+
'+': ' ',
|
|
321
|
+
'=': ' ',
|
|
322
|
+
'~': ' ',
|
|
323
|
+
'?': ' ',
|
|
324
|
+
}
|
|
157
325
|
)
|
|
326
|
+
sanitized = query.translate(separator_map)
|
|
327
|
+
# Clean up multiple spaces
|
|
328
|
+
sanitized = ' '.join(sanitized.split())
|
|
329
|
+
return sanitized
|
|
330
|
+
|
|
331
|
+
def build_fulltext_query(
|
|
332
|
+
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
|
333
|
+
) -> str:
|
|
334
|
+
"""
|
|
335
|
+
Build a fulltext query string for FalkorDB using RedisSearch syntax.
|
|
336
|
+
FalkorDB uses RedisSearch-like syntax where:
|
|
337
|
+
- Field queries use @ prefix: @field:value
|
|
338
|
+
- Multiple values for same field: (@field:value1|value2)
|
|
339
|
+
- Text search doesn't need @ prefix for content fields
|
|
340
|
+
- AND is implicit with space: (@group_id:value) (text)
|
|
341
|
+
- OR uses pipe within parentheses: (@group_id:value1|value2)
|
|
342
|
+
"""
|
|
343
|
+
if group_ids is None or len(group_ids) == 0:
|
|
344
|
+
group_filter = ''
|
|
345
|
+
else:
|
|
346
|
+
group_values = '|'.join(group_ids)
|
|
347
|
+
group_filter = f'(@group_id:{group_values})'
|
|
348
|
+
|
|
349
|
+
sanitized_query = self.sanitize(query)
|
|
350
|
+
|
|
351
|
+
# Remove stopwords from the sanitized query
|
|
352
|
+
query_words = sanitized_query.split()
|
|
353
|
+
filtered_words = [word for word in query_words if word.lower() not in STOPWORDS]
|
|
354
|
+
sanitized_query = ' | '.join(filtered_words)
|
|
355
|
+
|
|
356
|
+
# If the query is too long return no query
|
|
357
|
+
if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length:
|
|
358
|
+
return ''
|
|
158
359
|
|
|
360
|
+
full_query = group_filter + ' (' + sanitized_query + ')'
|
|
159
361
|
|
|
160
|
-
|
|
161
|
-
if isinstance(obj, dict):
|
|
162
|
-
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
|
|
163
|
-
elif isinstance(obj, list):
|
|
164
|
-
return [convert_datetimes_to_strings(item) for item in obj]
|
|
165
|
-
elif isinstance(obj, tuple):
|
|
166
|
-
return tuple(convert_datetimes_to_strings(item) for item in obj)
|
|
167
|
-
elif isinstance(obj, datetime):
|
|
168
|
-
return obj.isoformat()
|
|
169
|
-
else:
|
|
170
|
-
return obj
|
|
362
|
+
return full_query
|