appkit-assistant 0.16.3__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,726 @@
1
+ """
2
+ Gemini responses processor for generating AI responses using Google's GenAI API.
3
+ """
4
+
5
+ import copy
6
+ import json
7
+ import logging
8
+ import uuid
9
+ from collections.abc import AsyncGenerator
10
+ from contextlib import AsyncExitStack, asynccontextmanager
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, Final
13
+
14
+ import reflex as rx
15
+ from google.genai import types
16
+ from mcp import ClientSession
17
+ from mcp.client.streamable_http import streamablehttp_client
18
+
19
+ from appkit_assistant.backend.mcp_auth_service import MCPAuthService
20
+ from appkit_assistant.backend.models import (
21
+ AIModel,
22
+ AssistantMCPUserToken,
23
+ Chunk,
24
+ ChunkType,
25
+ MCPAuthType,
26
+ MCPServer,
27
+ Message,
28
+ MessageType,
29
+ )
30
+ from appkit_assistant.backend.processor import mcp_oauth_redirect_uri
31
+ from appkit_assistant.backend.processors.gemini_base import BaseGeminiProcessor
32
+ from appkit_assistant.backend.system_prompt_cache import get_system_prompt
33
+
34
+ logger = logging.getLogger(__name__)
35
+ default_oauth_redirect_uri: Final[str] = mcp_oauth_redirect_uri()
36
+
37
+ # Maximum characters to show in tool result preview
38
+ TOOL_RESULT_PREVIEW_LENGTH: Final[int] = 500
39
+
40
+
41
+ @dataclass
42
+ class MCPToolContext:
43
+ """Context for MCP tool execution."""
44
+
45
+ session: ClientSession
46
+ server_name: str
47
+ tools: dict[str, Any] = field(default_factory=dict)
48
+
49
+
50
+ class GeminiResponsesProcessor(BaseGeminiProcessor):
51
+ """Gemini processor using the GenAI API with native MCP support."""
52
+
53
+ def __init__(
54
+ self,
55
+ models: dict[str, AIModel],
56
+ api_key: str | None = None,
57
+ oauth_redirect_uri: str = default_oauth_redirect_uri,
58
+ ) -> None:
59
+ super().__init__(models, api_key)
60
+ self._current_reasoning_session: str | None = None
61
+ self._current_user_id: int | None = None
62
+ self._mcp_auth_service = MCPAuthService(redirect_uri=oauth_redirect_uri)
63
+ self._pending_auth_servers: list[MCPServer] = []
64
+
65
+ logger.debug("Using redirect URI for MCP OAuth: %s", oauth_redirect_uri)
66
+
67
+ async def process(
68
+ self,
69
+ messages: list[Message],
70
+ model_id: str,
71
+ files: list[str] | None = None, # noqa: ARG002
72
+ mcp_servers: list[MCPServer] | None = None,
73
+ payload: dict[str, Any] | None = None,
74
+ user_id: int | None = None,
75
+ ) -> AsyncGenerator[Chunk, None]:
76
+ """Process messages using Google GenAI API with native MCP support."""
77
+ if not self.client:
78
+ raise ValueError("Gemini Client not initialized.")
79
+
80
+ if model_id not in self.models:
81
+ msg = f"Model {model_id} not supported by Gemini processor"
82
+ raise ValueError(msg)
83
+
84
+ model = self.models[model_id]
85
+ self._current_user_id = user_id
86
+ self._pending_auth_servers = []
87
+ self._current_reasoning_session = None
88
+
89
+ # Prepare configuration
90
+ config = self._create_generation_config(model, payload)
91
+
92
+ # Connect to MCP servers and create sessions
93
+ mcp_sessions = []
94
+ mcp_prompt = ""
95
+ if mcp_servers:
96
+ sessions_result = await self._create_mcp_sessions(mcp_servers, user_id)
97
+ mcp_sessions = sessions_result["sessions"]
98
+ self._pending_auth_servers = sessions_result["auth_required"]
99
+ mcp_prompt = self._build_mcp_prompt(mcp_servers)
100
+
101
+ if mcp_sessions:
102
+ # Pass sessions directly to tools - SDK handles everything!
103
+ config.tools = mcp_sessions
104
+
105
+ # Prepare messages with MCP prompts for tool selection
106
+ contents, system_instruction = await self._convert_messages_to_gemini_format(
107
+ messages, mcp_prompt
108
+ )
109
+
110
+ # Add system instruction to config if present
111
+ if system_instruction:
112
+ config.system_instruction = system_instruction
113
+
114
+ if mcp_sessions:
115
+ logger.info(
116
+ "Connected to %d MCP servers for native tool support",
117
+ len(mcp_sessions),
118
+ )
119
+
120
+ try:
121
+ # Generate content with MCP tools
122
+ async for chunk in self._stream_with_mcp(
123
+ model.model, contents, config, mcp_sessions
124
+ ):
125
+ yield chunk
126
+
127
+ # Handle any pending auth
128
+ for server in self._pending_auth_servers:
129
+ yield await self._create_auth_required_chunk(server)
130
+
131
+ except Exception as e:
132
+ logger.exception("Error in Gemini processor: %s", str(e))
133
+ yield self._create_chunk(ChunkType.ERROR, f"Error: {e!s}")
134
+
135
+ async def _create_mcp_sessions(
136
+ self, servers: list[MCPServer], user_id: int | None
137
+ ) -> dict[str, Any]:
138
+ """Create MCP ClientSession objects for each server.
139
+
140
+ Returns:
141
+ Dict with 'sessions' and 'auth_required' lists
142
+ """
143
+ sessions = []
144
+ auth_required = []
145
+
146
+ for server in servers:
147
+ try:
148
+ # Parse headers
149
+ headers = self._parse_mcp_headers(server)
150
+
151
+ # Handle OAuth - inject token
152
+ if (
153
+ server.auth_type == MCPAuthType.OAUTH_DISCOVERY
154
+ and user_id is not None
155
+ ):
156
+ token = await self._get_valid_token_for_server(server, user_id)
157
+ if token:
158
+ headers["Authorization"] = f"Bearer {token.access_token}"
159
+ else:
160
+ auth_required.append(server)
161
+ logger.warning(
162
+ "Skipping MCP server %s - OAuth token required",
163
+ server.name,
164
+ )
165
+ continue
166
+
167
+ # Create SSE client connection
168
+ # Use URL directly as configured (server determines endpoint)
169
+ logger.debug(
170
+ "Connecting to MCP server %s at %s (headers: %s)",
171
+ server.name,
172
+ server.url,
173
+ {
174
+ k: "***" if k.lower() == "authorization" else v
175
+ for k, v in headers.items()
176
+ },
177
+ )
178
+
179
+ # Create a session wrapper with URL and headers
180
+ session = MCPSessionWrapper(server.url, headers, server.name)
181
+ sessions.append(session)
182
+
183
+ except Exception as e:
184
+ logger.error(
185
+ "Failed to connect to MCP server %s: %s", server.name, str(e)
186
+ )
187
+
188
+ return {"sessions": sessions, "auth_required": auth_required}
189
+
190
+ async def _stream_with_mcp(
191
+ self,
192
+ model_name: str,
193
+ contents: list[types.Content],
194
+ config: types.GenerateContentConfig,
195
+ mcp_sessions: list[Any],
196
+ ) -> AsyncGenerator[Chunk, None]:
197
+ """Stream responses with MCP tool support."""
198
+ if not mcp_sessions:
199
+ # No MCP sessions, direct streaming
200
+ async for chunk in self._stream_generation(model_name, contents, config):
201
+ yield chunk
202
+ return
203
+
204
+ # Enter all session contexts and fetch tools
205
+ async with self._mcp_context_manager(mcp_sessions) as tool_contexts:
206
+ if tool_contexts:
207
+ # Convert MCP tools to Gemini FunctionDeclarations
208
+ function_declarations = []
209
+ for ctx in tool_contexts:
210
+ for tool_name, tool_def in ctx.tools.items():
211
+ func_decl = self._mcp_tool_to_gemini_function(
212
+ tool_name, tool_def
213
+ )
214
+ if func_decl:
215
+ function_declarations.append(func_decl)
216
+
217
+ if function_declarations:
218
+ config.tools = [
219
+ types.Tool(function_declarations=function_declarations)
220
+ ]
221
+ logger.info(
222
+ "Configured %d tools for Gemini from MCP",
223
+ len(function_declarations),
224
+ )
225
+
226
+ # Stream with automatic function calling loop
227
+ async for chunk in self._stream_with_tool_loop(
228
+ model_name, contents, config, tool_contexts
229
+ ):
230
+ yield chunk
231
+
232
+ async def _stream_with_tool_loop(
233
+ self,
234
+ model_name: str,
235
+ contents: list[types.Content],
236
+ config: types.GenerateContentConfig,
237
+ tool_contexts: list[MCPToolContext],
238
+ ) -> AsyncGenerator[Chunk, None]:
239
+ """Stream generation with tool call handling loop."""
240
+ max_tool_rounds = 10
241
+ current_contents = list(contents)
242
+
243
+ for _round_num in range(max_tool_rounds):
244
+ response = await self.client.aio.models.generate_content(
245
+ model=model_name, contents=current_contents, config=config
246
+ )
247
+
248
+ if not response.candidates:
249
+ return
250
+
251
+ candidate = response.candidates[0]
252
+ content = candidate.content
253
+
254
+ # Check for function calls
255
+ function_calls = [
256
+ part.function_call
257
+ for part in content.parts
258
+ if part.function_call is not None
259
+ ]
260
+
261
+ if function_calls:
262
+ # Add model response with function calls to conversation
263
+ current_contents.append(content)
264
+
265
+ # Execute tool calls and collect results
266
+ function_responses = []
267
+ for fc in function_calls:
268
+ # Find server name for this tool
269
+ server_name = "unknown"
270
+ for ctx in tool_contexts:
271
+ if fc.name in ctx.tools:
272
+ server_name = ctx.server_name
273
+ break
274
+
275
+ # Generate a unique tool call ID
276
+ tool_call_id = f"mcp_{uuid.uuid4().hex[:32]}"
277
+
278
+ # Yield TOOL_CALL chunk to show in UI
279
+ yield self._create_chunk(
280
+ ChunkType.TOOL_CALL,
281
+ f"Werkzeug: {server_name}.{fc.name}",
282
+ {
283
+ "tool_name": fc.name,
284
+ "tool_id": tool_call_id,
285
+ "server_label": server_name,
286
+ "arguments": json.dumps(fc.args),
287
+ "status": "starting",
288
+ },
289
+ )
290
+
291
+ result = await self._execute_mcp_tool(
292
+ fc.name, fc.args, tool_contexts
293
+ )
294
+
295
+ # Yield TOOL_RESULT chunk with preview
296
+ preview = (
297
+ result[:TOOL_RESULT_PREVIEW_LENGTH]
298
+ if len(result) > TOOL_RESULT_PREVIEW_LENGTH
299
+ else result
300
+ )
301
+ yield self._create_chunk(
302
+ ChunkType.TOOL_RESULT,
303
+ preview,
304
+ {
305
+ "tool_name": fc.name,
306
+ "tool_id": tool_call_id,
307
+ "server_label": server_name,
308
+ "status": "completed",
309
+ "result_length": str(len(result)),
310
+ },
311
+ )
312
+
313
+ function_responses.append(
314
+ types.Part(
315
+ function_response=types.FunctionResponse(
316
+ name=fc.name,
317
+ response={"result": result},
318
+ )
319
+ )
320
+ )
321
+ logger.debug(
322
+ "Tool %s executed, result length: %d",
323
+ fc.name,
324
+ len(str(result)),
325
+ )
326
+
327
+ # Add function responses
328
+ current_contents.append(
329
+ types.Content(role="user", parts=function_responses)
330
+ )
331
+
332
+ # Continue to next round
333
+ continue
334
+
335
+ # No function calls - yield text response
336
+ text_parts = [part.text for part in content.parts if part.text]
337
+ if text_parts:
338
+ yield self._create_chunk(
339
+ ChunkType.TEXT,
340
+ "".join(text_parts),
341
+ {"delta": "".join(text_parts)},
342
+ )
343
+
344
+ # Done - no more function calls
345
+ return
346
+
347
+ logger.warning("Max tool rounds (%d) exceeded", max_tool_rounds)
348
+
349
+ async def _execute_mcp_tool(
350
+ self,
351
+ tool_name: str,
352
+ args: dict[str, Any],
353
+ tool_contexts: list[MCPToolContext],
354
+ ) -> str:
355
+ """Execute an MCP tool and return the result."""
356
+ # Find which context has this tool
357
+ for ctx in tool_contexts:
358
+ if tool_name in ctx.tools:
359
+ try:
360
+ logger.debug(
361
+ "Executing tool %s on server %s with args: %s",
362
+ tool_name,
363
+ ctx.server_name,
364
+ args,
365
+ )
366
+ result = await ctx.session.call_tool(tool_name, args)
367
+ # Extract text from result
368
+ if hasattr(result, "content") and result.content:
369
+ texts = [
370
+ item.text
371
+ for item in result.content
372
+ if hasattr(item, "text")
373
+ ]
374
+ return "\n".join(texts) if texts else str(result)
375
+ return str(result)
376
+ except Exception as e:
377
+ logger.exception("Error executing tool %s: %s", tool_name, str(e))
378
+ return f"Error executing tool: {e!s}"
379
+
380
+ return f"Tool {tool_name} not found in any MCP server"
381
+
382
+ def _mcp_tool_to_gemini_function(
383
+ self, name: str, tool_def: dict[str, Any]
384
+ ) -> types.FunctionDeclaration | None:
385
+ """Convert MCP tool definition to Gemini FunctionDeclaration."""
386
+ try:
387
+ description = tool_def.get("description", "")
388
+ input_schema = tool_def.get("inputSchema", {})
389
+
390
+ # Fix the schema for Gemini compatibility
391
+ fixed_schema = self._fix_schema_for_gemini(input_schema)
392
+
393
+ return types.FunctionDeclaration(
394
+ name=name,
395
+ description=description,
396
+ parameters=fixed_schema if fixed_schema else None,
397
+ )
398
+ except Exception as e:
399
+ logger.warning("Failed to convert MCP tool %s: %s", name, str(e))
400
+ return None
401
+
402
+ def _fix_schema_for_gemini(self, schema: dict[str, Any]) -> dict[str, Any]:
403
+ """Fix JSON schema for Gemini API compatibility.
404
+
405
+ Gemini requires 'items' field for array types and doesn't allow certain
406
+ JSON Schema fields like '$schema', '$id', 'definitions', etc.
407
+ This recursively fixes the schema.
408
+ """
409
+ if not schema:
410
+ return schema
411
+
412
+ # Deep copy to avoid modifying original
413
+ schema = copy.deepcopy(schema)
414
+
415
+ # Fields that Gemini doesn't allow in FunctionDeclaration parameters
416
+ # Note: additionalProperties gets converted to additional_properties by SDK
417
+ forbidden_fields = {
418
+ "$schema",
419
+ "$id",
420
+ "$ref",
421
+ "$defs",
422
+ "definitions",
423
+ "$comment",
424
+ "examples",
425
+ "default",
426
+ "const",
427
+ "contentMediaType",
428
+ "contentEncoding",
429
+ "additionalProperties",
430
+ "additional_properties",
431
+ "patternProperties",
432
+ "unevaluatedProperties",
433
+ "unevaluatedItems",
434
+ "minItems",
435
+ "maxItems",
436
+ "minLength",
437
+ "maxLength",
438
+ "minimum",
439
+ "maximum",
440
+ "exclusiveMinimum",
441
+ "exclusiveMaximum",
442
+ "multipleOf",
443
+ "pattern",
444
+ "format",
445
+ "title",
446
+ # Composition keywords - Gemini doesn't support these
447
+ "allOf",
448
+ "oneOf",
449
+ "not",
450
+ "if",
451
+ "then",
452
+ "else",
453
+ "dependentSchemas",
454
+ "dependentRequired",
455
+ }
456
+
457
+ def fix_property(prop: dict[str, Any]) -> dict[str, Any]:
458
+ """Recursively fix a property schema."""
459
+ if not isinstance(prop, dict):
460
+ return prop
461
+
462
+ # Remove forbidden fields
463
+ for forbidden in forbidden_fields:
464
+ prop.pop(forbidden, None)
465
+
466
+ prop_type = prop.get("type")
467
+
468
+ # Fix array without items
469
+ if prop_type == "array" and "items" not in prop:
470
+ prop["items"] = {"type": "string"}
471
+ logger.debug("Added missing 'items' to array property")
472
+
473
+ # Recurse into items
474
+ if "items" in prop and isinstance(prop["items"], dict):
475
+ prop["items"] = fix_property(prop["items"])
476
+
477
+ # Recurse into properties
478
+ if "properties" in prop and isinstance(prop["properties"], dict):
479
+ for key, val in prop["properties"].items():
480
+ prop["properties"][key] = fix_property(val)
481
+
482
+ # Recurse into anyOf/any_of arrays (Gemini accepts these but not
483
+ # forbidden fields inside them)
484
+ for any_of_key in ("anyOf", "any_of"):
485
+ if any_of_key in prop and isinstance(prop[any_of_key], list):
486
+ prop[any_of_key] = [
487
+ fix_property(item) if isinstance(item, dict) else item
488
+ for item in prop[any_of_key]
489
+ ]
490
+
491
+ return prop
492
+
493
+ return fix_property(schema)
494
+
495
+ @asynccontextmanager
496
+ async def _mcp_context_manager(
497
+ self, session_wrappers: list[Any]
498
+ ) -> AsyncGenerator[list[MCPToolContext], None]:
499
+ """Context manager to enter all MCP session contexts and fetch tools."""
500
+ async with AsyncExitStack() as stack:
501
+ tool_contexts: list[MCPToolContext] = []
502
+
503
+ for wrapper in session_wrappers:
504
+ try:
505
+ logger.debug(
506
+ "Connecting to MCP server %s via streamablehttp_client",
507
+ wrapper.name,
508
+ )
509
+ read, write, _ = await stack.enter_async_context(
510
+ streamablehttp_client(
511
+ url=wrapper.url,
512
+ headers=wrapper.headers,
513
+ timeout=60.0,
514
+ )
515
+ )
516
+
517
+ session = await stack.enter_async_context(
518
+ ClientSession(read, write)
519
+ )
520
+ await session.initialize()
521
+
522
+ # Fetch tools from this session
523
+ tools_result = await session.list_tools()
524
+ tools_dict = {}
525
+ for tool in tools_result.tools:
526
+ tools_dict[tool.name] = {
527
+ "description": tool.description or "",
528
+ "inputSchema": (
529
+ tool.inputSchema if hasattr(tool, "inputSchema") else {}
530
+ ),
531
+ }
532
+
533
+ ctx = MCPToolContext(
534
+ session=session,
535
+ server_name=wrapper.name,
536
+ tools=tools_dict,
537
+ )
538
+ tool_contexts.append(ctx)
539
+
540
+ logger.info(
541
+ "MCP session initialized for %s with %d tools",
542
+ wrapper.name,
543
+ len(tools_dict),
544
+ )
545
+ except Exception as e:
546
+ logger.exception(
547
+ "Failed to initialize MCP session for %s: %s",
548
+ wrapper.name,
549
+ str(e),
550
+ )
551
+
552
+ try:
553
+ yield tool_contexts
554
+ except Exception as e:
555
+ logger.exception("Error during MCP session usage: %s", str(e))
556
+ raise
557
+
558
+ async def _stream_generation(
559
+ self,
560
+ model_name: str,
561
+ contents: list[types.Content],
562
+ config: types.GenerateContentConfig,
563
+ ) -> AsyncGenerator[Chunk, None]:
564
+ """Stream generation from Gemini model."""
565
+ # generate_content_stream returns an awaitable that yields an async generator
566
+ stream = await self.client.aio.models.generate_content_stream(
567
+ model=model_name, contents=contents, config=config
568
+ )
569
+ async for chunk in stream:
570
+ processed = self._handle_chunk(chunk)
571
+ if processed:
572
+ yield processed
573
+
574
+ def _create_generation_config(
575
+ self, model: AIModel, payload: dict[str, Any] | None
576
+ ) -> types.GenerateContentConfig:
577
+ """Create generation config from model and payload."""
578
+ # Default thinking level depends on model
579
+ # "medium" is only supported by Flash, Pro uses "high" (default dynamic)
580
+ thinking_level = "high"
581
+ if "flash" in model.model.lower():
582
+ thinking_level = "medium"
583
+
584
+ # Override from payload if present
585
+ if payload and "thinking_level" in payload:
586
+ thinking_level = payload.pop("thinking_level")
587
+
588
+ return types.GenerateContentConfig(
589
+ temperature=model.temperature,
590
+ thinking_config=types.ThinkingConfig(thinking_level=thinking_level),
591
+ **(payload or {}),
592
+ response_modalities=["TEXT"],
593
+ )
594
+
595
+ def _build_mcp_prompt(self, mcp_servers: list[MCPServer]) -> str:
596
+ """Build MCP tool selection prompt from server prompts."""
597
+ prompts = [f"- {server.prompt}" for server in mcp_servers if server.prompt]
598
+ return "\n".join(prompts) if prompts else ""
599
+
600
+ async def _convert_messages_to_gemini_format(
601
+ self, messages: list[Message], mcp_prompt: str = ""
602
+ ) -> tuple[list[types.Content], str | None]:
603
+ """Convert app messages to Gemini Content objects."""
604
+ contents: list[types.Content] = []
605
+ system_instruction: str | None = None
606
+
607
+ # Build MCP prompt section if tools are available
608
+ mcp_section = ""
609
+ if mcp_prompt:
610
+ mcp_section = (
611
+ "\n\n### Tool-Auswahlrichtlinien (Einbettung externer Beschreibungen)\n"
612
+ f"{mcp_prompt}"
613
+ )
614
+
615
+ # Get system prompt content first
616
+ system_prompt_template = await get_system_prompt()
617
+ if system_prompt_template:
618
+ # Format with MCP prompts placeholder
619
+ system_instruction = system_prompt_template.format(mcp_prompts=mcp_section)
620
+
621
+ for msg in messages:
622
+ if msg.type == MessageType.SYSTEM:
623
+ # Append to system instruction
624
+ if system_instruction:
625
+ system_instruction += f"\n{msg.text}"
626
+ else:
627
+ system_instruction = msg.text
628
+ elif msg.type in (MessageType.HUMAN, MessageType.ASSISTANT):
629
+ role = "user" if msg.type == MessageType.HUMAN else "model"
630
+ contents.append(
631
+ types.Content(role=role, parts=[types.Part(text=msg.text)])
632
+ )
633
+
634
+ return contents, system_instruction
635
+
636
+ def _handle_chunk(self, chunk: Any) -> Chunk | None:
637
+ """Handle a single chunk from Gemini stream."""
638
+ # Gemini chunks contain candidates. First candidate.
639
+ if not chunk.candidates or not chunk.candidates[0].content:
640
+ return None
641
+
642
+ candidate = chunk.candidates[0]
643
+ content = candidate.content
644
+
645
+ # List comprehension for text parts
646
+ if not content.parts:
647
+ return None
648
+
649
+ text_parts = [part.text for part in content.parts if part.text]
650
+
651
+ if text_parts:
652
+ return self._create_chunk(
653
+ ChunkType.TEXT, "".join(text_parts), {"delta": "".join(text_parts)}
654
+ )
655
+
656
+ return None
657
+
658
+ def _create_chunk(
659
+ self,
660
+ chunk_type: ChunkType,
661
+ content: str,
662
+ extra_metadata: dict[str, str] | None = None,
663
+ ) -> Chunk:
664
+ """Create a Chunk."""
665
+ metadata = {
666
+ "processor": "gemini_responses",
667
+ }
668
+ if extra_metadata:
669
+ metadata.update(extra_metadata)
670
+
671
+ return Chunk(
672
+ type=chunk_type,
673
+ text=content,
674
+ chunk_metadata=metadata,
675
+ )
676
+
677
+ async def _create_auth_required_chunk(self, server: MCPServer) -> Chunk:
678
+ """Create an AUTH_REQUIRED chunk."""
679
+ # reusing logic from other processors, simplified here
680
+ return Chunk(
681
+ type=ChunkType.AUTH_REQUIRED,
682
+ text=f"{server.name} authentication required",
683
+ chunk_metadata={"server_name": server.name},
684
+ )
685
+
686
+ def _parse_mcp_headers(self, server: MCPServer) -> dict[str, str]:
687
+ """Parse headers from server config.
688
+
689
+ Returns:
690
+ Dictionary of HTTP headers to send to the MCP server.
691
+ """
692
+ if not server.headers or server.headers == "{}":
693
+ return {}
694
+
695
+ try:
696
+ headers_dict = json.loads(server.headers)
697
+ return dict(headers_dict)
698
+ except json.JSONDecodeError:
699
+ logger.warning("Invalid headers JSON for server %s", server.name)
700
+ return {}
701
+
702
+ async def _get_valid_token_for_server(
703
+ self, server: MCPServer, user_id: int
704
+ ) -> AssistantMCPUserToken | None:
705
+ """Get a valid OAuth token for the server/user."""
706
+ if server.id is None:
707
+ return None
708
+
709
+ with rx.session() as session:
710
+ token = self._mcp_auth_service.get_user_token(session, user_id, server.id)
711
+
712
+ if token is None:
713
+ return None
714
+
715
+ return await self._mcp_auth_service.ensure_valid_token(
716
+ session, server, token
717
+ )
718
+
719
+
720
+ class MCPSessionWrapper:
721
+ """Wrapper to store MCP connection details before creating actual session."""
722
+
723
+ def __init__(self, url: str, headers: dict[str, str], name: str) -> None:
724
+ self.url = url
725
+ self.headers = headers
726
+ self.name = name