docent-python 0.1.20a0__py3-none-any.whl → 0.1.22a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +320 -0
- docent/_llm_util/data_models/simple_svc.py +79 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/model_registry.py +126 -0
- docent/_llm_util/prod_llms.py +454 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/data_models/transcript.py +2 -0
- docent/judges/__init__.py +21 -0
- docent/judges/impl.py +232 -0
- docent/judges/types.py +240 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +84 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +95 -0
- docent/judges/util/voting.py +114 -0
- docent/trace.py +19 -1
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/METADATA +7 -1
- docent_python-0.1.22a0.dist-info/RECORD +58 -0
- docent_python-0.1.20a0.dist-info/RECORD +0 -34
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,530 @@
|
|
|
1
|
+
from typing import Any, Literal, cast
|
|
2
|
+
|
|
3
|
+
import backoff
|
|
4
|
+
import requests
|
|
5
|
+
from backoff.types import Details
|
|
6
|
+
from google import genai
|
|
7
|
+
from google.genai import errors, types
|
|
8
|
+
from google.genai.client import AsyncClient as AsyncGoogle
|
|
9
|
+
|
|
10
|
+
from docent._llm_util.data_models.exceptions import (
|
|
11
|
+
CompletionTooLongException,
|
|
12
|
+
ContextWindowException,
|
|
13
|
+
NoResponseException,
|
|
14
|
+
RateLimitException,
|
|
15
|
+
)
|
|
16
|
+
from docent._llm_util.data_models.llm_output import (
|
|
17
|
+
AsyncSingleLLMOutputStreamingCallback,
|
|
18
|
+
LLMCompletion,
|
|
19
|
+
LLMOutput,
|
|
20
|
+
UsageMetrics,
|
|
21
|
+
)
|
|
22
|
+
from docent._llm_util.providers.common import (
|
|
23
|
+
async_timeout_ctx,
|
|
24
|
+
coerce_tool_args,
|
|
25
|
+
reasoning_budget,
|
|
26
|
+
)
|
|
27
|
+
from docent._log_util import get_logger
|
|
28
|
+
from docent.data_models.chat import ChatMessage, Content, ToolCall, ToolInfo
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_google_client_async(api_key: str | None = None) -> AsyncGoogle:
|
|
32
|
+
if api_key:
|
|
33
|
+
return genai.Client(api_key=api_key).aio
|
|
34
|
+
return genai.Client().aio
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
logger = get_logger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _convert_google_error(e: errors.APIError):
|
|
41
|
+
if e.code in [429, 502, 503, 504]:
|
|
42
|
+
return RateLimitException(e)
|
|
43
|
+
elif e.code == 400 and "maximum number of tokens" in str(e).lower():
|
|
44
|
+
return ContextWindowException()
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _print_backoff_message(e: Details):
|
|
49
|
+
logger.warning(
|
|
50
|
+
f"Google backing off for {e['wait']:.2f}s due to {e['exception'].__class__.__name__}" # type: ignore
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _is_retryable_error(exception: BaseException) -> bool:
|
|
55
|
+
"""Checks if the exception is a retryable error based on the criteria."""
|
|
56
|
+
if isinstance(exception, errors.APIError):
|
|
57
|
+
return exception.code in [429, 500, 502, 503, 504]
|
|
58
|
+
if isinstance(exception, requests.exceptions.ConnectionError):
|
|
59
|
+
return True
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@backoff.on_exception(
|
|
64
|
+
backoff.expo,
|
|
65
|
+
exception=(Exception),
|
|
66
|
+
giveup=lambda e: not _is_retryable_error(e),
|
|
67
|
+
max_tries=3,
|
|
68
|
+
factor=2.0,
|
|
69
|
+
on_backoff=_print_backoff_message,
|
|
70
|
+
)
|
|
71
|
+
async def get_google_chat_completion_async(
|
|
72
|
+
client: AsyncGoogle,
|
|
73
|
+
messages: list[ChatMessage],
|
|
74
|
+
model_name: str,
|
|
75
|
+
tools: list[ToolInfo] | None = None,
|
|
76
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
77
|
+
max_new_tokens: int = 32,
|
|
78
|
+
temperature: float = 1.0,
|
|
79
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
80
|
+
logprobs: bool = False,
|
|
81
|
+
top_logprobs: int | None = None,
|
|
82
|
+
timeout: float = 5.0,
|
|
83
|
+
) -> LLMOutput:
|
|
84
|
+
if logprobs or top_logprobs is not None:
|
|
85
|
+
raise NotImplementedError(
|
|
86
|
+
"We have not implemented logprobs or top_logprobs for Google yet."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
system, input_messages = _parse_chat_messages(messages, tools_provided=bool(tools))
|
|
90
|
+
|
|
91
|
+
async with async_timeout_ctx(timeout):
|
|
92
|
+
thinking_cfg = None
|
|
93
|
+
if reasoning_effort:
|
|
94
|
+
thinking_cfg = types.ThinkingConfig(
|
|
95
|
+
include_thoughts=True,
|
|
96
|
+
thinking_budget=reasoning_budget(max_new_tokens, reasoning_effort),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
raw_output = await client.models.generate_content( # type: ignore
|
|
100
|
+
model=model_name,
|
|
101
|
+
contents=input_messages, # type: ignore
|
|
102
|
+
config=types.GenerateContentConfig(
|
|
103
|
+
temperature=temperature,
|
|
104
|
+
thinking_config=thinking_cfg,
|
|
105
|
+
max_output_tokens=max_new_tokens,
|
|
106
|
+
system_instruction=system,
|
|
107
|
+
tools=_parse_tools(tools) if tools else None,
|
|
108
|
+
tool_config=(
|
|
109
|
+
types.ToolConfig(function_calling_config=_parse_tool_choice(tool_choice))
|
|
110
|
+
if tool_choice is not None
|
|
111
|
+
else None
|
|
112
|
+
),
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
output = _parse_google_completion(raw_output, model_name)
|
|
117
|
+
if output.first and output.first.finish_reason == "length" and output.first.no_text:
|
|
118
|
+
raise CompletionTooLongException(
|
|
119
|
+
f"Completion empty due to truncation. Consider increasing max_new_tokens (currently {max_new_tokens})."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return output
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@backoff.on_exception(
|
|
126
|
+
backoff.expo,
|
|
127
|
+
exception=(Exception),
|
|
128
|
+
giveup=lambda e: not _is_retryable_error(e),
|
|
129
|
+
max_tries=3,
|
|
130
|
+
factor=2.0,
|
|
131
|
+
on_backoff=_print_backoff_message,
|
|
132
|
+
)
|
|
133
|
+
async def get_google_chat_completion_streaming_async(
|
|
134
|
+
client: AsyncGoogle,
|
|
135
|
+
streaming_callback: AsyncSingleLLMOutputStreamingCallback | None,
|
|
136
|
+
messages: list[ChatMessage],
|
|
137
|
+
model_name: str,
|
|
138
|
+
tools: list[ToolInfo] | None = None,
|
|
139
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
140
|
+
max_new_tokens: int = 32,
|
|
141
|
+
temperature: float = 1.0,
|
|
142
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
143
|
+
logprobs: bool = False,
|
|
144
|
+
top_logprobs: int | None = None,
|
|
145
|
+
timeout: float = 5.0,
|
|
146
|
+
) -> LLMOutput:
|
|
147
|
+
if logprobs or top_logprobs is not None:
|
|
148
|
+
raise NotImplementedError(
|
|
149
|
+
"We have not implemented logprobs or top_logprobs for Google yet."
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
system, input_messages = _parse_chat_messages(messages, tools_provided=bool(tools))
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
async with async_timeout_ctx(timeout):
|
|
156
|
+
thinking_cfg = None
|
|
157
|
+
if reasoning_effort:
|
|
158
|
+
thinking_cfg = types.ThinkingConfig(
|
|
159
|
+
include_thoughts=True,
|
|
160
|
+
thinking_budget=reasoning_budget(max_new_tokens, reasoning_effort),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
stream = await client.models.generate_content_stream( # type: ignore
|
|
164
|
+
model=model_name,
|
|
165
|
+
contents=input_messages, # type: ignore
|
|
166
|
+
config=types.GenerateContentConfig(
|
|
167
|
+
temperature=temperature,
|
|
168
|
+
thinking_config=thinking_cfg,
|
|
169
|
+
max_output_tokens=max_new_tokens,
|
|
170
|
+
system_instruction=system,
|
|
171
|
+
tools=_parse_tools(tools) if tools else None,
|
|
172
|
+
tool_config=(
|
|
173
|
+
types.ToolConfig(function_calling_config=_parse_tool_choice(tool_choice))
|
|
174
|
+
if tool_choice is not None
|
|
175
|
+
else None
|
|
176
|
+
),
|
|
177
|
+
),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
accumulated_text = ""
|
|
181
|
+
accumulated_tool_calls: list[ToolCall] = []
|
|
182
|
+
finish_reason: str | None = None
|
|
183
|
+
usage = UsageMetrics()
|
|
184
|
+
|
|
185
|
+
async for chunk in stream:
|
|
186
|
+
candidate = chunk.candidates[0] if chunk.candidates else None
|
|
187
|
+
if candidate and candidate.content and candidate.content.parts:
|
|
188
|
+
for part in candidate.content.parts:
|
|
189
|
+
if part.text is not None and not part.thought:
|
|
190
|
+
accumulated_text += part.text or ""
|
|
191
|
+
elif part.function_call is not None:
|
|
192
|
+
fc = part.function_call
|
|
193
|
+
args = coerce_tool_args(getattr(fc, "args", {}))
|
|
194
|
+
accumulated_tool_calls.append(
|
|
195
|
+
ToolCall(
|
|
196
|
+
id=getattr(fc, "id", None)
|
|
197
|
+
or f"{getattr(fc, 'name', 'tool')}_call",
|
|
198
|
+
function=getattr(fc, "name", "unknown"),
|
|
199
|
+
arguments=args,
|
|
200
|
+
type="function",
|
|
201
|
+
)
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
if candidate and candidate.finish_reason is not None:
|
|
205
|
+
if candidate.finish_reason == types.FinishReason.STOP:
|
|
206
|
+
finish_reason = "stop"
|
|
207
|
+
elif candidate.finish_reason == types.FinishReason.MAX_TOKENS:
|
|
208
|
+
finish_reason = "length"
|
|
209
|
+
else:
|
|
210
|
+
finish_reason = "error"
|
|
211
|
+
|
|
212
|
+
# Check for usage metadata in the chunk
|
|
213
|
+
if usage_metadata := chunk.usage_metadata:
|
|
214
|
+
if usage_metadata.prompt_token_count is not None:
|
|
215
|
+
usage["input"] = int(usage_metadata.prompt_token_count)
|
|
216
|
+
if usage_metadata.candidates_token_count is not None:
|
|
217
|
+
usage["output"] = int(usage_metadata.candidates_token_count)
|
|
218
|
+
|
|
219
|
+
if streaming_callback is not None:
|
|
220
|
+
await streaming_callback(
|
|
221
|
+
LLMOutput(
|
|
222
|
+
model=model_name,
|
|
223
|
+
completions=[LLMCompletion(text=accumulated_text)],
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return LLMOutput(
|
|
228
|
+
model=model_name,
|
|
229
|
+
completions=[
|
|
230
|
+
LLMCompletion(
|
|
231
|
+
text=accumulated_text,
|
|
232
|
+
tool_calls=(accumulated_tool_calls or None),
|
|
233
|
+
finish_reason=finish_reason,
|
|
234
|
+
)
|
|
235
|
+
],
|
|
236
|
+
usage=usage,
|
|
237
|
+
)
|
|
238
|
+
except errors.APIError as e:
|
|
239
|
+
if e2 := _convert_google_error(e):
|
|
240
|
+
raise e2 from e
|
|
241
|
+
else:
|
|
242
|
+
raise
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _parse_chat_messages(
|
|
246
|
+
messages: list[ChatMessage],
|
|
247
|
+
*,
|
|
248
|
+
tools_provided: bool = False,
|
|
249
|
+
) -> tuple[str | None, list[types.Content]]:
|
|
250
|
+
result: list[types.Content] = []
|
|
251
|
+
system_prompt: str | None = None
|
|
252
|
+
|
|
253
|
+
for message in messages:
|
|
254
|
+
if message.role == "user":
|
|
255
|
+
parts = _parse_message_content(message.content)
|
|
256
|
+
if parts: # Avoid sending empty text parts
|
|
257
|
+
result.append(
|
|
258
|
+
types.Content(
|
|
259
|
+
role="user",
|
|
260
|
+
parts=parts,
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
elif message.role == "assistant":
|
|
264
|
+
parts: list[types.Part] = _parse_message_content(message.content)
|
|
265
|
+
# If assistant previously made tool calls, include them so the model has full context
|
|
266
|
+
for tool_call in getattr(message, "tool_calls", []) or []:
|
|
267
|
+
try:
|
|
268
|
+
parts.append(
|
|
269
|
+
types.Part.from_function_call(
|
|
270
|
+
name=tool_call.function,
|
|
271
|
+
args=tool_call.arguments, # type: ignore[arg-type]
|
|
272
|
+
id=tool_call.id, # type: ignore[call-arg]
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
except Exception:
|
|
276
|
+
# Fallback without id if the SDK signature differs
|
|
277
|
+
parts.append(
|
|
278
|
+
types.Part.from_function_call(
|
|
279
|
+
name=tool_call.function,
|
|
280
|
+
args=tool_call.arguments, # type: ignore[arg-type]
|
|
281
|
+
)
|
|
282
|
+
)
|
|
283
|
+
if parts: # If only tool calls with no text, we still include function_call parts
|
|
284
|
+
result.append(types.Content(role="model", parts=parts))
|
|
285
|
+
elif getattr(message, "tool_calls", []):
|
|
286
|
+
# Include just the tool calls if present
|
|
287
|
+
result.append(types.Content(role="model", parts=parts))
|
|
288
|
+
elif message.role == "tool":
|
|
289
|
+
# Represent tool result as a function_response part (Gemini tool execution result)
|
|
290
|
+
if not tools_provided:
|
|
291
|
+
# If no tools configured, pass through as plain text
|
|
292
|
+
parts = _parse_message_content(message.content)
|
|
293
|
+
if parts:
|
|
294
|
+
result.append(types.Content(role="user", parts=parts))
|
|
295
|
+
else:
|
|
296
|
+
tool_name = getattr(message, "function", None) or "unknown_tool"
|
|
297
|
+
tool_id = getattr(message, "tool_call_id", None)
|
|
298
|
+
# Try to parse tool content as JSON if it looks like JSON; otherwise wrap as text
|
|
299
|
+
tool_text = message.text or ""
|
|
300
|
+
response_obj: dict[str, Any]
|
|
301
|
+
try:
|
|
302
|
+
import json as _json
|
|
303
|
+
|
|
304
|
+
parsed = _json.loads(tool_text)
|
|
305
|
+
if isinstance(parsed, dict): # type: ignore[redundant-cast]
|
|
306
|
+
response_obj = cast(dict[str, Any], parsed)
|
|
307
|
+
else:
|
|
308
|
+
response_obj = {"result": parsed}
|
|
309
|
+
except Exception:
|
|
310
|
+
response_obj = {"result": tool_text}
|
|
311
|
+
|
|
312
|
+
part = _make_function_response_part(name=tool_name, response=response_obj, id=tool_id) # type: ignore[arg-type]
|
|
313
|
+
result.append(types.Content(role="user", parts=[part]))
|
|
314
|
+
elif message.role == "system":
|
|
315
|
+
system_prompt = message.text
|
|
316
|
+
else:
|
|
317
|
+
raise ValueError(f"Unknown message role: {message.role}")
|
|
318
|
+
|
|
319
|
+
return system_prompt, result
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _parse_message_content(content: str | list[Content]) -> list[types.Part]:
|
|
323
|
+
if isinstance(content, str):
|
|
324
|
+
text = content.strip()
|
|
325
|
+
return [types.Part.from_text(text=text)] if text else []
|
|
326
|
+
else:
|
|
327
|
+
result: list[types.Part] = []
|
|
328
|
+
for sub_content in content:
|
|
329
|
+
if sub_content.type == "text":
|
|
330
|
+
txt = (sub_content.text or "").strip()
|
|
331
|
+
if txt:
|
|
332
|
+
result.append(types.Part.from_text(text=txt))
|
|
333
|
+
else:
|
|
334
|
+
raise ValueError(f"Unsupported content type: {sub_content.type}")
|
|
335
|
+
return result
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _parse_google_completion(message: types.GenerateContentResponse, model: str) -> LLMOutput:
|
|
339
|
+
if not message.candidates:
|
|
340
|
+
return LLMOutput(
|
|
341
|
+
model=model,
|
|
342
|
+
completions=[],
|
|
343
|
+
errors=[NoResponseException()],
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
candidate = message.candidates[0]
|
|
347
|
+
|
|
348
|
+
if candidate.finish_reason == types.FinishReason.STOP:
|
|
349
|
+
finish_reason = "stop"
|
|
350
|
+
elif candidate.finish_reason == types.FinishReason.MAX_TOKENS:
|
|
351
|
+
finish_reason = "length"
|
|
352
|
+
else:
|
|
353
|
+
finish_reason = "error"
|
|
354
|
+
|
|
355
|
+
text = ""
|
|
356
|
+
tool_calls: list[ToolCall] = []
|
|
357
|
+
content_parts = candidate.content.parts if candidate.content else []
|
|
358
|
+
content_parts = content_parts or []
|
|
359
|
+
for part in content_parts:
|
|
360
|
+
if part.text is not None and not part.thought:
|
|
361
|
+
text += part.text
|
|
362
|
+
elif part.thought:
|
|
363
|
+
logger.warning("Google returned thinking block; we should support this soon.")
|
|
364
|
+
elif getattr(part, "function_call", None) is not None:
|
|
365
|
+
fc = part.function_call
|
|
366
|
+
# Attempt to parse arguments as a dictionary
|
|
367
|
+
args = coerce_tool_args(getattr(fc, "args", {}))
|
|
368
|
+
tool_calls.append(
|
|
369
|
+
ToolCall(
|
|
370
|
+
id=getattr(fc, "id", None) or f"{getattr(fc, 'name', 'tool')}_call",
|
|
371
|
+
function=getattr(fc, "name", "unknown"),
|
|
372
|
+
arguments=args,
|
|
373
|
+
type="function",
|
|
374
|
+
)
|
|
375
|
+
)
|
|
376
|
+
else:
|
|
377
|
+
raise ValueError(f"Unknown content part: {part}")
|
|
378
|
+
|
|
379
|
+
# Extract usage metrics from the response
|
|
380
|
+
usage = UsageMetrics()
|
|
381
|
+
if usage_metadata := message.usage_metadata:
|
|
382
|
+
if usage_metadata.prompt_token_count is not None:
|
|
383
|
+
usage["input"] = int(usage_metadata.prompt_token_count)
|
|
384
|
+
if usage_metadata.candidates_token_count is not None:
|
|
385
|
+
usage["output"] = int(usage_metadata.candidates_token_count)
|
|
386
|
+
|
|
387
|
+
return LLMOutput(
|
|
388
|
+
model=model,
|
|
389
|
+
completions=[
|
|
390
|
+
LLMCompletion(
|
|
391
|
+
text=text,
|
|
392
|
+
finish_reason=("tool_calls" if tool_calls else finish_reason),
|
|
393
|
+
tool_calls=(tool_calls or None),
|
|
394
|
+
)
|
|
395
|
+
],
|
|
396
|
+
usage=usage,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def _parse_tools(tools: list[ToolInfo]) -> list[types.Tool]:
|
|
401
|
+
# Gemini expects a list of Tool objects, each with one or more FunctionDeclarations
|
|
402
|
+
fds: list[types.FunctionDeclaration] = []
|
|
403
|
+
for tool in tools:
|
|
404
|
+
fds.append(
|
|
405
|
+
types.FunctionDeclaration(
|
|
406
|
+
name=tool.name,
|
|
407
|
+
description=tool.description,
|
|
408
|
+
parameters=_convert_toolparams_to_schema(tool.parameters),
|
|
409
|
+
)
|
|
410
|
+
)
|
|
411
|
+
# Group all function declarations into a single Tool for simplicity
|
|
412
|
+
return [types.Tool(function_declarations=fds)]
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _parse_tool_choice(tool_choice: Literal["auto", "required"] | None):
|
|
416
|
+
if tool_choice is None:
|
|
417
|
+
return None
|
|
418
|
+
# Map our values to SDK enum; if unavailable, return None so default behavior applies
|
|
419
|
+
try:
|
|
420
|
+
if tool_choice == "auto":
|
|
421
|
+
return types.FunctionCallingConfig(mode=types.FunctionCallingConfigMode.AUTO)
|
|
422
|
+
elif tool_choice == "required":
|
|
423
|
+
return types.FunctionCallingConfig(mode=types.FunctionCallingConfigMode.ANY)
|
|
424
|
+
except Exception:
|
|
425
|
+
return None
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _convert_toolparams_to_schema(params: Any) -> types.Schema:
|
|
429
|
+
properties: dict[str, types.Schema] = {}
|
|
430
|
+
params_props: dict[str, Any] = getattr(params, "properties", {}) or {}
|
|
431
|
+
for name, param in params_props.items():
|
|
432
|
+
prop_schema = _convert_json_schema_to_gemini_schema(
|
|
433
|
+
(getattr(param, "input_schema", {}) or {})
|
|
434
|
+
)
|
|
435
|
+
desc: Any = getattr(param, "description", None)
|
|
436
|
+
if desc and prop_schema.description is None:
|
|
437
|
+
prop_schema.description = desc
|
|
438
|
+
properties[str(name)] = prop_schema
|
|
439
|
+
|
|
440
|
+
required_names: list[str] | None = None
|
|
441
|
+
required_raw: Any = getattr(params, "required", None)
|
|
442
|
+
if isinstance(required_raw, list):
|
|
443
|
+
required_list: list[Any] = cast(list[Any], required_raw)
|
|
444
|
+
required_names = [str(item) for item in required_list]
|
|
445
|
+
|
|
446
|
+
return types.Schema(
|
|
447
|
+
type=types.Type.OBJECT,
|
|
448
|
+
properties=properties or None,
|
|
449
|
+
required=required_names,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def _convert_json_schema_to_gemini_schema(js: dict[str, Any]) -> types.Schema:
|
|
454
|
+
type_get: Any = js.get("type")
|
|
455
|
+
type_name: str
|
|
456
|
+
if isinstance(type_get, str):
|
|
457
|
+
type_name = type_get.lower()
|
|
458
|
+
elif isinstance(type_get, list):
|
|
459
|
+
# Convert list to list[str] then take first
|
|
460
|
+
type_list: list[str] = [str(v) for v in cast(list[Any], type_get)]
|
|
461
|
+
type_name = type_list[0].lower() if type_list else ""
|
|
462
|
+
elif type_get is None:
|
|
463
|
+
type_name = ""
|
|
464
|
+
else:
|
|
465
|
+
type_name = str(type_get).lower()
|
|
466
|
+
if type_name == "string":
|
|
467
|
+
t: types.Type | None = types.Type.STRING
|
|
468
|
+
elif type_name == "number":
|
|
469
|
+
t = types.Type.NUMBER
|
|
470
|
+
elif type_name == "integer":
|
|
471
|
+
t = types.Type.INTEGER
|
|
472
|
+
elif type_name == "boolean":
|
|
473
|
+
t = types.Type.BOOLEAN
|
|
474
|
+
elif type_name == "array":
|
|
475
|
+
t = types.Type.ARRAY
|
|
476
|
+
elif type_name == "object":
|
|
477
|
+
t = types.Type.OBJECT
|
|
478
|
+
elif type_name == "null":
|
|
479
|
+
t = types.Type.NULL
|
|
480
|
+
else:
|
|
481
|
+
t = None
|
|
482
|
+
description = js.get("description")
|
|
483
|
+
enum_vals_any: Any = js.get("enum")
|
|
484
|
+
enum_vals: list[str] | None = None
|
|
485
|
+
if isinstance(enum_vals_any, list):
|
|
486
|
+
enum_vals = [str(v) for v in cast(list[Any], enum_vals_any)] or None
|
|
487
|
+
|
|
488
|
+
props_in_raw_any: Any = js.get("properties") or {}
|
|
489
|
+
props_in_raw: dict[str, Any] = (
|
|
490
|
+
cast(dict[str, Any], props_in_raw_any) if isinstance(props_in_raw_any, dict) else {}
|
|
491
|
+
)
|
|
492
|
+
props_out: dict[str, types.Schema] | None = None
|
|
493
|
+
if props_in_raw:
|
|
494
|
+
tmp_props: dict[str, types.Schema] = {}
|
|
495
|
+
for key, val in props_in_raw.items():
|
|
496
|
+
if isinstance(val, dict):
|
|
497
|
+
tmp_props[str(key)] = _convert_json_schema_to_gemini_schema(
|
|
498
|
+
cast(dict[str, Any], val)
|
|
499
|
+
)
|
|
500
|
+
props_out = tmp_props if tmp_props else None
|
|
501
|
+
|
|
502
|
+
required_out: list[str] | None = None
|
|
503
|
+
required_raw_js: Any = js.get("required")
|
|
504
|
+
if isinstance(required_raw_js, list):
|
|
505
|
+
tmp_required_any: list[Any] = cast(list[Any], required_raw_js)
|
|
506
|
+
tmp_required: list[str] = [str(item) for item in tmp_required_any]
|
|
507
|
+
required_out = tmp_required or None
|
|
508
|
+
|
|
509
|
+
items_in_any: Any = js.get("items")
|
|
510
|
+
items_out: types.Schema | None = None
|
|
511
|
+
if isinstance(items_in_any, dict):
|
|
512
|
+
items_out = _convert_json_schema_to_gemini_schema(cast(dict[str, Any], items_in_any))
|
|
513
|
+
|
|
514
|
+
return types.Schema(
|
|
515
|
+
type=t,
|
|
516
|
+
description=description,
|
|
517
|
+
enum=enum_vals,
|
|
518
|
+
properties=props_out,
|
|
519
|
+
required=required_out,
|
|
520
|
+
items=items_out,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def _make_function_response_part(
|
|
525
|
+
*, name: str, response: dict[str, object], id: str | None
|
|
526
|
+
) -> types.Part:
|
|
527
|
+
try:
|
|
528
|
+
return types.Part.from_function_response(name=name, response=response, id=id) # type: ignore[call-arg]
|
|
529
|
+
except Exception:
|
|
530
|
+
return types.Part.from_function_response(name=name, response=response)
|