graphiti-core 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/bge_reranker_client.py +1 -2
- graphiti_core/cross_encoder/client.py +3 -4
- graphiti_core/cross_encoder/openai_reranker_client.py +2 -2
- graphiti_core/edges.py +56 -7
- graphiti_core/embedder/client.py +3 -3
- graphiti_core/embedder/openai.py +2 -2
- graphiti_core/embedder/voyage.py +3 -3
- graphiti_core/graphiti.py +39 -37
- graphiti_core/helpers.py +26 -0
- graphiti_core/llm_client/anthropic_client.py +4 -1
- graphiti_core/llm_client/client.py +45 -5
- graphiti_core/llm_client/errors.py +8 -0
- graphiti_core/llm_client/groq_client.py +4 -1
- graphiti_core/llm_client/openai_client.py +71 -7
- graphiti_core/llm_client/openai_generic_client.py +163 -0
- graphiti_core/nodes.py +58 -8
- graphiti_core/prompts/dedupe_edges.py +20 -17
- graphiti_core/prompts/dedupe_nodes.py +15 -1
- graphiti_core/prompts/eval.py +17 -14
- graphiti_core/prompts/extract_edge_dates.py +15 -7
- graphiti_core/prompts/extract_edges.py +18 -19
- graphiti_core/prompts/extract_nodes.py +11 -21
- graphiti_core/prompts/invalidate_edges.py +13 -25
- graphiti_core/prompts/lib.py +5 -1
- graphiti_core/prompts/prompt_helpers.py +1 -0
- graphiti_core/prompts/summarize_nodes.py +17 -16
- graphiti_core/search/search.py +5 -5
- graphiti_core/search/search_utils.py +55 -14
- graphiti_core/utils/__init__.py +0 -15
- graphiti_core/utils/bulk_utils.py +22 -15
- graphiti_core/utils/datetime_utils.py +42 -0
- graphiti_core/utils/maintenance/community_operations.py +13 -9
- graphiti_core/utils/maintenance/edge_operations.py +32 -26
- graphiti_core/utils/maintenance/graph_data_operations.py +3 -4
- graphiti_core/utils/maintenance/node_operations.py +19 -13
- graphiti_core/utils/maintenance/temporal_operations.py +17 -9
- {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0.dist-info}/METADATA +1 -1
- graphiti_core-0.5.0.dist-info/RECORD +60 -0
- graphiti_core-0.4.2.dist-info/RECORD +0 -57
- {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0.dist-info}/WHEEL +0 -0
|
@@ -21,6 +21,7 @@ import typing
|
|
|
21
21
|
import groq
|
|
22
22
|
from groq import AsyncGroq
|
|
23
23
|
from groq.types.chat import ChatCompletionMessageParam
|
|
24
|
+
from pydantic import BaseModel
|
|
24
25
|
|
|
25
26
|
from ..prompts.models import Message
|
|
26
27
|
from .client import LLMClient
|
|
@@ -43,7 +44,9 @@ class GroqClient(LLMClient):
|
|
|
43
44
|
|
|
44
45
|
self.client = AsyncGroq(api_key=config.api_key)
|
|
45
46
|
|
|
46
|
-
async def _generate_response(
|
|
47
|
+
async def _generate_response(
|
|
48
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
49
|
+
) -> dict[str, typing.Any]:
|
|
47
50
|
msgs: list[ChatCompletionMessageParam] = []
|
|
48
51
|
for m in messages:
|
|
49
52
|
if m.role == 'user':
|
|
@@ -14,18 +14,19 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import json
|
|
18
17
|
import logging
|
|
19
18
|
import typing
|
|
19
|
+
from typing import ClassVar
|
|
20
20
|
|
|
21
21
|
import openai
|
|
22
22
|
from openai import AsyncOpenAI
|
|
23
23
|
from openai.types.chat import ChatCompletionMessageParam
|
|
24
|
+
from pydantic import BaseModel
|
|
24
25
|
|
|
25
26
|
from ..prompts.models import Message
|
|
26
27
|
from .client import LLMClient
|
|
27
28
|
from .config import LLMConfig
|
|
28
|
-
from .errors import RateLimitError
|
|
29
|
+
from .errors import RateLimitError, RefusalError
|
|
29
30
|
|
|
30
31
|
logger = logging.getLogger(__name__)
|
|
31
32
|
|
|
@@ -53,6 +54,9 @@ class OpenAIClient(LLMClient):
|
|
|
53
54
|
Generates a response from the language model based on the provided messages.
|
|
54
55
|
"""
|
|
55
56
|
|
|
57
|
+
# Class-level constants
|
|
58
|
+
MAX_RETRIES: ClassVar[int] = 2
|
|
59
|
+
|
|
56
60
|
def __init__(
|
|
57
61
|
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
|
|
58
62
|
):
|
|
@@ -65,6 +69,10 @@ class OpenAIClient(LLMClient):
|
|
|
65
69
|
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
66
70
|
|
|
67
71
|
"""
|
|
72
|
+
# removed caching to simplify the `generate_response` override
|
|
73
|
+
if cache:
|
|
74
|
+
raise NotImplementedError('Caching is not implemented for OpenAI')
|
|
75
|
+
|
|
68
76
|
if config is None:
|
|
69
77
|
config = LLMConfig()
|
|
70
78
|
|
|
@@ -75,25 +83,81 @@ class OpenAIClient(LLMClient):
|
|
|
75
83
|
else:
|
|
76
84
|
self.client = client
|
|
77
85
|
|
|
78
|
-
async def _generate_response(
|
|
86
|
+
async def _generate_response(
|
|
87
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
88
|
+
) -> dict[str, typing.Any]:
|
|
79
89
|
openai_messages: list[ChatCompletionMessageParam] = []
|
|
80
90
|
for m in messages:
|
|
91
|
+
m.content = self._clean_input(m.content)
|
|
81
92
|
if m.role == 'user':
|
|
82
93
|
openai_messages.append({'role': 'user', 'content': m.content})
|
|
83
94
|
elif m.role == 'system':
|
|
84
95
|
openai_messages.append({'role': 'system', 'content': m.content})
|
|
85
96
|
try:
|
|
86
|
-
response = await self.client.chat.completions.
|
|
97
|
+
response = await self.client.beta.chat.completions.parse(
|
|
87
98
|
model=self.model or DEFAULT_MODEL,
|
|
88
99
|
messages=openai_messages,
|
|
89
100
|
temperature=self.temperature,
|
|
90
101
|
max_tokens=self.max_tokens,
|
|
91
|
-
response_format=
|
|
102
|
+
response_format=response_model, # type: ignore
|
|
92
103
|
)
|
|
93
|
-
|
|
94
|
-
|
|
104
|
+
|
|
105
|
+
response_object = response.choices[0].message
|
|
106
|
+
|
|
107
|
+
if response_object.parsed:
|
|
108
|
+
return response_object.parsed.model_dump()
|
|
109
|
+
elif response_object.refusal:
|
|
110
|
+
raise RefusalError(response_object.refusal)
|
|
111
|
+
else:
|
|
112
|
+
raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
|
|
113
|
+
except openai.LengthFinishReasonError as e:
|
|
114
|
+
raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
|
|
95
115
|
except openai.RateLimitError as e:
|
|
96
116
|
raise RateLimitError from e
|
|
97
117
|
except Exception as e:
|
|
98
118
|
logger.error(f'Error in generating LLM response: {e}')
|
|
99
119
|
raise
|
|
120
|
+
|
|
121
|
+
async def generate_response(
|
|
122
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
123
|
+
) -> dict[str, typing.Any]:
|
|
124
|
+
retry_count = 0
|
|
125
|
+
last_error = None
|
|
126
|
+
|
|
127
|
+
while retry_count <= self.MAX_RETRIES:
|
|
128
|
+
try:
|
|
129
|
+
response = await self._generate_response(messages, response_model)
|
|
130
|
+
return response
|
|
131
|
+
except (RateLimitError, RefusalError):
|
|
132
|
+
# These errors should not trigger retries
|
|
133
|
+
raise
|
|
134
|
+
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
|
|
135
|
+
# Let OpenAI's client handle these retries
|
|
136
|
+
raise
|
|
137
|
+
except Exception as e:
|
|
138
|
+
last_error = e
|
|
139
|
+
|
|
140
|
+
# Don't retry if we've hit the max retries
|
|
141
|
+
if retry_count >= self.MAX_RETRIES:
|
|
142
|
+
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
|
143
|
+
raise
|
|
144
|
+
|
|
145
|
+
retry_count += 1
|
|
146
|
+
|
|
147
|
+
# Construct a detailed error message for the LLM
|
|
148
|
+
error_context = (
|
|
149
|
+
f'The previous response attempt was invalid. '
|
|
150
|
+
f'Error type: {e.__class__.__name__}. '
|
|
151
|
+
f'Error details: {str(e)}. '
|
|
152
|
+
f'Please try again with a valid response, ensuring the output matches '
|
|
153
|
+
f'the expected format and constraints.'
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
error_message = Message(role='user', content=error_context)
|
|
157
|
+
messages.append(error_message)
|
|
158
|
+
logger.warning(
|
|
159
|
+
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# If we somehow get here, raise the last error
|
|
163
|
+
raise last_error or Exception('Max retries exceeded with no specific error')
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import typing
|
|
20
|
+
from typing import ClassVar
|
|
21
|
+
|
|
22
|
+
import openai
|
|
23
|
+
from openai import AsyncOpenAI
|
|
24
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
25
|
+
from pydantic import BaseModel
|
|
26
|
+
|
|
27
|
+
from ..prompts.models import Message
|
|
28
|
+
from .client import LLMClient
|
|
29
|
+
from .config import LLMConfig
|
|
30
|
+
from .errors import RateLimitError, RefusalError
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
DEFAULT_MODEL = 'gpt-4o-mini'
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OpenAIGenericClient(LLMClient):
|
|
38
|
+
"""
|
|
39
|
+
OpenAIClient is a client class for interacting with OpenAI's language models.
|
|
40
|
+
|
|
41
|
+
This class extends the LLMClient and provides methods to initialize the client,
|
|
42
|
+
get an embedder, and generate responses from the language model.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
client (AsyncOpenAI): The OpenAI client used to interact with the API.
|
|
46
|
+
model (str): The model name to use for generating responses.
|
|
47
|
+
temperature (float): The temperature to use for generating responses.
|
|
48
|
+
max_tokens (int): The maximum number of tokens to generate in a response.
|
|
49
|
+
|
|
50
|
+
Methods:
|
|
51
|
+
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
|
|
52
|
+
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
|
|
53
|
+
|
|
54
|
+
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
|
55
|
+
Generates a response from the language model based on the provided messages.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
# Class-level constants
|
|
59
|
+
MAX_RETRIES: ClassVar[int] = 2
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
|
|
63
|
+
):
|
|
64
|
+
"""
|
|
65
|
+
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
|
69
|
+
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
70
|
+
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
71
|
+
|
|
72
|
+
"""
|
|
73
|
+
# removed caching to simplify the `generate_response` override
|
|
74
|
+
if cache:
|
|
75
|
+
raise NotImplementedError('Caching is not implemented for OpenAI')
|
|
76
|
+
|
|
77
|
+
if config is None:
|
|
78
|
+
config = LLMConfig()
|
|
79
|
+
|
|
80
|
+
super().__init__(config, cache)
|
|
81
|
+
|
|
82
|
+
if client is None:
|
|
83
|
+
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
84
|
+
else:
|
|
85
|
+
self.client = client
|
|
86
|
+
|
|
87
|
+
async def _generate_response(
|
|
88
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
89
|
+
) -> dict[str, typing.Any]:
|
|
90
|
+
openai_messages: list[ChatCompletionMessageParam] = []
|
|
91
|
+
for m in messages:
|
|
92
|
+
m.content = self._clean_input(m.content)
|
|
93
|
+
if m.role == 'user':
|
|
94
|
+
openai_messages.append({'role': 'user', 'content': m.content})
|
|
95
|
+
elif m.role == 'system':
|
|
96
|
+
openai_messages.append({'role': 'system', 'content': m.content})
|
|
97
|
+
try:
|
|
98
|
+
response = await self.client.chat.completions.create(
|
|
99
|
+
model=self.model or DEFAULT_MODEL,
|
|
100
|
+
messages=openai_messages,
|
|
101
|
+
temperature=self.temperature,
|
|
102
|
+
max_tokens=self.max_tokens,
|
|
103
|
+
response_format={'type': 'json_object'},
|
|
104
|
+
)
|
|
105
|
+
result = response.choices[0].message.content or ''
|
|
106
|
+
return json.loads(result)
|
|
107
|
+
except openai.RateLimitError as e:
|
|
108
|
+
raise RateLimitError from e
|
|
109
|
+
except Exception as e:
|
|
110
|
+
logger.error(f'Error in generating LLM response: {e}')
|
|
111
|
+
raise
|
|
112
|
+
|
|
113
|
+
async def generate_response(
|
|
114
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
115
|
+
) -> dict[str, typing.Any]:
|
|
116
|
+
retry_count = 0
|
|
117
|
+
last_error = None
|
|
118
|
+
|
|
119
|
+
if response_model is not None:
|
|
120
|
+
serialized_model = json.dumps(response_model.model_json_schema())
|
|
121
|
+
messages[
|
|
122
|
+
-1
|
|
123
|
+
].content += (
|
|
124
|
+
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
while retry_count <= self.MAX_RETRIES:
|
|
128
|
+
try:
|
|
129
|
+
response = await self._generate_response(messages, response_model)
|
|
130
|
+
return response
|
|
131
|
+
except (RateLimitError, RefusalError):
|
|
132
|
+
# These errors should not trigger retries
|
|
133
|
+
raise
|
|
134
|
+
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
|
|
135
|
+
# Let OpenAI's client handle these retries
|
|
136
|
+
raise
|
|
137
|
+
except Exception as e:
|
|
138
|
+
last_error = e
|
|
139
|
+
|
|
140
|
+
# Don't retry if we've hit the max retries
|
|
141
|
+
if retry_count >= self.MAX_RETRIES:
|
|
142
|
+
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
|
143
|
+
raise
|
|
144
|
+
|
|
145
|
+
retry_count += 1
|
|
146
|
+
|
|
147
|
+
# Construct a detailed error message for the LLM
|
|
148
|
+
error_context = (
|
|
149
|
+
f'The previous response attempt was invalid. '
|
|
150
|
+
f'Error type: {e.__class__.__name__}. '
|
|
151
|
+
f'Error details: {str(e)}. '
|
|
152
|
+
f'Please try again with a valid response, ensuring the output matches '
|
|
153
|
+
f'the expected format and constraints.'
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
error_message = Message(role='user', content=error_context)
|
|
157
|
+
messages.append(error_message)
|
|
158
|
+
logger.warning(
|
|
159
|
+
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# If we somehow get here, raise the last error
|
|
163
|
+
raise last_error or Exception('Max retries exceeded with no specific error')
|
graphiti_core/nodes.py
CHANGED
|
@@ -16,7 +16,7 @@ limitations under the License.
|
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
|
-
from datetime import datetime
|
|
19
|
+
from datetime import datetime
|
|
20
20
|
from enum import Enum
|
|
21
21
|
from time import time
|
|
22
22
|
from typing import Any
|
|
@@ -24,6 +24,7 @@ from uuid import uuid4
|
|
|
24
24
|
|
|
25
25
|
from neo4j import AsyncDriver
|
|
26
26
|
from pydantic import BaseModel, Field
|
|
27
|
+
from typing_extensions import LiteralString
|
|
27
28
|
|
|
28
29
|
from graphiti_core.embedder import EmbedderClient
|
|
29
30
|
from graphiti_core.errors import NodeNotFoundError
|
|
@@ -33,6 +34,7 @@ from graphiti_core.models.nodes.node_db_queries import (
|
|
|
33
34
|
ENTITY_NODE_SAVE,
|
|
34
35
|
EPISODIC_NODE_SAVE,
|
|
35
36
|
)
|
|
37
|
+
from graphiti_core.utils.datetime_utils import utc_now
|
|
36
38
|
|
|
37
39
|
logger = logging.getLogger(__name__)
|
|
38
40
|
|
|
@@ -78,7 +80,7 @@ class Node(BaseModel, ABC):
|
|
|
78
80
|
name: str = Field(description='name of the node')
|
|
79
81
|
group_id: str = Field(description='partition of the graph')
|
|
80
82
|
labels: list[str] = Field(default_factory=list)
|
|
81
|
-
created_at: datetime = Field(default_factory=lambda:
|
|
83
|
+
created_at: datetime = Field(default_factory=lambda: utc_now())
|
|
82
84
|
|
|
83
85
|
@abstractmethod
|
|
84
86
|
async def save(self, driver: AsyncDriver): ...
|
|
@@ -207,10 +209,22 @@ class EpisodicNode(Node):
|
|
|
207
209
|
return episodes
|
|
208
210
|
|
|
209
211
|
@classmethod
|
|
210
|
-
async def get_by_group_ids(
|
|
212
|
+
async def get_by_group_ids(
|
|
213
|
+
cls,
|
|
214
|
+
driver: AsyncDriver,
|
|
215
|
+
group_ids: list[str],
|
|
216
|
+
limit: int | None = None,
|
|
217
|
+
created_at: datetime | None = None,
|
|
218
|
+
):
|
|
219
|
+
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
|
220
|
+
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
221
|
+
|
|
211
222
|
records, _, _ = await driver.execute_query(
|
|
212
223
|
"""
|
|
213
224
|
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
|
|
225
|
+
"""
|
|
226
|
+
+ cursor_query
|
|
227
|
+
+ """
|
|
214
228
|
RETURN DISTINCT
|
|
215
229
|
e.content AS content,
|
|
216
230
|
e.created_at AS created_at,
|
|
@@ -220,8 +234,12 @@ class EpisodicNode(Node):
|
|
|
220
234
|
e.group_id AS group_id,
|
|
221
235
|
e.source_description AS source_description,
|
|
222
236
|
e.source AS source
|
|
223
|
-
|
|
237
|
+
ORDER BY e.uuid DESC
|
|
238
|
+
"""
|
|
239
|
+
+ limit_query,
|
|
224
240
|
group_ids=group_ids,
|
|
241
|
+
created_at=created_at,
|
|
242
|
+
limit=limit,
|
|
225
243
|
database_=DEFAULT_DATABASE,
|
|
226
244
|
routing_='r',
|
|
227
245
|
)
|
|
@@ -308,10 +326,22 @@ class EntityNode(Node):
|
|
|
308
326
|
return nodes
|
|
309
327
|
|
|
310
328
|
@classmethod
|
|
311
|
-
async def get_by_group_ids(
|
|
329
|
+
async def get_by_group_ids(
|
|
330
|
+
cls,
|
|
331
|
+
driver: AsyncDriver,
|
|
332
|
+
group_ids: list[str],
|
|
333
|
+
limit: int | None = None,
|
|
334
|
+
created_at: datetime | None = None,
|
|
335
|
+
):
|
|
336
|
+
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
|
|
337
|
+
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
338
|
+
|
|
312
339
|
records, _, _ = await driver.execute_query(
|
|
313
340
|
"""
|
|
314
341
|
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
|
342
|
+
"""
|
|
343
|
+
+ cursor_query
|
|
344
|
+
+ """
|
|
315
345
|
RETURN
|
|
316
346
|
n.uuid As uuid,
|
|
317
347
|
n.name AS name,
|
|
@@ -319,8 +349,12 @@ class EntityNode(Node):
|
|
|
319
349
|
n.group_id AS group_id,
|
|
320
350
|
n.created_at AS created_at,
|
|
321
351
|
n.summary AS summary
|
|
322
|
-
|
|
352
|
+
ORDER BY n.uuid DESC
|
|
353
|
+
"""
|
|
354
|
+
+ limit_query,
|
|
323
355
|
group_ids=group_ids,
|
|
356
|
+
created_at=created_at,
|
|
357
|
+
limit=limit,
|
|
324
358
|
database_=DEFAULT_DATABASE,
|
|
325
359
|
routing_='r',
|
|
326
360
|
)
|
|
@@ -407,10 +441,22 @@ class CommunityNode(Node):
|
|
|
407
441
|
return communities
|
|
408
442
|
|
|
409
443
|
@classmethod
|
|
410
|
-
async def get_by_group_ids(
|
|
444
|
+
async def get_by_group_ids(
|
|
445
|
+
cls,
|
|
446
|
+
driver: AsyncDriver,
|
|
447
|
+
group_ids: list[str],
|
|
448
|
+
limit: int | None = None,
|
|
449
|
+
created_at: datetime | None = None,
|
|
450
|
+
):
|
|
451
|
+
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
|
|
452
|
+
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
453
|
+
|
|
411
454
|
records, _, _ = await driver.execute_query(
|
|
412
455
|
"""
|
|
413
456
|
MATCH (n:Community) WHERE n.group_id IN $group_ids
|
|
457
|
+
"""
|
|
458
|
+
+ cursor_query
|
|
459
|
+
+ """
|
|
414
460
|
RETURN
|
|
415
461
|
n.uuid As uuid,
|
|
416
462
|
n.name AS name,
|
|
@@ -418,8 +464,12 @@ class CommunityNode(Node):
|
|
|
418
464
|
n.group_id AS group_id,
|
|
419
465
|
n.created_at AS created_at,
|
|
420
466
|
n.summary AS summary
|
|
421
|
-
|
|
467
|
+
ORDER BY n.uuid DESC
|
|
468
|
+
"""
|
|
469
|
+
+ limit_query,
|
|
422
470
|
group_ids=group_ids,
|
|
471
|
+
created_at=created_at,
|
|
472
|
+
limit=limit,
|
|
423
473
|
database_=DEFAULT_DATABASE,
|
|
424
474
|
routing_='r',
|
|
425
475
|
)
|
|
@@ -15,11 +15,30 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import json
|
|
18
|
-
from typing import Any, Protocol, TypedDict
|
|
18
|
+
from typing import Any, Optional, Protocol, TypedDict
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
19
21
|
|
|
20
22
|
from .models import Message, PromptFunction, PromptVersion
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class EdgeDuplicate(BaseModel):
|
|
26
|
+
is_duplicate: bool = Field(..., description='true or false')
|
|
27
|
+
uuid: Optional[str] = Field(
|
|
28
|
+
None,
|
|
29
|
+
description="uuid of the existing edge like '5d643020624c42fa9de13f97b1b3fa39' or null",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class UniqueFact(BaseModel):
|
|
34
|
+
uuid: str = Field(..., description='unique identifier of the fact')
|
|
35
|
+
fact: str = Field(..., description='fact of a unique edge')
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class UniqueFacts(BaseModel):
|
|
39
|
+
unique_facts: list[UniqueFact]
|
|
40
|
+
|
|
41
|
+
|
|
23
42
|
class Prompt(Protocol):
|
|
24
43
|
edge: PromptVersion
|
|
25
44
|
edge_list: PromptVersion
|
|
@@ -56,12 +75,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
56
75
|
|
|
57
76
|
Guidelines:
|
|
58
77
|
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
|
|
59
|
-
|
|
60
|
-
Respond with a JSON object in the following format:
|
|
61
|
-
{{
|
|
62
|
-
"is_duplicate": true or false,
|
|
63
|
-
"uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
|
|
64
|
-
}}
|
|
65
78
|
""",
|
|
66
79
|
),
|
|
67
80
|
]
|
|
@@ -90,16 +103,6 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
|
|
|
90
103
|
3. Facts will often discuss the same or similar relation between identical entities
|
|
91
104
|
4. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
|
|
92
105
|
facts should be in the response
|
|
93
|
-
|
|
94
|
-
Respond with a JSON object in the following format:
|
|
95
|
-
{{
|
|
96
|
-
"unique_facts": [
|
|
97
|
-
{{
|
|
98
|
-
"uuid": "unique identifier of the fact",
|
|
99
|
-
"fact": "fact of a unique edge"
|
|
100
|
-
}}
|
|
101
|
-
]
|
|
102
|
-
}}
|
|
103
106
|
""",
|
|
104
107
|
),
|
|
105
108
|
]
|
|
@@ -15,11 +15,25 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import json
|
|
18
|
-
from typing import Any, Protocol, TypedDict
|
|
18
|
+
from typing import Any, Optional, Protocol, TypedDict
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
19
21
|
|
|
20
22
|
from .models import Message, PromptFunction, PromptVersion
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class NodeDuplicate(BaseModel):
|
|
26
|
+
is_duplicate: bool = Field(..., description='true or false')
|
|
27
|
+
uuid: Optional[str] = Field(
|
|
28
|
+
None,
|
|
29
|
+
description="uuid of the existing node like '5d643020624c42fa9de13f97b1b3fa39' or null",
|
|
30
|
+
)
|
|
31
|
+
name: str = Field(
|
|
32
|
+
...,
|
|
33
|
+
description="Updated name of the new node (use the best name between the new node's name, an existing duplicate name, or a combination of both)",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
23
37
|
class Prompt(Protocol):
|
|
24
38
|
node: PromptVersion
|
|
25
39
|
node_list: PromptVersion
|
graphiti_core/prompts/eval.py
CHANGED
|
@@ -17,9 +17,26 @@ limitations under the License.
|
|
|
17
17
|
import json
|
|
18
18
|
from typing import Any, Protocol, TypedDict
|
|
19
19
|
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
21
|
+
|
|
20
22
|
from .models import Message, PromptFunction, PromptVersion
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class QueryExpansion(BaseModel):
|
|
26
|
+
query: str = Field(..., description='query optimized for database search')
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class QAResponse(BaseModel):
|
|
30
|
+
ANSWER: str = Field(..., description='how Alice would answer the question')
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class EvalResponse(BaseModel):
|
|
34
|
+
is_correct: bool = Field(..., description='boolean if the answer is correct or incorrect')
|
|
35
|
+
reasoning: str = Field(
|
|
36
|
+
..., description='why you determined the response was correct or incorrect'
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
23
40
|
class Prompt(Protocol):
|
|
24
41
|
qa_prompt: PromptVersion
|
|
25
42
|
eval_prompt: PromptVersion
|
|
@@ -41,10 +58,6 @@ def query_expansion(context: dict[str, Any]) -> list[Message]:
|
|
|
41
58
|
<QUESTION>
|
|
42
59
|
{json.dumps(context['query'])}
|
|
43
60
|
</QUESTION>
|
|
44
|
-
respond with a JSON object in the following format:
|
|
45
|
-
{{
|
|
46
|
-
"query": "query optimized for database search"
|
|
47
|
-
}}
|
|
48
61
|
"""
|
|
49
62
|
return [
|
|
50
63
|
Message(role='system', content=sys_prompt),
|
|
@@ -67,10 +80,6 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
|
|
67
80
|
<QUESTION>
|
|
68
81
|
{context['query']}
|
|
69
82
|
</QUESTION>
|
|
70
|
-
respond with a JSON object in the following format:
|
|
71
|
-
{{
|
|
72
|
-
"ANSWER": "how Alice would answer the question"
|
|
73
|
-
}}
|
|
74
83
|
"""
|
|
75
84
|
return [
|
|
76
85
|
Message(role='system', content=sys_prompt),
|
|
@@ -96,12 +105,6 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
|
|
|
96
105
|
<RESPONSE>
|
|
97
106
|
{context['response']}
|
|
98
107
|
</RESPONSE>
|
|
99
|
-
|
|
100
|
-
respond with a JSON object in the following format:
|
|
101
|
-
{{
|
|
102
|
-
"is_correct": "boolean if the answer is correct or incorrect"
|
|
103
|
-
"reasoning": "why you determined the response was correct or incorrect"
|
|
104
|
-
}}
|
|
105
108
|
"""
|
|
106
109
|
return [
|
|
107
110
|
Message(role='system', content=sys_prompt),
|
|
@@ -14,11 +14,24 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
from typing import Any, Protocol, TypedDict
|
|
17
|
+
from typing import Any, Optional, Protocol, TypedDict
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel, Field
|
|
18
20
|
|
|
19
21
|
from .models import Message, PromptFunction, PromptVersion
|
|
20
22
|
|
|
21
23
|
|
|
24
|
+
class EdgeDates(BaseModel):
|
|
25
|
+
valid_at: Optional[str] = Field(
|
|
26
|
+
None,
|
|
27
|
+
description='The date and time when the relationship described by the edge fact became true or was established. YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null.',
|
|
28
|
+
)
|
|
29
|
+
invalid_at: Optional[str] = Field(
|
|
30
|
+
None,
|
|
31
|
+
description='The date and time when the relationship described by the edge fact stopped being true or ended. YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null.',
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
22
35
|
class Prompt(Protocol):
|
|
23
36
|
v1: PromptVersion
|
|
24
37
|
|
|
@@ -60,7 +73,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|
|
60
73
|
Analyze the conversation and determine if there are dates that are part of the edge fact. Only set dates if they explicitly relate to the formation or alteration of the relationship itself.
|
|
61
74
|
|
|
62
75
|
Guidelines:
|
|
63
|
-
1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:
|
|
76
|
+
1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ) for datetimes.
|
|
64
77
|
2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
|
|
65
78
|
3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
|
|
66
79
|
4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
|
|
@@ -69,11 +82,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|
|
69
82
|
7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
|
|
70
83
|
8. If only year is mentioned, use January 1st of that year at 00:00:00.
|
|
71
84
|
9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
|
|
72
|
-
Respond with a JSON object:
|
|
73
|
-
{{
|
|
74
|
-
"valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
|
|
75
|
-
"invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
|
|
76
|
-
}}
|
|
77
85
|
""",
|
|
78
86
|
),
|
|
79
87
|
]
|