MemoryOS 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/METADATA +66 -26
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/RECORD +80 -56
- memoryos-0.2.1.dist-info/entry_points.txt +3 -0
- memos/__init__.py +1 -1
- memos/api/config.py +471 -0
- memos/api/exceptions.py +28 -0
- memos/api/mcp_serve.py +502 -0
- memos/api/product_api.py +35 -0
- memos/api/product_models.py +159 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +358 -0
- memos/chunkers/sentence_chunker.py +8 -2
- memos/cli.py +113 -0
- memos/configs/embedder.py +27 -0
- memos/configs/graph_db.py +83 -2
- memos/configs/llm.py +47 -0
- memos/configs/mem_cube.py +1 -1
- memos/configs/mem_scheduler.py +91 -5
- memos/configs/memory.py +5 -4
- memos/dependency.py +52 -0
- memos/embedders/ark.py +92 -0
- memos/embedders/factory.py +4 -0
- memos/embedders/sentence_transformer.py +8 -2
- memos/embedders/universal_api.py +32 -0
- memos/graph_dbs/base.py +2 -2
- memos/graph_dbs/factory.py +2 -0
- memos/graph_dbs/neo4j.py +331 -122
- memos/graph_dbs/neo4j_community.py +300 -0
- memos/llms/base.py +9 -0
- memos/llms/deepseek.py +54 -0
- memos/llms/factory.py +10 -1
- memos/llms/hf.py +170 -13
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +4 -0
- memos/llms/openai.py +67 -1
- memos/llms/qwen.py +63 -0
- memos/llms/vllm.py +153 -0
- memos/mem_cube/general.py +77 -16
- memos/mem_cube/utils.py +102 -0
- memos/mem_os/core.py +131 -41
- memos/mem_os/main.py +93 -11
- memos/mem_os/product.py +1098 -35
- memos/mem_os/utils/default_config.py +352 -0
- memos/mem_os/utils/format_utils.py +1154 -0
- memos/mem_reader/simple_struct.py +5 -5
- memos/mem_scheduler/base_scheduler.py +467 -36
- memos/mem_scheduler/general_scheduler.py +125 -244
- memos/mem_scheduler/modules/base.py +9 -0
- memos/mem_scheduler/modules/dispatcher.py +68 -2
- memos/mem_scheduler/modules/misc.py +39 -0
- memos/mem_scheduler/modules/monitor.py +228 -49
- memos/mem_scheduler/modules/rabbitmq_service.py +317 -0
- memos/mem_scheduler/modules/redis_service.py +32 -22
- memos/mem_scheduler/modules/retriever.py +250 -23
- memos/mem_scheduler/modules/schemas.py +189 -7
- memos/mem_scheduler/mos_for_test_scheduler.py +143 -0
- memos/mem_scheduler/utils.py +51 -2
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/memories/activation/item.py +25 -0
- memos/memories/activation/kv.py +10 -3
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/factory.py +2 -0
- memos/memories/textual/general.py +7 -5
- memos/memories/textual/tree.py +9 -5
- memos/memories/textual/tree_text_memory/organize/conflict.py +5 -3
- memos/memories/textual/tree_text_memory/organize/manager.py +26 -18
- memos/memories/textual/tree_text_memory/organize/redundancy.py +25 -44
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +11 -13
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +73 -51
- memos/memories/textual/tree_text_memory/retrieve/recall.py +0 -1
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +2 -2
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +6 -5
- memos/parsers/markitdown.py +8 -2
- memos/templates/mem_reader_prompts.py +65 -23
- memos/templates/mem_scheduler_prompts.py +96 -47
- memos/templates/tree_reorganize_prompts.py +85 -30
- memos/vec_dbs/base.py +12 -0
- memos/vec_dbs/qdrant.py +46 -20
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/LICENSE +0 -0
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/WHEEL +0 -0
|
@@ -1,35 +1,65 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
1
3
|
from datetime import datetime
|
|
2
4
|
from pathlib import Path
|
|
3
|
-
from typing import ClassVar, TypeVar
|
|
5
|
+
from typing import ClassVar, NewType, TypeVar
|
|
4
6
|
from uuid import uuid4
|
|
5
7
|
|
|
6
|
-
from pydantic import BaseModel, Field
|
|
8
|
+
from pydantic import BaseModel, Field, computed_field
|
|
7
9
|
from typing_extensions import TypedDict
|
|
8
10
|
|
|
11
|
+
from memos.log import get_logger
|
|
9
12
|
from memos.mem_cube.general import GeneralMemCube
|
|
10
13
|
|
|
11
14
|
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
12
18
|
FILE_PATH = Path(__file__).absolute()
|
|
13
19
|
BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent
|
|
14
20
|
|
|
15
21
|
QUERY_LABEL = "query"
|
|
16
22
|
ANSWER_LABEL = "answer"
|
|
23
|
+
ADD_LABEL = "add"
|
|
17
24
|
|
|
18
25
|
TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search"
|
|
19
26
|
TextMemory_SEARCH_METHOD = "text_memory_search"
|
|
20
|
-
|
|
27
|
+
DIRECT_EXCHANGE_TYPE = "direct"
|
|
28
|
+
FANOUT_EXCHANGE_TYPE = "fanout"
|
|
29
|
+
DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT = 20
|
|
30
|
+
DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT = 5
|
|
21
31
|
DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache"
|
|
22
32
|
DEFAULT_THREAD__POOL_MAX_WORKERS = 5
|
|
23
33
|
DEFAULT_CONSUME_INTERVAL_SECONDS = 3
|
|
24
34
|
NOT_INITIALIZED = -1
|
|
25
35
|
BaseModelType = TypeVar("T", bound="BaseModel")
|
|
26
36
|
|
|
37
|
+
# web log
|
|
38
|
+
LONG_TERM_MEMORY_TYPE = "LongTermMemory"
|
|
39
|
+
USER_MEMORY_TYPE = "UserMemory"
|
|
40
|
+
WORKING_MEMORY_TYPE = "WorkingMemory"
|
|
41
|
+
TEXT_MEMORY_TYPE = "TextMemory"
|
|
42
|
+
ACTIVATION_MEMORY_TYPE = "ActivationMemory"
|
|
43
|
+
PARAMETER_MEMORY_TYPE = "ParameterMemory"
|
|
44
|
+
USER_INPUT_TYPE = "UserInput"
|
|
45
|
+
NOT_APPLICABLE_TYPE = "NotApplicable"
|
|
46
|
+
|
|
47
|
+
# monitors
|
|
48
|
+
MONITOR_WORKING_MEMORY_TYPE = "MonitorWorkingMemoryType"
|
|
49
|
+
MONITOR_ACTIVATION_MEMORY_TYPE = "MonitorActivationMemoryType"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# new types
|
|
53
|
+
UserID = NewType("UserID", str)
|
|
54
|
+
MemCubeID = NewType("CubeID", str)
|
|
27
55
|
|
|
56
|
+
|
|
57
|
+
# ************************* Public *************************
|
|
28
58
|
class DictConversionMixin:
|
|
29
59
|
def to_dict(self) -> dict:
|
|
30
60
|
"""Convert the instance to a dictionary."""
|
|
31
61
|
return {
|
|
32
|
-
**self.dict()
|
|
62
|
+
**self.model_dump(), # 替换 self.dict()
|
|
33
63
|
"timestamp": self.timestamp.isoformat() if hasattr(self, "timestamp") else None,
|
|
34
64
|
}
|
|
35
65
|
|
|
@@ -40,10 +70,25 @@ class DictConversionMixin:
|
|
|
40
70
|
data["timestamp"] = datetime.fromisoformat(data["timestamp"])
|
|
41
71
|
return cls(**data)
|
|
42
72
|
|
|
73
|
+
def __str__(self) -> str:
|
|
74
|
+
"""Convert the instance to a JSON string with indentation of 4 spaces.
|
|
75
|
+
This will be used when str() or print() is called on the instance.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
str: A JSON string representation of the instance with 4-space indentation.
|
|
79
|
+
"""
|
|
80
|
+
return json.dumps(
|
|
81
|
+
self.to_dict(),
|
|
82
|
+
indent=4,
|
|
83
|
+
ensure_ascii=False,
|
|
84
|
+
default=str, # 处理无法序列化的对象
|
|
85
|
+
)
|
|
86
|
+
|
|
43
87
|
class Config:
|
|
44
88
|
json_encoders: ClassVar[dict[type, object]] = {datetime: lambda v: v.isoformat()}
|
|
45
89
|
|
|
46
90
|
|
|
91
|
+
# ************************* Messages *************************
|
|
47
92
|
class ScheduleMessageItem(BaseModel, DictConversionMixin):
|
|
48
93
|
item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4()))
|
|
49
94
|
user_id: str = Field(..., description="user id")
|
|
@@ -68,7 +113,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin):
|
|
|
68
113
|
"item_id": self.item_id,
|
|
69
114
|
"user_id": self.user_id,
|
|
70
115
|
"cube_id": self.mem_cube_id,
|
|
71
|
-
"message_id": self.message_id,
|
|
72
116
|
"label": self.label,
|
|
73
117
|
"cube": "Not Applicable", # Custom cube serialization
|
|
74
118
|
"content": self.content,
|
|
@@ -82,7 +126,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin):
|
|
|
82
126
|
item_id=data.get("item_id", str(uuid4())),
|
|
83
127
|
user_id=data["user_id"],
|
|
84
128
|
cube_id=data["cube_id"],
|
|
85
|
-
message_id=data.get("message_id", str(uuid4())),
|
|
86
129
|
label=data["label"],
|
|
87
130
|
cube="Not Applicable", # Custom cube deserialization
|
|
88
131
|
content=data["content"],
|
|
@@ -130,7 +173,8 @@ class ScheduleLogForWebItem(BaseModel, DictConversionMixin):
|
|
|
130
173
|
..., description="Identifier for the memcube associated with this log entry"
|
|
131
174
|
)
|
|
132
175
|
label: str = Field(..., description="Label categorizing the type of log")
|
|
133
|
-
|
|
176
|
+
from_memory_type: str = Field(..., description="Source memory type")
|
|
177
|
+
to_memory_type: str = Field(..., description="Destination memory type")
|
|
134
178
|
log_content: str = Field(..., description="Detailed content of the log entry")
|
|
135
179
|
current_memory_sizes: MemorySizes = Field(
|
|
136
180
|
default_factory=lambda: dict(DEFAULT_MEMORY_SIZES),
|
|
@@ -144,3 +188,141 @@ class ScheduleLogForWebItem(BaseModel, DictConversionMixin):
|
|
|
144
188
|
default_factory=datetime.now,
|
|
145
189
|
description="Timestamp indicating when the log entry was created",
|
|
146
190
|
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# ************************* Monitor *************************
|
|
194
|
+
class MemoryMonitorItem(BaseModel, DictConversionMixin):
|
|
195
|
+
item_id: str = Field(
|
|
196
|
+
description="Unique identifier for the memory item", default_factory=lambda: str(uuid4())
|
|
197
|
+
)
|
|
198
|
+
memory_text: str = Field(
|
|
199
|
+
...,
|
|
200
|
+
description="The actual content of the memory",
|
|
201
|
+
min_length=1,
|
|
202
|
+
max_length=10000, # Prevent excessively large memory texts
|
|
203
|
+
)
|
|
204
|
+
importance_score: float = Field(
|
|
205
|
+
default=NOT_INITIALIZED,
|
|
206
|
+
description="Numerical score representing the memory's importance",
|
|
207
|
+
ge=NOT_INITIALIZED, # Minimum value of 0
|
|
208
|
+
)
|
|
209
|
+
recording_count: int = Field(
|
|
210
|
+
default=1,
|
|
211
|
+
description="How many times this memory has been recorded",
|
|
212
|
+
ge=1, # Greater than or equal to 1
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
def get_score(self) -> float:
|
|
216
|
+
"""
|
|
217
|
+
Calculate the effective score for the memory item.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
float: The importance_score if it has been initialized (>=0),
|
|
221
|
+
otherwise the recording_count converted to float.
|
|
222
|
+
|
|
223
|
+
Note:
|
|
224
|
+
This method provides a unified way to retrieve a comparable score
|
|
225
|
+
for memory items, regardless of whether their importance has been explicitly set.
|
|
226
|
+
"""
|
|
227
|
+
if self.importance_score == NOT_INITIALIZED:
|
|
228
|
+
# Return recording_count as float when importance_score is not initialized
|
|
229
|
+
return float(self.recording_count)
|
|
230
|
+
else:
|
|
231
|
+
# Return the initialized importance_score
|
|
232
|
+
return self.importance_score
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class MemoryMonitorManager(BaseModel, DictConversionMixin):
|
|
236
|
+
user_id: str = Field(..., description="Required user identifier", min_length=1)
|
|
237
|
+
mem_cube_id: str = Field(..., description="Required memory cube identifier", min_length=1)
|
|
238
|
+
memories: list[MemoryMonitorItem] = Field(
|
|
239
|
+
default_factory=list, description="Collection of memory items"
|
|
240
|
+
)
|
|
241
|
+
max_capacity: int | None = Field(
|
|
242
|
+
default=None, description="Maximum number of memories allowed (None for unlimited)", ge=1
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
@computed_field
|
|
246
|
+
@property
|
|
247
|
+
def memory_size(self) -> int:
|
|
248
|
+
"""Automatically calculated count of memory items."""
|
|
249
|
+
return len(self.memories)
|
|
250
|
+
|
|
251
|
+
def update_memories(
|
|
252
|
+
self, text_working_memories: list[str], partial_retention_number: int
|
|
253
|
+
) -> MemoryMonitorItem:
|
|
254
|
+
"""
|
|
255
|
+
Update memories based on text_working_memories.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
text_working_memories: List of memory texts to update
|
|
259
|
+
partial_retention_number: Number of top memories to keep by recording count
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
List of added or updated MemoryMonitorItem instances
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
# Validate partial_retention_number
|
|
266
|
+
if partial_retention_number < 0:
|
|
267
|
+
raise ValueError("partial_retention_number must be non-negative")
|
|
268
|
+
|
|
269
|
+
# Create text lookup set
|
|
270
|
+
working_memory_set = set(text_working_memories)
|
|
271
|
+
|
|
272
|
+
# Step 1: Update existing memories or add new ones
|
|
273
|
+
added_or_updated = []
|
|
274
|
+
memory_text_map = {item.memory_text: item for item in self.memories}
|
|
275
|
+
|
|
276
|
+
for text in text_working_memories:
|
|
277
|
+
if text in memory_text_map:
|
|
278
|
+
# Update existing memory
|
|
279
|
+
memory = memory_text_map[text]
|
|
280
|
+
memory.recording_count += 1
|
|
281
|
+
added_or_updated.append(memory)
|
|
282
|
+
else:
|
|
283
|
+
# Add new memory
|
|
284
|
+
new_memory = MemoryMonitorItem(memory_text=text, recording_count=1)
|
|
285
|
+
self.memories.append(new_memory)
|
|
286
|
+
added_or_updated.append(new_memory)
|
|
287
|
+
|
|
288
|
+
# Step 2: Identify memories to remove
|
|
289
|
+
# Sort memories by recording_count in descending order
|
|
290
|
+
sorted_memories = sorted(self.memories, key=lambda item: item.recording_count, reverse=True)
|
|
291
|
+
|
|
292
|
+
# Keep the top N memories by recording_count
|
|
293
|
+
records_to_keep = {
|
|
294
|
+
memory.memory_text for memory in sorted_memories[:partial_retention_number]
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
# Collect memories to remove: not in current working memory and not in top N
|
|
298
|
+
memories_to_remove = [
|
|
299
|
+
memory
|
|
300
|
+
for memory in self.memories
|
|
301
|
+
if memory.memory_text not in working_memory_set
|
|
302
|
+
and memory.memory_text not in records_to_keep
|
|
303
|
+
]
|
|
304
|
+
|
|
305
|
+
# Step 3: Remove identified memories
|
|
306
|
+
for memory in memories_to_remove:
|
|
307
|
+
self.memories.remove(memory)
|
|
308
|
+
|
|
309
|
+
# Step 4: Enforce max_capacity if set
|
|
310
|
+
if self.max_capacity is not None and len(self.memories) > self.max_capacity:
|
|
311
|
+
# Sort by importance and then recording count
|
|
312
|
+
sorted_memories = sorted(
|
|
313
|
+
self.memories,
|
|
314
|
+
key=lambda item: (item.importance_score, item.recording_count),
|
|
315
|
+
reverse=True,
|
|
316
|
+
)
|
|
317
|
+
# Keep only the top max_capacity memories
|
|
318
|
+
self.memories = sorted_memories[: self.max_capacity]
|
|
319
|
+
|
|
320
|
+
# Log the update result
|
|
321
|
+
logger.info(
|
|
322
|
+
f"Updated monitor manager for user {self.user_id}, mem_cube {self.mem_cube_id}: "
|
|
323
|
+
f"Total memories: {len(self.memories)}, "
|
|
324
|
+
f"Added/Updated: {len(added_or_updated)}, "
|
|
325
|
+
f"Removed: {len(memories_to_remove)} (excluding top {partial_retention_number} by recording_count)"
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
return added_or_updated
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
|
|
3
|
+
from memos.configs.mem_os import MOSConfig
|
|
4
|
+
from memos.log import get_logger
|
|
5
|
+
from memos.mem_os.main import MOS
|
|
6
|
+
from memos.mem_scheduler.modules.schemas import (
|
|
7
|
+
ANSWER_LABEL,
|
|
8
|
+
MONITOR_WORKING_MEMORY_TYPE,
|
|
9
|
+
QUERY_LABEL,
|
|
10
|
+
ScheduleMessageItem,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MOSForTestScheduler(MOS):
|
|
18
|
+
"""This class is only to test abilities of mem scheduler"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, config: MOSConfig):
|
|
21
|
+
super().__init__(config)
|
|
22
|
+
|
|
23
|
+
def _str_memories(self, memories: list[str]) -> str:
|
|
24
|
+
"""Format memories for display."""
|
|
25
|
+
if not memories:
|
|
26
|
+
return "No memories."
|
|
27
|
+
return "\n".join(f"{i + 1}. {memory}" for i, memory in enumerate(memories))
|
|
28
|
+
|
|
29
|
+
def chat(self, query: str, user_id: str | None = None) -> str:
|
|
30
|
+
"""
|
|
31
|
+
Chat with the MOS.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
query (str): The user's query.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
str: The response from the MOS.
|
|
38
|
+
"""
|
|
39
|
+
target_user_id = user_id if user_id is not None else self.user_id
|
|
40
|
+
accessible_cubes = self.user_manager.get_user_cubes(target_user_id)
|
|
41
|
+
user_cube_ids = [cube.cube_id for cube in accessible_cubes]
|
|
42
|
+
if target_user_id not in self.chat_history_manager:
|
|
43
|
+
self._register_chat_history(target_user_id)
|
|
44
|
+
|
|
45
|
+
chat_history = self.chat_history_manager[target_user_id]
|
|
46
|
+
|
|
47
|
+
topk_for_scheduler = 2
|
|
48
|
+
|
|
49
|
+
if self.config.enable_textual_memory and self.mem_cubes:
|
|
50
|
+
memories_all = []
|
|
51
|
+
for mem_cube_id, mem_cube in self.mem_cubes.items():
|
|
52
|
+
if mem_cube_id not in user_cube_ids:
|
|
53
|
+
continue
|
|
54
|
+
if not mem_cube.text_mem:
|
|
55
|
+
continue
|
|
56
|
+
|
|
57
|
+
# submit message to scheduler
|
|
58
|
+
if self.enable_mem_scheduler and self.mem_scheduler is not None:
|
|
59
|
+
message_item = ScheduleMessageItem(
|
|
60
|
+
user_id=target_user_id,
|
|
61
|
+
mem_cube_id=mem_cube_id,
|
|
62
|
+
mem_cube=mem_cube,
|
|
63
|
+
label=QUERY_LABEL,
|
|
64
|
+
content=query,
|
|
65
|
+
timestamp=datetime.now(),
|
|
66
|
+
)
|
|
67
|
+
self.mem_scheduler.submit_messages(messages=[message_item])
|
|
68
|
+
|
|
69
|
+
self.mem_scheduler.monitor.register_memory_manager_if_not_exists(
|
|
70
|
+
user_id=user_id,
|
|
71
|
+
mem_cube_id=mem_cube_id,
|
|
72
|
+
memory_monitors=self.mem_scheduler.monitor.working_memory_monitors,
|
|
73
|
+
max_capacity=self.mem_scheduler.monitor.working_mem_monitor_capacity,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# from scheduler
|
|
77
|
+
scheduler_memories = self.mem_scheduler.monitor.get_monitor_memories(
|
|
78
|
+
user_id=target_user_id,
|
|
79
|
+
mem_cube_id=mem_cube_id,
|
|
80
|
+
memory_type=MONITOR_WORKING_MEMORY_TYPE,
|
|
81
|
+
top_k=topk_for_scheduler,
|
|
82
|
+
)
|
|
83
|
+
memories_all.extend(scheduler_memories)
|
|
84
|
+
|
|
85
|
+
# from mem_cube
|
|
86
|
+
memories = mem_cube.text_mem.search(
|
|
87
|
+
query, top_k=self.config.top_k - topk_for_scheduler
|
|
88
|
+
)
|
|
89
|
+
text_memories = [m.memory for m in memories]
|
|
90
|
+
memories_all.extend(text_memories)
|
|
91
|
+
|
|
92
|
+
memories_all = list(set(memories_all))
|
|
93
|
+
|
|
94
|
+
logger.info(f"🧠 [Memory] Searched memories:\n{self._str_memories(memories_all)}\n")
|
|
95
|
+
system_prompt = self._build_system_prompt(memories_all)
|
|
96
|
+
else:
|
|
97
|
+
system_prompt = self._build_system_prompt()
|
|
98
|
+
current_messages = [
|
|
99
|
+
{"role": "system", "content": system_prompt},
|
|
100
|
+
*chat_history.chat_history,
|
|
101
|
+
{"role": "user", "content": query},
|
|
102
|
+
]
|
|
103
|
+
past_key_values = None
|
|
104
|
+
|
|
105
|
+
if self.config.enable_activation_memory:
|
|
106
|
+
assert self.config.chat_model.backend == "huggingface", (
|
|
107
|
+
"Activation memory only used for huggingface backend."
|
|
108
|
+
)
|
|
109
|
+
# TODO this only one cubes
|
|
110
|
+
for mem_cube_id, mem_cube in self.mem_cubes.items():
|
|
111
|
+
if mem_cube_id not in user_cube_ids:
|
|
112
|
+
continue
|
|
113
|
+
if mem_cube.act_mem:
|
|
114
|
+
kv_cache = next(iter(mem_cube.act_mem.get_all()), None)
|
|
115
|
+
past_key_values = (
|
|
116
|
+
kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None
|
|
117
|
+
)
|
|
118
|
+
break
|
|
119
|
+
# Generate response
|
|
120
|
+
response = self.chat_llm.generate(current_messages, past_key_values=past_key_values)
|
|
121
|
+
else:
|
|
122
|
+
response = self.chat_llm.generate(current_messages)
|
|
123
|
+
logger.info(f"🤖 [Assistant] {response}\n")
|
|
124
|
+
chat_history.chat_history.append({"role": "user", "content": query})
|
|
125
|
+
chat_history.chat_history.append({"role": "assistant", "content": response})
|
|
126
|
+
self.chat_history_manager[user_id] = chat_history
|
|
127
|
+
|
|
128
|
+
# submit message to scheduler
|
|
129
|
+
for accessible_mem_cube in accessible_cubes:
|
|
130
|
+
mem_cube_id = accessible_mem_cube.cube_id
|
|
131
|
+
mem_cube = self.mem_cubes[mem_cube_id]
|
|
132
|
+
if self.enable_mem_scheduler and self.mem_scheduler is not None:
|
|
133
|
+
message_item = ScheduleMessageItem(
|
|
134
|
+
user_id=target_user_id,
|
|
135
|
+
mem_cube_id=mem_cube_id,
|
|
136
|
+
mem_cube=mem_cube,
|
|
137
|
+
label=ANSWER_LABEL,
|
|
138
|
+
content=response,
|
|
139
|
+
timestamp=datetime.now(),
|
|
140
|
+
)
|
|
141
|
+
self.mem_scheduler.submit_messages(messages=[message_item])
|
|
142
|
+
|
|
143
|
+
return response
|
memos/mem_scheduler/utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import re
|
|
2
3
|
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
|
|
@@ -7,13 +8,41 @@ import yaml
|
|
|
7
8
|
|
|
8
9
|
def extract_json_dict(text: str):
|
|
9
10
|
text = text.strip()
|
|
10
|
-
patterns_to_remove = ["json
|
|
11
|
+
patterns_to_remove = ["json```", "```json", "latex```", "```latex", "```"]
|
|
11
12
|
for pattern in patterns_to_remove:
|
|
12
13
|
text = text.replace(pattern, "")
|
|
13
|
-
res = json.loads(text)
|
|
14
|
+
res = json.loads(text.strip())
|
|
14
15
|
return res
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
def transform_name_to_key(name):
|
|
19
|
+
"""
|
|
20
|
+
Normalize text by removing all punctuation marks, keeping only letters, numbers, and word characters.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
name (str): Input text to be processed
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
str: Processed text with all punctuation removed
|
|
27
|
+
"""
|
|
28
|
+
# Match all characters that are NOT:
|
|
29
|
+
# \w - word characters (letters, digits, underscore)
|
|
30
|
+
# \u4e00-\u9fff - Chinese/Japanese/Korean characters
|
|
31
|
+
# \s - whitespace
|
|
32
|
+
pattern = r"[^\w\u4e00-\u9fff\s]"
|
|
33
|
+
|
|
34
|
+
# Substitute all matched punctuation marks with empty string
|
|
35
|
+
# re.UNICODE flag ensures proper handling of Unicode characters
|
|
36
|
+
normalized = re.sub(pattern, "", name, flags=re.UNICODE)
|
|
37
|
+
|
|
38
|
+
# Optional: Collapse multiple whitespaces into single space
|
|
39
|
+
normalized = "_".join(normalized.split())
|
|
40
|
+
|
|
41
|
+
normalized = normalized.lower()
|
|
42
|
+
|
|
43
|
+
return normalized
|
|
44
|
+
|
|
45
|
+
|
|
17
46
|
def parse_yaml(yaml_file):
|
|
18
47
|
yaml_path = Path(yaml_file)
|
|
19
48
|
yaml_path = Path(yaml_file)
|
|
@@ -24,3 +53,23 @@ def parse_yaml(yaml_file):
|
|
|
24
53
|
data = yaml.safe_load(fr)
|
|
25
54
|
|
|
26
55
|
return data
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def is_all_english(input_string: str) -> bool:
|
|
59
|
+
"""Determine if the string consists entirely of English characters (including spaces)"""
|
|
60
|
+
return all(char.isascii() or char.isspace() for char in input_string)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def is_all_chinese(input_string: str) -> bool:
|
|
64
|
+
"""Determine if the string consists entirely of Chinese characters (including Chinese punctuation and spaces)"""
|
|
65
|
+
return all(
|
|
66
|
+
("\u4e00" <= char <= "\u9fff") # Basic Chinese characters
|
|
67
|
+
or ("\u3400" <= char <= "\u4dbf") # Extension A
|
|
68
|
+
or ("\u20000" <= char <= "\u2a6df") # Extension B
|
|
69
|
+
or ("\u2a700" <= char <= "\u2b73f") # Extension C
|
|
70
|
+
or ("\u2b740" <= char <= "\u2b81f") # Extension D
|
|
71
|
+
or ("\u2b820" <= char <= "\u2ceaf") # Extension E
|
|
72
|
+
or ("\u2f800" <= char <= "\u2fa1f") # Extension F
|
|
73
|
+
or char.isspace() # Spaces
|
|
74
|
+
for char in input_string
|
|
75
|
+
)
|