graphiti-core 0.2.3__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/PKG-INFO +8 -2
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/README.md +6 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/edges.py +68 -29
- graphiti_core-0.3.0/graphiti_core/errors.py +18 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/graphiti.py +18 -1
- graphiti_core-0.3.0/graphiti_core/llm_client/__init__.py +6 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/anthropic_client.py +9 -1
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/client.py +17 -10
- graphiti_core-0.3.0/graphiti_core/llm_client/errors.py +6 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/groq_client.py +4 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/openai_client.py +4 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/nodes.py +144 -20
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/extract_nodes.py +43 -1
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/lib.py +6 -0
- graphiti_core-0.3.0/graphiti_core/prompts/summarize_nodes.py +79 -0
- graphiti_core-0.3.0/graphiti_core/py.typed +1 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/search/search_utils.py +27 -22
- graphiti_core-0.3.0/graphiti_core/utils/maintenance/community_operations.py +155 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/edge_operations.py +20 -2
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/graph_data_operations.py +11 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/node_operations.py +26 -1
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/pyproject.toml +3 -3
- graphiti_core-0.2.3/graphiti_core/llm_client/__init__.py +0 -5
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/LICENSE +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/helpers.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/extract_edge_dates.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/extract_edges.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/search/search.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/bulk_utils.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
- {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
14
|
Requires-Dist: diskcache (>=5.6.3,<6.0.0)
|
|
15
15
|
Requires-Dist: neo4j (>=5.23.0,<6.0.0)
|
|
16
|
-
Requires-Dist: numpy (>=
|
|
16
|
+
Requires-Dist: numpy (>=1.0.0)
|
|
17
17
|
Requires-Dist: openai (>=1.38.0,<2.0.0)
|
|
18
18
|
Requires-Dist: pydantic (>=2.8.2,<3.0.0)
|
|
19
19
|
Requires-Dist: tenacity (<9.0.0)
|
|
@@ -170,6 +170,12 @@ await graphiti.search('Who was the California Attorney General?', center_node_uu
|
|
|
170
170
|
graphiti.close()
|
|
171
171
|
```
|
|
172
172
|
|
|
173
|
+
## Graph Service
|
|
174
|
+
|
|
175
|
+
The `server` directory contains an API service for interacting with the Graphiti API. It is built using FastAPI.
|
|
176
|
+
|
|
177
|
+
Please see the [server README](./server/README.md) for more information.
|
|
178
|
+
|
|
173
179
|
## Documentation
|
|
174
180
|
|
|
175
181
|
- [Guides and API documentation](https://help.getzep.com/graphiti).
|
|
@@ -149,6 +149,12 @@ await graphiti.search('Who was the California Attorney General?', center_node_uu
|
|
|
149
149
|
graphiti.close()
|
|
150
150
|
```
|
|
151
151
|
|
|
152
|
+
## Graph Service
|
|
153
|
+
|
|
154
|
+
The `server` directory contains an API service for interacting with the Graphiti API. It is built using FastAPI.
|
|
155
|
+
|
|
156
|
+
Please see the [server README](./server/README.md) for more information.
|
|
157
|
+
|
|
152
158
|
## Documentation
|
|
153
159
|
|
|
154
160
|
- [Guides and API documentation](https://help.getzep.com/graphiti).
|
|
@@ -24,6 +24,7 @@ from uuid import uuid4
|
|
|
24
24
|
from neo4j import AsyncDriver
|
|
25
25
|
from pydantic import BaseModel, Field
|
|
26
26
|
|
|
27
|
+
from graphiti_core.errors import EdgeNotFoundError
|
|
27
28
|
from graphiti_core.helpers import parse_db_date
|
|
28
29
|
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
29
30
|
from graphiti_core.nodes import Node
|
|
@@ -41,8 +42,18 @@ class Edge(BaseModel, ABC):
|
|
|
41
42
|
@abstractmethod
|
|
42
43
|
async def save(self, driver: AsyncDriver): ...
|
|
43
44
|
|
|
44
|
-
|
|
45
|
-
|
|
45
|
+
async def delete(self, driver: AsyncDriver):
|
|
46
|
+
result = await driver.execute_query(
|
|
47
|
+
"""
|
|
48
|
+
MATCH (n)-[e {uuid: $uuid}]->(m)
|
|
49
|
+
DELETE e
|
|
50
|
+
""",
|
|
51
|
+
uuid=self.uuid,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
logger.info(f'Deleted Edge: {self.uuid}')
|
|
55
|
+
|
|
56
|
+
return result
|
|
46
57
|
|
|
47
58
|
def __hash__(self):
|
|
48
59
|
return hash(self.uuid)
|
|
@@ -76,19 +87,6 @@ class EpisodicEdge(Edge):
|
|
|
76
87
|
|
|
77
88
|
return result
|
|
78
89
|
|
|
79
|
-
async def delete(self, driver: AsyncDriver):
|
|
80
|
-
result = await driver.execute_query(
|
|
81
|
-
"""
|
|
82
|
-
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
83
|
-
DELETE e
|
|
84
|
-
""",
|
|
85
|
-
uuid=self.uuid,
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
logger.info(f'Deleted Edge: {self.uuid}')
|
|
89
|
-
|
|
90
|
-
return result
|
|
91
|
-
|
|
92
90
|
@classmethod
|
|
93
91
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
94
92
|
records, _, _ = await driver.execute_query(
|
|
@@ -107,7 +105,8 @@ class EpisodicEdge(Edge):
|
|
|
107
105
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
|
108
106
|
|
|
109
107
|
logger.info(f'Found Edge: {uuid}')
|
|
110
|
-
|
|
108
|
+
if len(edges) == 0:
|
|
109
|
+
raise EdgeNotFoundError(uuid)
|
|
111
110
|
return edges[0]
|
|
112
111
|
|
|
113
112
|
|
|
@@ -169,19 +168,6 @@ class EntityEdge(Edge):
|
|
|
169
168
|
|
|
170
169
|
return result
|
|
171
170
|
|
|
172
|
-
async def delete(self, driver: AsyncDriver):
|
|
173
|
-
result = await driver.execute_query(
|
|
174
|
-
"""
|
|
175
|
-
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
176
|
-
DELETE e
|
|
177
|
-
""",
|
|
178
|
-
uuid=self.uuid,
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
logger.info(f'Deleted Edge: {self.uuid}')
|
|
182
|
-
|
|
183
|
-
return result
|
|
184
|
-
|
|
185
171
|
@classmethod
|
|
186
172
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
187
173
|
records, _, _ = await driver.execute_query(
|
|
@@ -206,6 +192,49 @@ class EntityEdge(Edge):
|
|
|
206
192
|
|
|
207
193
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
208
194
|
|
|
195
|
+
logger.info(f'Found Edge: {uuid}')
|
|
196
|
+
if len(edges) == 0:
|
|
197
|
+
raise EdgeNotFoundError(uuid)
|
|
198
|
+
return edges[0]
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class CommunityEdge(Edge):
|
|
202
|
+
async def save(self, driver: AsyncDriver):
|
|
203
|
+
result = await driver.execute_query(
|
|
204
|
+
"""
|
|
205
|
+
MATCH (community:Community {uuid: $community_uuid})
|
|
206
|
+
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
|
207
|
+
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
|
208
|
+
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
|
209
|
+
RETURN r.uuid AS uuid""",
|
|
210
|
+
community_uuid=self.source_node_uuid,
|
|
211
|
+
entity_uuid=self.target_node_uuid,
|
|
212
|
+
uuid=self.uuid,
|
|
213
|
+
group_id=self.group_id,
|
|
214
|
+
created_at=self.created_at,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
logger.info(f'Saved edge to neo4j: {self.uuid}')
|
|
218
|
+
|
|
219
|
+
return result
|
|
220
|
+
|
|
221
|
+
@classmethod
|
|
222
|
+
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
223
|
+
records, _, _ = await driver.execute_query(
|
|
224
|
+
"""
|
|
225
|
+
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
|
|
226
|
+
RETURN
|
|
227
|
+
e.uuid As uuid,
|
|
228
|
+
e.group_id AS group_id,
|
|
229
|
+
n.uuid AS source_node_uuid,
|
|
230
|
+
m.uuid AS target_node_uuid,
|
|
231
|
+
e.created_at AS created_at
|
|
232
|
+
""",
|
|
233
|
+
uuid=uuid,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
edges = [get_community_edge_from_record(record) for record in records]
|
|
237
|
+
|
|
209
238
|
logger.info(f'Found Edge: {uuid}')
|
|
210
239
|
|
|
211
240
|
return edges[0]
|
|
@@ -237,3 +266,13 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
|
237
266
|
valid_at=parse_db_date(record['valid_at']),
|
|
238
267
|
invalid_at=parse_db_date(record['invalid_at']),
|
|
239
268
|
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def get_community_edge_from_record(record: Any):
|
|
272
|
+
return CommunityEdge(
|
|
273
|
+
uuid=record['uuid'],
|
|
274
|
+
group_id=record['group_id'],
|
|
275
|
+
source_node_uuid=record['source_node_uuid'],
|
|
276
|
+
target_node_uuid=record['target_node_uuid'],
|
|
277
|
+
created_at=record['created_at'].to_native(),
|
|
278
|
+
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
class GraphitiError(Exception):
|
|
2
|
+
"""Base exception class for Graphiti Core."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class EdgeNotFoundError(GraphitiError):
|
|
6
|
+
"""Raised when an edge is not found."""
|
|
7
|
+
|
|
8
|
+
def __init__(self, uuid: str):
|
|
9
|
+
self.message = f'edge {uuid} not found'
|
|
10
|
+
super().__init__(self.message)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NodeNotFoundError(GraphitiError):
|
|
14
|
+
"""Raised when a node is not found."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, uuid: str):
|
|
17
|
+
self.message = f'node {uuid} not found'
|
|
18
|
+
super().__init__(self.message)
|
|
@@ -46,6 +46,10 @@ from graphiti_core.utils.bulk_utils import (
|
|
|
46
46
|
resolve_edge_pointers,
|
|
47
47
|
retrieve_previous_episodes_bulk,
|
|
48
48
|
)
|
|
49
|
+
from graphiti_core.utils.maintenance.community_operations import (
|
|
50
|
+
build_communities,
|
|
51
|
+
remove_communities,
|
|
52
|
+
)
|
|
49
53
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
50
54
|
extract_edges,
|
|
51
55
|
resolve_extracted_edges,
|
|
@@ -412,7 +416,7 @@ class Graphiti:
|
|
|
412
416
|
except Exception as e:
|
|
413
417
|
raise e
|
|
414
418
|
|
|
415
|
-
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
|
|
419
|
+
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None):
|
|
416
420
|
"""
|
|
417
421
|
Process multiple episodes in bulk and update the graph.
|
|
418
422
|
|
|
@@ -526,6 +530,19 @@ class Graphiti:
|
|
|
526
530
|
except Exception as e:
|
|
527
531
|
raise e
|
|
528
532
|
|
|
533
|
+
async def build_communities(self):
|
|
534
|
+
embedder = self.llm_client.get_embedder()
|
|
535
|
+
|
|
536
|
+
# Clear existing communities
|
|
537
|
+
await remove_communities(self.driver)
|
|
538
|
+
|
|
539
|
+
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
|
540
|
+
|
|
541
|
+
await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
|
|
542
|
+
|
|
543
|
+
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
|
544
|
+
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
|
545
|
+
|
|
529
546
|
async def search(
|
|
530
547
|
self,
|
|
531
548
|
query: str,
|
|
@@ -18,12 +18,14 @@ import json
|
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
20
|
|
|
21
|
+
import anthropic
|
|
21
22
|
from anthropic import AsyncAnthropic
|
|
22
23
|
from openai import AsyncOpenAI
|
|
23
24
|
|
|
24
25
|
from ..prompts.models import Message
|
|
25
26
|
from .client import LLMClient
|
|
26
27
|
from .config import LLMConfig
|
|
28
|
+
from .errors import RateLimitError
|
|
27
29
|
|
|
28
30
|
logger = logging.getLogger(__name__)
|
|
29
31
|
|
|
@@ -35,7 +37,11 @@ class AnthropicClient(LLMClient):
|
|
|
35
37
|
if config is None:
|
|
36
38
|
config = LLMConfig()
|
|
37
39
|
super().__init__(config, cache)
|
|
38
|
-
self.client = AsyncAnthropic(
|
|
40
|
+
self.client = AsyncAnthropic(
|
|
41
|
+
api_key=config.api_key,
|
|
42
|
+
# we'll use tenacity to retry
|
|
43
|
+
max_retries=1,
|
|
44
|
+
)
|
|
39
45
|
|
|
40
46
|
def get_embedder(self) -> typing.Any:
|
|
41
47
|
openai_client = AsyncOpenAI()
|
|
@@ -58,6 +64,8 @@ class AnthropicClient(LLMClient):
|
|
|
58
64
|
)
|
|
59
65
|
|
|
60
66
|
return json.loads('{' + result.content[0].text) # type: ignore
|
|
67
|
+
except anthropic.RateLimitError as e:
|
|
68
|
+
raise RateLimitError from e
|
|
61
69
|
except Exception as e:
|
|
62
70
|
logger.error(f'Error in generating LLM response: {e}')
|
|
63
71
|
raise
|
|
@@ -22,10 +22,11 @@ from abc import ABC, abstractmethod
|
|
|
22
22
|
|
|
23
23
|
import httpx
|
|
24
24
|
from diskcache import Cache
|
|
25
|
-
from tenacity import retry, retry_if_exception, stop_after_attempt,
|
|
25
|
+
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
|
26
26
|
|
|
27
27
|
from ..prompts.models import Message
|
|
28
28
|
from .config import LLMConfig
|
|
29
|
+
from .errors import RateLimitError
|
|
29
30
|
|
|
30
31
|
DEFAULT_TEMPERATURE = 0
|
|
31
32
|
DEFAULT_CACHE_DIR = './llm_cache'
|
|
@@ -33,7 +34,10 @@ DEFAULT_CACHE_DIR = './llm_cache'
|
|
|
33
34
|
logger = logging.getLogger(__name__)
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
def
|
|
37
|
+
def is_server_or_retry_error(exception):
|
|
38
|
+
if isinstance(exception, RateLimitError):
|
|
39
|
+
return True
|
|
40
|
+
|
|
37
41
|
return (
|
|
38
42
|
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
|
|
39
43
|
)
|
|
@@ -56,18 +60,21 @@ class LLMClient(ABC):
|
|
|
56
60
|
pass
|
|
57
61
|
|
|
58
62
|
@retry(
|
|
59
|
-
stop=stop_after_attempt(
|
|
60
|
-
wait=
|
|
61
|
-
retry=retry_if_exception(
|
|
63
|
+
stop=stop_after_attempt(4),
|
|
64
|
+
wait=wait_random_exponential(multiplier=10, min=5, max=120),
|
|
65
|
+
retry=retry_if_exception(is_server_or_retry_error),
|
|
66
|
+
after=lambda retry_state: logger.warning(
|
|
67
|
+
f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
|
|
68
|
+
)
|
|
69
|
+
if retry_state.attempt_number > 1
|
|
70
|
+
else None,
|
|
71
|
+
reraise=True,
|
|
62
72
|
)
|
|
63
73
|
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
|
|
64
74
|
try:
|
|
65
75
|
return await self._generate_response(messages)
|
|
66
|
-
except httpx.HTTPStatusError as e:
|
|
67
|
-
|
|
68
|
-
raise Exception(f'LLM request error: {e}') from e
|
|
69
|
-
else:
|
|
70
|
-
raise
|
|
76
|
+
except (httpx.HTTPStatusError, RateLimitError) as e:
|
|
77
|
+
raise e
|
|
71
78
|
|
|
72
79
|
@abstractmethod
|
|
73
80
|
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
|
@@ -18,6 +18,7 @@ import json
|
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
20
|
|
|
21
|
+
import groq
|
|
21
22
|
from groq import AsyncGroq
|
|
22
23
|
from groq.types.chat import ChatCompletionMessageParam
|
|
23
24
|
from openai import AsyncOpenAI
|
|
@@ -25,6 +26,7 @@ from openai import AsyncOpenAI
|
|
|
25
26
|
from ..prompts.models import Message
|
|
26
27
|
from .client import LLMClient
|
|
27
28
|
from .config import LLMConfig
|
|
29
|
+
from .errors import RateLimitError
|
|
28
30
|
|
|
29
31
|
logger = logging.getLogger(__name__)
|
|
30
32
|
|
|
@@ -59,6 +61,8 @@ class GroqClient(LLMClient):
|
|
|
59
61
|
)
|
|
60
62
|
result = response.choices[0].message.content or ''
|
|
61
63
|
return json.loads(result)
|
|
64
|
+
except groq.RateLimitError as e:
|
|
65
|
+
raise RateLimitError from e
|
|
62
66
|
except Exception as e:
|
|
63
67
|
logger.error(f'Error in generating LLM response: {e}')
|
|
64
68
|
raise
|
|
@@ -18,12 +18,14 @@ import json
|
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
20
|
|
|
21
|
+
import openai
|
|
21
22
|
from openai import AsyncOpenAI
|
|
22
23
|
from openai.types.chat import ChatCompletionMessageParam
|
|
23
24
|
|
|
24
25
|
from ..prompts.models import Message
|
|
25
26
|
from .client import LLMClient
|
|
26
27
|
from .config import LLMConfig
|
|
28
|
+
from .errors import RateLimitError
|
|
27
29
|
|
|
28
30
|
logger = logging.getLogger(__name__)
|
|
29
31
|
|
|
@@ -59,6 +61,8 @@ class OpenAIClient(LLMClient):
|
|
|
59
61
|
)
|
|
60
62
|
result = response.choices[0].message.content or ''
|
|
61
63
|
return json.loads(result)
|
|
64
|
+
except openai.RateLimitError as e:
|
|
65
|
+
raise RateLimitError from e
|
|
62
66
|
except Exception as e:
|
|
63
67
|
logger.error(f'Error in generating LLM response: {e}')
|
|
64
68
|
raise
|
|
@@ -25,6 +25,7 @@ from uuid import uuid4
|
|
|
25
25
|
from neo4j import AsyncDriver
|
|
26
26
|
from pydantic import BaseModel, Field
|
|
27
27
|
|
|
28
|
+
from graphiti_core.errors import NodeNotFoundError
|
|
28
29
|
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
29
30
|
|
|
30
31
|
logger = logging.getLogger(__name__)
|
|
@@ -76,8 +77,18 @@ class Node(BaseModel, ABC):
|
|
|
76
77
|
@abstractmethod
|
|
77
78
|
async def save(self, driver: AsyncDriver): ...
|
|
78
79
|
|
|
79
|
-
|
|
80
|
-
|
|
80
|
+
async def delete(self, driver: AsyncDriver):
|
|
81
|
+
result = await driver.execute_query(
|
|
82
|
+
"""
|
|
83
|
+
MATCH (n {uuid: $uuid})
|
|
84
|
+
DETACH DELETE n
|
|
85
|
+
""",
|
|
86
|
+
uuid=self.uuid,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
logger.info(f'Deleted Node: {self.uuid}')
|
|
90
|
+
|
|
91
|
+
return result
|
|
81
92
|
|
|
82
93
|
def __hash__(self):
|
|
83
94
|
return hash(self.uuid)
|
|
@@ -90,6 +101,9 @@ class Node(BaseModel, ABC):
|
|
|
90
101
|
@classmethod
|
|
91
102
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
|
|
92
103
|
|
|
104
|
+
@classmethod
|
|
105
|
+
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
|
|
106
|
+
|
|
93
107
|
|
|
94
108
|
class EpisodicNode(Node):
|
|
95
109
|
source: EpisodeType = Field(description='source type')
|
|
@@ -125,24 +139,37 @@ class EpisodicNode(Node):
|
|
|
125
139
|
|
|
126
140
|
return result
|
|
127
141
|
|
|
128
|
-
|
|
129
|
-
|
|
142
|
+
@classmethod
|
|
143
|
+
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
144
|
+
records, _, _ = await driver.execute_query(
|
|
130
145
|
"""
|
|
131
|
-
MATCH (
|
|
132
|
-
|
|
146
|
+
MATCH (e:Episodic {uuid: $uuid})
|
|
147
|
+
RETURN e.content AS content,
|
|
148
|
+
e.created_at AS created_at,
|
|
149
|
+
e.valid_at AS valid_at,
|
|
150
|
+
e.uuid AS uuid,
|
|
151
|
+
e.name AS name,
|
|
152
|
+
e.group_id AS group_id,
|
|
153
|
+
e.source_description AS source_description,
|
|
154
|
+
e.source AS source
|
|
133
155
|
""",
|
|
134
|
-
uuid=
|
|
156
|
+
uuid=uuid,
|
|
135
157
|
)
|
|
136
158
|
|
|
137
|
-
|
|
159
|
+
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
138
160
|
|
|
139
|
-
|
|
161
|
+
logger.info(f'Found Node: {uuid}')
|
|
162
|
+
|
|
163
|
+
if len(episodes) == 0:
|
|
164
|
+
raise NodeNotFoundError(uuid)
|
|
165
|
+
|
|
166
|
+
return episodes[0]
|
|
140
167
|
|
|
141
168
|
@classmethod
|
|
142
|
-
async def
|
|
169
|
+
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
|
143
170
|
records, _, _ = await driver.execute_query(
|
|
144
171
|
"""
|
|
145
|
-
MATCH (e:Episodic
|
|
172
|
+
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
|
146
173
|
RETURN e.content AS content,
|
|
147
174
|
e.created_at AS created_at,
|
|
148
175
|
e.valid_at AS valid_at,
|
|
@@ -152,14 +179,14 @@ class EpisodicNode(Node):
|
|
|
152
179
|
e.source_description AS source_description,
|
|
153
180
|
e.source AS source
|
|
154
181
|
""",
|
|
155
|
-
|
|
182
|
+
uuids=uuids,
|
|
156
183
|
)
|
|
157
184
|
|
|
158
185
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
159
186
|
|
|
160
|
-
logger.info(f'Found
|
|
187
|
+
logger.info(f'Found Nodes: {uuids}')
|
|
161
188
|
|
|
162
|
-
return episodes
|
|
189
|
+
return episodes
|
|
163
190
|
|
|
164
191
|
|
|
165
192
|
class EntityNode(Node):
|
|
@@ -194,24 +221,88 @@ class EntityNode(Node):
|
|
|
194
221
|
|
|
195
222
|
return result
|
|
196
223
|
|
|
197
|
-
|
|
198
|
-
|
|
224
|
+
@classmethod
|
|
225
|
+
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
226
|
+
records, _, _ = await driver.execute_query(
|
|
199
227
|
"""
|
|
200
228
|
MATCH (n:Entity {uuid: $uuid})
|
|
201
|
-
|
|
229
|
+
RETURN
|
|
230
|
+
n.uuid As uuid,
|
|
231
|
+
n.name AS name,
|
|
232
|
+
n.name_embedding AS name_embedding,
|
|
233
|
+
n.group_id AS group_id
|
|
234
|
+
n.created_at AS created_at,
|
|
235
|
+
n.summary AS summary
|
|
236
|
+
""",
|
|
237
|
+
uuid=uuid,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
nodes = [get_entity_node_from_record(record) for record in records]
|
|
241
|
+
|
|
242
|
+
logger.info(f'Found Node: {uuid}')
|
|
243
|
+
|
|
244
|
+
return nodes[0]
|
|
245
|
+
|
|
246
|
+
@classmethod
|
|
247
|
+
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
|
248
|
+
records, _, _ = await driver.execute_query(
|
|
249
|
+
"""
|
|
250
|
+
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
|
251
|
+
RETURN
|
|
252
|
+
n.uuid As uuid,
|
|
253
|
+
n.name AS name,
|
|
254
|
+
n.name_embedding AS name_embedding,
|
|
255
|
+
n.group_id AS group_id,
|
|
256
|
+
n.created_at AS created_at,
|
|
257
|
+
n.summary AS summary
|
|
202
258
|
""",
|
|
259
|
+
uuids=uuids,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
nodes = [get_entity_node_from_record(record) for record in records]
|
|
263
|
+
|
|
264
|
+
logger.info(f'Found Nodes: {uuids}')
|
|
265
|
+
|
|
266
|
+
return nodes
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class CommunityNode(Node):
|
|
270
|
+
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
271
|
+
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
|
272
|
+
|
|
273
|
+
async def save(self, driver: AsyncDriver):
|
|
274
|
+
result = await driver.execute_query(
|
|
275
|
+
"""
|
|
276
|
+
MERGE (n:Community {uuid: $uuid})
|
|
277
|
+
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
|
|
278
|
+
RETURN n.uuid AS uuid""",
|
|
203
279
|
uuid=self.uuid,
|
|
280
|
+
name=self.name,
|
|
281
|
+
group_id=self.group_id,
|
|
282
|
+
summary=self.summary,
|
|
283
|
+
name_embedding=self.name_embedding,
|
|
284
|
+
created_at=self.created_at,
|
|
204
285
|
)
|
|
205
286
|
|
|
206
|
-
logger.info(f'
|
|
287
|
+
logger.info(f'Saved Node to neo4j: {self.uuid}')
|
|
207
288
|
|
|
208
289
|
return result
|
|
209
290
|
|
|
291
|
+
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
|
|
292
|
+
start = time()
|
|
293
|
+
text = self.name.replace('\n', ' ')
|
|
294
|
+
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
|
295
|
+
self.name_embedding = embedding[:EMBEDDING_DIM]
|
|
296
|
+
end = time()
|
|
297
|
+
logger.info(f'embedded {text} in {end - start} ms')
|
|
298
|
+
|
|
299
|
+
return embedding
|
|
300
|
+
|
|
210
301
|
@classmethod
|
|
211
302
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
212
303
|
records, _, _ = await driver.execute_query(
|
|
213
304
|
"""
|
|
214
|
-
MATCH (n:
|
|
305
|
+
MATCH (n:Community {uuid: $uuid})
|
|
215
306
|
RETURN
|
|
216
307
|
n.uuid As uuid,
|
|
217
308
|
n.name AS name,
|
|
@@ -223,12 +314,34 @@ class EntityNode(Node):
|
|
|
223
314
|
uuid=uuid,
|
|
224
315
|
)
|
|
225
316
|
|
|
226
|
-
nodes = [
|
|
317
|
+
nodes = [get_community_node_from_record(record) for record in records]
|
|
227
318
|
|
|
228
319
|
logger.info(f'Found Node: {uuid}')
|
|
229
320
|
|
|
230
321
|
return nodes[0]
|
|
231
322
|
|
|
323
|
+
@classmethod
|
|
324
|
+
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
|
325
|
+
records, _, _ = await driver.execute_query(
|
|
326
|
+
"""
|
|
327
|
+
MATCH (n:Community) WHERE n.uuid IN $uuids
|
|
328
|
+
RETURN
|
|
329
|
+
n.uuid As uuid,
|
|
330
|
+
n.name AS name,
|
|
331
|
+
n.name_embedding AS name_embedding,
|
|
332
|
+
n.group_id AS group_id
|
|
333
|
+
n.created_at AS created_at,
|
|
334
|
+
n.summary AS summary
|
|
335
|
+
""",
|
|
336
|
+
uuids=uuids,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
nodes = [get_community_node_from_record(record) for record in records]
|
|
340
|
+
|
|
341
|
+
logger.info(f'Found Nodes: {uuids}')
|
|
342
|
+
|
|
343
|
+
return nodes
|
|
344
|
+
|
|
232
345
|
|
|
233
346
|
# Node helpers
|
|
234
347
|
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
@@ -254,3 +367,14 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
|
254
367
|
created_at=record['created_at'].to_native(),
|
|
255
368
|
summary=record['summary'],
|
|
256
369
|
)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def get_community_node_from_record(record: Any) -> CommunityNode:
|
|
373
|
+
return CommunityNode(
|
|
374
|
+
uuid=record['uuid'],
|
|
375
|
+
name=record['name'],
|
|
376
|
+
group_id=record['group_id'],
|
|
377
|
+
name_embedding=record['name_embedding'],
|
|
378
|
+
created_at=record['created_at'].to_native(),
|
|
379
|
+
summary=record['summary'],
|
|
380
|
+
)
|
|
@@ -24,12 +24,14 @@ class Prompt(Protocol):
|
|
|
24
24
|
v1: PromptVersion
|
|
25
25
|
v2: PromptVersion
|
|
26
26
|
extract_json: PromptVersion
|
|
27
|
+
extract_text: PromptVersion
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class Versions(TypedDict):
|
|
30
31
|
v1: PromptFunction
|
|
31
32
|
v2: PromptFunction
|
|
32
33
|
extract_json: PromptFunction
|
|
34
|
+
extract_text: PromptFunction
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def v1(context: dict[str, Any]) -> list[Message]:
|
|
@@ -144,4 +146,44 @@ Respond with a JSON object in the following format:
|
|
|
144
146
|
]
|
|
145
147
|
|
|
146
148
|
|
|
147
|
-
|
|
149
|
+
def extract_text(context: dict[str, Any]) -> list[Message]:
|
|
150
|
+
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
|
|
151
|
+
|
|
152
|
+
user_prompt = f"""
|
|
153
|
+
Given the following conversation, extract entity nodes from the CURRENT MESSAGE that are explicitly or implicitly mentioned:
|
|
154
|
+
|
|
155
|
+
Conversation:
|
|
156
|
+
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
|
|
157
|
+
<CURRENT MESSAGE>
|
|
158
|
+
{context["episode_content"]}
|
|
159
|
+
|
|
160
|
+
Guidelines:
|
|
161
|
+
2. Extract significant entities, concepts, or actors mentioned in the conversation.
|
|
162
|
+
3. Provide concise but informative summaries for each extracted node.
|
|
163
|
+
4. Avoid creating nodes for relationships or actions.
|
|
164
|
+
5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
|
|
165
|
+
6. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
|
|
166
|
+
|
|
167
|
+
Respond with a JSON object in the following format:
|
|
168
|
+
{{
|
|
169
|
+
"extracted_nodes": [
|
|
170
|
+
{{
|
|
171
|
+
"name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
|
|
172
|
+
"labels": ["Entity", "OptionalAdditionalLabel"],
|
|
173
|
+
"summary": "Brief summary of the node's role or significance"
|
|
174
|
+
}}
|
|
175
|
+
]
|
|
176
|
+
}}
|
|
177
|
+
"""
|
|
178
|
+
return [
|
|
179
|
+
Message(role='system', content=sys_prompt),
|
|
180
|
+
Message(role='user', content=user_prompt),
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
versions: Versions = {
|
|
185
|
+
'v1': v1,
|
|
186
|
+
'v2': v2,
|
|
187
|
+
'extract_json': extract_json,
|
|
188
|
+
'extract_text': extract_text,
|
|
189
|
+
}
|
|
@@ -71,6 +71,9 @@ from .invalidate_edges import (
|
|
|
71
71
|
versions as invalidate_edges_versions,
|
|
72
72
|
)
|
|
73
73
|
from .models import Message, PromptFunction
|
|
74
|
+
from .summarize_nodes import Prompt as SummarizeNodesPrompt
|
|
75
|
+
from .summarize_nodes import Versions as SummarizeNodesVersions
|
|
76
|
+
from .summarize_nodes import versions as summarize_nodes_versions
|
|
74
77
|
|
|
75
78
|
|
|
76
79
|
class PromptLibrary(Protocol):
|
|
@@ -80,6 +83,7 @@ class PromptLibrary(Protocol):
|
|
|
80
83
|
dedupe_edges: DedupeEdgesPrompt
|
|
81
84
|
invalidate_edges: InvalidateEdgesPrompt
|
|
82
85
|
extract_edge_dates: ExtractEdgeDatesPrompt
|
|
86
|
+
summarize_nodes: SummarizeNodesPrompt
|
|
83
87
|
|
|
84
88
|
|
|
85
89
|
class PromptLibraryImpl(TypedDict):
|
|
@@ -89,6 +93,7 @@ class PromptLibraryImpl(TypedDict):
|
|
|
89
93
|
dedupe_edges: DedupeEdgesVersions
|
|
90
94
|
invalidate_edges: InvalidateEdgesVersions
|
|
91
95
|
extract_edge_dates: ExtractEdgeDatesVersions
|
|
96
|
+
summarize_nodes: SummarizeNodesVersions
|
|
92
97
|
|
|
93
98
|
|
|
94
99
|
class VersionWrapper:
|
|
@@ -118,5 +123,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
|
|
118
123
|
'dedupe_edges': dedupe_edges_versions,
|
|
119
124
|
'invalidate_edges': invalidate_edges_versions,
|
|
120
125
|
'extract_edge_dates': extract_edge_dates_versions,
|
|
126
|
+
'summarize_nodes': summarize_nodes_versions,
|
|
121
127
|
}
|
|
122
128
|
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
from typing import Any, Protocol, TypedDict
|
|
19
|
+
|
|
20
|
+
from .models import Message, PromptFunction, PromptVersion
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Prompt(Protocol):
|
|
24
|
+
summarize_pair: PromptVersion
|
|
25
|
+
summary_description: PromptVersion
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Versions(TypedDict):
|
|
29
|
+
summarize_pair: PromptFunction
|
|
30
|
+
summary_description: PromptFunction
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def summarize_pair(context: dict[str, Any]) -> list[Message]:
|
|
34
|
+
return [
|
|
35
|
+
Message(
|
|
36
|
+
role='system',
|
|
37
|
+
content='You are a helpful assistant that combines summaries.',
|
|
38
|
+
),
|
|
39
|
+
Message(
|
|
40
|
+
role='user',
|
|
41
|
+
content=f"""
|
|
42
|
+
Synthesize the information from the following two summaries into a single succinct summary.
|
|
43
|
+
|
|
44
|
+
Summaries:
|
|
45
|
+
{json.dumps(context['node_summaries'], indent=2)}
|
|
46
|
+
|
|
47
|
+
Respond with a JSON object in the following format:
|
|
48
|
+
{{
|
|
49
|
+
"summary": "Summary containing the important information from both summaries"
|
|
50
|
+
}}
|
|
51
|
+
""",
|
|
52
|
+
),
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def summary_description(context: dict[str, Any]) -> list[Message]:
|
|
57
|
+
return [
|
|
58
|
+
Message(
|
|
59
|
+
role='system',
|
|
60
|
+
content='You are a helpful assistant that describes provided contents in a single sentence.',
|
|
61
|
+
),
|
|
62
|
+
Message(
|
|
63
|
+
role='user',
|
|
64
|
+
content=f"""
|
|
65
|
+
Create a short one sentence description of the summary that explains what kind of information is summarized.
|
|
66
|
+
|
|
67
|
+
Summary:
|
|
68
|
+
{json.dumps(context['summary'], indent=2)}
|
|
69
|
+
|
|
70
|
+
Respond with a JSON object in the following format:
|
|
71
|
+
{{
|
|
72
|
+
"description": "One sentence description of the provided summary"
|
|
73
|
+
}}
|
|
74
|
+
""",
|
|
75
|
+
),
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
versions: Versions = {'summarize_pair': summarize_pair, 'summary_description': summary_description}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# This file is intentionally left empty to indicate that the package is typed.
|
|
@@ -496,34 +496,39 @@ async def node_distance_reranker(
|
|
|
496
496
|
sorted_uuids = rrf(results)
|
|
497
497
|
scores: dict[str, float] = {}
|
|
498
498
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
records, _, _ = await driver.execute_query(
|
|
502
|
-
"""
|
|
499
|
+
# Find the shortest path to center node
|
|
500
|
+
query = Query("""
|
|
503
501
|
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
|
504
|
-
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
""",
|
|
508
|
-
edge_uuid=uuid,
|
|
509
|
-
center_uuid=center_node_uuid,
|
|
510
|
-
)
|
|
511
|
-
distance = 0.01
|
|
502
|
+
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid})
|
|
503
|
+
RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
|
504
|
+
""")
|
|
512
505
|
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
506
|
+
path_results = await asyncio.gather(
|
|
507
|
+
*[
|
|
508
|
+
driver.execute_query(
|
|
509
|
+
query,
|
|
510
|
+
edge_uuid=uuid,
|
|
511
|
+
center_uuid=center_node_uuid,
|
|
512
|
+
)
|
|
513
|
+
for uuid in sorted_uuids
|
|
514
|
+
]
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
for uuid, result in zip(sorted_uuids, path_results):
|
|
518
|
+
records = result[0]
|
|
519
|
+
record = records[0] if len(records) > 0 else None
|
|
520
|
+
distance: float = record['score'] if record is not None else float('inf')
|
|
521
|
+
if record is not None and (
|
|
522
|
+
record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid
|
|
523
|
+
):
|
|
524
|
+
distance = 0
|
|
520
525
|
|
|
521
526
|
if uuid in scores:
|
|
522
|
-
scores[uuid] = min(
|
|
527
|
+
scores[uuid] = min(distance, scores[uuid])
|
|
523
528
|
else:
|
|
524
|
-
scores[uuid] =
|
|
529
|
+
scores[uuid] = distance
|
|
525
530
|
|
|
526
531
|
# rerank on shortest distance
|
|
527
|
-
sorted_uuids.sort(
|
|
532
|
+
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
528
533
|
|
|
529
534
|
return sorted_uuids
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
|
|
6
|
+
from neo4j import AsyncDriver
|
|
7
|
+
|
|
8
|
+
from graphiti_core.edges import CommunityEdge
|
|
9
|
+
from graphiti_core.llm_client import LLMClient
|
|
10
|
+
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
11
|
+
from graphiti_core.prompts import prompt_library
|
|
12
|
+
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def build_community_projection(driver: AsyncDriver) -> str:
|
|
18
|
+
records, _, _ = await driver.execute_query("""
|
|
19
|
+
CALL gds.graph.project("communities", "Entity",
|
|
20
|
+
{RELATES_TO: {
|
|
21
|
+
type: "RELATES_TO",
|
|
22
|
+
orientation: "UNDIRECTED",
|
|
23
|
+
properties: {weight: {property: "*", aggregation: "COUNT"}}
|
|
24
|
+
}}
|
|
25
|
+
)
|
|
26
|
+
YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges
|
|
27
|
+
""")
|
|
28
|
+
|
|
29
|
+
return records[0]['graph']
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
async def destroy_projection(driver: AsyncDriver, projection_name: str):
|
|
33
|
+
await driver.execute_query(
|
|
34
|
+
"""
|
|
35
|
+
CALL gds.graph.drop($projection_name)
|
|
36
|
+
""",
|
|
37
|
+
projection_name=projection_name,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
async def get_community_clusters(
|
|
42
|
+
driver: AsyncDriver, projection_name: str
|
|
43
|
+
) -> list[list[EntityNode]]:
|
|
44
|
+
records, _, _ = await driver.execute_query("""
|
|
45
|
+
CALL gds.leiden.stream("communities")
|
|
46
|
+
YIELD nodeId, communityId
|
|
47
|
+
RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
|
|
48
|
+
""")
|
|
49
|
+
community_map: dict[int, list[str]] = defaultdict(list)
|
|
50
|
+
for record in records:
|
|
51
|
+
community_map[record['communityId']].append(record['entity_uuid'])
|
|
52
|
+
|
|
53
|
+
community_clusters: list[list[EntityNode]] = list(
|
|
54
|
+
await asyncio.gather(
|
|
55
|
+
*[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()]
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return community_clusters
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
|
|
63
|
+
# Prepare context for LLM
|
|
64
|
+
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
|
|
65
|
+
|
|
66
|
+
llm_response = await llm_client.generate_response(
|
|
67
|
+
prompt_library.summarize_nodes.summarize_pair(context)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
pair_summary = llm_response.get('summary', '')
|
|
71
|
+
|
|
72
|
+
return pair_summary
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
|
|
76
|
+
context = {'summary': summary}
|
|
77
|
+
|
|
78
|
+
llm_response = await llm_client.generate_response(
|
|
79
|
+
prompt_library.summarize_nodes.summary_description(context)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
description = llm_response.get('description', '')
|
|
83
|
+
|
|
84
|
+
return description
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
async def build_community(
|
|
88
|
+
llm_client: LLMClient, community_cluster: list[EntityNode]
|
|
89
|
+
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
|
90
|
+
summaries = [entity.summary for entity in community_cluster]
|
|
91
|
+
length = len(summaries)
|
|
92
|
+
while length > 1:
|
|
93
|
+
odd_one_out: str | None = None
|
|
94
|
+
if length % 2 == 1:
|
|
95
|
+
odd_one_out = summaries.pop()
|
|
96
|
+
length -= 1
|
|
97
|
+
new_summaries: list[str] = list(
|
|
98
|
+
await asyncio.gather(
|
|
99
|
+
*[
|
|
100
|
+
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
|
101
|
+
for left_summary, right_summary in zip(
|
|
102
|
+
summaries[: int(length / 2)], summaries[int(length / 2) :]
|
|
103
|
+
)
|
|
104
|
+
]
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
if odd_one_out is not None:
|
|
108
|
+
new_summaries.append(odd_one_out)
|
|
109
|
+
summaries = new_summaries
|
|
110
|
+
length = len(summaries)
|
|
111
|
+
|
|
112
|
+
summary = summaries[0]
|
|
113
|
+
name = await generate_summary_description(llm_client, summary)
|
|
114
|
+
now = datetime.now()
|
|
115
|
+
community_node = CommunityNode(
|
|
116
|
+
name=name,
|
|
117
|
+
group_id=community_cluster[0].group_id,
|
|
118
|
+
labels=['Community'],
|
|
119
|
+
created_at=now,
|
|
120
|
+
summary=summary,
|
|
121
|
+
)
|
|
122
|
+
community_edges = build_community_edges(community_cluster, community_node, now)
|
|
123
|
+
|
|
124
|
+
logger.info((community_node, community_edges))
|
|
125
|
+
|
|
126
|
+
return community_node, community_edges
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
async def build_communities(
|
|
130
|
+
driver: AsyncDriver, llm_client: LLMClient
|
|
131
|
+
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
|
132
|
+
projection = await build_community_projection(driver)
|
|
133
|
+
community_clusters = await get_community_clusters(driver, projection)
|
|
134
|
+
|
|
135
|
+
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
|
|
136
|
+
await asyncio.gather(
|
|
137
|
+
*[build_community(llm_client, cluster) for cluster in community_clusters]
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
community_nodes: list[CommunityNode] = []
|
|
142
|
+
community_edges: list[CommunityEdge] = []
|
|
143
|
+
for community in communities:
|
|
144
|
+
community_nodes.append(community[0])
|
|
145
|
+
community_edges.extend(community[1])
|
|
146
|
+
|
|
147
|
+
await destroy_projection(driver, projection)
|
|
148
|
+
return community_nodes, community_edges
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
async def remove_communities(driver: AsyncDriver):
|
|
152
|
+
await driver.execute_query("""
|
|
153
|
+
MATCH (c:Community)
|
|
154
|
+
DETACH DELETE c
|
|
155
|
+
""")
|
{graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/edge_operations.py
RENAMED
|
@@ -20,9 +20,9 @@ from datetime import datetime
|
|
|
20
20
|
from time import time
|
|
21
21
|
from typing import List
|
|
22
22
|
|
|
23
|
-
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
23
|
+
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
|
24
24
|
from graphiti_core.llm_client import LLMClient
|
|
25
|
-
from graphiti_core.nodes import EntityNode, EpisodicNode
|
|
25
|
+
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|
26
26
|
from graphiti_core.prompts import prompt_library
|
|
27
27
|
from graphiti_core.utils.maintenance.temporal_operations import (
|
|
28
28
|
extract_edge_dates,
|
|
@@ -50,6 +50,24 @@ def build_episodic_edges(
|
|
|
50
50
|
return edges
|
|
51
51
|
|
|
52
52
|
|
|
53
|
+
def build_community_edges(
|
|
54
|
+
entity_nodes: List[EntityNode],
|
|
55
|
+
community_node: CommunityNode,
|
|
56
|
+
created_at: datetime,
|
|
57
|
+
) -> List[CommunityEdge]:
|
|
58
|
+
edges: List[CommunityEdge] = [
|
|
59
|
+
CommunityEdge(
|
|
60
|
+
source_node_uuid=community_node.uuid,
|
|
61
|
+
target_node_uuid=node.uuid,
|
|
62
|
+
created_at=created_at,
|
|
63
|
+
group_id=community_node.group_id,
|
|
64
|
+
)
|
|
65
|
+
for node in entity_nodes
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
return edges
|
|
69
|
+
|
|
70
|
+
|
|
53
71
|
async def extract_edges(
|
|
54
72
|
llm_client: LLMClient,
|
|
55
73
|
episode: EpisodicNode,
|
{graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/graph_data_operations.py
RENAMED
|
@@ -32,8 +32,10 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
|
|
32
32
|
range_indices: list[LiteralString] = [
|
|
33
33
|
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
34
34
|
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
35
|
+
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
|
35
36
|
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
|
36
37
|
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
|
38
|
+
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
37
39
|
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
|
38
40
|
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
|
39
41
|
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
|
@@ -51,6 +53,7 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
|
|
51
53
|
|
|
52
54
|
fulltext_indices: list[LiteralString] = [
|
|
53
55
|
'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]',
|
|
56
|
+
'CREATE FULLTEXT INDEX community_name IF NOT EXISTS FOR (n:Community) ON EACH [n.name]',
|
|
54
57
|
'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]',
|
|
55
58
|
]
|
|
56
59
|
|
|
@@ -71,6 +74,14 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
|
|
71
74
|
`vector.similarity_function`: 'cosine'
|
|
72
75
|
}}
|
|
73
76
|
""",
|
|
77
|
+
"""
|
|
78
|
+
CREATE VECTOR INDEX community_name_embedding IF NOT EXISTS
|
|
79
|
+
FOR (n:Community) ON (n.name_embedding)
|
|
80
|
+
OPTIONS {indexConfig: {
|
|
81
|
+
`vector.dimensions`: 1024,
|
|
82
|
+
`vector.similarity_function`: 'cosine'
|
|
83
|
+
}}
|
|
84
|
+
""",
|
|
74
85
|
]
|
|
75
86
|
index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
|
|
76
87
|
|
{graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/node_operations.py
RENAMED
|
@@ -48,6 +48,29 @@ async def extract_message_nodes(
|
|
|
48
48
|
return extracted_node_data
|
|
49
49
|
|
|
50
50
|
|
|
51
|
+
async def extract_text_nodes(
|
|
52
|
+
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
|
53
|
+
) -> list[dict[str, Any]]:
|
|
54
|
+
# Prepare context for LLM
|
|
55
|
+
context = {
|
|
56
|
+
'episode_content': episode.content,
|
|
57
|
+
'episode_timestamp': episode.valid_at.isoformat(),
|
|
58
|
+
'previous_episodes': [
|
|
59
|
+
{
|
|
60
|
+
'content': ep.content,
|
|
61
|
+
'timestamp': ep.valid_at.isoformat(),
|
|
62
|
+
}
|
|
63
|
+
for ep in previous_episodes
|
|
64
|
+
],
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
llm_response = await llm_client.generate_response(
|
|
68
|
+
prompt_library.extract_nodes.extract_text(context)
|
|
69
|
+
)
|
|
70
|
+
extracted_node_data = llm_response.get('extracted_nodes', [])
|
|
71
|
+
return extracted_node_data
|
|
72
|
+
|
|
73
|
+
|
|
51
74
|
async def extract_json_nodes(
|
|
52
75
|
llm_client: LLMClient,
|
|
53
76
|
episode: EpisodicNode,
|
|
@@ -73,8 +96,10 @@ async def extract_nodes(
|
|
|
73
96
|
) -> list[EntityNode]:
|
|
74
97
|
start = time()
|
|
75
98
|
extracted_node_data: list[dict[str, Any]] = []
|
|
76
|
-
if episode.source
|
|
99
|
+
if episode.source == EpisodeType.message:
|
|
77
100
|
extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
|
|
101
|
+
elif episode.source == EpisodeType.text:
|
|
102
|
+
extracted_node_data = await extract_text_nodes(llm_client, episode, previous_episodes)
|
|
78
103
|
elif episode.source == EpisodeType.json:
|
|
79
104
|
extracted_node_data = await extract_json_nodes(llm_client, episode)
|
|
80
105
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "graphiti-core"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.3.0"
|
|
4
4
|
description = "A temporal graph building library"
|
|
5
5
|
authors = [
|
|
6
6
|
"Paul Paliychuk <paul@getzep.com>",
|
|
@@ -19,7 +19,7 @@ neo4j = "^5.23.0"
|
|
|
19
19
|
diskcache = "^5.6.3"
|
|
20
20
|
openai = "^1.38.0"
|
|
21
21
|
tenacity = "<9.0.0"
|
|
22
|
-
numpy = "
|
|
22
|
+
numpy = ">=1.0.0"
|
|
23
23
|
|
|
24
24
|
[tool.poetry.dev-dependencies]
|
|
25
25
|
pytest = "^8.3.2"
|
|
@@ -31,7 +31,7 @@ ruff = "^0.6.2"
|
|
|
31
31
|
[tool.poetry.group.dev.dependencies]
|
|
32
32
|
pydantic = "^2.8.2"
|
|
33
33
|
mypy = "^1.11.1"
|
|
34
|
-
groq = ">=0.9,<0.
|
|
34
|
+
groq = ">=0.9,<0.12"
|
|
35
35
|
anthropic = "^0.34.1"
|
|
36
36
|
ipykernel = "^6.29.5"
|
|
37
37
|
jupyterlab = "^4.2.4"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/temporal_operations.py
RENAMED
|
File without changes
|
|
File without changes
|