graphiti-core 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/edges.py CHANGED
@@ -24,6 +24,7 @@ from uuid import uuid4
24
24
  from neo4j import AsyncDriver
25
25
  from pydantic import BaseModel, Field
26
26
 
27
+ from graphiti_core.errors import EdgeNotFoundError
27
28
  from graphiti_core.helpers import parse_db_date
28
29
  from graphiti_core.llm_client.config import EMBEDDING_DIM
29
30
  from graphiti_core.nodes import Node
@@ -41,8 +42,18 @@ class Edge(BaseModel, ABC):
41
42
  @abstractmethod
42
43
  async def save(self, driver: AsyncDriver): ...
43
44
 
44
- @abstractmethod
45
- async def delete(self, driver: AsyncDriver): ...
45
+ async def delete(self, driver: AsyncDriver):
46
+ result = await driver.execute_query(
47
+ """
48
+ MATCH (n)-[e {uuid: $uuid}]->(m)
49
+ DELETE e
50
+ """,
51
+ uuid=self.uuid,
52
+ )
53
+
54
+ logger.info(f'Deleted Edge: {self.uuid}')
55
+
56
+ return result
46
57
 
47
58
  def __hash__(self):
48
59
  return hash(self.uuid)
@@ -76,19 +87,6 @@ class EpisodicEdge(Edge):
76
87
 
77
88
  return result
78
89
 
79
- async def delete(self, driver: AsyncDriver):
80
- result = await driver.execute_query(
81
- """
82
- MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
83
- DELETE e
84
- """,
85
- uuid=self.uuid,
86
- )
87
-
88
- logger.info(f'Deleted Edge: {self.uuid}')
89
-
90
- return result
91
-
92
90
  @classmethod
93
91
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
94
92
  records, _, _ = await driver.execute_query(
@@ -107,7 +105,8 @@ class EpisodicEdge(Edge):
107
105
  edges = [get_episodic_edge_from_record(record) for record in records]
108
106
 
109
107
  logger.info(f'Found Edge: {uuid}')
110
-
108
+ if len(edges) == 0:
109
+ raise EdgeNotFoundError(uuid)
111
110
  return edges[0]
112
111
 
113
112
 
@@ -169,19 +168,6 @@ class EntityEdge(Edge):
169
168
 
170
169
  return result
171
170
 
172
- async def delete(self, driver: AsyncDriver):
173
- result = await driver.execute_query(
174
- """
175
- MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
176
- DELETE e
177
- """,
178
- uuid=self.uuid,
179
- )
180
-
181
- logger.info(f'Deleted Edge: {self.uuid}')
182
-
183
- return result
184
-
185
171
  @classmethod
186
172
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
187
173
  records, _, _ = await driver.execute_query(
@@ -206,6 +192,49 @@ class EntityEdge(Edge):
206
192
 
207
193
  edges = [get_entity_edge_from_record(record) for record in records]
208
194
 
195
+ logger.info(f'Found Edge: {uuid}')
196
+ if len(edges) == 0:
197
+ raise EdgeNotFoundError(uuid)
198
+ return edges[0]
199
+
200
+
201
+ class CommunityEdge(Edge):
202
+ async def save(self, driver: AsyncDriver):
203
+ result = await driver.execute_query(
204
+ """
205
+ MATCH (community:Community {uuid: $community_uuid})
206
+ MATCH (node:Entity | Community {uuid: $entity_uuid})
207
+ MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
208
+ SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
209
+ RETURN r.uuid AS uuid""",
210
+ community_uuid=self.source_node_uuid,
211
+ entity_uuid=self.target_node_uuid,
212
+ uuid=self.uuid,
213
+ group_id=self.group_id,
214
+ created_at=self.created_at,
215
+ )
216
+
217
+ logger.info(f'Saved edge to neo4j: {self.uuid}')
218
+
219
+ return result
220
+
221
+ @classmethod
222
+ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
223
+ records, _, _ = await driver.execute_query(
224
+ """
225
+ MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
226
+ RETURN
227
+ e.uuid As uuid,
228
+ e.group_id AS group_id,
229
+ n.uuid AS source_node_uuid,
230
+ m.uuid AS target_node_uuid,
231
+ e.created_at AS created_at
232
+ """,
233
+ uuid=uuid,
234
+ )
235
+
236
+ edges = [get_community_edge_from_record(record) for record in records]
237
+
209
238
  logger.info(f'Found Edge: {uuid}')
210
239
 
211
240
  return edges[0]
@@ -237,3 +266,13 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
237
266
  valid_at=parse_db_date(record['valid_at']),
238
267
  invalid_at=parse_db_date(record['invalid_at']),
239
268
  )
269
+
270
+
271
+ def get_community_edge_from_record(record: Any):
272
+ return CommunityEdge(
273
+ uuid=record['uuid'],
274
+ group_id=record['group_id'],
275
+ source_node_uuid=record['source_node_uuid'],
276
+ target_node_uuid=record['target_node_uuid'],
277
+ created_at=record['created_at'].to_native(),
278
+ )
@@ -0,0 +1,43 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+
18
+ class GraphitiError(Exception):
19
+ """Base exception class for Graphiti Core."""
20
+
21
+
22
+ class EdgeNotFoundError(GraphitiError):
23
+ """Raised when an edge is not found."""
24
+
25
+ def __init__(self, uuid: str):
26
+ self.message = f'edge {uuid} not found'
27
+ super().__init__(self.message)
28
+
29
+
30
+ class NodeNotFoundError(GraphitiError):
31
+ """Raised when a node is not found."""
32
+
33
+ def __init__(self, uuid: str):
34
+ self.message = f'node {uuid} not found'
35
+ super().__init__(self.message)
36
+
37
+
38
+ class SearchRerankerError(GraphitiError):
39
+ """Raised when a node is not found."""
40
+
41
+ def __init__(self, text: str):
42
+ self.message = text
43
+ super().__init__(self.message)
graphiti_core/graphiti.py CHANGED
@@ -24,14 +24,19 @@ from neo4j import AsyncGraphDatabase
24
24
 
25
25
  from graphiti_core.edges import EntityEdge, EpisodicEdge
26
26
  from graphiti_core.llm_client import LLMClient, OpenAIClient
27
- from graphiti_core.llm_client.utils import generate_embedding
28
27
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
29
- from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
28
+ from graphiti_core.search.search import SearchConfig, search
29
+ from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
30
+ from graphiti_core.search.search_config_recipes import (
31
+ EDGE_HYBRID_SEARCH_NODE_DISTANCE,
32
+ EDGE_HYBRID_SEARCH_RRF,
33
+ NODE_HYBRID_SEARCH_NODE_DISTANCE,
34
+ NODE_HYBRID_SEARCH_RRF,
35
+ )
30
36
  from graphiti_core.search.search_utils import (
31
37
  RELEVANT_SCHEMA_LIMIT,
32
38
  get_relevant_edges,
33
39
  get_relevant_nodes,
34
- hybrid_node_search,
35
40
  )
36
41
  from graphiti_core.utils import (
37
42
  build_episodic_edges,
@@ -46,6 +51,10 @@ from graphiti_core.utils.bulk_utils import (
46
51
  resolve_edge_pointers,
47
52
  retrieve_previous_episodes_bulk,
48
53
  )
54
+ from graphiti_core.utils.maintenance.community_operations import (
55
+ build_communities,
56
+ remove_communities,
57
+ )
49
58
  from graphiti_core.utils.maintenance.edge_operations import (
50
59
  extract_edges,
51
60
  resolve_extracted_edges,
@@ -412,7 +421,7 @@ class Graphiti:
412
421
  except Exception as e:
413
422
  raise e
414
423
 
415
- async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
424
+ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None):
416
425
  """
417
426
  Process multiple episodes in bulk and update the graph.
418
427
 
@@ -526,12 +535,25 @@ class Graphiti:
526
535
  except Exception as e:
527
536
  raise e
528
537
 
538
+ async def build_communities(self):
539
+ embedder = self.llm_client.get_embedder()
540
+
541
+ # Clear existing communities
542
+ await remove_communities(self.driver)
543
+
544
+ community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
545
+
546
+ await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
547
+
548
+ await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
549
+ await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
550
+
529
551
  async def search(
530
552
  self,
531
553
  query: str,
532
554
  center_node_uuid: str | None = None,
533
555
  group_ids: list[str | None] | None = None,
534
- num_results=10,
556
+ num_results=DEFAULT_SEARCH_LIMIT,
535
557
  ):
536
558
  """
537
559
  Perform a hybrid search on the knowledge graph.
@@ -547,7 +569,7 @@ class Graphiti:
547
569
  Facts will be reranked based on proximity to this node
548
570
  group_ids : list[str | None] | None, optional
549
571
  The graph partitions to return data from.
550
- num_results : int, optional
572
+ limit : int, optional
551
573
  The maximum number of results to return. Defaults to 10.
552
574
 
553
575
  Returns
@@ -564,21 +586,17 @@ class Graphiti:
564
586
  The search is performed using the current date and time as the reference
565
587
  point for temporal relevance.
566
588
  """
567
- reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
568
- search_config = SearchConfig(
569
- num_episodes=0,
570
- num_edges=num_results,
571
- num_nodes=0,
572
- group_ids=group_ids,
573
- search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
574
- reranker=reranker,
589
+ search_config = (
590
+ EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
575
591
  )
592
+ search_config.limit = num_results
593
+
576
594
  edges = (
577
- await hybrid_search(
595
+ await search(
578
596
  self.driver,
579
597
  self.llm_client.get_embedder(),
580
598
  query,
581
- datetime.now(),
599
+ group_ids,
582
600
  search_config,
583
601
  center_node_uuid,
584
602
  )
@@ -589,19 +607,20 @@ class Graphiti:
589
607
  async def _search(
590
608
  self,
591
609
  query: str,
592
- timestamp: datetime,
593
610
  config: SearchConfig,
611
+ group_ids: list[str | None] | None = None,
594
612
  center_node_uuid: str | None = None,
595
- ):
596
- return await hybrid_search(
597
- self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
613
+ ) -> SearchResults:
614
+ return await search(
615
+ self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
598
616
  )
599
617
 
600
618
  async def get_nodes_by_query(
601
619
  self,
602
620
  query: str,
621
+ center_node_uuid: str | None = None,
603
622
  group_ids: list[str | None] | None = None,
604
- limit: int = RELEVANT_SCHEMA_LIMIT,
623
+ limit: int = DEFAULT_SEARCH_LIMIT,
605
624
  ) -> list[EntityNode]:
606
625
  """
607
626
  Retrieve nodes from the graph database based on a text query.
@@ -612,7 +631,9 @@ class Graphiti:
612
631
  Parameters
613
632
  ----------
614
633
  query : str
615
- The text query to search for in the graph.
634
+ The text query to search for in the graph
635
+ center_node_uuid: str, optional
636
+ Facts will be reranked based on proximity to this node.
616
637
  group_ids : list[str | None] | None, optional
617
638
  The graph partitions to return data from.
618
639
  limit : int | None, optional
@@ -638,8 +659,12 @@ class Graphiti:
638
659
  If not specified, a default limit (defined in the search functions) will be used.
639
660
  """
640
661
  embedder = self.llm_client.get_embedder()
641
- query_embedding = await generate_embedding(embedder, query)
642
- relevant_nodes = await hybrid_node_search(
643
- [query], [query_embedding], self.driver, group_ids, limit
662
+ search_config = (
663
+ NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
644
664
  )
645
- return relevant_nodes
665
+ search_config.limit = limit
666
+
667
+ nodes = (
668
+ await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
669
+ ).nodes
670
+ return nodes
graphiti_core/helpers.py CHANGED
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  from datetime import datetime
2
18
 
3
19
  from neo4j import time as neo4j_time
@@ -1,5 +1,6 @@
1
1
  from .client import LLMClient
2
2
  from .config import LLMConfig
3
+ from .errors import RateLimitError
3
4
  from .openai_client import OpenAIClient
4
5
 
5
- __all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
6
+ __all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError']
@@ -18,12 +18,14 @@ import json
18
18
  import logging
19
19
  import typing
20
20
 
21
+ import anthropic
21
22
  from anthropic import AsyncAnthropic
22
23
  from openai import AsyncOpenAI
23
24
 
24
25
  from ..prompts.models import Message
25
26
  from .client import LLMClient
26
27
  from .config import LLMConfig
28
+ from .errors import RateLimitError
27
29
 
28
30
  logger = logging.getLogger(__name__)
29
31
 
@@ -35,7 +37,11 @@ class AnthropicClient(LLMClient):
35
37
  if config is None:
36
38
  config = LLMConfig()
37
39
  super().__init__(config, cache)
38
- self.client = AsyncAnthropic(api_key=config.api_key)
40
+ self.client = AsyncAnthropic(
41
+ api_key=config.api_key,
42
+ # we'll use tenacity to retry
43
+ max_retries=1,
44
+ )
39
45
 
40
46
  def get_embedder(self) -> typing.Any:
41
47
  openai_client = AsyncOpenAI()
@@ -58,6 +64,8 @@ class AnthropicClient(LLMClient):
58
64
  )
59
65
 
60
66
  return json.loads('{' + result.content[0].text) # type: ignore
67
+ except anthropic.RateLimitError as e:
68
+ raise RateLimitError from e
61
69
  except Exception as e:
62
70
  logger.error(f'Error in generating LLM response: {e}')
63
71
  raise
@@ -22,10 +22,11 @@ from abc import ABC, abstractmethod
22
22
 
23
23
  import httpx
24
24
  from diskcache import Cache
25
- from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
25
+ from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
26
26
 
27
27
  from ..prompts.models import Message
28
28
  from .config import LLMConfig
29
+ from .errors import RateLimitError
29
30
 
30
31
  DEFAULT_TEMPERATURE = 0
31
32
  DEFAULT_CACHE_DIR = './llm_cache'
@@ -33,7 +34,10 @@ DEFAULT_CACHE_DIR = './llm_cache'
33
34
  logger = logging.getLogger(__name__)
34
35
 
35
36
 
36
- def is_server_error(exception):
37
+ def is_server_or_retry_error(exception):
38
+ if isinstance(exception, RateLimitError):
39
+ return True
40
+
37
41
  return (
38
42
  isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
39
43
  )
@@ -56,18 +60,21 @@ class LLMClient(ABC):
56
60
  pass
57
61
 
58
62
  @retry(
59
- stop=stop_after_attempt(3),
60
- wait=wait_exponential(multiplier=1, min=4, max=10),
61
- retry=retry_if_exception(is_server_error),
63
+ stop=stop_after_attempt(4),
64
+ wait=wait_random_exponential(multiplier=10, min=5, max=120),
65
+ retry=retry_if_exception(is_server_or_retry_error),
66
+ after=lambda retry_state: logger.warning(
67
+ f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
68
+ )
69
+ if retry_state.attempt_number > 1
70
+ else None,
71
+ reraise=True,
62
72
  )
63
73
  async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
64
74
  try:
65
75
  return await self._generate_response(messages)
66
- except httpx.HTTPStatusError as e:
67
- if not is_server_error(e):
68
- raise Exception(f'LLM request error: {e}') from e
69
- else:
70
- raise
76
+ except (httpx.HTTPStatusError, RateLimitError) as e:
77
+ raise e
71
78
 
72
79
  @abstractmethod
73
80
  async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
@@ -0,0 +1,23 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+
18
+ class RateLimitError(Exception):
19
+ """Exception raised when the rate limit is exceeded."""
20
+
21
+ def __init__(self, message='Rate limit exceeded. Please try again later.'):
22
+ self.message = message
23
+ super().__init__(self.message)
@@ -18,6 +18,7 @@ import json
18
18
  import logging
19
19
  import typing
20
20
 
21
+ import groq
21
22
  from groq import AsyncGroq
22
23
  from groq.types.chat import ChatCompletionMessageParam
23
24
  from openai import AsyncOpenAI
@@ -25,6 +26,7 @@ from openai import AsyncOpenAI
25
26
  from ..prompts.models import Message
26
27
  from .client import LLMClient
27
28
  from .config import LLMConfig
29
+ from .errors import RateLimitError
28
30
 
29
31
  logger = logging.getLogger(__name__)
30
32
 
@@ -59,6 +61,8 @@ class GroqClient(LLMClient):
59
61
  )
60
62
  result = response.choices[0].message.content or ''
61
63
  return json.loads(result)
64
+ except groq.RateLimitError as e:
65
+ raise RateLimitError from e
62
66
  except Exception as e:
63
67
  logger.error(f'Error in generating LLM response: {e}')
64
68
  raise
@@ -18,12 +18,14 @@ import json
18
18
  import logging
19
19
  import typing
20
20
 
21
+ import openai
21
22
  from openai import AsyncOpenAI
22
23
  from openai.types.chat import ChatCompletionMessageParam
23
24
 
24
25
  from ..prompts.models import Message
25
26
  from .client import LLMClient
26
27
  from .config import LLMConfig
28
+ from .errors import RateLimitError
27
29
 
28
30
  logger = logging.getLogger(__name__)
29
31
 
@@ -59,6 +61,8 @@ class OpenAIClient(LLMClient):
59
61
  )
60
62
  result = response.choices[0].message.content or ''
61
63
  return json.loads(result)
64
+ except openai.RateLimitError as e:
65
+ raise RateLimitError from e
62
66
  except Exception as e:
63
67
  logger.error(f'Error in generating LLM response: {e}')
64
68
  raise
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  import logging
2
18
  import typing
3
19
  from time import time
@@ -17,6 +33,6 @@ async def generate_embedding(
17
33
  embedding = embedding[:EMBEDDING_DIM]
18
34
 
19
35
  end = time()
20
- logger.debug(f'embedded text of length {len(text)} in {end-start} ms')
36
+ logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
21
37
 
22
38
  return embedding