graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
|
@@ -15,8 +15,18 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
|
-
|
|
19
|
-
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from sentence_transformers import CrossEncoder
|
|
22
|
+
else:
|
|
23
|
+
try:
|
|
24
|
+
from sentence_transformers import CrossEncoder
|
|
25
|
+
except ImportError:
|
|
26
|
+
raise ImportError(
|
|
27
|
+
'sentence-transformers is required for BGERerankerClient. '
|
|
28
|
+
'Install it with: pip install graphiti-core[sentence-transformers]'
|
|
29
|
+
) from None
|
|
20
30
|
|
|
21
31
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
22
32
|
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import re
|
|
19
|
+
from typing import TYPE_CHECKING
|
|
20
|
+
|
|
21
|
+
from ..helpers import semaphore_gather
|
|
22
|
+
from ..llm_client import LLMConfig, RateLimitError
|
|
23
|
+
from .client import CrossEncoderClient
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from google import genai
|
|
27
|
+
from google.genai import types
|
|
28
|
+
else:
|
|
29
|
+
try:
|
|
30
|
+
from google import genai
|
|
31
|
+
from google.genai import types
|
|
32
|
+
except ImportError:
|
|
33
|
+
raise ImportError(
|
|
34
|
+
'google-genai is required for GeminiRerankerClient. '
|
|
35
|
+
'Install it with: pip install graphiti-core[google-genai]'
|
|
36
|
+
) from None
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
DEFAULT_MODEL = 'gemini-2.5-flash-lite'
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class GeminiRerankerClient(CrossEncoderClient):
|
|
44
|
+
"""
|
|
45
|
+
Google Gemini Reranker Client
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
config: LLMConfig | None = None,
|
|
51
|
+
client: 'genai.Client | None' = None,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Initialize the GeminiRerankerClient with the provided configuration and client.
|
|
55
|
+
|
|
56
|
+
The Gemini Developer API does not yet support logprobs. Unlike the OpenAI reranker,
|
|
57
|
+
this reranker uses the Gemini API to perform direct relevance scoring of passages.
|
|
58
|
+
Each passage is scored individually on a 0-100 scale.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
|
62
|
+
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
|
63
|
+
"""
|
|
64
|
+
if config is None:
|
|
65
|
+
config = LLMConfig()
|
|
66
|
+
|
|
67
|
+
self.config = config
|
|
68
|
+
if client is None:
|
|
69
|
+
self.client = genai.Client(api_key=config.api_key)
|
|
70
|
+
else:
|
|
71
|
+
self.client = client
|
|
72
|
+
|
|
73
|
+
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
|
74
|
+
"""
|
|
75
|
+
Rank passages based on their relevance to the query using direct scoring.
|
|
76
|
+
|
|
77
|
+
Each passage is scored individually on a 0-100 scale, then normalized to [0,1].
|
|
78
|
+
"""
|
|
79
|
+
if len(passages) <= 1:
|
|
80
|
+
return [(passage, 1.0) for passage in passages]
|
|
81
|
+
|
|
82
|
+
# Generate scoring prompts for each passage
|
|
83
|
+
scoring_prompts = []
|
|
84
|
+
for passage in passages:
|
|
85
|
+
prompt = f"""Rate how well this passage answers or relates to the query. Use a scale from 0 to 100.
|
|
86
|
+
|
|
87
|
+
Query: {query}
|
|
88
|
+
|
|
89
|
+
Passage: {passage}
|
|
90
|
+
|
|
91
|
+
Provide only a number between 0 and 100 (no explanation, just the number):"""
|
|
92
|
+
|
|
93
|
+
scoring_prompts.append(
|
|
94
|
+
[
|
|
95
|
+
types.Content(
|
|
96
|
+
role='user',
|
|
97
|
+
parts=[types.Part.from_text(text=prompt)],
|
|
98
|
+
),
|
|
99
|
+
]
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
# Execute all scoring requests concurrently - O(n) API calls
|
|
104
|
+
responses = await semaphore_gather(
|
|
105
|
+
*[
|
|
106
|
+
self.client.aio.models.generate_content(
|
|
107
|
+
model=self.config.model or DEFAULT_MODEL,
|
|
108
|
+
contents=prompt_messages, # type: ignore
|
|
109
|
+
config=types.GenerateContentConfig(
|
|
110
|
+
system_instruction='You are an expert at rating passage relevance. Respond with only a number from 0-100.',
|
|
111
|
+
temperature=0.0,
|
|
112
|
+
max_output_tokens=3,
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
for prompt_messages in scoring_prompts
|
|
116
|
+
]
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Extract scores and create results
|
|
120
|
+
results = []
|
|
121
|
+
for passage, response in zip(passages, responses, strict=True):
|
|
122
|
+
try:
|
|
123
|
+
if hasattr(response, 'text') and response.text:
|
|
124
|
+
# Extract numeric score from response
|
|
125
|
+
score_text = response.text.strip()
|
|
126
|
+
# Handle cases where model might return non-numeric text
|
|
127
|
+
score_match = re.search(r'\b(\d{1,3})\b', score_text)
|
|
128
|
+
if score_match:
|
|
129
|
+
score = float(score_match.group(1))
|
|
130
|
+
# Normalize to [0, 1] range and clamp to valid range
|
|
131
|
+
normalized_score = max(0.0, min(1.0, score / 100.0))
|
|
132
|
+
results.append((passage, normalized_score))
|
|
133
|
+
else:
|
|
134
|
+
logger.warning(
|
|
135
|
+
f'Could not extract numeric score from response: {score_text}'
|
|
136
|
+
)
|
|
137
|
+
results.append((passage, 0.0))
|
|
138
|
+
else:
|
|
139
|
+
logger.warning('Empty response from Gemini for passage scoring')
|
|
140
|
+
results.append((passage, 0.0))
|
|
141
|
+
except (ValueError, AttributeError) as e:
|
|
142
|
+
logger.warning(f'Error parsing score from Gemini response: {e}')
|
|
143
|
+
results.append((passage, 0.0))
|
|
144
|
+
|
|
145
|
+
# Sort by score in descending order (highest relevance first)
|
|
146
|
+
results.sort(reverse=True, key=lambda x: x[1])
|
|
147
|
+
return results
|
|
148
|
+
|
|
149
|
+
except Exception as e:
|
|
150
|
+
# Check if it's a rate limit error based on Gemini API error codes
|
|
151
|
+
error_message = str(e).lower()
|
|
152
|
+
if (
|
|
153
|
+
'rate limit' in error_message
|
|
154
|
+
or 'quota' in error_message
|
|
155
|
+
or 'resource_exhausted' in error_message
|
|
156
|
+
or '429' in str(e)
|
|
157
|
+
):
|
|
158
|
+
raise RateLimitError from e
|
|
159
|
+
|
|
160
|
+
logger.error(f'Error in generating LLM response: {e}')
|
|
161
|
+
raise
|
|
@@ -22,7 +22,7 @@ import openai
|
|
|
22
22
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
23
23
|
|
|
24
24
|
from ..helpers import semaphore_gather
|
|
25
|
-
from ..llm_client import LLMConfig, RateLimitError
|
|
25
|
+
from ..llm_client import LLMConfig, OpenAIClient, RateLimitError
|
|
26
26
|
from ..prompts import Message
|
|
27
27
|
from .client import CrossEncoderClient
|
|
28
28
|
|
|
@@ -35,7 +35,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
35
35
|
def __init__(
|
|
36
36
|
self,
|
|
37
37
|
config: LLMConfig | None = None,
|
|
38
|
-
client: AsyncOpenAI | AsyncAzureOpenAI | None = None,
|
|
38
|
+
client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None,
|
|
39
39
|
):
|
|
40
40
|
"""
|
|
41
41
|
Initialize the OpenAIRerankerClient with the provided configuration and client.
|
|
@@ -45,7 +45,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
45
45
|
|
|
46
46
|
Args:
|
|
47
47
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
|
48
|
-
client (AsyncOpenAI | AsyncAzureOpenAI | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
48
|
+
client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
49
49
|
"""
|
|
50
50
|
if config is None:
|
|
51
51
|
config = LLMConfig()
|
|
@@ -53,6 +53,8 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
53
53
|
self.config = config
|
|
54
54
|
if client is None:
|
|
55
55
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
56
|
+
elif isinstance(client, OpenAIClient):
|
|
57
|
+
self.client = client.client
|
|
56
58
|
else:
|
|
57
59
|
self.client = client
|
|
58
60
|
|
|
@@ -82,7 +84,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
82
84
|
responses = await semaphore_gather(
|
|
83
85
|
*[
|
|
84
86
|
self.client.chat.completions.create(
|
|
85
|
-
model=DEFAULT_MODEL,
|
|
87
|
+
model=self.config.model or DEFAULT_MODEL,
|
|
86
88
|
messages=openai_messages,
|
|
87
89
|
temperature=0,
|
|
88
90
|
max_tokens=1,
|
|
@@ -106,7 +108,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
106
108
|
if len(top_logprobs) == 0:
|
|
107
109
|
continue
|
|
108
110
|
norm_logprobs = np.exp(top_logprobs[0].logprob)
|
|
109
|
-
if
|
|
111
|
+
if top_logprobs[0].token.strip().split(' ')[0].lower() == 'true':
|
|
110
112
|
scores.append(norm_logprobs)
|
|
111
113
|
else:
|
|
112
114
|
scores.append(1 - norm_logprobs)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import functools
|
|
18
|
+
import inspect
|
|
19
|
+
from collections.abc import Awaitable, Callable
|
|
20
|
+
from typing import Any, TypeVar
|
|
21
|
+
|
|
22
|
+
from graphiti_core.driver.driver import GraphProvider
|
|
23
|
+
from graphiti_core.helpers import semaphore_gather
|
|
24
|
+
from graphiti_core.search.search_config import SearchResults
|
|
25
|
+
|
|
26
|
+
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def handle_multiple_group_ids(func: F) -> F:
|
|
30
|
+
"""
|
|
31
|
+
Decorator for FalkorDB methods that need to handle multiple group_ids.
|
|
32
|
+
Runs the function for each group_id separately and merges results.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@functools.wraps(func)
|
|
36
|
+
async def wrapper(self, *args, **kwargs):
|
|
37
|
+
group_ids_func_pos = get_parameter_position(func, 'group_ids')
|
|
38
|
+
group_ids_pos = (
|
|
39
|
+
group_ids_func_pos - 1 if group_ids_func_pos is not None else None
|
|
40
|
+
) # Adjust for zero-based index
|
|
41
|
+
group_ids = kwargs.get('group_ids')
|
|
42
|
+
|
|
43
|
+
# If not in kwargs and position exists, get from args
|
|
44
|
+
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
|
|
45
|
+
group_ids = args[group_ids_pos]
|
|
46
|
+
|
|
47
|
+
# Only handle FalkorDB with multiple group_ids
|
|
48
|
+
if (
|
|
49
|
+
hasattr(self, 'clients')
|
|
50
|
+
and hasattr(self.clients, 'driver')
|
|
51
|
+
and self.clients.driver.provider == GraphProvider.FALKORDB
|
|
52
|
+
and group_ids
|
|
53
|
+
and len(group_ids) > 1
|
|
54
|
+
):
|
|
55
|
+
# Execute for each group_id concurrently
|
|
56
|
+
driver = self.clients.driver
|
|
57
|
+
|
|
58
|
+
async def execute_for_group(gid: str):
|
|
59
|
+
# Remove group_ids from args if it was passed positionally
|
|
60
|
+
filtered_args = list(args)
|
|
61
|
+
if group_ids_pos is not None and len(args) > group_ids_pos:
|
|
62
|
+
filtered_args.pop(group_ids_pos)
|
|
63
|
+
|
|
64
|
+
return await func(
|
|
65
|
+
self,
|
|
66
|
+
*filtered_args,
|
|
67
|
+
**{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)},
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
results = await semaphore_gather(
|
|
71
|
+
*[execute_for_group(gid) for gid in group_ids],
|
|
72
|
+
max_coroutines=getattr(self, 'max_coroutines', None),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Merge results based on type
|
|
76
|
+
if isinstance(results[0], SearchResults):
|
|
77
|
+
return SearchResults.merge(results)
|
|
78
|
+
elif isinstance(results[0], list):
|
|
79
|
+
return [item for result in results for item in result]
|
|
80
|
+
elif isinstance(results[0], tuple):
|
|
81
|
+
# Handle tuple outputs (like build_communities returning (nodes, edges))
|
|
82
|
+
merged_tuple = []
|
|
83
|
+
for i in range(len(results[0])):
|
|
84
|
+
component_results = [result[i] for result in results]
|
|
85
|
+
if isinstance(component_results[0], list):
|
|
86
|
+
merged_tuple.append(
|
|
87
|
+
[item for component in component_results for item in component]
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
merged_tuple.append(component_results)
|
|
91
|
+
return tuple(merged_tuple)
|
|
92
|
+
else:
|
|
93
|
+
return results
|
|
94
|
+
|
|
95
|
+
# Normal execution
|
|
96
|
+
return await func(self, *args, **kwargs)
|
|
97
|
+
|
|
98
|
+
return wrapper # type: ignore
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_parameter_position(func: Callable, param_name: str) -> int | None:
|
|
102
|
+
"""
|
|
103
|
+
Returns the positional index of a parameter in the function signature.
|
|
104
|
+
If the parameter is not found, returns None.
|
|
105
|
+
"""
|
|
106
|
+
sig = inspect.signature(func)
|
|
107
|
+
for idx, (name, _param) in enumerate(sig.parameters.items()):
|
|
108
|
+
if name == param_name:
|
|
109
|
+
return idx
|
|
110
|
+
return None
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from neo4j import Neo4jDriver
|
|
18
|
+
|
|
19
|
+
__all__ = ['Neo4jDriver']
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import copy
|
|
18
|
+
import logging
|
|
19
|
+
import os
|
|
20
|
+
from abc import ABC, abstractmethod
|
|
21
|
+
from collections.abc import Coroutine
|
|
22
|
+
from enum import Enum
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
from dotenv import load_dotenv
|
|
26
|
+
|
|
27
|
+
from graphiti_core.driver.graph_operations.graph_operations import GraphOperationsInterface
|
|
28
|
+
from graphiti_core.driver.search_interface.search_interface import SearchInterface
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
DEFAULT_SIZE = 10
|
|
33
|
+
|
|
34
|
+
load_dotenv()
|
|
35
|
+
|
|
36
|
+
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
|
|
37
|
+
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
|
|
38
|
+
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
|
39
|
+
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class GraphProvider(Enum):
|
|
43
|
+
NEO4J = 'neo4j'
|
|
44
|
+
FALKORDB = 'falkordb'
|
|
45
|
+
KUZU = 'kuzu'
|
|
46
|
+
NEPTUNE = 'neptune'
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class GraphDriverSession(ABC):
|
|
50
|
+
provider: GraphProvider
|
|
51
|
+
|
|
52
|
+
async def __aenter__(self):
|
|
53
|
+
return self
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
async def __aexit__(self, exc_type, exc, tb):
|
|
57
|
+
# No cleanup needed for Falkor, but method must exist
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
@abstractmethod
|
|
61
|
+
async def run(self, query: str, **kwargs: Any) -> Any:
|
|
62
|
+
raise NotImplementedError()
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
async def close(self):
|
|
66
|
+
raise NotImplementedError()
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
async def execute_write(self, func, *args, **kwargs):
|
|
70
|
+
raise NotImplementedError()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class GraphDriver(ABC):
|
|
74
|
+
provider: GraphProvider
|
|
75
|
+
fulltext_syntax: str = (
|
|
76
|
+
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
|
77
|
+
)
|
|
78
|
+
_database: str
|
|
79
|
+
default_group_id: str = ''
|
|
80
|
+
search_interface: SearchInterface | None = None
|
|
81
|
+
graph_operations_interface: GraphOperationsInterface | None = None
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
|
85
|
+
raise NotImplementedError()
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
def session(self, database: str | None = None) -> GraphDriverSession:
|
|
89
|
+
raise NotImplementedError()
|
|
90
|
+
|
|
91
|
+
@abstractmethod
|
|
92
|
+
def close(self):
|
|
93
|
+
raise NotImplementedError()
|
|
94
|
+
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def delete_all_indexes(self) -> Coroutine:
|
|
97
|
+
raise NotImplementedError()
|
|
98
|
+
|
|
99
|
+
def with_database(self, database: str) -> 'GraphDriver':
|
|
100
|
+
"""
|
|
101
|
+
Returns a shallow copy of this driver with a different default database.
|
|
102
|
+
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
|
103
|
+
"""
|
|
104
|
+
cloned = copy.copy(self)
|
|
105
|
+
cloned._database = database
|
|
106
|
+
|
|
107
|
+
return cloned
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
|
111
|
+
raise NotImplementedError()
|
|
112
|
+
|
|
113
|
+
def clone(self, database: str) -> 'GraphDriver':
|
|
114
|
+
"""Clone the driver with a different database or graph name."""
|
|
115
|
+
return self
|
|
116
|
+
|
|
117
|
+
def build_fulltext_query(
|
|
118
|
+
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
|
119
|
+
) -> str:
|
|
120
|
+
"""
|
|
121
|
+
Specific fulltext query builder for database providers.
|
|
122
|
+
Only implemented by providers that need custom fulltext query building.
|
|
123
|
+
"""
|
|
124
|
+
raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
|