code-puppy 0.0.348__py3-none-any.whl → 0.0.372__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_puppy/agents/__init__.py +8 -0
- code_puppy/agents/agent_manager.py +272 -1
- code_puppy/agents/agent_pack_leader.py +383 -0
- code_puppy/agents/agent_qa_kitten.py +12 -7
- code_puppy/agents/agent_terminal_qa.py +323 -0
- code_puppy/agents/base_agent.py +11 -8
- code_puppy/agents/event_stream_handler.py +101 -8
- code_puppy/agents/pack/__init__.py +34 -0
- code_puppy/agents/pack/bloodhound.py +304 -0
- code_puppy/agents/pack/husky.py +321 -0
- code_puppy/agents/pack/retriever.py +393 -0
- code_puppy/agents/pack/shepherd.py +348 -0
- code_puppy/agents/pack/terrier.py +287 -0
- code_puppy/agents/pack/watchdog.py +367 -0
- code_puppy/agents/subagent_stream_handler.py +276 -0
- code_puppy/api/__init__.py +13 -0
- code_puppy/api/app.py +169 -0
- code_puppy/api/main.py +21 -0
- code_puppy/api/pty_manager.py +446 -0
- code_puppy/api/routers/__init__.py +12 -0
- code_puppy/api/routers/agents.py +36 -0
- code_puppy/api/routers/commands.py +217 -0
- code_puppy/api/routers/config.py +74 -0
- code_puppy/api/routers/sessions.py +232 -0
- code_puppy/api/templates/terminal.html +361 -0
- code_puppy/api/websocket.py +154 -0
- code_puppy/callbacks.py +73 -0
- code_puppy/chatgpt_codex_client.py +53 -0
- code_puppy/claude_cache_client.py +294 -41
- code_puppy/command_line/add_model_menu.py +13 -4
- code_puppy/command_line/agent_menu.py +662 -0
- code_puppy/command_line/core_commands.py +89 -112
- code_puppy/command_line/model_picker_completion.py +3 -20
- code_puppy/command_line/model_settings_menu.py +21 -3
- code_puppy/config.py +145 -70
- code_puppy/gemini_model.py +706 -0
- code_puppy/http_utils.py +6 -3
- code_puppy/messaging/__init__.py +15 -0
- code_puppy/messaging/messages.py +27 -0
- code_puppy/messaging/queue_console.py +1 -1
- code_puppy/messaging/rich_renderer.py +36 -1
- code_puppy/messaging/spinner/__init__.py +20 -2
- code_puppy/messaging/subagent_console.py +461 -0
- code_puppy/model_factory.py +50 -16
- code_puppy/model_switching.py +63 -0
- code_puppy/model_utils.py +27 -24
- code_puppy/models.json +12 -12
- code_puppy/plugins/antigravity_oauth/antigravity_model.py +206 -172
- code_puppy/plugins/antigravity_oauth/register_callbacks.py +15 -8
- code_puppy/plugins/antigravity_oauth/transport.py +236 -45
- code_puppy/plugins/chatgpt_oauth/register_callbacks.py +2 -2
- code_puppy/plugins/claude_code_oauth/register_callbacks.py +2 -30
- code_puppy/plugins/claude_code_oauth/utils.py +4 -1
- code_puppy/plugins/frontend_emitter/__init__.py +25 -0
- code_puppy/plugins/frontend_emitter/emitter.py +121 -0
- code_puppy/plugins/frontend_emitter/register_callbacks.py +261 -0
- code_puppy/prompts/antigravity_system_prompt.md +1 -0
- code_puppy/pydantic_patches.py +52 -0
- code_puppy/status_display.py +6 -2
- code_puppy/tools/__init__.py +37 -1
- code_puppy/tools/agent_tools.py +83 -33
- code_puppy/tools/browser/__init__.py +37 -0
- code_puppy/tools/browser/browser_control.py +6 -6
- code_puppy/tools/browser/browser_interactions.py +21 -20
- code_puppy/tools/browser/browser_locators.py +9 -9
- code_puppy/tools/browser/browser_manager.py +316 -0
- code_puppy/tools/browser/browser_navigation.py +7 -7
- code_puppy/tools/browser/browser_screenshot.py +78 -140
- code_puppy/tools/browser/browser_scripts.py +15 -13
- code_puppy/tools/browser/chromium_terminal_manager.py +259 -0
- code_puppy/tools/browser/terminal_command_tools.py +521 -0
- code_puppy/tools/browser/terminal_screenshot_tools.py +556 -0
- code_puppy/tools/browser/terminal_tools.py +525 -0
- code_puppy/tools/command_runner.py +292 -101
- code_puppy/tools/common.py +176 -1
- code_puppy/tools/display.py +84 -0
- code_puppy/tools/subagent_context.py +158 -0
- {code_puppy-0.0.348.data → code_puppy-0.0.372.data}/data/code_puppy/models.json +12 -12
- {code_puppy-0.0.348.dist-info → code_puppy-0.0.372.dist-info}/METADATA +17 -16
- {code_puppy-0.0.348.dist-info → code_puppy-0.0.372.dist-info}/RECORD +84 -51
- code_puppy/prompts/codex_system_prompt.md +0 -310
- code_puppy/tools/browser/camoufox_manager.py +0 -235
- code_puppy/tools/browser/vqa_agent.py +0 -90
- {code_puppy-0.0.348.data → code_puppy-0.0.372.data}/data/code_puppy/models_dev_api.json +0 -0
- {code_puppy-0.0.348.dist-info → code_puppy-0.0.372.dist-info}/WHEEL +0 -0
- {code_puppy-0.0.348.dist-info → code_puppy-0.0.372.dist-info}/entry_points.txt +0 -0
- {code_puppy-0.0.348.dist-info → code_puppy-0.0.372.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,706 @@
|
|
|
1
|
+
"""Standalone Gemini Model for pydantic_ai - no google-genai dependency.
|
|
2
|
+
|
|
3
|
+
This module provides a custom Model implementation that uses Google's
|
|
4
|
+
Generative Language API directly via httpx, without the bloated google-genai
|
|
5
|
+
SDK dependency.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import base64
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import uuid
|
|
14
|
+
from collections.abc import AsyncIterator
|
|
15
|
+
from contextlib import asynccontextmanager
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from datetime import datetime, timezone
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import httpx
|
|
21
|
+
from pydantic_ai._run_context import RunContext
|
|
22
|
+
from pydantic_ai.messages import (
|
|
23
|
+
ModelMessage,
|
|
24
|
+
ModelRequest,
|
|
25
|
+
ModelResponse,
|
|
26
|
+
ModelResponsePart,
|
|
27
|
+
ModelResponseStreamEvent,
|
|
28
|
+
RetryPromptPart,
|
|
29
|
+
SystemPromptPart,
|
|
30
|
+
TextPart,
|
|
31
|
+
ThinkingPart,
|
|
32
|
+
ToolCallPart,
|
|
33
|
+
ToolReturnPart,
|
|
34
|
+
UserPromptPart,
|
|
35
|
+
)
|
|
36
|
+
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
|
|
37
|
+
from pydantic_ai.settings import ModelSettings
|
|
38
|
+
from pydantic_ai.tools import ToolDefinition
|
|
39
|
+
from pydantic_ai.usage import RequestUsage
|
|
40
|
+
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def generate_tool_call_id() -> str:
|
|
45
|
+
"""Generate a unique tool call ID."""
|
|
46
|
+
return str(uuid.uuid4())
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _flatten_union_to_object_gemini(union_items: list, defs: dict, resolve_fn) -> dict:
|
|
50
|
+
"""Flatten a union of object types into a single object with all properties.
|
|
51
|
+
|
|
52
|
+
For discriminated unions like EditFilePayload, we merge all object types
|
|
53
|
+
into one with all properties (Gemini doesn't support anyOf/oneOf).
|
|
54
|
+
"""
|
|
55
|
+
import copy as copy_module
|
|
56
|
+
|
|
57
|
+
merged_properties = {}
|
|
58
|
+
has_string_type = False
|
|
59
|
+
|
|
60
|
+
for item in union_items:
|
|
61
|
+
if not isinstance(item, dict):
|
|
62
|
+
continue
|
|
63
|
+
|
|
64
|
+
# Resolve $ref first
|
|
65
|
+
if "$ref" in item:
|
|
66
|
+
ref_path = item["$ref"]
|
|
67
|
+
ref_name = None
|
|
68
|
+
if ref_path.startswith("#/$defs/"):
|
|
69
|
+
ref_name = ref_path[8:]
|
|
70
|
+
elif ref_path.startswith("#/definitions/"):
|
|
71
|
+
ref_name = ref_path[14:]
|
|
72
|
+
if ref_name and ref_name in defs:
|
|
73
|
+
item = copy_module.deepcopy(defs[ref_name])
|
|
74
|
+
else:
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
if item.get("type") == "string":
|
|
78
|
+
has_string_type = True
|
|
79
|
+
continue
|
|
80
|
+
|
|
81
|
+
if item.get("type") == "null":
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
if item.get("type") == "object" or "properties" in item:
|
|
85
|
+
props = item.get("properties", {})
|
|
86
|
+
for prop_name, prop_schema in props.items():
|
|
87
|
+
if prop_name not in merged_properties:
|
|
88
|
+
merged_properties[prop_name] = resolve_fn(
|
|
89
|
+
copy_module.deepcopy(prop_schema)
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if not merged_properties:
|
|
93
|
+
return {"type": "string"} if has_string_type else {"type": "object"}
|
|
94
|
+
|
|
95
|
+
return {
|
|
96
|
+
"type": "object",
|
|
97
|
+
"properties": merged_properties,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _sanitize_schema_for_gemini(schema: dict) -> dict:
|
|
102
|
+
"""Sanitize JSON schema for Gemini API compatibility.
|
|
103
|
+
|
|
104
|
+
Removes/transforms fields that Gemini doesn't support:
|
|
105
|
+
- $defs, definitions, $schema, $id
|
|
106
|
+
- additionalProperties
|
|
107
|
+
- $ref (inlined)
|
|
108
|
+
- anyOf/oneOf/allOf (flattened - Gemini doesn't support unions!)
|
|
109
|
+
- For unions of objects: merges into single object with all properties
|
|
110
|
+
- For simple unions (string | null): picks first non-null type
|
|
111
|
+
"""
|
|
112
|
+
import copy
|
|
113
|
+
|
|
114
|
+
if not isinstance(schema, dict):
|
|
115
|
+
return schema
|
|
116
|
+
|
|
117
|
+
# Make a deep copy to avoid modifying original
|
|
118
|
+
schema = copy.deepcopy(schema)
|
|
119
|
+
|
|
120
|
+
# Extract $defs for reference resolution
|
|
121
|
+
defs = schema.pop("$defs", schema.pop("definitions", {}))
|
|
122
|
+
|
|
123
|
+
def resolve_refs(obj):
|
|
124
|
+
"""Recursively resolve $ref references and clean schema."""
|
|
125
|
+
if isinstance(obj, dict):
|
|
126
|
+
# Handle anyOf/oneOf unions
|
|
127
|
+
for union_key in ["anyOf", "oneOf"]:
|
|
128
|
+
if union_key in obj:
|
|
129
|
+
union = obj[union_key]
|
|
130
|
+
if isinstance(union, list):
|
|
131
|
+
# Check if this is a complex union of objects
|
|
132
|
+
object_count = 0
|
|
133
|
+
has_refs = False
|
|
134
|
+
for item in union:
|
|
135
|
+
if isinstance(item, dict):
|
|
136
|
+
if "$ref" in item:
|
|
137
|
+
has_refs = True
|
|
138
|
+
object_count += 1
|
|
139
|
+
elif (
|
|
140
|
+
item.get("type") == "object" or "properties" in item
|
|
141
|
+
):
|
|
142
|
+
object_count += 1
|
|
143
|
+
|
|
144
|
+
# If multiple objects or has refs, flatten to single object
|
|
145
|
+
if object_count > 1 or has_refs:
|
|
146
|
+
flattened = _flatten_union_to_object_gemini(
|
|
147
|
+
union, defs, resolve_refs
|
|
148
|
+
)
|
|
149
|
+
if "description" in obj:
|
|
150
|
+
flattened["description"] = obj["description"]
|
|
151
|
+
return flattened
|
|
152
|
+
|
|
153
|
+
# Simple union - pick first non-null type
|
|
154
|
+
for item in union:
|
|
155
|
+
if isinstance(item, dict) and item.get("type") != "null":
|
|
156
|
+
result = dict(item)
|
|
157
|
+
if "description" in obj:
|
|
158
|
+
result["description"] = obj["description"]
|
|
159
|
+
return resolve_refs(result)
|
|
160
|
+
|
|
161
|
+
# Handle allOf by merging all schemas
|
|
162
|
+
if "allOf" in obj:
|
|
163
|
+
all_of = obj["allOf"]
|
|
164
|
+
if isinstance(all_of, list):
|
|
165
|
+
merged = {}
|
|
166
|
+
merged_properties = {}
|
|
167
|
+
for item in all_of:
|
|
168
|
+
if isinstance(item, dict):
|
|
169
|
+
resolved_item = resolve_refs(item)
|
|
170
|
+
if "properties" in resolved_item:
|
|
171
|
+
merged_properties.update(
|
|
172
|
+
resolved_item.pop("properties")
|
|
173
|
+
)
|
|
174
|
+
merged.update(resolved_item)
|
|
175
|
+
if merged_properties:
|
|
176
|
+
merged["properties"] = merged_properties
|
|
177
|
+
for k, v in obj.items():
|
|
178
|
+
if k != "allOf":
|
|
179
|
+
merged[k] = v
|
|
180
|
+
return resolve_refs(merged)
|
|
181
|
+
|
|
182
|
+
# Check for $ref
|
|
183
|
+
if "$ref" in obj:
|
|
184
|
+
ref_path = obj["$ref"]
|
|
185
|
+
ref_name = None
|
|
186
|
+
|
|
187
|
+
# Parse ref like "#/$defs/SomeType" or "#/definitions/SomeType"
|
|
188
|
+
if ref_path.startswith("#/$defs/"):
|
|
189
|
+
ref_name = ref_path[8:]
|
|
190
|
+
elif ref_path.startswith("#/definitions/"):
|
|
191
|
+
ref_name = ref_path[14:]
|
|
192
|
+
|
|
193
|
+
if ref_name and ref_name in defs:
|
|
194
|
+
resolved = resolve_refs(copy.deepcopy(defs[ref_name]))
|
|
195
|
+
other_props = {k: v for k, v in obj.items() if k != "$ref"}
|
|
196
|
+
if other_props:
|
|
197
|
+
resolved.update(resolve_refs(other_props))
|
|
198
|
+
return resolved
|
|
199
|
+
else:
|
|
200
|
+
return {"type": "object"}
|
|
201
|
+
|
|
202
|
+
# Recursively process and transform
|
|
203
|
+
result = {}
|
|
204
|
+
for key, value in obj.items():
|
|
205
|
+
# Skip unsupported fields
|
|
206
|
+
if key in (
|
|
207
|
+
"$defs",
|
|
208
|
+
"definitions",
|
|
209
|
+
"$schema",
|
|
210
|
+
"$id",
|
|
211
|
+
"additionalProperties",
|
|
212
|
+
"default",
|
|
213
|
+
"examples",
|
|
214
|
+
"const",
|
|
215
|
+
"anyOf", # Skip any remaining union types
|
|
216
|
+
"oneOf",
|
|
217
|
+
"allOf",
|
|
218
|
+
):
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
result[key] = resolve_refs(value)
|
|
222
|
+
return result
|
|
223
|
+
elif isinstance(obj, list):
|
|
224
|
+
return [resolve_refs(item) for item in obj]
|
|
225
|
+
else:
|
|
226
|
+
return obj
|
|
227
|
+
|
|
228
|
+
return resolve_refs(schema)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class GeminiModel(Model):
|
|
232
|
+
"""Standalone Model implementation for Google's Generative Language API.
|
|
233
|
+
|
|
234
|
+
Uses httpx directly instead of google-genai SDK.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
def __init__(
|
|
238
|
+
self,
|
|
239
|
+
model_name: str,
|
|
240
|
+
api_key: str,
|
|
241
|
+
base_url: str = "https://generativelanguage.googleapis.com/v1beta",
|
|
242
|
+
http_client: httpx.AsyncClient | None = None,
|
|
243
|
+
):
|
|
244
|
+
self._model_name = model_name
|
|
245
|
+
self.api_key = api_key
|
|
246
|
+
self._base_url = base_url.rstrip("/")
|
|
247
|
+
self._http_client = http_client
|
|
248
|
+
self._owns_client = http_client is None
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def model_name(self) -> str:
|
|
252
|
+
"""Return the model name."""
|
|
253
|
+
return self._model_name
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def base_url(self) -> str:
|
|
257
|
+
"""Return the base URL for the API."""
|
|
258
|
+
return self._base_url
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def system(self) -> str:
|
|
262
|
+
"""Return the provider system identifier."""
|
|
263
|
+
return "google"
|
|
264
|
+
|
|
265
|
+
def _get_instructions(
|
|
266
|
+
self,
|
|
267
|
+
messages: list,
|
|
268
|
+
model_request_parameters,
|
|
269
|
+
) -> str | None:
|
|
270
|
+
"""Get additional instructions to prepend to system prompt.
|
|
271
|
+
|
|
272
|
+
This is a compatibility method for pydantic-ai interface.
|
|
273
|
+
Override in subclasses to inject custom instructions.
|
|
274
|
+
"""
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
def prepare_request(
|
|
278
|
+
self,
|
|
279
|
+
model_settings: ModelSettings | None,
|
|
280
|
+
model_request_parameters,
|
|
281
|
+
) -> tuple:
|
|
282
|
+
"""Prepare request by normalizing settings.
|
|
283
|
+
|
|
284
|
+
This is a compatibility method for pydantic-ai interface.
|
|
285
|
+
"""
|
|
286
|
+
return model_settings, model_request_parameters
|
|
287
|
+
|
|
288
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
289
|
+
"""Get or create HTTP client."""
|
|
290
|
+
if self._http_client is None:
|
|
291
|
+
self._http_client = httpx.AsyncClient(timeout=180)
|
|
292
|
+
return self._http_client
|
|
293
|
+
|
|
294
|
+
async def _close_client(self) -> None:
|
|
295
|
+
"""Close HTTP client if we own it."""
|
|
296
|
+
if self._owns_client and self._http_client is not None:
|
|
297
|
+
await self._http_client.aclose()
|
|
298
|
+
self._http_client = None
|
|
299
|
+
|
|
300
|
+
def _get_headers(self) -> dict[str, str]:
|
|
301
|
+
"""Get HTTP headers for the request."""
|
|
302
|
+
return {
|
|
303
|
+
"Content-Type": "application/json",
|
|
304
|
+
"Accept": "application/json",
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
async def _map_user_prompt(self, part: UserPromptPart) -> list[dict[str, Any]]:
|
|
308
|
+
"""Map a user prompt part to Gemini format."""
|
|
309
|
+
parts = []
|
|
310
|
+
|
|
311
|
+
if isinstance(part.content, str):
|
|
312
|
+
parts.append({"text": part.content})
|
|
313
|
+
elif isinstance(part.content, list):
|
|
314
|
+
for item in part.content:
|
|
315
|
+
if isinstance(item, str):
|
|
316
|
+
parts.append({"text": item})
|
|
317
|
+
elif hasattr(item, "media_type") and hasattr(item, "data"):
|
|
318
|
+
# Handle file/image content
|
|
319
|
+
data = item.data
|
|
320
|
+
if isinstance(data, bytes):
|
|
321
|
+
data = base64.b64encode(data).decode("utf-8")
|
|
322
|
+
parts.append(
|
|
323
|
+
{
|
|
324
|
+
"inline_data": {
|
|
325
|
+
"mime_type": item.media_type,
|
|
326
|
+
"data": data,
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
else:
|
|
331
|
+
parts.append({"text": str(item)})
|
|
332
|
+
else:
|
|
333
|
+
parts.append({"text": str(part.content)})
|
|
334
|
+
|
|
335
|
+
return parts
|
|
336
|
+
|
|
337
|
+
async def _map_messages(
|
|
338
|
+
self,
|
|
339
|
+
messages: list[ModelMessage],
|
|
340
|
+
model_request_parameters: ModelRequestParameters,
|
|
341
|
+
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
|
|
342
|
+
"""Map pydantic-ai messages to Gemini API format."""
|
|
343
|
+
contents: list[dict[str, Any]] = []
|
|
344
|
+
system_parts: list[dict[str, Any]] = []
|
|
345
|
+
|
|
346
|
+
for m in messages:
|
|
347
|
+
if isinstance(m, ModelRequest):
|
|
348
|
+
message_parts: list[dict[str, Any]] = []
|
|
349
|
+
|
|
350
|
+
for part in m.parts:
|
|
351
|
+
if isinstance(part, SystemPromptPart):
|
|
352
|
+
system_parts.append({"text": part.content})
|
|
353
|
+
elif isinstance(part, UserPromptPart):
|
|
354
|
+
mapped_parts = await self._map_user_prompt(part)
|
|
355
|
+
message_parts.extend(mapped_parts)
|
|
356
|
+
elif isinstance(part, ToolReturnPart):
|
|
357
|
+
message_parts.append(
|
|
358
|
+
{
|
|
359
|
+
"function_response": {
|
|
360
|
+
"name": part.tool_name,
|
|
361
|
+
"response": part.model_response_object(),
|
|
362
|
+
"id": part.tool_call_id,
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
)
|
|
366
|
+
elif isinstance(part, RetryPromptPart):
|
|
367
|
+
if part.tool_name is None:
|
|
368
|
+
message_parts.append({"text": part.model_response()})
|
|
369
|
+
else:
|
|
370
|
+
message_parts.append(
|
|
371
|
+
{
|
|
372
|
+
"function_response": {
|
|
373
|
+
"name": part.tool_name,
|
|
374
|
+
"response": {"error": part.model_response()},
|
|
375
|
+
"id": part.tool_call_id,
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
if message_parts:
|
|
381
|
+
# Merge with previous user message if exists
|
|
382
|
+
if contents and contents[-1].get("role") == "user":
|
|
383
|
+
contents[-1]["parts"].extend(message_parts)
|
|
384
|
+
else:
|
|
385
|
+
contents.append({"role": "user", "parts": message_parts})
|
|
386
|
+
|
|
387
|
+
elif isinstance(m, ModelResponse):
|
|
388
|
+
model_parts = self._map_model_response(m)
|
|
389
|
+
if model_parts:
|
|
390
|
+
# Merge with previous model message if exists
|
|
391
|
+
if contents and contents[-1].get("role") == "model":
|
|
392
|
+
contents[-1]["parts"].extend(model_parts["parts"])
|
|
393
|
+
else:
|
|
394
|
+
contents.append(model_parts)
|
|
395
|
+
|
|
396
|
+
# Ensure at least one content
|
|
397
|
+
if not contents:
|
|
398
|
+
contents = [{"role": "user", "parts": [{"text": ""}]}]
|
|
399
|
+
|
|
400
|
+
# Get any injected instructions
|
|
401
|
+
instructions = self._get_instructions(messages, model_request_parameters)
|
|
402
|
+
if instructions:
|
|
403
|
+
system_parts.insert(0, {"text": instructions})
|
|
404
|
+
|
|
405
|
+
# Build system instruction
|
|
406
|
+
system_instruction = None
|
|
407
|
+
if system_parts:
|
|
408
|
+
system_instruction = {"role": "user", "parts": system_parts}
|
|
409
|
+
|
|
410
|
+
return system_instruction, contents
|
|
411
|
+
|
|
412
|
+
def _map_model_response(self, m: ModelResponse) -> dict[str, Any] | None:
|
|
413
|
+
"""Map a ModelResponse to Gemini content format."""
|
|
414
|
+
parts: list[dict[str, Any]] = []
|
|
415
|
+
|
|
416
|
+
for item in m.parts:
|
|
417
|
+
if isinstance(item, ToolCallPart):
|
|
418
|
+
parts.append(
|
|
419
|
+
{
|
|
420
|
+
"function_call": {
|
|
421
|
+
"name": item.tool_name,
|
|
422
|
+
"args": item.args_as_dict(),
|
|
423
|
+
"id": item.tool_call_id,
|
|
424
|
+
}
|
|
425
|
+
}
|
|
426
|
+
)
|
|
427
|
+
elif isinstance(item, TextPart):
|
|
428
|
+
parts.append({"text": item.content})
|
|
429
|
+
elif isinstance(item, ThinkingPart):
|
|
430
|
+
if item.content:
|
|
431
|
+
part_dict: dict[str, Any] = {"text": item.content, "thought": True}
|
|
432
|
+
if item.signature:
|
|
433
|
+
part_dict["thoughtSignature"] = item.signature
|
|
434
|
+
parts.append(part_dict)
|
|
435
|
+
|
|
436
|
+
if not parts:
|
|
437
|
+
return None
|
|
438
|
+
return {"role": "model", "parts": parts}
|
|
439
|
+
|
|
440
|
+
def _build_tools(self, tools: list[ToolDefinition]) -> list[dict[str, Any]]:
|
|
441
|
+
"""Build tool definitions for the API."""
|
|
442
|
+
function_declarations = []
|
|
443
|
+
|
|
444
|
+
for tool in tools:
|
|
445
|
+
func_decl: dict[str, Any] = {
|
|
446
|
+
"name": tool.name,
|
|
447
|
+
"description": tool.description or "",
|
|
448
|
+
}
|
|
449
|
+
if tool.parameters_json_schema:
|
|
450
|
+
# Sanitize schema for Gemini compatibility
|
|
451
|
+
func_decl["parameters"] = _sanitize_schema_for_gemini(
|
|
452
|
+
tool.parameters_json_schema
|
|
453
|
+
)
|
|
454
|
+
function_declarations.append(func_decl)
|
|
455
|
+
|
|
456
|
+
return [{"functionDeclarations": function_declarations}]
|
|
457
|
+
|
|
458
|
+
def _build_generation_config(
|
|
459
|
+
self, model_settings: ModelSettings | None
|
|
460
|
+
) -> dict[str, Any]:
|
|
461
|
+
"""Build generation config from model settings."""
|
|
462
|
+
config: dict[str, Any] = {}
|
|
463
|
+
|
|
464
|
+
if model_settings:
|
|
465
|
+
if (
|
|
466
|
+
hasattr(model_settings, "temperature")
|
|
467
|
+
and model_settings.temperature is not None
|
|
468
|
+
):
|
|
469
|
+
config["temperature"] = model_settings.temperature
|
|
470
|
+
if hasattr(model_settings, "top_p") and model_settings.top_p is not None:
|
|
471
|
+
config["topP"] = model_settings.top_p
|
|
472
|
+
if (
|
|
473
|
+
hasattr(model_settings, "max_tokens")
|
|
474
|
+
and model_settings.max_tokens is not None
|
|
475
|
+
):
|
|
476
|
+
config["maxOutputTokens"] = model_settings.max_tokens
|
|
477
|
+
|
|
478
|
+
return config
|
|
479
|
+
|
|
480
|
+
async def request(
|
|
481
|
+
self,
|
|
482
|
+
messages: list[ModelMessage],
|
|
483
|
+
model_settings: ModelSettings | None,
|
|
484
|
+
model_request_parameters: ModelRequestParameters,
|
|
485
|
+
) -> ModelResponse:
|
|
486
|
+
"""Make a non-streaming request to the Gemini API."""
|
|
487
|
+
system_instruction, contents = await self._map_messages(
|
|
488
|
+
messages, model_request_parameters
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# Build request body
|
|
492
|
+
body: dict[str, Any] = {"contents": contents}
|
|
493
|
+
|
|
494
|
+
gen_config = self._build_generation_config(model_settings)
|
|
495
|
+
if gen_config:
|
|
496
|
+
body["generationConfig"] = gen_config
|
|
497
|
+
if system_instruction:
|
|
498
|
+
body["systemInstruction"] = system_instruction
|
|
499
|
+
|
|
500
|
+
# Add tools
|
|
501
|
+
if model_request_parameters.function_tools:
|
|
502
|
+
body["tools"] = self._build_tools(model_request_parameters.function_tools)
|
|
503
|
+
|
|
504
|
+
# Make request
|
|
505
|
+
client = await self._get_client()
|
|
506
|
+
url = f"{self._base_url}/models/{self._model_name}:generateContent?key={self.api_key}"
|
|
507
|
+
headers = self._get_headers()
|
|
508
|
+
|
|
509
|
+
response = await client.post(url, json=body, headers=headers)
|
|
510
|
+
|
|
511
|
+
if response.status_code != 200:
|
|
512
|
+
raise RuntimeError(
|
|
513
|
+
f"Gemini API error {response.status_code}: {response.text}"
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
data = response.json()
|
|
517
|
+
return self._parse_response(data)
|
|
518
|
+
|
|
519
|
+
def _parse_response(self, data: dict[str, Any]) -> ModelResponse:
|
|
520
|
+
"""Parse the Gemini API response."""
|
|
521
|
+
candidates = data.get("candidates", [])
|
|
522
|
+
if not candidates:
|
|
523
|
+
return ModelResponse(
|
|
524
|
+
parts=[TextPart(content="")],
|
|
525
|
+
model_name=self._model_name,
|
|
526
|
+
usage=RequestUsage(),
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
candidate = candidates[0]
|
|
530
|
+
content = candidate.get("content", {})
|
|
531
|
+
parts = content.get("parts", [])
|
|
532
|
+
|
|
533
|
+
response_parts: list[ModelResponsePart] = []
|
|
534
|
+
|
|
535
|
+
for part in parts:
|
|
536
|
+
if part.get("thought") and part.get("text") is not None:
|
|
537
|
+
# Thinking part
|
|
538
|
+
signature = part.get("thoughtSignature")
|
|
539
|
+
response_parts.append(
|
|
540
|
+
ThinkingPart(content=part["text"], signature=signature)
|
|
541
|
+
)
|
|
542
|
+
elif "text" in part:
|
|
543
|
+
response_parts.append(TextPart(content=part["text"]))
|
|
544
|
+
elif "functionCall" in part:
|
|
545
|
+
fc = part["functionCall"]
|
|
546
|
+
response_parts.append(
|
|
547
|
+
ToolCallPart(
|
|
548
|
+
tool_name=fc["name"],
|
|
549
|
+
args=fc.get("args", {}),
|
|
550
|
+
tool_call_id=fc.get("id") or generate_tool_call_id(),
|
|
551
|
+
)
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# Extract usage
|
|
555
|
+
usage_meta = data.get("usageMetadata", {})
|
|
556
|
+
usage = RequestUsage(
|
|
557
|
+
input_tokens=usage_meta.get("promptTokenCount", 0),
|
|
558
|
+
output_tokens=usage_meta.get("candidatesTokenCount", 0),
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
return ModelResponse(
|
|
562
|
+
parts=response_parts,
|
|
563
|
+
model_name=self._model_name,
|
|
564
|
+
usage=usage,
|
|
565
|
+
provider_response_id=data.get("requestId"),
|
|
566
|
+
provider_name=self.system,
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
@asynccontextmanager
|
|
570
|
+
async def request_stream(
|
|
571
|
+
self,
|
|
572
|
+
messages: list[ModelMessage],
|
|
573
|
+
model_settings: ModelSettings | None,
|
|
574
|
+
model_request_parameters: ModelRequestParameters,
|
|
575
|
+
run_context: RunContext[Any] | None = None,
|
|
576
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
577
|
+
"""Make a streaming request to the Gemini API."""
|
|
578
|
+
system_instruction, contents = await self._map_messages(
|
|
579
|
+
messages, model_request_parameters
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
# Build request body
|
|
583
|
+
body: dict[str, Any] = {"contents": contents}
|
|
584
|
+
|
|
585
|
+
gen_config = self._build_generation_config(model_settings)
|
|
586
|
+
if gen_config:
|
|
587
|
+
body["generationConfig"] = gen_config
|
|
588
|
+
if system_instruction:
|
|
589
|
+
body["systemInstruction"] = system_instruction
|
|
590
|
+
|
|
591
|
+
# Add tools
|
|
592
|
+
if model_request_parameters.function_tools:
|
|
593
|
+
body["tools"] = self._build_tools(model_request_parameters.function_tools)
|
|
594
|
+
|
|
595
|
+
# Make streaming request
|
|
596
|
+
client = await self._get_client()
|
|
597
|
+
url = f"{self._base_url}/models/{self._model_name}:streamGenerateContent?alt=sse&key={self.api_key}"
|
|
598
|
+
headers = self._get_headers()
|
|
599
|
+
|
|
600
|
+
async def stream_chunks() -> AsyncIterator[dict[str, Any]]:
|
|
601
|
+
async with client.stream(
|
|
602
|
+
"POST", url, json=body, headers=headers
|
|
603
|
+
) as response:
|
|
604
|
+
if response.status_code != 200:
|
|
605
|
+
text = await response.aread()
|
|
606
|
+
raise RuntimeError(
|
|
607
|
+
f"Gemini API error {response.status_code}: {text.decode()}"
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
async for line in response.aiter_lines():
|
|
611
|
+
line = line.strip()
|
|
612
|
+
if not line:
|
|
613
|
+
continue
|
|
614
|
+
if line.startswith("data: "):
|
|
615
|
+
json_str = line[6:]
|
|
616
|
+
if json_str:
|
|
617
|
+
try:
|
|
618
|
+
yield json.loads(json_str)
|
|
619
|
+
except json.JSONDecodeError:
|
|
620
|
+
continue
|
|
621
|
+
|
|
622
|
+
yield GeminiStreamingResponse(
|
|
623
|
+
model_request_parameters=model_request_parameters,
|
|
624
|
+
_chunks=stream_chunks(),
|
|
625
|
+
_model_name_str=self._model_name,
|
|
626
|
+
_provider_name_str=self.system,
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
@dataclass
|
|
631
|
+
class GeminiStreamingResponse(StreamedResponse):
|
|
632
|
+
"""Streaming response handler for Gemini API."""
|
|
633
|
+
|
|
634
|
+
_chunks: AsyncIterator[dict[str, Any]]
|
|
635
|
+
_model_name_str: str
|
|
636
|
+
_provider_name_str: str = "google"
|
|
637
|
+
_timestamp_val: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
638
|
+
|
|
639
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
640
|
+
"""Process streaming chunks and yield events."""
|
|
641
|
+
async for chunk in self._chunks:
|
|
642
|
+
# Extract usage
|
|
643
|
+
usage_meta = chunk.get("usageMetadata", {})
|
|
644
|
+
if usage_meta:
|
|
645
|
+
self._usage = RequestUsage(
|
|
646
|
+
input_tokens=usage_meta.get("promptTokenCount", 0),
|
|
647
|
+
output_tokens=usage_meta.get("candidatesTokenCount", 0),
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
# Extract response ID
|
|
651
|
+
if chunk.get("responseId"):
|
|
652
|
+
self.provider_response_id = chunk["responseId"]
|
|
653
|
+
|
|
654
|
+
candidates = chunk.get("candidates", [])
|
|
655
|
+
if not candidates:
|
|
656
|
+
continue
|
|
657
|
+
|
|
658
|
+
candidate = candidates[0]
|
|
659
|
+
content = candidate.get("content", {})
|
|
660
|
+
parts = content.get("parts", [])
|
|
661
|
+
|
|
662
|
+
for part in parts:
|
|
663
|
+
# Handle thinking part
|
|
664
|
+
if part.get("thought") and part.get("text") is not None:
|
|
665
|
+
event = self._parts_manager.handle_thinking_delta(
|
|
666
|
+
vendor_part_id=None,
|
|
667
|
+
content=part["text"],
|
|
668
|
+
)
|
|
669
|
+
if event:
|
|
670
|
+
yield event
|
|
671
|
+
|
|
672
|
+
# Handle regular text
|
|
673
|
+
elif part.get("text") is not None and not part.get("thought"):
|
|
674
|
+
text = part["text"]
|
|
675
|
+
if len(text) == 0:
|
|
676
|
+
continue
|
|
677
|
+
event = self._parts_manager.handle_text_delta(
|
|
678
|
+
vendor_part_id=None,
|
|
679
|
+
content=text,
|
|
680
|
+
)
|
|
681
|
+
if event:
|
|
682
|
+
yield event
|
|
683
|
+
|
|
684
|
+
# Handle function call
|
|
685
|
+
elif part.get("functionCall"):
|
|
686
|
+
fc = part["functionCall"]
|
|
687
|
+
event = self._parts_manager.handle_tool_call_delta(
|
|
688
|
+
vendor_part_id=uuid.uuid4(),
|
|
689
|
+
tool_name=fc.get("name"),
|
|
690
|
+
args=fc.get("args"),
|
|
691
|
+
tool_call_id=fc.get("id") or generate_tool_call_id(),
|
|
692
|
+
)
|
|
693
|
+
if event:
|
|
694
|
+
yield event
|
|
695
|
+
|
|
696
|
+
@property
|
|
697
|
+
def model_name(self) -> str:
|
|
698
|
+
return self._model_name_str
|
|
699
|
+
|
|
700
|
+
@property
|
|
701
|
+
def provider_name(self) -> str | None:
|
|
702
|
+
return self._provider_name_str
|
|
703
|
+
|
|
704
|
+
@property
|
|
705
|
+
def timestamp(self) -> datetime:
|
|
706
|
+
return self._timestamp_val
|