appkit-assistant 0.14.1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import threading
5
+ from typing import Self
5
6
 
6
7
  from appkit_assistant.backend.models import AIModel
7
8
  from appkit_assistant.backend.processor import Processor
@@ -18,7 +19,7 @@ class ModelManager:
18
19
  None # Default model ID will be set to the first registered model
19
20
  )
20
21
 
21
- def __new__(cls) -> ModelManager:
22
+ def __new__(cls) -> Self:
22
23
  if cls._instance is None:
23
24
  with cls._lock:
24
25
  if cls._instance is None:
@@ -44,6 +44,7 @@ class ChunkType(StrEnum):
44
44
  TOOL_RESULT = "tool_result" # result from a tool
45
45
  TOOL_CALL = "tool_call" # calling a tool
46
46
  COMPLETION = "completion" # when response generation is complete
47
+ AUTH_REQUIRED = "auth_required" # user needs to authenticate (MCP)
47
48
  ERROR = "error" # when an error occurs
48
49
  LIFECYCLE = "lifecycle"
49
50
 
@@ -100,6 +101,7 @@ class AIModel(BaseModel):
100
101
  supports_attachments: bool = False
101
102
  keywords: list[str] = []
102
103
  disabled: bool = False
104
+ requires_role: str | None = None
103
105
 
104
106
 
105
107
  class Suggestion(BaseModel):
@@ -143,6 +145,12 @@ class MCPServer(rx.Model, table=True):
143
145
  # Optional discovery URL override
144
146
  discovery_url: str | None = Field(default=None, nullable=True)
145
147
 
148
+ # OAuth client credentials (encrypted)
149
+ oauth_client_id: str | None = Field(default=None, nullable=True)
150
+ oauth_client_secret: str | None = Field(
151
+ default=None, nullable=True, sa_type=EncryptedString
152
+ )
153
+
146
154
  # Cached OAuth/Discovery metadata (read-only for user mostly)
147
155
  oauth_issuer: str | None = Field(default=None, nullable=True)
148
156
  oauth_authorize_url: str | None = Field(default=None, nullable=True)
@@ -194,3 +202,38 @@ class AssistantThread(rx.Model, table=True):
194
202
  default_factory=lambda: datetime.now(UTC),
195
203
  sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
196
204
  )
205
+
206
+
207
+ class AssistantMCPUserToken(rx.Model, table=True):
208
+ """Model for storing user-specific OAuth tokens for MCP servers.
209
+
210
+ Each user can have one token per MCP server. Tokens are encrypted at rest.
211
+ """
212
+
213
+ __tablename__ = "assistant_mcp_user_token"
214
+
215
+ id: int | None = Field(default=None, primary_key=True)
216
+ user_id: int = Field(index=True, nullable=False)
217
+ mcp_server_id: int = Field(
218
+ index=True, nullable=False, foreign_key="assistant_mcp_servers.id"
219
+ )
220
+
221
+ # Tokens are encrypted at rest
222
+ access_token: str = Field(nullable=False, sa_type=EncryptedString)
223
+ refresh_token: str | None = Field(
224
+ default=None, nullable=True, sa_type=EncryptedString
225
+ )
226
+
227
+ # Token expiry timestamp
228
+ expires_at: datetime = Field(
229
+ sa_column=Column(DateTime(timezone=True), nullable=False)
230
+ )
231
+
232
+ created_at: datetime = Field(
233
+ default_factory=lambda: datetime.now(UTC),
234
+ sa_column=Column(DateTime(timezone=True)),
235
+ )
236
+ updated_at: datetime = Field(
237
+ default_factory=lambda: datetime.now(UTC),
238
+ sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
239
+ )
@@ -3,16 +3,22 @@ import logging
3
3
  from collections.abc import AsyncGenerator
4
4
  from typing import Any
5
5
 
6
+ import reflex as rx
7
+
8
+ from appkit_assistant.backend.mcp_auth_service import MCPAuthService
6
9
  from appkit_assistant.backend.models import (
7
10
  AIModel,
11
+ AssistantMCPUserToken,
8
12
  Chunk,
9
13
  ChunkType,
14
+ MCPAuthType,
10
15
  MCPServer,
11
16
  Message,
12
17
  MessageType,
13
18
  )
14
19
  from appkit_assistant.backend.processors.openai_base import BaseOpenAIProcessor
15
20
  from appkit_assistant.backend.system_prompt_cache import get_system_prompt
21
+ from appkit_commons.database.session import get_session_manager
16
22
 
17
23
  logger = logging.getLogger(__name__)
18
24
 
@@ -26,9 +32,13 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
26
32
  api_key: str | None = None,
27
33
  base_url: str | None = None,
28
34
  is_azure: bool = False,
35
+ oauth_redirect_uri: str = "",
29
36
  ) -> None:
30
37
  super().__init__(models, api_key, base_url, is_azure)
31
38
  self._current_reasoning_session: str | None = None
39
+ self._current_user_id: int | None = None
40
+ self._mcp_auth_service = MCPAuthService(redirect_uri=oauth_redirect_uri)
41
+ self._pending_auth_servers: list[MCPServer] = []
32
42
 
33
43
  async def process(
34
44
  self,
@@ -37,6 +47,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
37
47
  files: list[str] | None = None, # noqa: ARG002
38
48
  mcp_servers: list[MCPServer] | None = None,
39
49
  payload: dict[str, Any] | None = None,
50
+ user_id: int | None = None,
40
51
  ) -> AsyncGenerator[Chunk, None]:
41
52
  """Process messages using simplified content accumulator pattern."""
42
53
  if not self.client:
@@ -47,29 +58,45 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
47
58
  raise ValueError(msg)
48
59
 
49
60
  model = self.models[model_id]
61
+ self._current_user_id = user_id
62
+ self._pending_auth_servers = []
50
63
 
51
64
  try:
52
65
  session = await self._create_responses_request(
53
- messages, model, mcp_servers, payload
66
+ messages, model, mcp_servers, payload, user_id
54
67
  )
55
68
 
56
- if hasattr(session, "__aiter__"): # Streaming
57
- async for event in session:
58
- chunk = self._handle_event(event)
59
- if chunk:
60
- yield chunk
61
- else: # Non-streaming
62
- content = self._extract_responses_content(session)
63
- if content:
64
- yield Chunk(
65
- type=ChunkType.TEXT,
66
- text=content,
67
- chunk_metadata={
68
- "source": "responses_api",
69
- "streaming": "false",
70
- },
71
- )
69
+ try:
70
+ if hasattr(session, "__aiter__"): # Streaming
71
+ async for event in session:
72
+ chunk = self._handle_event(event)
73
+ if chunk:
74
+ yield chunk
75
+ else: # Non-streaming
76
+ content = self._extract_responses_content(session)
77
+ if content:
78
+ yield Chunk(
79
+ type=ChunkType.TEXT,
80
+ text=content,
81
+ chunk_metadata={
82
+ "source": "responses_api",
83
+ "streaming": "false",
84
+ },
85
+ )
86
+ except Exception as e:
87
+ logger.error("Error during response processing: %s", e)
88
+ # Continue to yield auth chunks if any
89
+
90
+ # After processing (or on error), yield any pending auth requirements
91
+ logger.debug(
92
+ "Processing pending auth servers: %d", len(self._pending_auth_servers)
93
+ )
94
+ for server in self._pending_auth_servers:
95
+ logger.debug("Yielding auth chunk for server: %s", server.name)
96
+ yield await self._create_auth_required_chunk(server)
97
+
72
98
  except Exception as e:
99
+ logger.error("Critical error in OpenAI processor: %s", e)
73
100
  raise e
74
101
 
75
102
  def _handle_event(self, event: Any) -> Chunk | None:
@@ -95,7 +122,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
95
122
  result = handler(event_type)
96
123
  if result:
97
124
  content_preview = result.text[:50] if result.text else ""
98
- logger.info(
125
+ logger.debug(
99
126
  "Event %s → Chunk: type=%s, content=%s",
100
127
  event_type,
101
128
  result.type,
@@ -205,14 +232,36 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
205
232
  return None
206
233
 
207
234
  def _handle_mcp_call_done(self, item: Any) -> Chunk | None:
208
- """Handle MCP call completion."""
235
+ """Handle MCP call completion.
236
+
237
+ Detects 401/403 authentication errors and marks servers for auth flow.
238
+ """
209
239
  tool_id = getattr(item, "id", "unknown_id")
210
240
  tool_name = getattr(item, "name", "unknown_tool")
241
+ server_label = getattr(item, "server_label", "unknown_server")
211
242
  error = getattr(item, "error", None)
212
243
  output = getattr(item, "output", None)
213
244
 
214
245
  if error:
215
246
  error_text = self._extract_error_text(error)
247
+
248
+ # Check for authentication errors (401/403)
249
+ if self._is_auth_error(error):
250
+ # Find the server config and queue for auth flow
251
+ return self._create_chunk(
252
+ ChunkType.TOOL_RESULT,
253
+ f"Authentifizierung erforderlich für {server_label}",
254
+ {
255
+ "tool_id": tool_id,
256
+ "tool_name": tool_name,
257
+ "server_label": server_label,
258
+ "status": "auth_required",
259
+ "error": True,
260
+ "auth_required": True,
261
+ "reasoning_session": self._current_reasoning_session,
262
+ },
263
+ )
264
+
216
265
  return self._create_chunk(
217
266
  ChunkType.TOOL_RESULT,
218
267
  f"Werkzeugfehler: {error_text}",
@@ -238,6 +287,21 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
238
287
  },
239
288
  )
240
289
 
290
+ def _is_auth_error(self, error: Any) -> bool:
291
+ """Check if an error indicates authentication failure (401/403)."""
292
+ error_str = str(error).lower()
293
+ auth_indicators = [
294
+ "401",
295
+ "403",
296
+ "unauthorized",
297
+ "forbidden",
298
+ "authentication required",
299
+ "access denied",
300
+ "invalid token",
301
+ "token expired",
302
+ ]
303
+ return any(indicator in error_str for indicator in auth_indicators)
304
+
241
305
  def _extract_error_text(self, error: Any) -> str:
242
306
  """Extract readable error text from error object."""
243
307
  if isinstance(error, dict) and "content" in error:
@@ -246,7 +310,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
246
310
  return content[0].get("text", str(error))
247
311
  return "Unknown error"
248
312
 
249
- def _handle_mcp_events(self, event_type: str, event: Any) -> Chunk | None:
313
+ def _handle_mcp_events(self, event_type: str, event: Any) -> Chunk | None: # noqa: PLR0911, PLR0912
250
314
  """Handle MCP-specific events."""
251
315
  if event_type == "response.mcp_call_arguments.delta":
252
316
  tool_id = getattr(event, "item_id", "unknown_id")
@@ -315,6 +379,70 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
315
379
 
316
380
  if event_type == "response.mcp_list_tools.failed":
317
381
  tool_id = getattr(event, "item_id", "unknown_id")
382
+ error = getattr(event, "error", None)
383
+
384
+ # Debugging: Log available attributes to help diagnosis
385
+ if logger.isEnabledFor(logging.DEBUG) and error:
386
+ logger.debug("Error object type: %s, content: %s", type(error), error)
387
+
388
+ # Extract error message safely
389
+ error_str = ""
390
+ if error:
391
+ if isinstance(error, dict):
392
+ error_str = error.get("message", str(error))
393
+ elif hasattr(error, "message"):
394
+ error_str = getattr(error, "message", str(error))
395
+ else:
396
+ error_str = str(error)
397
+
398
+ # Check for authentication errors (401/403)
399
+ # OR if we have pending auth servers (strong signal we missed a token)
400
+ is_auth_error = self._is_auth_error(error_str)
401
+ pending_server = None
402
+
403
+ # 1. Try to find matching server by name in error message
404
+ for server in self._pending_auth_servers:
405
+ if server.name.lower() in error_str.lower():
406
+ pending_server = server
407
+ break
408
+
409
+ # 2. If no match but we have pending servers and it looks like an auth error
410
+ # OR if we have pending servers and likely one of them failed (len=1)
411
+ # We assume the failure belongs to the pending server if we can't be sure
412
+ if (
413
+ not pending_server
414
+ and self._pending_auth_servers
415
+ and (is_auth_error or len(self._pending_auth_servers) == 1)
416
+ ):
417
+ pending_server = self._pending_auth_servers[0]
418
+ logger.debug(
419
+ "Assuming pending server %s for list_tools failure '%s'",
420
+ pending_server.name,
421
+ error_str,
422
+ )
423
+
424
+ if pending_server:
425
+ logger.debug(
426
+ "Queuing Auth Card for server: %s (Error: %s)",
427
+ pending_server.name,
428
+ error_str,
429
+ )
430
+ # Queue for async processing in the main process loop
431
+ # The auth chunk will be yielded after event processing completes
432
+ if pending_server not in self._pending_auth_servers:
433
+ self._pending_auth_servers.append(pending_server)
434
+ return self._create_chunk(
435
+ ChunkType.TOOL_RESULT,
436
+ f"Authentifizierung erforderlich für {pending_server.name}",
437
+ {
438
+ "tool_id": tool_id,
439
+ "status": "auth_required",
440
+ "server_name": pending_server.name,
441
+ "auth_pending": True,
442
+ "reasoning_session": self._current_reasoning_session,
443
+ },
444
+ )
445
+
318
446
  logger.error("MCP tool listing failed for tool_id: %s", str(event))
319
447
  return self._create_chunk(
320
448
  ChunkType.TOOL_RESULT,
@@ -396,11 +524,14 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
396
524
  model: AIModel,
397
525
  mcp_servers: list[MCPServer] | None = None,
398
526
  payload: dict[str, Any] | None = None,
527
+ user_id: int | None = None,
399
528
  ) -> Any:
400
529
  """Create a simplified responses API request."""
401
- # Configure MCP tools if provided
530
+ # Configure MCP tools if provided (now async for token lookup)
402
531
  tools, mcp_prompt = (
403
- self._configure_mcp_tools(mcp_servers) if mcp_servers else ([], "")
532
+ await self._configure_mcp_tools(mcp_servers, user_id)
533
+ if mcp_servers
534
+ else ([], "")
404
535
  )
405
536
 
406
537
  # Convert messages to responses format with system message
@@ -421,11 +552,15 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
421
552
  logger.debug("Responses API request params: %s", params)
422
553
  return await self.client.responses.create(**params)
423
554
 
424
- def _configure_mcp_tools(
425
- self, mcp_servers: list[MCPServer] | None
555
+ async def _configure_mcp_tools(
556
+ self,
557
+ mcp_servers: list[MCPServer] | None,
558
+ user_id: int | None = None,
426
559
  ) -> tuple[list[dict[str, Any]], str]:
427
560
  """Configure MCP servers as tools for the responses API.
428
561
 
562
+ Injects OAuth Bearer tokens for servers that require authentication.
563
+
429
564
  Returns:
430
565
  tuple: (tools list, concatenated prompts string)
431
566
  """
@@ -434,6 +569,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
434
569
 
435
570
  tools = []
436
571
  prompts = []
572
+
437
573
  for server in mcp_servers:
438
574
  tool_config = {
439
575
  "type": "mcp",
@@ -442,8 +578,28 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
442
578
  "require_approval": "never",
443
579
  }
444
580
 
581
+ # Start with existing headers
582
+ headers = {}
445
583
  if server.headers and server.headers != "{}":
446
- tool_config["headers"] = json.loads(server.headers)
584
+ headers = json.loads(server.headers)
585
+
586
+ # Inject OAuth token if server requires OAuth and user is authenticated
587
+ if server.auth_type == MCPAuthType.OAUTH_DISCOVERY and user_id is not None:
588
+ token = await self._get_valid_token_for_server(server, user_id)
589
+ if token:
590
+ headers["Authorization"] = f"Bearer {token.access_token}"
591
+ logger.debug("Injected OAuth token for server %s", server.name)
592
+ else:
593
+ # No valid token - server will likely fail with 401
594
+ # Track this server for potential auth flow
595
+ self._pending_auth_servers.append(server)
596
+ logger.debug(
597
+ "No valid token for OAuth server %s, auth may be required",
598
+ server.name,
599
+ )
600
+
601
+ if headers:
602
+ tool_config["headers"] = headers
447
603
 
448
604
  tools.append(tool_config)
449
605
 
@@ -511,3 +667,81 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
511
667
  return first_output.content[0].get("text", "")
512
668
  return str(first_output.content)
513
669
  return None
670
+
671
+ async def _get_valid_token_for_server(
672
+ self,
673
+ server: MCPServer,
674
+ user_id: int,
675
+ ) -> AssistantMCPUserToken | None:
676
+ """Get a valid OAuth token for the given server and user.
677
+
678
+ Refreshes the token if expired and refresh token is available.
679
+
680
+ Args:
681
+ server: The MCP server configuration.
682
+ user_id: The user's ID.
683
+
684
+ Returns:
685
+ A valid token or None if not available.
686
+ """
687
+ if server.id is None:
688
+ return None
689
+
690
+ with rx.session() as session:
691
+ token = self._mcp_auth_service.get_user_token(session, user_id, server.id)
692
+
693
+ if token is None:
694
+ return None
695
+
696
+ # Check if token is valid or can be refreshed
697
+ return await self._mcp_auth_service.ensure_valid_token(
698
+ session, server, token
699
+ )
700
+
701
+ async def _create_auth_required_chunk(self, server: MCPServer) -> Chunk:
702
+ """Create an AUTH_REQUIRED chunk for a server that needs authentication.
703
+
704
+ Args:
705
+ server: The MCP server requiring authentication.
706
+
707
+ Returns:
708
+ A chunk signaling auth is required with the auth URL.
709
+ """
710
+ # Build the authorization URL
711
+ try:
712
+ # We use a session to store the PKCE state
713
+ # NOTE: rx.session() is for Reflex user session, not DB session.
714
+ # We use get_session_manager().session() for DB access required by PKCE.
715
+ with get_session_manager().session() as session:
716
+ # Use the async method that supports DCR
717
+ auth_service = self._mcp_auth_service
718
+ (
719
+ auth_url,
720
+ state,
721
+ ) = await auth_service.build_authorization_url_with_registration(
722
+ server,
723
+ session=session,
724
+ user_id=self._current_user_id,
725
+ )
726
+ logger.info(
727
+ "Built auth URL for server %s, state=%s, url=%s",
728
+ server.name,
729
+ state,
730
+ auth_url[:100] if auth_url else "None",
731
+ )
732
+ except (ValueError, Exception) as e:
733
+ logger.error("Cannot build auth URL for server %s: %s", server.name, str(e))
734
+ auth_url = ""
735
+ state = ""
736
+
737
+ return Chunk(
738
+ type=ChunkType.AUTH_REQUIRED,
739
+ text=f"{server.name} benötigt Ihre Autorisierung",
740
+ chunk_metadata={
741
+ "server_id": str(server.id) if server.id else "",
742
+ "server_name": server.name,
743
+ "auth_url": auth_url,
744
+ "state": state,
745
+ "processor": "openai_responses",
746
+ },
747
+ )
@@ -80,7 +80,7 @@ class SystemPromptRepository(BaseRepository[SystemPrompt, AsyncSession]):
80
80
  await session.flush()
81
81
  await session.refresh(system_prompt)
82
82
 
83
- logger.info(
83
+ logger.debug(
84
84
  "Created system prompt version %s for user %s",
85
85
  next_version,
86
86
  user_id,
@@ -43,7 +43,7 @@ class SystemPromptCache:
43
43
  self._ttl_seconds: int = CACHE_TTL_SECONDS
44
44
  self._initialized = True
45
45
 
46
- logger.info(
46
+ logger.debug(
47
47
  "SystemPromptCache initialized with TTL=%d seconds",
48
48
  self._ttl_seconds,
49
49
  )
@@ -80,7 +80,7 @@ class SystemPromptCache:
80
80
  return self._cached_prompt
81
81
 
82
82
  # Cache miss or expired - fetch from database
83
- logger.info("Cache miss - fetching latest prompt from database")
83
+ logger.debug("Cache miss - fetching latest prompt from database")
84
84
 
85
85
  async with get_asyncdb_session() as session:
86
86
  latest_prompt = await system_prompt_repo.find_latest(session)
@@ -102,7 +102,7 @@ class SystemPromptCache:
102
102
  self._cached_version = prompt_version
103
103
  self._cache_timestamp = datetime.now(UTC)
104
104
 
105
- logger.info(
105
+ logger.debug(
106
106
  "Cached prompt version %d (%d characters)",
107
107
  self._cached_version,
108
108
  len(self._cached_prompt),
@@ -118,7 +118,7 @@ class SystemPromptCache:
118
118
  """
119
119
  async with self._lock:
120
120
  if self._cached_prompt is not None:
121
- logger.info(
121
+ logger.debug(
122
122
  "Cache invalidated (was version %d)",
123
123
  self._cached_version,
124
124
  )
@@ -135,7 +135,7 @@ class SystemPromptCache:
135
135
  seconds: New TTL in seconds.
136
136
  """
137
137
  self._ttl_seconds = seconds
138
- logger.info("Cache TTL updated to %d seconds", seconds)
138
+ logger.debug("Cache TTL updated to %d seconds", seconds)
139
139
 
140
140
  @property
141
141
  def is_cached(self) -> bool: