graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
|
@@ -26,15 +26,34 @@ from pydantic import BaseModel
|
|
|
26
26
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
|
27
27
|
|
|
28
28
|
from ..prompts.models import Message
|
|
29
|
+
from ..tracer import NoOpTracer, Tracer
|
|
29
30
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
|
30
31
|
from .errors import RateLimitError
|
|
31
32
|
|
|
32
33
|
DEFAULT_TEMPERATURE = 0
|
|
33
34
|
DEFAULT_CACHE_DIR = './llm_cache'
|
|
34
35
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
36
|
+
|
|
37
|
+
def get_extraction_language_instruction(group_id: str | None = None) -> str:
|
|
38
|
+
"""Returns instruction for language extraction behavior.
|
|
39
|
+
|
|
40
|
+
Override this function to customize language extraction:
|
|
41
|
+
- Return empty string to disable multilingual instructions
|
|
42
|
+
- Return custom instructions for specific language requirements
|
|
43
|
+
- Use group_id to provide different instructions per group/partition
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
group_id: Optional partition identifier for the graph
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
str: Language instruction to append to system messages
|
|
50
|
+
"""
|
|
51
|
+
return (
|
|
52
|
+
'\n\nAny extracted information should be returned in the same language as it was written in. '
|
|
53
|
+
'Only output non-English text when the user has written full sentences or phrases in that non-English language. '
|
|
54
|
+
'Otherwise, output English.'
|
|
55
|
+
)
|
|
56
|
+
|
|
38
57
|
|
|
39
58
|
logger = logging.getLogger(__name__)
|
|
40
59
|
|
|
@@ -60,11 +79,16 @@ class LLMClient(ABC):
|
|
|
60
79
|
self.max_tokens = config.max_tokens
|
|
61
80
|
self.cache_enabled = cache
|
|
62
81
|
self.cache_dir = None
|
|
82
|
+
self.tracer: Tracer = NoOpTracer()
|
|
63
83
|
|
|
64
84
|
# Only create the cache directory if caching is enabled
|
|
65
85
|
if self.cache_enabled:
|
|
66
86
|
self.cache_dir = Cache(DEFAULT_CACHE_DIR)
|
|
67
87
|
|
|
88
|
+
def set_tracer(self, tracer: Tracer) -> None:
|
|
89
|
+
"""Set the tracer for this LLM client."""
|
|
90
|
+
self.tracer = tracer
|
|
91
|
+
|
|
68
92
|
def _clean_input(self, input: str) -> str:
|
|
69
93
|
"""Clean input string of invalid unicode and control characters.
|
|
70
94
|
|
|
@@ -132,6 +156,8 @@ class LLMClient(ABC):
|
|
|
132
156
|
response_model: type[BaseModel] | None = None,
|
|
133
157
|
max_tokens: int | None = None,
|
|
134
158
|
model_size: ModelSize = ModelSize.medium,
|
|
159
|
+
group_id: str | None = None,
|
|
160
|
+
prompt_name: str | None = None,
|
|
135
161
|
) -> dict[str, typing.Any]:
|
|
136
162
|
if max_tokens is None:
|
|
137
163
|
max_tokens = self.max_tokens
|
|
@@ -145,25 +171,76 @@ class LLMClient(ABC):
|
|
|
145
171
|
)
|
|
146
172
|
|
|
147
173
|
# Add multilingual extraction instructions
|
|
148
|
-
messages[0].content +=
|
|
149
|
-
|
|
150
|
-
if self.cache_enabled and self.cache_dir is not None:
|
|
151
|
-
cache_key = self._get_cache_key(messages)
|
|
152
|
-
|
|
153
|
-
cached_response = self.cache_dir.get(cache_key)
|
|
154
|
-
if cached_response is not None:
|
|
155
|
-
logger.debug(f'Cache hit for {cache_key}')
|
|
156
|
-
return cached_response
|
|
174
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
157
175
|
|
|
158
176
|
for message in messages:
|
|
159
177
|
message.content = self._clean_input(message.content)
|
|
160
178
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
179
|
+
# Wrap entire operation in tracing span
|
|
180
|
+
with self.tracer.start_span('llm.generate') as span:
|
|
181
|
+
attributes = {
|
|
182
|
+
'llm.provider': self._get_provider_type(),
|
|
183
|
+
'model.size': model_size.value,
|
|
184
|
+
'max_tokens': max_tokens,
|
|
185
|
+
'cache.enabled': self.cache_enabled,
|
|
186
|
+
}
|
|
187
|
+
if prompt_name:
|
|
188
|
+
attributes['prompt.name'] = prompt_name
|
|
189
|
+
span.add_attributes(attributes)
|
|
190
|
+
|
|
191
|
+
# Check cache first
|
|
192
|
+
if self.cache_enabled and self.cache_dir is not None:
|
|
193
|
+
cache_key = self._get_cache_key(messages)
|
|
194
|
+
cached_response = self.cache_dir.get(cache_key)
|
|
195
|
+
if cached_response is not None:
|
|
196
|
+
logger.debug(f'Cache hit for {cache_key}')
|
|
197
|
+
span.add_attributes({'cache.hit': True})
|
|
198
|
+
return cached_response
|
|
199
|
+
|
|
200
|
+
span.add_attributes({'cache.hit': False})
|
|
201
|
+
|
|
202
|
+
# Execute LLM call
|
|
203
|
+
try:
|
|
204
|
+
response = await self._generate_response_with_retry(
|
|
205
|
+
messages, response_model, max_tokens, model_size
|
|
206
|
+
)
|
|
207
|
+
except Exception as e:
|
|
208
|
+
span.set_status('error', str(e))
|
|
209
|
+
span.record_exception(e)
|
|
210
|
+
raise
|
|
211
|
+
|
|
212
|
+
# Cache response if enabled
|
|
213
|
+
if self.cache_enabled and self.cache_dir is not None:
|
|
214
|
+
cache_key = self._get_cache_key(messages)
|
|
215
|
+
self.cache_dir.set(cache_key, response)
|
|
216
|
+
|
|
217
|
+
return response
|
|
218
|
+
|
|
219
|
+
def _get_provider_type(self) -> str:
|
|
220
|
+
"""Get provider type from class name."""
|
|
221
|
+
class_name = self.__class__.__name__.lower()
|
|
222
|
+
if 'openai' in class_name:
|
|
223
|
+
return 'openai'
|
|
224
|
+
elif 'anthropic' in class_name:
|
|
225
|
+
return 'anthropic'
|
|
226
|
+
elif 'gemini' in class_name:
|
|
227
|
+
return 'gemini'
|
|
228
|
+
elif 'groq' in class_name:
|
|
229
|
+
return 'groq'
|
|
230
|
+
else:
|
|
231
|
+
return 'unknown'
|
|
232
|
+
|
|
233
|
+
def _get_failed_generation_log(self, messages: list[Message], output: str | None) -> str:
|
|
234
|
+
"""
|
|
235
|
+
Log the full input messages, the raw output (if any), and the exception for debugging failed generations.
|
|
236
|
+
"""
|
|
237
|
+
log = ''
|
|
238
|
+
log += f'Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n'
|
|
239
|
+
if output is not None:
|
|
240
|
+
if len(output) > 4000:
|
|
241
|
+
log += f'Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n'
|
|
242
|
+
else:
|
|
243
|
+
log += f'Raw output: {output}\n'
|
|
244
|
+
else:
|
|
245
|
+
log += 'No raw output available'
|
|
246
|
+
return log
|
|
@@ -16,20 +16,54 @@ limitations under the License.
|
|
|
16
16
|
|
|
17
17
|
import json
|
|
18
18
|
import logging
|
|
19
|
+
import re
|
|
19
20
|
import typing
|
|
21
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
20
22
|
|
|
21
|
-
from google import genai # type: ignore
|
|
22
|
-
from google.genai import types # type: ignore
|
|
23
23
|
from pydantic import BaseModel
|
|
24
24
|
|
|
25
25
|
from ..prompts.models import Message
|
|
26
|
-
from .client import LLMClient
|
|
27
|
-
from .config import
|
|
26
|
+
from .client import LLMClient, get_extraction_language_instruction
|
|
27
|
+
from .config import LLMConfig, ModelSize
|
|
28
28
|
from .errors import RateLimitError
|
|
29
29
|
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from google import genai
|
|
32
|
+
from google.genai import types
|
|
33
|
+
else:
|
|
34
|
+
try:
|
|
35
|
+
from google import genai
|
|
36
|
+
from google.genai import types
|
|
37
|
+
except ImportError:
|
|
38
|
+
# If gemini client is not installed, raise an ImportError
|
|
39
|
+
raise ImportError(
|
|
40
|
+
'google-genai is required for GeminiClient. '
|
|
41
|
+
'Install it with: pip install graphiti-core[google-genai]'
|
|
42
|
+
) from None
|
|
43
|
+
|
|
44
|
+
|
|
30
45
|
logger = logging.getLogger(__name__)
|
|
31
46
|
|
|
32
|
-
DEFAULT_MODEL = 'gemini-2.
|
|
47
|
+
DEFAULT_MODEL = 'gemini-2.5-flash'
|
|
48
|
+
DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite'
|
|
49
|
+
|
|
50
|
+
# Maximum output tokens for different Gemini models
|
|
51
|
+
GEMINI_MODEL_MAX_TOKENS = {
|
|
52
|
+
# Gemini 2.5 models
|
|
53
|
+
'gemini-2.5-pro': 65536,
|
|
54
|
+
'gemini-2.5-flash': 65536,
|
|
55
|
+
'gemini-2.5-flash-lite': 64000,
|
|
56
|
+
# Gemini 2.0 models
|
|
57
|
+
'gemini-2.0-flash': 8192,
|
|
58
|
+
'gemini-2.0-flash-lite': 8192,
|
|
59
|
+
# Gemini 1.5 models
|
|
60
|
+
'gemini-1.5-pro': 8192,
|
|
61
|
+
'gemini-1.5-flash': 8192,
|
|
62
|
+
'gemini-1.5-flash-8b': 8192,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
# Default max tokens for models not in the mapping
|
|
66
|
+
DEFAULT_GEMINI_MAX_TOKENS = 8192
|
|
33
67
|
|
|
34
68
|
|
|
35
69
|
class GeminiClient(LLMClient):
|
|
@@ -43,27 +77,35 @@ class GeminiClient(LLMClient):
|
|
|
43
77
|
model (str): The model name to use for generating responses.
|
|
44
78
|
temperature (float): The temperature to use for generating responses.
|
|
45
79
|
max_tokens (int): The maximum number of tokens to generate in a response.
|
|
46
|
-
|
|
80
|
+
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
|
47
81
|
Methods:
|
|
48
|
-
__init__(config: LLMConfig | None = None, cache: bool = False):
|
|
49
|
-
Initializes the GeminiClient with the provided configuration and
|
|
82
|
+
__init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
|
|
83
|
+
Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
|
50
84
|
|
|
51
85
|
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
|
52
86
|
Generates a response from the language model based on the provided messages.
|
|
53
87
|
"""
|
|
54
88
|
|
|
89
|
+
# Class-level constants
|
|
90
|
+
MAX_RETRIES: ClassVar[int] = 2
|
|
91
|
+
|
|
55
92
|
def __init__(
|
|
56
93
|
self,
|
|
57
94
|
config: LLMConfig | None = None,
|
|
58
95
|
cache: bool = False,
|
|
59
|
-
max_tokens: int =
|
|
96
|
+
max_tokens: int | None = None,
|
|
97
|
+
thinking_config: types.ThinkingConfig | None = None,
|
|
98
|
+
client: 'genai.Client | None' = None,
|
|
60
99
|
):
|
|
61
100
|
"""
|
|
62
|
-
Initialize the GeminiClient with the provided configuration and
|
|
101
|
+
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
|
63
102
|
|
|
64
103
|
Args:
|
|
65
104
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
|
|
66
105
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
106
|
+
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
|
107
|
+
Only use with models that support thinking (gemini-2.5+). Defaults to None.
|
|
108
|
+
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
|
67
109
|
"""
|
|
68
110
|
if config is None:
|
|
69
111
|
config = LLMConfig()
|
|
@@ -71,17 +113,128 @@ class GeminiClient(LLMClient):
|
|
|
71
113
|
super().__init__(config, cache)
|
|
72
114
|
|
|
73
115
|
self.model = config.model
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
api_key=config.api_key
|
|
77
|
-
|
|
116
|
+
|
|
117
|
+
if client is None:
|
|
118
|
+
self.client = genai.Client(api_key=config.api_key)
|
|
119
|
+
else:
|
|
120
|
+
self.client = client
|
|
121
|
+
|
|
78
122
|
self.max_tokens = max_tokens
|
|
123
|
+
self.thinking_config = thinking_config
|
|
124
|
+
|
|
125
|
+
def _check_safety_blocks(self, response) -> None:
|
|
126
|
+
"""Check if response was blocked for safety reasons and raise appropriate exceptions."""
|
|
127
|
+
# Check if the response was blocked for safety reasons
|
|
128
|
+
if not (hasattr(response, 'candidates') and response.candidates):
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
candidate = response.candidates[0]
|
|
132
|
+
if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
# Content was blocked for safety reasons - collect safety details
|
|
136
|
+
safety_info = []
|
|
137
|
+
safety_ratings = getattr(candidate, 'safety_ratings', None)
|
|
138
|
+
|
|
139
|
+
if safety_ratings:
|
|
140
|
+
for rating in safety_ratings:
|
|
141
|
+
if getattr(rating, 'blocked', False):
|
|
142
|
+
category = getattr(rating, 'category', 'Unknown')
|
|
143
|
+
probability = getattr(rating, 'probability', 'Unknown')
|
|
144
|
+
safety_info.append(f'{category}: {probability}')
|
|
145
|
+
|
|
146
|
+
safety_details = (
|
|
147
|
+
', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
|
|
148
|
+
)
|
|
149
|
+
raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
|
|
150
|
+
|
|
151
|
+
def _check_prompt_blocks(self, response) -> None:
|
|
152
|
+
"""Check if prompt was blocked and raise appropriate exceptions."""
|
|
153
|
+
prompt_feedback = getattr(response, 'prompt_feedback', None)
|
|
154
|
+
if not prompt_feedback:
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
block_reason = getattr(prompt_feedback, 'block_reason', None)
|
|
158
|
+
if block_reason:
|
|
159
|
+
raise Exception(f'Prompt blocked by Gemini: {block_reason}')
|
|
160
|
+
|
|
161
|
+
def _get_model_for_size(self, model_size: ModelSize) -> str:
|
|
162
|
+
"""Get the appropriate model name based on the requested size."""
|
|
163
|
+
if model_size == ModelSize.small:
|
|
164
|
+
return self.small_model or DEFAULT_SMALL_MODEL
|
|
165
|
+
else:
|
|
166
|
+
return self.model or DEFAULT_MODEL
|
|
167
|
+
|
|
168
|
+
def _get_max_tokens_for_model(self, model: str) -> int:
|
|
169
|
+
"""Get the maximum output tokens for a specific Gemini model."""
|
|
170
|
+
return GEMINI_MODEL_MAX_TOKENS.get(model, DEFAULT_GEMINI_MAX_TOKENS)
|
|
171
|
+
|
|
172
|
+
def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
|
|
173
|
+
"""
|
|
174
|
+
Resolve the maximum output tokens to use based on precedence rules.
|
|
175
|
+
|
|
176
|
+
Precedence order (highest to lowest):
|
|
177
|
+
1. Explicit max_tokens parameter passed to generate_response()
|
|
178
|
+
2. Instance max_tokens set during client initialization
|
|
179
|
+
3. Model-specific maximum tokens from GEMINI_MODEL_MAX_TOKENS mapping
|
|
180
|
+
4. DEFAULT_MAX_TOKENS as final fallback
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
requested_max_tokens: The max_tokens parameter passed to generate_response()
|
|
184
|
+
model: The model name to look up model-specific limits
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
int: The resolved maximum tokens to use
|
|
188
|
+
"""
|
|
189
|
+
# 1. Use explicit parameter if provided
|
|
190
|
+
if requested_max_tokens is not None:
|
|
191
|
+
return requested_max_tokens
|
|
192
|
+
|
|
193
|
+
# 2. Use instance max_tokens if set during initialization
|
|
194
|
+
if self.max_tokens is not None:
|
|
195
|
+
return self.max_tokens
|
|
196
|
+
|
|
197
|
+
# 3. Use model-specific maximum or return DEFAULT_GEMINI_MAX_TOKENS
|
|
198
|
+
return self._get_max_tokens_for_model(model)
|
|
199
|
+
|
|
200
|
+
def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
|
|
201
|
+
"""
|
|
202
|
+
Attempt to salvage a JSON object if the raw output is truncated.
|
|
203
|
+
|
|
204
|
+
This is accomplished by looking for the last closing bracket for an array or object.
|
|
205
|
+
If found, it will try to load the JSON object from the raw output.
|
|
206
|
+
If the JSON object is not valid, it will return None.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
raw_output (str): The raw output from the LLM.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
dict[str, typing.Any]: The salvaged JSON object.
|
|
213
|
+
None: If no salvage is possible.
|
|
214
|
+
"""
|
|
215
|
+
if not raw_output:
|
|
216
|
+
return None
|
|
217
|
+
# Try to salvage a JSON array
|
|
218
|
+
array_match = re.search(r'\]\s*$', raw_output)
|
|
219
|
+
if array_match:
|
|
220
|
+
try:
|
|
221
|
+
return json.loads(raw_output[: array_match.end()])
|
|
222
|
+
except Exception:
|
|
223
|
+
pass
|
|
224
|
+
# Try to salvage a JSON object
|
|
225
|
+
obj_match = re.search(r'\}\s*$', raw_output)
|
|
226
|
+
if obj_match:
|
|
227
|
+
try:
|
|
228
|
+
return json.loads(raw_output[: obj_match.end()])
|
|
229
|
+
except Exception:
|
|
230
|
+
pass
|
|
231
|
+
return None
|
|
79
232
|
|
|
80
233
|
async def _generate_response(
|
|
81
234
|
self,
|
|
82
235
|
messages: list[Message],
|
|
83
236
|
response_model: type[BaseModel] | None = None,
|
|
84
|
-
max_tokens: int =
|
|
237
|
+
max_tokens: int | None = None,
|
|
85
238
|
model_size: ModelSize = ModelSize.medium,
|
|
86
239
|
) -> dict[str, typing.Any]:
|
|
87
240
|
"""
|
|
@@ -90,18 +243,18 @@ class GeminiClient(LLMClient):
|
|
|
90
243
|
Args:
|
|
91
244
|
messages (list[Message]): A list of messages to send to the language model.
|
|
92
245
|
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
|
93
|
-
max_tokens (int): The maximum number of tokens to generate in the response.
|
|
246
|
+
max_tokens (int | None): The maximum number of tokens to generate in the response. If None, uses precedence rules.
|
|
247
|
+
model_size (ModelSize): The size of the model to use (small or medium).
|
|
94
248
|
|
|
95
249
|
Returns:
|
|
96
250
|
dict[str, typing.Any]: The response from the language model.
|
|
97
251
|
|
|
98
252
|
Raises:
|
|
99
253
|
RateLimitError: If the API rate limit is exceeded.
|
|
100
|
-
|
|
101
|
-
Exception: If there is an error generating the response.
|
|
254
|
+
Exception: If there is an error generating the response or content is blocked.
|
|
102
255
|
"""
|
|
103
256
|
try:
|
|
104
|
-
gemini_messages:
|
|
257
|
+
gemini_messages: typing.Any = []
|
|
105
258
|
# If a response model is provided, add schema for structured output
|
|
106
259
|
system_prompt = ''
|
|
107
260
|
if response_model is not None:
|
|
@@ -127,44 +280,75 @@ class GeminiClient(LLMClient):
|
|
|
127
280
|
types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
|
|
128
281
|
)
|
|
129
282
|
|
|
283
|
+
# Get the appropriate model for the requested size
|
|
284
|
+
model = self._get_model_for_size(model_size)
|
|
285
|
+
|
|
286
|
+
# Resolve max_tokens using precedence rules (see _resolve_max_tokens for details)
|
|
287
|
+
resolved_max_tokens = self._resolve_max_tokens(max_tokens, model)
|
|
288
|
+
|
|
130
289
|
# Create generation config
|
|
131
290
|
generation_config = types.GenerateContentConfig(
|
|
132
291
|
temperature=self.temperature,
|
|
133
|
-
max_output_tokens=
|
|
292
|
+
max_output_tokens=resolved_max_tokens,
|
|
134
293
|
response_mime_type='application/json' if response_model else None,
|
|
135
294
|
response_schema=response_model if response_model else None,
|
|
136
295
|
system_instruction=system_prompt,
|
|
296
|
+
thinking_config=self.thinking_config,
|
|
137
297
|
)
|
|
138
298
|
|
|
139
299
|
# Generate content using the simple string approach
|
|
140
300
|
response = await self.client.aio.models.generate_content(
|
|
141
|
-
model=
|
|
142
|
-
contents=gemini_messages,
|
|
301
|
+
model=model,
|
|
302
|
+
contents=gemini_messages,
|
|
143
303
|
config=generation_config,
|
|
144
304
|
)
|
|
145
305
|
|
|
306
|
+
# Always capture the raw output for debugging
|
|
307
|
+
raw_output = getattr(response, 'text', None)
|
|
308
|
+
|
|
309
|
+
# Check for safety and prompt blocks
|
|
310
|
+
self._check_safety_blocks(response)
|
|
311
|
+
self._check_prompt_blocks(response)
|
|
312
|
+
|
|
146
313
|
# If this was a structured output request, parse the response into the Pydantic model
|
|
147
314
|
if response_model is not None:
|
|
148
315
|
try:
|
|
149
|
-
if not
|
|
316
|
+
if not raw_output:
|
|
150
317
|
raise ValueError('No response text')
|
|
151
318
|
|
|
152
|
-
validated_model = response_model.model_validate(json.loads(
|
|
319
|
+
validated_model = response_model.model_validate(json.loads(raw_output))
|
|
153
320
|
|
|
154
321
|
# Return as a dictionary for API consistency
|
|
155
322
|
return validated_model.model_dump()
|
|
156
323
|
except Exception as e:
|
|
324
|
+
if raw_output:
|
|
325
|
+
logger.error(
|
|
326
|
+
'🦀 LLM generation failed parsing as JSON, will try to salvage.'
|
|
327
|
+
)
|
|
328
|
+
logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
|
|
329
|
+
# Try to salvage
|
|
330
|
+
salvaged = self.salvage_json(raw_output)
|
|
331
|
+
if salvaged is not None:
|
|
332
|
+
logger.warning('Salvaged partial JSON from truncated/malformed output.')
|
|
333
|
+
return salvaged
|
|
157
334
|
raise Exception(f'Failed to parse structured response: {e}') from e
|
|
158
335
|
|
|
159
336
|
# Otherwise, return the response text as a dictionary
|
|
160
|
-
return {'content':
|
|
337
|
+
return {'content': raw_output}
|
|
161
338
|
|
|
162
339
|
except Exception as e:
|
|
163
|
-
# Check if it's a rate limit error
|
|
164
|
-
|
|
340
|
+
# Check if it's a rate limit error based on Gemini API error codes
|
|
341
|
+
error_message = str(e).lower()
|
|
342
|
+
if (
|
|
343
|
+
'rate limit' in error_message
|
|
344
|
+
or 'quota' in error_message
|
|
345
|
+
or 'resource_exhausted' in error_message
|
|
346
|
+
or '429' in str(e)
|
|
347
|
+
):
|
|
165
348
|
raise RateLimitError from e
|
|
349
|
+
|
|
166
350
|
logger.error(f'Error in generating LLM response: {e}')
|
|
167
|
-
raise
|
|
351
|
+
raise Exception from e
|
|
168
352
|
|
|
169
353
|
async def generate_response(
|
|
170
354
|
self,
|
|
@@ -172,26 +356,91 @@ class GeminiClient(LLMClient):
|
|
|
172
356
|
response_model: type[BaseModel] | None = None,
|
|
173
357
|
max_tokens: int | None = None,
|
|
174
358
|
model_size: ModelSize = ModelSize.medium,
|
|
359
|
+
group_id: str | None = None,
|
|
360
|
+
prompt_name: str | None = None,
|
|
175
361
|
) -> dict[str, typing.Any]:
|
|
176
362
|
"""
|
|
177
|
-
Generate a response from the Gemini language model.
|
|
178
|
-
This method overrides the parent class method to provide a direct implementation.
|
|
363
|
+
Generate a response from the Gemini language model with retry logic and error handling.
|
|
364
|
+
This method overrides the parent class method to provide a direct implementation with advanced retry logic.
|
|
179
365
|
|
|
180
366
|
Args:
|
|
181
367
|
messages (list[Message]): A list of messages to send to the language model.
|
|
182
368
|
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
|
183
|
-
max_tokens (int): The maximum number of tokens to generate in the response.
|
|
369
|
+
max_tokens (int | None): The maximum number of tokens to generate in the response.
|
|
370
|
+
model_size (ModelSize): The size of the model to use (small or medium).
|
|
371
|
+
group_id (str | None): Optional partition identifier for the graph.
|
|
372
|
+
prompt_name (str | None): Optional name of the prompt for tracing.
|
|
184
373
|
|
|
185
374
|
Returns:
|
|
186
375
|
dict[str, typing.Any]: The response from the language model.
|
|
187
376
|
"""
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
377
|
+
# Add multilingual extraction instructions
|
|
378
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
379
|
+
|
|
380
|
+
# Wrap entire operation in tracing span
|
|
381
|
+
with self.tracer.start_span('llm.generate') as span:
|
|
382
|
+
attributes = {
|
|
383
|
+
'llm.provider': 'gemini',
|
|
384
|
+
'model.size': model_size.value,
|
|
385
|
+
'max_tokens': max_tokens or self.max_tokens,
|
|
386
|
+
}
|
|
387
|
+
if prompt_name:
|
|
388
|
+
attributes['prompt.name'] = prompt_name
|
|
389
|
+
span.add_attributes(attributes)
|
|
390
|
+
|
|
391
|
+
retry_count = 0
|
|
392
|
+
last_error = None
|
|
393
|
+
last_output = None
|
|
394
|
+
|
|
395
|
+
while retry_count < self.MAX_RETRIES:
|
|
396
|
+
try:
|
|
397
|
+
response = await self._generate_response(
|
|
398
|
+
messages=messages,
|
|
399
|
+
response_model=response_model,
|
|
400
|
+
max_tokens=max_tokens,
|
|
401
|
+
model_size=model_size,
|
|
402
|
+
)
|
|
403
|
+
last_output = (
|
|
404
|
+
response.get('content')
|
|
405
|
+
if isinstance(response, dict) and 'content' in response
|
|
406
|
+
else None
|
|
407
|
+
)
|
|
408
|
+
return response
|
|
409
|
+
except RateLimitError as e:
|
|
410
|
+
# Rate limit errors should not trigger retries (fail fast)
|
|
411
|
+
span.set_status('error', str(e))
|
|
412
|
+
raise e
|
|
413
|
+
except Exception as e:
|
|
414
|
+
last_error = e
|
|
415
|
+
|
|
416
|
+
# Check if this is a safety block - these typically shouldn't be retried
|
|
417
|
+
error_text = str(e) or (str(e.__cause__) if e.__cause__ else '')
|
|
418
|
+
if 'safety' in error_text.lower() or 'blocked' in error_text.lower():
|
|
419
|
+
logger.warning(f'Content blocked by safety filters: {e}')
|
|
420
|
+
span.set_status('error', str(e))
|
|
421
|
+
raise Exception(f'Content blocked by safety filters: {e}') from e
|
|
422
|
+
|
|
423
|
+
retry_count += 1
|
|
424
|
+
|
|
425
|
+
# Construct a detailed error message for the LLM
|
|
426
|
+
error_context = (
|
|
427
|
+
f'The previous response attempt was invalid. '
|
|
428
|
+
f'Error type: {e.__class__.__name__}. '
|
|
429
|
+
f'Error details: {str(e)}. '
|
|
430
|
+
f'Please try again with a valid response, ensuring the output matches '
|
|
431
|
+
f'the expected format and constraints.'
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
error_message = Message(role='user', content=error_context)
|
|
435
|
+
messages.append(error_message)
|
|
436
|
+
logger.warning(
|
|
437
|
+
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# If we exit the loop without returning, all retries are exhausted
|
|
441
|
+
logger.error('🦀 LLM generation failed and retries are exhausted.')
|
|
442
|
+
logger.error(self._get_failed_generation_log(messages, last_output))
|
|
443
|
+
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
|
|
444
|
+
span.set_status('error', str(last_error))
|
|
445
|
+
span.record_exception(last_error) if last_error else None
|
|
446
|
+
raise last_error or Exception('Max retries exceeded')
|
|
@@ -17,10 +17,21 @@ limitations under the License.
|
|
|
17
17
|
import json
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
20
21
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
from groq
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
import groq
|
|
24
|
+
from groq import AsyncGroq
|
|
25
|
+
from groq.types.chat import ChatCompletionMessageParam
|
|
26
|
+
else:
|
|
27
|
+
try:
|
|
28
|
+
import groq
|
|
29
|
+
from groq import AsyncGroq
|
|
30
|
+
from groq.types.chat import ChatCompletionMessageParam
|
|
31
|
+
except ImportError:
|
|
32
|
+
raise ImportError(
|
|
33
|
+
'groq is required for GroqClient. Install it with: pip install graphiti-core[groq]'
|
|
34
|
+
) from None
|
|
24
35
|
from pydantic import BaseModel
|
|
25
36
|
|
|
26
37
|
from ..prompts.models import Message
|