MemoryOS 1.0.1__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (82) hide show
  1. {memoryos-1.0.1.dist-info → memoryos-1.1.2.dist-info}/METADATA +7 -2
  2. {memoryos-1.0.1.dist-info → memoryos-1.1.2.dist-info}/RECORD +79 -65
  3. {memoryos-1.0.1.dist-info → memoryos-1.1.2.dist-info}/WHEEL +1 -1
  4. memos/__init__.py +1 -1
  5. memos/api/client.py +109 -0
  6. memos/api/config.py +11 -9
  7. memos/api/context/dependencies.py +15 -55
  8. memos/api/middleware/request_context.py +9 -40
  9. memos/api/product_api.py +2 -3
  10. memos/api/product_models.py +91 -16
  11. memos/api/routers/product_router.py +23 -16
  12. memos/api/start_api.py +10 -0
  13. memos/configs/graph_db.py +4 -0
  14. memos/configs/mem_scheduler.py +38 -3
  15. memos/context/context.py +255 -0
  16. memos/embedders/factory.py +2 -0
  17. memos/graph_dbs/nebular.py +230 -232
  18. memos/graph_dbs/neo4j.py +35 -1
  19. memos/graph_dbs/neo4j_community.py +7 -0
  20. memos/llms/factory.py +2 -0
  21. memos/llms/openai.py +74 -2
  22. memos/log.py +27 -15
  23. memos/mem_cube/general.py +3 -1
  24. memos/mem_os/core.py +60 -22
  25. memos/mem_os/main.py +3 -6
  26. memos/mem_os/product.py +35 -11
  27. memos/mem_reader/factory.py +2 -0
  28. memos/mem_reader/simple_struct.py +127 -74
  29. memos/mem_scheduler/analyzer/__init__.py +0 -0
  30. memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +569 -0
  31. memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
  32. memos/mem_scheduler/base_scheduler.py +126 -56
  33. memos/mem_scheduler/general_modules/dispatcher.py +2 -2
  34. memos/mem_scheduler/general_modules/misc.py +99 -1
  35. memos/mem_scheduler/general_modules/scheduler_logger.py +17 -11
  36. memos/mem_scheduler/general_scheduler.py +40 -88
  37. memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
  38. memos/mem_scheduler/memory_manage_modules/memory_filter.py +308 -0
  39. memos/mem_scheduler/{general_modules → memory_manage_modules}/retriever.py +34 -7
  40. memos/mem_scheduler/monitors/dispatcher_monitor.py +9 -8
  41. memos/mem_scheduler/monitors/general_monitor.py +119 -39
  42. memos/mem_scheduler/optimized_scheduler.py +124 -0
  43. memos/mem_scheduler/orm_modules/__init__.py +0 -0
  44. memos/mem_scheduler/orm_modules/base_model.py +635 -0
  45. memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
  46. memos/mem_scheduler/scheduler_factory.py +2 -0
  47. memos/mem_scheduler/schemas/monitor_schemas.py +96 -29
  48. memos/mem_scheduler/utils/config_utils.py +100 -0
  49. memos/mem_scheduler/utils/db_utils.py +33 -0
  50. memos/mem_scheduler/utils/filter_utils.py +1 -1
  51. memos/mem_scheduler/webservice_modules/__init__.py +0 -0
  52. memos/memories/activation/kv.py +2 -1
  53. memos/memories/textual/item.py +95 -16
  54. memos/memories/textual/naive.py +1 -1
  55. memos/memories/textual/tree.py +27 -3
  56. memos/memories/textual/tree_text_memory/organize/handler.py +4 -2
  57. memos/memories/textual/tree_text_memory/organize/manager.py +28 -14
  58. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +1 -2
  59. memos/memories/textual/tree_text_memory/organize/reorganizer.py +75 -23
  60. memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +7 -5
  61. memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -2
  62. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
  63. memos/memories/textual/tree_text_memory/retrieve/recall.py +70 -22
  64. memos/memories/textual/tree_text_memory/retrieve/searcher.py +101 -33
  65. memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +5 -4
  66. memos/memos_tools/singleton.py +174 -0
  67. memos/memos_tools/thread_safe_dict.py +22 -0
  68. memos/memos_tools/thread_safe_dict_segment.py +382 -0
  69. memos/parsers/factory.py +2 -0
  70. memos/reranker/concat.py +59 -0
  71. memos/reranker/cosine_local.py +1 -0
  72. memos/reranker/factory.py +5 -0
  73. memos/reranker/http_bge.py +225 -12
  74. memos/templates/mem_scheduler_prompts.py +242 -0
  75. memos/types.py +4 -1
  76. memos/api/context/context.py +0 -147
  77. memos/api/context/context_thread.py +0 -96
  78. memos/mem_scheduler/mos_for_test_scheduler.py +0 -146
  79. {memoryos-1.0.1.dist-info → memoryos-1.1.2.dist-info}/entry_points.txt +0 -0
  80. {memoryos-1.0.1.dist-info → memoryos-1.1.2.dist-info/licenses}/LICENSE +0 -0
  81. /memos/mem_scheduler/{general_modules → webservice_modules}/rabbitmq_service.py +0 -0
  82. /memos/mem_scheduler/{general_modules → webservice_modules}/redis_service.py +0 -0
memos/graph_dbs/neo4j.py CHANGED
@@ -1,3 +1,4 @@
1
+ import json
1
2
  import time
2
3
 
3
4
  from datetime import datetime
@@ -174,6 +175,12 @@ class Neo4jGraphDB(BaseGraphDB):
174
175
  n.updated_at = datetime($updated_at),
175
176
  n += $metadata
176
177
  """
178
+
179
+ # serialization
180
+ if metadata["sources"]:
181
+ for idx in range(len(metadata["sources"])):
182
+ metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
183
+
177
184
  with self.driver.session(database=self.db_name) as session:
178
185
  session.run(
179
186
  query,
@@ -606,6 +613,7 @@ class Neo4jGraphDB(BaseGraphDB):
606
613
  scope: str | None = None,
607
614
  status: str | None = None,
608
615
  threshold: float | None = None,
616
+ search_filter: dict | None = None,
609
617
  **kwargs,
610
618
  ) -> list[dict]:
611
619
  """
@@ -618,6 +626,8 @@ class Neo4jGraphDB(BaseGraphDB):
618
626
  status (str, optional): Node status filter (e.g., 'active', 'archived').
619
627
  If provided, restricts results to nodes with matching status.
620
628
  threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
629
+ search_filter (dict, optional): Additional metadata filters for search results.
630
+ Keys should match node properties, values are the expected values.
621
631
 
622
632
  Returns:
623
633
  list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
@@ -627,6 +637,7 @@ class Neo4jGraphDB(BaseGraphDB):
627
637
  - If scope is provided, it restricts results to nodes with matching memory_type.
628
638
  - If 'status' is provided, only nodes with the matching status will be returned.
629
639
  - If threshold is provided, only results with score >= threshold will be returned.
640
+ - If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
630
641
  - Typical use case: restrict to 'status = activated' to avoid
631
642
  matching archived or merged nodes.
632
643
  """
@@ -639,6 +650,12 @@ class Neo4jGraphDB(BaseGraphDB):
639
650
  if not self.config.use_multi_db and self.config.user_name:
640
651
  where_clauses.append("node.user_name = $user_name")
641
652
 
653
+ # Add search_filter conditions
654
+ if search_filter:
655
+ for key, _ in search_filter.items():
656
+ param_name = f"filter_{key}"
657
+ where_clauses.append(f"node.{key} = ${param_name}")
658
+
642
659
  where_clause = ""
643
660
  if where_clauses:
644
661
  where_clause = "WHERE " + " AND ".join(where_clauses)
@@ -650,7 +667,8 @@ class Neo4jGraphDB(BaseGraphDB):
650
667
  RETURN node.id AS id, score
651
668
  """
652
669
 
653
- parameters = {"embedding": vector, "k": top_k, "scope": scope}
670
+ parameters = {"embedding": vector, "k": top_k}
671
+
654
672
  if scope:
655
673
  parameters["scope"] = scope
656
674
  if status:
@@ -661,6 +679,12 @@ class Neo4jGraphDB(BaseGraphDB):
661
679
  else:
662
680
  parameters["user_name"] = self.config.user_name
663
681
 
682
+ # Add search_filter parameters
683
+ if search_filter:
684
+ for key, value in search_filter.items():
685
+ param_name = f"filter_{key}"
686
+ parameters[param_name] = value
687
+
664
688
  with self.driver.session(database=self.db_name) as session:
665
689
  result = session.run(query, parameters)
666
690
  records = [{"id": record["id"], "score": record["score"]} for record in result]
@@ -1111,4 +1135,14 @@ class Neo4jGraphDB(BaseGraphDB):
1111
1135
  node[time_field] = node[time_field].isoformat()
1112
1136
  node.pop("user_name", None)
1113
1137
 
1138
+ # serialization
1139
+ if node["sources"]:
1140
+ for idx in range(len(node["sources"])):
1141
+ if not (
1142
+ isinstance(node["sources"][idx], str)
1143
+ and node["sources"][idx][0] == "{"
1144
+ and node["sources"][idx][0] == "}"
1145
+ ):
1146
+ break
1147
+ node["sources"][idx] = json.loads(node["sources"][idx])
1114
1148
  return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}
@@ -129,6 +129,7 @@ class Neo4jCommunityGraphDB(Neo4jGraphDB):
129
129
  scope: str | None = None,
130
130
  status: str | None = None,
131
131
  threshold: float | None = None,
132
+ search_filter: dict | None = None,
132
133
  **kwargs,
133
134
  ) -> list[dict]:
134
135
  """
@@ -140,6 +141,7 @@ class Neo4jCommunityGraphDB(Neo4jGraphDB):
140
141
  scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
141
142
  status (str, optional): Node status filter (e.g., 'activated', 'archived').
142
143
  threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
144
+ search_filter (dict, optional): Additional metadata filters to apply.
143
145
 
144
146
  Returns:
145
147
  list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
@@ -149,6 +151,7 @@ class Neo4jCommunityGraphDB(Neo4jGraphDB):
149
151
  - If 'scope' is provided, it restricts results to nodes with matching memory_type.
150
152
  - If 'status' is provided, it further filters nodes by status.
151
153
  - If 'threshold' is provided, only results with score >= threshold will be returned.
154
+ - If 'search_filter' is provided, it applies additional metadata-based filtering.
152
155
  - The returned IDs can be used to fetch full node data from Neo4j if needed.
153
156
  """
154
157
  # Build VecDB filter
@@ -163,6 +166,10 @@ class Neo4jCommunityGraphDB(Neo4jGraphDB):
163
166
  else:
164
167
  vec_filter["user_name"] = self.config.user_name
165
168
 
169
+ # Add search_filter conditions
170
+ if search_filter:
171
+ vec_filter.update(search_filter)
172
+
166
173
  # Perform vector search
167
174
  results = self.vec_db.search(query_vector=vector, top_k=top_k, filter=vec_filter)
168
175
 
memos/llms/factory.py CHANGED
@@ -9,6 +9,7 @@ from memos.llms.ollama import OllamaLLM
9
9
  from memos.llms.openai import AzureLLM, OpenAILLM
10
10
  from memos.llms.qwen import QwenLLM
11
11
  from memos.llms.vllm import VLLMLLM
12
+ from memos.memos_tools.singleton import singleton_factory
12
13
 
13
14
 
14
15
  class LLMFactory(BaseLLM):
@@ -26,6 +27,7 @@ class LLMFactory(BaseLLM):
26
27
  }
27
28
 
28
29
  @classmethod
30
+ @singleton_factory()
29
31
  def from_config(cls, config_factory: LLMConfigFactory) -> BaseLLM:
30
32
  backend = config_factory.backend
31
33
  if backend not in cls.backend_to_class:
memos/llms/openai.py CHANGED
@@ -1,4 +1,8 @@
1
+ import hashlib
2
+ import json
3
+
1
4
  from collections.abc import Generator
5
+ from typing import ClassVar
2
6
 
3
7
  import openai
4
8
 
@@ -13,11 +17,44 @@ logger = get_logger(__name__)
13
17
 
14
18
 
15
19
  class OpenAILLM(BaseLLM):
16
- """OpenAI LLM class."""
20
+ """OpenAI LLM class with singleton pattern."""
21
+
22
+ _instances: ClassVar[dict] = {} # Class variable to store instances
23
+
24
+ def __new__(cls, config: OpenAILLMConfig) -> "OpenAILLM":
25
+ config_hash = cls._get_config_hash(config)
26
+
27
+ if config_hash not in cls._instances:
28
+ logger.info(f"Creating new OpenAI LLM instance for config hash: {config_hash}")
29
+ instance = super().__new__(cls)
30
+ cls._instances[config_hash] = instance
31
+ else:
32
+ logger.info(f"Reusing existing OpenAI LLM instance for config hash: {config_hash}")
33
+
34
+ return cls._instances[config_hash]
17
35
 
18
36
  def __init__(self, config: OpenAILLMConfig):
37
+ # Avoid duplicate initialization
38
+ if hasattr(self, "_initialized"):
39
+ return
40
+
19
41
  self.config = config
20
42
  self.client = openai.Client(api_key=config.api_key, base_url=config.api_base)
43
+ self._initialized = True
44
+ logger.info("OpenAI LLM instance initialized")
45
+
46
+ @classmethod
47
+ def _get_config_hash(cls, config: OpenAILLMConfig) -> str:
48
+ """Generate hash value of configuration"""
49
+ config_dict = config.model_dump()
50
+ config_str = json.dumps(config_dict, sort_keys=True)
51
+ return hashlib.md5(config_str.encode()).hexdigest()
52
+
53
+ @classmethod
54
+ def clear_cache(cls):
55
+ """Clear all cached instances"""
56
+ cls._instances.clear()
57
+ logger.info("OpenAI LLM instance cache cleared")
21
58
 
22
59
  def generate(self, messages: MessageList) -> str:
23
60
  """Generate a response from OpenAI LLM."""
@@ -71,15 +108,50 @@ class OpenAILLM(BaseLLM):
71
108
 
72
109
 
73
110
  class AzureLLM(BaseLLM):
74
- """Azure OpenAI LLM class."""
111
+ """Azure OpenAI LLM class with singleton pattern."""
112
+
113
+ _instances: ClassVar[dict] = {} # Class variable to store instances
114
+
115
+ def __new__(cls, config: AzureLLMConfig):
116
+ # Generate hash value of config as cache key
117
+ config_hash = cls._get_config_hash(config)
118
+
119
+ if config_hash not in cls._instances:
120
+ logger.info(f"Creating new Azure LLM instance for config hash: {config_hash}")
121
+ instance = super().__new__(cls)
122
+ cls._instances[config_hash] = instance
123
+ else:
124
+ logger.info(f"Reusing existing Azure LLM instance for config hash: {config_hash}")
125
+
126
+ return cls._instances[config_hash]
75
127
 
76
128
  def __init__(self, config: AzureLLMConfig):
129
+ # Avoid duplicate initialization
130
+ if hasattr(self, "_initialized"):
131
+ return
132
+
77
133
  self.config = config
78
134
  self.client = openai.AzureOpenAI(
79
135
  azure_endpoint=config.base_url,
80
136
  api_version=config.api_version,
81
137
  api_key=config.api_key,
82
138
  )
139
+ self._initialized = True
140
+ logger.info("Azure LLM instance initialized")
141
+
142
+ @classmethod
143
+ def _get_config_hash(cls, config: AzureLLMConfig) -> str:
144
+ """Generate hash value of configuration"""
145
+ # Convert config to dict and sort to ensure consistency
146
+ config_dict = config.model_dump()
147
+ config_str = json.dumps(config_dict, sort_keys=True)
148
+ return hashlib.md5(config_str.encode()).hexdigest()
149
+
150
+ @classmethod
151
+ def clear_cache(cls):
152
+ """Clear all cached instances"""
153
+ cls._instances.clear()
154
+ logger.info("Azure LLM instance cache cleared")
83
155
 
84
156
  def generate(self, messages: MessageList) -> str:
85
157
  """Generate a response from Azure OpenAI LLM."""
memos/log.py CHANGED
@@ -2,7 +2,9 @@ import atexit
2
2
  import logging
3
3
  import os
4
4
  import threading
5
+ import time
5
6
 
7
+ from concurrent.futures import ThreadPoolExecutor
6
8
  from logging.config import dictConfig
7
9
  from pathlib import Path
8
10
  from sys import stdout
@@ -12,8 +14,7 @@ import requests
12
14
  from dotenv import load_dotenv
13
15
 
14
16
  from memos import settings
15
- from memos.api.context.context import get_current_trace_id
16
- from memos.api.context.context_thread import ContextThreadPoolExecutor
17
+ from memos.context.context import get_current_api_path, get_current_trace_id
17
18
 
18
19
 
19
20
  # Load environment variables
@@ -39,9 +40,9 @@ class TraceIDFilter(logging.Filter):
39
40
  def filter(self, record):
40
41
  try:
41
42
  trace_id = get_current_trace_id()
42
- record.trace_id = trace_id if trace_id else "no-trace-id"
43
+ record.trace_id = trace_id if trace_id else "trace-id"
43
44
  except Exception:
44
- record.trace_id = "no-trace-id"
45
+ record.trace_id = "trace-id"
45
46
  return True
46
47
 
47
48
 
@@ -65,7 +66,7 @@ class CustomLoggerRequestHandler(logging.Handler):
65
66
  if not self._initialized:
66
67
  super().__init__()
67
68
  workers = int(os.getenv("CUSTOM_LOGGER_WORKERS", "2"))
68
- self._executor = ContextThreadPoolExecutor(
69
+ self._executor = ThreadPoolExecutor(
69
70
  max_workers=workers, thread_name_prefix="log_sender"
70
71
  )
71
72
  self._is_shutting_down = threading.Event()
@@ -78,21 +79,32 @@ class CustomLoggerRequestHandler(logging.Handler):
78
79
  if os.getenv("CUSTOM_LOGGER_URL") is None or self._is_shutting_down.is_set():
79
80
  return
80
81
 
82
+ # Only process INFO and ERROR level logs
83
+ if record.levelno < logging.INFO: # Skip DEBUG and lower
84
+ return
85
+
81
86
  try:
82
- trace_id = get_current_trace_id() or "no-trace-id"
83
- self._executor.submit(self._send_log_sync, record.getMessage(), trace_id)
87
+ trace_id = get_current_trace_id() or "trace-id"
88
+ api_path = get_current_api_path()
89
+ if api_path is not None:
90
+ self._executor.submit(self._send_log_sync, record.getMessage(), trace_id, api_path)
84
91
  except Exception as e:
85
92
  if not self._is_shutting_down.is_set():
86
93
  print(f"Error sending log: {e}")
87
94
 
88
- def _send_log_sync(self, message, trace_id):
95
+ def _send_log_sync(self, message, trace_id, api_path):
89
96
  """Send log message synchronously in a separate thread"""
90
97
  try:
91
98
  logger_url = os.getenv("CUSTOM_LOGGER_URL")
92
99
  token = os.getenv("CUSTOM_LOGGER_TOKEN")
93
100
 
94
101
  headers = {"Content-Type": "application/json"}
95
- post_content = {"message": message, "trace_id": trace_id}
102
+ post_content = {
103
+ "message": message,
104
+ "trace_id": trace_id,
105
+ "action": api_path,
106
+ "current_time": round(time.time(), 3),
107
+ }
96
108
 
97
109
  # Add auth token if exists
98
110
  if token:
@@ -139,7 +151,7 @@ LOGGING_CONFIG = {
139
151
  "format": "[%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
140
152
  },
141
153
  "simplified": {
142
- "format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s | %(message)s"
154
+ "format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s"
143
155
  },
144
156
  },
145
157
  "filters": {
@@ -151,7 +163,7 @@ LOGGING_CONFIG = {
151
163
  "level": selected_log_level,
152
164
  "class": "logging.StreamHandler",
153
165
  "stream": stdout,
154
- "formatter": "simplified",
166
+ "formatter": "no_datetime",
155
167
  "filters": ["package_tree_filter", "trace_id_filter"],
156
168
  },
157
169
  "file": {
@@ -160,18 +172,18 @@ LOGGING_CONFIG = {
160
172
  "filename": _setup_logfile(),
161
173
  "maxBytes": 1024**2 * 10,
162
174
  "backupCount": 10,
163
- "formatter": "simplified",
175
+ "formatter": "standard",
164
176
  "filters": ["trace_id_filter"],
165
177
  },
166
178
  "custom_logger": {
167
- "level": selected_log_level,
179
+ "level": "INFO",
168
180
  "class": "memos.log.CustomLoggerRequestHandler",
169
181
  "formatter": "simplified",
170
182
  },
171
183
  },
172
184
  "root": { # Root logger handles all logs
173
- "level": selected_log_level,
174
- "handlers": ["console", "file", "custom_logger"],
185
+ "level": logging.DEBUG if settings.DEBUG else logging.INFO,
186
+ "handlers": ["console", "file"],
175
187
  },
176
188
  "loggers": {
177
189
  "memos": {
memos/mem_cube/general.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import time
2
3
 
3
4
  from typing import Literal
4
5
 
@@ -23,11 +24,13 @@ class GeneralMemCube(BaseMemCube):
23
24
  def __init__(self, config: GeneralMemCubeConfig):
24
25
  """Initialize the MemCube with a configuration."""
25
26
  self.config = config
27
+ time_start = time.time()
26
28
  self._text_mem: BaseTextMemory | None = (
27
29
  MemoryFactory.from_config(config.text_mem)
28
30
  if config.text_mem.backend != "uninitialized"
29
31
  else None
30
32
  )
33
+ logger.info(f"init_text_mem in {time.time() - time_start} seconds")
31
34
  self._act_mem: BaseActMemory | None = (
32
35
  MemoryFactory.from_config(config.act_mem)
33
36
  if config.act_mem.backend != "uninitialized"
@@ -137,7 +140,6 @@ class GeneralMemCube(BaseMemCube):
137
140
  if default_config is not None:
138
141
  config = merge_config_with_default(config, default_config)
139
142
  logger.info(f"Applied default config to cube {config.cube_id}")
140
-
141
143
  mem_cube = GeneralMemCube(config)
142
144
  mem_cube.load(dir, memory_types)
143
145
  return mem_cube
memos/mem_os/core.py CHANGED
@@ -24,7 +24,7 @@ from memos.mem_user.user_manager import UserManager, UserRole
24
24
  from memos.memories.activation.item import ActivationMemoryItem
25
25
  from memos.memories.parametric.item import ParametricMemoryItem
26
26
  from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
27
- from memos.memos_tools.thread_safe_dict import ThreadSafeDict
27
+ from memos.memos_tools.thread_safe_dict_segment import OptimizedThreadSafeDict
28
28
  from memos.templates.mos_prompts import QUERY_REWRITING_PROMPT
29
29
  from memos.types import ChatHistory, MessageList, MOSSearchResult
30
30
 
@@ -47,8 +47,8 @@ class MOSCore:
47
47
  self.mem_reader = MemReaderFactory.from_config(config.mem_reader)
48
48
  self.chat_history_manager: dict[str, ChatHistory] = {}
49
49
  # use thread safe dict for multi-user product-server scenario
50
- self.mem_cubes: ThreadSafeDict[str, GeneralMemCube] = (
51
- ThreadSafeDict() if user_manager is not None else {}
50
+ self.mem_cubes: OptimizedThreadSafeDict[str, GeneralMemCube] = (
51
+ OptimizedThreadSafeDict() if user_manager is not None else {}
52
52
  )
53
53
  self._register_chat_history()
54
54
 
@@ -125,12 +125,16 @@ class MOSCore:
125
125
  "missing required 'llm' attribute"
126
126
  )
127
127
  self._mem_scheduler.initialize_modules(
128
- chat_llm=self.chat_llm, process_llm=self.chat_llm
128
+ chat_llm=self.chat_llm,
129
+ process_llm=self.chat_llm,
130
+ db_engine=self.user_manager.engine,
129
131
  )
130
132
  else:
131
133
  # Configure scheduler general_modules
132
134
  self._mem_scheduler.initialize_modules(
133
- chat_llm=self.chat_llm, process_llm=self.mem_reader.llm
135
+ chat_llm=self.chat_llm,
136
+ process_llm=self.mem_reader.llm,
137
+ db_engine=self.user_manager.engine,
134
138
  )
135
139
  self._mem_scheduler.start()
136
140
  return self._mem_scheduler
@@ -182,13 +186,13 @@ class MOSCore:
182
186
  logger.info(f"close reorganizer for {mem_cube.text_mem.config.cube_id}")
183
187
  mem_cube.text_mem.memory_manager.wait_reorganizer()
184
188
 
185
- def _register_chat_history(self, user_id: str | None = None) -> None:
189
+ def _register_chat_history(
190
+ self, user_id: str | None = None, session_id: str | None = None
191
+ ) -> None:
186
192
  """Initialize chat history with user ID."""
187
- if user_id is None:
188
- user_id = self.user_id
189
193
  self.chat_history_manager[user_id] = ChatHistory(
190
- user_id=user_id,
191
- session_id=self.session_id,
194
+ user_id=user_id if user_id is not None else self.user_id,
195
+ session_id=session_id if session_id is not None else self.session_id,
192
196
  created_at=datetime.utcnow(),
193
197
  total_messages=0,
194
198
  chat_history=[],
@@ -483,14 +487,14 @@ class MOSCore:
483
487
  self.mem_cubes[mem_cube_id] = mem_cube_name_or_path
484
488
  logger.info(f"register new cube {mem_cube_id} for user {target_user_id}")
485
489
  elif os.path.exists(mem_cube_name_or_path):
486
- self.mem_cubes[mem_cube_id] = GeneralMemCube.init_from_dir(mem_cube_name_or_path)
490
+ mem_cube_obj = GeneralMemCube.init_from_dir(mem_cube_name_or_path)
491
+ self.mem_cubes[mem_cube_id] = mem_cube_obj
487
492
  else:
488
493
  logger.warning(
489
494
  f"MemCube {mem_cube_name_or_path} does not exist, try to init from remote repo."
490
495
  )
491
- self.mem_cubes[mem_cube_id] = GeneralMemCube.init_from_remote_repo(
492
- mem_cube_name_or_path
493
- )
496
+ mem_cube_obj = GeneralMemCube.init_from_remote_repo(mem_cube_name_or_path)
497
+ self.mem_cubes[mem_cube_id] = mem_cube_obj
494
498
  # Check if cube already exists in database
495
499
  existing_cube = self.user_manager.get_cube(mem_cube_id)
496
500
 
@@ -547,6 +551,7 @@ class MOSCore:
547
551
  mode: Literal["fast", "fine"] = "fast",
548
552
  internet_search: bool = False,
549
553
  moscube: bool = False,
554
+ session_id: str | None = None,
550
555
  **kwargs,
551
556
  ) -> MOSSearchResult:
552
557
  """
@@ -562,7 +567,9 @@ class MOSCore:
562
567
  Returns:
563
568
  MemoryResult: A dictionary containing the search results.
564
569
  """
570
+ target_session_id = session_id if session_id is not None else self.session_id
565
571
  target_user_id = user_id if user_id is not None else self.user_id
572
+
566
573
  self._validate_user_exists(target_user_id)
567
574
  # Get all cubes accessible by the target user
568
575
  accessible_cubes = self.user_manager.get_user_cubes(target_user_id)
@@ -575,6 +582,11 @@ class MOSCore:
575
582
  self._register_chat_history(target_user_id)
576
583
  chat_history = self.chat_history_manager[target_user_id]
577
584
 
585
+ # Create search filter if session_id is provided
586
+ search_filter = None
587
+ if session_id is not None:
588
+ search_filter = {"session_id": session_id}
589
+
578
590
  result: MOSSearchResult = {
579
591
  "text_mem": [],
580
592
  "act_mem": [],
@@ -584,9 +596,13 @@ class MOSCore:
584
596
  install_cube_ids = user_cube_ids
585
597
  # create exist dict in mem_cubes and avoid one search slow
586
598
  tmp_mem_cubes = {}
599
+ time_start_cube_get = time.time()
587
600
  for mem_cube_id in install_cube_ids:
588
601
  if mem_cube_id in self.mem_cubes:
589
602
  tmp_mem_cubes[mem_cube_id] = self.mem_cubes.get(mem_cube_id)
603
+ logger.info(
604
+ f"time search: transform cube time user_id: {target_user_id} time is: {time.time() - time_start_cube_get}"
605
+ )
590
606
 
591
607
  for mem_cube_id, mem_cube in tmp_mem_cubes.items():
592
608
  if (
@@ -602,10 +618,11 @@ class MOSCore:
602
618
  manual_close_internet=not internet_search,
603
619
  info={
604
620
  "user_id": target_user_id,
605
- "session_id": self.session_id,
621
+ "session_id": target_session_id,
606
622
  "chat_history": chat_history.chat_history,
607
623
  },
608
624
  moscube=moscube,
625
+ search_filter=search_filter,
609
626
  )
610
627
  result["text_mem"].append({"cube_id": mem_cube_id, "memories": memories})
611
628
  logger.info(
@@ -624,6 +641,8 @@ class MOSCore:
624
641
  doc_path: str | None = None,
625
642
  mem_cube_id: str | None = None,
626
643
  user_id: str | None = None,
644
+ session_id: str | None = None,
645
+ **kwargs,
627
646
  ) -> None:
628
647
  """
629
648
  Add textual memories to a MemCube.
@@ -636,11 +655,16 @@ class MOSCore:
636
655
  If None, the default MemCube for the user is used.
637
656
  user_id (str, optional): The identifier of the user to add the memories to.
638
657
  If None, the default user is used.
658
+ session_id (str, optional): session_id
639
659
  """
640
660
  # user input messages
641
661
  assert (messages is not None) or (memory_content is not None) or (doc_path is not None), (
642
662
  "messages_or_doc_path or memory_content or doc_path must be provided."
643
663
  )
664
+ # TODO: asure that session_id is a valid string
665
+ time_start = time.time()
666
+
667
+ target_session_id = session_id if session_id else self.session_id
644
668
  target_user_id = user_id if user_id is not None else self.user_id
645
669
  if mem_cube_id is None:
646
670
  # Try to find a default cube for the user
@@ -652,18 +676,29 @@ class MOSCore:
652
676
  mem_cube_id = accessible_cubes[0].cube_id # TODO not only first
653
677
  else:
654
678
  self._validate_cube_access(target_user_id, mem_cube_id)
679
+ logger.info(
680
+ f"time add: get mem_cube_id time user_id: {target_user_id} time is: {time.time() - time_start}"
681
+ )
655
682
 
683
+ time_start_0 = time.time()
656
684
  if mem_cube_id not in self.mem_cubes:
657
685
  raise ValueError(f"MemCube '{mem_cube_id}' is not loaded. Please register.")
686
+ logger.info(
687
+ f"time add: get mem_cube_id check in mem_cubes time user_id: {target_user_id} time is: {time.time() - time_start_0}"
688
+ )
689
+ time_start_1 = time.time()
658
690
  if (
659
691
  (messages is not None)
660
692
  and self.config.enable_textual_memory
661
693
  and self.mem_cubes[mem_cube_id].text_mem
662
694
  ):
695
+ logger.info(
696
+ f"time add: messages is not None and enable_textual_memory and text_mem is not None time user_id: {target_user_id} time is: {time.time() - time_start_1}"
697
+ )
663
698
  if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text":
664
699
  add_memory = []
665
700
  metadata = TextualMemoryMetadata(
666
- user_id=target_user_id, session_id=self.session_id, source="conversation"
701
+ user_id=target_user_id, session_id=target_session_id, source="conversation"
667
702
  )
668
703
  for message in messages:
669
704
  add_memory.append(
@@ -672,12 +707,15 @@ class MOSCore:
672
707
  self.mem_cubes[mem_cube_id].text_mem.add(add_memory)
673
708
  else:
674
709
  messages_list = [messages]
710
+ time_start_2 = time.time()
675
711
  memories = self.mem_reader.get_memory(
676
712
  messages_list,
677
713
  type="chat",
678
- info={"user_id": target_user_id, "session_id": self.session_id},
714
+ info={"user_id": target_user_id, "session_id": target_session_id},
715
+ )
716
+ logger.info(
717
+ f"time add: get mem_reader time user_id: {target_user_id} time is: {time.time() - time_start_2}"
679
718
  )
680
-
681
719
  mem_ids = []
682
720
  for mem in memories:
683
721
  mem_id_list: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(mem)
@@ -707,7 +745,7 @@ class MOSCore:
707
745
  ):
708
746
  if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text":
709
747
  metadata = TextualMemoryMetadata(
710
- user_id=self.user_id, session_id=self.session_id, source="conversation"
748
+ user_id=target_user_id, session_id=target_session_id, source="conversation"
711
749
  )
712
750
  self.mem_cubes[mem_cube_id].text_mem.add(
713
751
  [TextualMemoryItem(memory=memory_content, metadata=metadata)]
@@ -719,7 +757,7 @@ class MOSCore:
719
757
  memories = self.mem_reader.get_memory(
720
758
  messages_list,
721
759
  type="chat",
722
- info={"user_id": target_user_id, "session_id": self.session_id},
760
+ info={"user_id": target_user_id, "session_id": target_session_id},
723
761
  )
724
762
 
725
763
  mem_ids = []
@@ -753,7 +791,7 @@ class MOSCore:
753
791
  doc_memories = self.mem_reader.get_memory(
754
792
  documents,
755
793
  type="doc",
756
- info={"user_id": target_user_id, "session_id": self.session_id},
794
+ info={"user_id": target_user_id, "session_id": target_session_id},
757
795
  )
758
796
 
759
797
  mem_ids = []
@@ -986,7 +1024,7 @@ class MOSCore:
986
1024
 
987
1025
  def get_user_info(self) -> dict[str, Any]:
988
1026
  """Get current user information including accessible cubes.
989
-
1027
+ TODO: maybe input user_id
990
1028
  Returns:
991
1029
  dict: User information and accessible cubes.
992
1030
  """
memos/mem_os/main.py CHANGED
@@ -5,6 +5,7 @@ import os
5
5
  from typing import Any
6
6
 
7
7
  from memos.configs.mem_os import MOSConfig
8
+ from memos.context.context import ContextThreadPoolExecutor
8
9
  from memos.llms.factory import LLMFactory
9
10
  from memos.log import get_logger
10
11
  from memos.mem_os.core import MOSCore
@@ -487,9 +488,7 @@ class MOS(MOSCore):
487
488
 
488
489
  # Generate answers in parallel while maintaining order
489
490
  sub_answers = [None] * len(sub_questions)
490
- with concurrent.futures.ThreadPoolExecutor(
491
- max_workers=min(len(sub_questions), 10)
492
- ) as executor:
491
+ with ContextThreadPoolExecutor(max_workers=min(len(sub_questions), 10)) as executor:
493
492
  # Submit all answer generation tasks
494
493
  future_to_index = {
495
494
  executor.submit(generate_answer_for_question, i, question): i
@@ -552,9 +551,7 @@ class MOS(MOSCore):
552
551
 
553
552
  # Search in parallel while maintaining order
554
553
  all_memories = []
555
- with concurrent.futures.ThreadPoolExecutor(
556
- max_workers=min(len(sub_questions), 10)
557
- ) as executor:
554
+ with ContextThreadPoolExecutor(max_workers=min(len(sub_questions), 10)) as executor:
558
555
  # Submit all search tasks and keep track of their order
559
556
  future_to_index = {
560
557
  executor.submit(search_single_question, question): i