graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
graphiti_core/graphiti_types.py
CHANGED
|
@@ -14,18 +14,20 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
from neo4j import AsyncDriver
|
|
18
17
|
from pydantic import BaseModel, ConfigDict
|
|
19
18
|
|
|
20
19
|
from graphiti_core.cross_encoder import CrossEncoderClient
|
|
20
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
21
21
|
from graphiti_core.embedder import EmbedderClient
|
|
22
22
|
from graphiti_core.llm_client import LLMClient
|
|
23
|
+
from graphiti_core.tracer import Tracer
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class GraphitiClients(BaseModel):
|
|
26
|
-
driver:
|
|
27
|
+
driver: GraphDriver
|
|
27
28
|
llm_client: LLMClient
|
|
28
29
|
embedder: EmbedderClient
|
|
29
30
|
cross_encoder: CrossEncoderClient
|
|
31
|
+
tracer: Tracer
|
|
30
32
|
|
|
31
33
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
graphiti_core/helpers.py
CHANGED
|
@@ -16,30 +16,47 @@ limitations under the License.
|
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
18
|
import os
|
|
19
|
+
import re
|
|
19
20
|
from collections.abc import Coroutine
|
|
20
21
|
from datetime import datetime
|
|
22
|
+
from typing import Any
|
|
21
23
|
|
|
22
24
|
import numpy as np
|
|
23
25
|
from dotenv import load_dotenv
|
|
24
26
|
from neo4j import time as neo4j_time
|
|
25
27
|
from numpy._typing import NDArray
|
|
26
|
-
from
|
|
28
|
+
from pydantic import BaseModel
|
|
29
|
+
|
|
30
|
+
from graphiti_core.driver.driver import GraphProvider
|
|
31
|
+
from graphiti_core.errors import GroupIdValidationError
|
|
27
32
|
|
|
28
33
|
load_dotenv()
|
|
29
34
|
|
|
30
|
-
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
|
31
35
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
32
36
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
|
33
37
|
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
|
34
38
|
DEFAULT_PAGE_LIMIT = 20
|
|
35
39
|
|
|
36
|
-
RUNTIME_QUERY: LiteralString = (
|
|
37
|
-
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
38
|
-
)
|
|
39
40
|
|
|
41
|
+
def parse_db_date(input_date: neo4j_time.DateTime | str | None) -> datetime | None:
|
|
42
|
+
if isinstance(input_date, neo4j_time.DateTime):
|
|
43
|
+
return input_date.to_native()
|
|
44
|
+
|
|
45
|
+
if isinstance(input_date, str):
|
|
46
|
+
return datetime.fromisoformat(input_date)
|
|
47
|
+
|
|
48
|
+
return input_date
|
|
40
49
|
|
|
41
|
-
|
|
42
|
-
|
|
50
|
+
|
|
51
|
+
def get_default_group_id(provider: GraphProvider) -> str:
|
|
52
|
+
"""
|
|
53
|
+
This function differentiates the default group id based on the database type.
|
|
54
|
+
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
|
55
|
+
"""
|
|
56
|
+
if provider == GraphProvider.FALKORDB:
|
|
57
|
+
return '\\_'
|
|
58
|
+
else:
|
|
59
|
+
return ''
|
|
43
60
|
|
|
44
61
|
|
|
45
62
|
def lucene_sanitize(query: str) -> str:
|
|
@@ -88,12 +105,72 @@ def normalize_l2(embedding: list[float]) -> NDArray:
|
|
|
88
105
|
# Use this instead of asyncio.gather() to bound coroutines
|
|
89
106
|
async def semaphore_gather(
|
|
90
107
|
*coroutines: Coroutine,
|
|
91
|
-
max_coroutines: int =
|
|
92
|
-
):
|
|
93
|
-
semaphore = asyncio.Semaphore(max_coroutines)
|
|
108
|
+
max_coroutines: int | None = None,
|
|
109
|
+
) -> list[Any]:
|
|
110
|
+
semaphore = asyncio.Semaphore(max_coroutines or SEMAPHORE_LIMIT)
|
|
94
111
|
|
|
95
112
|
async def _wrap_coroutine(coroutine):
|
|
96
113
|
async with semaphore:
|
|
97
114
|
return await coroutine
|
|
98
115
|
|
|
99
116
|
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def validate_group_id(group_id: str | None) -> bool:
|
|
120
|
+
"""
|
|
121
|
+
Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
group_id: The group_id to validate
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
True if valid, False otherwise
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
GroupIdValidationError: If group_id contains invalid characters
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
# Allow empty string (default case)
|
|
134
|
+
if not group_id:
|
|
135
|
+
return True
|
|
136
|
+
|
|
137
|
+
# Check if string contains only ASCII alphanumeric characters, dashes, or underscores
|
|
138
|
+
# Pattern matches: letters (a-z, A-Z), digits (0-9), hyphens (-), and underscores (_)
|
|
139
|
+
if not re.match(r'^[a-zA-Z0-9_-]+$', group_id):
|
|
140
|
+
raise GroupIdValidationError(group_id)
|
|
141
|
+
|
|
142
|
+
return True
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def validate_excluded_entity_types(
|
|
146
|
+
excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None
|
|
147
|
+
) -> bool:
|
|
148
|
+
"""
|
|
149
|
+
Validate that excluded entity types are valid type names.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
excluded_entity_types: List of entity type names to exclude
|
|
153
|
+
entity_types: Dictionary of available custom entity types
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
True if valid
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
ValueError: If any excluded type names are invalid
|
|
160
|
+
"""
|
|
161
|
+
if not excluded_entity_types:
|
|
162
|
+
return True
|
|
163
|
+
|
|
164
|
+
# Build set of available type names
|
|
165
|
+
available_types = {'Entity'} # Default type is always available
|
|
166
|
+
if entity_types:
|
|
167
|
+
available_types.update(entity_types.keys())
|
|
168
|
+
|
|
169
|
+
# Check for invalid type names
|
|
170
|
+
invalid_types = set(excluded_entity_types) - available_types
|
|
171
|
+
if invalid_types:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f'Invalid excluded entity types: {sorted(invalid_types)}. Available types: {sorted(available_types)}'
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return True
|
|
@@ -1,3 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
1
17
|
from .client import LLMClient
|
|
2
18
|
from .config import LLMConfig
|
|
3
19
|
from .errors import RateLimitError
|
|
@@ -19,11 +19,8 @@ import logging
|
|
|
19
19
|
import os
|
|
20
20
|
import typing
|
|
21
21
|
from json import JSONDecodeError
|
|
22
|
-
from typing import Literal
|
|
22
|
+
from typing import TYPE_CHECKING, Literal
|
|
23
23
|
|
|
24
|
-
import anthropic
|
|
25
|
-
from anthropic import AsyncAnthropic
|
|
26
|
-
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
|
|
27
24
|
from pydantic import BaseModel, ValidationError
|
|
28
25
|
|
|
29
26
|
from ..prompts.models import Message
|
|
@@ -31,9 +28,28 @@ from .client import LLMClient
|
|
|
31
28
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
|
32
29
|
from .errors import RateLimitError, RefusalError
|
|
33
30
|
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
import anthropic
|
|
33
|
+
from anthropic import AsyncAnthropic
|
|
34
|
+
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
|
|
35
|
+
else:
|
|
36
|
+
try:
|
|
37
|
+
import anthropic
|
|
38
|
+
from anthropic import AsyncAnthropic
|
|
39
|
+
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
|
|
40
|
+
except ImportError:
|
|
41
|
+
raise ImportError(
|
|
42
|
+
'anthropic is required for AnthropicClient. '
|
|
43
|
+
'Install it with: pip install graphiti-core[anthropic]'
|
|
44
|
+
) from None
|
|
45
|
+
|
|
46
|
+
|
|
34
47
|
logger = logging.getLogger(__name__)
|
|
35
48
|
|
|
36
49
|
AnthropicModel = Literal[
|
|
50
|
+
'claude-sonnet-4-5-latest',
|
|
51
|
+
'claude-sonnet-4-5-20250929',
|
|
52
|
+
'claude-haiku-4-5-latest',
|
|
37
53
|
'claude-3-7-sonnet-latest',
|
|
38
54
|
'claude-3-7-sonnet-20250219',
|
|
39
55
|
'claude-3-5-haiku-latest',
|
|
@@ -49,7 +65,39 @@ AnthropicModel = Literal[
|
|
|
49
65
|
'claude-2.0',
|
|
50
66
|
]
|
|
51
67
|
|
|
52
|
-
DEFAULT_MODEL: AnthropicModel = 'claude-
|
|
68
|
+
DEFAULT_MODEL: AnthropicModel = 'claude-haiku-4-5-latest'
|
|
69
|
+
|
|
70
|
+
# Maximum output tokens for different Anthropic models
|
|
71
|
+
# Based on official Anthropic documentation (as of 2025)
|
|
72
|
+
# Note: These represent standard limits without beta headers.
|
|
73
|
+
# Some models support higher limits with additional configuration (e.g., Claude 3.7 supports
|
|
74
|
+
# 128K with 'anthropic-beta: output-128k-2025-02-19' header, but this is not currently implemented).
|
|
75
|
+
ANTHROPIC_MODEL_MAX_TOKENS = {
|
|
76
|
+
# Claude 4.5 models - 64K tokens
|
|
77
|
+
'claude-sonnet-4-5-latest': 65536,
|
|
78
|
+
'claude-sonnet-4-5-20250929': 65536,
|
|
79
|
+
'claude-haiku-4-5-latest': 65536,
|
|
80
|
+
# Claude 3.7 models - standard 64K tokens
|
|
81
|
+
'claude-3-7-sonnet-latest': 65536,
|
|
82
|
+
'claude-3-7-sonnet-20250219': 65536,
|
|
83
|
+
# Claude 3.5 models
|
|
84
|
+
'claude-3-5-haiku-latest': 8192,
|
|
85
|
+
'claude-3-5-haiku-20241022': 8192,
|
|
86
|
+
'claude-3-5-sonnet-latest': 8192,
|
|
87
|
+
'claude-3-5-sonnet-20241022': 8192,
|
|
88
|
+
'claude-3-5-sonnet-20240620': 8192,
|
|
89
|
+
# Claude 3 models - 4K tokens
|
|
90
|
+
'claude-3-opus-latest': 4096,
|
|
91
|
+
'claude-3-opus-20240229': 4096,
|
|
92
|
+
'claude-3-sonnet-20240229': 4096,
|
|
93
|
+
'claude-3-haiku-20240307': 4096,
|
|
94
|
+
# Claude 2 models - 4K tokens
|
|
95
|
+
'claude-2.1': 4096,
|
|
96
|
+
'claude-2.0': 4096,
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
# Default max tokens for models not in the mapping
|
|
100
|
+
DEFAULT_ANTHROPIC_MAX_TOKENS = 8192
|
|
53
101
|
|
|
54
102
|
|
|
55
103
|
class AnthropicClient(LLMClient):
|
|
@@ -164,6 +212,45 @@ class AnthropicClient(LLMClient):
|
|
|
164
212
|
tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
|
|
165
213
|
return tool_list_cast, tool_choice_cast
|
|
166
214
|
|
|
215
|
+
def _get_max_tokens_for_model(self, model: str) -> int:
|
|
216
|
+
"""Get the maximum output tokens for a specific Anthropic model.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
model: The model name to look up
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
int: The maximum output tokens for the model
|
|
223
|
+
"""
|
|
224
|
+
return ANTHROPIC_MODEL_MAX_TOKENS.get(model, DEFAULT_ANTHROPIC_MAX_TOKENS)
|
|
225
|
+
|
|
226
|
+
def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
|
|
227
|
+
"""
|
|
228
|
+
Resolve the maximum output tokens to use based on precedence rules.
|
|
229
|
+
|
|
230
|
+
Precedence order (highest to lowest):
|
|
231
|
+
1. Explicit max_tokens parameter passed to generate_response()
|
|
232
|
+
2. Instance max_tokens set during client initialization
|
|
233
|
+
3. Model-specific maximum tokens from ANTHROPIC_MODEL_MAX_TOKENS mapping
|
|
234
|
+
4. DEFAULT_ANTHROPIC_MAX_TOKENS as final fallback
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
requested_max_tokens: The max_tokens parameter passed to generate_response()
|
|
238
|
+
model: The model name to look up model-specific limits
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
int: The resolved maximum tokens to use
|
|
242
|
+
"""
|
|
243
|
+
# 1. Use explicit parameter if provided
|
|
244
|
+
if requested_max_tokens is not None:
|
|
245
|
+
return requested_max_tokens
|
|
246
|
+
|
|
247
|
+
# 2. Use instance max_tokens if set during initialization
|
|
248
|
+
if self.max_tokens is not None:
|
|
249
|
+
return self.max_tokens
|
|
250
|
+
|
|
251
|
+
# 3. Use model-specific maximum or return DEFAULT_ANTHROPIC_MAX_TOKENS
|
|
252
|
+
return self._get_max_tokens_for_model(model)
|
|
253
|
+
|
|
167
254
|
async def _generate_response(
|
|
168
255
|
self,
|
|
169
256
|
messages: list[Message],
|
|
@@ -191,12 +278,9 @@ class AnthropicClient(LLMClient):
|
|
|
191
278
|
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
|
|
192
279
|
user_messages_cast = typing.cast(list[MessageParam], user_messages)
|
|
193
280
|
|
|
194
|
-
#
|
|
195
|
-
#
|
|
196
|
-
max_creation_tokens: int =
|
|
197
|
-
max_tokens if max_tokens is not None else self.config.max_tokens,
|
|
198
|
-
DEFAULT_MAX_TOKENS,
|
|
199
|
-
)
|
|
281
|
+
# Resolve max_tokens dynamically based on the model's capabilities
|
|
282
|
+
# This allows different models to use their full output capacity
|
|
283
|
+
max_creation_tokens: int = self._resolve_max_tokens(max_tokens, self.model)
|
|
200
284
|
|
|
201
285
|
try:
|
|
202
286
|
# Create the appropriate tool based on whether response_model is provided
|
|
@@ -252,6 +336,8 @@ class AnthropicClient(LLMClient):
|
|
|
252
336
|
response_model: type[BaseModel] | None = None,
|
|
253
337
|
max_tokens: int | None = None,
|
|
254
338
|
model_size: ModelSize = ModelSize.medium,
|
|
339
|
+
group_id: str | None = None,
|
|
340
|
+
prompt_name: str | None = None,
|
|
255
341
|
) -> dict[str, typing.Any]:
|
|
256
342
|
"""
|
|
257
343
|
Generate a response from the LLM.
|
|
@@ -272,55 +358,72 @@ class AnthropicClient(LLMClient):
|
|
|
272
358
|
if max_tokens is None:
|
|
273
359
|
max_tokens = self.max_tokens
|
|
274
360
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
361
|
+
# Wrap entire operation in tracing span
|
|
362
|
+
with self.tracer.start_span('llm.generate') as span:
|
|
363
|
+
attributes = {
|
|
364
|
+
'llm.provider': 'anthropic',
|
|
365
|
+
'model.size': model_size.value,
|
|
366
|
+
'max_tokens': max_tokens,
|
|
367
|
+
}
|
|
368
|
+
if prompt_name:
|
|
369
|
+
attributes['prompt.name'] = prompt_name
|
|
370
|
+
span.add_attributes(attributes)
|
|
371
|
+
|
|
372
|
+
retry_count = 0
|
|
373
|
+
max_retries = 2
|
|
374
|
+
last_error: Exception | None = None
|
|
375
|
+
|
|
376
|
+
while retry_count <= max_retries:
|
|
377
|
+
try:
|
|
378
|
+
response = await self._generate_response(
|
|
379
|
+
messages, response_model, max_tokens, model_size
|
|
380
|
+
)
|
|
293
381
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
382
|
+
# If we have a response_model, attempt to validate the response
|
|
383
|
+
if response_model is not None:
|
|
384
|
+
# Validate the response against the response_model
|
|
385
|
+
model_instance = response_model(**response)
|
|
386
|
+
return model_instance.model_dump()
|
|
387
|
+
|
|
388
|
+
# If no validation needed, return the response
|
|
389
|
+
return response
|
|
390
|
+
|
|
391
|
+
except (RateLimitError, RefusalError):
|
|
392
|
+
# These errors should not trigger retries
|
|
393
|
+
span.set_status('error', str(last_error))
|
|
394
|
+
raise
|
|
395
|
+
except Exception as e:
|
|
396
|
+
last_error = e
|
|
397
|
+
|
|
398
|
+
if retry_count >= max_retries:
|
|
399
|
+
if isinstance(e, ValidationError):
|
|
400
|
+
logger.error(
|
|
401
|
+
f'Validation error after {retry_count}/{max_retries} attempts: {e}'
|
|
402
|
+
)
|
|
403
|
+
else:
|
|
404
|
+
logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
|
|
405
|
+
span.set_status('error', str(e))
|
|
406
|
+
span.record_exception(e)
|
|
407
|
+
raise e
|
|
299
408
|
|
|
300
|
-
if retry_count >= max_retries:
|
|
301
409
|
if isinstance(e, ValidationError):
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
)
|
|
410
|
+
response_model_cast = typing.cast(type[BaseModel], response_model)
|
|
411
|
+
error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
|
|
305
412
|
else:
|
|
306
|
-
|
|
307
|
-
|
|
413
|
+
error_context = (
|
|
414
|
+
f'The previous response attempt was invalid. '
|
|
415
|
+
f'Error type: {e.__class__.__name__}. '
|
|
416
|
+
f'Error details: {str(e)}. '
|
|
417
|
+
f'Please try again with a valid response.'
|
|
418
|
+
)
|
|
308
419
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
f'The previous response attempt was invalid. '
|
|
315
|
-
f'Error type: {e.__class__.__name__}. '
|
|
316
|
-
f'Error details: {str(e)}. '
|
|
317
|
-
f'Please try again with a valid response.'
|
|
420
|
+
# Common retry logic
|
|
421
|
+
retry_count += 1
|
|
422
|
+
messages.append(Message(role='user', content=error_context))
|
|
423
|
+
logger.warning(
|
|
424
|
+
f'Retrying after error (attempt {retry_count}/{max_retries}): {e}'
|
|
318
425
|
)
|
|
319
426
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
logger.warning(f'Retrying after error (attempt {retry_count}/{max_retries}): {e}')
|
|
324
|
-
|
|
325
|
-
# If we somehow get here, raise the last error
|
|
326
|
-
raise last_error or Exception('Max retries exceeded with no specific error')
|
|
427
|
+
# If we somehow get here, raise the last error
|
|
428
|
+
span.set_status('error', str(last_error))
|
|
429
|
+
raise last_error or Exception('Max retries exceeded with no specific error')
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import ClassVar
|
|
19
|
+
|
|
20
|
+
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
21
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
22
|
+
from pydantic import BaseModel
|
|
23
|
+
|
|
24
|
+
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
|
25
|
+
from .openai_base_client import BaseOpenAIClient
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AzureOpenAILLMClient(BaseOpenAIClient):
|
|
31
|
+
"""Wrapper class for Azure OpenAI that implements the LLMClient interface.
|
|
32
|
+
|
|
33
|
+
Supports both AsyncAzureOpenAI and AsyncOpenAI (with Azure v1 API endpoint).
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
# Class-level constants
|
|
37
|
+
MAX_RETRIES: ClassVar[int] = 2
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
azure_client: AsyncAzureOpenAI | AsyncOpenAI,
|
|
42
|
+
config: LLMConfig | None = None,
|
|
43
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
44
|
+
reasoning: str | None = None,
|
|
45
|
+
verbosity: str | None = None,
|
|
46
|
+
):
|
|
47
|
+
super().__init__(
|
|
48
|
+
config,
|
|
49
|
+
cache=False,
|
|
50
|
+
max_tokens=max_tokens,
|
|
51
|
+
reasoning=reasoning,
|
|
52
|
+
verbosity=verbosity,
|
|
53
|
+
)
|
|
54
|
+
self.client = azure_client
|
|
55
|
+
|
|
56
|
+
async def _create_structured_completion(
|
|
57
|
+
self,
|
|
58
|
+
model: str,
|
|
59
|
+
messages: list[ChatCompletionMessageParam],
|
|
60
|
+
temperature: float | None,
|
|
61
|
+
max_tokens: int,
|
|
62
|
+
response_model: type[BaseModel],
|
|
63
|
+
reasoning: str | None,
|
|
64
|
+
verbosity: str | None,
|
|
65
|
+
):
|
|
66
|
+
"""Create a structured completion using Azure OpenAI's responses.parse API."""
|
|
67
|
+
supports_reasoning = self._supports_reasoning_features(model)
|
|
68
|
+
request_kwargs = {
|
|
69
|
+
'model': model,
|
|
70
|
+
'input': messages,
|
|
71
|
+
'max_output_tokens': max_tokens,
|
|
72
|
+
'text_format': response_model, # type: ignore
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
temperature_value = temperature if not supports_reasoning else None
|
|
76
|
+
if temperature_value is not None:
|
|
77
|
+
request_kwargs['temperature'] = temperature_value
|
|
78
|
+
|
|
79
|
+
if supports_reasoning and reasoning:
|
|
80
|
+
request_kwargs['reasoning'] = {'effort': reasoning} # type: ignore
|
|
81
|
+
|
|
82
|
+
if supports_reasoning and verbosity:
|
|
83
|
+
request_kwargs['text'] = {'verbosity': verbosity} # type: ignore
|
|
84
|
+
|
|
85
|
+
return await self.client.responses.parse(**request_kwargs)
|
|
86
|
+
|
|
87
|
+
async def _create_completion(
|
|
88
|
+
self,
|
|
89
|
+
model: str,
|
|
90
|
+
messages: list[ChatCompletionMessageParam],
|
|
91
|
+
temperature: float | None,
|
|
92
|
+
max_tokens: int,
|
|
93
|
+
response_model: type[BaseModel] | None = None,
|
|
94
|
+
):
|
|
95
|
+
"""Create a regular completion with JSON format using Azure OpenAI."""
|
|
96
|
+
supports_reasoning = self._supports_reasoning_features(model)
|
|
97
|
+
|
|
98
|
+
request_kwargs = {
|
|
99
|
+
'model': model,
|
|
100
|
+
'messages': messages,
|
|
101
|
+
'max_tokens': max_tokens,
|
|
102
|
+
'response_format': {'type': 'json_object'},
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
temperature_value = temperature if not supports_reasoning else None
|
|
106
|
+
if temperature_value is not None:
|
|
107
|
+
request_kwargs['temperature'] = temperature_value
|
|
108
|
+
|
|
109
|
+
return await self.client.chat.completions.create(**request_kwargs)
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _supports_reasoning_features(model: str) -> bool:
|
|
113
|
+
"""Return True when the Azure model supports reasoning/verbosity options."""
|
|
114
|
+
reasoning_prefixes = ('o1', 'o3', 'gpt-5')
|
|
115
|
+
return model.startswith(reasoning_prefixes)
|