lollms-client 0.20.10__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lollms-client might be problematic. Click here for more details.
- examples/console_discussion.py +207 -0
- examples/gradio_lollms_chat.py +259 -0
- examples/lollms_discussions_test.py +155 -0
- lollms_client/__init__.py +3 -3
- lollms_client/llm_bindings/ollama/__init__.py +1 -1
- lollms_client/lollms_core.py +83 -1
- lollms_client/lollms_discussion.py +578 -357
- lollms_client/lollms_types.py +19 -16
- lollms_client/lollms_utilities.py +71 -57
- lollms_client/mcp_bindings/remote_mcp/__init__.py +2 -1
- {lollms_client-0.20.10.dist-info → lollms_client-0.21.0.dist-info}/METADATA +1 -1
- {lollms_client-0.20.10.dist-info → lollms_client-0.21.0.dist-info}/RECORD +15 -15
- examples/personality_test/chat_test.py +0 -37
- examples/personality_test/chat_with_aristotle.py +0 -42
- examples/personality_test/tesks_test.py +0 -62
- {lollms_client-0.20.10.dist-info → lollms_client-0.21.0.dist-info}/WHEEL +0 -0
- {lollms_client-0.20.10.dist-info → lollms_client-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {lollms_client-0.20.10.dist-info → lollms_client-0.21.0.dist-info}/top_level.txt +0 -0
|
@@ -1,412 +1,633 @@
|
|
|
1
|
-
# lollms_discussion.py
|
|
2
|
-
|
|
3
1
|
import yaml
|
|
4
|
-
|
|
5
|
-
|
|
2
|
+
import json
|
|
3
|
+
import base64
|
|
4
|
+
import os
|
|
6
5
|
import uuid
|
|
6
|
+
import shutil
|
|
7
7
|
from collections import defaultdict
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import List, Dict, Optional, Union, Any, Type, Callable
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
from sqlalchemy import (create_engine, Column, String, Text, Integer, DateTime,
|
|
13
|
+
ForeignKey, JSON, Boolean, LargeBinary, Index)
|
|
14
|
+
from sqlalchemy.orm import sessionmaker, relationship, Session, declarative_base
|
|
15
|
+
from sqlalchemy.types import TypeDecorator
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from cryptography.fernet import Fernet, InvalidToken
|
|
19
|
+
from cryptography.hazmat.primitives import hashes
|
|
20
|
+
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
21
|
+
from cryptography.hazmat.backends import default_backend
|
|
22
|
+
ENCRYPTION_AVAILABLE = True
|
|
23
|
+
except ImportError:
|
|
24
|
+
ENCRYPTION_AVAILABLE = False
|
|
25
|
+
|
|
26
|
+
if False:
|
|
27
|
+
from lollms_client import LollmsClient
|
|
28
|
+
from lollms_client.lollms_types import MSG_TYPE
|
|
29
|
+
|
|
30
|
+
class EncryptedString(TypeDecorator):
|
|
31
|
+
impl = LargeBinary
|
|
32
|
+
cache_ok = True
|
|
33
|
+
|
|
34
|
+
def __init__(self, key: str, *args, **kwargs):
|
|
35
|
+
super().__init__(*args, **kwargs)
|
|
36
|
+
if not ENCRYPTION_AVAILABLE: raise ImportError("'cryptography' is required for DB encryption.")
|
|
37
|
+
self.salt = b'lollms-fixed-salt-for-db-encryption'
|
|
38
|
+
kdf = PBKDF2HMAC(
|
|
39
|
+
algorithm=hashes.SHA256(), length=32, salt=self.salt,
|
|
40
|
+
iterations=480000, backend=default_backend()
|
|
41
|
+
)
|
|
42
|
+
derived_key = base64.urlsafe_b64encode(kdf.derive(key.encode()))
|
|
43
|
+
self.fernet = Fernet(derived_key)
|
|
8
44
|
|
|
9
|
-
|
|
10
|
-
if
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
45
|
+
def process_bind_param(self, value: Optional[str], dialect) -> Optional[bytes]:
|
|
46
|
+
if value is None: return None
|
|
47
|
+
return self.fernet.encrypt(value.encode('utf-8'))
|
|
48
|
+
|
|
49
|
+
def process_result_value(self, value: Optional[bytes], dialect) -> Optional[str]:
|
|
50
|
+
if value is None: return None
|
|
51
|
+
try:
|
|
52
|
+
return self.fernet.decrypt(value).decode('utf-8')
|
|
53
|
+
except InvalidToken:
|
|
54
|
+
return "<DECRYPTION_FAILED: Invalid Key or Corrupt Data>"
|
|
55
|
+
|
|
56
|
+
def create_dynamic_models(discussion_mixin: Optional[Type] = None, message_mixin: Optional[Type] = None, encryption_key: Optional[str] = None):
|
|
57
|
+
Base = declarative_base()
|
|
58
|
+
EncryptedText = EncryptedString(encryption_key) if encryption_key else Text
|
|
59
|
+
|
|
60
|
+
class DiscussionBase(Base):
|
|
61
|
+
__abstract__ = True
|
|
62
|
+
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
|
63
|
+
system_prompt = Column(EncryptedText, nullable=True)
|
|
64
|
+
participants = Column(JSON, nullable=True, default=dict)
|
|
65
|
+
active_branch_id = Column(String, nullable=True)
|
|
66
|
+
discussion_metadata = Column(JSON, nullable=True, default=dict)
|
|
67
|
+
created_at = Column(DateTime, default=datetime.utcnow)
|
|
68
|
+
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
69
|
+
|
|
70
|
+
class MessageBase(Base):
|
|
71
|
+
__abstract__ = True
|
|
72
|
+
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
|
73
|
+
discussion_id = Column(String, ForeignKey('discussions.id'), nullable=False)
|
|
74
|
+
parent_id = Column(String, ForeignKey('messages.id'), nullable=True)
|
|
75
|
+
sender = Column(String, nullable=False)
|
|
76
|
+
sender_type = Column(String, nullable=False)
|
|
77
|
+
content = Column(EncryptedText, nullable=False)
|
|
78
|
+
message_metadata = Column(JSON, nullable=True, default=dict)
|
|
79
|
+
images = Column(JSON, nullable=True, default=list)
|
|
80
|
+
created_at = Column(DateTime, default=datetime.utcnow)
|
|
81
|
+
|
|
82
|
+
discussion_attrs = {'__tablename__': 'discussions'}
|
|
83
|
+
if hasattr(discussion_mixin, '__table_args__'):
|
|
84
|
+
discussion_attrs['__table_args__'] = discussion_mixin.__table_args__
|
|
85
|
+
if discussion_mixin:
|
|
86
|
+
for attr, col in discussion_mixin.__dict__.items():
|
|
87
|
+
if isinstance(col, Column):
|
|
88
|
+
discussion_attrs[attr] = col
|
|
89
|
+
|
|
90
|
+
message_attrs = {'__tablename__': 'messages'}
|
|
91
|
+
if hasattr(message_mixin, '__table_args__'):
|
|
92
|
+
message_attrs['__table_args__'] = message_mixin.__table_args__
|
|
93
|
+
if message_mixin:
|
|
94
|
+
for attr, col in message_mixin.__dict__.items():
|
|
95
|
+
if isinstance(col, Column):
|
|
96
|
+
message_attrs[attr] = col
|
|
97
|
+
|
|
98
|
+
discussion_bases = (discussion_mixin, DiscussionBase) if discussion_mixin else (DiscussionBase,)
|
|
99
|
+
DynamicDiscussion = type('Discussion', discussion_bases, discussion_attrs)
|
|
100
|
+
|
|
101
|
+
message_bases = (message_mixin, MessageBase) if message_mixin else (MessageBase,)
|
|
102
|
+
DynamicMessage = type('Message', message_bases, message_attrs)
|
|
103
|
+
|
|
104
|
+
DynamicDiscussion.messages = relationship(DynamicMessage, back_populates="discussion", cascade="all, delete-orphan", lazy="joined")
|
|
105
|
+
DynamicMessage.discussion = relationship(DynamicDiscussion, back_populates="messages")
|
|
106
|
+
|
|
107
|
+
return Base, DynamicDiscussion, DynamicMessage
|
|
39
108
|
|
|
109
|
+
class DatabaseManager:
|
|
110
|
+
def __init__(self, db_path: str, discussion_mixin: Optional[Type] = None, message_mixin: Optional[Type] = None, encryption_key: Optional[str] = None):
|
|
111
|
+
if not db_path: raise ValueError("Database path cannot be empty.")
|
|
112
|
+
self.Base, self.DiscussionModel, self.MessageModel = create_dynamic_models(
|
|
113
|
+
discussion_mixin, message_mixin, encryption_key
|
|
114
|
+
)
|
|
115
|
+
self.engine = create_engine(db_path)
|
|
116
|
+
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
|
|
117
|
+
self.create_tables()
|
|
118
|
+
|
|
119
|
+
def create_tables(self):
|
|
120
|
+
self.Base.metadata.create_all(bind=self.engine)
|
|
121
|
+
|
|
122
|
+
def get_session(self) -> Session:
|
|
123
|
+
return self.SessionLocal()
|
|
124
|
+
|
|
125
|
+
def list_discussions(self) -> List[Dict]:
|
|
126
|
+
session = self.get_session()
|
|
127
|
+
discussions = session.query(self.DiscussionModel).all()
|
|
128
|
+
session.close()
|
|
129
|
+
discussion_list = []
|
|
130
|
+
for disc in discussions:
|
|
131
|
+
disc_dict = {c.name: getattr(disc, c.name) for c in disc.__table__.columns}
|
|
132
|
+
discussion_list.append(disc_dict)
|
|
133
|
+
return discussion_list
|
|
134
|
+
|
|
135
|
+
def get_discussion(self, lollms_client: 'LollmsClient', discussion_id: str, **kwargs) -> Optional['LollmsDiscussion']:
|
|
136
|
+
session = self.get_session()
|
|
137
|
+
db_disc = session.query(self.DiscussionModel).filter_by(id=discussion_id).first()
|
|
138
|
+
session.close()
|
|
139
|
+
if db_disc:
|
|
140
|
+
return LollmsDiscussion(lollmsClient=lollms_client, discussion_id=discussion_id, db_manager=self, **kwargs)
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
def search_discussions(self, **criteria) -> List[Dict]:
|
|
144
|
+
session = self.get_session()
|
|
145
|
+
query = session.query(self.DiscussionModel)
|
|
146
|
+
for key, value in criteria.items():
|
|
147
|
+
query = query.filter(getattr(self.DiscussionModel, key).ilike(f"%{value}%"))
|
|
148
|
+
discussions = query.all()
|
|
149
|
+
session.close()
|
|
150
|
+
discussion_list = []
|
|
151
|
+
for disc in discussions:
|
|
152
|
+
disc_dict = {c.name: getattr(disc, c.name) for c in disc.__table__.columns}
|
|
153
|
+
discussion_list.append(disc_dict)
|
|
154
|
+
return discussion_list
|
|
155
|
+
|
|
156
|
+
def delete_discussion(self, discussion_id: str):
|
|
157
|
+
session = self.get_session()
|
|
158
|
+
db_disc = session.query(self.DiscussionModel).filter_by(id=discussion_id).first()
|
|
159
|
+
if db_disc:
|
|
160
|
+
session.delete(db_disc)
|
|
161
|
+
session.commit()
|
|
162
|
+
session.close()
|
|
40
163
|
|
|
41
164
|
class LollmsDiscussion:
|
|
42
|
-
|
|
43
|
-
Manages a branching conversation tree, including system prompts, participants,
|
|
44
|
-
an internal knowledge scratchpad, and context pruning capabilities.
|
|
45
|
-
"""
|
|
46
|
-
|
|
47
|
-
def __init__(self, lollmsClient: 'LollmsClient'):
|
|
48
|
-
"""
|
|
49
|
-
Initializes a new LollmsDiscussion instance.
|
|
50
|
-
|
|
51
|
-
Args:
|
|
52
|
-
lollmsClient: An instance of LollmsClient, required for tokenization.
|
|
53
|
-
"""
|
|
165
|
+
def __init__(self, lollmsClient: 'LollmsClient', discussion_id: Optional[str] = None, db_manager: Optional[DatabaseManager] = None, autosave: bool = False, max_context_size: Optional[int] = None):
|
|
54
166
|
self.lollmsClient = lollmsClient
|
|
55
|
-
self.
|
|
56
|
-
self.
|
|
167
|
+
self.db_manager = db_manager
|
|
168
|
+
self.autosave = autosave
|
|
169
|
+
self.max_context_size = max_context_size
|
|
170
|
+
self._is_db_backed = db_manager is not None
|
|
171
|
+
|
|
172
|
+
self.session = None
|
|
173
|
+
self.db_discussion = None
|
|
174
|
+
self._messages_to_delete = []
|
|
57
175
|
|
|
58
|
-
|
|
59
|
-
"""Helper to reset all discussion attributes to their defaults."""
|
|
60
|
-
self.messages: List[LollmsMessage] = []
|
|
61
|
-
self.active_branch_id: Optional[str] = None
|
|
62
|
-
self.message_index: Dict[str, LollmsMessage] = {}
|
|
63
|
-
self.children_index: Dict[Optional[str], List[str]] = defaultdict(list)
|
|
64
|
-
self.participants: Dict[str, str] = {}
|
|
65
|
-
self.system_prompt: Optional[str] = None
|
|
66
|
-
self.scratchpad: Optional[str] = None
|
|
67
|
-
|
|
68
|
-
# --- Scratchpad Management Methods ---
|
|
69
|
-
def set_scratchpad(self, content: str):
|
|
70
|
-
"""Sets or replaces the entire content of the internal scratchpad."""
|
|
71
|
-
self.scratchpad = content
|
|
72
|
-
|
|
73
|
-
def update_scratchpad(self, new_content: str, append: bool = True):
|
|
74
|
-
"""
|
|
75
|
-
Updates the scratchpad. By default, it appends with a newline separator.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
new_content: The new text to add to the scratchpad.
|
|
79
|
-
append: If True, appends to existing content. If False, replaces it.
|
|
80
|
-
"""
|
|
81
|
-
if append and self.scratchpad:
|
|
82
|
-
self.scratchpad += f"\n{new_content}"
|
|
83
|
-
else:
|
|
84
|
-
self.scratchpad = new_content
|
|
176
|
+
self._reset_in_memory_state()
|
|
85
177
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
178
|
+
if self._is_db_backed:
|
|
179
|
+
if not discussion_id: raise ValueError("A discussion_id is required for database-backed discussions.")
|
|
180
|
+
self.session = db_manager.get_session()
|
|
181
|
+
self._load_from_db(discussion_id)
|
|
182
|
+
else:
|
|
183
|
+
self.id = discussion_id or str(uuid.uuid4())
|
|
184
|
+
self.created_at = datetime.utcnow()
|
|
185
|
+
self.updated_at = self.created_at
|
|
89
186
|
|
|
90
|
-
def
|
|
91
|
-
|
|
92
|
-
self.
|
|
187
|
+
def _reset_in_memory_state(self):
|
|
188
|
+
self.id: str = ""
|
|
189
|
+
self.system_prompt: Optional[str] = None
|
|
190
|
+
self.participants: Dict[str, str] = {}
|
|
191
|
+
self.active_branch_id: Optional[str] = None
|
|
192
|
+
self.metadata: Dict[str, Any] = {}
|
|
193
|
+
self.scratchpad: str = ""
|
|
194
|
+
self.messages: List[Dict] = []
|
|
195
|
+
self.message_index: Dict[str, Dict] = {}
|
|
196
|
+
self.created_at: Optional[datetime] = None
|
|
197
|
+
self.updated_at: Optional[datetime] = None
|
|
198
|
+
|
|
199
|
+
def _load_from_db(self, discussion_id: str):
|
|
200
|
+
self.db_discussion = self.session.query(self.db_manager.DiscussionModel).filter(self.db_manager.DiscussionModel.id == discussion_id).one()
|
|
201
|
+
|
|
202
|
+
self.id = self.db_discussion.id
|
|
203
|
+
self.system_prompt = self.db_discussion.system_prompt
|
|
204
|
+
self.participants = self.db_discussion.participants or {}
|
|
205
|
+
self.active_branch_id = self.db_discussion.active_branch_id
|
|
206
|
+
self.metadata = self.db_discussion.discussion_metadata or {}
|
|
207
|
+
|
|
208
|
+
self.messages = []
|
|
209
|
+
self.message_index = {}
|
|
210
|
+
for msg in self.db_discussion.messages:
|
|
211
|
+
msg_dict = {c.name: getattr(msg, c.name) for c in msg.__table__.columns}
|
|
212
|
+
if 'message_metadata' in msg_dict:
|
|
213
|
+
msg_dict['metadata'] = msg_dict.pop('message_metadata')
|
|
214
|
+
self.messages.append(msg_dict)
|
|
215
|
+
self.message_index[msg.id] = msg_dict
|
|
216
|
+
|
|
217
|
+
def commit(self):
|
|
218
|
+
if not self._is_db_backed or not self.session: return
|
|
219
|
+
|
|
220
|
+
if self.db_discussion:
|
|
221
|
+
self.db_discussion.system_prompt = self.system_prompt
|
|
222
|
+
self.db_discussion.participants = self.participants
|
|
223
|
+
self.db_discussion.active_branch_id = self.active_branch_id
|
|
224
|
+
self.db_discussion.discussion_metadata = self.metadata
|
|
225
|
+
self.db_discussion.updated_at = datetime.utcnow()
|
|
226
|
+
|
|
227
|
+
for msg_id in self._messages_to_delete:
|
|
228
|
+
msg_to_del = self.session.query(self.db_manager.MessageModel).filter_by(id=msg_id).first()
|
|
229
|
+
if msg_to_del: self.session.delete(msg_to_del)
|
|
230
|
+
self._messages_to_delete.clear()
|
|
231
|
+
|
|
232
|
+
for msg_data in self.messages:
|
|
233
|
+
msg_id = msg_data['id']
|
|
234
|
+
msg_orm = self.session.query(self.db_manager.MessageModel).filter_by(id=msg_id).first()
|
|
235
|
+
|
|
236
|
+
if 'metadata' in msg_data:
|
|
237
|
+
msg_data['message_metadata'] = msg_data.pop('metadata',None)
|
|
238
|
+
|
|
239
|
+
if not msg_orm:
|
|
240
|
+
msg_data_copy = msg_data.copy()
|
|
241
|
+
valid_keys = {c.name for c in self.db_manager.MessageModel.__table__.columns}
|
|
242
|
+
filtered_msg_data = {k: v for k, v in msg_data_copy.items() if k in valid_keys}
|
|
243
|
+
msg_orm = self.db_manager.MessageModel(**filtered_msg_data)
|
|
244
|
+
self.session.add(msg_orm)
|
|
245
|
+
else:
|
|
246
|
+
for key, value in msg_data.items():
|
|
247
|
+
if hasattr(msg_orm, key):
|
|
248
|
+
setattr(msg_orm, key, value)
|
|
249
|
+
|
|
250
|
+
self.session.commit()
|
|
251
|
+
|
|
252
|
+
def touch(self):
|
|
253
|
+
self.updated_at = datetime.utcnow()
|
|
254
|
+
if self._is_db_backed and self.autosave:
|
|
255
|
+
self.commit()
|
|
256
|
+
|
|
257
|
+
@classmethod
|
|
258
|
+
def create_new(cls, lollms_client: 'LollmsClient', db_manager: Optional[DatabaseManager] = None, **kwargs) -> 'LollmsDiscussion':
|
|
259
|
+
init_args = {
|
|
260
|
+
'autosave': kwargs.pop('autosave', False),
|
|
261
|
+
'max_context_size': kwargs.pop('max_context_size', None)
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
if db_manager:
|
|
265
|
+
session = db_manager.get_session()
|
|
266
|
+
valid_keys = db_manager.DiscussionModel.__table__.columns.keys()
|
|
267
|
+
db_creation_args = {k: v for k, v in kwargs.items() if k in valid_keys}
|
|
268
|
+
db_discussion = db_manager.DiscussionModel(**db_creation_args)
|
|
269
|
+
session.add(db_discussion)
|
|
270
|
+
session.commit()
|
|
271
|
+
return cls(lollmsClient=lollms_client, discussion_id=db_discussion.id, db_manager=db_manager, **init_args)
|
|
272
|
+
else:
|
|
273
|
+
discussion_id = kwargs.get('discussion_id')
|
|
274
|
+
return cls(lollmsClient=lollms_client, discussion_id=discussion_id, **init_args)
|
|
93
275
|
|
|
94
|
-
# --- Configuration Methods ---
|
|
95
276
|
def set_system_prompt(self, prompt: str):
|
|
96
|
-
"""Sets the main system prompt for the discussion."""
|
|
97
277
|
self.system_prompt = prompt
|
|
278
|
+
self.touch()
|
|
98
279
|
|
|
99
280
|
def set_participants(self, participants: Dict[str, str]):
|
|
100
|
-
"""
|
|
101
|
-
Defines the participants and their roles ('user' or 'assistant').
|
|
102
|
-
|
|
103
|
-
Args:
|
|
104
|
-
participants: A dictionary mapping sender names to roles.
|
|
105
|
-
"""
|
|
106
281
|
for name, role in participants.items():
|
|
107
|
-
if role not in ["user", "assistant"]:
|
|
282
|
+
if role not in ["user", "assistant", "system"]:
|
|
108
283
|
raise ValueError(f"Invalid role '{role}' for participant '{name}'")
|
|
109
284
|
self.participants = participants
|
|
285
|
+
self.touch()
|
|
110
286
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
self
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
message = LollmsMessage(
|
|
131
|
-
sender=sender, sender_type=sender_type, content=content,
|
|
132
|
-
parent_id=parent_id, metadata=str(metadata or {}), images=images or []
|
|
133
|
-
)
|
|
134
|
-
if override_id:
|
|
135
|
-
message.id = override_id
|
|
136
|
-
|
|
137
|
-
self.messages.append(message)
|
|
138
|
-
self.message_index[message.id] = message
|
|
139
|
-
self.children_index[parent_id].append(message.id)
|
|
140
|
-
self.active_branch_id = message.id
|
|
141
|
-
return message.id
|
|
142
|
-
|
|
143
|
-
def get_branch(self, leaf_id: str) -> List[LollmsMessage]:
|
|
144
|
-
"""Gets the full branch of messages from the root to the specified leaf."""
|
|
287
|
+
def add_message(self, **kwargs) -> Dict:
|
|
288
|
+
msg_id = kwargs.get('id', str(uuid.uuid4()))
|
|
289
|
+
parent_id = kwargs.get('parent_id', self.active_branch_id or None)
|
|
290
|
+
|
|
291
|
+
message_data = {
|
|
292
|
+
'id': msg_id, 'parent_id': parent_id,
|
|
293
|
+
'discussion_id': self.id, 'created_at': datetime.utcnow(),
|
|
294
|
+
**kwargs
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
self.messages.append(message_data)
|
|
298
|
+
self.message_index[msg_id] = message_data
|
|
299
|
+
self.active_branch_id = msg_id
|
|
300
|
+
self.touch()
|
|
301
|
+
return message_data
|
|
302
|
+
|
|
303
|
+
def get_branch(self, leaf_id: Optional[str]) -> List[Dict]:
|
|
304
|
+
if not leaf_id: return []
|
|
145
305
|
branch = []
|
|
146
306
|
current_id: Optional[str] = leaf_id
|
|
147
307
|
while current_id and current_id in self.message_index:
|
|
148
308
|
msg = self.message_index[current_id]
|
|
149
309
|
branch.append(msg)
|
|
150
|
-
current_id = msg.parent_id
|
|
310
|
+
current_id = msg.get('parent_id')
|
|
151
311
|
return list(reversed(branch))
|
|
312
|
+
|
|
313
|
+
def chat(self, user_message: str, show_thoughts: bool = False, **kwargs) -> Dict:
|
|
314
|
+
if self.max_context_size is not None:
|
|
315
|
+
self.summarize_and_prune(self.max_context_size)
|
|
316
|
+
|
|
317
|
+
if user_message:
|
|
318
|
+
self.add_message(sender="user", sender_type="user", content=user_message)
|
|
152
319
|
|
|
153
|
-
|
|
154
|
-
|
|
320
|
+
from lollms_client.lollms_types import MSG_TYPE
|
|
321
|
+
|
|
322
|
+
is_streaming = "streaming_callback" in kwargs and kwargs["streaming_callback"] is not None
|
|
323
|
+
|
|
324
|
+
if is_streaming:
|
|
325
|
+
full_response_parts = []
|
|
326
|
+
token_buffer = ""
|
|
327
|
+
in_thought_block = False
|
|
328
|
+
original_callback = kwargs.get("streaming_callback")
|
|
329
|
+
|
|
330
|
+
def accumulating_callback(token: str, msg_type: MSG_TYPE = MSG_TYPE.MSG_TYPE_CHUNK):
|
|
331
|
+
nonlocal token_buffer, in_thought_block
|
|
332
|
+
continue_streaming = True
|
|
333
|
+
|
|
334
|
+
if token: token_buffer += token
|
|
335
|
+
|
|
336
|
+
while True:
|
|
337
|
+
if in_thought_block:
|
|
338
|
+
end_tag_pos = token_buffer.find("</think>")
|
|
339
|
+
if end_tag_pos != -1:
|
|
340
|
+
thought_chunk = token_buffer[:end_tag_pos]
|
|
341
|
+
if show_thoughts and original_callback and thought_chunk:
|
|
342
|
+
if not original_callback(thought_chunk, MSG_TYPE.MSG_TYPE_THOUGHT_CHUNK): continue_streaming = False
|
|
343
|
+
in_thought_block = False
|
|
344
|
+
token_buffer = token_buffer[end_tag_pos + len("</think>"):]
|
|
345
|
+
else:
|
|
346
|
+
if show_thoughts and original_callback and token_buffer:
|
|
347
|
+
if not original_callback(token_buffer, MSG_TYPE.MSG_TYPE_THOUGHT_CHUNK): continue_streaming = False
|
|
348
|
+
token_buffer = ""
|
|
349
|
+
break
|
|
350
|
+
else:
|
|
351
|
+
start_tag_pos = token_buffer.find("<think>")
|
|
352
|
+
if start_tag_pos != -1:
|
|
353
|
+
response_chunk = token_buffer[:start_tag_pos]
|
|
354
|
+
if response_chunk:
|
|
355
|
+
full_response_parts.append(response_chunk)
|
|
356
|
+
if original_callback:
|
|
357
|
+
if not original_callback(response_chunk, MSG_TYPE.MSG_TYPE_CHUNK): continue_streaming = False
|
|
358
|
+
in_thought_block = True
|
|
359
|
+
token_buffer = token_buffer[start_tag_pos + len("<think>"):]
|
|
360
|
+
else:
|
|
361
|
+
if token_buffer:
|
|
362
|
+
full_response_parts.append(token_buffer)
|
|
363
|
+
if original_callback:
|
|
364
|
+
if not original_callback(token_buffer, MSG_TYPE.MSG_TYPE_CHUNK): continue_streaming = False
|
|
365
|
+
token_buffer = ""
|
|
366
|
+
break
|
|
367
|
+
return continue_streaming
|
|
368
|
+
|
|
369
|
+
kwargs["streaming_callback"] = accumulating_callback
|
|
370
|
+
kwargs["stream"] = True
|
|
371
|
+
|
|
372
|
+
self.lollmsClient.chat(self, **kwargs)
|
|
373
|
+
ai_response = "".join(full_response_parts)
|
|
374
|
+
else:
|
|
375
|
+
kwargs["stream"] = False
|
|
376
|
+
raw_response = self.lollmsClient.chat(self, **kwargs)
|
|
377
|
+
ai_response = self.lollmsClient.remove_thinking_blocks(raw_response) if raw_response else ""
|
|
378
|
+
|
|
379
|
+
ai_message_obj = self.add_message(sender="assistant", sender_type="assistant", content=ai_response)
|
|
380
|
+
|
|
381
|
+
if self._is_db_backed and not self.autosave:
|
|
382
|
+
self.commit()
|
|
383
|
+
|
|
384
|
+
return ai_message_obj
|
|
385
|
+
|
|
386
|
+
def regenerate_branch(self, show_thoughts: bool = False, **kwargs) -> Dict:
|
|
387
|
+
last_message = self.message_index.get(self.active_branch_id)
|
|
388
|
+
if not last_message or last_message['sender_type'] != 'assistant':
|
|
389
|
+
raise ValueError("Can only regenerate from an assistant's message.")
|
|
390
|
+
|
|
391
|
+
parent_id = last_message['parent_id']
|
|
392
|
+
self.active_branch_id = parent_id
|
|
393
|
+
|
|
394
|
+
self.messages = [m for m in self.messages if m['id'] != last_message['id']]
|
|
395
|
+
self._messages_to_delete.append(last_message['id'])
|
|
396
|
+
self._rebuild_in_memory_indexes()
|
|
397
|
+
|
|
398
|
+
new_ai_response_obj = self.chat("", show_thoughts, **kwargs)
|
|
399
|
+
return new_ai_response_obj
|
|
400
|
+
|
|
401
|
+
def delete_branch(self, message_id: str):
|
|
402
|
+
if not self._is_db_backed:
|
|
403
|
+
raise NotImplementedError("Branch deletion is only supported for database-backed discussions.")
|
|
404
|
+
|
|
405
|
+
if message_id not in self.message_index:
|
|
406
|
+
raise ValueError("Message not found.")
|
|
407
|
+
|
|
408
|
+
msg_to_delete = self.session.query(self.db_manager.MessageModel).filter_by(id=message_id).first()
|
|
409
|
+
if msg_to_delete:
|
|
410
|
+
parent_id = msg_to_delete.parent_id
|
|
411
|
+
self.session.delete(msg_to_delete)
|
|
412
|
+
self.active_branch_id = parent_id
|
|
413
|
+
self.commit()
|
|
414
|
+
self._load_from_db(self.id)
|
|
415
|
+
|
|
416
|
+
def switch_to_branch(self, message_id: str):
|
|
155
417
|
if message_id not in self.message_index:
|
|
156
|
-
raise ValueError(f"Message ID {message_id} not found in discussion.")
|
|
418
|
+
raise ValueError(f"Message ID '{message_id}' not found in the current discussion.")
|
|
157
419
|
self.active_branch_id = message_id
|
|
420
|
+
if self._is_db_backed:
|
|
421
|
+
self.db_discussion.active_branch_id = message_id
|
|
422
|
+
if self.autosave: self.commit()
|
|
423
|
+
|
|
424
|
+
def format_discussion(self, max_allowed_tokens: int, branch_tip_id: Optional[str] = None) -> str:
|
|
425
|
+
return self.export("lollms_text", branch_tip_id, max_allowed_tokens)
|
|
158
426
|
|
|
159
|
-
# --- Persistence ---
|
|
160
|
-
def save_to_disk(self, file_path: str):
|
|
161
|
-
"""Saves the entire discussion state to a YAML file."""
|
|
162
|
-
data = {
|
|
163
|
-
'version': self.version, 'active_branch_id': self.active_branch_id,
|
|
164
|
-
'system_prompt': self.system_prompt, 'participants': self.participants,
|
|
165
|
-
'scratchpad': self.scratchpad, 'messages': [m.to_dict() for m in self.messages]
|
|
166
|
-
}
|
|
167
|
-
with open(file_path, 'w', encoding='utf-8') as file:
|
|
168
|
-
yaml.dump(data, file, allow_unicode=True, sort_keys=False)
|
|
169
|
-
|
|
170
|
-
def load_from_disk(self, file_path: str):
|
|
171
|
-
"""Loads a discussion state from a YAML file."""
|
|
172
|
-
with open(file_path, 'r', encoding='utf-8') as file:
|
|
173
|
-
data = yaml.safe_load(file)
|
|
174
|
-
|
|
175
|
-
self._reset_state()
|
|
176
|
-
version = data.get("version", 1)
|
|
177
|
-
if version > self.version:
|
|
178
|
-
raise ValueError(f"File version {version} is newer than supported version {self.version}.")
|
|
179
|
-
|
|
180
|
-
self.active_branch_id = data.get('active_branch_id')
|
|
181
|
-
self.system_prompt = data.get('system_prompt', None)
|
|
182
|
-
self.participants = data.get('participants', {})
|
|
183
|
-
self.scratchpad = data.get('scratchpad', None)
|
|
184
|
-
|
|
185
|
-
for msg_data in data.get('messages', []):
|
|
186
|
-
msg = LollmsMessage(
|
|
187
|
-
sender=msg_data['sender'], sender_type=msg_data.get('sender_type', 'user'),
|
|
188
|
-
content=msg_data['content'], parent_id=msg_data.get('parent_id'),
|
|
189
|
-
id=msg_data.get('id', str(uuid.uuid4())), metadata=msg_data.get('metadata', '{}'),
|
|
190
|
-
images=msg_data.get('images', [])
|
|
191
|
-
)
|
|
192
|
-
self.messages.append(msg)
|
|
193
|
-
self.message_index[msg.id] = msg
|
|
194
|
-
self.children_index[msg.parent_id].append(msg.id)
|
|
195
|
-
|
|
196
|
-
# --- Context Management and Formatting ---
|
|
197
427
|
def _get_full_system_prompt(self) -> Optional[str]:
|
|
198
|
-
"""Combines the scratchpad and system prompt into a single string for the LLM."""
|
|
199
428
|
full_sys_prompt_parts = []
|
|
200
|
-
if self.scratchpad
|
|
429
|
+
if self.scratchpad:
|
|
201
430
|
full_sys_prompt_parts.append("--- KNOWLEDGE SCRATCHPAD ---")
|
|
202
431
|
full_sys_prompt_parts.append(self.scratchpad.strip())
|
|
203
432
|
full_sys_prompt_parts.append("--- END SCRATCHPAD ---")
|
|
204
433
|
|
|
205
434
|
if self.system_prompt and self.system_prompt.strip():
|
|
206
435
|
full_sys_prompt_parts.append(self.system_prompt.strip())
|
|
436
|
+
|
|
207
437
|
return "\n\n".join(full_sys_prompt_parts) if full_sys_prompt_parts else None
|
|
208
438
|
|
|
209
|
-
def
|
|
210
|
-
"""
|
|
211
|
-
Checks context size and, if exceeded, summarizes the oldest messages
|
|
212
|
-
into the scratchpad and prunes them to free up token space.
|
|
213
|
-
"""
|
|
439
|
+
def export(self, format_type: str, branch_tip_id: Optional[str] = None, max_allowed_tokens: Optional[int] = None) -> Union[List[Dict], str]:
|
|
214
440
|
if branch_tip_id is None: branch_tip_id = self.active_branch_id
|
|
215
|
-
if not branch_tip_id
|
|
441
|
+
if not branch_tip_id and format_type in ["lollms_text", "openai_chat", "ollama_chat"]:
|
|
442
|
+
return "" if format_type == "lollms_text" else []
|
|
216
443
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
444
|
+
branch = self.get_branch(branch_tip_id)
|
|
445
|
+
full_system_prompt = self._get_full_system_prompt()
|
|
446
|
+
|
|
447
|
+
participants = self.participants or {}
|
|
448
|
+
|
|
449
|
+
if format_type == "lollms_text":
|
|
450
|
+
prompt_parts = []
|
|
451
|
+
current_tokens = 0
|
|
452
|
+
|
|
453
|
+
if full_system_prompt:
|
|
454
|
+
sys_msg_text = f"!@>system:\n{full_system_prompt}\n"
|
|
455
|
+
sys_tokens = self.lollmsClient.count_tokens(sys_msg_text)
|
|
456
|
+
if max_allowed_tokens is None or sys_tokens <= max_allowed_tokens:
|
|
457
|
+
prompt_parts.append(sys_msg_text)
|
|
458
|
+
current_tokens += sys_tokens
|
|
459
|
+
|
|
460
|
+
for msg in reversed(branch):
|
|
461
|
+
sender_str = msg['sender'].replace(':', '').replace('!@>', '')
|
|
462
|
+
content = msg['content'].strip()
|
|
463
|
+
if msg.get('images'): content += f"\n({len(msg['images'])} image(s) attached)"
|
|
464
|
+
msg_text = f"!@>{sender_str}:\n{content}\n"
|
|
465
|
+
msg_tokens = self.lollmsClient.count_tokens(msg_text)
|
|
466
|
+
|
|
467
|
+
if max_allowed_tokens is not None and current_tokens + msg_tokens > max_allowed_tokens: break
|
|
468
|
+
prompt_parts.insert(1 if full_system_prompt else 0, msg_text)
|
|
469
|
+
current_tokens += msg_tokens
|
|
470
|
+
return "".join(prompt_parts).strip()
|
|
471
|
+
|
|
472
|
+
messages = []
|
|
473
|
+
if full_system_prompt:
|
|
474
|
+
messages.append({"role": "system", "content": full_system_prompt})
|
|
475
|
+
|
|
476
|
+
for msg in branch:
|
|
477
|
+
role = participants.get(msg['sender'], "user")
|
|
478
|
+
content = msg.get('content', '').strip()
|
|
479
|
+
images = msg.get('images', [])
|
|
480
|
+
|
|
481
|
+
if format_type == "openai_chat":
|
|
482
|
+
if images:
|
|
483
|
+
content_parts = [{"type": "text", "text": content}] if content else []
|
|
484
|
+
for img in images:
|
|
485
|
+
image_url = img['data'] if img['type'] == 'url' else f"data:image/jpeg;base64,{img['data']}"
|
|
486
|
+
content_parts.append({"type": "image_url", "image_url": {"url": image_url, "detail": "auto"}})
|
|
487
|
+
messages.append({"role": role, "content": content_parts})
|
|
488
|
+
else:
|
|
489
|
+
messages.append({"role": role, "content": content})
|
|
490
|
+
elif format_type == "ollama_chat":
|
|
491
|
+
message_dict = {"role": role, "content": content}
|
|
492
|
+
base64_images = [img['data'] for img in images or [] if img['type'] == 'base64']
|
|
493
|
+
if base64_images:
|
|
494
|
+
message_dict["images"] = base64_images
|
|
495
|
+
messages.append(message_dict)
|
|
496
|
+
else:
|
|
497
|
+
raise ValueError(f"Unsupported export format_type: {format_type}")
|
|
498
|
+
|
|
499
|
+
return messages
|
|
500
|
+
|
|
501
|
+
def summarize_and_prune(self, max_tokens: int, preserve_last_n: int = 4):
|
|
502
|
+
branch_tip_id = self.active_branch_id
|
|
503
|
+
if not branch_tip_id: return
|
|
504
|
+
|
|
505
|
+
current_prompt_text = self.format_discussion(999999, branch_tip_id)
|
|
506
|
+
current_tokens = self.lollmsClient.count_tokens(current_prompt_text)
|
|
507
|
+
if current_tokens <= max_tokens: return
|
|
220
508
|
|
|
221
509
|
branch = self.get_branch(branch_tip_id)
|
|
222
|
-
if len(branch) <= preserve_last_n: return
|
|
510
|
+
if len(branch) <= preserve_last_n: return
|
|
223
511
|
|
|
224
512
|
messages_to_prune = branch[:-preserve_last_n]
|
|
225
|
-
|
|
226
|
-
text_to_summarize = "\n\n".join([f"{self.participants.get(m.sender, 'user').capitalize()}: {m.content}" for m in messages_to_prune])
|
|
513
|
+
text_to_summarize = "\n\n".join([f"{m['sender']}: {m['content']}" for m in messages_to_prune])
|
|
227
514
|
|
|
228
|
-
summary_prompt =
|
|
229
|
-
"You are a summarization expert. Read the following conversation excerpt and create a "
|
|
230
|
-
"concise, factual summary of all key information, decisions, and outcomes. This summary "
|
|
231
|
-
"will be placed in a knowledge scratchpad for future reference. Omit conversational filler.\n\n"
|
|
232
|
-
f"CONVERSATION EXCERPT:\n---\n{text_to_summarize}\n---\n\nCONCISE SUMMARY:"
|
|
233
|
-
)
|
|
515
|
+
summary_prompt = f"Concisely summarize this conversation excerpt:\n---\n{text_to_summarize}\n---\nSUMMARY:"
|
|
234
516
|
try:
|
|
235
|
-
summary = self.lollmsClient.generate_text(summary_prompt,
|
|
517
|
+
summary = self.lollmsClient.generate_text(summary_prompt, n_predict=300, temperature=0.1)
|
|
236
518
|
except Exception as e:
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
summary_block = f"--- Summary of earlier conversation (pruned on {uuid.uuid4().hex[:8]}) ---\n{summary.strip()}"
|
|
240
|
-
self.update_scratchpad(summary_block, append=True)
|
|
241
|
-
|
|
242
|
-
ids_to_prune = {msg.id for msg in messages_to_prune}
|
|
243
|
-
new_root_of_branch = messages_to_keep[0]
|
|
244
|
-
original_parent_id = messages_to_prune[0].parent_id
|
|
245
|
-
|
|
246
|
-
self.message_index[new_root_of_branch.id].parent_id = original_parent_id
|
|
247
|
-
if original_parent_id in self.children_index:
|
|
248
|
-
self.children_index[original_parent_id] = [mid for mid in self.children_index[original_parent_id] if mid != messages_to_prune[0].id]
|
|
249
|
-
self.children_index[original_parent_id].append(new_root_of_branch.id)
|
|
250
|
-
|
|
251
|
-
for msg_id in ids_to_prune:
|
|
252
|
-
self.message_index.pop(msg_id, None)
|
|
253
|
-
self.children_index.pop(msg_id, None)
|
|
254
|
-
self.messages = [m for m in self.messages if m.id not in ids_to_prune]
|
|
255
|
-
|
|
256
|
-
new_prompt_text = self.export("lollms_text", branch_tip_id)
|
|
257
|
-
new_tokens = len(self.lollmsClient.binding.tokenize(new_prompt_text))
|
|
258
|
-
return {"pruned": True, "tokens_saved": current_tokens - new_tokens, "summary_added": True}
|
|
259
|
-
|
|
260
|
-
def format_discussion(self, max_allowed_tokens: int, splitter_text: str = "!@>", branch_tip_id: Optional[str] = None) -> str:
|
|
261
|
-
"""
|
|
262
|
-
Formats the discussion into a single string for instruct models,
|
|
263
|
-
truncating from the start to respect the token limit.
|
|
264
|
-
|
|
265
|
-
Args:
|
|
266
|
-
max_allowed_tokens: The maximum token limit for the final prompt.
|
|
267
|
-
splitter_text: The separator token to use (e.g., '!@>').
|
|
268
|
-
branch_tip_id: The ID of the branch to format. Defaults to active.
|
|
269
|
-
|
|
270
|
-
Returns:
|
|
271
|
-
A single, truncated prompt string.
|
|
272
|
-
"""
|
|
273
|
-
if branch_tip_id is None:
|
|
274
|
-
branch_tip_id = self.active_branch_id
|
|
275
|
-
|
|
276
|
-
branch_msgs = self.get_branch(branch_tip_id) if branch_tip_id else []
|
|
277
|
-
full_system_prompt = self._get_full_system_prompt()
|
|
278
|
-
|
|
279
|
-
prompt_parts = []
|
|
280
|
-
current_tokens = 0
|
|
281
|
-
|
|
282
|
-
# Start with the system prompt if defined
|
|
283
|
-
if full_system_prompt:
|
|
284
|
-
sys_msg_text = f"{splitter_text}system:\n{full_system_prompt}\n"
|
|
285
|
-
sys_tokens = len(self.lollmsClient.binding.tokenize(sys_msg_text))
|
|
286
|
-
if sys_tokens <= max_allowed_tokens:
|
|
287
|
-
prompt_parts.append(sys_msg_text)
|
|
288
|
-
current_tokens += sys_tokens
|
|
289
|
-
|
|
290
|
-
# Iterate from newest to oldest to fill the remaining context
|
|
291
|
-
for msg in reversed(branch_msgs):
|
|
292
|
-
sender_str = msg.sender.replace(':', '').replace(splitter_text, '')
|
|
293
|
-
content = msg.content.strip()
|
|
294
|
-
if msg.images:
|
|
295
|
-
content += f"\n({len(msg.images)} image(s) attached)"
|
|
296
|
-
|
|
297
|
-
msg_text = f"{splitter_text}{sender_str}:\n{content}\n"
|
|
298
|
-
msg_tokens = len(self.lollmsClient.binding.tokenize(msg_text))
|
|
519
|
+
print(f"\n[WARNING] Pruning failed, couldn't generate summary: {e}")
|
|
520
|
+
return
|
|
299
521
|
|
|
300
|
-
|
|
301
|
-
|
|
522
|
+
new_scratchpad_content = f"{self.scratchpad}\n\n--- Summary of earlier conversation ---\n{summary.strip()}"
|
|
523
|
+
self.scratchpad = new_scratchpad_content.strip()
|
|
302
524
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
525
|
+
pruned_ids = {msg['id'] for msg in messages_to_prune}
|
|
526
|
+
self.messages = [m for m in self.messages if m['id'] not in pruned_ids]
|
|
527
|
+
self._messages_to_delete.extend(list(pruned_ids))
|
|
528
|
+
self._rebuild_in_memory_indexes()
|
|
307
529
|
|
|
530
|
+
print(f"\n[INFO] Discussion auto-pruned. {len(messages_to_prune)} messages summarized.")
|
|
308
531
|
|
|
309
|
-
def
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
532
|
+
def to_dict(self):
|
|
533
|
+
messages_copy = [msg.copy() for msg in self.messages]
|
|
534
|
+
for msg in messages_copy:
|
|
535
|
+
if 'created_at' in msg and isinstance(msg['created_at'], datetime):
|
|
536
|
+
msg['created_at'] = msg['created_at'].isoformat()
|
|
537
|
+
if 'message_metadata' in msg:
|
|
538
|
+
msg['metadata'] = msg.pop('message_metadata')
|
|
315
539
|
|
|
316
|
-
|
|
317
|
-
|
|
540
|
+
return {
|
|
541
|
+
"id": self.id, "system_prompt": self.system_prompt,
|
|
542
|
+
"participants": self.participants, "active_branch_id": self.active_branch_id,
|
|
543
|
+
"metadata": self.metadata, "scratchpad": self.scratchpad,
|
|
544
|
+
"messages": messages_copy,
|
|
545
|
+
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
546
|
+
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
|
547
|
+
}
|
|
318
548
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
truncated_tokens = len(mock_client.binding.tokenize(truncated_prompt))
|
|
405
|
-
print(f"Truncated prompt tokens: {truncated_tokens}")
|
|
406
|
-
print("Truncated Prompt:\n" + "="*20 + f"\n{truncated_prompt}\n" + "="*20)
|
|
407
|
-
|
|
408
|
-
# Verification
|
|
409
|
-
assert truncated_tokens <= 80
|
|
410
|
-
# Check that it contains the newest message that fits
|
|
411
|
-
assert "Message #19" in truncated_prompt or "Message #20" in truncated_prompt
|
|
412
|
-
print("✅ format_discussion correctly truncated the prompt.")
|
|
549
|
+
def load_from_dict(self, data: Dict):
|
|
550
|
+
self._reset_in_memory_state()
|
|
551
|
+
self.id = data.get("id", str(uuid.uuid4()))
|
|
552
|
+
self.system_prompt = data.get("system_prompt")
|
|
553
|
+
self.participants = data.get("participants", {})
|
|
554
|
+
self.active_branch_id = data.get("active_branch_id")
|
|
555
|
+
self.metadata = data.get("metadata", {})
|
|
556
|
+
self.scratchpad = data.get("scratchpad", "")
|
|
557
|
+
|
|
558
|
+
loaded_messages = data.get("messages", [])
|
|
559
|
+
for msg in loaded_messages:
|
|
560
|
+
if 'created_at' in msg and isinstance(msg['created_at'], str):
|
|
561
|
+
try:
|
|
562
|
+
msg['created_at'] = datetime.fromisoformat(msg['created_at'])
|
|
563
|
+
except ValueError:
|
|
564
|
+
msg['created_at'] = datetime.utcnow()
|
|
565
|
+
self.messages = loaded_messages
|
|
566
|
+
|
|
567
|
+
self.created_at = datetime.fromisoformat(data['created_at']) if data.get('created_at') else datetime.utcnow()
|
|
568
|
+
self.updated_at = datetime.fromisoformat(data['updated_at']) if data.get('updated_at') else self.created_at
|
|
569
|
+
self._rebuild_in_memory_indexes()
|
|
570
|
+
|
|
571
|
+
def _rebuild_in_memory_indexes(self):
|
|
572
|
+
self.message_index = {msg['id']: msg for msg in self.messages}
|
|
573
|
+
|
|
574
|
+
@staticmethod
|
|
575
|
+
def migrate(lollms_client: 'LollmsClient', db_manager: DatabaseManager, folder_path: Union[str, Path]):
|
|
576
|
+
folder = Path(folder_path)
|
|
577
|
+
if not folder.is_dir():
|
|
578
|
+
print(f"Error: Path '{folder}' is not a valid directory.")
|
|
579
|
+
return
|
|
580
|
+
|
|
581
|
+
print(f"\n--- Starting Migration from '{folder}' ---")
|
|
582
|
+
discussion_files = list(folder.glob("*.json")) + list(folder.glob("*.yaml"))
|
|
583
|
+
session = db_manager.get_session()
|
|
584
|
+
for i, file_path in enumerate(discussion_files):
|
|
585
|
+
print(f"Migrating file {i+1}/{len(discussion_files)}: {file_path.name} ... ", end="")
|
|
586
|
+
try:
|
|
587
|
+
in_memory_discussion = LollmsDiscussion.create_new(lollms_client=lollms_client)
|
|
588
|
+
if file_path.suffix.lower() == ".json":
|
|
589
|
+
with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f)
|
|
590
|
+
else:
|
|
591
|
+
with open(file_path, 'r', encoding='utf-8') as f: data = yaml.safe_load(f)
|
|
592
|
+
|
|
593
|
+
in_memory_discussion.load_from_dict(data)
|
|
594
|
+
discussion_id = in_memory_discussion.id
|
|
595
|
+
|
|
596
|
+
existing = session.query(db_manager.DiscussionModel).filter_by(id=discussion_id).first()
|
|
597
|
+
if existing:
|
|
598
|
+
print("SKIPPED (already exists)")
|
|
599
|
+
continue
|
|
600
|
+
|
|
601
|
+
valid_disc_keys = {c.name for c in db_manager.DiscussionModel.__table__.columns}
|
|
602
|
+
valid_msg_keys = {c.name for c in db_manager.MessageModel.__table__.columns}
|
|
603
|
+
|
|
604
|
+
discussion_data = {
|
|
605
|
+
'id': in_memory_discussion.id,
|
|
606
|
+
'system_prompt': in_memory_discussion.system_prompt,
|
|
607
|
+
'participants': in_memory_discussion.participants,
|
|
608
|
+
'active_branch_id': in_memory_discussion.active_branch_id,
|
|
609
|
+
'discussion_metadata': in_memory_discussion.metadata,
|
|
610
|
+
'created_at': in_memory_discussion.created_at,
|
|
611
|
+
'updated_at': in_memory_discussion.updated_at
|
|
612
|
+
}
|
|
613
|
+
project_name = in_memory_discussion.metadata.get('project_name', file_path.stem)
|
|
614
|
+
if 'project_name' in valid_disc_keys:
|
|
615
|
+
discussion_data['project_name'] = project_name
|
|
616
|
+
|
|
617
|
+
db_discussion = db_manager.DiscussionModel(**discussion_data)
|
|
618
|
+
session.add(db_discussion)
|
|
619
|
+
|
|
620
|
+
for msg_data in in_memory_discussion.messages:
|
|
621
|
+
msg_data['discussion_id'] = db_discussion.id
|
|
622
|
+
if 'metadata' in msg_data:
|
|
623
|
+
msg_data['message_metadata'] = msg_data.pop('metadata')
|
|
624
|
+
filtered_msg_data = {k: v for k, v in msg_data.items() if k in valid_msg_keys}
|
|
625
|
+
msg_orm = db_manager.MessageModel(**filtered_msg_data)
|
|
626
|
+
session.add(msg_orm)
|
|
627
|
+
|
|
628
|
+
print("OK")
|
|
629
|
+
except Exception as e:
|
|
630
|
+
print(f"FAILED. Error: {e}")
|
|
631
|
+
session.rollback()
|
|
632
|
+
session.commit()
|
|
633
|
+
session.close()
|