appkit-assistant 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,34 @@
1
- from datetime import datetime
1
+ import json
2
+ from datetime import UTC, datetime
2
3
  from enum import StrEnum
4
+ from typing import Any
3
5
 
4
6
  import reflex as rx
5
7
  from pydantic import BaseModel
8
+ from sqlalchemy.sql import func
6
9
  from sqlmodel import Column, DateTime, Field
7
10
 
11
+ from appkit_commons.database.configuration import DatabaseConfig
8
12
  from appkit_commons.database.entities import EncryptedString
13
+ from appkit_commons.registry import service_registry
14
+
15
+ db_config = service_registry().get(DatabaseConfig)
16
+ SECRET_VALUE = db_config.encryption_key.get_secret_value()
17
+
18
+
19
+ class EncryptedJSON(EncryptedString):
20
+ """Custom type for storing encrypted JSON data."""
21
+
22
+ def process_bind_param(self, value: Any, dialect: Any) -> str | None:
23
+ if value is not None:
24
+ value = json.dumps(value)
25
+ return super().process_bind_param(value, dialect)
26
+
27
+ def process_result_value(self, value: Any, dialect: Any) -> Any | None:
28
+ value = super().process_result_value(value, dialect)
29
+ if value is not None:
30
+ return json.loads(value)
31
+ return value
9
32
 
10
33
 
11
34
  class ChunkType(StrEnum):
@@ -40,6 +63,7 @@ class ThreadStatus(StrEnum):
40
63
  ACTIVE = "active"
41
64
  IDLE = "idle"
42
65
  WAITING = "waiting"
66
+ ERROR = "error"
43
67
  DELETED = "deleted"
44
68
  ARCHIVED = "archived"
45
69
 
@@ -104,7 +128,7 @@ class MCPAuthType(StrEnum):
104
128
  class MCPServer(rx.Model, table=True):
105
129
  """Model for MCP (Model Context Protocol) server configuration."""
106
130
 
107
- __tablename__ = "mcp_server"
131
+ __tablename__ = "assistant_mcp_servers"
108
132
 
109
133
  id: int | None = Field(default=None, primary_key=True)
110
134
  name: str = Field(unique=True, max_length=100, nullable=False)
@@ -139,7 +163,7 @@ class SystemPrompt(rx.Model, table=True):
139
163
  Each save creates a new immutable version. Supports up to 20,000 characters.
140
164
  """
141
165
 
142
- __tablename__ = "system_prompt"
166
+ __tablename__ = "assistant_system_prompt"
143
167
 
144
168
  id: int | None = Field(default=None, primary_key=True)
145
169
  name: str = Field(max_length=200, nullable=False)
@@ -147,3 +171,26 @@ class SystemPrompt(rx.Model, table=True):
147
171
  version: int = Field(nullable=False)
148
172
  user_id: int = Field(nullable=False)
149
173
  created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
174
+
175
+
176
+ class AssistantThread(rx.Model, table=True):
177
+ """Model for storing chat threads in the database."""
178
+
179
+ __tablename__ = "assistant_thread"
180
+
181
+ id: int | None = Field(default=None, primary_key=True)
182
+ thread_id: str = Field(unique=True, index=True, nullable=False)
183
+ user_id: int = Field(index=True, nullable=False)
184
+ title: str = Field(default="", nullable=False)
185
+ state: str = Field(default=ThreadStatus.NEW, nullable=False)
186
+ ai_model: str = Field(default="", nullable=False)
187
+ active: bool = Field(default=False, nullable=False)
188
+ messages: list[dict[str, Any]] = Field(default=[], sa_column=Column(EncryptedJSON))
189
+ created_at: datetime = Field(
190
+ default_factory=lambda: datetime.now(UTC),
191
+ sa_column=Column(DateTime(timezone=True)),
192
+ )
193
+ updated_at: datetime = Field(
194
+ default_factory=lambda: datetime.now(UTC),
195
+ sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
196
+ )
@@ -454,7 +454,10 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
454
454
  return tools, prompt_string
455
455
 
456
456
  async def _convert_messages_to_responses_format(
457
- self, messages: list[Message], mcp_prompt: str = ""
457
+ self,
458
+ messages: list[Message],
459
+ mcp_prompt: str = "",
460
+ use_system_prompt: bool = True,
458
461
  ) -> list[dict[str, Any]]:
459
462
  """Convert messages to the responses API input format.
460
463
 
@@ -471,14 +474,15 @@ class OpenAIResponsesProcessor(BaseOpenAIProcessor):
471
474
  else:
472
475
  mcp_prompt = ""
473
476
 
474
- system_prompt_template = await get_system_prompt()
475
- system_text = system_prompt_template.format(mcp_prompts=mcp_prompt)
476
- input_messages.append(
477
- {
478
- "role": "system",
479
- "content": [{"type": "input_text", "text": system_text}],
480
- }
481
- )
477
+ if use_system_prompt:
478
+ system_prompt_template = await get_system_prompt()
479
+ system_text = system_prompt_template.format(mcp_prompts=mcp_prompt)
480
+ input_messages.append(
481
+ {
482
+ "role": "system",
483
+ "content": [{"type": "input_text", "text": system_text}],
484
+ }
485
+ )
482
486
 
483
487
  # Add conversation messages
484
488
  for msg in messages:
@@ -4,8 +4,16 @@ import logging
4
4
  from datetime import UTC, datetime
5
5
 
6
6
  import reflex as rx
7
+ from sqlalchemy.orm import defer
7
8
 
8
- from appkit_assistant.backend.models import MCPServer, SystemPrompt
9
+ from appkit_assistant.backend.models import (
10
+ AssistantThread,
11
+ MCPServer,
12
+ Message,
13
+ SystemPrompt,
14
+ ThreadModel,
15
+ ThreadStatus,
16
+ )
9
17
 
10
18
  logger = logging.getLogger(__name__)
11
19
 
@@ -181,3 +189,135 @@ class SystemPromptRepository:
181
189
  prompt_id,
182
190
  )
183
191
  return False
192
+
193
+
194
+ class ThreadRepository:
195
+ """Repository class for Thread database operations."""
196
+
197
+ @staticmethod
198
+ async def get_by_user(user_id: int) -> list[ThreadModel]:
199
+ """Retrieve all threads for a user."""
200
+ async with rx.asession() as session:
201
+ result = await session.exec(
202
+ AssistantThread.select()
203
+ .where(AssistantThread.user_id == user_id)
204
+ .order_by(AssistantThread.updated_at.desc())
205
+ )
206
+ threads = result.all()
207
+ return [
208
+ ThreadModel(
209
+ thread_id=t.thread_id,
210
+ title=t.title,
211
+ state=ThreadStatus(t.state),
212
+ ai_model=t.ai_model,
213
+ active=t.active,
214
+ messages=[Message(**m) for m in t.messages],
215
+ )
216
+ for t in threads
217
+ ]
218
+
219
+ @staticmethod
220
+ async def save_thread(thread: ThreadModel, user_id: int) -> None:
221
+ """Save or update a thread."""
222
+ async with rx.asession() as session:
223
+ result = await session.exec(
224
+ AssistantThread.select().where(
225
+ AssistantThread.thread_id == thread.thread_id
226
+ )
227
+ )
228
+ db_thread = result.first()
229
+
230
+ messages_dict = [m.dict() for m in thread.messages]
231
+
232
+ if db_thread:
233
+ # Ensure user owns the thread or handle shared threads logic if needed
234
+ # For now, we assume thread_id is unique enough,
235
+ # but checking user_id is safer
236
+ if db_thread.user_id != user_id:
237
+ logger.warning(
238
+ "User %s tried to update thread %s belonging to user %s",
239
+ user_id,
240
+ thread.thread_id,
241
+ db_thread.user_id,
242
+ )
243
+ return
244
+
245
+ db_thread.title = thread.title
246
+ db_thread.state = thread.state.value
247
+ db_thread.ai_model = thread.ai_model
248
+ db_thread.active = thread.active
249
+ db_thread.messages = messages_dict
250
+ session.add(db_thread)
251
+ else:
252
+ db_thread = AssistantThread(
253
+ thread_id=thread.thread_id,
254
+ user_id=user_id,
255
+ title=thread.title,
256
+ state=thread.state.value,
257
+ ai_model=thread.ai_model,
258
+ active=thread.active,
259
+ messages=messages_dict,
260
+ )
261
+ session.add(db_thread)
262
+
263
+ await session.commit()
264
+
265
+ @staticmethod
266
+ async def delete_thread(thread_id: str, user_id: int) -> None:
267
+ """Delete a thread."""
268
+ async with rx.asession() as session:
269
+ result = await session.exec(
270
+ AssistantThread.select().where(
271
+ AssistantThread.thread_id == thread_id,
272
+ AssistantThread.user_id == user_id,
273
+ )
274
+ )
275
+ thread = result.first()
276
+ if thread:
277
+ await session.delete(thread)
278
+ await session.commit()
279
+
280
+ @staticmethod
281
+ async def get_summaries_by_user(user_id: int) -> list[ThreadModel]:
282
+ """Retrieve thread summaries (no messages) for a user."""
283
+ async with rx.asession() as session:
284
+ result = await session.exec(
285
+ AssistantThread.select()
286
+ .where(AssistantThread.user_id == user_id)
287
+ .options(defer(AssistantThread.messages))
288
+ .order_by(AssistantThread.updated_at.desc())
289
+ )
290
+ threads = result.all()
291
+ return [
292
+ ThreadModel(
293
+ thread_id=t.thread_id,
294
+ title=t.title,
295
+ state=ThreadStatus(t.state),
296
+ ai_model=t.ai_model,
297
+ active=t.active,
298
+ messages=[], # Empty messages for summary
299
+ )
300
+ for t in threads
301
+ ]
302
+
303
+ @staticmethod
304
+ async def get_thread_by_id(thread_id: str, user_id: int) -> ThreadModel | None:
305
+ """Retrieve a full thread by ID."""
306
+ async with rx.asession() as session:
307
+ result = await session.exec(
308
+ AssistantThread.select().where(
309
+ AssistantThread.thread_id == thread_id,
310
+ AssistantThread.user_id == user_id,
311
+ )
312
+ )
313
+ t = result.first()
314
+ if not t:
315
+ return None
316
+ return ThreadModel(
317
+ thread_id=t.thread_id,
318
+ title=t.title,
319
+ state=ThreadStatus(t.state),
320
+ ai_model=t.ai_model,
321
+ active=t.active,
322
+ messages=[Message(**m) for m in t.messages],
323
+ )
@@ -12,10 +12,8 @@ from appkit_assistant.backend.models import (
12
12
  ThreadModel,
13
13
  ThreadStatus,
14
14
  )
15
- from appkit_assistant.state.thread_state import (
16
- ThreadState,
17
- ThreadListState,
18
- )
15
+ from appkit_assistant.state.thread_list_state import ThreadListState
16
+ from appkit_assistant.state.thread_state import ThreadState
19
17
  from appkit_assistant.components.mcp_server_table import mcp_servers_table
20
18
 
21
19
  __all__ = [
@@ -258,10 +258,15 @@ def add_mcp_server_button() -> rx.Component:
258
258
  rx.dialog.trigger(
259
259
  rx.button(
260
260
  rx.icon("plus"),
261
- rx.text("Neuen MCP Server anlegen", display=["none", "none", "block"]),
262
- size="3",
261
+ rx.text(
262
+ "Neuen MCP Server anlegen",
263
+ display=["none", "none", "block"],
264
+ size="2",
265
+ ),
266
+ size="2",
263
267
  variant="solid",
264
268
  on_click=[ValidationState.initialize(server=None)],
269
+ margin_bottom="15px",
265
270
  ),
266
271
  ),
267
272
  rx.dialog.content(
@@ -54,7 +54,7 @@ class MessageComponent:
54
54
 
55
55
  # Show thinking content only for the last assistant message
56
56
  should_show_thinking = (
57
- message.text == ThreadState.last_assistant_message_text
57
+ message.text == ThreadState.get_last_assistant_message_text
58
58
  ) & ThreadState.has_thinking_content
59
59
 
60
60
  # Main content area with all components
@@ -74,9 +74,9 @@ class MessageComponent:
74
74
  ),
75
75
  title="Denkprozess & Werkzeuge",
76
76
  info_text=(
77
- f"{ThreadState.unique_reasoning_sessions.length()} "
77
+ f"{ThreadState.get_unique_reasoning_sessions.length()} "
78
78
  f"Nachdenken, "
79
- f"{ThreadState.unique_tool_calls.length()} Werkzeuge"
79
+ f"{ThreadState.get_unique_tool_calls.length()} Werkzeuge"
80
80
  ),
81
81
  show_condition=should_show_thinking,
82
82
  expanded=ThreadState.thinking_expanded,
@@ -4,15 +4,11 @@ from collections.abc import Callable
4
4
  import reflex as rx
5
5
 
6
6
  import appkit_mantine as mn
7
+ from appkit_assistant.backend.models import Message, MessageType
7
8
  from appkit_assistant.components import composer
8
9
  from appkit_assistant.components.message import MessageComponent
9
10
  from appkit_assistant.components.threadlist import ThreadList
10
- from appkit_assistant.state.thread_state import (
11
- Message,
12
- MessageType,
13
- ThreadListState,
14
- ThreadState,
15
- )
11
+ from appkit_assistant.state.thread_state import ThreadState
16
12
 
17
13
  logger = logging.getLogger(__name__)
18
14
 
@@ -55,7 +51,7 @@ class Assistant:
55
51
  lambda suggestion: Assistant.suggestion(
56
52
  prompt=suggestion.prompt,
57
53
  icon=suggestion.icon,
58
- update_prompt=ThreadState.update_prompt,
54
+ update_prompt=ThreadState.set_prompt,
59
55
  ),
60
56
  ),
61
57
  spacing="4",
@@ -122,7 +118,7 @@ class Assistant:
122
118
  ),
123
119
  rx.hstack(
124
120
  composer.tools(
125
- show=with_tools and ThreadState.current_model_supports_tools
121
+ show=with_tools and ThreadState.selected_model_supports_tools
126
122
  ),
127
123
  composer.add_attachment(show=with_attachments),
128
124
  composer.clear(show=with_clear),
@@ -156,9 +152,6 @@ class Assistant:
156
152
  # if suggestions is not None:
157
153
  # ThreadState.set_suggestions(suggestions)
158
154
 
159
- if with_thread_list:
160
- ThreadState.with_thread_list = with_thread_list
161
-
162
155
  return rx.flex(
163
156
  rx.cond(
164
157
  ThreadState.messages,
@@ -207,7 +200,10 @@ class Assistant:
207
200
  spacing="0",
208
201
  flex_shrink=0,
209
202
  z_index=1000,
210
- on_mount=ThreadState.load_available_mcp_servers,
203
+ on_mount=[
204
+ ThreadState.set_with_thread_list(with_thread_list),
205
+ ThreadState.load_mcp_servers,
206
+ ],
211
207
  ),
212
208
  **props,
213
209
  )
@@ -216,12 +212,8 @@ class Assistant:
216
212
  def thread_list(
217
213
  *items,
218
214
  with_footer: bool = False,
219
- default_model: str | None = None,
220
215
  **props,
221
216
  ) -> rx.Component:
222
- if default_model:
223
- ThreadListState.default_model = default_model
224
-
225
217
  return rx.flex(
226
218
  rx.flex(
227
219
  ThreadList.header(
@@ -1,6 +1,8 @@
1
1
  import reflex as rx
2
2
 
3
- from appkit_assistant.state.thread_state import ThreadListState, ThreadModel
3
+ from appkit_assistant.backend.models import ThreadModel
4
+ from appkit_assistant.state.thread_list_state import ThreadListState
5
+ from appkit_assistant.state.thread_state import ThreadState
4
6
 
5
7
 
6
8
  class ThreadList:
@@ -13,7 +15,7 @@ class ThreadList:
13
15
  rx.text(title),
14
16
  size="2",
15
17
  margin_right="28px",
16
- on_click=ThreadListState.create_thread(),
18
+ on_click=ThreadState.new_thread(),
17
19
  width="95%",
18
20
  ),
19
21
  content="Neuen Chat starten",
@@ -46,24 +48,28 @@ class ThreadList:
46
48
  min_width="0",
47
49
  title=thread.title,
48
50
  ),
49
- rx.tooltip(
50
- rx.button(
51
- rx.icon(
52
- "trash",
53
- size=13,
54
- stroke_width=1.5,
51
+ rx.cond(
52
+ ThreadListState.loading_thread_id == thread.thread_id,
53
+ rx.spinner(size="1", margin_left="6px", margin_right="6px"),
54
+ rx.tooltip(
55
+ rx.button(
56
+ rx.icon(
57
+ "trash",
58
+ size=13,
59
+ stroke_width=1.5,
60
+ ),
61
+ variant="ghost",
62
+ size="1",
63
+ margin_left="0px",
64
+ margin_right="0px",
65
+ color_scheme="gray",
66
+ on_click=ThreadListState.delete_thread(thread.thread_id),
55
67
  ),
56
- variant="ghost",
57
- size="1",
58
- margin_left="0px",
59
- margin_right="0px",
60
- color_scheme="gray",
61
- on_click=ThreadListState.delete_thread(thread.thread_id),
68
+ content="Chat löschen",
69
+ flex_shrink=0,
62
70
  ),
63
- content="Chat löschen",
64
- flex_shrink=0,
65
71
  ),
66
- on_click=ThreadListState.select_thread(thread.thread_id),
72
+ on_click=ThreadState.load_thread(thread.thread_id),
67
73
  flex_direction=["row"],
68
74
  margin_right="10px",
69
75
  margin_bottom="8px",
@@ -113,18 +119,25 @@ class ThreadList:
113
119
  ThreadListState.threads,
114
120
  ThreadList.thread_list_item,
115
121
  ),
116
- rx.text(
117
- "Keine Chats vorhanden.",
118
- size="2",
119
- white_space="nowrap",
120
- overflow="hidden",
121
- text_overflow="ellipsis",
122
- flex_grow="1",
123
- min_width="0",
124
- margin_right="10px",
125
- margin_bottom="8px",
126
- padding="6px",
127
- align="center",
122
+ rx.cond(
123
+ ThreadListState.loading,
124
+ rx.vstack(
125
+ rx.skeleton(height="34px", width="210px"),
126
+ rx.skeleton(height="34px", width="210px"),
127
+ rx.skeleton(height="34px", width="210px"),
128
+ spacing="2",
129
+ ),
130
+ rx.text(
131
+ "Noch keine Chats vorhanden. "
132
+ "Klicke auf 'Neuer Chat', um zu beginnen.",
133
+ size="2",
134
+ flex_grow="1",
135
+ min_width="0",
136
+ margin_right="10px",
137
+ margin_bottom="8px",
138
+ padding="6px",
139
+ align="center",
140
+ ),
128
141
  ),
129
142
  ),
130
143
  scrollbars="vertical",
@@ -92,6 +92,6 @@ def tools_popover() -> rx.Component:
92
92
  side="top",
93
93
  ),
94
94
  open=ThreadState.show_tools_modal,
95
- on_open_change=ThreadState.set_show_tools_modal,
95
+ on_open_change=ThreadState.toogle_tools_modal,
96
96
  placement="bottom-start",
97
97
  )
@@ -8,3 +8,4 @@ class AssistantConfig(BaseConfig):
8
8
  openai_base_url: str | None = None
9
9
  openai_api_key: SecretStr | None = None
10
10
  google_api_key: SecretStr | None = None
11
+ azure_ai_projects_endpoint: str | None = None
@@ -25,7 +25,6 @@ class SystemPromptState(State):
25
25
  is_loading: bool = False
26
26
  error_message: str = ""
27
27
  char_count: int = 0
28
- # Trigger to force textarea update when selecting a version
29
28
  textarea_key: int = 0
30
29
 
31
30
  async def load_versions(self) -> None:
@@ -45,12 +44,10 @@ class SystemPromptState(State):
45
44
  for p in prompts
46
45
  ]
47
46
 
48
- # Populate map for fast switching
49
47
  self.prompt_map = {str(p.version): p.prompt for p in prompts}
50
48
 
51
49
  if prompts:
52
50
  latest = prompts[0]
53
- # Automatically select the latest version
54
51
  self.selected_version_id = latest.version
55
52
 
56
53
  if not self.current_prompt:
@@ -64,8 +61,9 @@ class SystemPromptState(State):
64
61
  self.current_prompt = ""
65
62
  self.last_saved_prompt = self.current_prompt
66
63
 
67
- # Zähler initial setzen
68
64
  self.char_count = len(self.current_prompt)
65
+ # Force textarea to re-render with loaded content
66
+ self.textarea_key += 1
69
67
 
70
68
  logger.info("Loaded %s system prompt versions", len(self.versions))
71
69
  except Exception as exc: