graphiti-core 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/__init__.py +2 -1
- graphiti_core/cross_encoder/gemini_reranker_client.py +146 -0
- graphiti_core/driver/__init__.py +4 -1
- graphiti_core/driver/falkordb_driver.py +47 -21
- graphiti_core/driver/neo4j_driver.py +5 -3
- graphiti_core/embedder/voyage.py +1 -1
- graphiti_core/graphiti.py +79 -5
- graphiti_core/helpers.py +38 -2
- graphiti_core/llm_client/gemini_client.py +135 -23
- graphiti_core/nodes.py +12 -2
- graphiti_core/search/search_filters.py +4 -5
- graphiti_core/search/search_utils.py +2 -8
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/utils/bulk_utils.py +5 -2
- graphiti_core/utils/maintenance/community_operations.py +1 -1
- graphiti_core/utils/maintenance/edge_operations.py +1 -1
- graphiti_core/utils/maintenance/graph_data_operations.py +3 -5
- graphiti_core/utils/maintenance/node_operations.py +6 -0
- {graphiti_core-0.13.2.dist-info → graphiti_core-0.15.0.dist-info}/METADATA +167 -52
- {graphiti_core-0.13.2.dist-info → graphiti_core-0.15.0.dist-info}/RECORD +28 -25
- {graphiti_core-0.13.2.dist-info → graphiti_core-0.15.0.dist-info}/WHEEL +1 -1
- {graphiti_core-0.13.2.dist-info → graphiti_core-0.15.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -15,6 +15,7 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
from .client import CrossEncoderClient
|
|
18
|
+
from .gemini_reranker_client import GeminiRerankerClient
|
|
18
19
|
from .openai_reranker_client import OpenAIRerankerClient
|
|
19
20
|
|
|
20
|
-
__all__ = ['CrossEncoderClient', 'OpenAIRerankerClient']
|
|
21
|
+
__all__ = ['CrossEncoderClient', 'GeminiRerankerClient', 'OpenAIRerankerClient']
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import re
|
|
19
|
+
|
|
20
|
+
from google import genai # type: ignore
|
|
21
|
+
from google.genai import types # type: ignore
|
|
22
|
+
|
|
23
|
+
from ..helpers import semaphore_gather
|
|
24
|
+
from ..llm_client import LLMConfig, RateLimitError
|
|
25
|
+
from .client import CrossEncoderClient
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class GeminiRerankerClient(CrossEncoderClient):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
config: LLMConfig | None = None,
|
|
36
|
+
client: genai.Client | None = None,
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Initialize the GeminiRerankerClient with the provided configuration and client.
|
|
40
|
+
|
|
41
|
+
The Gemini Developer API does not yet support logprobs. Unlike the OpenAI reranker,
|
|
42
|
+
this reranker uses the Gemini API to perform direct relevance scoring of passages.
|
|
43
|
+
Each passage is scored individually on a 0-100 scale.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
|
47
|
+
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
|
48
|
+
"""
|
|
49
|
+
if config is None:
|
|
50
|
+
config = LLMConfig()
|
|
51
|
+
|
|
52
|
+
self.config = config
|
|
53
|
+
if client is None:
|
|
54
|
+
self.client = genai.Client(api_key=config.api_key)
|
|
55
|
+
else:
|
|
56
|
+
self.client = client
|
|
57
|
+
|
|
58
|
+
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
|
59
|
+
"""
|
|
60
|
+
Rank passages based on their relevance to the query using direct scoring.
|
|
61
|
+
|
|
62
|
+
Each passage is scored individually on a 0-100 scale, then normalized to [0,1].
|
|
63
|
+
"""
|
|
64
|
+
if len(passages) <= 1:
|
|
65
|
+
return [(passage, 1.0) for passage in passages]
|
|
66
|
+
|
|
67
|
+
# Generate scoring prompts for each passage
|
|
68
|
+
scoring_prompts = []
|
|
69
|
+
for passage in passages:
|
|
70
|
+
prompt = f"""Rate how well this passage answers or relates to the query. Use a scale from 0 to 100.
|
|
71
|
+
|
|
72
|
+
Query: {query}
|
|
73
|
+
|
|
74
|
+
Passage: {passage}
|
|
75
|
+
|
|
76
|
+
Provide only a number between 0 and 100 (no explanation, just the number):"""
|
|
77
|
+
|
|
78
|
+
scoring_prompts.append(
|
|
79
|
+
[
|
|
80
|
+
types.Content(
|
|
81
|
+
role='user',
|
|
82
|
+
parts=[types.Part.from_text(text=prompt)],
|
|
83
|
+
),
|
|
84
|
+
]
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
# Execute all scoring requests concurrently - O(n) API calls
|
|
89
|
+
responses = await semaphore_gather(
|
|
90
|
+
*[
|
|
91
|
+
self.client.aio.models.generate_content(
|
|
92
|
+
model=self.config.model or DEFAULT_MODEL,
|
|
93
|
+
contents=prompt_messages, # type: ignore
|
|
94
|
+
config=types.GenerateContentConfig(
|
|
95
|
+
system_instruction='You are an expert at rating passage relevance. Respond with only a number from 0-100.',
|
|
96
|
+
temperature=0.0,
|
|
97
|
+
max_output_tokens=3,
|
|
98
|
+
),
|
|
99
|
+
)
|
|
100
|
+
for prompt_messages in scoring_prompts
|
|
101
|
+
]
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Extract scores and create results
|
|
105
|
+
results = []
|
|
106
|
+
for passage, response in zip(passages, responses, strict=True):
|
|
107
|
+
try:
|
|
108
|
+
if hasattr(response, 'text') and response.text:
|
|
109
|
+
# Extract numeric score from response
|
|
110
|
+
score_text = response.text.strip()
|
|
111
|
+
# Handle cases where model might return non-numeric text
|
|
112
|
+
score_match = re.search(r'\b(\d{1,3})\b', score_text)
|
|
113
|
+
if score_match:
|
|
114
|
+
score = float(score_match.group(1))
|
|
115
|
+
# Normalize to [0, 1] range and clamp to valid range
|
|
116
|
+
normalized_score = max(0.0, min(1.0, score / 100.0))
|
|
117
|
+
results.append((passage, normalized_score))
|
|
118
|
+
else:
|
|
119
|
+
logger.warning(
|
|
120
|
+
f'Could not extract numeric score from response: {score_text}'
|
|
121
|
+
)
|
|
122
|
+
results.append((passage, 0.0))
|
|
123
|
+
else:
|
|
124
|
+
logger.warning('Empty response from Gemini for passage scoring')
|
|
125
|
+
results.append((passage, 0.0))
|
|
126
|
+
except (ValueError, AttributeError) as e:
|
|
127
|
+
logger.warning(f'Error parsing score from Gemini response: {e}')
|
|
128
|
+
results.append((passage, 0.0))
|
|
129
|
+
|
|
130
|
+
# Sort by score in descending order (highest relevance first)
|
|
131
|
+
results.sort(reverse=True, key=lambda x: x[1])
|
|
132
|
+
return results
|
|
133
|
+
|
|
134
|
+
except Exception as e:
|
|
135
|
+
# Check if it's a rate limit error based on Gemini API error codes
|
|
136
|
+
error_message = str(e).lower()
|
|
137
|
+
if (
|
|
138
|
+
'rate limit' in error_message
|
|
139
|
+
or 'quota' in error_message
|
|
140
|
+
or 'resource_exhausted' in error_message
|
|
141
|
+
or '429' in str(e)
|
|
142
|
+
):
|
|
143
|
+
raise RateLimitError from e
|
|
144
|
+
|
|
145
|
+
logger.error(f'Error in generating LLM response: {e}')
|
|
146
|
+
raise
|
graphiti_core/driver/__init__.py
CHANGED
|
@@ -14,4 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
from falkordb import FalkorDB
|
|
18
|
+
from neo4j import Neo4jDriver
|
|
19
|
+
|
|
20
|
+
__all__ = ['Neo4jDriver', 'FalkorDB']
|
|
@@ -15,7 +15,6 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
|
-
from collections.abc import Coroutine
|
|
19
18
|
from datetime import datetime
|
|
20
19
|
from typing import Any
|
|
21
20
|
|
|
@@ -52,11 +51,11 @@ class FalkorDriverSession(GraphDriverSession):
|
|
|
52
51
|
if isinstance(query, list):
|
|
53
52
|
for cypher, params in query:
|
|
54
53
|
params = convert_datetimes_to_strings(params)
|
|
55
|
-
await self.graph.query(str(cypher), params)
|
|
54
|
+
await self.graph.query(str(cypher), params) # type: ignore[reportUnknownArgumentType]
|
|
56
55
|
else:
|
|
57
56
|
params = dict(kwargs)
|
|
58
57
|
params = convert_datetimes_to_strings(params)
|
|
59
|
-
await self.graph.query(str(query), params)
|
|
58
|
+
await self.graph.query(str(query), params) # type: ignore[reportUnknownArgumentType]
|
|
60
59
|
# Assuming `graph.query` is async (ideal); otherwise, wrap in executor
|
|
61
60
|
return None
|
|
62
61
|
|
|
@@ -66,22 +65,30 @@ class FalkorDriver(GraphDriver):
|
|
|
66
65
|
|
|
67
66
|
def __init__(
|
|
68
67
|
self,
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
68
|
+
host: str = 'localhost',
|
|
69
|
+
port: int = 6379,
|
|
70
|
+
username: str | None = None,
|
|
71
|
+
password: str | None = None,
|
|
72
|
+
falkor_db: FalkorDB | None = None,
|
|
72
73
|
):
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
|
|
74
|
+
"""
|
|
75
|
+
Initialize the FalkorDB driver.
|
|
76
76
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
)
|
|
77
|
+
FalkorDB is a multi-tenant graph database.
|
|
78
|
+
To connect, provide the host and port.
|
|
79
|
+
The default parameters assume a local (on-premises) FalkorDB instance.
|
|
80
|
+
"""
|
|
81
|
+
super().__init__()
|
|
82
|
+
if falkor_db is not None:
|
|
83
|
+
# If a FalkorDB instance is provided, use it directly
|
|
84
|
+
self.client = falkor_db
|
|
85
|
+
else:
|
|
86
|
+
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
|
80
87
|
|
|
81
88
|
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
|
82
|
-
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is
|
|
89
|
+
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is DEFAULT_DATABASE
|
|
83
90
|
if graph_name is None:
|
|
84
|
-
graph_name =
|
|
91
|
+
graph_name = DEFAULT_DATABASE
|
|
85
92
|
return self.client.select_graph(graph_name)
|
|
86
93
|
|
|
87
94
|
async def execute_query(self, cypher_query_, **kwargs: Any):
|
|
@@ -92,7 +99,7 @@ class FalkorDriver(GraphDriver):
|
|
|
92
99
|
params = convert_datetimes_to_strings(dict(kwargs))
|
|
93
100
|
|
|
94
101
|
try:
|
|
95
|
-
result = await graph.query(cypher_query_, params)
|
|
102
|
+
result = await graph.query(cypher_query_, params) # type: ignore[reportUnknownArgumentType]
|
|
96
103
|
except Exception as e:
|
|
97
104
|
if 'already indexed' in str(e):
|
|
98
105
|
# check if index already exists
|
|
@@ -102,17 +109,36 @@ class FalkorDriver(GraphDriver):
|
|
|
102
109
|
raise
|
|
103
110
|
|
|
104
111
|
# Convert the result header to a list of strings
|
|
105
|
-
header = [h[1]
|
|
106
|
-
|
|
112
|
+
header = [h[1] for h in result.header]
|
|
113
|
+
|
|
114
|
+
# Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts)
|
|
115
|
+
records = []
|
|
116
|
+
for row in result.result_set:
|
|
117
|
+
record = {}
|
|
118
|
+
for i, field_name in enumerate(header):
|
|
119
|
+
if i < len(row):
|
|
120
|
+
record[field_name] = row[i]
|
|
121
|
+
else:
|
|
122
|
+
# If there are more fields in header than values in row, set to None
|
|
123
|
+
record[field_name] = None
|
|
124
|
+
records.append(record)
|
|
125
|
+
|
|
126
|
+
return records, header, None
|
|
107
127
|
|
|
108
128
|
def session(self, database: str | None) -> GraphDriverSession:
|
|
109
129
|
return FalkorDriverSession(self._get_graph(database))
|
|
110
130
|
|
|
111
131
|
async def close(self) -> None:
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
132
|
+
"""Close the driver connection."""
|
|
133
|
+
if hasattr(self.client, 'aclose'):
|
|
134
|
+
await self.client.aclose() # type: ignore[reportUnknownMemberType]
|
|
135
|
+
elif hasattr(self.client.connection, 'aclose'):
|
|
136
|
+
await self.client.connection.aclose()
|
|
137
|
+
elif hasattr(self.client.connection, 'close'):
|
|
138
|
+
await self.client.connection.close()
|
|
139
|
+
|
|
140
|
+
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> None:
|
|
141
|
+
await self.execute_query(
|
|
116
142
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
117
143
|
database_=database_,
|
|
118
144
|
)
|
|
@@ -18,7 +18,7 @@ import logging
|
|
|
18
18
|
from collections.abc import Coroutine
|
|
19
19
|
from typing import Any
|
|
20
20
|
|
|
21
|
-
from neo4j import AsyncGraphDatabase
|
|
21
|
+
from neo4j import AsyncGraphDatabase, EagerResult
|
|
22
22
|
from typing_extensions import LiteralString
|
|
23
23
|
|
|
24
24
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
@@ -42,7 +42,7 @@ class Neo4jDriver(GraphDriver):
|
|
|
42
42
|
auth=(user or '', password or ''),
|
|
43
43
|
)
|
|
44
44
|
|
|
45
|
-
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) ->
|
|
45
|
+
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
|
46
46
|
params = kwargs.pop('params', None)
|
|
47
47
|
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
|
48
48
|
|
|
@@ -54,7 +54,9 @@ class Neo4jDriver(GraphDriver):
|
|
|
54
54
|
async def close(self) -> None:
|
|
55
55
|
return await self.client.close()
|
|
56
56
|
|
|
57
|
-
def delete_all_indexes(
|
|
57
|
+
def delete_all_indexes(
|
|
58
|
+
self, database_: str = DEFAULT_DATABASE
|
|
59
|
+
) -> Coroutine[Any, Any, EagerResult]:
|
|
58
60
|
return self.client.execute_query(
|
|
59
61
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
60
62
|
database_=database_,
|
graphiti_core/embedder/voyage.py
CHANGED
|
@@ -38,7 +38,7 @@ class VoyageAIEmbedder(EmbedderClient):
|
|
|
38
38
|
if config is None:
|
|
39
39
|
config = VoyageAIEmbedderConfig()
|
|
40
40
|
self.config = config
|
|
41
|
-
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
|
41
|
+
self.client = voyageai.AsyncClient(api_key=config.api_key) # type: ignore[reportUnknownMemberType]
|
|
42
42
|
|
|
43
43
|
async def create(
|
|
44
44
|
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
graphiti_core/graphiti.py
CHANGED
|
@@ -29,7 +29,12 @@ from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|
|
29
29
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
30
30
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
31
31
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
32
|
-
from graphiti_core.helpers import
|
|
32
|
+
from graphiti_core.helpers import (
|
|
33
|
+
DEFAULT_DATABASE,
|
|
34
|
+
semaphore_gather,
|
|
35
|
+
validate_excluded_entity_types,
|
|
36
|
+
validate_group_id,
|
|
37
|
+
)
|
|
33
38
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
34
39
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
|
35
40
|
from graphiti_core.search.search import SearchConfig, search
|
|
@@ -46,6 +51,7 @@ from graphiti_core.search.search_utils import (
|
|
|
46
51
|
get_mentioned_nodes,
|
|
47
52
|
get_relevant_edges,
|
|
48
53
|
)
|
|
54
|
+
from graphiti_core.telemetry import capture_event
|
|
49
55
|
from graphiti_core.utils.bulk_utils import (
|
|
50
56
|
RawEpisode,
|
|
51
57
|
add_nodes_and_edges_bulk,
|
|
@@ -95,7 +101,7 @@ class AddEpisodeResults(BaseModel):
|
|
|
95
101
|
class Graphiti:
|
|
96
102
|
def __init__(
|
|
97
103
|
self,
|
|
98
|
-
uri: str,
|
|
104
|
+
uri: str | None = None,
|
|
99
105
|
user: str | None = None,
|
|
100
106
|
password: str | None = None,
|
|
101
107
|
llm_client: LLMClient | None = None,
|
|
@@ -156,7 +162,12 @@ class Graphiti:
|
|
|
156
162
|
Graphiti if you're using the default OpenAIClient.
|
|
157
163
|
"""
|
|
158
164
|
|
|
159
|
-
|
|
165
|
+
if graph_driver:
|
|
166
|
+
self.driver = graph_driver
|
|
167
|
+
else:
|
|
168
|
+
if uri is None:
|
|
169
|
+
raise ValueError("uri must be provided when graph_driver is None")
|
|
170
|
+
self.driver = Neo4jDriver(uri, user, password)
|
|
160
171
|
|
|
161
172
|
self.database = DEFAULT_DATABASE
|
|
162
173
|
self.store_raw_episode_content = store_raw_episode_content
|
|
@@ -181,6 +192,61 @@ class Graphiti:
|
|
|
181
192
|
cross_encoder=self.cross_encoder,
|
|
182
193
|
)
|
|
183
194
|
|
|
195
|
+
# Capture telemetry event
|
|
196
|
+
self._capture_initialization_telemetry()
|
|
197
|
+
|
|
198
|
+
def _capture_initialization_telemetry(self):
|
|
199
|
+
"""Capture telemetry event for Graphiti initialization."""
|
|
200
|
+
try:
|
|
201
|
+
# Detect provider types from class names
|
|
202
|
+
llm_provider = self._get_provider_type(self.llm_client)
|
|
203
|
+
embedder_provider = self._get_provider_type(self.embedder)
|
|
204
|
+
reranker_provider = self._get_provider_type(self.cross_encoder)
|
|
205
|
+
database_provider = self._get_provider_type(self.driver)
|
|
206
|
+
|
|
207
|
+
properties = {
|
|
208
|
+
'llm_provider': llm_provider,
|
|
209
|
+
'embedder_provider': embedder_provider,
|
|
210
|
+
'reranker_provider': reranker_provider,
|
|
211
|
+
'database_provider': database_provider,
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
capture_event('graphiti_initialized', properties)
|
|
215
|
+
except Exception:
|
|
216
|
+
# Silently handle telemetry errors
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
def _get_provider_type(self, client) -> str:
|
|
220
|
+
"""Get provider type from client class name."""
|
|
221
|
+
if client is None:
|
|
222
|
+
return 'none'
|
|
223
|
+
|
|
224
|
+
class_name = client.__class__.__name__.lower()
|
|
225
|
+
|
|
226
|
+
# LLM providers
|
|
227
|
+
if 'openai' in class_name:
|
|
228
|
+
return 'openai'
|
|
229
|
+
elif 'azure' in class_name:
|
|
230
|
+
return 'azure'
|
|
231
|
+
elif 'anthropic' in class_name:
|
|
232
|
+
return 'anthropic'
|
|
233
|
+
elif 'crossencoder' in class_name:
|
|
234
|
+
return 'crossencoder'
|
|
235
|
+
elif 'gemini' in class_name:
|
|
236
|
+
return 'gemini'
|
|
237
|
+
elif 'groq' in class_name:
|
|
238
|
+
return 'groq'
|
|
239
|
+
# Database providers
|
|
240
|
+
elif 'neo4j' in class_name:
|
|
241
|
+
return 'neo4j'
|
|
242
|
+
elif 'falkor' in class_name:
|
|
243
|
+
return 'falkordb'
|
|
244
|
+
# Embedder providers
|
|
245
|
+
elif 'voyage' in class_name:
|
|
246
|
+
return 'voyage'
|
|
247
|
+
else:
|
|
248
|
+
return 'unknown'
|
|
249
|
+
|
|
184
250
|
async def close(self):
|
|
185
251
|
"""
|
|
186
252
|
Close the connection to the Neo4j database.
|
|
@@ -293,6 +359,7 @@ class Graphiti:
|
|
|
293
359
|
uuid: str | None = None,
|
|
294
360
|
update_communities: bool = False,
|
|
295
361
|
entity_types: dict[str, BaseModel] | None = None,
|
|
362
|
+
excluded_entity_types: list[str] | None = None,
|
|
296
363
|
previous_episode_uuids: list[str] | None = None,
|
|
297
364
|
edge_types: dict[str, BaseModel] | None = None,
|
|
298
365
|
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
|
@@ -321,6 +388,12 @@ class Graphiti:
|
|
|
321
388
|
Optional uuid of the episode.
|
|
322
389
|
update_communities : bool
|
|
323
390
|
Optional. Whether to update communities with new node information
|
|
391
|
+
entity_types : dict[str, BaseModel] | None
|
|
392
|
+
Optional. Dictionary mapping entity type names to their Pydantic model definitions.
|
|
393
|
+
excluded_entity_types : list[str] | None
|
|
394
|
+
Optional. List of entity type names to exclude from the graph. Entities classified
|
|
395
|
+
into these types will not be added to the graph. Can include 'Entity' to exclude
|
|
396
|
+
the default entity type.
|
|
324
397
|
previous_episode_uuids : list[str] | None
|
|
325
398
|
Optional. list of episode uuids to use as the previous episodes. If this is not provided,
|
|
326
399
|
the most recent episodes by created_at date will be used.
|
|
@@ -351,6 +424,7 @@ class Graphiti:
|
|
|
351
424
|
now = utc_now()
|
|
352
425
|
|
|
353
426
|
validate_entity_types(entity_types)
|
|
427
|
+
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
|
354
428
|
validate_group_id(group_id)
|
|
355
429
|
|
|
356
430
|
previous_episodes = (
|
|
@@ -389,7 +463,7 @@ class Graphiti:
|
|
|
389
463
|
# Extract entities as nodes
|
|
390
464
|
|
|
391
465
|
extracted_nodes = await extract_nodes(
|
|
392
|
-
self.clients, episode, previous_episodes, entity_types
|
|
466
|
+
self.clients, episode, previous_episodes, entity_types, excluded_entity_types
|
|
393
467
|
)
|
|
394
468
|
|
|
395
469
|
# Extract edges and resolve nodes
|
|
@@ -534,7 +608,7 @@ class Graphiti:
|
|
|
534
608
|
extracted_nodes,
|
|
535
609
|
extracted_edges,
|
|
536
610
|
episodic_edges,
|
|
537
|
-
) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs)
|
|
611
|
+
) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs, None, None)
|
|
538
612
|
|
|
539
613
|
# Generate embeddings
|
|
540
614
|
await semaphore_gather(
|
graphiti_core/helpers.py
CHANGED
|
@@ -19,18 +19,20 @@ import os
|
|
|
19
19
|
import re
|
|
20
20
|
from collections.abc import Coroutine
|
|
21
21
|
from datetime import datetime
|
|
22
|
+
from typing import Any
|
|
22
23
|
|
|
23
24
|
import numpy as np
|
|
24
25
|
from dotenv import load_dotenv
|
|
25
26
|
from neo4j import time as neo4j_time
|
|
26
27
|
from numpy._typing import NDArray
|
|
28
|
+
from pydantic import BaseModel
|
|
27
29
|
from typing_extensions import LiteralString
|
|
28
30
|
|
|
29
31
|
from graphiti_core.errors import GroupIdValidationError
|
|
30
32
|
|
|
31
33
|
load_dotenv()
|
|
32
34
|
|
|
33
|
-
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', '
|
|
35
|
+
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'default_db')
|
|
34
36
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
35
37
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
|
36
38
|
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
|
@@ -98,7 +100,7 @@ def normalize_l2(embedding: list[float]) -> NDArray:
|
|
|
98
100
|
async def semaphore_gather(
|
|
99
101
|
*coroutines: Coroutine,
|
|
100
102
|
max_coroutines: int | None = None,
|
|
101
|
-
):
|
|
103
|
+
) -> list[Any]:
|
|
102
104
|
semaphore = asyncio.Semaphore(max_coroutines or SEMAPHORE_LIMIT)
|
|
103
105
|
|
|
104
106
|
async def _wrap_coroutine(coroutine):
|
|
@@ -132,3 +134,37 @@ def validate_group_id(group_id: str) -> bool:
|
|
|
132
134
|
raise GroupIdValidationError(group_id)
|
|
133
135
|
|
|
134
136
|
return True
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def validate_excluded_entity_types(
|
|
140
|
+
excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None
|
|
141
|
+
) -> bool:
|
|
142
|
+
"""
|
|
143
|
+
Validate that excluded entity types are valid type names.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
excluded_entity_types: List of entity type names to exclude
|
|
147
|
+
entity_types: Dictionary of available custom entity types
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
True if valid
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
ValueError: If any excluded type names are invalid
|
|
154
|
+
"""
|
|
155
|
+
if not excluded_entity_types:
|
|
156
|
+
return True
|
|
157
|
+
|
|
158
|
+
# Build set of available type names
|
|
159
|
+
available_types = {'Entity'} # Default type is always available
|
|
160
|
+
if entity_types:
|
|
161
|
+
available_types.update(entity_types.keys())
|
|
162
|
+
|
|
163
|
+
# Check for invalid type names
|
|
164
|
+
invalid_types = set(excluded_entity_types) - available_types
|
|
165
|
+
if invalid_types:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
f'Invalid excluded entity types: {sorted(invalid_types)}. Available types: {sorted(available_types)}'
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return True
|