synth-ai 0.1.0.dev27__py3-none-any.whl → 0.1.0.dev29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- public_tests/test_agent.py +11 -11
- public_tests/test_all_structured_outputs.py +32 -37
- public_tests/test_anthropic_structured_outputs.py +0 -0
- public_tests/test_deepseek_structured_outputs.py +0 -0
- public_tests/test_deepseek_tools.py +64 -0
- public_tests/test_gemini_structured_outputs.py +106 -0
- public_tests/test_models.py +27 -27
- public_tests/test_openai_structured_outputs.py +106 -0
- public_tests/test_reasoning_models.py +9 -7
- public_tests/test_recursive_structured_outputs.py +30 -30
- public_tests/test_structured.py +137 -0
- public_tests/test_structured_outputs.py +22 -13
- public_tests/test_text.py +160 -0
- public_tests/test_tools.py +300 -0
- synth_ai/__init__.py +1 -4
- synth_ai/zyk/__init__.py +2 -2
- synth_ai/zyk/lms/caching/ephemeral.py +54 -32
- synth_ai/zyk/lms/caching/handler.py +43 -15
- synth_ai/zyk/lms/caching/persistent.py +55 -27
- synth_ai/zyk/lms/core/main.py +26 -14
- synth_ai/zyk/lms/core/vendor_clients.py +1 -1
- synth_ai/zyk/lms/structured_outputs/handler.py +79 -45
- synth_ai/zyk/lms/structured_outputs/rehabilitate.py +3 -2
- synth_ai/zyk/lms/tools/base.py +104 -0
- synth_ai/zyk/lms/vendors/base.py +22 -6
- synth_ai/zyk/lms/vendors/core/anthropic_api.py +130 -95
- synth_ai/zyk/lms/vendors/core/gemini_api.py +153 -34
- synth_ai/zyk/lms/vendors/core/mistral_api.py +160 -54
- synth_ai/zyk/lms/vendors/core/openai_api.py +64 -53
- synth_ai/zyk/lms/vendors/openai_standard.py +197 -41
- synth_ai/zyk/lms/vendors/supported/deepseek.py +55 -0
- {synth_ai-0.1.0.dev27.dist-info → synth_ai-0.1.0.dev29.dist-info}/METADATA +2 -5
- synth_ai-0.1.0.dev29.dist-info/RECORD +65 -0
- public_tests/test_sonnet_thinking.py +0 -178
- synth_ai-0.1.0.dev27.dist-info/RECORD +0 -57
- {synth_ai-0.1.0.dev27.dist-info → synth_ai-0.1.0.dev29.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.0.dev27.dist-info → synth_ai-0.1.0.dev29.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.0.dev27.dist-info → synth_ai-0.1.0.dev29.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Any, Dict, List, Tuple, Type
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
3
3
|
|
4
4
|
import anthropic
|
5
5
|
import pydantic
|
@@ -8,14 +8,20 @@ from pydantic import BaseModel
|
|
8
8
|
from synth_ai.zyk.lms.caching.initialize import (
|
9
9
|
get_cache_handler,
|
10
10
|
)
|
11
|
-
from synth_ai.zyk.lms.
|
11
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
12
|
+
from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
|
12
13
|
from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
|
13
14
|
from synth_ai.zyk.lms.vendors.core.openai_api import OpenAIStructuredOutputClient
|
14
|
-
from synth_ai.zyk.lms.vendors.retries import BACKOFF_TOLERANCE, backoff
|
15
15
|
|
16
16
|
ANTHROPIC_EXCEPTIONS_TO_RETRY: Tuple[Type[Exception], ...] = (anthropic.APIError,)
|
17
17
|
|
18
18
|
|
19
|
+
sonnet_37_budgets = {
|
20
|
+
"high": 4000,
|
21
|
+
"medium": 2000,
|
22
|
+
"low": 1000,
|
23
|
+
}
|
24
|
+
|
19
25
|
class AnthropicAPI(VendorBase):
|
20
26
|
used_for_structured_outputs: bool = True
|
21
27
|
exceptions_to_retry: Tuple = ANTHROPIC_EXCEPTIONS_TO_RETRY
|
@@ -37,12 +43,12 @@ class AnthropicAPI(VendorBase):
|
|
37
43
|
self._openai_fallback = None
|
38
44
|
self.reasoning_effort = reasoning_effort
|
39
45
|
|
40
|
-
@backoff.on_exception(
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
)
|
46
|
+
# @backoff.on_exception(
|
47
|
+
# backoff.expo,
|
48
|
+
# exceptions_to_retry,
|
49
|
+
# max_tries=BACKOFF_TOLERANCE,
|
50
|
+
# on_giveup=lambda e: print(e),
|
51
|
+
# )
|
46
52
|
async def _hit_api_async(
|
47
53
|
self,
|
48
54
|
model: str,
|
@@ -50,83 +56,90 @@ class AnthropicAPI(VendorBase):
|
|
50
56
|
lm_config: Dict[str, Any],
|
51
57
|
use_ephemeral_cache_only: bool = False,
|
52
58
|
reasoning_effort: str = "high",
|
59
|
+
tools: Optional[List[BaseTool]] = None,
|
53
60
|
**vendor_params: Dict[str, Any],
|
54
|
-
) ->
|
61
|
+
) -> BaseLMResponse:
|
55
62
|
assert (
|
56
63
|
lm_config.get("response_model", None) is None
|
57
64
|
), "response_model is not supported for standard calls"
|
58
65
|
used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
|
59
66
|
cache_result = used_cache_handler.hit_managed_cache(
|
60
|
-
model, messages, lm_config=lm_config
|
67
|
+
model, messages, lm_config=lm_config, tools=tools
|
61
68
|
)
|
62
69
|
if cache_result:
|
63
|
-
return
|
64
|
-
cache_result["response"]
|
65
|
-
if isinstance(cache_result, dict)
|
66
|
-
else cache_result
|
67
|
-
)
|
70
|
+
return cache_result
|
68
71
|
|
69
72
|
# Common API parameters
|
70
73
|
api_params = {
|
71
74
|
"system": messages[0]["content"],
|
72
75
|
"messages": messages[1:],
|
73
76
|
"model": model,
|
74
|
-
"max_tokens": lm_config.get("max_tokens", 4096
|
77
|
+
"max_tokens": lm_config.get("max_tokens", 4096),
|
75
78
|
"temperature": lm_config.get(
|
76
79
|
"temperature", SPECIAL_BASE_TEMPS.get(model, 0)
|
77
80
|
),
|
78
81
|
}
|
79
82
|
|
80
|
-
#
|
83
|
+
# Add tools if provided
|
84
|
+
if tools:
|
85
|
+
api_params["tools"] = [tool.to_anthropic_tool() for tool in tools]
|
86
|
+
|
87
|
+
# Only try to add thinking if supported by the SDK
|
81
88
|
try:
|
82
89
|
import inspect
|
83
90
|
|
84
91
|
create_sig = inspect.signature(self.async_client.messages.create)
|
85
92
|
if "thinking" in create_sig.parameters and "claude-3-7" in model:
|
86
93
|
if reasoning_effort in ["high", "medium"]:
|
87
|
-
|
88
|
-
"high": 4000,
|
89
|
-
"medium": 2000,
|
90
|
-
"low": 1000,
|
91
|
-
}
|
92
|
-
budget = budgets[reasoning_effort]
|
94
|
+
budget = sonnet_37_budgets[reasoning_effort]
|
93
95
|
api_params["thinking"] = {
|
94
96
|
"type": "enabled",
|
95
97
|
"budget_tokens": budget,
|
96
98
|
}
|
97
|
-
|
98
|
-
api_params["
|
99
|
-
api_params["max_tokens"], budget + 4096
|
100
|
-
)
|
101
|
-
# Set temperature to 1 for thinking, but only in API call
|
102
|
-
api_params["temperature"] = 1.0
|
99
|
+
api_params["max_tokens"] = budget+4096
|
100
|
+
api_params["temperature"] = 1
|
103
101
|
except (ImportError, AttributeError, TypeError):
|
104
102
|
pass
|
105
103
|
|
106
104
|
# Make the API call
|
107
105
|
response = await self.async_client.messages.create(**api_params)
|
108
106
|
|
109
|
-
#
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
107
|
+
# Extract text content and tool calls
|
108
|
+
raw_response = ""
|
109
|
+
tool_calls = []
|
110
|
+
|
111
|
+
for content in response.content:
|
112
|
+
if content.type == "text":
|
113
|
+
raw_response += content.text
|
114
|
+
elif content.type == "tool_use":
|
115
|
+
tool_calls.append(
|
116
|
+
{
|
117
|
+
"id": content.id,
|
118
|
+
"type": "function",
|
119
|
+
"function": {
|
120
|
+
"name": content.name,
|
121
|
+
"arguments": json.dumps(content.input),
|
122
|
+
},
|
123
|
+
}
|
124
|
+
)
|
125
|
+
|
126
|
+
lm_response = BaseLMResponse(
|
127
|
+
raw_response=raw_response,
|
128
|
+
structured_output=None,
|
129
|
+
tool_calls=tool_calls if tool_calls else None,
|
130
|
+
)
|
118
131
|
|
119
132
|
used_cache_handler.add_to_managed_cache(
|
120
|
-
model, messages, lm_config=lm_config, output=
|
133
|
+
model, messages, lm_config=lm_config, output=lm_response, tools=tools
|
121
134
|
)
|
122
|
-
return
|
123
|
-
|
124
|
-
@backoff.on_exception(
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
)
|
135
|
+
return lm_response
|
136
|
+
|
137
|
+
# @backoff.on_exception(
|
138
|
+
# backoff.expo,
|
139
|
+
# exceptions_to_retry,
|
140
|
+
# max_tries=BACKOFF_TOLERANCE,
|
141
|
+
# on_giveup=lambda e: print(e),
|
142
|
+
# )
|
130
143
|
def _hit_api_sync(
|
131
144
|
self,
|
132
145
|
model: str,
|
@@ -134,8 +147,9 @@ class AnthropicAPI(VendorBase):
|
|
134
147
|
lm_config: Dict[str, Any],
|
135
148
|
use_ephemeral_cache_only: bool = False,
|
136
149
|
reasoning_effort: str = "high",
|
150
|
+
tools: Optional[List[BaseTool]] = None,
|
137
151
|
**vendor_params: Dict[str, Any],
|
138
|
-
) ->
|
152
|
+
) -> BaseLMResponse:
|
139
153
|
assert (
|
140
154
|
lm_config.get("response_model", None) is None
|
141
155
|
), "response_model is not supported for standard calls"
|
@@ -143,14 +157,10 @@ class AnthropicAPI(VendorBase):
|
|
143
157
|
use_ephemeral_cache_only=use_ephemeral_cache_only
|
144
158
|
)
|
145
159
|
cache_result = used_cache_handler.hit_managed_cache(
|
146
|
-
model, messages, lm_config=lm_config
|
160
|
+
model, messages, lm_config=lm_config, tools=tools
|
147
161
|
)
|
148
162
|
if cache_result:
|
149
|
-
return
|
150
|
-
cache_result["response"]
|
151
|
-
if isinstance(cache_result, dict)
|
152
|
-
else cache_result
|
153
|
-
)
|
163
|
+
return cache_result
|
154
164
|
|
155
165
|
# Common API parameters
|
156
166
|
api_params = {
|
@@ -163,45 +173,61 @@ class AnthropicAPI(VendorBase):
|
|
163
173
|
),
|
164
174
|
}
|
165
175
|
|
166
|
-
#
|
167
|
-
|
176
|
+
# Add tools if provided
|
177
|
+
if tools:
|
178
|
+
api_params["tools"] = [tool.to_anthropic_tool() for tool in tools]
|
179
|
+
|
180
|
+
# Only try to add thinking if supported by the SDK
|
168
181
|
try:
|
169
182
|
import inspect
|
170
183
|
|
171
184
|
create_sig = inspect.signature(self.sync_client.messages.create)
|
172
185
|
if "thinking" in create_sig.parameters and "claude-3-7" in model:
|
186
|
+
api_params["temperature"] = 1
|
173
187
|
if reasoning_effort in ["high", "medium"]:
|
174
|
-
budgets =
|
175
|
-
"high": 4000,
|
176
|
-
"medium": 2000,
|
177
|
-
"low": 1000,
|
178
|
-
}
|
188
|
+
budgets = sonnet_37_budgets
|
179
189
|
budget = budgets[reasoning_effort]
|
180
190
|
api_params["thinking"] = {
|
181
191
|
"type": "enabled",
|
182
192
|
"budget_tokens": budget,
|
183
193
|
}
|
194
|
+
api_params["max_tokens"] = budget+4096
|
195
|
+
api_params["temperature"] = 1
|
184
196
|
except (ImportError, AttributeError, TypeError):
|
185
|
-
# If we can't inspect or the parameter doesn't exist, just continue without it
|
186
197
|
pass
|
187
198
|
|
188
199
|
# Make the API call
|
189
200
|
response = self.sync_client.messages.create(**api_params)
|
190
201
|
|
191
|
-
#
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
202
|
+
# Extract text content and tool calls
|
203
|
+
raw_response = ""
|
204
|
+
tool_calls = []
|
205
|
+
|
206
|
+
for content in response.content:
|
207
|
+
if content.type == "text":
|
208
|
+
raw_response += content.text
|
209
|
+
elif content.type == "tool_use":
|
210
|
+
tool_calls.append(
|
211
|
+
{
|
212
|
+
"id": content.id,
|
213
|
+
"type": "function",
|
214
|
+
"function": {
|
215
|
+
"name": content.name,
|
216
|
+
"arguments": json.dumps(content.input),
|
217
|
+
},
|
218
|
+
}
|
219
|
+
)
|
220
|
+
|
221
|
+
lm_response = BaseLMResponse(
|
222
|
+
raw_response=raw_response,
|
223
|
+
structured_output=None,
|
224
|
+
tool_calls=tool_calls if tool_calls else None,
|
225
|
+
)
|
200
226
|
|
201
227
|
used_cache_handler.add_to_managed_cache(
|
202
|
-
model, messages, lm_config=lm_config, output=
|
228
|
+
model, messages, lm_config=lm_config, output=lm_response, tools=tools
|
203
229
|
)
|
204
|
-
return
|
230
|
+
return lm_response
|
205
231
|
|
206
232
|
async def _hit_api_async_structured_output(
|
207
233
|
self,
|
@@ -212,36 +238,42 @@ class AnthropicAPI(VendorBase):
|
|
212
238
|
use_ephemeral_cache_only: bool = False,
|
213
239
|
reasoning_effort: str = "high",
|
214
240
|
**vendor_params: Dict[str, Any],
|
215
|
-
) ->
|
241
|
+
) -> BaseLMResponse:
|
216
242
|
try:
|
217
243
|
# First try with Anthropic
|
218
244
|
reasoning_effort = vendor_params.get("reasoning_effort", reasoning_effort)
|
219
245
|
if "claude-3-7" in model:
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
246
|
+
|
247
|
+
#if reasoning_effort in ["high", "medium"]:
|
248
|
+
budgets = sonnet_37_budgets
|
249
|
+
budget = budgets[reasoning_effort]
|
250
|
+
max_tokens = budget+4096
|
251
|
+
temperature = 1
|
252
|
+
|
227
253
|
response = await self.async_client.messages.create(
|
228
254
|
system=messages[0]["content"],
|
229
255
|
messages=messages[1:],
|
230
256
|
model=model,
|
231
|
-
max_tokens=
|
257
|
+
max_tokens=max_tokens,
|
232
258
|
thinking={"type": "enabled", "budget_tokens": budget},
|
259
|
+
temperature=temperature,
|
233
260
|
)
|
234
261
|
else:
|
235
262
|
response = await self.async_client.messages.create(
|
236
263
|
system=messages[0]["content"],
|
237
264
|
messages=messages[1:],
|
238
265
|
model=model,
|
239
|
-
max_tokens=
|
266
|
+
max_tokens=max_tokens,
|
267
|
+
temperature=temperature,
|
240
268
|
)
|
241
269
|
result = response.content[0].text
|
242
|
-
# Try to parse the result as JSON
|
243
270
|
parsed = json.loads(result)
|
244
|
-
|
271
|
+
lm_response = BaseLMResponse(
|
272
|
+
raw_response="",
|
273
|
+
structured_output=response_model(**parsed),
|
274
|
+
tool_calls=None,
|
275
|
+
)
|
276
|
+
return lm_response
|
245
277
|
except (json.JSONDecodeError, pydantic.ValidationError):
|
246
278
|
# If Anthropic fails, fallback to OpenAI
|
247
279
|
if self._openai_fallback is None:
|
@@ -263,7 +295,7 @@ class AnthropicAPI(VendorBase):
|
|
263
295
|
use_ephemeral_cache_only: bool = False,
|
264
296
|
reasoning_effort: str = "high",
|
265
297
|
**vendor_params: Dict[str, Any],
|
266
|
-
) ->
|
298
|
+
) -> BaseLMResponse:
|
267
299
|
try:
|
268
300
|
# First try with Anthropic
|
269
301
|
reasoning_effort = vendor_params.get("reasoning_effort", reasoning_effort)
|
@@ -271,17 +303,15 @@ class AnthropicAPI(VendorBase):
|
|
271
303
|
|
272
304
|
if "claude-3-7" in model:
|
273
305
|
if reasoning_effort in ["high", "medium"]:
|
274
|
-
budgets =
|
275
|
-
"high": 4000,
|
276
|
-
"medium": 2000,
|
277
|
-
"low": 1000,
|
278
|
-
}
|
306
|
+
budgets = sonnet_37_budgets
|
279
307
|
budget = budgets[reasoning_effort]
|
308
|
+
max_tokens = budget+4096
|
309
|
+
temperature = 1
|
280
310
|
response = self.sync_client.messages.create(
|
281
311
|
system=messages[0]["content"],
|
282
312
|
messages=messages[1:],
|
283
313
|
model=model,
|
284
|
-
max_tokens=
|
314
|
+
max_tokens=max_tokens,
|
285
315
|
temperature=temperature,
|
286
316
|
thinking={"type": "enabled", "budget_tokens": budget},
|
287
317
|
)
|
@@ -290,14 +320,19 @@ class AnthropicAPI(VendorBase):
|
|
290
320
|
system=messages[0]["content"],
|
291
321
|
messages=messages[1:],
|
292
322
|
model=model,
|
293
|
-
max_tokens=
|
323
|
+
max_tokens=max_tokens,
|
294
324
|
temperature=temperature,
|
295
325
|
)
|
296
326
|
# print("Time taken for API call", time.time() - t)
|
297
327
|
result = response.content[0].text
|
298
328
|
# Try to parse the result as JSON
|
299
329
|
parsed = json.loads(result)
|
300
|
-
|
330
|
+
lm_response = BaseLMResponse(
|
331
|
+
raw_response="",
|
332
|
+
structured_output=response_model(**parsed),
|
333
|
+
tool_calls=None,
|
334
|
+
)
|
335
|
+
return lm_response
|
301
336
|
except (json.JSONDecodeError, pydantic.ValidationError):
|
302
337
|
# If Anthropic fails, fallback to OpenAI
|
303
338
|
print("WARNING - Falling back to OpenAI - THIS IS SLOW")
|
@@ -1,16 +1,18 @@
|
|
1
|
+
import json
|
1
2
|
import logging
|
2
3
|
import os
|
3
4
|
import warnings
|
4
|
-
from typing import Any, Dict, List, Tuple, Type
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
5
6
|
|
6
7
|
import google.generativeai as genai
|
7
8
|
from google.api_core.exceptions import ResourceExhausted
|
8
|
-
from google.generativeai.types import HarmBlockThreshold, HarmCategory
|
9
|
+
from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool
|
9
10
|
|
10
11
|
from synth_ai.zyk.lms.caching.initialize import (
|
11
12
|
get_cache_handler,
|
12
13
|
)
|
13
|
-
from synth_ai.zyk.lms.
|
14
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
15
|
+
from synth_ai.zyk.lms.vendors.base import BaseLMResponse, VendorBase
|
14
16
|
from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
|
15
17
|
from synth_ai.zyk.lms.vendors.retries import BACKOFF_TOLERANCE, backoff
|
16
18
|
|
@@ -41,23 +43,89 @@ class GeminiAPI(VendorBase):
|
|
41
43
|
self.used_for_structured_outputs = used_for_structured_outputs
|
42
44
|
self.exceptions_to_retry = exceptions_to_retry
|
43
45
|
|
46
|
+
def _convert_messages_to_contents(
|
47
|
+
self, messages: List[Dict[str, Any]]
|
48
|
+
) -> List[Dict[str, Any]]:
|
49
|
+
contents = []
|
50
|
+
system_instruction = None
|
51
|
+
for message in messages:
|
52
|
+
if message["role"] == "system":
|
53
|
+
system_instruction = (
|
54
|
+
f"<instructions>\n{message['content']}\n</instructions>"
|
55
|
+
)
|
56
|
+
continue
|
57
|
+
elif system_instruction:
|
58
|
+
text = system_instruction + "\n" + message["content"]
|
59
|
+
else:
|
60
|
+
text = message["content"]
|
61
|
+
contents.append(
|
62
|
+
{
|
63
|
+
"role": message["role"],
|
64
|
+
"parts": [{"text": text}],
|
65
|
+
}
|
66
|
+
)
|
67
|
+
return contents
|
68
|
+
|
69
|
+
def _convert_tools_to_gemini_format(self, tools: List[BaseTool]) -> Tool:
|
70
|
+
function_declarations = []
|
71
|
+
for tool in tools:
|
72
|
+
function_declarations.append(tool.to_gemini_tool())
|
73
|
+
return Tool(function_declarations=function_declarations)
|
74
|
+
|
44
75
|
async def _private_request_async(
|
45
76
|
self,
|
46
77
|
messages: List[Dict],
|
47
78
|
temperature: float = 0,
|
48
79
|
model_name: str = "gemini-1.5-flash",
|
49
80
|
reasoning_effort: str = "high",
|
50
|
-
|
81
|
+
tools: Optional[List[BaseTool]] = None,
|
82
|
+
lm_config: Optional[Dict[str, Any]] = None,
|
83
|
+
) -> Tuple[str, Optional[List[Dict]]]:
|
84
|
+
generation_config = {
|
85
|
+
"temperature": temperature,
|
86
|
+
}
|
87
|
+
|
88
|
+
tools_config = None
|
89
|
+
if tools:
|
90
|
+
tools_config = self._convert_tools_to_gemini_format(tools)
|
91
|
+
|
92
|
+
# Extract tool_config from lm_config if provided
|
93
|
+
tool_config = lm_config.get("tool_config") if lm_config else {
|
94
|
+
"function_calling_config": {
|
95
|
+
"mode": "any"
|
96
|
+
}
|
97
|
+
}
|
98
|
+
|
51
99
|
code_generation_model = genai.GenerativeModel(
|
52
100
|
model_name=model_name,
|
53
|
-
generation_config=
|
54
|
-
|
101
|
+
generation_config=generation_config,
|
102
|
+
tools=tools_config if tools_config else None,
|
103
|
+
tool_config=tool_config,
|
55
104
|
)
|
105
|
+
|
106
|
+
contents = self._convert_messages_to_contents(messages)
|
56
107
|
result = await code_generation_model.generate_content_async(
|
57
|
-
|
108
|
+
contents=contents,
|
58
109
|
safety_settings=SAFETY_SETTINGS,
|
59
110
|
)
|
60
|
-
|
111
|
+
|
112
|
+
text = result.candidates[0].content.parts[0].text
|
113
|
+
tool_calls = []
|
114
|
+
for part in result.candidates[0].content.parts:
|
115
|
+
if part.function_call:
|
116
|
+
# Convert MapComposite args to dict
|
117
|
+
args_dict = dict(part.function_call.args)
|
118
|
+
tool_calls.append(
|
119
|
+
{
|
120
|
+
"id": f"call_{len(tool_calls) + 1}", # Generate unique IDs
|
121
|
+
"type": "function",
|
122
|
+
"function": {
|
123
|
+
"name": part.function_call.name,
|
124
|
+
"arguments": json.dumps(args_dict),
|
125
|
+
},
|
126
|
+
}
|
127
|
+
)
|
128
|
+
return text, tool_calls if tool_calls else None
|
61
129
|
|
62
130
|
def _private_request_sync(
|
63
131
|
self,
|
@@ -65,17 +133,54 @@ class GeminiAPI(VendorBase):
|
|
65
133
|
temperature: float = 0,
|
66
134
|
model_name: str = "gemini-1.5-flash",
|
67
135
|
reasoning_effort: str = "high",
|
68
|
-
|
136
|
+
tools: Optional[List[BaseTool]] = None,
|
137
|
+
lm_config: Optional[Dict[str, Any]] = None,
|
138
|
+
) -> Tuple[str, Optional[List[Dict]]]:
|
139
|
+
generation_config = {
|
140
|
+
"temperature": temperature,
|
141
|
+
}
|
142
|
+
|
143
|
+
tools_config = None
|
144
|
+
if tools:
|
145
|
+
tools_config = self._convert_tools_to_gemini_format(tools)
|
146
|
+
|
147
|
+
# Extract tool_config from lm_config if provided
|
148
|
+
tool_config = lm_config.get("tool_config") if lm_config else {
|
149
|
+
"function_calling_config": {
|
150
|
+
"mode": "any"
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
69
154
|
code_generation_model = genai.GenerativeModel(
|
70
155
|
model_name=model_name,
|
71
|
-
generation_config=
|
72
|
-
|
156
|
+
generation_config=generation_config,
|
157
|
+
tools=tools_config if tools_config else None,
|
158
|
+
tool_config=tool_config,
|
73
159
|
)
|
160
|
+
|
161
|
+
contents = self._convert_messages_to_contents(messages)
|
74
162
|
result = code_generation_model.generate_content(
|
75
|
-
|
163
|
+
contents=contents,
|
76
164
|
safety_settings=SAFETY_SETTINGS,
|
77
165
|
)
|
78
|
-
|
166
|
+
|
167
|
+
text = result.candidates[0].content.parts[0].text
|
168
|
+
tool_calls = []
|
169
|
+
for part in result.candidates[0].content.parts:
|
170
|
+
if part.function_call:
|
171
|
+
# Convert MapComposite args to dict
|
172
|
+
args_dict = dict(part.function_call.args)
|
173
|
+
tool_calls.append(
|
174
|
+
{
|
175
|
+
"id": f"call_{len(tool_calls) + 1}", # Generate unique IDs
|
176
|
+
"type": "function",
|
177
|
+
"function": {
|
178
|
+
"name": part.function_call.name,
|
179
|
+
"arguments": json.dumps(args_dict),
|
180
|
+
},
|
181
|
+
}
|
182
|
+
)
|
183
|
+
return text, tool_calls if tool_calls else None
|
79
184
|
|
80
185
|
@backoff.on_exception(
|
81
186
|
backoff.expo,
|
@@ -90,29 +195,35 @@ class GeminiAPI(VendorBase):
|
|
90
195
|
lm_config: Dict[str, Any],
|
91
196
|
use_ephemeral_cache_only: bool = False,
|
92
197
|
reasoning_effort: str = "high",
|
93
|
-
|
198
|
+
tools: Optional[List[BaseTool]] = None,
|
199
|
+
) -> BaseLMResponse:
|
94
200
|
assert (
|
95
201
|
lm_config.get("response_model", None) is None
|
96
202
|
), "response_model is not supported for standard calls"
|
97
203
|
used_cache_handler = get_cache_handler(use_ephemeral_cache_only)
|
98
204
|
cache_result = used_cache_handler.hit_managed_cache(
|
99
|
-
model, messages, lm_config=lm_config
|
205
|
+
model, messages, lm_config=lm_config, tools=tools
|
100
206
|
)
|
101
207
|
if cache_result:
|
102
|
-
return
|
103
|
-
|
104
|
-
|
105
|
-
else cache_result
|
106
|
-
)
|
107
|
-
api_result = await self._private_request_async(
|
208
|
+
return cache_result
|
209
|
+
|
210
|
+
raw_response, tool_calls = await self._private_request_async(
|
108
211
|
messages,
|
109
212
|
temperature=lm_config.get("temperature", SPECIAL_BASE_TEMPS.get(model, 0)),
|
110
213
|
reasoning_effort=reasoning_effort,
|
214
|
+
tools=tools,
|
215
|
+
)
|
216
|
+
|
217
|
+
lm_response = BaseLMResponse(
|
218
|
+
raw_response=raw_response,
|
219
|
+
structured_output=None,
|
220
|
+
tool_calls=tool_calls,
|
111
221
|
)
|
222
|
+
|
112
223
|
used_cache_handler.add_to_managed_cache(
|
113
|
-
model, messages, lm_config=lm_config, output=
|
224
|
+
model, messages, lm_config=lm_config, output=lm_response, tools=tools
|
114
225
|
)
|
115
|
-
return
|
226
|
+
return lm_response
|
116
227
|
|
117
228
|
@backoff.on_exception(
|
118
229
|
backoff.expo,
|
@@ -127,26 +238,34 @@ class GeminiAPI(VendorBase):
|
|
127
238
|
lm_config: Dict[str, Any],
|
128
239
|
use_ephemeral_cache_only: bool = False,
|
129
240
|
reasoning_effort: str = "high",
|
130
|
-
|
241
|
+
tools: Optional[List[BaseTool]] = None,
|
242
|
+
) -> BaseLMResponse:
|
131
243
|
assert (
|
132
244
|
lm_config.get("response_model", None) is None
|
133
245
|
), "response_model is not supported for standard calls"
|
134
|
-
used_cache_handler = get_cache_handler(
|
246
|
+
used_cache_handler = get_cache_handler(
|
247
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only
|
248
|
+
)
|
135
249
|
cache_result = used_cache_handler.hit_managed_cache(
|
136
|
-
model, messages, lm_config=lm_config
|
250
|
+
model, messages, lm_config=lm_config, tools=tools
|
137
251
|
)
|
138
252
|
if cache_result:
|
139
|
-
return
|
140
|
-
|
141
|
-
|
142
|
-
else cache_result
|
143
|
-
)
|
144
|
-
api_result = self._private_request_sync(
|
253
|
+
return cache_result
|
254
|
+
|
255
|
+
raw_response, tool_calls = self._private_request_sync(
|
145
256
|
messages,
|
146
257
|
temperature=lm_config.get("temperature", SPECIAL_BASE_TEMPS.get(model, 0)),
|
147
258
|
reasoning_effort=reasoning_effort,
|
259
|
+
tools=tools,
|
148
260
|
)
|
261
|
+
|
262
|
+
lm_response = BaseLMResponse(
|
263
|
+
raw_response=raw_response,
|
264
|
+
structured_output=None,
|
265
|
+
tool_calls=tool_calls,
|
266
|
+
)
|
267
|
+
|
149
268
|
used_cache_handler.add_to_managed_cache(
|
150
|
-
model, messages, lm_config=lm_config, output=
|
269
|
+
model, messages, lm_config=lm_config, output=lm_response, tools=tools
|
151
270
|
)
|
152
|
-
return
|
271
|
+
return lm_response
|