synth-ai 0.1.0.dev27__py3-none-any.whl → 0.1.0.dev29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- public_tests/test_agent.py +11 -11
- public_tests/test_all_structured_outputs.py +32 -37
- public_tests/test_anthropic_structured_outputs.py +0 -0
- public_tests/test_deepseek_structured_outputs.py +0 -0
- public_tests/test_deepseek_tools.py +64 -0
- public_tests/test_gemini_structured_outputs.py +106 -0
- public_tests/test_models.py +27 -27
- public_tests/test_openai_structured_outputs.py +106 -0
- public_tests/test_reasoning_models.py +9 -7
- public_tests/test_recursive_structured_outputs.py +30 -30
- public_tests/test_structured.py +137 -0
- public_tests/test_structured_outputs.py +22 -13
- public_tests/test_text.py +160 -0
- public_tests/test_tools.py +300 -0
- synth_ai/__init__.py +1 -4
- synth_ai/zyk/__init__.py +2 -2
- synth_ai/zyk/lms/caching/ephemeral.py +54 -32
- synth_ai/zyk/lms/caching/handler.py +43 -15
- synth_ai/zyk/lms/caching/persistent.py +55 -27
- synth_ai/zyk/lms/core/main.py +26 -14
- synth_ai/zyk/lms/core/vendor_clients.py +1 -1
- synth_ai/zyk/lms/structured_outputs/handler.py +79 -45
- synth_ai/zyk/lms/structured_outputs/rehabilitate.py +3 -2
- synth_ai/zyk/lms/tools/base.py +104 -0
- synth_ai/zyk/lms/vendors/base.py +22 -6
- synth_ai/zyk/lms/vendors/core/anthropic_api.py +130 -95
- synth_ai/zyk/lms/vendors/core/gemini_api.py +153 -34
- synth_ai/zyk/lms/vendors/core/mistral_api.py +160 -54
- synth_ai/zyk/lms/vendors/core/openai_api.py +64 -53
- synth_ai/zyk/lms/vendors/openai_standard.py +197 -41
- synth_ai/zyk/lms/vendors/supported/deepseek.py +55 -0
- {synth_ai-0.1.0.dev27.dist-info → synth_ai-0.1.0.dev29.dist-info}/METADATA +2 -5
- synth_ai-0.1.0.dev29.dist-info/RECORD +65 -0
- public_tests/test_sonnet_thinking.py +0 -178
- synth_ai-0.1.0.dev27.dist-info/RECORD +0 -57
- {synth_ai-0.1.0.dev27.dist-info → synth_ai-0.1.0.dev29.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.0.dev27.dist-info → synth_ai-0.1.0.dev29.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.0.dev27.dist-info → synth_ai-0.1.0.dev29.dist-info}/top_level.txt +0 -0
synth_ai/zyk/lms/core/main.py
CHANGED
@@ -10,7 +10,7 @@ from synth_ai.zyk.lms.core.vendor_clients import (
|
|
10
10
|
)
|
11
11
|
from synth_ai.zyk.lms.structured_outputs.handler import StructuredOutputHandler
|
12
12
|
from synth_ai.zyk.lms.vendors.base import VendorBase
|
13
|
-
|
13
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
14
14
|
REASONING_MODELS = ["deepseek-reasoner", "o1-mini", "o1-preview", "o1", "o3"]
|
15
15
|
|
16
16
|
|
@@ -120,6 +120,7 @@ class LM:
|
|
120
120
|
images_as_bytes: List[Any] = [],
|
121
121
|
response_model: Optional[BaseModel] = None,
|
122
122
|
use_ephemeral_cache_only: bool = False,
|
123
|
+
tools: Optional[List[BaseTool]] = None,
|
123
124
|
):
|
124
125
|
assert (system_message is None) == (
|
125
126
|
user_message is None
|
@@ -127,15 +128,15 @@ class LM:
|
|
127
128
|
assert (
|
128
129
|
(messages is None) != (system_message is None)
|
129
130
|
), "Must provide either messages or system_message/user_message pair, but not both"
|
130
|
-
|
131
|
+
assert not (response_model and tools), "Cannot provide both response_model and tools"
|
131
132
|
if messages is None:
|
132
133
|
messages = build_messages(
|
133
134
|
system_message, user_message, images_as_bytes, self.model_name
|
134
135
|
)
|
135
|
-
|
136
|
+
result = None
|
136
137
|
if response_model:
|
137
138
|
try:
|
138
|
-
|
139
|
+
result = self.structured_output_handler.call_sync(
|
139
140
|
messages,
|
140
141
|
model=self.model_name,
|
141
142
|
lm_config=self.lm_config,
|
@@ -144,7 +145,7 @@ class LM:
|
|
144
145
|
)
|
145
146
|
except StructuredOutputCoercionFailureException:
|
146
147
|
# print("Falling back to backup handler")
|
147
|
-
|
148
|
+
result = self.backup_structured_output_handler.call_sync(
|
148
149
|
messages,
|
149
150
|
model=self.model_name,
|
150
151
|
lm_config=self.lm_config,
|
@@ -152,12 +153,17 @@ class LM:
|
|
152
153
|
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
153
154
|
)
|
154
155
|
else:
|
155
|
-
|
156
|
+
result = self.client._hit_api_sync(
|
156
157
|
messages=messages,
|
157
158
|
model=self.model_name,
|
158
159
|
lm_config=self.lm_config,
|
159
160
|
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
161
|
+
tools=tools,
|
160
162
|
)
|
163
|
+
assert isinstance(result.raw_response, str), "Raw response must be a string"
|
164
|
+
assert (isinstance(result.structured_output, BaseModel) or result.structured_output is None), "Structured output must be a Pydantic model or None"
|
165
|
+
assert (isinstance(result.tool_calls, list) or result.tool_calls is None), "Tool calls must be a list or None"
|
166
|
+
return result
|
161
167
|
|
162
168
|
async def respond_async(
|
163
169
|
self,
|
@@ -167,6 +173,7 @@ class LM:
|
|
167
173
|
images_as_bytes: List[Any] = [],
|
168
174
|
response_model: Optional[BaseModel] = None,
|
169
175
|
use_ephemeral_cache_only: bool = False,
|
176
|
+
tools: Optional[List[BaseTool]] = None,
|
170
177
|
):
|
171
178
|
# "In respond_async")
|
172
179
|
assert (system_message is None) == (
|
@@ -176,15 +183,16 @@ class LM:
|
|
176
183
|
(messages is None) != (system_message is None)
|
177
184
|
), "Must provide either messages or system_message/user_message pair, but not both"
|
178
185
|
|
186
|
+
assert not (response_model and tools), "Cannot provide both response_model and tools"
|
179
187
|
if messages is None:
|
180
188
|
messages = build_messages(
|
181
189
|
system_message, user_message, images_as_bytes, self.model_name
|
182
190
|
)
|
183
|
-
|
191
|
+
result = None
|
184
192
|
if response_model:
|
185
193
|
try:
|
186
|
-
|
187
|
-
|
194
|
+
print("Trying structured output handler")
|
195
|
+
result = await self.structured_output_handler.call_async(
|
188
196
|
messages,
|
189
197
|
model=self.model_name,
|
190
198
|
lm_config=self.lm_config,
|
@@ -192,8 +200,8 @@ class LM:
|
|
192
200
|
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
193
201
|
)
|
194
202
|
except StructuredOutputCoercionFailureException:
|
195
|
-
|
196
|
-
|
203
|
+
print("Falling back to backup handler")
|
204
|
+
result = await self.backup_structured_output_handler.call_async(
|
197
205
|
messages,
|
198
206
|
model=self.model_name,
|
199
207
|
lm_config=self.lm_config,
|
@@ -201,14 +209,18 @@ class LM:
|
|
201
209
|
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
202
210
|
)
|
203
211
|
else:
|
204
|
-
|
205
|
-
|
212
|
+
print("Calling API no response model")
|
213
|
+
result = await self.client._hit_api_async(
|
206
214
|
messages=messages,
|
207
215
|
model=self.model_name,
|
208
216
|
lm_config=self.lm_config,
|
209
217
|
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
218
|
+
tools=tools,
|
210
219
|
)
|
211
|
-
|
220
|
+
assert isinstance(result.raw_response, str), "Raw response must be a string"
|
221
|
+
assert (isinstance(result.structured_output, BaseModel) or result.structured_output is None), "Structured output must be a Pydantic model or None"
|
222
|
+
assert (isinstance(result.tool_calls, list) or result.tool_calls is None), "Tool calls must be a list or None"
|
223
|
+
return result
|
212
224
|
|
213
225
|
if __name__ == "__main__":
|
214
226
|
import asyncio
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import logging
|
1
2
|
import time
|
2
3
|
from abc import ABC, abstractmethod
|
3
4
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
@@ -11,13 +12,13 @@ from synth_ai.zyk.lms.structured_outputs.inject import (
|
|
11
12
|
from synth_ai.zyk.lms.structured_outputs.rehabilitate import (
|
12
13
|
fix_errant_forced_async,
|
13
14
|
fix_errant_forced_sync,
|
14
|
-
fix_errant_stringified_json_async,
|
15
|
-
fix_errant_stringified_json_sync,
|
16
15
|
pull_out_structured_output,
|
17
16
|
)
|
18
|
-
from synth_ai.zyk.lms.vendors.base import VendorBase
|
17
|
+
from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
|
19
18
|
from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
|
20
19
|
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
21
22
|
|
22
23
|
class StructuredHandlerBase(ABC):
|
23
24
|
core_client: VendorBase
|
@@ -49,7 +50,7 @@ class StructuredHandlerBase(ABC):
|
|
49
50
|
temperature: float = 0.0,
|
50
51
|
use_ephemeral_cache_only: bool = False,
|
51
52
|
reasoning_effort: str = "high",
|
52
|
-
) ->
|
53
|
+
) -> BaseLMResponse:
|
53
54
|
if temperature == 0.0:
|
54
55
|
temperature = SPECIAL_BASE_TEMPS.get(model, 0.0)
|
55
56
|
# print("Calling from base")
|
@@ -73,7 +74,7 @@ class StructuredHandlerBase(ABC):
|
|
73
74
|
temperature: float = 0.0,
|
74
75
|
use_ephemeral_cache_only: bool = False,
|
75
76
|
reasoning_effort: str = "high",
|
76
|
-
) ->
|
77
|
+
) -> BaseLMResponse:
|
77
78
|
if temperature == 0.0:
|
78
79
|
temperature = SPECIAL_BASE_TEMPS.get(model, 0.0)
|
79
80
|
return self._process_call_sync(
|
@@ -97,7 +98,7 @@ class StructuredHandlerBase(ABC):
|
|
97
98
|
api_call_method,
|
98
99
|
use_ephemeral_cache_only: bool = False,
|
99
100
|
reasoning_effort: str = "high",
|
100
|
-
) ->
|
101
|
+
) -> BaseLMResponse:
|
101
102
|
pass
|
102
103
|
|
103
104
|
@abstractmethod
|
@@ -109,7 +110,7 @@ class StructuredHandlerBase(ABC):
|
|
109
110
|
api_call_method,
|
110
111
|
use_ephemeral_cache_only: bool = False,
|
111
112
|
reasoning_effort: str = "high",
|
112
|
-
) ->
|
113
|
+
) -> BaseLMResponse:
|
113
114
|
pass
|
114
115
|
|
115
116
|
|
@@ -140,11 +141,9 @@ class StringifiedJSONHandler(StructuredHandlerBase):
|
|
140
141
|
api_call_method: Callable,
|
141
142
|
use_ephemeral_cache_only: bool = False,
|
142
143
|
reasoning_effort: str = "high",
|
143
|
-
) ->
|
144
|
-
|
145
|
-
assert
|
146
|
-
api_call_method, Callable
|
147
|
-
), "api_call_method must be a callable"
|
144
|
+
) -> BaseLMResponse:
|
145
|
+
logger.info(f"Processing structured output call for model: {model}")
|
146
|
+
assert callable(api_call_method), "api_call_method must be a callable"
|
148
147
|
assert (
|
149
148
|
response_model is not None
|
150
149
|
), "Don't use this handler for unstructured outputs"
|
@@ -168,43 +167,63 @@ class StringifiedJSONHandler(StructuredHandlerBase):
|
|
168
167
|
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
169
168
|
reasoning_effort=reasoning_effort,
|
170
169
|
)
|
171
|
-
|
172
|
-
|
173
|
-
|
170
|
+
logger.debug(f"Time to get response: {time.time() - t0:.2f}s")
|
171
|
+
|
172
|
+
# Check if we got a cached BaseLMResponse
|
173
|
+
assert (
|
174
|
+
type(raw_text_response_or_cached_hit) in [str, BaseLMResponse]
|
175
|
+
), f"Expected str or BaseLMResponse, got {type(raw_text_response_or_cached_hit)}"
|
176
|
+
if type(raw_text_response_or_cached_hit) == BaseLMResponse:
|
177
|
+
print("Got cached hit, returning directly")
|
178
|
+
raw_text_response = raw_text_response_or_cached_hit.raw_response
|
174
179
|
else:
|
175
180
|
raw_text_response = raw_text_response_or_cached_hit
|
181
|
+
logger.debug(f"Raw response from model:\n{raw_text_response}")
|
182
|
+
|
183
|
+
print("Trying to parse structured output")
|
176
184
|
try:
|
177
185
|
structured_output = pull_out_structured_output(
|
178
186
|
raw_text_response, response_model
|
179
187
|
)
|
188
|
+
|
189
|
+
print("Successfully parsed structured output on first attempt")
|
180
190
|
break
|
181
|
-
# except Exception as e:
|
182
|
-
# try:
|
183
|
-
# structured_output = await fix_errant_stringified_json_async(raw_text_response, response_model)
|
184
|
-
# break
|
185
191
|
except Exception as e:
|
192
|
+
logger.warning(f"Failed to parse structured output: {str(e)}")
|
186
193
|
try:
|
187
|
-
|
188
|
-
# print(f"Got error {e}, attempting to fix")
|
194
|
+
print("Attempting to fix with forced JSON parser")
|
189
195
|
structured_output = await fix_errant_forced_async(
|
190
196
|
messages_with_json_formatting_instructions,
|
191
197
|
raw_text_response,
|
192
198
|
response_model,
|
193
199
|
"gpt-4o-mini",
|
194
200
|
)
|
195
|
-
|
196
|
-
|
201
|
+
assert isinstance(structured_output, BaseModel), "Structured output must be a Pydantic model"
|
202
|
+
assert not isinstance(structured_output, BaseLMResponse), "Got BaseLMResponse instead of Pydantic model"
|
203
|
+
print("Successfully fixed and parsed structured output")
|
197
204
|
break
|
198
205
|
except Exception as e:
|
206
|
+
logger.error(f"Failed to fix structured output: {str(e)}")
|
199
207
|
previously_failed_error_messages.append(
|
200
208
|
f"Generated attempt and got error. Attempt:\n\n{raw_text_response}\n\nError:\n\n{e}"
|
201
209
|
)
|
202
210
|
remaining_retries -= 1
|
211
|
+
logger.warning(f"Retries remaining: {remaining_retries}")
|
212
|
+
|
203
213
|
if structured_output is None:
|
214
|
+
logger.error("Failed to get structured output after all retries")
|
204
215
|
raise StructuredOutputCoercionFailureException(
|
205
216
|
"Failed to get structured output"
|
206
217
|
)
|
207
|
-
|
218
|
+
print("Successfully parsed structured output")
|
219
|
+
print(structured_output)
|
220
|
+
assert isinstance(structured_output, BaseModel), "Structured output must be a Pydantic model"
|
221
|
+
assert not isinstance(structured_output, BaseLMResponse),"Got BaseLMResponse instead of Pydantic model"
|
222
|
+
return BaseLMResponse(
|
223
|
+
raw_response=raw_text_response,
|
224
|
+
structured_output=structured_output,
|
225
|
+
tool_calls=None,
|
226
|
+
)
|
208
227
|
|
209
228
|
def _process_call_sync(
|
210
229
|
self,
|
@@ -215,10 +234,9 @@ class StringifiedJSONHandler(StructuredHandlerBase):
|
|
215
234
|
api_call_method: Callable,
|
216
235
|
use_ephemeral_cache_only: bool = False,
|
217
236
|
reasoning_effort: str = "high",
|
218
|
-
) ->
|
219
|
-
|
220
|
-
|
221
|
-
), "api_call_method must be a callable"
|
237
|
+
) -> BaseLMResponse:
|
238
|
+
logger.info(f"Processing structured output call for model: {model}")
|
239
|
+
assert callable(api_call_method), "api_call_method must be a callable"
|
222
240
|
assert (
|
223
241
|
response_model is not None
|
224
242
|
), "Don't use this handler for unstructured outputs"
|
@@ -234,7 +252,7 @@ class StringifiedJSONHandler(StructuredHandlerBase):
|
|
234
252
|
previously_failed_error_messages=previously_failed_error_messages,
|
235
253
|
)
|
236
254
|
)
|
237
|
-
|
255
|
+
t0 = time.time()
|
238
256
|
raw_text_response_or_cached_hit = api_call_method(
|
239
257
|
messages=messages_with_json_formatting_instructions,
|
240
258
|
model=model,
|
@@ -242,39 +260,54 @@ class StringifiedJSONHandler(StructuredHandlerBase):
|
|
242
260
|
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
243
261
|
reasoning_effort=reasoning_effort,
|
244
262
|
)
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
263
|
+
logger.debug(f"Time to get response: {time.time() - t0:.2f}s")
|
264
|
+
|
265
|
+
# Check if we got a cached BaseLMResponse
|
266
|
+
assert (
|
267
|
+
type(raw_text_response_or_cached_hit) in [str, BaseLMResponse]
|
268
|
+
), f"Expected str or BaseLMResponse, got {type(raw_text_response_or_cached_hit)}"
|
269
|
+
if type(raw_text_response_or_cached_hit) == BaseLMResponse:
|
270
|
+
logger.info("Got cached hit, returning directly")
|
271
|
+
raw_text_response = raw_text_response_or_cached_hit.raw_response
|
272
|
+
else:
|
249
273
|
raw_text_response = raw_text_response_or_cached_hit
|
274
|
+
logger.debug(f"Raw response from model:\n{raw_text_response}")
|
275
|
+
|
250
276
|
try:
|
251
277
|
structured_output = pull_out_structured_output(
|
252
278
|
raw_text_response, response_model
|
253
279
|
)
|
280
|
+
print("Successfully parsed structured output on first attempt")
|
254
281
|
break
|
255
|
-
# except Exception:
|
256
|
-
# try:
|
257
|
-
# structured_output = fix_errant_stringified_json_sync(raw_text_response, response_model)
|
258
|
-
# break
|
259
282
|
except Exception as e:
|
283
|
+
logger.warning(f"Failed to parse structured output: {str(e)}")
|
260
284
|
try:
|
261
|
-
|
262
|
-
# print(f"Got error {e}, attempting to fix")
|
285
|
+
print("Attempting to fix with forced JSON parser")
|
263
286
|
structured_output = fix_errant_forced_sync(
|
264
287
|
raw_text_response, response_model, "gpt-4o-mini"
|
265
288
|
)
|
289
|
+
print("Successfully fixed and parsed structured output")
|
266
290
|
break
|
267
|
-
# print(f"Time to fix: {time.time() - t0}")
|
268
291
|
except Exception as e:
|
292
|
+
logger.error(f"Failed to fix structured output: {str(e)}")
|
269
293
|
previously_failed_error_messages.append(
|
270
294
|
f"Generated attempt and got error. Attempt:\n\n{raw_text_response}\n\nError:\n\n{e}"
|
271
295
|
)
|
272
296
|
remaining_retries -= 1
|
297
|
+
logger.warning(f"Retries remaining: {remaining_retries}")
|
298
|
+
|
299
|
+
print("Successfully parsed structured output")
|
300
|
+
print(structured_output)
|
273
301
|
if structured_output is None:
|
302
|
+
logger.error("Failed to get structured output after all retries")
|
274
303
|
raise StructuredOutputCoercionFailureException(
|
275
304
|
"Failed to get structured output"
|
276
305
|
)
|
277
|
-
return
|
306
|
+
return BaseLMResponse(
|
307
|
+
raw_response=raw_text_response,
|
308
|
+
structured_output=structured_output,
|
309
|
+
tool_calls=None,
|
310
|
+
)
|
278
311
|
|
279
312
|
|
280
313
|
class ForcedJSONHandler(StructuredHandlerBase):
|
@@ -296,6 +329,7 @@ class ForcedJSONHandler(StructuredHandlerBase):
|
|
296
329
|
structured_output_mode="forced_json",
|
297
330
|
)
|
298
331
|
self.reasoning_effort = reasoning_effort
|
332
|
+
|
299
333
|
async def _process_call_async(
|
300
334
|
self,
|
301
335
|
messages: List[Dict[str, Any]],
|
@@ -305,7 +339,7 @@ class ForcedJSONHandler(StructuredHandlerBase):
|
|
305
339
|
temperature: float = 0.0,
|
306
340
|
use_ephemeral_cache_only: bool = False,
|
307
341
|
reasoning_effort: str = "high",
|
308
|
-
) ->
|
342
|
+
) -> BaseLMResponse:
|
309
343
|
# print("Forced JSON")
|
310
344
|
assert (
|
311
345
|
response_model is not None
|
@@ -328,7 +362,7 @@ class ForcedJSONHandler(StructuredHandlerBase):
|
|
328
362
|
temperature: float = 0.0,
|
329
363
|
use_ephemeral_cache_only: bool = False,
|
330
364
|
reasoning_effort: str = "high",
|
331
|
-
|
365
|
+
) -> BaseLMResponse:
|
332
366
|
assert (
|
333
367
|
response_model is not None
|
334
368
|
), "Don't use this handler for unstructured outputs"
|
@@ -373,7 +407,7 @@ class StructuredOutputHandler:
|
|
373
407
|
use_ephemeral_cache_only: bool = False,
|
374
408
|
lm_config: Dict[str, Any] = {},
|
375
409
|
reasoning_effort: str = "high",
|
376
|
-
) ->
|
410
|
+
) -> BaseLMResponse:
|
377
411
|
# print("Output handler call async")
|
378
412
|
return await self.handler.call_async(
|
379
413
|
messages=messages,
|
@@ -394,7 +428,7 @@ class StructuredOutputHandler:
|
|
394
428
|
use_ephemeral_cache_only: bool = False,
|
395
429
|
lm_config: Dict[str, Any] = {},
|
396
430
|
reasoning_effort: str = "high",
|
397
|
-
) ->
|
431
|
+
) -> BaseLMResponse:
|
398
432
|
return self.handler.call_sync(
|
399
433
|
messages=messages,
|
400
434
|
model=model,
|
@@ -55,6 +55,7 @@ def pull_out_structured_output(
|
|
55
55
|
raise ValueError(
|
56
56
|
f"Failed to parse response as {response_model}: {e} - {response_prepared}"
|
57
57
|
)
|
58
|
+
assert isinstance(final, BaseModel), "Structured output must be a Pydantic model"
|
58
59
|
return final
|
59
60
|
|
60
61
|
|
@@ -157,7 +158,7 @@ async def fix_errant_forced_async(
|
|
157
158
|
)
|
158
159
|
# print("Fixed response:")
|
159
160
|
# print(fixed_response)
|
160
|
-
return fixed_response
|
161
|
+
return fixed_response.structured_output
|
161
162
|
|
162
163
|
|
163
164
|
def fix_errant_forced_sync(
|
@@ -183,4 +184,4 @@ def fix_errant_forced_sync(
|
|
183
184
|
)
|
184
185
|
# print("Fixed response:")
|
185
186
|
# print(fixed_response)
|
186
|
-
return fixed_response
|
187
|
+
return fixed_response.structured_output
|
@@ -0,0 +1,104 @@
|
|
1
|
+
from typing import Type
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
|
6
|
+
class BaseTool(BaseModel):
|
7
|
+
name: str
|
8
|
+
arguments: Type[BaseModel]
|
9
|
+
description: str = ""
|
10
|
+
strict: bool = True
|
11
|
+
|
12
|
+
def to_openai_tool(self):
|
13
|
+
schema = self.arguments.model_json_schema()
|
14
|
+
schema["additionalProperties"] = False
|
15
|
+
|
16
|
+
return {
|
17
|
+
"type": "function",
|
18
|
+
"function": {
|
19
|
+
"name": self.name,
|
20
|
+
"description": self.description,
|
21
|
+
"parameters": schema,
|
22
|
+
"strict": self.strict,
|
23
|
+
},
|
24
|
+
}
|
25
|
+
|
26
|
+
def to_anthropic_tool(self):
|
27
|
+
schema = self.arguments.model_json_schema()
|
28
|
+
schema["additionalProperties"] = False
|
29
|
+
|
30
|
+
return {
|
31
|
+
"name": self.name,
|
32
|
+
"description": self.description,
|
33
|
+
"input_schema": {
|
34
|
+
"type": "object",
|
35
|
+
"properties": schema["properties"],
|
36
|
+
"required": schema.get("required", []),
|
37
|
+
},
|
38
|
+
}
|
39
|
+
|
40
|
+
def to_mistral_tool(self):
|
41
|
+
schema = self.arguments.model_json_schema()
|
42
|
+
properties = {}
|
43
|
+
for prop_name, prop in schema.get("properties", {}).items():
|
44
|
+
prop_type = prop["type"]
|
45
|
+
if prop_type == "array" and "items" in prop:
|
46
|
+
properties[prop_name] = {
|
47
|
+
"type": "array",
|
48
|
+
"items": prop["items"],
|
49
|
+
"description": prop.get("description", ""),
|
50
|
+
}
|
51
|
+
continue
|
52
|
+
|
53
|
+
properties[prop_name] = {
|
54
|
+
"type": prop_type,
|
55
|
+
"description": prop.get("description", ""),
|
56
|
+
}
|
57
|
+
if "enum" in prop:
|
58
|
+
properties[prop_name]["enum"] = prop["enum"]
|
59
|
+
|
60
|
+
parameters = {
|
61
|
+
"type": "object",
|
62
|
+
"properties": properties,
|
63
|
+
"required": schema.get("required", []),
|
64
|
+
"additionalProperties": False,
|
65
|
+
}
|
66
|
+
return {
|
67
|
+
"type": "function",
|
68
|
+
"function": {
|
69
|
+
"name": self.name,
|
70
|
+
"description": self.description,
|
71
|
+
"parameters": parameters,
|
72
|
+
},
|
73
|
+
}
|
74
|
+
|
75
|
+
def to_gemini_tool(self):
|
76
|
+
schema = self.arguments.model_json_schema()
|
77
|
+
# Convert Pydantic schema types to Gemini schema types
|
78
|
+
properties = {}
|
79
|
+
for name, prop in schema["properties"].items():
|
80
|
+
prop_type = prop.get("type", "string")
|
81
|
+
if prop_type == "array" and "items" in prop:
|
82
|
+
properties[name] = {
|
83
|
+
"type": "array",
|
84
|
+
"items": prop["items"],
|
85
|
+
"description": prop.get("description", ""),
|
86
|
+
}
|
87
|
+
continue
|
88
|
+
|
89
|
+
properties[name] = {
|
90
|
+
"type": prop_type,
|
91
|
+
"description": prop.get("description", ""),
|
92
|
+
}
|
93
|
+
if "enum" in prop:
|
94
|
+
properties[name]["enum"] = prop["enum"]
|
95
|
+
|
96
|
+
return {
|
97
|
+
"name": self.name,
|
98
|
+
"description": self.description,
|
99
|
+
"parameters": {
|
100
|
+
"type": "object",
|
101
|
+
"properties": properties,
|
102
|
+
"required": schema.get("required", []),
|
103
|
+
},
|
104
|
+
}
|
synth_ai/zyk/lms/vendors/base.py
CHANGED
@@ -1,15 +1,31 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
|
3
4
|
from pydantic import BaseModel
|
4
5
|
|
6
|
+
|
7
|
+
class BaseLMResponse(BaseModel):
|
8
|
+
raw_response: str
|
9
|
+
structured_output: Optional[BaseModel] = None
|
10
|
+
tool_calls: Optional[List[Dict]] = None
|
11
|
+
|
12
|
+
|
5
13
|
class VendorBase(ABC):
|
6
14
|
used_for_structured_outputs: bool = False
|
7
15
|
exceptions_to_retry: List[Exception] = []
|
8
|
-
|
16
|
+
|
9
17
|
@abstractmethod
|
10
|
-
async def _hit_api_async(
|
18
|
+
async def _hit_api_async(
|
19
|
+
self,
|
20
|
+
messages: List[Dict[str, Any]],
|
21
|
+
response_model_override: Optional[BaseModel] = None,
|
22
|
+
) -> str:
|
11
23
|
pass
|
12
|
-
|
24
|
+
|
13
25
|
@abstractmethod
|
14
|
-
def _hit_api_sync(
|
15
|
-
|
26
|
+
def _hit_api_sync(
|
27
|
+
self,
|
28
|
+
messages: List[Dict[str, Any]],
|
29
|
+
response_model_override: Optional[BaseModel] = None,
|
30
|
+
) -> str:
|
31
|
+
pass
|