MemoryOS 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/METADATA +8 -2
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/RECORD +92 -69
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/WHEEL +1 -1
- memos/__init__.py +1 -1
- memos/api/client.py +109 -0
- memos/api/config.py +35 -8
- memos/api/context/dependencies.py +15 -66
- memos/api/middleware/request_context.py +63 -0
- memos/api/product_api.py +5 -2
- memos/api/product_models.py +107 -16
- memos/api/routers/product_router.py +62 -19
- memos/api/start_api.py +13 -0
- memos/configs/graph_db.py +4 -0
- memos/configs/mem_scheduler.py +38 -3
- memos/configs/memory.py +13 -0
- memos/configs/reranker.py +18 -0
- memos/context/context.py +255 -0
- memos/embedders/factory.py +2 -0
- memos/graph_dbs/base.py +4 -2
- memos/graph_dbs/nebular.py +368 -223
- memos/graph_dbs/neo4j.py +49 -13
- memos/graph_dbs/neo4j_community.py +13 -3
- memos/llms/factory.py +2 -0
- memos/llms/openai.py +74 -2
- memos/llms/vllm.py +2 -0
- memos/log.py +128 -4
- memos/mem_cube/general.py +3 -1
- memos/mem_os/core.py +89 -23
- memos/mem_os/main.py +3 -6
- memos/mem_os/product.py +418 -154
- memos/mem_os/utils/reference_utils.py +20 -0
- memos/mem_reader/factory.py +2 -0
- memos/mem_reader/simple_struct.py +204 -82
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +569 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +126 -56
- memos/mem_scheduler/general_modules/dispatcher.py +2 -2
- memos/mem_scheduler/general_modules/misc.py +99 -1
- memos/mem_scheduler/general_modules/scheduler_logger.py +17 -11
- memos/mem_scheduler/general_scheduler.py +40 -88
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +308 -0
- memos/mem_scheduler/{general_modules → memory_manage_modules}/retriever.py +34 -7
- memos/mem_scheduler/monitors/dispatcher_monitor.py +9 -8
- memos/mem_scheduler/monitors/general_monitor.py +119 -39
- memos/mem_scheduler/optimized_scheduler.py +124 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/base_model.py +635 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/scheduler_factory.py +2 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +96 -29
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +33 -0
- memos/mem_scheduler/utils/filter_utils.py +1 -1
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/mem_user/mysql_user_manager.py +4 -2
- memos/memories/activation/kv.py +2 -1
- memos/memories/textual/item.py +96 -17
- memos/memories/textual/naive.py +1 -1
- memos/memories/textual/tree.py +57 -3
- memos/memories/textual/tree_text_memory/organize/handler.py +4 -2
- memos/memories/textual/tree_text_memory/organize/manager.py +28 -14
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +1 -2
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +75 -23
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +10 -6
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -2
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +119 -21
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +172 -44
- memos/memories/textual/tree_text_memory/retrieve/utils.py +6 -4
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +5 -4
- memos/memos_tools/notification_utils.py +46 -0
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +22 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/parsers/factory.py +2 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +24 -0
- memos/reranker/concat.py +59 -0
- memos/reranker/cosine_local.py +96 -0
- memos/reranker/factory.py +48 -0
- memos/reranker/http_bge.py +312 -0
- memos/reranker/noop.py +16 -0
- memos/templates/mem_reader_prompts.py +289 -40
- memos/templates/mem_scheduler_prompts.py +242 -0
- memos/templates/mos_prompts.py +133 -60
- memos/types.py +4 -1
- memos/api/context/context.py +0 -147
- memos/mem_scheduler/mos_for_test_scheduler.py +0 -146
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/entry_points.txt +0 -0
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info/licenses}/LICENSE +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/rabbitmq_service.py +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/redis_service.py +0 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Request context middleware for automatic trace_id injection.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
|
|
7
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
8
|
+
from starlette.requests import Request
|
|
9
|
+
from starlette.responses import Response
|
|
10
|
+
|
|
11
|
+
import memos.log
|
|
12
|
+
|
|
13
|
+
from memos.context.context import RequestContext, generate_trace_id, set_request_context
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
logger = memos.log.get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def extract_trace_id_from_headers(request: Request) -> str | None:
|
|
20
|
+
"""Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
|
|
21
|
+
for header in ["g-trace-id", "x-trace-id", "trace-id"]:
|
|
22
|
+
if trace_id := request.headers.get(header):
|
|
23
|
+
return trace_id
|
|
24
|
+
return None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RequestContextMiddleware(BaseHTTPMiddleware):
|
|
28
|
+
"""
|
|
29
|
+
Middleware to automatically inject request context for every HTTP request.
|
|
30
|
+
|
|
31
|
+
This middleware:
|
|
32
|
+
1. Extracts trace_id from headers or generates a new one
|
|
33
|
+
2. Creates a RequestContext and sets it globally
|
|
34
|
+
3. Ensures the context is available throughout the request lifecycle
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
38
|
+
# Extract or generate trace_id
|
|
39
|
+
trace_id = extract_trace_id_from_headers(request) or generate_trace_id()
|
|
40
|
+
|
|
41
|
+
# Create and set request context
|
|
42
|
+
context = RequestContext(trace_id=trace_id, api_path=request.url.path)
|
|
43
|
+
set_request_context(context)
|
|
44
|
+
|
|
45
|
+
# Log request start with parameters
|
|
46
|
+
params_log = {}
|
|
47
|
+
|
|
48
|
+
# Get query parameters
|
|
49
|
+
if request.query_params:
|
|
50
|
+
params_log["query_params"] = dict(request.query_params)
|
|
51
|
+
|
|
52
|
+
logger.info(f"Request started: {request.method} {request.url.path}, {params_log}")
|
|
53
|
+
|
|
54
|
+
# Process the request
|
|
55
|
+
response = await call_next(request)
|
|
56
|
+
|
|
57
|
+
# Log request completion with output
|
|
58
|
+
logger.info(f"Request completed: {request.url.path}, status: {response.status_code}")
|
|
59
|
+
|
|
60
|
+
# Add trace_id to response headers for debugging
|
|
61
|
+
response.headers["x-trace-id"] = trace_id
|
|
62
|
+
|
|
63
|
+
return response
|
memos/api/product_api.py
CHANGED
|
@@ -3,6 +3,7 @@ import logging
|
|
|
3
3
|
from fastapi import FastAPI
|
|
4
4
|
|
|
5
5
|
from memos.api.exceptions import APIExceptionHandler
|
|
6
|
+
from memos.api.middleware.request_context import RequestContextMiddleware
|
|
6
7
|
from memos.api.routers.product_router import router as product_router
|
|
7
8
|
|
|
8
9
|
|
|
@@ -13,9 +14,10 @@ logger = logging.getLogger(__name__)
|
|
|
13
14
|
app = FastAPI(
|
|
14
15
|
title="MemOS Product REST APIs",
|
|
15
16
|
description="A REST API for managing multiple users with MemOS Product.",
|
|
16
|
-
version="1.0.
|
|
17
|
+
version="1.0.1",
|
|
17
18
|
)
|
|
18
19
|
|
|
20
|
+
app.add_middleware(RequestContextMiddleware)
|
|
19
21
|
# Include routers
|
|
20
22
|
app.include_router(product_router)
|
|
21
23
|
|
|
@@ -31,5 +33,6 @@ if __name__ == "__main__":
|
|
|
31
33
|
|
|
32
34
|
parser = argparse.ArgumentParser()
|
|
33
35
|
parser.add_argument("--port", type=int, default=8001)
|
|
36
|
+
parser.add_argument("--workers", type=int, default=32)
|
|
34
37
|
args = parser.parse_args()
|
|
35
|
-
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
|
38
|
+
uvicorn.run("memos.api.product_api:app", host="0.0.0.0", port=args.port, workers=args.workers)
|
memos/api/product_models.py
CHANGED
|
@@ -1,26 +1,14 @@
|
|
|
1
1
|
import uuid
|
|
2
2
|
|
|
3
|
-
from typing import Generic, Literal,
|
|
3
|
+
from typing import Generic, Literal, TypeVar
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel, Field
|
|
6
|
-
from typing_extensions import TypedDict
|
|
7
6
|
|
|
7
|
+
# Import message types from core types module
|
|
8
|
+
from memos.types import MessageDict
|
|
8
9
|
|
|
9
|
-
T = TypeVar("T")
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
# ─── Message Types ──────────────────────────────────────────────────────────────
|
|
13
|
-
|
|
14
|
-
# Chat message roles
|
|
15
|
-
MessageRole: TypeAlias = Literal["user", "assistant", "system"]
|
|
16
10
|
|
|
17
|
-
|
|
18
|
-
# Message structure
|
|
19
|
-
class MessageDict(TypedDict):
|
|
20
|
-
"""Typed dictionary for chat message dictionaries."""
|
|
21
|
-
|
|
22
|
-
role: MessageRole
|
|
23
|
-
content: str
|
|
11
|
+
T = TypeVar("T")
|
|
24
12
|
|
|
25
13
|
|
|
26
14
|
class BaseRequest(BaseModel):
|
|
@@ -42,6 +30,7 @@ class UserRegisterRequest(BaseRequest):
|
|
|
42
30
|
user_id: str = Field(
|
|
43
31
|
default_factory=lambda: str(uuid.uuid4()), description="User ID for registration"
|
|
44
32
|
)
|
|
33
|
+
mem_cube_id: str | None = Field(None, description="Cube ID for registration")
|
|
45
34
|
user_name: str | None = Field(None, description="User name for registration")
|
|
46
35
|
interests: str | None = Field(None, description="User interests")
|
|
47
36
|
|
|
@@ -84,6 +73,23 @@ class ChatRequest(BaseRequest):
|
|
|
84
73
|
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
|
|
85
74
|
history: list[MessageDict] | None = Field(None, description="Chat history")
|
|
86
75
|
internet_search: bool = Field(True, description="Whether to use internet search")
|
|
76
|
+
moscube: bool = Field(False, description="Whether to use MemOSCube")
|
|
77
|
+
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ChatCompleteRequest(BaseRequest):
|
|
81
|
+
"""Request model for chat operations."""
|
|
82
|
+
|
|
83
|
+
user_id: str = Field(..., description="User ID")
|
|
84
|
+
query: str = Field(..., description="Chat query message")
|
|
85
|
+
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
|
|
86
|
+
history: list[MessageDict] | None = Field(None, description="Chat history")
|
|
87
|
+
internet_search: bool = Field(False, description="Whether to use internet search")
|
|
88
|
+
moscube: bool = Field(False, description="Whether to use MemOSCube")
|
|
89
|
+
base_prompt: str | None = Field(None, description="Base prompt to use for chat")
|
|
90
|
+
top_k: int = Field(10, description="Number of results to return")
|
|
91
|
+
threshold: float = Field(0.5, description="Threshold for filtering references")
|
|
92
|
+
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
|
|
87
93
|
|
|
88
94
|
|
|
89
95
|
class UserCreate(BaseRequest):
|
|
@@ -145,6 +151,7 @@ class MemoryCreateRequest(BaseRequest):
|
|
|
145
151
|
mem_cube_id: str | None = Field(None, description="Cube ID")
|
|
146
152
|
source: str | None = Field(None, description="Source of the memory")
|
|
147
153
|
user_profile: bool = Field(False, description="User profile memory")
|
|
154
|
+
session_id: str | None = Field(None, description="Session id")
|
|
148
155
|
|
|
149
156
|
|
|
150
157
|
class SearchRequest(BaseRequest):
|
|
@@ -154,6 +161,7 @@ class SearchRequest(BaseRequest):
|
|
|
154
161
|
query: str = Field(..., description="Search query")
|
|
155
162
|
mem_cube_id: str | None = Field(None, description="Cube ID to search in")
|
|
156
163
|
top_k: int = Field(10, description="Number of results to return")
|
|
164
|
+
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
|
|
157
165
|
|
|
158
166
|
|
|
159
167
|
class SuggestionRequest(BaseRequest):
|
|
@@ -161,3 +169,86 @@ class SuggestionRequest(BaseRequest):
|
|
|
161
169
|
|
|
162
170
|
user_id: str = Field(..., description="User ID")
|
|
163
171
|
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")
|
|
172
|
+
message: list[MessageDict] | None = Field(None, description="List of messages to store.")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# ─── MemOS Client Response Models ──────────────────────────────────────────────
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class MessageDetail(BaseModel):
|
|
179
|
+
"""Individual message detail model based on actual API response."""
|
|
180
|
+
|
|
181
|
+
model_config = {"extra": "allow"}
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class MemoryDetail(BaseModel):
|
|
185
|
+
"""Individual memory detail model based on actual API response."""
|
|
186
|
+
|
|
187
|
+
model_config = {"extra": "allow"}
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class GetMessagesData(BaseModel):
|
|
191
|
+
"""Data model for get messages response based on actual API."""
|
|
192
|
+
|
|
193
|
+
message_detail_list: list[MessageDetail] = Field(
|
|
194
|
+
default_factory=list, alias="memory_detail_list", description="List of message details"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class SearchMemoryData(BaseModel):
|
|
199
|
+
"""Data model for search memory response based on actual API."""
|
|
200
|
+
|
|
201
|
+
memory_detail_list: list[MemoryDetail] = Field(
|
|
202
|
+
default_factory=list, alias="memory_detail_list", description="List of memory details"
|
|
203
|
+
)
|
|
204
|
+
message_detail_list: list[MessageDetail] | None = Field(
|
|
205
|
+
None, alias="message_detail_list", description="List of message details (usually None)"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class AddMessageData(BaseModel):
|
|
210
|
+
"""Data model for add message response based on actual API."""
|
|
211
|
+
|
|
212
|
+
success: bool = Field(..., description="Operation success status")
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ─── MemOS Response Models (Similar to OpenAI ChatCompletion) ──────────────────
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class MemOSGetMessagesResponse(BaseModel):
|
|
219
|
+
"""Response model for get messages operation based on actual API."""
|
|
220
|
+
|
|
221
|
+
code: int = Field(..., description="Response status code")
|
|
222
|
+
message: str = Field(..., description="Response message")
|
|
223
|
+
data: GetMessagesData = Field(..., description="Messages data")
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def messages(self) -> list[MessageDetail]:
|
|
227
|
+
"""Convenient access to message list."""
|
|
228
|
+
return self.data.message_detail_list
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class MemOSSearchResponse(BaseModel):
|
|
232
|
+
"""Response model for search memory operation based on actual API."""
|
|
233
|
+
|
|
234
|
+
code: int = Field(..., description="Response status code")
|
|
235
|
+
message: str = Field(..., description="Response message")
|
|
236
|
+
data: SearchMemoryData = Field(..., description="Search results data")
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def memories(self) -> list[MemoryDetail]:
|
|
240
|
+
"""Convenient access to memory list."""
|
|
241
|
+
return self.data.memory_detail_list
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class MemOSAddResponse(BaseModel):
|
|
245
|
+
"""Response model for add message operation based on actual API."""
|
|
246
|
+
|
|
247
|
+
code: int = Field(..., description="Response status code")
|
|
248
|
+
message: str = Field(..., description="Response message")
|
|
249
|
+
data: AddMessageData = Field(..., description="Add operation data")
|
|
250
|
+
|
|
251
|
+
@property
|
|
252
|
+
def success(self) -> bool:
|
|
253
|
+
"""Convenient access to success status."""
|
|
254
|
+
return self.data.success
|
|
@@ -1,17 +1,14 @@
|
|
|
1
1
|
import json
|
|
2
|
-
import
|
|
2
|
+
import time
|
|
3
3
|
import traceback
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
from typing import Annotated
|
|
7
|
-
|
|
8
|
-
from fastapi import APIRouter, Depends, HTTPException
|
|
5
|
+
from fastapi import APIRouter, HTTPException
|
|
9
6
|
from fastapi.responses import StreamingResponse
|
|
10
7
|
|
|
11
8
|
from memos.api.config import APIConfig
|
|
12
|
-
from memos.api.context.dependencies import G, get_g_object
|
|
13
9
|
from memos.api.product_models import (
|
|
14
10
|
BaseResponse,
|
|
11
|
+
ChatCompleteRequest,
|
|
15
12
|
ChatRequest,
|
|
16
13
|
GetMemoryRequest,
|
|
17
14
|
MemoryCreateRequest,
|
|
@@ -25,11 +22,12 @@ from memos.api.product_models import (
|
|
|
25
22
|
UserRegisterResponse,
|
|
26
23
|
)
|
|
27
24
|
from memos.configs.mem_os import MOSConfig
|
|
25
|
+
from memos.log import get_logger
|
|
28
26
|
from memos.mem_os.product import MOSProduct
|
|
29
27
|
from memos.memos_tools.notification_service import get_error_bot_function, get_online_bot_function
|
|
30
28
|
|
|
31
29
|
|
|
32
|
-
logger =
|
|
30
|
+
logger = get_logger(__name__)
|
|
33
31
|
|
|
34
32
|
router = APIRouter(prefix="/product", tags=["Product API"])
|
|
35
33
|
|
|
@@ -78,24 +76,19 @@ def set_config(config):
|
|
|
78
76
|
|
|
79
77
|
|
|
80
78
|
@router.post("/users/register", summary="Register a new user", response_model=UserRegisterResponse)
|
|
81
|
-
def register_user(user_req: UserRegisterRequest
|
|
79
|
+
def register_user(user_req: UserRegisterRequest):
|
|
82
80
|
"""Register a new user with configuration and default cube."""
|
|
83
81
|
try:
|
|
84
|
-
# Set request-related information in g object
|
|
85
|
-
g.user_id = user_req.user_id
|
|
86
|
-
g.action = "user_register"
|
|
87
|
-
g.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
88
|
-
|
|
89
|
-
logger.info(f"Starting user registration for user_id: {user_req.user_id}")
|
|
90
|
-
logger.info(f"Request trace_id: {g.trace_id}")
|
|
91
|
-
logger.info(f"Request timestamp: {g.timestamp}")
|
|
92
|
-
|
|
93
82
|
# Get configuration for the user
|
|
83
|
+
time_start_register = time.time()
|
|
94
84
|
user_config, default_mem_cube = APIConfig.create_user_config(
|
|
95
85
|
user_name=user_req.user_id, user_id=user_req.user_id
|
|
96
86
|
)
|
|
97
87
|
logger.info(f"user_config: {user_config.model_dump(mode='json')}")
|
|
98
88
|
logger.info(f"default_mem_cube: {default_mem_cube.config.model_dump(mode='json')}")
|
|
89
|
+
logger.info(
|
|
90
|
+
f"time register api : create user config time user_id: {user_req.user_id} time is: {time.time() - time_start_register}"
|
|
91
|
+
)
|
|
99
92
|
mos_product = get_mos_product_instance()
|
|
100
93
|
|
|
101
94
|
# Register user with default config and mem cube
|
|
@@ -105,8 +98,11 @@ def register_user(user_req: UserRegisterRequest, g: Annotated[G, Depends(get_g_o
|
|
|
105
98
|
interests=user_req.interests,
|
|
106
99
|
config=user_config,
|
|
107
100
|
default_mem_cube=default_mem_cube,
|
|
101
|
+
mem_cube_id=user_req.mem_cube_id,
|
|
102
|
+
)
|
|
103
|
+
logger.info(
|
|
104
|
+
f"time register api : register time user_id: {user_req.user_id} time is: {time.time() - time_start_register}"
|
|
108
105
|
)
|
|
109
|
-
|
|
110
106
|
if result["status"] == "success":
|
|
111
107
|
return UserRegisterResponse(
|
|
112
108
|
message="User registered successfully",
|
|
@@ -148,7 +144,9 @@ def get_suggestion_queries_post(suggestion_req: SuggestionRequest):
|
|
|
148
144
|
try:
|
|
149
145
|
mos_product = get_mos_product_instance()
|
|
150
146
|
suggestions = mos_product.get_suggestion_query(
|
|
151
|
-
user_id=suggestion_req.user_id,
|
|
147
|
+
user_id=suggestion_req.user_id,
|
|
148
|
+
language=suggestion_req.language,
|
|
149
|
+
message=suggestion_req.message,
|
|
152
150
|
)
|
|
153
151
|
return SuggestionResponse(
|
|
154
152
|
message="Suggestions retrieved successfully", data={"query": suggestions}
|
|
@@ -191,6 +189,7 @@ def get_all_memories(memory_req: GetMemoryRequest):
|
|
|
191
189
|
def create_memory(memory_req: MemoryCreateRequest):
|
|
192
190
|
"""Create a new memory for a specific user."""
|
|
193
191
|
try:
|
|
192
|
+
time_start_add = time.time()
|
|
194
193
|
mos_product = get_mos_product_instance()
|
|
195
194
|
mos_product.add(
|
|
196
195
|
user_id=memory_req.user_id,
|
|
@@ -200,6 +199,10 @@ def create_memory(memory_req: MemoryCreateRequest):
|
|
|
200
199
|
mem_cube_id=memory_req.mem_cube_id,
|
|
201
200
|
source=memory_req.source,
|
|
202
201
|
user_profile=memory_req.user_profile,
|
|
202
|
+
session_id=memory_req.session_id,
|
|
203
|
+
)
|
|
204
|
+
logger.info(
|
|
205
|
+
f"time add api : add time user_id: {memory_req.user_id} time is: {time.time() - time_start_add}"
|
|
203
206
|
)
|
|
204
207
|
return SimpleResponse(message="Memory created successfully")
|
|
205
208
|
|
|
@@ -214,12 +217,17 @@ def create_memory(memory_req: MemoryCreateRequest):
|
|
|
214
217
|
def search_memories(search_req: SearchRequest):
|
|
215
218
|
"""Search memories for a specific user."""
|
|
216
219
|
try:
|
|
220
|
+
time_start_search = time.time()
|
|
217
221
|
mos_product = get_mos_product_instance()
|
|
218
222
|
result = mos_product.search(
|
|
219
223
|
query=search_req.query,
|
|
220
224
|
user_id=search_req.user_id,
|
|
221
225
|
install_cube_ids=[search_req.mem_cube_id] if search_req.mem_cube_id else None,
|
|
222
226
|
top_k=search_req.top_k,
|
|
227
|
+
session_id=search_req.session_id,
|
|
228
|
+
)
|
|
229
|
+
logger.info(
|
|
230
|
+
f"time search api : add time user_id: {search_req.user_id} time is: {time.time() - time_start_search}"
|
|
223
231
|
)
|
|
224
232
|
return SearchResponse(message="Search completed successfully", data=result)
|
|
225
233
|
|
|
@@ -246,6 +254,8 @@ def chat(chat_req: ChatRequest):
|
|
|
246
254
|
cube_id=chat_req.mem_cube_id,
|
|
247
255
|
history=chat_req.history,
|
|
248
256
|
internet_search=chat_req.internet_search,
|
|
257
|
+
moscube=chat_req.moscube,
|
|
258
|
+
session_id=chat_req.session_id,
|
|
249
259
|
)
|
|
250
260
|
|
|
251
261
|
except Exception as e:
|
|
@@ -273,6 +283,39 @@ def chat(chat_req: ChatRequest):
|
|
|
273
283
|
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
|
|
274
284
|
|
|
275
285
|
|
|
286
|
+
@router.post("/chat/complete", summary="Chat with MemOS (Complete Response)")
|
|
287
|
+
def chat_complete(chat_req: ChatCompleteRequest):
|
|
288
|
+
"""Chat with MemOS for a specific user. Returns complete response (non-streaming)."""
|
|
289
|
+
try:
|
|
290
|
+
mos_product = get_mos_product_instance()
|
|
291
|
+
|
|
292
|
+
# Collect all responses from the generator
|
|
293
|
+
content, references = mos_product.chat(
|
|
294
|
+
query=chat_req.query,
|
|
295
|
+
user_id=chat_req.user_id,
|
|
296
|
+
cube_id=chat_req.mem_cube_id,
|
|
297
|
+
history=chat_req.history,
|
|
298
|
+
internet_search=chat_req.internet_search,
|
|
299
|
+
moscube=chat_req.moscube,
|
|
300
|
+
base_prompt=chat_req.base_prompt,
|
|
301
|
+
top_k=chat_req.top_k,
|
|
302
|
+
threshold=chat_req.threshold,
|
|
303
|
+
session_id=chat_req.session_id,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
# Return the complete response
|
|
307
|
+
return {
|
|
308
|
+
"message": "Chat completed successfully",
|
|
309
|
+
"data": {"response": content, "references": references},
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
except ValueError as err:
|
|
313
|
+
raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
|
|
314
|
+
except Exception as err:
|
|
315
|
+
logger.error(f"Failed to start chat: {traceback.format_exc()}")
|
|
316
|
+
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
|
|
317
|
+
|
|
318
|
+
|
|
276
319
|
@router.get("/users", summary="List all users", response_model=BaseResponse[list])
|
|
277
320
|
def list_users():
|
|
278
321
|
"""List all registered users."""
|
memos/api/start_api.py
CHANGED
|
@@ -9,6 +9,7 @@ from fastapi.requests import Request
|
|
|
9
9
|
from fastapi.responses import JSONResponse, RedirectResponse
|
|
10
10
|
from pydantic import BaseModel, Field
|
|
11
11
|
|
|
12
|
+
from memos.api.middleware.request_context import RequestContextMiddleware
|
|
12
13
|
from memos.configs.mem_os import MOSConfig
|
|
13
14
|
from memos.mem_os.main import MOS
|
|
14
15
|
from memos.mem_user.user_manager import UserManager, UserRole
|
|
@@ -78,6 +79,8 @@ app = FastAPI(
|
|
|
78
79
|
version="1.0.0",
|
|
79
80
|
)
|
|
80
81
|
|
|
82
|
+
app.add_middleware(RequestContextMiddleware)
|
|
83
|
+
|
|
81
84
|
|
|
82
85
|
class BaseRequest(BaseModel):
|
|
83
86
|
"""Base model for all requests."""
|
|
@@ -418,3 +421,13 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|
|
418
421
|
status_code=500,
|
|
419
422
|
content={"code": 500, "message": str(exc), "data": None},
|
|
420
423
|
)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
if __name__ == "__main__":
|
|
427
|
+
import argparse
|
|
428
|
+
|
|
429
|
+
parser = argparse.ArgumentParser()
|
|
430
|
+
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
|
431
|
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
|
|
432
|
+
parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development")
|
|
433
|
+
args = parser.parse_args()
|
memos/configs/graph_db.py
CHANGED
|
@@ -140,6 +140,10 @@ class NebulaGraphDBConfig(BaseGraphDBConfig):
|
|
|
140
140
|
"If False: use a single shared database with logical isolation by user_name."
|
|
141
141
|
),
|
|
142
142
|
)
|
|
143
|
+
max_client: int = Field(
|
|
144
|
+
default=1000,
|
|
145
|
+
description=("max_client"),
|
|
146
|
+
)
|
|
143
147
|
embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding")
|
|
144
148
|
|
|
145
149
|
@model_validator(mode="after")
|
memos/configs/mem_scheduler.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import Any, ClassVar
|
|
|
6
6
|
from pydantic import ConfigDict, Field, field_validator, model_validator
|
|
7
7
|
|
|
8
8
|
from memos.configs.base import BaseConfig
|
|
9
|
-
from memos.mem_scheduler.general_modules.misc import DictConversionMixin
|
|
9
|
+
from memos.mem_scheduler.general_modules.misc import DictConversionMixin, EnvConfigMixin
|
|
10
10
|
from memos.mem_scheduler.schemas.general_schemas import (
|
|
11
11
|
BASE_DIR,
|
|
12
12
|
DEFAULT_ACT_MEM_DUMP_PATH,
|
|
@@ -64,6 +64,19 @@ class GeneralSchedulerConfig(BaseSchedulerConfig):
|
|
|
64
64
|
default=20, description="Capacity of the activation memory monitor"
|
|
65
65
|
)
|
|
66
66
|
|
|
67
|
+
# Database configuration for ORM persistence
|
|
68
|
+
db_path: str | None = Field(
|
|
69
|
+
default=None,
|
|
70
|
+
description="Path to SQLite database file for ORM persistence. If None, uses default scheduler_orm.db",
|
|
71
|
+
)
|
|
72
|
+
db_url: str | None = Field(
|
|
73
|
+
default=None,
|
|
74
|
+
description="Database URL for ORM persistence (e.g., mysql://user:pass@host/db). Takes precedence over db_path",
|
|
75
|
+
)
|
|
76
|
+
enable_orm_persistence: bool = Field(
|
|
77
|
+
default=True, description="Whether to enable ORM-based persistence for monitors"
|
|
78
|
+
)
|
|
79
|
+
|
|
67
80
|
|
|
68
81
|
class SchedulerConfigFactory(BaseConfig):
|
|
69
82
|
"""Factory class for creating scheduler configurations."""
|
|
@@ -74,6 +87,7 @@ class SchedulerConfigFactory(BaseConfig):
|
|
|
74
87
|
model_config = ConfigDict(extra="forbid", strict=True)
|
|
75
88
|
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
76
89
|
"general_scheduler": GeneralSchedulerConfig,
|
|
90
|
+
"optimized_scheduler": GeneralSchedulerConfig, # optimized_scheduler uses same config as general_scheduler
|
|
77
91
|
}
|
|
78
92
|
|
|
79
93
|
@field_validator("backend")
|
|
@@ -94,6 +108,8 @@ class SchedulerConfigFactory(BaseConfig):
|
|
|
94
108
|
# ************************* Auth *************************
|
|
95
109
|
class RabbitMQConfig(
|
|
96
110
|
BaseConfig,
|
|
111
|
+
DictConversionMixin,
|
|
112
|
+
EnvConfigMixin,
|
|
97
113
|
):
|
|
98
114
|
host_name: str = Field(default="", description="Endpoint for RabbitMQ instance access")
|
|
99
115
|
user_name: str = Field(default="", description="Static username for RabbitMQ instance")
|
|
@@ -110,7 +126,7 @@ class RabbitMQConfig(
|
|
|
110
126
|
)
|
|
111
127
|
|
|
112
128
|
|
|
113
|
-
class GraphDBAuthConfig(BaseConfig):
|
|
129
|
+
class GraphDBAuthConfig(BaseConfig, DictConversionMixin, EnvConfigMixin):
|
|
114
130
|
uri: str = Field(
|
|
115
131
|
default="bolt://localhost:7687",
|
|
116
132
|
description="URI for graph database access (e.g., bolt://host:port)",
|
|
@@ -127,7 +143,7 @@ class GraphDBAuthConfig(BaseConfig):
|
|
|
127
143
|
)
|
|
128
144
|
|
|
129
145
|
|
|
130
|
-
class OpenAIConfig(BaseConfig):
|
|
146
|
+
class OpenAIConfig(BaseConfig, DictConversionMixin, EnvConfigMixin):
|
|
131
147
|
api_key: str = Field(default="", description="API key for OpenAI service")
|
|
132
148
|
base_url: str = Field(default="", description="Base URL for API endpoint")
|
|
133
149
|
default_model: str = Field(default="", description="Default model to use")
|
|
@@ -183,6 +199,25 @@ class AuthConfig(BaseConfig, DictConversionMixin):
|
|
|
183
199
|
"Please use YAML (.yaml, .yml) or JSON (.json) files."
|
|
184
200
|
)
|
|
185
201
|
|
|
202
|
+
@classmethod
|
|
203
|
+
def from_local_env(cls) -> "AuthConfig":
|
|
204
|
+
"""Creates an AuthConfig instance by loading configuration from environment variables.
|
|
205
|
+
|
|
206
|
+
This method loads configuration for all nested components (RabbitMQ, OpenAI, GraphDB)
|
|
207
|
+
from their respective environment variables using each component's specific prefix.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
AuthConfig: Configured instance with values from environment variables
|
|
211
|
+
|
|
212
|
+
Raises:
|
|
213
|
+
ValueError: If any required environment variables are missing
|
|
214
|
+
"""
|
|
215
|
+
return cls(
|
|
216
|
+
rabbitmq=RabbitMQConfig.from_env(),
|
|
217
|
+
openai=OpenAIConfig.from_env(),
|
|
218
|
+
graph_db=GraphDBAuthConfig.from_env(),
|
|
219
|
+
)
|
|
220
|
+
|
|
186
221
|
def set_openai_config_to_environment(self):
|
|
187
222
|
# Set environment variables
|
|
188
223
|
os.environ["OPENAI_API_KEY"] = self.openai.api_key
|
memos/configs/memory.py
CHANGED
|
@@ -7,6 +7,7 @@ from memos.configs.embedder import EmbedderConfigFactory
|
|
|
7
7
|
from memos.configs.graph_db import GraphDBConfigFactory
|
|
8
8
|
from memos.configs.internet_retriever import InternetRetrieverConfigFactory
|
|
9
9
|
from memos.configs.llm import LLMConfigFactory
|
|
10
|
+
from memos.configs.reranker import RerankerConfigFactory
|
|
10
11
|
from memos.configs.vec_db import VectorDBConfigFactory
|
|
11
12
|
from memos.exceptions import ConfigurationError
|
|
12
13
|
|
|
@@ -151,6 +152,10 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
|
|
|
151
152
|
default_factory=EmbedderConfigFactory,
|
|
152
153
|
description="Embedder configuration for the memory embedding",
|
|
153
154
|
)
|
|
155
|
+
reranker: RerankerConfigFactory | None = Field(
|
|
156
|
+
None,
|
|
157
|
+
description="Reranker configuration (optional, defaults to cosine_local).",
|
|
158
|
+
)
|
|
154
159
|
graph_db: GraphDBConfigFactory = Field(
|
|
155
160
|
...,
|
|
156
161
|
default_factory=GraphDBConfigFactory,
|
|
@@ -166,6 +171,14 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
|
|
|
166
171
|
description="Optional description for this memory configuration.",
|
|
167
172
|
)
|
|
168
173
|
|
|
174
|
+
memory_size: dict[str, Any] | None = Field(
|
|
175
|
+
default=None,
|
|
176
|
+
description=(
|
|
177
|
+
"Maximum item counts per memory bucket, e.g.: "
|
|
178
|
+
'{"WorkingMemory": 20, "LongTermMemory": 10000, "UserMemory": 10000}'
|
|
179
|
+
),
|
|
180
|
+
)
|
|
181
|
+
|
|
169
182
|
|
|
170
183
|
# ─── 3. Global Memory Config Factory ──────────────────────────────────────────
|
|
171
184
|
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# memos/configs/reranker.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RerankerConfigFactory(BaseModel):
|
|
10
|
+
"""
|
|
11
|
+
{
|
|
12
|
+
"backend": "http_bge" | "cosine_local" | "noop",
|
|
13
|
+
"config": { ... backend-specific ... }
|
|
14
|
+
}
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
backend: str = Field(..., description="Reranker backend id")
|
|
18
|
+
config: dict[str, Any] = Field(default_factory=dict, description="Backend-specific options")
|