synth-ai 0.1.0.dev38__py3-none-any.whl → 0.1.0.dev49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +3 -1
- {synth_ai-0.1.0.dev38.dist-info → synth_ai-0.1.0.dev49.dist-info}/METADATA +12 -11
- synth_ai-0.1.0.dev49.dist-info/RECORD +6 -0
- {synth_ai-0.1.0.dev38.dist-info → synth_ai-0.1.0.dev49.dist-info}/WHEEL +1 -1
- synth_ai-0.1.0.dev49.dist-info/top_level.txt +1 -0
- private_tests/try_synth_sdk.py +0 -1
- public_tests/test_agent.py +0 -538
- public_tests/test_all_structured_outputs.py +0 -196
- public_tests/test_anthropic_structured_outputs.py +0 -0
- public_tests/test_deepseek_structured_outputs.py +0 -0
- public_tests/test_deepseek_tools.py +0 -64
- public_tests/test_gemini_output.py +0 -188
- public_tests/test_gemini_structured_outputs.py +0 -106
- public_tests/test_models.py +0 -183
- public_tests/test_openai_structured_outputs.py +0 -106
- public_tests/test_reasoning_effort.py +0 -75
- public_tests/test_reasoning_models.py +0 -92
- public_tests/test_recursive_structured_outputs.py +0 -180
- public_tests/test_structured.py +0 -137
- public_tests/test_structured_outputs.py +0 -109
- public_tests/test_synth_sdk.py +0 -384
- public_tests/test_text.py +0 -160
- public_tests/test_tools.py +0 -319
- synth_ai/zyk/__init__.py +0 -3
- synth_ai/zyk/lms/__init__.py +0 -0
- synth_ai/zyk/lms/caching/__init__.py +0 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/caching/dbs.py +0 -0
- synth_ai/zyk/lms/caching/ephemeral.py +0 -72
- synth_ai/zyk/lms/caching/handler.py +0 -142
- synth_ai/zyk/lms/caching/initialize.py +0 -13
- synth_ai/zyk/lms/caching/persistent.py +0 -83
- synth_ai/zyk/lms/config.py +0 -8
- synth_ai/zyk/lms/core/__init__.py +0 -0
- synth_ai/zyk/lms/core/all.py +0 -47
- synth_ai/zyk/lms/core/exceptions.py +0 -9
- synth_ai/zyk/lms/core/main.py +0 -314
- synth_ai/zyk/lms/core/vendor_clients.py +0 -85
- synth_ai/zyk/lms/cost/__init__.py +0 -0
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai/zyk/lms/structured_outputs/__init__.py +0 -0
- synth_ai/zyk/lms/structured_outputs/handler.py +0 -442
- synth_ai/zyk/lms/structured_outputs/inject.py +0 -314
- synth_ai/zyk/lms/structured_outputs/rehabilitate.py +0 -187
- synth_ai/zyk/lms/tools/base.py +0 -104
- synth_ai/zyk/lms/vendors/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/base.py +0 -31
- synth_ai/zyk/lms/vendors/constants.py +0 -22
- synth_ai/zyk/lms/vendors/core/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/core/anthropic_api.py +0 -413
- synth_ai/zyk/lms/vendors/core/gemini_api.py +0 -306
- synth_ai/zyk/lms/vendors/core/mistral_api.py +0 -327
- synth_ai/zyk/lms/vendors/core/openai_api.py +0 -185
- synth_ai/zyk/lms/vendors/local/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/local/ollama.py +0 -0
- synth_ai/zyk/lms/vendors/openai_standard.py +0 -375
- synth_ai/zyk/lms/vendors/retries.py +0 -3
- synth_ai/zyk/lms/vendors/supported/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/supported/deepseek.py +0 -73
- synth_ai/zyk/lms/vendors/supported/groq.py +0 -16
- synth_ai/zyk/lms/vendors/supported/ollama.py +0 -14
- synth_ai/zyk/lms/vendors/supported/together.py +0 -11
- synth_ai-0.1.0.dev38.dist-info/RECORD +0 -67
- synth_ai-0.1.0.dev38.dist-info/top_level.txt +0 -4
- tests/test_agent.py +0 -538
- tests/test_recursive_structured_outputs.py +0 -180
- tests/test_structured_outputs.py +0 -100
- {synth_ai-0.1.0.dev38.dist-info → synth_ai-0.1.0.dev49.dist-info}/licenses/LICENSE +0 -0
@@ -1,306 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import logging
|
3
|
-
import os
|
4
|
-
import warnings
|
5
|
-
from typing import Any, Dict, List, Optional, Tuple, Type
|
6
|
-
|
7
|
-
import google.generativeai as genai
|
8
|
-
from google.api_core.exceptions import ResourceExhausted
|
9
|
-
from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool
|
10
|
-
|
11
|
-
from synth_ai.zyk.lms.caching.initialize import (
|
12
|
-
get_cache_handler,
|
13
|
-
)
|
14
|
-
from synth_ai.zyk.lms.tools.base import BaseTool
|
15
|
-
from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
|
16
|
-
from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
|
17
|
-
from synth_ai.zyk.lms.vendors.retries import BACKOFF_TOLERANCE, backoff
|
18
|
-
|
19
|
-
GEMINI_EXCEPTIONS_TO_RETRY: Tuple[Type[Exception], ...] = (ResourceExhausted,)
|
20
|
-
logging.getLogger("google.generativeai").setLevel(logging.ERROR)
|
21
|
-
os.environ["GRPC_VERBOSITY"] = "ERROR"
|
22
|
-
# Suppress TensorFlow logging
|
23
|
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
24
|
-
warnings.filterwarnings("ignore")
|
25
|
-
|
26
|
-
SAFETY_SETTINGS = {
|
27
|
-
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
28
|
-
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
29
|
-
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
30
|
-
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
31
|
-
}
|
32
|
-
|
33
|
-
|
34
|
-
class GeminiAPI(VendorBase):
|
35
|
-
used_for_structured_outputs: bool = True
|
36
|
-
exceptions_to_retry: Tuple[Type[Exception], ...] = GEMINI_EXCEPTIONS_TO_RETRY
|
37
|
-
|
38
|
-
def __init__(
|
39
|
-
self,
|
40
|
-
exceptions_to_retry: Tuple[Type[Exception], ...] = GEMINI_EXCEPTIONS_TO_RETRY,
|
41
|
-
used_for_structured_outputs: bool = False,
|
42
|
-
):
|
43
|
-
self.used_for_structured_outputs = used_for_structured_outputs
|
44
|
-
self.exceptions_to_retry = exceptions_to_retry
|
45
|
-
|
46
|
-
def _convert_messages_to_contents(
|
47
|
-
self, messages: List[Dict[str, Any]]
|
48
|
-
) -> List[Dict[str, Any]]:
|
49
|
-
contents = []
|
50
|
-
system_instruction = None
|
51
|
-
for message in messages:
|
52
|
-
if message["role"] == "system":
|
53
|
-
system_instruction = (
|
54
|
-
f"<instructions>\n{message['content']}\n</instructions>"
|
55
|
-
)
|
56
|
-
continue
|
57
|
-
elif system_instruction:
|
58
|
-
text = system_instruction + "\n" + message["content"]
|
59
|
-
else:
|
60
|
-
text = message["content"]
|
61
|
-
contents.append(
|
62
|
-
{
|
63
|
-
"role": message["role"],
|
64
|
-
"parts": [{"text": text}],
|
65
|
-
}
|
66
|
-
)
|
67
|
-
return contents
|
68
|
-
|
69
|
-
def _convert_tools_to_gemini_format(self, tools: List[Any]) -> Tool:
|
70
|
-
function_declarations = []
|
71
|
-
for tool in tools:
|
72
|
-
# Try to use to_gemini_tool method if available, otherwise assume it's a dict
|
73
|
-
try:
|
74
|
-
function_declarations.append(tool.to_gemini_tool())
|
75
|
-
except AttributeError:
|
76
|
-
# If tool is a properly formatted dict, use it directly
|
77
|
-
if "name" in tool and "parameters" in tool:
|
78
|
-
function_declarations.append(tool)
|
79
|
-
else:
|
80
|
-
raise ValueError(
|
81
|
-
f"Unsupported tool format. Tools must be BaseTool instances or properly formatted dictionaries."
|
82
|
-
)
|
83
|
-
return Tool(function_declarations=function_declarations)
|
84
|
-
|
85
|
-
def _convert_args_to_dict(self, args):
|
86
|
-
"""
|
87
|
-
Recursively convert Gemini's args objects to Python dictionaries.
|
88
|
-
"""
|
89
|
-
# Try to convert dict-like objects
|
90
|
-
try:
|
91
|
-
return {k: self._convert_args_to_dict(v) for k, v in args.items()}
|
92
|
-
except (AttributeError, TypeError):
|
93
|
-
# Try to convert list-like objects
|
94
|
-
try:
|
95
|
-
if isinstance(args, (str, bytes)):
|
96
|
-
return args
|
97
|
-
return [self._convert_args_to_dict(item) for item in args]
|
98
|
-
except (TypeError, AttributeError):
|
99
|
-
# Base case: primitive value
|
100
|
-
return args
|
101
|
-
|
102
|
-
async def _private_request_async(
|
103
|
-
self,
|
104
|
-
messages: List[Dict],
|
105
|
-
temperature: float = 0,
|
106
|
-
model_name: str = "gemini-1.5-flash",
|
107
|
-
reasoning_effort: str = "high",
|
108
|
-
tools: Optional[List[BaseTool]] = None,
|
109
|
-
lm_config: Optional[Dict[str, Any]] = None,
|
110
|
-
) -> Tuple[str, Optional[List[Dict]]]:
|
111
|
-
generation_config = {
|
112
|
-
"temperature": temperature,
|
113
|
-
}
|
114
|
-
# Add max_output_tokens if max_tokens is in lm_config
|
115
|
-
if lm_config and "max_tokens" in lm_config:
|
116
|
-
generation_config["max_output_tokens"] = lm_config["max_tokens"]
|
117
|
-
|
118
|
-
tools_config = None
|
119
|
-
if tools:
|
120
|
-
tools_config = self._convert_tools_to_gemini_format(tools)
|
121
|
-
|
122
|
-
# Extract tool_config from lm_config if provided
|
123
|
-
tool_config = (
|
124
|
-
lm_config.get("tool_config")
|
125
|
-
if lm_config
|
126
|
-
else {"function_calling_config": {"mode": "any"}}
|
127
|
-
)
|
128
|
-
|
129
|
-
code_generation_model = genai.GenerativeModel(
|
130
|
-
model_name=model_name,
|
131
|
-
generation_config=generation_config,
|
132
|
-
tools=tools_config if tools_config else None,
|
133
|
-
tool_config=tool_config,
|
134
|
-
)
|
135
|
-
|
136
|
-
contents = self._convert_messages_to_contents(messages)
|
137
|
-
result = await code_generation_model.generate_content_async(
|
138
|
-
contents=contents,
|
139
|
-
safety_settings=SAFETY_SETTINGS,
|
140
|
-
)
|
141
|
-
|
142
|
-
text = result.candidates[0].content.parts[0].text
|
143
|
-
tool_calls = []
|
144
|
-
for part in result.candidates[0].content.parts:
|
145
|
-
if part.function_call:
|
146
|
-
# Convert complex objects to Python dictionaries recursively
|
147
|
-
args_dict = self._convert_args_to_dict(part.function_call.args)
|
148
|
-
# Ensure serializable arguments
|
149
|
-
tool_calls.append(
|
150
|
-
{
|
151
|
-
"id": f"call_{len(tool_calls) + 1}", # Generate unique IDs
|
152
|
-
"type": "function",
|
153
|
-
"function": {
|
154
|
-
"name": part.function_call.name,
|
155
|
-
"arguments": json.dumps(args_dict),
|
156
|
-
},
|
157
|
-
}
|
158
|
-
)
|
159
|
-
return text, tool_calls if tool_calls else None
|
160
|
-
|
161
|
-
def _private_request_sync(
|
162
|
-
self,
|
163
|
-
messages: List[Dict],
|
164
|
-
temperature: float = 0,
|
165
|
-
model_name: str = "gemini-1.5-flash",
|
166
|
-
reasoning_effort: str = "high",
|
167
|
-
tools: Optional[List[BaseTool]] = None,
|
168
|
-
lm_config: Optional[Dict[str, Any]] = None,
|
169
|
-
) -> Tuple[str, Optional[List[Dict]]]:
|
170
|
-
generation_config = {
|
171
|
-
"temperature": temperature,
|
172
|
-
}
|
173
|
-
# Add max_output_tokens if max_tokens is in lm_config
|
174
|
-
if lm_config and "max_tokens" in lm_config:
|
175
|
-
generation_config["max_output_tokens"] = lm_config["max_tokens"]
|
176
|
-
|
177
|
-
tools_config = None
|
178
|
-
if tools:
|
179
|
-
tools_config = self._convert_tools_to_gemini_format(tools)
|
180
|
-
|
181
|
-
# Extract tool_config from lm_config if provided
|
182
|
-
tool_config = (
|
183
|
-
lm_config.get("tool_config")
|
184
|
-
if lm_config
|
185
|
-
else {"function_calling_config": {"mode": "any"}}
|
186
|
-
)
|
187
|
-
|
188
|
-
code_generation_model = genai.GenerativeModel(
|
189
|
-
model_name=model_name,
|
190
|
-
generation_config=generation_config,
|
191
|
-
tools=tools_config if tools_config else None,
|
192
|
-
tool_config=tool_config,
|
193
|
-
)
|
194
|
-
|
195
|
-
contents = self._convert_messages_to_contents(messages)
|
196
|
-
result = code_generation_model.generate_content(
|
197
|
-
contents=contents,
|
198
|
-
safety_settings=SAFETY_SETTINGS,
|
199
|
-
)
|
200
|
-
|
201
|
-
text = result.candidates[0].content.parts[0].text
|
202
|
-
tool_calls = []
|
203
|
-
for part in result.candidates[0].content.parts:
|
204
|
-
if part.function_call:
|
205
|
-
# Convert complex objects to Python dictionaries recursively
|
206
|
-
args_dict = self._convert_args_to_dict(part.function_call.args)
|
207
|
-
# Ensure serializable arguments
|
208
|
-
tool_calls.append(
|
209
|
-
{
|
210
|
-
"id": f"call_{len(tool_calls) + 1}", # Generate unique IDs
|
211
|
-
"type": "function",
|
212
|
-
"function": {
|
213
|
-
"name": part.function_call.name,
|
214
|
-
"arguments": json.dumps(args_dict),
|
215
|
-
},
|
216
|
-
}
|
217
|
-
)
|
218
|
-
return text, tool_calls if tool_calls else None
|
219
|
-
|
220
|
-
@backoff.on_exception(
|
221
|
-
backoff.expo,
|
222
|
-
exceptions_to_retry,
|
223
|
-
max_tries=BACKOFF_TOLERANCE,
|
224
|
-
on_giveup=lambda e: print(e),
|
225
|
-
)
|
226
|
-
async def _hit_api_async(
|
227
|
-
self,
|
228
|
-
model: str,
|
229
|
-
messages: List[Dict[str, Any]],
|
230
|
-
lm_config: Dict[str, Any],
|
231
|
-
use_ephemeral_cache_only: bool = False,
|
232
|
-
reasoning_effort: str = "high",
|
233
|
-
tools: Optional[List[BaseTool]] = None,
|
234
|
-
) -> BaseLMResponse:
|
235
|
-
assert (
|
236
|
-
lm_config.get("response_model", None) is None
|
237
|
-
), "response_model is not supported for standard calls"
|
238
|
-
used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
|
239
|
-
cache_result = used_cache_handler.hit_managed_cache(
|
240
|
-
model, messages, lm_config=lm_config, tools=tools, reasoning_effort=reasoning_effort
|
241
|
-
)
|
242
|
-
if cache_result:
|
243
|
-
return cache_result
|
244
|
-
|
245
|
-
raw_response, tool_calls = await self._private_request_async(
|
246
|
-
messages,
|
247
|
-
temperature=lm_config.get("temperature", SPECIAL_BASE_TEMPS.get(model, 0)),
|
248
|
-
reasoning_effort=reasoning_effort,
|
249
|
-
tools=tools,
|
250
|
-
)
|
251
|
-
|
252
|
-
lm_response = BaseLMResponse(
|
253
|
-
raw_response=raw_response,
|
254
|
-
structured_output=None,
|
255
|
-
tool_calls=tool_calls,
|
256
|
-
)
|
257
|
-
|
258
|
-
used_cache_handler.add_to_managed_cache(
|
259
|
-
model, messages, lm_config=lm_config, output=lm_response, tools=tools, reasoning_effort=reasoning_effort
|
260
|
-
)
|
261
|
-
return lm_response
|
262
|
-
|
263
|
-
@backoff.on_exception(
|
264
|
-
backoff.expo,
|
265
|
-
exceptions_to_retry,
|
266
|
-
max_tries=BACKOFF_TOLERANCE,
|
267
|
-
on_giveup=lambda e: print(e),
|
268
|
-
)
|
269
|
-
def _hit_api_sync(
|
270
|
-
self,
|
271
|
-
model: str,
|
272
|
-
messages: List[Dict[str, Any]],
|
273
|
-
lm_config: Dict[str, Any],
|
274
|
-
use_ephemeral_cache_only: bool = False,
|
275
|
-
reasoning_effort: str = "high",
|
276
|
-
tools: Optional[List[BaseTool]] = None,
|
277
|
-
) -> BaseLMResponse:
|
278
|
-
assert (
|
279
|
-
lm_config.get("response_model", None) is None
|
280
|
-
), "response_model is not supported for standard calls"
|
281
|
-
used_cache_handler = get_cache_handler(
|
282
|
-
use_ephemeral_cache_only=use_ephemeral_cache_only
|
283
|
-
)
|
284
|
-
cache_result = used_cache_handler.hit_managed_cache(
|
285
|
-
model, messages, lm_config=lm_config, tools=tools, reasoning_effort=reasoning_effort
|
286
|
-
)
|
287
|
-
if cache_result:
|
288
|
-
return cache_result
|
289
|
-
|
290
|
-
raw_response, tool_calls = self._private_request_sync(
|
291
|
-
messages,
|
292
|
-
temperature=lm_config.get("temperature", SPECIAL_BASE_TEMPS.get(model, 0)),
|
293
|
-
reasoning_effort=reasoning_effort,
|
294
|
-
tools=tools,
|
295
|
-
)
|
296
|
-
|
297
|
-
lm_response = BaseLMResponse(
|
298
|
-
raw_response=raw_response,
|
299
|
-
structured_output=None,
|
300
|
-
tool_calls=tool_calls,
|
301
|
-
)
|
302
|
-
|
303
|
-
used_cache_handler.add_to_managed_cache(
|
304
|
-
model, messages, lm_config=lm_config, output=lm_response, tools=tools, reasoning_effort=reasoning_effort
|
305
|
-
)
|
306
|
-
return lm_response
|
@@ -1,327 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import os
|
3
|
-
from typing import Any, Dict, List, Optional, Tuple, Type
|
4
|
-
|
5
|
-
import pydantic
|
6
|
-
from mistralai import Mistral # use Mistral as both sync and async client
|
7
|
-
from pydantic import BaseModel
|
8
|
-
|
9
|
-
from synth_ai.zyk.lms.caching.initialize import get_cache_handler
|
10
|
-
from synth_ai.zyk.lms.tools.base import BaseTool
|
11
|
-
from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
|
12
|
-
from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
|
13
|
-
from synth_ai.zyk.lms.vendors.core.openai_api import OpenAIStructuredOutputClient
|
14
|
-
|
15
|
-
# Since the mistralai package doesn't expose an exceptions module,
|
16
|
-
# we fallback to catching all Exceptions for retry.
|
17
|
-
MISTRAL_EXCEPTIONS_TO_RETRY: Tuple[Type[Exception], ...] = (Exception,)
|
18
|
-
|
19
|
-
|
20
|
-
class MistralAPI(VendorBase):
|
21
|
-
used_for_structured_outputs: bool = True
|
22
|
-
exceptions_to_retry: Tuple = MISTRAL_EXCEPTIONS_TO_RETRY
|
23
|
-
_openai_fallback: Any
|
24
|
-
|
25
|
-
def __init__(
|
26
|
-
self,
|
27
|
-
exceptions_to_retry: Tuple[Type[Exception], ...] = MISTRAL_EXCEPTIONS_TO_RETRY,
|
28
|
-
used_for_structured_outputs: bool = False,
|
29
|
-
):
|
30
|
-
self.used_for_structured_outputs = used_for_structured_outputs
|
31
|
-
self.exceptions_to_retry = exceptions_to_retry
|
32
|
-
self._openai_fallback = None
|
33
|
-
|
34
|
-
# @backoff.on_exception(
|
35
|
-
# backoff.expo,
|
36
|
-
# MISTRAL_EXCEPTIONS_TO_RETRY,
|
37
|
-
# max_tries=BACKOFF_TOLERANCE,
|
38
|
-
# on_giveup=lambda e: print(e),
|
39
|
-
# )
|
40
|
-
async def _hit_api_async(
|
41
|
-
self,
|
42
|
-
model: str,
|
43
|
-
messages: List[Dict[str, Any]],
|
44
|
-
lm_config: Dict[str, Any],
|
45
|
-
response_model: Optional[BaseModel] = None,
|
46
|
-
use_ephemeral_cache_only: bool = False,
|
47
|
-
reasoning_effort: str = "high",
|
48
|
-
tools: Optional[List[BaseTool]] = None,
|
49
|
-
) -> BaseLMResponse:
|
50
|
-
assert (
|
51
|
-
lm_config.get("response_model", None) is None
|
52
|
-
), "response_model is not supported for standard calls"
|
53
|
-
assert not (response_model and tools), "Cannot provide both response_model and tools"
|
54
|
-
used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
|
55
|
-
cache_result = used_cache_handler.hit_managed_cache(
|
56
|
-
model, messages, lm_config=lm_config, tools=tools
|
57
|
-
)
|
58
|
-
if cache_result:
|
59
|
-
assert type(cache_result) in [
|
60
|
-
BaseLMResponse,
|
61
|
-
str,
|
62
|
-
], f"Expected BaseLMResponse or str, got {type(cache_result)}"
|
63
|
-
return (
|
64
|
-
cache_result
|
65
|
-
if type(cache_result) == BaseLMResponse
|
66
|
-
else BaseLMResponse(
|
67
|
-
raw_response=cache_result, structured_output=None, tool_calls=None
|
68
|
-
)
|
69
|
-
)
|
70
|
-
|
71
|
-
mistral_messages = [
|
72
|
-
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
73
|
-
]
|
74
|
-
functions = [tool.to_mistral_tool() for tool in tools] if tools else None
|
75
|
-
params = {
|
76
|
-
"model": model,
|
77
|
-
"messages": mistral_messages,
|
78
|
-
"max_tokens": lm_config.get("max_tokens", 4096),
|
79
|
-
"temperature": lm_config.get(
|
80
|
-
"temperature", SPECIAL_BASE_TEMPS.get(model, 0)
|
81
|
-
),
|
82
|
-
"stream": False,
|
83
|
-
"tool_choice": "auto" if functions else None,
|
84
|
-
|
85
|
-
}
|
86
|
-
if response_model:
|
87
|
-
params["response_format"] = response_model
|
88
|
-
elif tools:
|
89
|
-
params["tools"] = functions
|
90
|
-
|
91
|
-
async with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
|
92
|
-
response = await client.chat.complete_async(**params)
|
93
|
-
|
94
|
-
message = response.choices[0].message
|
95
|
-
try:
|
96
|
-
raw_response = message.content
|
97
|
-
except AttributeError:
|
98
|
-
raw_response = ""
|
99
|
-
|
100
|
-
tool_calls = []
|
101
|
-
try:
|
102
|
-
if message.tool_calls:
|
103
|
-
tool_calls = [
|
104
|
-
{
|
105
|
-
"id": call.id,
|
106
|
-
"type": "function",
|
107
|
-
"function": {
|
108
|
-
"name": call.function.name,
|
109
|
-
"arguments": call.function.arguments,
|
110
|
-
},
|
111
|
-
}
|
112
|
-
for call in message.tool_calls
|
113
|
-
]
|
114
|
-
except AttributeError:
|
115
|
-
pass
|
116
|
-
|
117
|
-
lm_response = BaseLMResponse(
|
118
|
-
raw_response=raw_response,
|
119
|
-
structured_output=None,
|
120
|
-
tool_calls=tool_calls if tool_calls else None,
|
121
|
-
)
|
122
|
-
used_cache_handler.add_to_managed_cache(
|
123
|
-
model, messages, lm_config=lm_config, output=lm_response, tools=tools
|
124
|
-
)
|
125
|
-
return lm_response
|
126
|
-
|
127
|
-
# @backoff.on_exception(
|
128
|
-
# backoff.expo,
|
129
|
-
# MISTRAL_EXCEPTIONS_TO_RETRY,
|
130
|
-
# max_tries=BACKOFF_TOLERANCE,
|
131
|
-
# on_giveup=lambda e: print(e),
|
132
|
-
# )
|
133
|
-
def _hit_api_sync(
|
134
|
-
self,
|
135
|
-
model: str,
|
136
|
-
messages: List[Dict[str, Any]],
|
137
|
-
lm_config: Dict[str, Any],
|
138
|
-
response_model: Optional[BaseModel] = None,
|
139
|
-
use_ephemeral_cache_only: bool = False,
|
140
|
-
reasoning_effort: str = "high",
|
141
|
-
tools: Optional[List[BaseTool]] = None,
|
142
|
-
) -> BaseLMResponse:
|
143
|
-
assert (
|
144
|
-
lm_config.get("response_model", None) is None
|
145
|
-
), "response_model is not supported for standard calls"
|
146
|
-
assert not (response_model and tools), "Cannot provide both response_model and tools"
|
147
|
-
|
148
|
-
used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
|
149
|
-
cache_result = used_cache_handler.hit_managed_cache(
|
150
|
-
model, messages, lm_config=lm_config, tools=tools
|
151
|
-
)
|
152
|
-
if cache_result:
|
153
|
-
assert type(cache_result) in [
|
154
|
-
BaseLMResponse,
|
155
|
-
str,
|
156
|
-
], f"Expected BaseLMResponse or str, got {type(cache_result)}"
|
157
|
-
return (
|
158
|
-
cache_result
|
159
|
-
if type(cache_result) == BaseLMResponse
|
160
|
-
else BaseLMResponse(
|
161
|
-
raw_response=cache_result, structured_output=None, tool_calls=None
|
162
|
-
)
|
163
|
-
)
|
164
|
-
|
165
|
-
mistral_messages = [
|
166
|
-
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
167
|
-
]
|
168
|
-
functions = [tool.to_mistral_tool() for tool in tools] if tools else None
|
169
|
-
|
170
|
-
params = {
|
171
|
-
"model": model,
|
172
|
-
"messages": mistral_messages,
|
173
|
-
"max_tokens": lm_config.get("max_tokens", 4096),
|
174
|
-
"temperature": lm_config.get(
|
175
|
-
"temperature", SPECIAL_BASE_TEMPS.get(model, 0)
|
176
|
-
),
|
177
|
-
"stream": False,
|
178
|
-
"tool_choice": "auto" if functions else None,
|
179
|
-
#"tools": functions,
|
180
|
-
}
|
181
|
-
if response_model:
|
182
|
-
params["response_format"] = response_model
|
183
|
-
elif tools:
|
184
|
-
params["tools"] = functions
|
185
|
-
|
186
|
-
with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
|
187
|
-
response = client.chat.complete(**params)
|
188
|
-
|
189
|
-
message = response.choices[0].message
|
190
|
-
try:
|
191
|
-
raw_response = message.content
|
192
|
-
except AttributeError:
|
193
|
-
raw_response = ""
|
194
|
-
|
195
|
-
tool_calls = []
|
196
|
-
try:
|
197
|
-
if message.tool_calls:
|
198
|
-
tool_calls = [
|
199
|
-
{
|
200
|
-
"id": call.id,
|
201
|
-
"type": "function",
|
202
|
-
"function": {
|
203
|
-
"name": call.function.name,
|
204
|
-
"arguments": call.function.arguments,
|
205
|
-
},
|
206
|
-
}
|
207
|
-
for call in message.tool_calls
|
208
|
-
]
|
209
|
-
except AttributeError:
|
210
|
-
pass
|
211
|
-
|
212
|
-
lm_response = BaseLMResponse(
|
213
|
-
raw_response=raw_response,
|
214
|
-
structured_output=None,
|
215
|
-
tool_calls=tool_calls if tool_calls else None,
|
216
|
-
)
|
217
|
-
used_cache_handler.add_to_managed_cache(
|
218
|
-
model, messages, lm_config=lm_config, output=lm_response, tools=tools
|
219
|
-
)
|
220
|
-
return lm_response
|
221
|
-
|
222
|
-
async def _hit_api_async_structured_output(
|
223
|
-
self,
|
224
|
-
model: str,
|
225
|
-
messages: List[Dict[str, Any]],
|
226
|
-
response_model: BaseModel,
|
227
|
-
temperature: float,
|
228
|
-
use_ephemeral_cache_only: bool = False,
|
229
|
-
) -> BaseLMResponse:
|
230
|
-
try:
|
231
|
-
mistral_messages = [
|
232
|
-
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
233
|
-
]
|
234
|
-
async with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
|
235
|
-
response = await client.chat.complete_async(
|
236
|
-
model=model,
|
237
|
-
messages=mistral_messages,
|
238
|
-
max_tokens=4096,
|
239
|
-
temperature=temperature,
|
240
|
-
stream=False,
|
241
|
-
)
|
242
|
-
result = response.choices[0].message.content
|
243
|
-
parsed = json.loads(result)
|
244
|
-
lm_response = BaseLMResponse(
|
245
|
-
raw_response="",
|
246
|
-
structured_output=response_model(**parsed),
|
247
|
-
tool_calls=None,
|
248
|
-
)
|
249
|
-
return lm_response
|
250
|
-
except (json.JSONDecodeError, pydantic.ValidationError):
|
251
|
-
if self._openai_fallback is None:
|
252
|
-
self._openai_fallback = OpenAIStructuredOutputClient()
|
253
|
-
return await self._openai_fallback._hit_api_async_structured_output(
|
254
|
-
model="gpt-4o",
|
255
|
-
messages=messages,
|
256
|
-
response_model=response_model,
|
257
|
-
temperature=temperature,
|
258
|
-
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
259
|
-
)
|
260
|
-
|
261
|
-
def _hit_api_sync_structured_output(
|
262
|
-
self,
|
263
|
-
model: str,
|
264
|
-
messages: List[Dict[str, Any]],
|
265
|
-
response_model: BaseModel,
|
266
|
-
temperature: float,
|
267
|
-
use_ephemeral_cache_only: bool = False,
|
268
|
-
) -> BaseLMResponse:
|
269
|
-
try:
|
270
|
-
mistral_messages = [
|
271
|
-
{"role": msg["role"], "content": msg["content"]} for msg in messages
|
272
|
-
]
|
273
|
-
with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
|
274
|
-
response = client.chat.complete(
|
275
|
-
model=model,
|
276
|
-
messages=mistral_messages,
|
277
|
-
max_tokens=4096,
|
278
|
-
temperature=temperature,
|
279
|
-
stream=False,
|
280
|
-
)
|
281
|
-
result = response.choices[0].message.content
|
282
|
-
parsed = json.loads(result)
|
283
|
-
lm_response = BaseLMResponse(
|
284
|
-
raw_response="",
|
285
|
-
structured_output=response_model(**parsed),
|
286
|
-
tool_calls=None,
|
287
|
-
)
|
288
|
-
return lm_response
|
289
|
-
except (json.JSONDecodeError, pydantic.ValidationError):
|
290
|
-
print("WARNING - Falling back to OpenAI - THIS IS SLOW")
|
291
|
-
if self._openai_fallback is None:
|
292
|
-
self._openai_fallback = OpenAIStructuredOutputClient()
|
293
|
-
return self._openai_fallback._hit_api_sync_structured_output(
|
294
|
-
model="gpt-4o",
|
295
|
-
messages=messages,
|
296
|
-
response_model=response_model,
|
297
|
-
temperature=temperature,
|
298
|
-
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
299
|
-
)
|
300
|
-
|
301
|
-
|
302
|
-
if __name__ == "__main__":
|
303
|
-
import asyncio
|
304
|
-
|
305
|
-
from pydantic import BaseModel
|
306
|
-
|
307
|
-
class TestModel(BaseModel):
|
308
|
-
name: str
|
309
|
-
|
310
|
-
client = MistralAPI(used_for_structured_outputs=True, exceptions_to_retry=[])
|
311
|
-
import time
|
312
|
-
|
313
|
-
t = time.time()
|
314
|
-
|
315
|
-
async def run_async():
|
316
|
-
response = await client._hit_api_async_structured_output(
|
317
|
-
model="mistral-large-latest",
|
318
|
-
messages=[{"role": "user", "content": "What is the capital of the moon?"}],
|
319
|
-
response_model=TestModel,
|
320
|
-
temperature=0.0,
|
321
|
-
)
|
322
|
-
print(response)
|
323
|
-
return response
|
324
|
-
|
325
|
-
response = asyncio.run(run_async())
|
326
|
-
t2 = time.time()
|
327
|
-
print(f"Got {len(response.name)} chars in {t2-t} seconds")
|