MemoryOS 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/METADATA +7 -1
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/RECORD +87 -64
- memos/__init__.py +1 -1
- memos/api/config.py +158 -69
- memos/api/context/context.py +147 -0
- memos/api/context/dependencies.py +101 -0
- memos/api/product_models.py +5 -1
- memos/api/routers/product_router.py +54 -26
- memos/configs/graph_db.py +49 -1
- memos/configs/internet_retriever.py +19 -0
- memos/configs/mem_os.py +5 -0
- memos/configs/mem_reader.py +9 -0
- memos/configs/mem_scheduler.py +54 -18
- memos/configs/mem_user.py +58 -0
- memos/graph_dbs/base.py +38 -3
- memos/graph_dbs/factory.py +2 -0
- memos/graph_dbs/nebular.py +1612 -0
- memos/graph_dbs/neo4j.py +18 -9
- memos/log.py +6 -1
- memos/mem_cube/utils.py +13 -6
- memos/mem_os/core.py +157 -37
- memos/mem_os/main.py +2 -2
- memos/mem_os/product.py +252 -201
- memos/mem_os/utils/default_config.py +1 -1
- memos/mem_os/utils/format_utils.py +281 -70
- memos/mem_os/utils/reference_utils.py +133 -0
- memos/mem_reader/simple_struct.py +13 -5
- memos/mem_scheduler/base_scheduler.py +239 -266
- memos/mem_scheduler/{modules → general_modules}/base.py +4 -5
- memos/mem_scheduler/{modules → general_modules}/dispatcher.py +57 -21
- memos/mem_scheduler/general_modules/misc.py +104 -0
- memos/mem_scheduler/{modules → general_modules}/rabbitmq_service.py +12 -10
- memos/mem_scheduler/{modules → general_modules}/redis_service.py +1 -1
- memos/mem_scheduler/general_modules/retriever.py +199 -0
- memos/mem_scheduler/general_modules/scheduler_logger.py +261 -0
- memos/mem_scheduler/general_scheduler.py +243 -80
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +305 -0
- memos/mem_scheduler/{modules/monitor.py → monitors/general_monitor.py} +106 -57
- memos/mem_scheduler/mos_for_test_scheduler.py +23 -20
- memos/mem_scheduler/schemas/__init__.py +0 -0
- memos/mem_scheduler/schemas/general_schemas.py +44 -0
- memos/mem_scheduler/schemas/message_schemas.py +149 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +337 -0
- memos/mem_scheduler/utils/__init__.py +0 -0
- memos/mem_scheduler/utils/filter_utils.py +176 -0
- memos/mem_scheduler/utils/misc_utils.py +102 -0
- memos/mem_user/factory.py +94 -0
- memos/mem_user/mysql_persistent_user_manager.py +271 -0
- memos/mem_user/mysql_user_manager.py +500 -0
- memos/mem_user/persistent_factory.py +96 -0
- memos/mem_user/user_manager.py +4 -4
- memos/memories/activation/item.py +5 -1
- memos/memories/activation/kv.py +20 -8
- memos/memories/textual/base.py +2 -2
- memos/memories/textual/general.py +36 -92
- memos/memories/textual/item.py +5 -33
- memos/memories/textual/tree.py +13 -7
- memos/memories/textual/tree_text_memory/organize/{conflict.py → handler.py} +34 -50
- memos/memories/textual/tree_text_memory/organize/manager.py +8 -96
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +49 -43
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +107 -142
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +229 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -3
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +11 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +15 -8
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +1 -1
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +2 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +191 -116
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +47 -15
- memos/memories/textual/tree_text_memory/retrieve/utils.py +11 -7
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +62 -58
- memos/memos_tools/dinding_report_bot.py +422 -0
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_service.py +44 -0
- memos/memos_tools/notification_utils.py +96 -0
- memos/memos_tools/thread_safe_dict.py +288 -0
- memos/settings.py +3 -1
- memos/templates/mem_reader_prompts.py +4 -1
- memos/templates/mem_scheduler_prompts.py +62 -15
- memos/templates/mos_prompts.py +116 -0
- memos/templates/tree_reorganize_prompts.py +24 -17
- memos/utils.py +19 -0
- memos/mem_scheduler/modules/misc.py +0 -39
- memos/mem_scheduler/modules/retriever.py +0 -268
- memos/mem_scheduler/modules/schemas.py +0 -328
- memos/mem_scheduler/utils.py +0 -75
- memos/memories/textual/tree_text_memory/organize/redundancy.py +0 -193
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/LICENSE +0 -0
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/WHEEL +0 -0
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/entry_points.txt +0 -0
- /memos/mem_scheduler/{modules → general_modules}/__init__.py +0 -0
|
@@ -3,7 +3,7 @@ from pathlib import Path
|
|
|
3
3
|
from memos.llms.base import BaseLLM
|
|
4
4
|
from memos.log import get_logger
|
|
5
5
|
from memos.mem_cube.general import GeneralMemCube
|
|
6
|
-
from memos.mem_scheduler.
|
|
6
|
+
from memos.mem_scheduler.schemas.general_schemas import BASE_DIR
|
|
7
7
|
from memos.templates.mem_scheduler_prompts import PROMPT_MAPPING
|
|
8
8
|
|
|
9
9
|
|
|
@@ -17,8 +17,7 @@ class BaseSchedulerModule:
|
|
|
17
17
|
|
|
18
18
|
self._chat_llm = None
|
|
19
19
|
self._process_llm = None
|
|
20
|
-
|
|
21
|
-
self._current_mem_cube: GeneralMemCube | None = None
|
|
20
|
+
|
|
22
21
|
self.mem_cubes: dict[str, GeneralMemCube] = {}
|
|
23
22
|
|
|
24
23
|
def load_template(self, template_name: str) -> str:
|
|
@@ -75,9 +74,9 @@ class BaseSchedulerModule:
|
|
|
75
74
|
@property
|
|
76
75
|
def mem_cube(self) -> GeneralMemCube:
|
|
77
76
|
"""The memory cube associated with this MemChat."""
|
|
78
|
-
return self.
|
|
77
|
+
return self.current_mem_cube
|
|
79
78
|
|
|
80
79
|
@mem_cube.setter
|
|
81
80
|
def mem_cube(self, value: GeneralMemCube) -> None:
|
|
82
81
|
"""The memory cube associated with this MemChat."""
|
|
83
|
-
self.
|
|
82
|
+
self.current_mem_cube = value
|
|
@@ -1,10 +1,12 @@
|
|
|
1
|
+
import concurrent
|
|
2
|
+
|
|
1
3
|
from collections import defaultdict
|
|
2
4
|
from collections.abc import Callable
|
|
3
5
|
from concurrent.futures import ThreadPoolExecutor
|
|
4
6
|
|
|
5
7
|
from memos.log import get_logger
|
|
6
|
-
from memos.mem_scheduler.
|
|
7
|
-
from memos.mem_scheduler.
|
|
8
|
+
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
|
|
9
|
+
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
logger = get_logger(__name__)
|
|
@@ -22,24 +24,31 @@ class SchedulerDispatcher(BaseSchedulerModule):
|
|
|
22
24
|
- Bulk handler registration
|
|
23
25
|
"""
|
|
24
26
|
|
|
25
|
-
def __init__(self, max_workers=
|
|
27
|
+
def __init__(self, max_workers=30, enable_parallel_dispatch=False):
|
|
26
28
|
super().__init__()
|
|
27
29
|
# Main dispatcher thread pool
|
|
28
30
|
self.max_workers = max_workers
|
|
31
|
+
|
|
29
32
|
# Only initialize thread pool if in parallel mode
|
|
30
33
|
self.enable_parallel_dispatch = enable_parallel_dispatch
|
|
34
|
+
self.thread_name_prefix = "dispatcher"
|
|
31
35
|
if self.enable_parallel_dispatch:
|
|
32
36
|
self.dispatcher_executor = ThreadPoolExecutor(
|
|
33
|
-
max_workers=self.max_workers, thread_name_prefix=
|
|
37
|
+
max_workers=self.max_workers, thread_name_prefix=self.thread_name_prefix
|
|
34
38
|
)
|
|
35
39
|
else:
|
|
36
40
|
self.dispatcher_executor = None
|
|
37
41
|
logger.info(f"enable_parallel_dispatch is set to {self.enable_parallel_dispatch}")
|
|
42
|
+
|
|
38
43
|
# Registered message handlers
|
|
39
44
|
self.handlers: dict[str, Callable] = {}
|
|
45
|
+
|
|
40
46
|
# Dispatcher running state
|
|
41
47
|
self._running = False
|
|
42
48
|
|
|
49
|
+
# Set to track active futures for monitoring purposes
|
|
50
|
+
self._futures = set()
|
|
51
|
+
|
|
43
52
|
def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]):
|
|
44
53
|
"""
|
|
45
54
|
Register a handler function for a specific message label.
|
|
@@ -105,6 +114,13 @@ class SchedulerDispatcher(BaseSchedulerModule):
|
|
|
105
114
|
# Convert defaultdict to regular dict for cleaner output
|
|
106
115
|
return {user_id: dict(cube_groups) for user_id, cube_groups in grouped_dict.items()}
|
|
107
116
|
|
|
117
|
+
def _handle_future_result(self, future):
|
|
118
|
+
self._futures.remove(future)
|
|
119
|
+
try:
|
|
120
|
+
future.result() # this will throw exception
|
|
121
|
+
except Exception as e:
|
|
122
|
+
logger.error(f"Handler execution failed: {e!s}", exc_info=True)
|
|
123
|
+
|
|
108
124
|
def dispatch(self, msg_list: list[ScheduleMessageItem]):
|
|
109
125
|
"""
|
|
110
126
|
Dispatch a list of messages to their respective handlers.
|
|
@@ -112,32 +128,29 @@ class SchedulerDispatcher(BaseSchedulerModule):
|
|
|
112
128
|
Args:
|
|
113
129
|
msg_list: List of ScheduleMessageItem objects to process
|
|
114
130
|
"""
|
|
131
|
+
if not msg_list:
|
|
132
|
+
logger.debug("Received empty message list, skipping dispatch")
|
|
133
|
+
return
|
|
115
134
|
|
|
116
|
-
# Group messages by their labels
|
|
135
|
+
# Group messages by their labels, and organize messages by label
|
|
117
136
|
label_groups = defaultdict(list)
|
|
118
|
-
|
|
119
|
-
# Organize messages by label
|
|
120
137
|
for message in msg_list:
|
|
121
138
|
label_groups[message.label].append(message)
|
|
122
139
|
|
|
123
140
|
# Process each label group
|
|
124
141
|
for label, msgs in label_groups.items():
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
handler = self._default_message_handler
|
|
128
|
-
else:
|
|
129
|
-
handler = self.handlers[label]
|
|
142
|
+
handler = self.handlers.get(label, self._default_message_handler)
|
|
143
|
+
|
|
130
144
|
# dispatch to different handler
|
|
131
|
-
logger.debug(f"Dispatch {len(msgs)}
|
|
145
|
+
logger.debug(f"Dispatch {len(msgs)} message(s) to {label} handler.")
|
|
132
146
|
if self.enable_parallel_dispatch and self.dispatcher_executor is not None:
|
|
133
147
|
# Capture variables in lambda to avoid loop variable issues
|
|
134
|
-
# TODO check this
|
|
135
148
|
future = self.dispatcher_executor.submit(handler, msgs)
|
|
136
|
-
|
|
137
|
-
|
|
149
|
+
self._futures.add(future)
|
|
150
|
+
future.add_done_callback(self._handle_future_result)
|
|
151
|
+
logger.info(f"Dispatched {len(msgs)} message(s) as future task")
|
|
138
152
|
else:
|
|
139
153
|
handler(msgs)
|
|
140
|
-
return None
|
|
141
154
|
|
|
142
155
|
def join(self, timeout: float | None = None) -> bool:
|
|
143
156
|
"""Wait for all dispatched tasks to complete.
|
|
@@ -151,15 +164,38 @@ class SchedulerDispatcher(BaseSchedulerModule):
|
|
|
151
164
|
if not self.enable_parallel_dispatch or self.dispatcher_executor is None:
|
|
152
165
|
return True # 串行模式无需等待
|
|
153
166
|
|
|
154
|
-
|
|
155
|
-
|
|
167
|
+
done, not_done = concurrent.futures.wait(
|
|
168
|
+
self._futures, timeout=timeout, return_when=concurrent.futures.ALL_COMPLETED
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Check for exceptions in completed tasks
|
|
172
|
+
for future in done:
|
|
173
|
+
try:
|
|
174
|
+
future.result()
|
|
175
|
+
except Exception:
|
|
176
|
+
logger.error("Handler failed during shutdown", exc_info=True)
|
|
177
|
+
|
|
178
|
+
return len(not_done) == 0
|
|
156
179
|
|
|
157
180
|
def shutdown(self) -> None:
|
|
158
181
|
"""Gracefully shutdown the dispatcher."""
|
|
182
|
+
self._running = False
|
|
183
|
+
|
|
159
184
|
if self.dispatcher_executor is not None:
|
|
185
|
+
# Cancel pending tasks
|
|
186
|
+
cancelled = 0
|
|
187
|
+
for future in self._futures:
|
|
188
|
+
if future.cancel():
|
|
189
|
+
cancelled += 1
|
|
190
|
+
logger.info(f"Cancelled {cancelled}/{len(self._futures)} pending tasks")
|
|
191
|
+
|
|
192
|
+
# Shutdown executor
|
|
193
|
+
try:
|
|
160
194
|
self.dispatcher_executor.shutdown(wait=True)
|
|
161
|
-
|
|
162
|
-
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.error(f"Executor shutdown error: {e}", exc_info=True)
|
|
197
|
+
finally:
|
|
198
|
+
self._futures.clear()
|
|
163
199
|
|
|
164
200
|
def __enter__(self):
|
|
165
201
|
self._running = True
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from contextlib import suppress
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from queue import Empty, Full, Queue
|
|
6
|
+
from typing import TYPE_CHECKING, TypeVar
|
|
7
|
+
|
|
8
|
+
from pydantic import field_serializer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
|
|
14
|
+
T = TypeVar("T")
|
|
15
|
+
|
|
16
|
+
BaseModelType = TypeVar("T", bound="BaseModel")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DictConversionMixin:
|
|
20
|
+
"""
|
|
21
|
+
Provides conversion functionality between Pydantic models and dictionaries,
|
|
22
|
+
including datetime serialization handling.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@field_serializer("timestamp", check_fields=False)
|
|
26
|
+
def serialize_datetime(self, dt: datetime | None, _info) -> str | None:
|
|
27
|
+
"""
|
|
28
|
+
Custom datetime serialization logic.
|
|
29
|
+
- Supports timezone-aware datetime objects
|
|
30
|
+
- Compatible with models without timestamp field (via check_fields=False)
|
|
31
|
+
"""
|
|
32
|
+
if dt is None:
|
|
33
|
+
return None
|
|
34
|
+
return dt.isoformat()
|
|
35
|
+
|
|
36
|
+
def to_dict(self) -> dict:
|
|
37
|
+
"""
|
|
38
|
+
Convert model instance to dictionary.
|
|
39
|
+
- Uses model_dump to ensure field consistency
|
|
40
|
+
- Prioritizes custom serializer for timestamp handling
|
|
41
|
+
"""
|
|
42
|
+
dump_data = self.model_dump()
|
|
43
|
+
if hasattr(self, "timestamp") and self.timestamp is not None:
|
|
44
|
+
dump_data["timestamp"] = self.serialize_datetime(self.timestamp, None)
|
|
45
|
+
return dump_data
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def from_dict(cls: type[BaseModelType], data: dict) -> BaseModelType:
|
|
49
|
+
"""
|
|
50
|
+
Create model instance from dictionary.
|
|
51
|
+
- Automatically converts timestamp strings to datetime objects
|
|
52
|
+
"""
|
|
53
|
+
data_copy = data.copy() # Avoid modifying original dictionary
|
|
54
|
+
if "timestamp" in data_copy and isinstance(data_copy["timestamp"], str):
|
|
55
|
+
try:
|
|
56
|
+
data_copy["timestamp"] = datetime.fromisoformat(data_copy["timestamp"])
|
|
57
|
+
except ValueError:
|
|
58
|
+
# Handle invalid time formats - adjust as needed (e.g., log warning or set to None)
|
|
59
|
+
data_copy["timestamp"] = None
|
|
60
|
+
|
|
61
|
+
return cls(**data_copy)
|
|
62
|
+
|
|
63
|
+
def __str__(self) -> str:
|
|
64
|
+
"""
|
|
65
|
+
Convert to formatted JSON string.
|
|
66
|
+
- Used for user-friendly display in print() or str() calls
|
|
67
|
+
"""
|
|
68
|
+
return json.dumps(
|
|
69
|
+
self.to_dict(),
|
|
70
|
+
indent=4,
|
|
71
|
+
ensure_ascii=False,
|
|
72
|
+
default=lambda o: str(o), # Handle other non-serializable objects
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AutoDroppingQueue(Queue[T]):
|
|
77
|
+
"""A thread-safe queue that automatically drops the oldest item when full."""
|
|
78
|
+
|
|
79
|
+
def __init__(self, maxsize: int = 0):
|
|
80
|
+
super().__init__(maxsize=maxsize)
|
|
81
|
+
|
|
82
|
+
def put(self, item: T, block: bool = False, timeout: float | None = None) -> None:
|
|
83
|
+
"""Put an item into the queue.
|
|
84
|
+
|
|
85
|
+
If the queue is full, the oldest item will be automatically removed to make space.
|
|
86
|
+
This operation is thread-safe.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
item: The item to be put into the queue
|
|
90
|
+
block: Ignored (kept for compatibility with Queue interface)
|
|
91
|
+
timeout: Ignored (kept for compatibility with Queue interface)
|
|
92
|
+
"""
|
|
93
|
+
try:
|
|
94
|
+
# First try non-blocking put
|
|
95
|
+
super().put(item, block=block, timeout=timeout)
|
|
96
|
+
except Full:
|
|
97
|
+
with suppress(Empty):
|
|
98
|
+
self.get_nowait() # Remove oldest item
|
|
99
|
+
# Retry putting the new item
|
|
100
|
+
super().put(item, block=block, timeout=timeout)
|
|
101
|
+
|
|
102
|
+
def get_queue_content_without_pop(self) -> list[T]:
|
|
103
|
+
"""Return a copy of the queue's contents without modifying it."""
|
|
104
|
+
return list(self.queue)
|
|
@@ -4,13 +4,13 @@ import threading
|
|
|
4
4
|
import time
|
|
5
5
|
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from queue import Queue
|
|
8
7
|
|
|
9
8
|
from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig
|
|
10
9
|
from memos.dependency import require_python_package
|
|
11
10
|
from memos.log import get_logger
|
|
12
|
-
from memos.mem_scheduler.
|
|
13
|
-
from memos.mem_scheduler.
|
|
11
|
+
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
|
|
12
|
+
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue
|
|
13
|
+
from memos.mem_scheduler.schemas.general_schemas import DIRECT_EXCHANGE_TYPE, FANOUT_EXCHANGE_TYPE
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
logger = get_logger(__name__)
|
|
@@ -38,7 +38,9 @@ class RabbitMQSchedulerModule(BaseSchedulerModule):
|
|
|
38
38
|
|
|
39
39
|
# fixed params
|
|
40
40
|
self.rabbitmq_message_cache_max_size = 10 # Max 10 messages
|
|
41
|
-
self.rabbitmq_message_cache =
|
|
41
|
+
self.rabbitmq_message_cache = AutoDroppingQueue(
|
|
42
|
+
maxsize=self.rabbitmq_message_cache_max_size
|
|
43
|
+
)
|
|
42
44
|
self.rabbitmq_connection_attempts = 3 # Max retry attempts on connection failure
|
|
43
45
|
self.rabbitmq_retry_delay = 5 # Delay (seconds) between retries
|
|
44
46
|
self.rabbitmq_heartbeat = 60 # Heartbeat interval (seconds) for connectio
|
|
@@ -69,9 +71,9 @@ class RabbitMQSchedulerModule(BaseSchedulerModule):
|
|
|
69
71
|
|
|
70
72
|
if config is None:
|
|
71
73
|
if config_path is None and AuthConfig.default_config_exists():
|
|
72
|
-
auth_config = AuthConfig.
|
|
74
|
+
auth_config = AuthConfig.from_local_config()
|
|
73
75
|
elif Path(config_path).exists():
|
|
74
|
-
auth_config = AuthConfig.
|
|
76
|
+
auth_config = AuthConfig.from_local_config(config_path=config_path)
|
|
75
77
|
else:
|
|
76
78
|
logger.error("Fail to initialize auth_config")
|
|
77
79
|
return
|
|
@@ -214,12 +216,12 @@ class RabbitMQSchedulerModule(BaseSchedulerModule):
|
|
|
214
216
|
def on_rabbitmq_message(self, channel, method, properties, body):
|
|
215
217
|
"""Handle incoming messages. Only for test."""
|
|
216
218
|
try:
|
|
217
|
-
print(f"Received message: {body.decode()}")
|
|
218
|
-
self.rabbitmq_message_cache.
|
|
219
|
-
print(f"message delivery_tag: {method.delivery_tag}")
|
|
219
|
+
print(f"Received message: {body.decode()}\n")
|
|
220
|
+
self.rabbitmq_message_cache.put({"properties": properties, "body": body})
|
|
221
|
+
print(f"message delivery_tag: {method.delivery_tag}\n")
|
|
220
222
|
channel.basic_ack(delivery_tag=method.delivery_tag)
|
|
221
223
|
except Exception as e:
|
|
222
|
-
logger.error(f"Message handling failed: {e}")
|
|
224
|
+
logger.error(f"Message handling failed: {e}", exc_info=True)
|
|
223
225
|
|
|
224
226
|
def wait_for_connection_ready(self):
|
|
225
227
|
start_time = time.time()
|
|
@@ -6,7 +6,7 @@ from typing import Any
|
|
|
6
6
|
|
|
7
7
|
from memos.dependency import require_python_package
|
|
8
8
|
from memos.log import get_logger
|
|
9
|
-
from memos.mem_scheduler.
|
|
9
|
+
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
logger = get_logger(__name__)
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from memos.configs.mem_scheduler import BaseSchedulerConfig
|
|
2
|
+
from memos.llms.base import BaseLLM
|
|
3
|
+
from memos.log import get_logger
|
|
4
|
+
from memos.mem_cube.general import GeneralMemCube
|
|
5
|
+
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
|
|
6
|
+
from memos.mem_scheduler.schemas.general_schemas import (
|
|
7
|
+
TreeTextMemory_FINE_SEARCH_METHOD,
|
|
8
|
+
TreeTextMemory_SEARCH_METHOD,
|
|
9
|
+
)
|
|
10
|
+
from memos.mem_scheduler.utils.filter_utils import (
|
|
11
|
+
filter_similar_memories,
|
|
12
|
+
filter_too_short_memories,
|
|
13
|
+
transform_name_to_key,
|
|
14
|
+
)
|
|
15
|
+
from memos.mem_scheduler.utils.misc_utils import (
|
|
16
|
+
extract_json_dict,
|
|
17
|
+
)
|
|
18
|
+
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SchedulerRetriever(BaseSchedulerModule):
|
|
25
|
+
def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
|
|
26
|
+
super().__init__()
|
|
27
|
+
|
|
28
|
+
# hyper-parameters
|
|
29
|
+
self.filter_similarity_threshold = 0.75
|
|
30
|
+
self.filter_min_length_threshold = 6
|
|
31
|
+
|
|
32
|
+
self.config: BaseSchedulerConfig = config
|
|
33
|
+
self.process_llm = process_llm
|
|
34
|
+
|
|
35
|
+
def search(
|
|
36
|
+
self,
|
|
37
|
+
query: str,
|
|
38
|
+
mem_cube: GeneralMemCube,
|
|
39
|
+
top_k: int,
|
|
40
|
+
method: str = TreeTextMemory_SEARCH_METHOD,
|
|
41
|
+
info: dict | None = None,
|
|
42
|
+
) -> list[TextualMemoryItem]:
|
|
43
|
+
"""Search in text memory with the given query.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
query: The search query string
|
|
47
|
+
top_k: Number of top results to return
|
|
48
|
+
method: Search method to use
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Search results or None if not implemented
|
|
52
|
+
"""
|
|
53
|
+
text_mem_base = mem_cube.text_mem
|
|
54
|
+
try:
|
|
55
|
+
if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]:
|
|
56
|
+
assert isinstance(text_mem_base, TreeTextMemory)
|
|
57
|
+
if info is None:
|
|
58
|
+
logger.warning(
|
|
59
|
+
"Please input 'info' when use tree.search so that "
|
|
60
|
+
"the database would store the consume history."
|
|
61
|
+
)
|
|
62
|
+
info = {"user_id": "", "session_id": ""}
|
|
63
|
+
|
|
64
|
+
mode = "fast" if method == TreeTextMemory_SEARCH_METHOD else "fine"
|
|
65
|
+
results_long_term = text_mem_base.search(
|
|
66
|
+
query=query, top_k=top_k, memory_type="LongTermMemory", mode=mode, info=info
|
|
67
|
+
)
|
|
68
|
+
results_user = text_mem_base.search(
|
|
69
|
+
query=query, top_k=top_k, memory_type="UserMemory", mode=mode, info=info
|
|
70
|
+
)
|
|
71
|
+
results = results_long_term + results_user
|
|
72
|
+
else:
|
|
73
|
+
raise NotImplementedError(str(type(text_mem_base)))
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.error(f"Fail to search. The exeption is {e}.", exc_info=True)
|
|
76
|
+
results = []
|
|
77
|
+
return results
|
|
78
|
+
|
|
79
|
+
def rerank_memories(
|
|
80
|
+
self,
|
|
81
|
+
queries: list[str],
|
|
82
|
+
original_memories: list[str],
|
|
83
|
+
top_k: int,
|
|
84
|
+
) -> (list[str], bool):
|
|
85
|
+
"""
|
|
86
|
+
Rerank memories based on relevance to given queries using LLM.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
queries: List of query strings to determine relevance
|
|
90
|
+
original_memories: List of memory strings to be reranked
|
|
91
|
+
top_k: Number of top memories to return after reranking
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
List of reranked memory strings (length <= top_k)
|
|
95
|
+
|
|
96
|
+
Note:
|
|
97
|
+
If LLM reranking fails, falls back to original order (truncated to top_k)
|
|
98
|
+
"""
|
|
99
|
+
success_flag = False
|
|
100
|
+
|
|
101
|
+
logger.info(f"Starting memory reranking for {len(original_memories)} memories")
|
|
102
|
+
|
|
103
|
+
# Build LLM prompt for memory reranking
|
|
104
|
+
prompt = self.build_prompt(
|
|
105
|
+
"memory_reranking",
|
|
106
|
+
queries=[f"[0] {queries[0]}"],
|
|
107
|
+
current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)],
|
|
108
|
+
)
|
|
109
|
+
logger.debug(f"Generated reranking prompt: {prompt[:200]}...") # Log first 200 chars
|
|
110
|
+
|
|
111
|
+
# Get LLM response
|
|
112
|
+
response = self.process_llm.generate([{"role": "user", "content": prompt}])
|
|
113
|
+
logger.debug(f"Received LLM response: {response[:200]}...") # Log first 200 chars
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
# Parse JSON response
|
|
117
|
+
response = extract_json_dict(response)
|
|
118
|
+
new_order = response["new_order"][:top_k]
|
|
119
|
+
text_memories_with_new_order = [original_memories[idx] for idx in new_order]
|
|
120
|
+
logger.info(
|
|
121
|
+
f"Successfully reranked memories. Returning top {len(text_memories_with_new_order)} items;"
|
|
122
|
+
f"Ranking reasoning: {response['reasoning']}"
|
|
123
|
+
)
|
|
124
|
+
success_flag = True
|
|
125
|
+
except Exception as e:
|
|
126
|
+
logger.error(
|
|
127
|
+
f"Failed to rerank memories with LLM. Exception: {e}. Raw response: {response} ",
|
|
128
|
+
exc_info=True,
|
|
129
|
+
)
|
|
130
|
+
text_memories_with_new_order = original_memories[:top_k]
|
|
131
|
+
success_flag = False
|
|
132
|
+
return text_memories_with_new_order, success_flag
|
|
133
|
+
|
|
134
|
+
def process_and_rerank_memories(
|
|
135
|
+
self,
|
|
136
|
+
queries: list[str],
|
|
137
|
+
original_memory: list[TextualMemoryItem],
|
|
138
|
+
new_memory: list[TextualMemoryItem],
|
|
139
|
+
top_k: int = 10,
|
|
140
|
+
) -> list[TextualMemoryItem] | None:
|
|
141
|
+
"""
|
|
142
|
+
Process and rerank memory items by combining original and new memories,
|
|
143
|
+
applying filters, and then reranking based on relevance to queries.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
queries: List of query strings to rerank memories against
|
|
147
|
+
original_memory: List of original TextualMemoryItem objects
|
|
148
|
+
new_memory: List of new TextualMemoryItem objects to merge
|
|
149
|
+
top_k: Maximum number of memories to return after reranking
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
List of reranked TextualMemoryItem objects, or None if processing fails
|
|
153
|
+
"""
|
|
154
|
+
# Combine original and new memories into a single list
|
|
155
|
+
combined_memory = original_memory + new_memory
|
|
156
|
+
|
|
157
|
+
# Create a mapping from normalized text to memory objects
|
|
158
|
+
memory_map = {
|
|
159
|
+
transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
# Extract normalized text representations from all memory items
|
|
163
|
+
combined_text_memory = [m.memory for m in combined_memory]
|
|
164
|
+
|
|
165
|
+
# Apply similarity filter to remove overly similar memories
|
|
166
|
+
filtered_combined_text_memory = filter_similar_memories(
|
|
167
|
+
text_memories=combined_text_memory,
|
|
168
|
+
similarity_threshold=self.filter_similarity_threshold,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Apply length filter to remove memories that are too short
|
|
172
|
+
filtered_combined_text_memory = filter_too_short_memories(
|
|
173
|
+
text_memories=filtered_combined_text_memory,
|
|
174
|
+
min_length_threshold=self.filter_min_length_threshold,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Ensure uniqueness of memory texts using dictionary keys (preserves order)
|
|
178
|
+
unique_memory = list(dict.fromkeys(filtered_combined_text_memory))
|
|
179
|
+
|
|
180
|
+
# Rerank the filtered memories based on relevance to the queries
|
|
181
|
+
text_memories_with_new_order, success_flag = self.rerank_memories(
|
|
182
|
+
queries=queries,
|
|
183
|
+
original_memories=unique_memory,
|
|
184
|
+
top_k=top_k,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Map reranked text entries back to their original memory objects
|
|
188
|
+
memories_with_new_order = []
|
|
189
|
+
for text in text_memories_with_new_order:
|
|
190
|
+
normalized_text = transform_name_to_key(name=text)
|
|
191
|
+
if normalized_text in memory_map: # Ensure correct key matching
|
|
192
|
+
memories_with_new_order.append(memory_map[normalized_text])
|
|
193
|
+
else:
|
|
194
|
+
logger.warning(
|
|
195
|
+
f"Memory text not found in memory map. text: {text};\n"
|
|
196
|
+
f"Keys of memory_map: {memory_map.keys()}"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return memories_with_new_order, success_flag
|