graphiti-core 0.3.15__tar.gz → 0.3.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/PKG-INFO +1 -2
- graphiti_core-0.3.17/graphiti_core/cross_encoder/bge_reranker_client.py +45 -0
- graphiti_core-0.3.17/graphiti_core/cross_encoder/client.py +41 -0
- graphiti_core-0.3.17/graphiti_core/cross_encoder/openai_reranker_client.py +113 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/edges.py +13 -13
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/graphiti.py +28 -3
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/nodes.py +13 -13
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/search/search.py +43 -15
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/search/search_config.py +13 -1
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/search/search_config_recipes.py +27 -1
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/search/search_utils.py +188 -113
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/community_operations.py +5 -5
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/graph_data_operations.py +1 -1
- graphiti_core-0.3.17/graphiti_core/utils/maintenance/utils.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/pyproject.toml +4 -2
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/LICENSE +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/README.md +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.3.15/graphiti_core/models → graphiti_core-0.3.17/graphiti_core/cross_encoder}/__init__.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/embedder/__init__.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/embedder/client.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/embedder/openai.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/embedder/voyage.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/errors.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/helpers.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/llm_client/anthropic_client.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/llm_client/client.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/llm_client/errors.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/llm_client/groq_client.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/llm_client/openai_client.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.3.15/graphiti_core/models/edges → graphiti_core-0.3.17/graphiti_core/models}/__init__.py +0 -0
- {graphiti_core-0.3.15/graphiti_core/models/nodes → graphiti_core-0.3.17/graphiti_core/models/edges}/__init__.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/models/edges/edge_db_queries.py +0 -0
- {graphiti_core-0.3.15/graphiti_core/search → graphiti_core-0.3.17/graphiti_core/models/nodes}/__init__.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/models/nodes/node_db_queries.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/prompts/eval.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/prompts/extract_edge_dates.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/prompts/extract_edges.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/prompts/extract_nodes.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/prompts/summarize_nodes.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/py.typed +0 -0
- /graphiti_core-0.3.15/graphiti_core/utils/maintenance/utils.py → /graphiti_core-0.3.17/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/utils/bulk_utils.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/edge_operations.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/node_operations.py +0 -0
- {graphiti_core-0.3.15 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.17
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -17,7 +17,6 @@ Requires-Dist: numpy (>=1.0.0)
|
|
|
17
17
|
Requires-Dist: openai (>=1.50.2,<2.0.0)
|
|
18
18
|
Requires-Dist: pydantic (>=2.8.2,<3.0.0)
|
|
19
19
|
Requires-Dist: tenacity (<9.0.0)
|
|
20
|
-
Requires-Dist: voyageai (>=0.2.3,<0.3.0)
|
|
21
20
|
Description-Content-Type: text/markdown
|
|
22
21
|
|
|
23
22
|
<div align="center">
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
from typing import List, Tuple
|
|
19
|
+
|
|
20
|
+
from sentence_transformers import CrossEncoder
|
|
21
|
+
|
|
22
|
+
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class BGERerankerClient(CrossEncoderClient):
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
|
|
28
|
+
|
|
29
|
+
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
|
|
30
|
+
if not passages:
|
|
31
|
+
return []
|
|
32
|
+
|
|
33
|
+
input_pairs = [[query, passage] for passage in passages]
|
|
34
|
+
|
|
35
|
+
# Run the synchronous predict method in an executor
|
|
36
|
+
loop = asyncio.get_running_loop()
|
|
37
|
+
scores = await loop.run_in_executor(None, self.model.predict, input_pairs)
|
|
38
|
+
|
|
39
|
+
ranked_passages = sorted(
|
|
40
|
+
[(passage, float(score)) for passage, score in zip(passages, scores)],
|
|
41
|
+
key=lambda x: x[1],
|
|
42
|
+
reverse=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
return ranked_passages
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from typing import List, Tuple
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CrossEncoderClient(ABC):
|
|
22
|
+
"""
|
|
23
|
+
CrossEncoderClient is an abstract base class that defines the interface
|
|
24
|
+
for cross-encoder models used for ranking passages based on their relevance to a query.
|
|
25
|
+
It allows for different implementations of cross-encoder models to be used interchangeably.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
|
|
30
|
+
"""
|
|
31
|
+
Rank the given passages based on their relevance to the query.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
query (str): The query string.
|
|
35
|
+
passages (List[str]): A list of passages to rank.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
List[Tuple[str, float]]: A list of tuples containing the passage and its score,
|
|
39
|
+
sorted in descending order of relevance.
|
|
40
|
+
"""
|
|
41
|
+
pass
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import openai
|
|
22
|
+
from openai import AsyncOpenAI
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
|
|
25
|
+
from ..llm_client import LLMConfig, RateLimitError
|
|
26
|
+
from ..prompts import Message
|
|
27
|
+
from .client import CrossEncoderClient
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
DEFAULT_MODEL = 'gpt-4o-mini'
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class BooleanClassifier(BaseModel):
|
|
35
|
+
isTrue: bool
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class OpenAIRerankerClient(CrossEncoderClient):
|
|
39
|
+
def __init__(self, config: LLMConfig | None = None):
|
|
40
|
+
"""
|
|
41
|
+
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
|
45
|
+
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
46
|
+
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
if config is None:
|
|
50
|
+
config = LLMConfig()
|
|
51
|
+
|
|
52
|
+
self.config = config
|
|
53
|
+
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
54
|
+
|
|
55
|
+
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
|
56
|
+
openai_messages_list: Any = [
|
|
57
|
+
[
|
|
58
|
+
Message(
|
|
59
|
+
role='system',
|
|
60
|
+
content='You are an expert tasked with determining whether the passage is relevant to the query',
|
|
61
|
+
),
|
|
62
|
+
Message(
|
|
63
|
+
role='user',
|
|
64
|
+
content=f"""
|
|
65
|
+
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
|
|
66
|
+
<PASSAGE>
|
|
67
|
+
{query}
|
|
68
|
+
</PASSAGE>
|
|
69
|
+
{passage}
|
|
70
|
+
<QUERY>
|
|
71
|
+
</QUERY>
|
|
72
|
+
""",
|
|
73
|
+
),
|
|
74
|
+
]
|
|
75
|
+
for passage in passages
|
|
76
|
+
]
|
|
77
|
+
try:
|
|
78
|
+
responses = await asyncio.gather(
|
|
79
|
+
*[
|
|
80
|
+
self.client.chat.completions.create(
|
|
81
|
+
model=DEFAULT_MODEL,
|
|
82
|
+
messages=openai_messages,
|
|
83
|
+
temperature=0,
|
|
84
|
+
max_tokens=1,
|
|
85
|
+
logit_bias={'6432': 1, '7983': 1},
|
|
86
|
+
logprobs=True,
|
|
87
|
+
top_logprobs=2,
|
|
88
|
+
)
|
|
89
|
+
for openai_messages in openai_messages_list
|
|
90
|
+
]
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
responses_top_logprobs = [
|
|
94
|
+
response.choices[0].logprobs.content[0].top_logprobs
|
|
95
|
+
if response.choices[0].logprobs is not None
|
|
96
|
+
and response.choices[0].logprobs.content is not None
|
|
97
|
+
else []
|
|
98
|
+
for response in responses
|
|
99
|
+
]
|
|
100
|
+
scores: list[float] = []
|
|
101
|
+
for top_logprobs in responses_top_logprobs:
|
|
102
|
+
for logprob in top_logprobs:
|
|
103
|
+
if bool(logprob.token):
|
|
104
|
+
scores.append(logprob.logprob)
|
|
105
|
+
|
|
106
|
+
results = [(passage, score) for passage, score in zip(passages, scores)]
|
|
107
|
+
results.sort(reverse=True, key=lambda x: x[1])
|
|
108
|
+
return results
|
|
109
|
+
except openai.RateLimitError as e:
|
|
110
|
+
raise RateLimitError from e
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.error(f'Error in generating LLM response: {e}')
|
|
113
|
+
raise
|
|
@@ -54,7 +54,7 @@ class Edge(BaseModel, ABC):
|
|
|
54
54
|
DELETE e
|
|
55
55
|
""",
|
|
56
56
|
uuid=self.uuid,
|
|
57
|
-
|
|
57
|
+
database_=DEFAULT_DATABASE,
|
|
58
58
|
)
|
|
59
59
|
|
|
60
60
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
|
@@ -82,7 +82,7 @@ class EpisodicEdge(Edge):
|
|
|
82
82
|
uuid=self.uuid,
|
|
83
83
|
group_id=self.group_id,
|
|
84
84
|
created_at=self.created_at,
|
|
85
|
-
|
|
85
|
+
database_=DEFAULT_DATABASE,
|
|
86
86
|
)
|
|
87
87
|
|
|
88
88
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
|
@@ -102,7 +102,7 @@ class EpisodicEdge(Edge):
|
|
|
102
102
|
e.created_at AS created_at
|
|
103
103
|
""",
|
|
104
104
|
uuid=uuid,
|
|
105
|
-
|
|
105
|
+
database_=DEFAULT_DATABASE,
|
|
106
106
|
)
|
|
107
107
|
|
|
108
108
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
|
@@ -125,7 +125,7 @@ class EpisodicEdge(Edge):
|
|
|
125
125
|
e.created_at AS created_at
|
|
126
126
|
""",
|
|
127
127
|
uuids=uuids,
|
|
128
|
-
|
|
128
|
+
database_=DEFAULT_DATABASE,
|
|
129
129
|
)
|
|
130
130
|
|
|
131
131
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
|
@@ -148,7 +148,7 @@ class EpisodicEdge(Edge):
|
|
|
148
148
|
e.created_at AS created_at
|
|
149
149
|
""",
|
|
150
150
|
group_ids=group_ids,
|
|
151
|
-
|
|
151
|
+
database_=DEFAULT_DATABASE,
|
|
152
152
|
)
|
|
153
153
|
|
|
154
154
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
|
@@ -202,7 +202,7 @@ class EntityEdge(Edge):
|
|
|
202
202
|
expired_at=self.expired_at,
|
|
203
203
|
valid_at=self.valid_at,
|
|
204
204
|
invalid_at=self.invalid_at,
|
|
205
|
-
|
|
205
|
+
database_=DEFAULT_DATABASE,
|
|
206
206
|
)
|
|
207
207
|
|
|
208
208
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
|
@@ -229,7 +229,7 @@ class EntityEdge(Edge):
|
|
|
229
229
|
e.invalid_at AS invalid_at
|
|
230
230
|
""",
|
|
231
231
|
uuid=uuid,
|
|
232
|
-
|
|
232
|
+
database_=DEFAULT_DATABASE,
|
|
233
233
|
)
|
|
234
234
|
|
|
235
235
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
@@ -259,7 +259,7 @@ class EntityEdge(Edge):
|
|
|
259
259
|
e.invalid_at AS invalid_at
|
|
260
260
|
""",
|
|
261
261
|
uuids=uuids,
|
|
262
|
-
|
|
262
|
+
database_=DEFAULT_DATABASE,
|
|
263
263
|
)
|
|
264
264
|
|
|
265
265
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
@@ -289,7 +289,7 @@ class EntityEdge(Edge):
|
|
|
289
289
|
e.invalid_at AS invalid_at
|
|
290
290
|
""",
|
|
291
291
|
group_ids=group_ids,
|
|
292
|
-
|
|
292
|
+
database_=DEFAULT_DATABASE,
|
|
293
293
|
)
|
|
294
294
|
|
|
295
295
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
@@ -308,7 +308,7 @@ class CommunityEdge(Edge):
|
|
|
308
308
|
uuid=self.uuid,
|
|
309
309
|
group_id=self.group_id,
|
|
310
310
|
created_at=self.created_at,
|
|
311
|
-
|
|
311
|
+
database_=DEFAULT_DATABASE,
|
|
312
312
|
)
|
|
313
313
|
|
|
314
314
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
|
@@ -328,7 +328,7 @@ class CommunityEdge(Edge):
|
|
|
328
328
|
e.created_at AS created_at
|
|
329
329
|
""",
|
|
330
330
|
uuid=uuid,
|
|
331
|
-
|
|
331
|
+
database_=DEFAULT_DATABASE,
|
|
332
332
|
)
|
|
333
333
|
|
|
334
334
|
edges = [get_community_edge_from_record(record) for record in records]
|
|
@@ -349,7 +349,7 @@ class CommunityEdge(Edge):
|
|
|
349
349
|
e.created_at AS created_at
|
|
350
350
|
""",
|
|
351
351
|
uuids=uuids,
|
|
352
|
-
|
|
352
|
+
database_=DEFAULT_DATABASE,
|
|
353
353
|
)
|
|
354
354
|
|
|
355
355
|
edges = [get_community_edge_from_record(record) for record in records]
|
|
@@ -370,7 +370,7 @@ class CommunityEdge(Edge):
|
|
|
370
370
|
e.created_at AS created_at
|
|
371
371
|
""",
|
|
372
372
|
group_ids=group_ids,
|
|
373
|
-
|
|
373
|
+
database_=DEFAULT_DATABASE,
|
|
374
374
|
)
|
|
375
375
|
|
|
376
376
|
edges = [get_community_edge_from_record(record) for record in records]
|
|
@@ -23,8 +23,11 @@ from dotenv import load_dotenv
|
|
|
23
23
|
from neo4j import AsyncGraphDatabase
|
|
24
24
|
from pydantic import BaseModel
|
|
25
25
|
|
|
26
|
+
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
27
|
+
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
26
28
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
27
29
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
30
|
+
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
28
31
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
29
32
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
|
30
33
|
from graphiti_core.search.search import SearchConfig, search
|
|
@@ -92,6 +95,7 @@ class Graphiti:
|
|
|
92
95
|
password: str,
|
|
93
96
|
llm_client: LLMClient | None = None,
|
|
94
97
|
embedder: EmbedderClient | None = None,
|
|
98
|
+
cross_encoder: CrossEncoderClient | None = None,
|
|
95
99
|
store_raw_episode_content: bool = True,
|
|
96
100
|
):
|
|
97
101
|
"""
|
|
@@ -131,7 +135,7 @@ class Graphiti:
|
|
|
131
135
|
Graphiti if you're using the default OpenAIClient.
|
|
132
136
|
"""
|
|
133
137
|
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
|
134
|
-
self.database =
|
|
138
|
+
self.database = DEFAULT_DATABASE
|
|
135
139
|
self.store_raw_episode_content = store_raw_episode_content
|
|
136
140
|
if llm_client:
|
|
137
141
|
self.llm_client = llm_client
|
|
@@ -141,6 +145,10 @@ class Graphiti:
|
|
|
141
145
|
self.embedder = embedder
|
|
142
146
|
else:
|
|
143
147
|
self.embedder = OpenAIEmbedder()
|
|
148
|
+
if cross_encoder:
|
|
149
|
+
self.cross_encoder = cross_encoder
|
|
150
|
+
else:
|
|
151
|
+
self.cross_encoder = OpenAIRerankerClient()
|
|
144
152
|
|
|
145
153
|
async def close(self):
|
|
146
154
|
"""
|
|
@@ -648,6 +656,7 @@ class Graphiti:
|
|
|
648
656
|
await search(
|
|
649
657
|
self.driver,
|
|
650
658
|
self.embedder,
|
|
659
|
+
self.cross_encoder,
|
|
651
660
|
query,
|
|
652
661
|
group_ids,
|
|
653
662
|
search_config,
|
|
@@ -663,8 +672,18 @@ class Graphiti:
|
|
|
663
672
|
config: SearchConfig,
|
|
664
673
|
group_ids: list[str] | None = None,
|
|
665
674
|
center_node_uuid: str | None = None,
|
|
675
|
+
bfs_origin_node_uuids: list[str] | None = None,
|
|
666
676
|
) -> SearchResults:
|
|
667
|
-
return await search(
|
|
677
|
+
return await search(
|
|
678
|
+
self.driver,
|
|
679
|
+
self.embedder,
|
|
680
|
+
self.cross_encoder,
|
|
681
|
+
query,
|
|
682
|
+
group_ids,
|
|
683
|
+
config,
|
|
684
|
+
center_node_uuid,
|
|
685
|
+
bfs_origin_node_uuids,
|
|
686
|
+
)
|
|
668
687
|
|
|
669
688
|
async def get_nodes_by_query(
|
|
670
689
|
self,
|
|
@@ -716,7 +735,13 @@ class Graphiti:
|
|
|
716
735
|
|
|
717
736
|
nodes = (
|
|
718
737
|
await search(
|
|
719
|
-
self.driver,
|
|
738
|
+
self.driver,
|
|
739
|
+
self.embedder,
|
|
740
|
+
self.cross_encoder,
|
|
741
|
+
query,
|
|
742
|
+
group_ids,
|
|
743
|
+
search_config,
|
|
744
|
+
center_node_uuid,
|
|
720
745
|
)
|
|
721
746
|
).nodes
|
|
722
747
|
return nodes
|
|
@@ -90,7 +90,7 @@ class Node(BaseModel, ABC):
|
|
|
90
90
|
DETACH DELETE n
|
|
91
91
|
""",
|
|
92
92
|
uuid=self.uuid,
|
|
93
|
-
|
|
93
|
+
database_=DEFAULT_DATABASE,
|
|
94
94
|
)
|
|
95
95
|
|
|
96
96
|
logger.debug(f'Deleted Node: {self.uuid}')
|
|
@@ -136,7 +136,7 @@ class EpisodicNode(Node):
|
|
|
136
136
|
created_at=self.created_at,
|
|
137
137
|
valid_at=self.valid_at,
|
|
138
138
|
source=self.source.value,
|
|
139
|
-
|
|
139
|
+
database_=DEFAULT_DATABASE,
|
|
140
140
|
)
|
|
141
141
|
|
|
142
142
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
|
@@ -158,7 +158,7 @@ class EpisodicNode(Node):
|
|
|
158
158
|
e.source AS source
|
|
159
159
|
""",
|
|
160
160
|
uuid=uuid,
|
|
161
|
-
|
|
161
|
+
database_=DEFAULT_DATABASE,
|
|
162
162
|
)
|
|
163
163
|
|
|
164
164
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
@@ -184,7 +184,7 @@ class EpisodicNode(Node):
|
|
|
184
184
|
e.source AS source
|
|
185
185
|
""",
|
|
186
186
|
uuids=uuids,
|
|
187
|
-
|
|
187
|
+
database_=DEFAULT_DATABASE,
|
|
188
188
|
)
|
|
189
189
|
|
|
190
190
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
@@ -207,7 +207,7 @@ class EpisodicNode(Node):
|
|
|
207
207
|
e.source AS source
|
|
208
208
|
""",
|
|
209
209
|
group_ids=group_ids,
|
|
210
|
-
|
|
210
|
+
database_=DEFAULT_DATABASE,
|
|
211
211
|
)
|
|
212
212
|
|
|
213
213
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
@@ -237,7 +237,7 @@ class EntityNode(Node):
|
|
|
237
237
|
summary=self.summary,
|
|
238
238
|
name_embedding=self.name_embedding,
|
|
239
239
|
created_at=self.created_at,
|
|
240
|
-
|
|
240
|
+
database_=DEFAULT_DATABASE,
|
|
241
241
|
)
|
|
242
242
|
|
|
243
243
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
|
@@ -258,7 +258,7 @@ class EntityNode(Node):
|
|
|
258
258
|
n.summary AS summary
|
|
259
259
|
""",
|
|
260
260
|
uuid=uuid,
|
|
261
|
-
|
|
261
|
+
database_=DEFAULT_DATABASE,
|
|
262
262
|
)
|
|
263
263
|
|
|
264
264
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
@@ -282,7 +282,7 @@ class EntityNode(Node):
|
|
|
282
282
|
n.summary AS summary
|
|
283
283
|
""",
|
|
284
284
|
uuids=uuids,
|
|
285
|
-
|
|
285
|
+
database_=DEFAULT_DATABASE,
|
|
286
286
|
)
|
|
287
287
|
|
|
288
288
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
@@ -303,7 +303,7 @@ class EntityNode(Node):
|
|
|
303
303
|
n.summary AS summary
|
|
304
304
|
""",
|
|
305
305
|
group_ids=group_ids,
|
|
306
|
-
|
|
306
|
+
database_=DEFAULT_DATABASE,
|
|
307
307
|
)
|
|
308
308
|
|
|
309
309
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
@@ -324,7 +324,7 @@ class CommunityNode(Node):
|
|
|
324
324
|
summary=self.summary,
|
|
325
325
|
name_embedding=self.name_embedding,
|
|
326
326
|
created_at=self.created_at,
|
|
327
|
-
|
|
327
|
+
database_=DEFAULT_DATABASE,
|
|
328
328
|
)
|
|
329
329
|
|
|
330
330
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
|
@@ -354,7 +354,7 @@ class CommunityNode(Node):
|
|
|
354
354
|
n.summary AS summary
|
|
355
355
|
""",
|
|
356
356
|
uuid=uuid,
|
|
357
|
-
|
|
357
|
+
database_=DEFAULT_DATABASE,
|
|
358
358
|
)
|
|
359
359
|
|
|
360
360
|
nodes = [get_community_node_from_record(record) for record in records]
|
|
@@ -378,7 +378,7 @@ class CommunityNode(Node):
|
|
|
378
378
|
n.summary AS summary
|
|
379
379
|
""",
|
|
380
380
|
uuids=uuids,
|
|
381
|
-
|
|
381
|
+
database_=DEFAULT_DATABASE,
|
|
382
382
|
)
|
|
383
383
|
|
|
384
384
|
communities = [get_community_node_from_record(record) for record in records]
|
|
@@ -399,7 +399,7 @@ class CommunityNode(Node):
|
|
|
399
399
|
n.summary AS summary
|
|
400
400
|
""",
|
|
401
401
|
group_ids=group_ids,
|
|
402
|
-
|
|
402
|
+
database_=DEFAULT_DATABASE,
|
|
403
403
|
)
|
|
404
404
|
|
|
405
405
|
communities = [get_community_node_from_record(record) for record in records]
|