MemoryOS 0.0.1__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- memoryos-0.1.12.dist-info/METADATA +257 -0
- memoryos-0.1.12.dist-info/RECORD +117 -0
- memos/__init__.py +20 -1
- memos/api/start_api.py +420 -0
- memos/chunkers/__init__.py +4 -0
- memos/chunkers/base.py +24 -0
- memos/chunkers/factory.py +22 -0
- memos/chunkers/sentence_chunker.py +35 -0
- memos/configs/__init__.py +0 -0
- memos/configs/base.py +82 -0
- memos/configs/chunker.py +45 -0
- memos/configs/embedder.py +53 -0
- memos/configs/graph_db.py +45 -0
- memos/configs/llm.py +71 -0
- memos/configs/mem_chat.py +81 -0
- memos/configs/mem_cube.py +89 -0
- memos/configs/mem_os.py +70 -0
- memos/configs/mem_reader.py +53 -0
- memos/configs/mem_scheduler.py +78 -0
- memos/configs/memory.py +190 -0
- memos/configs/parser.py +38 -0
- memos/configs/utils.py +8 -0
- memos/configs/vec_db.py +64 -0
- memos/deprecation.py +262 -0
- memos/embedders/__init__.py +0 -0
- memos/embedders/base.py +15 -0
- memos/embedders/factory.py +23 -0
- memos/embedders/ollama.py +74 -0
- memos/embedders/sentence_transformer.py +40 -0
- memos/exceptions.py +30 -0
- memos/graph_dbs/__init__.py +0 -0
- memos/graph_dbs/base.py +215 -0
- memos/graph_dbs/factory.py +21 -0
- memos/graph_dbs/neo4j.py +827 -0
- memos/hello_world.py +97 -0
- memos/llms/__init__.py +0 -0
- memos/llms/base.py +16 -0
- memos/llms/factory.py +25 -0
- memos/llms/hf.py +231 -0
- memos/llms/ollama.py +82 -0
- memos/llms/openai.py +34 -0
- memos/llms/utils.py +14 -0
- memos/log.py +78 -0
- memos/mem_chat/__init__.py +0 -0
- memos/mem_chat/base.py +30 -0
- memos/mem_chat/factory.py +21 -0
- memos/mem_chat/simple.py +200 -0
- memos/mem_cube/__init__.py +0 -0
- memos/mem_cube/base.py +29 -0
- memos/mem_cube/general.py +146 -0
- memos/mem_cube/utils.py +24 -0
- memos/mem_os/client.py +5 -0
- memos/mem_os/core.py +819 -0
- memos/mem_os/main.py +12 -0
- memos/mem_os/product.py +89 -0
- memos/mem_reader/__init__.py +0 -0
- memos/mem_reader/base.py +27 -0
- memos/mem_reader/factory.py +21 -0
- memos/mem_reader/memory.py +298 -0
- memos/mem_reader/simple_struct.py +241 -0
- memos/mem_scheduler/__init__.py +0 -0
- memos/mem_scheduler/base_scheduler.py +164 -0
- memos/mem_scheduler/general_scheduler.py +305 -0
- memos/mem_scheduler/modules/__init__.py +0 -0
- memos/mem_scheduler/modules/base.py +74 -0
- memos/mem_scheduler/modules/dispatcher.py +103 -0
- memos/mem_scheduler/modules/monitor.py +82 -0
- memos/mem_scheduler/modules/redis_service.py +146 -0
- memos/mem_scheduler/modules/retriever.py +41 -0
- memos/mem_scheduler/modules/schemas.py +146 -0
- memos/mem_scheduler/scheduler_factory.py +21 -0
- memos/mem_scheduler/utils.py +26 -0
- memos/mem_user/user_manager.py +478 -0
- memos/memories/__init__.py +0 -0
- memos/memories/activation/__init__.py +0 -0
- memos/memories/activation/base.py +42 -0
- memos/memories/activation/item.py +25 -0
- memos/memories/activation/kv.py +232 -0
- memos/memories/base.py +19 -0
- memos/memories/factory.py +34 -0
- memos/memories/parametric/__init__.py +0 -0
- memos/memories/parametric/base.py +19 -0
- memos/memories/parametric/item.py +11 -0
- memos/memories/parametric/lora.py +41 -0
- memos/memories/textual/__init__.py +0 -0
- memos/memories/textual/base.py +89 -0
- memos/memories/textual/general.py +286 -0
- memos/memories/textual/item.py +167 -0
- memos/memories/textual/naive.py +185 -0
- memos/memories/textual/tree.py +289 -0
- memos/memories/textual/tree_text_memory/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/manager.py +305 -0
- memos/memories/textual/tree_text_memory/retrieve/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/retrieve/reasoner.py +64 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +158 -0
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +111 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +13 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +166 -0
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +68 -0
- memos/memories/textual/tree_text_memory/retrieve/utils.py +48 -0
- memos/parsers/__init__.py +0 -0
- memos/parsers/base.py +15 -0
- memos/parsers/factory.py +19 -0
- memos/parsers/markitdown.py +22 -0
- memos/settings.py +8 -0
- memos/templates/__init__.py +0 -0
- memos/templates/mem_reader_prompts.py +98 -0
- memos/templates/mem_scheduler_prompts.py +65 -0
- memos/types.py +55 -0
- memos/vec_dbs/__init__.py +0 -0
- memos/vec_dbs/base.py +105 -0
- memos/vec_dbs/factory.py +21 -0
- memos/vec_dbs/item.py +43 -0
- memos/vec_dbs/qdrant.py +292 -0
- memoryos-0.0.1.dist-info/METADATA +0 -53
- memoryos-0.0.1.dist-info/RECORD +0 -5
- {memoryos-0.0.1.dist-info → memoryos-0.1.12.dist-info}/LICENSE +0 -0
- {memoryos-0.0.1.dist-info → memoryos-0.1.12.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timedelta
|
|
4
|
+
|
|
5
|
+
from memos.configs.mem_scheduler import GeneralSchedulerConfig
|
|
6
|
+
from memos.llms.base import BaseLLM
|
|
7
|
+
from memos.log import get_logger
|
|
8
|
+
from memos.mem_cube.general import GeneralMemCube
|
|
9
|
+
from memos.mem_scheduler.base_scheduler import BaseScheduler
|
|
10
|
+
from memos.mem_scheduler.modules.monitor import SchedulerMonitor
|
|
11
|
+
from memos.mem_scheduler.modules.retriever import SchedulerRetriever
|
|
12
|
+
from memos.mem_scheduler.modules.schemas import (
|
|
13
|
+
ANSWER_LABEL,
|
|
14
|
+
DEFAULT_ACT_MEM_DUMP_PATH,
|
|
15
|
+
DEFAULT_ACTIVATION_MEM_SIZE,
|
|
16
|
+
NOT_INITIALIZED,
|
|
17
|
+
QUERY_LABEL,
|
|
18
|
+
ScheduleLogForWebItem,
|
|
19
|
+
ScheduleMessageItem,
|
|
20
|
+
TextMemory_SEARCH_METHOD,
|
|
21
|
+
TreeTextMemory_SEARCH_METHOD,
|
|
22
|
+
)
|
|
23
|
+
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
|
|
24
|
+
from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
logger = get_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GeneralScheduler(BaseScheduler):
|
|
31
|
+
def __init__(self, config: GeneralSchedulerConfig):
|
|
32
|
+
"""Initialize the scheduler with the given configuration."""
|
|
33
|
+
super().__init__(config)
|
|
34
|
+
self.top_k = self.config.get("top_k", 10)
|
|
35
|
+
self.top_n = self.config.get("top_n", 5)
|
|
36
|
+
self.act_mem_update_interval = self.config.get("act_mem_update_interval", 300)
|
|
37
|
+
self.context_window_size = self.config.get("context_window_size", 5)
|
|
38
|
+
self.activation_mem_size = self.config.get(
|
|
39
|
+
"activation_mem_size", DEFAULT_ACTIVATION_MEM_SIZE
|
|
40
|
+
)
|
|
41
|
+
self.act_mem_dump_path = self.config.get("act_mem_dump_path", DEFAULT_ACT_MEM_DUMP_PATH)
|
|
42
|
+
self.search_method = TextMemory_SEARCH_METHOD
|
|
43
|
+
self._last_activation_mem_update_time = 0.0
|
|
44
|
+
self.query_list = []
|
|
45
|
+
|
|
46
|
+
# register handlers
|
|
47
|
+
handlers = {
|
|
48
|
+
QUERY_LABEL: self._query_message_consume,
|
|
49
|
+
ANSWER_LABEL: self._answer_message_consume,
|
|
50
|
+
}
|
|
51
|
+
self.dispatcher.register_handlers(handlers)
|
|
52
|
+
|
|
53
|
+
def initialize_modules(self, chat_llm: BaseLLM):
|
|
54
|
+
self.chat_llm = chat_llm
|
|
55
|
+
self.monitor = SchedulerMonitor(
|
|
56
|
+
chat_llm=self.chat_llm, activation_mem_size=self.activation_mem_size
|
|
57
|
+
)
|
|
58
|
+
self.retriever = SchedulerRetriever(chat_llm=self.chat_llm)
|
|
59
|
+
logger.debug("GeneralScheduler has been initialized")
|
|
60
|
+
|
|
61
|
+
def _answer_message_consume(self, messages: list[ScheduleMessageItem]) -> None:
|
|
62
|
+
"""
|
|
63
|
+
Process and handle answer trigger messages from the queue.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
messages: List of answer messages to process
|
|
67
|
+
"""
|
|
68
|
+
# TODO: This handler is not ready yet
|
|
69
|
+
logger.debug(f"Messages {messages} assigned to {ANSWER_LABEL} handler.")
|
|
70
|
+
for msg in messages:
|
|
71
|
+
if msg.label is not ANSWER_LABEL:
|
|
72
|
+
logger.error(f"_answer_message_consume is not designed for {msg.label}")
|
|
73
|
+
continue
|
|
74
|
+
answer = msg.content
|
|
75
|
+
self._current_user_id = msg.user_id
|
|
76
|
+
self._current_mem_cube_id = msg.mem_cube_id
|
|
77
|
+
self._current_mem_cube = msg.mem_cube
|
|
78
|
+
|
|
79
|
+
# Get current activation memory items
|
|
80
|
+
current_activation_mem = [
|
|
81
|
+
item["memory"]
|
|
82
|
+
for item in self.monitor.activation_memory_freq_list
|
|
83
|
+
if item["memory"] is not None
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
# Update memory frequencies based on the answer
|
|
87
|
+
# TODO: not implemented
|
|
88
|
+
self.monitor.activation_memory_freq_list = self.monitor.update_freq(
|
|
89
|
+
answer=answer, activation_memory_freq_list=self.monitor.activation_memory_freq_list
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Check if it's time to update activation memory
|
|
93
|
+
now = datetime.now()
|
|
94
|
+
if (now - self._last_activation_mem_update_time) >= timedelta(
|
|
95
|
+
seconds=self.act_mem_update_interval
|
|
96
|
+
):
|
|
97
|
+
# TODO: not implemented
|
|
98
|
+
self.update_activation_memory(current_activation_mem)
|
|
99
|
+
self._last_activation_mem_update_time = now
|
|
100
|
+
|
|
101
|
+
# recording messages
|
|
102
|
+
log_message = self.create_autofilled_log_item(
|
|
103
|
+
log_title="memos answer triggers scheduling...",
|
|
104
|
+
label=ANSWER_LABEL,
|
|
105
|
+
log_content="activation_memory has been updated",
|
|
106
|
+
)
|
|
107
|
+
self._submit_web_logs(messages=log_message)
|
|
108
|
+
|
|
109
|
+
def _query_message_consume(self, messages: list[ScheduleMessageItem]) -> None:
|
|
110
|
+
"""
|
|
111
|
+
Process and handle query trigger messages from the queue.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
messages: List of query messages to process
|
|
115
|
+
"""
|
|
116
|
+
logger.debug(f"Messages {messages} assigned to {QUERY_LABEL} handler.")
|
|
117
|
+
for msg in messages:
|
|
118
|
+
if msg.label is not QUERY_LABEL:
|
|
119
|
+
logger.error(f"_query_message_consume is not designed for {msg.label}")
|
|
120
|
+
continue
|
|
121
|
+
# Process the query in a session turn
|
|
122
|
+
self._current_user_id = msg.user_id
|
|
123
|
+
self._current_mem_cube_id = msg.mem_cube_id
|
|
124
|
+
self._current_mem_cube = msg.mem_cube
|
|
125
|
+
self.process_session_turn(query=msg.content, top_k=self.top_k, top_n=self.top_n)
|
|
126
|
+
|
|
127
|
+
def process_session_turn(
|
|
128
|
+
self,
|
|
129
|
+
query: str,
|
|
130
|
+
top_k: int = 10,
|
|
131
|
+
top_n: int = 5,
|
|
132
|
+
) -> None:
|
|
133
|
+
"""
|
|
134
|
+
Process a dialog turn:
|
|
135
|
+
- If q_list reaches window size, trigger retrieval;
|
|
136
|
+
- Immediately switch to the new memory if retrieval is triggered.
|
|
137
|
+
"""
|
|
138
|
+
q_list = [query]
|
|
139
|
+
self.query_list.append(query)
|
|
140
|
+
text_mem_base = self.mem_cube.text_mem
|
|
141
|
+
if isinstance(text_mem_base, TreeTextMemory):
|
|
142
|
+
working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
|
|
143
|
+
else:
|
|
144
|
+
logger.error("Not implemented!")
|
|
145
|
+
return
|
|
146
|
+
text_working_memory: list[str] = [w_m.memory for w_m in working_memory]
|
|
147
|
+
intent_result = self.monitor.detect_intent(
|
|
148
|
+
q_list=q_list, text_working_memory=text_working_memory
|
|
149
|
+
)
|
|
150
|
+
if intent_result["trigger_retrieval"]:
|
|
151
|
+
missing_evidence = intent_result["missing_evidence"]
|
|
152
|
+
num_evidence = len(missing_evidence)
|
|
153
|
+
k_per_evidence = max(1, top_k // max(1, num_evidence))
|
|
154
|
+
new_candidates = []
|
|
155
|
+
for item in missing_evidence:
|
|
156
|
+
logger.debug(f"missing_evidence: {item}")
|
|
157
|
+
results = self.search(query=item, top_k=k_per_evidence, method=self.search_method)
|
|
158
|
+
logger.debug(f"search results for {missing_evidence}: {results}")
|
|
159
|
+
new_candidates.extend(results)
|
|
160
|
+
|
|
161
|
+
# recording messages
|
|
162
|
+
log_message = self.create_autofilled_log_item(
|
|
163
|
+
log_title="user query triggers scheduling...",
|
|
164
|
+
label=QUERY_LABEL,
|
|
165
|
+
log_content=f"search new candidates for working memory: {len(new_candidates)}",
|
|
166
|
+
)
|
|
167
|
+
self._submit_web_logs(messages=log_message)
|
|
168
|
+
new_order_working_memory = self.replace_working_memory(
|
|
169
|
+
original_memory=working_memory, new_memory=new_candidates, top_k=top_k, top_n=top_n
|
|
170
|
+
)
|
|
171
|
+
self.update_activation_memory(new_order_working_memory)
|
|
172
|
+
|
|
173
|
+
def create_autofilled_log_item(
|
|
174
|
+
self, log_title: str, log_content: str, label: str
|
|
175
|
+
) -> ScheduleLogForWebItem:
|
|
176
|
+
# TODO: create the log iterm with real stats
|
|
177
|
+
text_mem_base: TreeTextMemory = self.mem_cube.text_mem
|
|
178
|
+
current_memory_sizes = {
|
|
179
|
+
"long_term_memory_size": NOT_INITIALIZED,
|
|
180
|
+
"user_memory_size": NOT_INITIALIZED,
|
|
181
|
+
"working_memory_size": NOT_INITIALIZED,
|
|
182
|
+
"transformed_act_memory_size": NOT_INITIALIZED,
|
|
183
|
+
"parameter_memory_size": NOT_INITIALIZED,
|
|
184
|
+
}
|
|
185
|
+
memory_capacities = {
|
|
186
|
+
"long_term_memory_capacity": text_mem_base.memory_manager.memory_size["LongTermMemory"],
|
|
187
|
+
"user_memory_capacity": text_mem_base.memory_manager.memory_size["UserMemory"],
|
|
188
|
+
"working_memory_capacity": text_mem_base.memory_manager.memory_size["WorkingMemory"],
|
|
189
|
+
"transformed_act_memory_capacity": NOT_INITIALIZED,
|
|
190
|
+
"parameter_memory_capacity": NOT_INITIALIZED,
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
log_message = ScheduleLogForWebItem(
|
|
194
|
+
user_id=self._current_user_id,
|
|
195
|
+
mem_cube_id=self._current_mem_cube_id,
|
|
196
|
+
label=label,
|
|
197
|
+
log_title=log_title,
|
|
198
|
+
log_content=log_content,
|
|
199
|
+
current_memory_sizes=current_memory_sizes,
|
|
200
|
+
memory_capacities=memory_capacities,
|
|
201
|
+
)
|
|
202
|
+
return log_message
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def mem_cube(self) -> GeneralMemCube:
|
|
206
|
+
"""The memory cube associated with this MemChat."""
|
|
207
|
+
return self._current_mem_cube
|
|
208
|
+
|
|
209
|
+
@mem_cube.setter
|
|
210
|
+
def mem_cube(self, value: GeneralMemCube) -> None:
|
|
211
|
+
"""The memory cube associated with this MemChat."""
|
|
212
|
+
self._current_mem_cube = value
|
|
213
|
+
self.retriever.mem_cube = value
|
|
214
|
+
|
|
215
|
+
def replace_working_memory(
|
|
216
|
+
self,
|
|
217
|
+
original_memory: list[TextualMemoryItem],
|
|
218
|
+
new_memory: list[TextualMemoryItem],
|
|
219
|
+
top_k: int = 10,
|
|
220
|
+
top_n: int = 5,
|
|
221
|
+
) -> None | list[TextualMemoryItem]:
|
|
222
|
+
new_order_memory = None
|
|
223
|
+
text_mem_base = self.mem_cube.text_mem
|
|
224
|
+
if isinstance(text_mem_base, TreeTextMemory):
|
|
225
|
+
text_mem_base: TreeTextMemory = text_mem_base
|
|
226
|
+
combined_text_memory = [new_m.memory for new_m in original_memory] + [
|
|
227
|
+
new_m.memory for new_m in new_memory
|
|
228
|
+
]
|
|
229
|
+
combined_memory = original_memory + new_memory
|
|
230
|
+
memory_map = {mem_obj.memory: mem_obj for mem_obj in combined_memory}
|
|
231
|
+
|
|
232
|
+
unique_memory = list(dict.fromkeys(combined_text_memory))
|
|
233
|
+
prompt = self.build_prompt(
|
|
234
|
+
"memory_reranking", query="", current_order=unique_memory, staging_buffer=[]
|
|
235
|
+
)
|
|
236
|
+
response = self.chat_llm.generate([{"role": "user", "content": prompt}])
|
|
237
|
+
response = json.loads(response)
|
|
238
|
+
new_order_text_memory = response.get("new_order", [])[: top_n + top_k]
|
|
239
|
+
|
|
240
|
+
new_order_memory = []
|
|
241
|
+
for text in new_order_text_memory:
|
|
242
|
+
if text in memory_map:
|
|
243
|
+
new_order_memory.append(memory_map[text])
|
|
244
|
+
else:
|
|
245
|
+
logger.warning(
|
|
246
|
+
f"Memory text not found in memory map. text: {text}; memory_map: {memory_map}"
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
text_mem_base.replace_working_memory(new_order_memory[top_n:])
|
|
250
|
+
new_order_memory = new_order_memory[:top_n]
|
|
251
|
+
logger.info(
|
|
252
|
+
f"The working memory has been replaced with {len(new_order_memory)} new memories."
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
logger.error("memory_base is not supported")
|
|
256
|
+
|
|
257
|
+
return new_order_memory
|
|
258
|
+
|
|
259
|
+
def search(self, query: str, top_k: int, method=TreeTextMemory_SEARCH_METHOD):
|
|
260
|
+
text_mem_base = self.mem_cube.text_mem
|
|
261
|
+
if isinstance(text_mem_base, TreeTextMemory) and method == TextMemory_SEARCH_METHOD:
|
|
262
|
+
results_long_term = text_mem_base.search(
|
|
263
|
+
query=query, top_k=top_k, memory_type="LongTermMemory"
|
|
264
|
+
)
|
|
265
|
+
results_user = text_mem_base.search(query=query, top_k=top_k, memory_type="UserMemory")
|
|
266
|
+
results = results_long_term + results_user
|
|
267
|
+
else:
|
|
268
|
+
logger.error("Not implemented.")
|
|
269
|
+
results = None
|
|
270
|
+
return results
|
|
271
|
+
|
|
272
|
+
def update_activation_memory(self, new_memory: list[str | TextualMemoryItem]) -> None:
|
|
273
|
+
"""
|
|
274
|
+
Update activation memory by extracting KVCacheItems from new_memory (list of str),
|
|
275
|
+
add them to a KVCacheMemory instance, and dump to disk.
|
|
276
|
+
"""
|
|
277
|
+
# TODO: The function of update activation memory is waiting to test
|
|
278
|
+
if len(new_memory) == 0:
|
|
279
|
+
logger.error("update_activation_memory: new_memory is empty.")
|
|
280
|
+
return
|
|
281
|
+
if isinstance(new_memory[0], TextualMemoryItem):
|
|
282
|
+
new_text_memory = [mem.memory for mem in new_memory]
|
|
283
|
+
elif isinstance(new_memory[0], str):
|
|
284
|
+
new_text_memory = new_memory
|
|
285
|
+
else:
|
|
286
|
+
logger.error("Not Implemented.")
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
act_mem = self.mem_cube.act_mem
|
|
290
|
+
|
|
291
|
+
text_memory = MEMORY_ASSEMBLY_TEMPLATE.format(
|
|
292
|
+
memory_text="".join(
|
|
293
|
+
[
|
|
294
|
+
f"{i + 1}. {sentence.strip()}\n"
|
|
295
|
+
for i, sentence in enumerate(new_text_memory)
|
|
296
|
+
if sentence.strip() # Skip empty strings
|
|
297
|
+
]
|
|
298
|
+
)
|
|
299
|
+
)
|
|
300
|
+
act_mem.delete_all()
|
|
301
|
+
cache_item = act_mem.extract(text_memory)
|
|
302
|
+
act_mem.add(cache_item)
|
|
303
|
+
act_mem.dump(self.act_mem_dump_path)
|
|
304
|
+
except Exception as e:
|
|
305
|
+
logger.warning(f"MOS-based activation memory update failed: {e}")
|
|
File without changes
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from memos.llms.base import BaseLLM
|
|
4
|
+
from memos.log import get_logger
|
|
5
|
+
from memos.mem_cube.general import GeneralMemCube
|
|
6
|
+
from memos.mem_scheduler.modules.schemas import BASE_DIR
|
|
7
|
+
from memos.templates.mem_scheduler_prompts import PROMPT_MAPPING
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BaseSchedulerModule:
|
|
14
|
+
def __init__(self):
|
|
15
|
+
"""Initialize the scheduler with the given configuration."""
|
|
16
|
+
self.base_dir = Path(BASE_DIR)
|
|
17
|
+
|
|
18
|
+
self._chat_llm = None
|
|
19
|
+
self._current_mem_cube_id: str | None = None
|
|
20
|
+
self._current_mem_cube: GeneralMemCube | None = None
|
|
21
|
+
self.mem_cubes: dict[str, GeneralMemCube] = {}
|
|
22
|
+
|
|
23
|
+
def load_template(self, template_name: str) -> str:
|
|
24
|
+
if template_name not in PROMPT_MAPPING:
|
|
25
|
+
logger.error("Prompt template is not found!")
|
|
26
|
+
prompt = PROMPT_MAPPING[template_name]
|
|
27
|
+
return prompt
|
|
28
|
+
|
|
29
|
+
def build_prompt(self, template_name: str, **kwargs) -> str:
|
|
30
|
+
template = self.load_template(template_name)
|
|
31
|
+
if not template:
|
|
32
|
+
raise FileNotFoundError(f"Prompt template `{template_name}` not found.")
|
|
33
|
+
return template.format(**kwargs)
|
|
34
|
+
|
|
35
|
+
def _build_system_prompt(self, memories: list | None = None) -> str:
|
|
36
|
+
"""Build system prompt with optional memories context."""
|
|
37
|
+
base_prompt = (
|
|
38
|
+
"You are a knowledgeable and helpful AI assistant. "
|
|
39
|
+
"You have access to conversation memories that help you provide more personalized responses. "
|
|
40
|
+
"Use the memories to understand the user's context, preferences, and past interactions. "
|
|
41
|
+
"If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if memories:
|
|
45
|
+
memory_context = "\n\n## Conversation Context:\n"
|
|
46
|
+
for i, memory in enumerate(memories, 1):
|
|
47
|
+
memory_context += f"{i}. {memory.memory}\n"
|
|
48
|
+
return base_prompt + memory_context
|
|
49
|
+
|
|
50
|
+
return base_prompt
|
|
51
|
+
|
|
52
|
+
def get_mem_cube(self, mem_cube_id: str) -> GeneralMemCube:
|
|
53
|
+
logger.error(f"mem_cube {mem_cube_id} does not exists.")
|
|
54
|
+
return self.mem_cubes.get(mem_cube_id, None)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def chat_llm(self) -> BaseLLM:
|
|
58
|
+
"""The memory cube associated with this MemChat."""
|
|
59
|
+
return self._chat_llm
|
|
60
|
+
|
|
61
|
+
@chat_llm.setter
|
|
62
|
+
def chat_llm(self, value: BaseLLM) -> None:
|
|
63
|
+
"""The memory cube associated with this MemChat."""
|
|
64
|
+
self._chat_llm = value
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def mem_cube(self) -> GeneralMemCube:
|
|
68
|
+
"""The memory cube associated with this MemChat."""
|
|
69
|
+
return self._current_mem_cube
|
|
70
|
+
|
|
71
|
+
@mem_cube.setter
|
|
72
|
+
def mem_cube(self, value: GeneralMemCube) -> None:
|
|
73
|
+
"""The memory cube associated with this MemChat."""
|
|
74
|
+
self._current_mem_cube = value
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
4
|
+
|
|
5
|
+
from memos.log import get_logger
|
|
6
|
+
from memos.mem_scheduler.modules.base import BaseSchedulerModule
|
|
7
|
+
from memos.mem_scheduler.modules.schemas import ScheduleMessageItem
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SchedulerDispatcher(BaseSchedulerModule):
|
|
14
|
+
"""
|
|
15
|
+
Thread pool-based message dispatcher that routes messages to dedicated handlers
|
|
16
|
+
based on their labels.
|
|
17
|
+
|
|
18
|
+
Features:
|
|
19
|
+
- Dedicated thread pool per message label
|
|
20
|
+
- Batch message processing
|
|
21
|
+
- Graceful shutdown
|
|
22
|
+
- Bulk handler registration
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, max_workers=3, enable_parallel_dispatch=False):
|
|
26
|
+
super().__init__()
|
|
27
|
+
# Main dispatcher thread pool
|
|
28
|
+
self.max_workers = max_workers
|
|
29
|
+
# Only initialize thread pool if in parallel mode
|
|
30
|
+
self.enable_parallel_dispatch = enable_parallel_dispatch
|
|
31
|
+
if self.enable_parallel_dispatch:
|
|
32
|
+
self.dispatcher_executor = ThreadPoolExecutor(
|
|
33
|
+
max_workers=self.max_workers, thread_name_prefix="dispatcher"
|
|
34
|
+
)
|
|
35
|
+
else:
|
|
36
|
+
self.dispatcher_executor = None
|
|
37
|
+
logger.info(f"enable_parallel_dispatch is set to {self.enable_parallel_dispatch}")
|
|
38
|
+
# Registered message handlers
|
|
39
|
+
self.handlers: dict[str, Callable] = {}
|
|
40
|
+
# Dispatcher running state
|
|
41
|
+
self._running = False
|
|
42
|
+
|
|
43
|
+
def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]):
|
|
44
|
+
"""
|
|
45
|
+
Register a handler function for a specific message label.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
label: Message label to handle
|
|
49
|
+
handler: Callable that processes messages of this label
|
|
50
|
+
"""
|
|
51
|
+
self.handlers[label] = handler
|
|
52
|
+
|
|
53
|
+
def register_handlers(
|
|
54
|
+
self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]]
|
|
55
|
+
) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Bulk register multiple handlers from a dictionary.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
handlers: Dictionary mapping labels to handler functions
|
|
61
|
+
Format: {label: handler_callable}
|
|
62
|
+
"""
|
|
63
|
+
for label, handler in handlers.items():
|
|
64
|
+
if not isinstance(label, str):
|
|
65
|
+
logger.error(f"Invalid label type: {type(label)}. Expected str.")
|
|
66
|
+
continue
|
|
67
|
+
if not callable(handler):
|
|
68
|
+
logger.error(f"Handler for label '{label}' is not callable.")
|
|
69
|
+
continue
|
|
70
|
+
self.register_handler(label=label, handler=handler)
|
|
71
|
+
logger.info(f"Registered {len(handlers)} handlers in bulk")
|
|
72
|
+
|
|
73
|
+
def _default_message_handler(self, messages: list[ScheduleMessageItem]) -> None:
|
|
74
|
+
logger.debug(f"Using _default_message_handler to deal with messages: {messages}")
|
|
75
|
+
|
|
76
|
+
def dispatch(self, msg_list: list[ScheduleMessageItem]):
|
|
77
|
+
"""
|
|
78
|
+
Dispatch a list of messages to their respective handlers.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
msg_list: List of ScheduleMessageItem objects to process
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
# Group messages by their labels
|
|
85
|
+
label_groups = defaultdict(list)
|
|
86
|
+
|
|
87
|
+
# Organize messages by label
|
|
88
|
+
for message in msg_list:
|
|
89
|
+
label_groups[message.label].append(message)
|
|
90
|
+
|
|
91
|
+
# Process each label group
|
|
92
|
+
for label, msgs in label_groups.items():
|
|
93
|
+
if label not in self.handlers:
|
|
94
|
+
logger.error(f"No handler registered for label: {label}")
|
|
95
|
+
handler = self._default_message_handler
|
|
96
|
+
else:
|
|
97
|
+
handler = self.handlers[label]
|
|
98
|
+
# dispatch to different handler
|
|
99
|
+
logger.debug(f"Dispatch {len(msgs)} messages to {label} handler.")
|
|
100
|
+
if self.enable_parallel_dispatch and self.dispatcher_executor is not None:
|
|
101
|
+
self.dispatcher_executor.submit(handler, msgs)
|
|
102
|
+
else:
|
|
103
|
+
handler(msgs) # Direct serial execution
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from memos.log import get_logger
|
|
6
|
+
from memos.mem_cube.general import GeneralMemCube
|
|
7
|
+
from memos.mem_scheduler.modules.base import BaseSchedulerModule
|
|
8
|
+
from memos.mem_scheduler.utils import extract_json_dict
|
|
9
|
+
from memos.memories.textual.tree import TreeTextMemory
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SchedulerMonitor(BaseSchedulerModule):
|
|
16
|
+
def __init__(self, chat_llm, activation_mem_size=5):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.statistics = {}
|
|
19
|
+
self.intent_history: list[str] = []
|
|
20
|
+
self.activation_mem_size = activation_mem_size
|
|
21
|
+
self.activation_memory_freq_list = [
|
|
22
|
+
{"memory": None, "count": 0} for _ in range(self.activation_mem_size)
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
self._chat_llm = chat_llm
|
|
26
|
+
|
|
27
|
+
def update_stats(self, mem_cube):
|
|
28
|
+
self.statistics["activation_mem_size"] = self.activation_mem_size
|
|
29
|
+
mem_cube_info = self.get_mem_cube_info(mem_cube)
|
|
30
|
+
self.statistics.update(mem_cube_info)
|
|
31
|
+
|
|
32
|
+
def get_mem_cube_info(self, mem_cube: GeneralMemCube):
|
|
33
|
+
mem_cube_info = {}
|
|
34
|
+
|
|
35
|
+
text_mem = mem_cube.text_mem
|
|
36
|
+
if isinstance(text_mem, TreeTextMemory):
|
|
37
|
+
memory_size_dict = text_mem.memory_manager.memory_size
|
|
38
|
+
mem_cube_info["text_mem"] = memory_size_dict
|
|
39
|
+
else:
|
|
40
|
+
logger.error("Not Implemented")
|
|
41
|
+
|
|
42
|
+
return mem_cube_info
|
|
43
|
+
|
|
44
|
+
def detect_intent(
|
|
45
|
+
self,
|
|
46
|
+
q_list: list[str],
|
|
47
|
+
text_working_memory: list[str],
|
|
48
|
+
prompt_name="intent_recognizing",
|
|
49
|
+
) -> dict[str, Any]:
|
|
50
|
+
"""
|
|
51
|
+
Detect the intent of the user input.
|
|
52
|
+
"""
|
|
53
|
+
prompt = self.build_prompt(
|
|
54
|
+
template_name=prompt_name,
|
|
55
|
+
q_list=q_list,
|
|
56
|
+
working_memory_list=text_working_memory,
|
|
57
|
+
)
|
|
58
|
+
response = self._chat_llm.generate([{"role": "user", "content": prompt}])
|
|
59
|
+
response = extract_json_dict(response)
|
|
60
|
+
return response
|
|
61
|
+
|
|
62
|
+
def update_freq(
|
|
63
|
+
self,
|
|
64
|
+
answer: str,
|
|
65
|
+
activation_memory_freq_list: list[dict],
|
|
66
|
+
prompt_name="freq_detecting",
|
|
67
|
+
) -> list[dict]:
|
|
68
|
+
"""
|
|
69
|
+
Use LLM to detect which memories in activation_memory_freq_list appear in the answer,
|
|
70
|
+
increment their count by 1, and return the updated list.
|
|
71
|
+
"""
|
|
72
|
+
prompt = self.build_prompt(
|
|
73
|
+
template_name=prompt_name,
|
|
74
|
+
answer=answer,
|
|
75
|
+
activation_memory_freq_list=activation_memory_freq_list,
|
|
76
|
+
)
|
|
77
|
+
response = self._chat_llm.generate([{"role": "user", "content": prompt}])
|
|
78
|
+
try:
|
|
79
|
+
result = json.loads(response)
|
|
80
|
+
except Exception:
|
|
81
|
+
result = activation_memory_freq_list
|
|
82
|
+
return result
|