appkit-assistant 0.14.1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import threading
5
+ from typing import Self
5
6
 
6
7
  from appkit_assistant.backend.models import AIModel
7
8
  from appkit_assistant.backend.processor import Processor
@@ -18,7 +19,7 @@ class ModelManager:
18
19
  None # Default model ID will be set to the first registered model
19
20
  )
20
21
 
21
- def __new__(cls) -> ModelManager:
22
+ def __new__(cls) -> Self:
22
23
  if cls._instance is None:
23
24
  with cls._lock:
24
25
  if cls._instance is None:
@@ -44,6 +44,7 @@ class ChunkType(StrEnum):
44
44
  TOOL_RESULT = "tool_result" # result from a tool
45
45
  TOOL_CALL = "tool_call" # calling a tool
46
46
  COMPLETION = "completion" # when response generation is complete
47
+ AUTH_REQUIRED = "auth_required" # user needs to authenticate (MCP)
47
48
  ERROR = "error" # when an error occurs
48
49
  LIFECYCLE = "lifecycle"
49
50
 
@@ -100,6 +101,7 @@ class AIModel(BaseModel):
100
101
  supports_attachments: bool = False
101
102
  keywords: list[str] = []
102
103
  disabled: bool = False
104
+ requires_role: str | None = None
103
105
 
104
106
 
105
107
  class Suggestion(BaseModel):
@@ -143,6 +145,12 @@ class MCPServer(rx.Model, table=True):
143
145
  # Optional discovery URL override
144
146
  discovery_url: str | None = Field(default=None, nullable=True)
145
147
 
148
+ # OAuth client credentials (encrypted)
149
+ oauth_client_id: str | None = Field(default=None, nullable=True)
150
+ oauth_client_secret: str | None = Field(
151
+ default=None, nullable=True, sa_type=EncryptedString
152
+ )
153
+
146
154
  # Cached OAuth/Discovery metadata (read-only for user mostly)
147
155
  oauth_issuer: str | None = Field(default=None, nullable=True)
148
156
  oauth_authorize_url: str | None = Field(default=None, nullable=True)
@@ -194,3 +202,38 @@ class AssistantThread(rx.Model, table=True):
194
202
  default_factory=lambda: datetime.now(UTC),
195
203
  sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
196
204
  )
205
+
206
+
207
+ class AssistantMCPUserToken(rx.Model, table=True):
208
+ """Model for storing user-specific OAuth tokens for MCP servers.
209
+
210
+ Each user can have one token per MCP server. Tokens are encrypted at rest.
211
+ """
212
+
213
+ __tablename__ = "assistant_mcp_user_token"
214
+
215
+ id: int | None = Field(default=None, primary_key=True)
216
+ user_id: int = Field(index=True, nullable=False)
217
+ mcp_server_id: int = Field(
218
+ index=True, nullable=False, foreign_key="assistant_mcp_servers.id"
219
+ )
220
+
221
+ # Tokens are encrypted at rest
222
+ access_token: str = Field(nullable=False, sa_type=EncryptedString)
223
+ refresh_token: str | None = Field(
224
+ default=None, nullable=True, sa_type=EncryptedString
225
+ )
226
+
227
+ # Token expiry timestamp
228
+ expires_at: datetime = Field(
229
+ sa_column=Column(DateTime(timezone=True), nullable=False)
230
+ )
231
+
232
+ created_at: datetime = Field(
233
+ default_factory=lambda: datetime.now(UTC),
234
+ sa_column=Column(DateTime(timezone=True)),
235
+ )
236
+ updated_at: datetime = Field(
237
+ default_factory=lambda: datetime.now(UTC),
238
+ sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
239
+ )
@@ -3,16 +3,22 @@ import logging
3
3
  from collections.abc import AsyncGenerator
4
4
  from typing import Any
5
5
 
6
+ import reflex as rx
7
+
8
+ from appkit_assistant.backend.mcp_auth_service import MCPAuthService
6
9
  from appkit_assistant.backend.models import (
7
10
  AIModel,
11
+ AssistantMCPUserToken,
8
12
  Chunk,
9
13
  ChunkType,
14
+ MCPAuthType,
10
15
  MCPServer,
11
16
  Message,
12
17
  MessageType,
13
18
  )
14
19
  from appkit_assistant.backend.processors.openai_base import BaseOpenAIProcessor
15
20
  from appkit_assistant.backend.system_prompt_cache import get_system_prompt
21
+ from appkit_commons.database.session import get_session_manager
16
22
 
17
23
  logger = logging.getLogger(__name__)
18
24
 
@@ -26,9 +32,13 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
26
32
  api_key: str | None = None,
27
33
  base_url: str | None = None,
28
34
  is_azure: bool = False,
35
+ oauth_redirect_uri: str = "",
29
36
  ) -> None:
30
37
  super().__init__(models, api_key, base_url, is_azure)
31
38
  self._current_reasoning_session: str | None = None
39
+ self._current_user_id: int | None = None
40
+ self._mcp_auth_service = MCPAuthService(redirect_uri=oauth_redirect_uri)
41
+ self._pending_auth_servers: list[MCPServer] = []
32
42
 
33
43
  async def process(
34
44
  self,
@@ -37,6 +47,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
37
47
  files: list[str] | None = None, # noqa: ARG002
38
48
  mcp_servers: list[MCPServer] | None = None,
39
49
  payload: dict[str, Any] | None = None,
50
+ user_id: int | None = None,
40
51
  ) -> AsyncGenerator[Chunk, None]:
41
52
  """Process messages using simplified content accumulator pattern."""
42
53
  if not self.client:
@@ -47,29 +58,45 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
47
58
  raise ValueError(msg)
48
59
 
49
60
  model = self.models[model_id]
61
+ self._current_user_id = user_id
62
+ self._pending_auth_servers = []
50
63
 
51
64
  try:
52
65
  session = await self._create_responses_request(
53
- messages, model, mcp_servers, payload
66
+ messages, model, mcp_servers, payload, user_id
54
67
  )
55
68
 
56
- if hasattr(session, "__aiter__"): # Streaming
57
- async for event in session:
58
- chunk = self._handle_event(event)
59
- if chunk:
60
- yield chunk
61
- else: # Non-streaming
62
- content = self._extract_responses_content(session)
63
- if content:
64
- yield Chunk(
65
- type=ChunkType.TEXT,
66
- text=content,
67
- chunk_metadata={
68
- "source": "responses_api",
69
- "streaming": "false",
70
- },
71
- )
69
+ try:
70
+ if hasattr(session, "__aiter__"): # Streaming
71
+ async for event in session:
72
+ chunk = self._handle_event(event)
73
+ if chunk:
74
+ yield chunk
75
+ else: # Non-streaming
76
+ content = self._extract_responses_content(session)
77
+ if content:
78
+ yield Chunk(
79
+ type=ChunkType.TEXT,
80
+ text=content,
81
+ chunk_metadata={
82
+ "source": "responses_api",
83
+ "streaming": "false",
84
+ },
85
+ )
86
+ except Exception as e:
87
+ logger.error("Error during response processing: %s", e)
88
+ # Continue to yield auth chunks if any
89
+
90
+ # After processing (or on error), yield any pending auth requirements
91
+ logger.debug(
92
+ "Processing pending auth servers: %d", len(self._pending_auth_servers)
93
+ )
94
+ for server in self._pending_auth_servers:
95
+ logger.debug("Yielding auth chunk for server: %s", server.name)
96
+ yield await self._create_auth_required_chunk(server)
97
+
72
98
  except Exception as e:
99
+ logger.error("Critical error in OpenAI processor: %s", e)
73
100
  raise e
74
101
 
75
102
  def _handle_event(self, event: Any) -> Chunk | None:
@@ -78,9 +105,6 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
78
105
  return None
79
106
 
80
107
  event_type = event.type
81
- logger.debug("Event: %s", event)
82
-
83
- # Try different handlers in order
84
108
  handlers = [
85
109
  self._handle_lifecycle_events,
86
110
  lambda et: self._handle_text_events(et, event),
@@ -94,16 +118,15 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
94
118
  for handler in handlers:
95
119
  result = handler(event_type)
96
120
  if result:
97
- content_preview = result.text[:50] if result.text else ""
98
- logger.info(
99
- "Event %s → Chunk: type=%s, content=%s",
100
- event_type,
101
- result.type,
102
- content_preview,
103
- )
121
+ # content_preview = result.text[:50] if result.text else ""
122
+ # logger.debug(
123
+ # "Event %s → Chunk: type=%s, content=%s",
124
+ # event_type,
125
+ # result.type,
126
+ # content_preview,
127
+ # )
104
128
  return result
105
129
 
106
- # Log unhandled events for debugging
107
130
  logger.debug("Unhandled event type: %s", event_type)
108
131
  return None
109
132
 
@@ -205,14 +228,36 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
205
228
  return None
206
229
 
207
230
  def _handle_mcp_call_done(self, item: Any) -> Chunk | None:
208
- """Handle MCP call completion."""
231
+ """Handle MCP call completion.
232
+
233
+ Detects 401/403 authentication errors and marks servers for auth flow.
234
+ """
209
235
  tool_id = getattr(item, "id", "unknown_id")
210
236
  tool_name = getattr(item, "name", "unknown_tool")
237
+ server_label = getattr(item, "server_label", "unknown_server")
211
238
  error = getattr(item, "error", None)
212
239
  output = getattr(item, "output", None)
213
240
 
214
241
  if error:
215
242
  error_text = self._extract_error_text(error)
243
+
244
+ # Check for authentication errors (401/403)
245
+ if self._is_auth_error(error):
246
+ # Find the server config and queue for auth flow
247
+ return self._create_chunk(
248
+ ChunkType.TOOL_RESULT,
249
+ f"Authentifizierung erforderlich für {server_label}",
250
+ {
251
+ "tool_id": tool_id,
252
+ "tool_name": tool_name,
253
+ "server_label": server_label,
254
+ "status": "auth_required",
255
+ "error": True,
256
+ "auth_required": True,
257
+ "reasoning_session": self._current_reasoning_session,
258
+ },
259
+ )
260
+
216
261
  return self._create_chunk(
217
262
  ChunkType.TOOL_RESULT,
218
263
  f"Werkzeugfehler: {error_text}",
@@ -238,6 +283,21 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
238
283
  },
239
284
  )
240
285
 
286
+ def _is_auth_error(self, error: Any) -> bool:
287
+ """Check if an error indicates authentication failure (401/403)."""
288
+ error_str = str(error).lower()
289
+ auth_indicators = [
290
+ "401",
291
+ "403",
292
+ "unauthorized",
293
+ "forbidden",
294
+ "authentication required",
295
+ "access denied",
296
+ "invalid token",
297
+ "token expired",
298
+ ]
299
+ return any(indicator in error_str for indicator in auth_indicators)
300
+
241
301
  def _extract_error_text(self, error: Any) -> str:
242
302
  """Extract readable error text from error object."""
243
303
  if isinstance(error, dict) and "content" in error:
@@ -246,7 +306,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
246
306
  return content[0].get("text", str(error))
247
307
  return "Unknown error"
248
308
 
249
- def _handle_mcp_events(self, event_type: str, event: Any) -> Chunk | None:
309
+ def _handle_mcp_events(self, event_type: str, event: Any) -> Chunk | None: # noqa: PLR0911, PLR0912
250
310
  """Handle MCP-specific events."""
251
311
  if event_type == "response.mcp_call_arguments.delta":
252
312
  tool_id = getattr(event, "item_id", "unknown_id")
@@ -315,6 +375,70 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
315
375
 
316
376
  if event_type == "response.mcp_list_tools.failed":
317
377
  tool_id = getattr(event, "item_id", "unknown_id")
378
+ error = getattr(event, "error", None)
379
+
380
+ # Debugging: Log available attributes to help diagnosis
381
+ if logger.isEnabledFor(logging.DEBUG) and error:
382
+ logger.debug("Error object type: %s, content: %s", type(error), error)
383
+
384
+ # Extract error message safely
385
+ error_str = ""
386
+ if error:
387
+ if isinstance(error, dict):
388
+ error_str = error.get("message", str(error))
389
+ elif hasattr(error, "message"):
390
+ error_str = getattr(error, "message", str(error))
391
+ else:
392
+ error_str = str(error)
393
+
394
+ # Check for authentication errors (401/403)
395
+ # OR if we have pending auth servers (strong signal we missed a token)
396
+ is_auth_error = self._is_auth_error(error_str)
397
+ pending_server = None
398
+
399
+ # 1. Try to find matching server by name in error message
400
+ for server in self._pending_auth_servers:
401
+ if server.name.lower() in error_str.lower():
402
+ pending_server = server
403
+ break
404
+
405
+ # 2. If no match but we have pending servers and it looks like an auth error
406
+ # OR if we have pending servers and likely one of them failed (len=1)
407
+ # We assume the failure belongs to the pending server if we can't be sure
408
+ if (
409
+ not pending_server
410
+ and self._pending_auth_servers
411
+ and (is_auth_error or len(self._pending_auth_servers) == 1)
412
+ ):
413
+ pending_server = self._pending_auth_servers[0]
414
+ logger.debug(
415
+ "Assuming pending server %s for list_tools failure '%s'",
416
+ pending_server.name,
417
+ error_str,
418
+ )
419
+
420
+ if pending_server:
421
+ logger.debug(
422
+ "Queuing Auth Card for server: %s (Error: %s)",
423
+ pending_server.name,
424
+ error_str,
425
+ )
426
+ # Queue for async processing in the main process loop
427
+ # The auth chunk will be yielded after event processing completes
428
+ if pending_server not in self._pending_auth_servers:
429
+ self._pending_auth_servers.append(pending_server)
430
+ return self._create_chunk(
431
+ ChunkType.TOOL_RESULT,
432
+ f"Authentifizierung erforderlich für {pending_server.name}",
433
+ {
434
+ "tool_id": tool_id,
435
+ "status": "auth_required",
436
+ "server_name": pending_server.name,
437
+ "auth_pending": True,
438
+ "reasoning_session": self._current_reasoning_session,
439
+ },
440
+ )
441
+
318
442
  logger.error("MCP tool listing failed for tool_id: %s", str(event))
319
443
  return self._create_chunk(
320
444
  ChunkType.TOOL_RESULT,
@@ -396,11 +520,14 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
396
520
  model: AIModel,
397
521
  mcp_servers: list[MCPServer] | None = None,
398
522
  payload: dict[str, Any] | None = None,
523
+ user_id: int | None = None,
399
524
  ) -> Any:
400
525
  """Create a simplified responses API request."""
401
- # Configure MCP tools if provided
526
+ # Configure MCP tools if provided (now async for token lookup)
402
527
  tools, mcp_prompt = (
403
- self._configure_mcp_tools(mcp_servers) if mcp_servers else ([], "")
528
+ await self._configure_mcp_tools(mcp_servers, user_id)
529
+ if mcp_servers
530
+ else ([], "")
404
531
  )
405
532
 
406
533
  # Convert messages to responses format with system message
@@ -418,14 +545,17 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
418
545
  **(payload or {}),
419
546
  }
420
547
 
421
- logger.debug("Responses API request params: %s", params)
422
548
  return await self.client.responses.create(**params)
423
549
 
424
- def _configure_mcp_tools(
425
- self, mcp_servers: list[MCPServer] | None
550
+ async def _configure_mcp_tools(
551
+ self,
552
+ mcp_servers: list[MCPServer] | None,
553
+ user_id: int | None = None,
426
554
  ) -> tuple[list[dict[str, Any]], str]:
427
555
  """Configure MCP servers as tools for the responses API.
428
556
 
557
+ Injects OAuth Bearer tokens for servers that require authentication.
558
+
429
559
  Returns:
430
560
  tuple: (tools list, concatenated prompts string)
431
561
  """
@@ -434,6 +564,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
434
564
 
435
565
  tools = []
436
566
  prompts = []
567
+
437
568
  for server in mcp_servers:
438
569
  tool_config = {
439
570
  "type": "mcp",
@@ -442,8 +573,28 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
442
573
  "require_approval": "never",
443
574
  }
444
575
 
576
+ # Start with existing headers
577
+ headers = {}
445
578
  if server.headers and server.headers != "{}":
446
- tool_config["headers"] = json.loads(server.headers)
579
+ headers = json.loads(server.headers)
580
+
581
+ # Inject OAuth token if server requires OAuth and user is authenticated
582
+ if server.auth_type == MCPAuthType.OAUTH_DISCOVERY and user_id is not None:
583
+ token = await self._get_valid_token_for_server(server, user_id)
584
+ if token:
585
+ headers["Authorization"] = f"Bearer {token.access_token}"
586
+ logger.debug("Injected OAuth token for server %s", server.name)
587
+ else:
588
+ # No valid token - server will likely fail with 401
589
+ # Track this server for potential auth flow
590
+ self._pending_auth_servers.append(server)
591
+ logger.debug(
592
+ "No valid token for OAuth server %s, auth may be required",
593
+ server.name,
594
+ )
595
+
596
+ if headers:
597
+ tool_config["headers"] = headers
447
598
 
448
599
  tools.append(tool_config)
449
600
 
@@ -511,3 +662,81 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
511
662
  return first_output.content[0].get("text", "")
512
663
  return str(first_output.content)
513
664
  return None
665
+
666
+ async def _get_valid_token_for_server(
667
+ self,
668
+ server: MCPServer,
669
+ user_id: int,
670
+ ) -> AssistantMCPUserToken | None:
671
+ """Get a valid OAuth token for the given server and user.
672
+
673
+ Refreshes the token if expired and refresh token is available.
674
+
675
+ Args:
676
+ server: The MCP server configuration.
677
+ user_id: The user's ID.
678
+
679
+ Returns:
680
+ A valid token or None if not available.
681
+ """
682
+ if server.id is None:
683
+ return None
684
+
685
+ with rx.session() as session:
686
+ token = self._mcp_auth_service.get_user_token(session, user_id, server.id)
687
+
688
+ if token is None:
689
+ return None
690
+
691
+ # Check if token is valid or can be refreshed
692
+ return await self._mcp_auth_service.ensure_valid_token(
693
+ session, server, token
694
+ )
695
+
696
+ async def _create_auth_required_chunk(self, server: MCPServer) -> Chunk:
697
+ """Create an AUTH_REQUIRED chunk for a server that needs authentication.
698
+
699
+ Args:
700
+ server: The MCP server requiring authentication.
701
+
702
+ Returns:
703
+ A chunk signaling auth is required with the auth URL.
704
+ """
705
+ # Build the authorization URL
706
+ try:
707
+ # We use a session to store the PKCE state
708
+ # NOTE: rx.session() is for Reflex user session, not DB session.
709
+ # We use get_session_manager().session() for DB access required by PKCE.
710
+ with get_session_manager().session() as session:
711
+ # Use the async method that supports DCR
712
+ auth_service = self._mcp_auth_service
713
+ (
714
+ auth_url,
715
+ state,
716
+ ) = await auth_service.build_authorization_url_with_registration(
717
+ server,
718
+ session=session,
719
+ user_id=self._current_user_id,
720
+ )
721
+ logger.info(
722
+ "Built auth URL for server %s, state=%s, url=%s",
723
+ server.name,
724
+ state,
725
+ auth_url[:100] if auth_url else "None",
726
+ )
727
+ except (ValueError, Exception) as e:
728
+ logger.error("Cannot build auth URL for server %s: %s", server.name, str(e))
729
+ auth_url = ""
730
+ state = ""
731
+
732
+ return Chunk(
733
+ type=ChunkType.AUTH_REQUIRED,
734
+ text=f"{server.name} benötigt Ihre Autorisierung",
735
+ chunk_metadata={
736
+ "server_id": str(server.id) if server.id else "",
737
+ "server_name": server.name,
738
+ "auth_url": auth_url,
739
+ "state": state,
740
+ "processor": "openai_responses",
741
+ },
742
+ )
@@ -80,7 +80,7 @@ class SystemPromptRepository(BaseRepository[SystemPrompt, AsyncSession]):
80
80
  await session.flush()
81
81
  await session.refresh(system_prompt)
82
82
 
83
- logger.info(
83
+ logger.debug(
84
84
  "Created system prompt version %s for user %s",
85
85
  next_version,
86
86
  user_id,
@@ -43,7 +43,7 @@ class SystemPromptCache:
43
43
  self._ttl_seconds: int = CACHE_TTL_SECONDS
44
44
  self._initialized = True
45
45
 
46
- logger.info(
46
+ logger.debug(
47
47
  "SystemPromptCache initialized with TTL=%d seconds",
48
48
  self._ttl_seconds,
49
49
  )
@@ -80,7 +80,7 @@ class SystemPromptCache:
80
80
  return self._cached_prompt
81
81
 
82
82
  # Cache miss or expired - fetch from database
83
- logger.info("Cache miss - fetching latest prompt from database")
83
+ logger.debug("Cache miss - fetching latest prompt from database")
84
84
 
85
85
  async with get_asyncdb_session() as session:
86
86
  latest_prompt = await system_prompt_repo.find_latest(session)
@@ -102,7 +102,7 @@ class SystemPromptCache:
102
102
  self._cached_version = prompt_version
103
103
  self._cache_timestamp = datetime.now(UTC)
104
104
 
105
- logger.info(
105
+ logger.debug(
106
106
  "Cached prompt version %d (%d characters)",
107
107
  self._cached_version,
108
108
  len(self._cached_prompt),
@@ -118,7 +118,7 @@ class SystemPromptCache:
118
118
  """
119
119
  async with self._lock:
120
120
  if self._cached_prompt is not None:
121
- logger.info(
121
+ logger.debug(
122
122
  "Cache invalidated (was version %d)",
123
123
  self._cached_version,
124
124
  )
@@ -135,7 +135,7 @@ class SystemPromptCache:
135
135
  seconds: New TTL in seconds.
136
136
  """
137
137
  self._ttl_seconds = seconds
138
- logger.info("Cache TTL updated to %d seconds", seconds)
138
+ logger.debug("Cache TTL updated to %d seconds", seconds)
139
139
 
140
140
  @property
141
141
  def is_cached(self) -> bool: