graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
  2. graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
  3. graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
  4. graphiti_core/decorators.py +110 -0
  5. graphiti_core/driver/__init__.py +19 -0
  6. graphiti_core/driver/driver.py +124 -0
  7. graphiti_core/driver/falkordb_driver.py +362 -0
  8. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  9. graphiti_core/driver/kuzu_driver.py +182 -0
  10. graphiti_core/driver/neo4j_driver.py +117 -0
  11. graphiti_core/driver/neptune_driver.py +305 -0
  12. graphiti_core/driver/search_interface/search_interface.py +89 -0
  13. graphiti_core/edges.py +287 -172
  14. graphiti_core/embedder/azure_openai.py +71 -0
  15. graphiti_core/embedder/client.py +2 -1
  16. graphiti_core/embedder/gemini.py +116 -22
  17. graphiti_core/embedder/voyage.py +13 -2
  18. graphiti_core/errors.py +8 -0
  19. graphiti_core/graph_queries.py +162 -0
  20. graphiti_core/graphiti.py +705 -193
  21. graphiti_core/graphiti_types.py +4 -2
  22. graphiti_core/helpers.py +87 -10
  23. graphiti_core/llm_client/__init__.py +16 -0
  24. graphiti_core/llm_client/anthropic_client.py +159 -56
  25. graphiti_core/llm_client/azure_openai_client.py +115 -0
  26. graphiti_core/llm_client/client.py +98 -21
  27. graphiti_core/llm_client/config.py +1 -1
  28. graphiti_core/llm_client/gemini_client.py +290 -41
  29. graphiti_core/llm_client/groq_client.py +14 -3
  30. graphiti_core/llm_client/openai_base_client.py +261 -0
  31. graphiti_core/llm_client/openai_client.py +56 -132
  32. graphiti_core/llm_client/openai_generic_client.py +91 -56
  33. graphiti_core/models/edges/edge_db_queries.py +259 -35
  34. graphiti_core/models/nodes/node_db_queries.py +311 -32
  35. graphiti_core/nodes.py +420 -205
  36. graphiti_core/prompts/dedupe_edges.py +46 -32
  37. graphiti_core/prompts/dedupe_nodes.py +67 -42
  38. graphiti_core/prompts/eval.py +4 -4
  39. graphiti_core/prompts/extract_edges.py +27 -16
  40. graphiti_core/prompts/extract_nodes.py +74 -31
  41. graphiti_core/prompts/prompt_helpers.py +39 -0
  42. graphiti_core/prompts/snippets.py +29 -0
  43. graphiti_core/prompts/summarize_nodes.py +23 -25
  44. graphiti_core/search/search.py +158 -82
  45. graphiti_core/search/search_config.py +39 -4
  46. graphiti_core/search/search_filters.py +126 -35
  47. graphiti_core/search/search_helpers.py +5 -6
  48. graphiti_core/search/search_utils.py +1405 -485
  49. graphiti_core/telemetry/__init__.py +9 -0
  50. graphiti_core/telemetry/telemetry.py +117 -0
  51. graphiti_core/tracer.py +193 -0
  52. graphiti_core/utils/bulk_utils.py +364 -285
  53. graphiti_core/utils/datetime_utils.py +13 -0
  54. graphiti_core/utils/maintenance/community_operations.py +67 -49
  55. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  56. graphiti_core/utils/maintenance/edge_operations.py +339 -197
  57. graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
  58. graphiti_core/utils/maintenance/node_operations.py +319 -238
  59. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  60. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  61. graphiti_core/utils/text_utils.py +53 -0
  62. graphiti_core-0.24.3.dist-info/METADATA +726 -0
  63. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  64. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  65. graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
  66. graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
  67. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  68. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
@@ -26,15 +26,34 @@ from pydantic import BaseModel
26
26
  from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
27
27
 
28
28
  from ..prompts.models import Message
29
+ from ..tracer import NoOpTracer, Tracer
29
30
  from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
30
31
  from .errors import RateLimitError
31
32
 
32
33
  DEFAULT_TEMPERATURE = 0
33
34
  DEFAULT_CACHE_DIR = './llm_cache'
34
35
 
35
- MULTILINGUAL_EXTRACTION_RESPONSES = (
36
- '\n\nAny extracted information should be returned in the same language as it was written in.'
37
- )
36
+
37
+ def get_extraction_language_instruction(group_id: str | None = None) -> str:
38
+ """Returns instruction for language extraction behavior.
39
+
40
+ Override this function to customize language extraction:
41
+ - Return empty string to disable multilingual instructions
42
+ - Return custom instructions for specific language requirements
43
+ - Use group_id to provide different instructions per group/partition
44
+
45
+ Args:
46
+ group_id: Optional partition identifier for the graph
47
+
48
+ Returns:
49
+ str: Language instruction to append to system messages
50
+ """
51
+ return (
52
+ '\n\nAny extracted information should be returned in the same language as it was written in. '
53
+ 'Only output non-English text when the user has written full sentences or phrases in that non-English language. '
54
+ 'Otherwise, output English.'
55
+ )
56
+
38
57
 
39
58
  logger = logging.getLogger(__name__)
40
59
 
@@ -60,11 +79,16 @@ class LLMClient(ABC):
60
79
  self.max_tokens = config.max_tokens
61
80
  self.cache_enabled = cache
62
81
  self.cache_dir = None
82
+ self.tracer: Tracer = NoOpTracer()
63
83
 
64
84
  # Only create the cache directory if caching is enabled
65
85
  if self.cache_enabled:
66
86
  self.cache_dir = Cache(DEFAULT_CACHE_DIR)
67
87
 
88
+ def set_tracer(self, tracer: Tracer) -> None:
89
+ """Set the tracer for this LLM client."""
90
+ self.tracer = tracer
91
+
68
92
  def _clean_input(self, input: str) -> str:
69
93
  """Clean input string of invalid unicode and control characters.
70
94
 
@@ -132,6 +156,8 @@ class LLMClient(ABC):
132
156
  response_model: type[BaseModel] | None = None,
133
157
  max_tokens: int | None = None,
134
158
  model_size: ModelSize = ModelSize.medium,
159
+ group_id: str | None = None,
160
+ prompt_name: str | None = None,
135
161
  ) -> dict[str, typing.Any]:
136
162
  if max_tokens is None:
137
163
  max_tokens = self.max_tokens
@@ -145,25 +171,76 @@ class LLMClient(ABC):
145
171
  )
146
172
 
147
173
  # Add multilingual extraction instructions
148
- messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
149
-
150
- if self.cache_enabled and self.cache_dir is not None:
151
- cache_key = self._get_cache_key(messages)
152
-
153
- cached_response = self.cache_dir.get(cache_key)
154
- if cached_response is not None:
155
- logger.debug(f'Cache hit for {cache_key}')
156
- return cached_response
174
+ messages[0].content += get_extraction_language_instruction(group_id)
157
175
 
158
176
  for message in messages:
159
177
  message.content = self._clean_input(message.content)
160
178
 
161
- response = await self._generate_response_with_retry(
162
- messages, response_model, max_tokens, model_size
163
- )
164
-
165
- if self.cache_enabled and self.cache_dir is not None:
166
- cache_key = self._get_cache_key(messages)
167
- self.cache_dir.set(cache_key, response)
168
-
169
- return response
179
+ # Wrap entire operation in tracing span
180
+ with self.tracer.start_span('llm.generate') as span:
181
+ attributes = {
182
+ 'llm.provider': self._get_provider_type(),
183
+ 'model.size': model_size.value,
184
+ 'max_tokens': max_tokens,
185
+ 'cache.enabled': self.cache_enabled,
186
+ }
187
+ if prompt_name:
188
+ attributes['prompt.name'] = prompt_name
189
+ span.add_attributes(attributes)
190
+
191
+ # Check cache first
192
+ if self.cache_enabled and self.cache_dir is not None:
193
+ cache_key = self._get_cache_key(messages)
194
+ cached_response = self.cache_dir.get(cache_key)
195
+ if cached_response is not None:
196
+ logger.debug(f'Cache hit for {cache_key}')
197
+ span.add_attributes({'cache.hit': True})
198
+ return cached_response
199
+
200
+ span.add_attributes({'cache.hit': False})
201
+
202
+ # Execute LLM call
203
+ try:
204
+ response = await self._generate_response_with_retry(
205
+ messages, response_model, max_tokens, model_size
206
+ )
207
+ except Exception as e:
208
+ span.set_status('error', str(e))
209
+ span.record_exception(e)
210
+ raise
211
+
212
+ # Cache response if enabled
213
+ if self.cache_enabled and self.cache_dir is not None:
214
+ cache_key = self._get_cache_key(messages)
215
+ self.cache_dir.set(cache_key, response)
216
+
217
+ return response
218
+
219
+ def _get_provider_type(self) -> str:
220
+ """Get provider type from class name."""
221
+ class_name = self.__class__.__name__.lower()
222
+ if 'openai' in class_name:
223
+ return 'openai'
224
+ elif 'anthropic' in class_name:
225
+ return 'anthropic'
226
+ elif 'gemini' in class_name:
227
+ return 'gemini'
228
+ elif 'groq' in class_name:
229
+ return 'groq'
230
+ else:
231
+ return 'unknown'
232
+
233
+ def _get_failed_generation_log(self, messages: list[Message], output: str | None) -> str:
234
+ """
235
+ Log the full input messages, the raw output (if any), and the exception for debugging failed generations.
236
+ """
237
+ log = ''
238
+ log += f'Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n'
239
+ if output is not None:
240
+ if len(output) > 4000:
241
+ log += f'Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n'
242
+ else:
243
+ log += f'Raw output: {output}\n'
244
+ else:
245
+ log += 'No raw output available'
246
+ return log
@@ -17,7 +17,7 @@ limitations under the License.
17
17
  from enum import Enum
18
18
 
19
19
  DEFAULT_MAX_TOKENS = 8192
20
- DEFAULT_TEMPERATURE = 0
20
+ DEFAULT_TEMPERATURE = 1
21
21
 
22
22
 
23
23
  class ModelSize(Enum):
@@ -16,20 +16,54 @@ limitations under the License.
16
16
 
17
17
  import json
18
18
  import logging
19
+ import re
19
20
  import typing
21
+ from typing import TYPE_CHECKING, ClassVar
20
22
 
21
- from google import genai # type: ignore
22
- from google.genai import types # type: ignore
23
23
  from pydantic import BaseModel
24
24
 
25
25
  from ..prompts.models import Message
26
- from .client import LLMClient
27
- from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
26
+ from .client import LLMClient, get_extraction_language_instruction
27
+ from .config import LLMConfig, ModelSize
28
28
  from .errors import RateLimitError
29
29
 
30
+ if TYPE_CHECKING:
31
+ from google import genai
32
+ from google.genai import types
33
+ else:
34
+ try:
35
+ from google import genai
36
+ from google.genai import types
37
+ except ImportError:
38
+ # If gemini client is not installed, raise an ImportError
39
+ raise ImportError(
40
+ 'google-genai is required for GeminiClient. '
41
+ 'Install it with: pip install graphiti-core[google-genai]'
42
+ ) from None
43
+
44
+
30
45
  logger = logging.getLogger(__name__)
31
46
 
32
- DEFAULT_MODEL = 'gemini-2.0-flash'
47
+ DEFAULT_MODEL = 'gemini-2.5-flash'
48
+ DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite'
49
+
50
+ # Maximum output tokens for different Gemini models
51
+ GEMINI_MODEL_MAX_TOKENS = {
52
+ # Gemini 2.5 models
53
+ 'gemini-2.5-pro': 65536,
54
+ 'gemini-2.5-flash': 65536,
55
+ 'gemini-2.5-flash-lite': 64000,
56
+ # Gemini 2.0 models
57
+ 'gemini-2.0-flash': 8192,
58
+ 'gemini-2.0-flash-lite': 8192,
59
+ # Gemini 1.5 models
60
+ 'gemini-1.5-pro': 8192,
61
+ 'gemini-1.5-flash': 8192,
62
+ 'gemini-1.5-flash-8b': 8192,
63
+ }
64
+
65
+ # Default max tokens for models not in the mapping
66
+ DEFAULT_GEMINI_MAX_TOKENS = 8192
33
67
 
34
68
 
35
69
  class GeminiClient(LLMClient):
@@ -43,27 +77,35 @@ class GeminiClient(LLMClient):
43
77
  model (str): The model name to use for generating responses.
44
78
  temperature (float): The temperature to use for generating responses.
45
79
  max_tokens (int): The maximum number of tokens to generate in a response.
46
-
80
+ thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
47
81
  Methods:
48
- __init__(config: LLMConfig | None = None, cache: bool = False):
49
- Initializes the GeminiClient with the provided configuration and cache setting.
82
+ __init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
83
+ Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
50
84
 
51
85
  _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
52
86
  Generates a response from the language model based on the provided messages.
53
87
  """
54
88
 
89
+ # Class-level constants
90
+ MAX_RETRIES: ClassVar[int] = 2
91
+
55
92
  def __init__(
56
93
  self,
57
94
  config: LLMConfig | None = None,
58
95
  cache: bool = False,
59
- max_tokens: int = DEFAULT_MAX_TOKENS,
96
+ max_tokens: int | None = None,
97
+ thinking_config: types.ThinkingConfig | None = None,
98
+ client: 'genai.Client | None' = None,
60
99
  ):
61
100
  """
62
- Initialize the GeminiClient with the provided configuration and cache setting.
101
+ Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
63
102
 
64
103
  Args:
65
104
  config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
66
105
  cache (bool): Whether to use caching for responses. Defaults to False.
106
+ thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
107
+ Only use with models that support thinking (gemini-2.5+). Defaults to None.
108
+ client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
67
109
  """
68
110
  if config is None:
69
111
  config = LLMConfig()
@@ -71,17 +113,128 @@ class GeminiClient(LLMClient):
71
113
  super().__init__(config, cache)
72
114
 
73
115
  self.model = config.model
74
- # Configure the Gemini API
75
- self.client = genai.Client(
76
- api_key=config.api_key,
77
- )
116
+
117
+ if client is None:
118
+ self.client = genai.Client(api_key=config.api_key)
119
+ else:
120
+ self.client = client
121
+
78
122
  self.max_tokens = max_tokens
123
+ self.thinking_config = thinking_config
124
+
125
+ def _check_safety_blocks(self, response) -> None:
126
+ """Check if response was blocked for safety reasons and raise appropriate exceptions."""
127
+ # Check if the response was blocked for safety reasons
128
+ if not (hasattr(response, 'candidates') and response.candidates):
129
+ return
130
+
131
+ candidate = response.candidates[0]
132
+ if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
133
+ return
134
+
135
+ # Content was blocked for safety reasons - collect safety details
136
+ safety_info = []
137
+ safety_ratings = getattr(candidate, 'safety_ratings', None)
138
+
139
+ if safety_ratings:
140
+ for rating in safety_ratings:
141
+ if getattr(rating, 'blocked', False):
142
+ category = getattr(rating, 'category', 'Unknown')
143
+ probability = getattr(rating, 'probability', 'Unknown')
144
+ safety_info.append(f'{category}: {probability}')
145
+
146
+ safety_details = (
147
+ ', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
148
+ )
149
+ raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
150
+
151
+ def _check_prompt_blocks(self, response) -> None:
152
+ """Check if prompt was blocked and raise appropriate exceptions."""
153
+ prompt_feedback = getattr(response, 'prompt_feedback', None)
154
+ if not prompt_feedback:
155
+ return
156
+
157
+ block_reason = getattr(prompt_feedback, 'block_reason', None)
158
+ if block_reason:
159
+ raise Exception(f'Prompt blocked by Gemini: {block_reason}')
160
+
161
+ def _get_model_for_size(self, model_size: ModelSize) -> str:
162
+ """Get the appropriate model name based on the requested size."""
163
+ if model_size == ModelSize.small:
164
+ return self.small_model or DEFAULT_SMALL_MODEL
165
+ else:
166
+ return self.model or DEFAULT_MODEL
167
+
168
+ def _get_max_tokens_for_model(self, model: str) -> int:
169
+ """Get the maximum output tokens for a specific Gemini model."""
170
+ return GEMINI_MODEL_MAX_TOKENS.get(model, DEFAULT_GEMINI_MAX_TOKENS)
171
+
172
+ def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
173
+ """
174
+ Resolve the maximum output tokens to use based on precedence rules.
175
+
176
+ Precedence order (highest to lowest):
177
+ 1. Explicit max_tokens parameter passed to generate_response()
178
+ 2. Instance max_tokens set during client initialization
179
+ 3. Model-specific maximum tokens from GEMINI_MODEL_MAX_TOKENS mapping
180
+ 4. DEFAULT_MAX_TOKENS as final fallback
181
+
182
+ Args:
183
+ requested_max_tokens: The max_tokens parameter passed to generate_response()
184
+ model: The model name to look up model-specific limits
185
+
186
+ Returns:
187
+ int: The resolved maximum tokens to use
188
+ """
189
+ # 1. Use explicit parameter if provided
190
+ if requested_max_tokens is not None:
191
+ return requested_max_tokens
192
+
193
+ # 2. Use instance max_tokens if set during initialization
194
+ if self.max_tokens is not None:
195
+ return self.max_tokens
196
+
197
+ # 3. Use model-specific maximum or return DEFAULT_GEMINI_MAX_TOKENS
198
+ return self._get_max_tokens_for_model(model)
199
+
200
+ def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
201
+ """
202
+ Attempt to salvage a JSON object if the raw output is truncated.
203
+
204
+ This is accomplished by looking for the last closing bracket for an array or object.
205
+ If found, it will try to load the JSON object from the raw output.
206
+ If the JSON object is not valid, it will return None.
207
+
208
+ Args:
209
+ raw_output (str): The raw output from the LLM.
210
+
211
+ Returns:
212
+ dict[str, typing.Any]: The salvaged JSON object.
213
+ None: If no salvage is possible.
214
+ """
215
+ if not raw_output:
216
+ return None
217
+ # Try to salvage a JSON array
218
+ array_match = re.search(r'\]\s*$', raw_output)
219
+ if array_match:
220
+ try:
221
+ return json.loads(raw_output[: array_match.end()])
222
+ except Exception:
223
+ pass
224
+ # Try to salvage a JSON object
225
+ obj_match = re.search(r'\}\s*$', raw_output)
226
+ if obj_match:
227
+ try:
228
+ return json.loads(raw_output[: obj_match.end()])
229
+ except Exception:
230
+ pass
231
+ return None
79
232
 
80
233
  async def _generate_response(
81
234
  self,
82
235
  messages: list[Message],
83
236
  response_model: type[BaseModel] | None = None,
84
- max_tokens: int = DEFAULT_MAX_TOKENS,
237
+ max_tokens: int | None = None,
85
238
  model_size: ModelSize = ModelSize.medium,
86
239
  ) -> dict[str, typing.Any]:
87
240
  """
@@ -90,18 +243,18 @@ class GeminiClient(LLMClient):
90
243
  Args:
91
244
  messages (list[Message]): A list of messages to send to the language model.
92
245
  response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
93
- max_tokens (int): The maximum number of tokens to generate in the response.
246
+ max_tokens (int | None): The maximum number of tokens to generate in the response. If None, uses precedence rules.
247
+ model_size (ModelSize): The size of the model to use (small or medium).
94
248
 
95
249
  Returns:
96
250
  dict[str, typing.Any]: The response from the language model.
97
251
 
98
252
  Raises:
99
253
  RateLimitError: If the API rate limit is exceeded.
100
- RefusalError: If the content is blocked by the model.
101
- Exception: If there is an error generating the response.
254
+ Exception: If there is an error generating the response or content is blocked.
102
255
  """
103
256
  try:
104
- gemini_messages: list[types.Content] = []
257
+ gemini_messages: typing.Any = []
105
258
  # If a response model is provided, add schema for structured output
106
259
  system_prompt = ''
107
260
  if response_model is not None:
@@ -127,44 +280,75 @@ class GeminiClient(LLMClient):
127
280
  types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
128
281
  )
129
282
 
283
+ # Get the appropriate model for the requested size
284
+ model = self._get_model_for_size(model_size)
285
+
286
+ # Resolve max_tokens using precedence rules (see _resolve_max_tokens for details)
287
+ resolved_max_tokens = self._resolve_max_tokens(max_tokens, model)
288
+
130
289
  # Create generation config
131
290
  generation_config = types.GenerateContentConfig(
132
291
  temperature=self.temperature,
133
- max_output_tokens=max_tokens or self.max_tokens,
292
+ max_output_tokens=resolved_max_tokens,
134
293
  response_mime_type='application/json' if response_model else None,
135
294
  response_schema=response_model if response_model else None,
136
295
  system_instruction=system_prompt,
296
+ thinking_config=self.thinking_config,
137
297
  )
138
298
 
139
299
  # Generate content using the simple string approach
140
300
  response = await self.client.aio.models.generate_content(
141
- model=self.model or DEFAULT_MODEL,
142
- contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
301
+ model=model,
302
+ contents=gemini_messages,
143
303
  config=generation_config,
144
304
  )
145
305
 
306
+ # Always capture the raw output for debugging
307
+ raw_output = getattr(response, 'text', None)
308
+
309
+ # Check for safety and prompt blocks
310
+ self._check_safety_blocks(response)
311
+ self._check_prompt_blocks(response)
312
+
146
313
  # If this was a structured output request, parse the response into the Pydantic model
147
314
  if response_model is not None:
148
315
  try:
149
- if not response.text:
316
+ if not raw_output:
150
317
  raise ValueError('No response text')
151
318
 
152
- validated_model = response_model.model_validate(json.loads(response.text))
319
+ validated_model = response_model.model_validate(json.loads(raw_output))
153
320
 
154
321
  # Return as a dictionary for API consistency
155
322
  return validated_model.model_dump()
156
323
  except Exception as e:
324
+ if raw_output:
325
+ logger.error(
326
+ '🦀 LLM generation failed parsing as JSON, will try to salvage.'
327
+ )
328
+ logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
329
+ # Try to salvage
330
+ salvaged = self.salvage_json(raw_output)
331
+ if salvaged is not None:
332
+ logger.warning('Salvaged partial JSON from truncated/malformed output.')
333
+ return salvaged
157
334
  raise Exception(f'Failed to parse structured response: {e}') from e
158
335
 
159
336
  # Otherwise, return the response text as a dictionary
160
- return {'content': response.text}
337
+ return {'content': raw_output}
161
338
 
162
339
  except Exception as e:
163
- # Check if it's a rate limit error
164
- if 'rate limit' in str(e).lower() or 'quota' in str(e).lower():
340
+ # Check if it's a rate limit error based on Gemini API error codes
341
+ error_message = str(e).lower()
342
+ if (
343
+ 'rate limit' in error_message
344
+ or 'quota' in error_message
345
+ or 'resource_exhausted' in error_message
346
+ or '429' in str(e)
347
+ ):
165
348
  raise RateLimitError from e
349
+
166
350
  logger.error(f'Error in generating LLM response: {e}')
167
- raise
351
+ raise Exception from e
168
352
 
169
353
  async def generate_response(
170
354
  self,
@@ -172,26 +356,91 @@ class GeminiClient(LLMClient):
172
356
  response_model: type[BaseModel] | None = None,
173
357
  max_tokens: int | None = None,
174
358
  model_size: ModelSize = ModelSize.medium,
359
+ group_id: str | None = None,
360
+ prompt_name: str | None = None,
175
361
  ) -> dict[str, typing.Any]:
176
362
  """
177
- Generate a response from the Gemini language model.
178
- This method overrides the parent class method to provide a direct implementation.
363
+ Generate a response from the Gemini language model with retry logic and error handling.
364
+ This method overrides the parent class method to provide a direct implementation with advanced retry logic.
179
365
 
180
366
  Args:
181
367
  messages (list[Message]): A list of messages to send to the language model.
182
368
  response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
183
- max_tokens (int): The maximum number of tokens to generate in the response.
369
+ max_tokens (int | None): The maximum number of tokens to generate in the response.
370
+ model_size (ModelSize): The size of the model to use (small or medium).
371
+ group_id (str | None): Optional partition identifier for the graph.
372
+ prompt_name (str | None): Optional name of the prompt for tracing.
184
373
 
185
374
  Returns:
186
375
  dict[str, typing.Any]: The response from the language model.
187
376
  """
188
- if max_tokens is None:
189
- max_tokens = self.max_tokens
190
-
191
- # Call the internal _generate_response method
192
- return await self._generate_response(
193
- messages=messages,
194
- response_model=response_model,
195
- max_tokens=max_tokens,
196
- model_size=model_size,
197
- )
377
+ # Add multilingual extraction instructions
378
+ messages[0].content += get_extraction_language_instruction(group_id)
379
+
380
+ # Wrap entire operation in tracing span
381
+ with self.tracer.start_span('llm.generate') as span:
382
+ attributes = {
383
+ 'llm.provider': 'gemini',
384
+ 'model.size': model_size.value,
385
+ 'max_tokens': max_tokens or self.max_tokens,
386
+ }
387
+ if prompt_name:
388
+ attributes['prompt.name'] = prompt_name
389
+ span.add_attributes(attributes)
390
+
391
+ retry_count = 0
392
+ last_error = None
393
+ last_output = None
394
+
395
+ while retry_count < self.MAX_RETRIES:
396
+ try:
397
+ response = await self._generate_response(
398
+ messages=messages,
399
+ response_model=response_model,
400
+ max_tokens=max_tokens,
401
+ model_size=model_size,
402
+ )
403
+ last_output = (
404
+ response.get('content')
405
+ if isinstance(response, dict) and 'content' in response
406
+ else None
407
+ )
408
+ return response
409
+ except RateLimitError as e:
410
+ # Rate limit errors should not trigger retries (fail fast)
411
+ span.set_status('error', str(e))
412
+ raise e
413
+ except Exception as e:
414
+ last_error = e
415
+
416
+ # Check if this is a safety block - these typically shouldn't be retried
417
+ error_text = str(e) or (str(e.__cause__) if e.__cause__ else '')
418
+ if 'safety' in error_text.lower() or 'blocked' in error_text.lower():
419
+ logger.warning(f'Content blocked by safety filters: {e}')
420
+ span.set_status('error', str(e))
421
+ raise Exception(f'Content blocked by safety filters: {e}') from e
422
+
423
+ retry_count += 1
424
+
425
+ # Construct a detailed error message for the LLM
426
+ error_context = (
427
+ f'The previous response attempt was invalid. '
428
+ f'Error type: {e.__class__.__name__}. '
429
+ f'Error details: {str(e)}. '
430
+ f'Please try again with a valid response, ensuring the output matches '
431
+ f'the expected format and constraints.'
432
+ )
433
+
434
+ error_message = Message(role='user', content=error_context)
435
+ messages.append(error_message)
436
+ logger.warning(
437
+ f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
438
+ )
439
+
440
+ # If we exit the loop without returning, all retries are exhausted
441
+ logger.error('🦀 LLM generation failed and retries are exhausted.')
442
+ logger.error(self._get_failed_generation_log(messages, last_output))
443
+ logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
444
+ span.set_status('error', str(last_error))
445
+ span.record_exception(last_error) if last_error else None
446
+ raise last_error or Exception('Max retries exceeded')
@@ -17,10 +17,21 @@ limitations under the License.
17
17
  import json
18
18
  import logging
19
19
  import typing
20
+ from typing import TYPE_CHECKING
20
21
 
21
- import groq
22
- from groq import AsyncGroq
23
- from groq.types.chat import ChatCompletionMessageParam
22
+ if TYPE_CHECKING:
23
+ import groq
24
+ from groq import AsyncGroq
25
+ from groq.types.chat import ChatCompletionMessageParam
26
+ else:
27
+ try:
28
+ import groq
29
+ from groq import AsyncGroq
30
+ from groq.types.chat import ChatCompletionMessageParam
31
+ except ImportError:
32
+ raise ImportError(
33
+ 'groq is required for GroqClient. Install it with: pip install graphiti-core[groq]'
34
+ ) from None
24
35
  from pydantic import BaseModel
25
36
 
26
37
  from ..prompts.models import Message