appkit-assistant 0.8.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,34 @@
1
+ import json
2
+ from datetime import UTC, datetime
1
3
  from enum import StrEnum
4
+ from typing import Any
2
5
 
3
6
  import reflex as rx
4
7
  from pydantic import BaseModel
5
- from sqlmodel import Field
8
+ from sqlalchemy.sql import func
9
+ from sqlmodel import Column, DateTime, Field
6
10
 
11
+ from appkit_commons.database.configuration import DatabaseConfig
7
12
  from appkit_commons.database.entities import EncryptedString
13
+ from appkit_commons.registry import service_registry
14
+
15
+ db_config = service_registry().get(DatabaseConfig)
16
+ SECRET_VALUE = db_config.encryption_key.get_secret_value()
17
+
18
+
19
+ class EncryptedJSON(EncryptedString):
20
+ """Custom type for storing encrypted JSON data."""
21
+
22
+ def process_bind_param(self, value: Any, dialect: Any) -> str | None:
23
+ if value is not None:
24
+ value = json.dumps(value)
25
+ return super().process_bind_param(value, dialect)
26
+
27
+ def process_result_value(self, value: Any, dialect: Any) -> Any | None:
28
+ value = super().process_result_value(value, dialect)
29
+ if value is not None:
30
+ return json.loads(value)
31
+ return value
8
32
 
9
33
 
10
34
  class ChunkType(StrEnum):
@@ -39,6 +63,7 @@ class ThreadStatus(StrEnum):
39
63
  ACTIVE = "active"
40
64
  IDLE = "idle"
41
65
  WAITING = "waiting"
66
+ ERROR = "error"
42
67
  DELETED = "deleted"
43
68
  ARCHIVED = "archived"
44
69
 
@@ -92,10 +117,18 @@ class ThreadModel(BaseModel):
92
117
  ai_model: str = ""
93
118
 
94
119
 
120
+ class MCPAuthType(StrEnum):
121
+ """Enum for MCP server authentication types."""
122
+
123
+ NONE = "none"
124
+ API_KEY = "api_key"
125
+ OAUTH_DISCOVERY = "oauth_discovery"
126
+
127
+
95
128
  class MCPServer(rx.Model, table=True):
96
129
  """Model for MCP (Model Context Protocol) server configuration."""
97
130
 
98
- __tablename__ = "mcp_server"
131
+ __tablename__ = "assistant_mcp_servers"
99
132
 
100
133
  id: int | None = Field(default=None, primary_key=True)
101
134
  name: str = Field(unique=True, max_length=100, nullable=False)
@@ -103,3 +136,61 @@ class MCPServer(rx.Model, table=True):
103
136
  url: str = Field(nullable=False)
104
137
  headers: str = Field(nullable=False, sa_type=EncryptedString)
105
138
  prompt: str = Field(default="", max_length=2000, nullable=True)
139
+
140
+ # Authentication type
141
+ auth_type: str = Field(default=MCPAuthType.NONE, nullable=False)
142
+
143
+ # Optional discovery URL override
144
+ discovery_url: str | None = Field(default=None, nullable=True)
145
+
146
+ # Cached OAuth/Discovery metadata (read-only for user mostly)
147
+ oauth_issuer: str | None = Field(default=None, nullable=True)
148
+ oauth_authorize_url: str | None = Field(default=None, nullable=True)
149
+ oauth_token_url: str | None = Field(default=None, nullable=True)
150
+ oauth_scopes: str | None = Field(
151
+ default=None, nullable=True
152
+ ) # Space separated scopes
153
+
154
+ # Timestamp when discovery was last successfully run
155
+ oauth_discovered_at: datetime | None = Field(
156
+ default=None, sa_column=Column(DateTime(timezone=True), nullable=True)
157
+ )
158
+
159
+
160
+ class SystemPrompt(rx.Model, table=True):
161
+ """Model for system prompt versioning and management.
162
+
163
+ Each save creates a new immutable version. Supports up to 20,000 characters.
164
+ """
165
+
166
+ __tablename__ = "assistant_system_prompt"
167
+
168
+ id: int | None = Field(default=None, primary_key=True)
169
+ name: str = Field(max_length=200, nullable=False)
170
+ prompt: str = Field(max_length=20000, nullable=False)
171
+ version: int = Field(nullable=False)
172
+ user_id: int = Field(nullable=False)
173
+ created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
174
+
175
+
176
+ class AssistantThread(rx.Model, table=True):
177
+ """Model for storing chat threads in the database."""
178
+
179
+ __tablename__ = "assistant_thread"
180
+
181
+ id: int | None = Field(default=None, primary_key=True)
182
+ thread_id: str = Field(unique=True, index=True, nullable=False)
183
+ user_id: int = Field(index=True, nullable=False)
184
+ title: str = Field(default="", nullable=False)
185
+ state: str = Field(default=ThreadStatus.NEW, nullable=False)
186
+ ai_model: str = Field(default="", nullable=False)
187
+ active: bool = Field(default=False, nullable=False)
188
+ messages: list[dict[str, Any]] = Field(default=[], sa_column=Column(EncryptedJSON))
189
+ created_at: datetime = Field(
190
+ default_factory=lambda: datetime.now(UTC),
191
+ sa_column=Column(DateTime(timezone=True)),
192
+ )
193
+ updated_at: datetime = Field(
194
+ default_factory=lambda: datetime.now(UTC),
195
+ sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
196
+ )
@@ -12,7 +12,7 @@ from appkit_assistant.backend.models import (
12
12
  MessageType,
13
13
  )
14
14
  from appkit_assistant.backend.processors.openai_base import BaseOpenAIProcessor
15
- from appkit_assistant.backend.system_prompt import SYSTEM_PROMPT
15
+ from appkit_assistant.backend.system_prompt_cache import get_system_prompt
16
16
 
17
17
  logger = logging.getLogger(__name__)
18
18
 
@@ -404,7 +404,7 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
404
404
  )
405
405
 
406
406
  # Convert messages to responses format with system message
407
- input_messages = self._convert_messages_to_responses_format(
407
+ input_messages = await self._convert_messages_to_responses_format(
408
408
  messages, mcp_prompt=mcp_prompt
409
409
  )
410
410
 
@@ -453,8 +453,11 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
453
453
  prompt_string = "\n".join(prompts) if prompts else ""
454
454
  return tools, prompt_string
455
455
 
456
- def _convert_messages_to_responses_format(
457
- self, messages: list[Message], mcp_prompt: str = ""
456
+ async def _convert_messages_to_responses_format(
457
+ self,
458
+ messages: list[Message],
459
+ mcp_prompt: str = "",
460
+ use_system_prompt: bool = True,
458
461
  ) -> list[dict[str, Any]]:
459
462
  """Convert messages to the responses API input format.
460
463
 
@@ -471,13 +474,15 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
471
474
  else:
472
475
  mcp_prompt = ""
473
476
 
474
- system_text = SYSTEM_PROMPT.format(mcp_prompts=mcp_prompt)
475
- input_messages.append(
476
- {
477
- "role": "system",
478
- "content": [{"type": "input_text", "text": system_text}],
479
- }
480
- )
477
+ if use_system_prompt:
478
+ system_prompt_template = await get_system_prompt()
479
+ system_text = system_prompt_template.format(mcp_prompts=mcp_prompt)
480
+ input_messages.append(
481
+ {
482
+ "role": "system",
483
+ "content": [{"type": "input_text", "text": system_text}],
484
+ }
485
+ )
481
486
 
482
487
  # Add conversation messages
483
488
  for msg in messages:
@@ -1,10 +1,19 @@
1
1
  """Repository for MCP server data access operations."""
2
2
 
3
3
  import logging
4
+ from datetime import UTC, datetime
4
5
 
5
6
  import reflex as rx
7
+ from sqlalchemy.orm import defer
6
8
 
7
- from appkit_assistant.backend.models import MCPServer
9
+ from appkit_assistant.backend.models import (
10
+ AssistantThread,
11
+ MCPServer,
12
+ Message,
13
+ SystemPrompt,
14
+ ThreadModel,
15
+ ThreadStatus,
16
+ )
8
17
 
9
18
  logger = logging.getLogger(__name__)
10
19
 
@@ -94,3 +103,221 @@ class MCPServerRepository:
94
103
  return True
95
104
  logger.warning("MCP server with ID %s not found for deletion", server_id)
96
105
  return False
106
+
107
+
108
+ class SystemPromptRepository:
109
+ """Repository class for system prompt database operations.
110
+
111
+ Implements append-only versioning with full CRUD capabilities.
112
+ """
113
+
114
+ @staticmethod
115
+ async def get_all() -> list[SystemPrompt]:
116
+ """Retrieve all system prompt versions ordered by version descending."""
117
+ async with rx.asession() as session:
118
+ result = await session.exec(
119
+ SystemPrompt.select().order_by(SystemPrompt.version.desc())
120
+ )
121
+ return result.all()
122
+
123
+ @staticmethod
124
+ async def get_latest() -> SystemPrompt | None:
125
+ """Retrieve the latest system prompt version."""
126
+ async with rx.asession() as session:
127
+ result = await session.exec(
128
+ SystemPrompt.select().order_by(SystemPrompt.version.desc()).limit(1)
129
+ )
130
+ return result.first()
131
+
132
+ @staticmethod
133
+ async def get_by_id(prompt_id: int) -> SystemPrompt | None:
134
+ """Retrieve a system prompt by ID."""
135
+ async with rx.asession() as session:
136
+ result = await session.exec(
137
+ SystemPrompt.select().where(SystemPrompt.id == prompt_id)
138
+ )
139
+ return result.first()
140
+
141
+ @staticmethod
142
+ async def create(prompt: str, user_id: int) -> SystemPrompt:
143
+ """Neue System Prompt Version anlegen.
144
+
145
+ Version ist fortlaufende Ganzzahl, beginnend bei 1.
146
+ """
147
+ async with rx.asession() as session:
148
+ result = await session.exec(
149
+ SystemPrompt.select().order_by(SystemPrompt.version.desc()).limit(1)
150
+ )
151
+ latest = result.first()
152
+ next_version = (latest.version + 1) if latest else 1
153
+
154
+ name = f"Version {next_version}"
155
+
156
+ system_prompt = SystemPrompt(
157
+ name=name,
158
+ prompt=prompt,
159
+ version=next_version,
160
+ user_id=user_id,
161
+ created_at=datetime.now(UTC),
162
+ )
163
+ session.add(system_prompt)
164
+ await session.commit()
165
+ await session.refresh(system_prompt)
166
+
167
+ logger.info(
168
+ "Created system prompt version %s for user %s",
169
+ next_version,
170
+ user_id,
171
+ )
172
+ return system_prompt
173
+
174
+ @staticmethod
175
+ async def delete(prompt_id: int) -> bool:
176
+ """Delete a system prompt version by ID."""
177
+ async with rx.asession() as session:
178
+ result = await session.exec(
179
+ SystemPrompt.select().where(SystemPrompt.id == prompt_id)
180
+ )
181
+ prompt = result.first()
182
+ if prompt:
183
+ await session.delete(prompt)
184
+ await session.commit()
185
+ logger.info("Deleted system prompt version: %s", prompt.version)
186
+ return True
187
+ logger.warning(
188
+ "System prompt with ID %s not found for deletion",
189
+ prompt_id,
190
+ )
191
+ return False
192
+
193
+
194
+ class ThreadRepository:
195
+ """Repository class for Thread database operations."""
196
+
197
+ @staticmethod
198
+ async def get_by_user(user_id: int) -> list[ThreadModel]:
199
+ """Retrieve all threads for a user."""
200
+ async with rx.asession() as session:
201
+ result = await session.exec(
202
+ AssistantThread.select()
203
+ .where(AssistantThread.user_id == user_id)
204
+ .order_by(AssistantThread.updated_at.desc())
205
+ )
206
+ threads = result.all()
207
+ return [
208
+ ThreadModel(
209
+ thread_id=t.thread_id,
210
+ title=t.title,
211
+ state=ThreadStatus(t.state),
212
+ ai_model=t.ai_model,
213
+ active=t.active,
214
+ messages=[Message(**m) for m in t.messages],
215
+ )
216
+ for t in threads
217
+ ]
218
+
219
+ @staticmethod
220
+ async def save_thread(thread: ThreadModel, user_id: int) -> None:
221
+ """Save or update a thread."""
222
+ async with rx.asession() as session:
223
+ result = await session.exec(
224
+ AssistantThread.select().where(
225
+ AssistantThread.thread_id == thread.thread_id
226
+ )
227
+ )
228
+ db_thread = result.first()
229
+
230
+ messages_dict = [m.dict() for m in thread.messages]
231
+
232
+ if db_thread:
233
+ # Ensure user owns the thread or handle shared threads logic if needed
234
+ # For now, we assume thread_id is unique enough,
235
+ # but checking user_id is safer
236
+ if db_thread.user_id != user_id:
237
+ logger.warning(
238
+ "User %s tried to update thread %s belonging to user %s",
239
+ user_id,
240
+ thread.thread_id,
241
+ db_thread.user_id,
242
+ )
243
+ return
244
+
245
+ db_thread.title = thread.title
246
+ db_thread.state = thread.state.value
247
+ db_thread.ai_model = thread.ai_model
248
+ db_thread.active = thread.active
249
+ db_thread.messages = messages_dict
250
+ session.add(db_thread)
251
+ else:
252
+ db_thread = AssistantThread(
253
+ thread_id=thread.thread_id,
254
+ user_id=user_id,
255
+ title=thread.title,
256
+ state=thread.state.value,
257
+ ai_model=thread.ai_model,
258
+ active=thread.active,
259
+ messages=messages_dict,
260
+ )
261
+ session.add(db_thread)
262
+
263
+ await session.commit()
264
+
265
+ @staticmethod
266
+ async def delete_thread(thread_id: str, user_id: int) -> None:
267
+ """Delete a thread."""
268
+ async with rx.asession() as session:
269
+ result = await session.exec(
270
+ AssistantThread.select().where(
271
+ AssistantThread.thread_id == thread_id,
272
+ AssistantThread.user_id == user_id,
273
+ )
274
+ )
275
+ thread = result.first()
276
+ if thread:
277
+ await session.delete(thread)
278
+ await session.commit()
279
+
280
+ @staticmethod
281
+ async def get_summaries_by_user(user_id: int) -> list[ThreadModel]:
282
+ """Retrieve thread summaries (no messages) for a user."""
283
+ async with rx.asession() as session:
284
+ result = await session.exec(
285
+ AssistantThread.select()
286
+ .where(AssistantThread.user_id == user_id)
287
+ .options(defer(AssistantThread.messages))
288
+ .order_by(AssistantThread.updated_at.desc())
289
+ )
290
+ threads = result.all()
291
+ return [
292
+ ThreadModel(
293
+ thread_id=t.thread_id,
294
+ title=t.title,
295
+ state=ThreadStatus(t.state),
296
+ ai_model=t.ai_model,
297
+ active=t.active,
298
+ messages=[], # Empty messages for summary
299
+ )
300
+ for t in threads
301
+ ]
302
+
303
+ @staticmethod
304
+ async def get_thread_by_id(thread_id: str, user_id: int) -> ThreadModel | None:
305
+ """Retrieve a full thread by ID."""
306
+ async with rx.asession() as session:
307
+ result = await session.exec(
308
+ AssistantThread.select().where(
309
+ AssistantThread.thread_id == thread_id,
310
+ AssistantThread.user_id == user_id,
311
+ )
312
+ )
313
+ t = result.first()
314
+ if not t:
315
+ return None
316
+ return ThreadModel(
317
+ thread_id=t.thread_id,
318
+ title=t.title,
319
+ state=ThreadStatus(t.state),
320
+ ai_model=t.ai_model,
321
+ active=t.active,
322
+ messages=[Message(**m) for m in t.messages],
323
+ )
@@ -0,0 +1,161 @@
1
+ import asyncio
2
+ import logging
3
+ from datetime import UTC, datetime, timedelta
4
+ from typing import Final
5
+
6
+ from appkit_assistant.backend.repositories import SystemPromptRepository
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Cache TTL in seconds (default: 5 minutes)
11
+ CACHE_TTL_SECONDS: Final[int] = 300
12
+
13
+
14
+ class SystemPromptCache:
15
+ """Singleton cache for system prompt with TTL-based invalidation.
16
+
17
+ Features:
18
+ - Lazy loading on first access
19
+ - Automatic cache invalidation after TTL expires
20
+ - Thread-safe with asyncio lock
21
+ - Manual invalidation support for immediate updates
22
+ """
23
+
24
+ _instance: "SystemPromptCache | None" = None
25
+ _lock: asyncio.Lock = asyncio.Lock()
26
+
27
+ def __new__(cls) -> "SystemPromptCache":
28
+ """Ensure singleton pattern."""
29
+ if cls._instance is None:
30
+ cls._instance = super().__new__(cls)
31
+ cls._instance._initialized = False # noqa: SLF001
32
+ return cls._instance
33
+
34
+ def __init__(self) -> None:
35
+ """Initialize cache state (only once due to singleton)."""
36
+ if self._initialized:
37
+ return
38
+
39
+ self._cached_prompt: str | None = None
40
+ self._cached_version: int | None = None
41
+ self._cache_timestamp: datetime | None = None
42
+ self._ttl_seconds: int = CACHE_TTL_SECONDS
43
+ self._initialized = True
44
+
45
+ logger.info(
46
+ "SystemPromptCache initialized with TTL=%d seconds",
47
+ self._ttl_seconds,
48
+ )
49
+
50
+ def _is_cache_valid(self) -> bool:
51
+ """Check if cached prompt is still valid based on TTL."""
52
+ if self._cached_prompt is None or self._cache_timestamp is None:
53
+ return False
54
+
55
+ elapsed = datetime.now(UTC) - self._cache_timestamp
56
+ is_valid = elapsed < timedelta(seconds=self._ttl_seconds)
57
+
58
+ if not is_valid:
59
+ logger.debug("Cache expired after %s seconds", elapsed.total_seconds())
60
+
61
+ return is_valid
62
+
63
+ async def get_prompt(self) -> str:
64
+ """Get the latest system prompt (from cache or database).
65
+
66
+ Returns:
67
+ The current system prompt text.
68
+
69
+ Raises:
70
+ ValueError: If no system prompt exists in database.
71
+ """
72
+ async with self._lock:
73
+ if self._is_cache_valid():
74
+ logger.debug(
75
+ "Cache hit: version=%d, age=%s seconds",
76
+ self._cached_version,
77
+ (datetime.now(UTC) - self._cache_timestamp).total_seconds(),
78
+ )
79
+ return self._cached_prompt
80
+
81
+ # Cache miss or expired - fetch from database
82
+ logger.info("Cache miss - fetching latest prompt from database")
83
+
84
+ latest_prompt = await SystemPromptRepository.get_latest()
85
+
86
+ if latest_prompt is None:
87
+ msg = "No system prompt found in database"
88
+ logger.error(msg)
89
+ raise ValueError(msg)
90
+
91
+ self._cached_prompt = latest_prompt.prompt
92
+ self._cached_version = latest_prompt.version
93
+ self._cache_timestamp = datetime.now(UTC)
94
+
95
+ logger.info(
96
+ "Cached prompt version %d (%d characters)",
97
+ self._cached_version,
98
+ len(self._cached_prompt),
99
+ )
100
+
101
+ return self._cached_prompt
102
+
103
+ async def invalidate(self) -> None:
104
+ """Manually invalidate the cache.
105
+
106
+ Use this when a new prompt version is created to force
107
+ immediate reload on next access.
108
+ """
109
+ async with self._lock:
110
+ if self._cached_prompt is not None:
111
+ logger.info(
112
+ "Cache invalidated (was version %d)",
113
+ self._cached_version,
114
+ )
115
+ self._cached_prompt = None
116
+ self._cached_version = None
117
+ self._cache_timestamp = None
118
+ else:
119
+ logger.debug("Cache invalidation called but cache was empty")
120
+
121
+ def set_ttl(self, seconds: int) -> None:
122
+ """Update cache TTL.
123
+
124
+ Args:
125
+ seconds: New TTL in seconds.
126
+ """
127
+ self._ttl_seconds = seconds
128
+ logger.info("Cache TTL updated to %d seconds", seconds)
129
+
130
+ @property
131
+ def is_cached(self) -> bool:
132
+ """Check if prompt is currently cached and valid."""
133
+ return self._is_cache_valid()
134
+
135
+ @property
136
+ def cached_version(self) -> int | None:
137
+ """Get the currently cached prompt version (if any)."""
138
+ return self._cached_version if self._is_cache_valid() else None
139
+
140
+
141
+ # Global cache instance
142
+ _prompt_cache = SystemPromptCache()
143
+
144
+
145
+ async def get_system_prompt() -> str:
146
+ """Convenience function to get the current system prompt.
147
+
148
+ Returns:
149
+ The current system prompt text.
150
+ """
151
+ return await _prompt_cache.get_prompt()
152
+
153
+
154
+ async def invalidate_prompt_cache() -> None:
155
+ """Convenience function to invalidate the prompt cache."""
156
+ await _prompt_cache.invalidate()
157
+
158
+
159
+ def get_cache_instance() -> SystemPromptCache:
160
+ """Get the global cache instance for advanced usage."""
161
+ return _prompt_cache
@@ -12,10 +12,8 @@ from appkit_assistant.backend.models import (
12
12
  ThreadModel,
13
13
  ThreadStatus,
14
14
  )
15
- from appkit_assistant.state.thread_state import (
16
- ThreadState,
17
- ThreadListState,
18
- )
15
+ from appkit_assistant.state.thread_list_state import ThreadListState
16
+ from appkit_assistant.state.thread_state import ThreadState
19
17
  from appkit_assistant.components.mcp_server_table import mcp_servers_table
20
18
 
21
19
  __all__ = [
@@ -258,10 +258,15 @@ def add_mcp_server_button() -> rx.Component:
258
258
  rx.dialog.trigger(
259
259
  rx.button(
260
260
  rx.icon("plus"),
261
- rx.text("Neuen MCP Server anlegen", display=["none", "none", "block"]),
262
- size="3",
261
+ rx.text(
262
+ "Neuen MCP Server anlegen",
263
+ display=["none", "none", "block"],
264
+ size="2",
265
+ ),
266
+ size="2",
263
267
  variant="solid",
264
268
  on_click=[ValidationState.initialize(server=None)],
269
+ margin_bottom="15px",
265
270
  ),
266
271
  ),
267
272
  rx.dialog.content(
@@ -54,7 +54,7 @@ class MessageComponent:
54
54
 
55
55
  # Show thinking content only for the last assistant message
56
56
  should_show_thinking = (
57
- message.text == ThreadState.last_assistant_message_text
57
+ message.text == ThreadState.get_last_assistant_message_text
58
58
  ) & ThreadState.has_thinking_content
59
59
 
60
60
  # Main content area with all components
@@ -74,9 +74,9 @@ class MessageComponent:
74
74
  ),
75
75
  title="Denkprozess & Werkzeuge",
76
76
  info_text=(
77
- f"{ThreadState.unique_reasoning_sessions.length()} "
77
+ f"{ThreadState.get_unique_reasoning_sessions.length()} "
78
78
  f"Nachdenken, "
79
- f"{ThreadState.unique_tool_calls.length()} Werkzeuge"
79
+ f"{ThreadState.get_unique_tool_calls.length()} Werkzeuge"
80
80
  ),
81
81
  show_condition=should_show_thinking,
82
82
  expanded=ThreadState.thinking_expanded,