synth-ai 0.1.0.dev28__py3-none-any.whl → 0.1.0.dev30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. public_tests/test_agent.py +11 -11
  2. public_tests/test_all_structured_outputs.py +32 -37
  3. public_tests/test_anthropic_structured_outputs.py +0 -0
  4. public_tests/test_deepseek_structured_outputs.py +0 -0
  5. public_tests/test_deepseek_tools.py +64 -0
  6. public_tests/test_gemini_structured_outputs.py +106 -0
  7. public_tests/test_models.py +27 -27
  8. public_tests/test_openai_structured_outputs.py +106 -0
  9. public_tests/test_reasoning_models.py +9 -7
  10. public_tests/test_recursive_structured_outputs.py +30 -30
  11. public_tests/test_structured.py +137 -0
  12. public_tests/test_structured_outputs.py +22 -13
  13. public_tests/test_text.py +160 -0
  14. public_tests/test_tools.py +300 -0
  15. synth_ai/__init__.py +1 -4
  16. synth_ai/zyk/__init__.py +2 -2
  17. synth_ai/zyk/lms/caching/ephemeral.py +54 -32
  18. synth_ai/zyk/lms/caching/handler.py +43 -15
  19. synth_ai/zyk/lms/caching/persistent.py +55 -27
  20. synth_ai/zyk/lms/core/main.py +29 -16
  21. synth_ai/zyk/lms/core/vendor_clients.py +1 -1
  22. synth_ai/zyk/lms/structured_outputs/handler.py +79 -45
  23. synth_ai/zyk/lms/structured_outputs/rehabilitate.py +3 -2
  24. synth_ai/zyk/lms/tools/base.py +104 -0
  25. synth_ai/zyk/lms/vendors/base.py +22 -6
  26. synth_ai/zyk/lms/vendors/core/anthropic_api.py +130 -95
  27. synth_ai/zyk/lms/vendors/core/gemini_api.py +153 -34
  28. synth_ai/zyk/lms/vendors/core/mistral_api.py +160 -54
  29. synth_ai/zyk/lms/vendors/core/openai_api.py +64 -53
  30. synth_ai/zyk/lms/vendors/openai_standard.py +197 -41
  31. synth_ai/zyk/lms/vendors/supported/deepseek.py +55 -0
  32. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/METADATA +2 -5
  33. synth_ai-0.1.0.dev30.dist-info/RECORD +65 -0
  34. public_tests/test_sonnet_thinking.py +0 -217
  35. synth_ai-0.1.0.dev28.dist-info/RECORD +0 -57
  36. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/WHEEL +0 -0
  37. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/licenses/LICENSE +0 -0
  38. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,8 @@ from synth_ai.zyk.lms.core.vendor_clients import (
9
9
  openai_naming_regexes,
10
10
  )
11
11
  from synth_ai.zyk.lms.structured_outputs.handler import StructuredOutputHandler
12
- from synth_ai.zyk.lms.vendors.base import VendorBase
12
+ from synth_ai.zyk.lms.vendors.base import VendorBase, BaseLMResponse
13
+ from synth_ai.zyk.lms.tools.base import BaseTool
13
14
 
14
15
  REASONING_MODELS = ["deepseek-reasoner", "o1-mini", "o1-preview", "o1", "o3"]
15
16
 
@@ -120,22 +121,23 @@ class LM:
120
121
  images_as_bytes: List[Any] = [],
121
122
  response_model: Optional[BaseModel] = None,
122
123
  use_ephemeral_cache_only: bool = False,
123
- ):
124
+ tools: Optional[List[BaseTool]] = None,
125
+ ) -> BaseLMResponse:
124
126
  assert (system_message is None) == (
125
127
  user_message is None
126
128
  ), "Must provide both system_message and user_message or neither"
127
129
  assert (
128
130
  (messages is None) != (system_message is None)
129
131
  ), "Must provide either messages or system_message/user_message pair, but not both"
130
-
132
+ assert not (response_model and tools), "Cannot provide both response_model and tools"
131
133
  if messages is None:
132
134
  messages = build_messages(
133
135
  system_message, user_message, images_as_bytes, self.model_name
134
136
  )
135
-
137
+ result = None
136
138
  if response_model:
137
139
  try:
138
- return self.structured_output_handler.call_sync(
140
+ result = self.structured_output_handler.call_sync(
139
141
  messages,
140
142
  model=self.model_name,
141
143
  lm_config=self.lm_config,
@@ -144,7 +146,7 @@ class LM:
144
146
  )
145
147
  except StructuredOutputCoercionFailureException:
146
148
  # print("Falling back to backup handler")
147
- return self.backup_structured_output_handler.call_sync(
149
+ result = self.backup_structured_output_handler.call_sync(
148
150
  messages,
149
151
  model=self.model_name,
150
152
  lm_config=self.lm_config,
@@ -152,12 +154,17 @@ class LM:
152
154
  use_ephemeral_cache_only=use_ephemeral_cache_only,
153
155
  )
154
156
  else:
155
- return self.client._hit_api_sync(
157
+ result = self.client._hit_api_sync(
156
158
  messages=messages,
157
159
  model=self.model_name,
158
160
  lm_config=self.lm_config,
159
161
  use_ephemeral_cache_only=use_ephemeral_cache_only,
162
+ tools=tools,
160
163
  )
164
+ assert isinstance(result.raw_response, str), "Raw response must be a string"
165
+ assert (isinstance(result.structured_output, BaseModel) or result.structured_output is None), "Structured output must be a Pydantic model or None"
166
+ assert (isinstance(result.tool_calls, list) or result.tool_calls is None), "Tool calls must be a list or None"
167
+ return result
161
168
 
162
169
  async def respond_async(
163
170
  self,
@@ -167,7 +174,8 @@ class LM:
167
174
  images_as_bytes: List[Any] = [],
168
175
  response_model: Optional[BaseModel] = None,
169
176
  use_ephemeral_cache_only: bool = False,
170
- ):
177
+ tools: Optional[List[BaseTool]] = None,
178
+ ) -> BaseLMResponse:
171
179
  # "In respond_async")
172
180
  assert (system_message is None) == (
173
181
  user_message is None
@@ -176,15 +184,16 @@ class LM:
176
184
  (messages is None) != (system_message is None)
177
185
  ), "Must provide either messages or system_message/user_message pair, but not both"
178
186
 
187
+ assert not (response_model and tools), "Cannot provide both response_model and tools"
179
188
  if messages is None:
180
189
  messages = build_messages(
181
190
  system_message, user_message, images_as_bytes, self.model_name
182
191
  )
183
-
192
+ result = None
184
193
  if response_model:
185
194
  try:
186
- # "Trying structured output handler")
187
- return await self.structured_output_handler.call_async(
195
+ #print("Trying structured output handler")
196
+ result = await self.structured_output_handler.call_async(
188
197
  messages,
189
198
  model=self.model_name,
190
199
  lm_config=self.lm_config,
@@ -192,8 +201,8 @@ class LM:
192
201
  use_ephemeral_cache_only=use_ephemeral_cache_only,
193
202
  )
194
203
  except StructuredOutputCoercionFailureException:
195
- # print("Falling back to backup handler")
196
- return await self.backup_structured_output_handler.call_async(
204
+ #print("Falling back to backup handler")
205
+ result = await self.backup_structured_output_handler.call_async(
197
206
  messages,
198
207
  model=self.model_name,
199
208
  lm_config=self.lm_config,
@@ -201,14 +210,18 @@ class LM:
201
210
  use_ephemeral_cache_only=use_ephemeral_cache_only,
202
211
  )
203
212
  else:
204
- # print("Calling API no response model")
205
- return await self.client._hit_api_async(
213
+ #print("Calling API no response model")
214
+ result = await self.client._hit_api_async(
206
215
  messages=messages,
207
216
  model=self.model_name,
208
217
  lm_config=self.lm_config,
209
218
  use_ephemeral_cache_only=use_ephemeral_cache_only,
219
+ tools=tools,
210
220
  )
211
-
221
+ assert isinstance(result.raw_response, str), "Raw response must be a string"
222
+ assert (isinstance(result.structured_output, BaseModel) or result.structured_output is None), "Structured output must be a Pydantic model or None"
223
+ assert (isinstance(result.tool_calls, list) or result.tool_calls is None), "Tool calls must be a list or None"
224
+ return result
212
225
 
213
226
  if __name__ == "__main__":
214
227
  import asyncio
@@ -6,8 +6,8 @@ from synth_ai.zyk.lms.core.all import (
6
6
  DeepSeekClient,
7
7
  GeminiClient,
8
8
  GroqAPI,
9
- # OpenAIClient,
10
9
  MistralAPI,
10
+ # OpenAIClient,
11
11
  OpenAIStructuredOutputClient,
12
12
  TogetherClient,
13
13
  )
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  import time
2
3
  from abc import ABC, abstractmethod
3
4
  from typing import Any, Callable, Dict, List, Literal, Optional, Union
@@ -11,13 +12,13 @@ from synth_ai.zyk.lms.structured_outputs.inject import (
11
12
  from synth_ai.zyk.lms.structured_outputs.rehabilitate import (
12
13
  fix_errant_forced_async,
13
14
  fix_errant_forced_sync,
14
- fix_errant_stringified_json_async,
15
- fix_errant_stringified_json_sync,
16
15
  pull_out_structured_output,
17
16
  )
18
- from synth_ai.zyk.lms.vendors.base import VendorBase
17
+ from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
19
18
  from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
20
19
 
20
+ logger = logging.getLogger(__name__)
21
+
21
22
 
22
23
  class StructuredHandlerBase(ABC):
23
24
  core_client: VendorBase
@@ -49,7 +50,7 @@ class StructuredHandlerBase(ABC):
49
50
  temperature: float = 0.0,
50
51
  use_ephemeral_cache_only: bool = False,
51
52
  reasoning_effort: str = "high",
52
- ) -> BaseModel:
53
+ ) -> BaseLMResponse:
53
54
  if temperature == 0.0:
54
55
  temperature = SPECIAL_BASE_TEMPS.get(model, 0.0)
55
56
  # print("Calling from base")
@@ -73,7 +74,7 @@ class StructuredHandlerBase(ABC):
73
74
  temperature: float = 0.0,
74
75
  use_ephemeral_cache_only: bool = False,
75
76
  reasoning_effort: str = "high",
76
- ) -> BaseModel:
77
+ ) -> BaseLMResponse:
77
78
  if temperature == 0.0:
78
79
  temperature = SPECIAL_BASE_TEMPS.get(model, 0.0)
79
80
  return self._process_call_sync(
@@ -97,7 +98,7 @@ class StructuredHandlerBase(ABC):
97
98
  api_call_method,
98
99
  use_ephemeral_cache_only: bool = False,
99
100
  reasoning_effort: str = "high",
100
- ) -> BaseModel:
101
+ ) -> BaseLMResponse:
101
102
  pass
102
103
 
103
104
  @abstractmethod
@@ -109,7 +110,7 @@ class StructuredHandlerBase(ABC):
109
110
  api_call_method,
110
111
  use_ephemeral_cache_only: bool = False,
111
112
  reasoning_effort: str = "high",
112
- ) -> BaseModel:
113
+ ) -> BaseLMResponse:
113
114
  pass
114
115
 
115
116
 
@@ -140,11 +141,9 @@ class StringifiedJSONHandler(StructuredHandlerBase):
140
141
  api_call_method: Callable,
141
142
  use_ephemeral_cache_only: bool = False,
142
143
  reasoning_effort: str = "high",
143
- ) -> BaseModel:
144
- # print("In _process_call_async")
145
- assert isinstance(
146
- api_call_method, Callable
147
- ), "api_call_method must be a callable"
144
+ ) -> BaseLMResponse:
145
+ logger.info(f"Processing structured output call for model: {model}")
146
+ assert callable(api_call_method), "api_call_method must be a callable"
148
147
  assert (
149
148
  response_model is not None
150
149
  ), "Don't use this handler for unstructured outputs"
@@ -168,43 +167,63 @@ class StringifiedJSONHandler(StructuredHandlerBase):
168
167
  use_ephemeral_cache_only=use_ephemeral_cache_only,
169
168
  reasoning_effort=reasoning_effort,
170
169
  )
171
- # print(f"Time to get response: {time.time() - t0}")
172
- if not isinstance(raw_text_response_or_cached_hit, str):
173
- return raw_text_response_or_cached_hit
170
+ logger.debug(f"Time to get response: {time.time() - t0:.2f}s")
171
+
172
+ # Check if we got a cached BaseLMResponse
173
+ assert (
174
+ type(raw_text_response_or_cached_hit) in [str, BaseLMResponse]
175
+ ), f"Expected str or BaseLMResponse, got {type(raw_text_response_or_cached_hit)}"
176
+ if type(raw_text_response_or_cached_hit) == BaseLMResponse:
177
+ print("Got cached hit, returning directly")
178
+ raw_text_response = raw_text_response_or_cached_hit.raw_response
174
179
  else:
175
180
  raw_text_response = raw_text_response_or_cached_hit
181
+ logger.debug(f"Raw response from model:\n{raw_text_response}")
182
+
183
+ print("Trying to parse structured output")
176
184
  try:
177
185
  structured_output = pull_out_structured_output(
178
186
  raw_text_response, response_model
179
187
  )
188
+
189
+ print("Successfully parsed structured output on first attempt")
180
190
  break
181
- # except Exception as e:
182
- # try:
183
- # structured_output = await fix_errant_stringified_json_async(raw_text_response, response_model)
184
- # break
185
191
  except Exception as e:
192
+ logger.warning(f"Failed to parse structured output: {str(e)}")
186
193
  try:
187
- # t0 = time.time()
188
- # print(f"Got error {e}, attempting to fix")
194
+ print("Attempting to fix with forced JSON parser")
189
195
  structured_output = await fix_errant_forced_async(
190
196
  messages_with_json_formatting_instructions,
191
197
  raw_text_response,
192
198
  response_model,
193
199
  "gpt-4o-mini",
194
200
  )
195
-
196
- # print(f"Time to fix: {time.time() - t0}")
201
+ assert isinstance(structured_output, BaseModel), "Structured output must be a Pydantic model"
202
+ assert not isinstance(structured_output, BaseLMResponse), "Got BaseLMResponse instead of Pydantic model"
203
+ print("Successfully fixed and parsed structured output")
197
204
  break
198
205
  except Exception as e:
206
+ logger.error(f"Failed to fix structured output: {str(e)}")
199
207
  previously_failed_error_messages.append(
200
208
  f"Generated attempt and got error. Attempt:\n\n{raw_text_response}\n\nError:\n\n{e}"
201
209
  )
202
210
  remaining_retries -= 1
211
+ logger.warning(f"Retries remaining: {remaining_retries}")
212
+
203
213
  if structured_output is None:
214
+ logger.error("Failed to get structured output after all retries")
204
215
  raise StructuredOutputCoercionFailureException(
205
216
  "Failed to get structured output"
206
217
  )
207
- return structured_output
218
+ print("Successfully parsed structured output")
219
+ print(structured_output)
220
+ assert isinstance(structured_output, BaseModel), "Structured output must be a Pydantic model"
221
+ assert not isinstance(structured_output, BaseLMResponse),"Got BaseLMResponse instead of Pydantic model"
222
+ return BaseLMResponse(
223
+ raw_response=raw_text_response,
224
+ structured_output=structured_output,
225
+ tool_calls=None,
226
+ )
208
227
 
209
228
  def _process_call_sync(
210
229
  self,
@@ -215,10 +234,9 @@ class StringifiedJSONHandler(StructuredHandlerBase):
215
234
  api_call_method: Callable,
216
235
  use_ephemeral_cache_only: bool = False,
217
236
  reasoning_effort: str = "high",
218
- ) -> BaseModel:
219
- assert isinstance(
220
- api_call_method, Callable
221
- ), "api_call_method must be a callable"
237
+ ) -> BaseLMResponse:
238
+ logger.info(f"Processing structured output call for model: {model}")
239
+ assert callable(api_call_method), "api_call_method must be a callable"
222
240
  assert (
223
241
  response_model is not None
224
242
  ), "Don't use this handler for unstructured outputs"
@@ -234,7 +252,7 @@ class StringifiedJSONHandler(StructuredHandlerBase):
234
252
  previously_failed_error_messages=previously_failed_error_messages,
235
253
  )
236
254
  )
237
- # t0 = time.time()
255
+ t0 = time.time()
238
256
  raw_text_response_or_cached_hit = api_call_method(
239
257
  messages=messages_with_json_formatting_instructions,
240
258
  model=model,
@@ -242,39 +260,54 @@ class StringifiedJSONHandler(StructuredHandlerBase):
242
260
  use_ephemeral_cache_only=use_ephemeral_cache_only,
243
261
  reasoning_effort=reasoning_effort,
244
262
  )
245
- # print(f"Time to get response: {time.time() - t0}")
246
- if not isinstance(raw_text_response_or_cached_hit, str):
247
- return raw_text_response_or_cached_hit
248
- else:
263
+ logger.debug(f"Time to get response: {time.time() - t0:.2f}s")
264
+
265
+ # Check if we got a cached BaseLMResponse
266
+ assert (
267
+ type(raw_text_response_or_cached_hit) in [str, BaseLMResponse]
268
+ ), f"Expected str or BaseLMResponse, got {type(raw_text_response_or_cached_hit)}"
269
+ if type(raw_text_response_or_cached_hit) == BaseLMResponse:
270
+ logger.info("Got cached hit, returning directly")
271
+ raw_text_response = raw_text_response_or_cached_hit.raw_response
272
+ else:
249
273
  raw_text_response = raw_text_response_or_cached_hit
274
+ logger.debug(f"Raw response from model:\n{raw_text_response}")
275
+
250
276
  try:
251
277
  structured_output = pull_out_structured_output(
252
278
  raw_text_response, response_model
253
279
  )
280
+ print("Successfully parsed structured output on first attempt")
254
281
  break
255
- # except Exception:
256
- # try:
257
- # structured_output = fix_errant_stringified_json_sync(raw_text_response, response_model)
258
- # break
259
282
  except Exception as e:
283
+ logger.warning(f"Failed to parse structured output: {str(e)}")
260
284
  try:
261
- # t0 = time.time()
262
- # print(f"Got error {e}, attempting to fix")
285
+ print("Attempting to fix with forced JSON parser")
263
286
  structured_output = fix_errant_forced_sync(
264
287
  raw_text_response, response_model, "gpt-4o-mini"
265
288
  )
289
+ print("Successfully fixed and parsed structured output")
266
290
  break
267
- # print(f"Time to fix: {time.time() - t0}")
268
291
  except Exception as e:
292
+ logger.error(f"Failed to fix structured output: {str(e)}")
269
293
  previously_failed_error_messages.append(
270
294
  f"Generated attempt and got error. Attempt:\n\n{raw_text_response}\n\nError:\n\n{e}"
271
295
  )
272
296
  remaining_retries -= 1
297
+ logger.warning(f"Retries remaining: {remaining_retries}")
298
+
299
+ print("Successfully parsed structured output")
300
+ print(structured_output)
273
301
  if structured_output is None:
302
+ logger.error("Failed to get structured output after all retries")
274
303
  raise StructuredOutputCoercionFailureException(
275
304
  "Failed to get structured output"
276
305
  )
277
- return structured_output
306
+ return BaseLMResponse(
307
+ raw_response=raw_text_response,
308
+ structured_output=structured_output,
309
+ tool_calls=None,
310
+ )
278
311
 
279
312
 
280
313
  class ForcedJSONHandler(StructuredHandlerBase):
@@ -296,6 +329,7 @@ class ForcedJSONHandler(StructuredHandlerBase):
296
329
  structured_output_mode="forced_json",
297
330
  )
298
331
  self.reasoning_effort = reasoning_effort
332
+
299
333
  async def _process_call_async(
300
334
  self,
301
335
  messages: List[Dict[str, Any]],
@@ -305,7 +339,7 @@ class ForcedJSONHandler(StructuredHandlerBase):
305
339
  temperature: float = 0.0,
306
340
  use_ephemeral_cache_only: bool = False,
307
341
  reasoning_effort: str = "high",
308
- ) -> BaseModel:
342
+ ) -> BaseLMResponse:
309
343
  # print("Forced JSON")
310
344
  assert (
311
345
  response_model is not None
@@ -328,7 +362,7 @@ class ForcedJSONHandler(StructuredHandlerBase):
328
362
  temperature: float = 0.0,
329
363
  use_ephemeral_cache_only: bool = False,
330
364
  reasoning_effort: str = "high",
331
- ) -> BaseModel:
365
+ ) -> BaseLMResponse:
332
366
  assert (
333
367
  response_model is not None
334
368
  ), "Don't use this handler for unstructured outputs"
@@ -373,7 +407,7 @@ class StructuredOutputHandler:
373
407
  use_ephemeral_cache_only: bool = False,
374
408
  lm_config: Dict[str, Any] = {},
375
409
  reasoning_effort: str = "high",
376
- ) -> BaseModel:
410
+ ) -> BaseLMResponse:
377
411
  # print("Output handler call async")
378
412
  return await self.handler.call_async(
379
413
  messages=messages,
@@ -394,7 +428,7 @@ class StructuredOutputHandler:
394
428
  use_ephemeral_cache_only: bool = False,
395
429
  lm_config: Dict[str, Any] = {},
396
430
  reasoning_effort: str = "high",
397
- ) -> BaseModel:
431
+ ) -> BaseLMResponse:
398
432
  return self.handler.call_sync(
399
433
  messages=messages,
400
434
  model=model,
@@ -55,6 +55,7 @@ def pull_out_structured_output(
55
55
  raise ValueError(
56
56
  f"Failed to parse response as {response_model}: {e} - {response_prepared}"
57
57
  )
58
+ assert isinstance(final, BaseModel), "Structured output must be a Pydantic model"
58
59
  return final
59
60
 
60
61
 
@@ -157,7 +158,7 @@ async def fix_errant_forced_async(
157
158
  )
158
159
  # print("Fixed response:")
159
160
  # print(fixed_response)
160
- return fixed_response
161
+ return fixed_response.structured_output
161
162
 
162
163
 
163
164
  def fix_errant_forced_sync(
@@ -183,4 +184,4 @@ def fix_errant_forced_sync(
183
184
  )
184
185
  # print("Fixed response:")
185
186
  # print(fixed_response)
186
- return fixed_response
187
+ return fixed_response.structured_output
@@ -0,0 +1,104 @@
1
+ from typing import Type
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class BaseTool(BaseModel):
7
+ name: str
8
+ arguments: Type[BaseModel]
9
+ description: str = ""
10
+ strict: bool = True
11
+
12
+ def to_openai_tool(self):
13
+ schema = self.arguments.model_json_schema()
14
+ schema["additionalProperties"] = False
15
+
16
+ return {
17
+ "type": "function",
18
+ "function": {
19
+ "name": self.name,
20
+ "description": self.description,
21
+ "parameters": schema,
22
+ "strict": self.strict,
23
+ },
24
+ }
25
+
26
+ def to_anthropic_tool(self):
27
+ schema = self.arguments.model_json_schema()
28
+ schema["additionalProperties"] = False
29
+
30
+ return {
31
+ "name": self.name,
32
+ "description": self.description,
33
+ "input_schema": {
34
+ "type": "object",
35
+ "properties": schema["properties"],
36
+ "required": schema.get("required", []),
37
+ },
38
+ }
39
+
40
+ def to_mistral_tool(self):
41
+ schema = self.arguments.model_json_schema()
42
+ properties = {}
43
+ for prop_name, prop in schema.get("properties", {}).items():
44
+ prop_type = prop["type"]
45
+ if prop_type == "array" and "items" in prop:
46
+ properties[prop_name] = {
47
+ "type": "array",
48
+ "items": prop["items"],
49
+ "description": prop.get("description", ""),
50
+ }
51
+ continue
52
+
53
+ properties[prop_name] = {
54
+ "type": prop_type,
55
+ "description": prop.get("description", ""),
56
+ }
57
+ if "enum" in prop:
58
+ properties[prop_name]["enum"] = prop["enum"]
59
+
60
+ parameters = {
61
+ "type": "object",
62
+ "properties": properties,
63
+ "required": schema.get("required", []),
64
+ "additionalProperties": False,
65
+ }
66
+ return {
67
+ "type": "function",
68
+ "function": {
69
+ "name": self.name,
70
+ "description": self.description,
71
+ "parameters": parameters,
72
+ },
73
+ }
74
+
75
+ def to_gemini_tool(self):
76
+ schema = self.arguments.model_json_schema()
77
+ # Convert Pydantic schema types to Gemini schema types
78
+ properties = {}
79
+ for name, prop in schema["properties"].items():
80
+ prop_type = prop.get("type", "string")
81
+ if prop_type == "array" and "items" in prop:
82
+ properties[name] = {
83
+ "type": "array",
84
+ "items": prop["items"],
85
+ "description": prop.get("description", ""),
86
+ }
87
+ continue
88
+
89
+ properties[name] = {
90
+ "type": prop_type,
91
+ "description": prop.get("description", ""),
92
+ }
93
+ if "enum" in prop:
94
+ properties[name]["enum"] = prop["enum"]
95
+
96
+ return {
97
+ "name": self.name,
98
+ "description": self.description,
99
+ "parameters": {
100
+ "type": "object",
101
+ "properties": properties,
102
+ "required": schema.get("required", []),
103
+ },
104
+ }
@@ -1,15 +1,31 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import List, Dict, Any, Optional
2
+ from typing import Any, Dict, List, Optional
3
+
3
4
  from pydantic import BaseModel
4
5
 
6
+
7
+ class BaseLMResponse(BaseModel):
8
+ raw_response: str
9
+ structured_output: Optional[BaseModel] = None
10
+ tool_calls: Optional[List[Dict]] = None
11
+
12
+
5
13
  class VendorBase(ABC):
6
14
  used_for_structured_outputs: bool = False
7
15
  exceptions_to_retry: List[Exception] = []
8
-
16
+
9
17
  @abstractmethod
10
- async def _hit_api_async(self, messages: List[Dict[str, Any]], response_model_override: Optional[BaseModel] = None) -> str:
18
+ async def _hit_api_async(
19
+ self,
20
+ messages: List[Dict[str, Any]],
21
+ response_model_override: Optional[BaseModel] = None,
22
+ ) -> str:
11
23
  pass
12
-
24
+
13
25
  @abstractmethod
14
- def _hit_api_sync(self, messages: List[Dict[str, Any]], response_model_override: Optional[BaseModel] = None) -> str:
15
- pass
26
+ def _hit_api_sync(
27
+ self,
28
+ messages: List[Dict[str, Any]],
29
+ response_model_override: Optional[BaseModel] = None,
30
+ ) -> str:
31
+ pass