synth-ai 0.1.0.dev38__py3-none-any.whl → 0.1.0.dev49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. synth_ai/__init__.py +3 -1
  2. {synth_ai-0.1.0.dev38.dist-info → synth_ai-0.1.0.dev49.dist-info}/METADATA +12 -11
  3. synth_ai-0.1.0.dev49.dist-info/RECORD +6 -0
  4. {synth_ai-0.1.0.dev38.dist-info → synth_ai-0.1.0.dev49.dist-info}/WHEEL +1 -1
  5. synth_ai-0.1.0.dev49.dist-info/top_level.txt +1 -0
  6. private_tests/try_synth_sdk.py +0 -1
  7. public_tests/test_agent.py +0 -538
  8. public_tests/test_all_structured_outputs.py +0 -196
  9. public_tests/test_anthropic_structured_outputs.py +0 -0
  10. public_tests/test_deepseek_structured_outputs.py +0 -0
  11. public_tests/test_deepseek_tools.py +0 -64
  12. public_tests/test_gemini_output.py +0 -188
  13. public_tests/test_gemini_structured_outputs.py +0 -106
  14. public_tests/test_models.py +0 -183
  15. public_tests/test_openai_structured_outputs.py +0 -106
  16. public_tests/test_reasoning_effort.py +0 -75
  17. public_tests/test_reasoning_models.py +0 -92
  18. public_tests/test_recursive_structured_outputs.py +0 -180
  19. public_tests/test_structured.py +0 -137
  20. public_tests/test_structured_outputs.py +0 -109
  21. public_tests/test_synth_sdk.py +0 -384
  22. public_tests/test_text.py +0 -160
  23. public_tests/test_tools.py +0 -319
  24. synth_ai/zyk/__init__.py +0 -3
  25. synth_ai/zyk/lms/__init__.py +0 -0
  26. synth_ai/zyk/lms/caching/__init__.py +0 -0
  27. synth_ai/zyk/lms/caching/constants.py +0 -1
  28. synth_ai/zyk/lms/caching/dbs.py +0 -0
  29. synth_ai/zyk/lms/caching/ephemeral.py +0 -72
  30. synth_ai/zyk/lms/caching/handler.py +0 -142
  31. synth_ai/zyk/lms/caching/initialize.py +0 -13
  32. synth_ai/zyk/lms/caching/persistent.py +0 -83
  33. synth_ai/zyk/lms/config.py +0 -8
  34. synth_ai/zyk/lms/core/__init__.py +0 -0
  35. synth_ai/zyk/lms/core/all.py +0 -47
  36. synth_ai/zyk/lms/core/exceptions.py +0 -9
  37. synth_ai/zyk/lms/core/main.py +0 -314
  38. synth_ai/zyk/lms/core/vendor_clients.py +0 -85
  39. synth_ai/zyk/lms/cost/__init__.py +0 -0
  40. synth_ai/zyk/lms/cost/monitor.py +0 -1
  41. synth_ai/zyk/lms/cost/statefulness.py +0 -1
  42. synth_ai/zyk/lms/structured_outputs/__init__.py +0 -0
  43. synth_ai/zyk/lms/structured_outputs/handler.py +0 -442
  44. synth_ai/zyk/lms/structured_outputs/inject.py +0 -314
  45. synth_ai/zyk/lms/structured_outputs/rehabilitate.py +0 -187
  46. synth_ai/zyk/lms/tools/base.py +0 -104
  47. synth_ai/zyk/lms/vendors/__init__.py +0 -0
  48. synth_ai/zyk/lms/vendors/base.py +0 -31
  49. synth_ai/zyk/lms/vendors/constants.py +0 -22
  50. synth_ai/zyk/lms/vendors/core/__init__.py +0 -0
  51. synth_ai/zyk/lms/vendors/core/anthropic_api.py +0 -413
  52. synth_ai/zyk/lms/vendors/core/gemini_api.py +0 -306
  53. synth_ai/zyk/lms/vendors/core/mistral_api.py +0 -327
  54. synth_ai/zyk/lms/vendors/core/openai_api.py +0 -185
  55. synth_ai/zyk/lms/vendors/local/__init__.py +0 -0
  56. synth_ai/zyk/lms/vendors/local/ollama.py +0 -0
  57. synth_ai/zyk/lms/vendors/openai_standard.py +0 -375
  58. synth_ai/zyk/lms/vendors/retries.py +0 -3
  59. synth_ai/zyk/lms/vendors/supported/__init__.py +0 -0
  60. synth_ai/zyk/lms/vendors/supported/deepseek.py +0 -73
  61. synth_ai/zyk/lms/vendors/supported/groq.py +0 -16
  62. synth_ai/zyk/lms/vendors/supported/ollama.py +0 -14
  63. synth_ai/zyk/lms/vendors/supported/together.py +0 -11
  64. synth_ai-0.1.0.dev38.dist-info/RECORD +0 -67
  65. synth_ai-0.1.0.dev38.dist-info/top_level.txt +0 -4
  66. tests/test_agent.py +0 -538
  67. tests/test_recursive_structured_outputs.py +0 -180
  68. tests/test_structured_outputs.py +0 -100
  69. {synth_ai-0.1.0.dev38.dist-info → synth_ai-0.1.0.dev49.dist-info}/licenses/LICENSE +0 -0
@@ -1,306 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import warnings
5
- from typing import Any, Dict, List, Optional, Tuple, Type
6
-
7
- import google.generativeai as genai
8
- from google.api_core.exceptions import ResourceExhausted
9
- from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool
10
-
11
- from synth_ai.zyk.lms.caching.initialize import (
12
- get_cache_handler,
13
- )
14
- from synth_ai.zyk.lms.tools.base import BaseTool
15
- from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
16
- from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
17
- from synth_ai.zyk.lms.vendors.retries import BACKOFF_TOLERANCE, backoff
18
-
19
- GEMINI_EXCEPTIONS_TO_RETRY: Tuple[Type[Exception], ...] = (ResourceExhausted,)
20
- logging.getLogger("google.generativeai").setLevel(logging.ERROR)
21
- os.environ["GRPC_VERBOSITY"] = "ERROR"
22
- # Suppress TensorFlow logging
23
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
24
- warnings.filterwarnings("ignore")
25
-
26
- SAFETY_SETTINGS = {
27
- HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
28
- HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
29
- HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
30
- HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
31
- }
32
-
33
-
34
- class GeminiAPI(VendorBase):
35
- used_for_structured_outputs: bool = True
36
- exceptions_to_retry: Tuple[Type[Exception], ...] = GEMINI_EXCEPTIONS_TO_RETRY
37
-
38
- def __init__(
39
- self,
40
- exceptions_to_retry: Tuple[Type[Exception], ...] = GEMINI_EXCEPTIONS_TO_RETRY,
41
- used_for_structured_outputs: bool = False,
42
- ):
43
- self.used_for_structured_outputs = used_for_structured_outputs
44
- self.exceptions_to_retry = exceptions_to_retry
45
-
46
- def _convert_messages_to_contents(
47
- self, messages: List[Dict[str, Any]]
48
- ) -> List[Dict[str, Any]]:
49
- contents = []
50
- system_instruction = None
51
- for message in messages:
52
- if message["role"] == "system":
53
- system_instruction = (
54
- f"<instructions>\n{message['content']}\n</instructions>"
55
- )
56
- continue
57
- elif system_instruction:
58
- text = system_instruction + "\n" + message["content"]
59
- else:
60
- text = message["content"]
61
- contents.append(
62
- {
63
- "role": message["role"],
64
- "parts": [{"text": text}],
65
- }
66
- )
67
- return contents
68
-
69
- def _convert_tools_to_gemini_format(self, tools: List[Any]) -> Tool:
70
- function_declarations = []
71
- for tool in tools:
72
- # Try to use to_gemini_tool method if available, otherwise assume it's a dict
73
- try:
74
- function_declarations.append(tool.to_gemini_tool())
75
- except AttributeError:
76
- # If tool is a properly formatted dict, use it directly
77
- if "name" in tool and "parameters" in tool:
78
- function_declarations.append(tool)
79
- else:
80
- raise ValueError(
81
- f"Unsupported tool format. Tools must be BaseTool instances or properly formatted dictionaries."
82
- )
83
- return Tool(function_declarations=function_declarations)
84
-
85
- def _convert_args_to_dict(self, args):
86
- """
87
- Recursively convert Gemini's args objects to Python dictionaries.
88
- """
89
- # Try to convert dict-like objects
90
- try:
91
- return {k: self._convert_args_to_dict(v) for k, v in args.items()}
92
- except (AttributeError, TypeError):
93
- # Try to convert list-like objects
94
- try:
95
- if isinstance(args, (str, bytes)):
96
- return args
97
- return [self._convert_args_to_dict(item) for item in args]
98
- except (TypeError, AttributeError):
99
- # Base case: primitive value
100
- return args
101
-
102
- async def _private_request_async(
103
- self,
104
- messages: List[Dict],
105
- temperature: float = 0,
106
- model_name: str = "gemini-1.5-flash",
107
- reasoning_effort: str = "high",
108
- tools: Optional[List[BaseTool]] = None,
109
- lm_config: Optional[Dict[str, Any]] = None,
110
- ) -> Tuple[str, Optional[List[Dict]]]:
111
- generation_config = {
112
- "temperature": temperature,
113
- }
114
- # Add max_output_tokens if max_tokens is in lm_config
115
- if lm_config and "max_tokens" in lm_config:
116
- generation_config["max_output_tokens"] = lm_config["max_tokens"]
117
-
118
- tools_config = None
119
- if tools:
120
- tools_config = self._convert_tools_to_gemini_format(tools)
121
-
122
- # Extract tool_config from lm_config if provided
123
- tool_config = (
124
- lm_config.get("tool_config")
125
- if lm_config
126
- else {"function_calling_config": {"mode": "any"}}
127
- )
128
-
129
- code_generation_model = genai.GenerativeModel(
130
- model_name=model_name,
131
- generation_config=generation_config,
132
- tools=tools_config if tools_config else None,
133
- tool_config=tool_config,
134
- )
135
-
136
- contents = self._convert_messages_to_contents(messages)
137
- result = await code_generation_model.generate_content_async(
138
- contents=contents,
139
- safety_settings=SAFETY_SETTINGS,
140
- )
141
-
142
- text = result.candidates[0].content.parts[0].text
143
- tool_calls = []
144
- for part in result.candidates[0].content.parts:
145
- if part.function_call:
146
- # Convert complex objects to Python dictionaries recursively
147
- args_dict = self._convert_args_to_dict(part.function_call.args)
148
- # Ensure serializable arguments
149
- tool_calls.append(
150
- {
151
- "id": f"call_{len(tool_calls) + 1}", # Generate unique IDs
152
- "type": "function",
153
- "function": {
154
- "name": part.function_call.name,
155
- "arguments": json.dumps(args_dict),
156
- },
157
- }
158
- )
159
- return text, tool_calls if tool_calls else None
160
-
161
- def _private_request_sync(
162
- self,
163
- messages: List[Dict],
164
- temperature: float = 0,
165
- model_name: str = "gemini-1.5-flash",
166
- reasoning_effort: str = "high",
167
- tools: Optional[List[BaseTool]] = None,
168
- lm_config: Optional[Dict[str, Any]] = None,
169
- ) -> Tuple[str, Optional[List[Dict]]]:
170
- generation_config = {
171
- "temperature": temperature,
172
- }
173
- # Add max_output_tokens if max_tokens is in lm_config
174
- if lm_config and "max_tokens" in lm_config:
175
- generation_config["max_output_tokens"] = lm_config["max_tokens"]
176
-
177
- tools_config = None
178
- if tools:
179
- tools_config = self._convert_tools_to_gemini_format(tools)
180
-
181
- # Extract tool_config from lm_config if provided
182
- tool_config = (
183
- lm_config.get("tool_config")
184
- if lm_config
185
- else {"function_calling_config": {"mode": "any"}}
186
- )
187
-
188
- code_generation_model = genai.GenerativeModel(
189
- model_name=model_name,
190
- generation_config=generation_config,
191
- tools=tools_config if tools_config else None,
192
- tool_config=tool_config,
193
- )
194
-
195
- contents = self._convert_messages_to_contents(messages)
196
- result = code_generation_model.generate_content(
197
- contents=contents,
198
- safety_settings=SAFETY_SETTINGS,
199
- )
200
-
201
- text = result.candidates[0].content.parts[0].text
202
- tool_calls = []
203
- for part in result.candidates[0].content.parts:
204
- if part.function_call:
205
- # Convert complex objects to Python dictionaries recursively
206
- args_dict = self._convert_args_to_dict(part.function_call.args)
207
- # Ensure serializable arguments
208
- tool_calls.append(
209
- {
210
- "id": f"call_{len(tool_calls) + 1}", # Generate unique IDs
211
- "type": "function",
212
- "function": {
213
- "name": part.function_call.name,
214
- "arguments": json.dumps(args_dict),
215
- },
216
- }
217
- )
218
- return text, tool_calls if tool_calls else None
219
-
220
- @backoff.on_exception(
221
- backoff.expo,
222
- exceptions_to_retry,
223
- max_tries=BACKOFF_TOLERANCE,
224
- on_giveup=lambda e: print(e),
225
- )
226
- async def _hit_api_async(
227
- self,
228
- model: str,
229
- messages: List[Dict[str, Any]],
230
- lm_config: Dict[str, Any],
231
- use_ephemeral_cache_only: bool = False,
232
- reasoning_effort: str = "high",
233
- tools: Optional[List[BaseTool]] = None,
234
- ) -> BaseLMResponse:
235
- assert (
236
- lm_config.get("response_model", None) is None
237
- ), "response_model is not supported for standard calls"
238
- used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
239
- cache_result = used_cache_handler.hit_managed_cache(
240
- model, messages, lm_config=lm_config, tools=tools, reasoning_effort=reasoning_effort
241
- )
242
- if cache_result:
243
- return cache_result
244
-
245
- raw_response, tool_calls = await self._private_request_async(
246
- messages,
247
- temperature=lm_config.get("temperature", SPECIAL_BASE_TEMPS.get(model, 0)),
248
- reasoning_effort=reasoning_effort,
249
- tools=tools,
250
- )
251
-
252
- lm_response = BaseLMResponse(
253
- raw_response=raw_response,
254
- structured_output=None,
255
- tool_calls=tool_calls,
256
- )
257
-
258
- used_cache_handler.add_to_managed_cache(
259
- model, messages, lm_config=lm_config, output=lm_response, tools=tools, reasoning_effort=reasoning_effort
260
- )
261
- return lm_response
262
-
263
- @backoff.on_exception(
264
- backoff.expo,
265
- exceptions_to_retry,
266
- max_tries=BACKOFF_TOLERANCE,
267
- on_giveup=lambda e: print(e),
268
- )
269
- def _hit_api_sync(
270
- self,
271
- model: str,
272
- messages: List[Dict[str, Any]],
273
- lm_config: Dict[str, Any],
274
- use_ephemeral_cache_only: bool = False,
275
- reasoning_effort: str = "high",
276
- tools: Optional[List[BaseTool]] = None,
277
- ) -> BaseLMResponse:
278
- assert (
279
- lm_config.get("response_model", None) is None
280
- ), "response_model is not supported for standard calls"
281
- used_cache_handler = get_cache_handler(
282
- use_ephemeral_cache_only=use_ephemeral_cache_only
283
- )
284
- cache_result = used_cache_handler.hit_managed_cache(
285
- model, messages, lm_config=lm_config, tools=tools, reasoning_effort=reasoning_effort
286
- )
287
- if cache_result:
288
- return cache_result
289
-
290
- raw_response, tool_calls = self._private_request_sync(
291
- messages,
292
- temperature=lm_config.get("temperature", SPECIAL_BASE_TEMPS.get(model, 0)),
293
- reasoning_effort=reasoning_effort,
294
- tools=tools,
295
- )
296
-
297
- lm_response = BaseLMResponse(
298
- raw_response=raw_response,
299
- structured_output=None,
300
- tool_calls=tool_calls,
301
- )
302
-
303
- used_cache_handler.add_to_managed_cache(
304
- model, messages, lm_config=lm_config, output=lm_response, tools=tools, reasoning_effort=reasoning_effort
305
- )
306
- return lm_response
@@ -1,327 +0,0 @@
1
- import json
2
- import os
3
- from typing import Any, Dict, List, Optional, Tuple, Type
4
-
5
- import pydantic
6
- from mistralai import Mistral # use Mistral as both sync and async client
7
- from pydantic import BaseModel
8
-
9
- from synth_ai.zyk.lms.caching.initialize import get_cache_handler
10
- from synth_ai.zyk.lms.tools.base import BaseTool
11
- from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
12
- from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
13
- from synth_ai.zyk.lms.vendors.core.openai_api import OpenAIStructuredOutputClient
14
-
15
- # Since the mistralai package doesn't expose an exceptions module,
16
- # we fallback to catching all Exceptions for retry.
17
- MISTRAL_EXCEPTIONS_TO_RETRY: Tuple[Type[Exception], ...] = (Exception,)
18
-
19
-
20
- class MistralAPI(VendorBase):
21
- used_for_structured_outputs: bool = True
22
- exceptions_to_retry: Tuple = MISTRAL_EXCEPTIONS_TO_RETRY
23
- _openai_fallback: Any
24
-
25
- def __init__(
26
- self,
27
- exceptions_to_retry: Tuple[Type[Exception], ...] = MISTRAL_EXCEPTIONS_TO_RETRY,
28
- used_for_structured_outputs: bool = False,
29
- ):
30
- self.used_for_structured_outputs = used_for_structured_outputs
31
- self.exceptions_to_retry = exceptions_to_retry
32
- self._openai_fallback = None
33
-
34
- # @backoff.on_exception(
35
- # backoff.expo,
36
- # MISTRAL_EXCEPTIONS_TO_RETRY,
37
- # max_tries=BACKOFF_TOLERANCE,
38
- # on_giveup=lambda e: print(e),
39
- # )
40
- async def _hit_api_async(
41
- self,
42
- model: str,
43
- messages: List[Dict[str, Any]],
44
- lm_config: Dict[str, Any],
45
- response_model: Optional[BaseModel] = None,
46
- use_ephemeral_cache_only: bool = False,
47
- reasoning_effort: str = "high",
48
- tools: Optional[List[BaseTool]] = None,
49
- ) -> BaseLMResponse:
50
- assert (
51
- lm_config.get("response_model", None) is None
52
- ), "response_model is not supported for standard calls"
53
- assert not (response_model and tools), "Cannot provide both response_model and tools"
54
- used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
55
- cache_result = used_cache_handler.hit_managed_cache(
56
- model, messages, lm_config=lm_config, tools=tools
57
- )
58
- if cache_result:
59
- assert type(cache_result) in [
60
- BaseLMResponse,
61
- str,
62
- ], f"Expected BaseLMResponse or str, got {type(cache_result)}"
63
- return (
64
- cache_result
65
- if type(cache_result) == BaseLMResponse
66
- else BaseLMResponse(
67
- raw_response=cache_result, structured_output=None, tool_calls=None
68
- )
69
- )
70
-
71
- mistral_messages = [
72
- {"role": msg["role"], "content": msg["content"]} for msg in messages
73
- ]
74
- functions = [tool.to_mistral_tool() for tool in tools] if tools else None
75
- params = {
76
- "model": model,
77
- "messages": mistral_messages,
78
- "max_tokens": lm_config.get("max_tokens", 4096),
79
- "temperature": lm_config.get(
80
- "temperature", SPECIAL_BASE_TEMPS.get(model, 0)
81
- ),
82
- "stream": False,
83
- "tool_choice": "auto" if functions else None,
84
-
85
- }
86
- if response_model:
87
- params["response_format"] = response_model
88
- elif tools:
89
- params["tools"] = functions
90
-
91
- async with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
92
- response = await client.chat.complete_async(**params)
93
-
94
- message = response.choices[0].message
95
- try:
96
- raw_response = message.content
97
- except AttributeError:
98
- raw_response = ""
99
-
100
- tool_calls = []
101
- try:
102
- if message.tool_calls:
103
- tool_calls = [
104
- {
105
- "id": call.id,
106
- "type": "function",
107
- "function": {
108
- "name": call.function.name,
109
- "arguments": call.function.arguments,
110
- },
111
- }
112
- for call in message.tool_calls
113
- ]
114
- except AttributeError:
115
- pass
116
-
117
- lm_response = BaseLMResponse(
118
- raw_response=raw_response,
119
- structured_output=None,
120
- tool_calls=tool_calls if tool_calls else None,
121
- )
122
- used_cache_handler.add_to_managed_cache(
123
- model, messages, lm_config=lm_config, output=lm_response, tools=tools
124
- )
125
- return lm_response
126
-
127
- # @backoff.on_exception(
128
- # backoff.expo,
129
- # MISTRAL_EXCEPTIONS_TO_RETRY,
130
- # max_tries=BACKOFF_TOLERANCE,
131
- # on_giveup=lambda e: print(e),
132
- # )
133
- def _hit_api_sync(
134
- self,
135
- model: str,
136
- messages: List[Dict[str, Any]],
137
- lm_config: Dict[str, Any],
138
- response_model: Optional[BaseModel] = None,
139
- use_ephemeral_cache_only: bool = False,
140
- reasoning_effort: str = "high",
141
- tools: Optional[List[BaseTool]] = None,
142
- ) -> BaseLMResponse:
143
- assert (
144
- lm_config.get("response_model", None) is None
145
- ), "response_model is not supported for standard calls"
146
- assert not (response_model and tools), "Cannot provide both response_model and tools"
147
-
148
- used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
149
- cache_result = used_cache_handler.hit_managed_cache(
150
- model, messages, lm_config=lm_config, tools=tools
151
- )
152
- if cache_result:
153
- assert type(cache_result) in [
154
- BaseLMResponse,
155
- str,
156
- ], f"Expected BaseLMResponse or str, got {type(cache_result)}"
157
- return (
158
- cache_result
159
- if type(cache_result) == BaseLMResponse
160
- else BaseLMResponse(
161
- raw_response=cache_result, structured_output=None, tool_calls=None
162
- )
163
- )
164
-
165
- mistral_messages = [
166
- {"role": msg["role"], "content": msg["content"]} for msg in messages
167
- ]
168
- functions = [tool.to_mistral_tool() for tool in tools] if tools else None
169
-
170
- params = {
171
- "model": model,
172
- "messages": mistral_messages,
173
- "max_tokens": lm_config.get("max_tokens", 4096),
174
- "temperature": lm_config.get(
175
- "temperature", SPECIAL_BASE_TEMPS.get(model, 0)
176
- ),
177
- "stream": False,
178
- "tool_choice": "auto" if functions else None,
179
- #"tools": functions,
180
- }
181
- if response_model:
182
- params["response_format"] = response_model
183
- elif tools:
184
- params["tools"] = functions
185
-
186
- with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
187
- response = client.chat.complete(**params)
188
-
189
- message = response.choices[0].message
190
- try:
191
- raw_response = message.content
192
- except AttributeError:
193
- raw_response = ""
194
-
195
- tool_calls = []
196
- try:
197
- if message.tool_calls:
198
- tool_calls = [
199
- {
200
- "id": call.id,
201
- "type": "function",
202
- "function": {
203
- "name": call.function.name,
204
- "arguments": call.function.arguments,
205
- },
206
- }
207
- for call in message.tool_calls
208
- ]
209
- except AttributeError:
210
- pass
211
-
212
- lm_response = BaseLMResponse(
213
- raw_response=raw_response,
214
- structured_output=None,
215
- tool_calls=tool_calls if tool_calls else None,
216
- )
217
- used_cache_handler.add_to_managed_cache(
218
- model, messages, lm_config=lm_config, output=lm_response, tools=tools
219
- )
220
- return lm_response
221
-
222
- async def _hit_api_async_structured_output(
223
- self,
224
- model: str,
225
- messages: List[Dict[str, Any]],
226
- response_model: BaseModel,
227
- temperature: float,
228
- use_ephemeral_cache_only: bool = False,
229
- ) -> BaseLMResponse:
230
- try:
231
- mistral_messages = [
232
- {"role": msg["role"], "content": msg["content"]} for msg in messages
233
- ]
234
- async with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
235
- response = await client.chat.complete_async(
236
- model=model,
237
- messages=mistral_messages,
238
- max_tokens=4096,
239
- temperature=temperature,
240
- stream=False,
241
- )
242
- result = response.choices[0].message.content
243
- parsed = json.loads(result)
244
- lm_response = BaseLMResponse(
245
- raw_response="",
246
- structured_output=response_model(**parsed),
247
- tool_calls=None,
248
- )
249
- return lm_response
250
- except (json.JSONDecodeError, pydantic.ValidationError):
251
- if self._openai_fallback is None:
252
- self._openai_fallback = OpenAIStructuredOutputClient()
253
- return await self._openai_fallback._hit_api_async_structured_output(
254
- model="gpt-4o",
255
- messages=messages,
256
- response_model=response_model,
257
- temperature=temperature,
258
- use_ephemeral_cache_only=use_ephemeral_cache_only,
259
- )
260
-
261
- def _hit_api_sync_structured_output(
262
- self,
263
- model: str,
264
- messages: List[Dict[str, Any]],
265
- response_model: BaseModel,
266
- temperature: float,
267
- use_ephemeral_cache_only: bool = False,
268
- ) -> BaseLMResponse:
269
- try:
270
- mistral_messages = [
271
- {"role": msg["role"], "content": msg["content"]} for msg in messages
272
- ]
273
- with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
274
- response = client.chat.complete(
275
- model=model,
276
- messages=mistral_messages,
277
- max_tokens=4096,
278
- temperature=temperature,
279
- stream=False,
280
- )
281
- result = response.choices[0].message.content
282
- parsed = json.loads(result)
283
- lm_response = BaseLMResponse(
284
- raw_response="",
285
- structured_output=response_model(**parsed),
286
- tool_calls=None,
287
- )
288
- return lm_response
289
- except (json.JSONDecodeError, pydantic.ValidationError):
290
- print("WARNING - Falling back to OpenAI - THIS IS SLOW")
291
- if self._openai_fallback is None:
292
- self._openai_fallback = OpenAIStructuredOutputClient()
293
- return self._openai_fallback._hit_api_sync_structured_output(
294
- model="gpt-4o",
295
- messages=messages,
296
- response_model=response_model,
297
- temperature=temperature,
298
- use_ephemeral_cache_only=use_ephemeral_cache_only,
299
- )
300
-
301
-
302
- if __name__ == "__main__":
303
- import asyncio
304
-
305
- from pydantic import BaseModel
306
-
307
- class TestModel(BaseModel):
308
- name: str
309
-
310
- client = MistralAPI(used_for_structured_outputs=True, exceptions_to_retry=[])
311
- import time
312
-
313
- t = time.time()
314
-
315
- async def run_async():
316
- response = await client._hit_api_async_structured_output(
317
- model="mistral-large-latest",
318
- messages=[{"role": "user", "content": "What is the capital of the moon?"}],
319
- response_model=TestModel,
320
- temperature=0.0,
321
- )
322
- print(response)
323
- return response
324
-
325
- response = asyncio.run(run_async())
326
- t2 = time.time()
327
- print(f"Got {len(response.name)} chars in {t2-t} seconds")