synth-ai 0.1.0.dev28__py3-none-any.whl → 0.1.0.dev30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. public_tests/test_agent.py +11 -11
  2. public_tests/test_all_structured_outputs.py +32 -37
  3. public_tests/test_anthropic_structured_outputs.py +0 -0
  4. public_tests/test_deepseek_structured_outputs.py +0 -0
  5. public_tests/test_deepseek_tools.py +64 -0
  6. public_tests/test_gemini_structured_outputs.py +106 -0
  7. public_tests/test_models.py +27 -27
  8. public_tests/test_openai_structured_outputs.py +106 -0
  9. public_tests/test_reasoning_models.py +9 -7
  10. public_tests/test_recursive_structured_outputs.py +30 -30
  11. public_tests/test_structured.py +137 -0
  12. public_tests/test_structured_outputs.py +22 -13
  13. public_tests/test_text.py +160 -0
  14. public_tests/test_tools.py +300 -0
  15. synth_ai/__init__.py +1 -4
  16. synth_ai/zyk/__init__.py +2 -2
  17. synth_ai/zyk/lms/caching/ephemeral.py +54 -32
  18. synth_ai/zyk/lms/caching/handler.py +43 -15
  19. synth_ai/zyk/lms/caching/persistent.py +55 -27
  20. synth_ai/zyk/lms/core/main.py +29 -16
  21. synth_ai/zyk/lms/core/vendor_clients.py +1 -1
  22. synth_ai/zyk/lms/structured_outputs/handler.py +79 -45
  23. synth_ai/zyk/lms/structured_outputs/rehabilitate.py +3 -2
  24. synth_ai/zyk/lms/tools/base.py +104 -0
  25. synth_ai/zyk/lms/vendors/base.py +22 -6
  26. synth_ai/zyk/lms/vendors/core/anthropic_api.py +130 -95
  27. synth_ai/zyk/lms/vendors/core/gemini_api.py +153 -34
  28. synth_ai/zyk/lms/vendors/core/mistral_api.py +160 -54
  29. synth_ai/zyk/lms/vendors/core/openai_api.py +64 -53
  30. synth_ai/zyk/lms/vendors/openai_standard.py +197 -41
  31. synth_ai/zyk/lms/vendors/supported/deepseek.py +55 -0
  32. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/METADATA +2 -5
  33. synth_ai-0.1.0.dev30.dist-info/RECORD +65 -0
  34. public_tests/test_sonnet_thinking.py +0 -217
  35. synth_ai-0.1.0.dev28.dist-info/RECORD +0 -57
  36. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/WHEEL +0 -0
  37. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/licenses/LICENSE +0 -0
  38. {synth_ai-0.1.0.dev28.dist-info → synth_ai-0.1.0.dev30.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Any, Dict, List, Tuple, Type
2
+ from typing import Any, Dict, List, Optional, Tuple, Type
3
3
 
4
4
  import anthropic
5
5
  import pydantic
@@ -8,14 +8,20 @@ from pydantic import BaseModel
8
8
  from synth_ai.zyk.lms.caching.initialize import (
9
9
  get_cache_handler,
10
10
  )
11
- from synth_ai.zyk.lms.vendors.base import VendorBase
11
+ from synth_ai.zyk.lms.tools.base import BaseTool
12
+ from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
12
13
  from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
13
14
  from synth_ai.zyk.lms.vendors.core.openai_api import OpenAIStructuredOutputClient
14
- from synth_ai.zyk.lms.vendors.retries import BACKOFF_TOLERANCE, backoff
15
15
 
16
16
  ANTHROPIC_EXCEPTIONS_TO_RETRY: Tuple[Type[Exception], ...] = (anthropic.APIError,)
17
17
 
18
18
 
19
+ sonnet_37_budgets = {
20
+ "high": 4000,
21
+ "medium": 2000,
22
+ "low": 1000,
23
+ }
24
+
19
25
  class AnthropicAPI(VendorBase):
20
26
  used_for_structured_outputs: bool = True
21
27
  exceptions_to_retry: Tuple = ANTHROPIC_EXCEPTIONS_TO_RETRY
@@ -37,12 +43,12 @@ class AnthropicAPI(VendorBase):
37
43
  self._openai_fallback = None
38
44
  self.reasoning_effort = reasoning_effort
39
45
 
40
- @backoff.on_exception(
41
- backoff.expo,
42
- exceptions_to_retry,
43
- max_tries=BACKOFF_TOLERANCE,
44
- on_giveup=lambda e: print(e),
45
- )
46
+ # @backoff.on_exception(
47
+ # backoff.expo,
48
+ # exceptions_to_retry,
49
+ # max_tries=BACKOFF_TOLERANCE,
50
+ # on_giveup=lambda e: print(e),
51
+ # )
46
52
  async def _hit_api_async(
47
53
  self,
48
54
  model: str,
@@ -50,83 +56,90 @@ class AnthropicAPI(VendorBase):
50
56
  lm_config: Dict[str, Any],
51
57
  use_ephemeral_cache_only: bool = False,
52
58
  reasoning_effort: str = "high",
59
+ tools: Optional[List[BaseTool]] = None,
53
60
  **vendor_params: Dict[str, Any],
54
- ) -> str:
61
+ ) -> BaseLMResponse:
55
62
  assert (
56
63
  lm_config.get("response_model", None) is None
57
64
  ), "response_model is not supported for standard calls"
58
65
  used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
59
66
  cache_result = used_cache_handler.hit_managed_cache(
60
- model, messages, lm_config=lm_config
67
+ model, messages, lm_config=lm_config, tools=tools
61
68
  )
62
69
  if cache_result:
63
- return (
64
- cache_result["response"]
65
- if isinstance(cache_result, dict)
66
- else cache_result
67
- )
70
+ return cache_result
68
71
 
69
72
  # Common API parameters
70
73
  api_params = {
71
74
  "system": messages[0]["content"],
72
75
  "messages": messages[1:],
73
76
  "model": model,
74
- "max_tokens": lm_config.get("max_tokens", 4096 * 2),
77
+ "max_tokens": lm_config.get("max_tokens", 4096),
75
78
  "temperature": lm_config.get(
76
79
  "temperature", SPECIAL_BASE_TEMPS.get(model, 0)
77
80
  ),
78
81
  }
79
82
 
80
- # Handle thinking budget for Claude 3.7
83
+ # Add tools if provided
84
+ if tools:
85
+ api_params["tools"] = [tool.to_anthropic_tool() for tool in tools]
86
+
87
+ # Only try to add thinking if supported by the SDK
81
88
  try:
82
89
  import inspect
83
90
 
84
91
  create_sig = inspect.signature(self.async_client.messages.create)
85
92
  if "thinking" in create_sig.parameters and "claude-3-7" in model:
86
93
  if reasoning_effort in ["high", "medium"]:
87
- budgets = {
88
- "high": 4000,
89
- "medium": 2000,
90
- "low": 1000,
91
- }
92
- budget = budgets[reasoning_effort]
94
+ budget = sonnet_37_budgets[reasoning_effort]
93
95
  api_params["thinking"] = {
94
96
  "type": "enabled",
95
97
  "budget_tokens": budget,
96
98
  }
97
- # Ensure max_tokens is greater than thinking budget
98
- api_params["max_tokens"] = max(
99
- api_params["max_tokens"], budget + 4096
100
- )
101
- # Set temperature to 1 for thinking, but only in API call
102
- api_params["temperature"] = 1.0
99
+ api_params["max_tokens"] = budget+4096
100
+ api_params["temperature"] = 1
103
101
  except (ImportError, AttributeError, TypeError):
104
102
  pass
105
103
 
106
104
  # Make the API call
107
105
  response = await self.async_client.messages.create(**api_params)
108
106
 
109
- # Handle both regular and thinking responses
110
- if hasattr(response.content[0], "text"):
111
- api_result = response.content[0].text
112
- else:
113
- # For thinking responses, get the final output
114
- thinking_blocks = [
115
- block for block in response.content if block.type == "text"
116
- ]
117
- api_result = thinking_blocks[-1].text if thinking_blocks else ""
107
+ # Extract text content and tool calls
108
+ raw_response = ""
109
+ tool_calls = []
110
+
111
+ for content in response.content:
112
+ if content.type == "text":
113
+ raw_response += content.text
114
+ elif content.type == "tool_use":
115
+ tool_calls.append(
116
+ {
117
+ "id": content.id,
118
+ "type": "function",
119
+ "function": {
120
+ "name": content.name,
121
+ "arguments": json.dumps(content.input),
122
+ },
123
+ }
124
+ )
125
+
126
+ lm_response = BaseLMResponse(
127
+ raw_response=raw_response,
128
+ structured_output=None,
129
+ tool_calls=tool_calls if tool_calls else None,
130
+ )
118
131
 
119
132
  used_cache_handler.add_to_managed_cache(
120
- model, messages, lm_config=lm_config, output=api_result
133
+ model, messages, lm_config=lm_config, output=lm_response, tools=tools
121
134
  )
122
- return api_result
123
-
124
- @backoff.on_exception(
125
- backoff.expo,
126
- exceptions_to_retry,
127
- max_tries=BACKOFF_TOLERANCE,
128
- on_giveup=lambda e: print(e),
129
- )
135
+ return lm_response
136
+
137
+ # @backoff.on_exception(
138
+ # backoff.expo,
139
+ # exceptions_to_retry,
140
+ # max_tries=BACKOFF_TOLERANCE,
141
+ # on_giveup=lambda e: print(e),
142
+ # )
130
143
  def _hit_api_sync(
131
144
  self,
132
145
  model: str,
@@ -134,8 +147,9 @@ class AnthropicAPI(VendorBase):
134
147
  lm_config: Dict[str, Any],
135
148
  use_ephemeral_cache_only: bool = False,
136
149
  reasoning_effort: str = "high",
150
+ tools: Optional[List[BaseTool]] = None,
137
151
  **vendor_params: Dict[str, Any],
138
- ) -> str:
152
+ ) -> BaseLMResponse:
139
153
  assert (
140
154
  lm_config.get("response_model", None) is None
141
155
  ), "response_model is not supported for standard calls"
@@ -143,14 +157,10 @@ class AnthropicAPI(VendorBase):
143
157
  use_ephemeral_cache_only=use_ephemeral_cache_only
144
158
  )
145
159
  cache_result = used_cache_handler.hit_managed_cache(
146
- model, messages, lm_config=lm_config
160
+ model, messages, lm_config=lm_config, tools=tools
147
161
  )
148
162
  if cache_result:
149
- return (
150
- cache_result["response"]
151
- if isinstance(cache_result, dict)
152
- else cache_result
153
- )
163
+ return cache_result
154
164
 
155
165
  # Common API parameters
156
166
  api_params = {
@@ -163,45 +173,61 @@ class AnthropicAPI(VendorBase):
163
173
  ),
164
174
  }
165
175
 
166
- # Only try to add thinking if supported by the SDK (check if Claude 3.7 and if reasoning_effort is set)
167
- # Try to detect capabilities without causing an error
176
+ # Add tools if provided
177
+ if tools:
178
+ api_params["tools"] = [tool.to_anthropic_tool() for tool in tools]
179
+
180
+ # Only try to add thinking if supported by the SDK
168
181
  try:
169
182
  import inspect
170
183
 
171
184
  create_sig = inspect.signature(self.sync_client.messages.create)
172
185
  if "thinking" in create_sig.parameters and "claude-3-7" in model:
186
+ api_params["temperature"] = 1
173
187
  if reasoning_effort in ["high", "medium"]:
174
- budgets = {
175
- "high": 4000,
176
- "medium": 2000,
177
- "low": 1000,
178
- }
188
+ budgets = sonnet_37_budgets
179
189
  budget = budgets[reasoning_effort]
180
190
  api_params["thinking"] = {
181
191
  "type": "enabled",
182
192
  "budget_tokens": budget,
183
193
  }
194
+ api_params["max_tokens"] = budget+4096
195
+ api_params["temperature"] = 1
184
196
  except (ImportError, AttributeError, TypeError):
185
- # If we can't inspect or the parameter doesn't exist, just continue without it
186
197
  pass
187
198
 
188
199
  # Make the API call
189
200
  response = self.sync_client.messages.create(**api_params)
190
201
 
191
- # Handle both regular and thinking responses
192
- if hasattr(response.content[0], "text"):
193
- api_result = response.content[0].text
194
- else:
195
- # For thinking responses, get the final output
196
- thinking_blocks = [
197
- block for block in response.content if block.type == "text"
198
- ]
199
- api_result = thinking_blocks[-1].text if thinking_blocks else ""
202
+ # Extract text content and tool calls
203
+ raw_response = ""
204
+ tool_calls = []
205
+
206
+ for content in response.content:
207
+ if content.type == "text":
208
+ raw_response += content.text
209
+ elif content.type == "tool_use":
210
+ tool_calls.append(
211
+ {
212
+ "id": content.id,
213
+ "type": "function",
214
+ "function": {
215
+ "name": content.name,
216
+ "arguments": json.dumps(content.input),
217
+ },
218
+ }
219
+ )
220
+
221
+ lm_response = BaseLMResponse(
222
+ raw_response=raw_response,
223
+ structured_output=None,
224
+ tool_calls=tool_calls if tool_calls else None,
225
+ )
200
226
 
201
227
  used_cache_handler.add_to_managed_cache(
202
- model, messages, lm_config=lm_config, output=api_result
228
+ model, messages, lm_config=lm_config, output=lm_response, tools=tools
203
229
  )
204
- return api_result
230
+ return lm_response
205
231
 
206
232
  async def _hit_api_async_structured_output(
207
233
  self,
@@ -212,36 +238,42 @@ class AnthropicAPI(VendorBase):
212
238
  use_ephemeral_cache_only: bool = False,
213
239
  reasoning_effort: str = "high",
214
240
  **vendor_params: Dict[str, Any],
215
- ) -> str:
241
+ ) -> BaseLMResponse:
216
242
  try:
217
243
  # First try with Anthropic
218
244
  reasoning_effort = vendor_params.get("reasoning_effort", reasoning_effort)
219
245
  if "claude-3-7" in model:
220
- if reasoning_effort in ["high", "medium"]:
221
- budgets = {
222
- "high": 4000,
223
- "medium": 2000,
224
- "low": 1000,
225
- }
226
- budget = budgets[reasoning_effort]
246
+
247
+ #if reasoning_effort in ["high", "medium"]:
248
+ budgets = sonnet_37_budgets
249
+ budget = budgets[reasoning_effort]
250
+ max_tokens = budget+4096
251
+ temperature = 1
252
+
227
253
  response = await self.async_client.messages.create(
228
254
  system=messages[0]["content"],
229
255
  messages=messages[1:],
230
256
  model=model,
231
- max_tokens=4096,
257
+ max_tokens=max_tokens,
232
258
  thinking={"type": "enabled", "budget_tokens": budget},
259
+ temperature=temperature,
233
260
  )
234
261
  else:
235
262
  response = await self.async_client.messages.create(
236
263
  system=messages[0]["content"],
237
264
  messages=messages[1:],
238
265
  model=model,
239
- max_tokens=4096,
266
+ max_tokens=max_tokens,
267
+ temperature=temperature,
240
268
  )
241
269
  result = response.content[0].text
242
- # Try to parse the result as JSON
243
270
  parsed = json.loads(result)
244
- return response_model(**parsed)
271
+ lm_response = BaseLMResponse(
272
+ raw_response="",
273
+ structured_output=response_model(**parsed),
274
+ tool_calls=None,
275
+ )
276
+ return lm_response
245
277
  except (json.JSONDecodeError, pydantic.ValidationError):
246
278
  # If Anthropic fails, fallback to OpenAI
247
279
  if self._openai_fallback is None:
@@ -263,7 +295,7 @@ class AnthropicAPI(VendorBase):
263
295
  use_ephemeral_cache_only: bool = False,
264
296
  reasoning_effort: str = "high",
265
297
  **vendor_params: Dict[str, Any],
266
- ) -> str:
298
+ ) -> BaseLMResponse:
267
299
  try:
268
300
  # First try with Anthropic
269
301
  reasoning_effort = vendor_params.get("reasoning_effort", reasoning_effort)
@@ -271,17 +303,15 @@ class AnthropicAPI(VendorBase):
271
303
 
272
304
  if "claude-3-7" in model:
273
305
  if reasoning_effort in ["high", "medium"]:
274
- budgets = {
275
- "high": 4000,
276
- "medium": 2000,
277
- "low": 1000,
278
- }
306
+ budgets = sonnet_37_budgets
279
307
  budget = budgets[reasoning_effort]
308
+ max_tokens = budget+4096
309
+ temperature = 1
280
310
  response = self.sync_client.messages.create(
281
311
  system=messages[0]["content"],
282
312
  messages=messages[1:],
283
313
  model=model,
284
- max_tokens=4096,
314
+ max_tokens=max_tokens,
285
315
  temperature=temperature,
286
316
  thinking={"type": "enabled", "budget_tokens": budget},
287
317
  )
@@ -290,14 +320,19 @@ class AnthropicAPI(VendorBase):
290
320
  system=messages[0]["content"],
291
321
  messages=messages[1:],
292
322
  model=model,
293
- max_tokens=4096,
323
+ max_tokens=max_tokens,
294
324
  temperature=temperature,
295
325
  )
296
326
  # print("Time taken for API call", time.time() - t)
297
327
  result = response.content[0].text
298
328
  # Try to parse the result as JSON
299
329
  parsed = json.loads(result)
300
- return response_model(**parsed)
330
+ lm_response = BaseLMResponse(
331
+ raw_response="",
332
+ structured_output=response_model(**parsed),
333
+ tool_calls=None,
334
+ )
335
+ return lm_response
301
336
  except (json.JSONDecodeError, pydantic.ValidationError):
302
337
  # If Anthropic fails, fallback to OpenAI
303
338
  print("WARNING - Falling back to OpenAI - THIS IS SLOW")
@@ -1,16 +1,18 @@
1
+ import json
1
2
  import logging
2
3
  import os
3
4
  import warnings
4
- from typing import Any, Dict, List, Tuple, Type
5
+ from typing import Any, Dict, List, Optional, Tuple, Type
5
6
 
6
7
  import google.generativeai as genai
7
8
  from google.api_core.exceptions import ResourceExhausted
8
- from google.generativeai.types import HarmBlockThreshold, HarmCategory
9
+ from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool
9
10
 
10
11
  from synth_ai.zyk.lms.caching.initialize import (
11
12
  get_cache_handler,
12
13
  )
13
- from synth_ai.zyk.lms.vendors.base import VendorBase
14
+ from synth_ai.zyk.lms.tools.base import BaseTool
15
+ from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
14
16
  from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
15
17
  from synth_ai.zyk.lms.vendors.retries import BACKOFF_TOLERANCE, backoff
16
18
 
@@ -41,23 +43,89 @@ class GeminiAPI(VendorBase):
41
43
  self.used_for_structured_outputs = used_for_structured_outputs
42
44
  self.exceptions_to_retry = exceptions_to_retry
43
45
 
46
+ def _convert_messages_to_contents(
47
+ self, messages: List[Dict[str, Any]]
48
+ ) -> List[Dict[str, Any]]:
49
+ contents = []
50
+ system_instruction = None
51
+ for message in messages:
52
+ if message["role"] == "system":
53
+ system_instruction = (
54
+ f"<instructions>\n{message['content']}\n</instructions>"
55
+ )
56
+ continue
57
+ elif system_instruction:
58
+ text = system_instruction + "\n" + message["content"]
59
+ else:
60
+ text = message["content"]
61
+ contents.append(
62
+ {
63
+ "role": message["role"],
64
+ "parts": [{"text": text}],
65
+ }
66
+ )
67
+ return contents
68
+
69
+ def _convert_tools_to_gemini_format(self, tools: List[BaseTool]) -> Tool:
70
+ function_declarations = []
71
+ for tool in tools:
72
+ function_declarations.append(tool.to_gemini_tool())
73
+ return Tool(function_declarations=function_declarations)
74
+
44
75
  async def _private_request_async(
45
76
  self,
46
77
  messages: List[Dict],
47
78
  temperature: float = 0,
48
79
  model_name: str = "gemini-1.5-flash",
49
80
  reasoning_effort: str = "high",
50
- ) -> str:
81
+ tools: Optional[List[BaseTool]] = None,
82
+ lm_config: Optional[Dict[str, Any]] = None,
83
+ ) -> Tuple[str, Optional[List[Dict]]]:
84
+ generation_config = {
85
+ "temperature": temperature,
86
+ }
87
+
88
+ tools_config = None
89
+ if tools:
90
+ tools_config = self._convert_tools_to_gemini_format(tools)
91
+
92
+ # Extract tool_config from lm_config if provided
93
+ tool_config = lm_config.get("tool_config") if lm_config else {
94
+ "function_calling_config": {
95
+ "mode": "any"
96
+ }
97
+ }
98
+
51
99
  code_generation_model = genai.GenerativeModel(
52
100
  model_name=model_name,
53
- generation_config={"temperature": temperature},
54
- system_instruction=messages[0]["content"],
101
+ generation_config=generation_config,
102
+ tools=tools_config if tools_config else None,
103
+ tool_config=tool_config,
55
104
  )
105
+
106
+ contents = self._convert_messages_to_contents(messages)
56
107
  result = await code_generation_model.generate_content_async(
57
- messages[1]["content"],
108
+ contents=contents,
58
109
  safety_settings=SAFETY_SETTINGS,
59
110
  )
60
- return result.text
111
+
112
+ text = result.candidates[0].content.parts[0].text
113
+ tool_calls = []
114
+ for part in result.candidates[0].content.parts:
115
+ if part.function_call:
116
+ # Convert MapComposite args to dict
117
+ args_dict = dict(part.function_call.args)
118
+ tool_calls.append(
119
+ {
120
+ "id": f"call_{len(tool_calls) + 1}", # Generate unique IDs
121
+ "type": "function",
122
+ "function": {
123
+ "name": part.function_call.name,
124
+ "arguments": json.dumps(args_dict),
125
+ },
126
+ }
127
+ )
128
+ return text, tool_calls if tool_calls else None
61
129
 
62
130
  def _private_request_sync(
63
131
  self,
@@ -65,17 +133,54 @@ class GeminiAPI(VendorBase):
65
133
  temperature: float = 0,
66
134
  model_name: str = "gemini-1.5-flash",
67
135
  reasoning_effort: str = "high",
68
- ) -> str:
136
+ tools: Optional[List[BaseTool]] = None,
137
+ lm_config: Optional[Dict[str, Any]] = None,
138
+ ) -> Tuple[str, Optional[List[Dict]]]:
139
+ generation_config = {
140
+ "temperature": temperature,
141
+ }
142
+
143
+ tools_config = None
144
+ if tools:
145
+ tools_config = self._convert_tools_to_gemini_format(tools)
146
+
147
+ # Extract tool_config from lm_config if provided
148
+ tool_config = lm_config.get("tool_config") if lm_config else {
149
+ "function_calling_config": {
150
+ "mode": "any"
151
+ }
152
+ }
153
+
69
154
  code_generation_model = genai.GenerativeModel(
70
155
  model_name=model_name,
71
- generation_config={"temperature": temperature},
72
- system_instruction=messages[0]["content"],
156
+ generation_config=generation_config,
157
+ tools=tools_config if tools_config else None,
158
+ tool_config=tool_config,
73
159
  )
160
+
161
+ contents = self._convert_messages_to_contents(messages)
74
162
  result = code_generation_model.generate_content(
75
- messages[1]["content"],
163
+ contents=contents,
76
164
  safety_settings=SAFETY_SETTINGS,
77
165
  )
78
- return result.text
166
+
167
+ text = result.candidates[0].content.parts[0].text
168
+ tool_calls = []
169
+ for part in result.candidates[0].content.parts:
170
+ if part.function_call:
171
+ # Convert MapComposite args to dict
172
+ args_dict = dict(part.function_call.args)
173
+ tool_calls.append(
174
+ {
175
+ "id": f"call_{len(tool_calls) + 1}", # Generate unique IDs
176
+ "type": "function",
177
+ "function": {
178
+ "name": part.function_call.name,
179
+ "arguments": json.dumps(args_dict),
180
+ },
181
+ }
182
+ )
183
+ return text, tool_calls if tool_calls else None
79
184
 
80
185
  @backoff.on_exception(
81
186
  backoff.expo,
@@ -90,29 +195,35 @@ class GeminiAPI(VendorBase):
90
195
  lm_config: Dict[str, Any],
91
196
  use_ephemeral_cache_only: bool = False,
92
197
  reasoning_effort: str = "high",
93
- ) -> str:
198
+ tools: Optional[List[BaseTool]] = None,
199
+ ) -> BaseLMResponse:
94
200
  assert (
95
201
  lm_config.get("response_model", None) is None
96
202
  ), "response_model is not supported for standard calls"
97
203
  used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
98
204
  cache_result = used_cache_handler.hit_managed_cache(
99
- model, messages, lm_config=lm_config
205
+ model, messages, lm_config=lm_config, tools=tools
100
206
  )
101
207
  if cache_result:
102
- return (
103
- cache_result["response"]
104
- if isinstance(cache_result, dict)
105
- else cache_result
106
- )
107
- api_result = await self._private_request_async(
208
+ return cache_result
209
+
210
+ raw_response, tool_calls = await self._private_request_async(
108
211
  messages,
109
212
  temperature=lm_config.get("temperature", SPECIAL_BASE_TEMPS.get(model, 0)),
110
213
  reasoning_effort=reasoning_effort,
214
+ tools=tools,
215
+ )
216
+
217
+ lm_response = BaseLMResponse(
218
+ raw_response=raw_response,
219
+ structured_output=None,
220
+ tool_calls=tool_calls,
111
221
  )
222
+
112
223
  used_cache_handler.add_to_managed_cache(
113
- model, messages, lm_config=lm_config, output=api_result
224
+ model, messages, lm_config=lm_config, output=lm_response, tools=tools
114
225
  )
115
- return api_result
226
+ return lm_response
116
227
 
117
228
  @backoff.on_exception(
118
229
  backoff.expo,
@@ -127,26 +238,34 @@ class GeminiAPI(VendorBase):
127
238
  lm_config: Dict[str, Any],
128
239
  use_ephemeral_cache_only: bool = False,
129
240
  reasoning_effort: str = "high",
130
- ) -> str:
241
+ tools: Optional[List[BaseTool]] = None,
242
+ ) -> BaseLMResponse:
131
243
  assert (
132
244
  lm_config.get("response_model", None) is None
133
245
  ), "response_model is not supported for standard calls"
134
- used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
246
+ used_cache_handler = get_cache_handler(
247
+ use_ephemeral_cache_only=use_ephemeral_cache_only
248
+ )
135
249
  cache_result = used_cache_handler.hit_managed_cache(
136
- model, messages, lm_config=lm_config
250
+ model, messages, lm_config=lm_config, tools=tools
137
251
  )
138
252
  if cache_result:
139
- return (
140
- cache_result["response"]
141
- if isinstance(cache_result, dict)
142
- else cache_result
143
- )
144
- api_result = self._private_request_sync(
253
+ return cache_result
254
+
255
+ raw_response, tool_calls = self._private_request_sync(
145
256
  messages,
146
257
  temperature=lm_config.get("temperature", SPECIAL_BASE_TEMPS.get(model, 0)),
147
258
  reasoning_effort=reasoning_effort,
259
+ tools=tools,
148
260
  )
261
+
262
+ lm_response = BaseLMResponse(
263
+ raw_response=raw_response,
264
+ structured_output=None,
265
+ tool_calls=tool_calls,
266
+ )
267
+
149
268
  used_cache_handler.add_to_managed_cache(
150
- model, messages, lm_config=lm_config, output=api_result
269
+ model, messages, lm_config=lm_config, output=lm_response, tools=tools
151
270
  )
152
- return api_result
271
+ return lm_response