appkit-assistant 0.17.3__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. appkit_assistant/backend/{models.py → database/models.py} +32 -132
  2. appkit_assistant/backend/{repositories.py → database/repositories.py} +93 -1
  3. appkit_assistant/backend/model_manager.py +5 -5
  4. appkit_assistant/backend/models/__init__.py +28 -0
  5. appkit_assistant/backend/models/anthropic.py +31 -0
  6. appkit_assistant/backend/models/google.py +27 -0
  7. appkit_assistant/backend/models/openai.py +50 -0
  8. appkit_assistant/backend/models/perplexity.py +56 -0
  9. appkit_assistant/backend/processors/__init__.py +29 -0
  10. appkit_assistant/backend/processors/claude_responses_processor.py +205 -387
  11. appkit_assistant/backend/processors/gemini_responses_processor.py +290 -352
  12. appkit_assistant/backend/processors/lorem_ipsum_processor.py +6 -4
  13. appkit_assistant/backend/processors/mcp_mixin.py +297 -0
  14. appkit_assistant/backend/processors/openai_base.py +11 -125
  15. appkit_assistant/backend/processors/openai_chat_completion_processor.py +5 -3
  16. appkit_assistant/backend/processors/openai_responses_processor.py +480 -402
  17. appkit_assistant/backend/processors/perplexity_processor.py +156 -79
  18. appkit_assistant/backend/{processor.py → processors/processor_base.py} +7 -2
  19. appkit_assistant/backend/processors/streaming_base.py +188 -0
  20. appkit_assistant/backend/schemas.py +138 -0
  21. appkit_assistant/backend/services/auth_error_detector.py +99 -0
  22. appkit_assistant/backend/services/chunk_factory.py +273 -0
  23. appkit_assistant/backend/services/citation_handler.py +292 -0
  24. appkit_assistant/backend/services/file_cleanup_service.py +316 -0
  25. appkit_assistant/backend/services/file_upload_service.py +903 -0
  26. appkit_assistant/backend/services/file_validation.py +138 -0
  27. appkit_assistant/backend/{mcp_auth_service.py → services/mcp_auth_service.py} +4 -2
  28. appkit_assistant/backend/services/mcp_token_service.py +61 -0
  29. appkit_assistant/backend/services/message_converter.py +289 -0
  30. appkit_assistant/backend/services/openai_client_service.py +120 -0
  31. appkit_assistant/backend/{response_accumulator.py → services/response_accumulator.py} +163 -1
  32. appkit_assistant/backend/services/system_prompt_builder.py +89 -0
  33. appkit_assistant/backend/services/thread_service.py +5 -3
  34. appkit_assistant/backend/system_prompt_cache.py +3 -3
  35. appkit_assistant/components/__init__.py +8 -4
  36. appkit_assistant/components/composer.py +59 -24
  37. appkit_assistant/components/file_manager.py +623 -0
  38. appkit_assistant/components/mcp_server_dialogs.py +12 -20
  39. appkit_assistant/components/mcp_server_table.py +12 -2
  40. appkit_assistant/components/message.py +119 -2
  41. appkit_assistant/components/thread.py +1 -1
  42. appkit_assistant/components/threadlist.py +4 -2
  43. appkit_assistant/components/tools_modal.py +37 -20
  44. appkit_assistant/configuration.py +12 -0
  45. appkit_assistant/state/file_manager_state.py +697 -0
  46. appkit_assistant/state/mcp_oauth_state.py +3 -3
  47. appkit_assistant/state/mcp_server_state.py +47 -2
  48. appkit_assistant/state/system_prompt_state.py +1 -1
  49. appkit_assistant/state/thread_list_state.py +99 -5
  50. appkit_assistant/state/thread_state.py +88 -9
  51. {appkit_assistant-0.17.3.dist-info → appkit_assistant-1.0.1.dist-info}/METADATA +8 -6
  52. appkit_assistant-1.0.1.dist-info/RECORD +58 -0
  53. appkit_assistant/backend/processors/claude_base.py +0 -178
  54. appkit_assistant/backend/processors/gemini_base.py +0 -84
  55. appkit_assistant-0.17.3.dist-info/RECORD +0 -39
  56. /appkit_assistant/backend/{file_manager.py → services/file_manager.py} +0 -0
  57. {appkit_assistant-0.17.3.dist-info → appkit_assistant-1.0.1.dist-info}/WHEEL +0 -0
@@ -4,33 +4,33 @@ Gemini responses processor for generating AI responses using Google's GenAI API.
4
4
 
5
5
  import asyncio
6
6
  import copy
7
- import json
8
7
  import logging
9
8
  import uuid
10
9
  from collections.abc import AsyncGenerator
11
10
  from contextlib import AsyncExitStack, asynccontextmanager
12
11
  from dataclasses import dataclass, field
13
- from typing import Any, Final
12
+ from typing import Any, Final, NamedTuple
14
13
 
15
- import reflex as rx
14
+ import httpx
15
+ from google import genai
16
16
  from google.genai import types
17
17
  from mcp import ClientSession
18
- from mcp.client.streamable_http import streamablehttp_client
18
+ from mcp.client.streamable_http import streamable_http_client
19
19
 
20
- from appkit_assistant.backend.mcp_auth_service import MCPAuthService
21
- from appkit_assistant.backend.models import (
20
+ from appkit_assistant.backend.database.models import (
21
+ MCPServer,
22
+ )
23
+ from appkit_assistant.backend.processors.mcp_mixin import MCPCapabilities
24
+ from appkit_assistant.backend.processors.processor_base import mcp_oauth_redirect_uri
25
+ from appkit_assistant.backend.processors.streaming_base import StreamingProcessorBase
26
+ from appkit_assistant.backend.schemas import (
22
27
  AIModel,
23
- AssistantMCPUserToken,
24
28
  Chunk,
25
- ChunkType,
26
29
  MCPAuthType,
27
- MCPServer,
28
30
  Message,
29
31
  MessageType,
30
32
  )
31
- from appkit_assistant.backend.processor import mcp_oauth_redirect_uri
32
- from appkit_assistant.backend.processors.gemini_base import BaseGeminiProcessor
33
- from appkit_assistant.backend.system_prompt_cache import get_system_prompt
33
+ from appkit_assistant.backend.services.system_prompt_builder import SystemPromptBuilder
34
34
 
35
35
  logger = logging.getLogger(__name__)
36
36
  default_oauth_redirect_uri: Final[str] = mcp_oauth_redirect_uri()
@@ -38,6 +38,45 @@ default_oauth_redirect_uri: Final[str] = mcp_oauth_redirect_uri()
38
38
  # Maximum characters to show in tool result preview
39
39
  TOOL_RESULT_PREVIEW_LENGTH: Final[int] = 500
40
40
 
41
+ GEMINI_FORBIDDEN_SCHEMA_FIELDS: Final[set[str]] = {
42
+ "$schema",
43
+ "$id",
44
+ "$ref",
45
+ "$defs",
46
+ "definitions",
47
+ "$comment",
48
+ "examples",
49
+ "default",
50
+ "const",
51
+ "contentMediaType",
52
+ "contentEncoding",
53
+ "additionalProperties",
54
+ "additional_properties",
55
+ "patternProperties",
56
+ "unevaluatedProperties",
57
+ "unevaluatedItems",
58
+ "minItems",
59
+ "maxItems",
60
+ "minLength",
61
+ "maxLength",
62
+ "minimum",
63
+ "maximum",
64
+ "exclusiveMinimum",
65
+ "exclusiveMaximum",
66
+ "multipleOf",
67
+ "pattern",
68
+ "format",
69
+ "title",
70
+ "allOf",
71
+ "oneOf",
72
+ "not",
73
+ "if",
74
+ "then",
75
+ "else",
76
+ "dependentSchemas",
77
+ "dependentRequired",
78
+ }
79
+
41
80
 
42
81
  @dataclass
43
82
  class MCPToolContext:
@@ -48,7 +87,15 @@ class MCPToolContext:
48
87
  tools: dict[str, Any] = field(default_factory=dict)
49
88
 
50
89
 
51
- class GeminiResponsesProcessor(BaseGeminiProcessor):
90
+ class MCPSessionWrapper(NamedTuple):
91
+ """Wrapper to store MCP connection details before creating actual session."""
92
+
93
+ url: str
94
+ headers: dict[str, str]
95
+ name: str
96
+
97
+
98
+ class GeminiResponsesProcessor(StreamingProcessorBase, MCPCapabilities):
52
99
  """Gemini processor using the GenAI API with native MCP support."""
53
100
 
54
101
  def __init__(
@@ -57,14 +104,27 @@ class GeminiResponsesProcessor(BaseGeminiProcessor):
57
104
  api_key: str | None = None,
58
105
  oauth_redirect_uri: str = default_oauth_redirect_uri,
59
106
  ) -> None:
60
- super().__init__(models, api_key)
61
- self._current_reasoning_session: str | None = None
62
- self._current_user_id: int | None = None
63
- self._mcp_auth_service = MCPAuthService(redirect_uri=oauth_redirect_uri)
64
- self._pending_auth_servers: list[MCPServer] = []
107
+ StreamingProcessorBase.__init__(self, models, "gemini_responses")
108
+ MCPCapabilities.__init__(self, oauth_redirect_uri, "gemini_responses")
109
+ self.client: genai.Client | None = None
110
+ self._system_prompt_builder = SystemPromptBuilder()
111
+
112
+ if api_key:
113
+ try:
114
+ self.client = genai.Client(
115
+ api_key=api_key, http_options={"api_version": "v1beta"}
116
+ )
117
+ except Exception as e:
118
+ logger.error("Failed to initialize Gemini client: %s", e)
119
+ else:
120
+ logger.warning("Gemini API key not found. Processor disabled.")
65
121
 
66
122
  logger.debug("Using redirect URI for MCP OAuth: %s", oauth_redirect_uri)
67
123
 
124
+ def _get_event_handlers(self) -> dict[str, Any]:
125
+ """Get event handlers (empty for Gemini - uses chunk-based handling)."""
126
+ return {}
127
+
68
128
  async def process(
69
129
  self,
70
130
  messages: list[Message],
@@ -84,9 +144,8 @@ class GeminiResponsesProcessor(BaseGeminiProcessor):
84
144
  raise ValueError(msg)
85
145
 
86
146
  model = self.models[model_id]
87
- self._current_user_id = user_id
88
- self._pending_auth_servers = []
89
- self._current_reasoning_session = None
147
+ self.current_user_id = user_id
148
+ self.clear_pending_auth_servers()
90
149
 
91
150
  # Prepare configuration
92
151
  config = self._create_generation_config(model, payload)
@@ -95,14 +154,13 @@ class GeminiResponsesProcessor(BaseGeminiProcessor):
95
154
  mcp_sessions = []
96
155
  mcp_prompt = ""
97
156
  if mcp_servers:
98
- sessions_result = await self._create_mcp_sessions(mcp_servers, user_id)
99
- mcp_sessions = sessions_result["sessions"]
100
- self._pending_auth_servers = sessions_result["auth_required"]
157
+ mcp_sessions, auth_required = await self._create_mcp_sessions(
158
+ mcp_servers, user_id
159
+ )
160
+ for server in auth_required:
161
+ self.add_pending_auth_server(server)
101
162
  mcp_prompt = self._build_mcp_prompt(mcp_servers)
102
-
103
- if mcp_sessions:
104
- # Pass sessions directly to tools - SDK handles everything!
105
- config.tools = mcp_sessions
163
+ # Note: tools are configured in _stream_with_mcp after connecting to MCP
106
164
 
107
165
  # Prepare messages with MCP prompts for tool selection
108
166
  contents, system_instruction = await self._convert_messages_to_gemini_format(
@@ -127,35 +185,35 @@ class GeminiResponsesProcessor(BaseGeminiProcessor):
127
185
  yield chunk
128
186
 
129
187
  # Handle any pending auth
130
- for server in self._pending_auth_servers:
131
- yield await self._create_auth_required_chunk(server)
188
+ async for auth_chunk in self.yield_pending_auth_chunks():
189
+ yield auth_chunk
132
190
 
133
191
  except Exception as e:
134
192
  logger.exception("Error in Gemini processor: %s", str(e))
135
- yield self._create_chunk(ChunkType.ERROR, f"Error: {e!s}")
193
+ yield self.chunk_factory.error(f"Error: {e!s}")
136
194
 
137
195
  async def _create_mcp_sessions(
138
196
  self, servers: list[MCPServer], user_id: int | None
139
- ) -> dict[str, Any]:
197
+ ) -> tuple[list[MCPSessionWrapper], list[MCPServer]]:
140
198
  """Create MCP ClientSession objects for each server.
141
199
 
142
200
  Returns:
143
- Dict with 'sessions' and 'auth_required' lists
201
+ Tuple with (sessions, auth_required_servers)
144
202
  """
145
203
  sessions = []
146
204
  auth_required = []
147
205
 
148
206
  for server in servers:
149
207
  try:
150
- # Parse headers
151
- headers = self._parse_mcp_headers(server)
208
+ # Parse headers using MCPCapabilities
209
+ headers = self.parse_mcp_headers(server)
152
210
 
153
211
  # Handle OAuth - inject token
154
212
  if (
155
213
  server.auth_type == MCPAuthType.OAUTH_DISCOVERY
156
214
  and user_id is not None
157
215
  ):
158
- token = await self._get_valid_token_for_server(server, user_id)
216
+ token = await self.get_valid_token(server, user_id)
159
217
  if token:
160
218
  headers["Authorization"] = f"Bearer {token.access_token}"
161
219
  else:
@@ -187,7 +245,7 @@ class GeminiResponsesProcessor(BaseGeminiProcessor):
187
245
  "Failed to connect to MCP server %s: %s", server.name, str(e)
188
246
  )
189
247
 
190
- return {"sessions": sessions, "auth_required": auth_required}
248
+ return sessions, auth_required
191
249
 
192
250
  async def _stream_with_mcp(
193
251
  self,
@@ -210,11 +268,14 @@ class GeminiResponsesProcessor(BaseGeminiProcessor):
210
268
  async with self._mcp_context_manager(mcp_sessions) as tool_contexts:
211
269
  if tool_contexts:
212
270
  # Convert MCP tools to Gemini FunctionDeclarations
271
+ # Use unique naming: server_name__tool_name to avoid duplicates
213
272
  function_declarations = []
214
273
  for ctx in tool_contexts:
215
274
  for tool_name, tool_def in ctx.tools.items():
275
+ # Create unique name: server_name__tool_name
276
+ unique_name = f"{ctx.server_name}__{tool_name}"
216
277
  func_decl = self._mcp_tool_to_gemini_function(
217
- tool_name, tool_def
278
+ unique_name, tool_def, original_name=tool_name
218
279
  )
219
280
  if func_decl:
220
281
  function_declarations.append(func_decl)
@@ -251,150 +312,164 @@ class GeminiResponsesProcessor(BaseGeminiProcessor):
251
312
  logger.info("Processing cancelled by user")
252
313
  break
253
314
 
254
- response = await self.client.aio.models.generate_content(
315
+ # Use streaming API
316
+ stream = await self.client.aio.models.generate_content_stream(
255
317
  model=model_name, contents=current_contents, config=config
256
318
  )
257
319
 
258
- if not response.candidates:
320
+ # Collect function calls and stream text as it arrives
321
+ collected_function_calls: list[Any] = []
322
+ collected_parts: list[types.Part] = []
323
+ streamed_text = ""
324
+
325
+ async for chunk in stream:
326
+ if cancellation_token and cancellation_token.is_set():
327
+ logger.info("Processing cancelled by user")
328
+ return
329
+
330
+ if not chunk.candidates or not chunk.candidates[0].content:
331
+ continue
332
+
333
+ for part in chunk.candidates[0].content.parts:
334
+ # Stream text immediately
335
+ if part.text:
336
+ streamed_text += part.text
337
+ yield self.chunk_factory.text(part.text, delta=part.text)
338
+ # Collect function calls
339
+ if part.function_call is not None:
340
+ collected_function_calls.append(part.function_call)
341
+ collected_parts.append(part)
342
+
343
+ # If no function calls, we're done (text was already streamed)
344
+ if not collected_function_calls:
259
345
  return
260
346
 
261
- candidate = response.candidates[0]
262
- content = candidate.content
263
-
264
- # Check for function calls
265
- function_calls = [
266
- part.function_call
267
- for part in content.parts
268
- if part.function_call is not None
269
- ]
270
-
271
- if function_calls:
272
- # Add model response with function calls to conversation
273
- current_contents.append(content)
274
-
275
- # Execute tool calls and collect results
276
- function_responses = []
277
- for fc in function_calls:
278
- # Find server name for this tool
279
- server_name = "unknown"
280
- for ctx in tool_contexts:
281
- if fc.name in ctx.tools:
282
- server_name = ctx.server_name
283
- break
284
-
285
- # Generate a unique tool call ID
286
- tool_call_id = f"mcp_{uuid.uuid4().hex[:32]}"
287
-
288
- # Yield TOOL_CALL chunk to show in UI
289
- yield self._create_chunk(
290
- ChunkType.TOOL_CALL,
291
- f"Werkzeug: {server_name}.{fc.name}",
292
- {
293
- "tool_name": fc.name,
294
- "tool_id": tool_call_id,
295
- "server_label": server_name,
296
- "arguments": json.dumps(fc.args),
297
- "status": "starting",
298
- },
299
- )
347
+ # Build the model's response content with function calls
348
+ model_response_parts = []
349
+ if streamed_text:
350
+ model_response_parts.append(types.Part(text=streamed_text))
351
+ model_response_parts.extend(collected_parts)
352
+ current_contents.append(
353
+ types.Content(role="model", parts=model_response_parts)
354
+ )
300
355
 
301
- result = await self._execute_mcp_tool(
302
- fc.name, fc.args, tool_contexts
303
- )
356
+ # Execute tool calls and collect results
357
+ function_responses = []
358
+ for fc in collected_function_calls:
359
+ # Parse unique tool name: server_name__tool_name
360
+ server_name, original_tool_name = self._parse_unique_tool_name(fc.name)
361
+
362
+ # Generate a unique tool call ID
363
+ tool_call_id = f"mcp_{uuid.uuid4().hex[:32]}"
364
+
365
+ # Yield TOOL_CALL chunk to show in UI (use original name)
366
+ yield self.chunk_factory.tool_call(
367
+ f"Benutze Werkzeug: {server_name}.{original_tool_name}",
368
+ tool_name=original_tool_name,
369
+ tool_id=tool_call_id,
370
+ server_label=server_name,
371
+ status="starting",
372
+ )
304
373
 
305
- # Yield TOOL_RESULT chunk with preview
306
- preview = (
307
- result[:TOOL_RESULT_PREVIEW_LENGTH]
308
- if len(result) > TOOL_RESULT_PREVIEW_LENGTH
309
- else result
310
- )
311
- yield self._create_chunk(
312
- ChunkType.TOOL_RESULT,
313
- preview,
314
- {
315
- "tool_name": fc.name,
316
- "tool_id": tool_call_id,
317
- "server_label": server_name,
318
- "status": "completed",
319
- "result_length": str(len(result)),
320
- },
321
- )
374
+ result = await self._execute_mcp_tool(fc.name, fc.args, tool_contexts)
322
375
 
323
- function_responses.append(
324
- types.Part(
325
- function_response=types.FunctionResponse(
326
- name=fc.name,
327
- response={"result": result},
328
- )
376
+ # Yield TOOL_RESULT chunk with preview
377
+ preview = (
378
+ result[:TOOL_RESULT_PREVIEW_LENGTH]
379
+ if len(result) > TOOL_RESULT_PREVIEW_LENGTH
380
+ else result
381
+ )
382
+ yield self.chunk_factory.tool_result(
383
+ preview,
384
+ tool_id=tool_call_id,
385
+ status="completed",
386
+ )
387
+
388
+ function_responses.append(
389
+ types.Part(
390
+ function_response=types.FunctionResponse(
391
+ name=fc.name,
392
+ response={"result": result},
329
393
  )
330
394
  )
331
- logger.debug(
332
- "Tool %s executed, result length: %d",
333
- fc.name,
334
- len(str(result)),
335
- )
336
-
337
- # Add function responses
338
- current_contents.append(
339
- types.Content(role="user", parts=function_responses)
340
395
  )
341
-
342
- # Continue to next round
343
- continue
344
-
345
- # No function calls - yield text response
346
- text_parts = [part.text for part in content.parts if part.text]
347
- if text_parts:
348
- yield self._create_chunk(
349
- ChunkType.TEXT,
350
- "".join(text_parts),
351
- {"delta": "".join(text_parts)},
396
+ logger.debug(
397
+ "Tool %s executed, result length: %d",
398
+ fc.name,
399
+ len(str(result)),
352
400
  )
353
401
 
354
- # Done - no more function calls
355
- return
402
+ # Add function responses
403
+ current_contents.append(
404
+ types.Content(role="user", parts=function_responses)
405
+ )
406
+
407
+ # Continue to next round
356
408
 
357
409
  logger.warning("Max tool rounds (%d) exceeded", max_tool_rounds)
358
410
 
411
+ def _parse_unique_tool_name(self, unique_name: str) -> tuple[str, str]:
412
+ """Parse unique tool name back to server name and original tool name.
413
+
414
+ Args:
415
+ unique_name: Tool name in format 'server_name__tool_name'
416
+
417
+ Returns:
418
+ Tuple of (server_name, original_tool_name)
419
+ """
420
+ if "__" in unique_name:
421
+ parts = unique_name.split("__", 1)
422
+ return parts[0], parts[1]
423
+ # Fallback for tools without prefix
424
+ return "unknown", unique_name
425
+
359
426
  async def _execute_mcp_tool(
360
427
  self,
361
- tool_name: str,
428
+ unique_tool_name: str,
362
429
  args: dict[str, Any],
363
430
  tool_contexts: list[MCPToolContext],
364
431
  ) -> str:
365
432
  """Execute an MCP tool and return the result."""
366
- # Find which context has this tool
433
+ server_name, tool_name = self._parse_unique_tool_name(unique_tool_name)
434
+
367
435
  for ctx in tool_contexts:
368
- if tool_name in ctx.tools:
436
+ if ctx.server_name == server_name and tool_name in ctx.tools:
369
437
  try:
370
438
  logger.debug(
371
439
  "Executing tool %s on server %s with args: %s",
372
440
  tool_name,
373
- ctx.server_name,
441
+ server_name,
374
442
  args,
375
443
  )
376
444
  result = await ctx.session.call_tool(tool_name, args)
377
- # Extract text from result
378
445
  if hasattr(result, "content") and result.content:
379
- texts = [
380
- item.text
381
- for item in result.content
382
- if hasattr(item, "text")
383
- ]
446
+ texts = [i.text for i in result.content if hasattr(i, "text")]
384
447
  return "\n".join(texts) if texts else str(result)
385
448
  return str(result)
386
449
  except Exception as e:
387
- logger.exception("Error executing tool %s: %s", tool_name, str(e))
450
+ logger.exception("Error executing tool %s: %s", tool_name, e)
388
451
  return f"Error executing tool: {e!s}"
389
452
 
390
- return f"Tool {tool_name} not found in any MCP server"
453
+ return f"Tool {tool_name} not found on server {server_name}"
391
454
 
392
455
  def _mcp_tool_to_gemini_function(
393
- self, name: str, tool_def: dict[str, Any]
456
+ self,
457
+ name: str,
458
+ tool_def: dict[str, Any],
459
+ original_name: str | None = None,
394
460
  ) -> types.FunctionDeclaration | None:
395
- """Convert MCP tool definition to Gemini FunctionDeclaration."""
461
+ """Convert MCP tool definition to Gemini FunctionDeclaration.
462
+
463
+ Args:
464
+ name: Unique function name (may include server prefix)
465
+ tool_def: MCP tool definition with description and inputSchema
466
+ original_name: Original tool name for description enhancement
467
+ """
396
468
  try:
397
469
  description = tool_def.get("description", "")
470
+ # Enhance description with original name if using prefixed naming
471
+ if original_name and original_name != name:
472
+ description = f"[{original_name}] {description}"
398
473
  input_schema = tool_def.get("inputSchema", {})
399
474
 
400
475
  # Fix the schema for Gemini compatibility
@@ -410,97 +485,37 @@ class GeminiResponsesProcessor(BaseGeminiProcessor):
410
485
  return None
411
486
 
412
487
  def _fix_schema_for_gemini(self, schema: dict[str, Any]) -> dict[str, Any]:
413
- """Fix JSON schema for Gemini API compatibility.
414
-
415
- Gemini requires 'items' field for array types and doesn't allow certain
416
- JSON Schema fields like '$schema', '$id', 'definitions', etc.
417
- This recursively fixes the schema.
418
- """
419
- if not schema:
488
+ """Fix JSON schema for Gemini API compatibility (recursive)."""
489
+ if not isinstance(schema, dict):
420
490
  return schema
421
491
 
422
- # Deep copy to avoid modifying original
423
- schema = copy.deepcopy(schema)
424
-
425
- # Fields that Gemini doesn't allow in FunctionDeclaration parameters
426
- # Note: additionalProperties gets converted to additional_properties by SDK
427
- forbidden_fields = {
428
- "$schema",
429
- "$id",
430
- "$ref",
431
- "$defs",
432
- "definitions",
433
- "$comment",
434
- "examples",
435
- "default",
436
- "const",
437
- "contentMediaType",
438
- "contentEncoding",
439
- "additionalProperties",
440
- "additional_properties",
441
- "patternProperties",
442
- "unevaluatedProperties",
443
- "unevaluatedItems",
444
- "minItems",
445
- "maxItems",
446
- "minLength",
447
- "maxLength",
448
- "minimum",
449
- "maximum",
450
- "exclusiveMinimum",
451
- "exclusiveMaximum",
452
- "multipleOf",
453
- "pattern",
454
- "format",
455
- "title",
456
- # Composition keywords - Gemini doesn't support these
457
- "allOf",
458
- "oneOf",
459
- "not",
460
- "if",
461
- "then",
462
- "else",
463
- "dependentSchemas",
464
- "dependentRequired",
465
- }
466
-
467
- def fix_property(prop: dict[str, Any]) -> dict[str, Any]:
468
- """Recursively fix a property schema."""
469
- if not isinstance(prop, dict):
470
- return prop
471
-
472
- # Remove forbidden fields
473
- for forbidden in forbidden_fields:
474
- prop.pop(forbidden, None)
475
-
476
- prop_type = prop.get("type")
477
-
478
- # Fix array without items
479
- if prop_type == "array" and "items" not in prop:
480
- prop["items"] = {"type": "string"}
481
- logger.debug("Added missing 'items' to array property")
482
-
483
- # Recurse into items
484
- if "items" in prop and isinstance(prop["items"], dict):
485
- prop["items"] = fix_property(prop["items"])
486
-
487
- # Recurse into properties
488
- if "properties" in prop and isinstance(prop["properties"], dict):
489
- for key, val in prop["properties"].items():
490
- prop["properties"][key] = fix_property(val)
491
-
492
- # Recurse into anyOf/any_of arrays (Gemini accepts these but not
493
- # forbidden fields inside them)
494
- for any_of_key in ("anyOf", "any_of"):
495
- if any_of_key in prop and isinstance(prop[any_of_key], list):
496
- prop[any_of_key] = [
497
- fix_property(item) if isinstance(item, dict) else item
498
- for item in prop[any_of_key]
499
- ]
492
+ result = copy.deepcopy(schema)
493
+
494
+ # Remove forbidden fields
495
+ for key in GEMINI_FORBIDDEN_SCHEMA_FIELDS:
496
+ result.pop(key, None)
497
+
498
+ # Fix array without items
499
+ if result.get("type") == "array" and "items" not in result:
500
+ result["items"] = {"type": "string"}
501
+
502
+ # Recurse into nested schemas
503
+ if isinstance(result.get("items"), dict):
504
+ result["items"] = self._fix_schema_for_gemini(result["items"])
500
505
 
501
- return prop
506
+ if isinstance(result.get("properties"), dict):
507
+ result["properties"] = {
508
+ k: self._fix_schema_for_gemini(v)
509
+ for k, v in result["properties"].items()
510
+ }
502
511
 
503
- return fix_property(schema)
512
+ for key in ("anyOf", "any_of"):
513
+ if isinstance(result.get(key), list):
514
+ result[key] = [
515
+ self._fix_schema_for_gemini(item) for item in result[key]
516
+ ]
517
+
518
+ return result
504
519
 
505
520
  @asynccontextmanager
506
521
  async def _mcp_context_manager(
@@ -516,11 +531,18 @@ class GeminiResponsesProcessor(BaseGeminiProcessor):
516
531
  "Connecting to MCP server %s via streamablehttp_client",
517
532
  wrapper.name,
518
533
  )
534
+ # Create httpx client with headers and timeout
535
+ http_client = httpx.AsyncClient(
536
+ headers=wrapper.headers,
537
+ timeout=60.0,
538
+ )
539
+ # Register client for cleanup
540
+ await stack.enter_async_context(http_client)
541
+
519
542
  read, write, _ = await stack.enter_async_context(
520
- streamablehttp_client(
543
+ streamable_http_client(
521
544
  url=wrapper.url,
522
- headers=wrapper.headers,
523
- timeout=60.0,
545
+ http_client=http_client,
524
546
  )
525
547
  )
526
548
 
@@ -590,151 +612,67 @@ class GeminiResponsesProcessor(BaseGeminiProcessor):
590
612
  ) -> types.GenerateContentConfig:
591
613
  """Create generation config from model and payload."""
592
614
  # Default thinking level depends on model
593
- # "medium" is only supported by Flash, Pro uses "high" (default dynamic)
594
- thinking_level = "high"
595
- if "flash" in model.model.lower():
596
- thinking_level = "medium"
615
+ thinking_level = "medium" if "flash" in model.model.lower() else "high"
597
616
 
598
617
  # Override from payload if present
599
- if payload and "thinking_level" in payload:
600
- thinking_level = payload.pop("thinking_level")
618
+ if payload:
619
+ thinking_level = payload.get("thinking_level", thinking_level)
620
+
621
+ # Filter out fields not accepted by GenerateContentConfig
622
+ filtered_payload = {}
623
+ if payload:
624
+ ignored_fields = {"thread_uuid", "user_id", "thinking_level"}
625
+ filtered_payload = {
626
+ k: v for k, v in payload.items() if k not in ignored_fields
627
+ }
601
628
 
602
629
  return types.GenerateContentConfig(
603
630
  temperature=model.temperature,
604
631
  thinking_config=types.ThinkingConfig(thinking_level=thinking_level),
605
- **(payload or {}),
632
+ **filtered_payload,
606
633
  response_modalities=["TEXT"],
607
634
  )
608
635
 
609
636
  def _build_mcp_prompt(self, mcp_servers: list[MCPServer]) -> str:
610
637
  """Build MCP tool selection prompt from server prompts."""
611
- prompts = [f"- {server.prompt}" for server in mcp_servers if server.prompt]
612
- return "\n".join(prompts) if prompts else ""
638
+ return "\n".join(f"- {s.prompt}" for s in mcp_servers if s.prompt)
613
639
 
614
640
  async def _convert_messages_to_gemini_format(
615
641
  self, messages: list[Message], mcp_prompt: str = ""
616
642
  ) -> tuple[list[types.Content], str | None]:
617
643
  """Convert app messages to Gemini Content objects."""
644
+ system_instruction = await self._system_prompt_builder.build(mcp_prompt)
618
645
  contents: list[types.Content] = []
619
- system_instruction: str | None = None
620
-
621
- # Build MCP prompt section if tools are available
622
- mcp_section = ""
623
- if mcp_prompt:
624
- mcp_section = (
625
- "\n\n### Tool-Auswahlrichtlinien (Einbettung externer Beschreibungen)\n"
626
- f"{mcp_prompt}"
627
- )
628
-
629
- # Get system prompt content first
630
- system_prompt_template = await get_system_prompt()
631
- if system_prompt_template:
632
- # Format with MCP prompts placeholder
633
- system_instruction = system_prompt_template.format(mcp_prompts=mcp_section)
634
646
 
635
647
  for msg in messages:
636
- if msg.type == MessageType.SYSTEM:
637
- # Append to system instruction
638
- if system_instruction:
639
- system_instruction += f"\n{msg.text}"
640
- else:
641
- system_instruction = msg.text
642
- elif msg.type in (MessageType.HUMAN, MessageType.ASSISTANT):
643
- role = "user" if msg.type == MessageType.HUMAN else "model"
644
- contents.append(
645
- types.Content(role=role, parts=[types.Part(text=msg.text)])
646
- )
648
+ match msg.type:
649
+ case MessageType.SYSTEM:
650
+ system_instruction = (
651
+ f"{system_instruction}\n{msg.text}"
652
+ if system_instruction
653
+ else msg.text
654
+ )
655
+ case MessageType.HUMAN | MessageType.ASSISTANT:
656
+ role = "user" if msg.type == MessageType.HUMAN else "model"
657
+ contents.append(
658
+ types.Content(role=role, parts=[types.Part(text=msg.text)])
659
+ )
647
660
 
648
661
  return contents, system_instruction
649
662
 
663
+ def _extract_text_from_parts(self, parts: list[Any]) -> str:
664
+ """Extract and join text from content parts."""
665
+ return "".join(p.text for p in parts if p.text)
666
+
650
667
  def _handle_chunk(self, chunk: Any) -> Chunk | None:
651
668
  """Handle a single chunk from Gemini stream."""
652
- # Gemini chunks contain candidates. First candidate.
653
- if not chunk.candidates or not chunk.candidates[0].content:
669
+ if (
670
+ not chunk.candidates
671
+ or not chunk.candidates[0].content
672
+ or not chunk.candidates[0].content.parts
673
+ ):
654
674
  return None
655
675
 
656
- candidate = chunk.candidates[0]
657
- content = candidate.content
658
-
659
- # List comprehension for text parts
660
- if not content.parts:
661
- return None
662
-
663
- text_parts = [part.text for part in content.parts if part.text]
664
-
665
- if text_parts:
666
- return self._create_chunk(
667
- ChunkType.TEXT, "".join(text_parts), {"delta": "".join(text_parts)}
668
- )
669
-
676
+ if text := self._extract_text_from_parts(chunk.candidates[0].content.parts):
677
+ return self.chunk_factory.text(text, delta=text)
670
678
  return None
671
-
672
- def _create_chunk(
673
- self,
674
- chunk_type: ChunkType,
675
- content: str,
676
- extra_metadata: dict[str, str] | None = None,
677
- ) -> Chunk:
678
- """Create a Chunk."""
679
- metadata = {
680
- "processor": "gemini_responses",
681
- }
682
- if extra_metadata:
683
- metadata.update(extra_metadata)
684
-
685
- return Chunk(
686
- type=chunk_type,
687
- text=content,
688
- chunk_metadata=metadata,
689
- )
690
-
691
- async def _create_auth_required_chunk(self, server: MCPServer) -> Chunk:
692
- """Create an AUTH_REQUIRED chunk."""
693
- # reusing logic from other processors, simplified here
694
- return Chunk(
695
- type=ChunkType.AUTH_REQUIRED,
696
- text=f"{server.name} authentication required",
697
- chunk_metadata={"server_name": server.name},
698
- )
699
-
700
- def _parse_mcp_headers(self, server: MCPServer) -> dict[str, str]:
701
- """Parse headers from server config.
702
-
703
- Returns:
704
- Dictionary of HTTP headers to send to the MCP server.
705
- """
706
- if not server.headers or server.headers == "{}":
707
- return {}
708
-
709
- try:
710
- headers_dict = json.loads(server.headers)
711
- return dict(headers_dict)
712
- except json.JSONDecodeError:
713
- logger.warning("Invalid headers JSON for server %s", server.name)
714
- return {}
715
-
716
- async def _get_valid_token_for_server(
717
- self, server: MCPServer, user_id: int
718
- ) -> AssistantMCPUserToken | None:
719
- """Get a valid OAuth token for the server/user."""
720
- if server.id is None:
721
- return None
722
-
723
- with rx.session() as session:
724
- token = self._mcp_auth_service.get_user_token(session, user_id, server.id)
725
-
726
- if token is None:
727
- return None
728
-
729
- return await self._mcp_auth_service.ensure_valid_token(
730
- session, server, token
731
- )
732
-
733
-
734
- class MCPSessionWrapper:
735
- """Wrapper to store MCP connection details before creating actual session."""
736
-
737
- def __init__(self, url: str, headers: dict[str, str], name: str) -> None:
738
- self.url = url
739
- self.headers = headers
740
- self.name = name