graphiti-core 0.2.3__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/PKG-INFO +8 -2
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/README.md +6 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/edges.py +68 -29
- graphiti_core-0.3.1/graphiti_core/errors.py +43 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/graphiti.py +51 -26
- graphiti_core-0.3.1/graphiti_core/helpers.py +23 -0
- graphiti_core-0.3.1/graphiti_core/llm_client/__init__.py +6 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/llm_client/anthropic_client.py +9 -1
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/llm_client/client.py +17 -10
- graphiti_core-0.3.1/graphiti_core/llm_client/errors.py +23 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/llm_client/groq_client.py +4 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/llm_client/openai_client.py +4 -0
- graphiti_core-0.3.1/graphiti_core/llm_client/utils.py +38 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/nodes.py +144 -20
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/prompts/extract_edge_dates.py +16 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/prompts/extract_nodes.py +43 -1
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/prompts/lib.py +6 -0
- graphiti_core-0.3.1/graphiti_core/prompts/summarize_nodes.py +79 -0
- graphiti_core-0.3.1/graphiti_core/py.typed +1 -0
- graphiti_core-0.3.1/graphiti_core/search/search.py +242 -0
- graphiti_core-0.3.1/graphiti_core/search/search_config.py +81 -0
- graphiti_core-0.3.1/graphiti_core/search/search_config_recipes.py +84 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/search/search_utils.py +259 -152
- graphiti_core-0.3.1/graphiti_core/utils/maintenance/community_operations.py +155 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/edge_operations.py +20 -2
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/graph_data_operations.py +11 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/node_operations.py +26 -1
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/pyproject.toml +3 -3
- graphiti_core-0.2.3/graphiti_core/helpers.py +0 -7
- graphiti_core-0.2.3/graphiti_core/llm_client/__init__.py +0 -5
- graphiti_core-0.2.3/graphiti_core/llm_client/utils.py +0 -22
- graphiti_core-0.2.3/graphiti_core/search/search.py +0 -145
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/LICENSE +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/prompts/extract_edges.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/utils/bulk_utils.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
14
|
Requires-Dist: diskcache (>=5.6.3,<6.0.0)
|
|
15
15
|
Requires-Dist: neo4j (>=5.23.0,<6.0.0)
|
|
16
|
-
Requires-Dist: numpy (>=
|
|
16
|
+
Requires-Dist: numpy (>=1.0.0)
|
|
17
17
|
Requires-Dist: openai (>=1.38.0,<2.0.0)
|
|
18
18
|
Requires-Dist: pydantic (>=2.8.2,<3.0.0)
|
|
19
19
|
Requires-Dist: tenacity (<9.0.0)
|
|
@@ -170,6 +170,12 @@ await graphiti.search('Who was the California Attorney General?', center_node_uu
|
|
|
170
170
|
graphiti.close()
|
|
171
171
|
```
|
|
172
172
|
|
|
173
|
+
## Graph Service
|
|
174
|
+
|
|
175
|
+
The `server` directory contains an API service for interacting with the Graphiti API. It is built using FastAPI.
|
|
176
|
+
|
|
177
|
+
Please see the [server README](./server/README.md) for more information.
|
|
178
|
+
|
|
173
179
|
## Documentation
|
|
174
180
|
|
|
175
181
|
- [Guides and API documentation](https://help.getzep.com/graphiti).
|
|
@@ -149,6 +149,12 @@ await graphiti.search('Who was the California Attorney General?', center_node_uu
|
|
|
149
149
|
graphiti.close()
|
|
150
150
|
```
|
|
151
151
|
|
|
152
|
+
## Graph Service
|
|
153
|
+
|
|
154
|
+
The `server` directory contains an API service for interacting with the Graphiti API. It is built using FastAPI.
|
|
155
|
+
|
|
156
|
+
Please see the [server README](./server/README.md) for more information.
|
|
157
|
+
|
|
152
158
|
## Documentation
|
|
153
159
|
|
|
154
160
|
- [Guides and API documentation](https://help.getzep.com/graphiti).
|
|
@@ -24,6 +24,7 @@ from uuid import uuid4
|
|
|
24
24
|
from neo4j import AsyncDriver
|
|
25
25
|
from pydantic import BaseModel, Field
|
|
26
26
|
|
|
27
|
+
from graphiti_core.errors import EdgeNotFoundError
|
|
27
28
|
from graphiti_core.helpers import parse_db_date
|
|
28
29
|
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
29
30
|
from graphiti_core.nodes import Node
|
|
@@ -41,8 +42,18 @@ class Edge(BaseModel, ABC):
|
|
|
41
42
|
@abstractmethod
|
|
42
43
|
async def save(self, driver: AsyncDriver): ...
|
|
43
44
|
|
|
44
|
-
|
|
45
|
-
|
|
45
|
+
async def delete(self, driver: AsyncDriver):
|
|
46
|
+
result = await driver.execute_query(
|
|
47
|
+
"""
|
|
48
|
+
MATCH (n)-[e {uuid: $uuid}]->(m)
|
|
49
|
+
DELETE e
|
|
50
|
+
""",
|
|
51
|
+
uuid=self.uuid,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
logger.info(f'Deleted Edge: {self.uuid}')
|
|
55
|
+
|
|
56
|
+
return result
|
|
46
57
|
|
|
47
58
|
def __hash__(self):
|
|
48
59
|
return hash(self.uuid)
|
|
@@ -76,19 +87,6 @@ class EpisodicEdge(Edge):
|
|
|
76
87
|
|
|
77
88
|
return result
|
|
78
89
|
|
|
79
|
-
async def delete(self, driver: AsyncDriver):
|
|
80
|
-
result = await driver.execute_query(
|
|
81
|
-
"""
|
|
82
|
-
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
83
|
-
DELETE e
|
|
84
|
-
""",
|
|
85
|
-
uuid=self.uuid,
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
logger.info(f'Deleted Edge: {self.uuid}')
|
|
89
|
-
|
|
90
|
-
return result
|
|
91
|
-
|
|
92
90
|
@classmethod
|
|
93
91
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
94
92
|
records, _, _ = await driver.execute_query(
|
|
@@ -107,7 +105,8 @@ class EpisodicEdge(Edge):
|
|
|
107
105
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
|
108
106
|
|
|
109
107
|
logger.info(f'Found Edge: {uuid}')
|
|
110
|
-
|
|
108
|
+
if len(edges) == 0:
|
|
109
|
+
raise EdgeNotFoundError(uuid)
|
|
111
110
|
return edges[0]
|
|
112
111
|
|
|
113
112
|
|
|
@@ -169,19 +168,6 @@ class EntityEdge(Edge):
|
|
|
169
168
|
|
|
170
169
|
return result
|
|
171
170
|
|
|
172
|
-
async def delete(self, driver: AsyncDriver):
|
|
173
|
-
result = await driver.execute_query(
|
|
174
|
-
"""
|
|
175
|
-
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
176
|
-
DELETE e
|
|
177
|
-
""",
|
|
178
|
-
uuid=self.uuid,
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
logger.info(f'Deleted Edge: {self.uuid}')
|
|
182
|
-
|
|
183
|
-
return result
|
|
184
|
-
|
|
185
171
|
@classmethod
|
|
186
172
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
187
173
|
records, _, _ = await driver.execute_query(
|
|
@@ -206,6 +192,49 @@ class EntityEdge(Edge):
|
|
|
206
192
|
|
|
207
193
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
208
194
|
|
|
195
|
+
logger.info(f'Found Edge: {uuid}')
|
|
196
|
+
if len(edges) == 0:
|
|
197
|
+
raise EdgeNotFoundError(uuid)
|
|
198
|
+
return edges[0]
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class CommunityEdge(Edge):
|
|
202
|
+
async def save(self, driver: AsyncDriver):
|
|
203
|
+
result = await driver.execute_query(
|
|
204
|
+
"""
|
|
205
|
+
MATCH (community:Community {uuid: $community_uuid})
|
|
206
|
+
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
|
207
|
+
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
|
208
|
+
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
|
209
|
+
RETURN r.uuid AS uuid""",
|
|
210
|
+
community_uuid=self.source_node_uuid,
|
|
211
|
+
entity_uuid=self.target_node_uuid,
|
|
212
|
+
uuid=self.uuid,
|
|
213
|
+
group_id=self.group_id,
|
|
214
|
+
created_at=self.created_at,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
logger.info(f'Saved edge to neo4j: {self.uuid}')
|
|
218
|
+
|
|
219
|
+
return result
|
|
220
|
+
|
|
221
|
+
@classmethod
|
|
222
|
+
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
223
|
+
records, _, _ = await driver.execute_query(
|
|
224
|
+
"""
|
|
225
|
+
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
|
|
226
|
+
RETURN
|
|
227
|
+
e.uuid As uuid,
|
|
228
|
+
e.group_id AS group_id,
|
|
229
|
+
n.uuid AS source_node_uuid,
|
|
230
|
+
m.uuid AS target_node_uuid,
|
|
231
|
+
e.created_at AS created_at
|
|
232
|
+
""",
|
|
233
|
+
uuid=uuid,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
edges = [get_community_edge_from_record(record) for record in records]
|
|
237
|
+
|
|
209
238
|
logger.info(f'Found Edge: {uuid}')
|
|
210
239
|
|
|
211
240
|
return edges[0]
|
|
@@ -237,3 +266,13 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
|
237
266
|
valid_at=parse_db_date(record['valid_at']),
|
|
238
267
|
invalid_at=parse_db_date(record['invalid_at']),
|
|
239
268
|
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def get_community_edge_from_record(record: Any):
|
|
272
|
+
return CommunityEdge(
|
|
273
|
+
uuid=record['uuid'],
|
|
274
|
+
group_id=record['group_id'],
|
|
275
|
+
source_node_uuid=record['source_node_uuid'],
|
|
276
|
+
target_node_uuid=record['target_node_uuid'],
|
|
277
|
+
created_at=record['created_at'].to_native(),
|
|
278
|
+
)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GraphitiError(Exception):
|
|
19
|
+
"""Base exception class for Graphiti Core."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class EdgeNotFoundError(GraphitiError):
|
|
23
|
+
"""Raised when an edge is not found."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, uuid: str):
|
|
26
|
+
self.message = f'edge {uuid} not found'
|
|
27
|
+
super().__init__(self.message)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class NodeNotFoundError(GraphitiError):
|
|
31
|
+
"""Raised when a node is not found."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, uuid: str):
|
|
34
|
+
self.message = f'node {uuid} not found'
|
|
35
|
+
super().__init__(self.message)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SearchRerankerError(GraphitiError):
|
|
39
|
+
"""Raised when a node is not found."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, text: str):
|
|
42
|
+
self.message = text
|
|
43
|
+
super().__init__(self.message)
|
|
@@ -24,14 +24,19 @@ from neo4j import AsyncGraphDatabase
|
|
|
24
24
|
|
|
25
25
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
26
26
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
27
|
-
from graphiti_core.llm_client.utils import generate_embedding
|
|
28
27
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
29
|
-
from graphiti_core.search.search import
|
|
28
|
+
from graphiti_core.search.search import SearchConfig, search
|
|
29
|
+
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
|
30
|
+
from graphiti_core.search.search_config_recipes import (
|
|
31
|
+
EDGE_HYBRID_SEARCH_NODE_DISTANCE,
|
|
32
|
+
EDGE_HYBRID_SEARCH_RRF,
|
|
33
|
+
NODE_HYBRID_SEARCH_NODE_DISTANCE,
|
|
34
|
+
NODE_HYBRID_SEARCH_RRF,
|
|
35
|
+
)
|
|
30
36
|
from graphiti_core.search.search_utils import (
|
|
31
37
|
RELEVANT_SCHEMA_LIMIT,
|
|
32
38
|
get_relevant_edges,
|
|
33
39
|
get_relevant_nodes,
|
|
34
|
-
hybrid_node_search,
|
|
35
40
|
)
|
|
36
41
|
from graphiti_core.utils import (
|
|
37
42
|
build_episodic_edges,
|
|
@@ -46,6 +51,10 @@ from graphiti_core.utils.bulk_utils import (
|
|
|
46
51
|
resolve_edge_pointers,
|
|
47
52
|
retrieve_previous_episodes_bulk,
|
|
48
53
|
)
|
|
54
|
+
from graphiti_core.utils.maintenance.community_operations import (
|
|
55
|
+
build_communities,
|
|
56
|
+
remove_communities,
|
|
57
|
+
)
|
|
49
58
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
50
59
|
extract_edges,
|
|
51
60
|
resolve_extracted_edges,
|
|
@@ -412,7 +421,7 @@ class Graphiti:
|
|
|
412
421
|
except Exception as e:
|
|
413
422
|
raise e
|
|
414
423
|
|
|
415
|
-
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
|
|
424
|
+
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None):
|
|
416
425
|
"""
|
|
417
426
|
Process multiple episodes in bulk and update the graph.
|
|
418
427
|
|
|
@@ -526,12 +535,25 @@ class Graphiti:
|
|
|
526
535
|
except Exception as e:
|
|
527
536
|
raise e
|
|
528
537
|
|
|
538
|
+
async def build_communities(self):
|
|
539
|
+
embedder = self.llm_client.get_embedder()
|
|
540
|
+
|
|
541
|
+
# Clear existing communities
|
|
542
|
+
await remove_communities(self.driver)
|
|
543
|
+
|
|
544
|
+
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
|
545
|
+
|
|
546
|
+
await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
|
|
547
|
+
|
|
548
|
+
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
|
549
|
+
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
|
550
|
+
|
|
529
551
|
async def search(
|
|
530
552
|
self,
|
|
531
553
|
query: str,
|
|
532
554
|
center_node_uuid: str | None = None,
|
|
533
555
|
group_ids: list[str | None] | None = None,
|
|
534
|
-
num_results=
|
|
556
|
+
num_results=DEFAULT_SEARCH_LIMIT,
|
|
535
557
|
):
|
|
536
558
|
"""
|
|
537
559
|
Perform a hybrid search on the knowledge graph.
|
|
@@ -547,7 +569,7 @@ class Graphiti:
|
|
|
547
569
|
Facts will be reranked based on proximity to this node
|
|
548
570
|
group_ids : list[str | None] | None, optional
|
|
549
571
|
The graph partitions to return data from.
|
|
550
|
-
|
|
572
|
+
limit : int, optional
|
|
551
573
|
The maximum number of results to return. Defaults to 10.
|
|
552
574
|
|
|
553
575
|
Returns
|
|
@@ -564,21 +586,17 @@ class Graphiti:
|
|
|
564
586
|
The search is performed using the current date and time as the reference
|
|
565
587
|
point for temporal relevance.
|
|
566
588
|
"""
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
num_episodes=0,
|
|
570
|
-
num_edges=num_results,
|
|
571
|
-
num_nodes=0,
|
|
572
|
-
group_ids=group_ids,
|
|
573
|
-
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
|
|
574
|
-
reranker=reranker,
|
|
589
|
+
search_config = (
|
|
590
|
+
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
|
|
575
591
|
)
|
|
592
|
+
search_config.limit = num_results
|
|
593
|
+
|
|
576
594
|
edges = (
|
|
577
|
-
await
|
|
595
|
+
await search(
|
|
578
596
|
self.driver,
|
|
579
597
|
self.llm_client.get_embedder(),
|
|
580
598
|
query,
|
|
581
|
-
|
|
599
|
+
group_ids,
|
|
582
600
|
search_config,
|
|
583
601
|
center_node_uuid,
|
|
584
602
|
)
|
|
@@ -589,19 +607,20 @@ class Graphiti:
|
|
|
589
607
|
async def _search(
|
|
590
608
|
self,
|
|
591
609
|
query: str,
|
|
592
|
-
timestamp: datetime,
|
|
593
610
|
config: SearchConfig,
|
|
611
|
+
group_ids: list[str | None] | None = None,
|
|
594
612
|
center_node_uuid: str | None = None,
|
|
595
|
-
):
|
|
596
|
-
return await
|
|
597
|
-
self.driver, self.llm_client.get_embedder(), query,
|
|
613
|
+
) -> SearchResults:
|
|
614
|
+
return await search(
|
|
615
|
+
self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
|
|
598
616
|
)
|
|
599
617
|
|
|
600
618
|
async def get_nodes_by_query(
|
|
601
619
|
self,
|
|
602
620
|
query: str,
|
|
621
|
+
center_node_uuid: str | None = None,
|
|
603
622
|
group_ids: list[str | None] | None = None,
|
|
604
|
-
limit: int =
|
|
623
|
+
limit: int = DEFAULT_SEARCH_LIMIT,
|
|
605
624
|
) -> list[EntityNode]:
|
|
606
625
|
"""
|
|
607
626
|
Retrieve nodes from the graph database based on a text query.
|
|
@@ -612,7 +631,9 @@ class Graphiti:
|
|
|
612
631
|
Parameters
|
|
613
632
|
----------
|
|
614
633
|
query : str
|
|
615
|
-
The text query to search for in the graph
|
|
634
|
+
The text query to search for in the graph
|
|
635
|
+
center_node_uuid: str, optional
|
|
636
|
+
Facts will be reranked based on proximity to this node.
|
|
616
637
|
group_ids : list[str | None] | None, optional
|
|
617
638
|
The graph partitions to return data from.
|
|
618
639
|
limit : int | None, optional
|
|
@@ -638,8 +659,12 @@ class Graphiti:
|
|
|
638
659
|
If not specified, a default limit (defined in the search functions) will be used.
|
|
639
660
|
"""
|
|
640
661
|
embedder = self.llm_client.get_embedder()
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
[query], [query_embedding], self.driver, group_ids, limit
|
|
662
|
+
search_config = (
|
|
663
|
+
NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
|
|
644
664
|
)
|
|
645
|
-
|
|
665
|
+
search_config.limit = limit
|
|
666
|
+
|
|
667
|
+
nodes = (
|
|
668
|
+
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
|
|
669
|
+
).nodes
|
|
670
|
+
return nodes
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
|
|
19
|
+
from neo4j import time as neo4j_time
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
23
|
+
return neo_date.to_native() if neo_date else None
|
|
@@ -18,12 +18,14 @@ import json
|
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
20
|
|
|
21
|
+
import anthropic
|
|
21
22
|
from anthropic import AsyncAnthropic
|
|
22
23
|
from openai import AsyncOpenAI
|
|
23
24
|
|
|
24
25
|
from ..prompts.models import Message
|
|
25
26
|
from .client import LLMClient
|
|
26
27
|
from .config import LLMConfig
|
|
28
|
+
from .errors import RateLimitError
|
|
27
29
|
|
|
28
30
|
logger = logging.getLogger(__name__)
|
|
29
31
|
|
|
@@ -35,7 +37,11 @@ class AnthropicClient(LLMClient):
|
|
|
35
37
|
if config is None:
|
|
36
38
|
config = LLMConfig()
|
|
37
39
|
super().__init__(config, cache)
|
|
38
|
-
self.client = AsyncAnthropic(
|
|
40
|
+
self.client = AsyncAnthropic(
|
|
41
|
+
api_key=config.api_key,
|
|
42
|
+
# we'll use tenacity to retry
|
|
43
|
+
max_retries=1,
|
|
44
|
+
)
|
|
39
45
|
|
|
40
46
|
def get_embedder(self) -> typing.Any:
|
|
41
47
|
openai_client = AsyncOpenAI()
|
|
@@ -58,6 +64,8 @@ class AnthropicClient(LLMClient):
|
|
|
58
64
|
)
|
|
59
65
|
|
|
60
66
|
return json.loads('{' + result.content[0].text) # type: ignore
|
|
67
|
+
except anthropic.RateLimitError as e:
|
|
68
|
+
raise RateLimitError from e
|
|
61
69
|
except Exception as e:
|
|
62
70
|
logger.error(f'Error in generating LLM response: {e}')
|
|
63
71
|
raise
|
|
@@ -22,10 +22,11 @@ from abc import ABC, abstractmethod
|
|
|
22
22
|
|
|
23
23
|
import httpx
|
|
24
24
|
from diskcache import Cache
|
|
25
|
-
from tenacity import retry, retry_if_exception, stop_after_attempt,
|
|
25
|
+
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
|
26
26
|
|
|
27
27
|
from ..prompts.models import Message
|
|
28
28
|
from .config import LLMConfig
|
|
29
|
+
from .errors import RateLimitError
|
|
29
30
|
|
|
30
31
|
DEFAULT_TEMPERATURE = 0
|
|
31
32
|
DEFAULT_CACHE_DIR = './llm_cache'
|
|
@@ -33,7 +34,10 @@ DEFAULT_CACHE_DIR = './llm_cache'
|
|
|
33
34
|
logger = logging.getLogger(__name__)
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
def
|
|
37
|
+
def is_server_or_retry_error(exception):
|
|
38
|
+
if isinstance(exception, RateLimitError):
|
|
39
|
+
return True
|
|
40
|
+
|
|
37
41
|
return (
|
|
38
42
|
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
|
|
39
43
|
)
|
|
@@ -56,18 +60,21 @@ class LLMClient(ABC):
|
|
|
56
60
|
pass
|
|
57
61
|
|
|
58
62
|
@retry(
|
|
59
|
-
stop=stop_after_attempt(
|
|
60
|
-
wait=
|
|
61
|
-
retry=retry_if_exception(
|
|
63
|
+
stop=stop_after_attempt(4),
|
|
64
|
+
wait=wait_random_exponential(multiplier=10, min=5, max=120),
|
|
65
|
+
retry=retry_if_exception(is_server_or_retry_error),
|
|
66
|
+
after=lambda retry_state: logger.warning(
|
|
67
|
+
f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
|
|
68
|
+
)
|
|
69
|
+
if retry_state.attempt_number > 1
|
|
70
|
+
else None,
|
|
71
|
+
reraise=True,
|
|
62
72
|
)
|
|
63
73
|
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
|
|
64
74
|
try:
|
|
65
75
|
return await self._generate_response(messages)
|
|
66
|
-
except httpx.HTTPStatusError as e:
|
|
67
|
-
|
|
68
|
-
raise Exception(f'LLM request error: {e}') from e
|
|
69
|
-
else:
|
|
70
|
-
raise
|
|
76
|
+
except (httpx.HTTPStatusError, RateLimitError) as e:
|
|
77
|
+
raise e
|
|
71
78
|
|
|
72
79
|
@abstractmethod
|
|
73
80
|
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RateLimitError(Exception):
|
|
19
|
+
"""Exception raised when the rate limit is exceeded."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, message='Rate limit exceeded. Please try again later.'):
|
|
22
|
+
self.message = message
|
|
23
|
+
super().__init__(self.message)
|
|
@@ -18,6 +18,7 @@ import json
|
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
20
|
|
|
21
|
+
import groq
|
|
21
22
|
from groq import AsyncGroq
|
|
22
23
|
from groq.types.chat import ChatCompletionMessageParam
|
|
23
24
|
from openai import AsyncOpenAI
|
|
@@ -25,6 +26,7 @@ from openai import AsyncOpenAI
|
|
|
25
26
|
from ..prompts.models import Message
|
|
26
27
|
from .client import LLMClient
|
|
27
28
|
from .config import LLMConfig
|
|
29
|
+
from .errors import RateLimitError
|
|
28
30
|
|
|
29
31
|
logger = logging.getLogger(__name__)
|
|
30
32
|
|
|
@@ -59,6 +61,8 @@ class GroqClient(LLMClient):
|
|
|
59
61
|
)
|
|
60
62
|
result = response.choices[0].message.content or ''
|
|
61
63
|
return json.loads(result)
|
|
64
|
+
except groq.RateLimitError as e:
|
|
65
|
+
raise RateLimitError from e
|
|
62
66
|
except Exception as e:
|
|
63
67
|
logger.error(f'Error in generating LLM response: {e}')
|
|
64
68
|
raise
|
|
@@ -18,12 +18,14 @@ import json
|
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
20
|
|
|
21
|
+
import openai
|
|
21
22
|
from openai import AsyncOpenAI
|
|
22
23
|
from openai.types.chat import ChatCompletionMessageParam
|
|
23
24
|
|
|
24
25
|
from ..prompts.models import Message
|
|
25
26
|
from .client import LLMClient
|
|
26
27
|
from .config import LLMConfig
|
|
28
|
+
from .errors import RateLimitError
|
|
27
29
|
|
|
28
30
|
logger = logging.getLogger(__name__)
|
|
29
31
|
|
|
@@ -59,6 +61,8 @@ class OpenAIClient(LLMClient):
|
|
|
59
61
|
)
|
|
60
62
|
result = response.choices[0].message.content or ''
|
|
61
63
|
return json.loads(result)
|
|
64
|
+
except openai.RateLimitError as e:
|
|
65
|
+
raise RateLimitError from e
|
|
62
66
|
except Exception as e:
|
|
63
67
|
logger.error(f'Error in generating LLM response: {e}')
|
|
64
68
|
raise
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import typing
|
|
19
|
+
from time import time
|
|
20
|
+
|
|
21
|
+
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def generate_embedding(
|
|
27
|
+
embedder: typing.Any, text: str, model: str = 'text-embedding-3-small'
|
|
28
|
+
):
|
|
29
|
+
start = time()
|
|
30
|
+
|
|
31
|
+
text = text.replace('\n', ' ')
|
|
32
|
+
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
|
33
|
+
embedding = embedding[:EMBEDDING_DIM]
|
|
34
|
+
|
|
35
|
+
end = time()
|
|
36
|
+
logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
|
|
37
|
+
|
|
38
|
+
return embedding
|