lollms-client 0.20.10__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lollms-client might be problematic. Click here for more details.

@@ -1,412 +1,633 @@
1
- # lollms_discussion.py
2
-
3
1
  import yaml
4
- from dataclasses import dataclass, field
5
- from typing import List, Dict, Optional, Union, Any
2
+ import json
3
+ import base64
4
+ import os
6
5
  import uuid
6
+ import shutil
7
7
  from collections import defaultdict
8
+ from datetime import datetime
9
+ from typing import List, Dict, Optional, Union, Any, Type, Callable
10
+ from pathlib import Path
11
+
12
+ from sqlalchemy import (create_engine, Column, String, Text, Integer, DateTime,
13
+ ForeignKey, JSON, Boolean, LargeBinary, Index)
14
+ from sqlalchemy.orm import sessionmaker, relationship, Session, declarative_base
15
+ from sqlalchemy.types import TypeDecorator
16
+
17
+ try:
18
+ from cryptography.fernet import Fernet, InvalidToken
19
+ from cryptography.hazmat.primitives import hashes
20
+ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
21
+ from cryptography.hazmat.backends import default_backend
22
+ ENCRYPTION_AVAILABLE = True
23
+ except ImportError:
24
+ ENCRYPTION_AVAILABLE = False
25
+
26
+ if False:
27
+ from lollms_client import LollmsClient
28
+ from lollms_client.lollms_types import MSG_TYPE
29
+
30
+ class EncryptedString(TypeDecorator):
31
+ impl = LargeBinary
32
+ cache_ok = True
33
+
34
+ def __init__(self, key: str, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ if not ENCRYPTION_AVAILABLE: raise ImportError("'cryptography' is required for DB encryption.")
37
+ self.salt = b'lollms-fixed-salt-for-db-encryption'
38
+ kdf = PBKDF2HMAC(
39
+ algorithm=hashes.SHA256(), length=32, salt=self.salt,
40
+ iterations=480000, backend=default_backend()
41
+ )
42
+ derived_key = base64.urlsafe_b64encode(kdf.derive(key.encode()))
43
+ self.fernet = Fernet(derived_key)
8
44
 
9
- # It's good practice to forward-declare the type for the client to avoid circular imports.
10
- if False:
11
- from lollms.client import LollmsClient
12
-
13
-
14
- @dataclass
15
- class LollmsMessage:
16
- """
17
- Represents a single message in a LollmsDiscussion, including its content,
18
- sender, and relationship within the discussion tree.
19
- """
20
- sender: str
21
- sender_type: str
22
- content: str
23
- id: str = field(default_factory=lambda: str(uuid.uuid4()))
24
- parent_id: Optional[str] = None
25
- metadata: str = "{}"
26
- images: List[Dict[str, str]] = field(default_factory=list)
27
-
28
- def to_dict(self) -> Dict[str, Any]:
29
- """Serializes the message object to a dictionary."""
30
- return {
31
- 'sender': self.sender,
32
- 'sender_type': self.sender_type,
33
- 'content': self.content,
34
- 'id': self.id,
35
- 'parent_id': self.parent_id,
36
- 'metadata': self.metadata,
37
- 'images': self.images
38
- }
45
+ def process_bind_param(self, value: Optional[str], dialect) -> Optional[bytes]:
46
+ if value is None: return None
47
+ return self.fernet.encrypt(value.encode('utf-8'))
48
+
49
+ def process_result_value(self, value: Optional[bytes], dialect) -> Optional[str]:
50
+ if value is None: return None
51
+ try:
52
+ return self.fernet.decrypt(value).decode('utf-8')
53
+ except InvalidToken:
54
+ return "<DECRYPTION_FAILED: Invalid Key or Corrupt Data>"
55
+
56
+ def create_dynamic_models(discussion_mixin: Optional[Type] = None, message_mixin: Optional[Type] = None, encryption_key: Optional[str] = None):
57
+ Base = declarative_base()
58
+ EncryptedText = EncryptedString(encryption_key) if encryption_key else Text
59
+
60
+ class DiscussionBase(Base):
61
+ __abstract__ = True
62
+ id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
63
+ system_prompt = Column(EncryptedText, nullable=True)
64
+ participants = Column(JSON, nullable=True, default=dict)
65
+ active_branch_id = Column(String, nullable=True)
66
+ discussion_metadata = Column(JSON, nullable=True, default=dict)
67
+ created_at = Column(DateTime, default=datetime.utcnow)
68
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
69
+
70
+ class MessageBase(Base):
71
+ __abstract__ = True
72
+ id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
73
+ discussion_id = Column(String, ForeignKey('discussions.id'), nullable=False)
74
+ parent_id = Column(String, ForeignKey('messages.id'), nullable=True)
75
+ sender = Column(String, nullable=False)
76
+ sender_type = Column(String, nullable=False)
77
+ content = Column(EncryptedText, nullable=False)
78
+ message_metadata = Column(JSON, nullable=True, default=dict)
79
+ images = Column(JSON, nullable=True, default=list)
80
+ created_at = Column(DateTime, default=datetime.utcnow)
81
+
82
+ discussion_attrs = {'__tablename__': 'discussions'}
83
+ if hasattr(discussion_mixin, '__table_args__'):
84
+ discussion_attrs['__table_args__'] = discussion_mixin.__table_args__
85
+ if discussion_mixin:
86
+ for attr, col in discussion_mixin.__dict__.items():
87
+ if isinstance(col, Column):
88
+ discussion_attrs[attr] = col
89
+
90
+ message_attrs = {'__tablename__': 'messages'}
91
+ if hasattr(message_mixin, '__table_args__'):
92
+ message_attrs['__table_args__'] = message_mixin.__table_args__
93
+ if message_mixin:
94
+ for attr, col in message_mixin.__dict__.items():
95
+ if isinstance(col, Column):
96
+ message_attrs[attr] = col
97
+
98
+ discussion_bases = (discussion_mixin, DiscussionBase) if discussion_mixin else (DiscussionBase,)
99
+ DynamicDiscussion = type('Discussion', discussion_bases, discussion_attrs)
100
+
101
+ message_bases = (message_mixin, MessageBase) if message_mixin else (MessageBase,)
102
+ DynamicMessage = type('Message', message_bases, message_attrs)
103
+
104
+ DynamicDiscussion.messages = relationship(DynamicMessage, back_populates="discussion", cascade="all, delete-orphan", lazy="joined")
105
+ DynamicMessage.discussion = relationship(DynamicDiscussion, back_populates="messages")
106
+
107
+ return Base, DynamicDiscussion, DynamicMessage
39
108
 
109
+ class DatabaseManager:
110
+ def __init__(self, db_path: str, discussion_mixin: Optional[Type] = None, message_mixin: Optional[Type] = None, encryption_key: Optional[str] = None):
111
+ if not db_path: raise ValueError("Database path cannot be empty.")
112
+ self.Base, self.DiscussionModel, self.MessageModel = create_dynamic_models(
113
+ discussion_mixin, message_mixin, encryption_key
114
+ )
115
+ self.engine = create_engine(db_path)
116
+ self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
117
+ self.create_tables()
118
+
119
+ def create_tables(self):
120
+ self.Base.metadata.create_all(bind=self.engine)
121
+
122
+ def get_session(self) -> Session:
123
+ return self.SessionLocal()
124
+
125
+ def list_discussions(self) -> List[Dict]:
126
+ session = self.get_session()
127
+ discussions = session.query(self.DiscussionModel).all()
128
+ session.close()
129
+ discussion_list = []
130
+ for disc in discussions:
131
+ disc_dict = {c.name: getattr(disc, c.name) for c in disc.__table__.columns}
132
+ discussion_list.append(disc_dict)
133
+ return discussion_list
134
+
135
+ def get_discussion(self, lollms_client: 'LollmsClient', discussion_id: str, **kwargs) -> Optional['LollmsDiscussion']:
136
+ session = self.get_session()
137
+ db_disc = session.query(self.DiscussionModel).filter_by(id=discussion_id).first()
138
+ session.close()
139
+ if db_disc:
140
+ return LollmsDiscussion(lollmsClient=lollms_client, discussion_id=discussion_id, db_manager=self, **kwargs)
141
+ return None
142
+
143
+ def search_discussions(self, **criteria) -> List[Dict]:
144
+ session = self.get_session()
145
+ query = session.query(self.DiscussionModel)
146
+ for key, value in criteria.items():
147
+ query = query.filter(getattr(self.DiscussionModel, key).ilike(f"%{value}%"))
148
+ discussions = query.all()
149
+ session.close()
150
+ discussion_list = []
151
+ for disc in discussions:
152
+ disc_dict = {c.name: getattr(disc, c.name) for c in disc.__table__.columns}
153
+ discussion_list.append(disc_dict)
154
+ return discussion_list
155
+
156
+ def delete_discussion(self, discussion_id: str):
157
+ session = self.get_session()
158
+ db_disc = session.query(self.DiscussionModel).filter_by(id=discussion_id).first()
159
+ if db_disc:
160
+ session.delete(db_disc)
161
+ session.commit()
162
+ session.close()
40
163
 
41
164
  class LollmsDiscussion:
42
- """
43
- Manages a branching conversation tree, including system prompts, participants,
44
- an internal knowledge scratchpad, and context pruning capabilities.
45
- """
46
-
47
- def __init__(self, lollmsClient: 'LollmsClient'):
48
- """
49
- Initializes a new LollmsDiscussion instance.
50
-
51
- Args:
52
- lollmsClient: An instance of LollmsClient, required for tokenization.
53
- """
165
+ def __init__(self, lollmsClient: 'LollmsClient', discussion_id: Optional[str] = None, db_manager: Optional[DatabaseManager] = None, autosave: bool = False, max_context_size: Optional[int] = None):
54
166
  self.lollmsClient = lollmsClient
55
- self.version: int = 3 # Current version of the format with scratchpad support
56
- self._reset_state()
167
+ self.db_manager = db_manager
168
+ self.autosave = autosave
169
+ self.max_context_size = max_context_size
170
+ self._is_db_backed = db_manager is not None
171
+
172
+ self.session = None
173
+ self.db_discussion = None
174
+ self._messages_to_delete = []
57
175
 
58
- def _reset_state(self):
59
- """Helper to reset all discussion attributes to their defaults."""
60
- self.messages: List[LollmsMessage] = []
61
- self.active_branch_id: Optional[str] = None
62
- self.message_index: Dict[str, LollmsMessage] = {}
63
- self.children_index: Dict[Optional[str], List[str]] = defaultdict(list)
64
- self.participants: Dict[str, str] = {}
65
- self.system_prompt: Optional[str] = None
66
- self.scratchpad: Optional[str] = None
67
-
68
- # --- Scratchpad Management Methods ---
69
- def set_scratchpad(self, content: str):
70
- """Sets or replaces the entire content of the internal scratchpad."""
71
- self.scratchpad = content
72
-
73
- def update_scratchpad(self, new_content: str, append: bool = True):
74
- """
75
- Updates the scratchpad. By default, it appends with a newline separator.
76
-
77
- Args:
78
- new_content: The new text to add to the scratchpad.
79
- append: If True, appends to existing content. If False, replaces it.
80
- """
81
- if append and self.scratchpad:
82
- self.scratchpad += f"\n{new_content}"
83
- else:
84
- self.scratchpad = new_content
176
+ self._reset_in_memory_state()
85
177
 
86
- def get_scratchpad(self) -> Optional[str]:
87
- """Returns the current content of the scratchpad."""
88
- return self.scratchpad
178
+ if self._is_db_backed:
179
+ if not discussion_id: raise ValueError("A discussion_id is required for database-backed discussions.")
180
+ self.session = db_manager.get_session()
181
+ self._load_from_db(discussion_id)
182
+ else:
183
+ self.id = discussion_id or str(uuid.uuid4())
184
+ self.created_at = datetime.utcnow()
185
+ self.updated_at = self.created_at
89
186
 
90
- def clear_scratchpad(self):
91
- """Clears the scratchpad content."""
92
- self.scratchpad = None
187
+ def _reset_in_memory_state(self):
188
+ self.id: str = ""
189
+ self.system_prompt: Optional[str] = None
190
+ self.participants: Dict[str, str] = {}
191
+ self.active_branch_id: Optional[str] = None
192
+ self.metadata: Dict[str, Any] = {}
193
+ self.scratchpad: str = ""
194
+ self.messages: List[Dict] = []
195
+ self.message_index: Dict[str, Dict] = {}
196
+ self.created_at: Optional[datetime] = None
197
+ self.updated_at: Optional[datetime] = None
198
+
199
+ def _load_from_db(self, discussion_id: str):
200
+ self.db_discussion = self.session.query(self.db_manager.DiscussionModel).filter(self.db_manager.DiscussionModel.id == discussion_id).one()
201
+
202
+ self.id = self.db_discussion.id
203
+ self.system_prompt = self.db_discussion.system_prompt
204
+ self.participants = self.db_discussion.participants or {}
205
+ self.active_branch_id = self.db_discussion.active_branch_id
206
+ self.metadata = self.db_discussion.discussion_metadata or {}
207
+
208
+ self.messages = []
209
+ self.message_index = {}
210
+ for msg in self.db_discussion.messages:
211
+ msg_dict = {c.name: getattr(msg, c.name) for c in msg.__table__.columns}
212
+ if 'message_metadata' in msg_dict:
213
+ msg_dict['metadata'] = msg_dict.pop('message_metadata')
214
+ self.messages.append(msg_dict)
215
+ self.message_index[msg.id] = msg_dict
216
+
217
+ def commit(self):
218
+ if not self._is_db_backed or not self.session: return
219
+
220
+ if self.db_discussion:
221
+ self.db_discussion.system_prompt = self.system_prompt
222
+ self.db_discussion.participants = self.participants
223
+ self.db_discussion.active_branch_id = self.active_branch_id
224
+ self.db_discussion.discussion_metadata = self.metadata
225
+ self.db_discussion.updated_at = datetime.utcnow()
226
+
227
+ for msg_id in self._messages_to_delete:
228
+ msg_to_del = self.session.query(self.db_manager.MessageModel).filter_by(id=msg_id).first()
229
+ if msg_to_del: self.session.delete(msg_to_del)
230
+ self._messages_to_delete.clear()
231
+
232
+ for msg_data in self.messages:
233
+ msg_id = msg_data['id']
234
+ msg_orm = self.session.query(self.db_manager.MessageModel).filter_by(id=msg_id).first()
235
+
236
+ if 'metadata' in msg_data:
237
+ msg_data['message_metadata'] = msg_data.pop('metadata',None)
238
+
239
+ if not msg_orm:
240
+ msg_data_copy = msg_data.copy()
241
+ valid_keys = {c.name for c in self.db_manager.MessageModel.__table__.columns}
242
+ filtered_msg_data = {k: v for k, v in msg_data_copy.items() if k in valid_keys}
243
+ msg_orm = self.db_manager.MessageModel(**filtered_msg_data)
244
+ self.session.add(msg_orm)
245
+ else:
246
+ for key, value in msg_data.items():
247
+ if hasattr(msg_orm, key):
248
+ setattr(msg_orm, key, value)
249
+
250
+ self.session.commit()
251
+
252
+ def touch(self):
253
+ self.updated_at = datetime.utcnow()
254
+ if self._is_db_backed and self.autosave:
255
+ self.commit()
256
+
257
+ @classmethod
258
+ def create_new(cls, lollms_client: 'LollmsClient', db_manager: Optional[DatabaseManager] = None, **kwargs) -> 'LollmsDiscussion':
259
+ init_args = {
260
+ 'autosave': kwargs.pop('autosave', False),
261
+ 'max_context_size': kwargs.pop('max_context_size', None)
262
+ }
263
+
264
+ if db_manager:
265
+ session = db_manager.get_session()
266
+ valid_keys = db_manager.DiscussionModel.__table__.columns.keys()
267
+ db_creation_args = {k: v for k, v in kwargs.items() if k in valid_keys}
268
+ db_discussion = db_manager.DiscussionModel(**db_creation_args)
269
+ session.add(db_discussion)
270
+ session.commit()
271
+ return cls(lollmsClient=lollms_client, discussion_id=db_discussion.id, db_manager=db_manager, **init_args)
272
+ else:
273
+ discussion_id = kwargs.get('discussion_id')
274
+ return cls(lollmsClient=lollms_client, discussion_id=discussion_id, **init_args)
93
275
 
94
- # --- Configuration Methods ---
95
276
  def set_system_prompt(self, prompt: str):
96
- """Sets the main system prompt for the discussion."""
97
277
  self.system_prompt = prompt
278
+ self.touch()
98
279
 
99
280
  def set_participants(self, participants: Dict[str, str]):
100
- """
101
- Defines the participants and their roles ('user' or 'assistant').
102
-
103
- Args:
104
- participants: A dictionary mapping sender names to roles.
105
- """
106
281
  for name, role in participants.items():
107
- if role not in ["user", "assistant"]:
282
+ if role not in ["user", "assistant", "system"]:
108
283
  raise ValueError(f"Invalid role '{role}' for participant '{name}'")
109
284
  self.participants = participants
285
+ self.touch()
110
286
 
111
- # --- Core Message Tree Methods ---
112
- def add_message(
113
- self,
114
- sender: str,
115
- sender_type: str,
116
- content: str,
117
- metadata: Optional[Dict] = None,
118
- parent_id: Optional[str] = None,
119
- images: Optional[List[Dict[str, str]]] = None,
120
- override_id: Optional[str] = None
121
- ) -> str:
122
- """
123
- Adds a new message to the discussion tree.
124
- """
125
- if parent_id is None:
126
- parent_id = self.active_branch_id
127
- if parent_id is None:
128
- parent_id = "main_root"
129
-
130
- message = LollmsMessage(
131
- sender=sender, sender_type=sender_type, content=content,
132
- parent_id=parent_id, metadata=str(metadata or {}), images=images or []
133
- )
134
- if override_id:
135
- message.id = override_id
136
-
137
- self.messages.append(message)
138
- self.message_index[message.id] = message
139
- self.children_index[parent_id].append(message.id)
140
- self.active_branch_id = message.id
141
- return message.id
142
-
143
- def get_branch(self, leaf_id: str) -> List[LollmsMessage]:
144
- """Gets the full branch of messages from the root to the specified leaf."""
287
+ def add_message(self, **kwargs) -> Dict:
288
+ msg_id = kwargs.get('id', str(uuid.uuid4()))
289
+ parent_id = kwargs.get('parent_id', self.active_branch_id or None)
290
+
291
+ message_data = {
292
+ 'id': msg_id, 'parent_id': parent_id,
293
+ 'discussion_id': self.id, 'created_at': datetime.utcnow(),
294
+ **kwargs
295
+ }
296
+
297
+ self.messages.append(message_data)
298
+ self.message_index[msg_id] = message_data
299
+ self.active_branch_id = msg_id
300
+ self.touch()
301
+ return message_data
302
+
303
+ def get_branch(self, leaf_id: Optional[str]) -> List[Dict]:
304
+ if not leaf_id: return []
145
305
  branch = []
146
306
  current_id: Optional[str] = leaf_id
147
307
  while current_id and current_id in self.message_index:
148
308
  msg = self.message_index[current_id]
149
309
  branch.append(msg)
150
- current_id = msg.parent_id
310
+ current_id = msg.get('parent_id')
151
311
  return list(reversed(branch))
312
+
313
+ def chat(self, user_message: str, show_thoughts: bool = False, **kwargs) -> Dict:
314
+ if self.max_context_size is not None:
315
+ self.summarize_and_prune(self.max_context_size)
316
+
317
+ if user_message:
318
+ self.add_message(sender="user", sender_type="user", content=user_message)
152
319
 
153
- def set_active_branch(self, message_id: str):
154
- """Sets the active message, effectively switching to a different branch."""
320
+ from lollms_client.lollms_types import MSG_TYPE
321
+
322
+ is_streaming = "streaming_callback" in kwargs and kwargs["streaming_callback"] is not None
323
+
324
+ if is_streaming:
325
+ full_response_parts = []
326
+ token_buffer = ""
327
+ in_thought_block = False
328
+ original_callback = kwargs.get("streaming_callback")
329
+
330
+ def accumulating_callback(token: str, msg_type: MSG_TYPE = MSG_TYPE.MSG_TYPE_CHUNK):
331
+ nonlocal token_buffer, in_thought_block
332
+ continue_streaming = True
333
+
334
+ if token: token_buffer += token
335
+
336
+ while True:
337
+ if in_thought_block:
338
+ end_tag_pos = token_buffer.find("</think>")
339
+ if end_tag_pos != -1:
340
+ thought_chunk = token_buffer[:end_tag_pos]
341
+ if show_thoughts and original_callback and thought_chunk:
342
+ if not original_callback(thought_chunk, MSG_TYPE.MSG_TYPE_THOUGHT_CHUNK): continue_streaming = False
343
+ in_thought_block = False
344
+ token_buffer = token_buffer[end_tag_pos + len("</think>"):]
345
+ else:
346
+ if show_thoughts and original_callback and token_buffer:
347
+ if not original_callback(token_buffer, MSG_TYPE.MSG_TYPE_THOUGHT_CHUNK): continue_streaming = False
348
+ token_buffer = ""
349
+ break
350
+ else:
351
+ start_tag_pos = token_buffer.find("<think>")
352
+ if start_tag_pos != -1:
353
+ response_chunk = token_buffer[:start_tag_pos]
354
+ if response_chunk:
355
+ full_response_parts.append(response_chunk)
356
+ if original_callback:
357
+ if not original_callback(response_chunk, MSG_TYPE.MSG_TYPE_CHUNK): continue_streaming = False
358
+ in_thought_block = True
359
+ token_buffer = token_buffer[start_tag_pos + len("<think>"):]
360
+ else:
361
+ if token_buffer:
362
+ full_response_parts.append(token_buffer)
363
+ if original_callback:
364
+ if not original_callback(token_buffer, MSG_TYPE.MSG_TYPE_CHUNK): continue_streaming = False
365
+ token_buffer = ""
366
+ break
367
+ return continue_streaming
368
+
369
+ kwargs["streaming_callback"] = accumulating_callback
370
+ kwargs["stream"] = True
371
+
372
+ self.lollmsClient.chat(self, **kwargs)
373
+ ai_response = "".join(full_response_parts)
374
+ else:
375
+ kwargs["stream"] = False
376
+ raw_response = self.lollmsClient.chat(self, **kwargs)
377
+ ai_response = self.lollmsClient.remove_thinking_blocks(raw_response) if raw_response else ""
378
+
379
+ ai_message_obj = self.add_message(sender="assistant", sender_type="assistant", content=ai_response)
380
+
381
+ if self._is_db_backed and not self.autosave:
382
+ self.commit()
383
+
384
+ return ai_message_obj
385
+
386
+ def regenerate_branch(self, show_thoughts: bool = False, **kwargs) -> Dict:
387
+ last_message = self.message_index.get(self.active_branch_id)
388
+ if not last_message or last_message['sender_type'] != 'assistant':
389
+ raise ValueError("Can only regenerate from an assistant's message.")
390
+
391
+ parent_id = last_message['parent_id']
392
+ self.active_branch_id = parent_id
393
+
394
+ self.messages = [m for m in self.messages if m['id'] != last_message['id']]
395
+ self._messages_to_delete.append(last_message['id'])
396
+ self._rebuild_in_memory_indexes()
397
+
398
+ new_ai_response_obj = self.chat("", show_thoughts, **kwargs)
399
+ return new_ai_response_obj
400
+
401
+ def delete_branch(self, message_id: str):
402
+ if not self._is_db_backed:
403
+ raise NotImplementedError("Branch deletion is only supported for database-backed discussions.")
404
+
405
+ if message_id not in self.message_index:
406
+ raise ValueError("Message not found.")
407
+
408
+ msg_to_delete = self.session.query(self.db_manager.MessageModel).filter_by(id=message_id).first()
409
+ if msg_to_delete:
410
+ parent_id = msg_to_delete.parent_id
411
+ self.session.delete(msg_to_delete)
412
+ self.active_branch_id = parent_id
413
+ self.commit()
414
+ self._load_from_db(self.id)
415
+
416
+ def switch_to_branch(self, message_id: str):
155
417
  if message_id not in self.message_index:
156
- raise ValueError(f"Message ID {message_id} not found in discussion.")
418
+ raise ValueError(f"Message ID '{message_id}' not found in the current discussion.")
157
419
  self.active_branch_id = message_id
420
+ if self._is_db_backed:
421
+ self.db_discussion.active_branch_id = message_id
422
+ if self.autosave: self.commit()
423
+
424
+ def format_discussion(self, max_allowed_tokens: int, branch_tip_id: Optional[str] = None) -> str:
425
+ return self.export("lollms_text", branch_tip_id, max_allowed_tokens)
158
426
 
159
- # --- Persistence ---
160
- def save_to_disk(self, file_path: str):
161
- """Saves the entire discussion state to a YAML file."""
162
- data = {
163
- 'version': self.version, 'active_branch_id': self.active_branch_id,
164
- 'system_prompt': self.system_prompt, 'participants': self.participants,
165
- 'scratchpad': self.scratchpad, 'messages': [m.to_dict() for m in self.messages]
166
- }
167
- with open(file_path, 'w', encoding='utf-8') as file:
168
- yaml.dump(data, file, allow_unicode=True, sort_keys=False)
169
-
170
- def load_from_disk(self, file_path: str):
171
- """Loads a discussion state from a YAML file."""
172
- with open(file_path, 'r', encoding='utf-8') as file:
173
- data = yaml.safe_load(file)
174
-
175
- self._reset_state()
176
- version = data.get("version", 1)
177
- if version > self.version:
178
- raise ValueError(f"File version {version} is newer than supported version {self.version}.")
179
-
180
- self.active_branch_id = data.get('active_branch_id')
181
- self.system_prompt = data.get('system_prompt', None)
182
- self.participants = data.get('participants', {})
183
- self.scratchpad = data.get('scratchpad', None)
184
-
185
- for msg_data in data.get('messages', []):
186
- msg = LollmsMessage(
187
- sender=msg_data['sender'], sender_type=msg_data.get('sender_type', 'user'),
188
- content=msg_data['content'], parent_id=msg_data.get('parent_id'),
189
- id=msg_data.get('id', str(uuid.uuid4())), metadata=msg_data.get('metadata', '{}'),
190
- images=msg_data.get('images', [])
191
- )
192
- self.messages.append(msg)
193
- self.message_index[msg.id] = msg
194
- self.children_index[msg.parent_id].append(msg.id)
195
-
196
- # --- Context Management and Formatting ---
197
427
  def _get_full_system_prompt(self) -> Optional[str]:
198
- """Combines the scratchpad and system prompt into a single string for the LLM."""
199
428
  full_sys_prompt_parts = []
200
- if self.scratchpad and self.scratchpad.strip():
429
+ if self.scratchpad:
201
430
  full_sys_prompt_parts.append("--- KNOWLEDGE SCRATCHPAD ---")
202
431
  full_sys_prompt_parts.append(self.scratchpad.strip())
203
432
  full_sys_prompt_parts.append("--- END SCRATCHPAD ---")
204
433
 
205
434
  if self.system_prompt and self.system_prompt.strip():
206
435
  full_sys_prompt_parts.append(self.system_prompt.strip())
436
+
207
437
  return "\n\n".join(full_sys_prompt_parts) if full_sys_prompt_parts else None
208
438
 
209
- def summarize_and_prune(self, max_tokens: int, preserve_last_n: int = 4, branch_tip_id: Optional[str] = None) -> Dict[str, Any]:
210
- """
211
- Checks context size and, if exceeded, summarizes the oldest messages
212
- into the scratchpad and prunes them to free up token space.
213
- """
439
+ def export(self, format_type: str, branch_tip_id: Optional[str] = None, max_allowed_tokens: Optional[int] = None) -> Union[List[Dict], str]:
214
440
  if branch_tip_id is None: branch_tip_id = self.active_branch_id
215
- if not branch_tip_id: return {"pruned": False, "reason": "No active branch."}
441
+ if not branch_tip_id and format_type in ["lollms_text", "openai_chat", "ollama_chat"]:
442
+ return "" if format_type == "lollms_text" else []
216
443
 
217
- full_prompt_text = self.export("lollms_text", branch_tip_id)
218
- current_tokens = len(self.lollmsClient.binding.tokenize(full_prompt_text))
219
- if current_tokens <= max_tokens: return {"pruned": False, "reason": "Token count within limit."}
444
+ branch = self.get_branch(branch_tip_id)
445
+ full_system_prompt = self._get_full_system_prompt()
446
+
447
+ participants = self.participants or {}
448
+
449
+ if format_type == "lollms_text":
450
+ prompt_parts = []
451
+ current_tokens = 0
452
+
453
+ if full_system_prompt:
454
+ sys_msg_text = f"!@>system:\n{full_system_prompt}\n"
455
+ sys_tokens = self.lollmsClient.count_tokens(sys_msg_text)
456
+ if max_allowed_tokens is None or sys_tokens <= max_allowed_tokens:
457
+ prompt_parts.append(sys_msg_text)
458
+ current_tokens += sys_tokens
459
+
460
+ for msg in reversed(branch):
461
+ sender_str = msg['sender'].replace(':', '').replace('!@>', '')
462
+ content = msg['content'].strip()
463
+ if msg.get('images'): content += f"\n({len(msg['images'])} image(s) attached)"
464
+ msg_text = f"!@>{sender_str}:\n{content}\n"
465
+ msg_tokens = self.lollmsClient.count_tokens(msg_text)
466
+
467
+ if max_allowed_tokens is not None and current_tokens + msg_tokens > max_allowed_tokens: break
468
+ prompt_parts.insert(1 if full_system_prompt else 0, msg_text)
469
+ current_tokens += msg_tokens
470
+ return "".join(prompt_parts).strip()
471
+
472
+ messages = []
473
+ if full_system_prompt:
474
+ messages.append({"role": "system", "content": full_system_prompt})
475
+
476
+ for msg in branch:
477
+ role = participants.get(msg['sender'], "user")
478
+ content = msg.get('content', '').strip()
479
+ images = msg.get('images', [])
480
+
481
+ if format_type == "openai_chat":
482
+ if images:
483
+ content_parts = [{"type": "text", "text": content}] if content else []
484
+ for img in images:
485
+ image_url = img['data'] if img['type'] == 'url' else f"data:image/jpeg;base64,{img['data']}"
486
+ content_parts.append({"type": "image_url", "image_url": {"url": image_url, "detail": "auto"}})
487
+ messages.append({"role": role, "content": content_parts})
488
+ else:
489
+ messages.append({"role": role, "content": content})
490
+ elif format_type == "ollama_chat":
491
+ message_dict = {"role": role, "content": content}
492
+ base64_images = [img['data'] for img in images or [] if img['type'] == 'base64']
493
+ if base64_images:
494
+ message_dict["images"] = base64_images
495
+ messages.append(message_dict)
496
+ else:
497
+ raise ValueError(f"Unsupported export format_type: {format_type}")
498
+
499
+ return messages
500
+
501
+ def summarize_and_prune(self, max_tokens: int, preserve_last_n: int = 4):
502
+ branch_tip_id = self.active_branch_id
503
+ if not branch_tip_id: return
504
+
505
+ current_prompt_text = self.format_discussion(999999, branch_tip_id)
506
+ current_tokens = self.lollmsClient.count_tokens(current_prompt_text)
507
+ if current_tokens <= max_tokens: return
220
508
 
221
509
  branch = self.get_branch(branch_tip_id)
222
- if len(branch) <= preserve_last_n: return {"pruned": False, "reason": "Not enough messages to prune."}
510
+ if len(branch) <= preserve_last_n: return
223
511
 
224
512
  messages_to_prune = branch[:-preserve_last_n]
225
- messages_to_keep = branch[-preserve_last_n:]
226
- text_to_summarize = "\n\n".join([f"{self.participants.get(m.sender, 'user').capitalize()}: {m.content}" for m in messages_to_prune])
513
+ text_to_summarize = "\n\n".join([f"{m['sender']}: {m['content']}" for m in messages_to_prune])
227
514
 
228
- summary_prompt = (
229
- "You are a summarization expert. Read the following conversation excerpt and create a "
230
- "concise, factual summary of all key information, decisions, and outcomes. This summary "
231
- "will be placed in a knowledge scratchpad for future reference. Omit conversational filler.\n\n"
232
- f"CONVERSATION EXCERPT:\n---\n{text_to_summarize}\n---\n\nCONCISE SUMMARY:"
233
- )
515
+ summary_prompt = f"Concisely summarize this conversation excerpt:\n---\n{text_to_summarize}\n---\nSUMMARY:"
234
516
  try:
235
- summary = self.lollmsClient.generate_text(summary_prompt, max_new_tokens=300, temperature=0.1)
517
+ summary = self.lollmsClient.generate_text(summary_prompt, n_predict=300, temperature=0.1)
236
518
  except Exception as e:
237
- return {"pruned": False, "reason": f"Failed to generate summary: {e}"}
238
-
239
- summary_block = f"--- Summary of earlier conversation (pruned on {uuid.uuid4().hex[:8]}) ---\n{summary.strip()}"
240
- self.update_scratchpad(summary_block, append=True)
241
-
242
- ids_to_prune = {msg.id for msg in messages_to_prune}
243
- new_root_of_branch = messages_to_keep[0]
244
- original_parent_id = messages_to_prune[0].parent_id
245
-
246
- self.message_index[new_root_of_branch.id].parent_id = original_parent_id
247
- if original_parent_id in self.children_index:
248
- self.children_index[original_parent_id] = [mid for mid in self.children_index[original_parent_id] if mid != messages_to_prune[0].id]
249
- self.children_index[original_parent_id].append(new_root_of_branch.id)
250
-
251
- for msg_id in ids_to_prune:
252
- self.message_index.pop(msg_id, None)
253
- self.children_index.pop(msg_id, None)
254
- self.messages = [m for m in self.messages if m.id not in ids_to_prune]
255
-
256
- new_prompt_text = self.export("lollms_text", branch_tip_id)
257
- new_tokens = len(self.lollmsClient.binding.tokenize(new_prompt_text))
258
- return {"pruned": True, "tokens_saved": current_tokens - new_tokens, "summary_added": True}
259
-
260
- def format_discussion(self, max_allowed_tokens: int, splitter_text: str = "!@>", branch_tip_id: Optional[str] = None) -> str:
261
- """
262
- Formats the discussion into a single string for instruct models,
263
- truncating from the start to respect the token limit.
264
-
265
- Args:
266
- max_allowed_tokens: The maximum token limit for the final prompt.
267
- splitter_text: The separator token to use (e.g., '!@>').
268
- branch_tip_id: The ID of the branch to format. Defaults to active.
269
-
270
- Returns:
271
- A single, truncated prompt string.
272
- """
273
- if branch_tip_id is None:
274
- branch_tip_id = self.active_branch_id
275
-
276
- branch_msgs = self.get_branch(branch_tip_id) if branch_tip_id else []
277
- full_system_prompt = self._get_full_system_prompt()
278
-
279
- prompt_parts = []
280
- current_tokens = 0
281
-
282
- # Start with the system prompt if defined
283
- if full_system_prompt:
284
- sys_msg_text = f"{splitter_text}system:\n{full_system_prompt}\n"
285
- sys_tokens = len(self.lollmsClient.binding.tokenize(sys_msg_text))
286
- if sys_tokens <= max_allowed_tokens:
287
- prompt_parts.append(sys_msg_text)
288
- current_tokens += sys_tokens
289
-
290
- # Iterate from newest to oldest to fill the remaining context
291
- for msg in reversed(branch_msgs):
292
- sender_str = msg.sender.replace(':', '').replace(splitter_text, '')
293
- content = msg.content.strip()
294
- if msg.images:
295
- content += f"\n({len(msg.images)} image(s) attached)"
296
-
297
- msg_text = f"{splitter_text}{sender_str}:\n{content}\n"
298
- msg_tokens = len(self.lollmsClient.binding.tokenize(msg_text))
519
+ print(f"\n[WARNING] Pruning failed, couldn't generate summary: {e}")
520
+ return
299
521
 
300
- if current_tokens + msg_tokens > max_allowed_tokens:
301
- break # Stop if adding the next message exceeds the limit
522
+ new_scratchpad_content = f"{self.scratchpad}\n\n--- Summary of earlier conversation ---\n{summary.strip()}"
523
+ self.scratchpad = new_scratchpad_content.strip()
302
524
 
303
- prompt_parts.insert(1 if full_system_prompt else 0, msg_text) # Prepend after system prompt
304
- current_tokens += msg_tokens
305
-
306
- return "".join(prompt_parts).strip()
525
+ pruned_ids = {msg['id'] for msg in messages_to_prune}
526
+ self.messages = [m for m in self.messages if m['id'] not in pruned_ids]
527
+ self._messages_to_delete.extend(list(pruned_ids))
528
+ self._rebuild_in_memory_indexes()
307
529
 
530
+ print(f"\n[INFO] Discussion auto-pruned. {len(messages_to_prune)} messages summarized.")
308
531
 
309
- def export(self, format_type: str, branch_tip_id: Optional[str] = None) -> Union[List[Dict], str]:
310
- """
311
- Exports the full, untruncated discussion history in a specific format.
312
- """
313
- if branch_tip_id is None: branch_tip_id = self.active_branch_id
314
- if branch_tip_id is None and not self._get_full_system_prompt(): return "" if format_type in ["lollms_text", "openai_completion"] else []
532
+ def to_dict(self):
533
+ messages_copy = [msg.copy() for msg in self.messages]
534
+ for msg in messages_copy:
535
+ if 'created_at' in msg and isinstance(msg['created_at'], datetime):
536
+ msg['created_at'] = msg['created_at'].isoformat()
537
+ if 'message_metadata' in msg:
538
+ msg['metadata'] = msg.pop('message_metadata')
315
539
 
316
- branch = self.get_branch(branch_tip_id) if branch_tip_id else []
317
- full_system_prompt = self._get_full_system_prompt()
540
+ return {
541
+ "id": self.id, "system_prompt": self.system_prompt,
542
+ "participants": self.participants, "active_branch_id": self.active_branch_id,
543
+ "metadata": self.metadata, "scratchpad": self.scratchpad,
544
+ "messages": messages_copy,
545
+ "created_at": self.created_at.isoformat() if self.created_at else None,
546
+ "updated_at": self.updated_at.isoformat() if self.updated_at else None
547
+ }
318
548
 
319
- if format_type == "openai_chat":
320
- messages = []
321
- if full_system_prompt: messages.append({"role": "system", "content": full_system_prompt})
322
- def openai_image_block(image: Dict[str, str]) -> Dict:
323
- image_url = image['data'] if image['type'] == 'url' else f"data:image/jpeg;base64,{image['data']}"
324
- return {"type": "image_url", "image_url": {"url": image_url, "detail": "auto"}}
325
- for msg in branch:
326
- role = self.participants.get(msg.sender, "user")
327
- if msg.images:
328
- content_parts = [{"type": "text", "text": msg.content.strip()}] if msg.content.strip() else []
329
- content_parts.extend(openai_image_block(img) for img in msg.images)
330
- messages.append({"role": role, "content": content_parts})
331
- else: messages.append({"role": role, "content": msg.content.strip()})
332
- return messages
333
-
334
- elif format_type == "ollama_chat":
335
- messages = []
336
- if full_system_prompt: messages.append({"role": "system", "content": full_system_prompt})
337
- for msg in branch:
338
- role = self.participants.get(msg.sender, "user")
339
- message_dict = {"role": role, "content": msg.content.strip()}
340
- ollama_images = [img['data'] for img in msg.images if img['type'] == 'base64']
341
- if ollama_images: message_dict["images"] = ollama_images
342
- messages.append(message_dict)
343
- return messages
344
-
345
- elif format_type == "lollms_text":
346
- full_prompt_parts = []
347
- if full_system_prompt: full_prompt_parts.append(f"!@>system:\n{full_system_prompt}")
348
- for msg in branch:
349
- sender_str = msg.sender.replace(':', '').replace('!@>', '')
350
- content = msg.content.strip()
351
- if msg.images: content += f"\n({len(msg.images)} image(s) attached)"
352
- full_prompt_parts.append(f"!@>{sender_str}:\n{content}")
353
- return "\n".join(full_prompt_parts)
354
-
355
- elif format_type == "openai_completion":
356
- full_prompt_parts = []
357
- if full_system_prompt: full_prompt_parts.append(f"System:\n{full_system_prompt}")
358
- for msg in branch:
359
- role_label = self.participants.get(msg.sender, "user").capitalize()
360
- content = msg.content.strip()
361
- if msg.images: content += f"\n({len(msg.images)} image(s) attached)"
362
- full_prompt_parts.append(f"{role_label}:\n{content}")
363
- return "\n\n".join(full_prompt_parts)
364
-
365
- else: raise ValueError(f"Unsupported export format_type: {format_type}")
366
-
367
-
368
- if __name__ == "__main__":
369
- class MockBinding:
370
- def tokenize(self, text: str) -> List[int]: return text.split()
371
- class MockLollmsClient:
372
- def __init__(self): self.binding = MockBinding()
373
- def generate(self, prompt: str, max_new_tokens: int, temperature: float) -> str: return "This is a generated summary."
374
-
375
- print("--- Initializing Mock Client and Discussion ---")
376
- mock_client = MockLollmsClient()
377
- discussion = LollmsDiscussion(mock_client)
378
- discussion.set_participants({"User": "user", "Project Lead": "assistant"})
379
- discussion.set_system_prompt("This is a formal discussion about Project Phoenix.")
380
- discussion.set_scratchpad("Initial State: Project Phoenix is in the planning phase.")
381
-
382
- print("\n--- Creating a long discussion history ---")
383
- parent_id = None
384
- long_text = "extra text to increase token count"
385
- for i in range(10):
386
- user_msg = f"Message #{i*2+1}: Update on task {i+1}? {long_text}"
387
- user_id = discussion.add_message("User", "user", user_msg, parent_id=parent_id)
388
- assistant_msg = f"Message #{i*2+2}: Task {i+1} status is blocked. {long_text}"
389
- assistant_id = discussion.add_message("Project Lead", "assistant", assistant_msg, parent_id=user_id)
390
- parent_id = assistant_id
391
-
392
- initial_tokens = len(mock_client.binding.tokenize(discussion.export("lollms_text")))
393
- print(f"Initial message count: {len(discussion.messages)}, Initial tokens: {initial_tokens}")
394
-
395
- print("\n--- Testing Pruning ---")
396
- prune_result = discussion.summarize_and_prune(max_tokens=200, preserve_last_n=4)
397
- if prune_result.get("pruned"):
398
- print("✅ Pruning was successful!")
399
- assert "Summary" in discussion.get_scratchpad()
400
- else: print(f" Pruning failed: {prune_result.get('reason')}")
401
-
402
- print("\n--- Testing format_discussion (Instruct Model Format) ---")
403
- truncated_prompt = discussion.format_discussion(max_allowed_tokens=80)
404
- truncated_tokens = len(mock_client.binding.tokenize(truncated_prompt))
405
- print(f"Truncated prompt tokens: {truncated_tokens}")
406
- print("Truncated Prompt:\n" + "="*20 + f"\n{truncated_prompt}\n" + "="*20)
407
-
408
- # Verification
409
- assert truncated_tokens <= 80
410
- # Check that it contains the newest message that fits
411
- assert "Message #19" in truncated_prompt or "Message #20" in truncated_prompt
412
- print("✅ format_discussion correctly truncated the prompt.")
549
+ def load_from_dict(self, data: Dict):
550
+ self._reset_in_memory_state()
551
+ self.id = data.get("id", str(uuid.uuid4()))
552
+ self.system_prompt = data.get("system_prompt")
553
+ self.participants = data.get("participants", {})
554
+ self.active_branch_id = data.get("active_branch_id")
555
+ self.metadata = data.get("metadata", {})
556
+ self.scratchpad = data.get("scratchpad", "")
557
+
558
+ loaded_messages = data.get("messages", [])
559
+ for msg in loaded_messages:
560
+ if 'created_at' in msg and isinstance(msg['created_at'], str):
561
+ try:
562
+ msg['created_at'] = datetime.fromisoformat(msg['created_at'])
563
+ except ValueError:
564
+ msg['created_at'] = datetime.utcnow()
565
+ self.messages = loaded_messages
566
+
567
+ self.created_at = datetime.fromisoformat(data['created_at']) if data.get('created_at') else datetime.utcnow()
568
+ self.updated_at = datetime.fromisoformat(data['updated_at']) if data.get('updated_at') else self.created_at
569
+ self._rebuild_in_memory_indexes()
570
+
571
+ def _rebuild_in_memory_indexes(self):
572
+ self.message_index = {msg['id']: msg for msg in self.messages}
573
+
574
+ @staticmethod
575
+ def migrate(lollms_client: 'LollmsClient', db_manager: DatabaseManager, folder_path: Union[str, Path]):
576
+ folder = Path(folder_path)
577
+ if not folder.is_dir():
578
+ print(f"Error: Path '{folder}' is not a valid directory.")
579
+ return
580
+
581
+ print(f"\n--- Starting Migration from '{folder}' ---")
582
+ discussion_files = list(folder.glob("*.json")) + list(folder.glob("*.yaml"))
583
+ session = db_manager.get_session()
584
+ for i, file_path in enumerate(discussion_files):
585
+ print(f"Migrating file {i+1}/{len(discussion_files)}: {file_path.name} ... ", end="")
586
+ try:
587
+ in_memory_discussion = LollmsDiscussion.create_new(lollms_client=lollms_client)
588
+ if file_path.suffix.lower() == ".json":
589
+ with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f)
590
+ else:
591
+ with open(file_path, 'r', encoding='utf-8') as f: data = yaml.safe_load(f)
592
+
593
+ in_memory_discussion.load_from_dict(data)
594
+ discussion_id = in_memory_discussion.id
595
+
596
+ existing = session.query(db_manager.DiscussionModel).filter_by(id=discussion_id).first()
597
+ if existing:
598
+ print("SKIPPED (already exists)")
599
+ continue
600
+
601
+ valid_disc_keys = {c.name for c in db_manager.DiscussionModel.__table__.columns}
602
+ valid_msg_keys = {c.name for c in db_manager.MessageModel.__table__.columns}
603
+
604
+ discussion_data = {
605
+ 'id': in_memory_discussion.id,
606
+ 'system_prompt': in_memory_discussion.system_prompt,
607
+ 'participants': in_memory_discussion.participants,
608
+ 'active_branch_id': in_memory_discussion.active_branch_id,
609
+ 'discussion_metadata': in_memory_discussion.metadata,
610
+ 'created_at': in_memory_discussion.created_at,
611
+ 'updated_at': in_memory_discussion.updated_at
612
+ }
613
+ project_name = in_memory_discussion.metadata.get('project_name', file_path.stem)
614
+ if 'project_name' in valid_disc_keys:
615
+ discussion_data['project_name'] = project_name
616
+
617
+ db_discussion = db_manager.DiscussionModel(**discussion_data)
618
+ session.add(db_discussion)
619
+
620
+ for msg_data in in_memory_discussion.messages:
621
+ msg_data['discussion_id'] = db_discussion.id
622
+ if 'metadata' in msg_data:
623
+ msg_data['message_metadata'] = msg_data.pop('metadata')
624
+ filtered_msg_data = {k: v for k, v in msg_data.items() if k in valid_msg_keys}
625
+ msg_orm = db_manager.MessageModel(**filtered_msg_data)
626
+ session.add(msg_orm)
627
+
628
+ print("OK")
629
+ except Exception as e:
630
+ print(f"FAILED. Error: {e}")
631
+ session.rollback()
632
+ session.commit()
633
+ session.close()