graphiti-core 0.4.2__tar.gz → 0.5.0rc1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/PKG-INFO +1 -1
  2. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/cross_encoder/bge_reranker_client.py +1 -2
  3. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/cross_encoder/client.py +3 -4
  4. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/edges.py +51 -5
  5. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/embedder/client.py +3 -3
  6. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/embedder/openai.py +2 -2
  7. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/embedder/voyage.py +3 -3
  8. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/graphiti.py +14 -10
  9. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/helpers.py +1 -0
  10. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/llm_client/anthropic_client.py +4 -1
  11. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/llm_client/client.py +20 -5
  12. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/llm_client/errors.py +8 -0
  13. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/llm_client/groq_client.py +4 -1
  14. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/llm_client/openai_client.py +29 -7
  15. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/nodes.py +50 -4
  16. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/prompts/dedupe_edges.py +20 -17
  17. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/prompts/dedupe_nodes.py +15 -1
  18. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/prompts/eval.py +17 -14
  19. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/prompts/extract_edge_dates.py +15 -7
  20. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/prompts/extract_edges.py +18 -19
  21. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/prompts/extract_nodes.py +11 -21
  22. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/prompts/invalidate_edges.py +13 -25
  23. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/prompts/lib.py +5 -1
  24. graphiti_core-0.5.0rc1/graphiti_core/prompts/prompt_helpers.py +1 -0
  25. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/prompts/summarize_nodes.py +12 -16
  26. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/search/search_utils.py +1 -1
  27. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/utils/maintenance/community_operations.py +4 -2
  28. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/utils/maintenance/edge_operations.py +14 -11
  29. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/utils/maintenance/node_operations.py +14 -7
  30. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/utils/maintenance/temporal_operations.py +9 -4
  31. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/pyproject.toml +1 -1
  32. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/LICENSE +0 -0
  33. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/README.md +0 -0
  34. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/__init__.py +0 -0
  35. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/cross_encoder/__init__.py +0 -0
  36. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/cross_encoder/openai_reranker_client.py +0 -0
  37. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/embedder/__init__.py +0 -0
  38. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/errors.py +0 -0
  39. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/llm_client/__init__.py +0 -0
  40. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/llm_client/config.py +0 -0
  41. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/llm_client/utils.py +0 -0
  42. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/models/__init__.py +0 -0
  43. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/models/edges/__init__.py +0 -0
  44. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/models/edges/edge_db_queries.py +0 -0
  45. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/models/nodes/__init__.py +0 -0
  46. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/models/nodes/node_db_queries.py +0 -0
  47. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/prompts/__init__.py +0 -0
  48. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/prompts/models.py +0 -0
  49. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/py.typed +0 -0
  50. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/search/__init__.py +0 -0
  51. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/search/search.py +0 -0
  52. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/search/search_config.py +0 -0
  53. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/search/search_config_recipes.py +0 -0
  54. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/utils/__init__.py +0 -0
  55. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/utils/bulk_utils.py +0 -0
  56. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/utils/maintenance/__init__.py +0 -0
  57. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
  58. {graphiti_core-0.4.2 → graphiti_core-0.5.0rc1}/graphiti_core/utils/maintenance/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.4.2
3
+ Version: 0.5.0rc1
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -15,7 +15,6 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import asyncio
18
- from typing import List, Tuple
19
18
 
20
19
  from sentence_transformers import CrossEncoder
21
20
 
@@ -26,7 +25,7 @@ class BGERerankerClient(CrossEncoderClient):
26
25
  def __init__(self):
27
26
  self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
28
27
 
29
- async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
28
+ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
30
29
  if not passages:
31
30
  return []
32
31
 
@@ -15,7 +15,6 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  from abc import ABC, abstractmethod
18
- from typing import List, Tuple
19
18
 
20
19
 
21
20
  class CrossEncoderClient(ABC):
@@ -26,16 +25,16 @@ class CrossEncoderClient(ABC):
26
25
  """
27
26
 
28
27
  @abstractmethod
29
- async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
28
+ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
30
29
  """
31
30
  Rank the given passages based on their relevance to the query.
32
31
 
33
32
  Args:
34
33
  query (str): The query string.
35
- passages (List[str]): A list of passages to rank.
34
+ passages (list[str]): A list of passages to rank.
36
35
 
37
36
  Returns:
38
- List[Tuple[str, float]]: A list of tuples containing the passage and its score,
37
+ list[tuple[str, float]]: A list of tuples containing the passage and its score,
39
38
  sorted in descending order of relevance.
40
39
  """
41
40
  pass
@@ -23,10 +23,11 @@ from uuid import uuid4
23
23
 
24
24
  from neo4j import AsyncDriver
25
25
  from pydantic import BaseModel, Field
26
+ from typing_extensions import LiteralString
26
27
 
27
28
  from graphiti_core.embedder import EmbedderClient
28
29
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
29
- from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
30
+ from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT, parse_db_date
30
31
  from graphiti_core.models.edges.edge_db_queries import (
31
32
  COMMUNITY_EDGE_SAVE,
32
33
  ENTITY_EDGE_SAVE,
@@ -50,7 +51,7 @@ class Edge(BaseModel, ABC):
50
51
  async def delete(self, driver: AsyncDriver):
51
52
  result = await driver.execute_query(
52
53
  """
53
- MATCH (n)-[e {uuid: $uuid}]->(m)
54
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
54
55
  DELETE e
55
56
  """,
56
57
  uuid=self.uuid,
@@ -137,19 +138,34 @@ class EpisodicEdge(Edge):
137
138
  return edges
138
139
 
139
140
  @classmethod
140
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
141
+ async def get_by_group_ids(
142
+ cls,
143
+ driver: AsyncDriver,
144
+ group_ids: list[str],
145
+ limit: int = DEFAULT_PAGE_LIMIT,
146
+ created_at: datetime | None = None,
147
+ ):
148
+ cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
149
+
141
150
  records, _, _ = await driver.execute_query(
142
151
  """
143
152
  MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
144
153
  WHERE e.group_id IN $group_ids
154
+ """
155
+ + cursor_query
156
+ + """
145
157
  RETURN
146
158
  e.uuid As uuid,
147
159
  e.group_id AS group_id,
148
160
  n.uuid AS source_node_uuid,
149
161
  m.uuid AS target_node_uuid,
150
162
  e.created_at AS created_at
163
+ ORDER BY e.uuid DESC
164
+ LIMIT $limit
151
165
  """,
152
166
  group_ids=group_ids,
167
+ created_at=created_at,
168
+ limit=limit,
153
169
  database_=DEFAULT_DATABASE,
154
170
  routing_='r',
155
171
  )
@@ -274,11 +290,22 @@ class EntityEdge(Edge):
274
290
  return edges
275
291
 
276
292
  @classmethod
277
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
293
+ async def get_by_group_ids(
294
+ cls,
295
+ driver: AsyncDriver,
296
+ group_ids: list[str],
297
+ limit: int = DEFAULT_PAGE_LIMIT,
298
+ created_at: datetime | None = None,
299
+ ):
300
+ cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
301
+
278
302
  records, _, _ = await driver.execute_query(
279
303
  """
280
304
  MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
281
305
  WHERE e.group_id IN $group_ids
306
+ """
307
+ + cursor_query
308
+ + """
282
309
  RETURN
283
310
  e.uuid AS uuid,
284
311
  n.uuid AS source_node_uuid,
@@ -292,8 +319,12 @@ class EntityEdge(Edge):
292
319
  e.expired_at AS expired_at,
293
320
  e.valid_at AS valid_at,
294
321
  e.invalid_at AS invalid_at
322
+ ORDER BY e.uuid DESC
323
+ LIMIT $limit
295
324
  """,
296
325
  group_ids=group_ids,
326
+ created_at=created_at,
327
+ limit=limit,
297
328
  database_=DEFAULT_DATABASE,
298
329
  routing_='r',
299
330
  )
@@ -365,19 +396,34 @@ class CommunityEdge(Edge):
365
396
  return edges
366
397
 
367
398
  @classmethod
368
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
399
+ async def get_by_group_ids(
400
+ cls,
401
+ driver: AsyncDriver,
402
+ group_ids: list[str],
403
+ limit: int = DEFAULT_PAGE_LIMIT,
404
+ created_at: datetime | None = None,
405
+ ):
406
+ cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
407
+
369
408
  records, _, _ = await driver.execute_query(
370
409
  """
371
410
  MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
372
411
  WHERE e.group_id IN $group_ids
412
+ """
413
+ + cursor_query
414
+ + """
373
415
  RETURN
374
416
  e.uuid As uuid,
375
417
  e.group_id AS group_id,
376
418
  n.uuid AS source_node_uuid,
377
419
  m.uuid AS target_node_uuid,
378
420
  e.created_at AS created_at
421
+ ORDER BY e.uuid DESC
422
+ LIMIT $limit
379
423
  """,
380
424
  group_ids=group_ids,
425
+ created_at=created_at,
426
+ limit=limit,
381
427
  database_=DEFAULT_DATABASE,
382
428
  routing_='r',
383
429
  )
@@ -15,7 +15,7 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  from abc import ABC, abstractmethod
18
- from typing import Iterable, List, Literal
18
+ from collections.abc import Iterable
19
19
 
20
20
  from pydantic import BaseModel, Field
21
21
 
@@ -23,12 +23,12 @@ EMBEDDING_DIM = 1024
23
23
 
24
24
 
25
25
  class EmbedderConfig(BaseModel):
26
- embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True)
26
+ embedding_dim: int = Field(default=EMBEDDING_DIM, frozen=True)
27
27
 
28
28
 
29
29
  class EmbedderClient(ABC):
30
30
  @abstractmethod
31
31
  async def create(
32
- self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
32
+ self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
33
33
  ) -> list[float]:
34
34
  pass
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from typing import Iterable, List
17
+ from collections.abc import Iterable
18
18
 
19
19
  from openai import AsyncOpenAI
20
20
  from openai.types import EmbeddingModel
@@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
42
42
  self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
43
43
 
44
44
  async def create(
45
- self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
45
+ self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
46
46
  ) -> list[float]:
47
47
  result = await self.client.embeddings.create(
48
48
  input=input_data, model=self.config.embedding_model
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from typing import Iterable, List
17
+ from collections.abc import Iterable
18
18
 
19
19
  import voyageai # type: ignore
20
20
  from pydantic import Field
@@ -41,11 +41,11 @@ class VoyageAIEmbedder(EmbedderClient):
41
41
  self.client = voyageai.AsyncClient(api_key=config.api_key)
42
42
 
43
43
  async def create(
44
- self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
44
+ self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
45
45
  ) -> list[float]:
46
46
  if isinstance(input_data, str):
47
47
  input_list = [input_data]
48
- elif isinstance(input_data, List):
48
+ elif isinstance(input_data, list):
49
49
  input_list = [str(i) for i in input_data if i]
50
50
  else:
51
51
  input_list = [str(i) for i in input_data if i is not None]
@@ -318,17 +318,21 @@ class Graphiti:
318
318
  previous_episodes = await self.retrieve_episodes(
319
319
  reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
320
320
  )
321
- episode = EpisodicNode(
322
- name=name,
323
- group_id=group_id,
324
- labels=[],
325
- source=source,
326
- content=episode_body,
327
- source_description=source_description,
328
- created_at=now,
329
- valid_at=reference_time,
321
+
322
+ episode = (
323
+ await EpisodicNode.get_by_uuid(self.driver, uuid)
324
+ if uuid is not None
325
+ else EpisodicNode(
326
+ name=name,
327
+ group_id=group_id,
328
+ labels=[],
329
+ source=source,
330
+ content=episode_body,
331
+ source_description=source_description,
332
+ created_at=now,
333
+ valid_at=reference_time,
334
+ )
330
335
  )
331
- episode.uuid = uuid if uuid is not None else episode.uuid
332
336
 
333
337
  # Extract entities as nodes
334
338
 
@@ -26,6 +26,7 @@ load_dotenv()
26
26
  DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
27
27
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
28
28
  MAX_REFLEXION_ITERATIONS = 2
29
+ DEFAULT_PAGE_LIMIT = 20
29
30
 
30
31
 
31
32
  def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
@@ -20,6 +20,7 @@ import typing
20
20
 
21
21
  import anthropic
22
22
  from anthropic import AsyncAnthropic
23
+ from pydantic import BaseModel
23
24
 
24
25
  from ..prompts.models import Message
25
26
  from .client import LLMClient
@@ -46,7 +47,9 @@ class AnthropicClient(LLMClient):
46
47
  max_retries=1,
47
48
  )
48
49
 
49
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
50
+ async def _generate_response(
51
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
52
+ ) -> dict[str, typing.Any]:
50
53
  system_message = messages[0]
51
54
  user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
52
55
  {'role': 'assistant', 'content': '{'}
@@ -22,6 +22,7 @@ from abc import ABC, abstractmethod
22
22
 
23
23
  import httpx
24
24
  from diskcache import Cache
25
+ from pydantic import BaseModel
25
26
  from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
26
27
 
27
28
  from ..prompts.models import Message
@@ -66,14 +67,18 @@ class LLMClient(ABC):
66
67
  else None,
67
68
  reraise=True,
68
69
  )
69
- async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
70
+ async def _generate_response_with_retry(
71
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
72
+ ) -> dict[str, typing.Any]:
70
73
  try:
71
- return await self._generate_response(messages)
74
+ return await self._generate_response(messages, response_model)
72
75
  except (httpx.HTTPStatusError, RateLimitError) as e:
73
76
  raise e
74
77
 
75
78
  @abstractmethod
76
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
79
+ async def _generate_response(
80
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
81
+ ) -> dict[str, typing.Any]:
77
82
  pass
78
83
 
79
84
  def _get_cache_key(self, messages: list[Message]) -> str:
@@ -82,7 +87,17 @@ class LLMClient(ABC):
82
87
  key_str = f'{self.model}:{message_str}'
83
88
  return hashlib.md5(key_str.encode()).hexdigest()
84
89
 
85
- async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
90
+ async def generate_response(
91
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
92
+ ) -> dict[str, typing.Any]:
93
+ if response_model is not None:
94
+ serialized_model = json.dumps(response_model.model_json_schema())
95
+ messages[
96
+ -1
97
+ ].content += (
98
+ f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
99
+ )
100
+
86
101
  if self.cache_enabled:
87
102
  cache_key = self._get_cache_key(messages)
88
103
 
@@ -91,7 +106,7 @@ class LLMClient(ABC):
91
106
  logger.debug(f'Cache hit for {cache_key}')
92
107
  return cached_response
93
108
 
94
- response = await self._generate_response_with_retry(messages)
109
+ response = await self._generate_response_with_retry(messages, response_model)
95
110
 
96
111
  if self.cache_enabled:
97
112
  self.cache_dir.set(cache_key, response)
@@ -21,3 +21,11 @@ class RateLimitError(Exception):
21
21
  def __init__(self, message='Rate limit exceeded. Please try again later.'):
22
22
  self.message = message
23
23
  super().__init__(self.message)
24
+
25
+
26
+ class RefusalError(Exception):
27
+ """Exception raised when the LLM refuses to generate a response."""
28
+
29
+ def __init__(self, message: str):
30
+ self.message = message
31
+ super().__init__(self.message)
@@ -21,6 +21,7 @@ import typing
21
21
  import groq
22
22
  from groq import AsyncGroq
23
23
  from groq.types.chat import ChatCompletionMessageParam
24
+ from pydantic import BaseModel
24
25
 
25
26
  from ..prompts.models import Message
26
27
  from .client import LLMClient
@@ -43,7 +44,9 @@ class GroqClient(LLMClient):
43
44
 
44
45
  self.client = AsyncGroq(api_key=config.api_key)
45
46
 
46
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
47
+ async def _generate_response(
48
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
49
+ ) -> dict[str, typing.Any]:
47
50
  msgs: list[ChatCompletionMessageParam] = []
48
51
  for m in messages:
49
52
  if m.role == 'user':
@@ -14,18 +14,18 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import json
18
17
  import logging
19
18
  import typing
20
19
 
21
20
  import openai
22
21
  from openai import AsyncOpenAI
23
22
  from openai.types.chat import ChatCompletionMessageParam
23
+ from pydantic import BaseModel
24
24
 
25
25
  from ..prompts.models import Message
26
26
  from .client import LLMClient
27
27
  from .config import LLMConfig
28
- from .errors import RateLimitError
28
+ from .errors import RateLimitError, RefusalError
29
29
 
30
30
  logger = logging.getLogger(__name__)
31
31
 
@@ -65,6 +65,10 @@ class OpenAIClient(LLMClient):
65
65
  client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
66
66
 
67
67
  """
68
+ # removed caching to simplify the `generate_response` override
69
+ if cache:
70
+ raise NotImplementedError('Caching is not implemented for OpenAI')
71
+
68
72
  if config is None:
69
73
  config = LLMConfig()
70
74
 
@@ -75,7 +79,9 @@ class OpenAIClient(LLMClient):
75
79
  else:
76
80
  self.client = client
77
81
 
78
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
82
+ async def _generate_response(
83
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
84
+ ) -> dict[str, typing.Any]:
79
85
  openai_messages: list[ChatCompletionMessageParam] = []
80
86
  for m in messages:
81
87
  if m.role == 'user':
@@ -83,17 +89,33 @@ class OpenAIClient(LLMClient):
83
89
  elif m.role == 'system':
84
90
  openai_messages.append({'role': 'system', 'content': m.content})
85
91
  try:
86
- response = await self.client.chat.completions.create(
92
+ response = await self.client.beta.chat.completions.parse(
87
93
  model=self.model or DEFAULT_MODEL,
88
94
  messages=openai_messages,
89
95
  temperature=self.temperature,
90
96
  max_tokens=self.max_tokens,
91
- response_format={'type': 'json_object'},
97
+ response_format=response_model, # type: ignore
92
98
  )
93
- result = response.choices[0].message.content or ''
94
- return json.loads(result)
99
+
100
+ response_object = response.choices[0].message
101
+
102
+ if response_object.parsed:
103
+ return response_object.parsed.model_dump()
104
+ elif response_object.refusal:
105
+ raise RefusalError(response_object.refusal)
106
+ else:
107
+ raise Exception('No response from LLM')
108
+ except openai.LengthFinishReasonError as e:
109
+ raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
95
110
  except openai.RateLimitError as e:
96
111
  raise RateLimitError from e
97
112
  except Exception as e:
98
113
  logger.error(f'Error in generating LLM response: {e}')
99
114
  raise
115
+
116
+ async def generate_response(
117
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
118
+ ) -> dict[str, typing.Any]:
119
+ response = await self._generate_response(messages, response_model)
120
+
121
+ return response
@@ -24,10 +24,11 @@ from uuid import uuid4
24
24
 
25
25
  from neo4j import AsyncDriver
26
26
  from pydantic import BaseModel, Field
27
+ from typing_extensions import LiteralString
27
28
 
28
29
  from graphiti_core.embedder import EmbedderClient
29
30
  from graphiti_core.errors import NodeNotFoundError
30
- from graphiti_core.helpers import DEFAULT_DATABASE
31
+ from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT
31
32
  from graphiti_core.models.nodes.node_db_queries import (
32
33
  COMMUNITY_NODE_SAVE,
33
34
  ENTITY_NODE_SAVE,
@@ -207,10 +208,21 @@ class EpisodicNode(Node):
207
208
  return episodes
208
209
 
209
210
  @classmethod
210
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
211
+ async def get_by_group_ids(
212
+ cls,
213
+ driver: AsyncDriver,
214
+ group_ids: list[str],
215
+ limit: int = DEFAULT_PAGE_LIMIT,
216
+ created_at: datetime | None = None,
217
+ ):
218
+ cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
219
+
211
220
  records, _, _ = await driver.execute_query(
212
221
  """
213
222
  MATCH (e:Episodic) WHERE e.group_id IN $group_ids
223
+ """
224
+ + cursor_query
225
+ + """
214
226
  RETURN DISTINCT
215
227
  e.content AS content,
216
228
  e.created_at AS created_at,
@@ -220,8 +232,12 @@ class EpisodicNode(Node):
220
232
  e.group_id AS group_id,
221
233
  e.source_description AS source_description,
222
234
  e.source AS source
235
+ ORDER BY e.uuid DESC
236
+ LIMIT $limit
223
237
  """,
224
238
  group_ids=group_ids,
239
+ created_at=created_at,
240
+ limit=limit,
225
241
  database_=DEFAULT_DATABASE,
226
242
  routing_='r',
227
243
  )
@@ -308,10 +324,21 @@ class EntityNode(Node):
308
324
  return nodes
309
325
 
310
326
  @classmethod
311
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
327
+ async def get_by_group_ids(
328
+ cls,
329
+ driver: AsyncDriver,
330
+ group_ids: list[str],
331
+ limit: int = DEFAULT_PAGE_LIMIT,
332
+ created_at: datetime | None = None,
333
+ ):
334
+ cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
335
+
312
336
  records, _, _ = await driver.execute_query(
313
337
  """
314
338
  MATCH (n:Entity) WHERE n.group_id IN $group_ids
339
+ """
340
+ + cursor_query
341
+ + """
315
342
  RETURN
316
343
  n.uuid As uuid,
317
344
  n.name AS name,
@@ -319,8 +346,12 @@ class EntityNode(Node):
319
346
  n.group_id AS group_id,
320
347
  n.created_at AS created_at,
321
348
  n.summary AS summary
349
+ ORDER BY n.uuid DESC
350
+ LIMIT $limit
322
351
  """,
323
352
  group_ids=group_ids,
353
+ created_at=created_at,
354
+ limit=limit,
324
355
  database_=DEFAULT_DATABASE,
325
356
  routing_='r',
326
357
  )
@@ -407,10 +438,21 @@ class CommunityNode(Node):
407
438
  return communities
408
439
 
409
440
  @classmethod
410
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
441
+ async def get_by_group_ids(
442
+ cls,
443
+ driver: AsyncDriver,
444
+ group_ids: list[str],
445
+ limit: int = DEFAULT_PAGE_LIMIT,
446
+ created_at: datetime | None = None,
447
+ ):
448
+ cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
449
+
411
450
  records, _, _ = await driver.execute_query(
412
451
  """
413
452
  MATCH (n:Community) WHERE n.group_id IN $group_ids
453
+ """
454
+ + cursor_query
455
+ + """
414
456
  RETURN
415
457
  n.uuid As uuid,
416
458
  n.name AS name,
@@ -418,8 +460,12 @@ class CommunityNode(Node):
418
460
  n.group_id AS group_id,
419
461
  n.created_at AS created_at,
420
462
  n.summary AS summary
463
+ ORDER BY n.uuid DESC
464
+ LIMIT $limit
421
465
  """,
422
466
  group_ids=group_ids,
467
+ created_at=created_at,
468
+ limit=limit,
423
469
  database_=DEFAULT_DATABASE,
424
470
  routing_='r',
425
471
  )
@@ -15,11 +15,30 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import json
18
- from typing import Any, Protocol, TypedDict
18
+ from typing import Any, Optional, Protocol, TypedDict
19
+
20
+ from pydantic import BaseModel, Field
19
21
 
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class EdgeDuplicate(BaseModel):
26
+ is_duplicate: bool = Field(..., description='true or false')
27
+ uuid: Optional[str] = Field(
28
+ None,
29
+ description="uuid of the existing edge like '5d643020624c42fa9de13f97b1b3fa39' or null",
30
+ )
31
+
32
+
33
+ class UniqueFact(BaseModel):
34
+ uuid: str = Field(..., description='unique identifier of the fact')
35
+ fact: str = Field(..., description='fact of a unique edge')
36
+
37
+
38
+ class UniqueFacts(BaseModel):
39
+ unique_facts: list[UniqueFact]
40
+
41
+
23
42
  class Prompt(Protocol):
24
43
  edge: PromptVersion
25
44
  edge_list: PromptVersion
@@ -56,12 +75,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
56
75
 
57
76
  Guidelines:
58
77
  1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
59
-
60
- Respond with a JSON object in the following format:
61
- {{
62
- "is_duplicate": true or false,
63
- "uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
64
- }}
65
78
  """,
66
79
  ),
67
80
  ]
@@ -90,16 +103,6 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
90
103
  3. Facts will often discuss the same or similar relation between identical entities
91
104
  4. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
92
105
  facts should be in the response
93
-
94
- Respond with a JSON object in the following format:
95
- {{
96
- "unique_facts": [
97
- {{
98
- "uuid": "unique identifier of the fact",
99
- "fact": "fact of a unique edge"
100
- }}
101
- ]
102
- }}
103
106
  """,
104
107
  ),
105
108
  ]