synth-ai 0.1.0.dev28__py3-none-any.whl → 0.1.0.dev30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- public_tests/test_agent.py +11 -11
- public_tests/test_all_structured_outputs.py +32 -37
- public_tests/test_anthropic_structured_outputs.py +0 -0
- public_tests/test_deepseek_structured_outputs.py +0 -0
- public_tests/test_deepseek_tools.py +64 -0
- public_tests/test_gemini_structured_outputs.py +106 -0
- public_tests/test_models.py +27 -27
- public_tests/test_openai_structured_outputs.py +106 -0
- public_tests/test_reasoning_models.py +9 -7
- public_tests/test_recursive_structured_outputs.py +30 -30
- public_tests/test_structured.py +137 -0
- public_tests/test_structured_outputs.py +22 -13
- public_tests/test_text.py +160 -0
- public_tests/test_tools.py +300 -0
- synth_ai/__init__.py +1 -4
- synth_ai/zyk/__init__.py +2 -2
- synth_ai/zyk/lms/caching/ephemeral.py +54 -32
- synth_ai/zyk/lms/caching/handler.py +43 -15
- synth_ai/zyk/lms/caching/persistent.py +55 -27
- synth_ai/zyk/lms/core/main.py +29 -16
- synth_ai/zyk/lms/core/vendor_clients.py +1 -1
- synth_ai/zyk/lms/structured_outputs/handler.py +79 -45
- synth_ai/zyk/lms/structured_outputs/rehabilitate.py +3 -2
- synth_ai/zyk/lms/tools/base.py +104 -0
- synth_ai/zyk/lms/vendors/base.py +22 -6
- synth_ai/zyk/lms/vendors/core/anthropic_api.py +130 -95
- synth_ai/zyk/lms/vendors/core/gemini_api.py +153 -34
- synth_ai/zyk/lms/vendors/core/mistral_api.py +160 -54
- synth_ai/zyk/lms/vendors/core/openai_api.py +64 -53
- synth_ai/zyk/lms/vendors/openai_standard.py +197 -41
- synth_ai/zyk/lms/vendors/supported/deepseek.py +55 -0
- {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/METADATA +2 -5
- synth_ai-0.1.0.dev30.dist-info/RECORD +65 -0
- public_tests/test_sonnet_thinking.py +0 -217
- synth_ai-0.1.0.dev28.dist-info/RECORD +0 -57
- {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,16 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
|
-
from typing import Any, Dict, List, Tuple, Type
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
4
4
|
|
5
5
|
import pydantic
|
6
6
|
from mistralai import Mistral # use Mistral as both sync and async client
|
7
7
|
from pydantic import BaseModel
|
8
8
|
|
9
9
|
from synth_ai.zyk.lms.caching.initialize import get_cache_handler
|
10
|
-
from synth_ai.zyk.lms.
|
10
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
11
|
+
from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
|
11
12
|
from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
|
12
13
|
from synth_ai.zyk.lms.vendors.core.openai_api import OpenAIStructuredOutputClient
|
13
|
-
from synth_ai.zyk.lms.vendors.retries import BACKOFF_TOLERANCE, backoff
|
14
14
|
|
15
15
|
# Since the mistralai package doesn't expose an exceptions module,
|
16
16
|
# we fallback to catching all Exceptions for retry.
|
@@ -31,97 +31,193 @@ class MistralAPI(VendorBase):
|
|
31
31
|
self.exceptions_to_retry = exceptions_to_retry
|
32
32
|
self._openai_fallback = None
|
33
33
|
|
34
|
-
@backoff.on_exception(
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
)
|
34
|
+
# @backoff.on_exception(
|
35
|
+
# backoff.expo,
|
36
|
+
# MISTRAL_EXCEPTIONS_TO_RETRY,
|
37
|
+
# max_tries=BACKOFF_TOLERANCE,
|
38
|
+
# on_giveup=lambda e: print(e),
|
39
|
+
# )
|
40
40
|
async def _hit_api_async(
|
41
41
|
self,
|
42
42
|
model: str,
|
43
43
|
messages: List[Dict[str, Any]],
|
44
44
|
lm_config: Dict[str, Any],
|
45
|
+
response_model: Optional[BaseModel] = None,
|
45
46
|
use_ephemeral_cache_only: bool = False,
|
46
|
-
|
47
|
+
reasoning_effort: str = "high",
|
48
|
+
tools: Optional[List[BaseTool]] = None,
|
49
|
+
) -> BaseLMResponse:
|
47
50
|
assert (
|
48
51
|
lm_config.get("response_model", None) is None
|
49
52
|
), "response_model is not supported for standard calls"
|
53
|
+
assert not (response_model and tools), "Cannot provide both response_model and tools"
|
50
54
|
used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
|
51
55
|
cache_result = used_cache_handler.hit_managed_cache(
|
52
|
-
model, messages, lm_config=lm_config
|
56
|
+
model, messages, lm_config=lm_config, tools=tools
|
53
57
|
)
|
54
58
|
if cache_result:
|
59
|
+
assert type(cache_result) in [
|
60
|
+
BaseLMResponse,
|
61
|
+
str,
|
62
|
+
], f"Expected BaseLMResponse or str, got {type(cache_result)}"
|
55
63
|
return (
|
56
|
-
cache_result
|
57
|
-
if
|
58
|
-
else
|
64
|
+
cache_result
|
65
|
+
if type(cache_result) == BaseLMResponse
|
66
|
+
else BaseLMResponse(
|
67
|
+
raw_response=cache_result, structured_output=None, tool_calls=None
|
68
|
+
)
|
59
69
|
)
|
60
70
|
|
61
71
|
mistral_messages = [
|
62
72
|
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
63
73
|
]
|
74
|
+
functions = [tool.to_mistral_tool() for tool in tools] if tools else None
|
75
|
+
params = {
|
76
|
+
"model": model,
|
77
|
+
"messages": mistral_messages,
|
78
|
+
"max_tokens": lm_config.get("max_tokens", 4096),
|
79
|
+
"temperature": lm_config.get(
|
80
|
+
"temperature", SPECIAL_BASE_TEMPS.get(model, 0)
|
81
|
+
),
|
82
|
+
"stream": False,
|
83
|
+
"tool_choice": "auto" if functions else None,
|
84
|
+
|
85
|
+
}
|
86
|
+
if response_model:
|
87
|
+
params["response_format"] = response_model
|
88
|
+
elif tools:
|
89
|
+
params["tools"] = functions
|
90
|
+
|
64
91
|
async with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
|
65
|
-
response = await client.chat.complete_async(
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
92
|
+
response = await client.chat.complete_async(**params)
|
93
|
+
|
94
|
+
message = response.choices[0].message
|
95
|
+
try:
|
96
|
+
raw_response = message.content
|
97
|
+
except AttributeError:
|
98
|
+
raw_response = ""
|
99
|
+
|
100
|
+
tool_calls = []
|
101
|
+
try:
|
102
|
+
if message.tool_calls:
|
103
|
+
tool_calls = [
|
104
|
+
{
|
105
|
+
"id": call.id,
|
106
|
+
"type": "function",
|
107
|
+
"function": {
|
108
|
+
"name": call.function.name,
|
109
|
+
"arguments": call.function.arguments,
|
110
|
+
},
|
111
|
+
}
|
112
|
+
for call in message.tool_calls
|
113
|
+
]
|
114
|
+
except AttributeError:
|
115
|
+
pass
|
116
|
+
|
117
|
+
lm_response = BaseLMResponse(
|
118
|
+
raw_response=raw_response,
|
119
|
+
structured_output=None,
|
120
|
+
tool_calls=tool_calls if tool_calls else None,
|
121
|
+
)
|
75
122
|
used_cache_handler.add_to_managed_cache(
|
76
|
-
model, messages, lm_config=lm_config, output=
|
123
|
+
model, messages, lm_config=lm_config, output=lm_response, tools=tools
|
77
124
|
)
|
78
|
-
return
|
79
|
-
|
80
|
-
@backoff.on_exception(
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
)
|
125
|
+
return lm_response
|
126
|
+
|
127
|
+
# @backoff.on_exception(
|
128
|
+
# backoff.expo,
|
129
|
+
# MISTRAL_EXCEPTIONS_TO_RETRY,
|
130
|
+
# max_tries=BACKOFF_TOLERANCE,
|
131
|
+
# on_giveup=lambda e: print(e),
|
132
|
+
# )
|
86
133
|
def _hit_api_sync(
|
87
134
|
self,
|
88
135
|
model: str,
|
89
136
|
messages: List[Dict[str, Any]],
|
90
137
|
lm_config: Dict[str, Any],
|
138
|
+
response_model: Optional[BaseModel] = None,
|
91
139
|
use_ephemeral_cache_only: bool = False,
|
92
|
-
|
140
|
+
reasoning_effort: str = "high",
|
141
|
+
tools: Optional[List[BaseTool]] = None,
|
142
|
+
) -> BaseLMResponse:
|
93
143
|
assert (
|
94
144
|
lm_config.get("response_model", None) is None
|
95
145
|
), "response_model is not supported for standard calls"
|
146
|
+
assert not (response_model and tools), "Cannot provide both response_model and tools"
|
147
|
+
|
96
148
|
used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
|
97
149
|
cache_result = used_cache_handler.hit_managed_cache(
|
98
|
-
model, messages, lm_config=lm_config
|
150
|
+
model, messages, lm_config=lm_config, tools=tools
|
99
151
|
)
|
100
152
|
if cache_result:
|
153
|
+
assert type(cache_result) in [
|
154
|
+
BaseLMResponse,
|
155
|
+
str,
|
156
|
+
], f"Expected BaseLMResponse or str, got {type(cache_result)}"
|
101
157
|
return (
|
102
|
-
cache_result
|
103
|
-
if
|
104
|
-
else
|
158
|
+
cache_result
|
159
|
+
if type(cache_result) == BaseLMResponse
|
160
|
+
else BaseLMResponse(
|
161
|
+
raw_response=cache_result, structured_output=None, tool_calls=None
|
162
|
+
)
|
105
163
|
)
|
106
164
|
|
107
165
|
mistral_messages = [
|
108
166
|
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
109
167
|
]
|
168
|
+
functions = [tool.to_mistral_tool() for tool in tools] if tools else None
|
169
|
+
|
170
|
+
params = {
|
171
|
+
"model": model,
|
172
|
+
"messages": mistral_messages,
|
173
|
+
"max_tokens": lm_config.get("max_tokens", 4096),
|
174
|
+
"temperature": lm_config.get(
|
175
|
+
"temperature", SPECIAL_BASE_TEMPS.get(model, 0)
|
176
|
+
),
|
177
|
+
"stream": False,
|
178
|
+
"tool_choice": "auto" if functions else None,
|
179
|
+
#"tools": functions,
|
180
|
+
}
|
181
|
+
if response_model:
|
182
|
+
params["response_format"] = response_model
|
183
|
+
elif tools:
|
184
|
+
params["tools"] = functions
|
185
|
+
|
110
186
|
with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
|
111
|
-
response = client.chat.complete(
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
187
|
+
response = client.chat.complete(**params)
|
188
|
+
|
189
|
+
message = response.choices[0].message
|
190
|
+
try:
|
191
|
+
raw_response = message.content
|
192
|
+
except AttributeError:
|
193
|
+
raw_response = ""
|
194
|
+
|
195
|
+
tool_calls = []
|
196
|
+
try:
|
197
|
+
if message.tool_calls:
|
198
|
+
tool_calls = [
|
199
|
+
{
|
200
|
+
"id": call.id,
|
201
|
+
"type": "function",
|
202
|
+
"function": {
|
203
|
+
"name": call.function.name,
|
204
|
+
"arguments": call.function.arguments,
|
205
|
+
},
|
206
|
+
}
|
207
|
+
for call in message.tool_calls
|
208
|
+
]
|
209
|
+
except AttributeError:
|
210
|
+
pass
|
211
|
+
|
212
|
+
lm_response = BaseLMResponse(
|
213
|
+
raw_response=raw_response,
|
214
|
+
structured_output=None,
|
215
|
+
tool_calls=tool_calls if tool_calls else None,
|
216
|
+
)
|
121
217
|
used_cache_handler.add_to_managed_cache(
|
122
|
-
model, messages, lm_config=lm_config, output=
|
218
|
+
model, messages, lm_config=lm_config, output=lm_response, tools=tools
|
123
219
|
)
|
124
|
-
return
|
220
|
+
return lm_response
|
125
221
|
|
126
222
|
async def _hit_api_async_structured_output(
|
127
223
|
self,
|
@@ -130,7 +226,7 @@ class MistralAPI(VendorBase):
|
|
130
226
|
response_model: BaseModel,
|
131
227
|
temperature: float,
|
132
228
|
use_ephemeral_cache_only: bool = False,
|
133
|
-
) ->
|
229
|
+
) -> BaseLMResponse:
|
134
230
|
try:
|
135
231
|
mistral_messages = [
|
136
232
|
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
@@ -145,7 +241,12 @@ class MistralAPI(VendorBase):
|
|
145
241
|
)
|
146
242
|
result = response.choices[0].message.content
|
147
243
|
parsed = json.loads(result)
|
148
|
-
|
244
|
+
lm_response = BaseLMResponse(
|
245
|
+
raw_response="",
|
246
|
+
structured_output=response_model(**parsed),
|
247
|
+
tool_calls=None,
|
248
|
+
)
|
249
|
+
return lm_response
|
149
250
|
except (json.JSONDecodeError, pydantic.ValidationError):
|
150
251
|
if self._openai_fallback is None:
|
151
252
|
self._openai_fallback = OpenAIStructuredOutputClient()
|
@@ -164,7 +265,7 @@ class MistralAPI(VendorBase):
|
|
164
265
|
response_model: BaseModel,
|
165
266
|
temperature: float,
|
166
267
|
use_ephemeral_cache_only: bool = False,
|
167
|
-
) ->
|
268
|
+
) -> BaseLMResponse:
|
168
269
|
try:
|
169
270
|
mistral_messages = [
|
170
271
|
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
@@ -179,7 +280,12 @@ class MistralAPI(VendorBase):
|
|
179
280
|
)
|
180
281
|
result = response.choices[0].message.content
|
181
282
|
parsed = json.loads(result)
|
182
|
-
|
283
|
+
lm_response = BaseLMResponse(
|
284
|
+
raw_response="",
|
285
|
+
structured_output=response_model(**parsed),
|
286
|
+
tool_calls=None,
|
287
|
+
)
|
288
|
+
return lm_response
|
183
289
|
except (json.JSONDecodeError, pydantic.ValidationError):
|
184
290
|
print("WARNING - Falling back to OpenAI - THIS IS SLOW")
|
185
291
|
if self._openai_fallback is None:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Any, Dict, List, Tuple, Type
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
3
3
|
|
4
4
|
import openai
|
5
5
|
import pydantic_core
|
@@ -8,6 +8,8 @@ import pydantic_core
|
|
8
8
|
from pydantic import BaseModel
|
9
9
|
|
10
10
|
from synth_ai.zyk.lms.caching.initialize import get_cache_handler
|
11
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
12
|
+
from synth_ai.zyk.lms.vendors.base import BaseLMResponse
|
11
13
|
from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
|
12
14
|
from synth_ai.zyk.lms.vendors.openai_standard import OpenAIStandard
|
13
15
|
|
@@ -46,8 +48,11 @@ class OpenAIStructuredOutputClient(OpenAIStandard):
|
|
46
48
|
response_model: BaseModel,
|
47
49
|
temperature: float,
|
48
50
|
use_ephemeral_cache_only: bool = False,
|
51
|
+
tools: Optional[List[BaseTool]] = None,
|
49
52
|
reasoning_effort: str = "high",
|
50
53
|
) -> str:
|
54
|
+
if tools:
|
55
|
+
raise ValueError("Tools are not supported for async structured output")
|
51
56
|
# "Hit client")
|
52
57
|
lm_config = {"temperature": temperature, "response_model": response_model}
|
53
58
|
used_cache_handler = get_cache_handler(
|
@@ -58,38 +63,40 @@ class OpenAIStructuredOutputClient(OpenAIStandard):
|
|
58
63
|
)
|
59
64
|
if cache_result:
|
60
65
|
# print("Hit cache")
|
66
|
+
assert type(cache_result) in [
|
67
|
+
dict,
|
68
|
+
BaseLMResponse,
|
69
|
+
], f"Expected dict or BaseLMResponse, got {type(cache_result)}"
|
61
70
|
return (
|
62
|
-
cache_result["response"]
|
63
|
-
if isinstance(cache_result, dict)
|
64
|
-
else cache_result
|
71
|
+
cache_result["response"] if type(cache_result) == dict else cache_result
|
65
72
|
)
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
73
|
+
if model in ["o3-mini", "o3", "o1-mini", "o1"]:
|
74
|
+
output = await self.async_client.beta.chat.completions.parse(
|
75
|
+
model=model,
|
76
|
+
messages=messages,
|
77
|
+
temperature=lm_config.get(
|
78
|
+
"temperature", SPECIAL_BASE_TEMPS.get(model, 0)
|
79
|
+
),
|
80
|
+
response_format=response_model,
|
81
|
+
reasoning_effort=reasoning_effort,
|
82
|
+
)
|
83
|
+
else:
|
84
|
+
output = await self.async_client.beta.chat.completions.parse(
|
85
|
+
model=model,
|
86
|
+
messages=messages,
|
87
|
+
response_format=response_model,
|
78
88
|
)
|
79
|
-
|
80
|
-
# Add reasoning_effort only for o3-mini
|
81
|
-
if "o3-mini" in model:
|
82
|
-
#print("Reasoning effort:", reasoning_effort)
|
83
|
-
api_params["reasoning_effort"] = reasoning_effort
|
84
|
-
|
85
|
-
output = await self.async_client.beta.chat.completions.parse(**api_params)
|
86
|
-
|
87
89
|
# "Output", output)
|
88
90
|
api_result = response_model(**json.loads(output.choices[0].message.content))
|
91
|
+
lm_response = BaseLMResponse(
|
92
|
+
raw_response="",
|
93
|
+
structured_output=api_result,
|
94
|
+
tool_calls=None,
|
95
|
+
)
|
89
96
|
used_cache_handler.add_to_managed_cache(
|
90
|
-
model, messages, lm_config, output=
|
97
|
+
model, messages, lm_config, output=lm_response
|
91
98
|
)
|
92
|
-
return
|
99
|
+
return lm_response
|
93
100
|
|
94
101
|
def _hit_api_sync_structured_output(
|
95
102
|
self,
|
@@ -98,8 +105,11 @@ class OpenAIStructuredOutputClient(OpenAIStandard):
|
|
98
105
|
response_model: BaseModel,
|
99
106
|
temperature: float,
|
100
107
|
use_ephemeral_cache_only: bool = False,
|
108
|
+
tools: Optional[List[BaseTool]] = None,
|
101
109
|
reasoning_effort: str = "high",
|
102
110
|
) -> str:
|
111
|
+
if tools:
|
112
|
+
raise ValueError("Tools are not supported for sync structured output")
|
103
113
|
lm_config = {"temperature": temperature, "response_model": response_model}
|
104
114
|
used_cache_handler = get_cache_handler(
|
105
115
|
use_ephemeral_cache_only=use_ephemeral_cache_only
|
@@ -108,39 +118,40 @@ class OpenAIStructuredOutputClient(OpenAIStandard):
|
|
108
118
|
model, messages, lm_config=lm_config
|
109
119
|
)
|
110
120
|
if cache_result:
|
121
|
+
assert type(cache_result) in [
|
122
|
+
dict,
|
123
|
+
BaseLMResponse,
|
124
|
+
], f"Expected dict or BaseLMResponse, got {type(cache_result)}"
|
111
125
|
return (
|
112
|
-
cache_result["response"]
|
113
|
-
if isinstance(cache_result, dict)
|
114
|
-
else cache_result
|
126
|
+
cache_result["response"] if type(cache_result) == dict else cache_result
|
115
127
|
)
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
+
if model in ["o3-mini", "o3", "o1-mini", "o1"]:
|
129
|
+
output = self.sync_client.beta.chat.completions.parse(
|
130
|
+
model=model,
|
131
|
+
messages=messages,
|
132
|
+
temperature=lm_config.get(
|
133
|
+
"temperature", SPECIAL_BASE_TEMPS.get(model, 0)
|
134
|
+
),
|
135
|
+
response_format=response_model,
|
136
|
+
reasoning_effort=reasoning_effort,
|
137
|
+
)
|
138
|
+
else:
|
139
|
+
output = self.sync_client.beta.chat.completions.parse(
|
140
|
+
model=model,
|
141
|
+
messages=messages,
|
142
|
+
response_format=response_model,
|
128
143
|
)
|
129
|
-
|
130
|
-
# Add reasoning_effort only for o3-mini
|
131
|
-
if model in ["o3-mini"]:
|
132
|
-
api_params["reasoning_effort"] = reasoning_effort
|
133
|
-
|
134
|
-
output = self.sync_client.beta.chat.completions.parse(**api_params)
|
135
|
-
|
136
144
|
api_result = response_model(**json.loads(output.choices[0].message.content))
|
145
|
+
|
146
|
+
lm_response = BaseLMResponse(
|
147
|
+
raw_response="",
|
148
|
+
structured_output=api_result,
|
149
|
+
tool_calls=None,
|
150
|
+
)
|
137
151
|
used_cache_handler.add_to_managed_cache(
|
138
|
-
model,
|
139
|
-
messages,
|
140
|
-
lm_config=lm_config,
|
141
|
-
output=output.choices[0].message.content,
|
152
|
+
model, messages, lm_config=lm_config, output=lm_response
|
142
153
|
)
|
143
|
-
return
|
154
|
+
return lm_response
|
144
155
|
|
145
156
|
|
146
157
|
class OpenAIPrivate(OpenAIStandard):
|