MemoryOS 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/METADATA +7 -2
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/RECORD +79 -65
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/WHEEL +1 -1
- memos/__init__.py +1 -1
- memos/api/client.py +109 -0
- memos/api/config.py +11 -9
- memos/api/context/dependencies.py +15 -55
- memos/api/middleware/request_context.py +9 -40
- memos/api/product_api.py +2 -3
- memos/api/product_models.py +91 -16
- memos/api/routers/product_router.py +23 -16
- memos/api/start_api.py +10 -0
- memos/configs/graph_db.py +4 -0
- memos/configs/mem_scheduler.py +38 -3
- memos/context/context.py +255 -0
- memos/embedders/factory.py +2 -0
- memos/graph_dbs/nebular.py +230 -232
- memos/graph_dbs/neo4j.py +35 -1
- memos/graph_dbs/neo4j_community.py +7 -0
- memos/llms/factory.py +2 -0
- memos/llms/openai.py +74 -2
- memos/log.py +27 -15
- memos/mem_cube/general.py +3 -1
- memos/mem_os/core.py +60 -22
- memos/mem_os/main.py +3 -6
- memos/mem_os/product.py +35 -11
- memos/mem_reader/factory.py +2 -0
- memos/mem_reader/simple_struct.py +127 -74
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +569 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +126 -56
- memos/mem_scheduler/general_modules/dispatcher.py +2 -2
- memos/mem_scheduler/general_modules/misc.py +99 -1
- memos/mem_scheduler/general_modules/scheduler_logger.py +17 -11
- memos/mem_scheduler/general_scheduler.py +40 -88
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +308 -0
- memos/mem_scheduler/{general_modules → memory_manage_modules}/retriever.py +34 -7
- memos/mem_scheduler/monitors/dispatcher_monitor.py +9 -8
- memos/mem_scheduler/monitors/general_monitor.py +119 -39
- memos/mem_scheduler/optimized_scheduler.py +124 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/base_model.py +635 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/scheduler_factory.py +2 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +96 -29
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +33 -0
- memos/mem_scheduler/utils/filter_utils.py +1 -1
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/memories/activation/kv.py +2 -1
- memos/memories/textual/item.py +95 -16
- memos/memories/textual/naive.py +1 -1
- memos/memories/textual/tree.py +27 -3
- memos/memories/textual/tree_text_memory/organize/handler.py +4 -2
- memos/memories/textual/tree_text_memory/organize/manager.py +28 -14
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +1 -2
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +75 -23
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +7 -5
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -2
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +70 -22
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +101 -33
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +5 -4
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +22 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/parsers/factory.py +2 -0
- memos/reranker/concat.py +59 -0
- memos/reranker/cosine_local.py +1 -0
- memos/reranker/factory.py +5 -0
- memos/reranker/http_bge.py +225 -12
- memos/templates/mem_scheduler_prompts.py +242 -0
- memos/types.py +4 -1
- memos/api/context/context.py +0 -147
- memos/api/context/context_thread.py +0 -96
- memos/mem_scheduler/mos_for_test_scheduler.py +0 -146
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/entry_points.txt +0 -0
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info/licenses}/LICENSE +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/rabbitmq_service.py +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/redis_service.py +0 -0
|
@@ -2,40 +2,25 @@
|
|
|
2
2
|
Request context middleware for automatic trace_id injection.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
import logging
|
|
6
|
-
import os
|
|
7
|
-
|
|
8
5
|
from collections.abc import Callable
|
|
9
6
|
|
|
10
7
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
11
8
|
from starlette.requests import Request
|
|
12
9
|
from starlette.responses import Response
|
|
13
10
|
|
|
14
|
-
|
|
15
|
-
|
|
11
|
+
import memos.log
|
|
16
12
|
|
|
17
|
-
|
|
13
|
+
from memos.context.context import RequestContext, generate_trace_id, set_request_context
|
|
18
14
|
|
|
19
15
|
|
|
20
|
-
|
|
21
|
-
"""Generate a random trace_id."""
|
|
22
|
-
return os.urandom(16).hex()
|
|
16
|
+
logger = memos.log.get_logger(__name__)
|
|
23
17
|
|
|
24
18
|
|
|
25
19
|
def extract_trace_id_from_headers(request: Request) -> str | None:
|
|
26
20
|
"""Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
trace_id = request.headers.get("x-trace-id")
|
|
32
|
-
if trace_id:
|
|
33
|
-
return trace_id
|
|
34
|
-
|
|
35
|
-
trace_id = request.headers.get("trace-id")
|
|
36
|
-
if trace_id:
|
|
37
|
-
return trace_id
|
|
38
|
-
|
|
21
|
+
for header in ["g-trace-id", "x-trace-id", "trace-id"]:
|
|
22
|
+
if trace_id := request.headers.get(header):
|
|
23
|
+
return trace_id
|
|
39
24
|
return None
|
|
40
25
|
|
|
41
26
|
|
|
@@ -51,19 +36,12 @@ class RequestContextMiddleware(BaseHTTPMiddleware):
|
|
|
51
36
|
|
|
52
37
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
53
38
|
# Extract or generate trace_id
|
|
54
|
-
trace_id = extract_trace_id_from_headers(request)
|
|
55
|
-
if not trace_id:
|
|
56
|
-
trace_id = generate_trace_id()
|
|
39
|
+
trace_id = extract_trace_id_from_headers(request) or generate_trace_id()
|
|
57
40
|
|
|
58
41
|
# Create and set request context
|
|
59
|
-
context = RequestContext(trace_id=trace_id)
|
|
42
|
+
context = RequestContext(trace_id=trace_id, api_path=request.url.path)
|
|
60
43
|
set_request_context(context)
|
|
61
44
|
|
|
62
|
-
# Add request metadata to context
|
|
63
|
-
context.set("method", request.method)
|
|
64
|
-
context.set("path", request.url.path)
|
|
65
|
-
context.set("client_ip", request.client.host if request.client else None)
|
|
66
|
-
|
|
67
45
|
# Log request start with parameters
|
|
68
46
|
params_log = {}
|
|
69
47
|
|
|
@@ -71,16 +49,7 @@ class RequestContextMiddleware(BaseHTTPMiddleware):
|
|
|
71
49
|
if request.query_params:
|
|
72
50
|
params_log["query_params"] = dict(request.query_params)
|
|
73
51
|
|
|
74
|
-
|
|
75
|
-
try:
|
|
76
|
-
params_log = await request.json()
|
|
77
|
-
except Exception as e:
|
|
78
|
-
logger.error(f"Error getting request body: {e}")
|
|
79
|
-
# If body is not JSON or empty, ignore it
|
|
80
|
-
|
|
81
|
-
logger.info(
|
|
82
|
-
f"Request started: {request.method} {request.url.path} - Parameters: {params_log}"
|
|
83
|
-
)
|
|
52
|
+
logger.info(f"Request started: {request.method} {request.url.path}, {params_log}")
|
|
84
53
|
|
|
85
54
|
# Process the request
|
|
86
55
|
response = await call_next(request)
|
memos/api/product_api.py
CHANGED
|
@@ -17,9 +17,7 @@ app = FastAPI(
|
|
|
17
17
|
version="1.0.1",
|
|
18
18
|
)
|
|
19
19
|
|
|
20
|
-
# Add request context middleware (must be added first)
|
|
21
20
|
app.add_middleware(RequestContextMiddleware)
|
|
22
|
-
|
|
23
21
|
# Include routers
|
|
24
22
|
app.include_router(product_router)
|
|
25
23
|
|
|
@@ -35,5 +33,6 @@ if __name__ == "__main__":
|
|
|
35
33
|
|
|
36
34
|
parser = argparse.ArgumentParser()
|
|
37
35
|
parser.add_argument("--port", type=int, default=8001)
|
|
36
|
+
parser.add_argument("--workers", type=int, default=32)
|
|
38
37
|
args = parser.parse_args()
|
|
39
|
-
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
|
38
|
+
uvicorn.run("memos.api.product_api:app", host="0.0.0.0", port=args.port, workers=args.workers)
|
memos/api/product_models.py
CHANGED
|
@@ -1,26 +1,14 @@
|
|
|
1
1
|
import uuid
|
|
2
2
|
|
|
3
|
-
from typing import Generic, Literal,
|
|
3
|
+
from typing import Generic, Literal, TypeVar
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel, Field
|
|
6
|
-
from typing_extensions import TypedDict
|
|
7
6
|
|
|
7
|
+
# Import message types from core types module
|
|
8
|
+
from memos.types import MessageDict
|
|
8
9
|
|
|
9
|
-
T = TypeVar("T")
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
# ─── Message Types ──────────────────────────────────────────────────────────────
|
|
13
|
-
|
|
14
|
-
# Chat message roles
|
|
15
|
-
MessageRole: TypeAlias = Literal["user", "assistant", "system"]
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
# Message structure
|
|
19
|
-
class MessageDict(TypedDict):
|
|
20
|
-
"""Typed dictionary for chat message dictionaries."""
|
|
21
10
|
|
|
22
|
-
|
|
23
|
-
content: str
|
|
11
|
+
T = TypeVar("T")
|
|
24
12
|
|
|
25
13
|
|
|
26
14
|
class BaseRequest(BaseModel):
|
|
@@ -42,6 +30,7 @@ class UserRegisterRequest(BaseRequest):
|
|
|
42
30
|
user_id: str = Field(
|
|
43
31
|
default_factory=lambda: str(uuid.uuid4()), description="User ID for registration"
|
|
44
32
|
)
|
|
33
|
+
mem_cube_id: str | None = Field(None, description="Cube ID for registration")
|
|
45
34
|
user_name: str | None = Field(None, description="User name for registration")
|
|
46
35
|
interests: str | None = Field(None, description="User interests")
|
|
47
36
|
|
|
@@ -85,6 +74,7 @@ class ChatRequest(BaseRequest):
|
|
|
85
74
|
history: list[MessageDict] | None = Field(None, description="Chat history")
|
|
86
75
|
internet_search: bool = Field(True, description="Whether to use internet search")
|
|
87
76
|
moscube: bool = Field(False, description="Whether to use MemOSCube")
|
|
77
|
+
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
|
|
88
78
|
|
|
89
79
|
|
|
90
80
|
class ChatCompleteRequest(BaseRequest):
|
|
@@ -99,6 +89,7 @@ class ChatCompleteRequest(BaseRequest):
|
|
|
99
89
|
base_prompt: str | None = Field(None, description="Base prompt to use for chat")
|
|
100
90
|
top_k: int = Field(10, description="Number of results to return")
|
|
101
91
|
threshold: float = Field(0.5, description="Threshold for filtering references")
|
|
92
|
+
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
|
|
102
93
|
|
|
103
94
|
|
|
104
95
|
class UserCreate(BaseRequest):
|
|
@@ -160,6 +151,7 @@ class MemoryCreateRequest(BaseRequest):
|
|
|
160
151
|
mem_cube_id: str | None = Field(None, description="Cube ID")
|
|
161
152
|
source: str | None = Field(None, description="Source of the memory")
|
|
162
153
|
user_profile: bool = Field(False, description="User profile memory")
|
|
154
|
+
session_id: str | None = Field(None, description="Session id")
|
|
163
155
|
|
|
164
156
|
|
|
165
157
|
class SearchRequest(BaseRequest):
|
|
@@ -169,6 +161,7 @@ class SearchRequest(BaseRequest):
|
|
|
169
161
|
query: str = Field(..., description="Search query")
|
|
170
162
|
mem_cube_id: str | None = Field(None, description="Cube ID to search in")
|
|
171
163
|
top_k: int = Field(10, description="Number of results to return")
|
|
164
|
+
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
|
|
172
165
|
|
|
173
166
|
|
|
174
167
|
class SuggestionRequest(BaseRequest):
|
|
@@ -177,3 +170,85 @@ class SuggestionRequest(BaseRequest):
|
|
|
177
170
|
user_id: str = Field(..., description="User ID")
|
|
178
171
|
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")
|
|
179
172
|
message: list[MessageDict] | None = Field(None, description="List of messages to store.")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# ─── MemOS Client Response Models ──────────────────────────────────────────────
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class MessageDetail(BaseModel):
|
|
179
|
+
"""Individual message detail model based on actual API response."""
|
|
180
|
+
|
|
181
|
+
model_config = {"extra": "allow"}
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class MemoryDetail(BaseModel):
|
|
185
|
+
"""Individual memory detail model based on actual API response."""
|
|
186
|
+
|
|
187
|
+
model_config = {"extra": "allow"}
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class GetMessagesData(BaseModel):
|
|
191
|
+
"""Data model for get messages response based on actual API."""
|
|
192
|
+
|
|
193
|
+
message_detail_list: list[MessageDetail] = Field(
|
|
194
|
+
default_factory=list, alias="memory_detail_list", description="List of message details"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class SearchMemoryData(BaseModel):
|
|
199
|
+
"""Data model for search memory response based on actual API."""
|
|
200
|
+
|
|
201
|
+
memory_detail_list: list[MemoryDetail] = Field(
|
|
202
|
+
default_factory=list, alias="memory_detail_list", description="List of memory details"
|
|
203
|
+
)
|
|
204
|
+
message_detail_list: list[MessageDetail] | None = Field(
|
|
205
|
+
None, alias="message_detail_list", description="List of message details (usually None)"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class AddMessageData(BaseModel):
|
|
210
|
+
"""Data model for add message response based on actual API."""
|
|
211
|
+
|
|
212
|
+
success: bool = Field(..., description="Operation success status")
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ─── MemOS Response Models (Similar to OpenAI ChatCompletion) ──────────────────
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class MemOSGetMessagesResponse(BaseModel):
|
|
219
|
+
"""Response model for get messages operation based on actual API."""
|
|
220
|
+
|
|
221
|
+
code: int = Field(..., description="Response status code")
|
|
222
|
+
message: str = Field(..., description="Response message")
|
|
223
|
+
data: GetMessagesData = Field(..., description="Messages data")
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def messages(self) -> list[MessageDetail]:
|
|
227
|
+
"""Convenient access to message list."""
|
|
228
|
+
return self.data.message_detail_list
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class MemOSSearchResponse(BaseModel):
|
|
232
|
+
"""Response model for search memory operation based on actual API."""
|
|
233
|
+
|
|
234
|
+
code: int = Field(..., description="Response status code")
|
|
235
|
+
message: str = Field(..., description="Response message")
|
|
236
|
+
data: SearchMemoryData = Field(..., description="Search results data")
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def memories(self) -> list[MemoryDetail]:
|
|
240
|
+
"""Convenient access to memory list."""
|
|
241
|
+
return self.data.memory_detail_list
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class MemOSAddResponse(BaseModel):
|
|
245
|
+
"""Response model for add message operation based on actual API."""
|
|
246
|
+
|
|
247
|
+
code: int = Field(..., description="Response status code")
|
|
248
|
+
message: str = Field(..., description="Response message")
|
|
249
|
+
data: AddMessageData = Field(..., description="Add operation data")
|
|
250
|
+
|
|
251
|
+
@property
|
|
252
|
+
def success(self) -> bool:
|
|
253
|
+
"""Convenient access to success status."""
|
|
254
|
+
return self.data.success
|
|
@@ -1,14 +1,11 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import time
|
|
2
3
|
import traceback
|
|
3
4
|
|
|
4
|
-
from
|
|
5
|
-
from typing import Annotated
|
|
6
|
-
|
|
7
|
-
from fastapi import APIRouter, Depends, HTTPException
|
|
5
|
+
from fastapi import APIRouter, HTTPException
|
|
8
6
|
from fastapi.responses import StreamingResponse
|
|
9
7
|
|
|
10
8
|
from memos.api.config import APIConfig
|
|
11
|
-
from memos.api.context.dependencies import G, get_g_object
|
|
12
9
|
from memos.api.product_models import (
|
|
13
10
|
BaseResponse,
|
|
14
11
|
ChatCompleteRequest,
|
|
@@ -79,24 +76,19 @@ def set_config(config):
|
|
|
79
76
|
|
|
80
77
|
|
|
81
78
|
@router.post("/users/register", summary="Register a new user", response_model=UserRegisterResponse)
|
|
82
|
-
def register_user(user_req: UserRegisterRequest
|
|
79
|
+
def register_user(user_req: UserRegisterRequest):
|
|
83
80
|
"""Register a new user with configuration and default cube."""
|
|
84
81
|
try:
|
|
85
|
-
# Set request-related information in g object
|
|
86
|
-
g.user_id = user_req.user_id
|
|
87
|
-
g.action = "user_register"
|
|
88
|
-
g.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
89
|
-
|
|
90
|
-
logger.info(f"Starting user registration for user_id: {user_req.user_id}")
|
|
91
|
-
logger.info(f"Request trace_id: {g.trace_id}")
|
|
92
|
-
logger.info(f"Request timestamp: {g.timestamp}")
|
|
93
|
-
|
|
94
82
|
# Get configuration for the user
|
|
83
|
+
time_start_register = time.time()
|
|
95
84
|
user_config, default_mem_cube = APIConfig.create_user_config(
|
|
96
85
|
user_name=user_req.user_id, user_id=user_req.user_id
|
|
97
86
|
)
|
|
98
87
|
logger.info(f"user_config: {user_config.model_dump(mode='json')}")
|
|
99
88
|
logger.info(f"default_mem_cube: {default_mem_cube.config.model_dump(mode='json')}")
|
|
89
|
+
logger.info(
|
|
90
|
+
f"time register api : create user config time user_id: {user_req.user_id} time is: {time.time() - time_start_register}"
|
|
91
|
+
)
|
|
100
92
|
mos_product = get_mos_product_instance()
|
|
101
93
|
|
|
102
94
|
# Register user with default config and mem cube
|
|
@@ -106,8 +98,11 @@ def register_user(user_req: UserRegisterRequest, g: Annotated[G, Depends(get_g_o
|
|
|
106
98
|
interests=user_req.interests,
|
|
107
99
|
config=user_config,
|
|
108
100
|
default_mem_cube=default_mem_cube,
|
|
101
|
+
mem_cube_id=user_req.mem_cube_id,
|
|
102
|
+
)
|
|
103
|
+
logger.info(
|
|
104
|
+
f"time register api : register time user_id: {user_req.user_id} time is: {time.time() - time_start_register}"
|
|
109
105
|
)
|
|
110
|
-
|
|
111
106
|
if result["status"] == "success":
|
|
112
107
|
return UserRegisterResponse(
|
|
113
108
|
message="User registered successfully",
|
|
@@ -194,6 +189,7 @@ def get_all_memories(memory_req: GetMemoryRequest):
|
|
|
194
189
|
def create_memory(memory_req: MemoryCreateRequest):
|
|
195
190
|
"""Create a new memory for a specific user."""
|
|
196
191
|
try:
|
|
192
|
+
time_start_add = time.time()
|
|
197
193
|
mos_product = get_mos_product_instance()
|
|
198
194
|
mos_product.add(
|
|
199
195
|
user_id=memory_req.user_id,
|
|
@@ -203,6 +199,10 @@ def create_memory(memory_req: MemoryCreateRequest):
|
|
|
203
199
|
mem_cube_id=memory_req.mem_cube_id,
|
|
204
200
|
source=memory_req.source,
|
|
205
201
|
user_profile=memory_req.user_profile,
|
|
202
|
+
session_id=memory_req.session_id,
|
|
203
|
+
)
|
|
204
|
+
logger.info(
|
|
205
|
+
f"time add api : add time user_id: {memory_req.user_id} time is: {time.time() - time_start_add}"
|
|
206
206
|
)
|
|
207
207
|
return SimpleResponse(message="Memory created successfully")
|
|
208
208
|
|
|
@@ -217,12 +217,17 @@ def create_memory(memory_req: MemoryCreateRequest):
|
|
|
217
217
|
def search_memories(search_req: SearchRequest):
|
|
218
218
|
"""Search memories for a specific user."""
|
|
219
219
|
try:
|
|
220
|
+
time_start_search = time.time()
|
|
220
221
|
mos_product = get_mos_product_instance()
|
|
221
222
|
result = mos_product.search(
|
|
222
223
|
query=search_req.query,
|
|
223
224
|
user_id=search_req.user_id,
|
|
224
225
|
install_cube_ids=[search_req.mem_cube_id] if search_req.mem_cube_id else None,
|
|
225
226
|
top_k=search_req.top_k,
|
|
227
|
+
session_id=search_req.session_id,
|
|
228
|
+
)
|
|
229
|
+
logger.info(
|
|
230
|
+
f"time search api : add time user_id: {search_req.user_id} time is: {time.time() - time_start_search}"
|
|
226
231
|
)
|
|
227
232
|
return SearchResponse(message="Search completed successfully", data=result)
|
|
228
233
|
|
|
@@ -250,6 +255,7 @@ def chat(chat_req: ChatRequest):
|
|
|
250
255
|
history=chat_req.history,
|
|
251
256
|
internet_search=chat_req.internet_search,
|
|
252
257
|
moscube=chat_req.moscube,
|
|
258
|
+
session_id=chat_req.session_id,
|
|
253
259
|
)
|
|
254
260
|
|
|
255
261
|
except Exception as e:
|
|
@@ -294,6 +300,7 @@ def chat_complete(chat_req: ChatCompleteRequest):
|
|
|
294
300
|
base_prompt=chat_req.base_prompt,
|
|
295
301
|
top_k=chat_req.top_k,
|
|
296
302
|
threshold=chat_req.threshold,
|
|
303
|
+
session_id=chat_req.session_id,
|
|
297
304
|
)
|
|
298
305
|
|
|
299
306
|
# Return the complete response
|
memos/api/start_api.py
CHANGED
|
@@ -421,3 +421,13 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|
|
421
421
|
status_code=500,
|
|
422
422
|
content={"code": 500, "message": str(exc), "data": None},
|
|
423
423
|
)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
if __name__ == "__main__":
|
|
427
|
+
import argparse
|
|
428
|
+
|
|
429
|
+
parser = argparse.ArgumentParser()
|
|
430
|
+
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
|
431
|
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
|
|
432
|
+
parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development")
|
|
433
|
+
args = parser.parse_args()
|
memos/configs/graph_db.py
CHANGED
|
@@ -140,6 +140,10 @@ class NebulaGraphDBConfig(BaseGraphDBConfig):
|
|
|
140
140
|
"If False: use a single shared database with logical isolation by user_name."
|
|
141
141
|
),
|
|
142
142
|
)
|
|
143
|
+
max_client: int = Field(
|
|
144
|
+
default=1000,
|
|
145
|
+
description=("max_client"),
|
|
146
|
+
)
|
|
143
147
|
embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding")
|
|
144
148
|
|
|
145
149
|
@model_validator(mode="after")
|
memos/configs/mem_scheduler.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import Any, ClassVar
|
|
|
6
6
|
from pydantic import ConfigDict, Field, field_validator, model_validator
|
|
7
7
|
|
|
8
8
|
from memos.configs.base import BaseConfig
|
|
9
|
-
from memos.mem_scheduler.general_modules.misc import DictConversionMixin
|
|
9
|
+
from memos.mem_scheduler.general_modules.misc import DictConversionMixin, EnvConfigMixin
|
|
10
10
|
from memos.mem_scheduler.schemas.general_schemas import (
|
|
11
11
|
BASE_DIR,
|
|
12
12
|
DEFAULT_ACT_MEM_DUMP_PATH,
|
|
@@ -64,6 +64,19 @@ class GeneralSchedulerConfig(BaseSchedulerConfig):
|
|
|
64
64
|
default=20, description="Capacity of the activation memory monitor"
|
|
65
65
|
)
|
|
66
66
|
|
|
67
|
+
# Database configuration for ORM persistence
|
|
68
|
+
db_path: str | None = Field(
|
|
69
|
+
default=None,
|
|
70
|
+
description="Path to SQLite database file for ORM persistence. If None, uses default scheduler_orm.db",
|
|
71
|
+
)
|
|
72
|
+
db_url: str | None = Field(
|
|
73
|
+
default=None,
|
|
74
|
+
description="Database URL for ORM persistence (e.g., mysql://user:pass@host/db). Takes precedence over db_path",
|
|
75
|
+
)
|
|
76
|
+
enable_orm_persistence: bool = Field(
|
|
77
|
+
default=True, description="Whether to enable ORM-based persistence for monitors"
|
|
78
|
+
)
|
|
79
|
+
|
|
67
80
|
|
|
68
81
|
class SchedulerConfigFactory(BaseConfig):
|
|
69
82
|
"""Factory class for creating scheduler configurations."""
|
|
@@ -74,6 +87,7 @@ class SchedulerConfigFactory(BaseConfig):
|
|
|
74
87
|
model_config = ConfigDict(extra="forbid", strict=True)
|
|
75
88
|
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
76
89
|
"general_scheduler": GeneralSchedulerConfig,
|
|
90
|
+
"optimized_scheduler": GeneralSchedulerConfig, # optimized_scheduler uses same config as general_scheduler
|
|
77
91
|
}
|
|
78
92
|
|
|
79
93
|
@field_validator("backend")
|
|
@@ -94,6 +108,8 @@ class SchedulerConfigFactory(BaseConfig):
|
|
|
94
108
|
# ************************* Auth *************************
|
|
95
109
|
class RabbitMQConfig(
|
|
96
110
|
BaseConfig,
|
|
111
|
+
DictConversionMixin,
|
|
112
|
+
EnvConfigMixin,
|
|
97
113
|
):
|
|
98
114
|
host_name: str = Field(default="", description="Endpoint for RabbitMQ instance access")
|
|
99
115
|
user_name: str = Field(default="", description="Static username for RabbitMQ instance")
|
|
@@ -110,7 +126,7 @@ class RabbitMQConfig(
|
|
|
110
126
|
)
|
|
111
127
|
|
|
112
128
|
|
|
113
|
-
class GraphDBAuthConfig(BaseConfig):
|
|
129
|
+
class GraphDBAuthConfig(BaseConfig, DictConversionMixin, EnvConfigMixin):
|
|
114
130
|
uri: str = Field(
|
|
115
131
|
default="bolt://localhost:7687",
|
|
116
132
|
description="URI for graph database access (e.g., bolt://host:port)",
|
|
@@ -127,7 +143,7 @@ class GraphDBAuthConfig(BaseConfig):
|
|
|
127
143
|
)
|
|
128
144
|
|
|
129
145
|
|
|
130
|
-
class OpenAIConfig(BaseConfig):
|
|
146
|
+
class OpenAIConfig(BaseConfig, DictConversionMixin, EnvConfigMixin):
|
|
131
147
|
api_key: str = Field(default="", description="API key for OpenAI service")
|
|
132
148
|
base_url: str = Field(default="", description="Base URL for API endpoint")
|
|
133
149
|
default_model: str = Field(default="", description="Default model to use")
|
|
@@ -183,6 +199,25 @@ class AuthConfig(BaseConfig, DictConversionMixin):
|
|
|
183
199
|
"Please use YAML (.yaml, .yml) or JSON (.json) files."
|
|
184
200
|
)
|
|
185
201
|
|
|
202
|
+
@classmethod
|
|
203
|
+
def from_local_env(cls) -> "AuthConfig":
|
|
204
|
+
"""Creates an AuthConfig instance by loading configuration from environment variables.
|
|
205
|
+
|
|
206
|
+
This method loads configuration for all nested components (RabbitMQ, OpenAI, GraphDB)
|
|
207
|
+
from their respective environment variables using each component's specific prefix.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
AuthConfig: Configured instance with values from environment variables
|
|
211
|
+
|
|
212
|
+
Raises:
|
|
213
|
+
ValueError: If any required environment variables are missing
|
|
214
|
+
"""
|
|
215
|
+
return cls(
|
|
216
|
+
rabbitmq=RabbitMQConfig.from_env(),
|
|
217
|
+
openai=OpenAIConfig.from_env(),
|
|
218
|
+
graph_db=GraphDBAuthConfig.from_env(),
|
|
219
|
+
)
|
|
220
|
+
|
|
186
221
|
def set_openai_config_to_environment(self):
|
|
187
222
|
# Set environment variables
|
|
188
223
|
os.environ["OPENAI_API_KEY"] = self.openai.api_key
|