appkit-assistant 0.14.1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- appkit_assistant/backend/mcp_auth_service.py +796 -0
- appkit_assistant/backend/model_manager.py +2 -1
- appkit_assistant/backend/models.py +43 -0
- appkit_assistant/backend/processors/openai_responses_processor.py +265 -36
- appkit_assistant/backend/repositories.py +1 -1
- appkit_assistant/backend/system_prompt_cache.py +5 -5
- appkit_assistant/components/mcp_server_dialogs.py +327 -21
- appkit_assistant/components/message.py +62 -0
- appkit_assistant/components/thread.py +99 -1
- appkit_assistant/state/mcp_server_state.py +42 -1
- appkit_assistant/state/system_prompt_state.py +4 -4
- appkit_assistant/state/thread_list_state.py +5 -5
- appkit_assistant/state/thread_state.py +190 -28
- {appkit_assistant-0.14.1.dist-info → appkit_assistant-0.15.1.dist-info}/METADATA +1 -1
- appkit_assistant-0.15.1.dist-info/RECORD +29 -0
- appkit_assistant-0.14.1.dist-info/RECORD +0 -28
- {appkit_assistant-0.14.1.dist-info → appkit_assistant-0.15.1.dist-info}/WHEEL +0 -0
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import threading
|
|
5
|
+
from typing import Self
|
|
5
6
|
|
|
6
7
|
from appkit_assistant.backend.models import AIModel
|
|
7
8
|
from appkit_assistant.backend.processor import Processor
|
|
@@ -18,7 +19,7 @@ class ModelManager:
|
|
|
18
19
|
None # Default model ID will be set to the first registered model
|
|
19
20
|
)
|
|
20
21
|
|
|
21
|
-
def __new__(cls) ->
|
|
22
|
+
def __new__(cls) -> Self:
|
|
22
23
|
if cls._instance is None:
|
|
23
24
|
with cls._lock:
|
|
24
25
|
if cls._instance is None:
|
|
@@ -44,6 +44,7 @@ class ChunkType(StrEnum):
|
|
|
44
44
|
TOOL_RESULT = "tool_result" # result from a tool
|
|
45
45
|
TOOL_CALL = "tool_call" # calling a tool
|
|
46
46
|
COMPLETION = "completion" # when response generation is complete
|
|
47
|
+
AUTH_REQUIRED = "auth_required" # user needs to authenticate (MCP)
|
|
47
48
|
ERROR = "error" # when an error occurs
|
|
48
49
|
LIFECYCLE = "lifecycle"
|
|
49
50
|
|
|
@@ -100,6 +101,7 @@ class AIModel(BaseModel):
|
|
|
100
101
|
supports_attachments: bool = False
|
|
101
102
|
keywords: list[str] = []
|
|
102
103
|
disabled: bool = False
|
|
104
|
+
requires_role: str | None = None
|
|
103
105
|
|
|
104
106
|
|
|
105
107
|
class Suggestion(BaseModel):
|
|
@@ -143,6 +145,12 @@ class MCPServer(rx.Model, table=True):
|
|
|
143
145
|
# Optional discovery URL override
|
|
144
146
|
discovery_url: str | None = Field(default=None, nullable=True)
|
|
145
147
|
|
|
148
|
+
# OAuth client credentials (encrypted)
|
|
149
|
+
oauth_client_id: str | None = Field(default=None, nullable=True)
|
|
150
|
+
oauth_client_secret: str | None = Field(
|
|
151
|
+
default=None, nullable=True, sa_type=EncryptedString
|
|
152
|
+
)
|
|
153
|
+
|
|
146
154
|
# Cached OAuth/Discovery metadata (read-only for user mostly)
|
|
147
155
|
oauth_issuer: str | None = Field(default=None, nullable=True)
|
|
148
156
|
oauth_authorize_url: str | None = Field(default=None, nullable=True)
|
|
@@ -194,3 +202,38 @@ class AssistantThread(rx.Model, table=True):
|
|
|
194
202
|
default_factory=lambda: datetime.now(UTC),
|
|
195
203
|
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
|
196
204
|
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class AssistantMCPUserToken(rx.Model, table=True):
|
|
208
|
+
"""Model for storing user-specific OAuth tokens for MCP servers.
|
|
209
|
+
|
|
210
|
+
Each user can have one token per MCP server. Tokens are encrypted at rest.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
__tablename__ = "assistant_mcp_user_token"
|
|
214
|
+
|
|
215
|
+
id: int | None = Field(default=None, primary_key=True)
|
|
216
|
+
user_id: int = Field(index=True, nullable=False)
|
|
217
|
+
mcp_server_id: int = Field(
|
|
218
|
+
index=True, nullable=False, foreign_key="assistant_mcp_servers.id"
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Tokens are encrypted at rest
|
|
222
|
+
access_token: str = Field(nullable=False, sa_type=EncryptedString)
|
|
223
|
+
refresh_token: str | None = Field(
|
|
224
|
+
default=None, nullable=True, sa_type=EncryptedString
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Token expiry timestamp
|
|
228
|
+
expires_at: datetime = Field(
|
|
229
|
+
sa_column=Column(DateTime(timezone=True), nullable=False)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
created_at: datetime = Field(
|
|
233
|
+
default_factory=lambda: datetime.now(UTC),
|
|
234
|
+
sa_column=Column(DateTime(timezone=True)),
|
|
235
|
+
)
|
|
236
|
+
updated_at: datetime = Field(
|
|
237
|
+
default_factory=lambda: datetime.now(UTC),
|
|
238
|
+
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
|
239
|
+
)
|
|
@@ -3,16 +3,22 @@ import logging
|
|
|
3
3
|
from collections.abc import AsyncGenerator
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
|
+
import reflex as rx
|
|
7
|
+
|
|
8
|
+
from appkit_assistant.backend.mcp_auth_service import MCPAuthService
|
|
6
9
|
from appkit_assistant.backend.models import (
|
|
7
10
|
AIModel,
|
|
11
|
+
AssistantMCPUserToken,
|
|
8
12
|
Chunk,
|
|
9
13
|
ChunkType,
|
|
14
|
+
MCPAuthType,
|
|
10
15
|
MCPServer,
|
|
11
16
|
Message,
|
|
12
17
|
MessageType,
|
|
13
18
|
)
|
|
14
19
|
from appkit_assistant.backend.processors.openai_base import BaseOpenAIProcessor
|
|
15
20
|
from appkit_assistant.backend.system_prompt_cache import get_system_prompt
|
|
21
|
+
from appkit_commons.database.session import get_session_manager
|
|
16
22
|
|
|
17
23
|
logger = logging.getLogger(__name__)
|
|
18
24
|
|
|
@@ -26,9 +32,13 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
26
32
|
api_key: str | None = None,
|
|
27
33
|
base_url: str | None = None,
|
|
28
34
|
is_azure: bool = False,
|
|
35
|
+
oauth_redirect_uri: str = "",
|
|
29
36
|
) -> None:
|
|
30
37
|
super().__init__(models, api_key, base_url, is_azure)
|
|
31
38
|
self._current_reasoning_session: str | None = None
|
|
39
|
+
self._current_user_id: int | None = None
|
|
40
|
+
self._mcp_auth_service = MCPAuthService(redirect_uri=oauth_redirect_uri)
|
|
41
|
+
self._pending_auth_servers: list[MCPServer] = []
|
|
32
42
|
|
|
33
43
|
async def process(
|
|
34
44
|
self,
|
|
@@ -37,6 +47,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
37
47
|
files: list[str] | None = None, # noqa: ARG002
|
|
38
48
|
mcp_servers: list[MCPServer] | None = None,
|
|
39
49
|
payload: dict[str, Any] | None = None,
|
|
50
|
+
user_id: int | None = None,
|
|
40
51
|
) -> AsyncGenerator[Chunk, None]:
|
|
41
52
|
"""Process messages using simplified content accumulator pattern."""
|
|
42
53
|
if not self.client:
|
|
@@ -47,29 +58,45 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
47
58
|
raise ValueError(msg)
|
|
48
59
|
|
|
49
60
|
model = self.models[model_id]
|
|
61
|
+
self._current_user_id = user_id
|
|
62
|
+
self._pending_auth_servers = []
|
|
50
63
|
|
|
51
64
|
try:
|
|
52
65
|
session = await self._create_responses_request(
|
|
53
|
-
messages, model, mcp_servers, payload
|
|
66
|
+
messages, model, mcp_servers, payload, user_id
|
|
54
67
|
)
|
|
55
68
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
69
|
+
try:
|
|
70
|
+
if hasattr(session, "__aiter__"): # Streaming
|
|
71
|
+
async for event in session:
|
|
72
|
+
chunk = self._handle_event(event)
|
|
73
|
+
if chunk:
|
|
74
|
+
yield chunk
|
|
75
|
+
else: # Non-streaming
|
|
76
|
+
content = self._extract_responses_content(session)
|
|
77
|
+
if content:
|
|
78
|
+
yield Chunk(
|
|
79
|
+
type=ChunkType.TEXT,
|
|
80
|
+
text=content,
|
|
81
|
+
chunk_metadata={
|
|
82
|
+
"source": "responses_api",
|
|
83
|
+
"streaming": "false",
|
|
84
|
+
},
|
|
85
|
+
)
|
|
86
|
+
except Exception as e:
|
|
87
|
+
logger.error("Error during response processing: %s", e)
|
|
88
|
+
# Continue to yield auth chunks if any
|
|
89
|
+
|
|
90
|
+
# After processing (or on error), yield any pending auth requirements
|
|
91
|
+
logger.debug(
|
|
92
|
+
"Processing pending auth servers: %d", len(self._pending_auth_servers)
|
|
93
|
+
)
|
|
94
|
+
for server in self._pending_auth_servers:
|
|
95
|
+
logger.debug("Yielding auth chunk for server: %s", server.name)
|
|
96
|
+
yield await self._create_auth_required_chunk(server)
|
|
97
|
+
|
|
72
98
|
except Exception as e:
|
|
99
|
+
logger.error("Critical error in OpenAI processor: %s", e)
|
|
73
100
|
raise e
|
|
74
101
|
|
|
75
102
|
def _handle_event(self, event: Any) -> Chunk | None:
|
|
@@ -78,9 +105,6 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
78
105
|
return None
|
|
79
106
|
|
|
80
107
|
event_type = event.type
|
|
81
|
-
logger.debug("Event: %s", event)
|
|
82
|
-
|
|
83
|
-
# Try different handlers in order
|
|
84
108
|
handlers = [
|
|
85
109
|
self._handle_lifecycle_events,
|
|
86
110
|
lambda et: self._handle_text_events(et, event),
|
|
@@ -94,16 +118,15 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
94
118
|
for handler in handlers:
|
|
95
119
|
result = handler(event_type)
|
|
96
120
|
if result:
|
|
97
|
-
content_preview = result.text[:50] if result.text else ""
|
|
98
|
-
logger.
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
)
|
|
121
|
+
# content_preview = result.text[:50] if result.text else ""
|
|
122
|
+
# logger.debug(
|
|
123
|
+
# "Event %s → Chunk: type=%s, content=%s",
|
|
124
|
+
# event_type,
|
|
125
|
+
# result.type,
|
|
126
|
+
# content_preview,
|
|
127
|
+
# )
|
|
104
128
|
return result
|
|
105
129
|
|
|
106
|
-
# Log unhandled events for debugging
|
|
107
130
|
logger.debug("Unhandled event type: %s", event_type)
|
|
108
131
|
return None
|
|
109
132
|
|
|
@@ -205,14 +228,36 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
205
228
|
return None
|
|
206
229
|
|
|
207
230
|
def _handle_mcp_call_done(self, item: Any) -> Chunk | None:
|
|
208
|
-
"""Handle MCP call completion.
|
|
231
|
+
"""Handle MCP call completion.
|
|
232
|
+
|
|
233
|
+
Detects 401/403 authentication errors and marks servers for auth flow.
|
|
234
|
+
"""
|
|
209
235
|
tool_id = getattr(item, "id", "unknown_id")
|
|
210
236
|
tool_name = getattr(item, "name", "unknown_tool")
|
|
237
|
+
server_label = getattr(item, "server_label", "unknown_server")
|
|
211
238
|
error = getattr(item, "error", None)
|
|
212
239
|
output = getattr(item, "output", None)
|
|
213
240
|
|
|
214
241
|
if error:
|
|
215
242
|
error_text = self._extract_error_text(error)
|
|
243
|
+
|
|
244
|
+
# Check for authentication errors (401/403)
|
|
245
|
+
if self._is_auth_error(error):
|
|
246
|
+
# Find the server config and queue for auth flow
|
|
247
|
+
return self._create_chunk(
|
|
248
|
+
ChunkType.TOOL_RESULT,
|
|
249
|
+
f"Authentifizierung erforderlich für {server_label}",
|
|
250
|
+
{
|
|
251
|
+
"tool_id": tool_id,
|
|
252
|
+
"tool_name": tool_name,
|
|
253
|
+
"server_label": server_label,
|
|
254
|
+
"status": "auth_required",
|
|
255
|
+
"error": True,
|
|
256
|
+
"auth_required": True,
|
|
257
|
+
"reasoning_session": self._current_reasoning_session,
|
|
258
|
+
},
|
|
259
|
+
)
|
|
260
|
+
|
|
216
261
|
return self._create_chunk(
|
|
217
262
|
ChunkType.TOOL_RESULT,
|
|
218
263
|
f"Werkzeugfehler: {error_text}",
|
|
@@ -238,6 +283,21 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
238
283
|
},
|
|
239
284
|
)
|
|
240
285
|
|
|
286
|
+
def _is_auth_error(self, error: Any) -> bool:
|
|
287
|
+
"""Check if an error indicates authentication failure (401/403)."""
|
|
288
|
+
error_str = str(error).lower()
|
|
289
|
+
auth_indicators = [
|
|
290
|
+
"401",
|
|
291
|
+
"403",
|
|
292
|
+
"unauthorized",
|
|
293
|
+
"forbidden",
|
|
294
|
+
"authentication required",
|
|
295
|
+
"access denied",
|
|
296
|
+
"invalid token",
|
|
297
|
+
"token expired",
|
|
298
|
+
]
|
|
299
|
+
return any(indicator in error_str for indicator in auth_indicators)
|
|
300
|
+
|
|
241
301
|
def _extract_error_text(self, error: Any) -> str:
|
|
242
302
|
"""Extract readable error text from error object."""
|
|
243
303
|
if isinstance(error, dict) and "content" in error:
|
|
@@ -246,7 +306,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
246
306
|
return content[0].get("text", str(error))
|
|
247
307
|
return "Unknown error"
|
|
248
308
|
|
|
249
|
-
def _handle_mcp_events(self, event_type: str, event: Any) -> Chunk | None:
|
|
309
|
+
def _handle_mcp_events(self, event_type: str, event: Any) -> Chunk | None: # noqa: PLR0911, PLR0912
|
|
250
310
|
"""Handle MCP-specific events."""
|
|
251
311
|
if event_type == "response.mcp_call_arguments.delta":
|
|
252
312
|
tool_id = getattr(event, "item_id", "unknown_id")
|
|
@@ -315,6 +375,70 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
315
375
|
|
|
316
376
|
if event_type == "response.mcp_list_tools.failed":
|
|
317
377
|
tool_id = getattr(event, "item_id", "unknown_id")
|
|
378
|
+
error = getattr(event, "error", None)
|
|
379
|
+
|
|
380
|
+
# Debugging: Log available attributes to help diagnosis
|
|
381
|
+
if logger.isEnabledFor(logging.DEBUG) and error:
|
|
382
|
+
logger.debug("Error object type: %s, content: %s", type(error), error)
|
|
383
|
+
|
|
384
|
+
# Extract error message safely
|
|
385
|
+
error_str = ""
|
|
386
|
+
if error:
|
|
387
|
+
if isinstance(error, dict):
|
|
388
|
+
error_str = error.get("message", str(error))
|
|
389
|
+
elif hasattr(error, "message"):
|
|
390
|
+
error_str = getattr(error, "message", str(error))
|
|
391
|
+
else:
|
|
392
|
+
error_str = str(error)
|
|
393
|
+
|
|
394
|
+
# Check for authentication errors (401/403)
|
|
395
|
+
# OR if we have pending auth servers (strong signal we missed a token)
|
|
396
|
+
is_auth_error = self._is_auth_error(error_str)
|
|
397
|
+
pending_server = None
|
|
398
|
+
|
|
399
|
+
# 1. Try to find matching server by name in error message
|
|
400
|
+
for server in self._pending_auth_servers:
|
|
401
|
+
if server.name.lower() in error_str.lower():
|
|
402
|
+
pending_server = server
|
|
403
|
+
break
|
|
404
|
+
|
|
405
|
+
# 2. If no match but we have pending servers and it looks like an auth error
|
|
406
|
+
# OR if we have pending servers and likely one of them failed (len=1)
|
|
407
|
+
# We assume the failure belongs to the pending server if we can't be sure
|
|
408
|
+
if (
|
|
409
|
+
not pending_server
|
|
410
|
+
and self._pending_auth_servers
|
|
411
|
+
and (is_auth_error or len(self._pending_auth_servers) == 1)
|
|
412
|
+
):
|
|
413
|
+
pending_server = self._pending_auth_servers[0]
|
|
414
|
+
logger.debug(
|
|
415
|
+
"Assuming pending server %s for list_tools failure '%s'",
|
|
416
|
+
pending_server.name,
|
|
417
|
+
error_str,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
if pending_server:
|
|
421
|
+
logger.debug(
|
|
422
|
+
"Queuing Auth Card for server: %s (Error: %s)",
|
|
423
|
+
pending_server.name,
|
|
424
|
+
error_str,
|
|
425
|
+
)
|
|
426
|
+
# Queue for async processing in the main process loop
|
|
427
|
+
# The auth chunk will be yielded after event processing completes
|
|
428
|
+
if pending_server not in self._pending_auth_servers:
|
|
429
|
+
self._pending_auth_servers.append(pending_server)
|
|
430
|
+
return self._create_chunk(
|
|
431
|
+
ChunkType.TOOL_RESULT,
|
|
432
|
+
f"Authentifizierung erforderlich für {pending_server.name}",
|
|
433
|
+
{
|
|
434
|
+
"tool_id": tool_id,
|
|
435
|
+
"status": "auth_required",
|
|
436
|
+
"server_name": pending_server.name,
|
|
437
|
+
"auth_pending": True,
|
|
438
|
+
"reasoning_session": self._current_reasoning_session,
|
|
439
|
+
},
|
|
440
|
+
)
|
|
441
|
+
|
|
318
442
|
logger.error("MCP tool listing failed for tool_id: %s", str(event))
|
|
319
443
|
return self._create_chunk(
|
|
320
444
|
ChunkType.TOOL_RESULT,
|
|
@@ -396,11 +520,14 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
396
520
|
model: AIModel,
|
|
397
521
|
mcp_servers: list[MCPServer] | None = None,
|
|
398
522
|
payload: dict[str, Any] | None = None,
|
|
523
|
+
user_id: int | None = None,
|
|
399
524
|
) -> Any:
|
|
400
525
|
"""Create a simplified responses API request."""
|
|
401
|
-
# Configure MCP tools if provided
|
|
526
|
+
# Configure MCP tools if provided (now async for token lookup)
|
|
402
527
|
tools, mcp_prompt = (
|
|
403
|
-
self._configure_mcp_tools(mcp_servers
|
|
528
|
+
await self._configure_mcp_tools(mcp_servers, user_id)
|
|
529
|
+
if mcp_servers
|
|
530
|
+
else ([], "")
|
|
404
531
|
)
|
|
405
532
|
|
|
406
533
|
# Convert messages to responses format with system message
|
|
@@ -418,14 +545,17 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
418
545
|
**(payload or {}),
|
|
419
546
|
}
|
|
420
547
|
|
|
421
|
-
logger.debug("Responses API request params: %s", params)
|
|
422
548
|
return await self.client.responses.create(**params)
|
|
423
549
|
|
|
424
|
-
def _configure_mcp_tools(
|
|
425
|
-
self,
|
|
550
|
+
async def _configure_mcp_tools(
|
|
551
|
+
self,
|
|
552
|
+
mcp_servers: list[MCPServer] | None,
|
|
553
|
+
user_id: int | None = None,
|
|
426
554
|
) -> tuple[list[dict[str, Any]], str]:
|
|
427
555
|
"""Configure MCP servers as tools for the responses API.
|
|
428
556
|
|
|
557
|
+
Injects OAuth Bearer tokens for servers that require authentication.
|
|
558
|
+
|
|
429
559
|
Returns:
|
|
430
560
|
tuple: (tools list, concatenated prompts string)
|
|
431
561
|
"""
|
|
@@ -434,6 +564,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
434
564
|
|
|
435
565
|
tools = []
|
|
436
566
|
prompts = []
|
|
567
|
+
|
|
437
568
|
for server in mcp_servers:
|
|
438
569
|
tool_config = {
|
|
439
570
|
"type": "mcp",
|
|
@@ -442,8 +573,28 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
442
573
|
"require_approval": "never",
|
|
443
574
|
}
|
|
444
575
|
|
|
576
|
+
# Start with existing headers
|
|
577
|
+
headers = {}
|
|
445
578
|
if server.headers and server.headers != "{}":
|
|
446
|
-
|
|
579
|
+
headers = json.loads(server.headers)
|
|
580
|
+
|
|
581
|
+
# Inject OAuth token if server requires OAuth and user is authenticated
|
|
582
|
+
if server.auth_type == MCPAuthType.OAUTH_DISCOVERY and user_id is not None:
|
|
583
|
+
token = await self._get_valid_token_for_server(server, user_id)
|
|
584
|
+
if token:
|
|
585
|
+
headers["Authorization"] = f"Bearer {token.access_token}"
|
|
586
|
+
logger.debug("Injected OAuth token for server %s", server.name)
|
|
587
|
+
else:
|
|
588
|
+
# No valid token - server will likely fail with 401
|
|
589
|
+
# Track this server for potential auth flow
|
|
590
|
+
self._pending_auth_servers.append(server)
|
|
591
|
+
logger.debug(
|
|
592
|
+
"No valid token for OAuth server %s, auth may be required",
|
|
593
|
+
server.name,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
if headers:
|
|
597
|
+
tool_config["headers"] = headers
|
|
447
598
|
|
|
448
599
|
tools.append(tool_config)
|
|
449
600
|
|
|
@@ -511,3 +662,81 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
|
|
|
511
662
|
return first_output.content[0].get("text", "")
|
|
512
663
|
return str(first_output.content)
|
|
513
664
|
return None
|
|
665
|
+
|
|
666
|
+
async def _get_valid_token_for_server(
|
|
667
|
+
self,
|
|
668
|
+
server: MCPServer,
|
|
669
|
+
user_id: int,
|
|
670
|
+
) -> AssistantMCPUserToken | None:
|
|
671
|
+
"""Get a valid OAuth token for the given server and user.
|
|
672
|
+
|
|
673
|
+
Refreshes the token if expired and refresh token is available.
|
|
674
|
+
|
|
675
|
+
Args:
|
|
676
|
+
server: The MCP server configuration.
|
|
677
|
+
user_id: The user's ID.
|
|
678
|
+
|
|
679
|
+
Returns:
|
|
680
|
+
A valid token or None if not available.
|
|
681
|
+
"""
|
|
682
|
+
if server.id is None:
|
|
683
|
+
return None
|
|
684
|
+
|
|
685
|
+
with rx.session() as session:
|
|
686
|
+
token = self._mcp_auth_service.get_user_token(session, user_id, server.id)
|
|
687
|
+
|
|
688
|
+
if token is None:
|
|
689
|
+
return None
|
|
690
|
+
|
|
691
|
+
# Check if token is valid or can be refreshed
|
|
692
|
+
return await self._mcp_auth_service.ensure_valid_token(
|
|
693
|
+
session, server, token
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
async def _create_auth_required_chunk(self, server: MCPServer) -> Chunk:
|
|
697
|
+
"""Create an AUTH_REQUIRED chunk for a server that needs authentication.
|
|
698
|
+
|
|
699
|
+
Args:
|
|
700
|
+
server: The MCP server requiring authentication.
|
|
701
|
+
|
|
702
|
+
Returns:
|
|
703
|
+
A chunk signaling auth is required with the auth URL.
|
|
704
|
+
"""
|
|
705
|
+
# Build the authorization URL
|
|
706
|
+
try:
|
|
707
|
+
# We use a session to store the PKCE state
|
|
708
|
+
# NOTE: rx.session() is for Reflex user session, not DB session.
|
|
709
|
+
# We use get_session_manager().session() for DB access required by PKCE.
|
|
710
|
+
with get_session_manager().session() as session:
|
|
711
|
+
# Use the async method that supports DCR
|
|
712
|
+
auth_service = self._mcp_auth_service
|
|
713
|
+
(
|
|
714
|
+
auth_url,
|
|
715
|
+
state,
|
|
716
|
+
) = await auth_service.build_authorization_url_with_registration(
|
|
717
|
+
server,
|
|
718
|
+
session=session,
|
|
719
|
+
user_id=self._current_user_id,
|
|
720
|
+
)
|
|
721
|
+
logger.info(
|
|
722
|
+
"Built auth URL for server %s, state=%s, url=%s",
|
|
723
|
+
server.name,
|
|
724
|
+
state,
|
|
725
|
+
auth_url[:100] if auth_url else "None",
|
|
726
|
+
)
|
|
727
|
+
except (ValueError, Exception) as e:
|
|
728
|
+
logger.error("Cannot build auth URL for server %s: %s", server.name, str(e))
|
|
729
|
+
auth_url = ""
|
|
730
|
+
state = ""
|
|
731
|
+
|
|
732
|
+
return Chunk(
|
|
733
|
+
type=ChunkType.AUTH_REQUIRED,
|
|
734
|
+
text=f"{server.name} benötigt Ihre Autorisierung",
|
|
735
|
+
chunk_metadata={
|
|
736
|
+
"server_id": str(server.id) if server.id else "",
|
|
737
|
+
"server_name": server.name,
|
|
738
|
+
"auth_url": auth_url,
|
|
739
|
+
"state": state,
|
|
740
|
+
"processor": "openai_responses",
|
|
741
|
+
},
|
|
742
|
+
)
|
|
@@ -43,7 +43,7 @@ class SystemPromptCache:
|
|
|
43
43
|
self._ttl_seconds: int = CACHE_TTL_SECONDS
|
|
44
44
|
self._initialized = True
|
|
45
45
|
|
|
46
|
-
logger.
|
|
46
|
+
logger.debug(
|
|
47
47
|
"SystemPromptCache initialized with TTL=%d seconds",
|
|
48
48
|
self._ttl_seconds,
|
|
49
49
|
)
|
|
@@ -80,7 +80,7 @@ class SystemPromptCache:
|
|
|
80
80
|
return self._cached_prompt
|
|
81
81
|
|
|
82
82
|
# Cache miss or expired - fetch from database
|
|
83
|
-
logger.
|
|
83
|
+
logger.debug("Cache miss - fetching latest prompt from database")
|
|
84
84
|
|
|
85
85
|
async with get_asyncdb_session() as session:
|
|
86
86
|
latest_prompt = await system_prompt_repo.find_latest(session)
|
|
@@ -102,7 +102,7 @@ class SystemPromptCache:
|
|
|
102
102
|
self._cached_version = prompt_version
|
|
103
103
|
self._cache_timestamp = datetime.now(UTC)
|
|
104
104
|
|
|
105
|
-
logger.
|
|
105
|
+
logger.debug(
|
|
106
106
|
"Cached prompt version %d (%d characters)",
|
|
107
107
|
self._cached_version,
|
|
108
108
|
len(self._cached_prompt),
|
|
@@ -118,7 +118,7 @@ class SystemPromptCache:
|
|
|
118
118
|
"""
|
|
119
119
|
async with self._lock:
|
|
120
120
|
if self._cached_prompt is not None:
|
|
121
|
-
logger.
|
|
121
|
+
logger.debug(
|
|
122
122
|
"Cache invalidated (was version %d)",
|
|
123
123
|
self._cached_version,
|
|
124
124
|
)
|
|
@@ -135,7 +135,7 @@ class SystemPromptCache:
|
|
|
135
135
|
seconds: New TTL in seconds.
|
|
136
136
|
"""
|
|
137
137
|
self._ttl_seconds = seconds
|
|
138
|
-
logger.
|
|
138
|
+
logger.debug("Cache TTL updated to %d seconds", seconds)
|
|
139
139
|
|
|
140
140
|
@property
|
|
141
141
|
def is_cached(self) -> bool:
|