graphiti-core 0.4.3__tar.gz → 0.5.0rc2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (58) hide show
  1. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/PKG-INFO +1 -1
  2. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/cross_encoder/client.py +1 -1
  3. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/llm_client/anthropic_client.py +4 -1
  4. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/llm_client/client.py +20 -5
  5. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/llm_client/errors.py +8 -0
  6. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/llm_client/groq_client.py +4 -1
  7. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/llm_client/openai_client.py +29 -7
  8. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/dedupe_edges.py +20 -17
  9. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/dedupe_nodes.py +15 -1
  10. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/eval.py +17 -14
  11. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/extract_edge_dates.py +15 -7
  12. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/extract_edges.py +18 -19
  13. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/extract_nodes.py +11 -21
  14. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/invalidate_edges.py +13 -25
  15. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/summarize_nodes.py +17 -16
  16. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/utils/maintenance/community_operations.py +4 -2
  17. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/utils/maintenance/edge_operations.py +8 -4
  18. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/utils/maintenance/node_operations.py +14 -7
  19. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/utils/maintenance/temporal_operations.py +8 -2
  20. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/pyproject.toml +1 -1
  21. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/LICENSE +0 -0
  22. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/README.md +0 -0
  23. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/__init__.py +0 -0
  24. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/cross_encoder/__init__.py +0 -0
  25. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
  26. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/cross_encoder/openai_reranker_client.py +0 -0
  27. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/edges.py +0 -0
  28. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/embedder/__init__.py +0 -0
  29. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/embedder/client.py +0 -0
  30. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/embedder/openai.py +0 -0
  31. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/embedder/voyage.py +0 -0
  32. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/errors.py +0 -0
  33. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/graphiti.py +0 -0
  34. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/helpers.py +0 -0
  35. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/llm_client/__init__.py +0 -0
  36. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/llm_client/config.py +0 -0
  37. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/llm_client/utils.py +0 -0
  38. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/models/__init__.py +0 -0
  39. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/models/edges/__init__.py +0 -0
  40. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/models/edges/edge_db_queries.py +0 -0
  41. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/models/nodes/__init__.py +0 -0
  42. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/models/nodes/node_db_queries.py +0 -0
  43. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/nodes.py +0 -0
  44. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/__init__.py +0 -0
  45. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/lib.py +0 -0
  46. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/models.py +0 -0
  47. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/prompts/prompt_helpers.py +0 -0
  48. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/py.typed +0 -0
  49. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/search/__init__.py +0 -0
  50. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/search/search.py +0 -0
  51. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/search/search_config.py +0 -0
  52. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/search/search_config_recipes.py +0 -0
  53. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/search/search_utils.py +0 -0
  54. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/utils/__init__.py +0 -0
  55. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/utils/bulk_utils.py +0 -0
  56. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/utils/maintenance/__init__.py +0 -0
  57. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
  58. {graphiti_core-0.4.3 → graphiti_core-0.5.0rc2}/graphiti_core/utils/maintenance/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.4.3
3
+ Version: 0.5.0rc2
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -34,7 +34,7 @@ class CrossEncoderClient(ABC):
34
34
  passages (list[str]): A list of passages to rank.
35
35
 
36
36
  Returns:
37
- List[tuple[str, float]]: A list of tuples containing the passage and its score,
37
+ list[tuple[str, float]]: A list of tuples containing the passage and its score,
38
38
  sorted in descending order of relevance.
39
39
  """
40
40
  pass
@@ -20,6 +20,7 @@ import typing
20
20
 
21
21
  import anthropic
22
22
  from anthropic import AsyncAnthropic
23
+ from pydantic import BaseModel
23
24
 
24
25
  from ..prompts.models import Message
25
26
  from .client import LLMClient
@@ -46,7 +47,9 @@ class AnthropicClient(LLMClient):
46
47
  max_retries=1,
47
48
  )
48
49
 
49
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
50
+ async def _generate_response(
51
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
52
+ ) -> dict[str, typing.Any]:
50
53
  system_message = messages[0]
51
54
  user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
52
55
  {'role': 'assistant', 'content': '{'}
@@ -22,6 +22,7 @@ from abc import ABC, abstractmethod
22
22
 
23
23
  import httpx
24
24
  from diskcache import Cache
25
+ from pydantic import BaseModel
25
26
  from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
26
27
 
27
28
  from ..prompts.models import Message
@@ -66,14 +67,18 @@ class LLMClient(ABC):
66
67
  else None,
67
68
  reraise=True,
68
69
  )
69
- async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
70
+ async def _generate_response_with_retry(
71
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
72
+ ) -> dict[str, typing.Any]:
70
73
  try:
71
- return await self._generate_response(messages)
74
+ return await self._generate_response(messages, response_model)
72
75
  except (httpx.HTTPStatusError, RateLimitError) as e:
73
76
  raise e
74
77
 
75
78
  @abstractmethod
76
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
79
+ async def _generate_response(
80
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
81
+ ) -> dict[str, typing.Any]:
77
82
  pass
78
83
 
79
84
  def _get_cache_key(self, messages: list[Message]) -> str:
@@ -82,7 +87,17 @@ class LLMClient(ABC):
82
87
  key_str = f'{self.model}:{message_str}'
83
88
  return hashlib.md5(key_str.encode()).hexdigest()
84
89
 
85
- async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
90
+ async def generate_response(
91
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
92
+ ) -> dict[str, typing.Any]:
93
+ if response_model is not None:
94
+ serialized_model = json.dumps(response_model.model_json_schema())
95
+ messages[
96
+ -1
97
+ ].content += (
98
+ f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
99
+ )
100
+
86
101
  if self.cache_enabled:
87
102
  cache_key = self._get_cache_key(messages)
88
103
 
@@ -91,7 +106,7 @@ class LLMClient(ABC):
91
106
  logger.debug(f'Cache hit for {cache_key}')
92
107
  return cached_response
93
108
 
94
- response = await self._generate_response_with_retry(messages)
109
+ response = await self._generate_response_with_retry(messages, response_model)
95
110
 
96
111
  if self.cache_enabled:
97
112
  self.cache_dir.set(cache_key, response)
@@ -21,3 +21,11 @@ class RateLimitError(Exception):
21
21
  def __init__(self, message='Rate limit exceeded. Please try again later.'):
22
22
  self.message = message
23
23
  super().__init__(self.message)
24
+
25
+
26
+ class RefusalError(Exception):
27
+ """Exception raised when the LLM refuses to generate a response."""
28
+
29
+ def __init__(self, message: str):
30
+ self.message = message
31
+ super().__init__(self.message)
@@ -21,6 +21,7 @@ import typing
21
21
  import groq
22
22
  from groq import AsyncGroq
23
23
  from groq.types.chat import ChatCompletionMessageParam
24
+ from pydantic import BaseModel
24
25
 
25
26
  from ..prompts.models import Message
26
27
  from .client import LLMClient
@@ -43,7 +44,9 @@ class GroqClient(LLMClient):
43
44
 
44
45
  self.client = AsyncGroq(api_key=config.api_key)
45
46
 
46
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
47
+ async def _generate_response(
48
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
49
+ ) -> dict[str, typing.Any]:
47
50
  msgs: list[ChatCompletionMessageParam] = []
48
51
  for m in messages:
49
52
  if m.role == 'user':
@@ -14,18 +14,18 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import json
18
17
  import logging
19
18
  import typing
20
19
 
21
20
  import openai
22
21
  from openai import AsyncOpenAI
23
22
  from openai.types.chat import ChatCompletionMessageParam
23
+ from pydantic import BaseModel
24
24
 
25
25
  from ..prompts.models import Message
26
26
  from .client import LLMClient
27
27
  from .config import LLMConfig
28
- from .errors import RateLimitError
28
+ from .errors import RateLimitError, RefusalError
29
29
 
30
30
  logger = logging.getLogger(__name__)
31
31
 
@@ -65,6 +65,10 @@ class OpenAIClient(LLMClient):
65
65
  client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
66
66
 
67
67
  """
68
+ # removed caching to simplify the `generate_response` override
69
+ if cache:
70
+ raise NotImplementedError('Caching is not implemented for OpenAI')
71
+
68
72
  if config is None:
69
73
  config = LLMConfig()
70
74
 
@@ -75,7 +79,9 @@ class OpenAIClient(LLMClient):
75
79
  else:
76
80
  self.client = client
77
81
 
78
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
82
+ async def _generate_response(
83
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
84
+ ) -> dict[str, typing.Any]:
79
85
  openai_messages: list[ChatCompletionMessageParam] = []
80
86
  for m in messages:
81
87
  if m.role == 'user':
@@ -83,17 +89,33 @@ class OpenAIClient(LLMClient):
83
89
  elif m.role == 'system':
84
90
  openai_messages.append({'role': 'system', 'content': m.content})
85
91
  try:
86
- response = await self.client.chat.completions.create(
92
+ response = await self.client.beta.chat.completions.parse(
87
93
  model=self.model or DEFAULT_MODEL,
88
94
  messages=openai_messages,
89
95
  temperature=self.temperature,
90
96
  max_tokens=self.max_tokens,
91
- response_format={'type': 'json_object'},
97
+ response_format=response_model, # type: ignore
92
98
  )
93
- result = response.choices[0].message.content or ''
94
- return json.loads(result)
99
+
100
+ response_object = response.choices[0].message
101
+
102
+ if response_object.parsed:
103
+ return response_object.parsed.model_dump()
104
+ elif response_object.refusal:
105
+ raise RefusalError(response_object.refusal)
106
+ else:
107
+ raise Exception('No response from LLM')
108
+ except openai.LengthFinishReasonError as e:
109
+ raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
95
110
  except openai.RateLimitError as e:
96
111
  raise RateLimitError from e
97
112
  except Exception as e:
98
113
  logger.error(f'Error in generating LLM response: {e}')
99
114
  raise
115
+
116
+ async def generate_response(
117
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
118
+ ) -> dict[str, typing.Any]:
119
+ response = await self._generate_response(messages, response_model)
120
+
121
+ return response
@@ -15,11 +15,30 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import json
18
- from typing import Any, Protocol, TypedDict
18
+ from typing import Any, Optional, Protocol, TypedDict
19
+
20
+ from pydantic import BaseModel, Field
19
21
 
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class EdgeDuplicate(BaseModel):
26
+ is_duplicate: bool = Field(..., description='true or false')
27
+ uuid: Optional[str] = Field(
28
+ None,
29
+ description="uuid of the existing edge like '5d643020624c42fa9de13f97b1b3fa39' or null",
30
+ )
31
+
32
+
33
+ class UniqueFact(BaseModel):
34
+ uuid: str = Field(..., description='unique identifier of the fact')
35
+ fact: str = Field(..., description='fact of a unique edge')
36
+
37
+
38
+ class UniqueFacts(BaseModel):
39
+ unique_facts: list[UniqueFact]
40
+
41
+
23
42
  class Prompt(Protocol):
24
43
  edge: PromptVersion
25
44
  edge_list: PromptVersion
@@ -56,12 +75,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
56
75
 
57
76
  Guidelines:
58
77
  1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
59
-
60
- Respond with a JSON object in the following format:
61
- {{
62
- "is_duplicate": true or false,
63
- "uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
64
- }}
65
78
  """,
66
79
  ),
67
80
  ]
@@ -90,16 +103,6 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
90
103
  3. Facts will often discuss the same or similar relation between identical entities
91
104
  4. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
92
105
  facts should be in the response
93
-
94
- Respond with a JSON object in the following format:
95
- {{
96
- "unique_facts": [
97
- {{
98
- "uuid": "unique identifier of the fact",
99
- "fact": "fact of a unique edge"
100
- }}
101
- ]
102
- }}
103
106
  """,
104
107
  ),
105
108
  ]
@@ -15,11 +15,25 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import json
18
- from typing import Any, Protocol, TypedDict
18
+ from typing import Any, Optional, Protocol, TypedDict
19
+
20
+ from pydantic import BaseModel, Field
19
21
 
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class NodeDuplicate(BaseModel):
26
+ is_duplicate: bool = Field(..., description='true or false')
27
+ uuid: Optional[str] = Field(
28
+ None,
29
+ description="uuid of the existing node like '5d643020624c42fa9de13f97b1b3fa39' or null",
30
+ )
31
+ name: str = Field(
32
+ ...,
33
+ description="Updated name of the new node (use the best name between the new node's name, an existing duplicate name, or a combination of both)",
34
+ )
35
+
36
+
23
37
  class Prompt(Protocol):
24
38
  node: PromptVersion
25
39
  node_list: PromptVersion
@@ -17,9 +17,26 @@ limitations under the License.
17
17
  import json
18
18
  from typing import Any, Protocol, TypedDict
19
19
 
20
+ from pydantic import BaseModel, Field
21
+
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class QueryExpansion(BaseModel):
26
+ query: str = Field(..., description='query optimized for database search')
27
+
28
+
29
+ class QAResponse(BaseModel):
30
+ ANSWER: str = Field(..., description='how Alice would answer the question')
31
+
32
+
33
+ class EvalResponse(BaseModel):
34
+ is_correct: bool = Field(..., description='boolean if the answer is correct or incorrect')
35
+ reasoning: str = Field(
36
+ ..., description='why you determined the response was correct or incorrect'
37
+ )
38
+
39
+
23
40
  class Prompt(Protocol):
24
41
  qa_prompt: PromptVersion
25
42
  eval_prompt: PromptVersion
@@ -41,10 +58,6 @@ def query_expansion(context: dict[str, Any]) -> list[Message]:
41
58
  <QUESTION>
42
59
  {json.dumps(context['query'])}
43
60
  </QUESTION>
44
- respond with a JSON object in the following format:
45
- {{
46
- "query": "query optimized for database search"
47
- }}
48
61
  """
49
62
  return [
50
63
  Message(role='system', content=sys_prompt),
@@ -67,10 +80,6 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
67
80
  <QUESTION>
68
81
  {context['query']}
69
82
  </QUESTION>
70
- respond with a JSON object in the following format:
71
- {{
72
- "ANSWER": "how Alice would answer the question"
73
- }}
74
83
  """
75
84
  return [
76
85
  Message(role='system', content=sys_prompt),
@@ -96,12 +105,6 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
96
105
  <RESPONSE>
97
106
  {context['response']}
98
107
  </RESPONSE>
99
-
100
- respond with a JSON object in the following format:
101
- {{
102
- "is_correct": "boolean if the answer is correct or incorrect"
103
- "reasoning": "why you determined the response was correct or incorrect"
104
- }}
105
108
  """
106
109
  return [
107
110
  Message(role='system', content=sys_prompt),
@@ -14,11 +14,24 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from typing import Any, Protocol, TypedDict
17
+ from typing import Any, Optional, Protocol, TypedDict
18
+
19
+ from pydantic import BaseModel, Field
18
20
 
19
21
  from .models import Message, PromptFunction, PromptVersion
20
22
 
21
23
 
24
+ class EdgeDates(BaseModel):
25
+ valid_at: Optional[str] = Field(
26
+ None,
27
+ description='The date and time when the relationship described by the edge fact became true or was established. YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null.',
28
+ )
29
+ invalid_at: Optional[str] = Field(
30
+ None,
31
+ description='The date and time when the relationship described by the edge fact stopped being true or ended. YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null.',
32
+ )
33
+
34
+
22
35
  class Prompt(Protocol):
23
36
  v1: PromptVersion
24
37
 
@@ -60,7 +73,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
60
73
  Analyze the conversation and determine if there are dates that are part of the edge fact. Only set dates if they explicitly relate to the formation or alteration of the relationship itself.
61
74
 
62
75
  Guidelines:
63
- 1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
76
+ 1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ) for datetimes.
64
77
  2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
65
78
  3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
66
79
  4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
@@ -69,11 +82,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
69
82
  7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
70
83
  8. If only year is mentioned, use January 1st of that year at 00:00:00.
71
84
  9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
72
- Respond with a JSON object:
73
- {{
74
- "valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
75
- "invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
76
- }}
77
85
  """,
78
86
  ),
79
87
  ]
@@ -17,9 +17,26 @@ limitations under the License.
17
17
  import json
18
18
  from typing import Any, Protocol, TypedDict
19
19
 
20
+ from pydantic import BaseModel, Field
21
+
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class Edge(BaseModel):
26
+ relation_type: str = Field(..., description='RELATION_TYPE_IN_CAPS')
27
+ source_entity_name: str = Field(..., description='name of the source entity')
28
+ target_entity_name: str = Field(..., description='name of the target entity')
29
+ fact: str = Field(..., description='extracted factual information')
30
+
31
+
32
+ class ExtractedEdges(BaseModel):
33
+ edges: list[Edge]
34
+
35
+
36
+ class MissingFacts(BaseModel):
37
+ missing_facts: list[str] = Field(..., description="facts that weren't extracted")
38
+
39
+
23
40
  class Prompt(Protocol):
24
41
  edge: PromptVersion
25
42
  reflexion: PromptVersion
@@ -54,25 +71,12 @@ def edge(context: dict[str, Any]) -> list[Message]:
54
71
 
55
72
  Given the above MESSAGES and ENTITIES, extract all facts pertaining to the listed ENTITIES from the CURRENT MESSAGE.
56
73
 
57
-
58
74
  Guidelines:
59
75
  1. Extract facts only between the provided entities.
60
76
  2. Each fact should represent a clear relationship between two DISTINCT nodes.
61
77
  3. The relation_type should be a concise, all-caps description of the fact (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
62
78
  4. Provide a more detailed fact containing all relevant information.
63
79
  5. Consider temporal aspects of relationships when relevant.
64
-
65
- Respond with a JSON object in the following format:
66
- {{
67
- "edges": [
68
- {{
69
- "relation_type": "RELATION_TYPE_IN_CAPS",
70
- "source_entity_name": "name of the source entity",
71
- "target_entity_name": "name of the target entity",
72
- "fact": "extracted factual information",
73
- }}
74
- ]
75
- }}
76
80
  """,
77
81
  ),
78
82
  ]
@@ -98,12 +102,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
98
102
  </EXTRACTED FACTS>
99
103
 
100
104
  Given the above MESSAGES, list of EXTRACTED ENTITIES entities, and list of EXTRACTED FACTS;
101
- determine if any facts haven't been extracted:
102
-
103
- Respond with a JSON object in the following format:
104
- {{
105
- "missing_facts": [ "facts that weren't extracted", ...]
106
- }}
105
+ determine if any facts haven't been extracted.
107
106
  """
108
107
  return [
109
108
  Message(role='system', content=sys_prompt),
@@ -17,9 +17,19 @@ limitations under the License.
17
17
  import json
18
18
  from typing import Any, Protocol, TypedDict
19
19
 
20
+ from pydantic import BaseModel, Field
21
+
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class ExtractedNodes(BaseModel):
26
+ extracted_node_names: list[str] = Field(..., description='Name of the extracted entity')
27
+
28
+
29
+ class MissedEntities(BaseModel):
30
+ missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
31
+
32
+
23
33
  class Prompt(Protocol):
24
34
  extract_message: PromptVersion
25
35
  extract_json: PromptVersion
@@ -56,11 +66,6 @@ Guidelines:
56
66
  4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later).
57
67
  5. Be as explicit as possible in your node names, using full names.
58
68
  6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context.
59
-
60
- Respond with a JSON object in the following format:
61
- {{
62
- "extracted_node_names": ["Name of the extracted entity", ...],
63
- }}
64
69
  """
65
70
  return [
66
71
  Message(role='system', content=sys_prompt),
@@ -87,11 +92,6 @@ Given the above source description and JSON, extract relevant entity nodes from
87
92
  Guidelines:
88
93
  1. Always try to extract an entities that the JSON represents. This will often be something like a "name" or "user field
89
94
  2. Do NOT extract any properties that contain dates
90
-
91
- Respond with a JSON object in the following format:
92
- {{
93
- "extracted_node_names": ["Name of the extracted entity", ...],
94
- }}
95
95
  """
96
96
  return [
97
97
  Message(role='system', content=sys_prompt),
@@ -116,11 +116,6 @@ Guidelines:
116
116
  2. Avoid creating nodes for relationships or actions.
117
117
  3. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
118
118
  4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
119
-
120
- Respond with a JSON object in the following format:
121
- {{
122
- "extracted_node_names": ["Name of the extracted entity", ...],
123
- }}
124
119
  """
125
120
  return [
126
121
  Message(role='system', content=sys_prompt),
@@ -144,12 +139,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
144
139
  </EXTRACTED ENTITIES>
145
140
 
146
141
  Given the above previous messages, current message, and list of extracted entities; determine if any entities haven't been
147
- extracted:
148
-
149
- Respond with a JSON object in the following format:
150
- {{
151
- "missed_entities": [ "name of entity that wasn't extracted", ...]
152
- }}
142
+ extracted.
153
143
  """
154
144
  return [
155
145
  Message(role='system', content=sys_prompt),
@@ -16,9 +16,22 @@ limitations under the License.
16
16
 
17
17
  from typing import Any, Protocol, TypedDict
18
18
 
19
+ from pydantic import BaseModel, Field
20
+
19
21
  from .models import Message, PromptFunction, PromptVersion
20
22
 
21
23
 
24
+ class InvalidatedEdge(BaseModel):
25
+ uuid: str = Field(..., description='The UUID of the edge to be invalidated')
26
+ fact: str = Field(..., description='Updated fact of the edge')
27
+
28
+
29
+ class InvalidatedEdges(BaseModel):
30
+ invalidated_edges: list[InvalidatedEdge] = Field(
31
+ ..., description='List of edges that should be invalidated'
32
+ )
33
+
34
+
22
35
  class Prompt(Protocol):
23
36
  v1: PromptVersion
24
37
  v2: PromptVersion
@@ -56,18 +69,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
56
69
  {context['new_edges']}
57
70
 
58
71
  Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), START_DATE (END_DATE, optional))"
59
-
60
- For each existing edge that should be invalidated, respond with a JSON object in the following format:
61
- {{
62
- "invalidated_edges": [
63
- {{
64
- "edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
65
- "fact": "Updated fact of the edge"
66
- }}
67
- ]
68
- }}
69
-
70
- If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
71
72
  """,
72
73
  ),
73
74
  ]
@@ -89,19 +90,6 @@ def v2(context: dict[str, Any]) -> list[Message]:
89
90
 
90
91
  New Edge:
91
92
  {context['new_edge']}
92
-
93
-
94
- For each existing edge that should be invalidated, respond with a JSON object in the following format:
95
- {{
96
- "invalidated_edges": [
97
- {{
98
- "uuid": "The UUID of the edge to be invalidated",
99
- "fact": "Updated fact of the edge"
100
- }}
101
- ]
102
- }}
103
-
104
- If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
105
93
  """,
106
94
  ),
107
95
  ]
@@ -17,9 +17,21 @@ limitations under the License.
17
17
  import json
18
18
  from typing import Any, Protocol, TypedDict
19
19
 
20
+ from pydantic import BaseModel, Field
21
+
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class Summary(BaseModel):
26
+ summary: str = Field(
27
+ ..., description='Summary containing the important information from both summaries'
28
+ )
29
+
30
+
31
+ class SummaryDescription(BaseModel):
32
+ description: str = Field(..., description='One sentence description of the provided summary')
33
+
34
+
23
35
  class Prompt(Protocol):
24
36
  summarize_pair: PromptVersion
25
37
  summarize_context: PromptVersion
@@ -42,14 +54,11 @@ def summarize_pair(context: dict[str, Any]) -> list[Message]:
42
54
  role='user',
43
55
  content=f"""
44
56
  Synthesize the information from the following two summaries into a single succinct summary.
57
+
58
+ Summaries must be under 500 words.
45
59
 
46
60
  Summaries:
47
61
  {json.dumps(context['node_summaries'], indent=2)}
48
-
49
- Respond with a JSON object in the following format:
50
- {{
51
- "summary": "Summary containing the important information from both summaries"
52
- }}
53
62
  """,
54
63
  ),
55
64
  ]
@@ -74,15 +83,11 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
74
83
  information from the provided MESSAGES. Your summary should also only contain information relevant to the
75
84
  provided ENTITY.
76
85
 
86
+ Summaries must be under 500 words.
87
+
77
88
  <ENTITY>
78
89
  {context['node_name']}
79
90
  </ENTITY>
80
-
81
-
82
- Respond with a JSON object in the following format:
83
- {{
84
- "summary": "Entity summary"
85
- }}
86
91
  """,
87
92
  ),
88
93
  ]
@@ -98,14 +103,10 @@ def summary_description(context: dict[str, Any]) -> list[Message]:
98
103
  role='user',
99
104
  content=f"""
100
105
  Create a short one sentence description of the summary that explains what kind of information is summarized.
106
+ Summaries must be under 500 words.
101
107
 
102
108
  Summary:
103
109
  {json.dumps(context['summary'], indent=2)}
104
-
105
- Respond with a JSON object in the following format:
106
- {{
107
- "description": "One sentence description of the provided summary"
108
- }}
109
110
  """,
110
111
  ),
111
112
  ]
@@ -16,6 +16,7 @@ from graphiti_core.nodes import (
16
16
  get_community_node_from_record,
17
17
  )
18
18
  from graphiti_core.prompts import prompt_library
19
+ from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
19
20
  from graphiti_core.utils.maintenance.edge_operations import build_community_edges
20
21
 
21
22
  MAX_COMMUNITY_BUILD_CONCURRENCY = 10
@@ -131,7 +132,7 @@ async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -
131
132
  context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
132
133
 
133
134
  llm_response = await llm_client.generate_response(
134
- prompt_library.summarize_nodes.summarize_pair(context)
135
+ prompt_library.summarize_nodes.summarize_pair(context), response_model=Summary
135
136
  )
136
137
 
137
138
  pair_summary = llm_response.get('summary', '')
@@ -143,7 +144,8 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
143
144
  context = {'summary': summary}
144
145
 
145
146
  llm_response = await llm_client.generate_response(
146
- prompt_library.summarize_nodes.summary_description(context)
147
+ prompt_library.summarize_nodes.summary_description(context),
148
+ response_model=SummaryDescription,
147
149
  )
148
150
 
149
151
  description = llm_response.get('description', '')
@@ -24,6 +24,8 @@ from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
24
24
  from graphiti_core.llm_client import LLMClient
25
25
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
26
26
  from graphiti_core.prompts import prompt_library
27
+ from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
28
+ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
27
29
  from graphiti_core.utils.maintenance.temporal_operations import (
28
30
  extract_edge_dates,
29
31
  get_edge_contradictions,
@@ -91,7 +93,7 @@ async def extract_edges(
91
93
  reflexion_iterations = 0
92
94
  while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS:
93
95
  llm_response = await llm_client.generate_response(
94
- prompt_library.extract_edges.edge(context)
96
+ prompt_library.extract_edges.edge(context), response_model=ExtractedEdges
95
97
  )
96
98
  edges_data = llm_response.get('edges', [])
97
99
 
@@ -100,7 +102,7 @@ async def extract_edges(
100
102
  reflexion_iterations += 1
101
103
  if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
102
104
  reflexion_response = await llm_client.generate_response(
103
- prompt_library.extract_edges.reflexion(context)
105
+ prompt_library.extract_edges.reflexion(context), response_model=MissingFacts
104
106
  )
105
107
 
106
108
  missing_facts = reflexion_response.get('missing_facts', [])
@@ -317,7 +319,9 @@ async def dedupe_extracted_edge(
317
319
  'extracted_edges': extracted_edge_context,
318
320
  }
319
321
 
320
- llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.edge(context))
322
+ llm_response = await llm_client.generate_response(
323
+ prompt_library.dedupe_edges.edge(context), response_model=EdgeDuplicate
324
+ )
321
325
 
322
326
  is_duplicate: bool = llm_response.get('is_duplicate', False)
323
327
  uuid: str | None = llm_response.get('uuid', None)
@@ -352,7 +356,7 @@ async def dedupe_edge_list(
352
356
  context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
353
357
 
354
358
  llm_response = await llm_client.generate_response(
355
- prompt_library.dedupe_edges.edge_list(context)
359
+ prompt_library.dedupe_edges.edge_list(context), response_model=UniqueFacts
356
360
  )
357
361
  unique_edges_data = llm_response.get('unique_facts', [])
358
362
 
@@ -23,6 +23,9 @@ from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
23
23
  from graphiti_core.llm_client import LLMClient
24
24
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
25
25
  from graphiti_core.prompts import prompt_library
26
+ from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
27
+ from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities
28
+ from graphiti_core.prompts.summarize_nodes import Summary
26
29
 
27
30
  logger = logging.getLogger(__name__)
28
31
 
@@ -42,7 +45,7 @@ async def extract_message_nodes(
42
45
  }
43
46
 
44
47
  llm_response = await llm_client.generate_response(
45
- prompt_library.extract_nodes.extract_message(context)
48
+ prompt_library.extract_nodes.extract_message(context), response_model=ExtractedNodes
46
49
  )
47
50
  extracted_node_names = llm_response.get('extracted_node_names', [])
48
51
  return extracted_node_names
@@ -63,7 +66,7 @@ async def extract_text_nodes(
63
66
  }
64
67
 
65
68
  llm_response = await llm_client.generate_response(
66
- prompt_library.extract_nodes.extract_text(context)
69
+ prompt_library.extract_nodes.extract_text(context), ExtractedNodes
67
70
  )
68
71
  extracted_node_names = llm_response.get('extracted_node_names', [])
69
72
  return extracted_node_names
@@ -81,7 +84,7 @@ async def extract_json_nodes(
81
84
  }
82
85
 
83
86
  llm_response = await llm_client.generate_response(
84
- prompt_library.extract_nodes.extract_json(context)
87
+ prompt_library.extract_nodes.extract_json(context), ExtractedNodes
85
88
  )
86
89
  extracted_node_names = llm_response.get('extracted_node_names', [])
87
90
  return extracted_node_names
@@ -101,7 +104,7 @@ async def extract_nodes_reflexion(
101
104
  }
102
105
 
103
106
  llm_response = await llm_client.generate_response(
104
- prompt_library.extract_nodes.reflexion(context)
107
+ prompt_library.extract_nodes.reflexion(context), MissedEntities
105
108
  )
106
109
  missed_entities = llm_response.get('missed_entities', [])
107
110
 
@@ -273,9 +276,12 @@ async def resolve_extracted_node(
273
276
  }
274
277
 
275
278
  llm_response, node_summary_response = await asyncio.gather(
276
- llm_client.generate_response(prompt_library.dedupe_nodes.node(context)),
277
279
  llm_client.generate_response(
278
- prompt_library.summarize_nodes.summarize_context(summary_context)
280
+ prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
281
+ ),
282
+ llm_client.generate_response(
283
+ prompt_library.summarize_nodes.summarize_context(summary_context),
284
+ response_model=Summary,
279
285
  ),
280
286
  )
281
287
 
@@ -294,7 +300,8 @@ async def resolve_extracted_node(
294
300
  summary_response = await llm_client.generate_response(
295
301
  prompt_library.summarize_nodes.summarize_pair(
296
302
  {'node_summaries': [extracted_node.summary, existing_node.summary]}
297
- )
303
+ ),
304
+ response_model=Summary,
298
305
  )
299
306
  node = existing_node
300
307
  node.name = name
@@ -22,6 +22,8 @@ from graphiti_core.edges import EntityEdge
22
22
  from graphiti_core.llm_client import LLMClient
23
23
  from graphiti_core.nodes import EpisodicNode
24
24
  from graphiti_core.prompts import prompt_library
25
+ from graphiti_core.prompts.extract_edge_dates import EdgeDates
26
+ from graphiti_core.prompts.invalidate_edges import InvalidatedEdges
25
27
 
26
28
  logger = logging.getLogger(__name__)
27
29
 
@@ -38,7 +40,9 @@ async def extract_edge_dates(
38
40
  'previous_episodes': [ep.content for ep in previous_episodes],
39
41
  'reference_timestamp': current_episode.valid_at.isoformat(),
40
42
  }
41
- llm_response = await llm_client.generate_response(prompt_library.extract_edge_dates.v1(context))
43
+ llm_response = await llm_client.generate_response(
44
+ prompt_library.extract_edge_dates.v1(context), response_model=EdgeDates
45
+ )
42
46
 
43
47
  valid_at = llm_response.get('valid_at')
44
48
  invalid_at = llm_response.get('invalid_at')
@@ -75,7 +79,9 @@ async def get_edge_contradictions(
75
79
 
76
80
  context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context}
77
81
 
78
- llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v2(context))
82
+ llm_response = await llm_client.generate_response(
83
+ prompt_library.invalidate_edges.v2(context), response_model=InvalidatedEdges
84
+ )
79
85
 
80
86
  contradicted_edge_data = llm_response.get('invalidated_edges', [])
81
87
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "graphiti-core"
3
- version = "0.4.3"
3
+ version = "0.5.0pre2"
4
4
  description = "A temporal graph building library"
5
5
  authors = [
6
6
  "Paul Paliychuk <paul@getzep.com>",
File without changes