swarms 7.7.8__py3-none-any.whl → 7.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- swarms/__init__.py +0 -1
- swarms/agents/cort_agent.py +206 -0
- swarms/agents/react_agent.py +173 -0
- swarms/agents/self_agent_builder.py +40 -0
- swarms/communication/base_communication.py +290 -0
- swarms/communication/duckdb_wrap.py +369 -72
- swarms/communication/pulsar_struct.py +691 -0
- swarms/communication/redis_wrap.py +1362 -0
- swarms/communication/sqlite_wrap.py +547 -44
- swarms/prompts/agent_self_builder_prompt.py +103 -0
- swarms/prompts/safety_prompt.py +50 -0
- swarms/schemas/__init__.py +6 -1
- swarms/schemas/agent_class_schema.py +91 -0
- swarms/schemas/agent_mcp_errors.py +18 -0
- swarms/schemas/agent_tool_schema.py +13 -0
- swarms/schemas/llm_agent_schema.py +92 -0
- swarms/schemas/mcp_schemas.py +43 -0
- swarms/structs/__init__.py +4 -0
- swarms/structs/agent.py +315 -267
- swarms/structs/aop.py +3 -1
- swarms/structs/batch_agent_execution.py +64 -0
- swarms/structs/conversation.py +261 -57
- swarms/structs/council_judge.py +542 -0
- swarms/structs/deep_research_swarm.py +19 -22
- swarms/structs/long_agent.py +424 -0
- swarms/structs/ma_utils.py +11 -8
- swarms/structs/malt.py +30 -28
- swarms/structs/multi_model_gpu_manager.py +1 -1
- swarms/structs/output_types.py +1 -1
- swarms/structs/swarm_router.py +70 -15
- swarms/tools/__init__.py +12 -0
- swarms/tools/base_tool.py +2840 -264
- swarms/tools/create_agent_tool.py +104 -0
- swarms/tools/mcp_client_call.py +504 -0
- swarms/tools/py_func_to_openai_func_str.py +45 -7
- swarms/tools/pydantic_to_json.py +10 -27
- swarms/utils/audio_processing.py +343 -0
- swarms/utils/history_output_formatter.py +5 -5
- swarms/utils/index.py +226 -0
- swarms/utils/litellm_wrapper.py +65 -67
- swarms/utils/try_except_wrapper.py +2 -2
- swarms/utils/xml_utils.py +42 -0
- {swarms-7.7.8.dist-info → swarms-7.8.0.dist-info}/METADATA +5 -4
- {swarms-7.7.8.dist-info → swarms-7.8.0.dist-info}/RECORD +47 -30
- {swarms-7.7.8.dist-info → swarms-7.8.0.dist-info}/WHEEL +1 -1
- swarms/client/__init__.py +0 -15
- swarms/client/main.py +0 -407
- swarms/tools/mcp_client.py +0 -246
- swarms/tools/mcp_integration.py +0 -340
- {swarms-7.7.8.dist-info → swarms-7.8.0.dist-info}/LICENSE +0 -0
- {swarms-7.7.8.dist-info → swarms-7.8.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1362 @@
|
|
1
|
+
import datetime
|
2
|
+
import hashlib
|
3
|
+
import json
|
4
|
+
import threading
|
5
|
+
import subprocess
|
6
|
+
import tempfile
|
7
|
+
import os
|
8
|
+
import atexit
|
9
|
+
import time
|
10
|
+
from typing import Any, Dict, List, Optional, Union
|
11
|
+
|
12
|
+
import yaml
|
13
|
+
|
14
|
+
try:
|
15
|
+
import redis
|
16
|
+
from redis.exceptions import (
|
17
|
+
AuthenticationError,
|
18
|
+
BusyLoadingError,
|
19
|
+
ConnectionError,
|
20
|
+
RedisError,
|
21
|
+
TimeoutError,
|
22
|
+
)
|
23
|
+
|
24
|
+
REDIS_AVAILABLE = True
|
25
|
+
except ImportError:
|
26
|
+
REDIS_AVAILABLE = False
|
27
|
+
|
28
|
+
from loguru import logger
|
29
|
+
|
30
|
+
from swarms.structs.base_structure import BaseStructure
|
31
|
+
from swarms.utils.any_to_str import any_to_str
|
32
|
+
from swarms.utils.formatter import formatter
|
33
|
+
from swarms.utils.litellm_tokenizer import count_tokens
|
34
|
+
|
35
|
+
|
36
|
+
class RedisConnectionError(Exception):
|
37
|
+
"""Custom exception for Redis connection errors."""
|
38
|
+
|
39
|
+
pass
|
40
|
+
|
41
|
+
|
42
|
+
class RedisOperationError(Exception):
|
43
|
+
"""Custom exception for Redis operation errors."""
|
44
|
+
|
45
|
+
pass
|
46
|
+
|
47
|
+
|
48
|
+
class EmbeddedRedisServer:
|
49
|
+
"""Embedded Redis server manager"""
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
port: int = 6379,
|
54
|
+
data_dir: str = None,
|
55
|
+
persist: bool = True,
|
56
|
+
auto_persist: bool = True,
|
57
|
+
):
|
58
|
+
self.port = port
|
59
|
+
self.process = None
|
60
|
+
self.data_dir = data_dir or os.path.expanduser(
|
61
|
+
"~/.swarms/redis"
|
62
|
+
)
|
63
|
+
self.persist = persist
|
64
|
+
self.auto_persist = auto_persist
|
65
|
+
|
66
|
+
# Only create data directory if persistence is enabled
|
67
|
+
if self.persist and self.auto_persist:
|
68
|
+
os.makedirs(self.data_dir, exist_ok=True)
|
69
|
+
# Create Redis configuration file
|
70
|
+
self._create_redis_config()
|
71
|
+
|
72
|
+
atexit.register(self.stop)
|
73
|
+
|
74
|
+
def _create_redis_config(self):
|
75
|
+
"""Create Redis configuration file with persistence settings"""
|
76
|
+
config_path = os.path.join(self.data_dir, "redis.conf")
|
77
|
+
config_content = f"""
|
78
|
+
port {self.port}
|
79
|
+
dir {self.data_dir}
|
80
|
+
dbfilename dump.rdb
|
81
|
+
appendonly yes
|
82
|
+
appendfilename appendonly.aof
|
83
|
+
appendfsync everysec
|
84
|
+
save 1 1
|
85
|
+
rdbcompression yes
|
86
|
+
rdbchecksum yes
|
87
|
+
"""
|
88
|
+
with open(config_path, "w") as f:
|
89
|
+
f.write(config_content)
|
90
|
+
logger.info(f"Created Redis configuration at {config_path}")
|
91
|
+
|
92
|
+
def start(self) -> bool:
|
93
|
+
"""Start the Redis server
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
bool: True if server started successfully, False otherwise
|
97
|
+
"""
|
98
|
+
try:
|
99
|
+
# Use data directory if persistence is enabled and auto_persist is True
|
100
|
+
if not (self.persist and self.auto_persist):
|
101
|
+
self.data_dir = tempfile.mkdtemp()
|
102
|
+
self._create_redis_config() # Create config even for temporary dir
|
103
|
+
|
104
|
+
config_path = os.path.join(self.data_dir, "redis.conf")
|
105
|
+
|
106
|
+
# Start Redis server with config file
|
107
|
+
redis_args = [
|
108
|
+
"redis-server",
|
109
|
+
config_path,
|
110
|
+
"--daemonize",
|
111
|
+
"no",
|
112
|
+
]
|
113
|
+
|
114
|
+
# Start Redis server
|
115
|
+
self.process = subprocess.Popen(
|
116
|
+
redis_args,
|
117
|
+
stdout=subprocess.PIPE,
|
118
|
+
stderr=subprocess.PIPE,
|
119
|
+
)
|
120
|
+
|
121
|
+
# Wait for Redis to start
|
122
|
+
time.sleep(1)
|
123
|
+
if self.process.poll() is not None:
|
124
|
+
stderr = self.process.stderr.read().decode()
|
125
|
+
raise Exception(f"Redis failed to start: {stderr}")
|
126
|
+
|
127
|
+
# Test connection
|
128
|
+
try:
|
129
|
+
r = redis.Redis(host="localhost", port=self.port)
|
130
|
+
r.ping()
|
131
|
+
r.close()
|
132
|
+
except redis.ConnectionError as e:
|
133
|
+
raise Exception(
|
134
|
+
f"Could not connect to Redis: {str(e)}"
|
135
|
+
)
|
136
|
+
|
137
|
+
logger.info(
|
138
|
+
f"Started {'persistent' if (self.persist and self.auto_persist) else 'temporary'} Redis server on port {self.port}"
|
139
|
+
)
|
140
|
+
if self.persist and self.auto_persist:
|
141
|
+
logger.info(f"Redis data directory: {self.data_dir}")
|
142
|
+
return True
|
143
|
+
except Exception as e:
|
144
|
+
logger.error(
|
145
|
+
f"Failed to start embedded Redis server: {str(e)}"
|
146
|
+
)
|
147
|
+
self.stop()
|
148
|
+
return False
|
149
|
+
|
150
|
+
def stop(self):
|
151
|
+
"""Stop the Redis server and cleanup resources"""
|
152
|
+
try:
|
153
|
+
if self.process:
|
154
|
+
# Send SAVE and BGSAVE commands before stopping if persistence is enabled
|
155
|
+
if self.persist and self.auto_persist:
|
156
|
+
try:
|
157
|
+
r = redis.Redis(
|
158
|
+
host="localhost", port=self.port
|
159
|
+
)
|
160
|
+
r.save() # Synchronous save
|
161
|
+
r.bgsave() # Asynchronous save
|
162
|
+
time.sleep(
|
163
|
+
1
|
164
|
+
) # Give time for background save to complete
|
165
|
+
r.close()
|
166
|
+
except Exception as e:
|
167
|
+
logger.warning(
|
168
|
+
f"Error during Redis save: {str(e)}"
|
169
|
+
)
|
170
|
+
|
171
|
+
self.process.terminate()
|
172
|
+
try:
|
173
|
+
self.process.wait(timeout=5)
|
174
|
+
except subprocess.TimeoutExpired:
|
175
|
+
self.process.kill()
|
176
|
+
self.process.wait()
|
177
|
+
self.process = None
|
178
|
+
logger.info("Stopped Redis server")
|
179
|
+
|
180
|
+
# Only remove directory if not persisting or auto_persist is False
|
181
|
+
if (
|
182
|
+
(not self.persist or not self.auto_persist)
|
183
|
+
and self.data_dir
|
184
|
+
and os.path.exists(self.data_dir)
|
185
|
+
):
|
186
|
+
import shutil
|
187
|
+
|
188
|
+
shutil.rmtree(self.data_dir)
|
189
|
+
self.data_dir = None
|
190
|
+
except Exception as e:
|
191
|
+
logger.error(f"Error stopping Redis server: {str(e)}")
|
192
|
+
|
193
|
+
|
194
|
+
class RedisConversation(BaseStructure):
|
195
|
+
"""
|
196
|
+
A Redis-based implementation of the Conversation class for managing conversation history.
|
197
|
+
This class provides the same interface as the memory-based Conversation class but uses
|
198
|
+
Redis as the storage backend.
|
199
|
+
|
200
|
+
Attributes:
|
201
|
+
system_prompt (Optional[str]): The system prompt for the conversation.
|
202
|
+
time_enabled (bool): Flag to enable time tracking for messages.
|
203
|
+
autosave (bool): Flag to enable automatic saving of conversation history.
|
204
|
+
save_filepath (str): File path for saving the conversation history.
|
205
|
+
tokenizer (Any): Tokenizer for counting tokens in messages.
|
206
|
+
context_length (int): Maximum number of tokens allowed in the conversation history.
|
207
|
+
rules (str): Rules for the conversation.
|
208
|
+
custom_rules_prompt (str): Custom prompt for rules.
|
209
|
+
user (str): The user identifier for messages.
|
210
|
+
auto_save (bool): Flag to enable auto-saving of conversation history.
|
211
|
+
save_as_yaml (bool): Flag to save conversation history as YAML.
|
212
|
+
save_as_json_bool (bool): Flag to save conversation history as JSON.
|
213
|
+
token_count (bool): Flag to enable token counting for messages.
|
214
|
+
cache_enabled (bool): Flag to enable prompt caching.
|
215
|
+
cache_stats (dict): Statistics about cache usage.
|
216
|
+
cache_lock (threading.Lock): Lock for thread-safe cache operations.
|
217
|
+
redis_client (redis.Redis): Redis client instance.
|
218
|
+
conversation_id (str): Unique identifier for the current conversation.
|
219
|
+
"""
|
220
|
+
|
221
|
+
def __init__(
|
222
|
+
self,
|
223
|
+
system_prompt: Optional[str] = None,
|
224
|
+
time_enabled: bool = False,
|
225
|
+
autosave: bool = False,
|
226
|
+
save_filepath: str = None,
|
227
|
+
tokenizer: Any = None,
|
228
|
+
context_length: int = 8192,
|
229
|
+
rules: str = None,
|
230
|
+
custom_rules_prompt: str = None,
|
231
|
+
user: str = "User:",
|
232
|
+
auto_save: bool = True,
|
233
|
+
save_as_yaml: bool = True,
|
234
|
+
save_as_json_bool: bool = False,
|
235
|
+
token_count: bool = True,
|
236
|
+
cache_enabled: bool = True,
|
237
|
+
redis_host: str = "localhost",
|
238
|
+
redis_port: int = 6379,
|
239
|
+
redis_db: int = 0,
|
240
|
+
redis_password: Optional[str] = None,
|
241
|
+
redis_ssl: bool = False,
|
242
|
+
redis_retry_attempts: int = 3,
|
243
|
+
redis_retry_delay: float = 1.0,
|
244
|
+
use_embedded_redis: bool = True,
|
245
|
+
persist_redis: bool = True,
|
246
|
+
auto_persist: bool = True,
|
247
|
+
redis_data_dir: Optional[str] = None,
|
248
|
+
conversation_id: Optional[str] = None,
|
249
|
+
name: Optional[str] = None,
|
250
|
+
*args,
|
251
|
+
**kwargs,
|
252
|
+
):
|
253
|
+
"""
|
254
|
+
Initialize the RedisConversation with Redis backend.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
system_prompt (Optional[str]): The system prompt for the conversation.
|
258
|
+
time_enabled (bool): Flag to enable time tracking for messages.
|
259
|
+
autosave (bool): Flag to enable automatic saving of conversation history.
|
260
|
+
save_filepath (str): File path for saving the conversation history.
|
261
|
+
tokenizer (Any): Tokenizer for counting tokens in messages.
|
262
|
+
context_length (int): Maximum number of tokens allowed in the conversation history.
|
263
|
+
rules (str): Rules for the conversation.
|
264
|
+
custom_rules_prompt (str): Custom prompt for rules.
|
265
|
+
user (str): The user identifier for messages.
|
266
|
+
auto_save (bool): Flag to enable auto-saving of conversation history.
|
267
|
+
save_as_yaml (bool): Flag to save conversation history as YAML.
|
268
|
+
save_as_json_bool (bool): Flag to save conversation history as JSON.
|
269
|
+
token_count (bool): Flag to enable token counting for messages.
|
270
|
+
cache_enabled (bool): Flag to enable prompt caching.
|
271
|
+
redis_host (str): Redis server host.
|
272
|
+
redis_port (int): Redis server port.
|
273
|
+
redis_db (int): Redis database number.
|
274
|
+
redis_password (Optional[str]): Redis password for authentication.
|
275
|
+
redis_ssl (bool): Whether to use SSL for Redis connection.
|
276
|
+
redis_retry_attempts (int): Number of connection retry attempts.
|
277
|
+
redis_retry_delay (float): Delay between retry attempts in seconds.
|
278
|
+
use_embedded_redis (bool): Whether to start an embedded Redis server.
|
279
|
+
If True, redis_host and redis_port will be used for the embedded server.
|
280
|
+
persist_redis (bool): Whether to enable Redis persistence.
|
281
|
+
auto_persist (bool): Whether to automatically handle persistence.
|
282
|
+
If True, persistence will be managed automatically.
|
283
|
+
If False, persistence will be manual even if persist_redis is True.
|
284
|
+
redis_data_dir (Optional[str]): Directory for Redis data persistence.
|
285
|
+
conversation_id (Optional[str]): Specific conversation ID to use/restore.
|
286
|
+
If None, a new ID will be generated.
|
287
|
+
name (Optional[str]): A friendly name for the conversation.
|
288
|
+
If provided, this will be used to look up or create a conversation.
|
289
|
+
Takes precedence over conversation_id if both are provided.
|
290
|
+
|
291
|
+
Raises:
|
292
|
+
ImportError: If Redis package is not installed.
|
293
|
+
RedisConnectionError: If connection to Redis fails.
|
294
|
+
RedisOperationError: If Redis operations fail.
|
295
|
+
"""
|
296
|
+
if not REDIS_AVAILABLE:
|
297
|
+
logger.error(
|
298
|
+
"Redis package is not installed. Please install it with 'pip install redis'"
|
299
|
+
)
|
300
|
+
raise ImportError(
|
301
|
+
"Redis package is not installed. Please install it with 'pip install redis'"
|
302
|
+
)
|
303
|
+
|
304
|
+
super().__init__()
|
305
|
+
self.system_prompt = system_prompt
|
306
|
+
self.time_enabled = time_enabled
|
307
|
+
self.autosave = autosave
|
308
|
+
self.save_filepath = save_filepath
|
309
|
+
self.tokenizer = tokenizer
|
310
|
+
self.context_length = context_length
|
311
|
+
self.rules = rules
|
312
|
+
self.custom_rules_prompt = custom_rules_prompt
|
313
|
+
self.user = user
|
314
|
+
self.auto_save = auto_save
|
315
|
+
self.save_as_yaml = save_as_yaml
|
316
|
+
self.save_as_json_bool = save_as_json_bool
|
317
|
+
self.token_count = token_count
|
318
|
+
self.cache_enabled = cache_enabled
|
319
|
+
self.cache_stats = {
|
320
|
+
"hits": 0,
|
321
|
+
"misses": 0,
|
322
|
+
"cached_tokens": 0,
|
323
|
+
"total_tokens": 0,
|
324
|
+
}
|
325
|
+
self.cache_lock = threading.Lock()
|
326
|
+
|
327
|
+
# Initialize Redis server (embedded or external)
|
328
|
+
self.embedded_server = None
|
329
|
+
if use_embedded_redis:
|
330
|
+
self.embedded_server = EmbeddedRedisServer(
|
331
|
+
port=redis_port,
|
332
|
+
data_dir=redis_data_dir,
|
333
|
+
persist=persist_redis,
|
334
|
+
auto_persist=auto_persist,
|
335
|
+
)
|
336
|
+
if not self.embedded_server.start():
|
337
|
+
raise RedisConnectionError(
|
338
|
+
"Failed to start embedded Redis server"
|
339
|
+
)
|
340
|
+
|
341
|
+
# Initialize Redis client with retries
|
342
|
+
self.redis_client = None
|
343
|
+
self._initialize_redis_connection(
|
344
|
+
host=redis_host,
|
345
|
+
port=redis_port,
|
346
|
+
db=redis_db,
|
347
|
+
password=redis_password,
|
348
|
+
ssl=redis_ssl,
|
349
|
+
retry_attempts=redis_retry_attempts,
|
350
|
+
retry_delay=redis_retry_delay,
|
351
|
+
)
|
352
|
+
|
353
|
+
# Handle conversation name and ID
|
354
|
+
self.name = name
|
355
|
+
if name:
|
356
|
+
# Try to find existing conversation by name
|
357
|
+
existing_id = self._get_conversation_id_by_name(name)
|
358
|
+
if existing_id:
|
359
|
+
self.conversation_id = existing_id
|
360
|
+
logger.info(
|
361
|
+
f"Found existing conversation '{name}' with ID: {self.conversation_id}"
|
362
|
+
)
|
363
|
+
else:
|
364
|
+
# Create new conversation with name
|
365
|
+
self.conversation_id = f"conversation:{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
366
|
+
self._save_conversation_name(name)
|
367
|
+
logger.info(
|
368
|
+
f"Created new conversation '{name}' with ID: {self.conversation_id}"
|
369
|
+
)
|
370
|
+
else:
|
371
|
+
# Use provided ID or generate new one
|
372
|
+
self.conversation_id = (
|
373
|
+
conversation_id
|
374
|
+
or f"conversation:{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
375
|
+
)
|
376
|
+
logger.info(
|
377
|
+
f"Using conversation ID: {self.conversation_id}"
|
378
|
+
)
|
379
|
+
|
380
|
+
# Check if we have existing data
|
381
|
+
has_existing_data = self._load_existing_data()
|
382
|
+
|
383
|
+
if has_existing_data:
|
384
|
+
logger.info(
|
385
|
+
f"Restored conversation data for: {self.name or self.conversation_id}"
|
386
|
+
)
|
387
|
+
else:
|
388
|
+
logger.info(
|
389
|
+
f"Initialized new conversation: {self.name or self.conversation_id}"
|
390
|
+
)
|
391
|
+
# Initialize with prompts only for new conversations
|
392
|
+
try:
|
393
|
+
if self.system_prompt is not None:
|
394
|
+
self.add("System", self.system_prompt)
|
395
|
+
|
396
|
+
if self.rules is not None:
|
397
|
+
self.add("User", rules)
|
398
|
+
|
399
|
+
if custom_rules_prompt is not None:
|
400
|
+
self.add(user or "User", custom_rules_prompt)
|
401
|
+
except RedisError as e:
|
402
|
+
logger.error(
|
403
|
+
f"Failed to initialize conversation: {str(e)}"
|
404
|
+
)
|
405
|
+
raise RedisOperationError(
|
406
|
+
f"Failed to initialize conversation: {str(e)}"
|
407
|
+
)
|
408
|
+
|
409
|
+
def _initialize_redis_connection(
|
410
|
+
self,
|
411
|
+
host: str,
|
412
|
+
port: int,
|
413
|
+
db: int,
|
414
|
+
password: Optional[str],
|
415
|
+
ssl: bool,
|
416
|
+
retry_attempts: int,
|
417
|
+
retry_delay: float,
|
418
|
+
):
|
419
|
+
"""Initialize Redis connection with retry mechanism.
|
420
|
+
|
421
|
+
Args:
|
422
|
+
host (str): Redis host.
|
423
|
+
port (int): Redis port.
|
424
|
+
db (int): Redis database number.
|
425
|
+
password (Optional[str]): Redis password.
|
426
|
+
ssl (bool): Whether to use SSL.
|
427
|
+
retry_attempts (int): Number of retry attempts.
|
428
|
+
retry_delay (float): Delay between retries in seconds.
|
429
|
+
|
430
|
+
Raises:
|
431
|
+
RedisConnectionError: If connection fails after all retries.
|
432
|
+
"""
|
433
|
+
import time
|
434
|
+
|
435
|
+
for attempt in range(retry_attempts):
|
436
|
+
try:
|
437
|
+
self.redis_client = redis.Redis(
|
438
|
+
host=host,
|
439
|
+
port=port,
|
440
|
+
db=db,
|
441
|
+
password=password,
|
442
|
+
ssl=ssl,
|
443
|
+
decode_responses=True,
|
444
|
+
socket_timeout=5.0,
|
445
|
+
socket_connect_timeout=5.0,
|
446
|
+
)
|
447
|
+
# Test connection and load data
|
448
|
+
self.redis_client.ping()
|
449
|
+
|
450
|
+
# Try to load the RDB file if it exists
|
451
|
+
try:
|
452
|
+
self.redis_client.config_set(
|
453
|
+
"dbfilename", "dump.rdb"
|
454
|
+
)
|
455
|
+
self.redis_client.config_set(
|
456
|
+
"dir", os.path.expanduser("~/.swarms/redis")
|
457
|
+
)
|
458
|
+
except redis.ResponseError:
|
459
|
+
pass # Ignore if config set fails
|
460
|
+
|
461
|
+
logger.info(
|
462
|
+
f"Successfully connected to Redis at {host}:{port}"
|
463
|
+
)
|
464
|
+
return
|
465
|
+
except (
|
466
|
+
ConnectionError,
|
467
|
+
TimeoutError,
|
468
|
+
AuthenticationError,
|
469
|
+
BusyLoadingError,
|
470
|
+
) as e:
|
471
|
+
if attempt < retry_attempts - 1:
|
472
|
+
logger.warning(
|
473
|
+
f"Redis connection attempt {attempt + 1} failed: {str(e)}"
|
474
|
+
)
|
475
|
+
time.sleep(retry_delay)
|
476
|
+
else:
|
477
|
+
logger.error(
|
478
|
+
f"Failed to connect to Redis after {retry_attempts} attempts"
|
479
|
+
)
|
480
|
+
raise RedisConnectionError(
|
481
|
+
f"Failed to connect to Redis: {str(e)}"
|
482
|
+
)
|
483
|
+
|
484
|
+
def _load_existing_data(self):
|
485
|
+
"""Load existing data for a conversation ID if it exists"""
|
486
|
+
try:
|
487
|
+
# Check if conversation exists
|
488
|
+
message_ids = self.redis_client.lrange(
|
489
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
490
|
+
)
|
491
|
+
if message_ids:
|
492
|
+
logger.info(
|
493
|
+
f"Found existing data for conversation {self.conversation_id}"
|
494
|
+
)
|
495
|
+
return True
|
496
|
+
return False
|
497
|
+
except Exception as e:
|
498
|
+
logger.warning(
|
499
|
+
f"Error checking for existing data: {str(e)}"
|
500
|
+
)
|
501
|
+
return False
|
502
|
+
|
503
|
+
def _safe_redis_operation(
|
504
|
+
self,
|
505
|
+
operation_name: str,
|
506
|
+
operation_func: callable,
|
507
|
+
*args,
|
508
|
+
**kwargs,
|
509
|
+
):
|
510
|
+
"""Execute Redis operation safely with error handling and logging.
|
511
|
+
|
512
|
+
Args:
|
513
|
+
operation_name (str): Name of the operation for logging.
|
514
|
+
operation_func (callable): Function to execute.
|
515
|
+
*args: Arguments for the function.
|
516
|
+
**kwargs: Keyword arguments for the function.
|
517
|
+
|
518
|
+
Returns:
|
519
|
+
Any: Result of the operation.
|
520
|
+
|
521
|
+
Raises:
|
522
|
+
RedisOperationError: If the operation fails.
|
523
|
+
"""
|
524
|
+
try:
|
525
|
+
return operation_func(*args, **kwargs)
|
526
|
+
except RedisError as e:
|
527
|
+
error_msg = (
|
528
|
+
f"Redis operation '{operation_name}' failed: {str(e)}"
|
529
|
+
)
|
530
|
+
logger.error(error_msg)
|
531
|
+
raise RedisOperationError(error_msg)
|
532
|
+
except Exception as e:
|
533
|
+
error_msg = f"Unexpected error during Redis operation '{operation_name}': {str(e)}"
|
534
|
+
logger.error(error_msg)
|
535
|
+
raise
|
536
|
+
|
537
|
+
def _generate_cache_key(
|
538
|
+
self, content: Union[str, dict, list]
|
539
|
+
) -> str:
|
540
|
+
"""Generate a cache key for the given content.
|
541
|
+
|
542
|
+
Args:
|
543
|
+
content (Union[str, dict, list]): The content to generate a cache key for.
|
544
|
+
|
545
|
+
Returns:
|
546
|
+
str: The cache key.
|
547
|
+
"""
|
548
|
+
try:
|
549
|
+
if isinstance(content, (dict, list)):
|
550
|
+
content = json.dumps(content, sort_keys=True)
|
551
|
+
return hashlib.md5(str(content).encode()).hexdigest()
|
552
|
+
except Exception as e:
|
553
|
+
logger.error(f"Failed to generate cache key: {str(e)}")
|
554
|
+
return hashlib.md5(
|
555
|
+
str(datetime.datetime.now()).encode()
|
556
|
+
).hexdigest()
|
557
|
+
|
558
|
+
def _get_cached_tokens(
|
559
|
+
self, content: Union[str, dict, list]
|
560
|
+
) -> Optional[int]:
|
561
|
+
"""Get the number of cached tokens for the given content.
|
562
|
+
|
563
|
+
Args:
|
564
|
+
content (Union[str, dict, list]): The content to check.
|
565
|
+
|
566
|
+
Returns:
|
567
|
+
Optional[int]: The number of cached tokens, or None if not cached.
|
568
|
+
"""
|
569
|
+
if not self.cache_enabled:
|
570
|
+
return None
|
571
|
+
|
572
|
+
with self.cache_lock:
|
573
|
+
try:
|
574
|
+
cache_key = self._generate_cache_key(content)
|
575
|
+
cached_value = self._safe_redis_operation(
|
576
|
+
"get_cached_tokens",
|
577
|
+
self.redis_client.hget,
|
578
|
+
f"{self.conversation_id}:cache",
|
579
|
+
cache_key,
|
580
|
+
)
|
581
|
+
if cached_value:
|
582
|
+
self.cache_stats["hits"] += 1
|
583
|
+
return int(cached_value)
|
584
|
+
self.cache_stats["misses"] += 1
|
585
|
+
return None
|
586
|
+
except Exception as e:
|
587
|
+
logger.warning(
|
588
|
+
f"Failed to get cached tokens: {str(e)}"
|
589
|
+
)
|
590
|
+
return None
|
591
|
+
|
592
|
+
def _update_cache_stats(
|
593
|
+
self, content: Union[str, dict, list], token_count: int
|
594
|
+
):
|
595
|
+
"""Update cache statistics for the given content.
|
596
|
+
|
597
|
+
Args:
|
598
|
+
content (Union[str, dict, list]): The content to update stats for.
|
599
|
+
token_count (int): The number of tokens in the content.
|
600
|
+
"""
|
601
|
+
if not self.cache_enabled:
|
602
|
+
return
|
603
|
+
|
604
|
+
with self.cache_lock:
|
605
|
+
try:
|
606
|
+
cache_key = self._generate_cache_key(content)
|
607
|
+
self._safe_redis_operation(
|
608
|
+
"update_cache",
|
609
|
+
self.redis_client.hset,
|
610
|
+
f"{self.conversation_id}:cache",
|
611
|
+
cache_key,
|
612
|
+
token_count,
|
613
|
+
)
|
614
|
+
self.cache_stats["cached_tokens"] += token_count
|
615
|
+
self.cache_stats["total_tokens"] += token_count
|
616
|
+
except Exception as e:
|
617
|
+
logger.warning(
|
618
|
+
f"Failed to update cache stats: {str(e)}"
|
619
|
+
)
|
620
|
+
|
621
|
+
def add(
|
622
|
+
self,
|
623
|
+
role: str,
|
624
|
+
content: Union[str, dict, list],
|
625
|
+
*args,
|
626
|
+
**kwargs,
|
627
|
+
):
|
628
|
+
"""Add a message to the conversation history.
|
629
|
+
|
630
|
+
Args:
|
631
|
+
role (str): The role of the speaker (e.g., 'User', 'System').
|
632
|
+
content (Union[str, dict, list]): The content of the message.
|
633
|
+
|
634
|
+
Raises:
|
635
|
+
RedisOperationError: If the operation fails.
|
636
|
+
"""
|
637
|
+
try:
|
638
|
+
message = {
|
639
|
+
"role": role,
|
640
|
+
"timestamp": datetime.datetime.now().isoformat(),
|
641
|
+
}
|
642
|
+
|
643
|
+
if isinstance(content, (dict, list)):
|
644
|
+
message["content"] = json.dumps(content)
|
645
|
+
elif self.time_enabled:
|
646
|
+
message["content"] = (
|
647
|
+
f"Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} \n {content}"
|
648
|
+
)
|
649
|
+
else:
|
650
|
+
message["content"] = str(content)
|
651
|
+
|
652
|
+
# Check cache for token count
|
653
|
+
cached_tokens = self._get_cached_tokens(content)
|
654
|
+
if cached_tokens is not None:
|
655
|
+
message["token_count"] = cached_tokens
|
656
|
+
message["cached"] = "true"
|
657
|
+
else:
|
658
|
+
message["cached"] = "false"
|
659
|
+
|
660
|
+
# Add message to Redis
|
661
|
+
message_id = self._safe_redis_operation(
|
662
|
+
"increment_counter",
|
663
|
+
self.redis_client.incr,
|
664
|
+
f"{self.conversation_id}:message_counter",
|
665
|
+
)
|
666
|
+
|
667
|
+
self._safe_redis_operation(
|
668
|
+
"store_message",
|
669
|
+
self.redis_client.hset,
|
670
|
+
f"{self.conversation_id}:message:{message_id}",
|
671
|
+
mapping=message,
|
672
|
+
)
|
673
|
+
|
674
|
+
self._safe_redis_operation(
|
675
|
+
"append_message_id",
|
676
|
+
self.redis_client.rpush,
|
677
|
+
f"{self.conversation_id}:message_ids",
|
678
|
+
message_id,
|
679
|
+
)
|
680
|
+
|
681
|
+
if (
|
682
|
+
self.token_count is True
|
683
|
+
and message["cached"] == "false"
|
684
|
+
):
|
685
|
+
self._count_tokens(content, message, message_id)
|
686
|
+
|
687
|
+
logger.debug(
|
688
|
+
f"Added message with ID {message_id} to conversation {self.conversation_id}"
|
689
|
+
)
|
690
|
+
except Exception as e:
|
691
|
+
error_msg = f"Failed to add message: {str(e)}"
|
692
|
+
logger.error(error_msg)
|
693
|
+
raise RedisOperationError(error_msg)
|
694
|
+
|
695
|
+
def _count_tokens(
|
696
|
+
self, content: str, message: dict, message_id: int
|
697
|
+
):
|
698
|
+
"""Count tokens for a message in a separate thread.
|
699
|
+
|
700
|
+
Args:
|
701
|
+
content (str): The content to count tokens for.
|
702
|
+
message (dict): The message dictionary.
|
703
|
+
message_id (int): The ID of the message in Redis.
|
704
|
+
"""
|
705
|
+
|
706
|
+
def count_tokens_thread():
|
707
|
+
try:
|
708
|
+
tokens = count_tokens(any_to_str(content))
|
709
|
+
message["token_count"] = int(tokens)
|
710
|
+
|
711
|
+
# Update the message in Redis
|
712
|
+
self._safe_redis_operation(
|
713
|
+
"update_token_count",
|
714
|
+
self.redis_client.hset,
|
715
|
+
f"{self.conversation_id}:message:{message_id}",
|
716
|
+
"token_count",
|
717
|
+
int(tokens),
|
718
|
+
)
|
719
|
+
|
720
|
+
# Update cache stats
|
721
|
+
self._update_cache_stats(content, int(tokens))
|
722
|
+
|
723
|
+
if self.autosave and self.save_filepath:
|
724
|
+
self.save_as_json(self.save_filepath)
|
725
|
+
|
726
|
+
logger.debug(
|
727
|
+
f"Updated token count for message {message_id}: {tokens} tokens"
|
728
|
+
)
|
729
|
+
except Exception as e:
|
730
|
+
logger.error(
|
731
|
+
f"Failed to count tokens for message {message_id}: {str(e)}"
|
732
|
+
)
|
733
|
+
|
734
|
+
token_thread = threading.Thread(target=count_tokens_thread)
|
735
|
+
token_thread.daemon = True
|
736
|
+
token_thread.start()
|
737
|
+
|
738
|
+
def delete(self, index: int):
|
739
|
+
"""Delete a message from the conversation history.
|
740
|
+
|
741
|
+
Args:
|
742
|
+
index (int): Index of the message to delete.
|
743
|
+
|
744
|
+
Raises:
|
745
|
+
RedisOperationError: If the operation fails.
|
746
|
+
ValueError: If the index is invalid.
|
747
|
+
"""
|
748
|
+
try:
|
749
|
+
message_ids = self._safe_redis_operation(
|
750
|
+
"get_message_ids",
|
751
|
+
self.redis_client.lrange,
|
752
|
+
f"{self.conversation_id}:message_ids",
|
753
|
+
0,
|
754
|
+
-1,
|
755
|
+
)
|
756
|
+
|
757
|
+
if not (0 <= index < len(message_ids)):
|
758
|
+
raise ValueError(f"Invalid message index: {index}")
|
759
|
+
|
760
|
+
message_id = message_ids[index]
|
761
|
+
self._safe_redis_operation(
|
762
|
+
"delete_message",
|
763
|
+
self.redis_client.delete,
|
764
|
+
f"{self.conversation_id}:message:{message_id}",
|
765
|
+
)
|
766
|
+
self._safe_redis_operation(
|
767
|
+
"remove_message_id",
|
768
|
+
self.redis_client.lrem,
|
769
|
+
f"{self.conversation_id}:message_ids",
|
770
|
+
1,
|
771
|
+
message_id,
|
772
|
+
)
|
773
|
+
logger.info(
|
774
|
+
f"Deleted message {message_id} from conversation {self.conversation_id}"
|
775
|
+
)
|
776
|
+
except Exception as e:
|
777
|
+
error_msg = (
|
778
|
+
f"Failed to delete message at index {index}: {str(e)}"
|
779
|
+
)
|
780
|
+
logger.error(error_msg)
|
781
|
+
raise RedisOperationError(error_msg)
|
782
|
+
|
783
|
+
def update(
|
784
|
+
self, index: int, role: str, content: Union[str, dict]
|
785
|
+
):
|
786
|
+
"""Update a message in the conversation history.
|
787
|
+
|
788
|
+
Args:
|
789
|
+
index (int): Index of the message to update.
|
790
|
+
role (str): Role of the speaker.
|
791
|
+
content (Union[str, dict]): New content of the message.
|
792
|
+
|
793
|
+
Raises:
|
794
|
+
RedisOperationError: If the operation fails.
|
795
|
+
ValueError: If the index is invalid.
|
796
|
+
"""
|
797
|
+
try:
|
798
|
+
message_ids = self._safe_redis_operation(
|
799
|
+
"get_message_ids",
|
800
|
+
self.redis_client.lrange,
|
801
|
+
f"{self.conversation_id}:message_ids",
|
802
|
+
0,
|
803
|
+
-1,
|
804
|
+
)
|
805
|
+
|
806
|
+
if not message_ids or not (0 <= index < len(message_ids)):
|
807
|
+
raise ValueError(f"Invalid message index: {index}")
|
808
|
+
|
809
|
+
message_id = message_ids[index]
|
810
|
+
message = {
|
811
|
+
"role": role,
|
812
|
+
"content": (
|
813
|
+
json.dumps(content)
|
814
|
+
if isinstance(content, (dict, list))
|
815
|
+
else str(content)
|
816
|
+
),
|
817
|
+
"timestamp": datetime.datetime.now().isoformat(),
|
818
|
+
"cached": "false",
|
819
|
+
}
|
820
|
+
|
821
|
+
# Update the message in Redis
|
822
|
+
self._safe_redis_operation(
|
823
|
+
"update_message",
|
824
|
+
self.redis_client.hset,
|
825
|
+
f"{self.conversation_id}:message:{message_id}",
|
826
|
+
mapping=message,
|
827
|
+
)
|
828
|
+
|
829
|
+
# Update token count if needed
|
830
|
+
if self.token_count:
|
831
|
+
self._count_tokens(content, message, message_id)
|
832
|
+
|
833
|
+
logger.debug(
|
834
|
+
f"Updated message {message_id} in conversation {self.conversation_id}"
|
835
|
+
)
|
836
|
+
except Exception as e:
|
837
|
+
error_msg = (
|
838
|
+
f"Failed to update message at index {index}: {str(e)}"
|
839
|
+
)
|
840
|
+
logger.error(error_msg)
|
841
|
+
raise RedisOperationError(error_msg)
|
842
|
+
|
843
|
+
def query(self, index: int) -> dict:
|
844
|
+
"""Query a message in the conversation history.
|
845
|
+
|
846
|
+
Args:
|
847
|
+
index (int): Index of the message to query.
|
848
|
+
|
849
|
+
Returns:
|
850
|
+
dict: The message with its role and content.
|
851
|
+
"""
|
852
|
+
message_ids = self.redis_client.lrange(
|
853
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
854
|
+
)
|
855
|
+
if 0 <= index < len(message_ids):
|
856
|
+
message_id = message_ids[index]
|
857
|
+
message = self.redis_client.hgetall(
|
858
|
+
f"{self.conversation_id}:message:{message_id}"
|
859
|
+
)
|
860
|
+
if "content" in message and message["content"].startswith(
|
861
|
+
"{"
|
862
|
+
):
|
863
|
+
try:
|
864
|
+
message["content"] = json.loads(
|
865
|
+
message["content"]
|
866
|
+
)
|
867
|
+
except json.JSONDecodeError:
|
868
|
+
pass
|
869
|
+
return message
|
870
|
+
return {}
|
871
|
+
|
872
|
+
def search(self, keyword: str) -> List[dict]:
|
873
|
+
"""Search for messages containing a keyword.
|
874
|
+
|
875
|
+
Args:
|
876
|
+
keyword (str): Keyword to search for.
|
877
|
+
|
878
|
+
Returns:
|
879
|
+
List[dict]: List of messages containing the keyword.
|
880
|
+
"""
|
881
|
+
results = []
|
882
|
+
message_ids = self.redis_client.lrange(
|
883
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
884
|
+
)
|
885
|
+
|
886
|
+
for message_id in message_ids:
|
887
|
+
message = self.redis_client.hgetall(
|
888
|
+
f"{self.conversation_id}:message:{message_id}"
|
889
|
+
)
|
890
|
+
if keyword in message.get("content", ""):
|
891
|
+
if message["content"].startswith("{"):
|
892
|
+
try:
|
893
|
+
message["content"] = json.loads(
|
894
|
+
message["content"]
|
895
|
+
)
|
896
|
+
except json.JSONDecodeError:
|
897
|
+
pass
|
898
|
+
results.append(message)
|
899
|
+
|
900
|
+
return results
|
901
|
+
|
902
|
+
def display_conversation(self, detailed: bool = False):
|
903
|
+
"""Display the conversation history.
|
904
|
+
|
905
|
+
Args:
|
906
|
+
detailed (bool): Whether to show detailed information.
|
907
|
+
"""
|
908
|
+
message_ids = self.redis_client.lrange(
|
909
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
910
|
+
)
|
911
|
+
for message_id in message_ids:
|
912
|
+
message = self.redis_client.hgetall(
|
913
|
+
f"{self.conversation_id}:message:{message_id}"
|
914
|
+
)
|
915
|
+
if message["content"].startswith("{"):
|
916
|
+
try:
|
917
|
+
message["content"] = json.loads(
|
918
|
+
message["content"]
|
919
|
+
)
|
920
|
+
except json.JSONDecodeError:
|
921
|
+
pass
|
922
|
+
formatter.print_panel(
|
923
|
+
f"{message['role']}: {message['content']}\n\n"
|
924
|
+
)
|
925
|
+
|
926
|
+
def export_conversation(self, filename: str):
|
927
|
+
"""Export the conversation history to a file.
|
928
|
+
|
929
|
+
Args:
|
930
|
+
filename (str): Filename to export to.
|
931
|
+
"""
|
932
|
+
with open(filename, "w") as f:
|
933
|
+
message_ids = self.redis_client.lrange(
|
934
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
935
|
+
)
|
936
|
+
for message_id in message_ids:
|
937
|
+
message = self.redis_client.hgetall(
|
938
|
+
f"{self.conversation_id}:message:{message_id}"
|
939
|
+
)
|
940
|
+
f.write(f"{message['role']}: {message['content']}\n")
|
941
|
+
|
942
|
+
def import_conversation(self, filename: str):
|
943
|
+
"""Import a conversation history from a file.
|
944
|
+
|
945
|
+
Args:
|
946
|
+
filename (str): Filename to import from.
|
947
|
+
"""
|
948
|
+
with open(filename) as f:
|
949
|
+
for line in f:
|
950
|
+
role, content = line.split(": ", 1)
|
951
|
+
self.add(role, content.strip())
|
952
|
+
|
953
|
+
def count_messages_by_role(self) -> Dict[str, int]:
|
954
|
+
"""Count messages by role.
|
955
|
+
|
956
|
+
Returns:
|
957
|
+
Dict[str, int]: Count of messages by role.
|
958
|
+
"""
|
959
|
+
counts = {
|
960
|
+
"system": 0,
|
961
|
+
"user": 0,
|
962
|
+
"assistant": 0,
|
963
|
+
"function": 0,
|
964
|
+
}
|
965
|
+
message_ids = self.redis_client.lrange(
|
966
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
967
|
+
)
|
968
|
+
for message_id in message_ids:
|
969
|
+
message = self.redis_client.hgetall(
|
970
|
+
f"{self.conversation_id}:message:{message_id}"
|
971
|
+
)
|
972
|
+
role = message["role"].lower()
|
973
|
+
if role in counts:
|
974
|
+
counts[role] += 1
|
975
|
+
return counts
|
976
|
+
|
977
|
+
def return_history_as_string(self) -> str:
|
978
|
+
"""Return the conversation history as a string.
|
979
|
+
|
980
|
+
Returns:
|
981
|
+
str: The conversation history formatted as a string.
|
982
|
+
"""
|
983
|
+
messages = []
|
984
|
+
message_ids = self.redis_client.lrange(
|
985
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
986
|
+
)
|
987
|
+
for message_id in message_ids:
|
988
|
+
message = self.redis_client.hgetall(
|
989
|
+
f"{self.conversation_id}:message:{message_id}"
|
990
|
+
)
|
991
|
+
messages.append(
|
992
|
+
f"{message['role']}: {message['content']}\n\n"
|
993
|
+
)
|
994
|
+
return "".join(messages)
|
995
|
+
|
996
|
+
def get_str(self) -> str:
|
997
|
+
"""Get the conversation history as a string.
|
998
|
+
|
999
|
+
Returns:
|
1000
|
+
str: The conversation history.
|
1001
|
+
"""
|
1002
|
+
messages = []
|
1003
|
+
message_ids = self.redis_client.lrange(
|
1004
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
1005
|
+
)
|
1006
|
+
for message_id in message_ids:
|
1007
|
+
message = self.redis_client.hgetall(
|
1008
|
+
f"{self.conversation_id}:message:{message_id}"
|
1009
|
+
)
|
1010
|
+
msg_str = f"{message['role']}: {message['content']}"
|
1011
|
+
if "token_count" in message:
|
1012
|
+
msg_str += f" (tokens: {message['token_count']})"
|
1013
|
+
if message.get("cached", "false") == "true":
|
1014
|
+
msg_str += " [cached]"
|
1015
|
+
messages.append(msg_str)
|
1016
|
+
return "\n".join(messages)
|
1017
|
+
|
1018
|
+
def save_as_json(self, filename: str = None):
|
1019
|
+
"""Save the conversation history as a JSON file.
|
1020
|
+
|
1021
|
+
Args:
|
1022
|
+
filename (str): Filename to save to.
|
1023
|
+
"""
|
1024
|
+
if filename:
|
1025
|
+
data = []
|
1026
|
+
message_ids = self.redis_client.lrange(
|
1027
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
1028
|
+
)
|
1029
|
+
for message_id in message_ids:
|
1030
|
+
message = self.redis_client.hgetall(
|
1031
|
+
f"{self.conversation_id}:message:{message_id}"
|
1032
|
+
)
|
1033
|
+
if message["content"].startswith("{"):
|
1034
|
+
try:
|
1035
|
+
message["content"] = json.loads(
|
1036
|
+
message["content"]
|
1037
|
+
)
|
1038
|
+
except json.JSONDecodeError:
|
1039
|
+
pass
|
1040
|
+
data.append(message)
|
1041
|
+
|
1042
|
+
with open(filename, "w") as f:
|
1043
|
+
json.dump(data, f, indent=2)
|
1044
|
+
|
1045
|
+
def load_from_json(self, filename: str):
|
1046
|
+
"""Load the conversation history from a JSON file.
|
1047
|
+
|
1048
|
+
Args:
|
1049
|
+
filename (str): Filename to load from.
|
1050
|
+
"""
|
1051
|
+
with open(filename) as f:
|
1052
|
+
data = json.load(f)
|
1053
|
+
self.clear() # Clear existing conversation
|
1054
|
+
for message in data:
|
1055
|
+
self.add(message["role"], message["content"])
|
1056
|
+
|
1057
|
+
def clear(self):
|
1058
|
+
"""Clear the conversation history."""
|
1059
|
+
# Get all message IDs
|
1060
|
+
message_ids = self.redis_client.lrange(
|
1061
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
1062
|
+
)
|
1063
|
+
|
1064
|
+
# Delete all messages
|
1065
|
+
for message_id in message_ids:
|
1066
|
+
self.redis_client.delete(
|
1067
|
+
f"{self.conversation_id}:message:{message_id}"
|
1068
|
+
)
|
1069
|
+
|
1070
|
+
# Clear message IDs list
|
1071
|
+
self.redis_client.delete(
|
1072
|
+
f"{self.conversation_id}:message_ids"
|
1073
|
+
)
|
1074
|
+
|
1075
|
+
# Clear cache
|
1076
|
+
self.redis_client.delete(f"{self.conversation_id}:cache")
|
1077
|
+
|
1078
|
+
# Reset message counter
|
1079
|
+
self.redis_client.delete(
|
1080
|
+
f"{self.conversation_id}:message_counter"
|
1081
|
+
)
|
1082
|
+
|
1083
|
+
def to_dict(self) -> List[Dict]:
|
1084
|
+
"""Convert the conversation history to a dictionary.
|
1085
|
+
|
1086
|
+
Returns:
|
1087
|
+
List[Dict]: The conversation history as a list of dictionaries.
|
1088
|
+
"""
|
1089
|
+
data = []
|
1090
|
+
message_ids = self.redis_client.lrange(
|
1091
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
1092
|
+
)
|
1093
|
+
for message_id in message_ids:
|
1094
|
+
message = self.redis_client.hgetall(
|
1095
|
+
f"{self.conversation_id}:message:{message_id}"
|
1096
|
+
)
|
1097
|
+
if message["content"].startswith("{"):
|
1098
|
+
try:
|
1099
|
+
message["content"] = json.loads(
|
1100
|
+
message["content"]
|
1101
|
+
)
|
1102
|
+
except json.JSONDecodeError:
|
1103
|
+
pass
|
1104
|
+
data.append(message)
|
1105
|
+
return data
|
1106
|
+
|
1107
|
+
def to_json(self) -> str:
|
1108
|
+
"""Convert the conversation history to a JSON string.
|
1109
|
+
|
1110
|
+
Returns:
|
1111
|
+
str: The conversation history as a JSON string.
|
1112
|
+
"""
|
1113
|
+
return json.dumps(self.to_dict(), indent=2)
|
1114
|
+
|
1115
|
+
def to_yaml(self) -> str:
|
1116
|
+
"""Convert the conversation history to a YAML string.
|
1117
|
+
|
1118
|
+
Returns:
|
1119
|
+
str: The conversation history as a YAML string.
|
1120
|
+
"""
|
1121
|
+
return yaml.dump(self.to_dict())
|
1122
|
+
|
1123
|
+
def get_last_message_as_string(self) -> str:
|
1124
|
+
"""Get the last message as a formatted string.
|
1125
|
+
|
1126
|
+
Returns:
|
1127
|
+
str: The last message formatted as 'role: content'.
|
1128
|
+
"""
|
1129
|
+
message_ids = self.redis_client.lrange(
|
1130
|
+
f"{self.conversation_id}:message_ids", -1, -1
|
1131
|
+
)
|
1132
|
+
if message_ids:
|
1133
|
+
message = self.redis_client.hgetall(
|
1134
|
+
f"{self.conversation_id}:message:{message_ids[0]}"
|
1135
|
+
)
|
1136
|
+
return f"{message['role']}: {message['content']}"
|
1137
|
+
return ""
|
1138
|
+
|
1139
|
+
def return_messages_as_list(self) -> List[str]:
|
1140
|
+
"""Return the conversation messages as a list of formatted strings.
|
1141
|
+
|
1142
|
+
Returns:
|
1143
|
+
List[str]: List of messages formatted as 'role: content'.
|
1144
|
+
"""
|
1145
|
+
messages = []
|
1146
|
+
message_ids = self.redis_client.lrange(
|
1147
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
1148
|
+
)
|
1149
|
+
for message_id in message_ids:
|
1150
|
+
message = self.redis_client.hgetall(
|
1151
|
+
f"{self.conversation_id}:message:{message_id}"
|
1152
|
+
)
|
1153
|
+
messages.append(
|
1154
|
+
f"{message['role']}: {message['content']}"
|
1155
|
+
)
|
1156
|
+
return messages
|
1157
|
+
|
1158
|
+
def return_messages_as_dictionary(self) -> List[Dict]:
|
1159
|
+
"""Return the conversation messages as a list of dictionaries.
|
1160
|
+
|
1161
|
+
Returns:
|
1162
|
+
List[Dict]: List of dictionaries containing role and content of each message.
|
1163
|
+
"""
|
1164
|
+
messages = []
|
1165
|
+
message_ids = self.redis_client.lrange(
|
1166
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
1167
|
+
)
|
1168
|
+
for message_id in message_ids:
|
1169
|
+
message = self.redis_client.hgetall(
|
1170
|
+
f"{self.conversation_id}:message:{message_id}"
|
1171
|
+
)
|
1172
|
+
if message["content"].startswith("{"):
|
1173
|
+
try:
|
1174
|
+
message["content"] = json.loads(
|
1175
|
+
message["content"]
|
1176
|
+
)
|
1177
|
+
except json.JSONDecodeError:
|
1178
|
+
pass
|
1179
|
+
messages.append(
|
1180
|
+
{
|
1181
|
+
"role": message["role"],
|
1182
|
+
"content": message["content"],
|
1183
|
+
}
|
1184
|
+
)
|
1185
|
+
return messages
|
1186
|
+
|
1187
|
+
def get_cache_stats(self) -> Dict[str, Union[int, float]]:
|
1188
|
+
"""Get statistics about cache usage.
|
1189
|
+
|
1190
|
+
Returns:
|
1191
|
+
Dict[str, Union[int, float]]: Statistics about cache usage.
|
1192
|
+
"""
|
1193
|
+
with self.cache_lock:
|
1194
|
+
total = (
|
1195
|
+
self.cache_stats["hits"] + self.cache_stats["misses"]
|
1196
|
+
)
|
1197
|
+
hit_rate = (
|
1198
|
+
self.cache_stats["hits"] / total if total > 0 else 0
|
1199
|
+
)
|
1200
|
+
return {
|
1201
|
+
"hits": self.cache_stats["hits"],
|
1202
|
+
"misses": self.cache_stats["misses"],
|
1203
|
+
"cached_tokens": self.cache_stats["cached_tokens"],
|
1204
|
+
"total_tokens": self.cache_stats["total_tokens"],
|
1205
|
+
"hit_rate": hit_rate,
|
1206
|
+
}
|
1207
|
+
|
1208
|
+
def truncate_memory_with_tokenizer(self):
|
1209
|
+
"""Truncate the conversation history based on token count."""
|
1210
|
+
if not self.tokenizer:
|
1211
|
+
return
|
1212
|
+
|
1213
|
+
total_tokens = 0
|
1214
|
+
message_ids = self.redis_client.lrange(
|
1215
|
+
f"{self.conversation_id}:message_ids", 0, -1
|
1216
|
+
)
|
1217
|
+
keep_message_ids = []
|
1218
|
+
|
1219
|
+
for message_id in message_ids:
|
1220
|
+
message = self.redis_client.hgetall(
|
1221
|
+
f"{self.conversation_id}:message:{message_id}"
|
1222
|
+
)
|
1223
|
+
tokens = int(
|
1224
|
+
message.get("token_count", 0)
|
1225
|
+
) or count_tokens(message["content"])
|
1226
|
+
|
1227
|
+
if total_tokens + tokens <= self.context_length:
|
1228
|
+
total_tokens += tokens
|
1229
|
+
keep_message_ids.append(message_id)
|
1230
|
+
else:
|
1231
|
+
# Delete messages that exceed the context length
|
1232
|
+
self.redis_client.delete(
|
1233
|
+
f"{self.conversation_id}:message:{message_id}"
|
1234
|
+
)
|
1235
|
+
|
1236
|
+
# Update the message IDs list
|
1237
|
+
self.redis_client.delete(
|
1238
|
+
f"{self.conversation_id}:message_ids"
|
1239
|
+
)
|
1240
|
+
if keep_message_ids:
|
1241
|
+
self.redis_client.rpush(
|
1242
|
+
f"{self.conversation_id}:message_ids",
|
1243
|
+
*keep_message_ids,
|
1244
|
+
)
|
1245
|
+
|
1246
|
+
def get_final_message(self) -> str:
|
1247
|
+
"""Return the final message from the conversation history.
|
1248
|
+
|
1249
|
+
Returns:
|
1250
|
+
str: The final message formatted as 'role: content'.
|
1251
|
+
"""
|
1252
|
+
message_ids = self.redis_client.lrange(
|
1253
|
+
f"{self.conversation_id}:message_ids", -1, -1
|
1254
|
+
)
|
1255
|
+
if message_ids:
|
1256
|
+
message = self.redis_client.hgetall(
|
1257
|
+
f"{self.conversation_id}:message:{message_ids[0]}"
|
1258
|
+
)
|
1259
|
+
return f"{message['role']}: {message['content']}"
|
1260
|
+
return ""
|
1261
|
+
|
1262
|
+
def get_final_message_content(self) -> str:
|
1263
|
+
"""Return the content of the final message from the conversation history.
|
1264
|
+
|
1265
|
+
Returns:
|
1266
|
+
str: The content of the final message.
|
1267
|
+
"""
|
1268
|
+
message_ids = self.redis_client.lrange(
|
1269
|
+
f"{self.conversation_id}:message_ids", -1, -1
|
1270
|
+
)
|
1271
|
+
if message_ids:
|
1272
|
+
message = self.redis_client.hgetall(
|
1273
|
+
f"{self.conversation_id}:message:{message_ids[0]}"
|
1274
|
+
)
|
1275
|
+
return message["content"]
|
1276
|
+
return ""
|
1277
|
+
|
1278
|
+
def __del__(self):
|
1279
|
+
"""Cleanup method to close Redis connection and stop embedded server if running."""
|
1280
|
+
try:
|
1281
|
+
if hasattr(self, "redis_client") and self.redis_client:
|
1282
|
+
self.redis_client.close()
|
1283
|
+
logger.debug(
|
1284
|
+
f"Closed Redis connection for conversation {self.conversation_id}"
|
1285
|
+
)
|
1286
|
+
|
1287
|
+
if (
|
1288
|
+
hasattr(self, "embedded_server")
|
1289
|
+
and self.embedded_server
|
1290
|
+
):
|
1291
|
+
self.embedded_server.stop()
|
1292
|
+
except Exception as e:
|
1293
|
+
logger.warning(f"Error during cleanup: {str(e)}")
|
1294
|
+
|
1295
|
+
def _get_conversation_id_by_name(
|
1296
|
+
self, name: str
|
1297
|
+
) -> Optional[str]:
|
1298
|
+
"""Get conversation ID for a given name.
|
1299
|
+
|
1300
|
+
Args:
|
1301
|
+
name (str): The conversation name to look up.
|
1302
|
+
|
1303
|
+
Returns:
|
1304
|
+
Optional[str]: The conversation ID if found, None otherwise.
|
1305
|
+
"""
|
1306
|
+
try:
|
1307
|
+
return self.redis_client.get(f"conversation_name:{name}")
|
1308
|
+
except Exception as e:
|
1309
|
+
logger.warning(
|
1310
|
+
f"Error looking up conversation name: {str(e)}"
|
1311
|
+
)
|
1312
|
+
return None
|
1313
|
+
|
1314
|
+
def _save_conversation_name(self, name: str):
|
1315
|
+
"""Save the mapping between conversation name and ID.
|
1316
|
+
|
1317
|
+
Args:
|
1318
|
+
name (str): The name to save.
|
1319
|
+
"""
|
1320
|
+
try:
|
1321
|
+
# Save name -> ID mapping
|
1322
|
+
self.redis_client.set(
|
1323
|
+
f"conversation_name:{name}", self.conversation_id
|
1324
|
+
)
|
1325
|
+
# Save ID -> name mapping
|
1326
|
+
self.redis_client.set(
|
1327
|
+
f"conversation_id:{self.conversation_id}:name", name
|
1328
|
+
)
|
1329
|
+
except Exception as e:
|
1330
|
+
logger.warning(
|
1331
|
+
f"Error saving conversation name: {str(e)}"
|
1332
|
+
)
|
1333
|
+
|
1334
|
+
def get_name(self) -> Optional[str]:
|
1335
|
+
"""Get the friendly name of the conversation.
|
1336
|
+
|
1337
|
+
Returns:
|
1338
|
+
Optional[str]: The conversation name if set, None otherwise.
|
1339
|
+
"""
|
1340
|
+
if hasattr(self, "name") and self.name:
|
1341
|
+
return self.name
|
1342
|
+
try:
|
1343
|
+
return self.redis_client.get(
|
1344
|
+
f"conversation_id:{self.conversation_id}:name"
|
1345
|
+
)
|
1346
|
+
except Exception:
|
1347
|
+
return None
|
1348
|
+
|
1349
|
+
def set_name(self, name: str):
|
1350
|
+
"""Set a new name for the conversation.
|
1351
|
+
|
1352
|
+
Args:
|
1353
|
+
name (str): The new name to set.
|
1354
|
+
"""
|
1355
|
+
old_name = self.get_name()
|
1356
|
+
if old_name:
|
1357
|
+
# Remove old name mapping
|
1358
|
+
self.redis_client.delete(f"conversation_name:{old_name}")
|
1359
|
+
|
1360
|
+
self.name = name
|
1361
|
+
self._save_conversation_name(name)
|
1362
|
+
logger.info(f"Set conversation name to: {name}")
|