graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
  2. graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
  3. graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
  4. graphiti_core/decorators.py +110 -0
  5. graphiti_core/driver/__init__.py +19 -0
  6. graphiti_core/driver/driver.py +124 -0
  7. graphiti_core/driver/falkordb_driver.py +362 -0
  8. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  9. graphiti_core/driver/kuzu_driver.py +182 -0
  10. graphiti_core/driver/neo4j_driver.py +117 -0
  11. graphiti_core/driver/neptune_driver.py +305 -0
  12. graphiti_core/driver/search_interface/search_interface.py +89 -0
  13. graphiti_core/edges.py +287 -172
  14. graphiti_core/embedder/azure_openai.py +71 -0
  15. graphiti_core/embedder/client.py +2 -1
  16. graphiti_core/embedder/gemini.py +116 -22
  17. graphiti_core/embedder/voyage.py +13 -2
  18. graphiti_core/errors.py +8 -0
  19. graphiti_core/graph_queries.py +162 -0
  20. graphiti_core/graphiti.py +705 -193
  21. graphiti_core/graphiti_types.py +4 -2
  22. graphiti_core/helpers.py +87 -10
  23. graphiti_core/llm_client/__init__.py +16 -0
  24. graphiti_core/llm_client/anthropic_client.py +159 -56
  25. graphiti_core/llm_client/azure_openai_client.py +115 -0
  26. graphiti_core/llm_client/client.py +98 -21
  27. graphiti_core/llm_client/config.py +1 -1
  28. graphiti_core/llm_client/gemini_client.py +290 -41
  29. graphiti_core/llm_client/groq_client.py +14 -3
  30. graphiti_core/llm_client/openai_base_client.py +261 -0
  31. graphiti_core/llm_client/openai_client.py +56 -132
  32. graphiti_core/llm_client/openai_generic_client.py +91 -56
  33. graphiti_core/models/edges/edge_db_queries.py +259 -35
  34. graphiti_core/models/nodes/node_db_queries.py +311 -32
  35. graphiti_core/nodes.py +420 -205
  36. graphiti_core/prompts/dedupe_edges.py +46 -32
  37. graphiti_core/prompts/dedupe_nodes.py +67 -42
  38. graphiti_core/prompts/eval.py +4 -4
  39. graphiti_core/prompts/extract_edges.py +27 -16
  40. graphiti_core/prompts/extract_nodes.py +74 -31
  41. graphiti_core/prompts/prompt_helpers.py +39 -0
  42. graphiti_core/prompts/snippets.py +29 -0
  43. graphiti_core/prompts/summarize_nodes.py +23 -25
  44. graphiti_core/search/search.py +158 -82
  45. graphiti_core/search/search_config.py +39 -4
  46. graphiti_core/search/search_filters.py +126 -35
  47. graphiti_core/search/search_helpers.py +5 -6
  48. graphiti_core/search/search_utils.py +1405 -485
  49. graphiti_core/telemetry/__init__.py +9 -0
  50. graphiti_core/telemetry/telemetry.py +117 -0
  51. graphiti_core/tracer.py +193 -0
  52. graphiti_core/utils/bulk_utils.py +364 -285
  53. graphiti_core/utils/datetime_utils.py +13 -0
  54. graphiti_core/utils/maintenance/community_operations.py +67 -49
  55. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  56. graphiti_core/utils/maintenance/edge_operations.py +339 -197
  57. graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
  58. graphiti_core/utils/maintenance/node_operations.py +319 -238
  59. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  60. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  61. graphiti_core/utils/text_utils.py +53 -0
  62. graphiti_core-0.24.3.dist-info/METADATA +726 -0
  63. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  64. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  65. graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
  66. graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
  67. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  68. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
@@ -14,18 +14,20 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from neo4j import AsyncDriver
18
17
  from pydantic import BaseModel, ConfigDict
19
18
 
20
19
  from graphiti_core.cross_encoder import CrossEncoderClient
20
+ from graphiti_core.driver.driver import GraphDriver
21
21
  from graphiti_core.embedder import EmbedderClient
22
22
  from graphiti_core.llm_client import LLMClient
23
+ from graphiti_core.tracer import Tracer
23
24
 
24
25
 
25
26
  class GraphitiClients(BaseModel):
26
- driver: AsyncDriver
27
+ driver: GraphDriver
27
28
  llm_client: LLMClient
28
29
  embedder: EmbedderClient
29
30
  cross_encoder: CrossEncoderClient
31
+ tracer: Tracer
30
32
 
31
33
  model_config = ConfigDict(arbitrary_types_allowed=True)
graphiti_core/helpers.py CHANGED
@@ -16,30 +16,47 @@ limitations under the License.
16
16
 
17
17
  import asyncio
18
18
  import os
19
+ import re
19
20
  from collections.abc import Coroutine
20
21
  from datetime import datetime
22
+ from typing import Any
21
23
 
22
24
  import numpy as np
23
25
  from dotenv import load_dotenv
24
26
  from neo4j import time as neo4j_time
25
27
  from numpy._typing import NDArray
26
- from typing_extensions import LiteralString
28
+ from pydantic import BaseModel
29
+
30
+ from graphiti_core.driver.driver import GraphProvider
31
+ from graphiti_core.errors import GroupIdValidationError
27
32
 
28
33
  load_dotenv()
29
34
 
30
- DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
31
35
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
32
36
  SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
33
37
  MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
34
38
  DEFAULT_PAGE_LIMIT = 20
35
39
 
36
- RUNTIME_QUERY: LiteralString = (
37
- 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
38
- )
39
40
 
41
+ def parse_db_date(input_date: neo4j_time.DateTime | str | None) -> datetime | None:
42
+ if isinstance(input_date, neo4j_time.DateTime):
43
+ return input_date.to_native()
44
+
45
+ if isinstance(input_date, str):
46
+ return datetime.fromisoformat(input_date)
47
+
48
+ return input_date
40
49
 
41
- def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
42
- return neo_date.to_native() if neo_date else None
50
+
51
+ def get_default_group_id(provider: GraphProvider) -> str:
52
+ """
53
+ This function differentiates the default group id based on the database type.
54
+ For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
55
+ """
56
+ if provider == GraphProvider.FALKORDB:
57
+ return '\\_'
58
+ else:
59
+ return ''
43
60
 
44
61
 
45
62
  def lucene_sanitize(query: str) -> str:
@@ -88,12 +105,72 @@ def normalize_l2(embedding: list[float]) -> NDArray:
88
105
  # Use this instead of asyncio.gather() to bound coroutines
89
106
  async def semaphore_gather(
90
107
  *coroutines: Coroutine,
91
- max_coroutines: int = SEMAPHORE_LIMIT,
92
- ):
93
- semaphore = asyncio.Semaphore(max_coroutines)
108
+ max_coroutines: int | None = None,
109
+ ) -> list[Any]:
110
+ semaphore = asyncio.Semaphore(max_coroutines or SEMAPHORE_LIMIT)
94
111
 
95
112
  async def _wrap_coroutine(coroutine):
96
113
  async with semaphore:
97
114
  return await coroutine
98
115
 
99
116
  return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
117
+
118
+
119
+ def validate_group_id(group_id: str | None) -> bool:
120
+ """
121
+ Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
122
+
123
+ Args:
124
+ group_id: The group_id to validate
125
+
126
+ Returns:
127
+ True if valid, False otherwise
128
+
129
+ Raises:
130
+ GroupIdValidationError: If group_id contains invalid characters
131
+ """
132
+
133
+ # Allow empty string (default case)
134
+ if not group_id:
135
+ return True
136
+
137
+ # Check if string contains only ASCII alphanumeric characters, dashes, or underscores
138
+ # Pattern matches: letters (a-z, A-Z), digits (0-9), hyphens (-), and underscores (_)
139
+ if not re.match(r'^[a-zA-Z0-9_-]+$', group_id):
140
+ raise GroupIdValidationError(group_id)
141
+
142
+ return True
143
+
144
+
145
+ def validate_excluded_entity_types(
146
+ excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None
147
+ ) -> bool:
148
+ """
149
+ Validate that excluded entity types are valid type names.
150
+
151
+ Args:
152
+ excluded_entity_types: List of entity type names to exclude
153
+ entity_types: Dictionary of available custom entity types
154
+
155
+ Returns:
156
+ True if valid
157
+
158
+ Raises:
159
+ ValueError: If any excluded type names are invalid
160
+ """
161
+ if not excluded_entity_types:
162
+ return True
163
+
164
+ # Build set of available type names
165
+ available_types = {'Entity'} # Default type is always available
166
+ if entity_types:
167
+ available_types.update(entity_types.keys())
168
+
169
+ # Check for invalid type names
170
+ invalid_types = set(excluded_entity_types) - available_types
171
+ if invalid_types:
172
+ raise ValueError(
173
+ f'Invalid excluded entity types: {sorted(invalid_types)}. Available types: {sorted(available_types)}'
174
+ )
175
+
176
+ return True
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  from .client import LLMClient
2
18
  from .config import LLMConfig
3
19
  from .errors import RateLimitError
@@ -19,11 +19,8 @@ import logging
19
19
  import os
20
20
  import typing
21
21
  from json import JSONDecodeError
22
- from typing import Literal
22
+ from typing import TYPE_CHECKING, Literal
23
23
 
24
- import anthropic
25
- from anthropic import AsyncAnthropic
26
- from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
27
24
  from pydantic import BaseModel, ValidationError
28
25
 
29
26
  from ..prompts.models import Message
@@ -31,9 +28,28 @@ from .client import LLMClient
31
28
  from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
32
29
  from .errors import RateLimitError, RefusalError
33
30
 
31
+ if TYPE_CHECKING:
32
+ import anthropic
33
+ from anthropic import AsyncAnthropic
34
+ from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
35
+ else:
36
+ try:
37
+ import anthropic
38
+ from anthropic import AsyncAnthropic
39
+ from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
40
+ except ImportError:
41
+ raise ImportError(
42
+ 'anthropic is required for AnthropicClient. '
43
+ 'Install it with: pip install graphiti-core[anthropic]'
44
+ ) from None
45
+
46
+
34
47
  logger = logging.getLogger(__name__)
35
48
 
36
49
  AnthropicModel = Literal[
50
+ 'claude-sonnet-4-5-latest',
51
+ 'claude-sonnet-4-5-20250929',
52
+ 'claude-haiku-4-5-latest',
37
53
  'claude-3-7-sonnet-latest',
38
54
  'claude-3-7-sonnet-20250219',
39
55
  'claude-3-5-haiku-latest',
@@ -49,7 +65,39 @@ AnthropicModel = Literal[
49
65
  'claude-2.0',
50
66
  ]
51
67
 
52
- DEFAULT_MODEL: AnthropicModel = 'claude-3-7-sonnet-latest'
68
+ DEFAULT_MODEL: AnthropicModel = 'claude-haiku-4-5-latest'
69
+
70
+ # Maximum output tokens for different Anthropic models
71
+ # Based on official Anthropic documentation (as of 2025)
72
+ # Note: These represent standard limits without beta headers.
73
+ # Some models support higher limits with additional configuration (e.g., Claude 3.7 supports
74
+ # 128K with 'anthropic-beta: output-128k-2025-02-19' header, but this is not currently implemented).
75
+ ANTHROPIC_MODEL_MAX_TOKENS = {
76
+ # Claude 4.5 models - 64K tokens
77
+ 'claude-sonnet-4-5-latest': 65536,
78
+ 'claude-sonnet-4-5-20250929': 65536,
79
+ 'claude-haiku-4-5-latest': 65536,
80
+ # Claude 3.7 models - standard 64K tokens
81
+ 'claude-3-7-sonnet-latest': 65536,
82
+ 'claude-3-7-sonnet-20250219': 65536,
83
+ # Claude 3.5 models
84
+ 'claude-3-5-haiku-latest': 8192,
85
+ 'claude-3-5-haiku-20241022': 8192,
86
+ 'claude-3-5-sonnet-latest': 8192,
87
+ 'claude-3-5-sonnet-20241022': 8192,
88
+ 'claude-3-5-sonnet-20240620': 8192,
89
+ # Claude 3 models - 4K tokens
90
+ 'claude-3-opus-latest': 4096,
91
+ 'claude-3-opus-20240229': 4096,
92
+ 'claude-3-sonnet-20240229': 4096,
93
+ 'claude-3-haiku-20240307': 4096,
94
+ # Claude 2 models - 4K tokens
95
+ 'claude-2.1': 4096,
96
+ 'claude-2.0': 4096,
97
+ }
98
+
99
+ # Default max tokens for models not in the mapping
100
+ DEFAULT_ANTHROPIC_MAX_TOKENS = 8192
53
101
 
54
102
 
55
103
  class AnthropicClient(LLMClient):
@@ -164,6 +212,45 @@ class AnthropicClient(LLMClient):
164
212
  tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
165
213
  return tool_list_cast, tool_choice_cast
166
214
 
215
+ def _get_max_tokens_for_model(self, model: str) -> int:
216
+ """Get the maximum output tokens for a specific Anthropic model.
217
+
218
+ Args:
219
+ model: The model name to look up
220
+
221
+ Returns:
222
+ int: The maximum output tokens for the model
223
+ """
224
+ return ANTHROPIC_MODEL_MAX_TOKENS.get(model, DEFAULT_ANTHROPIC_MAX_TOKENS)
225
+
226
+ def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
227
+ """
228
+ Resolve the maximum output tokens to use based on precedence rules.
229
+
230
+ Precedence order (highest to lowest):
231
+ 1. Explicit max_tokens parameter passed to generate_response()
232
+ 2. Instance max_tokens set during client initialization
233
+ 3. Model-specific maximum tokens from ANTHROPIC_MODEL_MAX_TOKENS mapping
234
+ 4. DEFAULT_ANTHROPIC_MAX_TOKENS as final fallback
235
+
236
+ Args:
237
+ requested_max_tokens: The max_tokens parameter passed to generate_response()
238
+ model: The model name to look up model-specific limits
239
+
240
+ Returns:
241
+ int: The resolved maximum tokens to use
242
+ """
243
+ # 1. Use explicit parameter if provided
244
+ if requested_max_tokens is not None:
245
+ return requested_max_tokens
246
+
247
+ # 2. Use instance max_tokens if set during initialization
248
+ if self.max_tokens is not None:
249
+ return self.max_tokens
250
+
251
+ # 3. Use model-specific maximum or return DEFAULT_ANTHROPIC_MAX_TOKENS
252
+ return self._get_max_tokens_for_model(model)
253
+
167
254
  async def _generate_response(
168
255
  self,
169
256
  messages: list[Message],
@@ -191,12 +278,9 @@ class AnthropicClient(LLMClient):
191
278
  user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
192
279
  user_messages_cast = typing.cast(list[MessageParam], user_messages)
193
280
 
194
- # TODO: Replace hacky min finding solution after fixing hardcoded EXTRACT_EDGES_MAX_TOKENS = 16384 in
195
- # edge_operations.py. Throws errors with cheaper models that lower max_tokens.
196
- max_creation_tokens: int = min(
197
- max_tokens if max_tokens is not None else self.config.max_tokens,
198
- DEFAULT_MAX_TOKENS,
199
- )
281
+ # Resolve max_tokens dynamically based on the model's capabilities
282
+ # This allows different models to use their full output capacity
283
+ max_creation_tokens: int = self._resolve_max_tokens(max_tokens, self.model)
200
284
 
201
285
  try:
202
286
  # Create the appropriate tool based on whether response_model is provided
@@ -252,6 +336,8 @@ class AnthropicClient(LLMClient):
252
336
  response_model: type[BaseModel] | None = None,
253
337
  max_tokens: int | None = None,
254
338
  model_size: ModelSize = ModelSize.medium,
339
+ group_id: str | None = None,
340
+ prompt_name: str | None = None,
255
341
  ) -> dict[str, typing.Any]:
256
342
  """
257
343
  Generate a response from the LLM.
@@ -272,55 +358,72 @@ class AnthropicClient(LLMClient):
272
358
  if max_tokens is None:
273
359
  max_tokens = self.max_tokens
274
360
 
275
- retry_count = 0
276
- max_retries = 2
277
- last_error: Exception | None = None
278
-
279
- while retry_count <= max_retries:
280
- try:
281
- response = await self._generate_response(
282
- messages, response_model, max_tokens, model_size
283
- )
284
-
285
- # If we have a response_model, attempt to validate the response
286
- if response_model is not None:
287
- # Validate the response against the response_model
288
- model_instance = response_model(**response)
289
- return model_instance.model_dump()
290
-
291
- # If no validation needed, return the response
292
- return response
361
+ # Wrap entire operation in tracing span
362
+ with self.tracer.start_span('llm.generate') as span:
363
+ attributes = {
364
+ 'llm.provider': 'anthropic',
365
+ 'model.size': model_size.value,
366
+ 'max_tokens': max_tokens,
367
+ }
368
+ if prompt_name:
369
+ attributes['prompt.name'] = prompt_name
370
+ span.add_attributes(attributes)
371
+
372
+ retry_count = 0
373
+ max_retries = 2
374
+ last_error: Exception | None = None
375
+
376
+ while retry_count <= max_retries:
377
+ try:
378
+ response = await self._generate_response(
379
+ messages, response_model, max_tokens, model_size
380
+ )
293
381
 
294
- except (RateLimitError, RefusalError):
295
- # These errors should not trigger retries
296
- raise
297
- except Exception as e:
298
- last_error = e
382
+ # If we have a response_model, attempt to validate the response
383
+ if response_model is not None:
384
+ # Validate the response against the response_model
385
+ model_instance = response_model(**response)
386
+ return model_instance.model_dump()
387
+
388
+ # If no validation needed, return the response
389
+ return response
390
+
391
+ except (RateLimitError, RefusalError):
392
+ # These errors should not trigger retries
393
+ span.set_status('error', str(last_error))
394
+ raise
395
+ except Exception as e:
396
+ last_error = e
397
+
398
+ if retry_count >= max_retries:
399
+ if isinstance(e, ValidationError):
400
+ logger.error(
401
+ f'Validation error after {retry_count}/{max_retries} attempts: {e}'
402
+ )
403
+ else:
404
+ logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
405
+ span.set_status('error', str(e))
406
+ span.record_exception(e)
407
+ raise e
299
408
 
300
- if retry_count >= max_retries:
301
409
  if isinstance(e, ValidationError):
302
- logger.error(
303
- f'Validation error after {retry_count}/{max_retries} attempts: {e}'
304
- )
410
+ response_model_cast = typing.cast(type[BaseModel], response_model)
411
+ error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
305
412
  else:
306
- logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
307
- raise e
413
+ error_context = (
414
+ f'The previous response attempt was invalid. '
415
+ f'Error type: {e.__class__.__name__}. '
416
+ f'Error details: {str(e)}. '
417
+ f'Please try again with a valid response.'
418
+ )
308
419
 
309
- if isinstance(e, ValidationError):
310
- response_model_cast = typing.cast(type[BaseModel], response_model)
311
- error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
312
- else:
313
- error_context = (
314
- f'The previous response attempt was invalid. '
315
- f'Error type: {e.__class__.__name__}. '
316
- f'Error details: {str(e)}. '
317
- f'Please try again with a valid response.'
420
+ # Common retry logic
421
+ retry_count += 1
422
+ messages.append(Message(role='user', content=error_context))
423
+ logger.warning(
424
+ f'Retrying after error (attempt {retry_count}/{max_retries}): {e}'
318
425
  )
319
426
 
320
- # Common retry logic
321
- retry_count += 1
322
- messages.append(Message(role='user', content=error_context))
323
- logger.warning(f'Retrying after error (attempt {retry_count}/{max_retries}): {e}')
324
-
325
- # If we somehow get here, raise the last error
326
- raise last_error or Exception('Max retries exceeded with no specific error')
427
+ # If we somehow get here, raise the last error
428
+ span.set_status('error', str(last_error))
429
+ raise last_error or Exception('Max retries exceeded with no specific error')
@@ -0,0 +1,115 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from typing import ClassVar
19
+
20
+ from openai import AsyncAzureOpenAI, AsyncOpenAI
21
+ from openai.types.chat import ChatCompletionMessageParam
22
+ from pydantic import BaseModel
23
+
24
+ from .config import DEFAULT_MAX_TOKENS, LLMConfig
25
+ from .openai_base_client import BaseOpenAIClient
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class AzureOpenAILLMClient(BaseOpenAIClient):
31
+ """Wrapper class for Azure OpenAI that implements the LLMClient interface.
32
+
33
+ Supports both AsyncAzureOpenAI and AsyncOpenAI (with Azure v1 API endpoint).
34
+ """
35
+
36
+ # Class-level constants
37
+ MAX_RETRIES: ClassVar[int] = 2
38
+
39
+ def __init__(
40
+ self,
41
+ azure_client: AsyncAzureOpenAI | AsyncOpenAI,
42
+ config: LLMConfig | None = None,
43
+ max_tokens: int = DEFAULT_MAX_TOKENS,
44
+ reasoning: str | None = None,
45
+ verbosity: str | None = None,
46
+ ):
47
+ super().__init__(
48
+ config,
49
+ cache=False,
50
+ max_tokens=max_tokens,
51
+ reasoning=reasoning,
52
+ verbosity=verbosity,
53
+ )
54
+ self.client = azure_client
55
+
56
+ async def _create_structured_completion(
57
+ self,
58
+ model: str,
59
+ messages: list[ChatCompletionMessageParam],
60
+ temperature: float | None,
61
+ max_tokens: int,
62
+ response_model: type[BaseModel],
63
+ reasoning: str | None,
64
+ verbosity: str | None,
65
+ ):
66
+ """Create a structured completion using Azure OpenAI's responses.parse API."""
67
+ supports_reasoning = self._supports_reasoning_features(model)
68
+ request_kwargs = {
69
+ 'model': model,
70
+ 'input': messages,
71
+ 'max_output_tokens': max_tokens,
72
+ 'text_format': response_model, # type: ignore
73
+ }
74
+
75
+ temperature_value = temperature if not supports_reasoning else None
76
+ if temperature_value is not None:
77
+ request_kwargs['temperature'] = temperature_value
78
+
79
+ if supports_reasoning and reasoning:
80
+ request_kwargs['reasoning'] = {'effort': reasoning} # type: ignore
81
+
82
+ if supports_reasoning and verbosity:
83
+ request_kwargs['text'] = {'verbosity': verbosity} # type: ignore
84
+
85
+ return await self.client.responses.parse(**request_kwargs)
86
+
87
+ async def _create_completion(
88
+ self,
89
+ model: str,
90
+ messages: list[ChatCompletionMessageParam],
91
+ temperature: float | None,
92
+ max_tokens: int,
93
+ response_model: type[BaseModel] | None = None,
94
+ ):
95
+ """Create a regular completion with JSON format using Azure OpenAI."""
96
+ supports_reasoning = self._supports_reasoning_features(model)
97
+
98
+ request_kwargs = {
99
+ 'model': model,
100
+ 'messages': messages,
101
+ 'max_tokens': max_tokens,
102
+ 'response_format': {'type': 'json_object'},
103
+ }
104
+
105
+ temperature_value = temperature if not supports_reasoning else None
106
+ if temperature_value is not None:
107
+ request_kwargs['temperature'] = temperature_value
108
+
109
+ return await self.client.chat.completions.create(**request_kwargs)
110
+
111
+ @staticmethod
112
+ def _supports_reasoning_features(model: str) -> bool:
113
+ """Return True when the Azure model supports reasoning/verbosity options."""
114
+ reasoning_prefixes = ('o1', 'o3', 'gpt-5')
115
+ return model.startswith(reasoning_prefixes)