synth-ai 0.1.0.dev38__py3-none-any.whl → 0.1.0.dev49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +3 -1
- {synth_ai-0.1.0.dev38.dist-info → synth_ai-0.1.0.dev49.dist-info}/METADATA +12 -11
- synth_ai-0.1.0.dev49.dist-info/RECORD +6 -0
- {synth_ai-0.1.0.dev38.dist-info → synth_ai-0.1.0.dev49.dist-info}/WHEEL +1 -1
- synth_ai-0.1.0.dev49.dist-info/top_level.txt +1 -0
- private_tests/try_synth_sdk.py +0 -1
- public_tests/test_agent.py +0 -538
- public_tests/test_all_structured_outputs.py +0 -196
- public_tests/test_anthropic_structured_outputs.py +0 -0
- public_tests/test_deepseek_structured_outputs.py +0 -0
- public_tests/test_deepseek_tools.py +0 -64
- public_tests/test_gemini_output.py +0 -188
- public_tests/test_gemini_structured_outputs.py +0 -106
- public_tests/test_models.py +0 -183
- public_tests/test_openai_structured_outputs.py +0 -106
- public_tests/test_reasoning_effort.py +0 -75
- public_tests/test_reasoning_models.py +0 -92
- public_tests/test_recursive_structured_outputs.py +0 -180
- public_tests/test_structured.py +0 -137
- public_tests/test_structured_outputs.py +0 -109
- public_tests/test_synth_sdk.py +0 -384
- public_tests/test_text.py +0 -160
- public_tests/test_tools.py +0 -319
- synth_ai/zyk/__init__.py +0 -3
- synth_ai/zyk/lms/__init__.py +0 -0
- synth_ai/zyk/lms/caching/__init__.py +0 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/caching/dbs.py +0 -0
- synth_ai/zyk/lms/caching/ephemeral.py +0 -72
- synth_ai/zyk/lms/caching/handler.py +0 -142
- synth_ai/zyk/lms/caching/initialize.py +0 -13
- synth_ai/zyk/lms/caching/persistent.py +0 -83
- synth_ai/zyk/lms/config.py +0 -8
- synth_ai/zyk/lms/core/__init__.py +0 -0
- synth_ai/zyk/lms/core/all.py +0 -47
- synth_ai/zyk/lms/core/exceptions.py +0 -9
- synth_ai/zyk/lms/core/main.py +0 -314
- synth_ai/zyk/lms/core/vendor_clients.py +0 -85
- synth_ai/zyk/lms/cost/__init__.py +0 -0
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai/zyk/lms/structured_outputs/__init__.py +0 -0
- synth_ai/zyk/lms/structured_outputs/handler.py +0 -442
- synth_ai/zyk/lms/structured_outputs/inject.py +0 -314
- synth_ai/zyk/lms/structured_outputs/rehabilitate.py +0 -187
- synth_ai/zyk/lms/tools/base.py +0 -104
- synth_ai/zyk/lms/vendors/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/base.py +0 -31
- synth_ai/zyk/lms/vendors/constants.py +0 -22
- synth_ai/zyk/lms/vendors/core/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/core/anthropic_api.py +0 -413
- synth_ai/zyk/lms/vendors/core/gemini_api.py +0 -306
- synth_ai/zyk/lms/vendors/core/mistral_api.py +0 -327
- synth_ai/zyk/lms/vendors/core/openai_api.py +0 -185
- synth_ai/zyk/lms/vendors/local/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/local/ollama.py +0 -0
- synth_ai/zyk/lms/vendors/openai_standard.py +0 -375
- synth_ai/zyk/lms/vendors/retries.py +0 -3
- synth_ai/zyk/lms/vendors/supported/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/supported/deepseek.py +0 -73
- synth_ai/zyk/lms/vendors/supported/groq.py +0 -16
- synth_ai/zyk/lms/vendors/supported/ollama.py +0 -14
- synth_ai/zyk/lms/vendors/supported/together.py +0 -11
- synth_ai-0.1.0.dev38.dist-info/RECORD +0 -67
- synth_ai-0.1.0.dev38.dist-info/top_level.txt +0 -4
- tests/test_agent.py +0 -538
- tests/test_recursive_structured_outputs.py +0 -180
- tests/test_structured_outputs.py +0 -100
- {synth_ai-0.1.0.dev38.dist-info → synth_ai-0.1.0.dev49.dist-info}/licenses/LICENSE +0 -0
@@ -1,85 +0,0 @@
|
|
1
|
-
import re
|
2
|
-
from typing import Any, List, Pattern
|
3
|
-
|
4
|
-
from synth_ai.zyk.lms.core.all import (
|
5
|
-
AnthropicClient,
|
6
|
-
DeepSeekClient,
|
7
|
-
GeminiClient,
|
8
|
-
GroqAPI,
|
9
|
-
MistralAPI,
|
10
|
-
# OpenAIClient,
|
11
|
-
OpenAIStructuredOutputClient,
|
12
|
-
TogetherClient,
|
13
|
-
)
|
14
|
-
|
15
|
-
openai_naming_regexes: List[Pattern] = [
|
16
|
-
re.compile(r"^(ft:)?(o[1,3,4,5](-.*)?|gpt-.*)$"),
|
17
|
-
]
|
18
|
-
openai_formatting_model_regexes: List[Pattern] = [
|
19
|
-
re.compile(r"^(ft:)?gpt-4o(-.*)?$"),
|
20
|
-
]
|
21
|
-
anthropic_naming_regexes: List[Pattern] = [
|
22
|
-
re.compile(r"^claude-.*$"),
|
23
|
-
]
|
24
|
-
gemini_naming_regexes: List[Pattern] = [
|
25
|
-
re.compile(r"^gemini-.*$"),
|
26
|
-
re.compile(r"^gemma[2-9].*$"),
|
27
|
-
]
|
28
|
-
deepseek_naming_regexes: List[Pattern] = [
|
29
|
-
re.compile(r"^deepseek-.*$"),
|
30
|
-
]
|
31
|
-
together_naming_regexes: List[Pattern] = [
|
32
|
-
re.compile(r"^.*\/.*$"),
|
33
|
-
]
|
34
|
-
|
35
|
-
groq_naming_regexes: List[Pattern] = [
|
36
|
-
re.compile(r"^llama-3.3-70b-versatile$"),
|
37
|
-
re.compile(r"^llama-3.1-8b-instant$"),
|
38
|
-
re.compile(r"^qwen-2.5-32b$"),
|
39
|
-
re.compile(r"^deepseek-r1-distill-qwen-32b$"),
|
40
|
-
re.compile(r"^deepseek-r1-distill-llama-70b-specdec$"),
|
41
|
-
re.compile(r"^deepseek-r1-distill-llama-70b$"),
|
42
|
-
re.compile(r"^llama-3.3-70b-specdec$"),
|
43
|
-
re.compile(r"^llama-3.2-1b-preview$"),
|
44
|
-
re.compile(r"^llama-3.2-3b-preview$"),
|
45
|
-
re.compile(r"^llama-3.2-11b-vision-preview$"),
|
46
|
-
re.compile(r"^llama-3.2-90b-vision-preview$"),
|
47
|
-
]
|
48
|
-
|
49
|
-
mistral_naming_regexes: List[Pattern] = [
|
50
|
-
re.compile(r"^mistral-.*$"),
|
51
|
-
]
|
52
|
-
|
53
|
-
|
54
|
-
def get_client(
|
55
|
-
model_name: str,
|
56
|
-
with_formatting: bool = False,
|
57
|
-
synth_logging: bool = True,
|
58
|
-
) -> Any:
|
59
|
-
# print("With formatting", with_formatting)
|
60
|
-
if any(regex.match(model_name) for regex in openai_naming_regexes):
|
61
|
-
# print("Returning OpenAIStructuredOutputClient")
|
62
|
-
return OpenAIStructuredOutputClient(
|
63
|
-
synth_logging=synth_logging,
|
64
|
-
)
|
65
|
-
elif any(regex.match(model_name) for regex in anthropic_naming_regexes):
|
66
|
-
if with_formatting:
|
67
|
-
client = AnthropicClient()
|
68
|
-
client._hit_api_async_structured_output = OpenAIStructuredOutputClient(
|
69
|
-
synth_logging=synth_logging
|
70
|
-
)._hit_api_async
|
71
|
-
return client
|
72
|
-
else:
|
73
|
-
return AnthropicClient()
|
74
|
-
elif any(regex.match(model_name) for regex in gemini_naming_regexes):
|
75
|
-
return GeminiClient()
|
76
|
-
elif any(regex.match(model_name) for regex in deepseek_naming_regexes):
|
77
|
-
return DeepSeekClient()
|
78
|
-
elif any(regex.match(model_name) for regex in together_naming_regexes):
|
79
|
-
return TogetherClient()
|
80
|
-
elif any(regex.match(model_name) for regex in groq_naming_regexes):
|
81
|
-
return GroqAPI()
|
82
|
-
elif any(regex.match(model_name) for regex in mistral_naming_regexes):
|
83
|
-
return MistralAPI()
|
84
|
-
else:
|
85
|
-
raise ValueError(f"Invalid model name: {model_name}")
|
File without changes
|
synth_ai/zyk/lms/cost/monitor.py
DELETED
@@ -1 +0,0 @@
|
|
1
|
-
#TODO
|
@@ -1 +0,0 @@
|
|
1
|
-
# Maybe some kind of ephemeral cache
|
File without changes
|
@@ -1,442 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
import time
|
3
|
-
from abc import ABC, abstractmethod
|
4
|
-
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
5
|
-
|
6
|
-
from pydantic import BaseModel
|
7
|
-
|
8
|
-
from synth_ai.zyk.lms.core.exceptions import StructuredOutputCoercionFailureException
|
9
|
-
from synth_ai.zyk.lms.structured_outputs.inject import (
|
10
|
-
inject_structured_output_instructions,
|
11
|
-
)
|
12
|
-
from synth_ai.zyk.lms.structured_outputs.rehabilitate import (
|
13
|
-
fix_errant_forced_async,
|
14
|
-
fix_errant_forced_sync,
|
15
|
-
pull_out_structured_output,
|
16
|
-
)
|
17
|
-
from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
|
18
|
-
from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
|
19
|
-
|
20
|
-
logger = logging.getLogger(__name__)
|
21
|
-
|
22
|
-
|
23
|
-
class StructuredHandlerBase(ABC):
|
24
|
-
core_client: VendorBase
|
25
|
-
retry_client: VendorBase
|
26
|
-
handler_params: Dict[str, Any]
|
27
|
-
structured_output_mode: Literal["stringified_json", "forced_json"]
|
28
|
-
|
29
|
-
def __init__(
|
30
|
-
self,
|
31
|
-
core_client: VendorBase,
|
32
|
-
retry_client: VendorBase,
|
33
|
-
handler_params: Optional[Dict[str, Any]] = None,
|
34
|
-
structured_output_mode: Literal[
|
35
|
-
"stringified_json", "forced_json"
|
36
|
-
] = "stringified_json",
|
37
|
-
):
|
38
|
-
self.core_client = core_client
|
39
|
-
self.retry_client = retry_client
|
40
|
-
self.handler_params = (
|
41
|
-
handler_params if handler_params is not None else {"retries": 3}
|
42
|
-
)
|
43
|
-
self.structured_output_mode = structured_output_mode
|
44
|
-
|
45
|
-
async def call_async(
|
46
|
-
self,
|
47
|
-
messages: List[Dict[str, Any]],
|
48
|
-
model: str,
|
49
|
-
response_model: BaseModel,
|
50
|
-
temperature: float = 0.0,
|
51
|
-
use_ephemeral_cache_only: bool = False,
|
52
|
-
reasoning_effort: str = "high",
|
53
|
-
) -> BaseLMResponse:
|
54
|
-
if temperature == 0.0:
|
55
|
-
temperature = SPECIAL_BASE_TEMPS.get(model, 0.0)
|
56
|
-
# print("Calling from base")
|
57
|
-
return await self._process_call_async(
|
58
|
-
messages=messages,
|
59
|
-
model=model,
|
60
|
-
response_model=response_model,
|
61
|
-
api_call_method=self.core_client._hit_api_async_structured_output
|
62
|
-
if (not not response_model and self.structured_output_mode == "forced_json")
|
63
|
-
else self.core_client._hit_api_async,
|
64
|
-
temperature=temperature,
|
65
|
-
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
66
|
-
reasoning_effort=reasoning_effort,
|
67
|
-
)
|
68
|
-
|
69
|
-
def call_sync(
|
70
|
-
self,
|
71
|
-
messages: List[Dict[str, Any]],
|
72
|
-
response_model: BaseModel,
|
73
|
-
model: str,
|
74
|
-
temperature: float = 0.0,
|
75
|
-
use_ephemeral_cache_only: bool = False,
|
76
|
-
reasoning_effort: str = "high",
|
77
|
-
) -> BaseLMResponse:
|
78
|
-
if temperature == 0.0:
|
79
|
-
temperature = SPECIAL_BASE_TEMPS.get(model, 0.0)
|
80
|
-
return self._process_call_sync(
|
81
|
-
messages=messages,
|
82
|
-
model=model,
|
83
|
-
response_model=response_model,
|
84
|
-
api_call_method=self.core_client._hit_api_sync_structured_output
|
85
|
-
if (not not response_model and self.structured_output_mode == "forced_json")
|
86
|
-
else self.core_client._hit_api_sync,
|
87
|
-
temperature=temperature,
|
88
|
-
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
89
|
-
reasoning_effort=reasoning_effort,
|
90
|
-
)
|
91
|
-
|
92
|
-
@abstractmethod
|
93
|
-
async def _process_call_async(
|
94
|
-
self,
|
95
|
-
messages: List[Dict[str, Any]],
|
96
|
-
model: str,
|
97
|
-
response_model: BaseModel,
|
98
|
-
api_call_method,
|
99
|
-
use_ephemeral_cache_only: bool = False,
|
100
|
-
reasoning_effort: str = "high",
|
101
|
-
) -> BaseLMResponse:
|
102
|
-
pass
|
103
|
-
|
104
|
-
@abstractmethod
|
105
|
-
def _process_call_sync(
|
106
|
-
self,
|
107
|
-
messages: List[Dict[str, Any]],
|
108
|
-
model: str,
|
109
|
-
response_model: BaseModel,
|
110
|
-
api_call_method,
|
111
|
-
use_ephemeral_cache_only: bool = False,
|
112
|
-
reasoning_effort: str = "high",
|
113
|
-
) -> BaseLMResponse:
|
114
|
-
pass
|
115
|
-
|
116
|
-
|
117
|
-
class StringifiedJSONHandler(StructuredHandlerBase):
|
118
|
-
core_client: VendorBase
|
119
|
-
retry_client: VendorBase
|
120
|
-
handler_params: Dict[str, Any]
|
121
|
-
|
122
|
-
def __init__(
|
123
|
-
self,
|
124
|
-
core_client: VendorBase,
|
125
|
-
retry_client: VendorBase,
|
126
|
-
handler_params: Dict[str, Any] = {"retries": 3},
|
127
|
-
):
|
128
|
-
super().__init__(
|
129
|
-
core_client,
|
130
|
-
retry_client,
|
131
|
-
handler_params,
|
132
|
-
structured_output_mode="stringified_json",
|
133
|
-
)
|
134
|
-
|
135
|
-
async def _process_call_async(
|
136
|
-
self,
|
137
|
-
messages: List[Dict[str, Any]],
|
138
|
-
model: str,
|
139
|
-
response_model: BaseModel,
|
140
|
-
temperature: float,
|
141
|
-
api_call_method: Callable,
|
142
|
-
use_ephemeral_cache_only: bool = False,
|
143
|
-
reasoning_effort: str = "high",
|
144
|
-
) -> BaseLMResponse:
|
145
|
-
logger.info(f"Processing structured output call for model: {model}")
|
146
|
-
assert callable(api_call_method), "api_call_method must be a callable"
|
147
|
-
assert (
|
148
|
-
response_model is not None
|
149
|
-
), "Don't use this handler for unstructured outputs"
|
150
|
-
remaining_retries = self.handler_params.get("retries", 2)
|
151
|
-
previously_failed_error_messages = []
|
152
|
-
structured_output = None
|
153
|
-
|
154
|
-
while remaining_retries > 0:
|
155
|
-
messages_with_json_formatting_instructions = (
|
156
|
-
inject_structured_output_instructions(
|
157
|
-
messages=messages,
|
158
|
-
response_model=response_model,
|
159
|
-
previously_failed_error_messages=previously_failed_error_messages,
|
160
|
-
)
|
161
|
-
)
|
162
|
-
t0 = time.time()
|
163
|
-
raw_text_response_or_cached_hit = await api_call_method(
|
164
|
-
messages=messages_with_json_formatting_instructions,
|
165
|
-
model=model,
|
166
|
-
lm_config={"response_model": None, "temperature": temperature},
|
167
|
-
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
168
|
-
reasoning_effort=reasoning_effort,
|
169
|
-
)
|
170
|
-
logger.debug(f"Time to get response: {time.time() - t0:.2f}s")
|
171
|
-
|
172
|
-
# Check if we got a cached BaseLMResponse
|
173
|
-
assert (
|
174
|
-
type(raw_text_response_or_cached_hit) in [str, BaseLMResponse]
|
175
|
-
), f"Expected str or BaseLMResponse, got {type(raw_text_response_or_cached_hit)}"
|
176
|
-
if type(raw_text_response_or_cached_hit) == BaseLMResponse:
|
177
|
-
#print("Got cached hit, returning directly")
|
178
|
-
raw_text_response = raw_text_response_or_cached_hit.raw_response
|
179
|
-
else:
|
180
|
-
raw_text_response = raw_text_response_or_cached_hit
|
181
|
-
logger.debug(f"Raw response from model:\n{raw_text_response}")
|
182
|
-
|
183
|
-
#print("Trying to parse structured output")
|
184
|
-
try:
|
185
|
-
structured_output = pull_out_structured_output(
|
186
|
-
raw_text_response, response_model
|
187
|
-
)
|
188
|
-
|
189
|
-
#print("Successfully parsed structured output on first attempt")
|
190
|
-
break
|
191
|
-
except Exception as e:
|
192
|
-
logger.warning(f"Failed to parse structured output: {str(e)}")
|
193
|
-
try:
|
194
|
-
#print("Attempting to fix with forced JSON parser")
|
195
|
-
structured_output = await fix_errant_forced_async(
|
196
|
-
messages_with_json_formatting_instructions,
|
197
|
-
raw_text_response,
|
198
|
-
response_model,
|
199
|
-
"gpt-4o-mini",
|
200
|
-
)
|
201
|
-
assert isinstance(structured_output, BaseModel), "Structured output must be a Pydantic model"
|
202
|
-
assert not isinstance(structured_output, BaseLMResponse), "Got BaseLMResponse instead of Pydantic model"
|
203
|
-
#print("Successfully fixed and parsed structured output")
|
204
|
-
break
|
205
|
-
except Exception as e:
|
206
|
-
logger.error(f"Failed to fix structured output: {str(e)}")
|
207
|
-
previously_failed_error_messages.append(
|
208
|
-
f"Generated attempt and got error. Attempt:\n\n{raw_text_response}\n\nError:\n\n{e}"
|
209
|
-
)
|
210
|
-
remaining_retries -= 1
|
211
|
-
logger.warning(f"Retries remaining: {remaining_retries}")
|
212
|
-
|
213
|
-
if structured_output is None:
|
214
|
-
logger.error("Failed to get structured output after all retries")
|
215
|
-
raise StructuredOutputCoercionFailureException(
|
216
|
-
"Failed to get structured output"
|
217
|
-
)
|
218
|
-
#print("Successfully parsed structured output")
|
219
|
-
#print(structured_output)
|
220
|
-
assert isinstance(structured_output, BaseModel), "Structured output must be a Pydantic model"
|
221
|
-
assert not isinstance(structured_output, BaseLMResponse),"Got BaseLMResponse instead of Pydantic model"
|
222
|
-
return BaseLMResponse(
|
223
|
-
raw_response=raw_text_response,
|
224
|
-
structured_output=structured_output,
|
225
|
-
tool_calls=None,
|
226
|
-
)
|
227
|
-
|
228
|
-
def _process_call_sync(
|
229
|
-
self,
|
230
|
-
messages: List[Dict[str, Any]],
|
231
|
-
model: str,
|
232
|
-
response_model: BaseModel,
|
233
|
-
temperature: float,
|
234
|
-
api_call_method: Callable,
|
235
|
-
use_ephemeral_cache_only: bool = False,
|
236
|
-
reasoning_effort: str = "high",
|
237
|
-
) -> BaseLMResponse:
|
238
|
-
logger.info(f"Processing structured output call for model: {model}")
|
239
|
-
assert callable(api_call_method), "api_call_method must be a callable"
|
240
|
-
assert (
|
241
|
-
response_model is not None
|
242
|
-
), "Don't use this handler for unstructured outputs"
|
243
|
-
remaining_retries = self.handler_params.get("retries", 2)
|
244
|
-
previously_failed_error_messages = []
|
245
|
-
structured_output = None
|
246
|
-
|
247
|
-
while remaining_retries > 0:
|
248
|
-
messages_with_json_formatting_instructions = (
|
249
|
-
inject_structured_output_instructions(
|
250
|
-
messages=messages,
|
251
|
-
response_model=response_model,
|
252
|
-
previously_failed_error_messages=previously_failed_error_messages,
|
253
|
-
)
|
254
|
-
)
|
255
|
-
t0 = time.time()
|
256
|
-
raw_text_response_or_cached_hit = api_call_method(
|
257
|
-
messages=messages_with_json_formatting_instructions,
|
258
|
-
model=model,
|
259
|
-
lm_config={"response_model": None, "temperature": temperature},
|
260
|
-
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
261
|
-
reasoning_effort=reasoning_effort,
|
262
|
-
)
|
263
|
-
logger.debug(f"Time to get response: {time.time() - t0:.2f}s")
|
264
|
-
|
265
|
-
# Check if we got a cached BaseLMResponse
|
266
|
-
assert (
|
267
|
-
type(raw_text_response_or_cached_hit) in [str, BaseLMResponse]
|
268
|
-
), f"Expected str or BaseLMResponse, got {type(raw_text_response_or_cached_hit)}"
|
269
|
-
if type(raw_text_response_or_cached_hit) == BaseLMResponse:
|
270
|
-
logger.info("Got cached hit, returning directly")
|
271
|
-
raw_text_response = raw_text_response_or_cached_hit.raw_response
|
272
|
-
else:
|
273
|
-
raw_text_response = raw_text_response_or_cached_hit
|
274
|
-
logger.debug(f"Raw response from model:\n{raw_text_response}")
|
275
|
-
|
276
|
-
try:
|
277
|
-
structured_output = pull_out_structured_output(
|
278
|
-
raw_text_response, response_model
|
279
|
-
)
|
280
|
-
#print("Successfully parsed structured output on first attempt")
|
281
|
-
break
|
282
|
-
except Exception as e:
|
283
|
-
logger.warning(f"Failed to parse structured output: {str(e)}")
|
284
|
-
try:
|
285
|
-
#print("Attempting to fix with forced JSON parser")
|
286
|
-
structured_output = fix_errant_forced_sync(
|
287
|
-
raw_text_response, response_model, "gpt-4o-mini"
|
288
|
-
)
|
289
|
-
#print("Successfully fixed and parsed structured output")
|
290
|
-
break
|
291
|
-
except Exception as e:
|
292
|
-
logger.error(f"Failed to fix structured output: {str(e)}")
|
293
|
-
previously_failed_error_messages.append(
|
294
|
-
f"Generated attempt and got error. Attempt:\n\n{raw_text_response}\n\nError:\n\n{e}"
|
295
|
-
)
|
296
|
-
remaining_retries -= 1
|
297
|
-
logger.warning(f"Retries remaining: {remaining_retries}")
|
298
|
-
|
299
|
-
print("Successfully parsed structured output")
|
300
|
-
print(structured_output)
|
301
|
-
if structured_output is None:
|
302
|
-
logger.error("Failed to get structured output after all retries")
|
303
|
-
raise StructuredOutputCoercionFailureException(
|
304
|
-
"Failed to get structured output"
|
305
|
-
)
|
306
|
-
return BaseLMResponse(
|
307
|
-
raw_response=raw_text_response,
|
308
|
-
structured_output=structured_output,
|
309
|
-
tool_calls=None,
|
310
|
-
)
|
311
|
-
|
312
|
-
|
313
|
-
class ForcedJSONHandler(StructuredHandlerBase):
|
314
|
-
core_client: VendorBase
|
315
|
-
retry_client: VendorBase
|
316
|
-
handler_params: Dict[str, Any]
|
317
|
-
|
318
|
-
def __init__(
|
319
|
-
self,
|
320
|
-
core_client: VendorBase,
|
321
|
-
retry_client: VendorBase,
|
322
|
-
handler_params: Dict[str, Any] = {},
|
323
|
-
reasoning_effort: str = "high",
|
324
|
-
):
|
325
|
-
super().__init__(
|
326
|
-
core_client,
|
327
|
-
retry_client,
|
328
|
-
handler_params,
|
329
|
-
structured_output_mode="forced_json",
|
330
|
-
)
|
331
|
-
self.reasoning_effort = reasoning_effort
|
332
|
-
|
333
|
-
async def _process_call_async(
|
334
|
-
self,
|
335
|
-
messages: List[Dict[str, Any]],
|
336
|
-
model: str,
|
337
|
-
response_model: BaseModel,
|
338
|
-
api_call_method: Callable,
|
339
|
-
temperature: float = 0.0,
|
340
|
-
use_ephemeral_cache_only: bool = False,
|
341
|
-
reasoning_effort: str = "high",
|
342
|
-
) -> BaseLMResponse:
|
343
|
-
# print("Forced JSON")
|
344
|
-
assert (
|
345
|
-
response_model is not None
|
346
|
-
), "Don't use this handler for unstructured outputs"
|
347
|
-
return await api_call_method(
|
348
|
-
messages=messages,
|
349
|
-
model=model,
|
350
|
-
response_model=response_model,
|
351
|
-
temperature=temperature,
|
352
|
-
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
353
|
-
reasoning_effort=reasoning_effort,
|
354
|
-
)
|
355
|
-
|
356
|
-
def _process_call_sync(
|
357
|
-
self,
|
358
|
-
messages: List[Dict[str, Any]],
|
359
|
-
model: str,
|
360
|
-
response_model: BaseModel,
|
361
|
-
api_call_method: Callable,
|
362
|
-
temperature: float = 0.0,
|
363
|
-
use_ephemeral_cache_only: bool = False,
|
364
|
-
reasoning_effort: str = "high",
|
365
|
-
) -> BaseLMResponse:
|
366
|
-
assert (
|
367
|
-
response_model is not None
|
368
|
-
), "Don't use this handler for unstructured outputs"
|
369
|
-
return api_call_method(
|
370
|
-
messages=messages,
|
371
|
-
model=model,
|
372
|
-
response_model=response_model,
|
373
|
-
temperature=temperature,
|
374
|
-
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
375
|
-
reasoning_effort=reasoning_effort,
|
376
|
-
)
|
377
|
-
|
378
|
-
|
379
|
-
class StructuredOutputHandler:
|
380
|
-
handler: Union[StringifiedJSONHandler, ForcedJSONHandler]
|
381
|
-
mode: Literal["stringified_json", "forced_json"]
|
382
|
-
handler_params: Dict[str, Any]
|
383
|
-
|
384
|
-
def __init__(
|
385
|
-
self,
|
386
|
-
core_client: VendorBase,
|
387
|
-
retry_client: VendorBase,
|
388
|
-
mode: Literal["stringified_json", "forced_json"],
|
389
|
-
handler_params: Dict[str, Any] = {},
|
390
|
-
):
|
391
|
-
self.mode = mode
|
392
|
-
if self.mode == "stringified_json":
|
393
|
-
self.handler = StringifiedJSONHandler(
|
394
|
-
core_client, retry_client, handler_params
|
395
|
-
)
|
396
|
-
elif self.mode == "forced_json":
|
397
|
-
# print("Forced JSON")
|
398
|
-
self.handler = ForcedJSONHandler(core_client, retry_client, handler_params)
|
399
|
-
else:
|
400
|
-
raise ValueError(f"Invalid mode: {mode}")
|
401
|
-
|
402
|
-
async def call_async(
|
403
|
-
self,
|
404
|
-
messages: List[Dict[str, Any]],
|
405
|
-
model: str,
|
406
|
-
response_model: BaseModel,
|
407
|
-
use_ephemeral_cache_only: bool = False,
|
408
|
-
lm_config: Dict[str, Any] = {},
|
409
|
-
reasoning_effort: str = "high",
|
410
|
-
) -> BaseLMResponse:
|
411
|
-
# print("Output handler call async")
|
412
|
-
return await self.handler.call_async(
|
413
|
-
messages=messages,
|
414
|
-
model=model,
|
415
|
-
response_model=response_model,
|
416
|
-
temperature=lm_config.get(
|
417
|
-
"temperature", SPECIAL_BASE_TEMPS.get(model, 0.0)
|
418
|
-
),
|
419
|
-
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
420
|
-
reasoning_effort=reasoning_effort,
|
421
|
-
)
|
422
|
-
|
423
|
-
def call_sync(
|
424
|
-
self,
|
425
|
-
messages: List[Dict[str, Any]],
|
426
|
-
model: str,
|
427
|
-
response_model: BaseModel,
|
428
|
-
use_ephemeral_cache_only: bool = False,
|
429
|
-
lm_config: Dict[str, Any] = {},
|
430
|
-
reasoning_effort: str = "high",
|
431
|
-
) -> BaseLMResponse:
|
432
|
-
# print("Output handler call sync")
|
433
|
-
return self.handler.call_sync(
|
434
|
-
messages=messages,
|
435
|
-
model=model,
|
436
|
-
response_model=response_model,
|
437
|
-
temperature=lm_config.get(
|
438
|
-
"temperature", SPECIAL_BASE_TEMPS.get(model, 0.0)
|
439
|
-
),
|
440
|
-
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
441
|
-
reasoning_effort=reasoning_effort,
|
442
|
-
)
|