synth-ai 0.1.0.dev50__py3-none-any.whl → 0.1.0.dev52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +1 -1
- synth_ai/zyk/__init__.py +3 -0
- synth_ai/zyk/lms/__init__.py +0 -0
- synth_ai/zyk/lms/caching/__init__.py +0 -0
- synth_ai/zyk/lms/caching/constants.py +1 -0
- synth_ai/zyk/lms/caching/dbs.py +0 -0
- synth_ai/zyk/lms/caching/ephemeral.py +72 -0
- synth_ai/zyk/lms/caching/handler.py +137 -0
- synth_ai/zyk/lms/caching/initialize.py +13 -0
- synth_ai/zyk/lms/caching/persistent.py +83 -0
- synth_ai/zyk/lms/config.py +10 -0
- synth_ai/zyk/lms/constants.py +22 -0
- synth_ai/zyk/lms/core/__init__.py +0 -0
- synth_ai/zyk/lms/core/all.py +47 -0
- synth_ai/zyk/lms/core/exceptions.py +9 -0
- synth_ai/zyk/lms/core/main.py +268 -0
- synth_ai/zyk/lms/core/vendor_clients.py +85 -0
- synth_ai/zyk/lms/cost/__init__.py +0 -0
- synth_ai/zyk/lms/cost/monitor.py +1 -0
- synth_ai/zyk/lms/cost/statefulness.py +1 -0
- synth_ai/zyk/lms/structured_outputs/__init__.py +0 -0
- synth_ai/zyk/lms/structured_outputs/handler.py +441 -0
- synth_ai/zyk/lms/structured_outputs/inject.py +314 -0
- synth_ai/zyk/lms/structured_outputs/rehabilitate.py +187 -0
- synth_ai/zyk/lms/tools/base.py +118 -0
- synth_ai/zyk/lms/vendors/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/base.py +31 -0
- synth_ai/zyk/lms/vendors/core/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/core/anthropic_api.py +365 -0
- synth_ai/zyk/lms/vendors/core/gemini_api.py +282 -0
- synth_ai/zyk/lms/vendors/core/mistral_api.py +331 -0
- synth_ai/zyk/lms/vendors/core/openai_api.py +187 -0
- synth_ai/zyk/lms/vendors/local/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/local/ollama.py +0 -0
- synth_ai/zyk/lms/vendors/openai_standard.py +345 -0
- synth_ai/zyk/lms/vendors/retries.py +3 -0
- synth_ai/zyk/lms/vendors/supported/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/supported/deepseek.py +73 -0
- synth_ai/zyk/lms/vendors/supported/groq.py +16 -0
- synth_ai/zyk/lms/vendors/supported/ollama.py +14 -0
- synth_ai/zyk/lms/vendors/supported/together.py +11 -0
- {synth_ai-0.1.0.dev50.dist-info → synth_ai-0.1.0.dev52.dist-info}/METADATA +2 -1
- synth_ai-0.1.0.dev52.dist-info/RECORD +46 -0
- synth_ai-0.1.0.dev50.dist-info/RECORD +0 -6
- {synth_ai-0.1.0.dev50.dist-info → synth_ai-0.1.0.dev52.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.0.dev50.dist-info → synth_ai-0.1.0.dev52.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.0.dev50.dist-info → synth_ai-0.1.0.dev52.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
4
|
+
|
5
|
+
import pydantic
|
6
|
+
from mistralai import Mistral # use Mistral as both sync and async client
|
7
|
+
from pydantic import BaseModel
|
8
|
+
|
9
|
+
from synth_ai.zyk.lms.caching.initialize import get_cache_handler
|
10
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
11
|
+
from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
|
12
|
+
from synth_ai.zyk.lms.constants import SPECIAL_BASE_TEMPS
|
13
|
+
from synth_ai.zyk.lms.vendors.core.openai_api import OpenAIStructuredOutputClient
|
14
|
+
|
15
|
+
# Since the mistralai package doesn't expose an exceptions module,
|
16
|
+
# we fallback to catching all Exceptions for retry.
|
17
|
+
MISTRAL_EXCEPTIONS_TO_RETRY: Tuple[Type[Exception], ...] = (Exception,)
|
18
|
+
|
19
|
+
|
20
|
+
class MistralAPI(VendorBase):
|
21
|
+
used_for_structured_outputs: bool = True
|
22
|
+
exceptions_to_retry: Tuple = MISTRAL_EXCEPTIONS_TO_RETRY
|
23
|
+
_openai_fallback: Any
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
exceptions_to_retry: Tuple[Type[Exception], ...] = MISTRAL_EXCEPTIONS_TO_RETRY,
|
28
|
+
used_for_structured_outputs: bool = False,
|
29
|
+
):
|
30
|
+
self.used_for_structured_outputs = used_for_structured_outputs
|
31
|
+
self.exceptions_to_retry = exceptions_to_retry
|
32
|
+
self._openai_fallback = None
|
33
|
+
|
34
|
+
# @backoff.on_exception(
|
35
|
+
# backoff.expo,
|
36
|
+
# MISTRAL_EXCEPTIONS_TO_RETRY,
|
37
|
+
# max_tries=BACKOFF_TOLERANCE,
|
38
|
+
# on_giveup=lambda e: print(e),
|
39
|
+
# )
|
40
|
+
async def _hit_api_async(
|
41
|
+
self,
|
42
|
+
model: str,
|
43
|
+
messages: List[Dict[str, Any]],
|
44
|
+
lm_config: Dict[str, Any],
|
45
|
+
response_model: Optional[BaseModel] = None,
|
46
|
+
use_ephemeral_cache_only: bool = False,
|
47
|
+
reasoning_effort: str = "high",
|
48
|
+
tools: Optional[List[BaseTool]] = None,
|
49
|
+
) -> BaseLMResponse:
|
50
|
+
assert (
|
51
|
+
lm_config.get("response_model", None) is None
|
52
|
+
), "response_model is not supported for standard calls"
|
53
|
+
assert not (response_model and tools), "Cannot provide both response_model and tools"
|
54
|
+
used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
|
55
|
+
lm_config["reasoning_effort"] = reasoning_effort
|
56
|
+
cache_result = used_cache_handler.hit_managed_cache(
|
57
|
+
model, messages, lm_config=lm_config, tools=tools
|
58
|
+
)
|
59
|
+
if cache_result:
|
60
|
+
assert type(cache_result) in [
|
61
|
+
BaseLMResponse,
|
62
|
+
str,
|
63
|
+
], f"Expected BaseLMResponse or str, got {type(cache_result)}"
|
64
|
+
return (
|
65
|
+
cache_result
|
66
|
+
if type(cache_result) == BaseLMResponse
|
67
|
+
else BaseLMResponse(
|
68
|
+
raw_response=cache_result, structured_output=None, tool_calls=None
|
69
|
+
)
|
70
|
+
)
|
71
|
+
|
72
|
+
mistral_messages = [
|
73
|
+
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
74
|
+
]
|
75
|
+
functions = [tool.to_mistral_tool() for tool in tools] if tools else None
|
76
|
+
params = {
|
77
|
+
"model": model,
|
78
|
+
"messages": mistral_messages,
|
79
|
+
"max_tokens": lm_config.get("max_tokens", 4096),
|
80
|
+
"temperature": lm_config.get(
|
81
|
+
"temperature", SPECIAL_BASE_TEMPS.get(model, 0)
|
82
|
+
),
|
83
|
+
"stream": False,
|
84
|
+
"tool_choice": "auto" if functions else None,
|
85
|
+
|
86
|
+
}
|
87
|
+
if response_model:
|
88
|
+
params["response_format"] = response_model
|
89
|
+
elif tools:
|
90
|
+
params["tools"] = functions
|
91
|
+
|
92
|
+
async with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
|
93
|
+
response = await client.chat.complete_async(**params)
|
94
|
+
|
95
|
+
message = response.choices[0].message
|
96
|
+
try:
|
97
|
+
raw_response = message.content
|
98
|
+
except AttributeError:
|
99
|
+
raw_response = ""
|
100
|
+
|
101
|
+
tool_calls = []
|
102
|
+
try:
|
103
|
+
if message.tool_calls:
|
104
|
+
tool_calls = [
|
105
|
+
{
|
106
|
+
"id": call.id,
|
107
|
+
"type": "function",
|
108
|
+
"function": {
|
109
|
+
"name": call.function.name,
|
110
|
+
"arguments": call.function.arguments,
|
111
|
+
},
|
112
|
+
}
|
113
|
+
for call in message.tool_calls
|
114
|
+
]
|
115
|
+
except AttributeError:
|
116
|
+
pass
|
117
|
+
|
118
|
+
lm_response = BaseLMResponse(
|
119
|
+
raw_response=raw_response,
|
120
|
+
structured_output=None,
|
121
|
+
tool_calls=tool_calls if tool_calls else None,
|
122
|
+
)
|
123
|
+
lm_config["reasoning_effort"] = reasoning_effort
|
124
|
+
used_cache_handler.add_to_managed_cache(
|
125
|
+
model, messages, lm_config=lm_config, output=lm_response, tools=tools
|
126
|
+
)
|
127
|
+
return lm_response
|
128
|
+
|
129
|
+
# @backoff.on_exception(
|
130
|
+
# backoff.expo,
|
131
|
+
# MISTRAL_EXCEPTIONS_TO_RETRY,
|
132
|
+
# max_tries=BACKOFF_TOLERANCE,
|
133
|
+
# on_giveup=lambda e: print(e),
|
134
|
+
# )
|
135
|
+
def _hit_api_sync(
|
136
|
+
self,
|
137
|
+
model: str,
|
138
|
+
messages: List[Dict[str, Any]],
|
139
|
+
lm_config: Dict[str, Any],
|
140
|
+
response_model: Optional[BaseModel] = None,
|
141
|
+
use_ephemeral_cache_only: bool = False,
|
142
|
+
reasoning_effort: str = "high",
|
143
|
+
tools: Optional[List[BaseTool]] = None,
|
144
|
+
) -> BaseLMResponse:
|
145
|
+
assert (
|
146
|
+
lm_config.get("response_model", None) is None
|
147
|
+
), "response_model is not supported for standard calls"
|
148
|
+
assert not (response_model and tools), "Cannot provide both response_model and tools"
|
149
|
+
|
150
|
+
used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
|
151
|
+
lm_config["reasoning_effort"] = reasoning_effort
|
152
|
+
cache_result = used_cache_handler.hit_managed_cache(
|
153
|
+
model, messages, lm_config=lm_config, tools=tools
|
154
|
+
)
|
155
|
+
if cache_result:
|
156
|
+
assert type(cache_result) in [
|
157
|
+
BaseLMResponse,
|
158
|
+
str,
|
159
|
+
], f"Expected BaseLMResponse or str, got {type(cache_result)}"
|
160
|
+
return (
|
161
|
+
cache_result
|
162
|
+
if type(cache_result) == BaseLMResponse
|
163
|
+
else BaseLMResponse(
|
164
|
+
raw_response=cache_result, structured_output=None, tool_calls=None
|
165
|
+
)
|
166
|
+
)
|
167
|
+
|
168
|
+
mistral_messages = [
|
169
|
+
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
170
|
+
]
|
171
|
+
functions = [tool.to_mistral_tool() for tool in tools] if tools else None
|
172
|
+
|
173
|
+
params = {
|
174
|
+
"model": model,
|
175
|
+
"messages": mistral_messages,
|
176
|
+
"max_tokens": lm_config.get("max_tokens", 4096),
|
177
|
+
"temperature": lm_config.get(
|
178
|
+
"temperature", SPECIAL_BASE_TEMPS.get(model, 0)
|
179
|
+
),
|
180
|
+
"stream": False,
|
181
|
+
"tool_choice": "auto" if functions else None,
|
182
|
+
#"tools": functions,
|
183
|
+
}
|
184
|
+
if response_model:
|
185
|
+
params["response_format"] = response_model
|
186
|
+
elif tools:
|
187
|
+
params["tools"] = functions
|
188
|
+
|
189
|
+
with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
|
190
|
+
response = client.chat.complete(**params)
|
191
|
+
|
192
|
+
message = response.choices[0].message
|
193
|
+
try:
|
194
|
+
raw_response = message.content
|
195
|
+
except AttributeError:
|
196
|
+
raw_response = ""
|
197
|
+
|
198
|
+
tool_calls = []
|
199
|
+
try:
|
200
|
+
if message.tool_calls:
|
201
|
+
tool_calls = [
|
202
|
+
{
|
203
|
+
"id": call.id,
|
204
|
+
"type": "function",
|
205
|
+
"function": {
|
206
|
+
"name": call.function.name,
|
207
|
+
"arguments": call.function.arguments,
|
208
|
+
},
|
209
|
+
}
|
210
|
+
for call in message.tool_calls
|
211
|
+
]
|
212
|
+
except AttributeError:
|
213
|
+
pass
|
214
|
+
|
215
|
+
lm_response = BaseLMResponse(
|
216
|
+
raw_response=raw_response,
|
217
|
+
structured_output=None,
|
218
|
+
tool_calls=tool_calls if tool_calls else None,
|
219
|
+
)
|
220
|
+
lm_config["reasoning_effort"] = reasoning_effort
|
221
|
+
used_cache_handler.add_to_managed_cache(
|
222
|
+
model, messages, lm_config=lm_config, output=lm_response, tools=tools
|
223
|
+
)
|
224
|
+
return lm_response
|
225
|
+
|
226
|
+
async def _hit_api_async_structured_output(
|
227
|
+
self,
|
228
|
+
model: str,
|
229
|
+
messages: List[Dict[str, Any]],
|
230
|
+
response_model: BaseModel,
|
231
|
+
temperature: float,
|
232
|
+
use_ephemeral_cache_only: bool = False,
|
233
|
+
) -> BaseLMResponse:
|
234
|
+
try:
|
235
|
+
mistral_messages = [
|
236
|
+
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
237
|
+
]
|
238
|
+
async with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
|
239
|
+
response = await client.chat.complete_async(
|
240
|
+
model=model,
|
241
|
+
messages=mistral_messages,
|
242
|
+
max_tokens=4096,
|
243
|
+
temperature=temperature,
|
244
|
+
stream=False,
|
245
|
+
)
|
246
|
+
result = response.choices[0].message.content
|
247
|
+
parsed = json.loads(result)
|
248
|
+
lm_response = BaseLMResponse(
|
249
|
+
raw_response="",
|
250
|
+
structured_output=response_model(**parsed),
|
251
|
+
tool_calls=None,
|
252
|
+
)
|
253
|
+
return lm_response
|
254
|
+
except (json.JSONDecodeError, pydantic.ValidationError):
|
255
|
+
if self._openai_fallback is None:
|
256
|
+
self._openai_fallback = OpenAIStructuredOutputClient()
|
257
|
+
return await self._openai_fallback._hit_api_async_structured_output(
|
258
|
+
model="gpt-4o",
|
259
|
+
messages=messages,
|
260
|
+
response_model=response_model,
|
261
|
+
temperature=temperature,
|
262
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
263
|
+
)
|
264
|
+
|
265
|
+
def _hit_api_sync_structured_output(
|
266
|
+
self,
|
267
|
+
model: str,
|
268
|
+
messages: List[Dict[str, Any]],
|
269
|
+
response_model: BaseModel,
|
270
|
+
temperature: float,
|
271
|
+
use_ephemeral_cache_only: bool = False,
|
272
|
+
) -> BaseLMResponse:
|
273
|
+
try:
|
274
|
+
mistral_messages = [
|
275
|
+
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
276
|
+
]
|
277
|
+
with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
|
278
|
+
response = client.chat.complete(
|
279
|
+
model=model,
|
280
|
+
messages=mistral_messages,
|
281
|
+
max_tokens=4096,
|
282
|
+
temperature=temperature,
|
283
|
+
stream=False,
|
284
|
+
)
|
285
|
+
result = response.choices[0].message.content
|
286
|
+
parsed = json.loads(result)
|
287
|
+
lm_response = BaseLMResponse(
|
288
|
+
raw_response="",
|
289
|
+
structured_output=response_model(**parsed),
|
290
|
+
tool_calls=None,
|
291
|
+
)
|
292
|
+
return lm_response
|
293
|
+
except (json.JSONDecodeError, pydantic.ValidationError):
|
294
|
+
print("WARNING - Falling back to OpenAI - THIS IS SLOW")
|
295
|
+
if self._openai_fallback is None:
|
296
|
+
self._openai_fallback = OpenAIStructuredOutputClient()
|
297
|
+
return self._openai_fallback._hit_api_sync_structured_output(
|
298
|
+
model="gpt-4o",
|
299
|
+
messages=messages,
|
300
|
+
response_model=response_model,
|
301
|
+
temperature=temperature,
|
302
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
303
|
+
)
|
304
|
+
|
305
|
+
|
306
|
+
if __name__ == "__main__":
|
307
|
+
import asyncio
|
308
|
+
|
309
|
+
from pydantic import BaseModel
|
310
|
+
|
311
|
+
class TestModel(BaseModel):
|
312
|
+
name: str
|
313
|
+
|
314
|
+
client = MistralAPI(used_for_structured_outputs=True, exceptions_to_retry=[])
|
315
|
+
import time
|
316
|
+
|
317
|
+
t = time.time()
|
318
|
+
|
319
|
+
async def run_async():
|
320
|
+
response = await client._hit_api_async_structured_output(
|
321
|
+
model="mistral-large-latest",
|
322
|
+
messages=[{"role": "user", "content": "What is the capital of the moon?"}],
|
323
|
+
response_model=TestModel,
|
324
|
+
temperature=0.0,
|
325
|
+
)
|
326
|
+
print(response)
|
327
|
+
return response
|
328
|
+
|
329
|
+
response = asyncio.run(run_async())
|
330
|
+
t2 = time.time()
|
331
|
+
print(f"Got {len(response.name)} chars in {t2-t} seconds")
|
@@ -0,0 +1,187 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
3
|
+
|
4
|
+
import openai
|
5
|
+
import pydantic_core
|
6
|
+
|
7
|
+
# from openai import AsyncOpenAI, OpenAI
|
8
|
+
from pydantic import BaseModel
|
9
|
+
|
10
|
+
from synth_ai.zyk.lms.caching.initialize import get_cache_handler
|
11
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
12
|
+
from synth_ai.zyk.lms.vendors.base import BaseLMResponse
|
13
|
+
from synth_ai.zyk.lms.constants import SPECIAL_BASE_TEMPS, OPENAI_REASONING_MODELS
|
14
|
+
from synth_ai.zyk.lms.vendors.openai_standard import OpenAIStandard
|
15
|
+
|
16
|
+
OPENAI_EXCEPTIONS_TO_RETRY: Tuple[Type[Exception], ...] = (
|
17
|
+
pydantic_core._pydantic_core.ValidationError,
|
18
|
+
openai.OpenAIError,
|
19
|
+
openai.APIConnectionError,
|
20
|
+
openai.RateLimitError,
|
21
|
+
openai.APIError,
|
22
|
+
openai.Timeout,
|
23
|
+
openai.InternalServerError,
|
24
|
+
openai.APIConnectionError,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class OpenAIStructuredOutputClient(OpenAIStandard):
|
29
|
+
def __init__(self, synth_logging: bool = True):
|
30
|
+
if synth_logging:
|
31
|
+
# print("Using synth logging - OpenAIStructuredOutputClient")
|
32
|
+
from synth_sdk import AsyncOpenAI, OpenAI
|
33
|
+
else:
|
34
|
+
# print("Not using synth logging - OpenAIStructuredOutputClient")
|
35
|
+
from openai import AsyncOpenAI, OpenAI
|
36
|
+
|
37
|
+
super().__init__(
|
38
|
+
used_for_structured_outputs=True,
|
39
|
+
exceptions_to_retry=OPENAI_EXCEPTIONS_TO_RETRY,
|
40
|
+
sync_client=OpenAI(),
|
41
|
+
async_client=AsyncOpenAI(),
|
42
|
+
)
|
43
|
+
|
44
|
+
async def _hit_api_async_structured_output(
|
45
|
+
self,
|
46
|
+
model: str,
|
47
|
+
messages: List[Dict[str, Any]],
|
48
|
+
response_model: BaseModel,
|
49
|
+
temperature: float,
|
50
|
+
use_ephemeral_cache_only: bool = False,
|
51
|
+
tools: Optional[List[BaseTool]] = None,
|
52
|
+
reasoning_effort: str = "high",
|
53
|
+
) -> str:
|
54
|
+
if tools:
|
55
|
+
raise ValueError("Tools are not supported for async structured output")
|
56
|
+
# "Hit client")
|
57
|
+
lm_config = {"temperature": temperature, "response_model": response_model, "reasoning_effort": reasoning_effort}
|
58
|
+
used_cache_handler = get_cache_handler(
|
59
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only
|
60
|
+
)
|
61
|
+
cache_result = used_cache_handler.hit_managed_cache(
|
62
|
+
model, messages, lm_config=lm_config
|
63
|
+
)
|
64
|
+
if cache_result:
|
65
|
+
# print("Hit cache")
|
66
|
+
assert type(cache_result) in [
|
67
|
+
dict,
|
68
|
+
BaseLMResponse,
|
69
|
+
], f"Expected dict or BaseLMResponse, got {type(cache_result)}"
|
70
|
+
return (
|
71
|
+
cache_result["response"] if type(cache_result) == dict else cache_result
|
72
|
+
)
|
73
|
+
if model in OPENAI_REASONING_MODELS:
|
74
|
+
output = await self.async_client.beta.chat.completions.parse(
|
75
|
+
model=model,
|
76
|
+
messages=messages,
|
77
|
+
temperature=lm_config.get(
|
78
|
+
"temperature", SPECIAL_BASE_TEMPS.get(model, 0)
|
79
|
+
),
|
80
|
+
response_format=response_model,
|
81
|
+
reasoning_effort=reasoning_effort,
|
82
|
+
)
|
83
|
+
else:
|
84
|
+
output = await self.async_client.beta.chat.completions.parse(
|
85
|
+
model=model,
|
86
|
+
messages=messages,
|
87
|
+
response_format=response_model,
|
88
|
+
)
|
89
|
+
# "Output", output)
|
90
|
+
api_result = response_model(**json.loads(output.choices[0].message.content))
|
91
|
+
lm_response = BaseLMResponse(
|
92
|
+
raw_response="",
|
93
|
+
structured_output=api_result,
|
94
|
+
tool_calls=None,
|
95
|
+
)
|
96
|
+
used_cache_handler.add_to_managed_cache(
|
97
|
+
model, messages, lm_config, output=lm_response
|
98
|
+
)
|
99
|
+
return lm_response
|
100
|
+
|
101
|
+
def _hit_api_sync_structured_output(
|
102
|
+
self,
|
103
|
+
model: str,
|
104
|
+
messages: List[Dict[str, Any]],
|
105
|
+
response_model: BaseModel,
|
106
|
+
temperature: float,
|
107
|
+
use_ephemeral_cache_only: bool = False,
|
108
|
+
tools: Optional[List[BaseTool]] = None,
|
109
|
+
reasoning_effort: str = "high",
|
110
|
+
) -> str:
|
111
|
+
if tools:
|
112
|
+
raise ValueError("Tools are not supported for sync structured output")
|
113
|
+
lm_config = {"temperature": temperature, "response_model": response_model, "reasoning_effort": reasoning_effort}
|
114
|
+
used_cache_handler = get_cache_handler(
|
115
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only
|
116
|
+
)
|
117
|
+
cache_result = used_cache_handler.hit_managed_cache(
|
118
|
+
model, messages, lm_config=lm_config
|
119
|
+
)
|
120
|
+
if cache_result:
|
121
|
+
assert type(cache_result) in [
|
122
|
+
dict,
|
123
|
+
BaseLMResponse,
|
124
|
+
], f"Expected dict or BaseLMResponse, got {type(cache_result)}"
|
125
|
+
return (
|
126
|
+
cache_result["response"] if type(cache_result) == dict else cache_result
|
127
|
+
)
|
128
|
+
if model in OPENAI_REASONING_MODELS:
|
129
|
+
output = self.sync_client.beta.chat.completions.parse(
|
130
|
+
model=model,
|
131
|
+
messages=messages,
|
132
|
+
temperature=lm_config.get(
|
133
|
+
"temperature", SPECIAL_BASE_TEMPS.get(model, 0)
|
134
|
+
),
|
135
|
+
response_format=response_model,
|
136
|
+
reasoning_effort=reasoning_effort,
|
137
|
+
)
|
138
|
+
else:
|
139
|
+
output = self.sync_client.beta.chat.completions.parse(
|
140
|
+
model=model,
|
141
|
+
messages=messages,
|
142
|
+
response_format=response_model,
|
143
|
+
)
|
144
|
+
api_result = response_model(**json.loads(output.choices[0].message.content))
|
145
|
+
|
146
|
+
lm_response = BaseLMResponse(
|
147
|
+
raw_response="",
|
148
|
+
structured_output=api_result,
|
149
|
+
tool_calls=None,
|
150
|
+
)
|
151
|
+
used_cache_handler.add_to_managed_cache(
|
152
|
+
model, messages, lm_config=lm_config, output=lm_response
|
153
|
+
)
|
154
|
+
return lm_response
|
155
|
+
|
156
|
+
|
157
|
+
class OpenAIPrivate(OpenAIStandard):
|
158
|
+
def __init__(self, synth_logging: bool = True):
|
159
|
+
if synth_logging:
|
160
|
+
# print("Using synth logging - OpenAIPrivate")
|
161
|
+
from synth_sdk import AsyncOpenAI, OpenAI
|
162
|
+
else:
|
163
|
+
# print("Not using synth logging - OpenAIPrivate")
|
164
|
+
from openai import AsyncOpenAI, OpenAI
|
165
|
+
|
166
|
+
self.sync_client = OpenAI()
|
167
|
+
self.async_client = AsyncOpenAI()
|
168
|
+
|
169
|
+
|
170
|
+
if __name__ == "__main__":
|
171
|
+
client = OpenAIStructuredOutputClient(
|
172
|
+
sync_client=openai.OpenAI(),
|
173
|
+
async_client=openai.AsyncOpenAI(),
|
174
|
+
used_for_structured_outputs=True,
|
175
|
+
exceptions_to_retry=[],
|
176
|
+
)
|
177
|
+
|
178
|
+
class TestModel(BaseModel):
|
179
|
+
name: str
|
180
|
+
|
181
|
+
sync_model_response = client._hit_api_sync_structured_output(
|
182
|
+
model="gpt-4o-mini-2024-07-18",
|
183
|
+
messages=[{"role": "user", "content": " What is the capital of the moon?"}],
|
184
|
+
response_model=TestModel,
|
185
|
+
temperature=0.0,
|
186
|
+
)
|
187
|
+
# print(sync_model_response)
|
File without changes
|
File without changes
|