graphiti-core 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -17,19 +17,21 @@ limitations under the License.
17
17
  import json
18
18
  import logging
19
19
  import typing
20
+ from typing import ClassVar
20
21
 
21
22
  from google import genai # type: ignore
22
23
  from google.genai import types # type: ignore
23
24
  from pydantic import BaseModel
24
25
 
25
26
  from ..prompts.models import Message
26
- from .client import LLMClient
27
+ from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
27
28
  from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
28
29
  from .errors import RateLimitError
29
30
 
30
31
  logger = logging.getLogger(__name__)
31
32
 
32
- DEFAULT_MODEL = 'gemini-2.0-flash'
33
+ DEFAULT_MODEL = 'gemini-2.5-flash'
34
+ DEFAULT_SMALL_MODEL = 'models/gemini-2.5-flash-lite-preview-06-17'
33
35
 
34
36
 
35
37
  class GeminiClient(LLMClient):
@@ -43,27 +45,34 @@ class GeminiClient(LLMClient):
43
45
  model (str): The model name to use for generating responses.
44
46
  temperature (float): The temperature to use for generating responses.
45
47
  max_tokens (int): The maximum number of tokens to generate in a response.
46
-
48
+ thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
47
49
  Methods:
48
- __init__(config: LLMConfig | None = None, cache: bool = False):
49
- Initializes the GeminiClient with the provided configuration and cache setting.
50
+ __init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
51
+ Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
50
52
 
51
53
  _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
52
54
  Generates a response from the language model based on the provided messages.
53
55
  """
54
56
 
57
+ # Class-level constants
58
+ MAX_RETRIES: ClassVar[int] = 2
59
+
55
60
  def __init__(
56
61
  self,
57
62
  config: LLMConfig | None = None,
58
63
  cache: bool = False,
59
64
  max_tokens: int = DEFAULT_MAX_TOKENS,
65
+ thinking_config: types.ThinkingConfig | None = None,
60
66
  ):
61
67
  """
62
- Initialize the GeminiClient with the provided configuration and cache setting.
68
+ Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
63
69
 
64
70
  Args:
65
71
  config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
66
72
  cache (bool): Whether to use caching for responses. Defaults to False.
73
+ thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
74
+ Only use with models that support thinking (gemini-2.5+). Defaults to None.
75
+
67
76
  """
68
77
  if config is None:
69
78
  config = LLMConfig()
@@ -76,6 +85,50 @@ class GeminiClient(LLMClient):
76
85
  api_key=config.api_key,
77
86
  )
78
87
  self.max_tokens = max_tokens
88
+ self.thinking_config = thinking_config
89
+
90
+ def _check_safety_blocks(self, response) -> None:
91
+ """Check if response was blocked for safety reasons and raise appropriate exceptions."""
92
+ # Check if the response was blocked for safety reasons
93
+ if not (hasattr(response, 'candidates') and response.candidates):
94
+ return
95
+
96
+ candidate = response.candidates[0]
97
+ if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
98
+ return
99
+
100
+ # Content was blocked for safety reasons - collect safety details
101
+ safety_info = []
102
+ safety_ratings = getattr(candidate, 'safety_ratings', None)
103
+
104
+ if safety_ratings:
105
+ for rating in safety_ratings:
106
+ if getattr(rating, 'blocked', False):
107
+ category = getattr(rating, 'category', 'Unknown')
108
+ probability = getattr(rating, 'probability', 'Unknown')
109
+ safety_info.append(f'{category}: {probability}')
110
+
111
+ safety_details = (
112
+ ', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
113
+ )
114
+ raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
115
+
116
+ def _check_prompt_blocks(self, response) -> None:
117
+ """Check if prompt was blocked and raise appropriate exceptions."""
118
+ prompt_feedback = getattr(response, 'prompt_feedback', None)
119
+ if not prompt_feedback:
120
+ return
121
+
122
+ block_reason = getattr(prompt_feedback, 'block_reason', None)
123
+ if block_reason:
124
+ raise Exception(f'Prompt blocked by Gemini: {block_reason}')
125
+
126
+ def _get_model_for_size(self, model_size: ModelSize) -> str:
127
+ """Get the appropriate model name based on the requested size."""
128
+ if model_size == ModelSize.small:
129
+ return self.small_model or DEFAULT_SMALL_MODEL
130
+ else:
131
+ return self.model or DEFAULT_MODEL
79
132
 
80
133
  async def _generate_response(
81
134
  self,
@@ -91,17 +144,17 @@ class GeminiClient(LLMClient):
91
144
  messages (list[Message]): A list of messages to send to the language model.
92
145
  response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
93
146
  max_tokens (int): The maximum number of tokens to generate in the response.
147
+ model_size (ModelSize): The size of the model to use (small or medium).
94
148
 
95
149
  Returns:
96
150
  dict[str, typing.Any]: The response from the language model.
97
151
 
98
152
  Raises:
99
153
  RateLimitError: If the API rate limit is exceeded.
100
- RefusalError: If the content is blocked by the model.
101
- Exception: If there is an error generating the response.
154
+ Exception: If there is an error generating the response or content is blocked.
102
155
  """
103
156
  try:
104
- gemini_messages: list[types.Content] = []
157
+ gemini_messages: typing.Any = []
105
158
  # If a response model is provided, add schema for structured output
106
159
  system_prompt = ''
107
160
  if response_model is not None:
@@ -127,6 +180,9 @@ class GeminiClient(LLMClient):
127
180
  types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
128
181
  )
129
182
 
183
+ # Get the appropriate model for the requested size
184
+ model = self._get_model_for_size(model_size)
185
+
130
186
  # Create generation config
131
187
  generation_config = types.GenerateContentConfig(
132
188
  temperature=self.temperature,
@@ -134,15 +190,20 @@ class GeminiClient(LLMClient):
134
190
  response_mime_type='application/json' if response_model else None,
135
191
  response_schema=response_model if response_model else None,
136
192
  system_instruction=system_prompt,
193
+ thinking_config=self.thinking_config,
137
194
  )
138
195
 
139
196
  # Generate content using the simple string approach
140
197
  response = await self.client.aio.models.generate_content(
141
- model=self.model or DEFAULT_MODEL,
142
- contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
198
+ model=model,
199
+ contents=gemini_messages,
143
200
  config=generation_config,
144
201
  )
145
202
 
203
+ # Check for safety and prompt blocks
204
+ self._check_safety_blocks(response)
205
+ self._check_prompt_blocks(response)
206
+
146
207
  # If this was a structured output request, parse the response into the Pydantic model
147
208
  if response_model is not None:
148
209
  try:
@@ -160,9 +221,16 @@ class GeminiClient(LLMClient):
160
221
  return {'content': response.text}
161
222
 
162
223
  except Exception as e:
163
- # Check if it's a rate limit error
164
- if 'rate limit' in str(e).lower() or 'quota' in str(e).lower():
224
+ # Check if it's a rate limit error based on Gemini API error codes
225
+ error_message = str(e).lower()
226
+ if (
227
+ 'rate limit' in error_message
228
+ or 'quota' in error_message
229
+ or 'resource_exhausted' in error_message
230
+ or '429' in str(e)
231
+ ):
165
232
  raise RateLimitError from e
233
+
166
234
  logger.error(f'Error in generating LLM response: {e}')
167
235
  raise
168
236
 
@@ -174,13 +242,14 @@ class GeminiClient(LLMClient):
174
242
  model_size: ModelSize = ModelSize.medium,
175
243
  ) -> dict[str, typing.Any]:
176
244
  """
177
- Generate a response from the Gemini language model.
178
- This method overrides the parent class method to provide a direct implementation.
245
+ Generate a response from the Gemini language model with retry logic and error handling.
246
+ This method overrides the parent class method to provide a direct implementation with advanced retry logic.
179
247
 
180
248
  Args:
181
249
  messages (list[Message]): A list of messages to send to the language model.
182
250
  response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
183
- max_tokens (int): The maximum number of tokens to generate in the response.
251
+ max_tokens (int | None): The maximum number of tokens to generate in the response.
252
+ model_size (ModelSize): The size of the model to use (small or medium).
184
253
 
185
254
  Returns:
186
255
  dict[str, typing.Any]: The response from the language model.
@@ -188,10 +257,53 @@ class GeminiClient(LLMClient):
188
257
  if max_tokens is None:
189
258
  max_tokens = self.max_tokens
190
259
 
191
- # Call the internal _generate_response method
192
- return await self._generate_response(
193
- messages=messages,
194
- response_model=response_model,
195
- max_tokens=max_tokens,
196
- model_size=model_size,
197
- )
260
+ retry_count = 0
261
+ last_error = None
262
+
263
+ # Add multilingual extraction instructions
264
+ messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
265
+
266
+ while retry_count <= self.MAX_RETRIES:
267
+ try:
268
+ response = await self._generate_response(
269
+ messages=messages,
270
+ response_model=response_model,
271
+ max_tokens=max_tokens,
272
+ model_size=model_size,
273
+ )
274
+ return response
275
+ except RateLimitError:
276
+ # Rate limit errors should not trigger retries (fail fast)
277
+ raise
278
+ except Exception as e:
279
+ last_error = e
280
+
281
+ # Check if this is a safety block - these typically shouldn't be retried
282
+ if 'safety' in str(e).lower() or 'blocked' in str(e).lower():
283
+ logger.warning(f'Content blocked by safety filters: {e}')
284
+ raise
285
+
286
+ # Don't retry if we've hit the max retries
287
+ if retry_count >= self.MAX_RETRIES:
288
+ logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
289
+ raise
290
+
291
+ retry_count += 1
292
+
293
+ # Construct a detailed error message for the LLM
294
+ error_context = (
295
+ f'The previous response attempt was invalid. '
296
+ f'Error type: {e.__class__.__name__}. '
297
+ f'Error details: {str(e)}. '
298
+ f'Please try again with a valid response, ensuring the output matches '
299
+ f'the expected format and constraints.'
300
+ )
301
+
302
+ error_message = Message(role='user', content=error_context)
303
+ messages.append(error_message)
304
+ logger.warning(
305
+ f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
306
+ )
307
+
308
+ # If we somehow get here, raise the last error
309
+ raise last_error or Exception('Max retries exceeded with no specific error')
graphiti_core/nodes.py CHANGED
@@ -540,10 +540,18 @@ class CommunityNode(Node):
540
540
 
541
541
  # Node helpers
542
542
  def get_episodic_node_from_record(record: Any) -> EpisodicNode:
543
+ created_at = parse_db_date(record['created_at'])
544
+ valid_at = parse_db_date(record['valid_at'])
545
+
546
+ if created_at is None:
547
+ raise ValueError(f"created_at cannot be None for episode {record.get('uuid', 'unknown')}")
548
+ if valid_at is None:
549
+ raise ValueError(f"valid_at cannot be None for episode {record.get('uuid', 'unknown')}")
550
+
543
551
  return EpisodicNode(
544
552
  content=record['content'],
545
- created_at=parse_db_date(record['created_at']), # type: ignore
546
- valid_at=parse_db_date(record['valid_at']), # type: ignore
553
+ created_at=created_at,
554
+ valid_at=valid_at,
547
555
  uuid=record['uuid'],
548
556
  group_id=record['group_id'],
549
557
  source=EpisodeType.from_str(record['source']),
@@ -586,6 +594,8 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
586
594
 
587
595
 
588
596
  async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
597
+ if not nodes: # Handle empty list case
598
+ return
589
599
  name_embeddings = await embedder.create_batch([node.name for node in nodes])
590
600
  for node, name_embedding in zip(nodes, name_embeddings, strict=True):
591
601
  node.name_embedding = name_embedding
@@ -19,7 +19,6 @@ from enum import Enum
19
19
  from typing import Any
20
20
 
21
21
  from pydantic import BaseModel, Field
22
- from typing_extensions import LiteralString
23
22
 
24
23
 
25
24
  class ComparisonOperator(Enum):
@@ -53,8 +52,8 @@ class SearchFilters(BaseModel):
53
52
 
54
53
  def node_search_filter_query_constructor(
55
54
  filters: SearchFilters,
56
- ) -> tuple[LiteralString, dict[str, Any]]:
57
- filter_query: LiteralString = ''
55
+ ) -> tuple[str, dict[str, Any]]:
56
+ filter_query: str = ''
58
57
  filter_params: dict[str, Any] = {}
59
58
 
60
59
  if filters.node_labels is not None:
@@ -67,8 +66,8 @@ def node_search_filter_query_constructor(
67
66
 
68
67
  def edge_search_filter_query_constructor(
69
68
  filters: SearchFilters,
70
- ) -> tuple[LiteralString, dict[str, Any]]:
71
- filter_query: LiteralString = ''
69
+ ) -> tuple[str, dict[str, Any]]:
70
+ filter_query: str = ''
72
71
  filter_params: dict[str, Any] = {}
73
72
 
74
73
  if filters.edge_types is not None:
@@ -67,7 +67,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
67
67
  )
68
68
  group_ids_filter = ''
69
69
  for f in group_ids_filter_list:
70
- group_ids_filter += f if not group_ids_filter else f'OR {f}'
70
+ group_ids_filter += f if not group_ids_filter else f' OR {f}'
71
71
 
72
72
  group_ids_filter += ' AND ' if group_ids_filter else ''
73
73
 
@@ -278,9 +278,6 @@ async def edge_similarity_search(
278
278
  routing_='r',
279
279
  )
280
280
 
281
- if driver.provider == 'falkordb':
282
- records = [dict(zip(header, row, strict=True)) for row in records]
283
-
284
281
  edges = [get_entity_edge_from_record(record) for record in records]
285
282
 
286
283
  return edges
@@ -377,8 +374,6 @@ async def node_fulltext_search(
377
374
  database_=DEFAULT_DATABASE,
378
375
  routing_='r',
379
376
  )
380
- if driver.provider == 'falkordb':
381
- records = [dict(zip(header, row, strict=True)) for row in records]
382
377
 
383
378
  nodes = [get_entity_node_from_record(record) for record in records]
384
379
 
@@ -433,8 +428,7 @@ async def node_similarity_search(
433
428
  database_=DEFAULT_DATABASE,
434
429
  routing_='r',
435
430
  )
436
- if driver.provider == 'falkordb':
437
- records = [dict(zip(header, row, strict=True)) for row in records]
431
+
438
432
  nodes = [get_entity_node_from_record(record) for record in records]
439
433
 
440
434
  return nodes
@@ -0,0 +1,9 @@
1
+ """
2
+ Telemetry module for Graphiti.
3
+
4
+ This module provides anonymous usage analytics to help improve Graphiti.
5
+ """
6
+
7
+ from .telemetry import capture_event, is_telemetry_enabled
8
+
9
+ __all__ = ['capture_event', 'is_telemetry_enabled']
@@ -0,0 +1,117 @@
1
+ """
2
+ Telemetry client for Graphiti.
3
+
4
+ Collects anonymous usage statistics to help improve the product.
5
+ """
6
+
7
+ import contextlib
8
+ import os
9
+ import platform
10
+ import sys
11
+ import uuid
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ # PostHog configuration
16
+ # Note: This is a public API key intended for client-side use and safe to commit
17
+ # PostHog public keys are designed to be exposed in client applications
18
+ POSTHOG_API_KEY = 'phc_UG6EcfDbuXz92neb3rMlQFDY0csxgMqRcIPWESqnSmo'
19
+ POSTHOG_HOST = 'https://us.i.posthog.com'
20
+
21
+ # Environment variable to control telemetry
22
+ TELEMETRY_ENV_VAR = 'GRAPHITI_TELEMETRY_ENABLED'
23
+
24
+ # Cache directory for anonymous ID
25
+ CACHE_DIR = Path.home() / '.cache' / 'graphiti'
26
+ ANON_ID_FILE = CACHE_DIR / 'telemetry_anon_id'
27
+
28
+
29
+ def is_telemetry_enabled() -> bool:
30
+ """Check if telemetry is enabled."""
31
+ # Disable during pytest runs
32
+ if 'pytest' in sys.modules:
33
+ return False
34
+
35
+ # Check environment variable (default: enabled)
36
+ env_value = os.environ.get(TELEMETRY_ENV_VAR, 'true').lower()
37
+ return env_value in ('true', '1', 'yes', 'on')
38
+
39
+
40
+ def get_anonymous_id() -> str:
41
+ """Get or create anonymous user ID."""
42
+ try:
43
+ # Create cache directory if it doesn't exist
44
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
45
+
46
+ # Try to read existing ID
47
+ if ANON_ID_FILE.exists():
48
+ try:
49
+ return ANON_ID_FILE.read_text().strip()
50
+ except Exception:
51
+ pass
52
+
53
+ # Generate new ID
54
+ anon_id = str(uuid.uuid4())
55
+
56
+ # Save to file
57
+ with contextlib.suppress(Exception):
58
+ ANON_ID_FILE.write_text(anon_id)
59
+
60
+ return anon_id
61
+ except Exception:
62
+ return 'UNKNOWN'
63
+
64
+
65
+ def get_graphiti_version() -> str:
66
+ """Get Graphiti version."""
67
+ try:
68
+ # Try to get version from package metadata
69
+ import importlib.metadata
70
+
71
+ return importlib.metadata.version('graphiti-core')
72
+ except Exception:
73
+ return 'unknown'
74
+
75
+
76
+ def initialize_posthog():
77
+ """Initialize PostHog client."""
78
+ try:
79
+ import posthog
80
+
81
+ posthog.api_key = POSTHOG_API_KEY
82
+ posthog.host = POSTHOG_HOST
83
+ return posthog
84
+ except ImportError:
85
+ # PostHog not installed, silently disable telemetry
86
+ return None
87
+ except Exception:
88
+ # Any other error, silently disable telemetry
89
+ return None
90
+
91
+
92
+ def capture_event(event_name: str, properties: dict[str, Any] | None = None) -> None:
93
+ """Capture a telemetry event."""
94
+ if not is_telemetry_enabled():
95
+ return
96
+
97
+ try:
98
+ posthog_client = initialize_posthog()
99
+ if posthog_client is None:
100
+ return
101
+
102
+ # Get anonymous ID
103
+ user_id = get_anonymous_id()
104
+
105
+ # Prepare event properties
106
+ event_properties = {
107
+ '$process_person_profile': False,
108
+ 'graphiti_version': get_graphiti_version(),
109
+ 'architecture': platform.machine(),
110
+ **(properties or {}),
111
+ }
112
+
113
+ # Capture the event
114
+ posthog_client.capture(distinct_id=user_id, event=event_name, properties=event_properties)
115
+ except Exception:
116
+ # Silently handle all telemetry errors to avoid disrupting the main application
117
+ pass
@@ -177,11 +177,14 @@ async def add_nodes_and_edges_bulk_tx(
177
177
 
178
178
 
179
179
  async def extract_nodes_and_edges_bulk(
180
- clients: GraphitiClients, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
180
+ clients: GraphitiClients,
181
+ episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
182
+ entity_types: dict[str, BaseModel] | None = None,
183
+ excluded_entity_types: list[str] | None = None,
181
184
  ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
182
185
  extracted_nodes_bulk = await semaphore_gather(
183
186
  *[
184
- extract_nodes(clients, episode, previous_episodes)
187
+ extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
185
188
  for episode, previous_episodes in episode_tuples
186
189
  ]
187
190
  )
@@ -40,7 +40,7 @@ async def get_community_clusters(
40
40
  database_=DEFAULT_DATABASE,
41
41
  )
42
42
 
43
- group_ids = group_id_values[0]['group_ids']
43
+ group_ids = group_id_values[0]['group_ids'] if group_id_values else []
44
44
 
45
45
  for group_id in group_ids:
46
46
  projection: dict[str, list[Neighbor]] = {}
@@ -297,7 +297,7 @@ async def resolve_extracted_edges(
297
297
  embedder = clients.embedder
298
298
  await create_entity_edge_embeddings(embedder, extracted_edges)
299
299
 
300
- search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
300
+ search_results = await semaphore_gather(
301
301
  get_relevant_edges(driver, extracted_edges, SearchFilters()),
302
302
  get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
303
303
  )
@@ -21,7 +21,7 @@ from typing_extensions import LiteralString
21
21
 
22
22
  from graphiti_core.driver.driver import GraphDriver
23
23
  from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
24
- from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
24
+ from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date, semaphore_gather
25
25
  from graphiti_core.nodes import EpisodeType, EpisodicNode
26
26
 
27
27
  EPISODE_WINDOW_LEN = 3
@@ -140,10 +140,8 @@ async def retrieve_episodes(
140
140
  episodes = [
141
141
  EpisodicNode(
142
142
  content=record['content'],
143
- created_at=datetime.fromtimestamp(
144
- record['created_at'].to_native().timestamp(), timezone.utc
145
- ),
146
- valid_at=(record['valid_at'].to_native()),
143
+ created_at=parse_db_date(record['created_at']) or datetime.min.replace(tzinfo=timezone.utc),
144
+ valid_at=parse_db_date(record['valid_at']) or datetime.min.replace(tzinfo=timezone.utc),
147
145
  uuid=record['uuid'],
148
146
  group_id=record['group_id'],
149
147
  source=EpisodeType.from_str(record['source']),
@@ -71,6 +71,7 @@ async def extract_nodes(
71
71
  episode: EpisodicNode,
72
72
  previous_episodes: list[EpisodicNode],
73
73
  entity_types: dict[str, BaseModel] | None = None,
74
+ excluded_entity_types: list[str] | None = None,
74
75
  ) -> list[EntityNode]:
75
76
  start = time()
76
77
  llm_client = clients.llm_client
@@ -154,6 +155,11 @@ async def extract_nodes(
154
155
  'entity_type_name'
155
156
  )
156
157
 
158
+ # Check if this entity type should be excluded
159
+ if excluded_entity_types and entity_type_name in excluded_entity_types:
160
+ logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
161
+ continue
162
+
157
163
  labels: list[str] = list({'Entity', str(entity_type_name)})
158
164
 
159
165
  new_node = EntityNode(