MemoryOS 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.2.0.dist-info → memoryos-0.2.2.dist-info}/METADATA +67 -26
- memoryos-0.2.2.dist-info/RECORD +169 -0
- memoryos-0.2.2.dist-info/entry_points.txt +3 -0
- memos/__init__.py +1 -1
- memos/api/config.py +562 -0
- memos/api/context/context.py +147 -0
- memos/api/context/dependencies.py +90 -0
- memos/api/exceptions.py +28 -0
- memos/api/mcp_serve.py +502 -0
- memos/api/product_api.py +35 -0
- memos/api/product_models.py +163 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +386 -0
- memos/chunkers/sentence_chunker.py +8 -2
- memos/cli.py +113 -0
- memos/configs/embedder.py +27 -0
- memos/configs/graph_db.py +132 -3
- memos/configs/internet_retriever.py +6 -0
- memos/configs/llm.py +47 -0
- memos/configs/mem_cube.py +1 -1
- memos/configs/mem_os.py +5 -0
- memos/configs/mem_reader.py +9 -0
- memos/configs/mem_scheduler.py +107 -7
- memos/configs/mem_user.py +58 -0
- memos/configs/memory.py +5 -4
- memos/dependency.py +52 -0
- memos/embedders/ark.py +92 -0
- memos/embedders/factory.py +4 -0
- memos/embedders/sentence_transformer.py +8 -2
- memos/embedders/universal_api.py +32 -0
- memos/graph_dbs/base.py +11 -3
- memos/graph_dbs/factory.py +4 -0
- memos/graph_dbs/nebular.py +1364 -0
- memos/graph_dbs/neo4j.py +333 -124
- memos/graph_dbs/neo4j_community.py +300 -0
- memos/llms/base.py +9 -0
- memos/llms/deepseek.py +54 -0
- memos/llms/factory.py +10 -1
- memos/llms/hf.py +170 -13
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +4 -0
- memos/llms/openai.py +67 -1
- memos/llms/qwen.py +63 -0
- memos/llms/vllm.py +153 -0
- memos/log.py +1 -1
- memos/mem_cube/general.py +77 -16
- memos/mem_cube/utils.py +109 -0
- memos/mem_os/core.py +251 -51
- memos/mem_os/main.py +94 -12
- memos/mem_os/product.py +1220 -43
- memos/mem_os/utils/default_config.py +352 -0
- memos/mem_os/utils/format_utils.py +1401 -0
- memos/mem_reader/simple_struct.py +18 -10
- memos/mem_scheduler/base_scheduler.py +441 -40
- memos/mem_scheduler/general_scheduler.py +249 -248
- memos/mem_scheduler/modules/base.py +14 -5
- memos/mem_scheduler/modules/dispatcher.py +67 -4
- memos/mem_scheduler/modules/misc.py +104 -0
- memos/mem_scheduler/modules/monitor.py +240 -50
- memos/mem_scheduler/modules/rabbitmq_service.py +319 -0
- memos/mem_scheduler/modules/redis_service.py +32 -22
- memos/mem_scheduler/modules/retriever.py +167 -23
- memos/mem_scheduler/modules/scheduler_logger.py +255 -0
- memos/mem_scheduler/mos_for_test_scheduler.py +140 -0
- memos/mem_scheduler/schemas/__init__.py +0 -0
- memos/mem_scheduler/schemas/general_schemas.py +43 -0
- memos/mem_scheduler/{modules/schemas.py → schemas/message_schemas.py} +63 -61
- memos/mem_scheduler/schemas/monitor_schemas.py +329 -0
- memos/mem_scheduler/utils/__init__.py +0 -0
- memos/mem_scheduler/utils/filter_utils.py +176 -0
- memos/mem_scheduler/utils/misc_utils.py +61 -0
- memos/mem_user/factory.py +94 -0
- memos/mem_user/mysql_persistent_user_manager.py +271 -0
- memos/mem_user/mysql_user_manager.py +500 -0
- memos/mem_user/persistent_factory.py +96 -0
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/mem_user/user_manager.py +4 -4
- memos/memories/activation/item.py +29 -0
- memos/memories/activation/kv.py +10 -3
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/factory.py +2 -0
- memos/memories/textual/base.py +1 -1
- memos/memories/textual/general.py +43 -97
- memos/memories/textual/item.py +5 -33
- memos/memories/textual/tree.py +22 -12
- memos/memories/textual/tree_text_memory/organize/conflict.py +9 -5
- memos/memories/textual/tree_text_memory/organize/manager.py +26 -18
- memos/memories/textual/tree_text_memory/organize/redundancy.py +25 -44
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +50 -48
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +81 -56
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -3
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +0 -1
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +2 -2
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +2 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +52 -28
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +42 -15
- memos/memories/textual/tree_text_memory/retrieve/utils.py +11 -7
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +62 -58
- memos/memos_tools/dinding_report_bot.py +422 -0
- memos/memos_tools/notification_service.py +44 -0
- memos/memos_tools/notification_utils.py +96 -0
- memos/parsers/markitdown.py +8 -2
- memos/settings.py +3 -1
- memos/templates/mem_reader_prompts.py +66 -23
- memos/templates/mem_scheduler_prompts.py +126 -43
- memos/templates/mos_prompts.py +87 -0
- memos/templates/tree_reorganize_prompts.py +85 -30
- memos/vec_dbs/base.py +12 -0
- memos/vec_dbs/qdrant.py +46 -20
- memoryos-0.2.0.dist-info/RECORD +0 -128
- memos/mem_scheduler/utils.py +0 -26
- {memoryos-0.2.0.dist-info → memoryos-0.2.2.dist-info}/LICENSE +0 -0
- {memoryos-0.2.0.dist-info → memoryos-0.2.2.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import ssl
|
|
3
|
+
import threading
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig
|
|
9
|
+
from memos.dependency import require_python_package
|
|
10
|
+
from memos.log import get_logger
|
|
11
|
+
from memos.mem_scheduler.modules.base import BaseSchedulerModule
|
|
12
|
+
from memos.mem_scheduler.modules.misc import AutoDroppingQueue
|
|
13
|
+
from memos.mem_scheduler.schemas.general_schemas import DIRECT_EXCHANGE_TYPE, FANOUT_EXCHANGE_TYPE
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RabbitMQSchedulerModule(BaseSchedulerModule):
|
|
20
|
+
@require_python_package(
|
|
21
|
+
import_name="pika",
|
|
22
|
+
install_command="pip install pika",
|
|
23
|
+
install_link="https://pika.readthedocs.io/en/stable/index.html",
|
|
24
|
+
)
|
|
25
|
+
def __init__(self):
|
|
26
|
+
"""
|
|
27
|
+
Initialize RabbitMQ connection settings.
|
|
28
|
+
"""
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
# RabbitMQ settings
|
|
32
|
+
self.rabbitmq_config: RabbitMQConfig | None = None
|
|
33
|
+
self.rabbit_queue_name = "memos-scheduler"
|
|
34
|
+
self.rabbitmq_exchange_name = "memos-fanout"
|
|
35
|
+
self.rabbitmq_exchange_type = FANOUT_EXCHANGE_TYPE
|
|
36
|
+
self.rabbitmq_connection = None
|
|
37
|
+
self.rabbitmq_channel = None
|
|
38
|
+
|
|
39
|
+
# fixed params
|
|
40
|
+
self.rabbitmq_message_cache_max_size = 10 # Max 10 messages
|
|
41
|
+
self.rabbitmq_message_cache = AutoDroppingQueue(
|
|
42
|
+
maxsize=self.rabbitmq_message_cache_max_size
|
|
43
|
+
)
|
|
44
|
+
self.rabbitmq_connection_attempts = 3 # Max retry attempts on connection failure
|
|
45
|
+
self.rabbitmq_retry_delay = 5 # Delay (seconds) between retries
|
|
46
|
+
self.rabbitmq_heartbeat = 60 # Heartbeat interval (seconds) for connectio
|
|
47
|
+
self.rabbitmq_conn_max_waiting_seconds = 30
|
|
48
|
+
self.rabbitmq_conn_sleep_seconds = 1
|
|
49
|
+
|
|
50
|
+
# Thread management
|
|
51
|
+
self._rabbitmq_io_loop_thread = None # For IOLoop execution
|
|
52
|
+
self._rabbitmq_stop_flag = False # Graceful shutdown flag
|
|
53
|
+
self._rabbitmq_lock = threading.Lock() # Ensure thread safety
|
|
54
|
+
|
|
55
|
+
def is_rabbitmq_connected(self) -> bool:
|
|
56
|
+
"""Check if RabbitMQ connection is alive"""
|
|
57
|
+
return (
|
|
58
|
+
self.rabbitmq_connection
|
|
59
|
+
and self.rabbitmq_connection.is_open
|
|
60
|
+
and self.rabbitmq_channel
|
|
61
|
+
and self.rabbitmq_channel.is_open
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def initialize_rabbitmq(
|
|
65
|
+
self, config: dict | None | RabbitMQConfig = None, config_path: str | Path | None = None
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Establish connection to RabbitMQ using pika.
|
|
69
|
+
"""
|
|
70
|
+
from pika.adapters.select_connection import SelectConnection
|
|
71
|
+
|
|
72
|
+
if config is None:
|
|
73
|
+
if config_path is None and AuthConfig.default_config_exists():
|
|
74
|
+
auth_config = AuthConfig.from_local_yaml()
|
|
75
|
+
elif Path(config_path).exists():
|
|
76
|
+
auth_config = AuthConfig.from_local_yaml(config_path=config_path)
|
|
77
|
+
else:
|
|
78
|
+
logger.error("Fail to initialize auth_config")
|
|
79
|
+
return
|
|
80
|
+
self.rabbitmq_config = auth_config.rabbitmq
|
|
81
|
+
elif isinstance(config, RabbitMQConfig):
|
|
82
|
+
self.rabbitmq_config = config
|
|
83
|
+
elif isinstance(config, dict):
|
|
84
|
+
self.rabbitmq_config = AuthConfig.from_dict(config).rabbitmq
|
|
85
|
+
else:
|
|
86
|
+
logger.error("Not implemented")
|
|
87
|
+
|
|
88
|
+
# Start connection process
|
|
89
|
+
parameters = self.get_rabbitmq_connection_param()
|
|
90
|
+
self.rabbitmq_connection = SelectConnection(
|
|
91
|
+
parameters,
|
|
92
|
+
on_open_callback=self.on_rabbitmq_connection_open,
|
|
93
|
+
on_open_error_callback=self.on_rabbitmq_connection_error,
|
|
94
|
+
on_close_callback=self.on_rabbitmq_connection_closed,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Start IOLoop in dedicated thread
|
|
98
|
+
self._io_loop_thread = threading.Thread(
|
|
99
|
+
target=self.rabbitmq_connection.ioloop.start, daemon=True
|
|
100
|
+
)
|
|
101
|
+
self._io_loop_thread.start()
|
|
102
|
+
logger.info("RabbitMQ connection process started")
|
|
103
|
+
|
|
104
|
+
def get_rabbitmq_queue_size(self) -> int:
|
|
105
|
+
"""Get the current number of messages in the queue.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
int: Number of messages in the queue.
|
|
109
|
+
Returns -1 if there's an error or no active connection.
|
|
110
|
+
"""
|
|
111
|
+
if self.rabbitmq_exchange_type != DIRECT_EXCHANGE_TYPE:
|
|
112
|
+
logger.warning("Queue size can only be checked for direct exchanges")
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
with self._rabbitmq_lock:
|
|
116
|
+
if not self.is_rabbitmq_connected():
|
|
117
|
+
logger.warning("No active connection to check queue size")
|
|
118
|
+
return -1
|
|
119
|
+
|
|
120
|
+
# Declare queue passively (only checks existence, doesn't create)
|
|
121
|
+
# Using passive=True prevents accidental queue creation
|
|
122
|
+
result = self.rabbitmq_channel.queue_declare(
|
|
123
|
+
queue=self.rabbit_queue_name,
|
|
124
|
+
durable=True, # Match the original queue durability setting
|
|
125
|
+
passive=True, # Only check queue existence, don't create
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if result is None:
|
|
129
|
+
return 0
|
|
130
|
+
# Return the message count from the queue declaration result
|
|
131
|
+
return result.method.message_count
|
|
132
|
+
|
|
133
|
+
def get_rabbitmq_connection_param(self):
|
|
134
|
+
import pika
|
|
135
|
+
|
|
136
|
+
credentials = pika.PlainCredentials(
|
|
137
|
+
username=self.rabbitmq_config.user_name,
|
|
138
|
+
password=self.rabbitmq_config.password,
|
|
139
|
+
erase_on_connect=self.rabbitmq_config.erase_on_connect,
|
|
140
|
+
)
|
|
141
|
+
if self.rabbitmq_config.port == 5671:
|
|
142
|
+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
143
|
+
context.check_hostname = False
|
|
144
|
+
context.verify_mode = False
|
|
145
|
+
return pika.ConnectionParameters(
|
|
146
|
+
host=self.rabbitmq_config.host_name,
|
|
147
|
+
port=self.rabbitmq_config.port,
|
|
148
|
+
virtual_host=self.rabbitmq_config.virtual_host,
|
|
149
|
+
credentials=credentials,
|
|
150
|
+
ssl_options=pika.SSLOptions(context),
|
|
151
|
+
connection_attempts=self.rabbitmq_connection_attempts,
|
|
152
|
+
retry_delay=self.rabbitmq_retry_delay,
|
|
153
|
+
heartbeat=self.rabbitmq_heartbeat,
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
return pika.ConnectionParameters(
|
|
157
|
+
host=self.rabbitmq_config.host_name,
|
|
158
|
+
port=self.rabbitmq_config.port,
|
|
159
|
+
virtual_host=self.rabbitmq_config.virtual_host,
|
|
160
|
+
credentials=credentials,
|
|
161
|
+
connection_attempts=self.rabbitmq_connection_attempts,
|
|
162
|
+
retry_delay=self.rabbitmq_retry_delay,
|
|
163
|
+
heartbeat=self.rabbitmq_heartbeat,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Connection lifecycle callbacks
|
|
167
|
+
def on_rabbitmq_connection_open(self, connection):
|
|
168
|
+
"""Called when connection is established."""
|
|
169
|
+
logger.debug("Connection opened")
|
|
170
|
+
connection.channel(on_open_callback=self.on_rabbitmq_channel_open)
|
|
171
|
+
|
|
172
|
+
def on_rabbitmq_connection_error(self, connection, error):
|
|
173
|
+
"""Called if connection fails to open."""
|
|
174
|
+
logger.error(f"Connection failed: {error}")
|
|
175
|
+
self.rabbit_reconnect()
|
|
176
|
+
|
|
177
|
+
def on_rabbitmq_connection_closed(self, connection, reason):
|
|
178
|
+
"""Called when connection closes."""
|
|
179
|
+
logger.warning(f"Connection closed: {reason}")
|
|
180
|
+
if not self._rabbitmq_stop_flag:
|
|
181
|
+
self.rabbit_reconnect()
|
|
182
|
+
|
|
183
|
+
# Channel lifecycle callbacks
|
|
184
|
+
def on_rabbitmq_channel_open(self, channel):
|
|
185
|
+
"""Called when channel is ready."""
|
|
186
|
+
self.rabbitmq_channel = channel
|
|
187
|
+
logger.debug("Channel opened")
|
|
188
|
+
|
|
189
|
+
# Setup exchange and queue
|
|
190
|
+
channel.exchange_declare(
|
|
191
|
+
exchange=self.rabbitmq_exchange_name,
|
|
192
|
+
exchange_type=self.rabbitmq_exchange_type,
|
|
193
|
+
durable=True,
|
|
194
|
+
callback=self.on_rabbitmq_exchange_declared,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def on_rabbitmq_exchange_declared(self, frame):
|
|
198
|
+
"""Called when exchange is ready."""
|
|
199
|
+
self.rabbitmq_channel.queue_declare(
|
|
200
|
+
queue=self.rabbit_queue_name, durable=True, callback=self.on_rabbitmq_queue_declared
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def on_rabbitmq_queue_declared(self, frame):
|
|
204
|
+
"""Called when queue is ready."""
|
|
205
|
+
self.rabbitmq_channel.queue_bind(
|
|
206
|
+
exchange=self.rabbitmq_exchange_name,
|
|
207
|
+
queue=self.rabbit_queue_name,
|
|
208
|
+
routing_key=self.rabbit_queue_name,
|
|
209
|
+
callback=self.on_rabbitmq_bind_ok,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def on_rabbitmq_bind_ok(self, frame):
|
|
213
|
+
"""Final setup step when bind is complete."""
|
|
214
|
+
logger.info("RabbitMQ setup completed")
|
|
215
|
+
|
|
216
|
+
def on_rabbitmq_message(self, channel, method, properties, body):
|
|
217
|
+
"""Handle incoming messages. Only for test."""
|
|
218
|
+
try:
|
|
219
|
+
print(f"Received message: {body.decode()}\n")
|
|
220
|
+
self.rabbitmq_message_cache.put({"properties": properties, "body": body})
|
|
221
|
+
print(f"message delivery_tag: {method.delivery_tag}\n")
|
|
222
|
+
channel.basic_ack(delivery_tag=method.delivery_tag)
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logger.error(f"Message handling failed: {e}", exc_info=True)
|
|
225
|
+
|
|
226
|
+
def wait_for_connection_ready(self):
|
|
227
|
+
start_time = time.time()
|
|
228
|
+
while not self.is_rabbitmq_connected():
|
|
229
|
+
delta_time = time.time() - start_time
|
|
230
|
+
if delta_time > self.rabbitmq_conn_max_waiting_seconds:
|
|
231
|
+
logger.error("Failed to start consuming: Connection timeout")
|
|
232
|
+
return False
|
|
233
|
+
self.rabbit_reconnect()
|
|
234
|
+
time.sleep(self.rabbitmq_conn_sleep_seconds) # Reduced frequency of checks
|
|
235
|
+
|
|
236
|
+
# Message handling
|
|
237
|
+
def rabbitmq_start_consuming(self):
|
|
238
|
+
"""Start consuming messages asynchronously."""
|
|
239
|
+
self.wait_for_connection_ready()
|
|
240
|
+
|
|
241
|
+
self.rabbitmq_channel.basic_consume(
|
|
242
|
+
queue=self.rabbit_queue_name,
|
|
243
|
+
on_message_callback=self.on_rabbitmq_message,
|
|
244
|
+
auto_ack=False,
|
|
245
|
+
)
|
|
246
|
+
logger.info("Started rabbitmq consuming messages")
|
|
247
|
+
|
|
248
|
+
def rabbitmq_publish_message(self, message: dict):
|
|
249
|
+
"""
|
|
250
|
+
Publish a message to RabbitMQ.
|
|
251
|
+
"""
|
|
252
|
+
import pika
|
|
253
|
+
|
|
254
|
+
with self._rabbitmq_lock:
|
|
255
|
+
if not self.is_rabbitmq_connected():
|
|
256
|
+
logger.error("Cannot publish - no active connection")
|
|
257
|
+
return False
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
self.rabbitmq_channel.basic_publish(
|
|
261
|
+
exchange=self.rabbitmq_exchange_name,
|
|
262
|
+
routing_key=self.rabbit_queue_name,
|
|
263
|
+
body=json.dumps(message),
|
|
264
|
+
properties=pika.BasicProperties(
|
|
265
|
+
delivery_mode=2, # Persistent
|
|
266
|
+
),
|
|
267
|
+
mandatory=True,
|
|
268
|
+
)
|
|
269
|
+
logger.debug(f"Published message: {message}")
|
|
270
|
+
return True
|
|
271
|
+
except Exception as e:
|
|
272
|
+
logger.error(f"Failed to publish message: {e}")
|
|
273
|
+
self.rabbit_reconnect()
|
|
274
|
+
return False
|
|
275
|
+
|
|
276
|
+
# Connection management
|
|
277
|
+
def rabbit_reconnect(self):
|
|
278
|
+
"""Schedule reconnection attempt."""
|
|
279
|
+
logger.info("Attempting to reconnect...")
|
|
280
|
+
if self.rabbitmq_connection and not self.rabbitmq_connection.is_closed:
|
|
281
|
+
self.rabbitmq_connection.ioloop.stop()
|
|
282
|
+
|
|
283
|
+
# Reset connection state
|
|
284
|
+
self.rabbitmq_channel = None
|
|
285
|
+
self.initialize_rabbitmq()
|
|
286
|
+
|
|
287
|
+
def rabbitmq_close(self):
|
|
288
|
+
"""Gracefully shutdown connection."""
|
|
289
|
+
with self._rabbitmq_lock:
|
|
290
|
+
self._rabbitmq_stop_flag = True
|
|
291
|
+
|
|
292
|
+
# Close channel if open
|
|
293
|
+
if self.rabbitmq_channel and self.rabbitmq_channel.is_open:
|
|
294
|
+
try:
|
|
295
|
+
self.rabbitmq_channel.close()
|
|
296
|
+
except Exception as e:
|
|
297
|
+
logger.warning(f"Error closing channel: {e}")
|
|
298
|
+
|
|
299
|
+
# Close connection if open
|
|
300
|
+
if self.rabbitmq_connection:
|
|
301
|
+
if self.rabbitmq_connection.is_open:
|
|
302
|
+
try:
|
|
303
|
+
self.rabbitmq_connection.close()
|
|
304
|
+
except Exception as e:
|
|
305
|
+
logger.warning(f"Error closing connection: {e}")
|
|
306
|
+
|
|
307
|
+
# Stop IOLoop if running
|
|
308
|
+
try:
|
|
309
|
+
self.rabbitmq_connection.ioloop.stop()
|
|
310
|
+
except Exception as e:
|
|
311
|
+
logger.warning(f"Error stopping IOLoop: {e}")
|
|
312
|
+
|
|
313
|
+
# Wait for IOLoop thread to finish
|
|
314
|
+
if self._io_loop_thread and self._io_loop_thread.is_alive():
|
|
315
|
+
self._io_loop_thread.join(timeout=5)
|
|
316
|
+
if self._io_loop_thread.is_alive():
|
|
317
|
+
logger.warning("IOLoop thread did not terminate cleanly")
|
|
318
|
+
|
|
319
|
+
logger.info("RabbitMQ connection closed")
|
|
@@ -2,11 +2,9 @@ import asyncio
|
|
|
2
2
|
import threading
|
|
3
3
|
|
|
4
4
|
from collections.abc import Callable
|
|
5
|
+
from typing import Any
|
|
5
6
|
|
|
6
|
-
import
|
|
7
|
-
|
|
8
|
-
from redis import Redis
|
|
9
|
-
|
|
7
|
+
from memos.dependency import require_python_package
|
|
10
8
|
from memos.log import get_logger
|
|
11
9
|
from memos.mem_scheduler.modules.base import BaseSchedulerModule
|
|
12
10
|
|
|
@@ -15,6 +13,11 @@ logger = get_logger(__name__)
|
|
|
15
13
|
|
|
16
14
|
|
|
17
15
|
class RedisSchedulerModule(BaseSchedulerModule):
|
|
16
|
+
@require_python_package(
|
|
17
|
+
import_name="redis",
|
|
18
|
+
install_command="pip install redis",
|
|
19
|
+
install_link="https://redis.readthedocs.io/en/stable/",
|
|
20
|
+
)
|
|
18
21
|
def __init__(self):
|
|
19
22
|
"""
|
|
20
23
|
intent_detector: Object used for intent recognition (such as the above IntentDetector)
|
|
@@ -35,23 +38,25 @@ class RedisSchedulerModule(BaseSchedulerModule):
|
|
|
35
38
|
self._redis_listener_loop: asyncio.AbstractEventLoop | None = None
|
|
36
39
|
|
|
37
40
|
@property
|
|
38
|
-
def redis(self) ->
|
|
41
|
+
def redis(self) -> Any:
|
|
39
42
|
return self._redis_conn
|
|
40
43
|
|
|
41
44
|
@redis.setter
|
|
42
|
-
def redis(self, value:
|
|
45
|
+
def redis(self, value: Any) -> None:
|
|
43
46
|
self._redis_conn = value
|
|
44
47
|
|
|
45
48
|
def initialize_redis(
|
|
46
49
|
self, redis_host: str = "localhost", redis_port: int = 6379, redis_db: int = 0
|
|
47
50
|
):
|
|
51
|
+
import redis
|
|
52
|
+
|
|
48
53
|
self.redis_host = redis_host
|
|
49
54
|
self.redis_port = redis_port
|
|
50
55
|
self.redis_db = redis_db
|
|
51
56
|
|
|
52
57
|
try:
|
|
53
58
|
logger.debug(f"Connecting to Redis at {redis_host}:{redis_port}/{redis_db}")
|
|
54
|
-
self._redis_conn = Redis(
|
|
59
|
+
self._redis_conn = redis.Redis(
|
|
55
60
|
host=self.redis_host, port=self.redis_port, db=self.redis_db, decode_responses=True
|
|
56
61
|
)
|
|
57
62
|
# test conn
|
|
@@ -63,21 +68,21 @@ class RedisSchedulerModule(BaseSchedulerModule):
|
|
|
63
68
|
self._redis_conn.xtrim("user:queries:stream", self.query_list_capacity)
|
|
64
69
|
return self._redis_conn
|
|
65
70
|
|
|
66
|
-
async def
|
|
71
|
+
async def redis_add_message_stream(self, message: dict):
|
|
67
72
|
logger.debug(f"add_message_stream: {message}")
|
|
68
73
|
return self._redis_conn.xadd("user:queries:stream", message)
|
|
69
74
|
|
|
70
|
-
async def
|
|
75
|
+
async def redis_consume_message_stream(self, message: dict):
|
|
71
76
|
logger.debug(f"consume_message_stream: {message}")
|
|
72
77
|
|
|
73
|
-
def
|
|
78
|
+
def _redis_run_listener_async(self, handler: Callable):
|
|
74
79
|
"""Run the async listener in a separate thread"""
|
|
75
80
|
self._redis_listener_loop = asyncio.new_event_loop()
|
|
76
81
|
asyncio.set_event_loop(self._redis_listener_loop)
|
|
77
82
|
|
|
78
83
|
async def listener_wrapper():
|
|
79
84
|
try:
|
|
80
|
-
await self.
|
|
85
|
+
await self.__redis_listen_query_stream(handler)
|
|
81
86
|
except Exception as e:
|
|
82
87
|
logger.error(f"Listener thread error: {e}")
|
|
83
88
|
finally:
|
|
@@ -85,8 +90,12 @@ class RedisSchedulerModule(BaseSchedulerModule):
|
|
|
85
90
|
|
|
86
91
|
self._redis_listener_loop.run_until_complete(listener_wrapper())
|
|
87
92
|
|
|
88
|
-
async def
|
|
93
|
+
async def __redis_listen_query_stream(
|
|
94
|
+
self, handler=None, last_id: str = "$", block_time: int = 2000
|
|
95
|
+
):
|
|
89
96
|
"""Internal async stream listener"""
|
|
97
|
+
import redis
|
|
98
|
+
|
|
90
99
|
self._redis_listener_running = True
|
|
91
100
|
while self._redis_listener_running:
|
|
92
101
|
try:
|
|
@@ -99,6 +108,7 @@ class RedisSchedulerModule(BaseSchedulerModule):
|
|
|
99
108
|
for _, stream_messages in messages:
|
|
100
109
|
for message_id, message_data in stream_messages:
|
|
101
110
|
try:
|
|
111
|
+
print(f"deal with message_data {message_data}")
|
|
102
112
|
await handler(message_data)
|
|
103
113
|
last_id = message_id
|
|
104
114
|
except Exception as e:
|
|
@@ -112,17 +122,17 @@ class RedisSchedulerModule(BaseSchedulerModule):
|
|
|
112
122
|
logger.error(f"Unexpected error: {e}")
|
|
113
123
|
await asyncio.sleep(1)
|
|
114
124
|
|
|
115
|
-
def
|
|
125
|
+
def redis_start_listening(self, handler: Callable | None = None):
|
|
116
126
|
"""Start the Redis stream listener in a background thread"""
|
|
117
127
|
if self._redis_listener_thread and self._redis_listener_thread.is_alive():
|
|
118
128
|
logger.warning("Listener is already running")
|
|
119
129
|
return
|
|
120
130
|
|
|
121
131
|
if handler is None:
|
|
122
|
-
handler = self.
|
|
132
|
+
handler = self.redis_consume_message_stream
|
|
123
133
|
|
|
124
134
|
self._redis_listener_thread = threading.Thread(
|
|
125
|
-
target=self.
|
|
135
|
+
target=self._redis_run_listener_async,
|
|
126
136
|
args=(handler,),
|
|
127
137
|
daemon=True,
|
|
128
138
|
name="RedisListenerThread",
|
|
@@ -130,13 +140,7 @@ class RedisSchedulerModule(BaseSchedulerModule):
|
|
|
130
140
|
self._redis_listener_thread.start()
|
|
131
141
|
logger.info("Started Redis stream listener thread")
|
|
132
142
|
|
|
133
|
-
def
|
|
134
|
-
"""Close Redis connection"""
|
|
135
|
-
if self._redis_conn is not None:
|
|
136
|
-
self._redis_conn.close()
|
|
137
|
-
self._redis_conn = None
|
|
138
|
-
|
|
139
|
-
def stop_listening(self):
|
|
143
|
+
def redis_stop_listening(self):
|
|
140
144
|
"""Stop the listener thread gracefully"""
|
|
141
145
|
self._redis_listener_running = False
|
|
142
146
|
if self._redis_listener_thread and self._redis_listener_thread.is_alive():
|
|
@@ -144,3 +148,9 @@ class RedisSchedulerModule(BaseSchedulerModule):
|
|
|
144
148
|
if self._redis_listener_thread.is_alive():
|
|
145
149
|
logger.warning("Listener thread did not stop gracefully")
|
|
146
150
|
logger.info("Redis stream listener stopped")
|
|
151
|
+
|
|
152
|
+
def redis_close(self):
|
|
153
|
+
"""Close Redis connection"""
|
|
154
|
+
if self._redis_conn is not None:
|
|
155
|
+
self._redis_conn.close()
|
|
156
|
+
self._redis_conn = None
|
|
@@ -1,41 +1,185 @@
|
|
|
1
|
+
from memos.configs.mem_scheduler import BaseSchedulerConfig
|
|
2
|
+
from memos.llms.base import BaseLLM
|
|
1
3
|
from memos.log import get_logger
|
|
4
|
+
from memos.mem_cube.general import GeneralMemCube
|
|
2
5
|
from memos.mem_scheduler.modules.base import BaseSchedulerModule
|
|
6
|
+
from memos.mem_scheduler.schemas.general_schemas import (
|
|
7
|
+
TreeTextMemory_SEARCH_METHOD,
|
|
8
|
+
)
|
|
9
|
+
from memos.mem_scheduler.utils.filter_utils import (
|
|
10
|
+
filter_similar_memories,
|
|
11
|
+
filter_too_short_memories,
|
|
12
|
+
transform_name_to_key,
|
|
13
|
+
)
|
|
14
|
+
from memos.mem_scheduler.utils.misc_utils import (
|
|
15
|
+
extract_json_dict,
|
|
16
|
+
)
|
|
17
|
+
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
|
|
3
18
|
|
|
4
19
|
|
|
5
20
|
logger = get_logger(__name__)
|
|
6
21
|
|
|
7
22
|
|
|
8
23
|
class SchedulerRetriever(BaseSchedulerModule):
|
|
9
|
-
def __init__(self,
|
|
24
|
+
def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
|
|
25
|
+
super().__init__()
|
|
26
|
+
|
|
27
|
+
# hyper-parameters
|
|
28
|
+
self.filter_similarity_threshold = 0.75
|
|
29
|
+
self.filter_min_length_threshold = 6
|
|
30
|
+
|
|
31
|
+
self.config: BaseSchedulerConfig = config
|
|
32
|
+
self.process_llm = process_llm
|
|
33
|
+
|
|
34
|
+
def search(
|
|
35
|
+
self, query: str, mem_cube: GeneralMemCube, top_k: int, method=TreeTextMemory_SEARCH_METHOD
|
|
36
|
+
) -> list[TextualMemoryItem]:
|
|
37
|
+
"""Search in text memory with the given query.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
query: The search query string
|
|
41
|
+
top_k: Number of top results to return
|
|
42
|
+
method: Search method to use
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Search results or None if not implemented
|
|
10
46
|
"""
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
47
|
+
text_mem_base = mem_cube.text_mem
|
|
48
|
+
try:
|
|
49
|
+
if method == TreeTextMemory_SEARCH_METHOD:
|
|
50
|
+
assert isinstance(text_mem_base, TreeTextMemory)
|
|
51
|
+
results_long_term = text_mem_base.search(
|
|
52
|
+
query=query, top_k=top_k, memory_type="LongTermMemory"
|
|
53
|
+
)
|
|
54
|
+
results_user = text_mem_base.search(
|
|
55
|
+
query=query, top_k=top_k, memory_type="UserMemory"
|
|
56
|
+
)
|
|
57
|
+
results = results_long_term + results_user
|
|
58
|
+
else:
|
|
59
|
+
raise NotImplementedError(str(type(text_mem_base)))
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.error(f"Fail to search. The exeption is {e}.", exc_info=True)
|
|
62
|
+
results = []
|
|
63
|
+
return results
|
|
64
|
+
|
|
65
|
+
def rerank_memories(
|
|
66
|
+
self,
|
|
67
|
+
queries: list[str],
|
|
68
|
+
original_memories: list[str],
|
|
69
|
+
top_k: int,
|
|
70
|
+
) -> (list[str], bool):
|
|
14
71
|
"""
|
|
15
|
-
|
|
72
|
+
Rerank memories based on relevance to given queries using LLM.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
queries: List of query strings to determine relevance
|
|
76
|
+
original_memories: List of memory strings to be reranked
|
|
77
|
+
top_k: Number of top memories to return after reranking
|
|
16
78
|
|
|
17
|
-
|
|
18
|
-
|
|
79
|
+
Returns:
|
|
80
|
+
List of reranked memory strings (length <= top_k)
|
|
81
|
+
|
|
82
|
+
Note:
|
|
83
|
+
If LLM reranking fails, falls back to original order (truncated to top_k)
|
|
84
|
+
"""
|
|
85
|
+
success_flag = False
|
|
19
86
|
|
|
20
|
-
|
|
21
|
-
self._current_mem_cube = None
|
|
87
|
+
logger.info(f"Starting memory reranking for {len(original_memories)} memories")
|
|
22
88
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
89
|
+
# Build LLM prompt for memory reranking
|
|
90
|
+
prompt = self.build_prompt(
|
|
91
|
+
"memory_reranking",
|
|
92
|
+
queries=[f"[0] {queries[0]}"],
|
|
93
|
+
current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)],
|
|
94
|
+
)
|
|
95
|
+
logger.debug(f"Generated reranking prompt: {prompt[:200]}...") # Log first 200 chars
|
|
27
96
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
"
|
|
31
|
-
self._memory_text_list = value
|
|
97
|
+
# Get LLM response
|
|
98
|
+
response = self.process_llm.generate([{"role": "user", "content": prompt}])
|
|
99
|
+
logger.debug(f"Received LLM response: {response[:200]}...") # Log first 200 chars
|
|
32
100
|
|
|
33
|
-
|
|
101
|
+
try:
|
|
102
|
+
# Parse JSON response
|
|
103
|
+
response = extract_json_dict(response)
|
|
104
|
+
new_order = response["new_order"][:top_k]
|
|
105
|
+
text_memories_with_new_order = [original_memories[idx] for idx in new_order]
|
|
106
|
+
logger.info(
|
|
107
|
+
f"Successfully reranked memories. Returning top {len(text_memories_with_new_order)} items;"
|
|
108
|
+
f"Ranking reasoning: {response['reasoning']}"
|
|
109
|
+
)
|
|
110
|
+
success_flag = True
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.error(
|
|
113
|
+
f"Failed to rerank memories with LLM. Exception: {e}. Raw response: {response} ",
|
|
114
|
+
exc_info=True,
|
|
115
|
+
)
|
|
116
|
+
text_memories_with_new_order = original_memories[:top_k]
|
|
117
|
+
success_flag = False
|
|
118
|
+
return text_memories_with_new_order, success_flag
|
|
119
|
+
|
|
120
|
+
def process_and_rerank_memories(
|
|
121
|
+
self,
|
|
122
|
+
queries: list[str],
|
|
123
|
+
original_memory: list[TextualMemoryItem],
|
|
124
|
+
new_memory: list[TextualMemoryItem],
|
|
125
|
+
top_k: int = 10,
|
|
126
|
+
) -> list[TextualMemoryItem] | None:
|
|
34
127
|
"""
|
|
35
|
-
|
|
36
|
-
|
|
128
|
+
Process and rerank memory items by combining original and new memories,
|
|
129
|
+
applying filters, and then reranking based on relevance to queries.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
queries: List of query strings to rerank memories against
|
|
133
|
+
original_memory: List of original TextualMemoryItem objects
|
|
134
|
+
new_memory: List of new TextualMemoryItem objects to merge
|
|
135
|
+
top_k: Maximum number of memories to return after reranking
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
List of reranked TextualMemoryItem objects, or None if processing fails
|
|
37
139
|
"""
|
|
38
|
-
|
|
140
|
+
# Combine original and new memories into a single list
|
|
141
|
+
combined_memory = original_memory + new_memory
|
|
142
|
+
|
|
143
|
+
# Create a mapping from normalized text to memory objects
|
|
144
|
+
memory_map = {
|
|
145
|
+
transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
# Extract normalized text representations from all memory items
|
|
149
|
+
combined_text_memory = [m.memory for m in combined_memory]
|
|
150
|
+
|
|
151
|
+
# Apply similarity filter to remove overly similar memories
|
|
152
|
+
filtered_combined_text_memory = filter_similar_memories(
|
|
153
|
+
text_memories=combined_text_memory,
|
|
154
|
+
similarity_threshold=self.filter_similarity_threshold,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Apply length filter to remove memories that are too short
|
|
158
|
+
filtered_combined_text_memory = filter_too_short_memories(
|
|
159
|
+
text_memories=filtered_combined_text_memory,
|
|
160
|
+
min_length_threshold=self.filter_min_length_threshold,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Ensure uniqueness of memory texts using dictionary keys (preserves order)
|
|
164
|
+
unique_memory = list(dict.fromkeys(filtered_combined_text_memory))
|
|
165
|
+
|
|
166
|
+
# Rerank the filtered memories based on relevance to the queries
|
|
167
|
+
text_memories_with_new_order, success_flag = self.rerank_memories(
|
|
168
|
+
queries=queries,
|
|
169
|
+
original_memories=unique_memory,
|
|
170
|
+
top_k=top_k,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Map reranked text entries back to their original memory objects
|
|
174
|
+
memories_with_new_order = []
|
|
175
|
+
for text in text_memories_with_new_order:
|
|
176
|
+
normalized_text = transform_name_to_key(name=text)
|
|
177
|
+
if normalized_text in memory_map: # Ensure correct key matching
|
|
178
|
+
memories_with_new_order.append(memory_map[normalized_text])
|
|
179
|
+
else:
|
|
180
|
+
logger.warning(
|
|
181
|
+
f"Memory text not found in memory map. text: {text};\n"
|
|
182
|
+
f"Keys of memory_map: {memory_map.keys()}"
|
|
183
|
+
)
|
|
39
184
|
|
|
40
|
-
|
|
41
|
-
return None
|
|
185
|
+
return memories_with_new_order, success_flag
|