MemoryOS 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/METADATA +7 -1
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/RECORD +87 -64
- memos/__init__.py +1 -1
- memos/api/config.py +158 -69
- memos/api/context/context.py +147 -0
- memos/api/context/dependencies.py +101 -0
- memos/api/product_models.py +5 -1
- memos/api/routers/product_router.py +54 -26
- memos/configs/graph_db.py +49 -1
- memos/configs/internet_retriever.py +19 -0
- memos/configs/mem_os.py +5 -0
- memos/configs/mem_reader.py +9 -0
- memos/configs/mem_scheduler.py +54 -18
- memos/configs/mem_user.py +58 -0
- memos/graph_dbs/base.py +38 -3
- memos/graph_dbs/factory.py +2 -0
- memos/graph_dbs/nebular.py +1612 -0
- memos/graph_dbs/neo4j.py +18 -9
- memos/log.py +6 -1
- memos/mem_cube/utils.py +13 -6
- memos/mem_os/core.py +157 -37
- memos/mem_os/main.py +2 -2
- memos/mem_os/product.py +252 -201
- memos/mem_os/utils/default_config.py +1 -1
- memos/mem_os/utils/format_utils.py +281 -70
- memos/mem_os/utils/reference_utils.py +133 -0
- memos/mem_reader/simple_struct.py +13 -5
- memos/mem_scheduler/base_scheduler.py +239 -266
- memos/mem_scheduler/{modules → general_modules}/base.py +4 -5
- memos/mem_scheduler/{modules → general_modules}/dispatcher.py +57 -21
- memos/mem_scheduler/general_modules/misc.py +104 -0
- memos/mem_scheduler/{modules → general_modules}/rabbitmq_service.py +12 -10
- memos/mem_scheduler/{modules → general_modules}/redis_service.py +1 -1
- memos/mem_scheduler/general_modules/retriever.py +199 -0
- memos/mem_scheduler/general_modules/scheduler_logger.py +261 -0
- memos/mem_scheduler/general_scheduler.py +243 -80
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +305 -0
- memos/mem_scheduler/{modules/monitor.py → monitors/general_monitor.py} +106 -57
- memos/mem_scheduler/mos_for_test_scheduler.py +23 -20
- memos/mem_scheduler/schemas/__init__.py +0 -0
- memos/mem_scheduler/schemas/general_schemas.py +44 -0
- memos/mem_scheduler/schemas/message_schemas.py +149 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +337 -0
- memos/mem_scheduler/utils/__init__.py +0 -0
- memos/mem_scheduler/utils/filter_utils.py +176 -0
- memos/mem_scheduler/utils/misc_utils.py +102 -0
- memos/mem_user/factory.py +94 -0
- memos/mem_user/mysql_persistent_user_manager.py +271 -0
- memos/mem_user/mysql_user_manager.py +500 -0
- memos/mem_user/persistent_factory.py +96 -0
- memos/mem_user/user_manager.py +4 -4
- memos/memories/activation/item.py +5 -1
- memos/memories/activation/kv.py +20 -8
- memos/memories/textual/base.py +2 -2
- memos/memories/textual/general.py +36 -92
- memos/memories/textual/item.py +5 -33
- memos/memories/textual/tree.py +13 -7
- memos/memories/textual/tree_text_memory/organize/{conflict.py → handler.py} +34 -50
- memos/memories/textual/tree_text_memory/organize/manager.py +8 -96
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +49 -43
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +107 -142
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +229 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -3
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +11 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +15 -8
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +1 -1
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +2 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +191 -116
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +47 -15
- memos/memories/textual/tree_text_memory/retrieve/utils.py +11 -7
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +62 -58
- memos/memos_tools/dinding_report_bot.py +422 -0
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_service.py +44 -0
- memos/memos_tools/notification_utils.py +96 -0
- memos/memos_tools/thread_safe_dict.py +288 -0
- memos/settings.py +3 -1
- memos/templates/mem_reader_prompts.py +4 -1
- memos/templates/mem_scheduler_prompts.py +62 -15
- memos/templates/mos_prompts.py +116 -0
- memos/templates/tree_reorganize_prompts.py +24 -17
- memos/utils.py +19 -0
- memos/mem_scheduler/modules/misc.py +0 -39
- memos/mem_scheduler/modules/retriever.py +0 -268
- memos/mem_scheduler/modules/schemas.py +0 -328
- memos/mem_scheduler/utils.py +0 -75
- memos/memories/textual/tree_text_memory/organize/redundancy.py +0 -193
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/LICENSE +0 -0
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/WHEEL +0 -0
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/entry_points.txt +0 -0
- /memos/mem_scheduler/{modules → general_modules}/__init__.py +0 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from memos.dependency import require_python_package
|
|
4
|
+
from memos.log import get_logger
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def transform_name_to_key(name):
|
|
11
|
+
"""
|
|
12
|
+
Normalize text by removing all punctuation marks, keeping only letters, numbers, and word characters.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
name (str): Input text to be processed
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
str: Processed text with all punctuation removed
|
|
19
|
+
"""
|
|
20
|
+
# Match all characters that are NOT:
|
|
21
|
+
# \w - word characters (letters, digits, underscore)
|
|
22
|
+
# \u4e00-\u9fff - Chinese/Japanese/Korean characters
|
|
23
|
+
# \s - whitespace
|
|
24
|
+
pattern = r"[^\w\u4e00-\u9fff\s]"
|
|
25
|
+
|
|
26
|
+
# Substitute all matched punctuation marks with empty string
|
|
27
|
+
# re.UNICODE flag ensures proper handling of Unicode characters
|
|
28
|
+
normalized = re.sub(pattern, "", name, flags=re.UNICODE)
|
|
29
|
+
|
|
30
|
+
# Optional: Collapse multiple whitespaces into single space
|
|
31
|
+
normalized = "_".join(normalized.split())
|
|
32
|
+
|
|
33
|
+
normalized = normalized.lower()
|
|
34
|
+
|
|
35
|
+
return normalized
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def is_all_english(input_string: str) -> bool:
|
|
39
|
+
"""Determine if the string consists entirely of English characters (including spaces)"""
|
|
40
|
+
return all(char.isascii() or char.isspace() for char in input_string)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def is_all_chinese(input_string: str) -> bool:
|
|
44
|
+
"""Determine if the string consists entirely of Chinese characters (including Chinese punctuation and spaces)"""
|
|
45
|
+
return all(
|
|
46
|
+
("\u4e00" <= char <= "\u9fff") # Basic Chinese characters
|
|
47
|
+
or ("\u3400" <= char <= "\u4dbf") # Extension A
|
|
48
|
+
or ("\u20000" <= char <= "\u2a6df") # Extension B
|
|
49
|
+
or ("\u2a700" <= char <= "\u2b73f") # Extension C
|
|
50
|
+
or ("\u2b740" <= char <= "\u2b81f") # Extension D
|
|
51
|
+
or ("\u2b820" <= char <= "\u2ceaf") # Extension E
|
|
52
|
+
or ("\u2f800" <= char <= "\u2fa1f") # Extension F
|
|
53
|
+
or char.isspace() # Spaces
|
|
54
|
+
for char in input_string
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@require_python_package(
|
|
59
|
+
import_name="sklearn",
|
|
60
|
+
install_command="pip install scikit-learn",
|
|
61
|
+
install_link="https://scikit-learn.org/stable/install.html",
|
|
62
|
+
)
|
|
63
|
+
def filter_similar_memories(
|
|
64
|
+
text_memories: list[str], similarity_threshold: float = 0.75
|
|
65
|
+
) -> list[str]:
|
|
66
|
+
"""
|
|
67
|
+
Filters out low-quality or duplicate memories based on text similarity.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
text_memories: List of text memories to filter
|
|
71
|
+
similarity_threshold: Threshold for considering memories duplicates (0.0-1.0)
|
|
72
|
+
Higher values mean stricter filtering
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
List of filtered memories with duplicates removed
|
|
76
|
+
"""
|
|
77
|
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
78
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
79
|
+
|
|
80
|
+
if not text_memories:
|
|
81
|
+
logger.warning("Received empty memories list - nothing to filter")
|
|
82
|
+
return []
|
|
83
|
+
|
|
84
|
+
for idx in range(len(text_memories)):
|
|
85
|
+
if not isinstance(text_memories[idx], str):
|
|
86
|
+
logger.error(
|
|
87
|
+
f"{text_memories[idx]} in memories is not a string,"
|
|
88
|
+
f" and now has been transformed to be a string."
|
|
89
|
+
)
|
|
90
|
+
text_memories[idx] = str(text_memories[idx])
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
# Step 1: Vectorize texts using TF-IDF
|
|
94
|
+
vectorizer = TfidfVectorizer()
|
|
95
|
+
tfidf_matrix = vectorizer.fit_transform(text_memories)
|
|
96
|
+
|
|
97
|
+
# Step 2: Calculate pairwise similarity matrix
|
|
98
|
+
similarity_matrix = cosine_similarity(tfidf_matrix)
|
|
99
|
+
|
|
100
|
+
# Step 3: Identify duplicates
|
|
101
|
+
to_keep = set(range(len(text_memories))) # Start with all indices
|
|
102
|
+
for i in range(len(similarity_matrix)):
|
|
103
|
+
if i not in to_keep:
|
|
104
|
+
continue # Already marked for removal
|
|
105
|
+
|
|
106
|
+
# Find all similar items to this one (excluding self and already removed)
|
|
107
|
+
similar_indices = [
|
|
108
|
+
j
|
|
109
|
+
for j in range(i + 1, len(similarity_matrix))
|
|
110
|
+
if similarity_matrix[i][j] >= similarity_threshold and j in to_keep
|
|
111
|
+
]
|
|
112
|
+
similar_indices = set(similar_indices)
|
|
113
|
+
|
|
114
|
+
# Remove all similar items (keeping the first one - i)
|
|
115
|
+
to_keep -= similar_indices
|
|
116
|
+
|
|
117
|
+
# Return filtered memories
|
|
118
|
+
filtered_memories = [text_memories[i] for i in sorted(to_keep)]
|
|
119
|
+
logger.debug(f"filtered_memories: {filtered_memories}")
|
|
120
|
+
return filtered_memories
|
|
121
|
+
|
|
122
|
+
except Exception as e:
|
|
123
|
+
logger.error(f"Error filtering memories: {e!s}")
|
|
124
|
+
return text_memories # Return original list if error occurs
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def filter_too_short_memories(
|
|
128
|
+
text_memories: list[str], min_length_threshold: int = 20
|
|
129
|
+
) -> list[str]:
|
|
130
|
+
"""
|
|
131
|
+
Filters out text memories that fall below the minimum length requirement.
|
|
132
|
+
Handles both English (word count) and Chinese (character count) differently.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
text_memories: List of text memories to be filtered
|
|
136
|
+
min_length_threshold: Minimum length required to keep a memory.
|
|
137
|
+
For English: word count, for Chinese: character count.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
List of filtered memories meeting the length requirement
|
|
141
|
+
"""
|
|
142
|
+
if not text_memories:
|
|
143
|
+
logger.debug("Empty memories list received in short memory filter")
|
|
144
|
+
return []
|
|
145
|
+
|
|
146
|
+
filtered_memories = []
|
|
147
|
+
removed_count = 0
|
|
148
|
+
|
|
149
|
+
for memory in text_memories:
|
|
150
|
+
stripped_memory = memory.strip()
|
|
151
|
+
if not stripped_memory: # Skip empty/whitespace memories
|
|
152
|
+
removed_count += 1
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
# Determine measurement method based on language
|
|
156
|
+
if is_all_english(stripped_memory):
|
|
157
|
+
length = len(stripped_memory.split()) # Word count for English
|
|
158
|
+
elif is_all_chinese(stripped_memory):
|
|
159
|
+
length = len(stripped_memory) # Character count for Chinese
|
|
160
|
+
else:
|
|
161
|
+
logger.debug(f"Mixed-language memory, using character count: {stripped_memory[:50]}...")
|
|
162
|
+
length = len(stripped_memory) # Default to character count
|
|
163
|
+
|
|
164
|
+
if length >= min_length_threshold:
|
|
165
|
+
filtered_memories.append(memory)
|
|
166
|
+
else:
|
|
167
|
+
removed_count += 1
|
|
168
|
+
|
|
169
|
+
if removed_count > 0:
|
|
170
|
+
logger.info(
|
|
171
|
+
f"Filtered out {removed_count} short memories "
|
|
172
|
+
f"(below {min_length_threshold} units). "
|
|
173
|
+
f"Total remaining: {len(filtered_memories)}"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return filtered_memories
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
from memos.log import get_logger
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def extract_json_dict(text: str):
|
|
16
|
+
"""
|
|
17
|
+
Safely extracts JSON from LLM response text with robust error handling.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
text: Raw text response from LLM that may contain JSON
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Parsed JSON data (dict or list)
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
ValueError: If no valid JSON can be extracted
|
|
27
|
+
"""
|
|
28
|
+
if not text:
|
|
29
|
+
raise ValueError("Empty input text")
|
|
30
|
+
|
|
31
|
+
# Normalize the text
|
|
32
|
+
text = text.strip()
|
|
33
|
+
|
|
34
|
+
# Remove common code block markers
|
|
35
|
+
patterns_to_remove = ["json```", "```python", "```json", "latex```", "```latex", "```"]
|
|
36
|
+
for pattern in patterns_to_remove:
|
|
37
|
+
text = text.replace(pattern, "")
|
|
38
|
+
|
|
39
|
+
# Try: direct JSON parse first
|
|
40
|
+
try:
|
|
41
|
+
return json.loads(text.strip())
|
|
42
|
+
except json.JSONDecodeError as e:
|
|
43
|
+
logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
|
|
44
|
+
|
|
45
|
+
# Fallback 1: Extract JSON using regex
|
|
46
|
+
json_pattern = r"\{[\s\S]*\}|\[[\s\S]*\]"
|
|
47
|
+
matches = re.findall(json_pattern, text)
|
|
48
|
+
if matches:
|
|
49
|
+
try:
|
|
50
|
+
return json.loads(matches[0])
|
|
51
|
+
except json.JSONDecodeError as e:
|
|
52
|
+
logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
|
|
53
|
+
|
|
54
|
+
# Fallback 2: Handle malformed JSON (common LLM issues)
|
|
55
|
+
try:
|
|
56
|
+
# Try adding missing quotes around keys
|
|
57
|
+
text = re.sub(r"([\{\s,])(\w+)(:)", r'\1"\2"\3', text)
|
|
58
|
+
return json.loads(text)
|
|
59
|
+
except json.JSONDecodeError as e:
|
|
60
|
+
logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
|
|
61
|
+
raise ValueError(text) from e
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def parse_yaml(yaml_file: str | Path):
|
|
65
|
+
yaml_path = Path(yaml_file)
|
|
66
|
+
if not yaml_path.is_file():
|
|
67
|
+
raise FileNotFoundError(f"No such file: {yaml_file}")
|
|
68
|
+
|
|
69
|
+
with yaml_path.open("r", encoding="utf-8") as fr:
|
|
70
|
+
data = yaml.safe_load(fr)
|
|
71
|
+
|
|
72
|
+
return data
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def log_exceptions(logger=logger):
|
|
76
|
+
"""
|
|
77
|
+
Exception-catching decorator that automatically logs errors (including stack traces)
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
logger: Optional logger object (default: module-level logger)
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
@log_exceptions()
|
|
84
|
+
def risky_function():
|
|
85
|
+
raise ValueError("Oops!")
|
|
86
|
+
|
|
87
|
+
@log_exceptions(logger=custom_logger)
|
|
88
|
+
def another_risky_function():
|
|
89
|
+
might_fail()
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def decorator(func):
|
|
93
|
+
@wraps(func)
|
|
94
|
+
def wrapper(*args, **kwargs):
|
|
95
|
+
try:
|
|
96
|
+
return func(*args, **kwargs)
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.error(f"Error in {func.__name__}: {e}", exc_info=True)
|
|
99
|
+
|
|
100
|
+
return wrapper
|
|
101
|
+
|
|
102
|
+
return decorator
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from typing import Any, ClassVar
|
|
2
|
+
|
|
3
|
+
from memos.configs.mem_user import UserManagerConfigFactory
|
|
4
|
+
from memos.mem_user.mysql_user_manager import MySQLUserManager
|
|
5
|
+
from memos.mem_user.user_manager import UserManager
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class UserManagerFactory:
|
|
9
|
+
"""Factory class for creating user manager instances."""
|
|
10
|
+
|
|
11
|
+
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
12
|
+
"sqlite": UserManager,
|
|
13
|
+
"mysql": MySQLUserManager,
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def from_config(
|
|
18
|
+
cls, config_factory: UserManagerConfigFactory
|
|
19
|
+
) -> UserManager | MySQLUserManager:
|
|
20
|
+
"""Create a user manager instance from configuration.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
config_factory: Configuration factory containing backend and config
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
User manager instance
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If backend is not supported
|
|
30
|
+
"""
|
|
31
|
+
backend = config_factory.backend
|
|
32
|
+
if backend not in cls.backend_to_class:
|
|
33
|
+
raise ValueError(f"Invalid user manager backend: {backend}")
|
|
34
|
+
|
|
35
|
+
user_manager_class = cls.backend_to_class[backend]
|
|
36
|
+
config = config_factory.config
|
|
37
|
+
|
|
38
|
+
# Use model_dump() to convert Pydantic model to dict and unpack as kwargs
|
|
39
|
+
return user_manager_class(**config.model_dump())
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def create_sqlite(cls, db_path: str | None = None, user_id: str = "root") -> UserManager:
|
|
43
|
+
"""Create SQLite user manager with default configuration.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
db_path: Path to SQLite database file
|
|
47
|
+
user_id: Default user ID for initialization
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
SQLite user manager instance
|
|
51
|
+
"""
|
|
52
|
+
config_factory = UserManagerConfigFactory(
|
|
53
|
+
backend="sqlite", config={"db_path": db_path, "user_id": user_id}
|
|
54
|
+
)
|
|
55
|
+
return cls.from_config(config_factory)
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def create_mysql(
|
|
59
|
+
cls,
|
|
60
|
+
user_id: str = "root",
|
|
61
|
+
host: str = "localhost",
|
|
62
|
+
port: int = 3306,
|
|
63
|
+
username: str = "root",
|
|
64
|
+
password: str = "",
|
|
65
|
+
database: str = "memos_users",
|
|
66
|
+
charset: str = "utf8mb4",
|
|
67
|
+
) -> MySQLUserManager:
|
|
68
|
+
"""Create MySQL user manager with specified configuration.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
user_id: Default user ID for initialization
|
|
72
|
+
host: MySQL server host
|
|
73
|
+
port: MySQL server port
|
|
74
|
+
username: MySQL username
|
|
75
|
+
password: MySQL password
|
|
76
|
+
database: MySQL database name
|
|
77
|
+
charset: MySQL charset
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
MySQL user manager instance
|
|
81
|
+
"""
|
|
82
|
+
config_factory = UserManagerConfigFactory(
|
|
83
|
+
backend="mysql",
|
|
84
|
+
config={
|
|
85
|
+
"user_id": user_id,
|
|
86
|
+
"host": host,
|
|
87
|
+
"port": port,
|
|
88
|
+
"username": username,
|
|
89
|
+
"password": password,
|
|
90
|
+
"database": database,
|
|
91
|
+
"charset": charset,
|
|
92
|
+
},
|
|
93
|
+
)
|
|
94
|
+
return cls.from_config(config_factory)
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""Persistent user management system for MemOS with configuration storage.
|
|
2
|
+
|
|
3
|
+
This module extends the MySQL UserManager to provide persistent storage
|
|
4
|
+
for user configurations and MOS instances.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from sqlalchemy import Column, String, Text
|
|
13
|
+
|
|
14
|
+
from memos.configs.mem_os import MOSConfig
|
|
15
|
+
from memos.log import get_logger
|
|
16
|
+
from memos.mem_user.mysql_user_manager import Base, MySQLUserManager
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
logger = get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class UserConfig(Base):
|
|
23
|
+
"""User configuration model for the database."""
|
|
24
|
+
|
|
25
|
+
__tablename__ = "user_configs"
|
|
26
|
+
|
|
27
|
+
user_id = Column(String(255), primary_key=True)
|
|
28
|
+
config_data = Column(Text, nullable=False) # JSON string of MOSConfig
|
|
29
|
+
created_at = Column(String(50), nullable=False) # ISO format timestamp
|
|
30
|
+
updated_at = Column(String(50), nullable=False) # ISO format timestamp
|
|
31
|
+
|
|
32
|
+
def __repr__(self):
|
|
33
|
+
return f"<UserConfig(user_id='{self.user_id}')>"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MySQLPersistentUserManager(MySQLUserManager):
|
|
37
|
+
"""Extended MySQLUserManager with configuration persistence."""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
user_id: str = "root",
|
|
42
|
+
host: str = "localhost",
|
|
43
|
+
port: int = 3306,
|
|
44
|
+
username: str = "root",
|
|
45
|
+
password: str = "",
|
|
46
|
+
database: str = "memos_users",
|
|
47
|
+
charset: str = "utf8mb4",
|
|
48
|
+
):
|
|
49
|
+
"""Initialize the persistent user manager.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
user_id (str, optional): User ID. If None, uses default user ID.
|
|
53
|
+
host (str): MySQL server host. Defaults to "localhost".
|
|
54
|
+
port (int): MySQL server port. Defaults to 3306.
|
|
55
|
+
username (str): MySQL username. Defaults to "root".
|
|
56
|
+
password (str): MySQL password. Defaults to "".
|
|
57
|
+
database (str): MySQL database name. Defaults to "memos_users".
|
|
58
|
+
charset (str): MySQL charset. Defaults to "utf8mb4".
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(user_id, host, port, username, password, database, charset)
|
|
61
|
+
|
|
62
|
+
# Create user_configs table
|
|
63
|
+
Base.metadata.create_all(bind=self.engine)
|
|
64
|
+
logger.info("MySQLPersistentUserManager initialized with configuration storage")
|
|
65
|
+
|
|
66
|
+
def _convert_datetime_strings(self, obj: Any) -> Any:
|
|
67
|
+
"""Recursively convert datetime strings back to datetime objects in config dict.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
obj: The object to process (dict, list, or primitive type)
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
The object with datetime strings converted to datetime objects
|
|
74
|
+
"""
|
|
75
|
+
if isinstance(obj, dict):
|
|
76
|
+
result = {}
|
|
77
|
+
for key, value in obj.items():
|
|
78
|
+
if key == "created_at" and isinstance(value, str):
|
|
79
|
+
try:
|
|
80
|
+
result[key] = datetime.fromisoformat(value)
|
|
81
|
+
except ValueError:
|
|
82
|
+
# If parsing fails, keep the original string
|
|
83
|
+
result[key] = value
|
|
84
|
+
else:
|
|
85
|
+
result[key] = self._convert_datetime_strings(value)
|
|
86
|
+
return result
|
|
87
|
+
elif isinstance(obj, list):
|
|
88
|
+
return [self._convert_datetime_strings(item) for item in obj]
|
|
89
|
+
else:
|
|
90
|
+
return obj
|
|
91
|
+
|
|
92
|
+
def save_user_config(self, user_id: str, config: MOSConfig) -> bool:
|
|
93
|
+
"""Save user configuration to database.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
user_id (str): The user ID.
|
|
97
|
+
config (MOSConfig): The user's MOS configuration.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
bool: True if successful, False otherwise.
|
|
101
|
+
"""
|
|
102
|
+
session = self._get_session()
|
|
103
|
+
try:
|
|
104
|
+
# Convert config to JSON string with proper datetime handling
|
|
105
|
+
config_dict = config.model_dump(mode="json")
|
|
106
|
+
config_json = json.dumps(config_dict, indent=2)
|
|
107
|
+
|
|
108
|
+
now = datetime.now().isoformat()
|
|
109
|
+
|
|
110
|
+
# Check if config already exists
|
|
111
|
+
existing_config = (
|
|
112
|
+
session.query(UserConfig).filter(UserConfig.user_id == user_id).first()
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
if existing_config:
|
|
116
|
+
# Update existing config
|
|
117
|
+
existing_config.config_data = config_json
|
|
118
|
+
existing_config.updated_at = now
|
|
119
|
+
logger.info(f"Updated configuration for user {user_id}")
|
|
120
|
+
else:
|
|
121
|
+
# Create new config
|
|
122
|
+
user_config = UserConfig(
|
|
123
|
+
user_id=user_id, config_data=config_json, created_at=now, updated_at=now
|
|
124
|
+
)
|
|
125
|
+
session.add(user_config)
|
|
126
|
+
logger.info(f"Saved new configuration for user {user_id}")
|
|
127
|
+
|
|
128
|
+
session.commit()
|
|
129
|
+
return True
|
|
130
|
+
|
|
131
|
+
except Exception as e:
|
|
132
|
+
session.rollback()
|
|
133
|
+
logger.error(f"Error saving user config for {user_id}: {e}")
|
|
134
|
+
return False
|
|
135
|
+
finally:
|
|
136
|
+
session.close()
|
|
137
|
+
|
|
138
|
+
def get_user_config(self, user_id: str) -> MOSConfig | None:
|
|
139
|
+
"""Get user configuration from database.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
user_id (str): The user ID.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
MOSConfig | None: The user's configuration or None if not found.
|
|
146
|
+
"""
|
|
147
|
+
session = self._get_session()
|
|
148
|
+
try:
|
|
149
|
+
user_config = session.query(UserConfig).filter(UserConfig.user_id == user_id).first()
|
|
150
|
+
|
|
151
|
+
if user_config:
|
|
152
|
+
config_dict = json.loads(user_config.config_data)
|
|
153
|
+
# Convert datetime strings back to datetime objects
|
|
154
|
+
config_dict = self._convert_datetime_strings(config_dict)
|
|
155
|
+
return MOSConfig(**config_dict)
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
except Exception as e:
|
|
159
|
+
logger.error(f"Error loading user config for {user_id}: {e}")
|
|
160
|
+
return None
|
|
161
|
+
finally:
|
|
162
|
+
session.close()
|
|
163
|
+
|
|
164
|
+
def delete_user_config(self, user_id: str) -> bool:
|
|
165
|
+
"""Delete user configuration from database.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
user_id (str): The user ID.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
bool: True if successful, False otherwise.
|
|
172
|
+
"""
|
|
173
|
+
session = self._get_session()
|
|
174
|
+
try:
|
|
175
|
+
user_config = session.query(UserConfig).filter(UserConfig.user_id == user_id).first()
|
|
176
|
+
|
|
177
|
+
if user_config:
|
|
178
|
+
session.delete(user_config)
|
|
179
|
+
session.commit()
|
|
180
|
+
logger.info(f"Deleted configuration for user {user_id}")
|
|
181
|
+
return True
|
|
182
|
+
return False
|
|
183
|
+
|
|
184
|
+
except Exception as e:
|
|
185
|
+
session.rollback()
|
|
186
|
+
logger.error(f"Error deleting user config for {user_id}: {e}")
|
|
187
|
+
return False
|
|
188
|
+
finally:
|
|
189
|
+
session.close()
|
|
190
|
+
|
|
191
|
+
def list_user_configs(self) -> dict[str, MOSConfig]:
|
|
192
|
+
"""List all user configurations.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Dict[str, MOSConfig]: Dictionary mapping user_id to MOSConfig.
|
|
196
|
+
"""
|
|
197
|
+
session = self._get_session()
|
|
198
|
+
try:
|
|
199
|
+
user_configs = session.query(UserConfig).all()
|
|
200
|
+
result = {}
|
|
201
|
+
|
|
202
|
+
for user_config in user_configs:
|
|
203
|
+
try:
|
|
204
|
+
config_dict = json.loads(user_config.config_data)
|
|
205
|
+
# Convert datetime strings back to datetime objects
|
|
206
|
+
config_dict = self._convert_datetime_strings(config_dict)
|
|
207
|
+
result[user_config.user_id] = MOSConfig(**config_dict)
|
|
208
|
+
except Exception as e:
|
|
209
|
+
logger.error(f"Error parsing config for user {user_config.user_id}: {e}")
|
|
210
|
+
continue
|
|
211
|
+
|
|
212
|
+
return result
|
|
213
|
+
|
|
214
|
+
except Exception as e:
|
|
215
|
+
logger.error(f"Error listing user configs: {e}")
|
|
216
|
+
return {}
|
|
217
|
+
finally:
|
|
218
|
+
session.close()
|
|
219
|
+
|
|
220
|
+
def create_user_with_config(
|
|
221
|
+
self, user_name: str, config: MOSConfig, role=None, user_id: str | None = None
|
|
222
|
+
) -> str:
|
|
223
|
+
"""Create a new user with configuration.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
user_name (str): Name of the user.
|
|
227
|
+
config (MOSConfig): The user's configuration.
|
|
228
|
+
role: User role (optional, uses default from UserManager).
|
|
229
|
+
user_id (str, optional): Custom user ID.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
str: The created user ID.
|
|
233
|
+
|
|
234
|
+
Raises:
|
|
235
|
+
ValueError: If user_name already exists.
|
|
236
|
+
"""
|
|
237
|
+
# Create user using parent method
|
|
238
|
+
created_user_id = self.create_user(user_name, role, user_id)
|
|
239
|
+
|
|
240
|
+
# Save configuration
|
|
241
|
+
if not self.save_user_config(created_user_id, config):
|
|
242
|
+
logger.error(f"Failed to save configuration for user {created_user_id}")
|
|
243
|
+
|
|
244
|
+
return created_user_id
|
|
245
|
+
|
|
246
|
+
def delete_user(self, user_id: str) -> bool:
|
|
247
|
+
"""Delete a user and their configuration.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
user_id (str): The user ID.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
bool: True if successful, False otherwise.
|
|
254
|
+
"""
|
|
255
|
+
# Delete configuration first
|
|
256
|
+
self.delete_user_config(user_id)
|
|
257
|
+
|
|
258
|
+
# Delete user using parent method
|
|
259
|
+
return super().delete_user(user_id)
|
|
260
|
+
|
|
261
|
+
def get_user_cube_access(self, user_id: str) -> list[str]:
|
|
262
|
+
"""Get list of cube IDs that a user has access to.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
user_id (str): The user ID.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
list[str]: List of cube IDs the user can access.
|
|
269
|
+
"""
|
|
270
|
+
cubes = self.get_user_cubes(user_id)
|
|
271
|
+
return [cube.cube_id for cube in cubes]
|