swarms 7.7.7__py3-none-any.whl → 7.7.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,21 @@
1
- import duckdb
2
- import json
3
1
  import datetime
4
- from typing import List, Optional, Union, Dict
5
- from pathlib import Path
6
- import threading
7
- from contextlib import contextmanager
2
+ import json
8
3
  import logging
9
- from dataclasses import dataclass
10
- from enum import Enum
4
+ import threading
11
5
  import uuid
6
+ from contextlib import contextmanager
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+
10
+ import duckdb
12
11
  import yaml
13
12
 
13
+ from swarms.communication.base_communication import (
14
+ BaseCommunication,
15
+ Message,
16
+ MessageType,
17
+ )
18
+
14
19
  try:
15
20
  from loguru import logger
16
21
 
@@ -19,31 +24,6 @@ except ImportError:
19
24
  LOGURU_AVAILABLE = False
20
25
 
21
26
 
22
- class MessageType(Enum):
23
- """Enum for different types of messages in the conversation."""
24
-
25
- SYSTEM = "system"
26
- USER = "user"
27
- ASSISTANT = "assistant"
28
- FUNCTION = "function"
29
- TOOL = "tool"
30
-
31
-
32
- @dataclass
33
- class Message:
34
- """Data class representing a message in the conversation."""
35
-
36
- role: str
37
- content: Union[str, dict, list]
38
- timestamp: Optional[str] = None
39
- message_type: Optional[MessageType] = None
40
- metadata: Optional[Dict] = None
41
- token_count: Optional[int] = None
42
-
43
- class Config:
44
- arbitrary_types_allowed = True
45
-
46
-
47
27
  class DateTimeEncoder(json.JSONEncoder):
48
28
  """Custom JSON encoder for handling datetime objects."""
49
29
 
@@ -53,7 +33,7 @@ class DateTimeEncoder(json.JSONEncoder):
53
33
  return super().default(obj)
54
34
 
55
35
 
56
- class DuckDBConversation:
36
+ class DuckDBConversation(BaseCommunication):
57
37
  """
58
38
  A production-grade DuckDB wrapper class for managing conversation history.
59
39
  This class provides persistent storage for conversations with various features
@@ -72,15 +52,55 @@ class DuckDBConversation:
72
52
 
73
53
  def __init__(
74
54
  self,
75
- db_path: Union[str, Path] = "conversations.duckdb",
55
+ system_prompt: Optional[str] = None,
56
+ time_enabled: bool = False,
57
+ autosave: bool = False,
58
+ save_filepath: str = None,
59
+ tokenizer: Any = None,
60
+ context_length: int = 8192,
61
+ rules: str = None,
62
+ custom_rules_prompt: str = None,
63
+ user: str = "User:",
64
+ auto_save: bool = True,
65
+ save_as_yaml: bool = True,
66
+ save_as_json_bool: bool = False,
67
+ token_count: bool = True,
68
+ cache_enabled: bool = True,
69
+ db_path: Union[str, Path] = None,
76
70
  table_name: str = "conversations",
77
71
  enable_timestamps: bool = True,
78
72
  enable_logging: bool = True,
79
73
  use_loguru: bool = True,
80
74
  max_retries: int = 3,
81
75
  connection_timeout: float = 5.0,
76
+ *args,
77
+ **kwargs,
82
78
  ):
79
+ super().__init__(
80
+ system_prompt=system_prompt,
81
+ time_enabled=time_enabled,
82
+ autosave=autosave,
83
+ save_filepath=save_filepath,
84
+ tokenizer=tokenizer,
85
+ context_length=context_length,
86
+ rules=rules,
87
+ custom_rules_prompt=custom_rules_prompt,
88
+ user=user,
89
+ auto_save=auto_save,
90
+ save_as_yaml=save_as_yaml,
91
+ save_as_json_bool=save_as_json_bool,
92
+ token_count=token_count,
93
+ cache_enabled=cache_enabled,
94
+ )
95
+
96
+ # Calculate default db_path if not provided
97
+ if db_path is None:
98
+ db_path = self.get_default_db_path("conversations.duckdb")
83
99
  self.db_path = Path(db_path)
100
+
101
+ # Ensure parent directory exists
102
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
103
+
84
104
  self.table_name = table_name
85
105
  self.enable_timestamps = enable_timestamps
86
106
  self.enable_logging = enable_logging
@@ -89,6 +109,7 @@ class DuckDBConversation:
89
109
  self.connection_timeout = connection_timeout
90
110
  self.current_conversation_id = None
91
111
  self._lock = threading.Lock()
112
+ self.tokenizer = tokenizer
92
113
 
93
114
  # Setup logging
94
115
  if self.enable_logging:
@@ -809,12 +830,7 @@ class DuckDBConversation:
809
830
  }
810
831
 
811
832
  def get_conversation_as_dict(self) -> Dict:
812
- """
813
- Get the entire conversation as a dictionary with messages and metadata.
814
-
815
- Returns:
816
- Dict: Dictionary containing conversation ID, messages, and metadata
817
- """
833
+ """Get the entire conversation as a dictionary with messages and metadata."""
818
834
  messages = self.get_messages()
819
835
  stats = self.get_statistics()
820
836
 
@@ -832,12 +848,7 @@ class DuckDBConversation:
832
848
  }
833
849
 
834
850
  def get_conversation_by_role_dict(self) -> Dict[str, List[Dict]]:
835
- """
836
- Get the conversation organized by roles.
837
-
838
- Returns:
839
- Dict[str, List[Dict]]: Dictionary with roles as keys and lists of messages as values
840
- """
851
+ """Get the conversation organized by roles."""
841
852
  with self._get_connection() as conn:
842
853
  result = conn.execute(
843
854
  f"""
@@ -926,12 +937,7 @@ class DuckDBConversation:
926
937
  return timeline_dict
927
938
 
928
939
  def get_conversation_metadata_dict(self) -> Dict:
929
- """
930
- Get detailed metadata about the conversation.
931
-
932
- Returns:
933
- Dict: Dictionary containing detailed conversation metadata
934
- """
940
+ """Get detailed metadata about the conversation."""
935
941
  with self._get_connection() as conn:
936
942
  # Get basic statistics
937
943
  stats = self.get_statistics()
@@ -975,7 +981,7 @@ class DuckDBConversation:
975
981
  "conversation_id": self.current_conversation_id,
976
982
  "basic_stats": stats,
977
983
  "message_type_distribution": {
978
- row[0]: row[1] for row in type_dist
984
+ row[0]: row[1] for row in type_dist if row[0]
979
985
  },
980
986
  "average_tokens_per_message": (
981
987
  avg_tokens[0] if avg_tokens[0] is not None else 0
@@ -987,15 +993,7 @@ class DuckDBConversation:
987
993
  }
988
994
 
989
995
  def save_as_yaml(self, filename: str) -> bool:
990
- """
991
- Save the current conversation to a YAML file.
992
-
993
- Args:
994
- filename (str): Path to save the YAML file
995
-
996
- Returns:
997
- bool: True if save was successful
998
- """
996
+ """Save the current conversation to a YAML file."""
999
997
  try:
1000
998
  with open(filename, "w") as f:
1001
999
  yaml.dump(self.to_dict(), f)
@@ -1008,15 +1006,7 @@ class DuckDBConversation:
1008
1006
  return False
1009
1007
 
1010
1008
  def load_from_yaml(self, filename: str) -> bool:
1011
- """
1012
- Load a conversation from a YAML file.
1013
-
1014
- Args:
1015
- filename (str): Path to the YAML file
1016
-
1017
- Returns:
1018
- bool: True if load was successful
1019
- """
1009
+ """Load a conversation from a YAML file."""
1020
1010
  try:
1021
1011
  with open(filename, "r") as f:
1022
1012
  messages = yaml.safe_load(f)
@@ -1044,3 +1034,310 @@ class DuckDBConversation:
1044
1034
  f"Failed to load conversation from YAML: {e}"
1045
1035
  )
1046
1036
  return False
1037
+
1038
+ def delete(self, index: str):
1039
+ """Delete a message from the conversation history."""
1040
+ with self._get_connection() as conn:
1041
+ conn.execute(
1042
+ f"DELETE FROM {self.table_name} WHERE id = ? AND conversation_id = ?",
1043
+ (index, self.current_conversation_id),
1044
+ )
1045
+
1046
+ def update(
1047
+ self, index: str, role: str, content: Union[str, dict]
1048
+ ):
1049
+ """Update a message in the conversation history."""
1050
+ if isinstance(content, (dict, list)):
1051
+ content = json.dumps(content)
1052
+
1053
+ with self._get_connection() as conn:
1054
+ conn.execute(
1055
+ f"""
1056
+ UPDATE {self.table_name}
1057
+ SET role = ?, content = ?
1058
+ WHERE id = ? AND conversation_id = ?
1059
+ """,
1060
+ (role, content, index, self.current_conversation_id),
1061
+ )
1062
+
1063
+ def query(self, index: str) -> Dict:
1064
+ """Query a message in the conversation history."""
1065
+ with self._get_connection() as conn:
1066
+ result = conn.execute(
1067
+ f"""
1068
+ SELECT * FROM {self.table_name}
1069
+ WHERE id = ? AND conversation_id = ?
1070
+ """,
1071
+ (index, self.current_conversation_id),
1072
+ ).fetchone()
1073
+
1074
+ if not result:
1075
+ return {}
1076
+
1077
+ content = result[2]
1078
+ try:
1079
+ content = json.loads(content)
1080
+ except json.JSONDecodeError:
1081
+ pass
1082
+
1083
+ return {
1084
+ "role": result[1],
1085
+ "content": content,
1086
+ "timestamp": result[3],
1087
+ "message_type": result[4],
1088
+ "metadata": (
1089
+ json.loads(result[5]) if result[5] else None
1090
+ ),
1091
+ "token_count": result[6],
1092
+ }
1093
+
1094
+ def search(self, keyword: str) -> List[Dict]:
1095
+ """Search for messages containing a keyword."""
1096
+ return self.search_messages(keyword)
1097
+
1098
+ def display_conversation(self, detailed: bool = False):
1099
+ """Display the conversation history."""
1100
+ print(self.get_str())
1101
+
1102
+ def export_conversation(self, filename: str):
1103
+ """Export the conversation history to a file."""
1104
+ self.save_as_json(filename)
1105
+
1106
+ def import_conversation(self, filename: str):
1107
+ """Import a conversation history from a file."""
1108
+ self.load_from_json(filename)
1109
+
1110
+ def return_history_as_string(self) -> str:
1111
+ """Return the conversation history as a string."""
1112
+ return self.get_str()
1113
+
1114
+ def clear(self):
1115
+ """Clear the conversation history."""
1116
+ with self._get_connection() as conn:
1117
+ conn.execute(
1118
+ f"DELETE FROM {self.table_name} WHERE conversation_id = ?",
1119
+ (self.current_conversation_id,),
1120
+ )
1121
+
1122
+ def truncate_memory_with_tokenizer(self):
1123
+ """Truncate the conversation history based on token count."""
1124
+ if not self.tokenizer:
1125
+ return
1126
+
1127
+ with self._get_connection() as conn:
1128
+ result = conn.execute(
1129
+ f"""
1130
+ SELECT id, content, token_count
1131
+ FROM {self.table_name}
1132
+ WHERE conversation_id = ?
1133
+ ORDER BY id ASC
1134
+ """,
1135
+ (self.current_conversation_id,),
1136
+ ).fetchall()
1137
+
1138
+ total_tokens = 0
1139
+ ids_to_keep = []
1140
+
1141
+ for row in result:
1142
+ token_count = row[2] or self.tokenizer.count_tokens(
1143
+ row[1]
1144
+ )
1145
+ if total_tokens + token_count <= self.context_length:
1146
+ total_tokens += token_count
1147
+ ids_to_keep.append(row[0])
1148
+ else:
1149
+ break
1150
+
1151
+ if ids_to_keep:
1152
+ ids_str = ",".join(map(str, ids_to_keep))
1153
+ conn.execute(
1154
+ f"""
1155
+ DELETE FROM {self.table_name}
1156
+ WHERE conversation_id = ?
1157
+ AND id NOT IN ({ids_str})
1158
+ """,
1159
+ (self.current_conversation_id,),
1160
+ )
1161
+
1162
+ def get_visible_messages(
1163
+ self, agent: Callable, turn: int
1164
+ ) -> List[Dict]:
1165
+ """
1166
+ Get the visible messages for a given agent and turn.
1167
+
1168
+ Args:
1169
+ agent (Agent): The agent.
1170
+ turn (int): The turn number.
1171
+
1172
+ Returns:
1173
+ List[Dict]: The list of visible messages.
1174
+ """
1175
+ with self._get_connection() as conn:
1176
+ result = conn.execute(
1177
+ f"""
1178
+ SELECT * FROM {self.table_name}
1179
+ WHERE conversation_id = ?
1180
+ AND CAST(json_extract(metadata, '$.turn') AS INTEGER) < ?
1181
+ ORDER BY id ASC
1182
+ """,
1183
+ (self.current_conversation_id, turn),
1184
+ ).fetchall()
1185
+
1186
+ visible_messages = []
1187
+ for row in result:
1188
+ metadata = json.loads(row[5]) if row[5] else {}
1189
+ visible_to = metadata.get("visible_to", "all")
1190
+
1191
+ if visible_to == "all" or (
1192
+ agent and agent.agent_name in visible_to
1193
+ ):
1194
+ content = row[2] # content column
1195
+ try:
1196
+ content = json.loads(content)
1197
+ except json.JSONDecodeError:
1198
+ pass
1199
+
1200
+ message = {
1201
+ "role": row[1],
1202
+ "content": content,
1203
+ "visible_to": visible_to,
1204
+ "turn": metadata.get("turn"),
1205
+ }
1206
+ visible_messages.append(message)
1207
+
1208
+ return visible_messages
1209
+
1210
+ def return_messages_as_list(self) -> List[str]:
1211
+ """Return the conversation messages as a list of formatted strings.
1212
+
1213
+ Returns:
1214
+ list: List of messages formatted as 'role: content'.
1215
+ """
1216
+ with self._get_connection() as conn:
1217
+ result = conn.execute(
1218
+ f"""
1219
+ SELECT role, content FROM {self.table_name}
1220
+ WHERE conversation_id = ?
1221
+ ORDER BY id ASC
1222
+ """,
1223
+ (self.current_conversation_id,),
1224
+ ).fetchall()
1225
+
1226
+ return [
1227
+ f"{row[0]}: {json.loads(row[1]) if isinstance(row[1], str) and row[1].startswith('{') else row[1]}"
1228
+ for row in result
1229
+ ]
1230
+
1231
+ def return_messages_as_dictionary(self) -> List[Dict]:
1232
+ """Return the conversation messages as a list of dictionaries.
1233
+
1234
+ Returns:
1235
+ list: List of dictionaries containing role and content of each message.
1236
+ """
1237
+ with self._get_connection() as conn:
1238
+ result = conn.execute(
1239
+ f"""
1240
+ SELECT role, content FROM {self.table_name}
1241
+ WHERE conversation_id = ?
1242
+ ORDER BY id ASC
1243
+ """,
1244
+ (self.current_conversation_id,),
1245
+ ).fetchall()
1246
+
1247
+ messages = []
1248
+ for row in result:
1249
+ content = row[1]
1250
+ try:
1251
+ content = json.loads(content)
1252
+ except json.JSONDecodeError:
1253
+ pass
1254
+
1255
+ messages.append(
1256
+ {
1257
+ "role": row[0],
1258
+ "content": content,
1259
+ }
1260
+ )
1261
+ return messages
1262
+
1263
+ def add_tool_output_to_agent(self, role: str, tool_output: dict):
1264
+ """Add a tool output to the conversation history.
1265
+
1266
+ Args:
1267
+ role (str): The role of the tool.
1268
+ tool_output (dict): The output from the tool to be added.
1269
+ """
1270
+ self.add(role, tool_output, message_type=MessageType.TOOL)
1271
+
1272
+ def get_final_message(self) -> str:
1273
+ """Return the final message from the conversation history.
1274
+
1275
+ Returns:
1276
+ str: The final message formatted as 'role: content'.
1277
+ """
1278
+ last_message = self.get_last_message()
1279
+ if not last_message:
1280
+ return ""
1281
+ return f"{last_message['role']}: {last_message['content']}"
1282
+
1283
+ def get_final_message_content(self) -> Union[str, dict]:
1284
+ """Return the content of the final message from the conversation history.
1285
+
1286
+ Returns:
1287
+ Union[str, dict]: The content of the final message.
1288
+ """
1289
+ last_message = self.get_last_message()
1290
+ if not last_message:
1291
+ return ""
1292
+ return last_message["content"]
1293
+
1294
+ def return_all_except_first(self) -> List[Dict]:
1295
+ """Return all messages except the first one.
1296
+
1297
+ Returns:
1298
+ list: List of messages except the first one.
1299
+ """
1300
+ with self._get_connection() as conn:
1301
+ result = conn.execute(
1302
+ f"""
1303
+ SELECT role, content, timestamp, message_type, metadata, token_count
1304
+ FROM {self.table_name}
1305
+ WHERE conversation_id = ?
1306
+ ORDER BY id ASC
1307
+ LIMIT -1 OFFSET 2
1308
+ """,
1309
+ (self.current_conversation_id,),
1310
+ ).fetchall()
1311
+
1312
+ messages = []
1313
+ for row in result:
1314
+ content = row[1]
1315
+ try:
1316
+ content = json.loads(content)
1317
+ except json.JSONDecodeError:
1318
+ pass
1319
+
1320
+ message = {
1321
+ "role": row[0],
1322
+ "content": content,
1323
+ }
1324
+ if row[2]: # timestamp
1325
+ message["timestamp"] = row[2]
1326
+ if row[3]: # message_type
1327
+ message["message_type"] = row[3]
1328
+ if row[4]: # metadata
1329
+ message["metadata"] = json.loads(row[4])
1330
+ if row[5]: # token_count
1331
+ message["token_count"] = row[5]
1332
+
1333
+ messages.append(message)
1334
+ return messages
1335
+
1336
+ def return_all_except_first_string(self) -> str:
1337
+ """Return all messages except the first one as a string.
1338
+
1339
+ Returns:
1340
+ str: All messages except the first one as a string.
1341
+ """
1342
+ messages = self.return_all_except_first()
1343
+ return "\n".join(f"{msg['content']}" for msg in messages)