appkit-assistant 0.16.3__py3-none-any.whl → 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- appkit_assistant/backend/file_manager.py +117 -0
- appkit_assistant/backend/models.py +12 -0
- appkit_assistant/backend/processors/claude_base.py +178 -0
- appkit_assistant/backend/processors/claude_responses_processor.py +923 -0
- appkit_assistant/backend/processors/gemini_base.py +84 -0
- appkit_assistant/backend/processors/gemini_responses_processor.py +726 -0
- appkit_assistant/backend/processors/lorem_ipsum_processor.py +2 -0
- appkit_assistant/backend/processors/openai_base.py +10 -10
- appkit_assistant/backend/processors/openai_chat_completion_processor.py +25 -8
- appkit_assistant/backend/processors/openai_responses_processor.py +22 -15
- appkit_assistant/{logic → backend}/response_accumulator.py +58 -11
- appkit_assistant/components/__init__.py +2 -0
- appkit_assistant/components/composer.py +99 -12
- appkit_assistant/components/message.py +218 -50
- appkit_assistant/components/thread.py +2 -1
- appkit_assistant/configuration.py +2 -0
- appkit_assistant/state/thread_state.py +239 -5
- {appkit_assistant-0.16.3.dist-info → appkit_assistant-0.17.1.dist-info}/METADATA +4 -1
- {appkit_assistant-0.16.3.dist-info → appkit_assistant-0.17.1.dist-info}/RECORD +20 -15
- {appkit_assistant-0.16.3.dist-info → appkit_assistant-0.17.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,726 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Gemini responses processor for generating AI responses using Google's GenAI API.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import uuid
|
|
9
|
+
from collections.abc import AsyncGenerator
|
|
10
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Any, Final
|
|
13
|
+
|
|
14
|
+
import reflex as rx
|
|
15
|
+
from google.genai import types
|
|
16
|
+
from mcp import ClientSession
|
|
17
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
18
|
+
|
|
19
|
+
from appkit_assistant.backend.mcp_auth_service import MCPAuthService
|
|
20
|
+
from appkit_assistant.backend.models import (
|
|
21
|
+
AIModel,
|
|
22
|
+
AssistantMCPUserToken,
|
|
23
|
+
Chunk,
|
|
24
|
+
ChunkType,
|
|
25
|
+
MCPAuthType,
|
|
26
|
+
MCPServer,
|
|
27
|
+
Message,
|
|
28
|
+
MessageType,
|
|
29
|
+
)
|
|
30
|
+
from appkit_assistant.backend.processor import mcp_oauth_redirect_uri
|
|
31
|
+
from appkit_assistant.backend.processors.gemini_base import BaseGeminiProcessor
|
|
32
|
+
from appkit_assistant.backend.system_prompt_cache import get_system_prompt
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
default_oauth_redirect_uri: Final[str] = mcp_oauth_redirect_uri()
|
|
36
|
+
|
|
37
|
+
# Maximum characters to show in tool result preview
|
|
38
|
+
TOOL_RESULT_PREVIEW_LENGTH: Final[int] = 500
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class MCPToolContext:
|
|
43
|
+
"""Context for MCP tool execution."""
|
|
44
|
+
|
|
45
|
+
session: ClientSession
|
|
46
|
+
server_name: str
|
|
47
|
+
tools: dict[str, Any] = field(default_factory=dict)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class GeminiResponsesProcessor(BaseGeminiProcessor):
|
|
51
|
+
"""Gemini processor using the GenAI API with native MCP support."""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
models: dict[str, AIModel],
|
|
56
|
+
api_key: str | None = None,
|
|
57
|
+
oauth_redirect_uri: str = default_oauth_redirect_uri,
|
|
58
|
+
) -> None:
|
|
59
|
+
super().__init__(models, api_key)
|
|
60
|
+
self._current_reasoning_session: str | None = None
|
|
61
|
+
self._current_user_id: int | None = None
|
|
62
|
+
self._mcp_auth_service = MCPAuthService(redirect_uri=oauth_redirect_uri)
|
|
63
|
+
self._pending_auth_servers: list[MCPServer] = []
|
|
64
|
+
|
|
65
|
+
logger.debug("Using redirect URI for MCP OAuth: %s", oauth_redirect_uri)
|
|
66
|
+
|
|
67
|
+
async def process(
|
|
68
|
+
self,
|
|
69
|
+
messages: list[Message],
|
|
70
|
+
model_id: str,
|
|
71
|
+
files: list[str] | None = None, # noqa: ARG002
|
|
72
|
+
mcp_servers: list[MCPServer] | None = None,
|
|
73
|
+
payload: dict[str, Any] | None = None,
|
|
74
|
+
user_id: int | None = None,
|
|
75
|
+
) -> AsyncGenerator[Chunk, None]:
|
|
76
|
+
"""Process messages using Google GenAI API with native MCP support."""
|
|
77
|
+
if not self.client:
|
|
78
|
+
raise ValueError("Gemini Client not initialized.")
|
|
79
|
+
|
|
80
|
+
if model_id not in self.models:
|
|
81
|
+
msg = f"Model {model_id} not supported by Gemini processor"
|
|
82
|
+
raise ValueError(msg)
|
|
83
|
+
|
|
84
|
+
model = self.models[model_id]
|
|
85
|
+
self._current_user_id = user_id
|
|
86
|
+
self._pending_auth_servers = []
|
|
87
|
+
self._current_reasoning_session = None
|
|
88
|
+
|
|
89
|
+
# Prepare configuration
|
|
90
|
+
config = self._create_generation_config(model, payload)
|
|
91
|
+
|
|
92
|
+
# Connect to MCP servers and create sessions
|
|
93
|
+
mcp_sessions = []
|
|
94
|
+
mcp_prompt = ""
|
|
95
|
+
if mcp_servers:
|
|
96
|
+
sessions_result = await self._create_mcp_sessions(mcp_servers, user_id)
|
|
97
|
+
mcp_sessions = sessions_result["sessions"]
|
|
98
|
+
self._pending_auth_servers = sessions_result["auth_required"]
|
|
99
|
+
mcp_prompt = self._build_mcp_prompt(mcp_servers)
|
|
100
|
+
|
|
101
|
+
if mcp_sessions:
|
|
102
|
+
# Pass sessions directly to tools - SDK handles everything!
|
|
103
|
+
config.tools = mcp_sessions
|
|
104
|
+
|
|
105
|
+
# Prepare messages with MCP prompts for tool selection
|
|
106
|
+
contents, system_instruction = await self._convert_messages_to_gemini_format(
|
|
107
|
+
messages, mcp_prompt
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Add system instruction to config if present
|
|
111
|
+
if system_instruction:
|
|
112
|
+
config.system_instruction = system_instruction
|
|
113
|
+
|
|
114
|
+
if mcp_sessions:
|
|
115
|
+
logger.info(
|
|
116
|
+
"Connected to %d MCP servers for native tool support",
|
|
117
|
+
len(mcp_sessions),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
# Generate content with MCP tools
|
|
122
|
+
async for chunk in self._stream_with_mcp(
|
|
123
|
+
model.model, contents, config, mcp_sessions
|
|
124
|
+
):
|
|
125
|
+
yield chunk
|
|
126
|
+
|
|
127
|
+
# Handle any pending auth
|
|
128
|
+
for server in self._pending_auth_servers:
|
|
129
|
+
yield await self._create_auth_required_chunk(server)
|
|
130
|
+
|
|
131
|
+
except Exception as e:
|
|
132
|
+
logger.exception("Error in Gemini processor: %s", str(e))
|
|
133
|
+
yield self._create_chunk(ChunkType.ERROR, f"Error: {e!s}")
|
|
134
|
+
|
|
135
|
+
async def _create_mcp_sessions(
|
|
136
|
+
self, servers: list[MCPServer], user_id: int | None
|
|
137
|
+
) -> dict[str, Any]:
|
|
138
|
+
"""Create MCP ClientSession objects for each server.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Dict with 'sessions' and 'auth_required' lists
|
|
142
|
+
"""
|
|
143
|
+
sessions = []
|
|
144
|
+
auth_required = []
|
|
145
|
+
|
|
146
|
+
for server in servers:
|
|
147
|
+
try:
|
|
148
|
+
# Parse headers
|
|
149
|
+
headers = self._parse_mcp_headers(server)
|
|
150
|
+
|
|
151
|
+
# Handle OAuth - inject token
|
|
152
|
+
if (
|
|
153
|
+
server.auth_type == MCPAuthType.OAUTH_DISCOVERY
|
|
154
|
+
and user_id is not None
|
|
155
|
+
):
|
|
156
|
+
token = await self._get_valid_token_for_server(server, user_id)
|
|
157
|
+
if token:
|
|
158
|
+
headers["Authorization"] = f"Bearer {token.access_token}"
|
|
159
|
+
else:
|
|
160
|
+
auth_required.append(server)
|
|
161
|
+
logger.warning(
|
|
162
|
+
"Skipping MCP server %s - OAuth token required",
|
|
163
|
+
server.name,
|
|
164
|
+
)
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
# Create SSE client connection
|
|
168
|
+
# Use URL directly as configured (server determines endpoint)
|
|
169
|
+
logger.debug(
|
|
170
|
+
"Connecting to MCP server %s at %s (headers: %s)",
|
|
171
|
+
server.name,
|
|
172
|
+
server.url,
|
|
173
|
+
{
|
|
174
|
+
k: "***" if k.lower() == "authorization" else v
|
|
175
|
+
for k, v in headers.items()
|
|
176
|
+
},
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Create a session wrapper with URL and headers
|
|
180
|
+
session = MCPSessionWrapper(server.url, headers, server.name)
|
|
181
|
+
sessions.append(session)
|
|
182
|
+
|
|
183
|
+
except Exception as e:
|
|
184
|
+
logger.error(
|
|
185
|
+
"Failed to connect to MCP server %s: %s", server.name, str(e)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
return {"sessions": sessions, "auth_required": auth_required}
|
|
189
|
+
|
|
190
|
+
async def _stream_with_mcp(
|
|
191
|
+
self,
|
|
192
|
+
model_name: str,
|
|
193
|
+
contents: list[types.Content],
|
|
194
|
+
config: types.GenerateContentConfig,
|
|
195
|
+
mcp_sessions: list[Any],
|
|
196
|
+
) -> AsyncGenerator[Chunk, None]:
|
|
197
|
+
"""Stream responses with MCP tool support."""
|
|
198
|
+
if not mcp_sessions:
|
|
199
|
+
# No MCP sessions, direct streaming
|
|
200
|
+
async for chunk in self._stream_generation(model_name, contents, config):
|
|
201
|
+
yield chunk
|
|
202
|
+
return
|
|
203
|
+
|
|
204
|
+
# Enter all session contexts and fetch tools
|
|
205
|
+
async with self._mcp_context_manager(mcp_sessions) as tool_contexts:
|
|
206
|
+
if tool_contexts:
|
|
207
|
+
# Convert MCP tools to Gemini FunctionDeclarations
|
|
208
|
+
function_declarations = []
|
|
209
|
+
for ctx in tool_contexts:
|
|
210
|
+
for tool_name, tool_def in ctx.tools.items():
|
|
211
|
+
func_decl = self._mcp_tool_to_gemini_function(
|
|
212
|
+
tool_name, tool_def
|
|
213
|
+
)
|
|
214
|
+
if func_decl:
|
|
215
|
+
function_declarations.append(func_decl)
|
|
216
|
+
|
|
217
|
+
if function_declarations:
|
|
218
|
+
config.tools = [
|
|
219
|
+
types.Tool(function_declarations=function_declarations)
|
|
220
|
+
]
|
|
221
|
+
logger.info(
|
|
222
|
+
"Configured %d tools for Gemini from MCP",
|
|
223
|
+
len(function_declarations),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Stream with automatic function calling loop
|
|
227
|
+
async for chunk in self._stream_with_tool_loop(
|
|
228
|
+
model_name, contents, config, tool_contexts
|
|
229
|
+
):
|
|
230
|
+
yield chunk
|
|
231
|
+
|
|
232
|
+
async def _stream_with_tool_loop(
|
|
233
|
+
self,
|
|
234
|
+
model_name: str,
|
|
235
|
+
contents: list[types.Content],
|
|
236
|
+
config: types.GenerateContentConfig,
|
|
237
|
+
tool_contexts: list[MCPToolContext],
|
|
238
|
+
) -> AsyncGenerator[Chunk, None]:
|
|
239
|
+
"""Stream generation with tool call handling loop."""
|
|
240
|
+
max_tool_rounds = 10
|
|
241
|
+
current_contents = list(contents)
|
|
242
|
+
|
|
243
|
+
for _round_num in range(max_tool_rounds):
|
|
244
|
+
response = await self.client.aio.models.generate_content(
|
|
245
|
+
model=model_name, contents=current_contents, config=config
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
if not response.candidates:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
candidate = response.candidates[0]
|
|
252
|
+
content = candidate.content
|
|
253
|
+
|
|
254
|
+
# Check for function calls
|
|
255
|
+
function_calls = [
|
|
256
|
+
part.function_call
|
|
257
|
+
for part in content.parts
|
|
258
|
+
if part.function_call is not None
|
|
259
|
+
]
|
|
260
|
+
|
|
261
|
+
if function_calls:
|
|
262
|
+
# Add model response with function calls to conversation
|
|
263
|
+
current_contents.append(content)
|
|
264
|
+
|
|
265
|
+
# Execute tool calls and collect results
|
|
266
|
+
function_responses = []
|
|
267
|
+
for fc in function_calls:
|
|
268
|
+
# Find server name for this tool
|
|
269
|
+
server_name = "unknown"
|
|
270
|
+
for ctx in tool_contexts:
|
|
271
|
+
if fc.name in ctx.tools:
|
|
272
|
+
server_name = ctx.server_name
|
|
273
|
+
break
|
|
274
|
+
|
|
275
|
+
# Generate a unique tool call ID
|
|
276
|
+
tool_call_id = f"mcp_{uuid.uuid4().hex[:32]}"
|
|
277
|
+
|
|
278
|
+
# Yield TOOL_CALL chunk to show in UI
|
|
279
|
+
yield self._create_chunk(
|
|
280
|
+
ChunkType.TOOL_CALL,
|
|
281
|
+
f"Werkzeug: {server_name}.{fc.name}",
|
|
282
|
+
{
|
|
283
|
+
"tool_name": fc.name,
|
|
284
|
+
"tool_id": tool_call_id,
|
|
285
|
+
"server_label": server_name,
|
|
286
|
+
"arguments": json.dumps(fc.args),
|
|
287
|
+
"status": "starting",
|
|
288
|
+
},
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
result = await self._execute_mcp_tool(
|
|
292
|
+
fc.name, fc.args, tool_contexts
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# Yield TOOL_RESULT chunk with preview
|
|
296
|
+
preview = (
|
|
297
|
+
result[:TOOL_RESULT_PREVIEW_LENGTH]
|
|
298
|
+
if len(result) > TOOL_RESULT_PREVIEW_LENGTH
|
|
299
|
+
else result
|
|
300
|
+
)
|
|
301
|
+
yield self._create_chunk(
|
|
302
|
+
ChunkType.TOOL_RESULT,
|
|
303
|
+
preview,
|
|
304
|
+
{
|
|
305
|
+
"tool_name": fc.name,
|
|
306
|
+
"tool_id": tool_call_id,
|
|
307
|
+
"server_label": server_name,
|
|
308
|
+
"status": "completed",
|
|
309
|
+
"result_length": str(len(result)),
|
|
310
|
+
},
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
function_responses.append(
|
|
314
|
+
types.Part(
|
|
315
|
+
function_response=types.FunctionResponse(
|
|
316
|
+
name=fc.name,
|
|
317
|
+
response={"result": result},
|
|
318
|
+
)
|
|
319
|
+
)
|
|
320
|
+
)
|
|
321
|
+
logger.debug(
|
|
322
|
+
"Tool %s executed, result length: %d",
|
|
323
|
+
fc.name,
|
|
324
|
+
len(str(result)),
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Add function responses
|
|
328
|
+
current_contents.append(
|
|
329
|
+
types.Content(role="user", parts=function_responses)
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# Continue to next round
|
|
333
|
+
continue
|
|
334
|
+
|
|
335
|
+
# No function calls - yield text response
|
|
336
|
+
text_parts = [part.text for part in content.parts if part.text]
|
|
337
|
+
if text_parts:
|
|
338
|
+
yield self._create_chunk(
|
|
339
|
+
ChunkType.TEXT,
|
|
340
|
+
"".join(text_parts),
|
|
341
|
+
{"delta": "".join(text_parts)},
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Done - no more function calls
|
|
345
|
+
return
|
|
346
|
+
|
|
347
|
+
logger.warning("Max tool rounds (%d) exceeded", max_tool_rounds)
|
|
348
|
+
|
|
349
|
+
async def _execute_mcp_tool(
|
|
350
|
+
self,
|
|
351
|
+
tool_name: str,
|
|
352
|
+
args: dict[str, Any],
|
|
353
|
+
tool_contexts: list[MCPToolContext],
|
|
354
|
+
) -> str:
|
|
355
|
+
"""Execute an MCP tool and return the result."""
|
|
356
|
+
# Find which context has this tool
|
|
357
|
+
for ctx in tool_contexts:
|
|
358
|
+
if tool_name in ctx.tools:
|
|
359
|
+
try:
|
|
360
|
+
logger.debug(
|
|
361
|
+
"Executing tool %s on server %s with args: %s",
|
|
362
|
+
tool_name,
|
|
363
|
+
ctx.server_name,
|
|
364
|
+
args,
|
|
365
|
+
)
|
|
366
|
+
result = await ctx.session.call_tool(tool_name, args)
|
|
367
|
+
# Extract text from result
|
|
368
|
+
if hasattr(result, "content") and result.content:
|
|
369
|
+
texts = [
|
|
370
|
+
item.text
|
|
371
|
+
for item in result.content
|
|
372
|
+
if hasattr(item, "text")
|
|
373
|
+
]
|
|
374
|
+
return "\n".join(texts) if texts else str(result)
|
|
375
|
+
return str(result)
|
|
376
|
+
except Exception as e:
|
|
377
|
+
logger.exception("Error executing tool %s: %s", tool_name, str(e))
|
|
378
|
+
return f"Error executing tool: {e!s}"
|
|
379
|
+
|
|
380
|
+
return f"Tool {tool_name} not found in any MCP server"
|
|
381
|
+
|
|
382
|
+
def _mcp_tool_to_gemini_function(
|
|
383
|
+
self, name: str, tool_def: dict[str, Any]
|
|
384
|
+
) -> types.FunctionDeclaration | None:
|
|
385
|
+
"""Convert MCP tool definition to Gemini FunctionDeclaration."""
|
|
386
|
+
try:
|
|
387
|
+
description = tool_def.get("description", "")
|
|
388
|
+
input_schema = tool_def.get("inputSchema", {})
|
|
389
|
+
|
|
390
|
+
# Fix the schema for Gemini compatibility
|
|
391
|
+
fixed_schema = self._fix_schema_for_gemini(input_schema)
|
|
392
|
+
|
|
393
|
+
return types.FunctionDeclaration(
|
|
394
|
+
name=name,
|
|
395
|
+
description=description,
|
|
396
|
+
parameters=fixed_schema if fixed_schema else None,
|
|
397
|
+
)
|
|
398
|
+
except Exception as e:
|
|
399
|
+
logger.warning("Failed to convert MCP tool %s: %s", name, str(e))
|
|
400
|
+
return None
|
|
401
|
+
|
|
402
|
+
def _fix_schema_for_gemini(self, schema: dict[str, Any]) -> dict[str, Any]:
|
|
403
|
+
"""Fix JSON schema for Gemini API compatibility.
|
|
404
|
+
|
|
405
|
+
Gemini requires 'items' field for array types and doesn't allow certain
|
|
406
|
+
JSON Schema fields like '$schema', '$id', 'definitions', etc.
|
|
407
|
+
This recursively fixes the schema.
|
|
408
|
+
"""
|
|
409
|
+
if not schema:
|
|
410
|
+
return schema
|
|
411
|
+
|
|
412
|
+
# Deep copy to avoid modifying original
|
|
413
|
+
schema = copy.deepcopy(schema)
|
|
414
|
+
|
|
415
|
+
# Fields that Gemini doesn't allow in FunctionDeclaration parameters
|
|
416
|
+
# Note: additionalProperties gets converted to additional_properties by SDK
|
|
417
|
+
forbidden_fields = {
|
|
418
|
+
"$schema",
|
|
419
|
+
"$id",
|
|
420
|
+
"$ref",
|
|
421
|
+
"$defs",
|
|
422
|
+
"definitions",
|
|
423
|
+
"$comment",
|
|
424
|
+
"examples",
|
|
425
|
+
"default",
|
|
426
|
+
"const",
|
|
427
|
+
"contentMediaType",
|
|
428
|
+
"contentEncoding",
|
|
429
|
+
"additionalProperties",
|
|
430
|
+
"additional_properties",
|
|
431
|
+
"patternProperties",
|
|
432
|
+
"unevaluatedProperties",
|
|
433
|
+
"unevaluatedItems",
|
|
434
|
+
"minItems",
|
|
435
|
+
"maxItems",
|
|
436
|
+
"minLength",
|
|
437
|
+
"maxLength",
|
|
438
|
+
"minimum",
|
|
439
|
+
"maximum",
|
|
440
|
+
"exclusiveMinimum",
|
|
441
|
+
"exclusiveMaximum",
|
|
442
|
+
"multipleOf",
|
|
443
|
+
"pattern",
|
|
444
|
+
"format",
|
|
445
|
+
"title",
|
|
446
|
+
# Composition keywords - Gemini doesn't support these
|
|
447
|
+
"allOf",
|
|
448
|
+
"oneOf",
|
|
449
|
+
"not",
|
|
450
|
+
"if",
|
|
451
|
+
"then",
|
|
452
|
+
"else",
|
|
453
|
+
"dependentSchemas",
|
|
454
|
+
"dependentRequired",
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
def fix_property(prop: dict[str, Any]) -> dict[str, Any]:
|
|
458
|
+
"""Recursively fix a property schema."""
|
|
459
|
+
if not isinstance(prop, dict):
|
|
460
|
+
return prop
|
|
461
|
+
|
|
462
|
+
# Remove forbidden fields
|
|
463
|
+
for forbidden in forbidden_fields:
|
|
464
|
+
prop.pop(forbidden, None)
|
|
465
|
+
|
|
466
|
+
prop_type = prop.get("type")
|
|
467
|
+
|
|
468
|
+
# Fix array without items
|
|
469
|
+
if prop_type == "array" and "items" not in prop:
|
|
470
|
+
prop["items"] = {"type": "string"}
|
|
471
|
+
logger.debug("Added missing 'items' to array property")
|
|
472
|
+
|
|
473
|
+
# Recurse into items
|
|
474
|
+
if "items" in prop and isinstance(prop["items"], dict):
|
|
475
|
+
prop["items"] = fix_property(prop["items"])
|
|
476
|
+
|
|
477
|
+
# Recurse into properties
|
|
478
|
+
if "properties" in prop and isinstance(prop["properties"], dict):
|
|
479
|
+
for key, val in prop["properties"].items():
|
|
480
|
+
prop["properties"][key] = fix_property(val)
|
|
481
|
+
|
|
482
|
+
# Recurse into anyOf/any_of arrays (Gemini accepts these but not
|
|
483
|
+
# forbidden fields inside them)
|
|
484
|
+
for any_of_key in ("anyOf", "any_of"):
|
|
485
|
+
if any_of_key in prop and isinstance(prop[any_of_key], list):
|
|
486
|
+
prop[any_of_key] = [
|
|
487
|
+
fix_property(item) if isinstance(item, dict) else item
|
|
488
|
+
for item in prop[any_of_key]
|
|
489
|
+
]
|
|
490
|
+
|
|
491
|
+
return prop
|
|
492
|
+
|
|
493
|
+
return fix_property(schema)
|
|
494
|
+
|
|
495
|
+
@asynccontextmanager
|
|
496
|
+
async def _mcp_context_manager(
|
|
497
|
+
self, session_wrappers: list[Any]
|
|
498
|
+
) -> AsyncGenerator[list[MCPToolContext], None]:
|
|
499
|
+
"""Context manager to enter all MCP session contexts and fetch tools."""
|
|
500
|
+
async with AsyncExitStack() as stack:
|
|
501
|
+
tool_contexts: list[MCPToolContext] = []
|
|
502
|
+
|
|
503
|
+
for wrapper in session_wrappers:
|
|
504
|
+
try:
|
|
505
|
+
logger.debug(
|
|
506
|
+
"Connecting to MCP server %s via streamablehttp_client",
|
|
507
|
+
wrapper.name,
|
|
508
|
+
)
|
|
509
|
+
read, write, _ = await stack.enter_async_context(
|
|
510
|
+
streamablehttp_client(
|
|
511
|
+
url=wrapper.url,
|
|
512
|
+
headers=wrapper.headers,
|
|
513
|
+
timeout=60.0,
|
|
514
|
+
)
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
session = await stack.enter_async_context(
|
|
518
|
+
ClientSession(read, write)
|
|
519
|
+
)
|
|
520
|
+
await session.initialize()
|
|
521
|
+
|
|
522
|
+
# Fetch tools from this session
|
|
523
|
+
tools_result = await session.list_tools()
|
|
524
|
+
tools_dict = {}
|
|
525
|
+
for tool in tools_result.tools:
|
|
526
|
+
tools_dict[tool.name] = {
|
|
527
|
+
"description": tool.description or "",
|
|
528
|
+
"inputSchema": (
|
|
529
|
+
tool.inputSchema if hasattr(tool, "inputSchema") else {}
|
|
530
|
+
),
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
ctx = MCPToolContext(
|
|
534
|
+
session=session,
|
|
535
|
+
server_name=wrapper.name,
|
|
536
|
+
tools=tools_dict,
|
|
537
|
+
)
|
|
538
|
+
tool_contexts.append(ctx)
|
|
539
|
+
|
|
540
|
+
logger.info(
|
|
541
|
+
"MCP session initialized for %s with %d tools",
|
|
542
|
+
wrapper.name,
|
|
543
|
+
len(tools_dict),
|
|
544
|
+
)
|
|
545
|
+
except Exception as e:
|
|
546
|
+
logger.exception(
|
|
547
|
+
"Failed to initialize MCP session for %s: %s",
|
|
548
|
+
wrapper.name,
|
|
549
|
+
str(e),
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
try:
|
|
553
|
+
yield tool_contexts
|
|
554
|
+
except Exception as e:
|
|
555
|
+
logger.exception("Error during MCP session usage: %s", str(e))
|
|
556
|
+
raise
|
|
557
|
+
|
|
558
|
+
async def _stream_generation(
|
|
559
|
+
self,
|
|
560
|
+
model_name: str,
|
|
561
|
+
contents: list[types.Content],
|
|
562
|
+
config: types.GenerateContentConfig,
|
|
563
|
+
) -> AsyncGenerator[Chunk, None]:
|
|
564
|
+
"""Stream generation from Gemini model."""
|
|
565
|
+
# generate_content_stream returns an awaitable that yields an async generator
|
|
566
|
+
stream = await self.client.aio.models.generate_content_stream(
|
|
567
|
+
model=model_name, contents=contents, config=config
|
|
568
|
+
)
|
|
569
|
+
async for chunk in stream:
|
|
570
|
+
processed = self._handle_chunk(chunk)
|
|
571
|
+
if processed:
|
|
572
|
+
yield processed
|
|
573
|
+
|
|
574
|
+
def _create_generation_config(
|
|
575
|
+
self, model: AIModel, payload: dict[str, Any] | None
|
|
576
|
+
) -> types.GenerateContentConfig:
|
|
577
|
+
"""Create generation config from model and payload."""
|
|
578
|
+
# Default thinking level depends on model
|
|
579
|
+
# "medium" is only supported by Flash, Pro uses "high" (default dynamic)
|
|
580
|
+
thinking_level = "high"
|
|
581
|
+
if "flash" in model.model.lower():
|
|
582
|
+
thinking_level = "medium"
|
|
583
|
+
|
|
584
|
+
# Override from payload if present
|
|
585
|
+
if payload and "thinking_level" in payload:
|
|
586
|
+
thinking_level = payload.pop("thinking_level")
|
|
587
|
+
|
|
588
|
+
return types.GenerateContentConfig(
|
|
589
|
+
temperature=model.temperature,
|
|
590
|
+
thinking_config=types.ThinkingConfig(thinking_level=thinking_level),
|
|
591
|
+
**(payload or {}),
|
|
592
|
+
response_modalities=["TEXT"],
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
def _build_mcp_prompt(self, mcp_servers: list[MCPServer]) -> str:
|
|
596
|
+
"""Build MCP tool selection prompt from server prompts."""
|
|
597
|
+
prompts = [f"- {server.prompt}" for server in mcp_servers if server.prompt]
|
|
598
|
+
return "\n".join(prompts) if prompts else ""
|
|
599
|
+
|
|
600
|
+
async def _convert_messages_to_gemini_format(
|
|
601
|
+
self, messages: list[Message], mcp_prompt: str = ""
|
|
602
|
+
) -> tuple[list[types.Content], str | None]:
|
|
603
|
+
"""Convert app messages to Gemini Content objects."""
|
|
604
|
+
contents: list[types.Content] = []
|
|
605
|
+
system_instruction: str | None = None
|
|
606
|
+
|
|
607
|
+
# Build MCP prompt section if tools are available
|
|
608
|
+
mcp_section = ""
|
|
609
|
+
if mcp_prompt:
|
|
610
|
+
mcp_section = (
|
|
611
|
+
"\n\n### Tool-Auswahlrichtlinien (Einbettung externer Beschreibungen)\n"
|
|
612
|
+
f"{mcp_prompt}"
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
# Get system prompt content first
|
|
616
|
+
system_prompt_template = await get_system_prompt()
|
|
617
|
+
if system_prompt_template:
|
|
618
|
+
# Format with MCP prompts placeholder
|
|
619
|
+
system_instruction = system_prompt_template.format(mcp_prompts=mcp_section)
|
|
620
|
+
|
|
621
|
+
for msg in messages:
|
|
622
|
+
if msg.type == MessageType.SYSTEM:
|
|
623
|
+
# Append to system instruction
|
|
624
|
+
if system_instruction:
|
|
625
|
+
system_instruction += f"\n{msg.text}"
|
|
626
|
+
else:
|
|
627
|
+
system_instruction = msg.text
|
|
628
|
+
elif msg.type in (MessageType.HUMAN, MessageType.ASSISTANT):
|
|
629
|
+
role = "user" if msg.type == MessageType.HUMAN else "model"
|
|
630
|
+
contents.append(
|
|
631
|
+
types.Content(role=role, parts=[types.Part(text=msg.text)])
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
return contents, system_instruction
|
|
635
|
+
|
|
636
|
+
def _handle_chunk(self, chunk: Any) -> Chunk | None:
|
|
637
|
+
"""Handle a single chunk from Gemini stream."""
|
|
638
|
+
# Gemini chunks contain candidates. First candidate.
|
|
639
|
+
if not chunk.candidates or not chunk.candidates[0].content:
|
|
640
|
+
return None
|
|
641
|
+
|
|
642
|
+
candidate = chunk.candidates[0]
|
|
643
|
+
content = candidate.content
|
|
644
|
+
|
|
645
|
+
# List comprehension for text parts
|
|
646
|
+
if not content.parts:
|
|
647
|
+
return None
|
|
648
|
+
|
|
649
|
+
text_parts = [part.text for part in content.parts if part.text]
|
|
650
|
+
|
|
651
|
+
if text_parts:
|
|
652
|
+
return self._create_chunk(
|
|
653
|
+
ChunkType.TEXT, "".join(text_parts), {"delta": "".join(text_parts)}
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
return None
|
|
657
|
+
|
|
658
|
+
def _create_chunk(
|
|
659
|
+
self,
|
|
660
|
+
chunk_type: ChunkType,
|
|
661
|
+
content: str,
|
|
662
|
+
extra_metadata: dict[str, str] | None = None,
|
|
663
|
+
) -> Chunk:
|
|
664
|
+
"""Create a Chunk."""
|
|
665
|
+
metadata = {
|
|
666
|
+
"processor": "gemini_responses",
|
|
667
|
+
}
|
|
668
|
+
if extra_metadata:
|
|
669
|
+
metadata.update(extra_metadata)
|
|
670
|
+
|
|
671
|
+
return Chunk(
|
|
672
|
+
type=chunk_type,
|
|
673
|
+
text=content,
|
|
674
|
+
chunk_metadata=metadata,
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
async def _create_auth_required_chunk(self, server: MCPServer) -> Chunk:
|
|
678
|
+
"""Create an AUTH_REQUIRED chunk."""
|
|
679
|
+
# reusing logic from other processors, simplified here
|
|
680
|
+
return Chunk(
|
|
681
|
+
type=ChunkType.AUTH_REQUIRED,
|
|
682
|
+
text=f"{server.name} authentication required",
|
|
683
|
+
chunk_metadata={"server_name": server.name},
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
def _parse_mcp_headers(self, server: MCPServer) -> dict[str, str]:
|
|
687
|
+
"""Parse headers from server config.
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
Dictionary of HTTP headers to send to the MCP server.
|
|
691
|
+
"""
|
|
692
|
+
if not server.headers or server.headers == "{}":
|
|
693
|
+
return {}
|
|
694
|
+
|
|
695
|
+
try:
|
|
696
|
+
headers_dict = json.loads(server.headers)
|
|
697
|
+
return dict(headers_dict)
|
|
698
|
+
except json.JSONDecodeError:
|
|
699
|
+
logger.warning("Invalid headers JSON for server %s", server.name)
|
|
700
|
+
return {}
|
|
701
|
+
|
|
702
|
+
async def _get_valid_token_for_server(
|
|
703
|
+
self, server: MCPServer, user_id: int
|
|
704
|
+
) -> AssistantMCPUserToken | None:
|
|
705
|
+
"""Get a valid OAuth token for the server/user."""
|
|
706
|
+
if server.id is None:
|
|
707
|
+
return None
|
|
708
|
+
|
|
709
|
+
with rx.session() as session:
|
|
710
|
+
token = self._mcp_auth_service.get_user_token(session, user_id, server.id)
|
|
711
|
+
|
|
712
|
+
if token is None:
|
|
713
|
+
return None
|
|
714
|
+
|
|
715
|
+
return await self._mcp_auth_service.ensure_valid_token(
|
|
716
|
+
session, server, token
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
class MCPSessionWrapper:
|
|
721
|
+
"""Wrapper to store MCP connection details before creating actual session."""
|
|
722
|
+
|
|
723
|
+
def __init__(self, url: str, headers: dict[str, str], name: str) -> None:
|
|
724
|
+
self.url = url
|
|
725
|
+
self.headers = headers
|
|
726
|
+
self.name = name
|