agentscope-runtime 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentscope_runtime/engine/agents/agentscope_agent/agent.py +1 -0
- agentscope_runtime/engine/agents/agno_agent.py +1 -0
- agentscope_runtime/engine/agents/autogen_agent.py +245 -0
- agentscope_runtime/engine/schemas/agent_schemas.py +1 -1
- agentscope_runtime/engine/services/context_manager.py +28 -1
- agentscope_runtime/engine/services/memory_service.py +2 -2
- agentscope_runtime/engine/services/rag_service.py +101 -0
- agentscope_runtime/engine/services/redis_memory_service.py +187 -0
- agentscope_runtime/engine/services/redis_session_history_service.py +155 -0
- agentscope_runtime/sandbox/box/training_box/env_service.py +1 -1
- agentscope_runtime/sandbox/box/training_box/environments/bfcl/bfcl_dataprocess.py +216 -0
- agentscope_runtime/sandbox/box/training_box/environments/bfcl/bfcl_env.py +380 -0
- agentscope_runtime/sandbox/box/training_box/environments/bfcl/env_handler.py +934 -0
- agentscope_runtime/sandbox/box/training_box/training_box.py +139 -9
- agentscope_runtime/sandbox/build.py +1 -1
- agentscope_runtime/sandbox/custom/custom_sandbox.py +0 -1
- agentscope_runtime/sandbox/custom/example.py +0 -1
- agentscope_runtime/sandbox/enums.py +2 -0
- agentscope_runtime/sandbox/manager/container_clients/__init__.py +2 -0
- agentscope_runtime/sandbox/manager/container_clients/docker_client.py +263 -11
- agentscope_runtime/sandbox/manager/container_clients/kubernetes_client.py +605 -0
- agentscope_runtime/sandbox/manager/sandbox_manager.py +112 -113
- agentscope_runtime/sandbox/manager/server/app.py +96 -28
- agentscope_runtime/sandbox/manager/server/config.py +28 -16
- agentscope_runtime/sandbox/model/__init__.py +1 -5
- agentscope_runtime/sandbox/model/container.py +3 -1
- agentscope_runtime/sandbox/model/manager_config.py +21 -15
- agentscope_runtime/sandbox/tools/tool.py +111 -0
- agentscope_runtime/version.py +1 -1
- {agentscope_runtime-0.1.0.dist-info → agentscope_runtime-0.1.2.dist-info}/METADATA +79 -13
- {agentscope_runtime-0.1.0.dist-info → agentscope_runtime-0.1.2.dist-info}/RECORD +35 -28
- agentscope_runtime/sandbox/manager/utils.py +0 -78
- {agentscope_runtime-0.1.0.dist-info → agentscope_runtime-0.1.2.dist-info}/WHEEL +0 -0
- {agentscope_runtime-0.1.0.dist-info → agentscope_runtime-0.1.2.dist-info}/entry_points.txt +0 -0
- {agentscope_runtime-0.1.0.dist-info → agentscope_runtime-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {agentscope_runtime-0.1.0.dist-info → agentscope_runtime-0.1.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from typing import Optional, Type
|
|
3
|
+
|
|
4
|
+
from autogen_core.models import ChatCompletionClient
|
|
5
|
+
from autogen_core.tools import FunctionTool
|
|
6
|
+
from autogen_agentchat.agents import AssistantAgent
|
|
7
|
+
from autogen_agentchat.messages import (
|
|
8
|
+
TextMessage,
|
|
9
|
+
ToolCallExecutionEvent,
|
|
10
|
+
ToolCallRequestEvent,
|
|
11
|
+
ModelClientStreamingChunkEvent,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from ..agents import Agent
|
|
15
|
+
from ..schemas.context import Context
|
|
16
|
+
from ..schemas.agent_schemas import (
|
|
17
|
+
Message,
|
|
18
|
+
TextContent,
|
|
19
|
+
DataContent,
|
|
20
|
+
FunctionCall,
|
|
21
|
+
FunctionCallOutput,
|
|
22
|
+
MessageType,
|
|
23
|
+
RunStatus,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AutogenContextAdapter:
|
|
28
|
+
def __init__(self, context: Context, attr: dict):
|
|
29
|
+
self.context = context
|
|
30
|
+
self.attr = attr
|
|
31
|
+
|
|
32
|
+
# Adapted attribute
|
|
33
|
+
self.toolkit = None
|
|
34
|
+
self.model = None
|
|
35
|
+
self.memory = None
|
|
36
|
+
self.new_message = None
|
|
37
|
+
|
|
38
|
+
async def initialize(self):
|
|
39
|
+
self.model = await self.adapt_model()
|
|
40
|
+
self.memory = await self.adapt_memory()
|
|
41
|
+
self.new_message = await self.adapt_new_message()
|
|
42
|
+
self.toolkit = await self.adapt_tools()
|
|
43
|
+
|
|
44
|
+
async def adapt_memory(self):
|
|
45
|
+
messages = []
|
|
46
|
+
|
|
47
|
+
# Build context
|
|
48
|
+
for msg in self.context.session.messages[:-1]: # Exclude the last one
|
|
49
|
+
messages.append(AutogenContextAdapter.converter(msg))
|
|
50
|
+
|
|
51
|
+
return messages
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def converter(message: Message):
|
|
55
|
+
# TODO: support more message type
|
|
56
|
+
return TextMessage.load(
|
|
57
|
+
{
|
|
58
|
+
"id": message.id,
|
|
59
|
+
"source": message.role,
|
|
60
|
+
"content": message.content[0].text if message.content else "",
|
|
61
|
+
},
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
async def adapt_new_message(self):
|
|
65
|
+
last_message = self.context.session.messages[-1]
|
|
66
|
+
|
|
67
|
+
return AutogenContextAdapter.converter(last_message)
|
|
68
|
+
|
|
69
|
+
async def adapt_model(self):
|
|
70
|
+
return self.attr["model"]
|
|
71
|
+
|
|
72
|
+
async def adapt_tools(self):
|
|
73
|
+
toolkit = self.attr["agent_config"].get("toolkit", [])
|
|
74
|
+
tools = self.attr["tools"]
|
|
75
|
+
|
|
76
|
+
# in case, tools is None and tools == []
|
|
77
|
+
if not tools:
|
|
78
|
+
return toolkit
|
|
79
|
+
|
|
80
|
+
if self.context.activate_tools:
|
|
81
|
+
# Only add activated tool
|
|
82
|
+
activated_tools = self.context.activate_tools
|
|
83
|
+
else:
|
|
84
|
+
from ...sandbox.tools.utils import setup_tools
|
|
85
|
+
|
|
86
|
+
activated_tools = setup_tools(
|
|
87
|
+
tools=self.attr["tools"],
|
|
88
|
+
environment_manager=self.context.environment_manager,
|
|
89
|
+
session_id=self.context.session.id,
|
|
90
|
+
user_id=self.context.session.user_id,
|
|
91
|
+
include_schemas=False,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
for tool in activated_tools:
|
|
95
|
+
func = FunctionTool(
|
|
96
|
+
func=tool.make_function(),
|
|
97
|
+
description=tool.schema["function"]["description"],
|
|
98
|
+
name=tool.name,
|
|
99
|
+
)
|
|
100
|
+
toolkit.append(func)
|
|
101
|
+
|
|
102
|
+
return toolkit
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class AutogenAgent(Agent):
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
name: str,
|
|
109
|
+
model: ChatCompletionClient,
|
|
110
|
+
tools=None,
|
|
111
|
+
agent_config=None,
|
|
112
|
+
agent_builder: Optional[Type[AssistantAgent]] = AssistantAgent,
|
|
113
|
+
):
|
|
114
|
+
super().__init__(name=name, agent_config=agent_config)
|
|
115
|
+
|
|
116
|
+
assert isinstance(
|
|
117
|
+
model,
|
|
118
|
+
ChatCompletionClient,
|
|
119
|
+
), "model must be a subclass of ChatCompletionClient in autogen"
|
|
120
|
+
|
|
121
|
+
# Set default agent_builder
|
|
122
|
+
if agent_builder is None:
|
|
123
|
+
agent_builder = AssistantAgent
|
|
124
|
+
|
|
125
|
+
assert issubclass(
|
|
126
|
+
agent_builder,
|
|
127
|
+
AssistantAgent,
|
|
128
|
+
), "agent_builder must be a subclass of AssistantAgent in autogen"
|
|
129
|
+
|
|
130
|
+
# Replace name if not exists
|
|
131
|
+
self.agent_config["name"] = self.agent_config.get("name") or name
|
|
132
|
+
|
|
133
|
+
self._attr = {
|
|
134
|
+
"model": model,
|
|
135
|
+
"tools": tools,
|
|
136
|
+
"agent_config": self.agent_config,
|
|
137
|
+
"agent_builder": agent_builder,
|
|
138
|
+
}
|
|
139
|
+
self._agent = None
|
|
140
|
+
self.tools = tools
|
|
141
|
+
|
|
142
|
+
def copy(self) -> "AutogenAgent":
|
|
143
|
+
return AutogenAgent(**self._attr)
|
|
144
|
+
|
|
145
|
+
def build(self, as_context):
|
|
146
|
+
self._agent = self._attr["agent_builder"](
|
|
147
|
+
**self._attr["agent_config"],
|
|
148
|
+
model_client=as_context.model,
|
|
149
|
+
tools=as_context.toolkit,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return self._agent
|
|
153
|
+
|
|
154
|
+
async def run(self, context):
|
|
155
|
+
ag_context = AutogenContextAdapter(context=context, attr=self._attr)
|
|
156
|
+
await ag_context.initialize()
|
|
157
|
+
|
|
158
|
+
# We should always build a new agent since the state is manage outside
|
|
159
|
+
# the agent
|
|
160
|
+
self._agent = self.build(ag_context)
|
|
161
|
+
|
|
162
|
+
resp = self._agent.run_stream(
|
|
163
|
+
task=ag_context.memory + [ag_context.new_message],
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
text_message = Message(
|
|
167
|
+
type=MessageType.MESSAGE,
|
|
168
|
+
role="assistant",
|
|
169
|
+
status=RunStatus.InProgress,
|
|
170
|
+
)
|
|
171
|
+
yield text_message
|
|
172
|
+
|
|
173
|
+
text_delta_content = TextContent(delta=True)
|
|
174
|
+
is_text_delta = False
|
|
175
|
+
stream_mode = False
|
|
176
|
+
async for event in resp:
|
|
177
|
+
if getattr(event, "source", "user") == "user":
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
if isinstance(event, TextMessage):
|
|
181
|
+
if stream_mode:
|
|
182
|
+
continue
|
|
183
|
+
is_text_delta = True
|
|
184
|
+
text_delta_content.text = event.content
|
|
185
|
+
text_delta_content = text_message.add_delta_content(
|
|
186
|
+
new_content=text_delta_content,
|
|
187
|
+
)
|
|
188
|
+
yield text_delta_content
|
|
189
|
+
elif isinstance(event, ModelClientStreamingChunkEvent):
|
|
190
|
+
stream_mode = True
|
|
191
|
+
is_text_delta = True
|
|
192
|
+
text_delta_content.text = event.content
|
|
193
|
+
text_delta_content = text_message.add_delta_content(
|
|
194
|
+
new_content=text_delta_content,
|
|
195
|
+
)
|
|
196
|
+
yield text_delta_content
|
|
197
|
+
elif isinstance(event, ToolCallRequestEvent):
|
|
198
|
+
data = DataContent(
|
|
199
|
+
data=FunctionCall(
|
|
200
|
+
call_id=event.id,
|
|
201
|
+
name=event.content[0].name,
|
|
202
|
+
arguments=event.content[0].arguments,
|
|
203
|
+
).model_dump(),
|
|
204
|
+
)
|
|
205
|
+
message = Message(
|
|
206
|
+
type=MessageType.PLUGIN_CALL,
|
|
207
|
+
role="assistant",
|
|
208
|
+
status=RunStatus.Completed,
|
|
209
|
+
content=[data],
|
|
210
|
+
)
|
|
211
|
+
yield message
|
|
212
|
+
elif isinstance(event, ToolCallExecutionEvent):
|
|
213
|
+
data = DataContent(
|
|
214
|
+
data=FunctionCallOutput(
|
|
215
|
+
call_id=event.id,
|
|
216
|
+
output=event.content[0].content,
|
|
217
|
+
).model_dump(),
|
|
218
|
+
)
|
|
219
|
+
message = Message(
|
|
220
|
+
type=MessageType.PLUGIN_CALL_OUTPUT,
|
|
221
|
+
role="assistant",
|
|
222
|
+
status=RunStatus.Completed,
|
|
223
|
+
content=[data],
|
|
224
|
+
)
|
|
225
|
+
yield message
|
|
226
|
+
|
|
227
|
+
# Add to message
|
|
228
|
+
is_text_delta = True
|
|
229
|
+
text_delta_content.text = event.content[0].content
|
|
230
|
+
text_delta_content = text_message.add_delta_content(
|
|
231
|
+
new_content=text_delta_content,
|
|
232
|
+
)
|
|
233
|
+
yield text_delta_content
|
|
234
|
+
|
|
235
|
+
if is_text_delta:
|
|
236
|
+
yield text_message.content_completed(text_delta_content.index)
|
|
237
|
+
yield text_message.completed()
|
|
238
|
+
|
|
239
|
+
async def run_async(
|
|
240
|
+
self,
|
|
241
|
+
context,
|
|
242
|
+
**kwargs,
|
|
243
|
+
):
|
|
244
|
+
async for event in self.run(context):
|
|
245
|
+
yield event
|
|
@@ -4,12 +4,19 @@ from typing import List
|
|
|
4
4
|
|
|
5
5
|
from .manager import ServiceManager
|
|
6
6
|
from .memory_service import MemoryService, InMemoryMemoryService
|
|
7
|
+
from .rag_service import RAGService
|
|
7
8
|
from .session_history_service import (
|
|
8
9
|
SessionHistoryService,
|
|
9
10
|
Session,
|
|
10
11
|
InMemorySessionHistoryService,
|
|
11
12
|
)
|
|
12
|
-
from ..schemas.agent_schemas import
|
|
13
|
+
from ..schemas.agent_schemas import (
|
|
14
|
+
Message,
|
|
15
|
+
MessageType,
|
|
16
|
+
Role,
|
|
17
|
+
TextContent,
|
|
18
|
+
ContentType,
|
|
19
|
+
)
|
|
13
20
|
|
|
14
21
|
|
|
15
22
|
class ContextComposer:
|
|
@@ -19,6 +26,7 @@ class ContextComposer:
|
|
|
19
26
|
session: Session, # session
|
|
20
27
|
memory_service: MemoryService = None,
|
|
21
28
|
session_history_service: SessionHistoryService = None,
|
|
29
|
+
rag_service: RAGService = None,
|
|
22
30
|
):
|
|
23
31
|
# session
|
|
24
32
|
if session_history_service:
|
|
@@ -42,6 +50,18 @@ class ContextComposer:
|
|
|
42
50
|
)
|
|
43
51
|
session.messages = memories + session.messages
|
|
44
52
|
|
|
53
|
+
# rag
|
|
54
|
+
if rag_service:
|
|
55
|
+
query = await rag_service.get_query_text(request_input[-1])
|
|
56
|
+
docs = await rag_service.retrieve(query=query, k=5)
|
|
57
|
+
cooked_doc = "\n".join(docs)
|
|
58
|
+
message = Message(
|
|
59
|
+
type=MessageType.MESSAGE,
|
|
60
|
+
role=Role.SYSTEM,
|
|
61
|
+
content=[TextContent(type=ContentType.TEXT, text=cooked_doc)],
|
|
62
|
+
)
|
|
63
|
+
session.messages.append(message)
|
|
64
|
+
|
|
45
65
|
|
|
46
66
|
class ContextManager(ServiceManager):
|
|
47
67
|
"""
|
|
@@ -53,10 +73,12 @@ class ContextManager(ServiceManager):
|
|
|
53
73
|
context_composer_cls=ContextComposer,
|
|
54
74
|
session_history_service: SessionHistoryService = None,
|
|
55
75
|
memory_service: MemoryService = None,
|
|
76
|
+
rag_service: RAGService = None,
|
|
56
77
|
):
|
|
57
78
|
self._context_composer_cls = context_composer_cls
|
|
58
79
|
self._session_history_service = session_history_service
|
|
59
80
|
self._memory_service = memory_service
|
|
81
|
+
self._rag_service = rag_service
|
|
60
82
|
super().__init__()
|
|
61
83
|
|
|
62
84
|
def _register_default_services(self):
|
|
@@ -68,6 +90,8 @@ class ContextManager(ServiceManager):
|
|
|
68
90
|
|
|
69
91
|
self.register_service("session", self._session_history_service)
|
|
70
92
|
self.register_service("memory", self._memory_service)
|
|
93
|
+
if self._rag_service:
|
|
94
|
+
self.register_service("rag", self._rag_service)
|
|
71
95
|
|
|
72
96
|
async def compose_context(
|
|
73
97
|
self,
|
|
@@ -77,6 +101,7 @@ class ContextManager(ServiceManager):
|
|
|
77
101
|
await self._context_composer_cls.compose(
|
|
78
102
|
memory_service=self._memory_service,
|
|
79
103
|
session_history_service=self._session_history_service,
|
|
104
|
+
rag_service=self._rag_service,
|
|
80
105
|
session=session,
|
|
81
106
|
request_input=request_input,
|
|
82
107
|
)
|
|
@@ -119,10 +144,12 @@ class ContextManager(ServiceManager):
|
|
|
119
144
|
async def create_context_manager(
|
|
120
145
|
memory_service: MemoryService = None,
|
|
121
146
|
session_history_service: SessionHistoryService = None,
|
|
147
|
+
rag_service: RAGService = None,
|
|
122
148
|
):
|
|
123
149
|
manager = ContextManager(
|
|
124
150
|
memory_service=memory_service,
|
|
125
151
|
session_history_service=session_history_service,
|
|
152
|
+
rag_service=rag_service,
|
|
126
153
|
)
|
|
127
154
|
|
|
128
155
|
async with manager:
|
|
@@ -58,8 +58,8 @@ class MemoryService(ServiceWithLifecycleManager):
|
|
|
58
58
|
Args:
|
|
59
59
|
user_id: The user id.
|
|
60
60
|
messages: The user query or the query with history messages,
|
|
61
|
-
|
|
62
|
-
|
|
61
|
+
both in the format of list of messages. If messages is a list,
|
|
62
|
+
the search will be based on the content of the last message.
|
|
63
63
|
filters: The filters used to search memory
|
|
64
64
|
"""
|
|
65
65
|
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from .base import ServiceWithLifecycleManager
|
|
5
|
+
from ..schemas.agent_schemas import Message, MessageType
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RAGService(ServiceWithLifecycleManager):
|
|
9
|
+
"""
|
|
10
|
+
RAG Service
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
async def get_query_text(self, message: Message) -> str:
|
|
14
|
+
"""
|
|
15
|
+
Gets the query text from the messages.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
message: A list of messages.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The query text.
|
|
22
|
+
"""
|
|
23
|
+
if message:
|
|
24
|
+
if message.type == MessageType.MESSAGE:
|
|
25
|
+
for content in message.content:
|
|
26
|
+
if content.type == "text":
|
|
27
|
+
return content.text
|
|
28
|
+
return ""
|
|
29
|
+
|
|
30
|
+
async def retrieve(self, query: str, k: int = 1) -> list[str]:
|
|
31
|
+
raise NotImplementedError
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
DEFAULT_URI = "milvus_demo.db"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class LangChainRAGService(RAGService):
|
|
38
|
+
"""
|
|
39
|
+
RAG Service using LangChain
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
uri: Optional[str] = None,
|
|
45
|
+
docs: Optional[list[str]] = None,
|
|
46
|
+
):
|
|
47
|
+
from langchain_community.embeddings import DashScopeEmbeddings
|
|
48
|
+
from langchain_milvus import Milvus
|
|
49
|
+
|
|
50
|
+
self.Milvus = Milvus
|
|
51
|
+
self.embeddings = DashScopeEmbeddings()
|
|
52
|
+
self.vectorstore = None
|
|
53
|
+
|
|
54
|
+
if uri:
|
|
55
|
+
self.uri = uri
|
|
56
|
+
self.from_db()
|
|
57
|
+
elif docs:
|
|
58
|
+
self.uri = DEFAULT_URI
|
|
59
|
+
self.from_docs(docs)
|
|
60
|
+
else:
|
|
61
|
+
docs = []
|
|
62
|
+
self.uri = DEFAULT_URI
|
|
63
|
+
self.from_docs(docs)
|
|
64
|
+
|
|
65
|
+
def from_docs(self, docs=None):
|
|
66
|
+
if docs is None:
|
|
67
|
+
docs = []
|
|
68
|
+
|
|
69
|
+
self.vectorstore = self.Milvus.from_documents(
|
|
70
|
+
documents=docs,
|
|
71
|
+
embedding=self.embeddings,
|
|
72
|
+
connection_args={
|
|
73
|
+
"uri": self.uri,
|
|
74
|
+
},
|
|
75
|
+
drop_old=False,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def from_db(self):
|
|
79
|
+
self.vectorstore = self.Milvus(
|
|
80
|
+
embedding_function=self.embeddings,
|
|
81
|
+
connection_args={"uri": self.uri},
|
|
82
|
+
index_params={"index_type": "FLAT", "metric_type": "L2"},
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
async def retrieve(self, query: str, k: int = 1) -> list[str]:
|
|
86
|
+
if self.vectorstore is None:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"Vector store not initialized. Call build_index first.",
|
|
89
|
+
)
|
|
90
|
+
docs = self.vectorstore.similarity_search(query, k=k)
|
|
91
|
+
return [doc.page_content for doc in docs]
|
|
92
|
+
|
|
93
|
+
async def start(self) -> None:
|
|
94
|
+
"""Starts the service."""
|
|
95
|
+
|
|
96
|
+
async def stop(self) -> None:
|
|
97
|
+
"""Stops the service."""
|
|
98
|
+
|
|
99
|
+
async def health(self) -> bool:
|
|
100
|
+
"""Checks the health of the service."""
|
|
101
|
+
return True
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from typing import Optional, Dict, Any
|
|
3
|
+
import json
|
|
4
|
+
import redis.asyncio as aioredis
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
from .memory_service import MemoryService
|
|
8
|
+
from ..schemas.agent_schemas import Message, MessageType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RedisMemoryService(MemoryService):
|
|
12
|
+
"""
|
|
13
|
+
A Redis-based implementation of the memory service.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
redis_url: str = "redis://localhost:6379/0",
|
|
19
|
+
redis_client: Optional[aioredis.Redis] = None,
|
|
20
|
+
):
|
|
21
|
+
self._redis_url = redis_url
|
|
22
|
+
self._redis = redis_client
|
|
23
|
+
self._DEFAULT_SESSION_ID = "default"
|
|
24
|
+
|
|
25
|
+
async def start(self) -> None:
|
|
26
|
+
"""Starts the Redis connection."""
|
|
27
|
+
if self._redis is None:
|
|
28
|
+
self._redis = aioredis.from_url(
|
|
29
|
+
self._redis_url,
|
|
30
|
+
decode_responses=True,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
async def stop(self) -> None:
|
|
34
|
+
"""Closes the Redis connection."""
|
|
35
|
+
if self._redis:
|
|
36
|
+
await self._redis.close()
|
|
37
|
+
self._redis = None
|
|
38
|
+
|
|
39
|
+
async def health(self) -> bool:
|
|
40
|
+
"""Checks the health of the service."""
|
|
41
|
+
|
|
42
|
+
if not self._redis:
|
|
43
|
+
return False
|
|
44
|
+
try:
|
|
45
|
+
pong = await self._redis.ping()
|
|
46
|
+
return pong is True or pong == "PONG"
|
|
47
|
+
except Exception:
|
|
48
|
+
return False
|
|
49
|
+
|
|
50
|
+
def _user_key(self, user_id):
|
|
51
|
+
# Each user is a Redis hash
|
|
52
|
+
return f"user_memory:{user_id}"
|
|
53
|
+
|
|
54
|
+
def _serialize(self, messages):
|
|
55
|
+
return json.dumps([msg.dict() for msg in messages])
|
|
56
|
+
|
|
57
|
+
def _deserialize(self, messages_json):
|
|
58
|
+
if not messages_json:
|
|
59
|
+
return []
|
|
60
|
+
return [Message.parse_obj(m) for m in json.loads(messages_json)]
|
|
61
|
+
|
|
62
|
+
async def add_memory(
|
|
63
|
+
self,
|
|
64
|
+
user_id: str,
|
|
65
|
+
messages: list,
|
|
66
|
+
session_id: Optional[str] = None,
|
|
67
|
+
) -> None:
|
|
68
|
+
if not self._redis:
|
|
69
|
+
raise RuntimeError("Redis connection is not available")
|
|
70
|
+
key = self._user_key(user_id)
|
|
71
|
+
field = session_id if session_id else self._DEFAULT_SESSION_ID
|
|
72
|
+
|
|
73
|
+
existing_json = await self._redis.hget(key, field)
|
|
74
|
+
existing_msgs = self._deserialize(existing_json)
|
|
75
|
+
all_msgs = existing_msgs + messages
|
|
76
|
+
await self._redis.hset(key, field, self._serialize(all_msgs))
|
|
77
|
+
|
|
78
|
+
async def search_memory(
|
|
79
|
+
self,
|
|
80
|
+
user_id: str,
|
|
81
|
+
messages: list,
|
|
82
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
83
|
+
) -> list:
|
|
84
|
+
key = self._user_key(user_id)
|
|
85
|
+
if (
|
|
86
|
+
not messages
|
|
87
|
+
or not isinstance(messages, list)
|
|
88
|
+
or len(messages) == 0
|
|
89
|
+
):
|
|
90
|
+
return []
|
|
91
|
+
|
|
92
|
+
message = messages[-1]
|
|
93
|
+
query = await self.get_query_text(message)
|
|
94
|
+
if not query:
|
|
95
|
+
return []
|
|
96
|
+
|
|
97
|
+
keywords = set(query.lower().split())
|
|
98
|
+
|
|
99
|
+
all_msgs = []
|
|
100
|
+
hash_keys = await self._redis.hkeys(key)
|
|
101
|
+
for session_id in hash_keys:
|
|
102
|
+
msgs_json = await self._redis.hget(key, session_id)
|
|
103
|
+
msgs = self._deserialize(msgs_json)
|
|
104
|
+
all_msgs.extend(msgs)
|
|
105
|
+
|
|
106
|
+
matched_messages = []
|
|
107
|
+
for msg in all_msgs:
|
|
108
|
+
candidate_content = await self.get_query_text(msg)
|
|
109
|
+
if candidate_content:
|
|
110
|
+
msg_content_lower = candidate_content.lower()
|
|
111
|
+
if any(keyword in msg_content_lower for keyword in keywords):
|
|
112
|
+
matched_messages.append(msg)
|
|
113
|
+
|
|
114
|
+
if (
|
|
115
|
+
filters
|
|
116
|
+
and "top_k" in filters
|
|
117
|
+
and isinstance(filters["top_k"], int)
|
|
118
|
+
):
|
|
119
|
+
return matched_messages[-filters["top_k"] :]
|
|
120
|
+
|
|
121
|
+
return matched_messages
|
|
122
|
+
|
|
123
|
+
async def get_query_text(self, message: Message) -> str:
|
|
124
|
+
if message:
|
|
125
|
+
if message.type == MessageType.MESSAGE:
|
|
126
|
+
for content in message.content:
|
|
127
|
+
if content.type == "text":
|
|
128
|
+
return content.text
|
|
129
|
+
return ""
|
|
130
|
+
|
|
131
|
+
async def list_memory(
|
|
132
|
+
self,
|
|
133
|
+
user_id: str,
|
|
134
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
135
|
+
) -> list:
|
|
136
|
+
key = self._user_key(user_id)
|
|
137
|
+
all_msgs = []
|
|
138
|
+
hash_keys = await self._redis.hkeys(key)
|
|
139
|
+
for session_id in sorted(hash_keys):
|
|
140
|
+
msgs_json = await self._redis.hget(key, session_id)
|
|
141
|
+
msgs = self._deserialize(msgs_json)
|
|
142
|
+
all_msgs.extend(msgs)
|
|
143
|
+
|
|
144
|
+
page_num = filters.get("page_num", 1) if filters else 1
|
|
145
|
+
page_size = filters.get("page_size", 10) if filters else 10
|
|
146
|
+
|
|
147
|
+
start_index = (page_num - 1) * page_size
|
|
148
|
+
end_index = start_index + page_size
|
|
149
|
+
|
|
150
|
+
return all_msgs[start_index:end_index]
|
|
151
|
+
|
|
152
|
+
async def delete_memory(
|
|
153
|
+
self,
|
|
154
|
+
user_id: str,
|
|
155
|
+
session_id: Optional[str] = None,
|
|
156
|
+
) -> None:
|
|
157
|
+
key = self._user_key(user_id)
|
|
158
|
+
if session_id:
|
|
159
|
+
await self._redis.hdel(key, session_id)
|
|
160
|
+
else:
|
|
161
|
+
await self._redis.delete(key)
|
|
162
|
+
|
|
163
|
+
async def clear_all_memory(self) -> None:
|
|
164
|
+
"""
|
|
165
|
+
Clears all memory data from Redis.
|
|
166
|
+
This method removes all user memory keys from the Redis database.
|
|
167
|
+
"""
|
|
168
|
+
if not self._redis:
|
|
169
|
+
raise RuntimeError("Redis connection is not available")
|
|
170
|
+
|
|
171
|
+
keys = await self._redis.keys(self._user_key("*"))
|
|
172
|
+
if keys:
|
|
173
|
+
await self._redis.delete(*keys)
|
|
174
|
+
|
|
175
|
+
async def delete_user_memory(self, user_id: str) -> None:
|
|
176
|
+
"""
|
|
177
|
+
Deletes all memory data for a specific user.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
user_id (str): The ID of the user whose memory data should be
|
|
181
|
+
deleted
|
|
182
|
+
"""
|
|
183
|
+
if not self._redis:
|
|
184
|
+
raise RuntimeError("Redis connection is not available")
|
|
185
|
+
|
|
186
|
+
key = self._user_key(user_id)
|
|
187
|
+
await self._redis.delete(key)
|