graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +70 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +635 -260
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +37 -15
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +92 -48
  24. graphiti_core/llm_client/openai_client.py +39 -9
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +24 -15
  33. graphiti_core/prompts/extract_nodes.py +76 -35
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +110 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/content_chunking.py +702 -0
  45. graphiti_core/utils/datetime_utils.py +13 -0
  46. graphiti_core/utils/maintenance/community_operations.py +62 -38
  47. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  48. graphiti_core/utils/maintenance/edge_operations.py +306 -156
  49. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  50. graphiti_core/utils/maintenance/node_operations.py +466 -206
  51. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  52. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  53. graphiti_core/utils/text_utils.py +53 -0
  54. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
  55. graphiti_core-0.25.3.dist-info/RECORD +87 -0
  56. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
  57. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  58. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  59. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
@@ -20,6 +20,7 @@ from graphiti_core.cross_encoder import CrossEncoderClient
20
20
  from graphiti_core.driver.driver import GraphDriver
21
21
  from graphiti_core.embedder import EmbedderClient
22
22
  from graphiti_core.llm_client import LLMClient
23
+ from graphiti_core.tracer import Tracer
23
24
 
24
25
 
25
26
  class GraphitiClients(BaseModel):
@@ -27,5 +28,6 @@ class GraphitiClients(BaseModel):
27
28
  llm_client: LLMClient
28
29
  embedder: EmbedderClient
29
30
  cross_encoder: CrossEncoderClient
31
+ tracer: Tracer
30
32
 
31
33
  model_config = ConfigDict(arbitrary_types_allowed=True)
graphiti_core/helpers.py CHANGED
@@ -26,30 +26,52 @@ from dotenv import load_dotenv
26
26
  from neo4j import time as neo4j_time
27
27
  from numpy._typing import NDArray
28
28
  from pydantic import BaseModel
29
- from typing_extensions import LiteralString
30
29
 
30
+ from graphiti_core.driver.driver import GraphProvider
31
31
  from graphiti_core.errors import GroupIdValidationError
32
32
 
33
33
  load_dotenv()
34
34
 
35
35
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
36
36
  SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
37
- MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
38
37
  DEFAULT_PAGE_LIMIT = 20
39
38
 
40
- RUNTIME_QUERY: LiteralString = (
41
- 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
42
- )
39
+ # Content chunking configuration for entity extraction
40
+ # Density-based chunking: only chunk high-density content (many entities per token)
41
+ # This targets the failure case (large entity-dense inputs) while preserving
42
+ # context for prose/narrative content
43
+ CHUNK_TOKEN_SIZE = int(os.getenv('CHUNK_TOKEN_SIZE', 3000))
44
+ CHUNK_OVERLAP_TOKENS = int(os.getenv('CHUNK_OVERLAP_TOKENS', 200))
45
+ # Minimum tokens before considering chunking - short content processes fine regardless of density
46
+ CHUNK_MIN_TOKENS = int(os.getenv('CHUNK_MIN_TOKENS', 1000))
47
+ # Entity density threshold: chunk if estimated density > this value
48
+ # For JSON: elements per 1000 tokens > threshold * 1000 (e.g., 0.15 = 150 elements/1000 tokens)
49
+ # For Text: capitalized words per 1000 tokens > threshold * 500 (e.g., 0.15 = 75 caps/1000 tokens)
50
+ # Higher values = more conservative (less chunking), targets P95+ density cases
51
+ # Examples that trigger chunking at 0.15: AWS cost data (12mo), bulk data imports, entity-dense JSON
52
+ # Examples that DON'T chunk at 0.15: meeting transcripts, news articles, documentation
53
+ CHUNK_DENSITY_THRESHOLD = float(os.getenv('CHUNK_DENSITY_THRESHOLD', 0.15))
43
54
 
44
55
 
45
- def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
46
- return (
47
- neo_date.to_native()
48
- if isinstance(neo_date, neo4j_time.DateTime)
49
- else datetime.fromisoformat(neo_date)
50
- if neo_date
51
- else None
52
- )
56
+ def parse_db_date(input_date: neo4j_time.DateTime | str | None) -> datetime | None:
57
+ if isinstance(input_date, neo4j_time.DateTime):
58
+ return input_date.to_native()
59
+
60
+ if isinstance(input_date, str):
61
+ return datetime.fromisoformat(input_date)
62
+
63
+ return input_date
64
+
65
+
66
+ def get_default_group_id(provider: GraphProvider) -> str:
67
+ """
68
+ This function differentiates the default group id based on the database type.
69
+ For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
70
+ """
71
+ if provider == GraphProvider.FALKORDB:
72
+ return '\\_'
73
+ else:
74
+ return ''
53
75
 
54
76
 
55
77
  def lucene_sanitize(query: str) -> str:
@@ -109,7 +131,7 @@ async def semaphore_gather(
109
131
  return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
110
132
 
111
133
 
112
- def validate_group_id(group_id: str) -> bool:
134
+ def validate_group_id(group_id: str | None) -> bool:
113
135
  """
114
136
  Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
115
137
 
@@ -136,7 +158,7 @@ def validate_group_id(group_id: str) -> bool:
136
158
 
137
159
 
138
160
  def validate_excluded_entity_types(
139
- excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None
161
+ excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None
140
162
  ) -> bool:
141
163
  """
142
164
  Validate that excluded entity types are valid type names.
@@ -47,6 +47,9 @@ else:
47
47
  logger = logging.getLogger(__name__)
48
48
 
49
49
  AnthropicModel = Literal[
50
+ 'claude-sonnet-4-5-latest',
51
+ 'claude-sonnet-4-5-20250929',
52
+ 'claude-haiku-4-5-latest',
50
53
  'claude-3-7-sonnet-latest',
51
54
  'claude-3-7-sonnet-20250219',
52
55
  'claude-3-5-haiku-latest',
@@ -62,7 +65,39 @@ AnthropicModel = Literal[
62
65
  'claude-2.0',
63
66
  ]
64
67
 
65
- DEFAULT_MODEL: AnthropicModel = 'claude-3-7-sonnet-latest'
68
+ DEFAULT_MODEL: AnthropicModel = 'claude-haiku-4-5-latest'
69
+
70
+ # Maximum output tokens for different Anthropic models
71
+ # Based on official Anthropic documentation (as of 2025)
72
+ # Note: These represent standard limits without beta headers.
73
+ # Some models support higher limits with additional configuration (e.g., Claude 3.7 supports
74
+ # 128K with 'anthropic-beta: output-128k-2025-02-19' header, but this is not currently implemented).
75
+ ANTHROPIC_MODEL_MAX_TOKENS = {
76
+ # Claude 4.5 models - 64K tokens
77
+ 'claude-sonnet-4-5-latest': 65536,
78
+ 'claude-sonnet-4-5-20250929': 65536,
79
+ 'claude-haiku-4-5-latest': 65536,
80
+ # Claude 3.7 models - standard 64K tokens
81
+ 'claude-3-7-sonnet-latest': 65536,
82
+ 'claude-3-7-sonnet-20250219': 65536,
83
+ # Claude 3.5 models
84
+ 'claude-3-5-haiku-latest': 8192,
85
+ 'claude-3-5-haiku-20241022': 8192,
86
+ 'claude-3-5-sonnet-latest': 8192,
87
+ 'claude-3-5-sonnet-20241022': 8192,
88
+ 'claude-3-5-sonnet-20240620': 8192,
89
+ # Claude 3 models - 4K tokens
90
+ 'claude-3-opus-latest': 4096,
91
+ 'claude-3-opus-20240229': 4096,
92
+ 'claude-3-sonnet-20240229': 4096,
93
+ 'claude-3-haiku-20240307': 4096,
94
+ # Claude 2 models - 4K tokens
95
+ 'claude-2.1': 4096,
96
+ 'claude-2.0': 4096,
97
+ }
98
+
99
+ # Default max tokens for models not in the mapping
100
+ DEFAULT_ANTHROPIC_MAX_TOKENS = 8192
66
101
 
67
102
 
68
103
  class AnthropicClient(LLMClient):
@@ -177,6 +212,45 @@ class AnthropicClient(LLMClient):
177
212
  tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
178
213
  return tool_list_cast, tool_choice_cast
179
214
 
215
+ def _get_max_tokens_for_model(self, model: str) -> int:
216
+ """Get the maximum output tokens for a specific Anthropic model.
217
+
218
+ Args:
219
+ model: The model name to look up
220
+
221
+ Returns:
222
+ int: The maximum output tokens for the model
223
+ """
224
+ return ANTHROPIC_MODEL_MAX_TOKENS.get(model, DEFAULT_ANTHROPIC_MAX_TOKENS)
225
+
226
+ def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
227
+ """
228
+ Resolve the maximum output tokens to use based on precedence rules.
229
+
230
+ Precedence order (highest to lowest):
231
+ 1. Explicit max_tokens parameter passed to generate_response()
232
+ 2. Instance max_tokens set during client initialization
233
+ 3. Model-specific maximum tokens from ANTHROPIC_MODEL_MAX_TOKENS mapping
234
+ 4. DEFAULT_ANTHROPIC_MAX_TOKENS as final fallback
235
+
236
+ Args:
237
+ requested_max_tokens: The max_tokens parameter passed to generate_response()
238
+ model: The model name to look up model-specific limits
239
+
240
+ Returns:
241
+ int: The resolved maximum tokens to use
242
+ """
243
+ # 1. Use explicit parameter if provided
244
+ if requested_max_tokens is not None:
245
+ return requested_max_tokens
246
+
247
+ # 2. Use instance max_tokens if set during initialization
248
+ if self.max_tokens is not None:
249
+ return self.max_tokens
250
+
251
+ # 3. Use model-specific maximum or return DEFAULT_ANTHROPIC_MAX_TOKENS
252
+ return self._get_max_tokens_for_model(model)
253
+
180
254
  async def _generate_response(
181
255
  self,
182
256
  messages: list[Message],
@@ -204,12 +278,9 @@ class AnthropicClient(LLMClient):
204
278
  user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
205
279
  user_messages_cast = typing.cast(list[MessageParam], user_messages)
206
280
 
207
- # TODO: Replace hacky min finding solution after fixing hardcoded EXTRACT_EDGES_MAX_TOKENS = 16384 in
208
- # edge_operations.py. Throws errors with cheaper models that lower max_tokens.
209
- max_creation_tokens: int = min(
210
- max_tokens if max_tokens is not None else self.config.max_tokens,
211
- DEFAULT_MAX_TOKENS,
212
- )
281
+ # Resolve max_tokens dynamically based on the model's capabilities
282
+ # This allows different models to use their full output capacity
283
+ max_creation_tokens: int = self._resolve_max_tokens(max_tokens, self.model)
213
284
 
214
285
  try:
215
286
  # Create the appropriate tool based on whether response_model is provided
@@ -265,6 +336,8 @@ class AnthropicClient(LLMClient):
265
336
  response_model: type[BaseModel] | None = None,
266
337
  max_tokens: int | None = None,
267
338
  model_size: ModelSize = ModelSize.medium,
339
+ group_id: str | None = None,
340
+ prompt_name: str | None = None,
268
341
  ) -> dict[str, typing.Any]:
269
342
  """
270
343
  Generate a response from the LLM.
@@ -285,55 +358,72 @@ class AnthropicClient(LLMClient):
285
358
  if max_tokens is None:
286
359
  max_tokens = self.max_tokens
287
360
 
288
- retry_count = 0
289
- max_retries = 2
290
- last_error: Exception | None = None
291
-
292
- while retry_count <= max_retries:
293
- try:
294
- response = await self._generate_response(
295
- messages, response_model, max_tokens, model_size
296
- )
297
-
298
- # If we have a response_model, attempt to validate the response
299
- if response_model is not None:
300
- # Validate the response against the response_model
301
- model_instance = response_model(**response)
302
- return model_instance.model_dump()
303
-
304
- # If no validation needed, return the response
305
- return response
361
+ # Wrap entire operation in tracing span
362
+ with self.tracer.start_span('llm.generate') as span:
363
+ attributes = {
364
+ 'llm.provider': 'anthropic',
365
+ 'model.size': model_size.value,
366
+ 'max_tokens': max_tokens,
367
+ }
368
+ if prompt_name:
369
+ attributes['prompt.name'] = prompt_name
370
+ span.add_attributes(attributes)
371
+
372
+ retry_count = 0
373
+ max_retries = 2
374
+ last_error: Exception | None = None
375
+
376
+ while retry_count <= max_retries:
377
+ try:
378
+ response = await self._generate_response(
379
+ messages, response_model, max_tokens, model_size
380
+ )
306
381
 
307
- except (RateLimitError, RefusalError):
308
- # These errors should not trigger retries
309
- raise
310
- except Exception as e:
311
- last_error = e
382
+ # If we have a response_model, attempt to validate the response
383
+ if response_model is not None:
384
+ # Validate the response against the response_model
385
+ model_instance = response_model(**response)
386
+ return model_instance.model_dump()
387
+
388
+ # If no validation needed, return the response
389
+ return response
390
+
391
+ except (RateLimitError, RefusalError):
392
+ # These errors should not trigger retries
393
+ span.set_status('error', str(last_error))
394
+ raise
395
+ except Exception as e:
396
+ last_error = e
397
+
398
+ if retry_count >= max_retries:
399
+ if isinstance(e, ValidationError):
400
+ logger.error(
401
+ f'Validation error after {retry_count}/{max_retries} attempts: {e}'
402
+ )
403
+ else:
404
+ logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
405
+ span.set_status('error', str(e))
406
+ span.record_exception(e)
407
+ raise e
312
408
 
313
- if retry_count >= max_retries:
314
409
  if isinstance(e, ValidationError):
315
- logger.error(
316
- f'Validation error after {retry_count}/{max_retries} attempts: {e}'
317
- )
410
+ response_model_cast = typing.cast(type[BaseModel], response_model)
411
+ error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
318
412
  else:
319
- logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
320
- raise e
413
+ error_context = (
414
+ f'The previous response attempt was invalid. '
415
+ f'Error type: {e.__class__.__name__}. '
416
+ f'Error details: {str(e)}. '
417
+ f'Please try again with a valid response.'
418
+ )
321
419
 
322
- if isinstance(e, ValidationError):
323
- response_model_cast = typing.cast(type[BaseModel], response_model)
324
- error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
325
- else:
326
- error_context = (
327
- f'The previous response attempt was invalid. '
328
- f'Error type: {e.__class__.__name__}. '
329
- f'Error details: {str(e)}. '
330
- f'Please try again with a valid response.'
420
+ # Common retry logic
421
+ retry_count += 1
422
+ messages.append(Message(role='user', content=error_context))
423
+ logger.warning(
424
+ f'Retrying after error (attempt {retry_count}/{max_retries}): {e}'
331
425
  )
332
426
 
333
- # Common retry logic
334
- retry_count += 1
335
- messages.append(Message(role='user', content=error_context))
336
- logger.warning(f'Retrying after error (attempt {retry_count}/{max_retries}): {e}')
337
-
338
- # If we somehow get here, raise the last error
339
- raise last_error or Exception('Max retries exceeded with no specific error')
427
+ # If we somehow get here, raise the last error
428
+ span.set_status('error', str(last_error))
429
+ raise last_error or Exception('Max retries exceeded with no specific error')
@@ -17,7 +17,7 @@ limitations under the License.
17
17
  import logging
18
18
  from typing import ClassVar
19
19
 
20
- from openai import AsyncAzureOpenAI
20
+ from openai import AsyncAzureOpenAI, AsyncOpenAI
21
21
  from openai.types.chat import ChatCompletionMessageParam
22
22
  from pydantic import BaseModel
23
23
 
@@ -28,18 +28,29 @@ logger = logging.getLogger(__name__)
28
28
 
29
29
 
30
30
  class AzureOpenAILLMClient(BaseOpenAIClient):
31
- """Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
31
+ """Wrapper class for Azure OpenAI that implements the LLMClient interface.
32
+
33
+ Supports both AsyncAzureOpenAI and AsyncOpenAI (with Azure v1 API endpoint).
34
+ """
32
35
 
33
36
  # Class-level constants
34
37
  MAX_RETRIES: ClassVar[int] = 2
35
38
 
36
39
  def __init__(
37
40
  self,
38
- azure_client: AsyncAzureOpenAI,
41
+ azure_client: AsyncAzureOpenAI | AsyncOpenAI,
39
42
  config: LLMConfig | None = None,
40
43
  max_tokens: int = DEFAULT_MAX_TOKENS,
44
+ reasoning: str | None = None,
45
+ verbosity: str | None = None,
41
46
  ):
42
- super().__init__(config, cache=False, max_tokens=max_tokens)
47
+ super().__init__(
48
+ config,
49
+ cache=False,
50
+ max_tokens=max_tokens,
51
+ reasoning=reasoning,
52
+ verbosity=verbosity,
53
+ )
43
54
  self.client = azure_client
44
55
 
45
56
  async def _create_structured_completion(
@@ -49,15 +60,29 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
49
60
  temperature: float | None,
50
61
  max_tokens: int,
51
62
  response_model: type[BaseModel],
63
+ reasoning: str | None,
64
+ verbosity: str | None,
52
65
  ):
53
- """Create a structured completion using Azure OpenAI's beta parse API."""
54
- return await self.client.beta.chat.completions.parse(
55
- model=model,
56
- messages=messages,
57
- temperature=temperature,
58
- max_tokens=max_tokens,
59
- response_format=response_model, # type: ignore
60
- )
66
+ """Create a structured completion using Azure OpenAI's responses.parse API."""
67
+ supports_reasoning = self._supports_reasoning_features(model)
68
+ request_kwargs = {
69
+ 'model': model,
70
+ 'input': messages,
71
+ 'max_output_tokens': max_tokens,
72
+ 'text_format': response_model, # type: ignore
73
+ }
74
+
75
+ temperature_value = temperature if not supports_reasoning else None
76
+ if temperature_value is not None:
77
+ request_kwargs['temperature'] = temperature_value
78
+
79
+ if supports_reasoning and reasoning:
80
+ request_kwargs['reasoning'] = {'effort': reasoning} # type: ignore
81
+
82
+ if supports_reasoning and verbosity:
83
+ request_kwargs['text'] = {'verbosity': verbosity} # type: ignore
84
+
85
+ return await self.client.responses.parse(**request_kwargs)
61
86
 
62
87
  async def _create_completion(
63
88
  self,
@@ -68,10 +93,23 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
68
93
  response_model: type[BaseModel] | None = None,
69
94
  ):
70
95
  """Create a regular completion with JSON format using Azure OpenAI."""
71
- return await self.client.chat.completions.create(
72
- model=model,
73
- messages=messages,
74
- temperature=temperature,
75
- max_tokens=max_tokens,
76
- response_format={'type': 'json_object'},
77
- )
96
+ supports_reasoning = self._supports_reasoning_features(model)
97
+
98
+ request_kwargs = {
99
+ 'model': model,
100
+ 'messages': messages,
101
+ 'max_tokens': max_tokens,
102
+ 'response_format': {'type': 'json_object'},
103
+ }
104
+
105
+ temperature_value = temperature if not supports_reasoning else None
106
+ if temperature_value is not None:
107
+ request_kwargs['temperature'] = temperature_value
108
+
109
+ return await self.client.chat.completions.create(**request_kwargs)
110
+
111
+ @staticmethod
112
+ def _supports_reasoning_features(model: str) -> bool:
113
+ """Return True when the Azure model supports reasoning/verbosity options."""
114
+ reasoning_prefixes = ('o1', 'o3', 'gpt-5')
115
+ return model.startswith(reasoning_prefixes)
@@ -26,15 +26,34 @@ from pydantic import BaseModel
26
26
  from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
27
27
 
28
28
  from ..prompts.models import Message
29
+ from ..tracer import NoOpTracer, Tracer
29
30
  from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
30
31
  from .errors import RateLimitError
31
32
 
32
33
  DEFAULT_TEMPERATURE = 0
33
34
  DEFAULT_CACHE_DIR = './llm_cache'
34
35
 
35
- MULTILINGUAL_EXTRACTION_RESPONSES = (
36
- '\n\nAny extracted information should be returned in the same language as it was written in.'
37
- )
36
+
37
+ def get_extraction_language_instruction(group_id: str | None = None) -> str:
38
+ """Returns instruction for language extraction behavior.
39
+
40
+ Override this function to customize language extraction:
41
+ - Return empty string to disable multilingual instructions
42
+ - Return custom instructions for specific language requirements
43
+ - Use group_id to provide different instructions per group/partition
44
+
45
+ Args:
46
+ group_id: Optional partition identifier for the graph
47
+
48
+ Returns:
49
+ str: Language instruction to append to system messages
50
+ """
51
+ return (
52
+ '\n\nAny extracted information should be returned in the same language as it was written in. '
53
+ 'Only output non-English text when the user has written full sentences or phrases in that non-English language. '
54
+ 'Otherwise, output English.'
55
+ )
56
+
38
57
 
39
58
  logger = logging.getLogger(__name__)
40
59
 
@@ -60,11 +79,16 @@ class LLMClient(ABC):
60
79
  self.max_tokens = config.max_tokens
61
80
  self.cache_enabled = cache
62
81
  self.cache_dir = None
82
+ self.tracer: Tracer = NoOpTracer()
63
83
 
64
84
  # Only create the cache directory if caching is enabled
65
85
  if self.cache_enabled:
66
86
  self.cache_dir = Cache(DEFAULT_CACHE_DIR)
67
87
 
88
+ def set_tracer(self, tracer: Tracer) -> None:
89
+ """Set the tracer for this LLM client."""
90
+ self.tracer = tracer
91
+
68
92
  def _clean_input(self, input: str) -> str:
69
93
  """Clean input string of invalid unicode and control characters.
70
94
 
@@ -132,6 +156,8 @@ class LLMClient(ABC):
132
156
  response_model: type[BaseModel] | None = None,
133
157
  max_tokens: int | None = None,
134
158
  model_size: ModelSize = ModelSize.medium,
159
+ group_id: str | None = None,
160
+ prompt_name: str | None = None,
135
161
  ) -> dict[str, typing.Any]:
136
162
  if max_tokens is None:
137
163
  max_tokens = self.max_tokens
@@ -145,28 +171,64 @@ class LLMClient(ABC):
145
171
  )
146
172
 
147
173
  # Add multilingual extraction instructions
148
- messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
149
-
150
- if self.cache_enabled and self.cache_dir is not None:
151
- cache_key = self._get_cache_key(messages)
152
-
153
- cached_response = self.cache_dir.get(cache_key)
154
- if cached_response is not None:
155
- logger.debug(f'Cache hit for {cache_key}')
156
- return cached_response
174
+ messages[0].content += get_extraction_language_instruction(group_id)
157
175
 
158
176
  for message in messages:
159
177
  message.content = self._clean_input(message.content)
160
178
 
161
- response = await self._generate_response_with_retry(
162
- messages, response_model, max_tokens, model_size
163
- )
164
-
165
- if self.cache_enabled and self.cache_dir is not None:
166
- cache_key = self._get_cache_key(messages)
167
- self.cache_dir.set(cache_key, response)
168
-
169
- return response
179
+ # Wrap entire operation in tracing span
180
+ with self.tracer.start_span('llm.generate') as span:
181
+ attributes = {
182
+ 'llm.provider': self._get_provider_type(),
183
+ 'model.size': model_size.value,
184
+ 'max_tokens': max_tokens,
185
+ 'cache.enabled': self.cache_enabled,
186
+ }
187
+ if prompt_name:
188
+ attributes['prompt.name'] = prompt_name
189
+ span.add_attributes(attributes)
190
+
191
+ # Check cache first
192
+ if self.cache_enabled and self.cache_dir is not None:
193
+ cache_key = self._get_cache_key(messages)
194
+ cached_response = self.cache_dir.get(cache_key)
195
+ if cached_response is not None:
196
+ logger.debug(f'Cache hit for {cache_key}')
197
+ span.add_attributes({'cache.hit': True})
198
+ return cached_response
199
+
200
+ span.add_attributes({'cache.hit': False})
201
+
202
+ # Execute LLM call
203
+ try:
204
+ response = await self._generate_response_with_retry(
205
+ messages, response_model, max_tokens, model_size
206
+ )
207
+ except Exception as e:
208
+ span.set_status('error', str(e))
209
+ span.record_exception(e)
210
+ raise
211
+
212
+ # Cache response if enabled
213
+ if self.cache_enabled and self.cache_dir is not None:
214
+ cache_key = self._get_cache_key(messages)
215
+ self.cache_dir.set(cache_key, response)
216
+
217
+ return response
218
+
219
+ def _get_provider_type(self) -> str:
220
+ """Get provider type from class name."""
221
+ class_name = self.__class__.__name__.lower()
222
+ if 'openai' in class_name:
223
+ return 'openai'
224
+ elif 'anthropic' in class_name:
225
+ return 'anthropic'
226
+ elif 'gemini' in class_name:
227
+ return 'gemini'
228
+ elif 'groq' in class_name:
229
+ return 'groq'
230
+ else:
231
+ return 'unknown'
170
232
 
171
233
  def _get_failed_generation_log(self, messages: list[Message], output: str | None) -> str:
172
234
  """
@@ -17,7 +17,7 @@ limitations under the License.
17
17
  from enum import Enum
18
18
 
19
19
  DEFAULT_MAX_TOKENS = 8192
20
- DEFAULT_TEMPERATURE = 0
20
+ DEFAULT_TEMPERATURE = 1
21
21
 
22
22
 
23
23
  class ModelSize(Enum):