graphiti-core 0.4.2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/bge_reranker_client.py +1 -2
- graphiti_core/cross_encoder/client.py +3 -4
- graphiti_core/edges.py +51 -5
- graphiti_core/embedder/client.py +3 -3
- graphiti_core/embedder/openai.py +2 -2
- graphiti_core/embedder/voyage.py +3 -3
- graphiti_core/graphiti.py +14 -10
- graphiti_core/helpers.py +1 -0
- graphiti_core/llm_client/anthropic_client.py +4 -1
- graphiti_core/llm_client/client.py +20 -5
- graphiti_core/llm_client/errors.py +8 -0
- graphiti_core/llm_client/groq_client.py +4 -1
- graphiti_core/llm_client/openai_client.py +29 -7
- graphiti_core/nodes.py +50 -4
- graphiti_core/prompts/dedupe_edges.py +20 -17
- graphiti_core/prompts/dedupe_nodes.py +15 -1
- graphiti_core/prompts/eval.py +17 -14
- graphiti_core/prompts/extract_edge_dates.py +15 -7
- graphiti_core/prompts/extract_edges.py +18 -19
- graphiti_core/prompts/extract_nodes.py +11 -21
- graphiti_core/prompts/invalidate_edges.py +13 -25
- graphiti_core/prompts/lib.py +5 -1
- graphiti_core/prompts/prompt_helpers.py +1 -0
- graphiti_core/prompts/summarize_nodes.py +12 -16
- graphiti_core/search/search_utils.py +1 -1
- graphiti_core/utils/maintenance/community_operations.py +4 -2
- graphiti_core/utils/maintenance/edge_operations.py +14 -11
- graphiti_core/utils/maintenance/node_operations.py +14 -7
- graphiti_core/utils/maintenance/temporal_operations.py +9 -4
- {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0rc1.dist-info}/METADATA +1 -1
- graphiti_core-0.5.0rc1.dist-info/RECORD +58 -0
- graphiti_core-0.4.2.dist-info/RECORD +0 -57
- {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0rc1.dist-info}/LICENSE +0 -0
- {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0rc1.dist-info}/WHEEL +0 -0
|
@@ -15,7 +15,6 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
|
-
from typing import List, Tuple
|
|
19
18
|
|
|
20
19
|
from sentence_transformers import CrossEncoder
|
|
21
20
|
|
|
@@ -26,7 +25,7 @@ class BGERerankerClient(CrossEncoderClient):
|
|
|
26
25
|
def __init__(self):
|
|
27
26
|
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
|
|
28
27
|
|
|
29
|
-
async def rank(self, query: str, passages:
|
|
28
|
+
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
|
30
29
|
if not passages:
|
|
31
30
|
return []
|
|
32
31
|
|
|
@@ -15,7 +15,6 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
from abc import ABC, abstractmethod
|
|
18
|
-
from typing import List, Tuple
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
class CrossEncoderClient(ABC):
|
|
@@ -26,16 +25,16 @@ class CrossEncoderClient(ABC):
|
|
|
26
25
|
"""
|
|
27
26
|
|
|
28
27
|
@abstractmethod
|
|
29
|
-
async def rank(self, query: str, passages:
|
|
28
|
+
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
|
30
29
|
"""
|
|
31
30
|
Rank the given passages based on their relevance to the query.
|
|
32
31
|
|
|
33
32
|
Args:
|
|
34
33
|
query (str): The query string.
|
|
35
|
-
passages (
|
|
34
|
+
passages (list[str]): A list of passages to rank.
|
|
36
35
|
|
|
37
36
|
Returns:
|
|
38
|
-
|
|
37
|
+
list[tuple[str, float]]: A list of tuples containing the passage and its score,
|
|
39
38
|
sorted in descending order of relevance.
|
|
40
39
|
"""
|
|
41
40
|
pass
|
graphiti_core/edges.py
CHANGED
|
@@ -23,10 +23,11 @@ from uuid import uuid4
|
|
|
23
23
|
|
|
24
24
|
from neo4j import AsyncDriver
|
|
25
25
|
from pydantic import BaseModel, Field
|
|
26
|
+
from typing_extensions import LiteralString
|
|
26
27
|
|
|
27
28
|
from graphiti_core.embedder import EmbedderClient
|
|
28
29
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
29
|
-
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
|
30
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT, parse_db_date
|
|
30
31
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
31
32
|
COMMUNITY_EDGE_SAVE,
|
|
32
33
|
ENTITY_EDGE_SAVE,
|
|
@@ -50,7 +51,7 @@ class Edge(BaseModel, ABC):
|
|
|
50
51
|
async def delete(self, driver: AsyncDriver):
|
|
51
52
|
result = await driver.execute_query(
|
|
52
53
|
"""
|
|
53
|
-
MATCH (n)-[e {uuid: $uuid}]->(m)
|
|
54
|
+
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
54
55
|
DELETE e
|
|
55
56
|
""",
|
|
56
57
|
uuid=self.uuid,
|
|
@@ -137,19 +138,34 @@ class EpisodicEdge(Edge):
|
|
|
137
138
|
return edges
|
|
138
139
|
|
|
139
140
|
@classmethod
|
|
140
|
-
async def get_by_group_ids(
|
|
141
|
+
async def get_by_group_ids(
|
|
142
|
+
cls,
|
|
143
|
+
driver: AsyncDriver,
|
|
144
|
+
group_ids: list[str],
|
|
145
|
+
limit: int = DEFAULT_PAGE_LIMIT,
|
|
146
|
+
created_at: datetime | None = None,
|
|
147
|
+
):
|
|
148
|
+
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
|
149
|
+
|
|
141
150
|
records, _, _ = await driver.execute_query(
|
|
142
151
|
"""
|
|
143
152
|
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
|
144
153
|
WHERE e.group_id IN $group_ids
|
|
154
|
+
"""
|
|
155
|
+
+ cursor_query
|
|
156
|
+
+ """
|
|
145
157
|
RETURN
|
|
146
158
|
e.uuid As uuid,
|
|
147
159
|
e.group_id AS group_id,
|
|
148
160
|
n.uuid AS source_node_uuid,
|
|
149
161
|
m.uuid AS target_node_uuid,
|
|
150
162
|
e.created_at AS created_at
|
|
163
|
+
ORDER BY e.uuid DESC
|
|
164
|
+
LIMIT $limit
|
|
151
165
|
""",
|
|
152
166
|
group_ids=group_ids,
|
|
167
|
+
created_at=created_at,
|
|
168
|
+
limit=limit,
|
|
153
169
|
database_=DEFAULT_DATABASE,
|
|
154
170
|
routing_='r',
|
|
155
171
|
)
|
|
@@ -274,11 +290,22 @@ class EntityEdge(Edge):
|
|
|
274
290
|
return edges
|
|
275
291
|
|
|
276
292
|
@classmethod
|
|
277
|
-
async def get_by_group_ids(
|
|
293
|
+
async def get_by_group_ids(
|
|
294
|
+
cls,
|
|
295
|
+
driver: AsyncDriver,
|
|
296
|
+
group_ids: list[str],
|
|
297
|
+
limit: int = DEFAULT_PAGE_LIMIT,
|
|
298
|
+
created_at: datetime | None = None,
|
|
299
|
+
):
|
|
300
|
+
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
|
301
|
+
|
|
278
302
|
records, _, _ = await driver.execute_query(
|
|
279
303
|
"""
|
|
280
304
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
281
305
|
WHERE e.group_id IN $group_ids
|
|
306
|
+
"""
|
|
307
|
+
+ cursor_query
|
|
308
|
+
+ """
|
|
282
309
|
RETURN
|
|
283
310
|
e.uuid AS uuid,
|
|
284
311
|
n.uuid AS source_node_uuid,
|
|
@@ -292,8 +319,12 @@ class EntityEdge(Edge):
|
|
|
292
319
|
e.expired_at AS expired_at,
|
|
293
320
|
e.valid_at AS valid_at,
|
|
294
321
|
e.invalid_at AS invalid_at
|
|
322
|
+
ORDER BY e.uuid DESC
|
|
323
|
+
LIMIT $limit
|
|
295
324
|
""",
|
|
296
325
|
group_ids=group_ids,
|
|
326
|
+
created_at=created_at,
|
|
327
|
+
limit=limit,
|
|
297
328
|
database_=DEFAULT_DATABASE,
|
|
298
329
|
routing_='r',
|
|
299
330
|
)
|
|
@@ -365,19 +396,34 @@ class CommunityEdge(Edge):
|
|
|
365
396
|
return edges
|
|
366
397
|
|
|
367
398
|
@classmethod
|
|
368
|
-
async def get_by_group_ids(
|
|
399
|
+
async def get_by_group_ids(
|
|
400
|
+
cls,
|
|
401
|
+
driver: AsyncDriver,
|
|
402
|
+
group_ids: list[str],
|
|
403
|
+
limit: int = DEFAULT_PAGE_LIMIT,
|
|
404
|
+
created_at: datetime | None = None,
|
|
405
|
+
):
|
|
406
|
+
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
|
407
|
+
|
|
369
408
|
records, _, _ = await driver.execute_query(
|
|
370
409
|
"""
|
|
371
410
|
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
|
372
411
|
WHERE e.group_id IN $group_ids
|
|
412
|
+
"""
|
|
413
|
+
+ cursor_query
|
|
414
|
+
+ """
|
|
373
415
|
RETURN
|
|
374
416
|
e.uuid As uuid,
|
|
375
417
|
e.group_id AS group_id,
|
|
376
418
|
n.uuid AS source_node_uuid,
|
|
377
419
|
m.uuid AS target_node_uuid,
|
|
378
420
|
e.created_at AS created_at
|
|
421
|
+
ORDER BY e.uuid DESC
|
|
422
|
+
LIMIT $limit
|
|
379
423
|
""",
|
|
380
424
|
group_ids=group_ids,
|
|
425
|
+
created_at=created_at,
|
|
426
|
+
limit=limit,
|
|
381
427
|
database_=DEFAULT_DATABASE,
|
|
382
428
|
routing_='r',
|
|
383
429
|
)
|
graphiti_core/embedder/client.py
CHANGED
|
@@ -15,7 +15,7 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
from abc import ABC, abstractmethod
|
|
18
|
-
from
|
|
18
|
+
from collections.abc import Iterable
|
|
19
19
|
|
|
20
20
|
from pydantic import BaseModel, Field
|
|
21
21
|
|
|
@@ -23,12 +23,12 @@ EMBEDDING_DIM = 1024
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class EmbedderConfig(BaseModel):
|
|
26
|
-
embedding_dim:
|
|
26
|
+
embedding_dim: int = Field(default=EMBEDDING_DIM, frozen=True)
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class EmbedderClient(ABC):
|
|
30
30
|
@abstractmethod
|
|
31
31
|
async def create(
|
|
32
|
-
self, input_data: str |
|
|
32
|
+
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
33
33
|
) -> list[float]:
|
|
34
34
|
pass
|
graphiti_core/embedder/openai.py
CHANGED
|
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
from
|
|
17
|
+
from collections.abc import Iterable
|
|
18
18
|
|
|
19
19
|
from openai import AsyncOpenAI
|
|
20
20
|
from openai.types import EmbeddingModel
|
|
@@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
|
|
|
42
42
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
43
43
|
|
|
44
44
|
async def create(
|
|
45
|
-
self, input_data: str |
|
|
45
|
+
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
46
46
|
) -> list[float]:
|
|
47
47
|
result = await self.client.embeddings.create(
|
|
48
48
|
input=input_data, model=self.config.embedding_model
|
graphiti_core/embedder/voyage.py
CHANGED
|
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
from
|
|
17
|
+
from collections.abc import Iterable
|
|
18
18
|
|
|
19
19
|
import voyageai # type: ignore
|
|
20
20
|
from pydantic import Field
|
|
@@ -41,11 +41,11 @@ class VoyageAIEmbedder(EmbedderClient):
|
|
|
41
41
|
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
|
42
42
|
|
|
43
43
|
async def create(
|
|
44
|
-
self, input_data: str |
|
|
44
|
+
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
45
45
|
) -> list[float]:
|
|
46
46
|
if isinstance(input_data, str):
|
|
47
47
|
input_list = [input_data]
|
|
48
|
-
elif isinstance(input_data,
|
|
48
|
+
elif isinstance(input_data, list):
|
|
49
49
|
input_list = [str(i) for i in input_data if i]
|
|
50
50
|
else:
|
|
51
51
|
input_list = [str(i) for i in input_data if i is not None]
|
graphiti_core/graphiti.py
CHANGED
|
@@ -318,17 +318,21 @@ class Graphiti:
|
|
|
318
318
|
previous_episodes = await self.retrieve_episodes(
|
|
319
319
|
reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
|
|
320
320
|
)
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
321
|
+
|
|
322
|
+
episode = (
|
|
323
|
+
await EpisodicNode.get_by_uuid(self.driver, uuid)
|
|
324
|
+
if uuid is not None
|
|
325
|
+
else EpisodicNode(
|
|
326
|
+
name=name,
|
|
327
|
+
group_id=group_id,
|
|
328
|
+
labels=[],
|
|
329
|
+
source=source,
|
|
330
|
+
content=episode_body,
|
|
331
|
+
source_description=source_description,
|
|
332
|
+
created_at=now,
|
|
333
|
+
valid_at=reference_time,
|
|
334
|
+
)
|
|
330
335
|
)
|
|
331
|
-
episode.uuid = uuid if uuid is not None else episode.uuid
|
|
332
336
|
|
|
333
337
|
# Extract entities as nodes
|
|
334
338
|
|
graphiti_core/helpers.py
CHANGED
|
@@ -26,6 +26,7 @@ load_dotenv()
|
|
|
26
26
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
|
27
27
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
28
28
|
MAX_REFLEXION_ITERATIONS = 2
|
|
29
|
+
DEFAULT_PAGE_LIMIT = 20
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
@@ -20,6 +20,7 @@ import typing
|
|
|
20
20
|
|
|
21
21
|
import anthropic
|
|
22
22
|
from anthropic import AsyncAnthropic
|
|
23
|
+
from pydantic import BaseModel
|
|
23
24
|
|
|
24
25
|
from ..prompts.models import Message
|
|
25
26
|
from .client import LLMClient
|
|
@@ -46,7 +47,9 @@ class AnthropicClient(LLMClient):
|
|
|
46
47
|
max_retries=1,
|
|
47
48
|
)
|
|
48
49
|
|
|
49
|
-
async def _generate_response(
|
|
50
|
+
async def _generate_response(
|
|
51
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
52
|
+
) -> dict[str, typing.Any]:
|
|
50
53
|
system_message = messages[0]
|
|
51
54
|
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
|
|
52
55
|
{'role': 'assistant', 'content': '{'}
|
|
@@ -22,6 +22,7 @@ from abc import ABC, abstractmethod
|
|
|
22
22
|
|
|
23
23
|
import httpx
|
|
24
24
|
from diskcache import Cache
|
|
25
|
+
from pydantic import BaseModel
|
|
25
26
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
|
26
27
|
|
|
27
28
|
from ..prompts.models import Message
|
|
@@ -66,14 +67,18 @@ class LLMClient(ABC):
|
|
|
66
67
|
else None,
|
|
67
68
|
reraise=True,
|
|
68
69
|
)
|
|
69
|
-
async def _generate_response_with_retry(
|
|
70
|
+
async def _generate_response_with_retry(
|
|
71
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
72
|
+
) -> dict[str, typing.Any]:
|
|
70
73
|
try:
|
|
71
|
-
return await self._generate_response(messages)
|
|
74
|
+
return await self._generate_response(messages, response_model)
|
|
72
75
|
except (httpx.HTTPStatusError, RateLimitError) as e:
|
|
73
76
|
raise e
|
|
74
77
|
|
|
75
78
|
@abstractmethod
|
|
76
|
-
async def _generate_response(
|
|
79
|
+
async def _generate_response(
|
|
80
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
81
|
+
) -> dict[str, typing.Any]:
|
|
77
82
|
pass
|
|
78
83
|
|
|
79
84
|
def _get_cache_key(self, messages: list[Message]) -> str:
|
|
@@ -82,7 +87,17 @@ class LLMClient(ABC):
|
|
|
82
87
|
key_str = f'{self.model}:{message_str}'
|
|
83
88
|
return hashlib.md5(key_str.encode()).hexdigest()
|
|
84
89
|
|
|
85
|
-
async def generate_response(
|
|
90
|
+
async def generate_response(
|
|
91
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
92
|
+
) -> dict[str, typing.Any]:
|
|
93
|
+
if response_model is not None:
|
|
94
|
+
serialized_model = json.dumps(response_model.model_json_schema())
|
|
95
|
+
messages[
|
|
96
|
+
-1
|
|
97
|
+
].content += (
|
|
98
|
+
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
|
|
99
|
+
)
|
|
100
|
+
|
|
86
101
|
if self.cache_enabled:
|
|
87
102
|
cache_key = self._get_cache_key(messages)
|
|
88
103
|
|
|
@@ -91,7 +106,7 @@ class LLMClient(ABC):
|
|
|
91
106
|
logger.debug(f'Cache hit for {cache_key}')
|
|
92
107
|
return cached_response
|
|
93
108
|
|
|
94
|
-
response = await self._generate_response_with_retry(messages)
|
|
109
|
+
response = await self._generate_response_with_retry(messages, response_model)
|
|
95
110
|
|
|
96
111
|
if self.cache_enabled:
|
|
97
112
|
self.cache_dir.set(cache_key, response)
|
|
@@ -21,3 +21,11 @@ class RateLimitError(Exception):
|
|
|
21
21
|
def __init__(self, message='Rate limit exceeded. Please try again later.'):
|
|
22
22
|
self.message = message
|
|
23
23
|
super().__init__(self.message)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RefusalError(Exception):
|
|
27
|
+
"""Exception raised when the LLM refuses to generate a response."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, message: str):
|
|
30
|
+
self.message = message
|
|
31
|
+
super().__init__(self.message)
|
|
@@ -21,6 +21,7 @@ import typing
|
|
|
21
21
|
import groq
|
|
22
22
|
from groq import AsyncGroq
|
|
23
23
|
from groq.types.chat import ChatCompletionMessageParam
|
|
24
|
+
from pydantic import BaseModel
|
|
24
25
|
|
|
25
26
|
from ..prompts.models import Message
|
|
26
27
|
from .client import LLMClient
|
|
@@ -43,7 +44,9 @@ class GroqClient(LLMClient):
|
|
|
43
44
|
|
|
44
45
|
self.client = AsyncGroq(api_key=config.api_key)
|
|
45
46
|
|
|
46
|
-
async def _generate_response(
|
|
47
|
+
async def _generate_response(
|
|
48
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
49
|
+
) -> dict[str, typing.Any]:
|
|
47
50
|
msgs: list[ChatCompletionMessageParam] = []
|
|
48
51
|
for m in messages:
|
|
49
52
|
if m.role == 'user':
|
|
@@ -14,18 +14,18 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import json
|
|
18
17
|
import logging
|
|
19
18
|
import typing
|
|
20
19
|
|
|
21
20
|
import openai
|
|
22
21
|
from openai import AsyncOpenAI
|
|
23
22
|
from openai.types.chat import ChatCompletionMessageParam
|
|
23
|
+
from pydantic import BaseModel
|
|
24
24
|
|
|
25
25
|
from ..prompts.models import Message
|
|
26
26
|
from .client import LLMClient
|
|
27
27
|
from .config import LLMConfig
|
|
28
|
-
from .errors import RateLimitError
|
|
28
|
+
from .errors import RateLimitError, RefusalError
|
|
29
29
|
|
|
30
30
|
logger = logging.getLogger(__name__)
|
|
31
31
|
|
|
@@ -65,6 +65,10 @@ class OpenAIClient(LLMClient):
|
|
|
65
65
|
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
66
66
|
|
|
67
67
|
"""
|
|
68
|
+
# removed caching to simplify the `generate_response` override
|
|
69
|
+
if cache:
|
|
70
|
+
raise NotImplementedError('Caching is not implemented for OpenAI')
|
|
71
|
+
|
|
68
72
|
if config is None:
|
|
69
73
|
config = LLMConfig()
|
|
70
74
|
|
|
@@ -75,7 +79,9 @@ class OpenAIClient(LLMClient):
|
|
|
75
79
|
else:
|
|
76
80
|
self.client = client
|
|
77
81
|
|
|
78
|
-
async def _generate_response(
|
|
82
|
+
async def _generate_response(
|
|
83
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
84
|
+
) -> dict[str, typing.Any]:
|
|
79
85
|
openai_messages: list[ChatCompletionMessageParam] = []
|
|
80
86
|
for m in messages:
|
|
81
87
|
if m.role == 'user':
|
|
@@ -83,17 +89,33 @@ class OpenAIClient(LLMClient):
|
|
|
83
89
|
elif m.role == 'system':
|
|
84
90
|
openai_messages.append({'role': 'system', 'content': m.content})
|
|
85
91
|
try:
|
|
86
|
-
response = await self.client.chat.completions.
|
|
92
|
+
response = await self.client.beta.chat.completions.parse(
|
|
87
93
|
model=self.model or DEFAULT_MODEL,
|
|
88
94
|
messages=openai_messages,
|
|
89
95
|
temperature=self.temperature,
|
|
90
96
|
max_tokens=self.max_tokens,
|
|
91
|
-
response_format=
|
|
97
|
+
response_format=response_model, # type: ignore
|
|
92
98
|
)
|
|
93
|
-
|
|
94
|
-
|
|
99
|
+
|
|
100
|
+
response_object = response.choices[0].message
|
|
101
|
+
|
|
102
|
+
if response_object.parsed:
|
|
103
|
+
return response_object.parsed.model_dump()
|
|
104
|
+
elif response_object.refusal:
|
|
105
|
+
raise RefusalError(response_object.refusal)
|
|
106
|
+
else:
|
|
107
|
+
raise Exception('No response from LLM')
|
|
108
|
+
except openai.LengthFinishReasonError as e:
|
|
109
|
+
raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
|
|
95
110
|
except openai.RateLimitError as e:
|
|
96
111
|
raise RateLimitError from e
|
|
97
112
|
except Exception as e:
|
|
98
113
|
logger.error(f'Error in generating LLM response: {e}')
|
|
99
114
|
raise
|
|
115
|
+
|
|
116
|
+
async def generate_response(
|
|
117
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
118
|
+
) -> dict[str, typing.Any]:
|
|
119
|
+
response = await self._generate_response(messages, response_model)
|
|
120
|
+
|
|
121
|
+
return response
|
graphiti_core/nodes.py
CHANGED
|
@@ -24,10 +24,11 @@ from uuid import uuid4
|
|
|
24
24
|
|
|
25
25
|
from neo4j import AsyncDriver
|
|
26
26
|
from pydantic import BaseModel, Field
|
|
27
|
+
from typing_extensions import LiteralString
|
|
27
28
|
|
|
28
29
|
from graphiti_core.embedder import EmbedderClient
|
|
29
30
|
from graphiti_core.errors import NodeNotFoundError
|
|
30
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
31
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT
|
|
31
32
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
32
33
|
COMMUNITY_NODE_SAVE,
|
|
33
34
|
ENTITY_NODE_SAVE,
|
|
@@ -207,10 +208,21 @@ class EpisodicNode(Node):
|
|
|
207
208
|
return episodes
|
|
208
209
|
|
|
209
210
|
@classmethod
|
|
210
|
-
async def get_by_group_ids(
|
|
211
|
+
async def get_by_group_ids(
|
|
212
|
+
cls,
|
|
213
|
+
driver: AsyncDriver,
|
|
214
|
+
group_ids: list[str],
|
|
215
|
+
limit: int = DEFAULT_PAGE_LIMIT,
|
|
216
|
+
created_at: datetime | None = None,
|
|
217
|
+
):
|
|
218
|
+
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
|
219
|
+
|
|
211
220
|
records, _, _ = await driver.execute_query(
|
|
212
221
|
"""
|
|
213
222
|
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
|
|
223
|
+
"""
|
|
224
|
+
+ cursor_query
|
|
225
|
+
+ """
|
|
214
226
|
RETURN DISTINCT
|
|
215
227
|
e.content AS content,
|
|
216
228
|
e.created_at AS created_at,
|
|
@@ -220,8 +232,12 @@ class EpisodicNode(Node):
|
|
|
220
232
|
e.group_id AS group_id,
|
|
221
233
|
e.source_description AS source_description,
|
|
222
234
|
e.source AS source
|
|
235
|
+
ORDER BY e.uuid DESC
|
|
236
|
+
LIMIT $limit
|
|
223
237
|
""",
|
|
224
238
|
group_ids=group_ids,
|
|
239
|
+
created_at=created_at,
|
|
240
|
+
limit=limit,
|
|
225
241
|
database_=DEFAULT_DATABASE,
|
|
226
242
|
routing_='r',
|
|
227
243
|
)
|
|
@@ -308,10 +324,21 @@ class EntityNode(Node):
|
|
|
308
324
|
return nodes
|
|
309
325
|
|
|
310
326
|
@classmethod
|
|
311
|
-
async def get_by_group_ids(
|
|
327
|
+
async def get_by_group_ids(
|
|
328
|
+
cls,
|
|
329
|
+
driver: AsyncDriver,
|
|
330
|
+
group_ids: list[str],
|
|
331
|
+
limit: int = DEFAULT_PAGE_LIMIT,
|
|
332
|
+
created_at: datetime | None = None,
|
|
333
|
+
):
|
|
334
|
+
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
|
|
335
|
+
|
|
312
336
|
records, _, _ = await driver.execute_query(
|
|
313
337
|
"""
|
|
314
338
|
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
|
339
|
+
"""
|
|
340
|
+
+ cursor_query
|
|
341
|
+
+ """
|
|
315
342
|
RETURN
|
|
316
343
|
n.uuid As uuid,
|
|
317
344
|
n.name AS name,
|
|
@@ -319,8 +346,12 @@ class EntityNode(Node):
|
|
|
319
346
|
n.group_id AS group_id,
|
|
320
347
|
n.created_at AS created_at,
|
|
321
348
|
n.summary AS summary
|
|
349
|
+
ORDER BY n.uuid DESC
|
|
350
|
+
LIMIT $limit
|
|
322
351
|
""",
|
|
323
352
|
group_ids=group_ids,
|
|
353
|
+
created_at=created_at,
|
|
354
|
+
limit=limit,
|
|
324
355
|
database_=DEFAULT_DATABASE,
|
|
325
356
|
routing_='r',
|
|
326
357
|
)
|
|
@@ -407,10 +438,21 @@ class CommunityNode(Node):
|
|
|
407
438
|
return communities
|
|
408
439
|
|
|
409
440
|
@classmethod
|
|
410
|
-
async def get_by_group_ids(
|
|
441
|
+
async def get_by_group_ids(
|
|
442
|
+
cls,
|
|
443
|
+
driver: AsyncDriver,
|
|
444
|
+
group_ids: list[str],
|
|
445
|
+
limit: int = DEFAULT_PAGE_LIMIT,
|
|
446
|
+
created_at: datetime | None = None,
|
|
447
|
+
):
|
|
448
|
+
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
|
|
449
|
+
|
|
411
450
|
records, _, _ = await driver.execute_query(
|
|
412
451
|
"""
|
|
413
452
|
MATCH (n:Community) WHERE n.group_id IN $group_ids
|
|
453
|
+
"""
|
|
454
|
+
+ cursor_query
|
|
455
|
+
+ """
|
|
414
456
|
RETURN
|
|
415
457
|
n.uuid As uuid,
|
|
416
458
|
n.name AS name,
|
|
@@ -418,8 +460,12 @@ class CommunityNode(Node):
|
|
|
418
460
|
n.group_id AS group_id,
|
|
419
461
|
n.created_at AS created_at,
|
|
420
462
|
n.summary AS summary
|
|
463
|
+
ORDER BY n.uuid DESC
|
|
464
|
+
LIMIT $limit
|
|
421
465
|
""",
|
|
422
466
|
group_ids=group_ids,
|
|
467
|
+
created_at=created_at,
|
|
468
|
+
limit=limit,
|
|
423
469
|
database_=DEFAULT_DATABASE,
|
|
424
470
|
routing_='r',
|
|
425
471
|
)
|
|
@@ -15,11 +15,30 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import json
|
|
18
|
-
from typing import Any, Protocol, TypedDict
|
|
18
|
+
from typing import Any, Optional, Protocol, TypedDict
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
19
21
|
|
|
20
22
|
from .models import Message, PromptFunction, PromptVersion
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class EdgeDuplicate(BaseModel):
|
|
26
|
+
is_duplicate: bool = Field(..., description='true or false')
|
|
27
|
+
uuid: Optional[str] = Field(
|
|
28
|
+
None,
|
|
29
|
+
description="uuid of the existing edge like '5d643020624c42fa9de13f97b1b3fa39' or null",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class UniqueFact(BaseModel):
|
|
34
|
+
uuid: str = Field(..., description='unique identifier of the fact')
|
|
35
|
+
fact: str = Field(..., description='fact of a unique edge')
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class UniqueFacts(BaseModel):
|
|
39
|
+
unique_facts: list[UniqueFact]
|
|
40
|
+
|
|
41
|
+
|
|
23
42
|
class Prompt(Protocol):
|
|
24
43
|
edge: PromptVersion
|
|
25
44
|
edge_list: PromptVersion
|
|
@@ -56,12 +75,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
56
75
|
|
|
57
76
|
Guidelines:
|
|
58
77
|
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
|
|
59
|
-
|
|
60
|
-
Respond with a JSON object in the following format:
|
|
61
|
-
{{
|
|
62
|
-
"is_duplicate": true or false,
|
|
63
|
-
"uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
|
|
64
|
-
}}
|
|
65
78
|
""",
|
|
66
79
|
),
|
|
67
80
|
]
|
|
@@ -90,16 +103,6 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
|
|
|
90
103
|
3. Facts will often discuss the same or similar relation between identical entities
|
|
91
104
|
4. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
|
|
92
105
|
facts should be in the response
|
|
93
|
-
|
|
94
|
-
Respond with a JSON object in the following format:
|
|
95
|
-
{{
|
|
96
|
-
"unique_facts": [
|
|
97
|
-
{{
|
|
98
|
-
"uuid": "unique identifier of the fact",
|
|
99
|
-
"fact": "fact of a unique edge"
|
|
100
|
-
}}
|
|
101
|
-
]
|
|
102
|
-
}}
|
|
103
106
|
""",
|
|
104
107
|
),
|
|
105
108
|
]
|
|
@@ -15,11 +15,25 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import json
|
|
18
|
-
from typing import Any, Protocol, TypedDict
|
|
18
|
+
from typing import Any, Optional, Protocol, TypedDict
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
19
21
|
|
|
20
22
|
from .models import Message, PromptFunction, PromptVersion
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class NodeDuplicate(BaseModel):
|
|
26
|
+
is_duplicate: bool = Field(..., description='true or false')
|
|
27
|
+
uuid: Optional[str] = Field(
|
|
28
|
+
None,
|
|
29
|
+
description="uuid of the existing node like '5d643020624c42fa9de13f97b1b3fa39' or null",
|
|
30
|
+
)
|
|
31
|
+
name: str = Field(
|
|
32
|
+
...,
|
|
33
|
+
description="Updated name of the new node (use the best name between the new node's name, an existing duplicate name, or a combination of both)",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
23
37
|
class Prompt(Protocol):
|
|
24
38
|
node: PromptVersion
|
|
25
39
|
node_list: PromptVersion
|