graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +70 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +635 -260
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +37 -15
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +92 -48
- graphiti_core/llm_client/openai_client.py +39 -9
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +24 -15
- graphiti_core/prompts/extract_nodes.py +76 -35
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +110 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/content_chunking.py +702 -0
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +306 -156
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +466 -206
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
- graphiti_core-0.25.3.dist-info/RECORD +87 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
graphiti_core/graphiti_types.py
CHANGED
|
@@ -20,6 +20,7 @@ from graphiti_core.cross_encoder import CrossEncoderClient
|
|
|
20
20
|
from graphiti_core.driver.driver import GraphDriver
|
|
21
21
|
from graphiti_core.embedder import EmbedderClient
|
|
22
22
|
from graphiti_core.llm_client import LLMClient
|
|
23
|
+
from graphiti_core.tracer import Tracer
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class GraphitiClients(BaseModel):
|
|
@@ -27,5 +28,6 @@ class GraphitiClients(BaseModel):
|
|
|
27
28
|
llm_client: LLMClient
|
|
28
29
|
embedder: EmbedderClient
|
|
29
30
|
cross_encoder: CrossEncoderClient
|
|
31
|
+
tracer: Tracer
|
|
30
32
|
|
|
31
33
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
graphiti_core/helpers.py
CHANGED
|
@@ -26,30 +26,52 @@ from dotenv import load_dotenv
|
|
|
26
26
|
from neo4j import time as neo4j_time
|
|
27
27
|
from numpy._typing import NDArray
|
|
28
28
|
from pydantic import BaseModel
|
|
29
|
-
from typing_extensions import LiteralString
|
|
30
29
|
|
|
30
|
+
from graphiti_core.driver.driver import GraphProvider
|
|
31
31
|
from graphiti_core.errors import GroupIdValidationError
|
|
32
32
|
|
|
33
33
|
load_dotenv()
|
|
34
34
|
|
|
35
35
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
36
36
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
|
37
|
-
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
|
38
37
|
DEFAULT_PAGE_LIMIT = 20
|
|
39
38
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
)
|
|
39
|
+
# Content chunking configuration for entity extraction
|
|
40
|
+
# Density-based chunking: only chunk high-density content (many entities per token)
|
|
41
|
+
# This targets the failure case (large entity-dense inputs) while preserving
|
|
42
|
+
# context for prose/narrative content
|
|
43
|
+
CHUNK_TOKEN_SIZE = int(os.getenv('CHUNK_TOKEN_SIZE', 3000))
|
|
44
|
+
CHUNK_OVERLAP_TOKENS = int(os.getenv('CHUNK_OVERLAP_TOKENS', 200))
|
|
45
|
+
# Minimum tokens before considering chunking - short content processes fine regardless of density
|
|
46
|
+
CHUNK_MIN_TOKENS = int(os.getenv('CHUNK_MIN_TOKENS', 1000))
|
|
47
|
+
# Entity density threshold: chunk if estimated density > this value
|
|
48
|
+
# For JSON: elements per 1000 tokens > threshold * 1000 (e.g., 0.15 = 150 elements/1000 tokens)
|
|
49
|
+
# For Text: capitalized words per 1000 tokens > threshold * 500 (e.g., 0.15 = 75 caps/1000 tokens)
|
|
50
|
+
# Higher values = more conservative (less chunking), targets P95+ density cases
|
|
51
|
+
# Examples that trigger chunking at 0.15: AWS cost data (12mo), bulk data imports, entity-dense JSON
|
|
52
|
+
# Examples that DON'T chunk at 0.15: meeting transcripts, news articles, documentation
|
|
53
|
+
CHUNK_DENSITY_THRESHOLD = float(os.getenv('CHUNK_DENSITY_THRESHOLD', 0.15))
|
|
43
54
|
|
|
44
55
|
|
|
45
|
-
def parse_db_date(
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
56
|
+
def parse_db_date(input_date: neo4j_time.DateTime | str | None) -> datetime | None:
|
|
57
|
+
if isinstance(input_date, neo4j_time.DateTime):
|
|
58
|
+
return input_date.to_native()
|
|
59
|
+
|
|
60
|
+
if isinstance(input_date, str):
|
|
61
|
+
return datetime.fromisoformat(input_date)
|
|
62
|
+
|
|
63
|
+
return input_date
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_default_group_id(provider: GraphProvider) -> str:
|
|
67
|
+
"""
|
|
68
|
+
This function differentiates the default group id based on the database type.
|
|
69
|
+
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
|
70
|
+
"""
|
|
71
|
+
if provider == GraphProvider.FALKORDB:
|
|
72
|
+
return '\\_'
|
|
73
|
+
else:
|
|
74
|
+
return ''
|
|
53
75
|
|
|
54
76
|
|
|
55
77
|
def lucene_sanitize(query: str) -> str:
|
|
@@ -109,7 +131,7 @@ async def semaphore_gather(
|
|
|
109
131
|
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
|
|
110
132
|
|
|
111
133
|
|
|
112
|
-
def validate_group_id(group_id: str) -> bool:
|
|
134
|
+
def validate_group_id(group_id: str | None) -> bool:
|
|
113
135
|
"""
|
|
114
136
|
Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
|
|
115
137
|
|
|
@@ -136,7 +158,7 @@ def validate_group_id(group_id: str) -> bool:
|
|
|
136
158
|
|
|
137
159
|
|
|
138
160
|
def validate_excluded_entity_types(
|
|
139
|
-
excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None
|
|
161
|
+
excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None
|
|
140
162
|
) -> bool:
|
|
141
163
|
"""
|
|
142
164
|
Validate that excluded entity types are valid type names.
|
|
@@ -47,6 +47,9 @@ else:
|
|
|
47
47
|
logger = logging.getLogger(__name__)
|
|
48
48
|
|
|
49
49
|
AnthropicModel = Literal[
|
|
50
|
+
'claude-sonnet-4-5-latest',
|
|
51
|
+
'claude-sonnet-4-5-20250929',
|
|
52
|
+
'claude-haiku-4-5-latest',
|
|
50
53
|
'claude-3-7-sonnet-latest',
|
|
51
54
|
'claude-3-7-sonnet-20250219',
|
|
52
55
|
'claude-3-5-haiku-latest',
|
|
@@ -62,7 +65,39 @@ AnthropicModel = Literal[
|
|
|
62
65
|
'claude-2.0',
|
|
63
66
|
]
|
|
64
67
|
|
|
65
|
-
DEFAULT_MODEL: AnthropicModel = 'claude-
|
|
68
|
+
DEFAULT_MODEL: AnthropicModel = 'claude-haiku-4-5-latest'
|
|
69
|
+
|
|
70
|
+
# Maximum output tokens for different Anthropic models
|
|
71
|
+
# Based on official Anthropic documentation (as of 2025)
|
|
72
|
+
# Note: These represent standard limits without beta headers.
|
|
73
|
+
# Some models support higher limits with additional configuration (e.g., Claude 3.7 supports
|
|
74
|
+
# 128K with 'anthropic-beta: output-128k-2025-02-19' header, but this is not currently implemented).
|
|
75
|
+
ANTHROPIC_MODEL_MAX_TOKENS = {
|
|
76
|
+
# Claude 4.5 models - 64K tokens
|
|
77
|
+
'claude-sonnet-4-5-latest': 65536,
|
|
78
|
+
'claude-sonnet-4-5-20250929': 65536,
|
|
79
|
+
'claude-haiku-4-5-latest': 65536,
|
|
80
|
+
# Claude 3.7 models - standard 64K tokens
|
|
81
|
+
'claude-3-7-sonnet-latest': 65536,
|
|
82
|
+
'claude-3-7-sonnet-20250219': 65536,
|
|
83
|
+
# Claude 3.5 models
|
|
84
|
+
'claude-3-5-haiku-latest': 8192,
|
|
85
|
+
'claude-3-5-haiku-20241022': 8192,
|
|
86
|
+
'claude-3-5-sonnet-latest': 8192,
|
|
87
|
+
'claude-3-5-sonnet-20241022': 8192,
|
|
88
|
+
'claude-3-5-sonnet-20240620': 8192,
|
|
89
|
+
# Claude 3 models - 4K tokens
|
|
90
|
+
'claude-3-opus-latest': 4096,
|
|
91
|
+
'claude-3-opus-20240229': 4096,
|
|
92
|
+
'claude-3-sonnet-20240229': 4096,
|
|
93
|
+
'claude-3-haiku-20240307': 4096,
|
|
94
|
+
# Claude 2 models - 4K tokens
|
|
95
|
+
'claude-2.1': 4096,
|
|
96
|
+
'claude-2.0': 4096,
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
# Default max tokens for models not in the mapping
|
|
100
|
+
DEFAULT_ANTHROPIC_MAX_TOKENS = 8192
|
|
66
101
|
|
|
67
102
|
|
|
68
103
|
class AnthropicClient(LLMClient):
|
|
@@ -177,6 +212,45 @@ class AnthropicClient(LLMClient):
|
|
|
177
212
|
tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
|
|
178
213
|
return tool_list_cast, tool_choice_cast
|
|
179
214
|
|
|
215
|
+
def _get_max_tokens_for_model(self, model: str) -> int:
|
|
216
|
+
"""Get the maximum output tokens for a specific Anthropic model.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
model: The model name to look up
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
int: The maximum output tokens for the model
|
|
223
|
+
"""
|
|
224
|
+
return ANTHROPIC_MODEL_MAX_TOKENS.get(model, DEFAULT_ANTHROPIC_MAX_TOKENS)
|
|
225
|
+
|
|
226
|
+
def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
|
|
227
|
+
"""
|
|
228
|
+
Resolve the maximum output tokens to use based on precedence rules.
|
|
229
|
+
|
|
230
|
+
Precedence order (highest to lowest):
|
|
231
|
+
1. Explicit max_tokens parameter passed to generate_response()
|
|
232
|
+
2. Instance max_tokens set during client initialization
|
|
233
|
+
3. Model-specific maximum tokens from ANTHROPIC_MODEL_MAX_TOKENS mapping
|
|
234
|
+
4. DEFAULT_ANTHROPIC_MAX_TOKENS as final fallback
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
requested_max_tokens: The max_tokens parameter passed to generate_response()
|
|
238
|
+
model: The model name to look up model-specific limits
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
int: The resolved maximum tokens to use
|
|
242
|
+
"""
|
|
243
|
+
# 1. Use explicit parameter if provided
|
|
244
|
+
if requested_max_tokens is not None:
|
|
245
|
+
return requested_max_tokens
|
|
246
|
+
|
|
247
|
+
# 2. Use instance max_tokens if set during initialization
|
|
248
|
+
if self.max_tokens is not None:
|
|
249
|
+
return self.max_tokens
|
|
250
|
+
|
|
251
|
+
# 3. Use model-specific maximum or return DEFAULT_ANTHROPIC_MAX_TOKENS
|
|
252
|
+
return self._get_max_tokens_for_model(model)
|
|
253
|
+
|
|
180
254
|
async def _generate_response(
|
|
181
255
|
self,
|
|
182
256
|
messages: list[Message],
|
|
@@ -204,12 +278,9 @@ class AnthropicClient(LLMClient):
|
|
|
204
278
|
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
|
|
205
279
|
user_messages_cast = typing.cast(list[MessageParam], user_messages)
|
|
206
280
|
|
|
207
|
-
#
|
|
208
|
-
#
|
|
209
|
-
max_creation_tokens: int =
|
|
210
|
-
max_tokens if max_tokens is not None else self.config.max_tokens,
|
|
211
|
-
DEFAULT_MAX_TOKENS,
|
|
212
|
-
)
|
|
281
|
+
# Resolve max_tokens dynamically based on the model's capabilities
|
|
282
|
+
# This allows different models to use their full output capacity
|
|
283
|
+
max_creation_tokens: int = self._resolve_max_tokens(max_tokens, self.model)
|
|
213
284
|
|
|
214
285
|
try:
|
|
215
286
|
# Create the appropriate tool based on whether response_model is provided
|
|
@@ -265,6 +336,8 @@ class AnthropicClient(LLMClient):
|
|
|
265
336
|
response_model: type[BaseModel] | None = None,
|
|
266
337
|
max_tokens: int | None = None,
|
|
267
338
|
model_size: ModelSize = ModelSize.medium,
|
|
339
|
+
group_id: str | None = None,
|
|
340
|
+
prompt_name: str | None = None,
|
|
268
341
|
) -> dict[str, typing.Any]:
|
|
269
342
|
"""
|
|
270
343
|
Generate a response from the LLM.
|
|
@@ -285,55 +358,72 @@ class AnthropicClient(LLMClient):
|
|
|
285
358
|
if max_tokens is None:
|
|
286
359
|
max_tokens = self.max_tokens
|
|
287
360
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
361
|
+
# Wrap entire operation in tracing span
|
|
362
|
+
with self.tracer.start_span('llm.generate') as span:
|
|
363
|
+
attributes = {
|
|
364
|
+
'llm.provider': 'anthropic',
|
|
365
|
+
'model.size': model_size.value,
|
|
366
|
+
'max_tokens': max_tokens,
|
|
367
|
+
}
|
|
368
|
+
if prompt_name:
|
|
369
|
+
attributes['prompt.name'] = prompt_name
|
|
370
|
+
span.add_attributes(attributes)
|
|
371
|
+
|
|
372
|
+
retry_count = 0
|
|
373
|
+
max_retries = 2
|
|
374
|
+
last_error: Exception | None = None
|
|
375
|
+
|
|
376
|
+
while retry_count <= max_retries:
|
|
377
|
+
try:
|
|
378
|
+
response = await self._generate_response(
|
|
379
|
+
messages, response_model, max_tokens, model_size
|
|
380
|
+
)
|
|
306
381
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
382
|
+
# If we have a response_model, attempt to validate the response
|
|
383
|
+
if response_model is not None:
|
|
384
|
+
# Validate the response against the response_model
|
|
385
|
+
model_instance = response_model(**response)
|
|
386
|
+
return model_instance.model_dump()
|
|
387
|
+
|
|
388
|
+
# If no validation needed, return the response
|
|
389
|
+
return response
|
|
390
|
+
|
|
391
|
+
except (RateLimitError, RefusalError):
|
|
392
|
+
# These errors should not trigger retries
|
|
393
|
+
span.set_status('error', str(last_error))
|
|
394
|
+
raise
|
|
395
|
+
except Exception as e:
|
|
396
|
+
last_error = e
|
|
397
|
+
|
|
398
|
+
if retry_count >= max_retries:
|
|
399
|
+
if isinstance(e, ValidationError):
|
|
400
|
+
logger.error(
|
|
401
|
+
f'Validation error after {retry_count}/{max_retries} attempts: {e}'
|
|
402
|
+
)
|
|
403
|
+
else:
|
|
404
|
+
logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
|
|
405
|
+
span.set_status('error', str(e))
|
|
406
|
+
span.record_exception(e)
|
|
407
|
+
raise e
|
|
312
408
|
|
|
313
|
-
if retry_count >= max_retries:
|
|
314
409
|
if isinstance(e, ValidationError):
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
)
|
|
410
|
+
response_model_cast = typing.cast(type[BaseModel], response_model)
|
|
411
|
+
error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
|
|
318
412
|
else:
|
|
319
|
-
|
|
320
|
-
|
|
413
|
+
error_context = (
|
|
414
|
+
f'The previous response attempt was invalid. '
|
|
415
|
+
f'Error type: {e.__class__.__name__}. '
|
|
416
|
+
f'Error details: {str(e)}. '
|
|
417
|
+
f'Please try again with a valid response.'
|
|
418
|
+
)
|
|
321
419
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
f'The previous response attempt was invalid. '
|
|
328
|
-
f'Error type: {e.__class__.__name__}. '
|
|
329
|
-
f'Error details: {str(e)}. '
|
|
330
|
-
f'Please try again with a valid response.'
|
|
420
|
+
# Common retry logic
|
|
421
|
+
retry_count += 1
|
|
422
|
+
messages.append(Message(role='user', content=error_context))
|
|
423
|
+
logger.warning(
|
|
424
|
+
f'Retrying after error (attempt {retry_count}/{max_retries}): {e}'
|
|
331
425
|
)
|
|
332
426
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
logger.warning(f'Retrying after error (attempt {retry_count}/{max_retries}): {e}')
|
|
337
|
-
|
|
338
|
-
# If we somehow get here, raise the last error
|
|
339
|
-
raise last_error or Exception('Max retries exceeded with no specific error')
|
|
427
|
+
# If we somehow get here, raise the last error
|
|
428
|
+
span.set_status('error', str(last_error))
|
|
429
|
+
raise last_error or Exception('Max retries exceeded with no specific error')
|
|
@@ -17,7 +17,7 @@ limitations under the License.
|
|
|
17
17
|
import logging
|
|
18
18
|
from typing import ClassVar
|
|
19
19
|
|
|
20
|
-
from openai import AsyncAzureOpenAI
|
|
20
|
+
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
21
21
|
from openai.types.chat import ChatCompletionMessageParam
|
|
22
22
|
from pydantic import BaseModel
|
|
23
23
|
|
|
@@ -28,18 +28,29 @@ logger = logging.getLogger(__name__)
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class AzureOpenAILLMClient(BaseOpenAIClient):
|
|
31
|
-
"""Wrapper class for
|
|
31
|
+
"""Wrapper class for Azure OpenAI that implements the LLMClient interface.
|
|
32
|
+
|
|
33
|
+
Supports both AsyncAzureOpenAI and AsyncOpenAI (with Azure v1 API endpoint).
|
|
34
|
+
"""
|
|
32
35
|
|
|
33
36
|
# Class-level constants
|
|
34
37
|
MAX_RETRIES: ClassVar[int] = 2
|
|
35
38
|
|
|
36
39
|
def __init__(
|
|
37
40
|
self,
|
|
38
|
-
azure_client: AsyncAzureOpenAI,
|
|
41
|
+
azure_client: AsyncAzureOpenAI | AsyncOpenAI,
|
|
39
42
|
config: LLMConfig | None = None,
|
|
40
43
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
44
|
+
reasoning: str | None = None,
|
|
45
|
+
verbosity: str | None = None,
|
|
41
46
|
):
|
|
42
|
-
super().__init__(
|
|
47
|
+
super().__init__(
|
|
48
|
+
config,
|
|
49
|
+
cache=False,
|
|
50
|
+
max_tokens=max_tokens,
|
|
51
|
+
reasoning=reasoning,
|
|
52
|
+
verbosity=verbosity,
|
|
53
|
+
)
|
|
43
54
|
self.client = azure_client
|
|
44
55
|
|
|
45
56
|
async def _create_structured_completion(
|
|
@@ -49,15 +60,29 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
|
|
49
60
|
temperature: float | None,
|
|
50
61
|
max_tokens: int,
|
|
51
62
|
response_model: type[BaseModel],
|
|
63
|
+
reasoning: str | None,
|
|
64
|
+
verbosity: str | None,
|
|
52
65
|
):
|
|
53
|
-
"""Create a structured completion using Azure OpenAI's
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
max_tokens
|
|
59
|
-
|
|
60
|
-
|
|
66
|
+
"""Create a structured completion using Azure OpenAI's responses.parse API."""
|
|
67
|
+
supports_reasoning = self._supports_reasoning_features(model)
|
|
68
|
+
request_kwargs = {
|
|
69
|
+
'model': model,
|
|
70
|
+
'input': messages,
|
|
71
|
+
'max_output_tokens': max_tokens,
|
|
72
|
+
'text_format': response_model, # type: ignore
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
temperature_value = temperature if not supports_reasoning else None
|
|
76
|
+
if temperature_value is not None:
|
|
77
|
+
request_kwargs['temperature'] = temperature_value
|
|
78
|
+
|
|
79
|
+
if supports_reasoning and reasoning:
|
|
80
|
+
request_kwargs['reasoning'] = {'effort': reasoning} # type: ignore
|
|
81
|
+
|
|
82
|
+
if supports_reasoning and verbosity:
|
|
83
|
+
request_kwargs['text'] = {'verbosity': verbosity} # type: ignore
|
|
84
|
+
|
|
85
|
+
return await self.client.responses.parse(**request_kwargs)
|
|
61
86
|
|
|
62
87
|
async def _create_completion(
|
|
63
88
|
self,
|
|
@@ -68,10 +93,23 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
|
|
68
93
|
response_model: type[BaseModel] | None = None,
|
|
69
94
|
):
|
|
70
95
|
"""Create a regular completion with JSON format using Azure OpenAI."""
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
96
|
+
supports_reasoning = self._supports_reasoning_features(model)
|
|
97
|
+
|
|
98
|
+
request_kwargs = {
|
|
99
|
+
'model': model,
|
|
100
|
+
'messages': messages,
|
|
101
|
+
'max_tokens': max_tokens,
|
|
102
|
+
'response_format': {'type': 'json_object'},
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
temperature_value = temperature if not supports_reasoning else None
|
|
106
|
+
if temperature_value is not None:
|
|
107
|
+
request_kwargs['temperature'] = temperature_value
|
|
108
|
+
|
|
109
|
+
return await self.client.chat.completions.create(**request_kwargs)
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _supports_reasoning_features(model: str) -> bool:
|
|
113
|
+
"""Return True when the Azure model supports reasoning/verbosity options."""
|
|
114
|
+
reasoning_prefixes = ('o1', 'o3', 'gpt-5')
|
|
115
|
+
return model.startswith(reasoning_prefixes)
|
|
@@ -26,15 +26,34 @@ from pydantic import BaseModel
|
|
|
26
26
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
|
27
27
|
|
|
28
28
|
from ..prompts.models import Message
|
|
29
|
+
from ..tracer import NoOpTracer, Tracer
|
|
29
30
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
|
30
31
|
from .errors import RateLimitError
|
|
31
32
|
|
|
32
33
|
DEFAULT_TEMPERATURE = 0
|
|
33
34
|
DEFAULT_CACHE_DIR = './llm_cache'
|
|
34
35
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
36
|
+
|
|
37
|
+
def get_extraction_language_instruction(group_id: str | None = None) -> str:
|
|
38
|
+
"""Returns instruction for language extraction behavior.
|
|
39
|
+
|
|
40
|
+
Override this function to customize language extraction:
|
|
41
|
+
- Return empty string to disable multilingual instructions
|
|
42
|
+
- Return custom instructions for specific language requirements
|
|
43
|
+
- Use group_id to provide different instructions per group/partition
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
group_id: Optional partition identifier for the graph
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
str: Language instruction to append to system messages
|
|
50
|
+
"""
|
|
51
|
+
return (
|
|
52
|
+
'\n\nAny extracted information should be returned in the same language as it was written in. '
|
|
53
|
+
'Only output non-English text when the user has written full sentences or phrases in that non-English language. '
|
|
54
|
+
'Otherwise, output English.'
|
|
55
|
+
)
|
|
56
|
+
|
|
38
57
|
|
|
39
58
|
logger = logging.getLogger(__name__)
|
|
40
59
|
|
|
@@ -60,11 +79,16 @@ class LLMClient(ABC):
|
|
|
60
79
|
self.max_tokens = config.max_tokens
|
|
61
80
|
self.cache_enabled = cache
|
|
62
81
|
self.cache_dir = None
|
|
82
|
+
self.tracer: Tracer = NoOpTracer()
|
|
63
83
|
|
|
64
84
|
# Only create the cache directory if caching is enabled
|
|
65
85
|
if self.cache_enabled:
|
|
66
86
|
self.cache_dir = Cache(DEFAULT_CACHE_DIR)
|
|
67
87
|
|
|
88
|
+
def set_tracer(self, tracer: Tracer) -> None:
|
|
89
|
+
"""Set the tracer for this LLM client."""
|
|
90
|
+
self.tracer = tracer
|
|
91
|
+
|
|
68
92
|
def _clean_input(self, input: str) -> str:
|
|
69
93
|
"""Clean input string of invalid unicode and control characters.
|
|
70
94
|
|
|
@@ -132,6 +156,8 @@ class LLMClient(ABC):
|
|
|
132
156
|
response_model: type[BaseModel] | None = None,
|
|
133
157
|
max_tokens: int | None = None,
|
|
134
158
|
model_size: ModelSize = ModelSize.medium,
|
|
159
|
+
group_id: str | None = None,
|
|
160
|
+
prompt_name: str | None = None,
|
|
135
161
|
) -> dict[str, typing.Any]:
|
|
136
162
|
if max_tokens is None:
|
|
137
163
|
max_tokens = self.max_tokens
|
|
@@ -145,28 +171,64 @@ class LLMClient(ABC):
|
|
|
145
171
|
)
|
|
146
172
|
|
|
147
173
|
# Add multilingual extraction instructions
|
|
148
|
-
messages[0].content +=
|
|
149
|
-
|
|
150
|
-
if self.cache_enabled and self.cache_dir is not None:
|
|
151
|
-
cache_key = self._get_cache_key(messages)
|
|
152
|
-
|
|
153
|
-
cached_response = self.cache_dir.get(cache_key)
|
|
154
|
-
if cached_response is not None:
|
|
155
|
-
logger.debug(f'Cache hit for {cache_key}')
|
|
156
|
-
return cached_response
|
|
174
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
157
175
|
|
|
158
176
|
for message in messages:
|
|
159
177
|
message.content = self._clean_input(message.content)
|
|
160
178
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
179
|
+
# Wrap entire operation in tracing span
|
|
180
|
+
with self.tracer.start_span('llm.generate') as span:
|
|
181
|
+
attributes = {
|
|
182
|
+
'llm.provider': self._get_provider_type(),
|
|
183
|
+
'model.size': model_size.value,
|
|
184
|
+
'max_tokens': max_tokens,
|
|
185
|
+
'cache.enabled': self.cache_enabled,
|
|
186
|
+
}
|
|
187
|
+
if prompt_name:
|
|
188
|
+
attributes['prompt.name'] = prompt_name
|
|
189
|
+
span.add_attributes(attributes)
|
|
190
|
+
|
|
191
|
+
# Check cache first
|
|
192
|
+
if self.cache_enabled and self.cache_dir is not None:
|
|
193
|
+
cache_key = self._get_cache_key(messages)
|
|
194
|
+
cached_response = self.cache_dir.get(cache_key)
|
|
195
|
+
if cached_response is not None:
|
|
196
|
+
logger.debug(f'Cache hit for {cache_key}')
|
|
197
|
+
span.add_attributes({'cache.hit': True})
|
|
198
|
+
return cached_response
|
|
199
|
+
|
|
200
|
+
span.add_attributes({'cache.hit': False})
|
|
201
|
+
|
|
202
|
+
# Execute LLM call
|
|
203
|
+
try:
|
|
204
|
+
response = await self._generate_response_with_retry(
|
|
205
|
+
messages, response_model, max_tokens, model_size
|
|
206
|
+
)
|
|
207
|
+
except Exception as e:
|
|
208
|
+
span.set_status('error', str(e))
|
|
209
|
+
span.record_exception(e)
|
|
210
|
+
raise
|
|
211
|
+
|
|
212
|
+
# Cache response if enabled
|
|
213
|
+
if self.cache_enabled and self.cache_dir is not None:
|
|
214
|
+
cache_key = self._get_cache_key(messages)
|
|
215
|
+
self.cache_dir.set(cache_key, response)
|
|
216
|
+
|
|
217
|
+
return response
|
|
218
|
+
|
|
219
|
+
def _get_provider_type(self) -> str:
|
|
220
|
+
"""Get provider type from class name."""
|
|
221
|
+
class_name = self.__class__.__name__.lower()
|
|
222
|
+
if 'openai' in class_name:
|
|
223
|
+
return 'openai'
|
|
224
|
+
elif 'anthropic' in class_name:
|
|
225
|
+
return 'anthropic'
|
|
226
|
+
elif 'gemini' in class_name:
|
|
227
|
+
return 'gemini'
|
|
228
|
+
elif 'groq' in class_name:
|
|
229
|
+
return 'groq'
|
|
230
|
+
else:
|
|
231
|
+
return 'unknown'
|
|
170
232
|
|
|
171
233
|
def _get_failed_generation_log(self, messages: list[Message], output: str | None) -> str:
|
|
172
234
|
"""
|