appkit-assistant 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- appkit_assistant/backend/mcp_auth_service.py +796 -0
- appkit_assistant/backend/model_manager.py +2 -1
- appkit_assistant/backend/models.py +43 -0
- appkit_assistant/backend/processors/openai_responses_processor.py +259 -25
- appkit_assistant/backend/repositories.py +1 -1
- appkit_assistant/backend/system_prompt_cache.py +5 -5
- appkit_assistant/components/mcp_server_dialogs.py +327 -21
- appkit_assistant/components/message.py +62 -0
- appkit_assistant/components/thread.py +99 -1
- appkit_assistant/state/mcp_server_state.py +42 -1
- appkit_assistant/state/system_prompt_state.py +4 -4
- appkit_assistant/state/thread_list_state.py +3 -3
- appkit_assistant/state/thread_state.py +190 -28
- {appkit_assistant-0.14.0.dist-info → appkit_assistant-0.15.0.dist-info}/METADATA +1 -1
- appkit_assistant-0.15.0.dist-info/RECORD +29 -0
- appkit_assistant-0.14.0.dist-info/RECORD +0 -28
- {appkit_assistant-0.14.0.dist-info → appkit_assistant-0.15.0.dist-info}/WHEEL +0 -0
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import threading
|
|
5
|
+
from typing import Self
|
|
5
6
|
|
|
6
7
|
from appkit_assistant.backend.models import AIModel
|
|
7
8
|
from appkit_assistant.backend.processor import Processor
|
|
@@ -18,7 +19,7 @@ class ModelManager:
|
|
|
18
19
|
None # Default model ID will be set to the first registered model
|
|
19
20
|
)
|
|
20
21
|
|
|
21
|
-
def __new__(cls) ->
|
|
22
|
+
def __new__(cls) -> Self:
|
|
22
23
|
if cls._instance is None:
|
|
23
24
|
with cls._lock:
|
|
24
25
|
if cls._instance is None:
|
|
@@ -44,6 +44,7 @@ class ChunkType(StrEnum):
|
|
|
44
44
|
TOOL_RESULT = "tool_result" # result from a tool
|
|
45
45
|
TOOL_CALL = "tool_call" # calling a tool
|
|
46
46
|
COMPLETION = "completion" # when response generation is complete
|
|
47
|
+
AUTH_REQUIRED = "auth_required" # user needs to authenticate (MCP)
|
|
47
48
|
ERROR = "error" # when an error occurs
|
|
48
49
|
LIFECYCLE = "lifecycle"
|
|
49
50
|
|
|
@@ -100,6 +101,7 @@ class AIModel(BaseModel):
|
|
|
100
101
|
supports_attachments: bool = False
|
|
101
102
|
keywords: list[str] = []
|
|
102
103
|
disabled: bool = False
|
|
104
|
+
requires_role: str | None = None
|
|
103
105
|
|
|
104
106
|
|
|
105
107
|
class Suggestion(BaseModel):
|
|
@@ -143,6 +145,12 @@ class MCPServer(rx.Model, table=True):
|
|
|
143
145
|
# Optional discovery URL override
|
|
144
146
|
discovery_url: str | None = Field(default=None, nullable=True)
|
|
145
147
|
|
|
148
|
+
# OAuth client credentials (encrypted)
|
|
149
|
+
oauth_client_id: str | None = Field(default=None, nullable=True)
|
|
150
|
+
oauth_client_secret: str | None = Field(
|
|
151
|
+
default=None, nullable=True, sa_type=EncryptedString
|
|
152
|
+
)
|
|
153
|
+
|
|
146
154
|
# Cached OAuth/Discovery metadata (read-only for user mostly)
|
|
147
155
|
oauth_issuer: str | None = Field(default=None, nullable=True)
|
|
148
156
|
oauth_authorize_url: str | None = Field(default=None, nullable=True)
|
|
@@ -194,3 +202,38 @@ class AssistantThread(rx.Model, table=True):
|
|
|
194
202
|
default_factory=lambda: datetime.now(UTC),
|
|
195
203
|
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
|
196
204
|
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class AssistantMCPUserToken(rx.Model, table=True):
|
|
208
|
+
"""Model for storing user-specific OAuth tokens for MCP servers.
|
|
209
|
+
|
|
210
|
+
Each user can have one token per MCP server. Tokens are encrypted at rest.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
__tablename__ = "assistant_mcp_user_token"
|
|
214
|
+
|
|
215
|
+
id: int | None = Field(default=None, primary_key=True)
|
|
216
|
+
user_id: int = Field(index=True, nullable=False)
|
|
217
|
+
mcp_server_id: int = Field(
|
|
218
|
+
index=True, nullable=False, foreign_key="assistant_mcp_servers.id"
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Tokens are encrypted at rest
|
|
222
|
+
access_token: str = Field(nullable=False, sa_type=EncryptedString)
|
|
223
|
+
refresh_token: str | None = Field(
|
|
224
|
+
default=None, nullable=True, sa_type=EncryptedString
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Token expiry timestamp
|
|
228
|
+
expires_at: datetime = Field(
|
|
229
|
+
sa_column=Column(DateTime(timezone=True), nullable=False)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
created_at: datetime = Field(
|
|
233
|
+
default_factory=lambda: datetime.now(UTC),
|
|
234
|
+
sa_column=Column(DateTime(timezone=True)),
|
|
235
|
+
)
|
|
236
|
+
updated_at: datetime = Field(
|
|
237
|
+
default_factory=lambda: datetime.now(UTC),
|
|
238
|
+
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
|
239
|
+
)
|
|
@@ -3,16 +3,22 @@ import logging
|
|
|
3
3
|
from collections.abc import AsyncGenerator
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
|
+
import reflex as rx
|
|
7
|
+
|
|
8
|
+
from appkit_assistant.backend.mcp_auth_service import MCPAuthService
|
|
6
9
|
from appkit_assistant.backend.models import (
|
|
7
10
|
AIModel,
|
|
11
|
+
AssistantMCPUserToken,
|
|
8
12
|
Chunk,
|
|
9
13
|
ChunkType,
|
|
14
|
+
MCPAuthType,
|
|
10
15
|
MCPServer,
|
|
11
16
|
Message,
|
|
12
17
|
MessageType,
|
|
13
18
|
)
|
|
14
19
|
from appkit_assistant.backend.processors.openai_base import BaseOpenAIProcessor
|
|
15
20
|
from appkit_assistant.backend.system_prompt_cache import get_system_prompt
|
|
21
|
+
from appkit_commons.database.session import get_session_manager
|
|
16
22
|
|
|
17
23
|
logger = logging.getLogger(__name__)
|
|
18
24
|
|
|
@@ -26,9 +32,13 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
26
32
|
api_key: str | None = None,
|
|
27
33
|
base_url: str | None = None,
|
|
28
34
|
is_azure: bool = False,
|
|
35
|
+
oauth_redirect_uri: str = "",
|
|
29
36
|
) -> None:
|
|
30
37
|
super().__init__(models, api_key, base_url, is_azure)
|
|
31
38
|
self._current_reasoning_session: str | None = None
|
|
39
|
+
self._current_user_id: int | None = None
|
|
40
|
+
self._mcp_auth_service = MCPAuthService(redirect_uri=oauth_redirect_uri)
|
|
41
|
+
self._pending_auth_servers: list[MCPServer] = []
|
|
32
42
|
|
|
33
43
|
async def process(
|
|
34
44
|
self,
|
|
@@ -37,6 +47,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
37
47
|
files: list[str] | None = None, # noqa: ARG002
|
|
38
48
|
mcp_servers: list[MCPServer] | None = None,
|
|
39
49
|
payload: dict[str, Any] | None = None,
|
|
50
|
+
user_id: int | None = None,
|
|
40
51
|
) -> AsyncGenerator[Chunk, None]:
|
|
41
52
|
"""Process messages using simplified content accumulator pattern."""
|
|
42
53
|
if not self.client:
|
|
@@ -47,29 +58,45 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
47
58
|
raise ValueError(msg)
|
|
48
59
|
|
|
49
60
|
model = self.models[model_id]
|
|
61
|
+
self._current_user_id = user_id
|
|
62
|
+
self._pending_auth_servers = []
|
|
50
63
|
|
|
51
64
|
try:
|
|
52
65
|
session = await self._create_responses_request(
|
|
53
|
-
messages, model, mcp_servers, payload
|
|
66
|
+
messages, model, mcp_servers, payload, user_id
|
|
54
67
|
)
|
|
55
68
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
69
|
+
try:
|
|
70
|
+
if hasattr(session, "__aiter__"): # Streaming
|
|
71
|
+
async for event in session:
|
|
72
|
+
chunk = self._handle_event(event)
|
|
73
|
+
if chunk:
|
|
74
|
+
yield chunk
|
|
75
|
+
else: # Non-streaming
|
|
76
|
+
content = self._extract_responses_content(session)
|
|
77
|
+
if content:
|
|
78
|
+
yield Chunk(
|
|
79
|
+
type=ChunkType.TEXT,
|
|
80
|
+
text=content,
|
|
81
|
+
chunk_metadata={
|
|
82
|
+
"source": "responses_api",
|
|
83
|
+
"streaming": "false",
|
|
84
|
+
},
|
|
85
|
+
)
|
|
86
|
+
except Exception as e:
|
|
87
|
+
logger.error("Error during response processing: %s", e)
|
|
88
|
+
# Continue to yield auth chunks if any
|
|
89
|
+
|
|
90
|
+
# After processing (or on error), yield any pending auth requirements
|
|
91
|
+
logger.debug(
|
|
92
|
+
"Processing pending auth servers: %d", len(self._pending_auth_servers)
|
|
93
|
+
)
|
|
94
|
+
for server in self._pending_auth_servers:
|
|
95
|
+
logger.debug("Yielding auth chunk for server: %s", server.name)
|
|
96
|
+
yield await self._create_auth_required_chunk(server)
|
|
97
|
+
|
|
72
98
|
except Exception as e:
|
|
99
|
+
logger.error("Critical error in OpenAI processor: %s", e)
|
|
73
100
|
raise e
|
|
74
101
|
|
|
75
102
|
def _handle_event(self, event: Any) -> Chunk | None:
|
|
@@ -95,7 +122,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
95
122
|
result = handler(event_type)
|
|
96
123
|
if result:
|
|
97
124
|
content_preview = result.text[:50] if result.text else ""
|
|
98
|
-
logger.
|
|
125
|
+
logger.debug(
|
|
99
126
|
"Event %s → Chunk: type=%s, content=%s",
|
|
100
127
|
event_type,
|
|
101
128
|
result.type,
|
|
@@ -205,14 +232,36 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
205
232
|
return None
|
|
206
233
|
|
|
207
234
|
def _handle_mcp_call_done(self, item: Any) -> Chunk | None:
|
|
208
|
-
"""Handle MCP call completion.
|
|
235
|
+
"""Handle MCP call completion.
|
|
236
|
+
|
|
237
|
+
Detects 401/403 authentication errors and marks servers for auth flow.
|
|
238
|
+
"""
|
|
209
239
|
tool_id = getattr(item, "id", "unknown_id")
|
|
210
240
|
tool_name = getattr(item, "name", "unknown_tool")
|
|
241
|
+
server_label = getattr(item, "server_label", "unknown_server")
|
|
211
242
|
error = getattr(item, "error", None)
|
|
212
243
|
output = getattr(item, "output", None)
|
|
213
244
|
|
|
214
245
|
if error:
|
|
215
246
|
error_text = self._extract_error_text(error)
|
|
247
|
+
|
|
248
|
+
# Check for authentication errors (401/403)
|
|
249
|
+
if self._is_auth_error(error):
|
|
250
|
+
# Find the server config and queue for auth flow
|
|
251
|
+
return self._create_chunk(
|
|
252
|
+
ChunkType.TOOL_RESULT,
|
|
253
|
+
f"Authentifizierung erforderlich für {server_label}",
|
|
254
|
+
{
|
|
255
|
+
"tool_id": tool_id,
|
|
256
|
+
"tool_name": tool_name,
|
|
257
|
+
"server_label": server_label,
|
|
258
|
+
"status": "auth_required",
|
|
259
|
+
"error": True,
|
|
260
|
+
"auth_required": True,
|
|
261
|
+
"reasoning_session": self._current_reasoning_session,
|
|
262
|
+
},
|
|
263
|
+
)
|
|
264
|
+
|
|
216
265
|
return self._create_chunk(
|
|
217
266
|
ChunkType.TOOL_RESULT,
|
|
218
267
|
f"Werkzeugfehler: {error_text}",
|
|
@@ -238,6 +287,21 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
238
287
|
},
|
|
239
288
|
)
|
|
240
289
|
|
|
290
|
+
def _is_auth_error(self, error: Any) -> bool:
|
|
291
|
+
"""Check if an error indicates authentication failure (401/403)."""
|
|
292
|
+
error_str = str(error).lower()
|
|
293
|
+
auth_indicators = [
|
|
294
|
+
"401",
|
|
295
|
+
"403",
|
|
296
|
+
"unauthorized",
|
|
297
|
+
"forbidden",
|
|
298
|
+
"authentication required",
|
|
299
|
+
"access denied",
|
|
300
|
+
"invalid token",
|
|
301
|
+
"token expired",
|
|
302
|
+
]
|
|
303
|
+
return any(indicator in error_str for indicator in auth_indicators)
|
|
304
|
+
|
|
241
305
|
def _extract_error_text(self, error: Any) -> str:
|
|
242
306
|
"""Extract readable error text from error object."""
|
|
243
307
|
if isinstance(error, dict) and "content" in error:
|
|
@@ -246,7 +310,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
246
310
|
return content[0].get("text", str(error))
|
|
247
311
|
return "Unknown error"
|
|
248
312
|
|
|
249
|
-
def _handle_mcp_events(self, event_type: str, event: Any) -> Chunk | None:
|
|
313
|
+
def _handle_mcp_events(self, event_type: str, event: Any) -> Chunk | None: # noqa: PLR0911, PLR0912
|
|
250
314
|
"""Handle MCP-specific events."""
|
|
251
315
|
if event_type == "response.mcp_call_arguments.delta":
|
|
252
316
|
tool_id = getattr(event, "item_id", "unknown_id")
|
|
@@ -315,6 +379,70 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
315
379
|
|
|
316
380
|
if event_type == "response.mcp_list_tools.failed":
|
|
317
381
|
tool_id = getattr(event, "item_id", "unknown_id")
|
|
382
|
+
error = getattr(event, "error", None)
|
|
383
|
+
|
|
384
|
+
# Debugging: Log available attributes to help diagnosis
|
|
385
|
+
if logger.isEnabledFor(logging.DEBUG) and error:
|
|
386
|
+
logger.debug("Error object type: %s, content: %s", type(error), error)
|
|
387
|
+
|
|
388
|
+
# Extract error message safely
|
|
389
|
+
error_str = ""
|
|
390
|
+
if error:
|
|
391
|
+
if isinstance(error, dict):
|
|
392
|
+
error_str = error.get("message", str(error))
|
|
393
|
+
elif hasattr(error, "message"):
|
|
394
|
+
error_str = getattr(error, "message", str(error))
|
|
395
|
+
else:
|
|
396
|
+
error_str = str(error)
|
|
397
|
+
|
|
398
|
+
# Check for authentication errors (401/403)
|
|
399
|
+
# OR if we have pending auth servers (strong signal we missed a token)
|
|
400
|
+
is_auth_error = self._is_auth_error(error_str)
|
|
401
|
+
pending_server = None
|
|
402
|
+
|
|
403
|
+
# 1. Try to find matching server by name in error message
|
|
404
|
+
for server in self._pending_auth_servers:
|
|
405
|
+
if server.name.lower() in error_str.lower():
|
|
406
|
+
pending_server = server
|
|
407
|
+
break
|
|
408
|
+
|
|
409
|
+
# 2. If no match but we have pending servers and it looks like an auth error
|
|
410
|
+
# OR if we have pending servers and likely one of them failed (len=1)
|
|
411
|
+
# We assume the failure belongs to the pending server if we can't be sure
|
|
412
|
+
if (
|
|
413
|
+
not pending_server
|
|
414
|
+
and self._pending_auth_servers
|
|
415
|
+
and (is_auth_error or len(self._pending_auth_servers) == 1)
|
|
416
|
+
):
|
|
417
|
+
pending_server = self._pending_auth_servers[0]
|
|
418
|
+
logger.debug(
|
|
419
|
+
"Assuming pending server %s for list_tools failure '%s'",
|
|
420
|
+
pending_server.name,
|
|
421
|
+
error_str,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
if pending_server:
|
|
425
|
+
logger.debug(
|
|
426
|
+
"Queuing Auth Card for server: %s (Error: %s)",
|
|
427
|
+
pending_server.name,
|
|
428
|
+
error_str,
|
|
429
|
+
)
|
|
430
|
+
# Queue for async processing in the main process loop
|
|
431
|
+
# The auth chunk will be yielded after event processing completes
|
|
432
|
+
if pending_server not in self._pending_auth_servers:
|
|
433
|
+
self._pending_auth_servers.append(pending_server)
|
|
434
|
+
return self._create_chunk(
|
|
435
|
+
ChunkType.TOOL_RESULT,
|
|
436
|
+
f"Authentifizierung erforderlich für {pending_server.name}",
|
|
437
|
+
{
|
|
438
|
+
"tool_id": tool_id,
|
|
439
|
+
"status": "auth_required",
|
|
440
|
+
"server_name": pending_server.name,
|
|
441
|
+
"auth_pending": True,
|
|
442
|
+
"reasoning_session": self._current_reasoning_session,
|
|
443
|
+
},
|
|
444
|
+
)
|
|
445
|
+
|
|
318
446
|
logger.error("MCP tool listing failed for tool_id: %s", str(event))
|
|
319
447
|
return self._create_chunk(
|
|
320
448
|
ChunkType.TOOL_RESULT,
|
|
@@ -396,11 +524,14 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
396
524
|
model: AIModel,
|
|
397
525
|
mcp_servers: list[MCPServer] | None = None,
|
|
398
526
|
payload: dict[str, Any] | None = None,
|
|
527
|
+
user_id: int | None = None,
|
|
399
528
|
) -> Any:
|
|
400
529
|
"""Create a simplified responses API request."""
|
|
401
|
-
# Configure MCP tools if provided
|
|
530
|
+
# Configure MCP tools if provided (now async for token lookup)
|
|
402
531
|
tools, mcp_prompt = (
|
|
403
|
-
self._configure_mcp_tools(mcp_servers
|
|
532
|
+
await self._configure_mcp_tools(mcp_servers, user_id)
|
|
533
|
+
if mcp_servers
|
|
534
|
+
else ([], "")
|
|
404
535
|
)
|
|
405
536
|
|
|
406
537
|
# Convert messages to responses format with system message
|
|
@@ -421,11 +552,15 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
421
552
|
logger.debug("Responses API request params: %s", params)
|
|
422
553
|
return await self.client.responses.create(**params)
|
|
423
554
|
|
|
424
|
-
def _configure_mcp_tools(
|
|
425
|
-
self,
|
|
555
|
+
async def _configure_mcp_tools(
|
|
556
|
+
self,
|
|
557
|
+
mcp_servers: list[MCPServer] | None,
|
|
558
|
+
user_id: int | None = None,
|
|
426
559
|
) -> tuple[list[dict[str, Any]], str]:
|
|
427
560
|
"""Configure MCP servers as tools for the responses API.
|
|
428
561
|
|
|
562
|
+
Injects OAuth Bearer tokens for servers that require authentication.
|
|
563
|
+
|
|
429
564
|
Returns:
|
|
430
565
|
tuple: (tools list, concatenated prompts string)
|
|
431
566
|
"""
|
|
@@ -434,6 +569,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
434
569
|
|
|
435
570
|
tools = []
|
|
436
571
|
prompts = []
|
|
572
|
+
|
|
437
573
|
for server in mcp_servers:
|
|
438
574
|
tool_config = {
|
|
439
575
|
"type": "mcp",
|
|
@@ -442,8 +578,28 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
442
578
|
"require_approval": "never",
|
|
443
579
|
}
|
|
444
580
|
|
|
581
|
+
# Start with existing headers
|
|
582
|
+
headers = {}
|
|
445
583
|
if server.headers and server.headers != "{}":
|
|
446
|
-
|
|
584
|
+
headers = json.loads(server.headers)
|
|
585
|
+
|
|
586
|
+
# Inject OAuth token if server requires OAuth and user is authenticated
|
|
587
|
+
if server.auth_type == MCPAuthType.OAUTH_DISCOVERY and user_id is not None:
|
|
588
|
+
token = await self._get_valid_token_for_server(server, user_id)
|
|
589
|
+
if token:
|
|
590
|
+
headers["Authorization"] = f"Bearer {token.access_token}"
|
|
591
|
+
logger.debug("Injected OAuth token for server %s", server.name)
|
|
592
|
+
else:
|
|
593
|
+
# No valid token - server will likely fail with 401
|
|
594
|
+
# Track this server for potential auth flow
|
|
595
|
+
self._pending_auth_servers.append(server)
|
|
596
|
+
logger.debug(
|
|
597
|
+
"No valid token for OAuth server %s, auth may be required",
|
|
598
|
+
server.name,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
if headers:
|
|
602
|
+
tool_config["headers"] = headers
|
|
447
603
|
|
|
448
604
|
tools.append(tool_config)
|
|
449
605
|
|
|
@@ -511,3 +667,81 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
511
667
|
return first_output.content[0].get("text", "")
|
|
512
668
|
return str(first_output.content)
|
|
513
669
|
return None
|
|
670
|
+
|
|
671
|
+
async def _get_valid_token_for_server(
|
|
672
|
+
self,
|
|
673
|
+
server: MCPServer,
|
|
674
|
+
user_id: int,
|
|
675
|
+
) -> AssistantMCPUserToken | None:
|
|
676
|
+
"""Get a valid OAuth token for the given server and user.
|
|
677
|
+
|
|
678
|
+
Refreshes the token if expired and refresh token is available.
|
|
679
|
+
|
|
680
|
+
Args:
|
|
681
|
+
server: The MCP server configuration.
|
|
682
|
+
user_id: The user's ID.
|
|
683
|
+
|
|
684
|
+
Returns:
|
|
685
|
+
A valid token or None if not available.
|
|
686
|
+
"""
|
|
687
|
+
if server.id is None:
|
|
688
|
+
return None
|
|
689
|
+
|
|
690
|
+
with rx.session() as session:
|
|
691
|
+
token = self._mcp_auth_service.get_user_token(session, user_id, server.id)
|
|
692
|
+
|
|
693
|
+
if token is None:
|
|
694
|
+
return None
|
|
695
|
+
|
|
696
|
+
# Check if token is valid or can be refreshed
|
|
697
|
+
return await self._mcp_auth_service.ensure_valid_token(
|
|
698
|
+
session, server, token
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
async def _create_auth_required_chunk(self, server: MCPServer) -> Chunk:
|
|
702
|
+
"""Create an AUTH_REQUIRED chunk for a server that needs authentication.
|
|
703
|
+
|
|
704
|
+
Args:
|
|
705
|
+
server: The MCP server requiring authentication.
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
A chunk signaling auth is required with the auth URL.
|
|
709
|
+
"""
|
|
710
|
+
# Build the authorization URL
|
|
711
|
+
try:
|
|
712
|
+
# We use a session to store the PKCE state
|
|
713
|
+
# NOTE: rx.session() is for Reflex user session, not DB session.
|
|
714
|
+
# We use get_session_manager().session() for DB access required by PKCE.
|
|
715
|
+
with get_session_manager().session() as session:
|
|
716
|
+
# Use the async method that supports DCR
|
|
717
|
+
auth_service = self._mcp_auth_service
|
|
718
|
+
(
|
|
719
|
+
auth_url,
|
|
720
|
+
state,
|
|
721
|
+
) = await auth_service.build_authorization_url_with_registration(
|
|
722
|
+
server,
|
|
723
|
+
session=session,
|
|
724
|
+
user_id=self._current_user_id,
|
|
725
|
+
)
|
|
726
|
+
logger.info(
|
|
727
|
+
"Built auth URL for server %s, state=%s, url=%s",
|
|
728
|
+
server.name,
|
|
729
|
+
state,
|
|
730
|
+
auth_url[:100] if auth_url else "None",
|
|
731
|
+
)
|
|
732
|
+
except (ValueError, Exception) as e:
|
|
733
|
+
logger.error("Cannot build auth URL for server %s: %s", server.name, str(e))
|
|
734
|
+
auth_url = ""
|
|
735
|
+
state = ""
|
|
736
|
+
|
|
737
|
+
return Chunk(
|
|
738
|
+
type=ChunkType.AUTH_REQUIRED,
|
|
739
|
+
text=f"{server.name} benötigt Ihre Autorisierung",
|
|
740
|
+
chunk_metadata={
|
|
741
|
+
"server_id": str(server.id) if server.id else "",
|
|
742
|
+
"server_name": server.name,
|
|
743
|
+
"auth_url": auth_url,
|
|
744
|
+
"state": state,
|
|
745
|
+
"processor": "openai_responses",
|
|
746
|
+
},
|
|
747
|
+
)
|
|
@@ -43,7 +43,7 @@ class SystemPromptCache:
|
|
|
43
43
|
self._ttl_seconds: int = CACHE_TTL_SECONDS
|
|
44
44
|
self._initialized = True
|
|
45
45
|
|
|
46
|
-
logger.
|
|
46
|
+
logger.debug(
|
|
47
47
|
"SystemPromptCache initialized with TTL=%d seconds",
|
|
48
48
|
self._ttl_seconds,
|
|
49
49
|
)
|
|
@@ -80,7 +80,7 @@ class SystemPromptCache:
|
|
|
80
80
|
return self._cached_prompt
|
|
81
81
|
|
|
82
82
|
# Cache miss or expired - fetch from database
|
|
83
|
-
logger.
|
|
83
|
+
logger.debug("Cache miss - fetching latest prompt from database")
|
|
84
84
|
|
|
85
85
|
async with get_asyncdb_session() as session:
|
|
86
86
|
latest_prompt = await system_prompt_repo.find_latest(session)
|
|
@@ -102,7 +102,7 @@ class SystemPromptCache:
|
|
|
102
102
|
self._cached_version = prompt_version
|
|
103
103
|
self._cache_timestamp = datetime.now(UTC)
|
|
104
104
|
|
|
105
|
-
logger.
|
|
105
|
+
logger.debug(
|
|
106
106
|
"Cached prompt version %d (%d characters)",
|
|
107
107
|
self._cached_version,
|
|
108
108
|
len(self._cached_prompt),
|
|
@@ -118,7 +118,7 @@ class SystemPromptCache:
|
|
|
118
118
|
"""
|
|
119
119
|
async with self._lock:
|
|
120
120
|
if self._cached_prompt is not None:
|
|
121
|
-
logger.
|
|
121
|
+
logger.debug(
|
|
122
122
|
"Cache invalidated (was version %d)",
|
|
123
123
|
self._cached_version,
|
|
124
124
|
)
|
|
@@ -135,7 +135,7 @@ class SystemPromptCache:
|
|
|
135
135
|
seconds: New TTL in seconds.
|
|
136
136
|
"""
|
|
137
137
|
self._ttl_seconds = seconds
|
|
138
|
-
logger.
|
|
138
|
+
logger.debug("Cache TTL updated to %d seconds", seconds)
|
|
139
139
|
|
|
140
140
|
@property
|
|
141
141
|
def is_cached(self) -> bool:
|