graphiti-core 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -15,6 +15,7 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  from .client import CrossEncoderClient
18
+ from .gemini_reranker_client import GeminiRerankerClient
18
19
  from .openai_reranker_client import OpenAIRerankerClient
19
20
 
20
- __all__ = ['CrossEncoderClient', 'OpenAIRerankerClient']
21
+ __all__ = ['CrossEncoderClient', 'GeminiRerankerClient', 'OpenAIRerankerClient']
@@ -0,0 +1,146 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ import re
19
+
20
+ from google import genai # type: ignore
21
+ from google.genai import types # type: ignore
22
+
23
+ from ..helpers import semaphore_gather
24
+ from ..llm_client import LLMConfig, RateLimitError
25
+ from .client import CrossEncoderClient
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
30
+
31
+
32
+ class GeminiRerankerClient(CrossEncoderClient):
33
+ def __init__(
34
+ self,
35
+ config: LLMConfig | None = None,
36
+ client: genai.Client | None = None,
37
+ ):
38
+ """
39
+ Initialize the GeminiRerankerClient with the provided configuration and client.
40
+
41
+ The Gemini Developer API does not yet support logprobs. Unlike the OpenAI reranker,
42
+ this reranker uses the Gemini API to perform direct relevance scoring of passages.
43
+ Each passage is scored individually on a 0-100 scale.
44
+
45
+ Args:
46
+ config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
47
+ client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
48
+ """
49
+ if config is None:
50
+ config = LLMConfig()
51
+
52
+ self.config = config
53
+ if client is None:
54
+ self.client = genai.Client(api_key=config.api_key)
55
+ else:
56
+ self.client = client
57
+
58
+ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
59
+ """
60
+ Rank passages based on their relevance to the query using direct scoring.
61
+
62
+ Each passage is scored individually on a 0-100 scale, then normalized to [0,1].
63
+ """
64
+ if len(passages) <= 1:
65
+ return [(passage, 1.0) for passage in passages]
66
+
67
+ # Generate scoring prompts for each passage
68
+ scoring_prompts = []
69
+ for passage in passages:
70
+ prompt = f"""Rate how well this passage answers or relates to the query. Use a scale from 0 to 100.
71
+
72
+ Query: {query}
73
+
74
+ Passage: {passage}
75
+
76
+ Provide only a number between 0 and 100 (no explanation, just the number):"""
77
+
78
+ scoring_prompts.append(
79
+ [
80
+ types.Content(
81
+ role='user',
82
+ parts=[types.Part.from_text(text=prompt)],
83
+ ),
84
+ ]
85
+ )
86
+
87
+ try:
88
+ # Execute all scoring requests concurrently - O(n) API calls
89
+ responses = await semaphore_gather(
90
+ *[
91
+ self.client.aio.models.generate_content(
92
+ model=self.config.model or DEFAULT_MODEL,
93
+ contents=prompt_messages, # type: ignore
94
+ config=types.GenerateContentConfig(
95
+ system_instruction='You are an expert at rating passage relevance. Respond with only a number from 0-100.',
96
+ temperature=0.0,
97
+ max_output_tokens=3,
98
+ ),
99
+ )
100
+ for prompt_messages in scoring_prompts
101
+ ]
102
+ )
103
+
104
+ # Extract scores and create results
105
+ results = []
106
+ for passage, response in zip(passages, responses, strict=True):
107
+ try:
108
+ if hasattr(response, 'text') and response.text:
109
+ # Extract numeric score from response
110
+ score_text = response.text.strip()
111
+ # Handle cases where model might return non-numeric text
112
+ score_match = re.search(r'\b(\d{1,3})\b', score_text)
113
+ if score_match:
114
+ score = float(score_match.group(1))
115
+ # Normalize to [0, 1] range and clamp to valid range
116
+ normalized_score = max(0.0, min(1.0, score / 100.0))
117
+ results.append((passage, normalized_score))
118
+ else:
119
+ logger.warning(
120
+ f'Could not extract numeric score from response: {score_text}'
121
+ )
122
+ results.append((passage, 0.0))
123
+ else:
124
+ logger.warning('Empty response from Gemini for passage scoring')
125
+ results.append((passage, 0.0))
126
+ except (ValueError, AttributeError) as e:
127
+ logger.warning(f'Error parsing score from Gemini response: {e}')
128
+ results.append((passage, 0.0))
129
+
130
+ # Sort by score in descending order (highest relevance first)
131
+ results.sort(reverse=True, key=lambda x: x[1])
132
+ return results
133
+
134
+ except Exception as e:
135
+ # Check if it's a rate limit error based on Gemini API error codes
136
+ error_message = str(e).lower()
137
+ if (
138
+ 'rate limit' in error_message
139
+ or 'quota' in error_message
140
+ or 'resource_exhausted' in error_message
141
+ or '429' in str(e)
142
+ ):
143
+ raise RateLimitError from e
144
+
145
+ logger.error(f'Error in generating LLM response: {e}')
146
+ raise
@@ -14,4 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- __all__ = ['GraphDriver', 'Neo4jDriver', 'FalkorDriver']
17
+ from falkordb import FalkorDB
18
+ from neo4j import Neo4jDriver
19
+
20
+ __all__ = ['Neo4jDriver', 'FalkorDB']
@@ -15,7 +15,6 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import logging
18
- from collections.abc import Coroutine
19
18
  from datetime import datetime
20
19
  from typing import Any
21
20
 
@@ -52,11 +51,11 @@ class FalkorDriverSession(GraphDriverSession):
52
51
  if isinstance(query, list):
53
52
  for cypher, params in query:
54
53
  params = convert_datetimes_to_strings(params)
55
- await self.graph.query(str(cypher), params)
54
+ await self.graph.query(str(cypher), params) # type: ignore[reportUnknownArgumentType]
56
55
  else:
57
56
  params = dict(kwargs)
58
57
  params = convert_datetimes_to_strings(params)
59
- await self.graph.query(str(query), params)
58
+ await self.graph.query(str(query), params) # type: ignore[reportUnknownArgumentType]
60
59
  # Assuming `graph.query` is async (ideal); otherwise, wrap in executor
61
60
  return None
62
61
 
@@ -66,22 +65,30 @@ class FalkorDriver(GraphDriver):
66
65
 
67
66
  def __init__(
68
67
  self,
69
- uri: str,
70
- user: str,
71
- password: str,
68
+ host: str = 'localhost',
69
+ port: int = 6379,
70
+ username: str | None = None,
71
+ password: str | None = None,
72
+ falkor_db: FalkorDB | None = None,
72
73
  ):
73
- super().__init__()
74
- uri_parts = uri.split('://', 1)
75
- uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
74
+ """
75
+ Initialize the FalkorDB driver.
76
76
 
77
- self.client = FalkorDB(
78
- host='your-db.falkor.cloud', port=6380, password='your_password', ssl=True
79
- )
77
+ FalkorDB is a multi-tenant graph database.
78
+ To connect, provide the host and port.
79
+ The default parameters assume a local (on-premises) FalkorDB instance.
80
+ """
81
+ super().__init__()
82
+ if falkor_db is not None:
83
+ # If a FalkorDB instance is provided, use it directly
84
+ self.client = falkor_db
85
+ else:
86
+ self.client = FalkorDB(host=host, port=port, username=username, password=password)
80
87
 
81
88
  def _get_graph(self, graph_name: str | None) -> FalkorGraph:
82
- # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
89
+ # FalkorDB requires a non-None database name for multi-tenant graphs; the default is DEFAULT_DATABASE
83
90
  if graph_name is None:
84
- graph_name = 'DEFAULT_DATABASE'
91
+ graph_name = DEFAULT_DATABASE
85
92
  return self.client.select_graph(graph_name)
86
93
 
87
94
  async def execute_query(self, cypher_query_, **kwargs: Any):
@@ -92,7 +99,7 @@ class FalkorDriver(GraphDriver):
92
99
  params = convert_datetimes_to_strings(dict(kwargs))
93
100
 
94
101
  try:
95
- result = await graph.query(cypher_query_, params)
102
+ result = await graph.query(cypher_query_, params) # type: ignore[reportUnknownArgumentType]
96
103
  except Exception as e:
97
104
  if 'already indexed' in str(e):
98
105
  # check if index already exists
@@ -102,17 +109,36 @@ class FalkorDriver(GraphDriver):
102
109
  raise
103
110
 
104
111
  # Convert the result header to a list of strings
105
- header = [h[1].decode('utf-8') for h in result.header]
106
- return result.result_set, header, None
112
+ header = [h[1] for h in result.header]
113
+
114
+ # Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts)
115
+ records = []
116
+ for row in result.result_set:
117
+ record = {}
118
+ for i, field_name in enumerate(header):
119
+ if i < len(row):
120
+ record[field_name] = row[i]
121
+ else:
122
+ # If there are more fields in header than values in row, set to None
123
+ record[field_name] = None
124
+ records.append(record)
125
+
126
+ return records, header, None
107
127
 
108
128
  def session(self, database: str | None) -> GraphDriverSession:
109
129
  return FalkorDriverSession(self._get_graph(database))
110
130
 
111
131
  async def close(self) -> None:
112
- await self.client.connection.close()
113
-
114
- async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
115
- return self.execute_query(
132
+ """Close the driver connection."""
133
+ if hasattr(self.client, 'aclose'):
134
+ await self.client.aclose() # type: ignore[reportUnknownMemberType]
135
+ elif hasattr(self.client.connection, 'aclose'):
136
+ await self.client.connection.aclose()
137
+ elif hasattr(self.client.connection, 'close'):
138
+ await self.client.connection.close()
139
+
140
+ async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> None:
141
+ await self.execute_query(
116
142
  'CALL db.indexes() YIELD name DROP INDEX name',
117
143
  database_=database_,
118
144
  )
@@ -18,7 +18,7 @@ import logging
18
18
  from collections.abc import Coroutine
19
19
  from typing import Any
20
20
 
21
- from neo4j import AsyncGraphDatabase
21
+ from neo4j import AsyncGraphDatabase, EagerResult
22
22
  from typing_extensions import LiteralString
23
23
 
24
24
  from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
@@ -42,7 +42,7 @@ class Neo4jDriver(GraphDriver):
42
42
  auth=(user or '', password or ''),
43
43
  )
44
44
 
45
- async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine:
45
+ async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
46
46
  params = kwargs.pop('params', None)
47
47
  result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
48
48
 
@@ -54,7 +54,9 @@ class Neo4jDriver(GraphDriver):
54
54
  async def close(self) -> None:
55
55
  return await self.client.close()
56
56
 
57
- def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
57
+ def delete_all_indexes(
58
+ self, database_: str = DEFAULT_DATABASE
59
+ ) -> Coroutine[Any, Any, EagerResult]:
58
60
  return self.client.execute_query(
59
61
  'CALL db.indexes() YIELD name DROP INDEX name',
60
62
  database_=database_,
@@ -38,7 +38,7 @@ class VoyageAIEmbedder(EmbedderClient):
38
38
  if config is None:
39
39
  config = VoyageAIEmbedderConfig()
40
40
  self.config = config
41
- self.client = voyageai.AsyncClient(api_key=config.api_key)
41
+ self.client = voyageai.AsyncClient(api_key=config.api_key) # type: ignore[reportUnknownMemberType]
42
42
 
43
43
  async def create(
44
44
  self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
graphiti_core/graphiti.py CHANGED
@@ -29,7 +29,12 @@ from graphiti_core.driver.neo4j_driver import Neo4jDriver
29
29
  from graphiti_core.edges import EntityEdge, EpisodicEdge
30
30
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
31
31
  from graphiti_core.graphiti_types import GraphitiClients
32
- from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather, validate_group_id
32
+ from graphiti_core.helpers import (
33
+ DEFAULT_DATABASE,
34
+ semaphore_gather,
35
+ validate_excluded_entity_types,
36
+ validate_group_id,
37
+ )
33
38
  from graphiti_core.llm_client import LLMClient, OpenAIClient
34
39
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
35
40
  from graphiti_core.search.search import SearchConfig, search
@@ -46,6 +51,7 @@ from graphiti_core.search.search_utils import (
46
51
  get_mentioned_nodes,
47
52
  get_relevant_edges,
48
53
  )
54
+ from graphiti_core.telemetry import capture_event
49
55
  from graphiti_core.utils.bulk_utils import (
50
56
  RawEpisode,
51
57
  add_nodes_and_edges_bulk,
@@ -95,7 +101,7 @@ class AddEpisodeResults(BaseModel):
95
101
  class Graphiti:
96
102
  def __init__(
97
103
  self,
98
- uri: str,
104
+ uri: str | None = None,
99
105
  user: str | None = None,
100
106
  password: str | None = None,
101
107
  llm_client: LLMClient | None = None,
@@ -156,7 +162,12 @@ class Graphiti:
156
162
  Graphiti if you're using the default OpenAIClient.
157
163
  """
158
164
 
159
- self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
165
+ if graph_driver:
166
+ self.driver = graph_driver
167
+ else:
168
+ if uri is None:
169
+ raise ValueError("uri must be provided when graph_driver is None")
170
+ self.driver = Neo4jDriver(uri, user, password)
160
171
 
161
172
  self.database = DEFAULT_DATABASE
162
173
  self.store_raw_episode_content = store_raw_episode_content
@@ -181,6 +192,61 @@ class Graphiti:
181
192
  cross_encoder=self.cross_encoder,
182
193
  )
183
194
 
195
+ # Capture telemetry event
196
+ self._capture_initialization_telemetry()
197
+
198
+ def _capture_initialization_telemetry(self):
199
+ """Capture telemetry event for Graphiti initialization."""
200
+ try:
201
+ # Detect provider types from class names
202
+ llm_provider = self._get_provider_type(self.llm_client)
203
+ embedder_provider = self._get_provider_type(self.embedder)
204
+ reranker_provider = self._get_provider_type(self.cross_encoder)
205
+ database_provider = self._get_provider_type(self.driver)
206
+
207
+ properties = {
208
+ 'llm_provider': llm_provider,
209
+ 'embedder_provider': embedder_provider,
210
+ 'reranker_provider': reranker_provider,
211
+ 'database_provider': database_provider,
212
+ }
213
+
214
+ capture_event('graphiti_initialized', properties)
215
+ except Exception:
216
+ # Silently handle telemetry errors
217
+ pass
218
+
219
+ def _get_provider_type(self, client) -> str:
220
+ """Get provider type from client class name."""
221
+ if client is None:
222
+ return 'none'
223
+
224
+ class_name = client.__class__.__name__.lower()
225
+
226
+ # LLM providers
227
+ if 'openai' in class_name:
228
+ return 'openai'
229
+ elif 'azure' in class_name:
230
+ return 'azure'
231
+ elif 'anthropic' in class_name:
232
+ return 'anthropic'
233
+ elif 'crossencoder' in class_name:
234
+ return 'crossencoder'
235
+ elif 'gemini' in class_name:
236
+ return 'gemini'
237
+ elif 'groq' in class_name:
238
+ return 'groq'
239
+ # Database providers
240
+ elif 'neo4j' in class_name:
241
+ return 'neo4j'
242
+ elif 'falkor' in class_name:
243
+ return 'falkordb'
244
+ # Embedder providers
245
+ elif 'voyage' in class_name:
246
+ return 'voyage'
247
+ else:
248
+ return 'unknown'
249
+
184
250
  async def close(self):
185
251
  """
186
252
  Close the connection to the Neo4j database.
@@ -293,6 +359,7 @@ class Graphiti:
293
359
  uuid: str | None = None,
294
360
  update_communities: bool = False,
295
361
  entity_types: dict[str, BaseModel] | None = None,
362
+ excluded_entity_types: list[str] | None = None,
296
363
  previous_episode_uuids: list[str] | None = None,
297
364
  edge_types: dict[str, BaseModel] | None = None,
298
365
  edge_type_map: dict[tuple[str, str], list[str]] | None = None,
@@ -321,6 +388,12 @@ class Graphiti:
321
388
  Optional uuid of the episode.
322
389
  update_communities : bool
323
390
  Optional. Whether to update communities with new node information
391
+ entity_types : dict[str, BaseModel] | None
392
+ Optional. Dictionary mapping entity type names to their Pydantic model definitions.
393
+ excluded_entity_types : list[str] | None
394
+ Optional. List of entity type names to exclude from the graph. Entities classified
395
+ into these types will not be added to the graph. Can include 'Entity' to exclude
396
+ the default entity type.
324
397
  previous_episode_uuids : list[str] | None
325
398
  Optional. list of episode uuids to use as the previous episodes. If this is not provided,
326
399
  the most recent episodes by created_at date will be used.
@@ -351,6 +424,7 @@ class Graphiti:
351
424
  now = utc_now()
352
425
 
353
426
  validate_entity_types(entity_types)
427
+ validate_excluded_entity_types(excluded_entity_types, entity_types)
354
428
  validate_group_id(group_id)
355
429
 
356
430
  previous_episodes = (
@@ -389,7 +463,7 @@ class Graphiti:
389
463
  # Extract entities as nodes
390
464
 
391
465
  extracted_nodes = await extract_nodes(
392
- self.clients, episode, previous_episodes, entity_types
466
+ self.clients, episode, previous_episodes, entity_types, excluded_entity_types
393
467
  )
394
468
 
395
469
  # Extract edges and resolve nodes
@@ -534,7 +608,7 @@ class Graphiti:
534
608
  extracted_nodes,
535
609
  extracted_edges,
536
610
  episodic_edges,
537
- ) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs)
611
+ ) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs, None, None)
538
612
 
539
613
  # Generate embeddings
540
614
  await semaphore_gather(
graphiti_core/helpers.py CHANGED
@@ -19,18 +19,20 @@ import os
19
19
  import re
20
20
  from collections.abc import Coroutine
21
21
  from datetime import datetime
22
+ from typing import Any
22
23
 
23
24
  import numpy as np
24
25
  from dotenv import load_dotenv
25
26
  from neo4j import time as neo4j_time
26
27
  from numpy._typing import NDArray
28
+ from pydantic import BaseModel
27
29
  from typing_extensions import LiteralString
28
30
 
29
31
  from graphiti_core.errors import GroupIdValidationError
30
32
 
31
33
  load_dotenv()
32
34
 
33
- DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
35
+ DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'default_db')
34
36
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
35
37
  SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
36
38
  MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
@@ -98,7 +100,7 @@ def normalize_l2(embedding: list[float]) -> NDArray:
98
100
  async def semaphore_gather(
99
101
  *coroutines: Coroutine,
100
102
  max_coroutines: int | None = None,
101
- ):
103
+ ) -> list[Any]:
102
104
  semaphore = asyncio.Semaphore(max_coroutines or SEMAPHORE_LIMIT)
103
105
 
104
106
  async def _wrap_coroutine(coroutine):
@@ -132,3 +134,37 @@ def validate_group_id(group_id: str) -> bool:
132
134
  raise GroupIdValidationError(group_id)
133
135
 
134
136
  return True
137
+
138
+
139
+ def validate_excluded_entity_types(
140
+ excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None
141
+ ) -> bool:
142
+ """
143
+ Validate that excluded entity types are valid type names.
144
+
145
+ Args:
146
+ excluded_entity_types: List of entity type names to exclude
147
+ entity_types: Dictionary of available custom entity types
148
+
149
+ Returns:
150
+ True if valid
151
+
152
+ Raises:
153
+ ValueError: If any excluded type names are invalid
154
+ """
155
+ if not excluded_entity_types:
156
+ return True
157
+
158
+ # Build set of available type names
159
+ available_types = {'Entity'} # Default type is always available
160
+ if entity_types:
161
+ available_types.update(entity_types.keys())
162
+
163
+ # Check for invalid type names
164
+ invalid_types = set(excluded_entity_types) - available_types
165
+ if invalid_types:
166
+ raise ValueError(
167
+ f'Invalid excluded entity types: {sorted(invalid_types)}. Available types: {sorted(available_types)}'
168
+ )
169
+
170
+ return True