synth-ai 0.1.0.dev28__py3-none-any.whl → 0.1.0.dev30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. public_tests/test_agent.py +11 -11
  2. public_tests/test_all_structured_outputs.py +32 -37
  3. public_tests/test_anthropic_structured_outputs.py +0 -0
  4. public_tests/test_deepseek_structured_outputs.py +0 -0
  5. public_tests/test_deepseek_tools.py +64 -0
  6. public_tests/test_gemini_structured_outputs.py +106 -0
  7. public_tests/test_models.py +27 -27
  8. public_tests/test_openai_structured_outputs.py +106 -0
  9. public_tests/test_reasoning_models.py +9 -7
  10. public_tests/test_recursive_structured_outputs.py +30 -30
  11. public_tests/test_structured.py +137 -0
  12. public_tests/test_structured_outputs.py +22 -13
  13. public_tests/test_text.py +160 -0
  14. public_tests/test_tools.py +300 -0
  15. synth_ai/__init__.py +1 -4
  16. synth_ai/zyk/__init__.py +2 -2
  17. synth_ai/zyk/lms/caching/ephemeral.py +54 -32
  18. synth_ai/zyk/lms/caching/handler.py +43 -15
  19. synth_ai/zyk/lms/caching/persistent.py +55 -27
  20. synth_ai/zyk/lms/core/main.py +29 -16
  21. synth_ai/zyk/lms/core/vendor_clients.py +1 -1
  22. synth_ai/zyk/lms/structured_outputs/handler.py +79 -45
  23. synth_ai/zyk/lms/structured_outputs/rehabilitate.py +3 -2
  24. synth_ai/zyk/lms/tools/base.py +104 -0
  25. synth_ai/zyk/lms/vendors/base.py +22 -6
  26. synth_ai/zyk/lms/vendors/core/anthropic_api.py +130 -95
  27. synth_ai/zyk/lms/vendors/core/gemini_api.py +153 -34
  28. synth_ai/zyk/lms/vendors/core/mistral_api.py +160 -54
  29. synth_ai/zyk/lms/vendors/core/openai_api.py +64 -53
  30. synth_ai/zyk/lms/vendors/openai_standard.py +197 -41
  31. synth_ai/zyk/lms/vendors/supported/deepseek.py +55 -0
  32. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/METADATA +2 -5
  33. synth_ai-0.1.0.dev30.dist-info/RECORD +65 -0
  34. public_tests/test_sonnet_thinking.py +0 -217
  35. synth_ai-0.1.0.dev28.dist-info/RECORD +0 -57
  36. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/WHEEL +0 -0
  37. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/licenses/LICENSE +0 -0
  38. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,16 @@
1
1
  import json
2
2
  import os
3
- from typing import Any, Dict, List, Tuple, Type
3
+ from typing import Any, Dict, List, Optional, Tuple, Type
4
4
 
5
5
  import pydantic
6
6
  from mistralai import Mistral # use Mistral as both sync and async client
7
7
  from pydantic import BaseModel
8
8
 
9
9
  from synth_ai.zyk.lms.caching.initialize import get_cache_handler
10
- from synth_ai.zyk.lms.vendors.base import VendorBase
10
+ from synth_ai.zyk.lms.tools.base import BaseTool
11
+ from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
11
12
  from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
12
13
  from synth_ai.zyk.lms.vendors.core.openai_api import OpenAIStructuredOutputClient
13
- from synth_ai.zyk.lms.vendors.retries import BACKOFF_TOLERANCE, backoff
14
14
 
15
15
  # Since the mistralai package doesn't expose an exceptions module,
16
16
  # we fallback to catching all Exceptions for retry.
@@ -31,97 +31,193 @@ class MistralAPI(VendorBase):
31
31
  self.exceptions_to_retry = exceptions_to_retry
32
32
  self._openai_fallback = None
33
33
 
34
- @backoff.on_exception(
35
- backoff.expo,
36
- MISTRAL_EXCEPTIONS_TO_RETRY,
37
- max_tries=BACKOFF_TOLERANCE,
38
- on_giveup=lambda e: print(e),
39
- )
34
+ # @backoff.on_exception(
35
+ # backoff.expo,
36
+ # MISTRAL_EXCEPTIONS_TO_RETRY,
37
+ # max_tries=BACKOFF_TOLERANCE,
38
+ # on_giveup=lambda e: print(e),
39
+ # )
40
40
  async def _hit_api_async(
41
41
  self,
42
42
  model: str,
43
43
  messages: List[Dict[str, Any]],
44
44
  lm_config: Dict[str, Any],
45
+ response_model: Optional[BaseModel] = None,
45
46
  use_ephemeral_cache_only: bool = False,
46
- ) -> str:
47
+ reasoning_effort: str = "high",
48
+ tools: Optional[List[BaseTool]] = None,
49
+ ) -> BaseLMResponse:
47
50
  assert (
48
51
  lm_config.get("response_model", None) is None
49
52
  ), "response_model is not supported for standard calls"
53
+ assert not (response_model and tools), "Cannot provide both response_model and tools"
50
54
  used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
51
55
  cache_result = used_cache_handler.hit_managed_cache(
52
- model, messages, lm_config=lm_config
56
+ model, messages, lm_config=lm_config, tools=tools
53
57
  )
54
58
  if cache_result:
59
+ assert type(cache_result) in [
60
+ BaseLMResponse,
61
+ str,
62
+ ], f"Expected BaseLMResponse or str, got {type(cache_result)}"
55
63
  return (
56
- cache_result["response"]
57
- if isinstance(cache_result, dict)
58
- else cache_result
64
+ cache_result
65
+ if type(cache_result) == BaseLMResponse
66
+ else BaseLMResponse(
67
+ raw_response=cache_result, structured_output=None, tool_calls=None
68
+ )
59
69
  )
60
70
 
61
71
  mistral_messages = [
62
72
  {"role": msg["role"], "content": msg["content"]} for msg in messages
63
73
  ]
74
+ functions = [tool.to_mistral_tool() for tool in tools] if tools else None
75
+ params = {
76
+ "model": model,
77
+ "messages": mistral_messages,
78
+ "max_tokens": lm_config.get("max_tokens", 4096),
79
+ "temperature": lm_config.get(
80
+ "temperature", SPECIAL_BASE_TEMPS.get(model, 0)
81
+ ),
82
+ "stream": False,
83
+ "tool_choice": "auto" if functions else None,
84
+
85
+ }
86
+ if response_model:
87
+ params["response_format"] = response_model
88
+ elif tools:
89
+ params["tools"] = functions
90
+
64
91
  async with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
65
- response = await client.chat.complete_async(
66
- model=model,
67
- messages=mistral_messages,
68
- max_tokens=lm_config.get("max_tokens", 4096),
69
- temperature=lm_config.get(
70
- "temperature", SPECIAL_BASE_TEMPS.get(model, 0)
71
- ),
72
- stream=False,
73
- )
74
- api_result = response.choices[0].message.content
92
+ response = await client.chat.complete_async(**params)
93
+
94
+ message = response.choices[0].message
95
+ try:
96
+ raw_response = message.content
97
+ except AttributeError:
98
+ raw_response = ""
99
+
100
+ tool_calls = []
101
+ try:
102
+ if message.tool_calls:
103
+ tool_calls = [
104
+ {
105
+ "id": call.id,
106
+ "type": "function",
107
+ "function": {
108
+ "name": call.function.name,
109
+ "arguments": call.function.arguments,
110
+ },
111
+ }
112
+ for call in message.tool_calls
113
+ ]
114
+ except AttributeError:
115
+ pass
116
+
117
+ lm_response = BaseLMResponse(
118
+ raw_response=raw_response,
119
+ structured_output=None,
120
+ tool_calls=tool_calls if tool_calls else None,
121
+ )
75
122
  used_cache_handler.add_to_managed_cache(
76
- model, messages, lm_config=lm_config, output=api_result
123
+ model, messages, lm_config=lm_config, output=lm_response, tools=tools
77
124
  )
78
- return api_result
79
-
80
- @backoff.on_exception(
81
- backoff.expo,
82
- MISTRAL_EXCEPTIONS_TO_RETRY,
83
- max_tries=BACKOFF_TOLERANCE,
84
- on_giveup=lambda e: print(e),
85
- )
125
+ return lm_response
126
+
127
+ # @backoff.on_exception(
128
+ # backoff.expo,
129
+ # MISTRAL_EXCEPTIONS_TO_RETRY,
130
+ # max_tries=BACKOFF_TOLERANCE,
131
+ # on_giveup=lambda e: print(e),
132
+ # )
86
133
  def _hit_api_sync(
87
134
  self,
88
135
  model: str,
89
136
  messages: List[Dict[str, Any]],
90
137
  lm_config: Dict[str, Any],
138
+ response_model: Optional[BaseModel] = None,
91
139
  use_ephemeral_cache_only: bool = False,
92
- ) -> str:
140
+ reasoning_effort: str = "high",
141
+ tools: Optional[List[BaseTool]] = None,
142
+ ) -> BaseLMResponse:
93
143
  assert (
94
144
  lm_config.get("response_model", None) is None
95
145
  ), "response_model is not supported for standard calls"
146
+ assert not (response_model and tools), "Cannot provide both response_model and tools"
147
+
96
148
  used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
97
149
  cache_result = used_cache_handler.hit_managed_cache(
98
- model, messages, lm_config=lm_config
150
+ model, messages, lm_config=lm_config, tools=tools
99
151
  )
100
152
  if cache_result:
153
+ assert type(cache_result) in [
154
+ BaseLMResponse,
155
+ str,
156
+ ], f"Expected BaseLMResponse or str, got {type(cache_result)}"
101
157
  return (
102
- cache_result["response"]
103
- if isinstance(cache_result, dict)
104
- else cache_result
158
+ cache_result
159
+ if type(cache_result) == BaseLMResponse
160
+ else BaseLMResponse(
161
+ raw_response=cache_result, structured_output=None, tool_calls=None
162
+ )
105
163
  )
106
164
 
107
165
  mistral_messages = [
108
166
  {"role": msg["role"], "content": msg["content"]} for msg in messages
109
167
  ]
168
+ functions = [tool.to_mistral_tool() for tool in tools] if tools else None
169
+
170
+ params = {
171
+ "model": model,
172
+ "messages": mistral_messages,
173
+ "max_tokens": lm_config.get("max_tokens", 4096),
174
+ "temperature": lm_config.get(
175
+ "temperature", SPECIAL_BASE_TEMPS.get(model, 0)
176
+ ),
177
+ "stream": False,
178
+ "tool_choice": "auto" if functions else None,
179
+ #"tools": functions,
180
+ }
181
+ if response_model:
182
+ params["response_format"] = response_model
183
+ elif tools:
184
+ params["tools"] = functions
185
+
110
186
  with Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) as client:
111
- response = client.chat.complete(
112
- model=model,
113
- messages=mistral_messages,
114
- max_tokens=lm_config.get("max_tokens", 4096),
115
- temperature=lm_config.get(
116
- "temperature", SPECIAL_BASE_TEMPS.get(model, 0)
117
- ),
118
- stream=False,
119
- )
120
- api_result = response.choices[0].message.content
187
+ response = client.chat.complete(**params)
188
+
189
+ message = response.choices[0].message
190
+ try:
191
+ raw_response = message.content
192
+ except AttributeError:
193
+ raw_response = ""
194
+
195
+ tool_calls = []
196
+ try:
197
+ if message.tool_calls:
198
+ tool_calls = [
199
+ {
200
+ "id": call.id,
201
+ "type": "function",
202
+ "function": {
203
+ "name": call.function.name,
204
+ "arguments": call.function.arguments,
205
+ },
206
+ }
207
+ for call in message.tool_calls
208
+ ]
209
+ except AttributeError:
210
+ pass
211
+
212
+ lm_response = BaseLMResponse(
213
+ raw_response=raw_response,
214
+ structured_output=None,
215
+ tool_calls=tool_calls if tool_calls else None,
216
+ )
121
217
  used_cache_handler.add_to_managed_cache(
122
- model, messages, lm_config=lm_config, output=api_result
218
+ model, messages, lm_config=lm_config, output=lm_response, tools=tools
123
219
  )
124
- return api_result
220
+ return lm_response
125
221
 
126
222
  async def _hit_api_async_structured_output(
127
223
  self,
@@ -130,7 +226,7 @@ class MistralAPI(VendorBase):
130
226
  response_model: BaseModel,
131
227
  temperature: float,
132
228
  use_ephemeral_cache_only: bool = False,
133
- ) -> Any:
229
+ ) -> BaseLMResponse:
134
230
  try:
135
231
  mistral_messages = [
136
232
  {"role": msg["role"], "content": msg["content"]} for msg in messages
@@ -145,7 +241,12 @@ class MistralAPI(VendorBase):
145
241
  )
146
242
  result = response.choices[0].message.content
147
243
  parsed = json.loads(result)
148
- return response_model(**parsed)
244
+ lm_response = BaseLMResponse(
245
+ raw_response="",
246
+ structured_output=response_model(**parsed),
247
+ tool_calls=None,
248
+ )
249
+ return lm_response
149
250
  except (json.JSONDecodeError, pydantic.ValidationError):
150
251
  if self._openai_fallback is None:
151
252
  self._openai_fallback = OpenAIStructuredOutputClient()
@@ -164,7 +265,7 @@ class MistralAPI(VendorBase):
164
265
  response_model: BaseModel,
165
266
  temperature: float,
166
267
  use_ephemeral_cache_only: bool = False,
167
- ) -> Any:
268
+ ) -> BaseLMResponse:
168
269
  try:
169
270
  mistral_messages = [
170
271
  {"role": msg["role"], "content": msg["content"]} for msg in messages
@@ -179,7 +280,12 @@ class MistralAPI(VendorBase):
179
280
  )
180
281
  result = response.choices[0].message.content
181
282
  parsed = json.loads(result)
182
- return response_model(**parsed)
283
+ lm_response = BaseLMResponse(
284
+ raw_response="",
285
+ structured_output=response_model(**parsed),
286
+ tool_calls=None,
287
+ )
288
+ return lm_response
183
289
  except (json.JSONDecodeError, pydantic.ValidationError):
184
290
  print("WARNING - Falling back to OpenAI - THIS IS SLOW")
185
291
  if self._openai_fallback is None:
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Any, Dict, List, Tuple, Type
2
+ from typing import Any, Dict, List, Optional, Tuple, Type
3
3
 
4
4
  import openai
5
5
  import pydantic_core
@@ -8,6 +8,8 @@ import pydantic_core
8
8
  from pydantic import BaseModel
9
9
 
10
10
  from synth_ai.zyk.lms.caching.initialize import get_cache_handler
11
+ from synth_ai.zyk.lms.tools.base import BaseTool
12
+ from synth_ai.zyk.lms.vendors.base import BaseLMResponse
11
13
  from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
12
14
  from synth_ai.zyk.lms.vendors.openai_standard import OpenAIStandard
13
15
 
@@ -46,8 +48,11 @@ class OpenAIStructuredOutputClient(OpenAIStandard):
46
48
  response_model: BaseModel,
47
49
  temperature: float,
48
50
  use_ephemeral_cache_only: bool = False,
51
+ tools: Optional[List[BaseTool]] = None,
49
52
  reasoning_effort: str = "high",
50
53
  ) -> str:
54
+ if tools:
55
+ raise ValueError("Tools are not supported for async structured output")
51
56
  # "Hit client")
52
57
  lm_config = {"temperature": temperature, "response_model": response_model}
53
58
  used_cache_handler = get_cache_handler(
@@ -58,38 +63,40 @@ class OpenAIStructuredOutputClient(OpenAIStandard):
58
63
  )
59
64
  if cache_result:
60
65
  # print("Hit cache")
66
+ assert type(cache_result) in [
67
+ dict,
68
+ BaseLMResponse,
69
+ ], f"Expected dict or BaseLMResponse, got {type(cache_result)}"
61
70
  return (
62
- cache_result["response"]
63
- if isinstance(cache_result, dict)
64
- else cache_result
71
+ cache_result["response"] if type(cache_result) == dict else cache_result
65
72
  )
66
-
67
- # Common API call params
68
- api_params = {
69
- "model": model,
70
- "messages": messages,
71
- "response_format": response_model,
72
- }
73
-
74
- # Only add temperature for non o1/o3 models
75
- if not any(prefix in model for prefix in ["o1-", "o3-"]):
76
- api_params["temperature"] = lm_config.get(
77
- "temperature", SPECIAL_BASE_TEMPS.get(model, 0)
73
+ if model in ["o3-mini", "o3", "o1-mini", "o1"]:
74
+ output = await self.async_client.beta.chat.completions.parse(
75
+ model=model,
76
+ messages=messages,
77
+ temperature=lm_config.get(
78
+ "temperature", SPECIAL_BASE_TEMPS.get(model, 0)
79
+ ),
80
+ response_format=response_model,
81
+ reasoning_effort=reasoning_effort,
82
+ )
83
+ else:
84
+ output = await self.async_client.beta.chat.completions.parse(
85
+ model=model,
86
+ messages=messages,
87
+ response_format=response_model,
78
88
  )
79
-
80
- # Add reasoning_effort only for o3-mini
81
- if "o3-mini" in model:
82
- #print("Reasoning effort:", reasoning_effort)
83
- api_params["reasoning_effort"] = reasoning_effort
84
-
85
- output = await self.async_client.beta.chat.completions.parse(**api_params)
86
-
87
89
  # "Output", output)
88
90
  api_result = response_model(**json.loads(output.choices[0].message.content))
91
+ lm_response = BaseLMResponse(
92
+ raw_response="",
93
+ structured_output=api_result,
94
+ tool_calls=None,
95
+ )
89
96
  used_cache_handler.add_to_managed_cache(
90
- model, messages, lm_config, output=output.choices[0].message.content
97
+ model, messages, lm_config, output=lm_response
91
98
  )
92
- return api_result
99
+ return lm_response
93
100
 
94
101
  def _hit_api_sync_structured_output(
95
102
  self,
@@ -98,8 +105,11 @@ class OpenAIStructuredOutputClient(OpenAIStandard):
98
105
  response_model: BaseModel,
99
106
  temperature: float,
100
107
  use_ephemeral_cache_only: bool = False,
108
+ tools: Optional[List[BaseTool]] = None,
101
109
  reasoning_effort: str = "high",
102
110
  ) -> str:
111
+ if tools:
112
+ raise ValueError("Tools are not supported for sync structured output")
103
113
  lm_config = {"temperature": temperature, "response_model": response_model}
104
114
  used_cache_handler = get_cache_handler(
105
115
  use_ephemeral_cache_only=use_ephemeral_cache_only
@@ -108,39 +118,40 @@ class OpenAIStructuredOutputClient(OpenAIStandard):
108
118
  model, messages, lm_config=lm_config
109
119
  )
110
120
  if cache_result:
121
+ assert type(cache_result) in [
122
+ dict,
123
+ BaseLMResponse,
124
+ ], f"Expected dict or BaseLMResponse, got {type(cache_result)}"
111
125
  return (
112
- cache_result["response"]
113
- if isinstance(cache_result, dict)
114
- else cache_result
126
+ cache_result["response"] if type(cache_result) == dict else cache_result
115
127
  )
116
-
117
- # Common API call params
118
- api_params = {
119
- "model": model,
120
- "messages": messages,
121
- "response_format": response_model,
122
- }
123
-
124
- # Only add temperature for non o1/o3 models
125
- if not any(prefix in model for prefix in ["o1-", "o3-"]):
126
- api_params["temperature"] = lm_config.get(
127
- "temperature", SPECIAL_BASE_TEMPS.get(model, 0)
128
+ if model in ["o3-mini", "o3", "o1-mini", "o1"]:
129
+ output = self.sync_client.beta.chat.completions.parse(
130
+ model=model,
131
+ messages=messages,
132
+ temperature=lm_config.get(
133
+ "temperature", SPECIAL_BASE_TEMPS.get(model, 0)
134
+ ),
135
+ response_format=response_model,
136
+ reasoning_effort=reasoning_effort,
137
+ )
138
+ else:
139
+ output = self.sync_client.beta.chat.completions.parse(
140
+ model=model,
141
+ messages=messages,
142
+ response_format=response_model,
128
143
  )
129
-
130
- # Add reasoning_effort only for o3-mini
131
- if model in ["o3-mini"]:
132
- api_params["reasoning_effort"] = reasoning_effort
133
-
134
- output = self.sync_client.beta.chat.completions.parse(**api_params)
135
-
136
144
  api_result = response_model(**json.loads(output.choices[0].message.content))
145
+
146
+ lm_response = BaseLMResponse(
147
+ raw_response="",
148
+ structured_output=api_result,
149
+ tool_calls=None,
150
+ )
137
151
  used_cache_handler.add_to_managed_cache(
138
- model,
139
- messages,
140
- lm_config=lm_config,
141
- output=output.choices[0].message.content,
152
+ model, messages, lm_config=lm_config, output=lm_response
142
153
  )
143
- return api_result
154
+ return lm_response
144
155
 
145
156
 
146
157
  class OpenAIPrivate(OpenAIStandard):