solace-agent-mesh 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of solace-agent-mesh might be problematic. Click here for more details.
- solace_agent_mesh/agents/base_agent_component.py +2 -0
- solace_agent_mesh/agents/global/actions/plantuml_diagram.py +14 -2
- solace_agent_mesh/agents/global/actions/plotly_graph.py +49 -40
- solace_agent_mesh/agents/web_request/actions/do_web_request.py +34 -33
- solace_agent_mesh/cli/__init__.py +1 -1
- solace_agent_mesh/cli/commands/add/gateway.py +162 -9
- solace_agent_mesh/cli/commands/build.py +0 -1
- solace_agent_mesh/cli/commands/init/builtin_agent_step.py +1 -6
- solace_agent_mesh/cli/commands/init/create_config_file_step.py +5 -0
- solace_agent_mesh/cli/commands/init/create_other_project_files_step.py +52 -1
- solace_agent_mesh/cli/commands/init/init.py +1 -5
- solace_agent_mesh/cli/commands/init/project_structure_step.py +0 -29
- solace_agent_mesh/cli/commands/plugin/add.py +3 -1
- solace_agent_mesh/cli/commands/plugin/build.py +11 -2
- solace_agent_mesh/cli/commands/plugin/plugin.py +20 -5
- solace_agent_mesh/cli/commands/plugin/remove.py +3 -1
- solace_agent_mesh/cli/config.py +4 -0
- solace_agent_mesh/cli/utils.py +7 -2
- solace_agent_mesh/common/action_response.py +13 -0
- solace_agent_mesh/common/constants.py +12 -0
- solace_agent_mesh/common/postgres_database.py +11 -5
- solace_agent_mesh/common/utils.py +16 -11
- solace_agent_mesh/configs/monitor_stim_and_errors_to_slack.yaml +3 -0
- solace_agent_mesh/configs/service_embedding.yaml +1 -1
- solace_agent_mesh/configs/service_llm.yaml +1 -1
- solace_agent_mesh/gateway/components/gateway_base.py +7 -1
- solace_agent_mesh/gateway/components/gateway_input.py +8 -5
- solace_agent_mesh/gateway/components/gateway_output.py +12 -3
- solace_agent_mesh/orchestrator/action_manager.py +13 -1
- solace_agent_mesh/orchestrator/components/orchestrator_stimulus_processor_component.py +25 -5
- solace_agent_mesh/orchestrator/orchestrator_prompt.py +155 -35
- solace_agent_mesh/services/file_service/file_service.py +5 -0
- solace_agent_mesh/services/file_service/file_service_constants.py +1 -1
- solace_agent_mesh/services/file_service/file_transformations.py +11 -1
- solace_agent_mesh/services/file_service/file_utils.py +2 -0
- solace_agent_mesh/services/history_service/history_providers/base_history_provider.py +21 -45
- solace_agent_mesh/services/history_service/history_providers/file_history_provider.py +74 -0
- solace_agent_mesh/services/history_service/history_providers/index.py +40 -0
- solace_agent_mesh/services/history_service/history_providers/memory_history_provider.py +19 -153
- solace_agent_mesh/services/history_service/history_providers/mongodb_history_provider.py +66 -0
- solace_agent_mesh/services/history_service/history_providers/redis_history_provider.py +40 -137
- solace_agent_mesh/services/history_service/history_providers/sql_history_provider.py +93 -0
- solace_agent_mesh/services/history_service/history_service.py +315 -41
- solace_agent_mesh/services/history_service/long_term_memory/__init__.py +0 -0
- solace_agent_mesh/services/history_service/long_term_memory/long_term_memory.py +399 -0
- solace_agent_mesh/services/llm_service/components/llm_request_component.py +24 -0
- solace_agent_mesh/templates/gateway-config-template.yaml +2 -1
- solace_agent_mesh/templates/gateway-default-config.yaml +3 -3
- solace_agent_mesh/templates/plugin-gateway-default-config.yaml +29 -0
- solace_agent_mesh/templates/rest-api-default-config.yaml +2 -1
- solace_agent_mesh/templates/slack-default-config.yaml +1 -1
- solace_agent_mesh/templates/web-default-config.yaml +2 -1
- {solace_agent_mesh-0.1.2.dist-info → solace_agent_mesh-0.2.0.dist-info}/METADATA +38 -8
- {solace_agent_mesh-0.1.2.dist-info → solace_agent_mesh-0.2.0.dist-info}/RECORD +57 -52
- solace_agent_mesh/cli/commands/init/rest_api_step.py +0 -50
- solace_agent_mesh/cli/commands/init/web_ui_step.py +0 -40
- {solace_agent_mesh-0.1.2.dist-info → solace_agent_mesh-0.2.0.dist-info}/WHEEL +0 -0
- {solace_agent_mesh-0.1.2.dist-info → solace_agent_mesh-0.2.0.dist-info}/entry_points.txt +0 -0
- {solace_agent_mesh-0.1.2.dist-info → solace_agent_mesh-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,167 +1,33 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Memory history provider
|
|
3
|
+
"""
|
|
4
|
+
|
|
2
5
|
from .base_history_provider import BaseHistoryProvider
|
|
3
6
|
|
|
4
7
|
|
|
5
8
|
class MemoryHistoryProvider(BaseHistoryProvider):
|
|
9
|
+
"""
|
|
10
|
+
A history provider that stores history in memory.
|
|
11
|
+
"""
|
|
12
|
+
|
|
6
13
|
def __init__(self, config=None):
|
|
7
14
|
super().__init__(config)
|
|
8
15
|
self.history = {}
|
|
9
16
|
|
|
10
|
-
def
|
|
11
|
-
"""
|
|
12
|
-
Store a new entry in the history.
|
|
13
|
-
|
|
14
|
-
:param session_id: The session identifier.
|
|
15
|
-
:param history_entry: The entry to be stored in the history.
|
|
16
|
-
"""
|
|
17
|
-
|
|
17
|
+
def store_session(self, session_id, data):
|
|
18
18
|
if session_id not in self.history:
|
|
19
|
-
self.history[session_id] = {
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
"last_active_time": time.time(),
|
|
23
|
-
"num_characters": 0,
|
|
24
|
-
"num_turns": 0,
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
# Check if adding another entry would exceed the max_turns
|
|
28
|
-
if self.history[session_id]["num_turns"] == self.max_turns:
|
|
29
|
-
# Remove the oldest entry
|
|
30
|
-
oldest_entry = self.history[session_id]["history"].pop(0)
|
|
31
|
-
# Subtract the length of the oldest entry from the total length
|
|
32
|
-
self.history[session_id]["num_characters"] -= len(
|
|
33
|
-
str(oldest_entry["content"])
|
|
34
|
-
)
|
|
35
|
-
self.history[session_id]["num_turns"] -= 1
|
|
36
|
-
|
|
37
|
-
if (
|
|
38
|
-
self.enforce_alternate_message_roles
|
|
39
|
-
and self.history[session_id]["num_turns"] > 0
|
|
40
|
-
# Check if the last entry was by the same role
|
|
41
|
-
and self.history[session_id]["history"]
|
|
42
|
-
and self.history[session_id]["history"][-1]["role"] == role
|
|
43
|
-
):
|
|
44
|
-
# Append to last entry
|
|
45
|
-
self.history[session_id]["history"][-1]["content"] += content
|
|
46
|
-
else:
|
|
47
|
-
# Add the new entry
|
|
48
|
-
self.history[session_id]["history"].append(
|
|
49
|
-
{"role": role, "content": content}
|
|
50
|
-
)
|
|
51
|
-
# Update the number of turns
|
|
52
|
-
self.history[session_id]["num_turns"] += 1
|
|
53
|
-
|
|
54
|
-
# Update the length
|
|
55
|
-
self.history[session_id]["num_characters"] += len(str(content))
|
|
56
|
-
# Update the last active time
|
|
57
|
-
self.history[session_id]["last_active_time"] = time.time()
|
|
19
|
+
self.history[session_id] = {}
|
|
20
|
+
|
|
21
|
+
self.history[session_id].update(data)
|
|
58
22
|
|
|
59
|
-
|
|
60
|
-
if self.max_characters:
|
|
61
|
-
while (
|
|
62
|
-
self.history[session_id]["num_characters"] > self.max_characters
|
|
63
|
-
and self.history[session_id]["num_turns"] > 0
|
|
64
|
-
):
|
|
65
|
-
# Remove the oldest entry
|
|
66
|
-
oldest_entry = self.history[session_id]["history"].pop(0)
|
|
67
|
-
# Subtract the length of the oldest entry from the total length
|
|
68
|
-
self.history[session_id]["num_characters"] -= len(
|
|
69
|
-
str(oldest_entry["content"])
|
|
70
|
-
)
|
|
71
|
-
self.history[session_id]["num_turns"] -= 1
|
|
72
|
-
|
|
73
|
-
def get_history(self, session_id: str):
|
|
74
|
-
"""
|
|
75
|
-
Retrieve the entire history for a session.
|
|
76
|
-
|
|
77
|
-
:param session_id: The session identifier.
|
|
78
|
-
:return: The complete history for the session.
|
|
79
|
-
"""
|
|
80
|
-
if session_id not in self.history:
|
|
81
|
-
return []
|
|
82
|
-
return self.history.get(session_id)["history"]
|
|
83
|
-
|
|
84
|
-
def store_file(self, session_id: str, file: dict):
|
|
85
|
-
"""
|
|
86
|
-
Store a file in the history.
|
|
87
|
-
|
|
88
|
-
:param session_id: The session identifier.
|
|
89
|
-
:param file: The file metadata to be stored in the history.
|
|
90
|
-
"""
|
|
23
|
+
def get_session(self, session_id):
|
|
91
24
|
if session_id not in self.history:
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
"files": [],
|
|
95
|
-
"last_active_time": time.time(),
|
|
96
|
-
"num_characters": 0,
|
|
97
|
-
"num_turns": 0,
|
|
98
|
-
}
|
|
99
|
-
self.history[session_id]["last_active_time"] = time.time()
|
|
100
|
-
|
|
101
|
-
# Check duplicate
|
|
102
|
-
for f in self.history[session_id]["files"]:
|
|
103
|
-
if f.get("url") and f.get("url") == file.get("url"):
|
|
104
|
-
return
|
|
105
|
-
|
|
106
|
-
self.history[session_id]["files"].append(file)
|
|
107
|
-
|
|
108
|
-
def get_files(self, session_id: str):
|
|
109
|
-
"""
|
|
110
|
-
Retrieve the files for a session.
|
|
111
|
-
|
|
112
|
-
:param session_id: The session identifier.
|
|
113
|
-
:return: The files for the session.
|
|
114
|
-
"""
|
|
115
|
-
if session_id not in self.history:
|
|
116
|
-
return []
|
|
117
|
-
files = []
|
|
118
|
-
current_time = time.time()
|
|
119
|
-
all_files = self.history.get(session_id)["files"].copy()
|
|
120
|
-
for file in all_files:
|
|
121
|
-
expiration_timestamp = file.get("expiration_timestamp")
|
|
122
|
-
if expiration_timestamp and current_time > expiration_timestamp:
|
|
123
|
-
self.history[session_id]["files"].remove(file)
|
|
124
|
-
continue
|
|
125
|
-
files.append(file)
|
|
126
|
-
return files
|
|
127
|
-
|
|
128
|
-
def clear_history(self, session_id: str, keep_levels=0):
|
|
129
|
-
"""
|
|
130
|
-
Clear the history for a session, optionally keeping a specified number of recent entries.
|
|
131
|
-
|
|
132
|
-
:param session_id: The session identifier.
|
|
133
|
-
:param keep_levels: Number of most recent history entries to keep. Default is 0 (clear all).
|
|
134
|
-
"""
|
|
135
|
-
if session_id in self.history:
|
|
136
|
-
if keep_levels <= 0:
|
|
137
|
-
del self.history[session_id]
|
|
138
|
-
else:
|
|
139
|
-
self.history[session_id]["history"] = self.history[session_id][
|
|
140
|
-
"history"
|
|
141
|
-
][-keep_levels:]
|
|
142
|
-
# Recalculate the length and num_turns
|
|
143
|
-
self.history[session_id]["num_characters"] = sum(
|
|
144
|
-
len(str(entry)) for entry in self.history[session_id]["history"]
|
|
145
|
-
)
|
|
146
|
-
self.history[session_id]["num_turns"] = len(
|
|
147
|
-
self.history[session_id]["history"]
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
def get_session_meta(self, session_id: str):
|
|
151
|
-
"""
|
|
152
|
-
Retrieve the session metadata.
|
|
153
|
-
|
|
154
|
-
:param session_id: The session identifier.
|
|
155
|
-
:return: The session metadata.
|
|
156
|
-
"""
|
|
157
|
-
if session_id in self.history:
|
|
158
|
-
session = self.history[session_id]
|
|
159
|
-
return {
|
|
160
|
-
"num_characters": session["num_characters"],
|
|
161
|
-
"num_turns": session["num_turns"],
|
|
162
|
-
"last_active_time": session["last_active_time"],
|
|
163
|
-
}
|
|
164
|
-
return None
|
|
25
|
+
return {}
|
|
26
|
+
return self.history[session_id]
|
|
165
27
|
|
|
166
28
|
def get_all_sessions(self) -> list[str]:
|
|
167
29
|
return list(self.history.keys())
|
|
30
|
+
|
|
31
|
+
def delete_session(self, session_id):
|
|
32
|
+
if session_id in self.history:
|
|
33
|
+
del self.history[session_id]
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MongoDB-based history provider for storing session data.
|
|
3
|
+
"""
|
|
4
|
+
from .base_history_provider import BaseHistoryProvider
|
|
5
|
+
|
|
6
|
+
class MongoDBHistoryProvider(BaseHistoryProvider):
|
|
7
|
+
"""
|
|
8
|
+
A MongoDB-based history provider for storing session data.
|
|
9
|
+
"""
|
|
10
|
+
def __init__(self, config=None):
|
|
11
|
+
super().__init__(config)
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from pymongo import MongoClient
|
|
15
|
+
except ImportError:
|
|
16
|
+
raise ImportError("Please install the pymongo package to use the MongoDBHistoryProvider.\n\t$ pip install pymongo")
|
|
17
|
+
|
|
18
|
+
if not self.config.get("mongodb_uri"):
|
|
19
|
+
raise ValueError("Missing required configuration for MongoDBHistoryProvider, Missing 'mongodb_uri' in 'store_config'.")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
self.client = MongoClient(self.config.get("mongodb_uri"))
|
|
23
|
+
self.db = self.client[self.config.get("mongodb_db", "history_db")]
|
|
24
|
+
self.collection = self.db[self.config.get("mongodb_collection", "sessions")]
|
|
25
|
+
|
|
26
|
+
def _get_key(self, session_id):
|
|
27
|
+
"""
|
|
28
|
+
Generate a document identifier for a session.
|
|
29
|
+
|
|
30
|
+
:param session_id: The session identifier.
|
|
31
|
+
:return: The session ID as the primary key.
|
|
32
|
+
"""
|
|
33
|
+
return {"_id": session_id}
|
|
34
|
+
|
|
35
|
+
def store_session(self, session_id: str, data: dict):
|
|
36
|
+
"""
|
|
37
|
+
Store the session metadata.
|
|
38
|
+
|
|
39
|
+
:param session_id: The session identifier.
|
|
40
|
+
:param data: The session data to be stored.
|
|
41
|
+
"""
|
|
42
|
+
self.collection.update_one(self._get_key(session_id), {"$set": {"data": data}}, upsert=True)
|
|
43
|
+
|
|
44
|
+
def get_session(self, session_id: str)->dict:
|
|
45
|
+
"""
|
|
46
|
+
Retrieve the session.
|
|
47
|
+
|
|
48
|
+
:param session_id: The session identifier.
|
|
49
|
+
:return: The session metadata as a dictionary.
|
|
50
|
+
"""
|
|
51
|
+
document = self.collection.find_one(self._get_key(session_id))
|
|
52
|
+
return document.get("data") if document else {}
|
|
53
|
+
|
|
54
|
+
def get_all_sessions(self) -> list[str]:
|
|
55
|
+
"""
|
|
56
|
+
Retrieve all session identifiers.
|
|
57
|
+
"""
|
|
58
|
+
return [doc["_id"] for doc in self.collection.find({}, {"_id": 1})]
|
|
59
|
+
|
|
60
|
+
def delete_session(self, session_id: str):
|
|
61
|
+
"""
|
|
62
|
+
Delete the session.
|
|
63
|
+
|
|
64
|
+
:param session_id: The session identifier.
|
|
65
|
+
"""
|
|
66
|
+
self.collection.delete_one(self._get_key(session_id))
|
|
@@ -1,8 +1,13 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A history provider that stores history in Redis.
|
|
3
|
+
"""
|
|
1
4
|
import json
|
|
2
|
-
import time
|
|
3
5
|
from .base_history_provider import BaseHistoryProvider
|
|
4
6
|
|
|
5
7
|
class RedisHistoryProvider(BaseHistoryProvider):
|
|
8
|
+
"""
|
|
9
|
+
A history provider that stores history in Redis.
|
|
10
|
+
"""
|
|
6
11
|
def __init__(self, config=None):
|
|
7
12
|
super().__init__(config)
|
|
8
13
|
try:
|
|
@@ -14,150 +19,48 @@ class RedisHistoryProvider(BaseHistoryProvider):
|
|
|
14
19
|
host=self.config.get("redis_host", "localhost"),
|
|
15
20
|
port=self.config.get("redis_port", 6379),
|
|
16
21
|
db=self.config.get("redis_db", 0),
|
|
22
|
+
decode_responses=True # Ensures string output
|
|
17
23
|
)
|
|
24
|
+
|
|
25
|
+
def _get_key(self, session_id):
|
|
26
|
+
"""
|
|
27
|
+
Generate a Redis key with a specific prefix for a session.
|
|
18
28
|
|
|
19
|
-
|
|
20
|
-
return
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
return f"session:{session_id}:files"
|
|
24
|
-
|
|
25
|
-
def store_history(self, session_id: str, role: str, content: str | dict):
|
|
26
|
-
key = self._get_history_key(session_id)
|
|
27
|
-
entry = {"role": role, "content": content}
|
|
28
|
-
entry_json = json.dumps(entry)
|
|
29
|
-
|
|
30
|
-
# Check if session exists, if not initialize it
|
|
31
|
-
if not self.redis_client.exists(key):
|
|
32
|
-
self.redis_client.hset(session_id, mapping={
|
|
33
|
-
"num_characters": 0,
|
|
34
|
-
"num_turns": 0,
|
|
35
|
-
"last_active_time": time.time()
|
|
36
|
-
})
|
|
37
|
-
|
|
38
|
-
# Get current stats
|
|
39
|
-
session_meta = self.redis_client.hgetall(session_id)
|
|
40
|
-
num_characters = int(session_meta.get(b"num_characters", 0))
|
|
41
|
-
num_turns = int(session_meta.get(b"num_turns", 0))
|
|
42
|
-
|
|
43
|
-
# Add the new entry
|
|
44
|
-
if self.enforce_alternate_message_roles and num_turns > 0:
|
|
45
|
-
last_entry = json.loads(self.redis_client.lindex(key, -1))
|
|
46
|
-
if last_entry["role"] == role:
|
|
47
|
-
last_entry["content"] += content
|
|
48
|
-
self.redis_client.lset(key, -1, json.dumps(last_entry))
|
|
49
|
-
else:
|
|
50
|
-
self.redis_client.rpush(key, entry_json)
|
|
51
|
-
num_turns += 1
|
|
52
|
-
else:
|
|
53
|
-
self.redis_client.rpush(key, entry_json)
|
|
54
|
-
num_turns += 1
|
|
55
|
-
num_characters += len(str(content))
|
|
56
|
-
|
|
57
|
-
# Enforce max_turns by trimming the oldest entry if needed
|
|
58
|
-
if self.max_turns and num_turns > self.max_turns:
|
|
59
|
-
oldest_entry = json.loads(self.redis_client.lpop(key))
|
|
60
|
-
num_characters -= len(str(oldest_entry["content"]))
|
|
61
|
-
num_turns -= 1
|
|
62
|
-
|
|
63
|
-
# Enforce max_characters
|
|
64
|
-
if self.max_characters:
|
|
65
|
-
while num_characters > self.max_characters and num_turns > 0:
|
|
66
|
-
oldest_entry = json.loads(self.redis_client.lpop(key))
|
|
67
|
-
num_characters -= len(str(oldest_entry["content"]))
|
|
68
|
-
num_turns -= 1
|
|
69
|
-
|
|
70
|
-
# Update metadata and set expiration
|
|
71
|
-
self.redis_client.hset(session_id, mapping={
|
|
72
|
-
"num_characters": num_characters,
|
|
73
|
-
"num_turns": num_turns,
|
|
74
|
-
"last_active_time": time.time()
|
|
75
|
-
})
|
|
76
|
-
|
|
77
|
-
def get_history(self, session_id: str):
|
|
78
|
-
key = self._get_history_key(session_id)
|
|
79
|
-
history = self.redis_client.lrange(key, 0, -1)
|
|
80
|
-
|
|
81
|
-
# Decode JSON entries and return a list of dictionaries
|
|
82
|
-
return [json.loads(entry) for entry in history]
|
|
83
|
-
|
|
84
|
-
def store_file(self, session_id: str, file: dict):
|
|
85
|
-
key = self._get_files_key(session_id)
|
|
86
|
-
file_entry = json.dumps(file)
|
|
87
|
-
|
|
88
|
-
# Avoid duplicate files by checking existing URLs
|
|
89
|
-
existing_files = self.get_files(session_id)
|
|
90
|
-
if any(f.get("url") == file.get("url") for f in existing_files):
|
|
91
|
-
return
|
|
92
|
-
|
|
93
|
-
# Add the file and update metadata
|
|
94
|
-
self.redis_client.rpush(key, file_entry)
|
|
95
|
-
self.redis_client.hset(session_id, "last_active_time", time.time())
|
|
96
|
-
|
|
97
|
-
def get_files(self, session_id: str):
|
|
98
|
-
key = self._get_files_key(session_id)
|
|
99
|
-
current_time = time.time()
|
|
100
|
-
files = self.redis_client.lrange(key, 0, -1)
|
|
101
|
-
|
|
102
|
-
valid_files = []
|
|
103
|
-
for file_json in files:
|
|
104
|
-
file = json.loads(file_json)
|
|
105
|
-
expiration_timestamp = file.get("expiration_timestamp")
|
|
106
|
-
|
|
107
|
-
# Remove expired files
|
|
108
|
-
if expiration_timestamp and current_time > expiration_timestamp:
|
|
109
|
-
self.redis_client.lrem(key, 0, file_json)
|
|
110
|
-
else:
|
|
111
|
-
valid_files.append(file)
|
|
112
|
-
|
|
113
|
-
return valid_files
|
|
114
|
-
|
|
115
|
-
def clear_history(self, session_id: str, keep_levels=0):
|
|
116
|
-
history_key = self._get_history_key(session_id)
|
|
117
|
-
files_key = self._get_files_key(session_id)
|
|
118
|
-
|
|
119
|
-
if keep_levels > 0:
|
|
120
|
-
# Keep the latest `keep_levels` entries
|
|
121
|
-
self.redis_client.ltrim(history_key, -keep_levels, -1)
|
|
122
|
-
|
|
123
|
-
# Recalculate session metadata
|
|
124
|
-
remaining_entries = self.redis_client.lrange(history_key, 0, -1)
|
|
125
|
-
num_characters = sum(len(str(json.loads(entry)["content"])) for entry in remaining_entries)
|
|
126
|
-
num_turns = len(remaining_entries)
|
|
29
|
+
:param session_id: The session identifier.
|
|
30
|
+
:return: A formatted Redis key string.
|
|
31
|
+
"""
|
|
32
|
+
return f"sessions:{session_id}:history"
|
|
127
33
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
"num_turns": num_turns
|
|
132
|
-
})
|
|
133
|
-
else:
|
|
134
|
-
# Clear all history and files
|
|
135
|
-
self.redis_client.delete(history_key, files_key, session_id)
|
|
34
|
+
def store_session(self, session_id: str, data: dict):
|
|
35
|
+
"""
|
|
36
|
+
Store the session metadata.
|
|
136
37
|
|
|
38
|
+
:param session_id: The session identifier.
|
|
39
|
+
:param data: The session data to be stored.
|
|
40
|
+
"""
|
|
41
|
+
self.redis_client.set(self._get_key(session_id), json.dumps(data))
|
|
137
42
|
|
|
138
|
-
def
|
|
43
|
+
def get_session(self, session_id: str)->dict:
|
|
139
44
|
"""
|
|
140
|
-
Retrieve the session
|
|
45
|
+
Retrieve the session.
|
|
141
46
|
|
|
142
47
|
:param session_id: The session identifier.
|
|
143
|
-
:return: The session metadata.
|
|
48
|
+
:return: The session metadata as a dictionary.
|
|
144
49
|
"""
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
return None
|
|
148
|
-
# Get current stats
|
|
149
|
-
session_meta = self.redis_client.hgetall(session_id)
|
|
150
|
-
num_characters = int(session_meta.get(b"num_characters", 0))
|
|
151
|
-
num_turns = int(session_meta.get(b"num_turns", 0))
|
|
152
|
-
last_active_time = float(session_meta.get(b"last_active_time", 0))
|
|
153
|
-
return {
|
|
154
|
-
"num_characters": num_characters,
|
|
155
|
-
"num_turns": num_turns,
|
|
156
|
-
"last_active_time": last_active_time,
|
|
157
|
-
}
|
|
50
|
+
data = self.redis_client.get(self._get_key(session_id))
|
|
51
|
+
return json.loads(data) if data else {}
|
|
158
52
|
|
|
53
|
+
def get_all_sessions(self) -> list[str]:
|
|
54
|
+
"""
|
|
55
|
+
Retrieve all session identifiers.
|
|
56
|
+
"""
|
|
57
|
+
keys = self.redis_client.keys("sessions:*:history")
|
|
58
|
+
return [key.split(":")[1] for key in keys]
|
|
59
|
+
|
|
60
|
+
def delete_session(self, session_id: str):
|
|
61
|
+
"""
|
|
62
|
+
Delete the session.
|
|
159
63
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
return [key.decode().split(":")[1] for key in session_keys]
|
|
64
|
+
:param session_id: The session identifier.
|
|
65
|
+
"""
|
|
66
|
+
self.redis_client.delete(self._get_key(session_id))
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from .base_history_provider import BaseHistoryProvider
|
|
4
|
+
from ....common.postgres_database import PostgreSQLDatabase
|
|
5
|
+
from ....common.mysql_database import MySQLDatabase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DatabaseFactory:
|
|
9
|
+
"""
|
|
10
|
+
Factory class to create database instances.
|
|
11
|
+
"""
|
|
12
|
+
DATABASE_PROVIDERS = {
|
|
13
|
+
"postgres": PostgreSQLDatabase,
|
|
14
|
+
"mysql": MySQLDatabase,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def get_database(db_type, **kwargs):
|
|
19
|
+
if db_type not in DatabaseFactory.DATABASE_PROVIDERS:
|
|
20
|
+
raise ValueError(f"Unsupported database type: {db_type}")
|
|
21
|
+
return DatabaseFactory.DATABASE_PROVIDERS[db_type](**kwargs)
|
|
22
|
+
|
|
23
|
+
class SQLHistoryProvider(BaseHistoryProvider):
|
|
24
|
+
"""
|
|
25
|
+
A history provider that stores session history in a SQL database.
|
|
26
|
+
"""
|
|
27
|
+
def __init__(self, config=None):
|
|
28
|
+
super().__init__(config)
|
|
29
|
+
self.db_type = self.config.get("db_type", "postgres")
|
|
30
|
+
self.table_name = self.config.get("table_name", "session_history")
|
|
31
|
+
self.db = DatabaseFactory.get_database(
|
|
32
|
+
self.db_type,
|
|
33
|
+
host=self.config.get("sql_host"),
|
|
34
|
+
user=self.config.get("sql_user"),
|
|
35
|
+
password=self.config.get("sql_password"),
|
|
36
|
+
database=self.config.get("sql_database"),
|
|
37
|
+
)
|
|
38
|
+
self._ensure_table_exists()
|
|
39
|
+
|
|
40
|
+
def _ensure_table_exists(self):
|
|
41
|
+
"""
|
|
42
|
+
Ensures the required table exists in the database.
|
|
43
|
+
"""
|
|
44
|
+
query = f"""
|
|
45
|
+
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
46
|
+
session_id TEXT PRIMARY KEY,
|
|
47
|
+
data JSON
|
|
48
|
+
)
|
|
49
|
+
"""
|
|
50
|
+
self.db.execute(query)
|
|
51
|
+
|
|
52
|
+
def store_session(self, session_id: str, data: dict):
|
|
53
|
+
"""
|
|
54
|
+
Store or update session metadata.
|
|
55
|
+
"""
|
|
56
|
+
query = f"""
|
|
57
|
+
INSERT INTO {self.table_name} (session_id, data)
|
|
58
|
+
VALUES (%s, %s)
|
|
59
|
+
ON CONFLICT (session_id) DO UPDATE
|
|
60
|
+
SET data = EXCLUDED.data
|
|
61
|
+
""" if self.db_type == "postgres" else f"""
|
|
62
|
+
INSERT INTO {self.table_name} (session_id, data)
|
|
63
|
+
VALUES (%s, %s)
|
|
64
|
+
ON DUPLICATE KEY UPDATE data = VALUES(data)
|
|
65
|
+
"""
|
|
66
|
+
self.db.execute(query, (session_id, json.dumps(data)))
|
|
67
|
+
|
|
68
|
+
def get_session(self, session_id: str) -> dict:
|
|
69
|
+
"""
|
|
70
|
+
Retrieve a session by ID.
|
|
71
|
+
"""
|
|
72
|
+
query = f"SELECT data FROM {self.table_name} WHERE session_id = %s"
|
|
73
|
+
cursor = self.db.execute(query, (session_id,))
|
|
74
|
+
row = cursor.fetchone()
|
|
75
|
+
if not row.get("data"):
|
|
76
|
+
return {}
|
|
77
|
+
data = row["data"] if isinstance(row["data"], dict) else json.loads(row["data"])
|
|
78
|
+
return data
|
|
79
|
+
|
|
80
|
+
def get_all_sessions(self) -> list[str]:
|
|
81
|
+
"""
|
|
82
|
+
Retrieve all session identifiers.
|
|
83
|
+
"""
|
|
84
|
+
query = f"SELECT session_id FROM {self.table_name}"
|
|
85
|
+
cursor = self.db.execute(query)
|
|
86
|
+
return [row["session_id"] for row in cursor.fetchall()]
|
|
87
|
+
|
|
88
|
+
def delete_session(self, session_id: str):
|
|
89
|
+
"""
|
|
90
|
+
Delete a session by ID, ensuring only one row is deleted.
|
|
91
|
+
"""
|
|
92
|
+
query = f"DELETE FROM {self.table_name} WHERE session_id = %s LIMIT 1"
|
|
93
|
+
self.db.execute(query, (session_id,))
|