graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
  2. graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
  3. graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
  4. graphiti_core/decorators.py +110 -0
  5. graphiti_core/driver/__init__.py +19 -0
  6. graphiti_core/driver/driver.py +124 -0
  7. graphiti_core/driver/falkordb_driver.py +362 -0
  8. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  9. graphiti_core/driver/kuzu_driver.py +182 -0
  10. graphiti_core/driver/neo4j_driver.py +117 -0
  11. graphiti_core/driver/neptune_driver.py +305 -0
  12. graphiti_core/driver/search_interface/search_interface.py +89 -0
  13. graphiti_core/edges.py +287 -172
  14. graphiti_core/embedder/azure_openai.py +71 -0
  15. graphiti_core/embedder/client.py +2 -1
  16. graphiti_core/embedder/gemini.py +116 -22
  17. graphiti_core/embedder/voyage.py +13 -2
  18. graphiti_core/errors.py +8 -0
  19. graphiti_core/graph_queries.py +162 -0
  20. graphiti_core/graphiti.py +705 -193
  21. graphiti_core/graphiti_types.py +4 -2
  22. graphiti_core/helpers.py +87 -10
  23. graphiti_core/llm_client/__init__.py +16 -0
  24. graphiti_core/llm_client/anthropic_client.py +159 -56
  25. graphiti_core/llm_client/azure_openai_client.py +115 -0
  26. graphiti_core/llm_client/client.py +98 -21
  27. graphiti_core/llm_client/config.py +1 -1
  28. graphiti_core/llm_client/gemini_client.py +290 -41
  29. graphiti_core/llm_client/groq_client.py +14 -3
  30. graphiti_core/llm_client/openai_base_client.py +261 -0
  31. graphiti_core/llm_client/openai_client.py +56 -132
  32. graphiti_core/llm_client/openai_generic_client.py +91 -56
  33. graphiti_core/models/edges/edge_db_queries.py +259 -35
  34. graphiti_core/models/nodes/node_db_queries.py +311 -32
  35. graphiti_core/nodes.py +420 -205
  36. graphiti_core/prompts/dedupe_edges.py +46 -32
  37. graphiti_core/prompts/dedupe_nodes.py +67 -42
  38. graphiti_core/prompts/eval.py +4 -4
  39. graphiti_core/prompts/extract_edges.py +27 -16
  40. graphiti_core/prompts/extract_nodes.py +74 -31
  41. graphiti_core/prompts/prompt_helpers.py +39 -0
  42. graphiti_core/prompts/snippets.py +29 -0
  43. graphiti_core/prompts/summarize_nodes.py +23 -25
  44. graphiti_core/search/search.py +158 -82
  45. graphiti_core/search/search_config.py +39 -4
  46. graphiti_core/search/search_filters.py +126 -35
  47. graphiti_core/search/search_helpers.py +5 -6
  48. graphiti_core/search/search_utils.py +1405 -485
  49. graphiti_core/telemetry/__init__.py +9 -0
  50. graphiti_core/telemetry/telemetry.py +117 -0
  51. graphiti_core/tracer.py +193 -0
  52. graphiti_core/utils/bulk_utils.py +364 -285
  53. graphiti_core/utils/datetime_utils.py +13 -0
  54. graphiti_core/utils/maintenance/community_operations.py +67 -49
  55. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  56. graphiti_core/utils/maintenance/edge_operations.py +339 -197
  57. graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
  58. graphiti_core/utils/maintenance/node_operations.py +319 -238
  59. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  60. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  61. graphiti_core/utils/text_utils.py +53 -0
  62. graphiti_core-0.24.3.dist-info/METADATA +726 -0
  63. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  64. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  65. graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
  66. graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
  67. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  68. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
@@ -15,8 +15,18 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import asyncio
18
-
19
- from sentence_transformers import CrossEncoder
18
+ from typing import TYPE_CHECKING
19
+
20
+ if TYPE_CHECKING:
21
+ from sentence_transformers import CrossEncoder
22
+ else:
23
+ try:
24
+ from sentence_transformers import CrossEncoder
25
+ except ImportError:
26
+ raise ImportError(
27
+ 'sentence-transformers is required for BGERerankerClient. '
28
+ 'Install it with: pip install graphiti-core[sentence-transformers]'
29
+ ) from None
20
30
 
21
31
  from graphiti_core.cross_encoder.client import CrossEncoderClient
22
32
 
@@ -0,0 +1,161 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ import re
19
+ from typing import TYPE_CHECKING
20
+
21
+ from ..helpers import semaphore_gather
22
+ from ..llm_client import LLMConfig, RateLimitError
23
+ from .client import CrossEncoderClient
24
+
25
+ if TYPE_CHECKING:
26
+ from google import genai
27
+ from google.genai import types
28
+ else:
29
+ try:
30
+ from google import genai
31
+ from google.genai import types
32
+ except ImportError:
33
+ raise ImportError(
34
+ 'google-genai is required for GeminiRerankerClient. '
35
+ 'Install it with: pip install graphiti-core[google-genai]'
36
+ ) from None
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ DEFAULT_MODEL = 'gemini-2.5-flash-lite'
41
+
42
+
43
+ class GeminiRerankerClient(CrossEncoderClient):
44
+ """
45
+ Google Gemini Reranker Client
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ config: LLMConfig | None = None,
51
+ client: 'genai.Client | None' = None,
52
+ ):
53
+ """
54
+ Initialize the GeminiRerankerClient with the provided configuration and client.
55
+
56
+ The Gemini Developer API does not yet support logprobs. Unlike the OpenAI reranker,
57
+ this reranker uses the Gemini API to perform direct relevance scoring of passages.
58
+ Each passage is scored individually on a 0-100 scale.
59
+
60
+ Args:
61
+ config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
62
+ client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
63
+ """
64
+ if config is None:
65
+ config = LLMConfig()
66
+
67
+ self.config = config
68
+ if client is None:
69
+ self.client = genai.Client(api_key=config.api_key)
70
+ else:
71
+ self.client = client
72
+
73
+ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
74
+ """
75
+ Rank passages based on their relevance to the query using direct scoring.
76
+
77
+ Each passage is scored individually on a 0-100 scale, then normalized to [0,1].
78
+ """
79
+ if len(passages) <= 1:
80
+ return [(passage, 1.0) for passage in passages]
81
+
82
+ # Generate scoring prompts for each passage
83
+ scoring_prompts = []
84
+ for passage in passages:
85
+ prompt = f"""Rate how well this passage answers or relates to the query. Use a scale from 0 to 100.
86
+
87
+ Query: {query}
88
+
89
+ Passage: {passage}
90
+
91
+ Provide only a number between 0 and 100 (no explanation, just the number):"""
92
+
93
+ scoring_prompts.append(
94
+ [
95
+ types.Content(
96
+ role='user',
97
+ parts=[types.Part.from_text(text=prompt)],
98
+ ),
99
+ ]
100
+ )
101
+
102
+ try:
103
+ # Execute all scoring requests concurrently - O(n) API calls
104
+ responses = await semaphore_gather(
105
+ *[
106
+ self.client.aio.models.generate_content(
107
+ model=self.config.model or DEFAULT_MODEL,
108
+ contents=prompt_messages, # type: ignore
109
+ config=types.GenerateContentConfig(
110
+ system_instruction='You are an expert at rating passage relevance. Respond with only a number from 0-100.',
111
+ temperature=0.0,
112
+ max_output_tokens=3,
113
+ ),
114
+ )
115
+ for prompt_messages in scoring_prompts
116
+ ]
117
+ )
118
+
119
+ # Extract scores and create results
120
+ results = []
121
+ for passage, response in zip(passages, responses, strict=True):
122
+ try:
123
+ if hasattr(response, 'text') and response.text:
124
+ # Extract numeric score from response
125
+ score_text = response.text.strip()
126
+ # Handle cases where model might return non-numeric text
127
+ score_match = re.search(r'\b(\d{1,3})\b', score_text)
128
+ if score_match:
129
+ score = float(score_match.group(1))
130
+ # Normalize to [0, 1] range and clamp to valid range
131
+ normalized_score = max(0.0, min(1.0, score / 100.0))
132
+ results.append((passage, normalized_score))
133
+ else:
134
+ logger.warning(
135
+ f'Could not extract numeric score from response: {score_text}'
136
+ )
137
+ results.append((passage, 0.0))
138
+ else:
139
+ logger.warning('Empty response from Gemini for passage scoring')
140
+ results.append((passage, 0.0))
141
+ except (ValueError, AttributeError) as e:
142
+ logger.warning(f'Error parsing score from Gemini response: {e}')
143
+ results.append((passage, 0.0))
144
+
145
+ # Sort by score in descending order (highest relevance first)
146
+ results.sort(reverse=True, key=lambda x: x[1])
147
+ return results
148
+
149
+ except Exception as e:
150
+ # Check if it's a rate limit error based on Gemini API error codes
151
+ error_message = str(e).lower()
152
+ if (
153
+ 'rate limit' in error_message
154
+ or 'quota' in error_message
155
+ or 'resource_exhausted' in error_message
156
+ or '429' in str(e)
157
+ ):
158
+ raise RateLimitError from e
159
+
160
+ logger.error(f'Error in generating LLM response: {e}')
161
+ raise
@@ -22,7 +22,7 @@ import openai
22
22
  from openai import AsyncAzureOpenAI, AsyncOpenAI
23
23
 
24
24
  from ..helpers import semaphore_gather
25
- from ..llm_client import LLMConfig, RateLimitError
25
+ from ..llm_client import LLMConfig, OpenAIClient, RateLimitError
26
26
  from ..prompts import Message
27
27
  from .client import CrossEncoderClient
28
28
 
@@ -35,7 +35,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
35
35
  def __init__(
36
36
  self,
37
37
  config: LLMConfig | None = None,
38
- client: AsyncOpenAI | AsyncAzureOpenAI | None = None,
38
+ client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None,
39
39
  ):
40
40
  """
41
41
  Initialize the OpenAIRerankerClient with the provided configuration and client.
@@ -45,7 +45,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
45
45
 
46
46
  Args:
47
47
  config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
48
- client (AsyncOpenAI | AsyncAzureOpenAI | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
48
+ client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
49
49
  """
50
50
  if config is None:
51
51
  config = LLMConfig()
@@ -53,6 +53,8 @@ class OpenAIRerankerClient(CrossEncoderClient):
53
53
  self.config = config
54
54
  if client is None:
55
55
  self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
56
+ elif isinstance(client, OpenAIClient):
57
+ self.client = client.client
56
58
  else:
57
59
  self.client = client
58
60
 
@@ -82,7 +84,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
82
84
  responses = await semaphore_gather(
83
85
  *[
84
86
  self.client.chat.completions.create(
85
- model=DEFAULT_MODEL,
87
+ model=self.config.model or DEFAULT_MODEL,
86
88
  messages=openai_messages,
87
89
  temperature=0,
88
90
  max_tokens=1,
@@ -106,7 +108,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
106
108
  if len(top_logprobs) == 0:
107
109
  continue
108
110
  norm_logprobs = np.exp(top_logprobs[0].logprob)
109
- if bool(top_logprobs[0].token):
111
+ if top_logprobs[0].token.strip().split(' ')[0].lower() == 'true':
110
112
  scores.append(norm_logprobs)
111
113
  else:
112
114
  scores.append(1 - norm_logprobs)
@@ -0,0 +1,110 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import functools
18
+ import inspect
19
+ from collections.abc import Awaitable, Callable
20
+ from typing import Any, TypeVar
21
+
22
+ from graphiti_core.driver.driver import GraphProvider
23
+ from graphiti_core.helpers import semaphore_gather
24
+ from graphiti_core.search.search_config import SearchResults
25
+
26
+ F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
27
+
28
+
29
+ def handle_multiple_group_ids(func: F) -> F:
30
+ """
31
+ Decorator for FalkorDB methods that need to handle multiple group_ids.
32
+ Runs the function for each group_id separately and merges results.
33
+ """
34
+
35
+ @functools.wraps(func)
36
+ async def wrapper(self, *args, **kwargs):
37
+ group_ids_func_pos = get_parameter_position(func, 'group_ids')
38
+ group_ids_pos = (
39
+ group_ids_func_pos - 1 if group_ids_func_pos is not None else None
40
+ ) # Adjust for zero-based index
41
+ group_ids = kwargs.get('group_ids')
42
+
43
+ # If not in kwargs and position exists, get from args
44
+ if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
45
+ group_ids = args[group_ids_pos]
46
+
47
+ # Only handle FalkorDB with multiple group_ids
48
+ if (
49
+ hasattr(self, 'clients')
50
+ and hasattr(self.clients, 'driver')
51
+ and self.clients.driver.provider == GraphProvider.FALKORDB
52
+ and group_ids
53
+ and len(group_ids) > 1
54
+ ):
55
+ # Execute for each group_id concurrently
56
+ driver = self.clients.driver
57
+
58
+ async def execute_for_group(gid: str):
59
+ # Remove group_ids from args if it was passed positionally
60
+ filtered_args = list(args)
61
+ if group_ids_pos is not None and len(args) > group_ids_pos:
62
+ filtered_args.pop(group_ids_pos)
63
+
64
+ return await func(
65
+ self,
66
+ *filtered_args,
67
+ **{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)},
68
+ )
69
+
70
+ results = await semaphore_gather(
71
+ *[execute_for_group(gid) for gid in group_ids],
72
+ max_coroutines=getattr(self, 'max_coroutines', None),
73
+ )
74
+
75
+ # Merge results based on type
76
+ if isinstance(results[0], SearchResults):
77
+ return SearchResults.merge(results)
78
+ elif isinstance(results[0], list):
79
+ return [item for result in results for item in result]
80
+ elif isinstance(results[0], tuple):
81
+ # Handle tuple outputs (like build_communities returning (nodes, edges))
82
+ merged_tuple = []
83
+ for i in range(len(results[0])):
84
+ component_results = [result[i] for result in results]
85
+ if isinstance(component_results[0], list):
86
+ merged_tuple.append(
87
+ [item for component in component_results for item in component]
88
+ )
89
+ else:
90
+ merged_tuple.append(component_results)
91
+ return tuple(merged_tuple)
92
+ else:
93
+ return results
94
+
95
+ # Normal execution
96
+ return await func(self, *args, **kwargs)
97
+
98
+ return wrapper # type: ignore
99
+
100
+
101
+ def get_parameter_position(func: Callable, param_name: str) -> int | None:
102
+ """
103
+ Returns the positional index of a parameter in the function signature.
104
+ If the parameter is not found, returns None.
105
+ """
106
+ sig = inspect.signature(func)
107
+ for idx, (name, _param) in enumerate(sig.parameters.items()):
108
+ if name == param_name:
109
+ return idx
110
+ return None
@@ -0,0 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from neo4j import Neo4jDriver
18
+
19
+ __all__ = ['Neo4jDriver']
@@ -0,0 +1,124 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import copy
18
+ import logging
19
+ import os
20
+ from abc import ABC, abstractmethod
21
+ from collections.abc import Coroutine
22
+ from enum import Enum
23
+ from typing import Any
24
+
25
+ from dotenv import load_dotenv
26
+
27
+ from graphiti_core.driver.graph_operations.graph_operations import GraphOperationsInterface
28
+ from graphiti_core.driver.search_interface.search_interface import SearchInterface
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ DEFAULT_SIZE = 10
33
+
34
+ load_dotenv()
35
+
36
+ ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
37
+ EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
38
+ COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
39
+ ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
40
+
41
+
42
+ class GraphProvider(Enum):
43
+ NEO4J = 'neo4j'
44
+ FALKORDB = 'falkordb'
45
+ KUZU = 'kuzu'
46
+ NEPTUNE = 'neptune'
47
+
48
+
49
+ class GraphDriverSession(ABC):
50
+ provider: GraphProvider
51
+
52
+ async def __aenter__(self):
53
+ return self
54
+
55
+ @abstractmethod
56
+ async def __aexit__(self, exc_type, exc, tb):
57
+ # No cleanup needed for Falkor, but method must exist
58
+ pass
59
+
60
+ @abstractmethod
61
+ async def run(self, query: str, **kwargs: Any) -> Any:
62
+ raise NotImplementedError()
63
+
64
+ @abstractmethod
65
+ async def close(self):
66
+ raise NotImplementedError()
67
+
68
+ @abstractmethod
69
+ async def execute_write(self, func, *args, **kwargs):
70
+ raise NotImplementedError()
71
+
72
+
73
+ class GraphDriver(ABC):
74
+ provider: GraphProvider
75
+ fulltext_syntax: str = (
76
+ '' # Neo4j (default) syntax does not require a prefix for fulltext queries
77
+ )
78
+ _database: str
79
+ default_group_id: str = ''
80
+ search_interface: SearchInterface | None = None
81
+ graph_operations_interface: GraphOperationsInterface | None = None
82
+
83
+ @abstractmethod
84
+ def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
85
+ raise NotImplementedError()
86
+
87
+ @abstractmethod
88
+ def session(self, database: str | None = None) -> GraphDriverSession:
89
+ raise NotImplementedError()
90
+
91
+ @abstractmethod
92
+ def close(self):
93
+ raise NotImplementedError()
94
+
95
+ @abstractmethod
96
+ def delete_all_indexes(self) -> Coroutine:
97
+ raise NotImplementedError()
98
+
99
+ def with_database(self, database: str) -> 'GraphDriver':
100
+ """
101
+ Returns a shallow copy of this driver with a different default database.
102
+ Reuses the same connection (e.g. FalkorDB, Neo4j).
103
+ """
104
+ cloned = copy.copy(self)
105
+ cloned._database = database
106
+
107
+ return cloned
108
+
109
+ @abstractmethod
110
+ async def build_indices_and_constraints(self, delete_existing: bool = False):
111
+ raise NotImplementedError()
112
+
113
+ def clone(self, database: str) -> 'GraphDriver':
114
+ """Clone the driver with a different database or graph name."""
115
+ return self
116
+
117
+ def build_fulltext_query(
118
+ self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
119
+ ) -> str:
120
+ """
121
+ Specific fulltext query builder for database providers.
122
+ Only implemented by providers that need custom fulltext query building.
123
+ """
124
+ raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')