graphiti-core 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (35) hide show
  1. graphiti_core/cross_encoder/client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +2 -2
  3. graphiti_core/edges.py +13 -10
  4. graphiti_core/graphiti.py +25 -27
  5. graphiti_core/helpers.py +25 -0
  6. graphiti_core/llm_client/anthropic_client.py +4 -1
  7. graphiti_core/llm_client/client.py +45 -5
  8. graphiti_core/llm_client/errors.py +8 -0
  9. graphiti_core/llm_client/groq_client.py +4 -1
  10. graphiti_core/llm_client/openai_client.py +71 -7
  11. graphiti_core/llm_client/openai_generic_client.py +163 -0
  12. graphiti_core/nodes.py +16 -12
  13. graphiti_core/prompts/dedupe_edges.py +20 -17
  14. graphiti_core/prompts/dedupe_nodes.py +15 -1
  15. graphiti_core/prompts/eval.py +17 -14
  16. graphiti_core/prompts/extract_edge_dates.py +15 -7
  17. graphiti_core/prompts/extract_edges.py +18 -19
  18. graphiti_core/prompts/extract_nodes.py +11 -21
  19. graphiti_core/prompts/invalidate_edges.py +13 -25
  20. graphiti_core/prompts/summarize_nodes.py +17 -16
  21. graphiti_core/search/search.py +5 -5
  22. graphiti_core/search/search_utils.py +54 -13
  23. graphiti_core/utils/__init__.py +0 -15
  24. graphiti_core/utils/bulk_utils.py +22 -15
  25. graphiti_core/utils/datetime_utils.py +42 -0
  26. graphiti_core/utils/maintenance/community_operations.py +13 -9
  27. graphiti_core/utils/maintenance/edge_operations.py +26 -19
  28. graphiti_core/utils/maintenance/graph_data_operations.py +3 -4
  29. graphiti_core/utils/maintenance/node_operations.py +19 -13
  30. graphiti_core/utils/maintenance/temporal_operations.py +16 -7
  31. {graphiti_core-0.4.3.dist-info → graphiti_core-0.5.0.dist-info}/METADATA +1 -1
  32. graphiti_core-0.5.0.dist-info/RECORD +60 -0
  33. graphiti_core-0.4.3.dist-info/RECORD +0 -58
  34. {graphiti_core-0.4.3.dist-info → graphiti_core-0.5.0.dist-info}/LICENSE +0 -0
  35. {graphiti_core-0.4.3.dist-info → graphiti_core-0.5.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,163 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import json
18
+ import logging
19
+ import typing
20
+ from typing import ClassVar
21
+
22
+ import openai
23
+ from openai import AsyncOpenAI
24
+ from openai.types.chat import ChatCompletionMessageParam
25
+ from pydantic import BaseModel
26
+
27
+ from ..prompts.models import Message
28
+ from .client import LLMClient
29
+ from .config import LLMConfig
30
+ from .errors import RateLimitError, RefusalError
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ DEFAULT_MODEL = 'gpt-4o-mini'
35
+
36
+
37
+ class OpenAIGenericClient(LLMClient):
38
+ """
39
+ OpenAIClient is a client class for interacting with OpenAI's language models.
40
+
41
+ This class extends the LLMClient and provides methods to initialize the client,
42
+ get an embedder, and generate responses from the language model.
43
+
44
+ Attributes:
45
+ client (AsyncOpenAI): The OpenAI client used to interact with the API.
46
+ model (str): The model name to use for generating responses.
47
+ temperature (float): The temperature to use for generating responses.
48
+ max_tokens (int): The maximum number of tokens to generate in a response.
49
+
50
+ Methods:
51
+ __init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
52
+ Initializes the OpenAIClient with the provided configuration, cache setting, and client.
53
+
54
+ _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
55
+ Generates a response from the language model based on the provided messages.
56
+ """
57
+
58
+ # Class-level constants
59
+ MAX_RETRIES: ClassVar[int] = 2
60
+
61
+ def __init__(
62
+ self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
63
+ ):
64
+ """
65
+ Initialize the OpenAIClient with the provided configuration, cache setting, and client.
66
+
67
+ Args:
68
+ config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
69
+ cache (bool): Whether to use caching for responses. Defaults to False.
70
+ client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
71
+
72
+ """
73
+ # removed caching to simplify the `generate_response` override
74
+ if cache:
75
+ raise NotImplementedError('Caching is not implemented for OpenAI')
76
+
77
+ if config is None:
78
+ config = LLMConfig()
79
+
80
+ super().__init__(config, cache)
81
+
82
+ if client is None:
83
+ self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
84
+ else:
85
+ self.client = client
86
+
87
+ async def _generate_response(
88
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
89
+ ) -> dict[str, typing.Any]:
90
+ openai_messages: list[ChatCompletionMessageParam] = []
91
+ for m in messages:
92
+ m.content = self._clean_input(m.content)
93
+ if m.role == 'user':
94
+ openai_messages.append({'role': 'user', 'content': m.content})
95
+ elif m.role == 'system':
96
+ openai_messages.append({'role': 'system', 'content': m.content})
97
+ try:
98
+ response = await self.client.chat.completions.create(
99
+ model=self.model or DEFAULT_MODEL,
100
+ messages=openai_messages,
101
+ temperature=self.temperature,
102
+ max_tokens=self.max_tokens,
103
+ response_format={'type': 'json_object'},
104
+ )
105
+ result = response.choices[0].message.content or ''
106
+ return json.loads(result)
107
+ except openai.RateLimitError as e:
108
+ raise RateLimitError from e
109
+ except Exception as e:
110
+ logger.error(f'Error in generating LLM response: {e}')
111
+ raise
112
+
113
+ async def generate_response(
114
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
115
+ ) -> dict[str, typing.Any]:
116
+ retry_count = 0
117
+ last_error = None
118
+
119
+ if response_model is not None:
120
+ serialized_model = json.dumps(response_model.model_json_schema())
121
+ messages[
122
+ -1
123
+ ].content += (
124
+ f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
125
+ )
126
+
127
+ while retry_count <= self.MAX_RETRIES:
128
+ try:
129
+ response = await self._generate_response(messages, response_model)
130
+ return response
131
+ except (RateLimitError, RefusalError):
132
+ # These errors should not trigger retries
133
+ raise
134
+ except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
135
+ # Let OpenAI's client handle these retries
136
+ raise
137
+ except Exception as e:
138
+ last_error = e
139
+
140
+ # Don't retry if we've hit the max retries
141
+ if retry_count >= self.MAX_RETRIES:
142
+ logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
143
+ raise
144
+
145
+ retry_count += 1
146
+
147
+ # Construct a detailed error message for the LLM
148
+ error_context = (
149
+ f'The previous response attempt was invalid. '
150
+ f'Error type: {e.__class__.__name__}. '
151
+ f'Error details: {str(e)}. '
152
+ f'Please try again with a valid response, ensuring the output matches '
153
+ f'the expected format and constraints.'
154
+ )
155
+
156
+ error_message = Message(role='user', content=error_context)
157
+ messages.append(error_message)
158
+ logger.warning(
159
+ f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
160
+ )
161
+
162
+ # If we somehow get here, raise the last error
163
+ raise last_error or Exception('Max retries exceeded with no specific error')
graphiti_core/nodes.py CHANGED
@@ -16,7 +16,7 @@ limitations under the License.
16
16
 
17
17
  import logging
18
18
  from abc import ABC, abstractmethod
19
- from datetime import datetime, timezone
19
+ from datetime import datetime
20
20
  from enum import Enum
21
21
  from time import time
22
22
  from typing import Any
@@ -28,12 +28,13 @@ from typing_extensions import LiteralString
28
28
 
29
29
  from graphiti_core.embedder import EmbedderClient
30
30
  from graphiti_core.errors import NodeNotFoundError
31
- from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT
31
+ from graphiti_core.helpers import DEFAULT_DATABASE
32
32
  from graphiti_core.models.nodes.node_db_queries import (
33
33
  COMMUNITY_NODE_SAVE,
34
34
  ENTITY_NODE_SAVE,
35
35
  EPISODIC_NODE_SAVE,
36
36
  )
37
+ from graphiti_core.utils.datetime_utils import utc_now
37
38
 
38
39
  logger = logging.getLogger(__name__)
39
40
 
@@ -79,7 +80,7 @@ class Node(BaseModel, ABC):
79
80
  name: str = Field(description='name of the node')
80
81
  group_id: str = Field(description='partition of the graph')
81
82
  labels: list[str] = Field(default_factory=list)
82
- created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
83
+ created_at: datetime = Field(default_factory=lambda: utc_now())
83
84
 
84
85
  @abstractmethod
85
86
  async def save(self, driver: AsyncDriver): ...
@@ -212,10 +213,11 @@ class EpisodicNode(Node):
212
213
  cls,
213
214
  driver: AsyncDriver,
214
215
  group_ids: list[str],
215
- limit: int = DEFAULT_PAGE_LIMIT,
216
+ limit: int | None = None,
216
217
  created_at: datetime | None = None,
217
218
  ):
218
219
  cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
220
+ limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
219
221
 
220
222
  records, _, _ = await driver.execute_query(
221
223
  """
@@ -233,8 +235,8 @@ class EpisodicNode(Node):
233
235
  e.source_description AS source_description,
234
236
  e.source AS source
235
237
  ORDER BY e.uuid DESC
236
- LIMIT $limit
237
- """,
238
+ """
239
+ + limit_query,
238
240
  group_ids=group_ids,
239
241
  created_at=created_at,
240
242
  limit=limit,
@@ -328,10 +330,11 @@ class EntityNode(Node):
328
330
  cls,
329
331
  driver: AsyncDriver,
330
332
  group_ids: list[str],
331
- limit: int = DEFAULT_PAGE_LIMIT,
333
+ limit: int | None = None,
332
334
  created_at: datetime | None = None,
333
335
  ):
334
336
  cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
337
+ limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
335
338
 
336
339
  records, _, _ = await driver.execute_query(
337
340
  """
@@ -347,8 +350,8 @@ class EntityNode(Node):
347
350
  n.created_at AS created_at,
348
351
  n.summary AS summary
349
352
  ORDER BY n.uuid DESC
350
- LIMIT $limit
351
- """,
353
+ """
354
+ + limit_query,
352
355
  group_ids=group_ids,
353
356
  created_at=created_at,
354
357
  limit=limit,
@@ -442,10 +445,11 @@ class CommunityNode(Node):
442
445
  cls,
443
446
  driver: AsyncDriver,
444
447
  group_ids: list[str],
445
- limit: int = DEFAULT_PAGE_LIMIT,
448
+ limit: int | None = None,
446
449
  created_at: datetime | None = None,
447
450
  ):
448
451
  cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
452
+ limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
449
453
 
450
454
  records, _, _ = await driver.execute_query(
451
455
  """
@@ -461,8 +465,8 @@ class CommunityNode(Node):
461
465
  n.created_at AS created_at,
462
466
  n.summary AS summary
463
467
  ORDER BY n.uuid DESC
464
- LIMIT $limit
465
- """,
468
+ """
469
+ + limit_query,
466
470
  group_ids=group_ids,
467
471
  created_at=created_at,
468
472
  limit=limit,
@@ -15,11 +15,30 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import json
18
- from typing import Any, Protocol, TypedDict
18
+ from typing import Any, Optional, Protocol, TypedDict
19
+
20
+ from pydantic import BaseModel, Field
19
21
 
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class EdgeDuplicate(BaseModel):
26
+ is_duplicate: bool = Field(..., description='true or false')
27
+ uuid: Optional[str] = Field(
28
+ None,
29
+ description="uuid of the existing edge like '5d643020624c42fa9de13f97b1b3fa39' or null",
30
+ )
31
+
32
+
33
+ class UniqueFact(BaseModel):
34
+ uuid: str = Field(..., description='unique identifier of the fact')
35
+ fact: str = Field(..., description='fact of a unique edge')
36
+
37
+
38
+ class UniqueFacts(BaseModel):
39
+ unique_facts: list[UniqueFact]
40
+
41
+
23
42
  class Prompt(Protocol):
24
43
  edge: PromptVersion
25
44
  edge_list: PromptVersion
@@ -56,12 +75,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
56
75
 
57
76
  Guidelines:
58
77
  1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
59
-
60
- Respond with a JSON object in the following format:
61
- {{
62
- "is_duplicate": true or false,
63
- "uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
64
- }}
65
78
  """,
66
79
  ),
67
80
  ]
@@ -90,16 +103,6 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
90
103
  3. Facts will often discuss the same or similar relation between identical entities
91
104
  4. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
92
105
  facts should be in the response
93
-
94
- Respond with a JSON object in the following format:
95
- {{
96
- "unique_facts": [
97
- {{
98
- "uuid": "unique identifier of the fact",
99
- "fact": "fact of a unique edge"
100
- }}
101
- ]
102
- }}
103
106
  """,
104
107
  ),
105
108
  ]
@@ -15,11 +15,25 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import json
18
- from typing import Any, Protocol, TypedDict
18
+ from typing import Any, Optional, Protocol, TypedDict
19
+
20
+ from pydantic import BaseModel, Field
19
21
 
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class NodeDuplicate(BaseModel):
26
+ is_duplicate: bool = Field(..., description='true or false')
27
+ uuid: Optional[str] = Field(
28
+ None,
29
+ description="uuid of the existing node like '5d643020624c42fa9de13f97b1b3fa39' or null",
30
+ )
31
+ name: str = Field(
32
+ ...,
33
+ description="Updated name of the new node (use the best name between the new node's name, an existing duplicate name, or a combination of both)",
34
+ )
35
+
36
+
23
37
  class Prompt(Protocol):
24
38
  node: PromptVersion
25
39
  node_list: PromptVersion
@@ -17,9 +17,26 @@ limitations under the License.
17
17
  import json
18
18
  from typing import Any, Protocol, TypedDict
19
19
 
20
+ from pydantic import BaseModel, Field
21
+
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class QueryExpansion(BaseModel):
26
+ query: str = Field(..., description='query optimized for database search')
27
+
28
+
29
+ class QAResponse(BaseModel):
30
+ ANSWER: str = Field(..., description='how Alice would answer the question')
31
+
32
+
33
+ class EvalResponse(BaseModel):
34
+ is_correct: bool = Field(..., description='boolean if the answer is correct or incorrect')
35
+ reasoning: str = Field(
36
+ ..., description='why you determined the response was correct or incorrect'
37
+ )
38
+
39
+
23
40
  class Prompt(Protocol):
24
41
  qa_prompt: PromptVersion
25
42
  eval_prompt: PromptVersion
@@ -41,10 +58,6 @@ def query_expansion(context: dict[str, Any]) -> list[Message]:
41
58
  <QUESTION>
42
59
  {json.dumps(context['query'])}
43
60
  </QUESTION>
44
- respond with a JSON object in the following format:
45
- {{
46
- "query": "query optimized for database search"
47
- }}
48
61
  """
49
62
  return [
50
63
  Message(role='system', content=sys_prompt),
@@ -67,10 +80,6 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
67
80
  <QUESTION>
68
81
  {context['query']}
69
82
  </QUESTION>
70
- respond with a JSON object in the following format:
71
- {{
72
- "ANSWER": "how Alice would answer the question"
73
- }}
74
83
  """
75
84
  return [
76
85
  Message(role='system', content=sys_prompt),
@@ -96,12 +105,6 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
96
105
  <RESPONSE>
97
106
  {context['response']}
98
107
  </RESPONSE>
99
-
100
- respond with a JSON object in the following format:
101
- {{
102
- "is_correct": "boolean if the answer is correct or incorrect"
103
- "reasoning": "why you determined the response was correct or incorrect"
104
- }}
105
108
  """
106
109
  return [
107
110
  Message(role='system', content=sys_prompt),
@@ -14,11 +14,24 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from typing import Any, Protocol, TypedDict
17
+ from typing import Any, Optional, Protocol, TypedDict
18
+
19
+ from pydantic import BaseModel, Field
18
20
 
19
21
  from .models import Message, PromptFunction, PromptVersion
20
22
 
21
23
 
24
+ class EdgeDates(BaseModel):
25
+ valid_at: Optional[str] = Field(
26
+ None,
27
+ description='The date and time when the relationship described by the edge fact became true or was established. YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null.',
28
+ )
29
+ invalid_at: Optional[str] = Field(
30
+ None,
31
+ description='The date and time when the relationship described by the edge fact stopped being true or ended. YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null.',
32
+ )
33
+
34
+
22
35
  class Prompt(Protocol):
23
36
  v1: PromptVersion
24
37
 
@@ -60,7 +73,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
60
73
  Analyze the conversation and determine if there are dates that are part of the edge fact. Only set dates if they explicitly relate to the formation or alteration of the relationship itself.
61
74
 
62
75
  Guidelines:
63
- 1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
76
+ 1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ) for datetimes.
64
77
  2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
65
78
  3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
66
79
  4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
@@ -69,11 +82,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
69
82
  7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
70
83
  8. If only year is mentioned, use January 1st of that year at 00:00:00.
71
84
  9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
72
- Respond with a JSON object:
73
- {{
74
- "valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
75
- "invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
76
- }}
77
85
  """,
78
86
  ),
79
87
  ]
@@ -17,9 +17,26 @@ limitations under the License.
17
17
  import json
18
18
  from typing import Any, Protocol, TypedDict
19
19
 
20
+ from pydantic import BaseModel, Field
21
+
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class Edge(BaseModel):
26
+ relation_type: str = Field(..., description='RELATION_TYPE_IN_CAPS')
27
+ source_entity_name: str = Field(..., description='name of the source entity')
28
+ target_entity_name: str = Field(..., description='name of the target entity')
29
+ fact: str = Field(..., description='extracted factual information')
30
+
31
+
32
+ class ExtractedEdges(BaseModel):
33
+ edges: list[Edge]
34
+
35
+
36
+ class MissingFacts(BaseModel):
37
+ missing_facts: list[str] = Field(..., description="facts that weren't extracted")
38
+
39
+
23
40
  class Prompt(Protocol):
24
41
  edge: PromptVersion
25
42
  reflexion: PromptVersion
@@ -54,25 +71,12 @@ def edge(context: dict[str, Any]) -> list[Message]:
54
71
 
55
72
  Given the above MESSAGES and ENTITIES, extract all facts pertaining to the listed ENTITIES from the CURRENT MESSAGE.
56
73
 
57
-
58
74
  Guidelines:
59
75
  1. Extract facts only between the provided entities.
60
76
  2. Each fact should represent a clear relationship between two DISTINCT nodes.
61
77
  3. The relation_type should be a concise, all-caps description of the fact (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
62
78
  4. Provide a more detailed fact containing all relevant information.
63
79
  5. Consider temporal aspects of relationships when relevant.
64
-
65
- Respond with a JSON object in the following format:
66
- {{
67
- "edges": [
68
- {{
69
- "relation_type": "RELATION_TYPE_IN_CAPS",
70
- "source_entity_name": "name of the source entity",
71
- "target_entity_name": "name of the target entity",
72
- "fact": "extracted factual information",
73
- }}
74
- ]
75
- }}
76
80
  """,
77
81
  ),
78
82
  ]
@@ -98,12 +102,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
98
102
  </EXTRACTED FACTS>
99
103
 
100
104
  Given the above MESSAGES, list of EXTRACTED ENTITIES entities, and list of EXTRACTED FACTS;
101
- determine if any facts haven't been extracted:
102
-
103
- Respond with a JSON object in the following format:
104
- {{
105
- "missing_facts": [ "facts that weren't extracted", ...]
106
- }}
105
+ determine if any facts haven't been extracted.
107
106
  """
108
107
  return [
109
108
  Message(role='system', content=sys_prompt),
@@ -17,9 +17,19 @@ limitations under the License.
17
17
  import json
18
18
  from typing import Any, Protocol, TypedDict
19
19
 
20
+ from pydantic import BaseModel, Field
21
+
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class ExtractedNodes(BaseModel):
26
+ extracted_node_names: list[str] = Field(..., description='Name of the extracted entity')
27
+
28
+
29
+ class MissedEntities(BaseModel):
30
+ missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
31
+
32
+
23
33
  class Prompt(Protocol):
24
34
  extract_message: PromptVersion
25
35
  extract_json: PromptVersion
@@ -56,11 +66,6 @@ Guidelines:
56
66
  4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later).
57
67
  5. Be as explicit as possible in your node names, using full names.
58
68
  6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context.
59
-
60
- Respond with a JSON object in the following format:
61
- {{
62
- "extracted_node_names": ["Name of the extracted entity", ...],
63
- }}
64
69
  """
65
70
  return [
66
71
  Message(role='system', content=sys_prompt),
@@ -87,11 +92,6 @@ Given the above source description and JSON, extract relevant entity nodes from
87
92
  Guidelines:
88
93
  1. Always try to extract an entities that the JSON represents. This will often be something like a "name" or "user field
89
94
  2. Do NOT extract any properties that contain dates
90
-
91
- Respond with a JSON object in the following format:
92
- {{
93
- "extracted_node_names": ["Name of the extracted entity", ...],
94
- }}
95
95
  """
96
96
  return [
97
97
  Message(role='system', content=sys_prompt),
@@ -116,11 +116,6 @@ Guidelines:
116
116
  2. Avoid creating nodes for relationships or actions.
117
117
  3. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
118
118
  4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
119
-
120
- Respond with a JSON object in the following format:
121
- {{
122
- "extracted_node_names": ["Name of the extracted entity", ...],
123
- }}
124
119
  """
125
120
  return [
126
121
  Message(role='system', content=sys_prompt),
@@ -144,12 +139,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
144
139
  </EXTRACTED ENTITIES>
145
140
 
146
141
  Given the above previous messages, current message, and list of extracted entities; determine if any entities haven't been
147
- extracted:
148
-
149
- Respond with a JSON object in the following format:
150
- {{
151
- "missed_entities": [ "name of entity that wasn't extracted", ...]
152
- }}
142
+ extracted.
153
143
  """
154
144
  return [
155
145
  Message(role='system', content=sys_prompt),
@@ -16,9 +16,22 @@ limitations under the License.
16
16
 
17
17
  from typing import Any, Protocol, TypedDict
18
18
 
19
+ from pydantic import BaseModel, Field
20
+
19
21
  from .models import Message, PromptFunction, PromptVersion
20
22
 
21
23
 
24
+ class InvalidatedEdge(BaseModel):
25
+ uuid: str = Field(..., description='The UUID of the edge to be invalidated')
26
+ fact: str = Field(..., description='Updated fact of the edge')
27
+
28
+
29
+ class InvalidatedEdges(BaseModel):
30
+ invalidated_edges: list[InvalidatedEdge] = Field(
31
+ ..., description='List of edges that should be invalidated'
32
+ )
33
+
34
+
22
35
  class Prompt(Protocol):
23
36
  v1: PromptVersion
24
37
  v2: PromptVersion
@@ -56,18 +69,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
56
69
  {context['new_edges']}
57
70
 
58
71
  Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), START_DATE (END_DATE, optional))"
59
-
60
- For each existing edge that should be invalidated, respond with a JSON object in the following format:
61
- {{
62
- "invalidated_edges": [
63
- {{
64
- "edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
65
- "fact": "Updated fact of the edge"
66
- }}
67
- ]
68
- }}
69
-
70
- If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
71
72
  """,
72
73
  ),
73
74
  ]
@@ -89,19 +90,6 @@ def v2(context: dict[str, Any]) -> list[Message]:
89
90
 
90
91
  New Edge:
91
92
  {context['new_edge']}
92
-
93
-
94
- For each existing edge that should be invalidated, respond with a JSON object in the following format:
95
- {{
96
- "invalidated_edges": [
97
- {{
98
- "uuid": "The UUID of the edge to be invalidated",
99
- "fact": "Updated fact of the edge"
100
- }}
101
- ]
102
- }}
103
-
104
- If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
105
93
  """,
106
94
  ),
107
95
  ]